diff --git a/src/llm_core.py b/src/llm_core.py index 00cff41..ca34437 100644 --- a/src/llm_core.py +++ b/src/llm_core.py @@ -8,6 +8,7 @@ import hashlib import threading from fastapi import HTTPException from typing import Optional, Dict, List +from src.model_context import get_context_length, DEFAULT_CONTEXT from urllib.parse import urlparse logger = logging.getLogger(__name__) @@ -238,7 +239,19 @@ def _build_ollama_payload( max_tokens: int, stream: bool = False, tools: Optional[List[Dict]] = None, + num_ctx: Optional[int] = None, ) -> Dict: + """Build the JSON payload for Ollama's /api/chat endpoint. + + ``num_ctx`` sets the input context window. Ollama defaults to 2048 + when the option is omitted, so a model with a larger advertised + window is silently truncated there, and a model with a smaller one + gets an oversized window it can't service. Pass the discovered + context length through ``num_ctx``; this builder only emits it when + the value is trusted (not the ``DEFAULT_CONTEXT`` fallback), so we + don't guess for unknown models but do tell Ollama the real window + when we know it — even if it's smaller than 2048. + """ payload: Dict = { "model": model, "messages": _ollama_normalize_tool_messages(messages), @@ -249,6 +262,8 @@ def _build_ollama_payload( options["temperature"] = temperature if max_tokens and max_tokens > 0: options["num_predict"] = max_tokens + if num_ctx is not None and num_ctx > 0 and num_ctx != DEFAULT_CONTEXT: + options["num_ctx"] = num_ctx if options: payload["options"] = options if tools: @@ -675,7 +690,10 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL payload = _build_anthropic_payload(model, messages_copy, temperature, max_tokens) elif provider == "ollama": target_url = _normalize_ollama_url(url) - payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=False) + payload = _build_ollama_payload( + model, messages_copy, temperature, max_tokens, + stream=False, num_ctx=get_context_length(url, model), + ) else: target_url = url payload = { @@ -790,7 +808,10 @@ async def llm_call_async( h = {"Content-Type": "application/json"} if headers: h.update(headers) - payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=False) + payload = _build_ollama_payload( + model, messages_copy, temperature, max_tokens, + stream=False, num_ctx=get_context_length(url, model), + ) else: target_url = url h = _provider_headers(provider, headers) @@ -888,7 +909,10 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl h = {"Content-Type": "application/json"} if headers: h.update(headers) - payload = _build_ollama_payload(model, messages_copy, temperature, max_tokens, stream=True, tools=tools) + payload = _build_ollama_payload( + model, messages_copy, temperature, max_tokens, + stream=True, tools=tools, num_ctx=get_context_length(url, model), + ) else: target_url = url payload = { diff --git a/tests/test_llm_core_ollama.py b/tests/test_llm_core_ollama.py index 59fa8ce..b334f26 100644 --- a/tests/test_llm_core_ollama.py +++ b/tests/test_llm_core_ollama.py @@ -113,3 +113,130 @@ def test_ollama_payload_tolerates_malformed_arguments(): payload = llm_core._build_ollama_payload("m", msgs, temperature=0.0, max_tokens=0) # Falls back to an empty object rather than raising. assert payload["messages"][0]["tool_calls"][0]["function"]["arguments"] == {} + + +# --------------------------------------------------------------------------- +# num_ctx threading (issue #909) +# +# Ollama defaults num_ctx to 2048 when the option is omitted, so prompts +# going to any Ollama backend are silently truncated there regardless of +# the model's actual capability. The builder must accept a discovered +# context length and emit options.num_ctx — but only when the value is +# trusted and larger than 2048. +# --------------------------------------------------------------------------- + + +def test_build_ollama_payload_emits_num_ctx_when_known_and_large(): + """num_ctx passes through when the caller supplies a trusted value + larger than Ollama's 2048 default.""" + payload = llm_core._build_ollama_payload( + "kimi-k2", [{"role": "user", "content": "x"}], + temperature=0.5, max_tokens=100, num_ctx=131072, + ) + assert payload["options"]["num_ctx"] == 131072 + + +def test_build_ollama_payload_emits_num_ctx_for_small_known_models(): + """A model with a real context smaller than Ollama's 2048 default + would OOM if Ollama used its own default. Pass the real value.""" + payload = llm_core._build_ollama_payload( + "tiny-llm", [{"role": "user", "content": "x"}], + temperature=0.5, max_tokens=100, num_ctx=1024, + ) + assert payload["options"]["num_ctx"] == 1024 + + +def test_build_ollama_payload_omits_none_and_zero(): + """None means the caller didn't look it up; 0 is nonsensical. + Both should be dropped, not emitted as a 0-context request.""" + for ctx in (None, 0): + payload = llm_core._build_ollama_payload( + "m", [{"role": "user", "content": "x"}], + temperature=0.5, max_tokens=100, num_ctx=ctx, + ) + assert "num_ctx" not in payload.get("options", {}), ( + f"num_ctx={ctx} should not be emitted" + ) + + +def test_build_ollama_payload_omits_default_context_fallback(): + """get_context_length returns DEFAULT_CONTEXT (128000) when it can't + discover the model's actual window. Emitting that as num_ctx would + lie to Ollama for unknown models, so the builder filters it out.""" + from src.model_context import DEFAULT_CONTEXT + payload = llm_core._build_ollama_payload( + "unknown-llm-9001", [{"role": "user", "content": "x"}], + temperature=0.5, max_tokens=100, num_ctx=DEFAULT_CONTEXT, + ) + assert "num_ctx" not in payload.get("options", {}) + + +def test_llm_call_threads_discovered_num_ctx(monkeypatch): + """When get_context_length returns a real, large value, it ends up + in the outgoing Ollama request as options.num_ctx (issue #909).""" + monkeypatch.setattr(llm_core, "get_context_length", + lambda url, model: 32768) + + seen = {} + + def fake_post(url, headers=None, json=None, timeout=None): + seen["json"] = json + request = httpx.Request("POST", url) + return httpx.Response( + 200, request=request, + json={"message": {"content": "OK"}, "done": True}, + ) + + monkeypatch.setattr(llm_core.httpx, "post", fake_post) + + llm_core.llm_call( + "https://ollama.com/api", + "kimi-k2", + [{"role": "user", "content": "Say OK"}], + temperature=0.2, + max_tokens=7, + ) + + assert seen["json"]["options"]["num_ctx"] == 32768 + + +def test_stream_llm_threads_discovered_num_ctx(monkeypatch): + """stream_llm goes through the same ollama branch and must also + pass num_ctx through to the streaming request body.""" + import asyncio + + seen = {} + + def spy_build_ollama_payload(*args, **kwargs): + seen["num_ctx"] = kwargs.get("num_ctx") + seen["stream"] = kwargs.get("stream") + return { + "model": "kimi-k2", + "messages": [{"role": "user", "content": "x"}], + "stream": True, + } + + monkeypatch.setattr(llm_core, "get_context_length", + lambda url, model: 32768) + monkeypatch.setattr(llm_core, "_build_ollama_payload", + spy_build_ollama_payload) + + # Short-circuit before the actual HTTP call: host is "dead" → yields + # an error SSE chunk and returns. The call to _build_ollama_payload + # still happens before the host check, so we can inspect it. + monkeypatch.setattr(llm_core, "_is_host_dead", lambda url: True) + + async def collect(): + return [chunk async for chunk in llm_core.stream_llm( + "https://ollama.com/api", + "kimi-k2", + [{"role": "user", "content": "Say OK"}], + temperature=0.2, + max_tokens=7, + )] + + out = asyncio.run(collect()) + + assert seen["num_ctx"] == 32768 + assert seen["stream"] is True + assert out # we got the SSE error chunk