Merge remote-tracking branch 'origin/dev'
This commit is contained in:
@@ -177,6 +177,7 @@ TOOL_SECTIONS = {
|
||||
<shell command>
|
||||
```
|
||||
Run any shell command. Output is returned to you. Use for: installing packages, checking files, git, curl, system info, etc.
|
||||
NEVER use bash to create or change files — no `>`/`>>` redirects, no heredocs (`cat > f << 'EOF'`), no `tee`, `sed -i`, `awk -i`, no `python -c` that writes. To CREATE or fully rewrite a file use `write_file`; to change part of an existing file use `edit_file`. Those show a diff and are the ONLY allowed way to write files. (bash is for read-only inspection: `ls`, `cat` to READ, `grep`, `git status`/`git diff`, builds, installs.)
|
||||
For LONG-running commands (package installs, pip/npm, ffmpeg, model downloads, training, builds — anything that may take more than ~20s), make the FIRST line `#!bg` to run it in the BACKGROUND. You get a job id back immediately and are automatically re-invoked with the full output when it finishes — so you never block the chat waiting. Example:
|
||||
```bash
|
||||
#!bg
|
||||
@@ -220,6 +221,12 @@ Read a file and return its contents.""",
|
||||
```
|
||||
Write content to a file. First line is the path, rest is the content.""",
|
||||
|
||||
"edit_file": """\
|
||||
```edit_file
|
||||
{"path": "<file path>", "old_string": "<exact text to replace>", "new_string": "<replacement>", "replace_all": false}
|
||||
```
|
||||
Edit an EXISTING file by exact string replacement. PREFER this over bash (sed/echo/redirects) for changing files — it shows a before/after diff. `old_string` must match the file exactly and be unique unless `replace_all` is true. Use write_file to create a new file.""",
|
||||
|
||||
"create_document": """\
|
||||
```create_document
|
||||
<title>
|
||||
@@ -236,7 +243,7 @@ old text to find
|
||||
new replacement text
|
||||
<<<END>>>
|
||||
```
|
||||
PREFERRED way to change an existing document. Find exact text and replace it. Multiple FIND/REPLACE blocks per call OK. Use this for any edit smaller than a full rewrite — adding a function, fixing a bug, tweaking a section, renaming things. **If a document is open in the editor, treat it as the user's current context: don't ask which file they mean, and don't create a new one — just edit_document the active one.** Do NOT re-send the whole file with update_document for small changes.""",
|
||||
Edit a document OPEN IN THE EDITOR PANEL — NOT a file on disk. For files on disk (home folder, project files, any real path like ~/sweden.txt) use `edit_file` instead. Find exact text and replace it. Multiple FIND/REPLACE blocks per call OK. Use for any edit smaller than a full rewrite. **If a document is open in the editor, treat it as the user's current context: don't ask which file they mean, and don't create a new one — just edit_document the active one.** Do NOT re-send the whole file with update_document for small changes.""",
|
||||
|
||||
"update_document": """\
|
||||
```update_document
|
||||
@@ -462,13 +469,14 @@ _API_HOSTS = frozenset([
|
||||
"api.together.xyz", "api.fireworks.ai",
|
||||
"api.perplexity.ai", "api.x.ai",
|
||||
"ollama.com", "api.venice.ai",
|
||||
"api.githubcopilot.com",
|
||||
# Local OpenAI-compatible endpoints (llama.cpp, vLLM, LM Studio, etc.).
|
||||
# Without these, `_is_api_model` falls back to keyword sniffing on the
|
||||
# model name, so well-behaved local servers don't get native tool
|
||||
# schemas and the agent silently degrades to fenced-block parsing.
|
||||
"localhost", "127.0.0.1", "host.docker.internal",
|
||||
])
|
||||
_MCP_KEYWORDS = frozenset(["browse", "browser", "website", "calendar", "event", "email",
|
||||
_MCP_KEYWORDS = frozenset(["mcp", "browse", "browser", "website", "calendar", "event", "email",
|
||||
"gmail", "screenshot", "navigate", "click", "miniflux", "rss", "feed"])
|
||||
_ADMIN_SCHEMA_NAMES = frozenset([
|
||||
"manage_session", "manage_skills", "manage_tasks",
|
||||
@@ -1380,6 +1388,7 @@ async def stream_agent_loop(
|
||||
owner: Optional[str] = None,
|
||||
relevant_tools: Optional[Set[str]] = None,
|
||||
fallbacks: Optional[List[tuple]] = None,
|
||||
workspace: Optional[str] = None,
|
||||
_is_teacher_run: bool = False,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming agent loop generator.
|
||||
@@ -1546,6 +1555,27 @@ async def stream_agent_loop(
|
||||
compact=_is_api_model,
|
||||
owner=owner,
|
||||
)
|
||||
if workspace:
|
||||
# PREPEND (not append) so it dominates the large base prompt — appended
|
||||
# at the end, small models ignored it and asked the user for code. The
|
||||
# folder IS the project; the agent must explore it, not ask.
|
||||
_ws_note = (
|
||||
f"## ACTIVE WORKSPACE — READ FIRST\n"
|
||||
f"The user is working in this folder: {workspace}\n"
|
||||
f"It IS the project. bash/python run with cwd set here and "
|
||||
f"read_file/write_file are confined to it (paths outside are rejected).\n"
|
||||
f"When the user says \"the code\" / \"this project\" / \"the workspace\" "
|
||||
f"or asks to review/find/edit something WITHOUT a path, they mean THIS "
|
||||
f"folder. Do NOT ask the user for code or a path, and do NOT read a file "
|
||||
f"literally named \"workspace\". ALWAYS start by exploring it yourself: "
|
||||
f"run `bash` → `git ls-files` (or `ls -R`) to see the files, then "
|
||||
f"read_file the relevant ones by path RELATIVE to the workspace."
|
||||
)
|
||||
if messages and messages[0].get("role") == "system":
|
||||
messages[0]["content"] = _ws_note + "\n\n" + (messages[0].get("content") or "")
|
||||
else:
|
||||
messages.insert(0, {"role": "system", "content": _ws_note})
|
||||
logger.info("[workspace] active for this turn: %s", workspace)
|
||||
prep_timings["prompt_build"] = time.time() - _t2
|
||||
|
||||
_t3 = time.time()
|
||||
@@ -1658,6 +1688,11 @@ async def stream_agent_loop(
|
||||
_doc_opened = False # whether doc_stream_open was sent
|
||||
_doc_last_len = 0 # last content length sent
|
||||
|
||||
# Set when the loop runs out of rounds while the agent was still actively
|
||||
# using tools — i.e. it was cut off, not finished. Drives a "Continue" event
|
||||
# so the user can resume instead of the turn silently stalling.
|
||||
_exhausted_rounds = False
|
||||
|
||||
for round_num in range(1, max_rounds + 1):
|
||||
round_response = ""
|
||||
round_reasoning = "" # reasoning_content deltas (DeepSeek-thinking, vLLM --reasoning-parser)
|
||||
@@ -2167,6 +2202,7 @@ async def stream_agent_loop(
|
||||
disabled_tools=disabled_tools,
|
||||
owner=owner,
|
||||
progress_cb=_push_progress,
|
||||
workspace=workspace,
|
||||
)
|
||||
finally:
|
||||
# Sentinel so the drainer knows to stop.
|
||||
@@ -2282,6 +2318,9 @@ async def stream_agent_loop(
|
||||
if result.get("images"):
|
||||
img = result["images"][0]
|
||||
tool_output_data["screenshot"] = f"data:{img['mimeType']};base64,{img['data']}"
|
||||
# Forward a file-write diff for inline before/after rendering
|
||||
if "diff" in result:
|
||||
tool_output_data["diff"] = result["diff"]
|
||||
yield f'data: {json.dumps(tool_output_data)}\n\n'
|
||||
|
||||
# Native document tools open in the editor + carry the REAL doc id.
|
||||
@@ -2324,6 +2363,10 @@ async def stream_agent_loop(
|
||||
if result.get("doc_id"):
|
||||
tool_event["doc_id"] = result["doc_id"]
|
||||
tool_event["doc_title"] = result.get("title", "")
|
||||
# Persist the file-write/edit diff so it re-renders on reload — without
|
||||
# this the diff shows live but vanishes from saved history.
|
||||
if result.get("diff"):
|
||||
tool_event["diff"] = result["diff"]
|
||||
tool_events.append(tool_event)
|
||||
if block.tool_type in _VERIFIER_EFFECTFUL_TOOLS:
|
||||
_effectful_used = True
|
||||
@@ -2348,6 +2391,20 @@ async def stream_agent_loop(
|
||||
|
||||
# Separator in accumulated response
|
||||
full_response += "\n\n"
|
||||
else:
|
||||
# The for-loop completed every allowed round WITHOUT an early `break`
|
||||
# (a `break` fires on "done", budget, or error). Reaching this `else`
|
||||
# means the agent kept working until it ran out of rounds — so offer
|
||||
# Continue instead of stopping silently. This catches ALL exhaustion
|
||||
# paths, including a verifier `continue` on the final round (the old
|
||||
# bottom-of-loop flag missed those).
|
||||
_exhausted_rounds = True
|
||||
|
||||
# If the loop hit the round cap while still working, tell the client so it
|
||||
# can show a "Continue" affordance instead of the turn just stopping.
|
||||
if _exhausted_rounds:
|
||||
logger.info("[agent] round cap (%d) reached mid-task — emitting rounds_exhausted", max_rounds)
|
||||
yield f'data: {json.dumps({"type": "rounds_exhausted", "rounds": max_rounds})}\n\n'
|
||||
|
||||
# If the response is completely empty and no tools were executed,
|
||||
# yield a fallback message so the user is not left hanging.
|
||||
|
||||
@@ -26,7 +26,8 @@ MAX_OUTPUT_CHARS = 10_000
|
||||
MAX_READ_CHARS = 20_000
|
||||
|
||||
# Tool types that trigger execution
|
||||
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file",
|
||||
TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
|
||||
"grep", "glob", "ls",
|
||||
"create_document", "update_document", "edit_document",
|
||||
"search_chats",
|
||||
"chat_with_model", "create_session", "list_sessions",
|
||||
|
||||
@@ -9,6 +9,7 @@ from src.constants import (
|
||||
SESSIONS_FILE, DEFAULT_HOST, OPENAI_API_KEY
|
||||
)
|
||||
from src.memory import MemoryManager
|
||||
from src.memory_provider import MemoryProviderRegistry, NativeMemoryProvider
|
||||
from services.memory.skills import SkillsManager
|
||||
from core.session_manager import SessionManager
|
||||
from core.models import set_session_manager
|
||||
@@ -73,6 +74,10 @@ def initialize_managers(base_dir: str, rag_manager=None) -> Dict[str, Any]:
|
||||
logger.warning(f"MemoryVectorStore DEGRADED: {e}")
|
||||
memory_vector = None
|
||||
|
||||
memory_provider_registry = MemoryProviderRegistry([
|
||||
NativeMemoryProvider(memory_manager, memory_vector),
|
||||
])
|
||||
|
||||
# Initialize processors
|
||||
chat_processor = ChatProcessor(memory_manager, personal_docs_manager, memory_vector=memory_vector, skills_manager=skills_manager)
|
||||
research_handler = ResearchHandler()
|
||||
@@ -99,6 +104,7 @@ def initialize_managers(base_dir: str, rag_manager=None) -> Dict[str, Any]:
|
||||
return {
|
||||
"memory_manager": memory_manager,
|
||||
"memory_vector": memory_vector,
|
||||
"memory_provider_registry": memory_provider_registry,
|
||||
"skills_manager": skills_manager,
|
||||
"session_manager": session_manager,
|
||||
"upload_handler": upload_handler,
|
||||
|
||||
@@ -265,6 +265,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
existing.all_day = all_day
|
||||
existing.is_utc = row_is_utc
|
||||
existing.rrule = rrule
|
||||
existing.origin = "caldav"
|
||||
else:
|
||||
new_ev = CalendarEvent(
|
||||
uid=uid_val,
|
||||
@@ -277,6 +278,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
all_day=all_day,
|
||||
is_utc=row_is_utc,
|
||||
rrule=rrule,
|
||||
origin="caldav",
|
||||
)
|
||||
db.add(new_ev)
|
||||
pending[uid_val] = new_ev
|
||||
@@ -286,8 +288,13 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
|
||||
# Prune locally-cached CalDAV events that vanished
|
||||
# upstream (only within our sync window — events outside
|
||||
# the window aren't in `objs`, so we'd false-delete them).
|
||||
# Only rows we previously pulled from the server (origin=="caldav")
|
||||
# are prunable; locally-created events (agent / email triage / a
|
||||
# UI event whose write-back failed) carry origin NULL and must
|
||||
# never be deleted just because the server didn't return them.
|
||||
stale = db.query(CalendarEvent).filter(
|
||||
CalendarEvent.calendar_id == local_cal.id,
|
||||
CalendarEvent.origin == "caldav",
|
||||
CalendarEvent.dtstart >= start,
|
||||
CalendarEvent.dtstart <= end,
|
||||
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),
|
||||
|
||||
253
src/copilot.py
Normal file
253
src/copilot.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# src/copilot.py
|
||||
"""GitHub Copilot provider support.
|
||||
|
||||
Copilot exposes an OpenAI-compatible API at ``https://api.githubcopilot.com``
|
||||
(``/chat/completions`` + ``/models``). Authentication is a GitHub OAuth
|
||||
**device flow**: the user authorises a device code in their browser and we
|
||||
receive a long-lived ``access_token`` that is sent directly as
|
||||
``Authorization: Bearer <token>`` — there is no separate Copilot-token
|
||||
exchange and no refresh (mirrors how editors / opencode talk to Copilot).
|
||||
|
||||
The only provider-specific wrinkle beyond the bearer token is a handful of
|
||||
required request headers (API version, intent, an editor-style User-Agent,
|
||||
and ``x-initiator`` for agent-vs-user request accounting). Those live in
|
||||
:func:`copilot_headers`.
|
||||
|
||||
This module holds the constants + pure helpers; the HTTP device-flow calls
|
||||
live in :mod:`routes.copilot_routes` so they can be auth-gated.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# GitHub OAuth client id used for the device flow. Copilot's token endpoint
|
||||
# only accepts client ids that GitHub has allow-listed for Copilot access, so
|
||||
# we reuse the public VS Code client id (the de-facto standard third-party
|
||||
# clients use). Override via env if you register your own allow-listed app.
|
||||
COPILOT_CLIENT_ID = os.environ.get(
|
||||
"ODYSSEUS_COPILOT_CLIENT_ID", "01ab8ac9400c4e429b23"
|
||||
)
|
||||
|
||||
# Dated API version header required by the Copilot API (models + chat).
|
||||
COPILOT_API_VERSION = os.environ.get(
|
||||
"ODYSSEUS_COPILOT_API_VERSION", "2026-06-01"
|
||||
)
|
||||
|
||||
# Public Copilot API base. GitHub Enterprise uses ``copilot-api.<domain>``.
|
||||
COPILOT_BASE = "https://api.githubcopilot.com"
|
||||
|
||||
# Copilot wants an editor-like User-Agent + integration id. These identify the
|
||||
# client to GitHub; keep them stable.
|
||||
COPILOT_USER_AGENT = os.environ.get(
|
||||
"ODYSSEUS_COPILOT_USER_AGENT", "Odysseus/1.0"
|
||||
)
|
||||
COPILOT_INTEGRATION_ID = os.environ.get(
|
||||
"ODYSSEUS_COPILOT_INTEGRATION_ID", "vscode-chat"
|
||||
)
|
||||
COPILOT_EDITOR_VERSION = os.environ.get(
|
||||
"ODYSSEUS_COPILOT_EDITOR_VERSION", "Odysseus/1.0"
|
||||
)
|
||||
|
||||
# OAuth scope requested during the device flow.
|
||||
COPILOT_SCOPE = "read:user"
|
||||
|
||||
# Default GitHub host for the device flow (public github.com).
|
||||
GITHUB_HOST = "github.com"
|
||||
|
||||
|
||||
def device_code_url(host: str = GITHUB_HOST) -> str:
|
||||
return f"https://{host}/login/device/code"
|
||||
|
||||
|
||||
def access_token_url(host: str = GITHUB_HOST) -> str:
|
||||
return f"https://{host}/login/oauth/access_token"
|
||||
|
||||
|
||||
def normalize_domain(url: str) -> str:
|
||||
"""Strip scheme/trailing slash from a GitHub Enterprise URL or domain."""
|
||||
return (url or "").replace("https://", "").replace("http://", "").rstrip("/")
|
||||
|
||||
|
||||
def enterprise_base(enterprise_url: Optional[str]) -> str:
|
||||
"""Return the Copilot API base for a deployment.
|
||||
|
||||
Public github.com → ``https://api.githubcopilot.com``.
|
||||
Enterprise <domain> → ``https://copilot-api.<domain>``.
|
||||
"""
|
||||
if not enterprise_url:
|
||||
return COPILOT_BASE
|
||||
return f"https://copilot-api.{normalize_domain(enterprise_url)}"
|
||||
|
||||
|
||||
def is_copilot_base(url: Optional[str]) -> bool:
|
||||
"""True if a base URL points at the Copilot API (public or enterprise)."""
|
||||
if not url:
|
||||
return False
|
||||
try:
|
||||
host = (urlparse(url).hostname or "").lower().rstrip(".")
|
||||
except Exception:
|
||||
return False
|
||||
if not host:
|
||||
return False
|
||||
# Public: api.githubcopilot.com (or any *.githubcopilot.com).
|
||||
if host == "githubcopilot.com" or host.endswith(".githubcopilot.com"):
|
||||
return True
|
||||
# Enterprise: copilot-api.<domain>.
|
||||
if host.startswith("copilot-api."):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def copilot_headers(
|
||||
api_key: Optional[str],
|
||||
*,
|
||||
agent: bool = False,
|
||||
vision: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Build the Copilot-specific request headers.
|
||||
|
||||
Args:
|
||||
api_key: the GitHub device-flow access token (sent as Bearer).
|
||||
agent: request originates from the agent loop (a tool-driven turn)
|
||||
rather than a direct user message. Sets ``x-initiator`` for
|
||||
Copilot's agent-vs-user request accounting.
|
||||
vision: the request carries an image part.
|
||||
"""
|
||||
headers: Dict[str, str] = {
|
||||
"X-GitHub-Api-Version": COPILOT_API_VERSION,
|
||||
"Openai-Intent": "conversation-edits",
|
||||
"User-Agent": COPILOT_USER_AGENT,
|
||||
"Editor-Version": COPILOT_EDITOR_VERSION,
|
||||
"Copilot-Integration-Id": COPILOT_INTEGRATION_ID,
|
||||
"x-initiator": "agent" if agent else "user",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
if vision:
|
||||
headers["Copilot-Vision-Request"] = "true"
|
||||
return headers
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device-flow OAuth (pure HTTP; orchestration lives in routes.copilot_routes)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _oauth_post_headers() -> Dict[str, str]:
|
||||
return {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": COPILOT_USER_AGENT,
|
||||
}
|
||||
|
||||
|
||||
def request_device_code(host: str = GITHUB_HOST, *, timeout: float = 10.0) -> Dict:
|
||||
"""Start the device flow. Returns GitHub's
|
||||
``{device_code, user_code, verification_uri, expires_in, interval}``.
|
||||
"""
|
||||
r = httpx.post(
|
||||
device_code_url(host),
|
||||
headers=_oauth_post_headers(),
|
||||
json={"client_id": COPILOT_CLIENT_ID, "scope": COPILOT_SCOPE},
|
||||
timeout=timeout,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def poll_access_token(host: str, device_code: str, *, timeout: float = 10.0) -> Dict:
|
||||
"""Poll once for the access token. GitHub returns HTTP 200 with an
|
||||
``error`` field (``authorization_pending``/``slow_down``) while the user
|
||||
hasn't authorised yet, or ``{access_token, ...}`` once they have.
|
||||
"""
|
||||
r = httpx.post(
|
||||
access_token_url(host),
|
||||
headers=_oauth_post_headers(),
|
||||
json={
|
||||
"client_id": COPILOT_CLIENT_ID,
|
||||
"device_code": device_code,
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
|
||||
def fetch_models(base: str, token: str, *, timeout: float = 15.0) -> List[Dict]:
|
||||
"""Fetch Copilot's model catalogue, filtered to picker-enabled models.
|
||||
|
||||
Returns a list of ``{id, tool_calls, vision}`` dicts. Falls back to the
|
||||
full list if no model advertises ``model_picker_enabled`` (defensive
|
||||
against API-shape drift).
|
||||
"""
|
||||
url = base.rstrip("/") + "/models"
|
||||
r = httpx.get(url, headers=copilot_headers(token), timeout=timeout)
|
||||
r.raise_for_status()
|
||||
data = (r.json() or {}).get("data") or []
|
||||
|
||||
def _parse(item: Dict) -> Optional[Dict]:
|
||||
mid = item.get("id")
|
||||
if not mid:
|
||||
return None
|
||||
supports = ((item.get("capabilities") or {}).get("supports")) or {}
|
||||
return {
|
||||
"id": mid,
|
||||
"tool_calls": bool(supports.get("tool_calls")),
|
||||
"vision": bool(supports.get("vision")),
|
||||
"picker": bool(item.get("model_picker_enabled")),
|
||||
}
|
||||
|
||||
parsed = [p for p in (_parse(it) for it in data) if p]
|
||||
picker = [p for p in parsed if p["picker"]]
|
||||
chosen = picker or parsed
|
||||
for p in chosen:
|
||||
p.pop("picker", None)
|
||||
return chosen
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-request header flags
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMAGE_PART_TYPES = ("image_url", "input_image", "image")
|
||||
|
||||
|
||||
def request_flags(messages) -> tuple:
|
||||
"""Derive ``(agent, vision)`` from an OpenAI-style message list.
|
||||
|
||||
Mirrors opencode's logic:
|
||||
* ``agent`` — the last message is *not* a plain user message (i.e. it's a
|
||||
tool result / assistant follow-up), so Copilot should treat the request
|
||||
as agent-initiated for request accounting.
|
||||
* ``vision`` — any message carries an image content part.
|
||||
"""
|
||||
msgs = messages or []
|
||||
last = msgs[-1] if msgs else None
|
||||
agent = bool(last) and last.get("role") != "user"
|
||||
vision = False
|
||||
for m in msgs:
|
||||
content = m.get("content") if isinstance(m, dict) else None
|
||||
if isinstance(content, list) and any(
|
||||
isinstance(p, dict) and p.get("type") in _IMAGE_PART_TYPES for p in content
|
||||
):
|
||||
vision = True
|
||||
break
|
||||
return agent, vision
|
||||
|
||||
|
||||
def apply_request_headers(headers: Dict[str, str], messages) -> Dict[str, str]:
|
||||
"""Set ``x-initiator`` / ``Copilot-Vision-Request`` on a header dict based
|
||||
on the outgoing messages. Mutates and returns ``headers``."""
|
||||
agent, vision = request_flags(messages)
|
||||
headers["x-initiator"] = "agent" if agent else "user"
|
||||
if vision:
|
||||
headers["Copilot-Vision-Request"] = "true"
|
||||
return headers
|
||||
|
||||
@@ -196,6 +196,8 @@ class DeepResearcher:
|
||||
max_content_chars: int = 15000,
|
||||
max_report_tokens: int = 8192,
|
||||
extraction_timeout: int = 90,
|
||||
planning_timeout: int = 90,
|
||||
query_timeout: int = 120,
|
||||
extraction_concurrency: int = 3,
|
||||
min_rounds: int = 2,
|
||||
max_empty_rounds: int = 2,
|
||||
@@ -215,6 +217,8 @@ class DeepResearcher:
|
||||
self.max_content_chars = max_content_chars
|
||||
self.max_report_tokens = max_report_tokens
|
||||
self.extraction_timeout = min(3600, max(15, int(extraction_timeout or 90)))
|
||||
self.planning_timeout = min(3600, max(15, int(planning_timeout or 90)))
|
||||
self.query_timeout = min(3600, max(15, int(query_timeout or 120)))
|
||||
self.extraction_concurrency = min(12, max(1, int(extraction_concurrency or 3)))
|
||||
self.min_rounds = min_rounds
|
||||
self.max_empty_rounds = max_empty_rounds
|
||||
@@ -395,7 +399,7 @@ class DeepResearcher:
|
||||
[{"role": "user", "content": prompt}],
|
||||
temperature=0.3,
|
||||
max_tokens=1024,
|
||||
timeout=30,
|
||||
timeout=getattr(self, "planning_timeout", 90),
|
||||
)
|
||||
# Try to parse as JSON for structured plan
|
||||
parsed = self._parse_json_object(response)
|
||||
@@ -478,6 +482,7 @@ class DeepResearcher:
|
||||
[{"role": "user", "content": prompt}],
|
||||
temperature=0.5,
|
||||
max_tokens=4096,
|
||||
timeout=getattr(self, "query_timeout", 120),
|
||||
)
|
||||
queries = self._parse_json_array(response)
|
||||
# Deduplicate
|
||||
|
||||
@@ -194,6 +194,9 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
|
||||
headers["x-api-key"] = api_key
|
||||
headers["anthropic-version"] = "2023-06-01"
|
||||
return headers
|
||||
if provider == "copilot":
|
||||
from src.copilot import copilot_headers
|
||||
return copilot_headers(api_key)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
if provider == "openrouter":
|
||||
|
||||
@@ -67,7 +67,7 @@ _host_health_lock = threading.Lock()
|
||||
_model_activity: Dict[str, float] = {}
|
||||
|
||||
def _model_activity_key(url: str, model: str) -> str:
|
||||
return f"{(url or '').strip().rstrip()}|{(model or '').strip()}"
|
||||
return f"{(url or '').strip()}|{(model or '').strip()}"
|
||||
|
||||
def note_model_activity(url: str, model: str):
|
||||
"""Record that a real upstream request used this endpoint/model."""
|
||||
@@ -317,6 +317,9 @@ def _detect_provider(url: str) -> str:
|
||||
return "openrouter"
|
||||
if _host_match(url, "groq.com"):
|
||||
return "groq"
|
||||
from src.copilot import is_copilot_base
|
||||
if is_copilot_base(url):
|
||||
return "copilot"
|
||||
return "openai"
|
||||
|
||||
|
||||
@@ -327,6 +330,14 @@ def _provider_headers(provider: str, headers: Optional[Dict] = None) -> Dict[str
|
||||
if provider == "openrouter":
|
||||
h.setdefault("HTTP-Referer", "https://github.com/pewdiepie-archdaemon/odysseus")
|
||||
h.setdefault("X-OpenRouter-Title", "Odysseus")
|
||||
if provider == "copilot":
|
||||
# Ensure the Copilot-required headers are present even when the caller
|
||||
# didn't pass pre-built headers (e.g. model listing). build_headers()
|
||||
# already injects these for the live chat path; setdefault keeps any
|
||||
# request-specific values (x-initiator/vision) the caller set.
|
||||
from src.copilot import copilot_headers
|
||||
for k, v in copilot_headers(None).items():
|
||||
h.setdefault(k, v)
|
||||
return h
|
||||
|
||||
|
||||
@@ -340,6 +351,8 @@ def _provider_label(url: str) -> str:
|
||||
if _host_match(url, "openai.com"): return "OpenAI"
|
||||
if _host_match(url, "openrouter.ai"): return "OpenRouter"
|
||||
if _host_match(url, "groq.com"): return "Groq"
|
||||
from src.copilot import is_copilot_base
|
||||
if is_copilot_base(url): return "GitHub Copilot"
|
||||
if _host_match(url, "mistral.ai"): return "Mistral"
|
||||
if _host_match(url, "deepseek.com"): return "DeepSeek"
|
||||
if _host_match(url, "googleapis.com"): return "Google"
|
||||
@@ -481,7 +494,7 @@ def _build_anthropic_payload(model, messages, temperature, max_tokens, stream=Fa
|
||||
chat_messages = []
|
||||
for m in messages:
|
||||
if m.get("role") == "system":
|
||||
system_parts.append(m["content"])
|
||||
system_parts.append(m.get("content") or "")
|
||||
elif m.get("role") == "tool":
|
||||
# Convert OpenAI tool result to Anthropic format
|
||||
chat_messages.append({
|
||||
@@ -884,7 +897,7 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
|
||||
non_sys = []
|
||||
for m in messages_copy:
|
||||
if m.get("role") == "system":
|
||||
sys_parts.append(m["content"])
|
||||
sys_parts.append(m.get('content') or '')
|
||||
else:
|
||||
non_sys.append(m)
|
||||
if sys_parts:
|
||||
@@ -911,6 +924,9 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
|
||||
)
|
||||
else:
|
||||
target_url = url
|
||||
if provider == "copilot":
|
||||
from src.copilot import apply_request_headers
|
||||
apply_request_headers(h, messages_copy)
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages_copy,
|
||||
@@ -1028,7 +1044,7 @@ async def llm_call_async(
|
||||
non_sys = []
|
||||
for m in messages_copy:
|
||||
if m.get("role") == "system":
|
||||
sys_parts.append(m["content"])
|
||||
sys_parts.append(m.get('content') or '')
|
||||
else:
|
||||
non_sys.append(m)
|
||||
if sys_parts:
|
||||
@@ -1058,6 +1074,9 @@ async def llm_call_async(
|
||||
else:
|
||||
target_url = url
|
||||
h = _provider_headers(provider, headers)
|
||||
if provider == "copilot":
|
||||
from src.copilot import apply_request_headers
|
||||
apply_request_headers(h, messages_copy)
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages_copy,
|
||||
@@ -1088,6 +1107,9 @@ async def llm_call_async(
|
||||
f"LLM async call to {target_url} failed in {duration:.2f}s "
|
||||
f"(attempt {attempt}): HTTP {r.status_code} {friendly}"
|
||||
)
|
||||
if r.status_code in (429, 502, 503, 504) and attempt < max_retries:
|
||||
await asyncio.sleep(LLMConfig.RETRY_DELAY)
|
||||
continue
|
||||
raise HTTPException(r.status_code, friendly)
|
||||
logger.info(f"LLM async call to {target_url} succeeded in {duration:.2f}s (attempt {attempt})")
|
||||
_clear_host_dead(target_url)
|
||||
@@ -1109,7 +1131,9 @@ async def llm_call_async(
|
||||
duration = time.time() - start
|
||||
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
|
||||
logger.warning(f"LLM async connect to {target_url} failed after {duration:.2f}s: {e}{_tail}")
|
||||
raise HTTPException(503, f"Cannot reach {_host_key(target_url)}: {e}")
|
||||
if _cooled or attempt >= max_retries:
|
||||
raise HTTPException(503, f"Cannot reach {_host_key(target_url)}: {e}")
|
||||
await asyncio.sleep(LLMConfig.RETRY_DELAY)
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
duration = time.time() - start
|
||||
logger.warning(f"LLM async call attempt {attempt} failed after {duration:.2f}s: {e}")
|
||||
@@ -1138,7 +1162,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
non_sys = []
|
||||
for m in messages_copy:
|
||||
if m.get("role") == "system":
|
||||
sys_parts.append(m["content"])
|
||||
sys_parts.append(m.get('content') or '')
|
||||
else:
|
||||
non_sys.append(m)
|
||||
if sys_parts:
|
||||
@@ -1177,6 +1201,9 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
h = _provider_headers(provider, headers)
|
||||
if provider == "copilot":
|
||||
from src.copilot import apply_request_headers
|
||||
apply_request_headers(h, messages_copy)
|
||||
|
||||
# Short connect timeout: a reachable peer answers SYN in <100ms even on
|
||||
# Tailscale. 3s is plenty; 30s let one dead upstream wedge the UI.
|
||||
@@ -1358,6 +1385,8 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
# can detect thinking-in-progress (some models output </think> but no <think>)
|
||||
_thinking_model = _supports_thinking(model)
|
||||
_first_content_sent = False
|
||||
_in_think_tag = False # True while consuming <think>…</think> content
|
||||
_think_open_stripped = False # opening <think> tag already removed
|
||||
|
||||
def _emit_tool_calls():
|
||||
"""Build the tool_calls event string if any were accumulated."""
|
||||
@@ -1439,14 +1468,53 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
|
||||
yield f'data: {json.dumps({"delta": reasoning, "thinking": True})}\n\n'
|
||||
content = delta.get("content") or ""
|
||||
if content:
|
||||
# Some thinking backends start normal content with a
|
||||
# stray closing tag. Repair only that shape; do not
|
||||
# wrap every first token for model families like
|
||||
# MiniMax, which often stream ordinary answers.
|
||||
if _thinking_model and not _first_content_sent and content.lstrip().lower().startswith("</think"):
|
||||
content = "<think>" + content
|
||||
_first_content_sent = True
|
||||
yield f'data: {json.dumps({"delta": content})}\n\n'
|
||||
stripped = content.lstrip()
|
||||
# Auto-detect <think>…</think> in content stream.
|
||||
# Covers Qwen3-derived models (Qwopus, QwQ forks) whose
|
||||
# names don't match _THINKING_MODEL_PATTERNS but still
|
||||
# emit literal <think> markup via llama.cpp --jinja.
|
||||
if not _first_content_sent and not _thinking_model and not _in_think_tag and stripped.lower().startswith("<think"):
|
||||
_thinking_model = True
|
||||
_in_think_tag = True
|
||||
if _in_think_tag:
|
||||
close_idx = content.lower().find("</think>")
|
||||
if close_idx != -1:
|
||||
# Split: up-to-</think> → thinking, remainder → content
|
||||
think_part = content[:close_idx]
|
||||
if not _think_open_stripped:
|
||||
# Strip the opening <think[...] > from the first chunk.
|
||||
# Use a dedicated flag — _first_content_sent stays False
|
||||
# throughout the think block, so it must not be reused.
|
||||
tag_end = think_part.lower().find(">")
|
||||
if tag_end != -1:
|
||||
think_part = think_part[tag_end + 1:]
|
||||
_think_open_stripped = True
|
||||
regular_part = content[close_idx + len("</think>"):]
|
||||
_in_think_tag = False
|
||||
if think_part:
|
||||
yield f'data: {json.dumps({"delta": think_part, "thinking": True})}\n\n'
|
||||
if regular_part:
|
||||
_first_content_sent = True
|
||||
yield f'data: {json.dumps({"delta": regular_part})}\n\n'
|
||||
else:
|
||||
# Still inside <think>: route to thinking channel
|
||||
if not _think_open_stripped:
|
||||
# Strip the opening <think[...] > tag (first chunk only)
|
||||
tag_end = stripped.lower().find(">")
|
||||
if tag_end != -1:
|
||||
content = stripped[tag_end + 1:]
|
||||
_think_open_stripped = True
|
||||
if content:
|
||||
yield f'data: {json.dumps({"delta": content, "thinking": True})}\n\n'
|
||||
else:
|
||||
# Some thinking backends start normal content with a
|
||||
# stray closing tag. Repair only that shape; do not
|
||||
# wrap every first token for model families like
|
||||
# MiniMax, which often stream ordinary answers.
|
||||
if _thinking_model and not _first_content_sent and stripped.lower().startswith("</think"):
|
||||
content = "<think>" + content
|
||||
_first_content_sent = True
|
||||
yield f'data: {json.dumps({"delta": content})}\n\n'
|
||||
# Native tool calls — accumulate across chunks
|
||||
for tc in delta.get("tool_calls") or []:
|
||||
if tc is None:
|
||||
|
||||
@@ -8,6 +8,7 @@ Each server exposes tools that are made available to the agent loop.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -30,6 +31,64 @@ def _format_mcp_connection_error(name: str, command: str = "", args: Optional[Li
|
||||
return raw_error
|
||||
|
||||
|
||||
# Caps for rendering untrusted MCP tool schemas into the agent prompt (issue #2660).
|
||||
# MCP servers are third-party/user-added, so field names and parameter counts are
|
||||
# untrusted input — bound them so an odd or hostile schema cannot distort the prompt.
|
||||
_MCP_PARAM_MAX = 12 # max params rendered per tool
|
||||
_MCP_TOKEN_MAX = 40 # max chars per rendered name / type token
|
||||
_MCP_HINT_MAX = 300 # total-length backstop for the whole hint
|
||||
|
||||
|
||||
def _sanitize_schema_token(value: Any, limit: int = _MCP_TOKEN_MAX) -> str:
|
||||
"""Make an untrusted JSON-Schema token safe to splice into the prompt.
|
||||
|
||||
Replaces control chars / newlines with a space, collapses whitespace, and
|
||||
length-caps the result, so a weird field name or type cannot inject newlines
|
||||
or run on. Normal short identifiers pass through unchanged.
|
||||
"""
|
||||
text = re.sub(r"[\x00-\x1f\x7f]+", " ", str(value))
|
||||
text = re.sub(r"\s+", " ", text).strip()
|
||||
if len(text) > limit:
|
||||
text = text[:limit].rstrip() + "…"
|
||||
return text
|
||||
|
||||
|
||||
def _format_mcp_params(input_schema: Any) -> str:
|
||||
"""Render an MCP tool's JSON-Schema inputs as a compact prompt hint.
|
||||
|
||||
Without this the agent only sees a tool's name + description and has to
|
||||
guess its arguments (issue #2509). Produces e.g.
|
||||
` Args (JSON): {"path": string (required), "limit": integer}` — names,
|
||||
coarse types, and required-ness, kept short so it stays prompt-friendly.
|
||||
Returns "" when there are no parameters.
|
||||
|
||||
MCP servers are third-party, so names/types are sanitized and the parameter
|
||||
count + total length are capped (issue #2660); normal schemas are unaffected.
|
||||
"""
|
||||
if not isinstance(input_schema, dict):
|
||||
return ""
|
||||
props = input_schema.get("properties")
|
||||
if not isinstance(props, dict) or not props:
|
||||
return ""
|
||||
required = set(input_schema.get("required") or [])
|
||||
parts = []
|
||||
for pname, pinfo in list(props.items())[:_MCP_PARAM_MAX]:
|
||||
pinfo = pinfo if isinstance(pinfo, dict) else {}
|
||||
ptype = pinfo.get("type") or "any"
|
||||
if isinstance(ptype, list):
|
||||
ptype = "|".join(str(x) for x in ptype)
|
||||
tag = f'"{_sanitize_schema_token(pname)}": {_sanitize_schema_token(ptype)}'
|
||||
if pname in required:
|
||||
tag += " (required)"
|
||||
parts.append(tag)
|
||||
extra = len(props) - len(parts)
|
||||
if extra > 0:
|
||||
parts.append(f"…+{extra} more")
|
||||
hint = " Args (JSON): {" + ", ".join(parts) + "}"
|
||||
if len(hint) > _MCP_HINT_MAX:
|
||||
hint = hint[:_MCP_HINT_MAX - 1].rstrip() + "…"
|
||||
return hint
|
||||
|
||||
|
||||
class McpManager:
|
||||
"""Manages MCP server connections and tool routing."""
|
||||
@@ -43,7 +102,9 @@ class McpManager:
|
||||
self._sessions: Dict[str, Any] = {}
|
||||
# server_id -> exit stack (for cleanup)
|
||||
self._stacks: Dict[str, Any] = {}
|
||||
# Tracking updates to tools/connections for RAG indexing
|
||||
# server_id -> background connect task (HTTP transport / OAuth)
|
||||
self._connect_tasks: Dict[str, Any] = {}
|
||||
# Tracking updates to tools/connections for RAG indexing / prompt cache
|
||||
self._generation = 0
|
||||
|
||||
async def connect_server(
|
||||
@@ -56,12 +117,14 @@ class McpManager:
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
url: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Connect to an MCP server via stdio or SSE transport."""
|
||||
"""Connect to an MCP server via stdio, SSE, or Streamable HTTP transport."""
|
||||
try:
|
||||
if transport == "stdio":
|
||||
res = await self._connect_stdio(server_id, name, command, args or [], env or {})
|
||||
elif transport == "sse":
|
||||
res = await self._connect_sse(server_id, name, url)
|
||||
elif transport == "http":
|
||||
res = await self._start_http_connect(server_id, name, url)
|
||||
else:
|
||||
logger.error(f"Unknown MCP transport: {transport}")
|
||||
res = False
|
||||
@@ -184,8 +247,101 @@ class McpManager:
|
||||
self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
|
||||
return False
|
||||
|
||||
async def _start_http_connect(self, server_id: str, name: str, url: str, wait: float = 8.0) -> bool:
|
||||
"""Begin a Streamable HTTP connect in the background. Returns within
|
||||
`wait` seconds: True if it connected (cached-token path), otherwise the
|
||||
flow is awaiting browser authorization and status becomes 'needs_auth'."""
|
||||
import asyncio
|
||||
self._connections[server_id] = {"status": "connecting", "name": name, "transport": "http"}
|
||||
task = asyncio.create_task(self._connect_http(server_id, name, url))
|
||||
self._connect_tasks[server_id] = task
|
||||
done, _ = await asyncio.wait({task}, timeout=wait)
|
||||
if task in done:
|
||||
try:
|
||||
return task.result()
|
||||
except Exception as e:
|
||||
self._connections[server_id] = {"status": "error", "error": str(e), "name": name}
|
||||
return False
|
||||
# Still running → either awaiting authorization, or discovery/DCR is
|
||||
# still in flight. If _on_redirect already published needs_auth+auth_url,
|
||||
# leave it; otherwise mark needs_auth (auth_url filled in once it fires).
|
||||
from src.mcp_oauth import pop_auth_url
|
||||
cur = self._connections.get(server_id, {})
|
||||
if cur.get("status") != "needs_auth":
|
||||
self._connections[server_id] = {
|
||||
"status": "needs_auth", "name": name, "transport": "http",
|
||||
"auth_url": pop_auth_url(server_id),
|
||||
}
|
||||
return False
|
||||
|
||||
async def _connect_http(self, server_id: str, name: str, url: str) -> bool:
|
||||
"""Connect to a Streamable HTTP MCP server (with automatic OAuth)."""
|
||||
try:
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from contextlib import AsyncExitStack
|
||||
from src.mcp_oauth import build_provider, clear_auth_url
|
||||
|
||||
def _on_redirect(auth_url):
|
||||
# Publish needs_auth the moment the URL is known, independent of
|
||||
# how long discovery/DCR took (may exceed the bounded start wait).
|
||||
self._connections[server_id] = {
|
||||
"status": "needs_auth", "name": name, "transport": "http",
|
||||
"auth_url": auth_url,
|
||||
}
|
||||
|
||||
provider = build_provider(server_id, url, on_redirect=_on_redirect)
|
||||
stack = AsyncExitStack()
|
||||
transport = await stack.enter_async_context(streamablehttp_client(url, auth=provider))
|
||||
read_stream, write_stream, _get_session_id = transport
|
||||
session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
|
||||
await session.initialize()
|
||||
|
||||
tools_result = await session.list_tools()
|
||||
tools = []
|
||||
for tool in tools_result.tools:
|
||||
tools.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
|
||||
})
|
||||
|
||||
self._sessions[server_id] = session
|
||||
self._stacks[server_id] = stack
|
||||
self._tools[server_id] = tools
|
||||
self._connections[server_id] = {
|
||||
"status": "connected", "name": name, "transport": "http",
|
||||
"tool_count": len(tools),
|
||||
}
|
||||
clear_auth_url(server_id)
|
||||
# Tools changed (this can complete after connect_server already
|
||||
# returned, via the background OAuth flow), so bump the generation
|
||||
# to invalidate the tool-prompt cache.
|
||||
self._generation += 1
|
||||
logger.info(f"MCP server connected: {name} ({server_id}) - {len(tools)} tools via http")
|
||||
return True
|
||||
except ImportError:
|
||||
logger.warning("MCP package not installed. Install with: pip install mcp")
|
||||
self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect HTTP MCP server {name} ({server_id}): {e}")
|
||||
self._connections[server_id] = {"status": "error", "error": str(e), "name": name}
|
||||
return False
|
||||
|
||||
async def disconnect_server(self, server_id: str):
|
||||
"""Disconnect from an MCP server."""
|
||||
# Cancel any in-flight HTTP/OAuth background connect so it stops
|
||||
# publishing status for a server that may be getting deleted.
|
||||
task = self._connect_tasks.pop(server_id, None)
|
||||
if task is not None and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
from src.mcp_oauth import clear_auth_url
|
||||
clear_auth_url(server_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
stack = self._stacks.pop(server_id, None)
|
||||
if stack:
|
||||
try:
|
||||
@@ -376,6 +532,7 @@ class McpManager:
|
||||
"name": tool["name"],
|
||||
"qualified_name": f"mcp__{server_id}__{tool['name']}",
|
||||
"description": tool.get("description", ""),
|
||||
"input_schema": tool.get("input_schema") or {},
|
||||
"is_disabled": tool["name"] in disabled,
|
||||
})
|
||||
return result
|
||||
@@ -439,7 +596,11 @@ class McpManager:
|
||||
for t in server_tools:
|
||||
# Truncate long descriptions
|
||||
desc = t['description'][:120] + '...' if len(t['description']) > 120 else t['description']
|
||||
lines.append(f" - {t['qualified_name']}: {desc}")
|
||||
# Include the tool's declared inputs so the model calls it with
|
||||
# real argument names instead of guessing from the description
|
||||
# alone (issue #2509).
|
||||
args_hint = _format_mcp_params(t.get("input_schema"))
|
||||
lines.append(f" - {t['qualified_name']}: {desc}{args_hint}")
|
||||
|
||||
result = "\n".join(lines)
|
||||
self._cached_prompt_desc = result
|
||||
|
||||
193
src/mcp_oauth.py
Normal file
193
src/mcp_oauth.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""mcp_oauth.py — generic OAuth for remote (Streamable HTTP) MCP servers.
|
||||
|
||||
Bridges the mcp SDK's OAuthClientProvider (RFC 9728 discovery, Dynamic Client
|
||||
Registration, authorization-code + PKCE, token refresh) to Odysseus's web
|
||||
callback route. Tokens and the dynamic registration persist per-server,
|
||||
encrypted, so the interactive flow runs only once.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OAuth redirect URI registered with every authorization server via DCR. Loopback
|
||||
# is allowed for native/desktop clients (RFC 8252); remote users finish via the
|
||||
# paste-back flow. Deployments not reachable at http://localhost:7000 (custom
|
||||
# port, reverse proxy, or public domain) must set OAUTH_REDIRECT_BASE_URL (or
|
||||
# APP_PUBLIC_URL) to their externally reachable origin so the redirect lands back
|
||||
# on Odysseus. APP_PORT is intentionally not used: it is only the Docker host
|
||||
# port-map; the app always listens on 7000 inside the container.
|
||||
_REDIRECT_BASE = (
|
||||
os.environ.get("OAUTH_REDIRECT_BASE_URL")
|
||||
or os.environ.get("APP_PUBLIC_URL")
|
||||
or "http://localhost:7000"
|
||||
).rstrip("/")
|
||||
REDIRECT_URI = f"{_REDIRECT_BASE}/api/mcp/oauth/callback"
|
||||
|
||||
# How long the background connect waits for the user to authorize before giving up.
|
||||
AUTH_WAIT_SECONDS = 300
|
||||
|
||||
_pending: Dict[str, asyncio.Future] = {} # state -> Future[(code, state)]
|
||||
_pending_ts: Dict[str, float] = {} # state -> monotonic timestamp, for pruning
|
||||
_auth_urls: Dict[str, str] = {} # server_id -> authorization URL
|
||||
|
||||
|
||||
def _prune_stale() -> None:
|
||||
"""Drop abandoned flows whose authorization window has elapsed so the
|
||||
module-level registries don't grow unbounded (e.g. a user who never
|
||||
finishes the browser step)."""
|
||||
now = time.monotonic()
|
||||
for state in [s for s, ts in _pending_ts.items() if now - ts > AUTH_WAIT_SECONDS]:
|
||||
fut = _pending.pop(state, None)
|
||||
_pending_ts.pop(state, None)
|
||||
if fut is not None and not fut.done():
|
||||
fut.cancel()
|
||||
|
||||
|
||||
def _discard_pending(state: Optional[str]) -> None:
|
||||
if state is None:
|
||||
return
|
||||
_pending.pop(state, None)
|
||||
_pending_ts.pop(state, None)
|
||||
|
||||
|
||||
def register_pending(state: str) -> asyncio.Future:
|
||||
_prune_stale()
|
||||
fut = asyncio.get_running_loop().create_future()
|
||||
_pending[state] = fut
|
||||
_pending_ts[state] = time.monotonic()
|
||||
return fut
|
||||
|
||||
|
||||
def resolve_pending(state: str, code: str) -> bool:
|
||||
fut = _pending.get(state)
|
||||
if fut is not None and not fut.done():
|
||||
fut.set_result((code, state))
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def pop_auth_url(server_id: str) -> Optional[str]:
|
||||
return _auth_urls.get(server_id)
|
||||
|
||||
|
||||
def clear_auth_url(server_id: str) -> None:
|
||||
_auth_urls.pop(server_id, None)
|
||||
|
||||
|
||||
class DbTokenStorage:
|
||||
"""SDK TokenStorage backed by the encrypted McpServer.oauth_tokens column."""
|
||||
|
||||
def __init__(self, server_id: str, session_factory=None):
|
||||
self.server_id = server_id
|
||||
if session_factory is None:
|
||||
from core.database import SessionLocal
|
||||
session_factory = SessionLocal
|
||||
self._sf = session_factory
|
||||
|
||||
def _load(self) -> dict:
|
||||
from core.database import McpServer
|
||||
db = self._sf()
|
||||
try:
|
||||
srv = db.query(McpServer).filter(McpServer.id == self.server_id).first()
|
||||
if srv and srv.oauth_tokens:
|
||||
return json.loads(srv.oauth_tokens)
|
||||
finally:
|
||||
db.close()
|
||||
return {}
|
||||
|
||||
def _update(self, key: str, value: dict) -> None:
|
||||
"""Load, set one key, and persist the oauth_tokens JSON in a single
|
||||
session/commit (avoids the load+save double round-trip per write)."""
|
||||
from core.database import McpServer
|
||||
db = self._sf()
|
||||
try:
|
||||
srv = db.query(McpServer).filter(McpServer.id == self.server_id).first()
|
||||
if srv is None:
|
||||
return
|
||||
data = json.loads(srv.oauth_tokens) if srv.oauth_tokens else {}
|
||||
data[key] = value
|
||||
srv.oauth_tokens = json.dumps(data)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def get_tokens(self):
|
||||
from mcp.shared.auth import OAuthToken
|
||||
data = self._load().get("tokens")
|
||||
return OAuthToken.model_validate(data) if data else None
|
||||
|
||||
async def set_tokens(self, tokens) -> None:
|
||||
self._update("tokens", json.loads(tokens.model_dump_json()))
|
||||
|
||||
async def get_client_info(self):
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
data = self._load().get("client_info")
|
||||
return OAuthClientInformationFull.model_validate(data) if data else None
|
||||
|
||||
async def set_client_info(self, client_info) -> None:
|
||||
self._update("client_info", json.loads(client_info.model_dump_json()))
|
||||
|
||||
|
||||
def build_provider(server_id: str, url: str, on_redirect=None):
|
||||
"""Construct an OAuthClientProvider that drives the browser flow via the
|
||||
Odysseus callback route.
|
||||
|
||||
on_redirect(authorization_url): optional sync callback invoked the moment
|
||||
the authorization URL is known (after discovery + DCR). The manager uses it
|
||||
to publish 'needs_auth' + auth_url to connection state regardless of how
|
||||
long discovery/DCR took.
|
||||
"""
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name="Odysseus",
|
||||
redirect_uris=[REDIRECT_URI],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
# Leave scope unset: the SDK applies the MCP scope-selection strategy and
|
||||
# overwrites this from the server's WWW-Authenticate / protected-resource
|
||||
# metadata before building the auth URL. Hardcoding an OIDC scope here
|
||||
# would break the many MCP servers that are not OpenID providers.
|
||||
scope=None,
|
||||
token_endpoint_auth_method="none",
|
||||
)
|
||||
|
||||
async def redirect_handler(authorization_url: str) -> None:
|
||||
state = (parse_qs(urlparse(authorization_url).query).get("state") or [None])[0]
|
||||
if state:
|
||||
register_pending(state)
|
||||
_auth_urls[server_id] = authorization_url
|
||||
if on_redirect is not None:
|
||||
try:
|
||||
on_redirect(authorization_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"MCP OAuth on_redirect callback failed: {e}")
|
||||
logger.info(f"MCP OAuth: server {server_id} awaiting authorization (state={state})")
|
||||
|
||||
async def callback_handler() -> Tuple[str, Optional[str]]:
|
||||
auth_url = _auth_urls.get(server_id)
|
||||
state = (parse_qs(urlparse(auth_url).query).get("state") or [None])[0] if auth_url else None
|
||||
fut = _pending.get(state)
|
||||
if fut is None:
|
||||
raise RuntimeError("No pending OAuth flow for this server")
|
||||
try:
|
||||
code, ret_state = await asyncio.wait_for(fut, timeout=AUTH_WAIT_SECONDS)
|
||||
return code, ret_state
|
||||
finally:
|
||||
_discard_pending(state)
|
||||
_auth_urls.pop(server_id, None)
|
||||
|
||||
return OAuthClientProvider(
|
||||
server_url=url,
|
||||
client_metadata=client_metadata,
|
||||
storage=DbTokenStorage(server_id),
|
||||
redirect_handler=redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
)
|
||||
320
src/memory_provider.py
Normal file
320
src/memory_provider.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""Memory provider interfaces for native and external memory systems."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemoryRecord:
|
||||
"""Provider-neutral memory entry."""
|
||||
|
||||
id: str
|
||||
text: str
|
||||
timestamp: int = 0
|
||||
category: str = "fact"
|
||||
source: str = "unknown"
|
||||
owner: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemorySearchHit:
|
||||
"""A memory returned by provider recall."""
|
||||
|
||||
memory: MemoryRecord
|
||||
provider_id: str
|
||||
score: Optional[float] = None
|
||||
|
||||
|
||||
class MemoryProvider(ABC):
|
||||
"""Base contract for Odysseus memory providers.
|
||||
|
||||
The native memory provider should always be available. External providers
|
||||
can add recall/write behavior and their own tools without replacing the
|
||||
built-in local memory baseline.
|
||||
"""
|
||||
|
||||
provider_id = "unknown"
|
||||
display_name = "Unknown"
|
||||
enabled = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Prepare provider resources before use."""
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Release provider resources."""
|
||||
|
||||
@abstractmethod
|
||||
async def remember(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
category: str = "fact",
|
||||
source: str = "user",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> MemoryRecord:
|
||||
"""Store a memory and return the stored record."""
|
||||
|
||||
@abstractmethod
|
||||
async def recall(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
top_k: int = 5,
|
||||
) -> List[MemorySearchHit]:
|
||||
"""Return provider memories relevant to the query."""
|
||||
|
||||
@abstractmethod
|
||||
async def list_memories(
|
||||
self,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[MemoryRecord]:
|
||||
"""List memories visible to the owner."""
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self, memory_id: str, *, owner: Optional[str] = None) -> bool:
|
||||
"""Delete a memory by ID when allowed by the provider."""
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
"""Return provider-defined tool schemas when this provider is enabled."""
|
||||
return []
|
||||
|
||||
async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""Handle a provider-defined tool call."""
|
||||
raise KeyError(f"Provider {self.provider_id} does not expose tool {name}")
|
||||
|
||||
|
||||
class NativeMemoryProvider(MemoryProvider):
|
||||
"""Provider adapter for Odysseus' built-in memory manager and vector store."""
|
||||
|
||||
provider_id = "native"
|
||||
display_name = "Odysseus native memory"
|
||||
|
||||
_CORE_FIELDS = {
|
||||
"id",
|
||||
"text",
|
||||
"timestamp",
|
||||
"source",
|
||||
"category",
|
||||
"uses",
|
||||
"owner",
|
||||
"session_id",
|
||||
"metadata",
|
||||
}
|
||||
|
||||
def __init__(self, memory_manager, memory_vector=None):
|
||||
self.memory_manager = memory_manager
|
||||
self.memory_vector = memory_vector
|
||||
|
||||
def _to_record(self, entry: Dict[str, Any]) -> MemoryRecord:
|
||||
metadata = {
|
||||
key: value
|
||||
for key, value in entry.items()
|
||||
if key not in self._CORE_FIELDS
|
||||
}
|
||||
stored_metadata = entry.get("metadata")
|
||||
if isinstance(stored_metadata, dict):
|
||||
metadata.update(stored_metadata)
|
||||
|
||||
return MemoryRecord(
|
||||
id=entry.get("id", ""),
|
||||
text=entry.get("text", ""),
|
||||
timestamp=entry.get("timestamp", 0),
|
||||
category=entry.get("category", "fact"),
|
||||
source=entry.get("source", "unknown"),
|
||||
owner=entry.get("owner"),
|
||||
session_id=entry.get("session_id"),
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def remember(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
category: str = "fact",
|
||||
source: str = "user",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> MemoryRecord:
|
||||
entry = self.memory_manager.add_entry(
|
||||
text,
|
||||
source=source,
|
||||
category=category,
|
||||
owner=owner,
|
||||
)
|
||||
if session_id:
|
||||
entry["session_id"] = session_id
|
||||
if metadata:
|
||||
entry["metadata"] = dict(metadata)
|
||||
|
||||
memories = self.memory_manager.load_all()
|
||||
memories.append(entry)
|
||||
self.memory_manager.save(memories)
|
||||
|
||||
if self._vector_available():
|
||||
self.memory_vector.add(entry["id"], entry["text"])
|
||||
|
||||
return self._to_record(entry)
|
||||
|
||||
async def recall(
|
||||
self,
|
||||
query: str,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
top_k: int = 5,
|
||||
) -> List[MemorySearchHit]:
|
||||
memories = self.memory_manager.load(owner=owner)
|
||||
by_id = {m.get("id"): m for m in memories}
|
||||
|
||||
if self._vector_available():
|
||||
hits: List[MemorySearchHit] = []
|
||||
for result in self.memory_vector.search(query, k=top_k):
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
memory_id = result.get("memory_id")
|
||||
entry = by_id.get(memory_id) if memory_id else result
|
||||
if not entry:
|
||||
continue
|
||||
if owner is not None and entry.get("owner") != owner:
|
||||
continue
|
||||
hits.append(
|
||||
MemorySearchHit(
|
||||
memory=self._to_record(entry),
|
||||
provider_id=self.provider_id,
|
||||
score=result.get("score"),
|
||||
)
|
||||
)
|
||||
if hits:
|
||||
return hits
|
||||
|
||||
fallback = self.memory_manager.get_relevant_memories(
|
||||
query,
|
||||
memories,
|
||||
max_items=top_k,
|
||||
)
|
||||
return [
|
||||
MemorySearchHit(
|
||||
memory=self._to_record(entry),
|
||||
provider_id=self.provider_id,
|
||||
score=None,
|
||||
)
|
||||
for entry in fallback
|
||||
]
|
||||
|
||||
async def list_memories(
|
||||
self,
|
||||
*,
|
||||
owner: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[MemoryRecord]:
|
||||
return [
|
||||
self._to_record(entry)
|
||||
for entry in self.memory_manager.load(owner=owner)[:limit]
|
||||
]
|
||||
|
||||
async def delete(self, memory_id: str, *, owner: Optional[str] = None) -> bool:
|
||||
memories = self.memory_manager.load_all()
|
||||
remaining = []
|
||||
deleted_id = None
|
||||
|
||||
for entry in memories:
|
||||
if entry.get("id") != memory_id:
|
||||
remaining.append(entry)
|
||||
continue
|
||||
if owner is not None and entry.get("owner") != owner:
|
||||
remaining.append(entry)
|
||||
continue
|
||||
deleted_id = entry.get("id")
|
||||
|
||||
if deleted_id is None:
|
||||
return False
|
||||
|
||||
self.memory_manager.save(remaining)
|
||||
if self._vector_available():
|
||||
self.memory_vector.remove(deleted_id)
|
||||
return True
|
||||
|
||||
def _vector_available(self) -> bool:
|
||||
return bool(self.memory_vector and getattr(self.memory_vector, "healthy", True))
|
||||
|
||||
|
||||
class MemoryProviderRegistry:
|
||||
"""Container for native and optional external memory providers."""
|
||||
|
||||
def __init__(self, providers: Optional[Iterable[MemoryProvider]] = None):
|
||||
self._providers: Dict[str, MemoryProvider] = {}
|
||||
for provider in providers or []:
|
||||
self.register(provider)
|
||||
|
||||
def register(self, provider: MemoryProvider) -> None:
|
||||
if provider.provider_id in self._providers:
|
||||
raise ValueError(f"Memory provider already registered: {provider.provider_id}")
|
||||
self._providers[provider.provider_id] = provider
|
||||
|
||||
def get(self, provider_id: str) -> MemoryProvider:
|
||||
return self._providers[provider_id]
|
||||
|
||||
def all(self) -> List[MemoryProvider]:
|
||||
return list(self._providers.values())
|
||||
|
||||
def active(self) -> List[MemoryProvider]:
|
||||
return [provider for provider in self._providers.values() if provider.enabled]
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
schemas: List[Dict[str, Any]] = []
|
||||
seen: Dict[str, str] = {}
|
||||
|
||||
for provider in self.active():
|
||||
for schema in provider.get_tool_schemas():
|
||||
name = self._tool_name(schema)
|
||||
if name in seen:
|
||||
raise ValueError(
|
||||
f"Memory tool name conflict: {name} from "
|
||||
f"{provider.provider_id} already exposed by {seen[name]}"
|
||||
)
|
||||
seen[name] = provider.provider_id
|
||||
schemas.append(schema)
|
||||
|
||||
return schemas
|
||||
|
||||
async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> Any:
|
||||
provider_by_tool: Dict[str, MemoryProvider] = {}
|
||||
for provider in self.active():
|
||||
for schema in provider.get_tool_schemas():
|
||||
tool_name = self._tool_name(schema)
|
||||
if tool_name in provider_by_tool:
|
||||
raise ValueError(
|
||||
f"Memory tool name conflict: {tool_name} from "
|
||||
f"{provider.provider_id} already exposed by "
|
||||
f"{provider_by_tool[tool_name].provider_id}"
|
||||
)
|
||||
provider_by_tool[tool_name] = provider
|
||||
|
||||
provider = provider_by_tool.get(name)
|
||||
if provider:
|
||||
return await provider.handle_tool_call(name, arguments)
|
||||
raise KeyError(f"No active memory provider exposes tool {name}")
|
||||
|
||||
@staticmethod
|
||||
def _tool_name(schema: Dict[str, Any]) -> str:
|
||||
if not isinstance(schema, dict):
|
||||
raise ValueError("Memory provider tool schema must be a dict")
|
||||
name = schema.get("name")
|
||||
if isinstance(name, str) and name:
|
||||
return name
|
||||
function = schema.get("function")
|
||||
if isinstance(function, dict):
|
||||
function_name = function.get("name")
|
||||
if isinstance(function_name, str) and function_name:
|
||||
return function_name
|
||||
raise ValueError("Memory provider tool schema is missing a tool name")
|
||||
@@ -7,7 +7,7 @@ Provides token estimation for context usage tracking.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -208,27 +208,32 @@ KNOWN_CONTEXT_WINDOWS = {
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache
|
||||
# ---------------------------------------------------------------------------
|
||||
_context_cache: Dict[str, int] = {}
|
||||
_context_cache: Dict[Tuple[str, str], int] = {}
|
||||
|
||||
|
||||
def get_context_length(endpoint_url: str, model: str) -> int:
|
||||
"""Get the context window size for a model.
|
||||
|
||||
Queries /v1/models on the endpoint and looks for context_length
|
||||
or context_window fields. Caches result per model ID.
|
||||
or context_window fields. Caches result per (endpoint, model).
|
||||
Falls back to DEFAULT_CONTEXT if unavailable.
|
||||
"""
|
||||
configured_kind = _configured_endpoint_kind(endpoint_url)
|
||||
is_local = _is_local_endpoint(endpoint_url)
|
||||
if not is_local and model in _context_cache:
|
||||
return _context_cache[model]
|
||||
# Key on (endpoint_url, model): the same model id can be served by two
|
||||
# different remote endpoints with different real context windows (e.g. a
|
||||
# capped proxy vs. the full provider), so caching by model id alone would
|
||||
# serve one endpoint's window for the other (issue #2603).
|
||||
cache_key = (endpoint_url, model)
|
||||
if not is_local and cache_key in _context_cache:
|
||||
return _context_cache[cache_key]
|
||||
|
||||
ctx = _query_context_length(endpoint_url, model)
|
||||
# Only cache non-default values to allow retry on next request.
|
||||
# Local endpoints can restart with a different --max-model-len while keeping
|
||||
# the same model id, so always re-query them instead of serving stale cache.
|
||||
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
|
||||
_context_cache[model] = ctx
|
||||
_context_cache[cache_key] = ctx
|
||||
logger.info(f"Context length for {model}: {ctx}")
|
||||
return ctx
|
||||
|
||||
@@ -282,6 +287,16 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# GitHub Copilot's /models requires auth + X-GitHub-Api-Version headers that
|
||||
# aren't available here; an unauthenticated probe just 400s. All Copilot
|
||||
# picker models are major API models covered by the known-context table, so
|
||||
# rely on that instead of a doomed network call.
|
||||
from src.copilot import is_copilot_base
|
||||
if is_copilot_base(endpoint_url):
|
||||
if known:
|
||||
logger.info(f"Using known context window for {model}: {known}")
|
||||
return known or DEFAULT_CONTEXT
|
||||
|
||||
models_url = endpoint_url.replace("/chat/completions", "/models")
|
||||
try:
|
||||
r = httpx.get(models_url, timeout=REQUEST_TIMEOUT)
|
||||
|
||||
@@ -722,6 +722,18 @@ class ResearchHandler:
|
||||
minimum=1,
|
||||
maximum=12,
|
||||
)
|
||||
_planning_timeout = _bounded_int(
|
||||
get_setting("research_planning_timeout_seconds", _extraction_timeout),
|
||||
default=_extraction_timeout,
|
||||
minimum=15,
|
||||
maximum=3600,
|
||||
)
|
||||
_query_timeout = _bounded_int(
|
||||
get_setting("research_query_timeout_seconds", _extraction_timeout),
|
||||
default=_extraction_timeout,
|
||||
minimum=15,
|
||||
maximum=3600,
|
||||
)
|
||||
|
||||
researcher = DeepResearcher(
|
||||
llm_endpoint=llm_endpoint,
|
||||
@@ -732,6 +744,8 @@ class ResearchHandler:
|
||||
max_time=max_time,
|
||||
max_report_tokens=_max_report_tokens,
|
||||
extraction_timeout=_extraction_timeout,
|
||||
planning_timeout=_planning_timeout,
|
||||
query_timeout=_query_timeout,
|
||||
extraction_concurrency=_extraction_concurrency,
|
||||
progress_callback=progress_callback,
|
||||
search_provider=search_provider,
|
||||
|
||||
@@ -1,141 +1,12 @@
|
||||
"""Search analytics, metrics tracking, and exception hierarchy."""
|
||||
"""Compatibility re-export shim for the live analytics module.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
The real implementation lives in :mod:`services.search.analytics`, which is
|
||||
what the search runtime imports. Alias this module to that implementation so
|
||||
mutable module state such as ``ANALYTICS_FILE`` cannot drift out of sync.
|
||||
"""
|
||||
|
||||
from .cache import cache_metrics
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from services.search import analytics as _analytics
|
||||
|
||||
# Dedicated error logger with file handler
|
||||
_error_log_path = Path(__file__).resolve().parent.parent / "search_engine_error.log"
|
||||
_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
|
||||
_error_handler.setLevel(logging.WARNING)
|
||||
_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
|
||||
error_logger = logging.getLogger("search_engine_error")
|
||||
error_logger.addHandler(_error_handler)
|
||||
error_logger.propagate = False
|
||||
|
||||
# Analytics file
|
||||
ANALYTICS_FILE = Path(__file__).resolve().parent.parent / "search_analytics.json"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Custom exception hierarchy
|
||||
# ----------------------------------------------------------------------
|
||||
class SearchEngineError(Exception):
|
||||
"""Base class for all search-engine related errors."""
|
||||
|
||||
|
||||
class NetworkError(SearchEngineError):
|
||||
"""Raised when a network request fails (e.g., timeout, DNS error)."""
|
||||
|
||||
|
||||
class ParseError(SearchEngineError):
|
||||
"""Raised when HTML or other content cannot be parsed."""
|
||||
|
||||
|
||||
class RateLimitError(SearchEngineError):
|
||||
"""Raised when the remote service returns a rate-limit (HTTP 429)."""
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Analytics helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def _default_analytics() -> Dict[str, Any]:
|
||||
"""A fresh analytics document with every counter present."""
|
||||
return {
|
||||
"total_queries": 0,
|
||||
"successful_queries": 0,
|
||||
"failed_queries": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"query_patterns": {},
|
||||
}
|
||||
|
||||
|
||||
def _load_analytics() -> Dict[str, Any]:
|
||||
"""Load analytics data from the JSON file, creating defaults if missing."""
|
||||
if not ANALYTICS_FILE.exists():
|
||||
default = _default_analytics()
|
||||
_save_analytics(default)
|
||||
return default
|
||||
try:
|
||||
with open(ANALYTICS_FILE, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
# Merge over defaults so a file written by an older schema (or a
|
||||
# partial write) still has every counter — _record_query indexes
|
||||
# these keys directly and would otherwise raise KeyError.
|
||||
merged = _default_analytics()
|
||||
if isinstance(data, dict):
|
||||
merged.update(data)
|
||||
return merged
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load analytics file: {e}")
|
||||
return _default_analytics()
|
||||
|
||||
|
||||
def _save_analytics(data: Dict[str, Any]) -> None:
|
||||
"""Persist analytics data to the JSON file."""
|
||||
try:
|
||||
with open(ANALYTICS_FILE, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write analytics file: {e}")
|
||||
|
||||
|
||||
def _record_query(query: str, success: bool, cache_hit: bool) -> None:
|
||||
"""Update analytics for a single query execution."""
|
||||
analytics = _load_analytics()
|
||||
analytics["total_queries"] += 1
|
||||
if success:
|
||||
analytics["successful_queries"] += 1
|
||||
else:
|
||||
analytics["failed_queries"] += 1
|
||||
|
||||
if cache_hit:
|
||||
analytics["cache_hits"] += 1
|
||||
cache_metrics["hits"] += 1
|
||||
else:
|
||||
analytics["cache_misses"] += 1
|
||||
cache_metrics["misses"] += 1
|
||||
|
||||
patterns = analytics["query_patterns"]
|
||||
entry = patterns.get(query, {"count": 0, "successes": 0})
|
||||
entry["count"] += 1
|
||||
if success:
|
||||
entry["successes"] += 1
|
||||
patterns[query] = entry
|
||||
|
||||
_save_analytics(analytics)
|
||||
|
||||
|
||||
def get_search_stats() -> Dict[str, Any]:
|
||||
"""Return aggregated search analytics."""
|
||||
analytics = _load_analytics()
|
||||
total = analytics.get("total_queries", 0) or 1
|
||||
success_rate = analytics.get("successful_queries", 0) / total
|
||||
cache_total = analytics.get("cache_hits", 0) + analytics.get("cache_misses", 0) or 1
|
||||
cache_hit_rate = analytics.get("cache_hits", 0) / cache_total
|
||||
|
||||
pattern_counter = Counter({
|
||||
q: data["count"] for q, data in analytics.get("query_patterns", {}).items()
|
||||
})
|
||||
most_common = [q for q, _ in pattern_counter.most_common(5)]
|
||||
|
||||
return {
|
||||
"most_common_queries": most_common,
|
||||
"success_rate": success_rate,
|
||||
"cache_hit_rate": cache_hit_rate,
|
||||
"total_queries": analytics.get("total_queries", 0),
|
||||
"successful_queries": analytics.get("successful_queries", 0),
|
||||
"failed_queries": analytics.get("failed_queries", 0),
|
||||
"cache_hits": analytics.get("cache_hits", 0),
|
||||
"cache_misses": analytics.get("cache_misses", 0),
|
||||
"cache_evictions": cache_metrics["evictions"],
|
||||
"runtime_cache_hits": cache_metrics["hits"],
|
||||
"runtime_cache_misses": cache_metrics["misses"],
|
||||
}
|
||||
sys.modules[__name__] = _analytics
|
||||
|
||||
@@ -1,57 +1,11 @@
|
||||
"""Search and content caching with LRU eviction."""
|
||||
"""Compatibility wrapper for the canonical services.search.cache module.
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
``src.search.cache`` stays importable for older agent/deep-research code, but the
|
||||
implementation now lives in ``services.search.cache`` so the two cannot drift.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import sys
|
||||
|
||||
# Cache directories
|
||||
CACHE_DIR = Path(__file__).resolve().parent.parent / "cache"
|
||||
SEARCH_CACHE_DIR = CACHE_DIR / "search"
|
||||
CONTENT_CACHE_DIR = CACHE_DIR / "content"
|
||||
CACHE_MAX_ENTRIES = 1000
|
||||
from services.search import cache as _cache
|
||||
|
||||
# Create cache directories
|
||||
SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Track cache size for LRU eviction
|
||||
search_cache_index: Dict[str, datetime] = {}
|
||||
content_cache_index: Dict[str, datetime] = {}
|
||||
|
||||
# Cache metrics (shared across modules)
|
||||
cache_metrics = {"hits": 0, "misses": 0, "evictions": 0}
|
||||
|
||||
|
||||
def generate_cache_key(data: str) -> str:
|
||||
"""Generate a unique cache key using SHA-256 hash."""
|
||||
return hashlib.sha256(data.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def cleanup_cache(cache_dir: Path, cache_index: Dict[str, datetime], max_age: timedelta):
|
||||
"""Remove expired cache entries and enforce LRU policy."""
|
||||
current_time = datetime.now()
|
||||
files_in_dir = {f.name.split(".")[0]: f for f in cache_dir.glob("*.cache")}
|
||||
|
||||
to_remove = []
|
||||
for key, timestamp in list(cache_index.items()):
|
||||
if current_time - timestamp > max_age or key not in files_in_dir:
|
||||
to_remove.append(key)
|
||||
if key in files_in_dir:
|
||||
files_in_dir[key].unlink(missing_ok=True)
|
||||
|
||||
for key in to_remove:
|
||||
cache_index.pop(key, None)
|
||||
cache_metrics["evictions"] += 1
|
||||
|
||||
if len(cache_index) > CACHE_MAX_ENTRIES:
|
||||
sorted_items = sorted(cache_index.items(), key=lambda x: x[1])
|
||||
excess_count = len(cache_index) - CACHE_MAX_ENTRIES
|
||||
for key, _ in sorted_items[:excess_count]:
|
||||
cache_index.pop(key, None)
|
||||
cache_file = cache_dir / f"{key}.cache"
|
||||
cache_file.unlink(missing_ok=True)
|
||||
cache_metrics["evictions"] += 1
|
||||
sys.modules[__name__] = _cache
|
||||
|
||||
@@ -1,419 +1,11 @@
|
||||
"""Webpage content fetching with caching, PDF extraction, and summarization helpers."""
|
||||
"""Compatibility wrapper for the canonical services.search.content module.
|
||||
|
||||
import copy
|
||||
import io
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
import socket
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List
|
||||
from urllib.parse import urljoin, urlparse
|
||||
``src.search.content`` stays importable for older agent/deep-research code, but the
|
||||
implementation now lives in ``services.search.content`` so the two cannot drift.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from bs4 import BeautifulSoup
|
||||
import sys
|
||||
|
||||
from .analytics import RateLimitError, error_logger
|
||||
from .cache import (
|
||||
CONTENT_CACHE_DIR,
|
||||
content_cache_index,
|
||||
generate_cache_key,
|
||||
cleanup_cache,
|
||||
)
|
||||
from services.search import content as _content
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PRIVATE_NETWORKS = (
|
||||
ipaddress.ip_network("0.0.0.0/8"),
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("169.254.0.0/16"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("fc00::/7"),
|
||||
ipaddress.ip_network("fe80::/10"),
|
||||
)
|
||||
|
||||
|
||||
def _is_private_address(addr: ipaddress._BaseAddress) -> bool:
|
||||
if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None:
|
||||
addr = addr.ipv4_mapped
|
||||
return (
|
||||
addr.is_private
|
||||
or addr.is_loopback
|
||||
or addr.is_link_local
|
||||
or addr.is_reserved
|
||||
or addr.is_multicast
|
||||
or addr.is_unspecified
|
||||
or any(addr in net for net in _PRIVATE_NETWORKS)
|
||||
)
|
||||
|
||||
|
||||
def _resolve_hostname_ips(hostname: str) -> List[ipaddress._BaseAddress]:
|
||||
ips = []
|
||||
for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
|
||||
if family in (socket.AF_INET, socket.AF_INET6):
|
||||
ips.append(ipaddress.ip_address(sockaddr[0]))
|
||||
return ips
|
||||
|
||||
|
||||
def _public_http_url(url: str) -> bool:
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in ("http", "https") or not parsed.hostname:
|
||||
return False
|
||||
host = parsed.hostname.strip().lower()
|
||||
if host in ("localhost", "metadata.google.internal", "metadata"):
|
||||
return False
|
||||
if host.endswith((".local", ".localhost", ".internal", ".lan", ".intranet")):
|
||||
return False
|
||||
try:
|
||||
return not _is_private_address(ipaddress.ip_address(host))
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
ips = _resolve_hostname_ips(host)
|
||||
except OSError:
|
||||
return False
|
||||
# Fail closed: a hostname that resolves to nothing is treated as
|
||||
# non-public (an empty all(...) would otherwise return True).
|
||||
return bool(ips) and all(not _is_private_address(ip) for ip in ips)
|
||||
|
||||
|
||||
def _get_public_url(url: str, *, headers: dict, timeout: int) -> httpx.Response:
|
||||
if not _public_http_url(url):
|
||||
raise httpx.RequestError(f"Blocked non-public URL: {url}")
|
||||
|
||||
current = url
|
||||
with httpx.Client(headers=headers, timeout=timeout, follow_redirects=False) as client:
|
||||
for _ in range(8):
|
||||
response = client.get(current)
|
||||
if response.status_code not in (301, 302, 303, 307, 308):
|
||||
return response
|
||||
location = response.headers.get("location")
|
||||
if not location:
|
||||
return response
|
||||
current = urljoin(current, location)
|
||||
if not _public_http_url(current):
|
||||
raise httpx.RequestError(f"Blocked redirect to non-public URL: {current}")
|
||||
raise httpx.RequestError("Too many redirects")
|
||||
|
||||
# PDF extraction (optional dependency)
|
||||
try:
|
||||
from pdfminer.high_level import extract_text as pdf_extract_text
|
||||
except ImportError:
|
||||
pdf_extract_text = None # type: ignore
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# HTML extraction helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def _extract_meta(soup: BeautifulSoup) -> dict:
|
||||
"""Pull meta description and keywords if present."""
|
||||
description = ""
|
||||
keywords = ""
|
||||
desc_tag = soup.find("meta", attrs={"name": re.compile("description", re.I)})
|
||||
if desc_tag and desc_tag.get("content"):
|
||||
description = desc_tag["content"].strip()
|
||||
kw_tag = soup.find("meta", attrs={"name": re.compile("keywords", re.I)})
|
||||
if kw_tag and kw_tag.get("content"):
|
||||
keywords = kw_tag["content"].strip()
|
||||
return {"description": description, "keywords": keywords}
|
||||
|
||||
|
||||
def _extract_og_image(soup: BeautifulSoup) -> str:
|
||||
"""Extract the best representative image URL from meta tags.
|
||||
|
||||
Only returns absolute http(s) URLs — skips relative paths and data URIs.
|
||||
"""
|
||||
candidates = []
|
||||
# Open Graph image (most reliable)
|
||||
for prop in ("og:image", "og:image:url", "og:image:secure_url"):
|
||||
tag = soup.find("meta", attrs={"property": prop})
|
||||
if tag and tag.get("content", "").strip():
|
||||
candidates.append(tag["content"].strip())
|
||||
# Twitter card image
|
||||
tag = soup.find("meta", attrs={"name": "twitter:image"})
|
||||
if tag and tag.get("content", "").strip():
|
||||
candidates.append(tag["content"].strip())
|
||||
# Thumbnail meta
|
||||
tag = soup.find("meta", attrs={"name": "thumbnail"})
|
||||
if tag and tag.get("content", "").strip():
|
||||
candidates.append(tag["content"].strip())
|
||||
# Return first absolute http(s) URL
|
||||
for url in candidates:
|
||||
if url.startswith(("https://", "http://")) and not url.endswith((".svg", ".ico")):
|
||||
return url
|
||||
return ""
|
||||
|
||||
|
||||
def _extract_lists(soup: BeautifulSoup) -> List[List[str]]:
|
||||
"""Return a list of lists, each inner list representing a <ul>/<ol>."""
|
||||
all_lists = []
|
||||
for lst in soup.find_all(["ul", "ol"]):
|
||||
items = [li.get_text(separator=" ", strip=True) for li in lst.find_all("li")]
|
||||
if items:
|
||||
all_lists.append(items)
|
||||
return all_lists
|
||||
|
||||
|
||||
def _extract_tables(soup: BeautifulSoup) -> List[List[List[str]]]:
|
||||
"""Return a list of tables, each table is a list of rows, each row a list of cell texts."""
|
||||
tables_data = []
|
||||
for table in soup.find_all("table"):
|
||||
rows = []
|
||||
for tr in table.find_all("tr"):
|
||||
cells = [td.get_text(separator=" ", strip=True) for td in tr.find_all(["td", "th"])]
|
||||
if cells:
|
||||
rows.append(cells)
|
||||
if rows:
|
||||
tables_data.append(rows)
|
||||
return tables_data
|
||||
|
||||
|
||||
def _extract_code_blocks(soup: BeautifulSoup) -> List[str]:
|
||||
"""Collect text from <pre> and <code> blocks."""
|
||||
blocks = []
|
||||
for tag in soup.find_all(["pre", "code"]):
|
||||
txt = tag.get_text(separator=" ", strip=True)
|
||||
if txt:
|
||||
blocks.append(txt)
|
||||
return blocks
|
||||
|
||||
|
||||
def _detect_js_frameworks(soup: BeautifulSoup) -> bool:
|
||||
"""Very naive detection of common JS frameworks."""
|
||||
js_indicators = [
|
||||
"react", "angular", "vue", "svelte", "next", "nuxt",
|
||||
"ember", "backbone", "jquery", "polymer", "mithril",
|
||||
]
|
||||
for script in soup.find_all("script"):
|
||||
src = script.get("src", "").lower()
|
||||
if any(fr in src for fr in js_indicators):
|
||||
return True
|
||||
if script.string:
|
||||
content = script.string.lower()
|
||||
if any(fr in content for fr in js_indicators):
|
||||
return True
|
||||
if soup.find(attrs={"data-reactroot": True}) or soup.find(attrs={"ng-app": True}):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _empty_result(url: str, error: str = "") -> dict:
|
||||
"""Build a standard failure result dict."""
|
||||
return {
|
||||
"url": url,
|
||||
"title": "",
|
||||
"content": "",
|
||||
"lists": [],
|
||||
"tables": [],
|
||||
"code_blocks": [],
|
||||
"meta_description": "",
|
||||
"meta_keywords": "",
|
||||
"js_rendered": False,
|
||||
"js_message": "",
|
||||
"success": False,
|
||||
"error": error,
|
||||
}
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Main content fetcher
|
||||
# ----------------------------------------------------------------------
|
||||
def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) -> dict:
|
||||
"""Fetch and extract meaningful content from a webpage with caching."""
|
||||
cache_key = generate_cache_key(url)
|
||||
cache_file = CONTENT_CACHE_DIR / f"{cache_key}.cache"
|
||||
|
||||
# Check cache
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
cached_data = json.load(f)
|
||||
timestamp = datetime.fromisoformat(cached_data["timestamp"])
|
||||
if datetime.now() - timestamp < timedelta(hours=2):
|
||||
logger.debug(f"Content cache hit for URL: {url}")
|
||||
return cached_data["data"]
|
||||
else:
|
||||
cache_file.unlink(missing_ok=True)
|
||||
content_cache_index.pop(cache_key, None)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read content cache for {url}: {e}")
|
||||
cache_file.unlink(missing_ok=True)
|
||||
content_cache_index.pop(cache_key, None)
|
||||
|
||||
# Fetch
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
|
||||
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
|
||||
"Accept-Language": "en-US,en;q=0.5",
|
||||
"Accept-Encoding": "gzip, deflate",
|
||||
"Connection": "keep-alive",
|
||||
}
|
||||
response = _get_public_url(url, headers=headers, timeout=timeout)
|
||||
|
||||
if response.status_code == 429:
|
||||
raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
|
||||
|
||||
response.raise_for_status()
|
||||
except httpx.RequestError as e:
|
||||
error_logger.error(f"NetworkError fetching {url} (attempt {retry_attempt}): {e}")
|
||||
return _empty_result(url, f"NetworkError: {e}")
|
||||
except RateLimitError as e:
|
||||
error_logger.error(str(e))
|
||||
return _empty_result(url, str(e))
|
||||
|
||||
# PDF handling
|
||||
content_type = response.headers.get("Content-Type", "").lower()
|
||||
if "application/pdf" in content_type or url.lower().endswith(".pdf"):
|
||||
if pdf_extract_text is None:
|
||||
logger.error("pdfminer.six is not installed; cannot extract PDF text.")
|
||||
pdf_text = ""
|
||||
else:
|
||||
try:
|
||||
pdf_bytes = io.BytesIO(response.content)
|
||||
pdf_text = pdf_extract_text(pdf_bytes)
|
||||
except Exception as e:
|
||||
logger.warning(f"PDF extraction failed for {url}: {e}")
|
||||
pdf_text = ""
|
||||
result = {
|
||||
"url": url,
|
||||
"title": os.path.basename(url),
|
||||
"content": pdf_text,
|
||||
"lists": [],
|
||||
"tables": [],
|
||||
"code_blocks": [],
|
||||
"meta_description": "",
|
||||
"meta_keywords": "",
|
||||
"js_rendered": False,
|
||||
"js_message": "",
|
||||
"success": bool(pdf_text),
|
||||
"error": "" if pdf_text else "Failed to extract PDF text",
|
||||
}
|
||||
_cache_result(cache_file, cache_key, result, url)
|
||||
return result
|
||||
|
||||
# HTML handling
|
||||
try:
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
except Exception as e:
|
||||
error_logger.error(f"ParseError parsing HTML from {url} (attempt {retry_attempt}): {e}")
|
||||
result = _empty_result(url, f"ParseError: {e}")
|
||||
_cache_result(cache_file, cache_key, result, url)
|
||||
return result
|
||||
|
||||
title_tag = soup.find("title")
|
||||
title_text = title_tag.get_text(strip=True) if title_tag else ""
|
||||
meta_info = _extract_meta(soup)
|
||||
og_image = _extract_og_image(soup)
|
||||
js_rendered = _detect_js_frameworks(soup)
|
||||
js_message = "Page appears to be rendered by a JavaScript framework; content may be incomplete." if js_rendered else ""
|
||||
|
||||
# Main textual content (heuristic): prefer semantic / "content"-classed
|
||||
# containers to skip nav/footer/boilerplate; tuned for article pages.
|
||||
main_content = ""
|
||||
content_areas = soup.find_all(
|
||||
["main", "article", "section", "div"],
|
||||
class_=re.compile("content|main|body|article|post|entry|text", re.I),
|
||||
)
|
||||
if content_areas:
|
||||
for area in content_areas[:3]:
|
||||
main_content += area.get_text(separator=" ", strip=True) + " "
|
||||
main_content = re.sub(r"\s+", " ", main_content).strip()
|
||||
|
||||
# The class heuristic can latch onto a small wrapper and miss the real
|
||||
# content (app/landing pages, or SSR sites whose body isn't in a
|
||||
# "content"-classed div, so these came back nearly empty before). When the
|
||||
# heuristic returns nothing OR suspiciously little, fall back to the full
|
||||
# <body>, stripping scripts/styles (so JSON/JS doesn't leak into the text)
|
||||
# plus nav/header/footer/aside (boilerplate), and keep whichever yields
|
||||
# more readable text.
|
||||
THIN_CONTENT_CHARS = 600 # below this the heuristic likely missed the page
|
||||
if len(main_content) < THIN_CONTENT_CHARS:
|
||||
body = soup.find("body")
|
||||
if body:
|
||||
# Strip from a copy so the later list/table/code extractors still
|
||||
# see the original soup unmodified.
|
||||
body_copy = copy.copy(body)
|
||||
for _noise in body_copy.find_all(
|
||||
["script", "style", "noscript", "template", "nav", "header", "footer", "aside"]
|
||||
):
|
||||
_noise.extract()
|
||||
body_text = re.sub(r"\s+", " ", body_copy.get_text(separator=" ", strip=True)).strip()
|
||||
if len(body_text) > len(main_content):
|
||||
main_content = body_text
|
||||
|
||||
result = {
|
||||
"url": url,
|
||||
"title": title_text,
|
||||
"content": main_content,
|
||||
"lists": _extract_lists(soup),
|
||||
"tables": _extract_tables(soup),
|
||||
"code_blocks": _extract_code_blocks(soup),
|
||||
"meta_description": meta_info.get("description", ""),
|
||||
"meta_keywords": meta_info.get("keywords", ""),
|
||||
"og_image": og_image,
|
||||
"js_rendered": js_rendered,
|
||||
"js_message": js_message,
|
||||
"success": True,
|
||||
"error": "",
|
||||
}
|
||||
_cache_result(cache_file, cache_key, result, url)
|
||||
return result
|
||||
|
||||
|
||||
def _cache_result(cache_file, cache_key: str, result: dict, url: str):
|
||||
"""Write a result to the content cache."""
|
||||
try:
|
||||
cache_data = {"timestamp": datetime.now().isoformat(), "data": result}
|
||||
with open(cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(cache_data, f)
|
||||
content_cache_index[cache_key] = datetime.now()
|
||||
cleanup_cache(CONTENT_CACHE_DIR, content_cache_index, timedelta(hours=2))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to write content cache for {url}: {e}")
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Content summarization helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def extract_key_points(text: str) -> List[str]:
|
||||
"""Pull out bullet-style key points from a block of text."""
|
||||
points: List[str] = []
|
||||
bullet_pat = re.compile(r"^\s*[-*•]\s+(.*)")
|
||||
numbered_pat = re.compile(r"^\s*\d+[\.\)]\s+(.*)")
|
||||
for line in text.splitlines():
|
||||
m = bullet_pat.match(line) or numbered_pat.match(line)
|
||||
if m:
|
||||
points.append(m.group(1).strip())
|
||||
return points
|
||||
|
||||
|
||||
def get_tldr(text: str, max_sentences: int = 3) -> str:
|
||||
"""Produce a very short TL;DR by taking the first few sentences."""
|
||||
sentences = re.split(r"(?<=[.!?])\s+", text)
|
||||
selected = [s.strip() for s in sentences if s][:max_sentences]
|
||||
return " ".join(selected)
|
||||
|
||||
|
||||
def extract_quotes(text: str) -> List[str]:
|
||||
"""Return quoted excerpts that are at least 15 characters long."""
|
||||
# Backreference the opening quote so the closing quote must match it —
|
||||
# otherwise `"text'` (open double, close single) is treated as a quote.
|
||||
return [m.group(2).strip() for m in re.finditer(r'(["\'])([^"\']{15,}?)\1', text)]
|
||||
|
||||
|
||||
def extract_statistics(text: str) -> List[str]:
|
||||
"""Find numbers, percentages, dates and simple measurements."""
|
||||
# Match a comma-grouped number (1,000,000) OR a plain digit run (50000) —
|
||||
# the old `\d{1,3}(?:,\d{3})*` matched only the first 3 digits of a
|
||||
# comma-less number, and the trailing `\b` dropped a closing `%`.
|
||||
pattern = re.compile(
|
||||
r"\b(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d+)?\s*(%|percent|‰|per cent|[a-zA-Z]+)?",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return [m.group(0).strip() for m in pattern.finditer(text)]
|
||||
sys.modules[__name__] = _content
|
||||
|
||||
@@ -1,141 +1,11 @@
|
||||
"""Query enhancement, entity extraction, and cache duration helpers."""
|
||||
"""Compatibility wrapper for the canonical services.search.query module.
|
||||
|
||||
import re
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
``src.search.query`` stays importable for older agent/deep-research code, but the
|
||||
implementation now lives in ``services.search.query`` so the two cannot drift.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
import sys
|
||||
|
||||
from services.search import query as _query
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Query processing helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def _detect_question_type(query: str) -> Optional[str]:
|
||||
"""Return the leading question word if present (who, what, when, where, why, how)."""
|
||||
if not isinstance(query, str):
|
||||
return None
|
||||
q = query.strip().lower()
|
||||
for word in ("who", "what", "when", "where", "why", "how"):
|
||||
# Require a whole-word match: a bare prefix mis-flags ordinary queries
|
||||
# like "whatsapp pricing" (-> what) or "however ..." (-> how), which
|
||||
# then get spurious boost terms OR-appended in enhance_query.
|
||||
if q == word or q.startswith(word + " "):
|
||||
return word
|
||||
return None
|
||||
|
||||
|
||||
def _extract_entities(query: str) -> Dict[str, List[str]]:
|
||||
"""Lightweight entity extraction: capitalized words and date patterns."""
|
||||
entities: Dict[str, List[str]] = {"names": [], "dates": []}
|
||||
qtype = _detect_question_type(query)
|
||||
cleaned = query
|
||||
if qtype:
|
||||
cleaned = re.sub(rf"^{qtype}\b", "", cleaned, flags=re.I).strip()
|
||||
for token in re.findall(r"\b[A-Z][a-zA-Z]+\b", cleaned):
|
||||
entities["names"].append(token)
|
||||
for year in re.findall(r"\b(?:19|20)\d{2}\b", cleaned):
|
||||
entities["dates"].append(year)
|
||||
month_day_year = re.findall(
|
||||
r"\b(?:Jan|January|Feb|February|Mar|March|Apr|April|May|Jun|June|Jul|July|Aug|August|Sep|Sept|September|Oct|October|Nov|November|Dec|December)\s+\d{1,2},?\s*\d{4}\b",
|
||||
cleaned,
|
||||
flags=re.I,
|
||||
)
|
||||
entities["dates"].extend(month_day_year)
|
||||
return entities
|
||||
|
||||
|
||||
def _split_multi_part(query: str) -> List[str]:
|
||||
"""Split a query into sub-queries on common conjunctions."""
|
||||
if not isinstance(query, str):
|
||||
return []
|
||||
parts = re.split(r"\s+and\s+|\s+or\s+|;", query, flags=re.I)
|
||||
return [p.strip() for p in parts if p.strip()]
|
||||
|
||||
|
||||
def _extract_site_filter(query: str) -> Tuple[str, Optional[str]]:
|
||||
"""Detect a 'site:example.com' token. Returns (query_without_token, site_or_None)."""
|
||||
if not isinstance(query, str):
|
||||
return "", None
|
||||
match = re.search(r"\bsite:([^\s]+)", query, flags=re.I)
|
||||
if match:
|
||||
site = match.group(1)
|
||||
new_query = re.sub(r"\bsite:[^\s]+", "", query, flags=re.I).strip()
|
||||
return new_query, site
|
||||
return query, None
|
||||
|
||||
|
||||
def _boost_entities_in_query(base_query: str, entities: Dict[str, List[str]]) -> str:
|
||||
"""Append extracted entities to the query using OR to increase relevance."""
|
||||
parts = [base_query]
|
||||
if entities.get("names"):
|
||||
parts.append(" OR ".join(f'"{n}"' for n in entities["names"]))
|
||||
if entities.get("dates"):
|
||||
parts.append(" OR ".join(f'"{d}"' for d in entities["dates"]))
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def enhance_query(original_query: str) -> Tuple[str, Optional[str]]:
|
||||
"""Process the original query: site filter, question type boosts, entity extraction."""
|
||||
if not isinstance(original_query, str):
|
||||
original_query = ""
|
||||
query_without_site, site = _extract_site_filter(original_query)
|
||||
sub_queries = _split_multi_part(query_without_site)
|
||||
|
||||
enhanced_subs: List[str] = []
|
||||
for sub in sub_queries:
|
||||
qtype = _detect_question_type(sub)
|
||||
boost_keywords = []
|
||||
if qtype == "who":
|
||||
boost_keywords.append("person")
|
||||
elif qtype == "when":
|
||||
boost_keywords.append("date")
|
||||
elif qtype == "where":
|
||||
boost_keywords.append("location")
|
||||
elif qtype == "why":
|
||||
boost_keywords.append("reason")
|
||||
elif qtype == "how":
|
||||
boost_keywords.append("method")
|
||||
entities = _extract_entities(sub)
|
||||
boosted = _boost_entities_in_query(sub, entities)
|
||||
if boost_keywords:
|
||||
boosted = f'({boosted}) OR ({" OR ".join(boost_keywords)})'
|
||||
enhanced_subs.append(boosted)
|
||||
|
||||
final_query = " AND ".join(f"({s})" for s in enhanced_subs)
|
||||
if site:
|
||||
final_query = f"{final_query} site:{site}"
|
||||
return final_query, site
|
||||
|
||||
|
||||
def build_enhanced_query(query: str, time_filter: str = None) -> str:
|
||||
"""Build an enhanced search query with optional time filtering."""
|
||||
enhanced_query, _ = enhance_query(query)
|
||||
|
||||
if time_filter:
|
||||
time_map = {"day": "d", "week": "w", "month": "m", "year": "y"}
|
||||
if time_filter in time_map:
|
||||
enhanced_query = f"{enhanced_query} after:{time_map[time_filter]}"
|
||||
logger.info(f"Added time filter '{time_filter}' to query")
|
||||
|
||||
logger.info(f"Enhanced query: '{query}' -> '{enhanced_query}'")
|
||||
return enhanced_query
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Cache duration helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def _is_news_query(query: str) -> bool:
|
||||
"""Lightweight heuristic to decide if a query is news-oriented."""
|
||||
if not isinstance(query, str):
|
||||
return False
|
||||
news_terms = {"news", "latest", "breaking", "today", "today's", "current", "updates", "happening"}
|
||||
tokens = set(re.findall(r"\b\w+\b", query.lower()))
|
||||
return bool(tokens & news_terms)
|
||||
|
||||
|
||||
def _cache_duration_for_query(query: str) -> timedelta:
|
||||
"""News queries -> 30 minutes, reference queries -> 24 hours."""
|
||||
if _is_news_query(query):
|
||||
return timedelta(minutes=30)
|
||||
return timedelta(hours=24)
|
||||
sys.modules[__name__] = _query
|
||||
|
||||
@@ -85,6 +85,11 @@ DEFAULT_SETTINGS = {
|
||||
"research_search_provider": "",
|
||||
"research_max_tokens": 16384,
|
||||
"research_extraction_timeout_seconds": 90,
|
||||
# Lightweight planning/query LLM calls happen before any search starts.
|
||||
# Keep them separately tunable so slow local backends are not capped by
|
||||
# the old 30s/60s per-call defaults.
|
||||
"research_planning_timeout_seconds": 90,
|
||||
"research_query_timeout_seconds": 90,
|
||||
"research_extraction_concurrency": 3,
|
||||
# Hard wall-clock cap on a single deep-research run. The previous 600s
|
||||
# (10 min) default cut off slow local / edge LLMs mid-synthesis; 1800s
|
||||
@@ -95,6 +100,7 @@ DEFAULT_SETTINGS = {
|
||||
# Tune via Settings or by editing data/settings.json.
|
||||
"research_run_timeout_seconds": 1800,
|
||||
"agent_max_tool_calls": 0,
|
||||
"agent_max_rounds": 20, # per-message agent step cap (clamped 1..200)
|
||||
"agent_input_token_budget": 6000,
|
||||
# Ceiling on the *auto-derived* input budget that #1230 introduced. Has
|
||||
# no effect when `agent_input_token_budget` is explicitly set (the user's
|
||||
|
||||
@@ -15,18 +15,33 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
_THINK_TAG_NAME = r"(?:think(?:ing)?|thought)"
|
||||
|
||||
# Closed reasoning blocks. Multi-pass loop in `strip_think` handles nested
|
||||
# `<think><think>...</think></think>` patterns some models emit.
|
||||
_THINK_CLOSED_RE = re.compile(r"<think(?:ing)?>[\s\S]*?</think(?:ing)?>\s*", re.IGNORECASE)
|
||||
_THINK_CLOSED_RE = re.compile(rf"<{_THINK_TAG_NAME}(?:\s+[^>]*)?>[\s\S]*?</{_THINK_TAG_NAME}>\s*", re.IGNORECASE)
|
||||
# Orphan opening or closing tags that survive after the closed-pass.
|
||||
_THINK_TAG_RE = re.compile(r"</?think(?:ing)?[^>]*>\s*", re.IGNORECASE)
|
||||
_THINK_TAG_RE = re.compile(rf"</?{_THINK_TAG_NAME}[^>]*>\s*", re.IGNORECASE)
|
||||
# Dangling opener anywhere in the response with no closer — strip everything
|
||||
# from `<think>` to the end of string.
|
||||
_THINK_OPEN_RE = re.compile(r"<think(?:ing)?>[\s\S]*$", re.IGNORECASE)
|
||||
_THINK_OPEN_RE = re.compile(rf"<{_THINK_TAG_NAME}(?:\s+[^>]*)?>[\s\S]*$", re.IGNORECASE)
|
||||
# Streaming models occasionally emit `<thinking time="0.42">`-style attributes.
|
||||
# Normalize to a plain `<think>` so the regexes above catch them.
|
||||
_THINK_ATTR_RE = re.compile(r"<think(?:ing)?\s+[^>]*>", re.IGNORECASE)
|
||||
_THINK_ATTR_CLOSE_RE = re.compile(r"</think(?:ing)?\s+[^>]*>", re.IGNORECASE)
|
||||
_THINK_ATTR_RE = re.compile(rf"<{_THINK_TAG_NAME}\s+[^>]*>", re.IGNORECASE)
|
||||
_THINK_ATTR_CLOSE_RE = re.compile(rf"</{_THINK_TAG_NAME}\s+[^>]*>", re.IGNORECASE)
|
||||
_GEMMA_THOUGHT_OPEN_RE = re.compile(r"<\|channel>thought\s*\n?[\s\S]*$", re.IGNORECASE)
|
||||
_GEMMA_RESPONSE_CHANNEL_RE = re.compile(
|
||||
r"<\|channel>response\s*\n?([\s\S]*?)<channel\|>",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_GEMMA_RESPONSE_OPEN_RE = re.compile(r"<\|channel>response\s*\n?", re.IGNORECASE)
|
||||
_GEMMA_CHANNEL_CLOSE_RE = re.compile(r"<channel\|>", re.IGNORECASE)
|
||||
_THOUGHT_TAG_OPEN_RE = re.compile(r"<thought(\s+[^>]*)?>", re.IGNORECASE)
|
||||
_THOUGHT_TAG_CLOSE_RE = re.compile(r"</thought>", re.IGNORECASE)
|
||||
_GEMMA_THOUGHT_CHANNEL_CAPTURE_RE = re.compile(
|
||||
r"<\|channel>thought\s*\n?([\s\S]*?)<channel\|>\s*",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
# Qwen and a few other models prefix the response with a "Thinking Process:"
|
||||
# block before the real answer.
|
||||
_QWEN_THINKING_RE = re.compile(
|
||||
@@ -78,6 +93,30 @@ def _strip_reasoning_prose(text: str) -> str:
|
||||
return "\n\n".join(keep).strip() if keep else text
|
||||
|
||||
|
||||
def normalize_thinking_markup(text: str) -> str:
|
||||
"""Canonicalize supported thinking wrappers to `<think>` markup.
|
||||
|
||||
The chat UI and persistence layer already understand `<think>...</think>`.
|
||||
Gemma 4 may instead emit `<|channel>thought\n...<channel|>`, and some
|
||||
gateways/models emit `<thought>...</thought>`. Normalize those shapes into
|
||||
the existing representation and strip empty thought channels.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
out = _THOUGHT_TAG_OPEN_RE.sub(lambda m: "<think" + (m.group(1) or "") + ">", text)
|
||||
out = _THOUGHT_TAG_CLOSE_RE.sub("</think>", out)
|
||||
|
||||
def _replace_gemma_thought(match: re.Match) -> str:
|
||||
thought = match.group(1).strip()
|
||||
return f"<think>{thought}</think>\n" if thought else ""
|
||||
|
||||
out = _GEMMA_THOUGHT_CHANNEL_CAPTURE_RE.sub(_replace_gemma_thought, out)
|
||||
out = _GEMMA_RESPONSE_CHANNEL_RE.sub(lambda m: m.group(1), out)
|
||||
out = _GEMMA_RESPONSE_OPEN_RE.sub("", out)
|
||||
out = _GEMMA_CHANNEL_CLOSE_RE.sub("", out)
|
||||
return out
|
||||
|
||||
|
||||
def strip_think(text: str, *, prose: bool = False, prompt_echo: bool = True) -> str:
|
||||
"""Strip `<think>` blocks from model output.
|
||||
|
||||
@@ -92,13 +131,21 @@ def strip_think(text: str, *, prose: bool = False, prompt_echo: bool = True) ->
|
||||
"The user asks:" / "We need to" leaked prompt echoes.
|
||||
|
||||
Robust to:
|
||||
* closed `<think>...</think>` (any depth, both `<think>` and `<thinking>`)
|
||||
* dangling unclosed `<think>...`
|
||||
* closed `<think>...</think>` (any depth, plus `<thinking>`/`<thought>`)
|
||||
* dangling unclosed `<think>...` / `<thought>...`
|
||||
* stray opener/closer tags
|
||||
* `<think time="0.42">`-style attributes
|
||||
* Gemma 4 `<|channel>thought...<channel|>` wrappers
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
# Gemma 4 thinking-capable models use channel control tokens rather than
|
||||
# XML tags when the runtime does not split reasoning into a separate field.
|
||||
# The thought channel can be empty in non-thinking mode; either way it is
|
||||
# not user-facing content. A response channel, when present, is only a
|
||||
# wrapper around the final answer.
|
||||
text = normalize_thinking_markup(text)
|
||||
text = _GEMMA_THOUGHT_OPEN_RE.sub("", text)
|
||||
# Normalize attributes so the closed/open regexes can catch them.
|
||||
text = _THINK_ATTR_RE.sub("<think>", text)
|
||||
text = _THINK_ATTR_CLOSE_RE.sub("</think>", text)
|
||||
|
||||
@@ -12,14 +12,127 @@ import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple
|
||||
|
||||
from src.tool_security import is_public_blocked_tool, owner_is_admin_or_single_user
|
||||
|
||||
# Persistent working directory for agent subprocesses.
|
||||
# Resolves to <repo_root>/data, which is the bind-mounted volume in Docker
|
||||
# (/app/data) and the local data directory for manual installs.
|
||||
# Using this as cwd and HOME prevents the agent from silently creating files
|
||||
# in ephemeral container layers that are lost on the next rebuild.
|
||||
_AGENT_WORKDIR = str(pathlib.Path(__file__).parent.parent / "data")
|
||||
|
||||
MAX_OUTPUT_CHARS = 10_000
|
||||
MAX_READ_CHARS = 20_000
|
||||
MAX_DIFF_LINES = 400 # cap unified-diff size returned to the UI
|
||||
|
||||
|
||||
def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]:
|
||||
"""Build a unified diff of a file write for display in the chat.
|
||||
|
||||
Returns {"text": <unified diff>, "added": N, "removed": M, "new_file": bool}
|
||||
or None when there's no textual change. Truncates very large diffs.
|
||||
"""
|
||||
if old == new:
|
||||
return None
|
||||
import difflib
|
||||
|
||||
old_lines = old.splitlines()
|
||||
new_lines = new.splitlines()
|
||||
label = path or "file"
|
||||
diff_lines = list(difflib.unified_diff(
|
||||
old_lines, new_lines,
|
||||
fromfile=f"a/{label}", tofile=f"b/{label}",
|
||||
lineterm="",
|
||||
))
|
||||
added = sum(1 for l in diff_lines if l.startswith("+") and not l.startswith("+++"))
|
||||
removed = sum(1 for l in diff_lines if l.startswith("-") and not l.startswith("---"))
|
||||
truncated = False
|
||||
if len(diff_lines) > MAX_DIFF_LINES:
|
||||
diff_lines = diff_lines[:MAX_DIFF_LINES]
|
||||
truncated = True
|
||||
text = "\n".join(diff_lines)
|
||||
if truncated:
|
||||
text += f"\n… diff truncated at {MAX_DIFF_LINES} lines"
|
||||
return {
|
||||
"text": text,
|
||||
"added": added,
|
||||
"removed": removed,
|
||||
"new_file": old == "",
|
||||
"file": os.path.basename(path) or (path or "file"),
|
||||
}
|
||||
|
||||
|
||||
async def _do_edit_file(content: str, workspace: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Exact string-replacement edit of an on-disk file.
|
||||
|
||||
content is JSON: {"path", "old_string", "new_string", "replace_all"?}.
|
||||
Fails if old_string is missing or non-unique (unless replace_all) so the
|
||||
model can't silently edit the wrong place. Returns a unified diff for the UI.
|
||||
Confined to the workspace when one is set (same policy as write_file).
|
||||
"""
|
||||
try:
|
||||
args = json.loads(content) if content.strip().startswith("{") else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
args = {}
|
||||
raw_path = (args.get("path") or "").strip()
|
||||
old = args.get("old_string", "")
|
||||
new = args.get("new_string", "")
|
||||
replace_all = bool(args.get("replace_all", False))
|
||||
if not raw_path:
|
||||
return {"error": "edit_file: path required", "exit_code": 1}
|
||||
# Confine to the workspace when set, else the same allowlist + sensitive-file
|
||||
# policy as read/write_file.
|
||||
try:
|
||||
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
|
||||
if workspace else _resolve_tool_path(raw_path))
|
||||
except ValueError as e:
|
||||
return {"error": f"edit_file: {e}", "exit_code": 1}
|
||||
if old == "":
|
||||
return {"error": "edit_file: old_string required (use write_file to create a file)", "exit_code": 1}
|
||||
if old == new:
|
||||
return {"error": "edit_file: old_string and new_string are identical", "exit_code": 1}
|
||||
|
||||
def _apply():
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
original = f.read()
|
||||
count = original.count(old)
|
||||
if count == 0:
|
||||
return original, None, "not_found"
|
||||
if count > 1 and not replace_all:
|
||||
return original, None, f"not_unique:{count}"
|
||||
updated = original.replace(old, new) if replace_all else original.replace(old, new, 1)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(updated)
|
||||
return original, updated, "ok"
|
||||
|
||||
try:
|
||||
original, updated, status = await asyncio.to_thread(_apply)
|
||||
except FileNotFoundError:
|
||||
return {"error": f"edit_file: {path}: not found (use write_file to create it)", "exit_code": 1}
|
||||
except (IsADirectoryError, UnicodeDecodeError):
|
||||
return {"error": f"edit_file: {path}: not an editable text file", "exit_code": 1}
|
||||
except PermissionError:
|
||||
return {"error": f"edit_file: {path}: permission denied", "exit_code": 1}
|
||||
except OSError as e:
|
||||
return {"error": f"edit_file: {path}: {e}", "exit_code": 1}
|
||||
|
||||
if status == "not_found":
|
||||
return {"error": f"edit_file: old_string not found in {path}. Read the file and match it exactly.", "exit_code": 1}
|
||||
if status.startswith("not_unique"):
|
||||
n = status.split(":", 1)[1]
|
||||
return {"error": f"edit_file: old_string is not unique in {path} ({n} matches). Add surrounding context or set replace_all=true.", "exit_code": 1}
|
||||
|
||||
n = original.count(old)
|
||||
result = {"output": f"Edited {path} ({n} replacement{'s' if n != 1 else ''})", "exit_code": 0}
|
||||
diff = _unified_diff(original, updated, path)
|
||||
if diff:
|
||||
result["diff"] = diff
|
||||
return result
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Path confinement for read_file / write_file
|
||||
@@ -158,6 +271,40 @@ def _resolve_tool_path(raw_path: str) -> str:
|
||||
f"path '{raw_path}' is outside the allowed roots"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_tool_path_in_workspace(workspace: str, raw_path: str) -> str:
|
||||
"""Confine a model-supplied path to the active workspace.
|
||||
|
||||
Layered on top of upstream's path policy: the workspace is the allowed
|
||||
root (relative paths resolve under it; paths that escape it are rejected),
|
||||
and the sensitive-file deny list (.ssh, .gnupg, id_rsa, …) still applies
|
||||
inside it. When no workspace is set, callers use _resolve_tool_path (the
|
||||
default data/tmp allowlist) instead.
|
||||
"""
|
||||
if raw_path is None or not str(raw_path).strip():
|
||||
raise ValueError("path is required")
|
||||
base = os.path.realpath(workspace)
|
||||
expanded = os.path.expanduser(str(raw_path).strip())
|
||||
candidate = expanded if os.path.isabs(expanded) else os.path.join(base, expanded)
|
||||
resolved = os.path.realpath(candidate)
|
||||
if _is_sensitive_path(resolved):
|
||||
raise ValueError(
|
||||
f"path '{raw_path}' is inside a sensitive directory "
|
||||
f"(e.g. .ssh, .gnupg) or matches a sensitive filename"
|
||||
)
|
||||
if resolved != base:
|
||||
# normcase so containment holds on case-insensitive filesystems
|
||||
# (Windows, default macOS): it lowercases on Windows and is a no-op on
|
||||
# POSIX. commonpath raises ValueError across Windows drives (C: vs D:)
|
||||
# or mixed abs/rel — both mean "outside", so the except rejects them.
|
||||
nbase = os.path.normcase(base)
|
||||
try:
|
||||
if os.path.commonpath([os.path.normcase(resolved), nbase]) != nbase:
|
||||
raise ValueError
|
||||
except ValueError:
|
||||
raise ValueError(f"path '{raw_path}' is outside the workspace ({workspace})")
|
||||
return resolved
|
||||
|
||||
# Bash + python tools used to share a single 60s timeout. That's
|
||||
# enough for one-shot commands but starves real workloads (pip
|
||||
# install, ffmpeg conversions, etc.) — and worse, the agent saw the
|
||||
@@ -186,6 +333,39 @@ def get_mcp_manager():
|
||||
return agent_tools.get_mcp_manager()
|
||||
|
||||
|
||||
# Directories ignored by the code-nav tools' Python fallbacks so results aren't
|
||||
# polluted by VCS internals / dependency trees / build caches. ripgrep already
|
||||
# honours .gitignore; this is the parity floor for the no-rg path (and the
|
||||
# explicit excludes passed to rg so it skips them even without a .gitignore).
|
||||
_CODENAV_SKIP_DIRS = frozenset({
|
||||
".git", ".hg", ".svn", "node_modules", "venv", ".venv", "__pycache__",
|
||||
".mypy_cache", ".pytest_cache", ".ruff_cache", "dist", "build",
|
||||
".next", ".cache", "site-packages", ".idea", ".tox",
|
||||
})
|
||||
# Per-tool result caps (keep tool output cheap + model-friendly).
|
||||
_CODENAV_MAX_HITS = 200
|
||||
_CODENAV_MAX_LINE = 400
|
||||
|
||||
|
||||
def _resolve_search_root(raw_path: str, workspace: Optional[str] = None) -> str:
|
||||
"""Resolve + confine a code-nav path (grep/glob/ls).
|
||||
|
||||
With a workspace set, the workspace folder is the root and supplied paths are
|
||||
confined inside it (same policy as read_file). Without one, an empty path
|
||||
defaults to the agent's primary root (project data dir) and a supplied path
|
||||
is confined by the global allowlist + sensitive-file policy.
|
||||
"""
|
||||
raw = (raw_path or "").strip()
|
||||
if workspace:
|
||||
if not raw:
|
||||
return os.path.realpath(workspace)
|
||||
return _resolve_tool_path_in_workspace(workspace, raw)
|
||||
if not raw:
|
||||
roots = _tool_path_roots()
|
||||
return roots[0] if roots else os.path.realpath(".")
|
||||
return _resolve_tool_path(raw)
|
||||
|
||||
|
||||
def _truncate(text: str, limit: int = MAX_OUTPUT_CHARS) -> str:
|
||||
if len(text) > limit:
|
||||
return text[:limit] + f"\n... (truncated, {len(text)} chars total)"
|
||||
@@ -396,11 +576,12 @@ async def _call_mcp_tool(
|
||||
tool: str,
|
||||
content: str,
|
||||
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
|
||||
workspace: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""Route a legacy tool call through the MCP manager, with direct fallbacks."""
|
||||
mcp = get_mcp_manager()
|
||||
if not mcp:
|
||||
return await _direct_fallback(tool, content, progress_cb=progress_cb) or {"error": f"MCP manager not available for tool '{tool}'", "exit_code": 1}
|
||||
return await _direct_fallback(tool, content, progress_cb=progress_cb, workspace=workspace) or {"error": f"MCP manager not available for tool '{tool}'", "exit_code": 1}
|
||||
|
||||
server_id, tool_name = _MCP_TOOL_MAP[tool]
|
||||
qualified = f"mcp__{server_id}__{tool_name}"
|
||||
@@ -409,7 +590,7 @@ async def _call_mcp_tool(
|
||||
|
||||
# If MCP server not connected, try direct fallback
|
||||
if isinstance(result, dict) and result.get("exit_code") == 1 and "not connected" in result.get("error", ""):
|
||||
fallback = await _direct_fallback(tool, content, progress_cb=progress_cb)
|
||||
fallback = await _direct_fallback(tool, content, progress_cb=progress_cb, workspace=workspace)
|
||||
if fallback:
|
||||
return fallback
|
||||
|
||||
@@ -436,6 +617,7 @@ async def _direct_fallback(
|
||||
tool: str,
|
||||
content: str,
|
||||
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
|
||||
workspace: Optional[str] = None,
|
||||
) -> Optional[Dict]:
|
||||
"""In-process execution path for the eight tools that used to live as
|
||||
stdio MCP servers under mcp_servers/. Those servers were deleted in
|
||||
@@ -461,6 +643,7 @@ async def _direct_fallback(
|
||||
"TERM": "xterm-256color",
|
||||
"COLUMNS": "120",
|
||||
"LINES": "40",
|
||||
"HOME": _AGENT_WORKDIR,
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -470,6 +653,7 @@ async def _direct_fallback(
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=_subproc_env,
|
||||
cwd=workspace or _AGENT_WORKDIR,
|
||||
)
|
||||
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
|
||||
proc,
|
||||
@@ -496,6 +680,7 @@ async def _direct_fallback(
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=_subproc_env,
|
||||
cwd=workspace or _AGENT_WORKDIR,
|
||||
)
|
||||
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
|
||||
proc,
|
||||
@@ -512,14 +697,43 @@ async def _direct_fallback(
|
||||
return {"output": output or "(no output)", "exit_code": rc or 0}
|
||||
|
||||
if tool == "read_file":
|
||||
raw_path = content.split("\n", 1)[0].strip()
|
||||
# Args: plain path on line 1 (back-compat) OR JSON
|
||||
# {path, offset?, limit?} where offset/limit are a 1-based line range.
|
||||
raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0
|
||||
_stripped = content.strip()
|
||||
if _stripped.startswith("{"):
|
||||
try:
|
||||
_a = _json.loads(_stripped)
|
||||
raw_path = str(_a.get("path", "")).strip()
|
||||
offset = int(_a.get("offset") or 0)
|
||||
limit = int(_a.get("limit") or 0)
|
||||
except (_json.JSONDecodeError, TypeError, ValueError):
|
||||
pass
|
||||
try:
|
||||
path = _resolve_tool_path(raw_path)
|
||||
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
|
||||
if workspace else _resolve_tool_path(raw_path))
|
||||
except ValueError as e:
|
||||
return {"error": f"read_file: {e}", "exit_code": 1}
|
||||
try:
|
||||
# Run blocking read in a thread to keep the loop responsive
|
||||
# Run blocking read in a thread to keep the loop responsive.
|
||||
def _read():
|
||||
if offset > 0 or limit > 0:
|
||||
# Line-range read: slice [offset, offset+limit).
|
||||
start = max(offset, 1)
|
||||
out, n, budget = [], 0, MAX_READ_CHARS
|
||||
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
||||
for i, line in enumerate(f, 1):
|
||||
if i < start:
|
||||
continue
|
||||
if limit > 0 and n >= limit:
|
||||
break
|
||||
out.append(line)
|
||||
n += 1
|
||||
budget -= len(line)
|
||||
if budget <= 0:
|
||||
out.append(f"\n... [truncated at {MAX_READ_CHARS} chars]")
|
||||
break
|
||||
return "".join(out)
|
||||
with open(path, "r", encoding="utf-8", errors="replace") as f:
|
||||
return f.read(MAX_READ_CHARS + 1)
|
||||
data = await asyncio.to_thread(_read)
|
||||
@@ -527,10 +741,11 @@ async def _direct_fallback(
|
||||
return {"error": f"read_file: {path}: not found", "exit_code": 1}
|
||||
except PermissionError:
|
||||
return {"error": f"read_file: {path}: permission denied", "exit_code": 1}
|
||||
except IsADirectoryError:
|
||||
return {"error": f"read_file: {path}: is a directory (use ls)", "exit_code": 1}
|
||||
except OSError as e:
|
||||
return {"error": f"read_file: {path}: {e}", "exit_code": 1}
|
||||
truncated = len(data) > MAX_READ_CHARS
|
||||
if truncated:
|
||||
if not (offset > 0 or limit > 0) and len(data) > MAX_READ_CHARS:
|
||||
data = data[:MAX_READ_CHARS] + f"\n... [truncated at {MAX_READ_CHARS} chars]"
|
||||
return {"output": data, "exit_code": 0}
|
||||
|
||||
@@ -539,23 +754,226 @@ async def _direct_fallback(
|
||||
raw_path = lines[0].strip()
|
||||
body = lines[1] if len(lines) > 1 else ""
|
||||
try:
|
||||
path = _resolve_tool_path(raw_path)
|
||||
path = (_resolve_tool_path_in_workspace(workspace, raw_path)
|
||||
if workspace else _resolve_tool_path(raw_path))
|
||||
except ValueError as e:
|
||||
return {"error": f"write_file: {e}", "exit_code": 1}
|
||||
try:
|
||||
def _write():
|
||||
# Capture prior content (best-effort, text) so we can show a
|
||||
# before/after diff. Missing/binary file → treat as empty.
|
||||
old = ""
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
old = f.read()
|
||||
except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError):
|
||||
old = ""
|
||||
d = os.path.dirname(path)
|
||||
if d:
|
||||
os.makedirs(d, exist_ok=True)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write(body)
|
||||
return len(body)
|
||||
size = await asyncio.to_thread(_write)
|
||||
return old, len(body)
|
||||
old_content, size = await asyncio.to_thread(_write)
|
||||
except PermissionError:
|
||||
return {"error": f"write_file: {path}: permission denied", "exit_code": 1}
|
||||
except OSError as e:
|
||||
return {"error": f"write_file: {path}: {e}", "exit_code": 1}
|
||||
return {"output": f"Wrote {size} bytes to {path}", "exit_code": 0}
|
||||
diff = _unified_diff(old_content, body, path)
|
||||
result = {"output": f"Wrote {size} bytes to {path}", "exit_code": 0}
|
||||
if diff:
|
||||
result["diff"] = diff
|
||||
return result
|
||||
|
||||
if tool == "grep":
|
||||
# Args (JSON): {pattern, path?, glob?, ignore_case?, max_results?}.
|
||||
# Bare string → treated as the pattern.
|
||||
args: Dict[str, Any] = {}
|
||||
_s = (content or "").strip()
|
||||
if _s.startswith("{"):
|
||||
try:
|
||||
args = _json.loads(_s)
|
||||
except _json.JSONDecodeError:
|
||||
args = {}
|
||||
else:
|
||||
args = {"pattern": _s}
|
||||
pattern = str(args.get("pattern", "")).strip()
|
||||
if not pattern:
|
||||
return {"error": "grep: pattern is required", "exit_code": 1}
|
||||
ignore_case = bool(args.get("ignore_case"))
|
||||
glob_pat = str(args.get("glob", "") or "").strip()
|
||||
try:
|
||||
max_hits = int(args.get("max_results") or _CODENAV_MAX_HITS)
|
||||
except (TypeError, ValueError):
|
||||
max_hits = _CODENAV_MAX_HITS
|
||||
max_hits = max(1, min(max_hits, _CODENAV_MAX_HITS))
|
||||
try:
|
||||
root = _resolve_search_root(str(args.get("path", "")), workspace)
|
||||
except ValueError as e:
|
||||
return {"error": f"grep: {e}", "exit_code": 1}
|
||||
|
||||
def _grep():
|
||||
import re as _re
|
||||
import shutil
|
||||
rg = shutil.which("rg")
|
||||
if rg:
|
||||
cmd = [rg, "--line-number", "--no-heading", "--color=never",
|
||||
"--max-count", str(max_hits)]
|
||||
if ignore_case:
|
||||
cmd.append("--ignore-case")
|
||||
if glob_pat:
|
||||
cmd += ["--glob", glob_pat]
|
||||
# Exclude junk dirs even when the tree has no .gitignore, so
|
||||
# results match the Python fallback's skip set.
|
||||
for _d in _CODENAV_SKIP_DIRS:
|
||||
cmd += ["--glob", f"!**/{_d}/**"]
|
||||
cmd += ["--regexp", pattern, root]
|
||||
try:
|
||||
import subprocess
|
||||
p = subprocess.run(cmd, capture_output=True, text=True, timeout=20)
|
||||
lines = [ln for ln in (p.stdout or "").splitlines() if ln][:max_hits]
|
||||
return lines, None
|
||||
except subprocess.TimeoutExpired:
|
||||
return None, "grep: timed out"
|
||||
except Exception as _e:
|
||||
return None, f"grep: {_e}"
|
||||
# Python fallback (no ripgrep): walk + regex.
|
||||
try:
|
||||
rx = _re.compile(pattern, _re.IGNORECASE if ignore_case else 0)
|
||||
except _re.error as _e:
|
||||
return None, f"grep: bad pattern: {_e}"
|
||||
import fnmatch
|
||||
hits = []
|
||||
if os.path.isfile(root):
|
||||
file_iter = [root]
|
||||
else:
|
||||
file_iter = []
|
||||
for dp, dns, fns in os.walk(root):
|
||||
dns[:] = [d for d in dns if d not in _CODENAV_SKIP_DIRS]
|
||||
for fn in fns:
|
||||
if glob_pat and not fnmatch.fnmatch(fn, glob_pat):
|
||||
continue
|
||||
file_iter.append(os.path.join(dp, fn))
|
||||
for fp in file_iter:
|
||||
if len(hits) >= max_hits:
|
||||
break
|
||||
try:
|
||||
with open(fp, "r", encoding="utf-8", errors="strict") as f:
|
||||
for i, line in enumerate(f, 1):
|
||||
if rx.search(line):
|
||||
hits.append(f"{fp}:{i}:{line.rstrip()[:_CODENAV_MAX_LINE]}")
|
||||
if len(hits) >= max_hits:
|
||||
break
|
||||
except (UnicodeDecodeError, OSError):
|
||||
continue # skip binary / unreadable
|
||||
return hits, None
|
||||
|
||||
lines, err = await asyncio.to_thread(_grep)
|
||||
if err:
|
||||
return {"error": err, "exit_code": 1}
|
||||
if not lines:
|
||||
return {"output": f"No matches for {pattern!r} under {root}", "exit_code": 0}
|
||||
out = "\n".join(ln[:_CODENAV_MAX_LINE] for ln in lines)
|
||||
if len(lines) >= max_hits:
|
||||
out += f"\n... [capped at {max_hits} matches]"
|
||||
return {"output": _truncate(out), "exit_code": 0}
|
||||
|
||||
if tool == "glob":
|
||||
args = {}
|
||||
_s = (content or "").strip()
|
||||
if _s.startswith("{"):
|
||||
try:
|
||||
args = _json.loads(_s)
|
||||
except _json.JSONDecodeError:
|
||||
args = {}
|
||||
else:
|
||||
args = {"pattern": _s}
|
||||
pattern = str(args.get("pattern", "")).strip()
|
||||
if not pattern:
|
||||
return {"error": "glob: pattern is required", "exit_code": 1}
|
||||
try:
|
||||
root = _resolve_search_root(str(args.get("path", "")), workspace)
|
||||
except ValueError as e:
|
||||
return {"error": f"glob: {e}", "exit_code": 1}
|
||||
|
||||
def _glob():
|
||||
from pathlib import Path
|
||||
base = Path(root)
|
||||
if not base.is_dir():
|
||||
return None, f"glob: {root}: not a directory"
|
||||
matched = []
|
||||
try:
|
||||
for p in base.rglob(pattern):
|
||||
if set(p.relative_to(base).parts) & _CODENAV_SKIP_DIRS:
|
||||
continue
|
||||
try:
|
||||
mtime = p.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = 0
|
||||
matched.append((mtime, str(p)))
|
||||
if len(matched) > _CODENAV_MAX_HITS * 5:
|
||||
break
|
||||
except (OSError, ValueError) as _e:
|
||||
return None, f"glob: {_e}"
|
||||
matched.sort(key=lambda t: t[0], reverse=True) # newest first
|
||||
return [pth for _, pth in matched[:_CODENAV_MAX_HITS]], None
|
||||
|
||||
paths, err = await asyncio.to_thread(_glob)
|
||||
if err:
|
||||
return {"error": err, "exit_code": 1}
|
||||
if not paths:
|
||||
return {"output": f"No files matching {pattern!r} under {root}", "exit_code": 0}
|
||||
out = "\n".join(paths)
|
||||
if len(paths) >= _CODENAV_MAX_HITS:
|
||||
out += f"\n... [capped at {_CODENAV_MAX_HITS} files]"
|
||||
return {"output": _truncate(out), "exit_code": 0}
|
||||
|
||||
if tool == "ls":
|
||||
raw_path = ""
|
||||
_s = (content or "").strip()
|
||||
if _s.startswith("{"):
|
||||
try:
|
||||
raw_path = str(_json.loads(_s).get("path", "")).strip()
|
||||
except _json.JSONDecodeError:
|
||||
raw_path = ""
|
||||
else:
|
||||
raw_path = _s.split("\n", 1)[0].strip()
|
||||
try:
|
||||
root = _resolve_search_root(raw_path, workspace)
|
||||
except ValueError as e:
|
||||
return {"error": f"ls: {e}", "exit_code": 1}
|
||||
|
||||
def _ls():
|
||||
if not os.path.isdir(root):
|
||||
return None, f"ls: {root}: not a directory"
|
||||
rows = []
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
if entry.name.startswith("."):
|
||||
continue
|
||||
try:
|
||||
is_dir = entry.is_dir(follow_symlinks=False)
|
||||
size = entry.stat(follow_symlinks=False).st_size if not is_dir else 0
|
||||
except OSError:
|
||||
continue
|
||||
rows.append((is_dir, entry.name, size))
|
||||
except (PermissionError, OSError) as _e:
|
||||
return None, f"ls: {_e}"
|
||||
rows.sort(key=lambda r: (not r[0], r[1].lower())) # dirs first, then name
|
||||
lines = [f"{root}:"]
|
||||
for is_dir, name, size in rows[:_CODENAV_MAX_HITS]:
|
||||
lines.append(f" {name}/" if is_dir else f" {name} ({size} B)")
|
||||
if len(rows) > _CODENAV_MAX_HITS:
|
||||
lines.append(f" ... [{len(rows) - _CODENAV_MAX_HITS} more]")
|
||||
if not rows:
|
||||
lines.append(" (empty)")
|
||||
return "\n".join(lines), None
|
||||
|
||||
out, err = await asyncio.to_thread(_ls)
|
||||
if err:
|
||||
return {"error": err, "exit_code": 1}
|
||||
return {"output": _truncate(out), "exit_code": 0}
|
||||
|
||||
if tool == "web_search":
|
||||
from src.search import comprehensive_web_search
|
||||
@@ -685,6 +1103,7 @@ async def execute_tool_block(
|
||||
disabled_tools: Optional[set] = None,
|
||||
owner: Optional[str] = None,
|
||||
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
|
||||
workspace: Optional[str] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Execute a single tool block. Returns (description, result_dict).
|
||||
|
||||
@@ -773,7 +1192,7 @@ async def execute_tool_block(
|
||||
_is_bg, _bg_cmd = _split_bg_marker(content)
|
||||
if _is_bg and _bg_cmd:
|
||||
from src import bg_jobs
|
||||
rec = bg_jobs.launch(_bg_cmd, session_id=session_id)
|
||||
rec = bg_jobs.launch(_bg_cmd, session_id=session_id, cwd=workspace or _AGENT_WORKDIR)
|
||||
short = _bg_cmd.strip().split(chr(10))[0][:80]
|
||||
desc = f"bash (background): {short}"
|
||||
result = {
|
||||
@@ -795,7 +1214,14 @@ async def execute_tool_block(
|
||||
if tool in _MCP_TOOL_MAP:
|
||||
first_line = content.split(chr(10))[0][:80]
|
||||
desc = f"{tool}: {first_line}"
|
||||
result = await _call_mcp_tool(tool, content, progress_cb=progress_cb)
|
||||
result = await _call_mcp_tool(tool, content, progress_cb=progress_cb, workspace=workspace)
|
||||
elif tool in ("grep", "glob", "ls"):
|
||||
# Code-navigation tools — no MCP server; run the direct implementation.
|
||||
# Confined to the workspace when one is set (same policy as read_file).
|
||||
first_line = content.split(chr(10))[0][:80]
|
||||
desc = f"{tool}: {first_line}"
|
||||
result = await _direct_fallback(tool, content, progress_cb=progress_cb, workspace=workspace) \
|
||||
or {"error": f"{tool}: execution failed", "exit_code": 1}
|
||||
elif tool == "create_document":
|
||||
title = content.split("\n")[0].strip()[:60]
|
||||
desc = f"create_document: {title}"
|
||||
@@ -898,6 +1324,9 @@ async def execute_tool_block(
|
||||
elif tool == "edit_image":
|
||||
desc = "edit_image"
|
||||
result = await do_edit_image(content, owner=owner)
|
||||
elif tool == "edit_file":
|
||||
result = await _do_edit_file(content, workspace=workspace)
|
||||
desc = result.get("output") or result.get("error") or "edit_file"
|
||||
elif tool == "trigger_research":
|
||||
desc = "trigger_research"
|
||||
result = await do_trigger_research(content, owner=owner)
|
||||
|
||||
@@ -22,7 +22,13 @@ logger = logging.getLogger(__name__)
|
||||
# Tools that are ALWAYS included regardless of retrieval results.
|
||||
# These are the most commonly needed and should never be missing.
|
||||
ALWAYS_AVAILABLE = frozenset({
|
||||
"bash", "python", "web_search", "web_fetch", "read_file",
|
||||
"bash", "python", "web_search", "web_fetch",
|
||||
# File tools: read AND write/edit. An agent with disk access should always
|
||||
# be able to change files, not just read them — otherwise a bare "edit X"
|
||||
# request can miss write_file/edit_file (RAG-only) and the model wrongly
|
||||
# falls back to edit_document (editor panel). All admin-gated by tool_security.
|
||||
"read_file", "write_file", "edit_file",
|
||||
"grep", "glob", "ls", # code-navigation tools (admin-gated by tool_security)
|
||||
"api_call", # For configured integrations (Miniflux, Gitea, Linkding, etc.)
|
||||
# The two genuinely AMBIENT cookbook tools — "what's running" and
|
||||
# "kill it" can be asked any time without prior cookbook context,
|
||||
@@ -71,8 +77,12 @@ BUILTIN_TOOL_DESCRIPTIONS: Dict[str, str] = {
|
||||
"python": "Execute Python code for computation, data processing, math, scripting, parsing, API calls. Not for writing code for the user.",
|
||||
"web_search": "Quick single web lookup for a fact, current event, or doc mid-task. NOT for 'research X' / 'do research on X' requests — those are deep-research jobs (use trigger_research). web_search = one query; trigger_research = a full researched report in the sidebar.",
|
||||
"web_fetch": "Fetch and read the text content of a specific URL/website the user names (e.g. 'check example.com', 'open this link'). Use when you have a concrete URL; for open-ended lookups use web_search instead.",
|
||||
"read_file": "Read a file from disk and return its contents. View source code, config files, logs.",
|
||||
"write_file": "Write content to a file on disk. Create new files, save output, update configs.",
|
||||
"read_file": "Read a file from disk and return its contents. View source code, config files, logs. Supports an optional line range (offset/limit) for large files.",
|
||||
"grep": "Search file CONTENTS for a regex across a directory tree (ripgrep-backed, honours .gitignore). Returns file:line:match. Use to find where code/symbols/strings live — prefer over bash grep.",
|
||||
"glob": "Find FILES by glob pattern (e.g. '**/*.py'), newest first. Use to locate files by name/extension — prefer over bash find/ls.",
|
||||
"ls": "List a directory's entries (folders then files with sizes). Use to see what's in a folder — prefer over bash ls.",
|
||||
"write_file": "Write/create or fully rewrite a file ON DISK (source code, configs, project files). Use for new files or full rewrites — NOT create_document (editor panel) and NOT a bash heredoc.",
|
||||
"edit_file": "Edit an existing file ON DISK by exact string replacement (fix a bug, change a function). Shows a diff. The tool for changing files on disk — NOT edit_document (editor panel) and NOT bash sed/heredoc.",
|
||||
"create_document": "Create a new document in the editor panel. For code, articles, text content longer than 15 lines, unless an already-open document/email draft is the obvious target. If an email compose draft is open, edit that draft instead of creating another document.",
|
||||
"edit_document": "Preferred tool for editing an existing document — targeted find-and-replace. Use for any small change: add a function, fix a bug, tweak a section, rename things.",
|
||||
"update_document": "Replace the entire active document content. ONLY for full rewrites (>50% changed). Do not use for small edits — use edit_document instead.",
|
||||
|
||||
@@ -82,16 +82,65 @@ FUNCTION_TOOL_SCHEMAS = [
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read a file from disk",
|
||||
"description": "Read a file from disk. Optionally read a line range with offset/limit for large files.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "File path to read"}
|
||||
"path": {"type": "string", "description": "File path to read"},
|
||||
"offset": {"type": "integer", "description": "1-based line to start reading from (optional)"},
|
||||
"limit": {"type": "integer", "description": "Max number of lines to read from offset (optional)"}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "grep",
|
||||
"description": "Search file contents for a regular expression across a directory tree (uses ripgrep when available, respecting .gitignore). Returns file:line:match. PREFER this over `bash grep/rg` for code search — confined to the allowed roots, structured output.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regular expression to search for"},
|
||||
"path": {"type": "string", "description": "Directory or file to search (optional; defaults to the project root)"},
|
||||
"glob": {"type": "string", "description": "Only search files matching this glob, e.g. '*.py' (optional)"},
|
||||
"ignore_case": {"type": "boolean", "description": "Case-insensitive match (optional)"},
|
||||
"max_results": {"type": "integer", "description": "Max matches to return (optional)"}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "glob",
|
||||
"description": "Find files by glob pattern (recursive), newest first. e.g. '**/*.py'. PREFER this over `bash find/ls` for locating files — confined to the allowed roots.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Glob pattern, e.g. '**/*.ts' or 'src/**/test_*.py'"},
|
||||
"path": {"type": "string", "description": "Base directory (optional; defaults to the project root)"}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ls",
|
||||
"description": "List the entries of a directory (folders first, then files with sizes). PREFER this over `bash ls` — confined to the allowed roots.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Directory to list (optional; defaults to the project root)"}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@@ -107,6 +156,23 @@ FUNCTION_TOOL_SCHEMAS = [
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "edit_file",
|
||||
"description": "Edit a file ON DISK by exact string replacement (home folder, project files, any real path like ~/sweden.txt or /path/to/file). This is the right tool for files on disk — NOT edit_document (that's for editor-panel documents). PREFER this over bash (sed/echo) — it shows a diff. old_string must match the file exactly and be unique (or set replace_all). Use write_file to create a new file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "File path to edit"},
|
||||
"old_string": {"type": "string", "description": "Exact text to replace (must match the file, including indentation)"},
|
||||
"new_string": {"type": "string", "description": "Replacement text"},
|
||||
"replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match"}
|
||||
},
|
||||
"required": ["path", "old_string", "new_string"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
@@ -127,7 +193,7 @@ FUNCTION_TOOL_SCHEMAS = [
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "edit_document",
|
||||
"description": "PREFERRED way to change an existing document. Targeted find-and-replace with multiple FIND/REPLACE pairs per call. Use this for any edit smaller than a full rewrite: adding a function, fixing a bug, tweaking a section, renaming things. Do NOT send the whole file back via update_document for small edits — it wastes tokens and is hard to review.",
|
||||
"description": "Edit a document OPEN IN THE EDITOR PANEL (created via create_document) — NOT a file on disk. For files on disk (home folder, project files, anything with a path like ~/x.txt or /path/to/file) use edit_file instead. Targeted find-and-replace with multiple FIND/REPLACE pairs per call; use for any edit smaller than a full rewrite. Do NOT send the whole file back via update_document for small edits.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -1126,9 +1192,17 @@ def function_call_to_tool_block(name: str, arguments: str) -> Optional[ToolBlock
|
||||
else:
|
||||
content = args.get("query", "")
|
||||
elif tool_type == "read_file":
|
||||
content = args.get("path", "")
|
||||
# Plain path (back-compat) unless a line range is requested → JSON.
|
||||
if args.get("offset") or args.get("limit"):
|
||||
content = json.dumps(args)
|
||||
else:
|
||||
content = args.get("path", "")
|
||||
elif tool_type in ("grep", "glob", "ls"):
|
||||
content = json.dumps(args) if args else "{}"
|
||||
elif tool_type == "write_file":
|
||||
content = args.get("path", "") + "\n" + args.get("content", "")
|
||||
elif tool_type == "edit_file":
|
||||
content = json.dumps(args)
|
||||
elif tool_type == "create_document":
|
||||
parts = [args.get("title", "Untitled")]
|
||||
if args.get("language"):
|
||||
|
||||
@@ -16,6 +16,10 @@ NON_ADMIN_BLOCKED_TOOLS = {
|
||||
"python",
|
||||
"read_file",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"grep",
|
||||
"glob",
|
||||
"ls",
|
||||
"search_chats",
|
||||
"manage_memory",
|
||||
"manage_skills",
|
||||
|
||||
Reference in New Issue
Block a user