Merge remote-tracking branch 'origin/dev'

This commit is contained in:
pewdiepie-archdaemon
2026-06-05 12:14:34 +09:00
154 changed files with 7750 additions and 2496 deletions

View File

@@ -177,6 +177,7 @@ TOOL_SECTIONS = {
<shell command>
```
Run any shell command. Output is returned to you. Use for: installing packages, checking files, git, curl, system info, etc.
NEVER use bash to create or change files — no `>`/`>>` redirects, no heredocs (`cat > f << 'EOF'`), no `tee`, `sed -i`, `awk -i`, no `python -c` that writes. To CREATE or fully rewrite a file use `write_file`; to change part of an existing file use `edit_file`. Those show a diff and are the ONLY allowed way to write files. (bash is for read-only inspection: `ls`, `cat` to READ, `grep`, `git status`/`git diff`, builds, installs.)
For LONG-running commands (package installs, pip/npm, ffmpeg, model downloads, training, builds — anything that may take more than ~20s), make the FIRST line `#!bg` to run it in the BACKGROUND. You get a job id back immediately and are automatically re-invoked with the full output when it finishes — so you never block the chat waiting. Example:
```bash
#!bg
@@ -220,6 +221,12 @@ Read a file and return its contents.""",
```
Write content to a file. First line is the path, rest is the content.""",
"edit_file": """\
```edit_file
{"path": "<file path>", "old_string": "<exact text to replace>", "new_string": "<replacement>", "replace_all": false}
```
Edit an EXISTING file by exact string replacement. PREFER this over bash (sed/echo/redirects) for changing files — it shows a before/after diff. `old_string` must match the file exactly and be unique unless `replace_all` is true. Use write_file to create a new file.""",
"create_document": """\
```create_document
<title>
@@ -236,7 +243,7 @@ old text to find
new replacement text
<<<END>>>
```
PREFERRED way to change an existing document. Find exact text and replace it. Multiple FIND/REPLACE blocks per call OK. Use this for any edit smaller than a full rewrite — adding a function, fixing a bug, tweaking a section, renaming things. **If a document is open in the editor, treat it as the user's current context: don't ask which file they mean, and don't create a new one — just edit_document the active one.** Do NOT re-send the whole file with update_document for small changes.""",
Edit a document OPEN IN THE EDITOR PANEL — NOT a file on disk. For files on disk (home folder, project files, any real path like ~/sweden.txt) use `edit_file` instead. Find exact text and replace it. Multiple FIND/REPLACE blocks per call OK. Use for any edit smaller than a full rewrite. **If a document is open in the editor, treat it as the user's current context: don't ask which file they mean, and don't create a new one — just edit_document the active one.** Do NOT re-send the whole file with update_document for small changes.""",
"update_document": """\
```update_document
@@ -462,13 +469,14 @@ _API_HOSTS = frozenset([
"api.together.xyz", "api.fireworks.ai",
"api.perplexity.ai", "api.x.ai",
"ollama.com", "api.venice.ai",
"api.githubcopilot.com",
# Local OpenAI-compatible endpoints (llama.cpp, vLLM, LM Studio, etc.).
# Without these, `_is_api_model` falls back to keyword sniffing on the
# model name, so well-behaved local servers don't get native tool
# schemas and the agent silently degrades to fenced-block parsing.
"localhost", "127.0.0.1", "host.docker.internal",
])
_MCP_KEYWORDS = frozenset(["browse", "browser", "website", "calendar", "event", "email",
_MCP_KEYWORDS = frozenset(["mcp", "browse", "browser", "website", "calendar", "event", "email",
"gmail", "screenshot", "navigate", "click", "miniflux", "rss", "feed"])
_ADMIN_SCHEMA_NAMES = frozenset([
"manage_session", "manage_skills", "manage_tasks",
@@ -1380,6 +1388,7 @@ async def stream_agent_loop(
owner: Optional[str] = None,
relevant_tools: Optional[Set[str]] = None,
fallbacks: Optional[List[tuple]] = None,
workspace: Optional[str] = None,
_is_teacher_run: bool = False,
) -> AsyncGenerator[str, None]:
"""Streaming agent loop generator.
@@ -1546,6 +1555,27 @@ async def stream_agent_loop(
compact=_is_api_model,
owner=owner,
)
if workspace:
# PREPEND (not append) so it dominates the large base prompt — appended
# at the end, small models ignored it and asked the user for code. The
# folder IS the project; the agent must explore it, not ask.
_ws_note = (
f"## ACTIVE WORKSPACE — READ FIRST\n"
f"The user is working in this folder: {workspace}\n"
f"It IS the project. bash/python run with cwd set here and "
f"read_file/write_file are confined to it (paths outside are rejected).\n"
f"When the user says \"the code\" / \"this project\" / \"the workspace\" "
f"or asks to review/find/edit something WITHOUT a path, they mean THIS "
f"folder. Do NOT ask the user for code or a path, and do NOT read a file "
f"literally named \"workspace\". ALWAYS start by exploring it yourself: "
f"run `bash` → `git ls-files` (or `ls -R`) to see the files, then "
f"read_file the relevant ones by path RELATIVE to the workspace."
)
if messages and messages[0].get("role") == "system":
messages[0]["content"] = _ws_note + "\n\n" + (messages[0].get("content") or "")
else:
messages.insert(0, {"role": "system", "content": _ws_note})
logger.info("[workspace] active for this turn: %s", workspace)
prep_timings["prompt_build"] = time.time() - _t2
_t3 = time.time()
@@ -1658,6 +1688,11 @@ async def stream_agent_loop(
_doc_opened = False # whether doc_stream_open was sent
_doc_last_len = 0 # last content length sent
# Set when the loop runs out of rounds while the agent was still actively
# using tools — i.e. it was cut off, not finished. Drives a "Continue" event
# so the user can resume instead of the turn silently stalling.
_exhausted_rounds = False
for round_num in range(1, max_rounds + 1):
round_response = ""
round_reasoning = "" # reasoning_content deltas (DeepSeek-thinking, vLLM --reasoning-parser)
@@ -2167,6 +2202,7 @@ async def stream_agent_loop(
disabled_tools=disabled_tools,
owner=owner,
progress_cb=_push_progress,
workspace=workspace,
)
finally:
# Sentinel so the drainer knows to stop.
@@ -2282,6 +2318,9 @@ async def stream_agent_loop(
if result.get("images"):
img = result["images"][0]
tool_output_data["screenshot"] = f"data:{img['mimeType']};base64,{img['data']}"
# Forward a file-write diff for inline before/after rendering
if "diff" in result:
tool_output_data["diff"] = result["diff"]
yield f'data: {json.dumps(tool_output_data)}\n\n'
# Native document tools open in the editor + carry the REAL doc id.
@@ -2324,6 +2363,10 @@ async def stream_agent_loop(
if result.get("doc_id"):
tool_event["doc_id"] = result["doc_id"]
tool_event["doc_title"] = result.get("title", "")
# Persist the file-write/edit diff so it re-renders on reload — without
# this the diff shows live but vanishes from saved history.
if result.get("diff"):
tool_event["diff"] = result["diff"]
tool_events.append(tool_event)
if block.tool_type in _VERIFIER_EFFECTFUL_TOOLS:
_effectful_used = True
@@ -2348,6 +2391,20 @@ async def stream_agent_loop(
# Separator in accumulated response
full_response += "\n\n"
else:
# The for-loop completed every allowed round WITHOUT an early `break`
# (a `break` fires on "done", budget, or error). Reaching this `else`
# means the agent kept working until it ran out of rounds — so offer
# Continue instead of stopping silently. This catches ALL exhaustion
# paths, including a verifier `continue` on the final round (the old
# bottom-of-loop flag missed those).
_exhausted_rounds = True
# If the loop hit the round cap while still working, tell the client so it
# can show a "Continue" affordance instead of the turn just stopping.
if _exhausted_rounds:
logger.info("[agent] round cap (%d) reached mid-task — emitting rounds_exhausted", max_rounds)
yield f'data: {json.dumps({"type": "rounds_exhausted", "rounds": max_rounds})}\n\n'
# If the response is completely empty and no tools were executed,
# yield a fallback message so the user is not left hanging.

View File

@@ -26,7 +26,8 @@ MAX_OUTPUT_CHARS = 10_000
MAX_READ_CHARS = 20_000
# Tool types that trigger execution
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file",
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
"grep", "glob", "ls",
"create_document", "update_document", "edit_document",
"search_chats",
"chat_with_model", "create_session", "list_sessions",

View File

@@ -9,6 +9,7 @@ from src.constants import (
SESSIONS_FILE, DEFAULT_HOST, OPENAI_API_KEY
)
from src.memory import MemoryManager
from src.memory_provider import MemoryProviderRegistry, NativeMemoryProvider
from services.memory.skills import SkillsManager
from core.session_manager import SessionManager
from core.models import set_session_manager
@@ -73,6 +74,10 @@ def initialize_managers(base_dir: str, rag_manager=None) -> Dict[str, Any]:
logger.warning(f"MemoryVectorStore DEGRADED: {e}")
memory_vector = None
memory_provider_registry = MemoryProviderRegistry([
NativeMemoryProvider(memory_manager, memory_vector),
])
# Initialize processors
chat_processor = ChatProcessor(memory_manager, personal_docs_manager, memory_vector=memory_vector, skills_manager=skills_manager)
research_handler = ResearchHandler()
@@ -99,6 +104,7 @@ def initialize_managers(base_dir: str, rag_manager=None) -> Dict[str, Any]:
return {
"memory_manager": memory_manager,
"memory_vector": memory_vector,
"memory_provider_registry": memory_provider_registry,
"skills_manager": skills_manager,
"session_manager": session_manager,
"upload_handler": upload_handler,

View File

@@ -265,6 +265,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
existing.all_day = all_day
existing.is_utc = row_is_utc
existing.rrule = rrule
existing.origin = "caldav"
else:
new_ev = CalendarEvent(
uid=uid_val,
@@ -277,6 +278,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
all_day=all_day,
is_utc=row_is_utc,
rrule=rrule,
origin="caldav",
)
db.add(new_ev)
pending[uid_val] = new_ev
@@ -286,8 +288,13 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
# Prune locally-cached CalDAV events that vanished
# upstream (only within our sync window — events outside
# the window aren't in `objs`, so we'd false-delete them).
# Only rows we previously pulled from the server (origin=="caldav")
# are prunable; locally-created events (agent / email triage / a
# UI event whose write-back failed) carry origin NULL and must
# never be deleted just because the server didn't return them.
stale = db.query(CalendarEvent).filter(
CalendarEvent.calendar_id == local_cal.id,
CalendarEvent.origin == "caldav",
CalendarEvent.dtstart >= start,
CalendarEvent.dtstart <= end,
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),

253
src/copilot.py Normal file
View File

@@ -0,0 +1,253 @@
# src/copilot.py
"""GitHub Copilot provider support.
Copilot exposes an OpenAI-compatible API at ``https://api.githubcopilot.com``
(``/chat/completions`` + ``/models``). Authentication is a GitHub OAuth
**device flow**: the user authorises a device code in their browser and we
receive a long-lived ``access_token`` that is sent directly as
``Authorization: Bearer <token>`` — there is no separate Copilot-token
exchange and no refresh (mirrors how editors / opencode talk to Copilot).
The only provider-specific wrinkle beyond the bearer token is a handful of
required request headers (API version, intent, an editor-style User-Agent,
and ``x-initiator`` for agent-vs-user request accounting). Those live in
:func:`copilot_headers`.
This module holds the constants + pure helpers; the HTTP device-flow calls
live in :mod:`routes.copilot_routes` so they can be auth-gated.
"""
import os
from typing import Dict, List, Optional
from urllib.parse import urlparse
import httpx
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
# GitHub OAuth client id used for the device flow. Copilot's token endpoint
# only accepts client ids that GitHub has allow-listed for Copilot access, so
# we reuse the public VS Code client id (the de-facto standard third-party
# clients use). Override via env if you register your own allow-listed app.
COPILOT_CLIENT_ID = os.environ.get(
"ODYSSEUS_COPILOT_CLIENT_ID", "01ab8ac9400c4e429b23"
)
# Dated API version header required by the Copilot API (models + chat).
COPILOT_API_VERSION = os.environ.get(
"ODYSSEUS_COPILOT_API_VERSION", "2026-06-01"
)
# Public Copilot API base. GitHub Enterprise uses ``copilot-api.<domain>``.
COPILOT_BASE = "https://api.githubcopilot.com"
# Copilot wants an editor-like User-Agent + integration id. These identify the
# client to GitHub; keep them stable.
COPILOT_USER_AGENT = os.environ.get(
"ODYSSEUS_COPILOT_USER_AGENT", "Odysseus/1.0"
)
COPILOT_INTEGRATION_ID = os.environ.get(
"ODYSSEUS_COPILOT_INTEGRATION_ID", "vscode-chat"
)
COPILOT_EDITOR_VERSION = os.environ.get(
"ODYSSEUS_COPILOT_EDITOR_VERSION", "Odysseus/1.0"
)
# OAuth scope requested during the device flow.
COPILOT_SCOPE = "read:user"
# Default GitHub host for the device flow (public github.com).
GITHUB_HOST = "github.com"
def device_code_url(host: str = GITHUB_HOST) -> str:
return f"https://{host}/login/device/code"
def access_token_url(host: str = GITHUB_HOST) -> str:
return f"https://{host}/login/oauth/access_token"
def normalize_domain(url: str) -> str:
"""Strip scheme/trailing slash from a GitHub Enterprise URL or domain."""
return (url or "").replace("https://", "").replace("http://", "").rstrip("/")
def enterprise_base(enterprise_url: Optional[str]) -> str:
"""Return the Copilot API base for a deployment.
Public github.com → ``https://api.githubcopilot.com``.
Enterprise <domain> → ``https://copilot-api.<domain>``.
"""
if not enterprise_url:
return COPILOT_BASE
return f"https://copilot-api.{normalize_domain(enterprise_url)}"
def is_copilot_base(url: Optional[str]) -> bool:
"""True if a base URL points at the Copilot API (public or enterprise)."""
if not url:
return False
try:
host = (urlparse(url).hostname or "").lower().rstrip(".")
except Exception:
return False
if not host:
return False
# Public: api.githubcopilot.com (or any *.githubcopilot.com).
if host == "githubcopilot.com" or host.endswith(".githubcopilot.com"):
return True
# Enterprise: copilot-api.<domain>.
if host.startswith("copilot-api."):
return True
return False
def copilot_headers(
api_key: Optional[str],
*,
agent: bool = False,
vision: bool = False,
) -> Dict[str, str]:
"""Build the Copilot-specific request headers.
Args:
api_key: the GitHub device-flow access token (sent as Bearer).
agent: request originates from the agent loop (a tool-driven turn)
rather than a direct user message. Sets ``x-initiator`` for
Copilot's agent-vs-user request accounting.
vision: the request carries an image part.
"""
headers: Dict[str, str] = {
"X-GitHub-Api-Version": COPILOT_API_VERSION,
"Openai-Intent": "conversation-edits",
"User-Agent": COPILOT_USER_AGENT,
"Editor-Version": COPILOT_EDITOR_VERSION,
"Copilot-Integration-Id": COPILOT_INTEGRATION_ID,
"x-initiator": "agent" if agent else "user",
}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if vision:
headers["Copilot-Vision-Request"] = "true"
return headers
# ---------------------------------------------------------------------------
# Device-flow OAuth (pure HTTP; orchestration lives in routes.copilot_routes)
# ---------------------------------------------------------------------------
def _oauth_post_headers() -> Dict[str, str]:
return {
"Accept": "application/json",
"Content-Type": "application/json",
"User-Agent": COPILOT_USER_AGENT,
}
def request_device_code(host: str = GITHUB_HOST, *, timeout: float = 10.0) -> Dict:
"""Start the device flow. Returns GitHub's
``{device_code, user_code, verification_uri, expires_in, interval}``.
"""
r = httpx.post(
device_code_url(host),
headers=_oauth_post_headers(),
json={"client_id": COPILOT_CLIENT_ID, "scope": COPILOT_SCOPE},
timeout=timeout,
)
r.raise_for_status()
return r.json()
def poll_access_token(host: str, device_code: str, *, timeout: float = 10.0) -> Dict:
"""Poll once for the access token. GitHub returns HTTP 200 with an
``error`` field (``authorization_pending``/``slow_down``) while the user
hasn't authorised yet, or ``{access_token, ...}`` once they have.
"""
r = httpx.post(
access_token_url(host),
headers=_oauth_post_headers(),
json={
"client_id": COPILOT_CLIENT_ID,
"device_code": device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
},
timeout=timeout,
)
r.raise_for_status()
return r.json()
def fetch_models(base: str, token: str, *, timeout: float = 15.0) -> List[Dict]:
"""Fetch Copilot's model catalogue, filtered to picker-enabled models.
Returns a list of ``{id, tool_calls, vision}`` dicts. Falls back to the
full list if no model advertises ``model_picker_enabled`` (defensive
against API-shape drift).
"""
url = base.rstrip("/") + "/models"
r = httpx.get(url, headers=copilot_headers(token), timeout=timeout)
r.raise_for_status()
data = (r.json() or {}).get("data") or []
def _parse(item: Dict) -> Optional[Dict]:
mid = item.get("id")
if not mid:
return None
supports = ((item.get("capabilities") or {}).get("supports")) or {}
return {
"id": mid,
"tool_calls": bool(supports.get("tool_calls")),
"vision": bool(supports.get("vision")),
"picker": bool(item.get("model_picker_enabled")),
}
parsed = [p for p in (_parse(it) for it in data) if p]
picker = [p for p in parsed if p["picker"]]
chosen = picker or parsed
for p in chosen:
p.pop("picker", None)
return chosen
# ---------------------------------------------------------------------------
# Per-request header flags
# ---------------------------------------------------------------------------
_IMAGE_PART_TYPES = ("image_url", "input_image", "image")
def request_flags(messages) -> tuple:
"""Derive ``(agent, vision)`` from an OpenAI-style message list.
Mirrors opencode's logic:
* ``agent`` — the last message is *not* a plain user message (i.e. it's a
tool result / assistant follow-up), so Copilot should treat the request
as agent-initiated for request accounting.
* ``vision`` — any message carries an image content part.
"""
msgs = messages or []
last = msgs[-1] if msgs else None
agent = bool(last) and last.get("role") != "user"
vision = False
for m in msgs:
content = m.get("content") if isinstance(m, dict) else None
if isinstance(content, list) and any(
isinstance(p, dict) and p.get("type") in _IMAGE_PART_TYPES for p in content
):
vision = True
break
return agent, vision
def apply_request_headers(headers: Dict[str, str], messages) -> Dict[str, str]:
"""Set ``x-initiator`` / ``Copilot-Vision-Request`` on a header dict based
on the outgoing messages. Mutates and returns ``headers``."""
agent, vision = request_flags(messages)
headers["x-initiator"] = "agent" if agent else "user"
if vision:
headers["Copilot-Vision-Request"] = "true"
return headers

View File

@@ -196,6 +196,8 @@ class DeepResearcher:
max_content_chars: int = 15000,
max_report_tokens: int = 8192,
extraction_timeout: int = 90,
planning_timeout: int = 90,
query_timeout: int = 120,
extraction_concurrency: int = 3,
min_rounds: int = 2,
max_empty_rounds: int = 2,
@@ -215,6 +217,8 @@ class DeepResearcher:
self.max_content_chars = max_content_chars
self.max_report_tokens = max_report_tokens
self.extraction_timeout = min(3600, max(15, int(extraction_timeout or 90)))
self.planning_timeout = min(3600, max(15, int(planning_timeout or 90)))
self.query_timeout = min(3600, max(15, int(query_timeout or 120)))
self.extraction_concurrency = min(12, max(1, int(extraction_concurrency or 3)))
self.min_rounds = min_rounds
self.max_empty_rounds = max_empty_rounds
@@ -395,7 +399,7 @@ class DeepResearcher:
[{"role": "user", "content": prompt}],
temperature=0.3,
max_tokens=1024,
timeout=30,
timeout=getattr(self, "planning_timeout", 90),
)
# Try to parse as JSON for structured plan
parsed = self._parse_json_object(response)
@@ -478,6 +482,7 @@ class DeepResearcher:
[{"role": "user", "content": prompt}],
temperature=0.5,
max_tokens=4096,
timeout=getattr(self, "query_timeout", 120),
)
queries = self._parse_json_array(response)
# Deduplicate

View File

@@ -194,6 +194,9 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
headers["x-api-key"] = api_key
headers["anthropic-version"] = "2023-06-01"
return headers
if provider == "copilot":
from src.copilot import copilot_headers
return copilot_headers(api_key)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if provider == "openrouter":

View File

@@ -67,7 +67,7 @@ _host_health_lock = threading.Lock()
_model_activity: Dict[str, float] = {}
def _model_activity_key(url: str, model: str) -> str:
return f"{(url or '').strip().rstrip()}|{(model or '').strip()}"
return f"{(url or '').strip()}|{(model or '').strip()}"
def note_model_activity(url: str, model: str):
"""Record that a real upstream request used this endpoint/model."""
@@ -317,6 +317,9 @@ def _detect_provider(url: str) -> str:
return "openrouter"
if _host_match(url, "groq.com"):
return "groq"
from src.copilot import is_copilot_base
if is_copilot_base(url):
return "copilot"
return "openai"
@@ -327,6 +330,14 @@ def _provider_headers(provider: str, headers: Optional[Dict] = None) -> Dict[str
if provider == "openrouter":
h.setdefault("HTTP-Referer", "https://github.com/pewdiepie-archdaemon/odysseus")
h.setdefault("X-OpenRouter-Title", "Odysseus")
if provider == "copilot":
# Ensure the Copilot-required headers are present even when the caller
# didn't pass pre-built headers (e.g. model listing). build_headers()
# already injects these for the live chat path; setdefault keeps any
# request-specific values (x-initiator/vision) the caller set.
from src.copilot import copilot_headers
for k, v in copilot_headers(None).items():
h.setdefault(k, v)
return h
@@ -340,6 +351,8 @@ def _provider_label(url: str) -> str:
if _host_match(url, "openai.com"): return "OpenAI"
if _host_match(url, "openrouter.ai"): return "OpenRouter"
if _host_match(url, "groq.com"): return "Groq"
from src.copilot import is_copilot_base
if is_copilot_base(url): return "GitHub Copilot"
if _host_match(url, "mistral.ai"): return "Mistral"
if _host_match(url, "deepseek.com"): return "DeepSeek"
if _host_match(url, "googleapis.com"): return "Google"
@@ -481,7 +494,7 @@ def _build_anthropic_payload(model, messages, temperature, max_tokens, stream=Fa
chat_messages = []
for m in messages:
if m.get("role") == "system":
system_parts.append(m["content"])
system_parts.append(m.get("content") or "")
elif m.get("role") == "tool":
# Convert OpenAI tool result to Anthropic format
chat_messages.append({
@@ -884,7 +897,7 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
non_sys = []
for m in messages_copy:
if m.get("role") == "system":
sys_parts.append(m["content"])
sys_parts.append(m.get('content') or '')
else:
non_sys.append(m)
if sys_parts:
@@ -911,6 +924,9 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
)
else:
target_url = url
if provider == "copilot":
from src.copilot import apply_request_headers
apply_request_headers(h, messages_copy)
payload = {
"model": model,
"messages": messages_copy,
@@ -1028,7 +1044,7 @@ async def llm_call_async(
non_sys = []
for m in messages_copy:
if m.get("role") == "system":
sys_parts.append(m["content"])
sys_parts.append(m.get('content') or '')
else:
non_sys.append(m)
if sys_parts:
@@ -1058,6 +1074,9 @@ async def llm_call_async(
else:
target_url = url
h = _provider_headers(provider, headers)
if provider == "copilot":
from src.copilot import apply_request_headers
apply_request_headers(h, messages_copy)
payload = {
"model": model,
"messages": messages_copy,
@@ -1088,6 +1107,9 @@ async def llm_call_async(
f"LLM async call to {target_url} failed in {duration:.2f}s "
f"(attempt {attempt}): HTTP {r.status_code} {friendly}"
)
if r.status_code in (429, 502, 503, 504) and attempt < max_retries:
await asyncio.sleep(LLMConfig.RETRY_DELAY)
continue
raise HTTPException(r.status_code, friendly)
logger.info(f"LLM async call to {target_url} succeeded in {duration:.2f}s (attempt {attempt})")
_clear_host_dead(target_url)
@@ -1109,7 +1131,9 @@ async def llm_call_async(
duration = time.time() - start
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
logger.warning(f"LLM async connect to {target_url} failed after {duration:.2f}s: {e}{_tail}")
raise HTTPException(503, f"Cannot reach {_host_key(target_url)}: {e}")
if _cooled or attempt >= max_retries:
raise HTTPException(503, f"Cannot reach {_host_key(target_url)}: {e}")
await asyncio.sleep(LLMConfig.RETRY_DELAY)
except (httpx.RequestError, httpx.HTTPStatusError) as e:
duration = time.time() - start
logger.warning(f"LLM async call attempt {attempt} failed after {duration:.2f}s: {e}")
@@ -1138,7 +1162,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
non_sys = []
for m in messages_copy:
if m.get("role") == "system":
sys_parts.append(m["content"])
sys_parts.append(m.get('content') or '')
else:
non_sys.append(m)
if sys_parts:
@@ -1177,6 +1201,9 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
if tools:
payload["tools"] = tools
h = _provider_headers(provider, headers)
if provider == "copilot":
from src.copilot import apply_request_headers
apply_request_headers(h, messages_copy)
# Short connect timeout: a reachable peer answers SYN in <100ms even on
# Tailscale. 3s is plenty; 30s let one dead upstream wedge the UI.
@@ -1358,6 +1385,8 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
# can detect thinking-in-progress (some models output </think> but no <think>)
_thinking_model = _supports_thinking(model)
_first_content_sent = False
_in_think_tag = False # True while consuming <think>…</think> content
_think_open_stripped = False # opening <think> tag already removed
def _emit_tool_calls():
"""Build the tool_calls event string if any were accumulated."""
@@ -1439,14 +1468,53 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
yield f'data: {json.dumps({"delta": reasoning, "thinking": True})}\n\n'
content = delta.get("content") or ""
if content:
# Some thinking backends start normal content with a
# stray closing tag. Repair only that shape; do not
# wrap every first token for model families like
# MiniMax, which often stream ordinary answers.
if _thinking_model and not _first_content_sent and content.lstrip().lower().startswith("</think"):
content = "<think>" + content
_first_content_sent = True
yield f'data: {json.dumps({"delta": content})}\n\n'
stripped = content.lstrip()
# Auto-detect <think>…</think> in content stream.
# Covers Qwen3-derived models (Qwopus, QwQ forks) whose
# names don't match _THINKING_MODEL_PATTERNS but still
# emit literal <think> markup via llama.cpp --jinja.
if not _first_content_sent and not _thinking_model and not _in_think_tag and stripped.lower().startswith("<think"):
_thinking_model = True
_in_think_tag = True
if _in_think_tag:
close_idx = content.lower().find("</think>")
if close_idx != -1:
# Split: up-to-</think> → thinking, remainder → content
think_part = content[:close_idx]
if not _think_open_stripped:
# Strip the opening <think[...] > from the first chunk.
# Use a dedicated flag — _first_content_sent stays False
# throughout the think block, so it must not be reused.
tag_end = think_part.lower().find(">")
if tag_end != -1:
think_part = think_part[tag_end + 1:]
_think_open_stripped = True
regular_part = content[close_idx + len("</think>"):]
_in_think_tag = False
if think_part:
yield f'data: {json.dumps({"delta": think_part, "thinking": True})}\n\n'
if regular_part:
_first_content_sent = True
yield f'data: {json.dumps({"delta": regular_part})}\n\n'
else:
# Still inside <think>: route to thinking channel
if not _think_open_stripped:
# Strip the opening <think[...] > tag (first chunk only)
tag_end = stripped.lower().find(">")
if tag_end != -1:
content = stripped[tag_end + 1:]
_think_open_stripped = True
if content:
yield f'data: {json.dumps({"delta": content, "thinking": True})}\n\n'
else:
# Some thinking backends start normal content with a
# stray closing tag. Repair only that shape; do not
# wrap every first token for model families like
# MiniMax, which often stream ordinary answers.
if _thinking_model and not _first_content_sent and stripped.lower().startswith("</think"):
content = "<think>" + content
_first_content_sent = True
yield f'data: {json.dumps({"delta": content})}\n\n'
# Native tool calls — accumulate across chunks
for tc in delta.get("tool_calls") or []:
if tc is None:

View File

@@ -8,6 +8,7 @@ Each server exposes tools that are made available to the agent loop.
import json
import logging
import os
import re
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@@ -30,6 +31,64 @@ def _format_mcp_connection_error(name: str, command: str = "", args: Optional[Li
return raw_error
# Caps for rendering untrusted MCP tool schemas into the agent prompt (issue #2660).
# MCP servers are third-party/user-added, so field names and parameter counts are
# untrusted input — bound them so an odd or hostile schema cannot distort the prompt.
_MCP_PARAM_MAX = 12 # max params rendered per tool
_MCP_TOKEN_MAX = 40 # max chars per rendered name / type token
_MCP_HINT_MAX = 300 # total-length backstop for the whole hint
def _sanitize_schema_token(value: Any, limit: int = _MCP_TOKEN_MAX) -> str:
"""Make an untrusted JSON-Schema token safe to splice into the prompt.
Replaces control chars / newlines with a space, collapses whitespace, and
length-caps the result, so a weird field name or type cannot inject newlines
or run on. Normal short identifiers pass through unchanged.
"""
text = re.sub(r"[\x00-\x1f\x7f]+", " ", str(value))
text = re.sub(r"\s+", " ", text).strip()
if len(text) > limit:
text = text[:limit].rstrip() + ""
return text
def _format_mcp_params(input_schema: Any) -> str:
"""Render an MCP tool's JSON-Schema inputs as a compact prompt hint.
Without this the agent only sees a tool's name + description and has to
guess its arguments (issue #2509). Produces e.g.
` Args (JSON): {"path": string (required), "limit": integer}` — names,
coarse types, and required-ness, kept short so it stays prompt-friendly.
Returns "" when there are no parameters.
MCP servers are third-party, so names/types are sanitized and the parameter
count + total length are capped (issue #2660); normal schemas are unaffected.
"""
if not isinstance(input_schema, dict):
return ""
props = input_schema.get("properties")
if not isinstance(props, dict) or not props:
return ""
required = set(input_schema.get("required") or [])
parts = []
for pname, pinfo in list(props.items())[:_MCP_PARAM_MAX]:
pinfo = pinfo if isinstance(pinfo, dict) else {}
ptype = pinfo.get("type") or "any"
if isinstance(ptype, list):
ptype = "|".join(str(x) for x in ptype)
tag = f'"{_sanitize_schema_token(pname)}": {_sanitize_schema_token(ptype)}'
if pname in required:
tag += " (required)"
parts.append(tag)
extra = len(props) - len(parts)
if extra > 0:
parts.append(f"…+{extra} more")
hint = " Args (JSON): {" + ", ".join(parts) + "}"
if len(hint) > _MCP_HINT_MAX:
hint = hint[:_MCP_HINT_MAX - 1].rstrip() + ""
return hint
class McpManager:
"""Manages MCP server connections and tool routing."""
@@ -43,7 +102,9 @@ class McpManager:
self._sessions: Dict[str, Any] = {}
# server_id -> exit stack (for cleanup)
self._stacks: Dict[str, Any] = {}
# Tracking updates to tools/connections for RAG indexing
# server_id -> background connect task (HTTP transport / OAuth)
self._connect_tasks: Dict[str, Any] = {}
# Tracking updates to tools/connections for RAG indexing / prompt cache
self._generation = 0
async def connect_server(
@@ -56,12 +117,14 @@ class McpManager:
env: Optional[Dict[str, str]] = None,
url: Optional[str] = None,
) -> bool:
"""Connect to an MCP server via stdio or SSE transport."""
"""Connect to an MCP server via stdio, SSE, or Streamable HTTP transport."""
try:
if transport == "stdio":
res = await self._connect_stdio(server_id, name, command, args or [], env or {})
elif transport == "sse":
res = await self._connect_sse(server_id, name, url)
elif transport == "http":
res = await self._start_http_connect(server_id, name, url)
else:
logger.error(f"Unknown MCP transport: {transport}")
res = False
@@ -184,8 +247,101 @@ class McpManager:
self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
return False
async def _start_http_connect(self, server_id: str, name: str, url: str, wait: float = 8.0) -> bool:
"""Begin a Streamable HTTP connect in the background. Returns within
`wait` seconds: True if it connected (cached-token path), otherwise the
flow is awaiting browser authorization and status becomes 'needs_auth'."""
import asyncio
self._connections[server_id] = {"status": "connecting", "name": name, "transport": "http"}
task = asyncio.create_task(self._connect_http(server_id, name, url))
self._connect_tasks[server_id] = task
done, _ = await asyncio.wait({task}, timeout=wait)
if task in done:
try:
return task.result()
except Exception as e:
self._connections[server_id] = {"status": "error", "error": str(e), "name": name}
return False
# Still running → either awaiting authorization, or discovery/DCR is
# still in flight. If _on_redirect already published needs_auth+auth_url,
# leave it; otherwise mark needs_auth (auth_url filled in once it fires).
from src.mcp_oauth import pop_auth_url
cur = self._connections.get(server_id, {})
if cur.get("status") != "needs_auth":
self._connections[server_id] = {
"status": "needs_auth", "name": name, "transport": "http",
"auth_url": pop_auth_url(server_id),
}
return False
async def _connect_http(self, server_id: str, name: str, url: str) -> bool:
"""Connect to a Streamable HTTP MCP server (with automatic OAuth)."""
try:
from mcp import ClientSession
from mcp.client.streamable_http import streamablehttp_client
from contextlib import AsyncExitStack
from src.mcp_oauth import build_provider, clear_auth_url
def _on_redirect(auth_url):
# Publish needs_auth the moment the URL is known, independent of
# how long discovery/DCR took (may exceed the bounded start wait).
self._connections[server_id] = {
"status": "needs_auth", "name": name, "transport": "http",
"auth_url": auth_url,
}
provider = build_provider(server_id, url, on_redirect=_on_redirect)
stack = AsyncExitStack()
transport = await stack.enter_async_context(streamablehttp_client(url, auth=provider))
read_stream, write_stream, _get_session_id = transport
session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
await session.initialize()
tools_result = await session.list_tools()
tools = []
for tool in tools_result.tools:
tools.append({
"name": tool.name,
"description": tool.description or "",
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
})
self._sessions[server_id] = session
self._stacks[server_id] = stack
self._tools[server_id] = tools
self._connections[server_id] = {
"status": "connected", "name": name, "transport": "http",
"tool_count": len(tools),
}
clear_auth_url(server_id)
# Tools changed (this can complete after connect_server already
# returned, via the background OAuth flow), so bump the generation
# to invalidate the tool-prompt cache.
self._generation += 1
logger.info(f"MCP server connected: {name} ({server_id}) - {len(tools)} tools via http")
return True
except ImportError:
logger.warning("MCP package not installed. Install with: pip install mcp")
self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
return False
except Exception as e:
logger.error(f"Failed to connect HTTP MCP server {name} ({server_id}): {e}")
self._connections[server_id] = {"status": "error", "error": str(e), "name": name}
return False
async def disconnect_server(self, server_id: str):
"""Disconnect from an MCP server."""
# Cancel any in-flight HTTP/OAuth background connect so it stops
# publishing status for a server that may be getting deleted.
task = self._connect_tasks.pop(server_id, None)
if task is not None and not task.done():
task.cancel()
try:
from src.mcp_oauth import clear_auth_url
clear_auth_url(server_id)
except Exception:
pass
stack = self._stacks.pop(server_id, None)
if stack:
try:
@@ -376,6 +532,7 @@ class McpManager:
"name": tool["name"],
"qualified_name": f"mcp__{server_id}__{tool['name']}",
"description": tool.get("description", ""),
"input_schema": tool.get("input_schema") or {},
"is_disabled": tool["name"] in disabled,
})
return result
@@ -439,7 +596,11 @@ class McpManager:
for t in server_tools:
# Truncate long descriptions
desc = t['description'][:120] + '...' if len(t['description']) > 120 else t['description']
lines.append(f" - {t['qualified_name']}: {desc}")
# Include the tool's declared inputs so the model calls it with
# real argument names instead of guessing from the description
# alone (issue #2509).
args_hint = _format_mcp_params(t.get("input_schema"))
lines.append(f" - {t['qualified_name']}: {desc}{args_hint}")
result = "\n".join(lines)
self._cached_prompt_desc = result

193
src/mcp_oauth.py Normal file
View File

@@ -0,0 +1,193 @@
"""mcp_oauth.py — generic OAuth for remote (Streamable HTTP) MCP servers.
Bridges the mcp SDK's OAuthClientProvider (RFC 9728 discovery, Dynamic Client
Registration, authorization-code + PKCE, token refresh) to Odysseus's web
callback route. Tokens and the dynamic registration persist per-server,
encrypted, so the interactive flow runs only once.
"""
import asyncio
import json
import logging
import os
import time
from typing import Dict, Optional, Tuple
from urllib.parse import urlparse, parse_qs
logger = logging.getLogger(__name__)
# OAuth redirect URI registered with every authorization server via DCR. Loopback
# is allowed for native/desktop clients (RFC 8252); remote users finish via the
# paste-back flow. Deployments not reachable at http://localhost:7000 (custom
# port, reverse proxy, or public domain) must set OAUTH_REDIRECT_BASE_URL (or
# APP_PUBLIC_URL) to their externally reachable origin so the redirect lands back
# on Odysseus. APP_PORT is intentionally not used: it is only the Docker host
# port-map; the app always listens on 7000 inside the container.
_REDIRECT_BASE = (
os.environ.get("OAUTH_REDIRECT_BASE_URL")
or os.environ.get("APP_PUBLIC_URL")
or "http://localhost:7000"
).rstrip("/")
REDIRECT_URI = f"{_REDIRECT_BASE}/api/mcp/oauth/callback"
# How long the background connect waits for the user to authorize before giving up.
AUTH_WAIT_SECONDS = 300
_pending: Dict[str, asyncio.Future] = {} # state -> Future[(code, state)]
_pending_ts: Dict[str, float] = {} # state -> monotonic timestamp, for pruning
_auth_urls: Dict[str, str] = {} # server_id -> authorization URL
def _prune_stale() -> None:
"""Drop abandoned flows whose authorization window has elapsed so the
module-level registries don't grow unbounded (e.g. a user who never
finishes the browser step)."""
now = time.monotonic()
for state in [s for s, ts in _pending_ts.items() if now - ts > AUTH_WAIT_SECONDS]:
fut = _pending.pop(state, None)
_pending_ts.pop(state, None)
if fut is not None and not fut.done():
fut.cancel()
def _discard_pending(state: Optional[str]) -> None:
if state is None:
return
_pending.pop(state, None)
_pending_ts.pop(state, None)
def register_pending(state: str) -> asyncio.Future:
_prune_stale()
fut = asyncio.get_running_loop().create_future()
_pending[state] = fut
_pending_ts[state] = time.monotonic()
return fut
def resolve_pending(state: str, code: str) -> bool:
fut = _pending.get(state)
if fut is not None and not fut.done():
fut.set_result((code, state))
return True
return False
def pop_auth_url(server_id: str) -> Optional[str]:
return _auth_urls.get(server_id)
def clear_auth_url(server_id: str) -> None:
_auth_urls.pop(server_id, None)
class DbTokenStorage:
"""SDK TokenStorage backed by the encrypted McpServer.oauth_tokens column."""
def __init__(self, server_id: str, session_factory=None):
self.server_id = server_id
if session_factory is None:
from core.database import SessionLocal
session_factory = SessionLocal
self._sf = session_factory
def _load(self) -> dict:
from core.database import McpServer
db = self._sf()
try:
srv = db.query(McpServer).filter(McpServer.id == self.server_id).first()
if srv and srv.oauth_tokens:
return json.loads(srv.oauth_tokens)
finally:
db.close()
return {}
def _update(self, key: str, value: dict) -> None:
"""Load, set one key, and persist the oauth_tokens JSON in a single
session/commit (avoids the load+save double round-trip per write)."""
from core.database import McpServer
db = self._sf()
try:
srv = db.query(McpServer).filter(McpServer.id == self.server_id).first()
if srv is None:
return
data = json.loads(srv.oauth_tokens) if srv.oauth_tokens else {}
data[key] = value
srv.oauth_tokens = json.dumps(data)
db.commit()
finally:
db.close()
async def get_tokens(self):
from mcp.shared.auth import OAuthToken
data = self._load().get("tokens")
return OAuthToken.model_validate(data) if data else None
async def set_tokens(self, tokens) -> None:
self._update("tokens", json.loads(tokens.model_dump_json()))
async def get_client_info(self):
from mcp.shared.auth import OAuthClientInformationFull
data = self._load().get("client_info")
return OAuthClientInformationFull.model_validate(data) if data else None
async def set_client_info(self, client_info) -> None:
self._update("client_info", json.loads(client_info.model_dump_json()))
def build_provider(server_id: str, url: str, on_redirect=None):
"""Construct an OAuthClientProvider that drives the browser flow via the
Odysseus callback route.
on_redirect(authorization_url): optional sync callback invoked the moment
the authorization URL is known (after discovery + DCR). The manager uses it
to publish 'needs_auth' + auth_url to connection state regardless of how
long discovery/DCR took.
"""
from mcp.client.auth import OAuthClientProvider
from mcp.shared.auth import OAuthClientMetadata
client_metadata = OAuthClientMetadata(
client_name="Odysseus",
redirect_uris=[REDIRECT_URI],
grant_types=["authorization_code", "refresh_token"],
response_types=["code"],
# Leave scope unset: the SDK applies the MCP scope-selection strategy and
# overwrites this from the server's WWW-Authenticate / protected-resource
# metadata before building the auth URL. Hardcoding an OIDC scope here
# would break the many MCP servers that are not OpenID providers.
scope=None,
token_endpoint_auth_method="none",
)
async def redirect_handler(authorization_url: str) -> None:
state = (parse_qs(urlparse(authorization_url).query).get("state") or [None])[0]
if state:
register_pending(state)
_auth_urls[server_id] = authorization_url
if on_redirect is not None:
try:
on_redirect(authorization_url)
except Exception as e:
logger.warning(f"MCP OAuth on_redirect callback failed: {e}")
logger.info(f"MCP OAuth: server {server_id} awaiting authorization (state={state})")
async def callback_handler() -> Tuple[str, Optional[str]]:
auth_url = _auth_urls.get(server_id)
state = (parse_qs(urlparse(auth_url).query).get("state") or [None])[0] if auth_url else None
fut = _pending.get(state)
if fut is None:
raise RuntimeError("No pending OAuth flow for this server")
try:
code, ret_state = await asyncio.wait_for(fut, timeout=AUTH_WAIT_SECONDS)
return code, ret_state
finally:
_discard_pending(state)
_auth_urls.pop(server_id, None)
return OAuthClientProvider(
server_url=url,
client_metadata=client_metadata,
storage=DbTokenStorage(server_id),
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

320
src/memory_provider.py Normal file
View File

@@ -0,0 +1,320 @@
"""Memory provider interfaces for native and external memory systems."""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Optional
@dataclass
class MemoryRecord:
"""Provider-neutral memory entry."""
id: str
text: str
timestamp: int = 0
category: str = "fact"
source: str = "unknown"
owner: Optional[str] = None
session_id: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class MemorySearchHit:
"""A memory returned by provider recall."""
memory: MemoryRecord
provider_id: str
score: Optional[float] = None
class MemoryProvider(ABC):
"""Base contract for Odysseus memory providers.
The native memory provider should always be available. External providers
can add recall/write behavior and their own tools without replacing the
built-in local memory baseline.
"""
provider_id = "unknown"
display_name = "Unknown"
enabled = True
async def initialize(self) -> None:
"""Prepare provider resources before use."""
async def shutdown(self) -> None:
"""Release provider resources."""
@abstractmethod
async def remember(
self,
text: str,
*,
owner: Optional[str] = None,
session_id: Optional[str] = None,
category: str = "fact",
source: str = "user",
metadata: Optional[Dict[str, Any]] = None,
) -> MemoryRecord:
"""Store a memory and return the stored record."""
@abstractmethod
async def recall(
self,
query: str,
*,
owner: Optional[str] = None,
top_k: int = 5,
) -> List[MemorySearchHit]:
"""Return provider memories relevant to the query."""
@abstractmethod
async def list_memories(
self,
*,
owner: Optional[str] = None,
limit: int = 100,
) -> List[MemoryRecord]:
"""List memories visible to the owner."""
@abstractmethod
async def delete(self, memory_id: str, *, owner: Optional[str] = None) -> bool:
"""Delete a memory by ID when allowed by the provider."""
def get_tool_schemas(self) -> List[Dict[str, Any]]:
"""Return provider-defined tool schemas when this provider is enabled."""
return []
async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> Any:
"""Handle a provider-defined tool call."""
raise KeyError(f"Provider {self.provider_id} does not expose tool {name}")
class NativeMemoryProvider(MemoryProvider):
"""Provider adapter for Odysseus' built-in memory manager and vector store."""
provider_id = "native"
display_name = "Odysseus native memory"
_CORE_FIELDS = {
"id",
"text",
"timestamp",
"source",
"category",
"uses",
"owner",
"session_id",
"metadata",
}
def __init__(self, memory_manager, memory_vector=None):
self.memory_manager = memory_manager
self.memory_vector = memory_vector
def _to_record(self, entry: Dict[str, Any]) -> MemoryRecord:
metadata = {
key: value
for key, value in entry.items()
if key not in self._CORE_FIELDS
}
stored_metadata = entry.get("metadata")
if isinstance(stored_metadata, dict):
metadata.update(stored_metadata)
return MemoryRecord(
id=entry.get("id", ""),
text=entry.get("text", ""),
timestamp=entry.get("timestamp", 0),
category=entry.get("category", "fact"),
source=entry.get("source", "unknown"),
owner=entry.get("owner"),
session_id=entry.get("session_id"),
metadata=metadata,
)
async def remember(
self,
text: str,
*,
owner: Optional[str] = None,
session_id: Optional[str] = None,
category: str = "fact",
source: str = "user",
metadata: Optional[Dict[str, Any]] = None,
) -> MemoryRecord:
entry = self.memory_manager.add_entry(
text,
source=source,
category=category,
owner=owner,
)
if session_id:
entry["session_id"] = session_id
if metadata:
entry["metadata"] = dict(metadata)
memories = self.memory_manager.load_all()
memories.append(entry)
self.memory_manager.save(memories)
if self._vector_available():
self.memory_vector.add(entry["id"], entry["text"])
return self._to_record(entry)
async def recall(
self,
query: str,
*,
owner: Optional[str] = None,
top_k: int = 5,
) -> List[MemorySearchHit]:
memories = self.memory_manager.load(owner=owner)
by_id = {m.get("id"): m for m in memories}
if self._vector_available():
hits: List[MemorySearchHit] = []
for result in self.memory_vector.search(query, k=top_k):
if not isinstance(result, dict):
continue
memory_id = result.get("memory_id")
entry = by_id.get(memory_id) if memory_id else result
if not entry:
continue
if owner is not None and entry.get("owner") != owner:
continue
hits.append(
MemorySearchHit(
memory=self._to_record(entry),
provider_id=self.provider_id,
score=result.get("score"),
)
)
if hits:
return hits
fallback = self.memory_manager.get_relevant_memories(
query,
memories,
max_items=top_k,
)
return [
MemorySearchHit(
memory=self._to_record(entry),
provider_id=self.provider_id,
score=None,
)
for entry in fallback
]
async def list_memories(
self,
*,
owner: Optional[str] = None,
limit: int = 100,
) -> List[MemoryRecord]:
return [
self._to_record(entry)
for entry in self.memory_manager.load(owner=owner)[:limit]
]
async def delete(self, memory_id: str, *, owner: Optional[str] = None) -> bool:
memories = self.memory_manager.load_all()
remaining = []
deleted_id = None
for entry in memories:
if entry.get("id") != memory_id:
remaining.append(entry)
continue
if owner is not None and entry.get("owner") != owner:
remaining.append(entry)
continue
deleted_id = entry.get("id")
if deleted_id is None:
return False
self.memory_manager.save(remaining)
if self._vector_available():
self.memory_vector.remove(deleted_id)
return True
def _vector_available(self) -> bool:
return bool(self.memory_vector and getattr(self.memory_vector, "healthy", True))
class MemoryProviderRegistry:
"""Container for native and optional external memory providers."""
def __init__(self, providers: Optional[Iterable[MemoryProvider]] = None):
self._providers: Dict[str, MemoryProvider] = {}
for provider in providers or []:
self.register(provider)
def register(self, provider: MemoryProvider) -> None:
if provider.provider_id in self._providers:
raise ValueError(f"Memory provider already registered: {provider.provider_id}")
self._providers[provider.provider_id] = provider
def get(self, provider_id: str) -> MemoryProvider:
return self._providers[provider_id]
def all(self) -> List[MemoryProvider]:
return list(self._providers.values())
def active(self) -> List[MemoryProvider]:
return [provider for provider in self._providers.values() if provider.enabled]
def get_tool_schemas(self) -> List[Dict[str, Any]]:
schemas: List[Dict[str, Any]] = []
seen: Dict[str, str] = {}
for provider in self.active():
for schema in provider.get_tool_schemas():
name = self._tool_name(schema)
if name in seen:
raise ValueError(
f"Memory tool name conflict: {name} from "
f"{provider.provider_id} already exposed by {seen[name]}"
)
seen[name] = provider.provider_id
schemas.append(schema)
return schemas
async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> Any:
provider_by_tool: Dict[str, MemoryProvider] = {}
for provider in self.active():
for schema in provider.get_tool_schemas():
tool_name = self._tool_name(schema)
if tool_name in provider_by_tool:
raise ValueError(
f"Memory tool name conflict: {tool_name} from "
f"{provider.provider_id} already exposed by "
f"{provider_by_tool[tool_name].provider_id}"
)
provider_by_tool[tool_name] = provider
provider = provider_by_tool.get(name)
if provider:
return await provider.handle_tool_call(name, arguments)
raise KeyError(f"No active memory provider exposes tool {name}")
@staticmethod
def _tool_name(schema: Dict[str, Any]) -> str:
if not isinstance(schema, dict):
raise ValueError("Memory provider tool schema must be a dict")
name = schema.get("name")
if isinstance(name, str) and name:
return name
function = schema.get("function")
if isinstance(function, dict):
function_name = function.get("name")
if isinstance(function_name, str) and function_name:
return function_name
raise ValueError("Memory provider tool schema is missing a tool name")

View File

@@ -7,7 +7,7 @@ Provides token estimation for context usage tracking.
import logging
import sys
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse
@@ -208,27 +208,32 @@ KNOWN_CONTEXT_WINDOWS = {
# ---------------------------------------------------------------------------
# Cache
# ---------------------------------------------------------------------------
_context_cache: Dict[str, int] = {}
_context_cache: Dict[Tuple[str, str], int] = {}
def get_context_length(endpoint_url: str, model: str) -> int:
"""Get the context window size for a model.
Queries /v1/models on the endpoint and looks for context_length
or context_window fields. Caches result per model ID.
or context_window fields. Caches result per (endpoint, model).
Falls back to DEFAULT_CONTEXT if unavailable.
"""
configured_kind = _configured_endpoint_kind(endpoint_url)
is_local = _is_local_endpoint(endpoint_url)
if not is_local and model in _context_cache:
return _context_cache[model]
# Key on (endpoint_url, model): the same model id can be served by two
# different remote endpoints with different real context windows (e.g. a
# capped proxy vs. the full provider), so caching by model id alone would
# serve one endpoint's window for the other (issue #2603).
cache_key = (endpoint_url, model)
if not is_local and cache_key in _context_cache:
return _context_cache[cache_key]
ctx = _query_context_length(endpoint_url, model)
# Only cache non-default values to allow retry on next request.
# Local endpoints can restart with a different --max-model-len while keeping
# the same model id, so always re-query them instead of serving stale cache.
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
_context_cache[model] = ctx
_context_cache[cache_key] = ctx
logger.info(f"Context length for {model}: {ctx}")
return ctx
@@ -282,6 +287,16 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
except Exception:
pass
# GitHub Copilot's /models requires auth + X-GitHub-Api-Version headers that
# aren't available here; an unauthenticated probe just 400s. All Copilot
# picker models are major API models covered by the known-context table, so
# rely on that instead of a doomed network call.
from src.copilot import is_copilot_base
if is_copilot_base(endpoint_url):
if known:
logger.info(f"Using known context window for {model}: {known}")
return known or DEFAULT_CONTEXT
models_url = endpoint_url.replace("/chat/completions", "/models")
try:
r = httpx.get(models_url, timeout=REQUEST_TIMEOUT)

View File

@@ -722,6 +722,18 @@ class ResearchHandler:
minimum=1,
maximum=12,
)
_planning_timeout = _bounded_int(
get_setting("research_planning_timeout_seconds", _extraction_timeout),
default=_extraction_timeout,
minimum=15,
maximum=3600,
)
_query_timeout = _bounded_int(
get_setting("research_query_timeout_seconds", _extraction_timeout),
default=_extraction_timeout,
minimum=15,
maximum=3600,
)
researcher = DeepResearcher(
llm_endpoint=llm_endpoint,
@@ -732,6 +744,8 @@ class ResearchHandler:
max_time=max_time,
max_report_tokens=_max_report_tokens,
extraction_timeout=_extraction_timeout,
planning_timeout=_planning_timeout,
query_timeout=_query_timeout,
extraction_concurrency=_extraction_concurrency,
progress_callback=progress_callback,
search_provider=search_provider,

View File

@@ -1,141 +1,12 @@
"""Search analytics, metrics tracking, and exception hierarchy."""
"""Compatibility re-export shim for the live analytics module.
import json
import logging
from collections import Counter
from pathlib import Path
from typing import Dict, Any
The real implementation lives in :mod:`services.search.analytics`, which is
what the search runtime imports. Alias this module to that implementation so
mutable module state such as ``ANALYTICS_FILE`` cannot drift out of sync.
"""
from .cache import cache_metrics
import sys
logger = logging.getLogger(__name__)
from services.search import analytics as _analytics
# Dedicated error logger with file handler
_error_log_path = Path(__file__).resolve().parent.parent / "search_engine_error.log"
_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
_error_handler.setLevel(logging.WARNING)
_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
error_logger = logging.getLogger("search_engine_error")
error_logger.addHandler(_error_handler)
error_logger.propagate = False
# Analytics file
ANALYTICS_FILE = Path(__file__).resolve().parent.parent / "search_analytics.json"
# ----------------------------------------------------------------------
# Custom exception hierarchy
# ----------------------------------------------------------------------
class SearchEngineError(Exception):
"""Base class for all search-engine related errors."""
class NetworkError(SearchEngineError):
"""Raised when a network request fails (e.g., timeout, DNS error)."""
class ParseError(SearchEngineError):
"""Raised when HTML or other content cannot be parsed."""
class RateLimitError(SearchEngineError):
"""Raised when the remote service returns a rate-limit (HTTP 429)."""
# ----------------------------------------------------------------------
# Analytics helpers
# ----------------------------------------------------------------------
def _default_analytics() -> Dict[str, Any]:
"""A fresh analytics document with every counter present."""
return {
"total_queries": 0,
"successful_queries": 0,
"failed_queries": 0,
"cache_hits": 0,
"cache_misses": 0,
"query_patterns": {},
}
def _load_analytics() -> Dict[str, Any]:
"""Load analytics data from the JSON file, creating defaults if missing."""
if not ANALYTICS_FILE.exists():
default = _default_analytics()
_save_analytics(default)
return default
try:
with open(ANALYTICS_FILE, "r", encoding="utf-8") as f:
data = json.load(f)
# Merge over defaults so a file written by an older schema (or a
# partial write) still has every counter — _record_query indexes
# these keys directly and would otherwise raise KeyError.
merged = _default_analytics()
if isinstance(data, dict):
merged.update(data)
return merged
except Exception as e:
logger.warning(f"Failed to load analytics file: {e}")
return _default_analytics()
def _save_analytics(data: Dict[str, Any]) -> None:
"""Persist analytics data to the JSON file."""
try:
with open(ANALYTICS_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.warning(f"Failed to write analytics file: {e}")
def _record_query(query: str, success: bool, cache_hit: bool) -> None:
"""Update analytics for a single query execution."""
analytics = _load_analytics()
analytics["total_queries"] += 1
if success:
analytics["successful_queries"] += 1
else:
analytics["failed_queries"] += 1
if cache_hit:
analytics["cache_hits"] += 1
cache_metrics["hits"] += 1
else:
analytics["cache_misses"] += 1
cache_metrics["misses"] += 1
patterns = analytics["query_patterns"]
entry = patterns.get(query, {"count": 0, "successes": 0})
entry["count"] += 1
if success:
entry["successes"] += 1
patterns[query] = entry
_save_analytics(analytics)
def get_search_stats() -> Dict[str, Any]:
"""Return aggregated search analytics."""
analytics = _load_analytics()
total = analytics.get("total_queries", 0) or 1
success_rate = analytics.get("successful_queries", 0) / total
cache_total = analytics.get("cache_hits", 0) + analytics.get("cache_misses", 0) or 1
cache_hit_rate = analytics.get("cache_hits", 0) / cache_total
pattern_counter = Counter({
q: data["count"] for q, data in analytics.get("query_patterns", {}).items()
})
most_common = [q for q, _ in pattern_counter.most_common(5)]
return {
"most_common_queries": most_common,
"success_rate": success_rate,
"cache_hit_rate": cache_hit_rate,
"total_queries": analytics.get("total_queries", 0),
"successful_queries": analytics.get("successful_queries", 0),
"failed_queries": analytics.get("failed_queries", 0),
"cache_hits": analytics.get("cache_hits", 0),
"cache_misses": analytics.get("cache_misses", 0),
"cache_evictions": cache_metrics["evictions"],
"runtime_cache_hits": cache_metrics["hits"],
"runtime_cache_misses": cache_metrics["misses"],
}
sys.modules[__name__] = _analytics

View File

@@ -1,57 +1,11 @@
"""Search and content caching with LRU eviction."""
"""Compatibility wrapper for the canonical services.search.cache module.
import hashlib
import logging
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict
``src.search.cache`` stays importable for older agent/deep-research code, but the
implementation now lives in ``services.search.cache`` so the two cannot drift.
"""
logger = logging.getLogger(__name__)
import sys
# Cache directories
CACHE_DIR = Path(__file__).resolve().parent.parent / "cache"
SEARCH_CACHE_DIR = CACHE_DIR / "search"
CONTENT_CACHE_DIR = CACHE_DIR / "content"
CACHE_MAX_ENTRIES = 1000
from services.search import cache as _cache
# Create cache directories
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
# Track cache size for LRU eviction
search_cache_index: Dict[str, datetime] = {}
content_cache_index: Dict[str, datetime] = {}
# Cache metrics (shared across modules)
cache_metrics = {"hits": 0, "misses": 0, "evictions": 0}
def generate_cache_key(data: str) -> str:
"""Generate a unique cache key using SHA-256 hash."""
return hashlib.sha256(data.encode("utf-8")).hexdigest()
def cleanup_cache(cache_dir: Path, cache_index: Dict[str, datetime], max_age: timedelta):
"""Remove expired cache entries and enforce LRU policy."""
current_time = datetime.now()
files_in_dir = {f.name.split(".")[0]: f for f in cache_dir.glob("*.cache")}
to_remove = []
for key, timestamp in list(cache_index.items()):
if current_time - timestamp > max_age or key not in files_in_dir:
to_remove.append(key)
if key in files_in_dir:
files_in_dir[key].unlink(missing_ok=True)
for key in to_remove:
cache_index.pop(key, None)
cache_metrics["evictions"] += 1
if len(cache_index) > CACHE_MAX_ENTRIES:
sorted_items = sorted(cache_index.items(), key=lambda x: x[1])
excess_count = len(cache_index) - CACHE_MAX_ENTRIES
for key, _ in sorted_items[:excess_count]:
cache_index.pop(key, None)
cache_file = cache_dir / f"{key}.cache"
cache_file.unlink(missing_ok=True)
cache_metrics["evictions"] += 1
sys.modules[__name__] = _cache

View File

@@ -1,419 +1,11 @@
"""Webpage content fetching with caching, PDF extraction, and summarization helpers."""
"""Compatibility wrapper for the canonical services.search.content module.
import copy
import io
import ipaddress
import json
import os
import re
import logging
import socket
from datetime import datetime, timedelta
from typing import List
from urllib.parse import urljoin, urlparse
``src.search.content`` stays importable for older agent/deep-research code, but the
implementation now lives in ``services.search.content`` so the two cannot drift.
"""
import httpx
from bs4 import BeautifulSoup
import sys
from .analytics import RateLimitError, error_logger
from .cache import (
CONTENT_CACHE_DIR,
content_cache_index,
generate_cache_key,
cleanup_cache,
)
from services.search import content as _content
logger = logging.getLogger(__name__)
_PRIVATE_NETWORKS = (
ipaddress.ip_network("0.0.0.0/8"),
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("fe80::/10"),
)
def _is_private_address(addr: ipaddress._BaseAddress) -> bool:
if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None:
addr = addr.ipv4_mapped
return (
addr.is_private
or addr.is_loopback
or addr.is_link_local
or addr.is_reserved
or addr.is_multicast
or addr.is_unspecified
or any(addr in net for net in _PRIVATE_NETWORKS)
)
def _resolve_hostname_ips(hostname: str) -> List[ipaddress._BaseAddress]:
ips = []
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
if family in (socket.AF_INET, socket.AF_INET6):
ips.append(ipaddress.ip_address(sockaddr[0]))
return ips
def _public_http_url(url: str) -> bool:
parsed = urlparse(url)
if parsed.scheme not in ("http", "https") or not parsed.hostname:
return False
host = parsed.hostname.strip().lower()
if host in ("localhost", "metadata.google.internal", "metadata"):
return False
if host.endswith((".local", ".localhost", ".internal", ".lan", ".intranet")):
return False
try:
return not _is_private_address(ipaddress.ip_address(host))
except ValueError:
pass
try:
ips = _resolve_hostname_ips(host)
except OSError:
return False
# Fail closed: a hostname that resolves to nothing is treated as
# non-public (an empty all(...) would otherwise return True).
return bool(ips) and all(not _is_private_address(ip) for ip in ips)
def _get_public_url(url: str, *, headers: dict, timeout: int) -> httpx.Response:
if not _public_http_url(url):
raise httpx.RequestError(f"Blocked non-public URL: {url}")
current = url
with httpx.Client(headers=headers, timeout=timeout, follow_redirects=False) as client:
for _ in range(8):
response = client.get(current)
if response.status_code not in (301, 302, 303, 307, 308):
return response
location = response.headers.get("location")
if not location:
return response
current = urljoin(current, location)
if not _public_http_url(current):
raise httpx.RequestError(f"Blocked redirect to non-public URL: {current}")
raise httpx.RequestError("Too many redirects")
# PDF extraction (optional dependency)
try:
from pdfminer.high_level import extract_text as pdf_extract_text
except ImportError:
pdf_extract_text = None # type: ignore
# ----------------------------------------------------------------------
# HTML extraction helpers
# ----------------------------------------------------------------------
def _extract_meta(soup: BeautifulSoup) -> dict:
"""Pull meta description and keywords if present."""
description = ""
keywords = ""
desc_tag = soup.find("meta", attrs={"name": re.compile("description", re.I)})
if desc_tag and desc_tag.get("content"):
description = desc_tag["content"].strip()
kw_tag = soup.find("meta", attrs={"name": re.compile("keywords", re.I)})
if kw_tag and kw_tag.get("content"):
keywords = kw_tag["content"].strip()
return {"description": description, "keywords": keywords}
def _extract_og_image(soup: BeautifulSoup) -> str:
"""Extract the best representative image URL from meta tags.
Only returns absolute http(s) URLs — skips relative paths and data URIs.
"""
candidates = []
# Open Graph image (most reliable)
for prop in ("og:image", "og:image:url", "og:image:secure_url"):
tag = soup.find("meta", attrs={"property": prop})
if tag and tag.get("content", "").strip():
candidates.append(tag["content"].strip())
# Twitter card image
tag = soup.find("meta", attrs={"name": "twitter:image"})
if tag and tag.get("content", "").strip():
candidates.append(tag["content"].strip())
# Thumbnail meta
tag = soup.find("meta", attrs={"name": "thumbnail"})
if tag and tag.get("content", "").strip():
candidates.append(tag["content"].strip())
# Return first absolute http(s) URL
for url in candidates:
if url.startswith(("https://", "http://")) and not url.endswith((".svg", ".ico")):
return url
return ""
def _extract_lists(soup: BeautifulSoup) -> List[List[str]]:
"""Return a list of lists, each inner list representing a <ul>/<ol>."""
all_lists = []
for lst in soup.find_all(["ul", "ol"]):
items = [li.get_text(separator=" ", strip=True) for li in lst.find_all("li")]
if items:
all_lists.append(items)
return all_lists
def _extract_tables(soup: BeautifulSoup) -> List[List[List[str]]]:
"""Return a list of tables, each table is a list of rows, each row a list of cell texts."""
tables_data = []
for table in soup.find_all("table"):
rows = []
for tr in table.find_all("tr"):
cells = [td.get_text(separator=" ", strip=True) for td in tr.find_all(["td", "th"])]
if cells:
rows.append(cells)
if rows:
tables_data.append(rows)
return tables_data
def _extract_code_blocks(soup: BeautifulSoup) -> List[str]:
"""Collect text from <pre> and <code> blocks."""
blocks = []
for tag in soup.find_all(["pre", "code"]):
txt = tag.get_text(separator=" ", strip=True)
if txt:
blocks.append(txt)
return blocks
def _detect_js_frameworks(soup: BeautifulSoup) -> bool:
"""Very naive detection of common JS frameworks."""
js_indicators = [
"react", "angular", "vue", "svelte", "next", "nuxt",
"ember", "backbone", "jquery", "polymer", "mithril",
]
for script in soup.find_all("script"):
src = script.get("src", "").lower()
if any(fr in src for fr in js_indicators):
return True
if script.string:
content = script.string.lower()
if any(fr in content for fr in js_indicators):
return True
if soup.find(attrs={"data-reactroot": True}) or soup.find(attrs={"ng-app": True}):
return True
return False
def _empty_result(url: str, error: str = "") -> dict:
"""Build a standard failure result dict."""
return {
"url": url,
"title": "",
"content": "",
"lists": [],
"tables": [],
"code_blocks": [],
"meta_description": "",
"meta_keywords": "",
"js_rendered": False,
"js_message": "",
"success": False,
"error": error,
}
# ----------------------------------------------------------------------
# Main content fetcher
# ----------------------------------------------------------------------
def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) -> dict:
"""Fetch and extract meaningful content from a webpage with caching."""
cache_key = generate_cache_key(url)
cache_file = CONTENT_CACHE_DIR / f"{cache_key}.cache"
# Check cache
if cache_file.exists():
try:
with open(cache_file, "r", encoding="utf-8") as f:
cached_data = json.load(f)
timestamp = datetime.fromisoformat(cached_data["timestamp"])
if datetime.now() - timestamp < timedelta(hours=2):
logger.debug(f"Content cache hit for URL: {url}")
return cached_data["data"]
else:
cache_file.unlink(missing_ok=True)
content_cache_index.pop(cache_key, None)
except Exception as e:
logger.warning(f"Failed to read content cache for {url}: {e}")
cache_file.unlink(missing_ok=True)
content_cache_index.pop(cache_key, None)
# Fetch
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.5",
"Accept-Encoding": "gzip, deflate",
"Connection": "keep-alive",
}
response = _get_public_url(url, headers=headers, timeout=timeout)
if response.status_code == 429:
raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
response.raise_for_status()
except httpx.RequestError as e:
error_logger.error(f"NetworkError fetching {url} (attempt {retry_attempt}): {e}")
return _empty_result(url, f"NetworkError: {e}")
except RateLimitError as e:
error_logger.error(str(e))
return _empty_result(url, str(e))
# PDF handling
content_type = response.headers.get("Content-Type", "").lower()
if "application/pdf" in content_type or url.lower().endswith(".pdf"):
if pdf_extract_text is None:
logger.error("pdfminer.six is not installed; cannot extract PDF text.")
pdf_text = ""
else:
try:
pdf_bytes = io.BytesIO(response.content)
pdf_text = pdf_extract_text(pdf_bytes)
except Exception as e:
logger.warning(f"PDF extraction failed for {url}: {e}")
pdf_text = ""
result = {
"url": url,
"title": os.path.basename(url),
"content": pdf_text,
"lists": [],
"tables": [],
"code_blocks": [],
"meta_description": "",
"meta_keywords": "",
"js_rendered": False,
"js_message": "",
"success": bool(pdf_text),
"error": "" if pdf_text else "Failed to extract PDF text",
}
_cache_result(cache_file, cache_key, result, url)
return result
# HTML handling
try:
soup = BeautifulSoup(response.text, "html.parser")
except Exception as e:
error_logger.error(f"ParseError parsing HTML from {url} (attempt {retry_attempt}): {e}")
result = _empty_result(url, f"ParseError: {e}")
_cache_result(cache_file, cache_key, result, url)
return result
title_tag = soup.find("title")
title_text = title_tag.get_text(strip=True) if title_tag else ""
meta_info = _extract_meta(soup)
og_image = _extract_og_image(soup)
js_rendered = _detect_js_frameworks(soup)
js_message = "Page appears to be rendered by a JavaScript framework; content may be incomplete." if js_rendered else ""
# Main textual content (heuristic): prefer semantic / "content"-classed
# containers to skip nav/footer/boilerplate; tuned for article pages.
main_content = ""
content_areas = soup.find_all(
["main", "article", "section", "div"],
class_=re.compile("content|main|body|article|post|entry|text", re.I),
)
if content_areas:
for area in content_areas[:3]:
main_content += area.get_text(separator=" ", strip=True) + " "
main_content = re.sub(r"\s+", " ", main_content).strip()
# The class heuristic can latch onto a small wrapper and miss the real
# content (app/landing pages, or SSR sites whose body isn't in a
# "content"-classed div, so these came back nearly empty before). When the
# heuristic returns nothing OR suspiciously little, fall back to the full
# <body>, stripping scripts/styles (so JSON/JS doesn't leak into the text)
# plus nav/header/footer/aside (boilerplate), and keep whichever yields
# more readable text.
THIN_CONTENT_CHARS = 600 # below this the heuristic likely missed the page
if len(main_content) < THIN_CONTENT_CHARS:
body = soup.find("body")
if body:
# Strip from a copy so the later list/table/code extractors still
# see the original soup unmodified.
body_copy = copy.copy(body)
for _noise in body_copy.find_all(
["script", "style", "noscript", "template", "nav", "header", "footer", "aside"]
):
_noise.extract()
body_text = re.sub(r"\s+", " ", body_copy.get_text(separator=" ", strip=True)).strip()
if len(body_text) > len(main_content):
main_content = body_text
result = {
"url": url,
"title": title_text,
"content": main_content,
"lists": _extract_lists(soup),
"tables": _extract_tables(soup),
"code_blocks": _extract_code_blocks(soup),
"meta_description": meta_info.get("description", ""),
"meta_keywords": meta_info.get("keywords", ""),
"og_image": og_image,
"js_rendered": js_rendered,
"js_message": js_message,
"success": True,
"error": "",
}
_cache_result(cache_file, cache_key, result, url)
return result
def _cache_result(cache_file, cache_key: str, result: dict, url: str):
"""Write a result to the content cache."""
try:
cache_data = {"timestamp": datetime.now().isoformat(), "data": result}
with open(cache_file, "w", encoding="utf-8") as f:
json.dump(cache_data, f)
content_cache_index[cache_key] = datetime.now()
cleanup_cache(CONTENT_CACHE_DIR, content_cache_index, timedelta(hours=2))
except Exception as e:
logger.warning(f"Failed to write content cache for {url}: {e}")
# ----------------------------------------------------------------------
# Content summarization helpers
# ----------------------------------------------------------------------
def extract_key_points(text: str) -> List[str]:
"""Pull out bullet-style key points from a block of text."""
points: List[str] = []
bullet_pat = re.compile(r"^\s*[-*•]\s+(.*)")
numbered_pat = re.compile(r"^\s*\d+[\.\)]\s+(.*)")
for line in text.splitlines():
m = bullet_pat.match(line) or numbered_pat.match(line)
if m:
points.append(m.group(1).strip())
return points
def get_tldr(text: str, max_sentences: int = 3) -> str:
"""Produce a very short TL;DR by taking the first few sentences."""
sentences = re.split(r"(?<=[.!?])\s+", text)
selected = [s.strip() for s in sentences if s][:max_sentences]
return " ".join(selected)
def extract_quotes(text: str) -> List[str]:
"""Return quoted excerpts that are at least 15 characters long."""
# Backreference the opening quote so the closing quote must match it —
# otherwise `"text'` (open double, close single) is treated as a quote.
return [m.group(2).strip() for m in re.finditer(r'(["\'])([^"\']{15,}?)\1', text)]
def extract_statistics(text: str) -> List[str]:
"""Find numbers, percentages, dates and simple measurements."""
# Match a comma-grouped number (1,000,000) OR a plain digit run (50000) —
# the old `\d{1,3}(?:,\d{3})*` matched only the first 3 digits of a
# comma-less number, and the trailing `\b` dropped a closing `%`.
pattern = re.compile(
r"\b(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d+)?\s*(%|percent|‰|per cent|[a-zA-Z]+)?",
re.IGNORECASE,
)
return [m.group(0).strip() for m in pattern.finditer(text)]
sys.modules[__name__] = _content

View File

@@ -1,141 +1,11 @@
"""Query enhancement, entity extraction, and cache duration helpers."""
"""Compatibility wrapper for the canonical services.search.query module.
import re
import logging
from datetime import timedelta
from typing import Dict, List, Optional, Tuple
``src.search.query`` stays importable for older agent/deep-research code, but the
implementation now lives in ``services.search.query`` so the two cannot drift.
"""
logger = logging.getLogger(__name__)
import sys
from services.search import query as _query
# ----------------------------------------------------------------------
# Query processing helpers
# ----------------------------------------------------------------------
def _detect_question_type(query: str) -> Optional[str]:
"""Return the leading question word if present (who, what, when, where, why, how)."""
if not isinstance(query, str):
return None
q = query.strip().lower()
for word in ("who", "what", "when", "where", "why", "how"):
# Require a whole-word match: a bare prefix mis-flags ordinary queries
# like "whatsapp pricing" (-> what) or "however ..." (-> how), which
# then get spurious boost terms OR-appended in enhance_query.
if q == word or q.startswith(word + " "):
return word
return None
def _extract_entities(query: str) -> Dict[str, List[str]]:
"""Lightweight entity extraction: capitalized words and date patterns."""
entities: Dict[str, List[str]] = {"names": [], "dates": []}
qtype = _detect_question_type(query)
cleaned = query
if qtype:
cleaned = re.sub(rf"^{qtype}\b", "", cleaned, flags=re.I).strip()
for token in re.findall(r"\b[A-Z][a-zA-Z]+\b", cleaned):
entities["names"].append(token)
for year in re.findall(r"\b(?:19|20)\d{2}\b", cleaned):
entities["dates"].append(year)
month_day_year = re.findall(
r"\b(?:Jan|January|Feb|February|Mar|March|Apr|April|May|Jun|June|Jul|July|Aug|August|Sep|Sept|September|Oct|October|Nov|November|Dec|December)\s+\d{1,2},?\s*\d{4}\b",
cleaned,
flags=re.I,
)
entities["dates"].extend(month_day_year)
return entities
def _split_multi_part(query: str) -> List[str]:
"""Split a query into sub-queries on common conjunctions."""
if not isinstance(query, str):
return []
parts = re.split(r"\s+and\s+|\s+or\s+|;", query, flags=re.I)
return [p.strip() for p in parts if p.strip()]
def _extract_site_filter(query: str) -> Tuple[str, Optional[str]]:
"""Detect a 'site:example.com' token. Returns (query_without_token, site_or_None)."""
if not isinstance(query, str):
return "", None
match = re.search(r"\bsite:([^\s]+)", query, flags=re.I)
if match:
site = match.group(1)
new_query = re.sub(r"\bsite:[^\s]+", "", query, flags=re.I).strip()
return new_query, site
return query, None
def _boost_entities_in_query(base_query: str, entities: Dict[str, List[str]]) -> str:
"""Append extracted entities to the query using OR to increase relevance."""
parts = [base_query]
if entities.get("names"):
parts.append(" OR ".join(f'"{n}"' for n in entities["names"]))
if entities.get("dates"):
parts.append(" OR ".join(f'"{d}"' for d in entities["dates"]))
return " ".join(parts)
def enhance_query(original_query: str) -> Tuple[str, Optional[str]]:
"""Process the original query: site filter, question type boosts, entity extraction."""
if not isinstance(original_query, str):
original_query = ""
query_without_site, site = _extract_site_filter(original_query)
sub_queries = _split_multi_part(query_without_site)
enhanced_subs: List[str] = []
for sub in sub_queries:
qtype = _detect_question_type(sub)
boost_keywords = []
if qtype == "who":
boost_keywords.append("person")
elif qtype == "when":
boost_keywords.append("date")
elif qtype == "where":
boost_keywords.append("location")
elif qtype == "why":
boost_keywords.append("reason")
elif qtype == "how":
boost_keywords.append("method")
entities = _extract_entities(sub)
boosted = _boost_entities_in_query(sub, entities)
if boost_keywords:
boosted = f'({boosted}) OR ({" OR ".join(boost_keywords)})'
enhanced_subs.append(boosted)
final_query = " AND ".join(f"({s})" for s in enhanced_subs)
if site:
final_query = f"{final_query} site:{site}"
return final_query, site
def build_enhanced_query(query: str, time_filter: str = None) -> str:
"""Build an enhanced search query with optional time filtering."""
enhanced_query, _ = enhance_query(query)
if time_filter:
time_map = {"day": "d", "week": "w", "month": "m", "year": "y"}
if time_filter in time_map:
enhanced_query = f"{enhanced_query} after:{time_map[time_filter]}"
logger.info(f"Added time filter '{time_filter}' to query")
logger.info(f"Enhanced query: '{query}' -> '{enhanced_query}'")
return enhanced_query
# ----------------------------------------------------------------------
# Cache duration helpers
# ----------------------------------------------------------------------
def _is_news_query(query: str) -> bool:
"""Lightweight heuristic to decide if a query is news-oriented."""
if not isinstance(query, str):
return False
news_terms = {"news", "latest", "breaking", "today", "today's", "current", "updates", "happening"}
tokens = set(re.findall(r"\b\w+\b", query.lower()))
return bool(tokens & news_terms)
def _cache_duration_for_query(query: str) -> timedelta:
"""News queries -> 30 minutes, reference queries -> 24 hours."""
if _is_news_query(query):
return timedelta(minutes=30)
return timedelta(hours=24)
sys.modules[__name__] = _query

View File

@@ -85,6 +85,11 @@ DEFAULT_SETTINGS = {
"research_search_provider": "",
"research_max_tokens": 16384,
"research_extraction_timeout_seconds": 90,
# Lightweight planning/query LLM calls happen before any search starts.
# Keep them separately tunable so slow local backends are not capped by
# the old 30s/60s per-call defaults.
"research_planning_timeout_seconds": 90,
"research_query_timeout_seconds": 90,
"research_extraction_concurrency": 3,
# Hard wall-clock cap on a single deep-research run. The previous 600s
# (10 min) default cut off slow local / edge LLMs mid-synthesis; 1800s
@@ -95,6 +100,7 @@ DEFAULT_SETTINGS = {
# Tune via Settings or by editing data/settings.json.
"research_run_timeout_seconds": 1800,
"agent_max_tool_calls": 0,
"agent_max_rounds": 20, # per-message agent step cap (clamped 1..200)
"agent_input_token_budget": 6000,
# Ceiling on the *auto-derived* input budget that #1230 introduced. Has
# no effect when `agent_input_token_budget` is explicitly set (the user's

View File

@@ -15,18 +15,33 @@ from __future__ import annotations
import re
_THINK_TAG_NAME = r"(?:think(?:ing)?|thought)"
# Closed reasoning blocks. Multi-pass loop in `strip_think` handles nested
# `<think><think>...</think></think>` patterns some models emit.
_THINK_CLOSED_RE = re.compile(r"<think(?:ing)?>[\s\S]*?</think(?:ing)?>\s*", re.IGNORECASE)
_THINK_CLOSED_RE = re.compile(rf"<{_THINK_TAG_NAME}(?:\s+[^>]*)?>[\s\S]*?</{_THINK_TAG_NAME}>\s*", re.IGNORECASE)
# Orphan opening or closing tags that survive after the closed-pass.
_THINK_TAG_RE = re.compile(r"</?think(?:ing)?[^>]*>\s*", re.IGNORECASE)
_THINK_TAG_RE = re.compile(rf"</?{_THINK_TAG_NAME}[^>]*>\s*", re.IGNORECASE)
# Dangling opener anywhere in the response with no closer — strip everything
# from `<think>` to the end of string.
_THINK_OPEN_RE = re.compile(r"<think(?:ing)?>[\s\S]*$", re.IGNORECASE)
_THINK_OPEN_RE = re.compile(rf"<{_THINK_TAG_NAME}(?:\s+[^>]*)?>[\s\S]*$", re.IGNORECASE)
# Streaming models occasionally emit `<thinking time="0.42">`-style attributes.
# Normalize to a plain `<think>` so the regexes above catch them.
_THINK_ATTR_RE = re.compile(r"<think(?:ing)?\s+[^>]*>", re.IGNORECASE)
_THINK_ATTR_CLOSE_RE = re.compile(r"</think(?:ing)?\s+[^>]*>", re.IGNORECASE)
_THINK_ATTR_RE = re.compile(rf"<{_THINK_TAG_NAME}\s+[^>]*>", re.IGNORECASE)
_THINK_ATTR_CLOSE_RE = re.compile(rf"</{_THINK_TAG_NAME}\s+[^>]*>", re.IGNORECASE)
_GEMMA_THOUGHT_OPEN_RE = re.compile(r"<\|channel>thought\s*\n?[\s\S]*$", re.IGNORECASE)
_GEMMA_RESPONSE_CHANNEL_RE = re.compile(
r"<\|channel>response\s*\n?([\s\S]*?)<channel\|>",
re.IGNORECASE,
)
_GEMMA_RESPONSE_OPEN_RE = re.compile(r"<\|channel>response\s*\n?", re.IGNORECASE)
_GEMMA_CHANNEL_CLOSE_RE = re.compile(r"<channel\|>", re.IGNORECASE)
_THOUGHT_TAG_OPEN_RE = re.compile(r"<thought(\s+[^>]*)?>", re.IGNORECASE)
_THOUGHT_TAG_CLOSE_RE = re.compile(r"</thought>", re.IGNORECASE)
_GEMMA_THOUGHT_CHANNEL_CAPTURE_RE = re.compile(
r"<\|channel>thought\s*\n?([\s\S]*?)<channel\|>\s*",
re.IGNORECASE,
)
# Qwen and a few other models prefix the response with a "Thinking Process:"
# block before the real answer.
_QWEN_THINKING_RE = re.compile(
@@ -78,6 +93,30 @@ def _strip_reasoning_prose(text: str) -> str:
return "\n\n".join(keep).strip() if keep else text
def normalize_thinking_markup(text: str) -> str:
"""Canonicalize supported thinking wrappers to `<think>` markup.
The chat UI and persistence layer already understand `<think>...</think>`.
Gemma 4 may instead emit `<|channel>thought\n...<channel|>`, and some
gateways/models emit `<thought>...</thought>`. Normalize those shapes into
the existing representation and strip empty thought channels.
"""
if not text:
return text
out = _THOUGHT_TAG_OPEN_RE.sub(lambda m: "<think" + (m.group(1) or "") + ">", text)
out = _THOUGHT_TAG_CLOSE_RE.sub("</think>", out)
def _replace_gemma_thought(match: re.Match) -> str:
thought = match.group(1).strip()
return f"<think>{thought}</think>\n" if thought else ""
out = _GEMMA_THOUGHT_CHANNEL_CAPTURE_RE.sub(_replace_gemma_thought, out)
out = _GEMMA_RESPONSE_CHANNEL_RE.sub(lambda m: m.group(1), out)
out = _GEMMA_RESPONSE_OPEN_RE.sub("", out)
out = _GEMMA_CHANNEL_CLOSE_RE.sub("", out)
return out
def strip_think(text: str, *, prose: bool = False, prompt_echo: bool = True) -> str:
"""Strip `<think>` blocks from model output.
@@ -92,13 +131,21 @@ def strip_think(text: str, *, prose: bool = False, prompt_echo: bool = True) ->
"The user asks:" / "We need to" leaked prompt echoes.
Robust to:
* closed `<think>...</think>` (any depth, both `<think>` and `<thinking>`)
* dangling unclosed `<think>...`
* closed `<think>...</think>` (any depth, plus `<thinking>`/`<thought>`)
* dangling unclosed `<think>...` / `<thought>...`
* stray opener/closer tags
* `<think time="0.42">`-style attributes
* Gemma 4 `<|channel>thought...<channel|>` wrappers
"""
if not text:
return ""
# Gemma 4 thinking-capable models use channel control tokens rather than
# XML tags when the runtime does not split reasoning into a separate field.
# The thought channel can be empty in non-thinking mode; either way it is
# not user-facing content. A response channel, when present, is only a
# wrapper around the final answer.
text = normalize_thinking_markup(text)
text = _GEMMA_THOUGHT_OPEN_RE.sub("", text)
# Normalize attributes so the closed/open regexes can catch them.
text = _THINK_ATTR_RE.sub("<think>", text)
text = _THINK_ATTR_CLOSE_RE.sub("</think>", text)

View File

@@ -12,14 +12,127 @@ import collections
import json
import logging
import os
import pathlib
import sys
import time
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple
from src.tool_security import is_public_blocked_tool, owner_is_admin_or_single_user
# Persistent working directory for agent subprocesses.
# Resolves to <repo_root>/data, which is the bind-mounted volume in Docker
# (/app/data) and the local data directory for manual installs.
# Using this as cwd and HOME prevents the agent from silently creating files
# in ephemeral container layers that are lost on the next rebuild.
_AGENT_WORKDIR = str(pathlib.Path(__file__).parent.parent / "data")
MAX_OUTPUT_CHARS = 10_000
MAX_READ_CHARS = 20_000
MAX_DIFF_LINES = 400 # cap unified-diff size returned to the UI
def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]:
"""Build a unified diff of a file write for display in the chat.
Returns {"text": <unified diff>, "added": N, "removed": M, "new_file": bool}
or None when there's no textual change. Truncates very large diffs.
"""
if old == new:
return None
import difflib
old_lines = old.splitlines()
new_lines = new.splitlines()
label = path or "file"
diff_lines = list(difflib.unified_diff(
old_lines, new_lines,
fromfile=f"a/{label}", tofile=f"b/{label}",
lineterm="",
))
added = sum(1 for l in diff_lines if l.startswith("+") and not l.startswith("+++"))
removed = sum(1 for l in diff_lines if l.startswith("-") and not l.startswith("---"))
truncated = False
if len(diff_lines) > MAX_DIFF_LINES:
diff_lines = diff_lines[:MAX_DIFF_LINES]
truncated = True
text = "\n".join(diff_lines)
if truncated:
text += f"\n… diff truncated at {MAX_DIFF_LINES} lines"
return {
"text": text,
"added": added,
"removed": removed,
"new_file": old == "",
"file": os.path.basename(path) or (path or "file"),
}
async def _do_edit_file(content: str, workspace: Optional[str] = None) -> Dict[str, Any]:
"""Exact string-replacement edit of an on-disk file.
content is JSON: {"path", "old_string", "new_string", "replace_all"?}.
Fails if old_string is missing or non-unique (unless replace_all) so the
model can't silently edit the wrong place. Returns a unified diff for the UI.
Confined to the workspace when one is set (same policy as write_file).
"""
try:
args = json.loads(content) if content.strip().startswith("{") else {}
except (json.JSONDecodeError, TypeError):
args = {}
raw_path = (args.get("path") or "").strip()
old = args.get("old_string", "")
new = args.get("new_string", "")
replace_all = bool(args.get("replace_all", False))
if not raw_path:
return {"error": "edit_file: path required", "exit_code": 1}
# Confine to the workspace when set, else the same allowlist + sensitive-file
# policy as read/write_file.
try:
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
if workspace else _resolve_tool_path(raw_path))
except ValueError as e:
return {"error": f"edit_file: {e}", "exit_code": 1}
if old == "":
return {"error": "edit_file: old_string required (use write_file to create a file)", "exit_code": 1}
if old == new:
return {"error": "edit_file: old_string and new_string are identical", "exit_code": 1}
def _apply():
with open(path, "r", encoding="utf-8") as f:
original = f.read()
count = original.count(old)
if count == 0:
return original, None, "not_found"
if count > 1 and not replace_all:
return original, None, f"not_unique:{count}"
updated = original.replace(old, new) if replace_all else original.replace(old, new, 1)
with open(path, "w", encoding="utf-8") as f:
f.write(updated)
return original, updated, "ok"
try:
original, updated, status = await asyncio.to_thread(_apply)
except FileNotFoundError:
return {"error": f"edit_file: {path}: not found (use write_file to create it)", "exit_code": 1}
except (IsADirectoryError, UnicodeDecodeError):
return {"error": f"edit_file: {path}: not an editable text file", "exit_code": 1}
except PermissionError:
return {"error": f"edit_file: {path}: permission denied", "exit_code": 1}
except OSError as e:
return {"error": f"edit_file: {path}: {e}", "exit_code": 1}
if status == "not_found":
return {"error": f"edit_file: old_string not found in {path}. Read the file and match it exactly.", "exit_code": 1}
if status.startswith("not_unique"):
n = status.split(":", 1)[1]
return {"error": f"edit_file: old_string is not unique in {path} ({n} matches). Add surrounding context or set replace_all=true.", "exit_code": 1}
n = original.count(old)
result = {"output": f"Edited {path} ({n} replacement{'s' if n != 1 else ''})", "exit_code": 0}
diff = _unified_diff(original, updated, path)
if diff:
result["diff"] = diff
return result
# ---------------------------------------------------------------------------
# Path confinement for read_file / write_file
@@ -158,6 +271,40 @@ def _resolve_tool_path(raw_path: str) -> str:
f"path '{raw_path}' is outside the allowed roots"
)
def _resolve_tool_path_in_workspace(workspace: str, raw_path: str) -> str:
"""Confine a model-supplied path to the active workspace.
Layered on top of upstream's path policy: the workspace is the allowed
root (relative paths resolve under it; paths that escape it are rejected),
and the sensitive-file deny list (.ssh, .gnupg, id_rsa, …) still applies
inside it. When no workspace is set, callers use _resolve_tool_path (the
default data/tmp allowlist) instead.
"""
if raw_path is None or not str(raw_path).strip():
raise ValueError("path is required")
base = os.path.realpath(workspace)
expanded = os.path.expanduser(str(raw_path).strip())
candidate = expanded if os.path.isabs(expanded) else os.path.join(base, expanded)
resolved = os.path.realpath(candidate)
if _is_sensitive_path(resolved):
raise ValueError(
f"path '{raw_path}' is inside a sensitive directory "
f"(e.g. .ssh, .gnupg) or matches a sensitive filename"
)
if resolved != base:
# normcase so containment holds on case-insensitive filesystems
# (Windows, default macOS): it lowercases on Windows and is a no-op on
# POSIX. commonpath raises ValueError across Windows drives (C: vs D:)
# or mixed abs/rel — both mean "outside", so the except rejects them.
nbase = os.path.normcase(base)
try:
if os.path.commonpath([os.path.normcase(resolved), nbase]) != nbase:
raise ValueError
except ValueError:
raise ValueError(f"path '{raw_path}' is outside the workspace ({workspace})")
return resolved
# Bash + python tools used to share a single 60s timeout. That's
# enough for one-shot commands but starves real workloads (pip
# install, ffmpeg conversions, etc.) — and worse, the agent saw the
@@ -186,6 +333,39 @@ def get_mcp_manager():
return agent_tools.get_mcp_manager()
# Directories ignored by the code-nav tools' Python fallbacks so results aren't
# polluted by VCS internals / dependency trees / build caches. ripgrep already
# honours .gitignore; this is the parity floor for the no-rg path (and the
# explicit excludes passed to rg so it skips them even without a .gitignore).
_CODENAV_SKIP_DIRS = frozenset({
".git", ".hg", ".svn", "node_modules", "venv", ".venv", "__pycache__",
".mypy_cache", ".pytest_cache", ".ruff_cache", "dist", "build",
".next", ".cache", "site-packages", ".idea", ".tox",
})
# Per-tool result caps (keep tool output cheap + model-friendly).
_CODENAV_MAX_HITS = 200
_CODENAV_MAX_LINE = 400
def _resolve_search_root(raw_path: str, workspace: Optional[str] = None) -> str:
"""Resolve + confine a code-nav path (grep/glob/ls).
With a workspace set, the workspace folder is the root and supplied paths are
confined inside it (same policy as read_file). Without one, an empty path
defaults to the agent's primary root (project data dir) and a supplied path
is confined by the global allowlist + sensitive-file policy.
"""
raw = (raw_path or "").strip()
if workspace:
if not raw:
return os.path.realpath(workspace)
return _resolve_tool_path_in_workspace(workspace, raw)
if not raw:
roots = _tool_path_roots()
return roots[0] if roots else os.path.realpath(".")
return _resolve_tool_path(raw)
def _truncate(text: str, limit: int = MAX_OUTPUT_CHARS) -> str:
if len(text) > limit:
return text[:limit] + f"\n... (truncated, {len(text)} chars total)"
@@ -396,11 +576,12 @@ async def _call_mcp_tool(
tool: str,
content: str,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
workspace: Optional[str] = None,
) -> Dict:
"""Route a legacy tool call through the MCP manager, with direct fallbacks."""
mcp = get_mcp_manager()
if not mcp:
return await _direct_fallback(tool, content, progress_cb=progress_cb) or {"error": f"MCP manager not available for tool '{tool}'", "exit_code": 1}
return await _direct_fallback(tool, content, progress_cb=progress_cb, workspace=workspace) or {"error": f"MCP manager not available for tool '{tool}'", "exit_code": 1}
server_id, tool_name = _MCP_TOOL_MAP[tool]
qualified = f"mcp__{server_id}__{tool_name}"
@@ -409,7 +590,7 @@ async def _call_mcp_tool(
# If MCP server not connected, try direct fallback
if isinstance(result, dict) and result.get("exit_code") == 1 and "not connected" in result.get("error", ""):
fallback = await _direct_fallback(tool, content, progress_cb=progress_cb)
fallback = await _direct_fallback(tool, content, progress_cb=progress_cb, workspace=workspace)
if fallback:
return fallback
@@ -436,6 +617,7 @@ async def _direct_fallback(
tool: str,
content: str,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
workspace: Optional[str] = None,
) -> Optional[Dict]:
"""In-process execution path for the eight tools that used to live as
stdio MCP servers under mcp_servers/. Those servers were deleted in
@@ -461,6 +643,7 @@ async def _direct_fallback(
"TERM": "xterm-256color",
"COLUMNS": "120",
"LINES": "40",
"HOME": _AGENT_WORKDIR,
}
try:
@@ -470,6 +653,7 @@ async def _direct_fallback(
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_subproc_env,
cwd=workspace or _AGENT_WORKDIR,
)
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
proc,
@@ -496,6 +680,7 @@ async def _direct_fallback(
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_subproc_env,
cwd=workspace or _AGENT_WORKDIR,
)
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
proc,
@@ -512,14 +697,43 @@ async def _direct_fallback(
return {"output": output or "(no output)", "exit_code": rc or 0}
if tool == "read_file":
raw_path = content.split("\n", 1)[0].strip()
# Args: plain path on line 1 (back-compat) OR JSON
# {path, offset?, limit?} where offset/limit are a 1-based line range.
raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0
_stripped = content.strip()
if _stripped.startswith("{"):
try:
_a = _json.loads(_stripped)
raw_path = str(_a.get("path", "")).strip()
offset = int(_a.get("offset") or 0)
limit = int(_a.get("limit") or 0)
except (_json.JSONDecodeError, TypeError, ValueError):
pass
try:
path = _resolve_tool_path(raw_path)
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
if workspace else _resolve_tool_path(raw_path))
except ValueError as e:
return {"error": f"read_file: {e}", "exit_code": 1}
try:
# Run blocking read in a thread to keep the loop responsive
# Run blocking read in a thread to keep the loop responsive.
def _read():
if offset > 0 or limit > 0:
# Line-range read: slice [offset, offset+limit).
start = max(offset, 1)
out, n, budget = [], 0, MAX_READ_CHARS
with open(path, "r", encoding="utf-8", errors="replace") as f:
for i, line in enumerate(f, 1):
if i < start:
continue
if limit > 0 and n >= limit:
break
out.append(line)
n += 1
budget -= len(line)
if budget <= 0:
out.append(f"\n... [truncated at {MAX_READ_CHARS} chars]")
break
return "".join(out)
with open(path, "r", encoding="utf-8", errors="replace") as f:
return f.read(MAX_READ_CHARS + 1)
data = await asyncio.to_thread(_read)
@@ -527,10 +741,11 @@ async def _direct_fallback(
return {"error": f"read_file: {path}: not found", "exit_code": 1}
except PermissionError:
return {"error": f"read_file: {path}: permission denied", "exit_code": 1}
except IsADirectoryError:
return {"error": f"read_file: {path}: is a directory (use ls)", "exit_code": 1}
except OSError as e:
return {"error": f"read_file: {path}: {e}", "exit_code": 1}
truncated = len(data) > MAX_READ_CHARS
if truncated:
if not (offset > 0 or limit > 0) and len(data) > MAX_READ_CHARS:
data = data[:MAX_READ_CHARS] + f"\n... [truncated at {MAX_READ_CHARS} chars]"
return {"output": data, "exit_code": 0}
@@ -539,23 +754,226 @@ async def _direct_fallback(
raw_path = lines[0].strip()
body = lines[1] if len(lines) > 1 else ""
try:
path = _resolve_tool_path(raw_path)
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
if workspace else _resolve_tool_path(raw_path))
except ValueError as e:
return {"error": f"write_file: {e}", "exit_code": 1}
try:
def _write():
# Capture prior content (best-effort, text) so we can show a
# before/after diff. Missing/binary file → treat as empty.
old = ""
try:
with open(path, "r", encoding="utf-8") as f:
old = f.read()
except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError):
old = ""
d = os.path.dirname(path)
if d:
os.makedirs(d, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
f.write(body)
return len(body)
size = await asyncio.to_thread(_write)
return old, len(body)
old_content, size = await asyncio.to_thread(_write)
except PermissionError:
return {"error": f"write_file: {path}: permission denied", "exit_code": 1}
except OSError as e:
return {"error": f"write_file: {path}: {e}", "exit_code": 1}
return {"output": f"Wrote {size} bytes to {path}", "exit_code": 0}
diff = _unified_diff(old_content, body, path)
result = {"output": f"Wrote {size} bytes to {path}", "exit_code": 0}
if diff:
result["diff"] = diff
return result
if tool == "grep":
# Args (JSON): {pattern, path?, glob?, ignore_case?, max_results?}.
# Bare string → treated as the pattern.
args: Dict[str, Any] = {}
_s = (content or "").strip()
if _s.startswith("{"):
try:
args = _json.loads(_s)
except _json.JSONDecodeError:
args = {}
else:
args = {"pattern": _s}
pattern = str(args.get("pattern", "")).strip()
if not pattern:
return {"error": "grep: pattern is required", "exit_code": 1}
ignore_case = bool(args.get("ignore_case"))
glob_pat = str(args.get("glob", "") or "").strip()
try:
max_hits = int(args.get("max_results") or _CODENAV_MAX_HITS)
except (TypeError, ValueError):
max_hits = _CODENAV_MAX_HITS
max_hits = max(1, min(max_hits, _CODENAV_MAX_HITS))
try:
root = _resolve_search_root(str(args.get("path", "")), workspace)
except ValueError as e:
return {"error": f"grep: {e}", "exit_code": 1}
def _grep():
import re as _re
import shutil
rg = shutil.which("rg")
if rg:
cmd = [rg, "--line-number", "--no-heading", "--color=never",
"--max-count", str(max_hits)]
if ignore_case:
cmd.append("--ignore-case")
if glob_pat:
cmd += ["--glob", glob_pat]
# Exclude junk dirs even when the tree has no .gitignore, so
# results match the Python fallback's skip set.
for _d in _CODENAV_SKIP_DIRS:
cmd += ["--glob", f"!**/{_d}/**"]
cmd += ["--regexp", pattern, root]
try:
import subprocess
p = subprocess.run(cmd, capture_output=True, text=True, timeout=20)
lines = [ln for ln in (p.stdout or "").splitlines() if ln][:max_hits]
return lines, None
except subprocess.TimeoutExpired:
return None, "grep: timed out"
except Exception as _e:
return None, f"grep: {_e}"
# Python fallback (no ripgrep): walk + regex.
try:
rx = _re.compile(pattern, _re.IGNORECASE if ignore_case else 0)
except _re.error as _e:
return None, f"grep: bad pattern: {_e}"
import fnmatch
hits = []
if os.path.isfile(root):
file_iter = [root]
else:
file_iter = []
for dp, dns, fns in os.walk(root):
dns[:] = [d for d in dns if d not in _CODENAV_SKIP_DIRS]
for fn in fns:
if glob_pat and not fnmatch.fnmatch(fn, glob_pat):
continue
file_iter.append(os.path.join(dp, fn))
for fp in file_iter:
if len(hits) >= max_hits:
break
try:
with open(fp, "r", encoding="utf-8", errors="strict") as f:
for i, line in enumerate(f, 1):
if rx.search(line):
hits.append(f"{fp}:{i}:{line.rstrip()[:_CODENAV_MAX_LINE]}")
if len(hits) >= max_hits:
break
except (UnicodeDecodeError, OSError):
continue # skip binary / unreadable
return hits, None
lines, err = await asyncio.to_thread(_grep)
if err:
return {"error": err, "exit_code": 1}
if not lines:
return {"output": f"No matches for {pattern!r} under {root}", "exit_code": 0}
out = "\n".join(ln[:_CODENAV_MAX_LINE] for ln in lines)
if len(lines) >= max_hits:
out += f"\n... [capped at {max_hits} matches]"
return {"output": _truncate(out), "exit_code": 0}
if tool == "glob":
args = {}
_s = (content or "").strip()
if _s.startswith("{"):
try:
args = _json.loads(_s)
except _json.JSONDecodeError:
args = {}
else:
args = {"pattern": _s}
pattern = str(args.get("pattern", "")).strip()
if not pattern:
return {"error": "glob: pattern is required", "exit_code": 1}
try:
root = _resolve_search_root(str(args.get("path", "")), workspace)
except ValueError as e:
return {"error": f"glob: {e}", "exit_code": 1}
def _glob():
from pathlib import Path
base = Path(root)
if not base.is_dir():
return None, f"glob: {root}: not a directory"
matched = []
try:
for p in base.rglob(pattern):
if set(p.relative_to(base).parts) & _CODENAV_SKIP_DIRS:
continue
try:
mtime = p.stat().st_mtime
except OSError:
mtime = 0
matched.append((mtime, str(p)))
if len(matched) > _CODENAV_MAX_HITS * 5:
break
except (OSError, ValueError) as _e:
return None, f"glob: {_e}"
matched.sort(key=lambda t: t[0], reverse=True) # newest first
return [pth for _, pth in matched[:_CODENAV_MAX_HITS]], None
paths, err = await asyncio.to_thread(_glob)
if err:
return {"error": err, "exit_code": 1}
if not paths:
return {"output": f"No files matching {pattern!r} under {root}", "exit_code": 0}
out = "\n".join(paths)
if len(paths) >= _CODENAV_MAX_HITS:
out += f"\n... [capped at {_CODENAV_MAX_HITS} files]"
return {"output": _truncate(out), "exit_code": 0}
if tool == "ls":
raw_path = ""
_s = (content or "").strip()
if _s.startswith("{"):
try:
raw_path = str(_json.loads(_s).get("path", "")).strip()
except _json.JSONDecodeError:
raw_path = ""
else:
raw_path = _s.split("\n", 1)[0].strip()
try:
root = _resolve_search_root(raw_path, workspace)
except ValueError as e:
return {"error": f"ls: {e}", "exit_code": 1}
def _ls():
if not os.path.isdir(root):
return None, f"ls: {root}: not a directory"
rows = []
try:
with os.scandir(root) as it:
for entry in it:
if entry.name.startswith("."):
continue
try:
is_dir = entry.is_dir(follow_symlinks=False)
size = entry.stat(follow_symlinks=False).st_size if not is_dir else 0
except OSError:
continue
rows.append((is_dir, entry.name, size))
except (PermissionError, OSError) as _e:
return None, f"ls: {_e}"
rows.sort(key=lambda r: (not r[0], r[1].lower())) # dirs first, then name
lines = [f"{root}:"]
for is_dir, name, size in rows[:_CODENAV_MAX_HITS]:
lines.append(f" {name}/" if is_dir else f" {name} ({size} B)")
if len(rows) > _CODENAV_MAX_HITS:
lines.append(f" ... [{len(rows) - _CODENAV_MAX_HITS} more]")
if not rows:
lines.append(" (empty)")
return "\n".join(lines), None
out, err = await asyncio.to_thread(_ls)
if err:
return {"error": err, "exit_code": 1}
return {"output": _truncate(out), "exit_code": 0}
if tool == "web_search":
from src.search import comprehensive_web_search
@@ -685,6 +1103,7 @@ async def execute_tool_block(
disabled_tools: Optional[set] = None,
owner: Optional[str] = None,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
workspace: Optional[str] = None,
) -> Tuple[str, Dict]:
"""Execute a single tool block. Returns (description, result_dict).
@@ -773,7 +1192,7 @@ async def execute_tool_block(
_is_bg, _bg_cmd = _split_bg_marker(content)
if _is_bg and _bg_cmd:
from src import bg_jobs
rec = bg_jobs.launch(_bg_cmd, session_id=session_id)
rec = bg_jobs.launch(_bg_cmd, session_id=session_id, cwd=workspace or _AGENT_WORKDIR)
short = _bg_cmd.strip().split(chr(10))[0][:80]
desc = f"bash (background): {short}"
result = {
@@ -795,7 +1214,14 @@ async def execute_tool_block(
if tool in _MCP_TOOL_MAP:
first_line = content.split(chr(10))[0][:80]
desc = f"{tool}: {first_line}"
result = await _call_mcp_tool(tool, content, progress_cb=progress_cb)
result = await _call_mcp_tool(tool, content, progress_cb=progress_cb, workspace=workspace)
elif tool in ("grep", "glob", "ls"):
# Code-navigation tools — no MCP server; run the direct implementation.
# Confined to the workspace when one is set (same policy as read_file).
first_line = content.split(chr(10))[0][:80]
desc = f"{tool}: {first_line}"
result = await _direct_fallback(tool, content, progress_cb=progress_cb, workspace=workspace) \
or {"error": f"{tool}: execution failed", "exit_code": 1}
elif tool == "create_document":
title = content.split("\n")[0].strip()[:60]
desc = f"create_document: {title}"
@@ -898,6 +1324,9 @@ async def execute_tool_block(
elif tool == "edit_image":
desc = "edit_image"
result = await do_edit_image(content, owner=owner)
elif tool == "edit_file":
result = await _do_edit_file(content, workspace=workspace)
desc = result.get("output") or result.get("error") or "edit_file"
elif tool == "trigger_research":
desc = "trigger_research"
result = await do_trigger_research(content, owner=owner)

View File

@@ -22,7 +22,13 @@ logger = logging.getLogger(__name__)
# Tools that are ALWAYS included regardless of retrieval results.
# These are the most commonly needed and should never be missing.
ALWAYS_AVAILABLE = frozenset({
"bash", "python", "web_search", "web_fetch", "read_file",
"bash", "python", "web_search", "web_fetch",
# File tools: read AND write/edit. An agent with disk access should always
# be able to change files, not just read them — otherwise a bare "edit X"
# request can miss write_file/edit_file (RAG-only) and the model wrongly
# falls back to edit_document (editor panel). All admin-gated by tool_security.
"read_file", "write_file", "edit_file",
"grep", "glob", "ls", # code-navigation tools (admin-gated by tool_security)
"api_call", # For configured integrations (Miniflux, Gitea, Linkding, etc.)
# The two genuinely AMBIENT cookbook tools — "what's running" and
# "kill it" can be asked any time without prior cookbook context,
@@ -71,8 +77,12 @@ BUILTIN_TOOL_DESCRIPTIONS: Dict[str, str] = {
"python": "Execute Python code for computation, data processing, math, scripting, parsing, API calls. Not for writing code for the user.",
"web_search": "Quick single web lookup for a fact, current event, or doc mid-task. NOT for 'research X' / 'do research on X' requests — those are deep-research jobs (use trigger_research). web_search = one query; trigger_research = a full researched report in the sidebar.",
"web_fetch": "Fetch and read the text content of a specific URL/website the user names (e.g. 'check example.com', 'open this link'). Use when you have a concrete URL; for open-ended lookups use web_search instead.",
"read_file": "Read a file from disk and return its contents. View source code, config files, logs.",
"write_file": "Write content to a file on disk. Create new files, save output, update configs.",
"read_file": "Read a file from disk and return its contents. View source code, config files, logs. Supports an optional line range (offset/limit) for large files.",
"grep": "Search file CONTENTS for a regex across a directory tree (ripgrep-backed, honours .gitignore). Returns file:line:match. Use to find where code/symbols/strings live — prefer over bash grep.",
"glob": "Find FILES by glob pattern (e.g. '**/*.py'), newest first. Use to locate files by name/extension — prefer over bash find/ls.",
"ls": "List a directory's entries (folders then files with sizes). Use to see what's in a folder — prefer over bash ls.",
"write_file": "Write/create or fully rewrite a file ON DISK (source code, configs, project files). Use for new files or full rewrites — NOT create_document (editor panel) and NOT a bash heredoc.",
"edit_file": "Edit an existing file ON DISK by exact string replacement (fix a bug, change a function). Shows a diff. The tool for changing files on disk — NOT edit_document (editor panel) and NOT bash sed/heredoc.",
"create_document": "Create a new document in the editor panel. For code, articles, text content longer than 15 lines, unless an already-open document/email draft is the obvious target. If an email compose draft is open, edit that draft instead of creating another document.",
"edit_document": "Preferred tool for editing an existing document — targeted find-and-replace. Use for any small change: add a function, fix a bug, tweak a section, rename things.",
"update_document": "Replace the entire active document content. ONLY for full rewrites (>50% changed). Do not use for small edits — use edit_document instead.",

View File

@@ -82,16 +82,65 @@ FUNCTION_TOOL_SCHEMAS = [
"type": "function",
"function": {
"name": "read_file",
"description": "Read a file from disk",
"description": "Read a file from disk. Optionally read a line range with offset/limit for large files.",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "File path to read"}
"path": {"type": "string", "description": "File path to read"},
"offset": {"type": "integer", "description": "1-based line to start reading from (optional)"},
"limit": {"type": "integer", "description": "Max number of lines to read from offset (optional)"}
},
"required": ["path"]
}
}
},
{
"type": "function",
"function": {
"name": "grep",
"description": "Search file contents for a regular expression across a directory tree (uses ripgrep when available, respecting .gitignore). Returns file:line:match. PREFER this over `bash grep/rg` for code search — confined to the allowed roots, structured output.",
"parameters": {
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "Regular expression to search for"},
"path": {"type": "string", "description": "Directory or file to search (optional; defaults to the project root)"},
"glob": {"type": "string", "description": "Only search files matching this glob, e.g. '*.py' (optional)"},
"ignore_case": {"type": "boolean", "description": "Case-insensitive match (optional)"},
"max_results": {"type": "integer", "description": "Max matches to return (optional)"}
},
"required": ["pattern"]
}
}
},
{
"type": "function",
"function": {
"name": "glob",
"description": "Find files by glob pattern (recursive), newest first. e.g. '**/*.py'. PREFER this over `bash find/ls` for locating files — confined to the allowed roots.",
"parameters": {
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "Glob pattern, e.g. '**/*.ts' or 'src/**/test_*.py'"},
"path": {"type": "string", "description": "Base directory (optional; defaults to the project root)"}
},
"required": ["pattern"]
}
}
},
{
"type": "function",
"function": {
"name": "ls",
"description": "List the entries of a directory (folders first, then files with sizes). PREFER this over `bash ls` — confined to the allowed roots.",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "Directory to list (optional; defaults to the project root)"}
},
"required": []
}
}
},
{
"type": "function",
"function": {
@@ -107,6 +156,23 @@ FUNCTION_TOOL_SCHEMAS = [
}
}
},
{
"type": "function",
"function": {
"name": "edit_file",
"description": "Edit a file ON DISK by exact string replacement (home folder, project files, any real path like ~/sweden.txt or /path/to/file). This is the right tool for files on disk — NOT edit_document (that's for editor-panel documents). PREFER this over bash (sed/echo) — it shows a diff. old_string must match the file exactly and be unique (or set replace_all). Use write_file to create a new file.",
"parameters": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "File path to edit"},
"old_string": {"type": "string", "description": "Exact text to replace (must match the file, including indentation)"},
"new_string": {"type": "string", "description": "Replacement text"},
"replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match"}
},
"required": ["path", "old_string", "new_string"]
}
}
},
{
"type": "function",
"function": {
@@ -127,7 +193,7 @@ FUNCTION_TOOL_SCHEMAS = [
"type": "function",
"function": {
"name": "edit_document",
"description": "PREFERRED way to change an existing document. Targeted find-and-replace with multiple FIND/REPLACE pairs per call. Use this for any edit smaller than a full rewrite: adding a function, fixing a bug, tweaking a section, renaming things. Do NOT send the whole file back via update_document for small edits — it wastes tokens and is hard to review.",
"description": "Edit a document OPEN IN THE EDITOR PANEL (created via create_document) — NOT a file on disk. For files on disk (home folder, project files, anything with a path like ~/x.txt or /path/to/file) use edit_file instead. Targeted find-and-replace with multiple FIND/REPLACE pairs per call; use for any edit smaller than a full rewrite. Do NOT send the whole file back via update_document for small edits.",
"parameters": {
"type": "object",
"properties": {
@@ -1126,9 +1192,17 @@ def function_call_to_tool_block(name: str, arguments: str) -> Optional[ToolBlock
else:
content = args.get("query", "")
elif tool_type == "read_file":
content = args.get("path", "")
# Plain path (back-compat) unless a line range is requested → JSON.
if args.get("offset") or args.get("limit"):
content = json.dumps(args)
else:
content = args.get("path", "")
elif tool_type in ("grep", "glob", "ls"):
content = json.dumps(args) if args else "{}"
elif tool_type == "write_file":
content = args.get("path", "") + "\n" + args.get("content", "")
elif tool_type == "edit_file":
content = json.dumps(args)
elif tool_type == "create_document":
parts = [args.get("title", "Untitled")]
if args.get("language"):

View File

@@ -16,6 +16,10 @@ NON_ADMIN_BLOCKED_TOOLS = {
"python",
"read_file",
"write_file",
"edit_file",
"grep",
"glob",
"ls",
"search_chats",
"manage_memory",
"manage_skills",