diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index e9ed637..911b4b9 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -25,7 +25,7 @@ Fixes #
## Checklist
- [ ] I searched [open issues](https://github.com/pewdiepie-archdaemon/odysseus/issues) and [open PRs](https://github.com/pewdiepie-archdaemon/odysseus/pulls) — this is not a duplicate.
-- [ ] This PR targets `main`
+- [ ] This PR targets `dev`
- [ ] My changes are limited to the scope described above — no unrelated refactors or whitespace changes mixed in.
- [ ] I actually ran the app (`docker compose up` or `uvicorn app:app`) and verified the change works end-to-end. Type-checks and unit tests are not enough.
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..3978ef5
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,60 @@
+name: CI
+
+on:
+ push:
+ branches: [main]
+ pull_request:
+
+# Least privilege: none of the jobs write to the repo.
+permissions:
+ contents: read
+
+# Cancel superseded runs on the same ref to save Actions minutes.
+concurrency:
+ group: ci-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ python-syntax:
+ name: Python syntax (compileall)
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
+ - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
+ with:
+ python-version: "3.11"
+ # Byte-compile sources — catches syntax errors without installing deps.
+ - run: python -m compileall -q app.py core routes src services scripts tests
+
+ node-syntax:
+ name: JS syntax (node --check)
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
+ - uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4
+ with:
+ node-version: "20"
+ # Syntax-check our own JS (skip vendored libs in static/lib).
+ - name: node --check
+ run: |
+ shopt -s globstar nullglob
+ for f in static/app.js static/js/**/*.js; do
+ node --check "$f"
+ done
+
+ python-tests:
+ name: Python tests (pytest)
+ runs-on: ubuntu-latest
+ # Informational for now: the suite has known flaky / environment-dependent
+ # failures (test isolation + embedding-model assertions). Tracked under the
+ # ROADMAP "fresh install smoke tests" item; make this required once green.
+ continue-on-error: true
+ steps:
+ - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
+ - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5
+ with:
+ python-version: "3.11"
+ cache: pip
+ - run: pip install -r requirements.txt
+ - run: mkdir -p data # sqlite DB lives at ./data/app.db
+ - run: python -m pytest -q
diff --git a/Dockerfile b/Dockerfile
index 535f0a0..ad273ce 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -22,9 +22,12 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
WORKDIR /app
-# Install Python deps first (layer cache)
-COPY requirements.txt .
-RUN pip install --no-cache-dir -r requirements.txt
+# Install Python deps first (layer cache). Optional extras (PyMuPDF AGPL, etc.)
+# are opt-in so the default image stays MIT-core; see requirements-optional.txt.
+ARG INSTALL_OPTIONAL=false
+COPY requirements.txt requirements-optional.txt ./
+RUN pip install --no-cache-dir -r requirements.txt \
+ && if [ "$INSTALL_OPTIONAL" = "true" ]; then pip install --no-cache-dir -r requirements-optional.txt; fi
# Copy app code
COPY . .
diff --git a/README.md b/README.md
index 4fd7f48..638089f 100644
--- a/README.md
+++ b/README.md
@@ -11,7 +11,7 @@
A self-hosted AI workspace -- meant to be the self-hosted version of the UI experience you get from ChatGPT and Claude. But with more jank and fun. Running on your own hardware, with your own data -- local-first, privacy-first, and no trojan.
## Features
- - **Chat** -- chat with any local model or API; adding them is super simple. vLLM · llama.cpp · Ollama · OpenRouter · OpenAI
+ - **Chat** -- chat with any local model or API; adding them is super simple. vLLM · llama.cpp · Ollama · OpenRouter · OpenAI · GitHub Copilot
- **Agent** -- hand it tools and let it run the whole task itself. built on [opencode](https://github.com/anomalyco/opencode) · MCP · web · files · shell · skills · memory
- **Cookbook** -- Scans your hardware, recommends models, click to download and serve.. easy! built on [llmfit](https://github.com/AlexsJones/llmfit) · VRAM-aware · GGUF / FP8 / AWQ · fit scoring · vLLM / llama.cpp serving
- **Deep Research** -- multi-step runs that gather, read, and synthesize sources into a nice visual report. adapted from [Tongyi DeepResearch](https://github.com/Alibaba-NLP/DeepResearch)
@@ -64,6 +64,8 @@ cd odysseus
cp .env.example .env # optional, but recommended for explicit defaults
docker compose up -d --build
```
+To include optional extras in the image (PDF viewer, Office extraction; includes AGPL PyMuPDF), build with `docker compose build --build-arg INSTALL_OPTIONAL=true` before `up`.
+
Open `http://localhost:7000` when the containers are healthy. Docker Compose
binds the web UI to `127.0.0.1` by default. If the port is taken, set
`APP_PORT=7001` in `.env` and recreate the container. Set `APP_BIND=0.0.0.0`
diff --git a/app.py b/app.py
index 7a00722..b34b818 100644
--- a/app.py
+++ b/app.py
@@ -34,7 +34,6 @@ from dotenv import load_dotenv
# is silently ignored and the user is unexpectedly forced to log in (issue #142).
# utf-8-sig reads plain UTF-8 (no BOM) identically, so this is safe everywhere.
load_dotenv(encoding="utf-8-sig")
-import uuid
import asyncio
import logging
@@ -526,6 +525,9 @@ upload_cleanup_task = None
from routes.emoji_routes import setup_emoji_routes
app.include_router(setup_emoji_routes())
+from routes.workspace_routes import setup_workspace_routes
+app.include_router(setup_workspace_routes())
+
# Sessions
from routes.session_routes import setup_session_routes
session_config = {"REQUEST_TIMEOUT": REQUEST_TIMEOUT, "OPENAI_API_KEY": OPENAI_API_KEY, "SESSIONS_FILE": SESSIONS_FILE}
@@ -588,6 +590,10 @@ app.include_router(setup_embedding_routes())
from routes.model_routes import setup_model_routes
app.include_router(setup_model_routes(model_discovery))
+# GitHub Copilot device-flow login
+from routes.copilot_routes import setup_copilot_routes
+app.include_router(setup_copilot_routes())
+
# TTS
from routes.tts_routes import setup_tts_routes
app.include_router(setup_tts_routes(tts_service))
diff --git a/core/database.py b/core/database.py
index 4788a45..8a88b28 100644
--- a/core/database.py
+++ b/core/database.py
@@ -375,6 +375,7 @@ class McpServer(TimestampMixin, Base):
is_enabled = Column(Boolean, default=True)
oauth_config = Column(Text, nullable=True) # JSON: provider, keys_file, token_file, scopes
disabled_tools = Column(Text, nullable=True) # JSON array of tool names to hide from LLM
+ oauth_tokens = Column(EncryptedText, nullable=True) # JSON {tokens, client_info} for generic MCP OAuth, encrypted at rest
class Comparison(TimestampMixin, Base):
@@ -1311,6 +1312,23 @@ def _migrate_add_disabled_tools():
except Exception as e:
logging.getLogger(__name__).warning(f"disabled_tools migration: {e}")
+def _migrate_add_mcp_oauth_tokens_column():
+ """Add oauth_tokens column to mcp_servers table if missing.
+
+ The model declares this column as EncryptedText, but the SQL type is plain
+ TEXT on purpose: EncryptedText is a SQLAlchemy TypeDecorator that encrypts at
+ the Python layer and stores the ciphertext as TEXT, so the DB column type is
+ TEXT. This matches the existing encrypted columns (see _migrate_encrypt_*)."""
+ try:
+ with engine.connect() as conn:
+ cols = [r[1] for r in conn.execute(text("PRAGMA table_info(mcp_servers)"))]
+ if "oauth_tokens" not in cols:
+ conn.execute(text("ALTER TABLE mcp_servers ADD COLUMN oauth_tokens TEXT"))
+ conn.commit()
+ logging.getLogger(__name__).info("Added oauth_tokens column to mcp_servers")
+ except Exception as e:
+ logging.getLogger(__name__).warning(f"oauth_tokens migration: {e}")
+
def _migrate_add_task_v2_columns():
"""Add cron_expression, then_task_id, webhook_token to scheduled_tasks."""
new_cols = {
@@ -1467,6 +1485,10 @@ class CalendarEvent(TimestampMixin, Base):
importance = Column(String, default="normal") # low | normal | high | critical
event_type = Column(String, nullable=True) # work | personal | health | travel | meal | social | admin | other
last_pinged = Column(DateTime, nullable=True) # last time the assistant pinged about this event
+ # "caldav" = pulled from a CalDAV server (so the sync may prune it when it
+ # vanishes upstream). NULL/local = created locally (agent, email triage, or
+ # a UI event whose write-back failed) and must NOT be pruned by the sync.
+ origin = Column(String, nullable=True, index=True)
calendar = relationship("CalendarCal", back_populates="events")
@@ -1589,6 +1611,7 @@ def init_db():
_migrate_add_oauth_config()
_migrate_add_task_automation_columns()
_migrate_add_disabled_tools()
+ _migrate_add_mcp_oauth_tokens_column()
_migrate_add_task_v2_columns()
_migrate_add_notifications_enabled()
_migrate_drop_ping_notes_tasks()
@@ -1598,6 +1621,7 @@ def init_db():
_migrate_seed_email_account()
_migrate_add_calendar_metadata()
_migrate_add_calendar_is_utc()
+ _migrate_add_calendar_origin()
_migrate_encrypt_email_passwords()
_migrate_encrypt_signatures()
_migrate_encrypt_endpoint_keys()
@@ -1740,6 +1764,28 @@ def _migrate_add_calendar_is_utc():
logging.getLogger(__name__).warning(f"is_utc migration failed: {e}")
+def _migrate_add_calendar_origin():
+ """Add `origin` to calendar_events so the CalDAV sync can tell server-pulled
+ rows (prunable when they vanish upstream) from locally-created ones (agent /
+ email triage / failed write-back), which must never be pruned. Idempotent."""
+ import sqlite3
+ db_path = DATABASE_URL.replace("sqlite:///", "")
+ if not os.path.exists(db_path):
+ return
+ try:
+ conn = sqlite3.connect(db_path)
+ cursor = conn.execute("PRAGMA table_info(calendar_events)")
+ columns = [row[1] for row in cursor.fetchall()]
+ if columns and "origin" not in columns:
+ conn.execute("ALTER TABLE calendar_events ADD COLUMN origin TEXT")
+ conn.execute("CREATE INDEX IF NOT EXISTS ix_calendar_events_origin ON calendar_events(origin)")
+ conn.commit()
+ logging.getLogger(__name__).info("Migrated: added 'origin' column to calendar_events")
+ conn.close()
+ except Exception as e:
+ logging.getLogger(__name__).warning(f"calendar_events.origin migration failed: {e}")
+
+
def _migrate_add_calendar_metadata():
"""Add importance/event_type/last_pinged columns to calendar_events table."""
import sqlite3
diff --git a/core/models.py b/core/models.py
index 6914b20..1adae65 100644
--- a/core/models.py
+++ b/core/models.py
@@ -76,8 +76,20 @@ class Session:
_session_manager._persist_message(self.id, message)
def get_context_messages(self) -> List[Dict[str, Any]]:
- """Get messages in format for LLM API."""
- return [msg.to_dict() for msg in self.history]
+ """Get messages in format for LLM API.
+
+ Slash-command / setup replies are persisted to history so they render
+ in the transcript, but they are UI chatter (e.g. ``/setup ...`` and its
+ status lines) the user never meant as conversation. They carry
+ ``metadata.source == "slash"``; exclude them here so they never reach
+ the model. Display/history-load paths use the raw ``history`` and are
+ unaffected.
+ """
+ return [
+ msg.to_dict()
+ for msg in self.history
+ if (msg.metadata or {}).get("source") != "slash"
+ ]
def get(self, key: str, default=None):
"""Dict-like access for compatibility."""
diff --git a/core/session_manager.py b/core/session_manager.py
index 6a884f8..5491929 100644
--- a/core/session_manager.py
+++ b/core/session_manager.py
@@ -273,7 +273,10 @@ class SessionManager:
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
if db_session:
- db_session.message_count = keep_count
+ # keep_count can exceed the real message total (e.g. the AI tool
+ # defaults to keep_count=10 on a short session); message_count must
+ # track the rows that actually remain, not the requested cap.
+ db_session.message_count = min(keep_count, len(db_messages))
db_session.updated_at = datetime.now(timezone.utc)
db.commit()
diff --git a/routes/auth_routes.py b/routes/auth_routes.py
index 1992f8c..644b12d 100644
--- a/routes/auth_routes.py
+++ b/routes/auth_routes.py
@@ -340,6 +340,14 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
ok = auth_manager.rename_user(old_username, new_username, user)
if not ok:
raise HTTPException(400, "Cannot rename user")
+ # The owner-rename loop above updated ApiToken.owner in the DB, but the
+ # bearer-token cache still maps each token to the OLD owner. Without
+ # refreshing it, the renamed user's API tokens resolve to the old (now
+ # non-existent) owner and stop reaching their data until the cache next
+ # goes dirty. Invalidate it now, like the token CRUD routes do.
+ invalidator = getattr(request.app.state, "invalidate_token_cache", None)
+ if callable(invalidator):
+ invalidator()
return {"ok": True, "username": new_username, "renamed_self": old_username == user}
@router.post("/signup-toggle", deprecated=True)
@@ -430,9 +438,24 @@ def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
raise HTTPException(403, "Admin only")
body = await request.json()
current = _load_settings()
+ # Per-key validation for numeric settings: coerce to int and clamp to a
+ # sane range so a bad value can't disable the agent or let it run away.
+ _INT_RANGES = {
+ "agent_max_rounds": (1, 200),
+ "agent_max_tool_calls": (0, 1000), # 0 = unlimited
+ }
for key in DEFAULT_SETTINGS:
- if key in body:
- current[key] = body[key]
+ if key not in body:
+ continue
+ val = body[key]
+ if key in _INT_RANGES:
+ lo, hi = _INT_RANGES[key]
+ try:
+ val = int(val)
+ except (TypeError, ValueError):
+ raise HTTPException(400, f"{key} must be an integer")
+ val = max(lo, min(val, hi))
+ current[key] = val
_save_settings(current)
return current
diff --git a/routes/chat_helpers.py b/routes/chat_helpers.py
index cc20036..0929b69 100644
--- a/routes/chat_helpers.py
+++ b/routes/chat_helpers.py
@@ -589,6 +589,8 @@ def _normalize_thinking(text: str) -> str:
import re
if not text:
return text
+ from src.text_helpers import normalize_thinking_markup
+ text = normalize_thinking_markup(text)
reasoning_prefix_re = re.compile(
r'^\s*(?:thinking(?:\s+process)?\s*:|the user |i need |i should |i will |they are |the question |i can )',
re.IGNORECASE,
@@ -699,6 +701,10 @@ def _extract_thinking_meta(text: str) -> dict | None:
import re
if not text:
return None
+ from src.text_helpers import normalize_thinking_markup
+ original_text = text
+ text = normalize_thinking_markup(text)
+ normalized_changed = text != original_text
# Check for tags (native or injected)
time_match = re.search(r' dict | None:
if thinking and reply:
return {"thinking": thinking, "reply": reply, "time": think_time}
+ if normalized_changed and text.strip() and text.strip() != original_text.strip():
+ return {"thinking": "", "reply": text.strip(), "time": think_time}
+
return None
@@ -737,7 +746,8 @@ def clean_thinking_for_save(content: str, metadata: dict | None = None) -> tuple
md = dict(metadata) if metadata else {}
info = _extract_thinking_meta(content)
if info:
- md["thinking"] = info["thinking"]
+ if info.get("thinking"):
+ md["thinking"] = info["thinking"]
if info.get("time"):
md["thinking_time"] = info["time"]
return info["reply"], md
@@ -781,8 +791,10 @@ def save_assistant_response(
# Extract thinking into metadata (don't pollute message content with tags)
_think_info = _extract_thinking_meta(full_response)
if _think_info:
- md["thinking"] = _think_info["thinking"]
- md["thinking_time"] = _think_info.get("time")
+ if _think_info.get("thinking"):
+ md["thinking"] = _think_info["thinking"]
+ if _think_info.get("time"):
+ md["thinking_time"] = _think_info.get("time")
_content = _think_info["reply"]
else:
_content = full_response
diff --git a/routes/chat_routes.py b/routes/chat_routes.py
index a3c6c16..a18a1a6 100644
--- a/routes/chat_routes.py
+++ b/routes/chat_routes.py
@@ -2,6 +2,7 @@
import asyncio
import json
+import os
import time
import logging
from datetime import datetime
@@ -394,6 +395,12 @@ def setup_chat_routes(
compare_mode = str(form_data.get("compare_mode", "")).lower() == "true"
incognito = str(form_data.get("incognito", "")).lower() == "true"
chat_mode = str(form_data.get("mode", "")).lower() # 'chat' or 'agent'
+ # Workspace: confine the agent's file/shell tools to this folder. Validate
+ # it's a real directory; ignore (no confinement) otherwise.
+ workspace = (form_data.get("workspace") or "").strip()
+ if workspace:
+ _ws_real = os.path.realpath(os.path.expanduser(workspace))
+ workspace = _ws_real if os.path.isdir(_ws_real) else ""
# Did the USER explicitly pick agent mode? (vs. us auto-escalating
# below). Skill extraction should only learn from real agent sessions,
# not chats we quietly promoted for a notes/calendar intent.
@@ -981,7 +988,15 @@ def setup_chat_routes(
_answered_by = None # set if the selected model failed and a fallback answered
try:
from src.settings import get_setting
+ from src.agent_tools import MAX_AGENT_ROUNDS as _DEFAULT_ROUNDS
_tool_budget = int(get_setting("agent_max_tool_calls", 0))
+ # Per-message round cap from settings; clamp defensively in
+ # case settings.json was hand-edited to a bad value.
+ try:
+ _max_rounds = int(get_setting("agent_max_rounds", _DEFAULT_ROUNDS) or _DEFAULT_ROUNDS)
+ except (TypeError, ValueError):
+ _max_rounds = _DEFAULT_ROUNDS
+ _max_rounds = max(1, min(_max_rounds, 200))
async for chunk in stream_agent_loop(
sess.endpoint_url,
@@ -992,12 +1007,14 @@ def setup_chat_routes(
max_tokens=ctx.preset.max_tokens,
prompt_type=preset_id,
max_tool_calls=_tool_budget,
+ max_rounds=_max_rounds,
context_length=ctx.context_length,
active_document=active_doc,
session_id=session,
disabled_tools=disabled_tools if disabled_tools else None,
owner=_user,
fallbacks=_fallback_candidates,
+ workspace=workspace or None,
):
if chunk.startswith("data: ") and not chunk.startswith("data: [DONE]"):
try:
@@ -1017,6 +1034,7 @@ def setup_chat_routes(
"tool_start", "tool_output", "agent_step",
"doc_stream_open", "doc_stream_delta",
"doc_update", "doc_suggestions", "ui_control",
+ "rounds_exhausted",
):
if data.get("type") == "agent_step":
_agent_rounds = max(_agent_rounds, data.get("round", 1))
diff --git a/routes/copilot_routes.py b/routes/copilot_routes.py
new file mode 100644
index 0000000..bb2b1d2
--- /dev/null
+++ b/routes/copilot_routes.py
@@ -0,0 +1,223 @@
+# routes/copilot_routes.py
+"""GitHub Copilot device-flow login.
+
+Drives the GitHub OAuth *device flow* and, on success, creates (or refreshes)
+an owner-scoped ``ModelEndpoint`` pointing at the Copilot API with the
+device-flow access token stored as its (encrypted) ``api_key``. After that the
+endpoint behaves like any other OpenAI-compatible provider — the Copilot-
+specific request headers are injected centrally by ``build_headers`` /
+``_provider_headers`` (see :mod:`src.copilot`).
+
+Flow:
+ 1. ``POST /api/copilot/device/start`` → returns a ``poll_id`` plus the
+ ``user_code`` + ``verification_uri`` to show the user. The secret
+ ``device_code`` is kept server-side, never sent to the browser.
+ 2. The browser polls ``POST /api/copilot/device/poll`` with ``poll_id``.
+ While pending it returns ``{status: "pending"}``; once the user authorises
+ it provisions the endpoint and returns ``{status: "authorized", ...}``.
+
+All routes are admin-gated (endpoint/provider management is an admin action).
+"""
+
+import json
+import time
+import uuid
+import logging
+import threading
+from typing import Dict, Optional
+
+import httpx
+from fastapi import APIRouter, Request, Form, HTTPException
+
+from core.database import SessionLocal, ModelEndpoint
+from core.middleware import require_admin
+from src.auth_helpers import get_current_user
+from src import copilot
+
+logger = logging.getLogger(__name__)
+
+# Pending device-flow logins, keyed by an opaque poll_id. The device_code is a
+# bearer-like secret, so it lives here (server memory) rather than in the
+# browser. Entries expire with the GitHub device code.
+#
+# NOTE: this is per-process state. The device flow assumes a single worker
+# (Odysseus' default): with multiple uvicorn workers, the poll request can land
+# on a worker that never saw the start, returning "Unknown or expired login
+# session". Move this to a shared store (DB/Redis) if running multi-worker.
+_PENDING: Dict[str, Dict] = {}
+_PENDING_LOCK = threading.Lock()
+
+
+def _prune_expired() -> None:
+ now = time.time()
+ with _PENDING_LOCK:
+ for k in [k for k, v in _PENDING.items() if v.get("expires_at", 0) < now]:
+ _PENDING.pop(k, None)
+
+
+def _provision_endpoint(token: str, base: str, owner: Optional[str]) -> Dict:
+ """Create or update the owner's Copilot endpoint with a fresh token."""
+ try:
+ models = copilot.fetch_models(base, token)
+ except Exception as e:
+ logger.warning(f"Copilot model fetch failed during provisioning: {e}")
+ models = []
+ model_ids = [m["id"] for m in models]
+ # Copilot picker models support OpenAI-style tool calling; mark the endpoint
+ # tool-capable so the agent loop sends native tool schemas.
+ # Tool-capable if any picker model advertises tool_calls. When the model
+ # fetch failed (empty list) default to True, since Copilot picker models
+ # support OpenAI-style tool calling.
+ supports_tools = bool(not models or any(m.get("tool_calls") for m in models))
+
+ db = SessionLocal()
+ try:
+ ep = (
+ db.query(ModelEndpoint)
+ .filter(ModelEndpoint.base_url == base)
+ .filter((ModelEndpoint.owner.is_(None)) | (ModelEndpoint.owner == owner))
+ .order_by(ModelEndpoint.owner.desc())
+ .first()
+ )
+ if ep is None:
+ ep = ModelEndpoint(
+ id=str(uuid.uuid4())[:8],
+ name="GitHub Copilot",
+ base_url=base,
+ model_type="llm",
+ owner=owner,
+ )
+ db.add(ep)
+ ep.api_key = token
+ ep.is_enabled = True
+ ep.supports_tools = supports_tools
+ if model_ids:
+ ep.cached_models = json.dumps(model_ids)
+ db.commit()
+ result = {
+ "id": ep.id,
+ "name": ep.name,
+ "base_url": ep.base_url,
+ "models": model_ids,
+ }
+ finally:
+ db.close()
+
+ # Best-effort: refresh the model cache so the new endpoint shows up.
+ try:
+ from routes.model_routes import _invalidate_models_cache
+ _invalidate_models_cache()
+ except Exception:
+ pass
+ return result
+
+
+def setup_copilot_routes() -> APIRouter:
+ router = APIRouter(prefix="/api/copilot", tags=["copilot"])
+
+ @router.post("/device/start")
+ def device_start(request: Request, enterprise_url: str = Form("")):
+ require_admin(request)
+ _prune_expired()
+ host = copilot.GITHUB_HOST
+ ent = (enterprise_url or "").strip()
+ if ent:
+ host = copilot.normalize_domain(ent)
+ try:
+ data = copilot.request_device_code(host)
+ except httpx.HTTPStatusError as e:
+ status = e.response.status_code if e.response is not None else "unknown"
+ raise HTTPException(502, f"GitHub device-code request failed (HTTP {status})")
+ except Exception as e:
+ raise HTTPException(502, f"GitHub device-code request failed: {e}")
+
+ device_code = data.get("device_code")
+ if not device_code:
+ raise HTTPException(502, "GitHub did not return a device code")
+ interval = int(data.get("interval") or 5)
+ expires_in = int(data.get("expires_in") or 900)
+ poll_id = uuid.uuid4().hex
+ with _PENDING_LOCK:
+ _PENDING[poll_id] = {
+ "device_code": device_code,
+ "host": host,
+ "enterprise_url": ent,
+ "interval": interval,
+ "owner": get_current_user(request) or None,
+ "expires_at": time.time() + expires_in,
+ "next_poll_at": 0.0,
+ }
+ # verification_uri_complete embeds the user code, so the browser tab we
+ # open lands the user straight on GitHub's "Authorize" screen with the
+ # code pre-filled — one click, no manual code entry.
+ return {
+ "poll_id": poll_id,
+ "user_code": data.get("user_code"),
+ "verification_uri": data.get("verification_uri"),
+ "verification_uri_complete": data.get("verification_uri_complete"),
+ "interval": interval,
+ "expires_in": expires_in,
+ }
+
+ @router.post("/device/poll")
+ def device_poll(request: Request, poll_id: str = Form(...)):
+ require_admin(request)
+ _prune_expired()
+ with _PENDING_LOCK:
+ pending = _PENDING.get(poll_id)
+ if not pending:
+ raise HTTPException(404, "Unknown or expired login session")
+
+ # Enforce GitHub's polling interval server-side so a chatty client
+ # can't trip slow_down.
+ now = time.time()
+ if now < pending.get("next_poll_at", 0):
+ return {"status": "pending"}
+
+ try:
+ data = copilot.poll_access_token(pending["host"], pending["device_code"])
+ except Exception as e:
+ return {"status": "pending", "detail": f"poll error: {e}"}
+
+ token = data.get("access_token")
+ if token:
+ base = copilot.enterprise_base(pending["enterprise_url"]) if pending["enterprise_url"] else copilot.COPILOT_BASE
+ try:
+ result = _provision_endpoint(token, base, pending["owner"])
+ except Exception as e:
+ logger.exception("Copilot endpoint provisioning failed")
+ with _PENDING_LOCK:
+ _PENDING.pop(poll_id, None)
+ raise HTTPException(500, f"Login succeeded but provisioning failed: {e}")
+ with _PENDING_LOCK:
+ _PENDING.pop(poll_id, None)
+ return {"status": "authorized", "endpoint": result}
+
+ err = data.get("error")
+ if err == "authorization_pending":
+ with _PENDING_LOCK:
+ if poll_id in _PENDING:
+ _PENDING[poll_id]["next_poll_at"] = now + pending["interval"]
+ return {"status": "pending"}
+ if err == "slow_down":
+ new_interval = int(data.get("interval") or (pending["interval"] + 5))
+ with _PENDING_LOCK:
+ if poll_id in _PENDING:
+ _PENDING[poll_id]["interval"] = new_interval
+ _PENDING[poll_id]["next_poll_at"] = now + new_interval
+ return {"status": "pending"}
+ if err in ("expired_token", "access_denied"):
+ with _PENDING_LOCK:
+ _PENDING.pop(poll_id, None)
+ return {"status": "failed", "error": err}
+ # Unknown error — surface but keep the session for another try.
+ return {"status": "pending", "detail": err or "unknown"}
+
+ @router.post("/device/cancel")
+ def device_cancel(request: Request, poll_id: str = Form(...)):
+ require_admin(request)
+ with _PENDING_LOCK:
+ _PENDING.pop(poll_id, None)
+ return {"status": "cancelled"}
+
+ return router
diff --git a/routes/document_routes.py b/routes/document_routes.py
index 5625df8..03661b2 100644
--- a/routes/document_routes.py
+++ b/routes/document_routes.py
@@ -153,7 +153,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
with a `pdf_source` marker so the viewer renders the pages without
overlays.
"""
- from src.constants import UPLOAD_DIR
from src.pdf_forms import has_form_fields, extract_fields
from src.pdf_form_doc import (
save_field_sidecar,
@@ -950,7 +949,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
any wrong values before triggering the actual download.
"""
from src.pdf_form_doc import find_source_upload_id, parse_markdown_to_values, load_field_sidecar
- from src.constants import UPLOAD_DIR
user = get_current_user(request)
db = SessionLocal()
@@ -1015,7 +1013,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
Frontend overlays HTML form controls at those positions.
"""
from src.pdf_form_doc import find_source_upload_id, parse_markdown_to_values, load_field_sidecar
- from src.constants import UPLOAD_DIR
user = get_current_user(request)
db = SessionLocal()
@@ -1083,7 +1080,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
frontend overlays HTML form inputs on top)."""
from fastapi.responses import Response
from src.pdf_form_doc import find_source_upload_id
- from src.constants import UPLOAD_DIR
user = get_current_user(request)
db = SessionLocal()
@@ -1132,7 +1128,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
import json
import fitz
from src.pdf_form_doc import find_source_upload_id
- from src.constants import UPLOAD_DIR
from src.document_processor import _resolve_vl_model, _load_vl_settings
from src.llm_core import llm_call_async
@@ -1275,7 +1270,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
from starlette.background import BackgroundTask
from src.pdf_form_doc import find_source_upload_id, parse_markdown_to_values, parse_markdown_annotations
from src.pdf_forms import fill_fields, stamp_annotations
- from src.constants import UPLOAD_DIR
from core.database import Signature
# Track temp files for this request so they get unlinked AFTER
@@ -1370,7 +1364,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
from starlette.background import BackgroundTask
from src.pdf_form_doc import find_source_upload_id, parse_markdown_to_values, load_field_sidecar, parse_markdown_annotations
from src.pdf_forms import fill_fields, stamp_signatures, stamp_annotations
- from src.constants import UPLOAD_DIR
from core.database import Signature
_to_unlink: list[str] = []
@@ -1512,7 +1505,6 @@ def setup_document_routes(session_manager, upload_handler=None) -> APIRouter:
load_field_sidecar, parse_markdown_annotations,
)
from src.pdf_forms import fill_fields, stamp_signatures, stamp_annotations
- from src.constants import UPLOAD_DIR
from core.database import Signature
# COMPOSE_UPLOADS_DIR lives in email_routes — re-derive here so we
# don't import from a routes file (cycle-prone). Same env override
diff --git a/routes/email_helpers.py b/routes/email_helpers.py
index 409c6c4..fef2944 100644
--- a/routes/email_helpers.py
+++ b/routes/email_helpers.py
@@ -266,6 +266,48 @@ COMPOSE_UPLOADS_DIR.mkdir(parents=True, exist_ok=True)
SCHEDULED_DB = DATA_DIR / "scheduled_emails.db"
+OWNER_SCOPED_EMAIL_CACHE_TABLES = {
+ "email_summaries",
+ "email_ai_replies",
+ "email_calendar_extractions",
+ "email_urgency_alerts",
+}
+
+
+def _email_cache_owner_clause(owner: str = "") -> tuple[str, tuple[str, ...]]:
+ owner = (owner or "").strip()
+ if owner:
+ return "owner = ?", (owner,)
+ return "(owner = '' OR owner IS NULL)", ()
+
+
+def _ensure_owner_scoped_email_cache_table(conn, table: str, create_sql: str, columns: list[str]):
+ """Rebuild legacy Message-ID-only cache tables with owner in the PK."""
+ conn.execute(create_sql)
+ try:
+ info = conn.execute(f"PRAGMA table_info({table})").fetchall()
+ cols = [r[1] for r in info]
+ pk_cols = [r[1] for r in sorted((r for r in info if r[5]), key=lambda r: r[5])]
+ if "owner" in cols and pk_cols == ["message_id", "owner"]:
+ return
+
+ conn.execute(f"ALTER TABLE {table} RENAME TO {table}__old")
+ conn.execute(create_sql)
+ old_cols = [r[1] for r in conn.execute(f"PRAGMA table_info({table}__old)").fetchall()]
+ copy_cols = [c for c in columns if c != "owner" and c in old_cols]
+ source_owner = "COALESCE(owner, '')" if "owner" in old_cols else "''"
+ target_cols = ["owner", *copy_cols]
+ select_exprs = [source_owner, *copy_cols]
+ conn.execute(
+ f"INSERT OR IGNORE INTO {table} ({', '.join(target_cols)}) "
+ f"SELECT {', '.join(select_exprs)} FROM {table}__old"
+ )
+ conn.execute(f"DROP TABLE {table}__old")
+ except Exception as _mig_e:
+ import logging as _lg
+ _lg.getLogger(__name__).warning(f"{table} owner-migration skipped: {_mig_e}")
+
+
def attachment_extract_dir(folder: str, uid: str) -> Path:
"""Containment-safe extraction directory for an attachment.
@@ -301,30 +343,35 @@ def _init_scheduled_db():
owner TEXT DEFAULT ''
)
""")
- # Email summary cache (keyed by Message-ID)
- conn.execute("""
+ # Email summary cache. SECURITY: Message-IDs are global, so AI-derived
+ # cache rows must be owner-scoped just like email_tags.
+ _ensure_owner_scoped_email_cache_table(conn, "email_summaries", """
CREATE TABLE IF NOT EXISTS email_summaries (
- message_id TEXT PRIMARY KEY,
+ message_id TEXT,
+ owner TEXT DEFAULT '',
uid TEXT,
folder TEXT,
subject TEXT,
sender TEXT,
summary TEXT NOT NULL,
model_used TEXT,
- created_at TEXT NOT NULL
+ created_at TEXT NOT NULL,
+ PRIMARY KEY (message_id, owner)
)
- """)
+ """, ["message_id", "owner", "uid", "folder", "subject", "sender", "summary", "model_used", "created_at"])
# Email AI reply cache (pre-generated draft replies)
- conn.execute("""
+ _ensure_owner_scoped_email_cache_table(conn, "email_ai_replies", """
CREATE TABLE IF NOT EXISTS email_ai_replies (
- message_id TEXT PRIMARY KEY,
+ message_id TEXT,
+ owner TEXT DEFAULT '',
uid TEXT,
folder TEXT,
reply TEXT NOT NULL,
model_used TEXT,
- created_at TEXT NOT NULL
+ created_at TEXT NOT NULL,
+ PRIMARY KEY (message_id, owner)
)
- """)
+ """, ["message_id", "owner", "uid", "folder", "reply", "model_used", "created_at"])
# Email tags / spam classification cache. SECURITY: keyed by
# (message_id, owner) because Message-IDs are GLOBAL (a newsletter goes
# to many users with the same Message-ID). Without owner-scoping, a
@@ -384,17 +431,20 @@ def _init_scheduled_db():
# Best-effort — log via the module logger if available
import logging as _lg
_lg.getLogger(__name__).warning(f"email_tags owner-migration skipped: {_mig_e}")
- conn.execute("""
+ _ensure_owner_scoped_email_cache_table(conn, "email_calendar_extractions", """
CREATE TABLE IF NOT EXISTS email_calendar_extractions (
- message_id TEXT PRIMARY KEY,
+ message_id TEXT,
+ owner TEXT DEFAULT '',
uid TEXT,
events_created INTEGER DEFAULT 0,
- created_at TEXT NOT NULL
+ created_at TEXT NOT NULL,
+ PRIMARY KEY (message_id, owner)
)
- """)
- conn.execute("""
+ """, ["message_id", "owner", "uid", "events_created", "created_at"])
+ _ensure_owner_scoped_email_cache_table(conn, "email_urgency_alerts", """
CREATE TABLE IF NOT EXISTS email_urgency_alerts (
- message_id TEXT PRIMARY KEY,
+ message_id TEXT,
+ owner TEXT DEFAULT '',
uid TEXT,
folder TEXT,
subject TEXT,
@@ -402,9 +452,10 @@ def _init_scheduled_db():
urgency TEXT,
reason TEXT,
alerted INTEGER DEFAULT 0,
- created_at TEXT NOT NULL
+ created_at TEXT NOT NULL,
+ PRIMARY KEY (message_id, owner)
)
- """)
+ """, ["message_id", "owner", "uid", "folder", "subject", "sender", "urgency", "reason", "alerted", "created_at"])
conn.execute("""
CREATE TABLE IF NOT EXISTS email_event_seen (
owner TEXT NOT NULL,
diff --git a/routes/email_pollers.py b/routes/email_pollers.py
index 7bed2f6..04ffb0a 100644
--- a/routes/email_pollers.py
+++ b/routes/email_pollers.py
@@ -39,7 +39,7 @@ from routes.email_helpers import (
_extract_attachment_text, _extract_text,
_pre_retrieve_context,
_attach_compose_uploads, _cleanup_compose_uploads, _q,
- SCHEDULED_DB, _EMAIL_REPLY_SYS_PROMPT_BASE,
+ SCHEDULED_DB, _EMAIL_REPLY_SYS_PROMPT_BASE, _email_cache_owner_clause,
)
logger = logging.getLogger(__name__)
@@ -243,8 +243,15 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
await _emit_progress(progress_cb, f"Found {len(uid_list)} recent email(s); checking cache…")
_c = _sql3.connect(SCHEDULED_DB)
- _sum_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_summaries").fetchall()}
- _reply_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_ai_replies").fetchall()}
+ _cache_owner_clause, _cache_owner_params = _email_cache_owner_clause(account_owner)
+ _sum_existing = {r[0] for r in _c.execute(
+ f"SELECT message_id FROM email_summaries WHERE {_cache_owner_clause}",
+ _cache_owner_params,
+ ).fetchall()}
+ _reply_existing = {r[0] for r in _c.execute(
+ f"SELECT message_id FROM email_ai_replies WHERE {_cache_owner_clause}",
+ _cache_owner_params,
+ ).fetchall()}
if auto_tag or auto_spam:
if account_owner:
_tag_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_tags WHERE owner=?", (account_owner,)).fetchall()}
@@ -252,12 +259,18 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
_tag_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_tags WHERE owner='' OR owner IS NULL").fetchall()}
else:
_tag_existing = set()
- _cal_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_calendar_extractions").fetchall()} if auto_cal else set()
+ _cal_existing = {r[0] for r in _c.execute(
+ f"SELECT message_id FROM email_calendar_extractions WHERE {_cache_owner_clause}",
+ _cache_owner_params,
+ ).fetchall()} if auto_cal else set()
# Urgency is handled by the built-in `check_email_urgency` task. Keep
# this legacy poller path disabled so users don't get two independent
# urgent-email systems.
auto_urgent = False
- _urgent_existing = {r[0] for r in _c.execute("SELECT message_id FROM email_urgency_alerts").fetchall()} if auto_urgent else set()
+ _urgent_existing = {r[0] for r in _c.execute(
+ f"SELECT message_id FROM email_urgency_alerts WHERE {_cache_owner_clause}",
+ _cache_owner_params,
+ ).fetchall()} if auto_urgent else set()
_c.close()
# Hoist the self-address lookup OUT of the per-email loop — fetching
@@ -415,9 +428,9 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
_c = _sql3.connect(SCHEDULED_DB)
_c.execute("""
INSERT OR REPLACE INTO email_summaries
- (message_id, uid, folder, subject, sender, summary, model_used, created_at)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
- """, (message_id, uid.decode() if isinstance(uid, bytes) else str(uid), _folder, subject, sender, summary, model, datetime.utcnow().isoformat()))
+ (message_id, owner, uid, folder, subject, sender, summary, model_used, created_at)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """, (message_id, account_owner or "", uid.decode() if isinstance(uid, bytes) else str(uid), _folder, subject, sender, summary, model, datetime.utcnow().isoformat()))
_c.commit()
_c.close()
_sum_existing.add(message_id)
@@ -458,9 +471,9 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
_c = _sql3.connect(SCHEDULED_DB)
_c.execute("""
INSERT OR REPLACE INTO email_ai_replies
- (message_id, uid, folder, reply, model_used, created_at)
- VALUES (?, ?, ?, ?, ?, ?)
- """, (message_id, uid.decode() if isinstance(uid, bytes) else str(uid), _folder, reply, model, datetime.utcnow().isoformat()))
+ (message_id, owner, uid, folder, reply, model_used, created_at)
+ VALUES (?, ?, ?, ?, ?, ?, ?)
+ """, (message_id, account_owner or "", uid.decode() if isinstance(uid, bytes) else str(uid), _folder, reply, model, datetime.utcnow().isoformat()))
_c.commit()
_c.close()
_reply_existing.add(message_id)
@@ -675,8 +688,8 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
_cc = _sql3.connect(SCHEDULED_DB)
_cc.execute(
"INSERT OR REPLACE INTO email_calendar_extractions "
- "(message_id, uid, events_created, created_at) VALUES (?, ?, ?, ?)",
- (message_id, uid.decode() if isinstance(uid, bytes) else str(uid),
+ "(message_id, owner, uid, events_created, created_at) VALUES (?, ?, ?, ?, ?)",
+ (message_id, account_owner or "", uid.decode() if isinstance(uid, bytes) else str(uid),
_cal_run_count, datetime.utcnow().isoformat())
)
_cc.commit()
@@ -733,9 +746,9 @@ async def _auto_summarize_pass_single(days_back: int = 1, account_id: str | None
_uc = _sql3.connect(SCHEDULED_DB)
_uc.execute(
"INSERT OR REPLACE INTO email_urgency_alerts "
- "(message_id, uid, folder, subject, sender, urgency, reason, alerted, created_at) "
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
- (message_id, uid.decode() if isinstance(uid, bytes) else str(uid),
+ "(message_id, owner, uid, folder, subject, sender, urgency, reason, alerted, created_at) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (message_id, account_owner or "", uid.decode() if isinstance(uid, bytes) else str(uid),
_folder, subject, sender, urgency, reason,
1 if urgency in ("critical", "high") else 0,
datetime.utcnow().isoformat())
diff --git a/routes/email_routes.py b/routes/email_routes.py
index 87f8e76..7ab033b 100644
--- a/routes/email_routes.py
+++ b/routes/email_routes.py
@@ -49,7 +49,7 @@ from routes.email_helpers import (
_EMAIL_REPLY_SYS_PROMPT_BASE, _POOL_HOOKS,
SendEmailRequest, ExtractStyleRequest,
ATTACHMENTS_DIR, COMPOSE_UPLOADS_DIR, SCHEDULED_DB,
- attachment_extract_dir,
+ attachment_extract_dir, _email_cache_owner_clause,
)
from routes.email_pollers import _start_poller
@@ -934,9 +934,11 @@ def setup_email_routes():
import sqlite3 as _sql3
_c = _sql3.connect(SCHEDULED_DB)
placeholders = ",".join("?" * len(ids))
+ owner_clause, owner_params = _email_cache_owner_clause(owner)
rows = _c.execute(
- f"SELECT message_id, summary FROM email_summaries WHERE message_id IN ({placeholders})",
- ids,
+ f"SELECT message_id, summary FROM email_summaries "
+ f"WHERE message_id IN ({placeholders}) AND {owner_clause}",
+ (*ids, *owner_params),
).fetchall()
_c.close()
by_id = {r[0]: r[1] for r in rows}
@@ -1219,15 +1221,16 @@ def setup_email_routes():
try:
import sqlite3 as _sql3
_c = _sql3.connect(SCHEDULED_DB)
+ owner_clause, owner_params = _email_cache_owner_clause(owner)
_row = _c.execute(
- "SELECT summary FROM email_summaries WHERE message_id = ?",
- (message_id.strip(),),
+ f"SELECT summary FROM email_summaries WHERE message_id = ? AND {owner_clause}",
+ (message_id.strip(), *owner_params),
).fetchone()
if _row:
cached_summary = _row[0]
_row2 = _c.execute(
- "SELECT reply FROM email_ai_replies WHERE message_id = ?",
- (message_id.strip(),),
+ f"SELECT reply FROM email_ai_replies WHERE message_id = ? AND {owner_clause}",
+ (message_id.strip(), *owner_params),
).fetchone()
if _row2:
cached_ai_reply = _apply_email_style_mechanics(_extract_reply(_row2[0] or ""))
@@ -2549,10 +2552,10 @@ def setup_email_routes():
_c = _sql3.connect(SCHEDULED_DB)
_c.execute("""
INSERT OR REPLACE INTO email_summaries
- (message_id, uid, folder, subject, sender, summary, model_used, created_at)
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
+ (message_id, owner, uid, folder, subject, sender, summary, model_used, created_at)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
- mid, data.get("uid", ""), data.get("folder", ""),
+ mid, owner, data.get("uid", ""), data.get("folder", ""),
subject, sender, content, model, datetime.utcnow().isoformat(),
))
_c.commit()
@@ -2587,9 +2590,10 @@ def setup_email_routes():
if message_id:
try:
_c = _sql3.connect(SCHEDULED_DB)
+ owner_clause, owner_params = _email_cache_owner_clause(owner)
_row = _c.execute(
- "SELECT reply, model_used FROM email_ai_replies WHERE message_id = ?",
- (message_id,),
+ f"SELECT reply, model_used FROM email_ai_replies WHERE message_id = ? AND {owner_clause}",
+ (message_id, *owner_params),
).fetchone()
_c.close()
if _row and _row[0]:
@@ -2791,9 +2795,9 @@ def setup_email_routes():
_c = _sql3.connect(SCHEDULED_DB)
_c.execute("""
INSERT OR REPLACE INTO email_ai_replies
- (message_id, uid, folder, reply, model_used, created_at)
- VALUES (?, ?, ?, ?, ?, ?)
- """, (message_id, source_uid, source_folder, reply, model, datetime.utcnow().isoformat()))
+ (message_id, owner, uid, folder, reply, model_used, created_at)
+ VALUES (?, ?, ?, ?, ?, ?, ?)
+ """, (message_id, owner, source_uid, source_folder, reply, model, datetime.utcnow().isoformat()))
_c.commit()
_c.close()
except Exception as e:
diff --git a/routes/history_routes.py b/routes/history_routes.py
index 9efaa94..35aaff2 100644
--- a/routes/history_routes.py
+++ b/routes/history_routes.py
@@ -10,11 +10,36 @@ from fastapi import APIRouter, Request, HTTPException
from core.models import ChatMessage
from core.database import SessionLocal, ChatMessage as DbChatMessage, Session as DbSession
from src.topic_analyzer import analyze_topics
-from routes.session_routes import _verify_session_owner
+from routes.session_routes import (
+ _message_role,
+ _message_text,
+ _reject_compact_during_active_run,
+ _verify_session_owner,
+)
logger = logging.getLogger(__name__)
+def _merge_continue_rows_to_delete(db_messages, db1, db2):
+ """DB rows to delete when merging the last two assistant messages.
+
+ Always the second assistant message (db2), plus ONLY the single
+ intervening "continue" user message (the one carrying "previous response
+ was interrupted") — matching the in-memory merge. The previous code
+ deleted the whole index range between the two assistant rows, destroying
+ any tool/system/user messages in between and desyncing the DB from the
+ in-memory history.
+ """
+ to_delete = [db2]
+ i1 = next((i for i, m in enumerate(db_messages) if m is db1), None)
+ i2 = next((i for i, m in enumerate(db_messages) if m is db2), None)
+ if i1 is not None and i2 is not None and i2 - 1 > i1:
+ between = db_messages[i2 - 1]
+ if getattr(between, "role", "") == "user" and "previous response was interrupted" in (getattr(between, "content", "") or ""):
+ to_delete.append(between)
+ return to_delete
+
+
def setup_history_routes(session_manager) -> APIRouter:
router = APIRouter(tags=["history"])
@@ -418,11 +443,13 @@ def setup_history_routes(session_manager) -> APIRouter:
db1.content = merged_content
db1.meta_data = _json.dumps(merged_meta)
- # Remove the continue user message if between them
- db_idx2 = db_messages.index(db2)
- db_idx1 = db_messages.index(db1)
- for di in range(db_idx2, db_idx1, -1):
- db.delete(db_messages[di])
+ # Mirror the in-memory deletion: remove the second assistant
+ # message and ONLY the "continue" user message between them
+ # (not arbitrary tool/system/user rows). The old
+ # range-delete destroyed every row between the two assistant
+ # messages, desyncing the DB from the in-memory history.
+ for _row in _merge_continue_rows_to_delete(db_messages, db1, db2):
+ db.delete(_row)
db.commit()
finally:
@@ -499,6 +526,7 @@ def setup_history_routes(session_manager) -> APIRouter:
session = session_manager.get_session(session_id)
except KeyError:
raise HTTPException(404, "Session not found")
+ _reject_compact_during_active_run(session_id)
try:
from src.model_context import estimate_tokens, get_context_length
@@ -521,8 +549,8 @@ def setup_history_routes(session_manager) -> APIRouter:
# Build text to summarize
convo_text = "\n".join(
- f"{(m.role if isinstance(m, ChatMessage) else m.get('role', '')).upper()}: "
- f"{(m.content if isinstance(m, ChatMessage) else m.get('content', ''))[:2000]}"
+ f"{_message_role(m).upper()}: "
+ f"{_message_text(m)[:2000]}"
for m in older
)
diff --git a/routes/mcp_routes.py b/routes/mcp_routes.py
index c09108f..e3a73c8 100644
--- a/routes/mcp_routes.py
+++ b/routes/mcp_routes.py
@@ -5,6 +5,7 @@ import os
import uuid
import urllib.parse
import html
+from pathlib import Path
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import RedirectResponse, HTMLResponse
import logging
@@ -12,6 +13,7 @@ import httpx
from core.database import McpServer, SessionLocal
from core.middleware import require_admin
+from src.constants import DATA_DIR
from src.mcp_manager import McpManager
logger = logging.getLogger(__name__)
@@ -19,6 +21,75 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/mcp", tags=["mcp"])
+def _mcp_oauth_base_dir() -> Path:
+ """Directory that may contain OAuth files managed by Odysseus."""
+ return (Path(DATA_DIR) / "mcp_oauth").resolve(strict=False)
+
+
+def _resolve_mcp_oauth_path(raw_path, field_name: str) -> str:
+ """Resolve an MCP OAuth path and keep it under DATA_DIR/mcp_oauth."""
+ raw = str(raw_path or "").strip()
+ if not raw:
+ return ""
+
+ base = _mcp_oauth_base_dir()
+ path = Path(os.path.expanduser(raw))
+ if not path.is_absolute():
+ path = base / path
+ resolved = path.resolve(strict=False)
+
+ try:
+ resolved.relative_to(base)
+ except ValueError as exc:
+ raise HTTPException(
+ 400,
+ f"Invalid OAuth {field_name}: path must stay under {base}",
+ ) from exc
+ return str(resolved)
+
+
+def _sanitize_mcp_oauth_config(oauth_cfg):
+ """Return an OAuth config copy with file paths confined to mcp_oauth."""
+ if not oauth_cfg:
+ return oauth_cfg
+ if not isinstance(oauth_cfg, dict):
+ return {}
+ sanitized = dict(oauth_cfg)
+ for field_name in ("keys_file", "token_file"):
+ if sanitized.get(field_name):
+ sanitized[field_name] = _resolve_mcp_oauth_path(
+ sanitized[field_name],
+ field_name,
+ )
+ return sanitized
+
+
+def _mcp_oauth_token_missing(oauth_cfg, *, strict: bool = True) -> bool:
+ """Check token existence without letting legacy bad paths break listing."""
+ if not isinstance(oauth_cfg, dict):
+ return False
+ try:
+ token_file = _resolve_mcp_oauth_path(oauth_cfg.get("token_file", ""), "token_file")
+ except HTTPException:
+ if strict:
+ raise
+ logger.warning("Ignoring MCP OAuth config with unsafe token_file")
+ return True
+ return bool(token_file and not os.path.exists(token_file))
+
+
+def _apply_mcp_oauth_env(env: dict, oauth_cfg) -> None:
+ """Pass sanitized Gmail package paths to MCP servers that honor them."""
+ if not oauth_cfg or not isinstance(env, dict):
+ return
+ keys_file = oauth_cfg.get("keys_file")
+ token_file = oauth_cfg.get("token_file")
+ if keys_file:
+ env["GMAIL_OAUTH_PATH"] = keys_file
+ if token_file:
+ env["GMAIL_CREDENTIALS_PATH"] = token_file
+
+
def _load_disabled_map():
"""Load per-server disabled tool sets from DB."""
db = SessionLocal()
@@ -53,8 +124,7 @@ def setup_mcp_routes(mcp_manager: McpManager):
oauth_cfg = json.loads(srv.oauth_config) if srv.oauth_config else None
needs_oauth = False
if oauth_cfg:
- token_file = os.path.expanduser(oauth_cfg.get("token_file", ""))
- needs_oauth = token_file and not os.path.exists(token_file)
+ needs_oauth = _mcp_oauth_token_missing(oauth_cfg, strict=False)
disabled_list = json.loads(srv.disabled_tools) if srv.disabled_tools else []
total_tools = status.get("tool_count", 0)
result.append({
@@ -71,6 +141,7 @@ def setup_mcp_routes(mcp_manager: McpManager):
"disabled_tool_count": len(disabled_list),
"enabled_tool_count": max(0, total_tools - len(disabled_list)),
"error": status.get("error"),
+ "auth_url": status.get("auth_url"),
"has_oauth": oauth_cfg is not None,
"needs_oauth": needs_oauth,
})
@@ -101,6 +172,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
raise HTTPException(400, "command is required for stdio transport")
if transport == "sse" and not url:
raise HTTPException(400, "url is required for SSE transport")
+ if transport == "http" and not url:
+ raise HTTPException(400, "url is required for HTTP transport")
# Parse JSON fields
try:
@@ -111,26 +184,33 @@ def setup_mcp_routes(mcp_manager: McpManager):
parsed_env = json.loads(env) if env else {}
except json.JSONDecodeError:
parsed_env = {}
+ if not isinstance(parsed_env, dict):
+ parsed_env = {}
# Parse OAuth config
parsed_oauth_config = None
if oauth_config:
try:
- parsed_oauth_config = json.loads(oauth_config)
+ parsed_oauth_config = _sanitize_mcp_oauth_config(json.loads(oauth_config))
except json.JSONDecodeError:
pass
+ _apply_mcp_oauth_env(parsed_env, parsed_oauth_config)
# Write OAuth credentials file if provided (for Google MCP servers)
logger.info(f"MCP add_server: oauth_file={oauth_file!r}")
if oauth_file:
try:
oauth_data = json.loads(oauth_file)
- oauth_dir = os.path.expanduser(oauth_data.get("dir", ""))
+ oauth_dir = _resolve_mcp_oauth_path(oauth_data.get("dir", ""), "dir")
oauth_filename = oauth_data.get("filename", "")
client_id = oauth_data.get("client_id", "")
client_secret = oauth_data.get("client_secret", "")
if oauth_dir and oauth_filename and client_id and client_secret:
- os.makedirs(oauth_dir, exist_ok=True)
+ filepath = _resolve_mcp_oauth_path(
+ Path(oauth_dir) / str(oauth_filename),
+ "filename",
+ )
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
creds = {
"installed": {
"client_id": client_id,
@@ -140,7 +220,6 @@ def setup_mcp_routes(mcp_manager: McpManager):
"token_uri": "https://accounts.google.com/o/oauth2/token",
}
}
- filepath = os.path.join(oauth_dir, oauth_filename)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(creds, f, indent=2)
logger.info(f"Wrote OAuth credentials to {filepath}")
@@ -171,9 +250,7 @@ def setup_mcp_routes(mcp_manager: McpManager):
# Check if OAuth token already exists — skip connection attempt if not
needs_oauth = False
if parsed_oauth_config:
- token_file = os.path.expanduser(parsed_oauth_config.get("token_file", ""))
- if token_file and not os.path.exists(token_file):
- needs_oauth = True
+ needs_oauth = _mcp_oauth_token_missing(parsed_oauth_config)
connected = False
if not needs_oauth:
@@ -188,6 +265,7 @@ def setup_mcp_routes(mcp_manager: McpManager):
)
status = mcp_manager.get_server_status(server_id)
+ needs_auth = status.get("status") == "needs_auth"
return {
"id": server_id,
"name": name,
@@ -196,6 +274,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
"tool_count": status.get("tool_count", 0),
"error": "OAuth authorization required" if needs_oauth else status.get("error"),
"needs_oauth": needs_oauth,
+ "needs_auth": needs_auth,
+ "auth_url": status.get("auth_url"),
}
@router.post("/servers/{server_id}/reconnect")
@@ -228,6 +308,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
"status": status.get("status", "disconnected"),
"tool_count": status.get("tool_count", 0),
"error": status.get("error"),
+ "auth_url": status.get("auth_url"),
+ "needs_auth": status.get("status") == "needs_auth",
}
finally:
db.close()
@@ -349,8 +431,8 @@ def setup_mcp_routes(mcp_manager: McpManager):
if not srv.oauth_config:
raise HTTPException(400, "Server has no OAuth config")
- oauth_cfg = json.loads(srv.oauth_config)
- keys_file = os.path.expanduser(oauth_cfg.get("keys_file", ""))
+ oauth_cfg = _sanitize_mcp_oauth_config(json.loads(srv.oauth_config))
+ keys_file = oauth_cfg.get("keys_file", "")
if not keys_file or not os.path.exists(keys_file):
raise HTTPException(400, "OAuth keys file not found")
@@ -393,10 +475,18 @@ def setup_mcp_routes(mcp_manager: McpManager):
@router.get("/oauth/callback")
async def oauth_callback(code: str, state: str, request: Request):
- """Handle OAuth callback from Google — exchange code for tokens."""
+ """Handle OAuth callback. Generic MCP OAuth flows resolve via the
+ pending-state registry; Google flows fall through to the legacy path."""
require_admin(request)
- server_id = state
- return await _exchange_and_connect(server_id, code, request)
+ from src.mcp_oauth import resolve_pending
+ if resolve_pending(state, code):
+ return HTMLResponse(_oauth_result_page(
+ "Authorization Successful",
+ "The MCP server is connecting. You can close this window and return to Odysseus.",
+ success=True,
+ ))
+ # Legacy Google path: state is the server_id
+ return await _exchange_and_connect(state, code, request)
@router.post("/oauth/exchange/{server_id}")
async def oauth_exchange(server_id: str, request: Request, callback_url: str = Form(...)):
@@ -411,6 +501,17 @@ def setup_mcp_routes(mcp_manager: McpManager):
except Exception:
return HTMLResponse(_oauth_result_page("Error", "Invalid URL format."), status_code=400)
+ # Generic MCP OAuth: if the pasted URL carries a state we are waiting on,
+ # resolve it directly (the background connect finishes the handshake).
+ state = params.get("state", [None])[0]
+ from src.mcp_oauth import resolve_pending
+ if state and resolve_pending(state, code):
+ return HTMLResponse(_oauth_result_page(
+ "Authorization Successful",
+ "The MCP server is connecting. You can close this window and return to Odysseus.",
+ success=True,
+ ))
+
return await _exchange_and_connect(server_id, code, request)
async def _exchange_and_connect(server_id: str, code: str, request: Request):
@@ -423,9 +524,11 @@ def setup_mcp_routes(mcp_manager: McpManager):
if not srv.oauth_config:
return HTMLResponse(_oauth_result_page("Error", "No OAuth config."), status_code=400)
- oauth_cfg = json.loads(srv.oauth_config)
- keys_file = os.path.expanduser(oauth_cfg.get("keys_file", ""))
- token_file = os.path.expanduser(oauth_cfg.get("token_file", ""))
+ oauth_cfg = _sanitize_mcp_oauth_config(json.loads(srv.oauth_config))
+ keys_file = oauth_cfg.get("keys_file", "")
+ token_file = oauth_cfg.get("token_file", "")
+ if not keys_file or not token_file:
+ raise HTTPException(400, "OAuth keys/token file not configured")
with open(keys_file, encoding="utf-8") as f:
keys_data = json.load(f)
@@ -488,6 +591,9 @@ def setup_mcp_routes(mcp_manager: McpManager):
"Authorized but Connection Failed",
f"Tokens saved, but the server failed to connect: {status.get('error', 'unknown error')}. Try reconnecting from Settings.",
))
+ except HTTPException as e:
+ logger.warning(f"OAuth callback rejected: {e.detail}")
+ return HTMLResponse(_oauth_result_page("Error", str(e.detail)), status_code=e.status_code)
except Exception as e:
logger.exception(f"OAuth callback error: {e}")
return HTMLResponse(_oauth_result_page("Error", str(e)), status_code=500)
diff --git a/routes/model_routes.py b/routes/model_routes.py
index ac025ad..6220305 100644
--- a/routes/model_routes.py
+++ b/routes/model_routes.py
@@ -1029,12 +1029,13 @@ def setup_model_routes(model_discovery):
for ep in endpoints:
base = _normalize_base(ep.base_url)
provider = _detect_provider(base)
- # Use cached models — background refresh keeps them updated
- model_ids = _cached_model_ids(ep)
+ # Merge cached + pinned models, then filter out hidden ones
ep_model_type = getattr(ep, "model_type", None) or "llm"
- # Filter out hidden (probe-failed) models
- hidden = _hidden_model_ids(ep)
- model_ids = [m for m in model_ids if m not in hidden]
+ model_ids = _visible_models(
+ _cached_model_ids(ep),
+ ep.hidden_models,
+ getattr(ep, "pinned_models", None),
+ )
# Build correct URL based on provider
chat_url = build_chat_url(base)
kind = _effective_endpoint_kind(ep, base)
@@ -1043,6 +1044,13 @@ def setup_model_routes(model_discovery):
if model_ids:
curated_key = _match_provider_curated(base, None)
curated, extra = _curate_models(model_ids, curated_key)
+ # Pinned models are admin-selected — they always belong in the
+ # primary curated list, not buried in extras.
+ pinned = _normalize_model_ids(getattr(ep, "pinned_models", None))
+ for m in pinned:
+ if m not in curated:
+ curated.append(m)
+ extra = [m for m in extra if m not in pinned]
items.append({
"host": "custom",
"port": 0,
@@ -1891,9 +1899,10 @@ def setup_model_routes(model_discovery):
if body:
if "supports_tools" in body:
v = body["supports_tools"]
- ep.supports_tools = bool(v) if v in (True, False, "true", "false", 1, 0) else None
+ ep.supports_tools = {True: True, False: False, 'true': True, 'false': False, 1: True, 0: False}.get(v)
if "is_enabled" in body:
- ep.is_enabled = bool(body["is_enabled"])
+ v_ie = body['is_enabled']
+ ep.is_enabled = v_ie.lower() in ('true', '1', 'yes') if isinstance(v_ie, str) else bool(v_ie)
if "name" in body and isinstance(body["name"], str):
ep.name = body["name"].strip() or ep.name
if "model_type" in body and isinstance(body["model_type"], str):
diff --git a/routes/session_routes.py b/routes/session_routes.py
index 049635d..049323c 100644
--- a/routes/session_routes.py
+++ b/routes/session_routes.py
@@ -57,6 +57,40 @@ def _content_to_text(content) -> str:
return ""
+def _message_role(message) -> str:
+ if isinstance(message, ChatMessage):
+ return message.role or ""
+ if isinstance(message, dict):
+ return message.get("role", "") or ""
+ return getattr(message, "role", "") or ""
+
+
+def _message_text(message) -> str:
+ if isinstance(message, ChatMessage):
+ content = message.content
+ elif isinstance(message, dict):
+ content = message.get("content")
+ else:
+ content = getattr(message, "content", None)
+ return _content_to_text(content)
+
+
+def _message_metadata(message) -> dict:
+ if isinstance(message, ChatMessage):
+ metadata = message.metadata
+ elif isinstance(message, dict):
+ metadata = message.get("metadata")
+ else:
+ metadata = getattr(message, "metadata", None)
+ return metadata if isinstance(metadata, dict) else {}
+
+
+def _reject_compact_during_active_run(session_id: str) -> None:
+ from src import agent_runs
+ if agent_runs.is_active(session_id):
+ raise HTTPException(409, "Session has an active run; try compacting after it finishes")
+
+
def _verify_session_owner(request: Request, session_id: str, session_manager=None):
"""Verify the current user owns the session. Raises 404 if not.
@@ -872,6 +906,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
session = session_manager.get_session(session_id)
except KeyError:
raise HTTPException(404, f"Session {session_id} not found")
+ _reject_compact_during_active_run(session_id)
history = list(session.history or [])
if len(history) < 6:
@@ -897,7 +932,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
prior_compactions = sum(
1 for m in history
- if (m.metadata or {}).get("compacted") or "[Conversation summary" in (m.content or "")
+ if _message_metadata(m).get("compacted") or "[Conversation summary" in _message_text(m)
)
prompt = SELF_SUMMARY_SYSTEM_PROMPT.replace(
"{count}", str(len(older))
@@ -905,7 +940,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
"{n}", str(prior_compactions + 1)
)
convo_text = "\n".join(
- f"{m.role.upper()}: {(m.content or '')[:2000]}"
+ f"{_message_role(m).upper()}: {_message_text(m)[:2000]}"
for m in older
)
try:
diff --git a/routes/task_routes.py b/routes/task_routes.py
index ebe9da1..a5c49ad 100644
--- a/routes/task_routes.py
+++ b/routes/task_routes.py
@@ -455,7 +455,7 @@ def setup_task_routes(task_scheduler) -> APIRouter:
import sqlite3
from pathlib import Path
- from routes.email_helpers import SCHEDULED_DB
+ from routes.email_helpers import SCHEDULED_DB, OWNER_SCOPED_EMAIL_CACHE_TABLES, _email_cache_owner_clause
cleared = {}
conn = sqlite3.connect(SCHEDULED_DB)
@@ -468,6 +468,13 @@ def setup_task_routes(task_scheduler) -> APIRouter:
(user,),
).fetchone()[0]
conn.execute("DELETE FROM email_tags WHERE owner = ? OR owner = ''", (user,))
+ elif table in OWNER_SCOPED_EMAIL_CACHE_TABLES and user:
+ owner_clause, owner_params = _email_cache_owner_clause(user)
+ before = conn.execute(
+ f"SELECT COUNT(*) FROM {table} WHERE {owner_clause}",
+ owner_params,
+ ).fetchone()[0]
+ conn.execute(f"DELETE FROM {table} WHERE {owner_clause}", owner_params)
else:
before = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
conn.execute(f"DELETE FROM {table}")
diff --git a/routes/workspace_routes.py b/routes/workspace_routes.py
new file mode 100644
index 0000000..f7b27fb
--- /dev/null
+++ b/routes/workspace_routes.py
@@ -0,0 +1,56 @@
+"""Workspace API — browse server directories to pick a tool workspace folder."""
+import os
+from fastapi import APIRouter, Request, HTTPException, Query
+
+from src.auth_helpers import get_current_user
+from src.tool_security import owner_is_admin_or_single_user
+
+
+def setup_workspace_routes():
+ router = APIRouter(prefix="/api/workspace", tags=["workspace"])
+
+ @router.get("/browse")
+ def browse(request: Request, path: str = Query(default="")):
+ """List subdirectories of `path` (default: home) so the UI can navigate
+ the server filesystem and pick a workspace folder. Directories only.
+
+ ADMIN-ONLY: this enumerates the server filesystem, so it is gated the
+ same way the file/shell tools are (read_file/write_file/bash are in
+ NON_ADMIN_BLOCKED_TOOLS). A non-admin who can't use those tools must not
+ be able to map the host's directory tree either.
+ """
+ owner = get_current_user(request)
+ if not owner_is_admin_or_single_user(owner):
+ raise HTTPException(status_code=403, detail="Workspace browsing is admin-only")
+
+ # Resolve symlinks so the reported path is canonical and the UI navigates
+ # real directories (defends against symlink games in displayed paths).
+ target = os.path.realpath(os.path.expanduser(path.strip() or "~"))
+ if not os.path.isdir(target):
+ target = os.path.realpath(os.path.expanduser("~"))
+
+ dirs = []
+ try:
+ with os.scandir(target) as it:
+ for entry in it:
+ try:
+ # Don't follow symlinks when classifying — a symlinked
+ # dir is skipped rather than letting the browser wander
+ # off via a link. Hidden entries are omitted.
+ if entry.is_dir(follow_symlinks=False) and not entry.name.startswith("."):
+ # Build the child path server-side with os.path.join
+ # so it's correct on Windows (backslashes) and Linux.
+ dirs.append({"name": entry.name, "path": os.path.join(target, entry.name)})
+ except OSError:
+ continue
+ except (PermissionError, OSError):
+ dirs = []
+
+ parent = os.path.dirname(target)
+ return {
+ "path": target,
+ "parent": parent if parent and parent != target else None,
+ "dirs": sorted(dirs, key=lambda d: d["name"].lower()),
+ }
+
+ return router
diff --git a/services/hwfit/fit.py b/services/hwfit/fit.py
index 9a45b53..09aea29 100644
--- a/services/hwfit/fit.py
+++ b/services/hwfit/fit.py
@@ -576,6 +576,7 @@ def rank_models(system, use_case=None, limit=50, search=None, sort="score", quan
system_backend = (system.get("backend") or "").lower()
apple_silicon = system_backend in ("mps", "metal", "apple")
rocm = system_backend == "rocm"
+ is_windows = system.get("platform") == "windows"
# Consumer AMD Radeon (RDNA, gfx10/11/12): the practical local serving path
# is GGUF via llama.cpp. vLLM/SGLang on ROCm are validated for datacenter
@@ -615,7 +616,11 @@ def rank_models(system, use_case=None, limit=50, search=None, sort="score", quan
# servable path, so a model needs a real GGUF to be recommended.
# Otherwise the Cookbook rates vLLM-only AWQ/GPTQ builds "GOOD" on a
# Radeon that can't actually serve them.
- if (apple_silicon or consumer_amd) and not (m.get("is_gguf") or m.get("gguf_sources")):
+ #
+ # Windows is the same: Odysseus only supports llama.cpp on Windows,
+ # which requires GGUF. vLLM/SGLang are explicitly blocked, so AWQ/GPTQ
+ # models without a GGUF source are unservable there.
+ if (apple_silicon or consumer_amd or is_windows) and not (m.get("is_gguf") or m.get("gguf_sources")):
continue
# Format filter: AWQ tab -> only AWQ models, FP4 tab -> FP4-family models, etc.
diff --git a/services/hwfit/hardware.py b/services/hwfit/hardware.py
index 9815327..f961b70 100644
--- a/services/hwfit/hardware.py
+++ b/services/hwfit/hardware.py
@@ -539,6 +539,7 @@ def _detect_windows():
"backend": d.get("gpu_backend", "cpu_x86"),
"homogeneous": True,
"gpu_error": None,
+ "platform": "windows",
}
# PowerShell only reports aggregate GPU info, not per-card detail, so we
# can't tell a mixed box from a uniform one here — assume one homogeneous
diff --git a/services/memory/memory_extractor.py b/services/memory/memory_extractor.py
index 32412e6..44a9f1f 100644
--- a/services/memory/memory_extractor.py
+++ b/services/memory/memory_extractor.py
@@ -345,8 +345,17 @@ async def extract_and_store(
logger.warning(f"Memory dedup (vector) unavailable, using text fallback: {e}")
existing_id = None
if existing_id:
- logger.debug(f"Memory dedup (vector): '{fact_text[:50]}' matches {existing_id}")
- continue
+ # The vector store is a single shared collection with no
+ # owner metadata, so find_similar can return ANOTHER
+ # tenant's memory. Only treat it as a duplicate when the
+ # match is this user's own (or a legacy unowned) memory —
+ # otherwise the user's freshly-extracted fact would be
+ # silently dropped. Mirror the owner predicate used by the
+ # text dedup below; cross-tenant/stale matches fall through.
+ _match = next((e for e in existing if e.get("id") == existing_id), None)
+ if _match is not None and (_match.get("owner") == _owner or _match.get("owner") is None):
+ logger.debug(f"Memory dedup (vector): '{fact_text[:50]}' matches {existing_id}")
+ continue
# Text dedup fallback: exact match + fuzzy similarity
user_existing = [e for e in existing if e.get("owner") == _owner or e.get("owner") is None] if _owner else existing
diff --git a/services/memory/service.py b/services/memory/service.py
index d07eb17..0a5b9b5 100644
--- a/services/memory/service.py
+++ b/services/memory/service.py
@@ -7,6 +7,7 @@ import os
from .memory import MemoryManager
from .memory_vector import MemoryVectorStore
+from src.memory_provider import MemoryRecord, NativeMemoryProvider
@dataclass
@@ -42,6 +43,10 @@ class MemoryService:
self.vector_store = MemoryVectorStore(data_dir) if os.path.exists(
os.path.join(data_dir, "memory_vectors")
) else None
+ self.provider = NativeMemoryProvider(self.manager, self.vector_store)
+
+ def _sync_provider(self) -> None:
+ self.provider.memory_vector = self.vector_store
@staticmethod
def _to_memory(entry: Dict[str, Any], metadata: Optional[Dict[str, Any]] = None) -> Memory:
@@ -53,6 +58,19 @@ class MemoryService:
metadata=metadata or {},
)
+ @staticmethod
+ def _record_to_memory(record: MemoryRecord, metadata: Optional[Dict[str, Any]] = None) -> Memory:
+ merged_metadata = dict(record.metadata)
+ if metadata:
+ merged_metadata.update(metadata)
+ return Memory(
+ id=record.id,
+ text=record.text,
+ timestamp=record.timestamp,
+ session_id=record.session_id,
+ metadata=merged_metadata,
+ )
+
async def remember(self, text: str, session_id: Optional[str] = None) -> Memory:
"""
Store a new memory.
@@ -64,19 +82,9 @@ class MemoryService:
Returns:
Created Memory object
"""
- entry = self.manager.add_entry(text)
- if session_id:
- entry["session_id"] = session_id
-
- memories = self.manager.load_all()
- memories.append(entry)
- self.manager.save(memories)
-
- # Also add to vector store if available
- if self.vector_store and self.vector_store.healthy:
- self.vector_store.add(entry["id"], entry["text"])
-
- return self._to_memory(entry)
+ self._sync_provider()
+ record = await self.provider.remember(text, session_id=session_id)
+ return self._record_to_memory(record)
async def recall(self, query: str, top_k: int = 5) -> MemorySearchResult:
"""
@@ -89,28 +97,20 @@ class MemoryService:
Returns:
MemorySearchResult with matching memories
"""
- # Try vector search first
- all_memories = self.manager.load_all()
- by_id = {m.get("id"): m for m in all_memories}
- if self.vector_store and self.vector_store.healthy:
- results = self.vector_store.search(query, k=top_k)
- found = []
- for result in results:
- entry = by_id.get(result.get("memory_id"))
- if entry:
- found.append(self._to_memory(entry, metadata={"score": result.get("score")}))
- if found:
- return MemorySearchResult(memories=found, query=query, total=len(found))
-
- # Fallback to keyword search
- results = self.manager.get_relevant_memories(query, all_memories, max_items=top_k)
- memories = [self._to_memory(m) for m in results]
+ self._sync_provider()
+ results = await self.provider.recall(query, top_k=top_k)
+ memories = [
+ self._record_to_memory(hit.memory, metadata={"score": hit.score})
+ if hit.score is not None
+ else self._record_to_memory(hit.memory)
+ for hit in results
+ ]
return MemorySearchResult(memories=memories, query=query, total=len(memories))
def get_all(self, limit: int = 100) -> List[Memory]:
"""Get all memories."""
- memories = self.manager.load_all()[:limit]
- return [self._to_memory(m) for m in memories]
+ records = self.manager.load_all()[:limit]
+ return [self._to_memory(m) for m in records]
def delete(self, memory_id: str) -> bool:
"""Delete a memory by ID."""
diff --git a/src/agent_loop.py b/src/agent_loop.py
index a990e19..401e7bb 100644
--- a/src/agent_loop.py
+++ b/src/agent_loop.py
@@ -177,6 +177,7 @@ TOOL_SECTIONS = {
```
Run any shell command. Output is returned to you. Use for: installing packages, checking files, git, curl, system info, etc.
+NEVER use bash to create or change files — no `>`/`>>` redirects, no heredocs (`cat > f << 'EOF'`), no `tee`, `sed -i`, `awk -i`, no `python -c` that writes. To CREATE or fully rewrite a file use `write_file`; to change part of an existing file use `edit_file`. Those show a diff and are the ONLY allowed way to write files. (bash is for read-only inspection: `ls`, `cat` to READ, `grep`, `git status`/`git diff`, builds, installs.)
For LONG-running commands (package installs, pip/npm, ffmpeg, model downloads, training, builds — anything that may take more than ~20s), make the FIRST line `#!bg` to run it in the BACKGROUND. You get a job id back immediately and are automatically re-invoked with the full output when it finishes — so you never block the chat waiting. Example:
```bash
#!bg
@@ -220,6 +221,12 @@ Read a file and return its contents.""",
```
Write content to a file. First line is the path, rest is the content.""",
+ "edit_file": """\
+```edit_file
+{"path": "", "old_string": "", "new_string": "", "replace_all": false}
+```
+Edit an EXISTING file by exact string replacement. PREFER this over bash (sed/echo/redirects) for changing files — it shows a before/after diff. `old_string` must match the file exactly and be unique unless `replace_all` is true. Use write_file to create a new file.""",
+
"create_document": """\
```create_document
@@ -236,7 +243,7 @@ old text to find
new replacement text
<<>>
```
-PREFERRED way to change an existing document. Find exact text and replace it. Multiple FIND/REPLACE blocks per call OK. Use this for any edit smaller than a full rewrite — adding a function, fixing a bug, tweaking a section, renaming things. **If a document is open in the editor, treat it as the user's current context: don't ask which file they mean, and don't create a new one — just edit_document the active one.** Do NOT re-send the whole file with update_document for small changes.""",
+Edit a document OPEN IN THE EDITOR PANEL — NOT a file on disk. For files on disk (home folder, project files, any real path like ~/sweden.txt) use `edit_file` instead. Find exact text and replace it. Multiple FIND/REPLACE blocks per call OK. Use for any edit smaller than a full rewrite. **If a document is open in the editor, treat it as the user's current context: don't ask which file they mean, and don't create a new one — just edit_document the active one.** Do NOT re-send the whole file with update_document for small changes.""",
"update_document": """\
```update_document
@@ -462,13 +469,14 @@ _API_HOSTS = frozenset([
"api.together.xyz", "api.fireworks.ai",
"api.perplexity.ai", "api.x.ai",
"ollama.com", "api.venice.ai",
+ "api.githubcopilot.com",
# Local OpenAI-compatible endpoints (llama.cpp, vLLM, LM Studio, etc.).
# Without these, `_is_api_model` falls back to keyword sniffing on the
# model name, so well-behaved local servers don't get native tool
# schemas and the agent silently degrades to fenced-block parsing.
"localhost", "127.0.0.1", "host.docker.internal",
])
-_MCP_KEYWORDS = frozenset(["browse", "browser", "website", "calendar", "event", "email",
+_MCP_KEYWORDS = frozenset(["mcp", "browse", "browser", "website", "calendar", "event", "email",
"gmail", "screenshot", "navigate", "click", "miniflux", "rss", "feed"])
_ADMIN_SCHEMA_NAMES = frozenset([
"manage_session", "manage_skills", "manage_tasks",
@@ -1380,6 +1388,7 @@ async def stream_agent_loop(
owner: Optional[str] = None,
relevant_tools: Optional[Set[str]] = None,
fallbacks: Optional[List[tuple]] = None,
+ workspace: Optional[str] = None,
_is_teacher_run: bool = False,
) -> AsyncGenerator[str, None]:
"""Streaming agent loop generator.
@@ -1546,6 +1555,27 @@ async def stream_agent_loop(
compact=_is_api_model,
owner=owner,
)
+ if workspace:
+ # PREPEND (not append) so it dominates the large base prompt — appended
+ # at the end, small models ignored it and asked the user for code. The
+ # folder IS the project; the agent must explore it, not ask.
+ _ws_note = (
+ f"## ACTIVE WORKSPACE — READ FIRST\n"
+ f"The user is working in this folder: {workspace}\n"
+ f"It IS the project. bash/python run with cwd set here and "
+ f"read_file/write_file are confined to it (paths outside are rejected).\n"
+ f"When the user says \"the code\" / \"this project\" / \"the workspace\" "
+ f"or asks to review/find/edit something WITHOUT a path, they mean THIS "
+ f"folder. Do NOT ask the user for code or a path, and do NOT read a file "
+ f"literally named \"workspace\". ALWAYS start by exploring it yourself: "
+ f"run `bash` → `git ls-files` (or `ls -R`) to see the files, then "
+ f"read_file the relevant ones by path RELATIVE to the workspace."
+ )
+ if messages and messages[0].get("role") == "system":
+ messages[0]["content"] = _ws_note + "\n\n" + (messages[0].get("content") or "")
+ else:
+ messages.insert(0, {"role": "system", "content": _ws_note})
+ logger.info("[workspace] active for this turn: %s", workspace)
prep_timings["prompt_build"] = time.time() - _t2
_t3 = time.time()
@@ -1658,6 +1688,11 @@ async def stream_agent_loop(
_doc_opened = False # whether doc_stream_open was sent
_doc_last_len = 0 # last content length sent
+ # Set when the loop runs out of rounds while the agent was still actively
+ # using tools — i.e. it was cut off, not finished. Drives a "Continue" event
+ # so the user can resume instead of the turn silently stalling.
+ _exhausted_rounds = False
+
for round_num in range(1, max_rounds + 1):
round_response = ""
round_reasoning = "" # reasoning_content deltas (DeepSeek-thinking, vLLM --reasoning-parser)
@@ -2167,6 +2202,7 @@ async def stream_agent_loop(
disabled_tools=disabled_tools,
owner=owner,
progress_cb=_push_progress,
+ workspace=workspace,
)
finally:
# Sentinel so the drainer knows to stop.
@@ -2282,6 +2318,9 @@ async def stream_agent_loop(
if result.get("images"):
img = result["images"][0]
tool_output_data["screenshot"] = f"data:{img['mimeType']};base64,{img['data']}"
+ # Forward a file-write diff for inline before/after rendering
+ if "diff" in result:
+ tool_output_data["diff"] = result["diff"]
yield f'data: {json.dumps(tool_output_data)}\n\n'
# Native document tools open in the editor + carry the REAL doc id.
@@ -2324,6 +2363,10 @@ async def stream_agent_loop(
if result.get("doc_id"):
tool_event["doc_id"] = result["doc_id"]
tool_event["doc_title"] = result.get("title", "")
+ # Persist the file-write/edit diff so it re-renders on reload — without
+ # this the diff shows live but vanishes from saved history.
+ if result.get("diff"):
+ tool_event["diff"] = result["diff"]
tool_events.append(tool_event)
if block.tool_type in _VERIFIER_EFFECTFUL_TOOLS:
_effectful_used = True
@@ -2348,6 +2391,20 @@ async def stream_agent_loop(
# Separator in accumulated response
full_response += "\n\n"
+ else:
+ # The for-loop completed every allowed round WITHOUT an early `break`
+ # (a `break` fires on "done", budget, or error). Reaching this `else`
+ # means the agent kept working until it ran out of rounds — so offer
+ # Continue instead of stopping silently. This catches ALL exhaustion
+ # paths, including a verifier `continue` on the final round (the old
+ # bottom-of-loop flag missed those).
+ _exhausted_rounds = True
+
+ # If the loop hit the round cap while still working, tell the client so it
+ # can show a "Continue" affordance instead of the turn just stopping.
+ if _exhausted_rounds:
+ logger.info("[agent] round cap (%d) reached mid-task — emitting rounds_exhausted", max_rounds)
+ yield f'data: {json.dumps({"type": "rounds_exhausted", "rounds": max_rounds})}\n\n'
# If the response is completely empty and no tools were executed,
# yield a fallback message so the user is not left hanging.
diff --git a/src/agent_tools.py b/src/agent_tools.py
index f162bc5..b86bd48 100644
--- a/src/agent_tools.py
+++ b/src/agent_tools.py
@@ -26,7 +26,8 @@ MAX_OUTPUT_CHARS = 10_000
MAX_READ_CHARS = 20_000
# Tool types that trigger execution
-TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file",
+TOOL_TAGS = {"bash", "python", "web_search", "web_fetch", "read_file", "write_file", "edit_file",
+ "grep", "glob", "ls",
"create_document", "update_document", "edit_document",
"search_chats",
"chat_with_model", "create_session", "list_sessions",
diff --git a/src/app_initializer.py b/src/app_initializer.py
index 1cfa308..7d6b8c2 100644
--- a/src/app_initializer.py
+++ b/src/app_initializer.py
@@ -9,6 +9,7 @@ from src.constants import (
SESSIONS_FILE, DEFAULT_HOST, OPENAI_API_KEY
)
from src.memory import MemoryManager
+from src.memory_provider import MemoryProviderRegistry, NativeMemoryProvider
from services.memory.skills import SkillsManager
from core.session_manager import SessionManager
from core.models import set_session_manager
@@ -73,6 +74,10 @@ def initialize_managers(base_dir: str, rag_manager=None) -> Dict[str, Any]:
logger.warning(f"MemoryVectorStore DEGRADED: {e}")
memory_vector = None
+ memory_provider_registry = MemoryProviderRegistry([
+ NativeMemoryProvider(memory_manager, memory_vector),
+ ])
+
# Initialize processors
chat_processor = ChatProcessor(memory_manager, personal_docs_manager, memory_vector=memory_vector, skills_manager=skills_manager)
research_handler = ResearchHandler()
@@ -99,6 +104,7 @@ def initialize_managers(base_dir: str, rag_manager=None) -> Dict[str, Any]:
return {
"memory_manager": memory_manager,
"memory_vector": memory_vector,
+ "memory_provider_registry": memory_provider_registry,
"skills_manager": skills_manager,
"session_manager": session_manager,
"upload_handler": upload_handler,
diff --git a/src/caldav_sync.py b/src/caldav_sync.py
index 10ca51c..663c0bd 100644
--- a/src/caldav_sync.py
+++ b/src/caldav_sync.py
@@ -265,6 +265,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
existing.all_day = all_day
existing.is_utc = row_is_utc
existing.rrule = rrule
+ existing.origin = "caldav"
else:
new_ev = CalendarEvent(
uid=uid_val,
@@ -277,6 +278,7 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
all_day=all_day,
is_utc=row_is_utc,
rrule=rrule,
+ origin="caldav",
)
db.add(new_ev)
pending[uid_val] = new_ev
@@ -286,8 +288,13 @@ def _sync_blocking(owner: str, url: str, username: str, password: str) -> dict:
# Prune locally-cached CalDAV events that vanished
# upstream (only within our sync window — events outside
# the window aren't in `objs`, so we'd false-delete them).
+ # Only rows we previously pulled from the server (origin=="caldav")
+ # are prunable; locally-created events (agent / email triage / a
+ # UI event whose write-back failed) carry origin NULL and must
+ # never be deleted just because the server didn't return them.
stale = db.query(CalendarEvent).filter(
CalendarEvent.calendar_id == local_cal.id,
+ CalendarEvent.origin == "caldav",
CalendarEvent.dtstart >= start,
CalendarEvent.dtstart <= end,
~CalendarEvent.uid.in_(seen_uids) if seen_uids else CalendarEvent.uid.isnot(None),
diff --git a/src/copilot.py b/src/copilot.py
new file mode 100644
index 0000000..62d2b8c
--- /dev/null
+++ b/src/copilot.py
@@ -0,0 +1,253 @@
+# src/copilot.py
+"""GitHub Copilot provider support.
+
+Copilot exposes an OpenAI-compatible API at ``https://api.githubcopilot.com``
+(``/chat/completions`` + ``/models``). Authentication is a GitHub OAuth
+**device flow**: the user authorises a device code in their browser and we
+receive a long-lived ``access_token`` that is sent directly as
+``Authorization: Bearer `` — there is no separate Copilot-token
+exchange and no refresh (mirrors how editors / opencode talk to Copilot).
+
+The only provider-specific wrinkle beyond the bearer token is a handful of
+required request headers (API version, intent, an editor-style User-Agent,
+and ``x-initiator`` for agent-vs-user request accounting). Those live in
+:func:`copilot_headers`.
+
+This module holds the constants + pure helpers; the HTTP device-flow calls
+live in :mod:`routes.copilot_routes` so they can be auth-gated.
+"""
+
+import os
+from typing import Dict, List, Optional
+from urllib.parse import urlparse
+
+import httpx
+
+# ---------------------------------------------------------------------------
+# Constants
+# ---------------------------------------------------------------------------
+
+# GitHub OAuth client id used for the device flow. Copilot's token endpoint
+# only accepts client ids that GitHub has allow-listed for Copilot access, so
+# we reuse the public VS Code client id (the de-facto standard third-party
+# clients use). Override via env if you register your own allow-listed app.
+COPILOT_CLIENT_ID = os.environ.get(
+ "ODYSSEUS_COPILOT_CLIENT_ID", "01ab8ac9400c4e429b23"
+)
+
+# Dated API version header required by the Copilot API (models + chat).
+COPILOT_API_VERSION = os.environ.get(
+ "ODYSSEUS_COPILOT_API_VERSION", "2026-06-01"
+)
+
+# Public Copilot API base. GitHub Enterprise uses ``copilot-api.``.
+COPILOT_BASE = "https://api.githubcopilot.com"
+
+# Copilot wants an editor-like User-Agent + integration id. These identify the
+# client to GitHub; keep them stable.
+COPILOT_USER_AGENT = os.environ.get(
+ "ODYSSEUS_COPILOT_USER_AGENT", "Odysseus/1.0"
+)
+COPILOT_INTEGRATION_ID = os.environ.get(
+ "ODYSSEUS_COPILOT_INTEGRATION_ID", "vscode-chat"
+)
+COPILOT_EDITOR_VERSION = os.environ.get(
+ "ODYSSEUS_COPILOT_EDITOR_VERSION", "Odysseus/1.0"
+)
+
+# OAuth scope requested during the device flow.
+COPILOT_SCOPE = "read:user"
+
+# Default GitHub host for the device flow (public github.com).
+GITHUB_HOST = "github.com"
+
+
+def device_code_url(host: str = GITHUB_HOST) -> str:
+ return f"https://{host}/login/device/code"
+
+
+def access_token_url(host: str = GITHUB_HOST) -> str:
+ return f"https://{host}/login/oauth/access_token"
+
+
+def normalize_domain(url: str) -> str:
+ """Strip scheme/trailing slash from a GitHub Enterprise URL or domain."""
+ return (url or "").replace("https://", "").replace("http://", "").rstrip("/")
+
+
+def enterprise_base(enterprise_url: Optional[str]) -> str:
+ """Return the Copilot API base for a deployment.
+
+ Public github.com → ``https://api.githubcopilot.com``.
+ Enterprise → ``https://copilot-api.``.
+ """
+ if not enterprise_url:
+ return COPILOT_BASE
+ return f"https://copilot-api.{normalize_domain(enterprise_url)}"
+
+
+def is_copilot_base(url: Optional[str]) -> bool:
+ """True if a base URL points at the Copilot API (public or enterprise)."""
+ if not url:
+ return False
+ try:
+ host = (urlparse(url).hostname or "").lower().rstrip(".")
+ except Exception:
+ return False
+ if not host:
+ return False
+ # Public: api.githubcopilot.com (or any *.githubcopilot.com).
+ if host == "githubcopilot.com" or host.endswith(".githubcopilot.com"):
+ return True
+ # Enterprise: copilot-api..
+ if host.startswith("copilot-api."):
+ return True
+ return False
+
+
+def copilot_headers(
+ api_key: Optional[str],
+ *,
+ agent: bool = False,
+ vision: bool = False,
+) -> Dict[str, str]:
+ """Build the Copilot-specific request headers.
+
+ Args:
+ api_key: the GitHub device-flow access token (sent as Bearer).
+ agent: request originates from the agent loop (a tool-driven turn)
+ rather than a direct user message. Sets ``x-initiator`` for
+ Copilot's agent-vs-user request accounting.
+ vision: the request carries an image part.
+ """
+ headers: Dict[str, str] = {
+ "X-GitHub-Api-Version": COPILOT_API_VERSION,
+ "Openai-Intent": "conversation-edits",
+ "User-Agent": COPILOT_USER_AGENT,
+ "Editor-Version": COPILOT_EDITOR_VERSION,
+ "Copilot-Integration-Id": COPILOT_INTEGRATION_ID,
+ "x-initiator": "agent" if agent else "user",
+ }
+ if api_key:
+ headers["Authorization"] = f"Bearer {api_key}"
+ if vision:
+ headers["Copilot-Vision-Request"] = "true"
+ return headers
+
+
+# ---------------------------------------------------------------------------
+# Device-flow OAuth (pure HTTP; orchestration lives in routes.copilot_routes)
+# ---------------------------------------------------------------------------
+
+def _oauth_post_headers() -> Dict[str, str]:
+ return {
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ "User-Agent": COPILOT_USER_AGENT,
+ }
+
+
+def request_device_code(host: str = GITHUB_HOST, *, timeout: float = 10.0) -> Dict:
+ """Start the device flow. Returns GitHub's
+ ``{device_code, user_code, verification_uri, expires_in, interval}``.
+ """
+ r = httpx.post(
+ device_code_url(host),
+ headers=_oauth_post_headers(),
+ json={"client_id": COPILOT_CLIENT_ID, "scope": COPILOT_SCOPE},
+ timeout=timeout,
+ )
+ r.raise_for_status()
+ return r.json()
+
+
+def poll_access_token(host: str, device_code: str, *, timeout: float = 10.0) -> Dict:
+ """Poll once for the access token. GitHub returns HTTP 200 with an
+ ``error`` field (``authorization_pending``/``slow_down``) while the user
+ hasn't authorised yet, or ``{access_token, ...}`` once they have.
+ """
+ r = httpx.post(
+ access_token_url(host),
+ headers=_oauth_post_headers(),
+ json={
+ "client_id": COPILOT_CLIENT_ID,
+ "device_code": device_code,
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
+ },
+ timeout=timeout,
+ )
+ r.raise_for_status()
+ return r.json()
+
+
+def fetch_models(base: str, token: str, *, timeout: float = 15.0) -> List[Dict]:
+ """Fetch Copilot's model catalogue, filtered to picker-enabled models.
+
+ Returns a list of ``{id, tool_calls, vision}`` dicts. Falls back to the
+ full list if no model advertises ``model_picker_enabled`` (defensive
+ against API-shape drift).
+ """
+ url = base.rstrip("/") + "/models"
+ r = httpx.get(url, headers=copilot_headers(token), timeout=timeout)
+ r.raise_for_status()
+ data = (r.json() or {}).get("data") or []
+
+ def _parse(item: Dict) -> Optional[Dict]:
+ mid = item.get("id")
+ if not mid:
+ return None
+ supports = ((item.get("capabilities") or {}).get("supports")) or {}
+ return {
+ "id": mid,
+ "tool_calls": bool(supports.get("tool_calls")),
+ "vision": bool(supports.get("vision")),
+ "picker": bool(item.get("model_picker_enabled")),
+ }
+
+ parsed = [p for p in (_parse(it) for it in data) if p]
+ picker = [p for p in parsed if p["picker"]]
+ chosen = picker or parsed
+ for p in chosen:
+ p.pop("picker", None)
+ return chosen
+
+
+# ---------------------------------------------------------------------------
+# Per-request header flags
+# ---------------------------------------------------------------------------
+
+_IMAGE_PART_TYPES = ("image_url", "input_image", "image")
+
+
+def request_flags(messages) -> tuple:
+ """Derive ``(agent, vision)`` from an OpenAI-style message list.
+
+ Mirrors opencode's logic:
+ * ``agent`` — the last message is *not* a plain user message (i.e. it's a
+ tool result / assistant follow-up), so Copilot should treat the request
+ as agent-initiated for request accounting.
+ * ``vision`` — any message carries an image content part.
+ """
+ msgs = messages or []
+ last = msgs[-1] if msgs else None
+ agent = bool(last) and last.get("role") != "user"
+ vision = False
+ for m in msgs:
+ content = m.get("content") if isinstance(m, dict) else None
+ if isinstance(content, list) and any(
+ isinstance(p, dict) and p.get("type") in _IMAGE_PART_TYPES for p in content
+ ):
+ vision = True
+ break
+ return agent, vision
+
+
+def apply_request_headers(headers: Dict[str, str], messages) -> Dict[str, str]:
+ """Set ``x-initiator`` / ``Copilot-Vision-Request`` on a header dict based
+ on the outgoing messages. Mutates and returns ``headers``."""
+ agent, vision = request_flags(messages)
+ headers["x-initiator"] = "agent" if agent else "user"
+ if vision:
+ headers["Copilot-Vision-Request"] = "true"
+ return headers
+
diff --git a/src/deep_research.py b/src/deep_research.py
index 4617439..7a31422 100644
--- a/src/deep_research.py
+++ b/src/deep_research.py
@@ -196,6 +196,8 @@ class DeepResearcher:
max_content_chars: int = 15000,
max_report_tokens: int = 8192,
extraction_timeout: int = 90,
+ planning_timeout: int = 90,
+ query_timeout: int = 120,
extraction_concurrency: int = 3,
min_rounds: int = 2,
max_empty_rounds: int = 2,
@@ -215,6 +217,8 @@ class DeepResearcher:
self.max_content_chars = max_content_chars
self.max_report_tokens = max_report_tokens
self.extraction_timeout = min(3600, max(15, int(extraction_timeout or 90)))
+ self.planning_timeout = min(3600, max(15, int(planning_timeout or 90)))
+ self.query_timeout = min(3600, max(15, int(query_timeout or 120)))
self.extraction_concurrency = min(12, max(1, int(extraction_concurrency or 3)))
self.min_rounds = min_rounds
self.max_empty_rounds = max_empty_rounds
@@ -395,7 +399,7 @@ class DeepResearcher:
[{"role": "user", "content": prompt}],
temperature=0.3,
max_tokens=1024,
- timeout=30,
+ timeout=getattr(self, "planning_timeout", 90),
)
# Try to parse as JSON for structured plan
parsed = self._parse_json_object(response)
@@ -478,6 +482,7 @@ class DeepResearcher:
[{"role": "user", "content": prompt}],
temperature=0.5,
max_tokens=4096,
+ timeout=getattr(self, "query_timeout", 120),
)
queries = self._parse_json_array(response)
# Deduplicate
diff --git a/src/endpoint_resolver.py b/src/endpoint_resolver.py
index 073f8d7..a9ab5c7 100644
--- a/src/endpoint_resolver.py
+++ b/src/endpoint_resolver.py
@@ -194,6 +194,9 @@ def build_headers(api_key: Optional[str], base: str) -> Dict[str, str]:
headers["x-api-key"] = api_key
headers["anthropic-version"] = "2023-06-01"
return headers
+ if provider == "copilot":
+ from src.copilot import copilot_headers
+ return copilot_headers(api_key)
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
if provider == "openrouter":
diff --git a/src/llm_core.py b/src/llm_core.py
index a929edc..7dcf380 100644
--- a/src/llm_core.py
+++ b/src/llm_core.py
@@ -67,7 +67,7 @@ _host_health_lock = threading.Lock()
_model_activity: Dict[str, float] = {}
def _model_activity_key(url: str, model: str) -> str:
- return f"{(url or '').strip().rstrip()}|{(model or '').strip()}"
+ return f"{(url or '').strip()}|{(model or '').strip()}"
def note_model_activity(url: str, model: str):
"""Record that a real upstream request used this endpoint/model."""
@@ -317,6 +317,9 @@ def _detect_provider(url: str) -> str:
return "openrouter"
if _host_match(url, "groq.com"):
return "groq"
+ from src.copilot import is_copilot_base
+ if is_copilot_base(url):
+ return "copilot"
return "openai"
@@ -327,6 +330,14 @@ def _provider_headers(provider: str, headers: Optional[Dict] = None) -> Dict[str
if provider == "openrouter":
h.setdefault("HTTP-Referer", "https://github.com/pewdiepie-archdaemon/odysseus")
h.setdefault("X-OpenRouter-Title", "Odysseus")
+ if provider == "copilot":
+ # Ensure the Copilot-required headers are present even when the caller
+ # didn't pass pre-built headers (e.g. model listing). build_headers()
+ # already injects these for the live chat path; setdefault keeps any
+ # request-specific values (x-initiator/vision) the caller set.
+ from src.copilot import copilot_headers
+ for k, v in copilot_headers(None).items():
+ h.setdefault(k, v)
return h
@@ -340,6 +351,8 @@ def _provider_label(url: str) -> str:
if _host_match(url, "openai.com"): return "OpenAI"
if _host_match(url, "openrouter.ai"): return "OpenRouter"
if _host_match(url, "groq.com"): return "Groq"
+ from src.copilot import is_copilot_base
+ if is_copilot_base(url): return "GitHub Copilot"
if _host_match(url, "mistral.ai"): return "Mistral"
if _host_match(url, "deepseek.com"): return "DeepSeek"
if _host_match(url, "googleapis.com"): return "Google"
@@ -481,7 +494,7 @@ def _build_anthropic_payload(model, messages, temperature, max_tokens, stream=Fa
chat_messages = []
for m in messages:
if m.get("role") == "system":
- system_parts.append(m["content"])
+ system_parts.append(m.get("content") or "")
elif m.get("role") == "tool":
# Convert OpenAI tool result to Anthropic format
chat_messages.append({
@@ -884,7 +897,7 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
non_sys = []
for m in messages_copy:
if m.get("role") == "system":
- sys_parts.append(m["content"])
+ sys_parts.append(m.get('content') or '')
else:
non_sys.append(m)
if sys_parts:
@@ -911,6 +924,9 @@ def llm_call(url: str, model: str, messages: List[Dict], temperature: float = LL
)
else:
target_url = url
+ if provider == "copilot":
+ from src.copilot import apply_request_headers
+ apply_request_headers(h, messages_copy)
payload = {
"model": model,
"messages": messages_copy,
@@ -1028,7 +1044,7 @@ async def llm_call_async(
non_sys = []
for m in messages_copy:
if m.get("role") == "system":
- sys_parts.append(m["content"])
+ sys_parts.append(m.get('content') or '')
else:
non_sys.append(m)
if sys_parts:
@@ -1058,6 +1074,9 @@ async def llm_call_async(
else:
target_url = url
h = _provider_headers(provider, headers)
+ if provider == "copilot":
+ from src.copilot import apply_request_headers
+ apply_request_headers(h, messages_copy)
payload = {
"model": model,
"messages": messages_copy,
@@ -1088,6 +1107,9 @@ async def llm_call_async(
f"LLM async call to {target_url} failed in {duration:.2f}s "
f"(attempt {attempt}): HTTP {r.status_code} {friendly}"
)
+ if r.status_code in (429, 502, 503, 504) and attempt < max_retries:
+ await asyncio.sleep(LLMConfig.RETRY_DELAY)
+ continue
raise HTTPException(r.status_code, friendly)
logger.info(f"LLM async call to {target_url} succeeded in {duration:.2f}s (attempt {attempt})")
_clear_host_dead(target_url)
@@ -1109,7 +1131,9 @@ async def llm_call_async(
duration = time.time() - start
_tail = f" — host cooled for {DEAD_HOST_COOLDOWN:.0f}s" if _cooled else " — transient, will retry"
logger.warning(f"LLM async connect to {target_url} failed after {duration:.2f}s: {e}{_tail}")
- raise HTTPException(503, f"Cannot reach {_host_key(target_url)}: {e}")
+ if _cooled or attempt >= max_retries:
+ raise HTTPException(503, f"Cannot reach {_host_key(target_url)}: {e}")
+ await asyncio.sleep(LLMConfig.RETRY_DELAY)
except (httpx.RequestError, httpx.HTTPStatusError) as e:
duration = time.time() - start
logger.warning(f"LLM async call attempt {attempt} failed after {duration:.2f}s: {e}")
@@ -1138,7 +1162,7 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
non_sys = []
for m in messages_copy:
if m.get("role") == "system":
- sys_parts.append(m["content"])
+ sys_parts.append(m.get('content') or '')
else:
non_sys.append(m)
if sys_parts:
@@ -1177,6 +1201,9 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
if tools:
payload["tools"] = tools
h = _provider_headers(provider, headers)
+ if provider == "copilot":
+ from src.copilot import apply_request_headers
+ apply_request_headers(h, messages_copy)
# Short connect timeout: a reachable peer answers SYN in <100ms even on
# Tailscale. 3s is plenty; 30s let one dead upstream wedge the UI.
@@ -1358,6 +1385,8 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
# can detect thinking-in-progress (some models output but no )
_thinking_model = _supports_thinking(model)
_first_content_sent = False
+ _in_think_tag = False # True while consuming … content
+ _think_open_stripped = False # opening tag already removed
def _emit_tool_calls():
"""Build the tool_calls event string if any were accumulated."""
@@ -1439,14 +1468,53 @@ async def stream_llm(url: str, model: str, messages: List[Dict], temperature: fl
yield f'data: {json.dumps({"delta": reasoning, "thinking": True})}\n\n'
content = delta.get("content") or ""
if content:
- # Some thinking backends start normal content with a
- # stray closing tag. Repair only that shape; do not
- # wrap every first token for model families like
- # MiniMax, which often stream ordinary answers.
- if _thinking_model and not _first_content_sent and content.lstrip().lower().startswith("" + content
- _first_content_sent = True
- yield f'data: {json.dumps({"delta": content})}\n\n'
+ stripped = content.lstrip()
+ # Auto-detect … in content stream.
+ # Covers Qwen3-derived models (Qwopus, QwQ forks) whose
+ # names don't match _THINKING_MODEL_PATTERNS but still
+ # emit literal markup via llama.cpp --jinja.
+ if not _first_content_sent and not _thinking_model and not _in_think_tag and stripped.lower().startswith("")
+ if close_idx != -1:
+ # Split: up-to- → thinking, remainder → content
+ think_part = content[:close_idx]
+ if not _think_open_stripped:
+ # Strip the opening from the first chunk.
+ # Use a dedicated flag — _first_content_sent stays False
+ # throughout the think block, so it must not be reused.
+ tag_end = think_part.lower().find(">")
+ if tag_end != -1:
+ think_part = think_part[tag_end + 1:]
+ _think_open_stripped = True
+ regular_part = content[close_idx + len(""):]
+ _in_think_tag = False
+ if think_part:
+ yield f'data: {json.dumps({"delta": think_part, "thinking": True})}\n\n'
+ if regular_part:
+ _first_content_sent = True
+ yield f'data: {json.dumps({"delta": regular_part})}\n\n'
+ else:
+ # Still inside : route to thinking channel
+ if not _think_open_stripped:
+ # Strip the opening tag (first chunk only)
+ tag_end = stripped.lower().find(">")
+ if tag_end != -1:
+ content = stripped[tag_end + 1:]
+ _think_open_stripped = True
+ if content:
+ yield f'data: {json.dumps({"delta": content, "thinking": True})}\n\n'
+ else:
+ # Some thinking backends start normal content with a
+ # stray closing tag. Repair only that shape; do not
+ # wrap every first token for model families like
+ # MiniMax, which often stream ordinary answers.
+ if _thinking_model and not _first_content_sent and stripped.lower().startswith("" + content
+ _first_content_sent = True
+ yield f'data: {json.dumps({"delta": content})}\n\n'
# Native tool calls — accumulate across chunks
for tc in delta.get("tool_calls") or []:
if tc is None:
diff --git a/src/mcp_manager.py b/src/mcp_manager.py
index 811094f..03bcf18 100644
--- a/src/mcp_manager.py
+++ b/src/mcp_manager.py
@@ -8,6 +8,7 @@ Each server exposes tools that are made available to the agent loop.
import json
import logging
import os
+import re
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@@ -30,6 +31,64 @@ def _format_mcp_connection_error(name: str, command: str = "", args: Optional[Li
return raw_error
+# Caps for rendering untrusted MCP tool schemas into the agent prompt (issue #2660).
+# MCP servers are third-party/user-added, so field names and parameter counts are
+# untrusted input — bound them so an odd or hostile schema cannot distort the prompt.
+_MCP_PARAM_MAX = 12 # max params rendered per tool
+_MCP_TOKEN_MAX = 40 # max chars per rendered name / type token
+_MCP_HINT_MAX = 300 # total-length backstop for the whole hint
+
+
+def _sanitize_schema_token(value: Any, limit: int = _MCP_TOKEN_MAX) -> str:
+ """Make an untrusted JSON-Schema token safe to splice into the prompt.
+
+ Replaces control chars / newlines with a space, collapses whitespace, and
+ length-caps the result, so a weird field name or type cannot inject newlines
+ or run on. Normal short identifiers pass through unchanged.
+ """
+ text = re.sub(r"[\x00-\x1f\x7f]+", " ", str(value))
+ text = re.sub(r"\s+", " ", text).strip()
+ if len(text) > limit:
+ text = text[:limit].rstrip() + "…"
+ return text
+
+
+def _format_mcp_params(input_schema: Any) -> str:
+ """Render an MCP tool's JSON-Schema inputs as a compact prompt hint.
+
+ Without this the agent only sees a tool's name + description and has to
+ guess its arguments (issue #2509). Produces e.g.
+ ` Args (JSON): {"path": string (required), "limit": integer}` — names,
+ coarse types, and required-ness, kept short so it stays prompt-friendly.
+ Returns "" when there are no parameters.
+
+ MCP servers are third-party, so names/types are sanitized and the parameter
+ count + total length are capped (issue #2660); normal schemas are unaffected.
+ """
+ if not isinstance(input_schema, dict):
+ return ""
+ props = input_schema.get("properties")
+ if not isinstance(props, dict) or not props:
+ return ""
+ required = set(input_schema.get("required") or [])
+ parts = []
+ for pname, pinfo in list(props.items())[:_MCP_PARAM_MAX]:
+ pinfo = pinfo if isinstance(pinfo, dict) else {}
+ ptype = pinfo.get("type") or "any"
+ if isinstance(ptype, list):
+ ptype = "|".join(str(x) for x in ptype)
+ tag = f'"{_sanitize_schema_token(pname)}": {_sanitize_schema_token(ptype)}'
+ if pname in required:
+ tag += " (required)"
+ parts.append(tag)
+ extra = len(props) - len(parts)
+ if extra > 0:
+ parts.append(f"…+{extra} more")
+ hint = " Args (JSON): {" + ", ".join(parts) + "}"
+ if len(hint) > _MCP_HINT_MAX:
+ hint = hint[:_MCP_HINT_MAX - 1].rstrip() + "…"
+ return hint
+
class McpManager:
"""Manages MCP server connections and tool routing."""
@@ -43,7 +102,9 @@ class McpManager:
self._sessions: Dict[str, Any] = {}
# server_id -> exit stack (for cleanup)
self._stacks: Dict[str, Any] = {}
- # Tracking updates to tools/connections for RAG indexing
+ # server_id -> background connect task (HTTP transport / OAuth)
+ self._connect_tasks: Dict[str, Any] = {}
+ # Tracking updates to tools/connections for RAG indexing / prompt cache
self._generation = 0
async def connect_server(
@@ -56,12 +117,14 @@ class McpManager:
env: Optional[Dict[str, str]] = None,
url: Optional[str] = None,
) -> bool:
- """Connect to an MCP server via stdio or SSE transport."""
+ """Connect to an MCP server via stdio, SSE, or Streamable HTTP transport."""
try:
if transport == "stdio":
res = await self._connect_stdio(server_id, name, command, args or [], env or {})
elif transport == "sse":
res = await self._connect_sse(server_id, name, url)
+ elif transport == "http":
+ res = await self._start_http_connect(server_id, name, url)
else:
logger.error(f"Unknown MCP transport: {transport}")
res = False
@@ -184,8 +247,101 @@ class McpManager:
self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
return False
+ async def _start_http_connect(self, server_id: str, name: str, url: str, wait: float = 8.0) -> bool:
+ """Begin a Streamable HTTP connect in the background. Returns within
+ `wait` seconds: True if it connected (cached-token path), otherwise the
+ flow is awaiting browser authorization and status becomes 'needs_auth'."""
+ import asyncio
+ self._connections[server_id] = {"status": "connecting", "name": name, "transport": "http"}
+ task = asyncio.create_task(self._connect_http(server_id, name, url))
+ self._connect_tasks[server_id] = task
+ done, _ = await asyncio.wait({task}, timeout=wait)
+ if task in done:
+ try:
+ return task.result()
+ except Exception as e:
+ self._connections[server_id] = {"status": "error", "error": str(e), "name": name}
+ return False
+ # Still running → either awaiting authorization, or discovery/DCR is
+ # still in flight. If _on_redirect already published needs_auth+auth_url,
+ # leave it; otherwise mark needs_auth (auth_url filled in once it fires).
+ from src.mcp_oauth import pop_auth_url
+ cur = self._connections.get(server_id, {})
+ if cur.get("status") != "needs_auth":
+ self._connections[server_id] = {
+ "status": "needs_auth", "name": name, "transport": "http",
+ "auth_url": pop_auth_url(server_id),
+ }
+ return False
+
+ async def _connect_http(self, server_id: str, name: str, url: str) -> bool:
+ """Connect to a Streamable HTTP MCP server (with automatic OAuth)."""
+ try:
+ from mcp import ClientSession
+ from mcp.client.streamable_http import streamablehttp_client
+ from contextlib import AsyncExitStack
+ from src.mcp_oauth import build_provider, clear_auth_url
+
+ def _on_redirect(auth_url):
+ # Publish needs_auth the moment the URL is known, independent of
+ # how long discovery/DCR took (may exceed the bounded start wait).
+ self._connections[server_id] = {
+ "status": "needs_auth", "name": name, "transport": "http",
+ "auth_url": auth_url,
+ }
+
+ provider = build_provider(server_id, url, on_redirect=_on_redirect)
+ stack = AsyncExitStack()
+ transport = await stack.enter_async_context(streamablehttp_client(url, auth=provider))
+ read_stream, write_stream, _get_session_id = transport
+ session = await stack.enter_async_context(ClientSession(read_stream, write_stream))
+ await session.initialize()
+
+ tools_result = await session.list_tools()
+ tools = []
+ for tool in tools_result.tools:
+ tools.append({
+ "name": tool.name,
+ "description": tool.description or "",
+ "input_schema": tool.inputSchema if hasattr(tool, "inputSchema") else {},
+ })
+
+ self._sessions[server_id] = session
+ self._stacks[server_id] = stack
+ self._tools[server_id] = tools
+ self._connections[server_id] = {
+ "status": "connected", "name": name, "transport": "http",
+ "tool_count": len(tools),
+ }
+ clear_auth_url(server_id)
+ # Tools changed (this can complete after connect_server already
+ # returned, via the background OAuth flow), so bump the generation
+ # to invalidate the tool-prompt cache.
+ self._generation += 1
+ logger.info(f"MCP server connected: {name} ({server_id}) - {len(tools)} tools via http")
+ return True
+ except ImportError:
+ logger.warning("MCP package not installed. Install with: pip install mcp")
+ self._connections[server_id] = {"status": "error", "error": "mcp package not installed", "name": name}
+ return False
+ except Exception as e:
+ logger.error(f"Failed to connect HTTP MCP server {name} ({server_id}): {e}")
+ self._connections[server_id] = {"status": "error", "error": str(e), "name": name}
+ return False
+
async def disconnect_server(self, server_id: str):
"""Disconnect from an MCP server."""
+ # Cancel any in-flight HTTP/OAuth background connect so it stops
+ # publishing status for a server that may be getting deleted.
+ task = self._connect_tasks.pop(server_id, None)
+ if task is not None and not task.done():
+ task.cancel()
+ try:
+ from src.mcp_oauth import clear_auth_url
+ clear_auth_url(server_id)
+ except Exception:
+ pass
+
stack = self._stacks.pop(server_id, None)
if stack:
try:
@@ -376,6 +532,7 @@ class McpManager:
"name": tool["name"],
"qualified_name": f"mcp__{server_id}__{tool['name']}",
"description": tool.get("description", ""),
+ "input_schema": tool.get("input_schema") or {},
"is_disabled": tool["name"] in disabled,
})
return result
@@ -439,7 +596,11 @@ class McpManager:
for t in server_tools:
# Truncate long descriptions
desc = t['description'][:120] + '...' if len(t['description']) > 120 else t['description']
- lines.append(f" - {t['qualified_name']}: {desc}")
+ # Include the tool's declared inputs so the model calls it with
+ # real argument names instead of guessing from the description
+ # alone (issue #2509).
+ args_hint = _format_mcp_params(t.get("input_schema"))
+ lines.append(f" - {t['qualified_name']}: {desc}{args_hint}")
result = "\n".join(lines)
self._cached_prompt_desc = result
diff --git a/src/mcp_oauth.py b/src/mcp_oauth.py
new file mode 100644
index 0000000..9f3b2ad
--- /dev/null
+++ b/src/mcp_oauth.py
@@ -0,0 +1,193 @@
+"""mcp_oauth.py — generic OAuth for remote (Streamable HTTP) MCP servers.
+
+Bridges the mcp SDK's OAuthClientProvider (RFC 9728 discovery, Dynamic Client
+Registration, authorization-code + PKCE, token refresh) to Odysseus's web
+callback route. Tokens and the dynamic registration persist per-server,
+encrypted, so the interactive flow runs only once.
+"""
+import asyncio
+import json
+import logging
+import os
+import time
+from typing import Dict, Optional, Tuple
+from urllib.parse import urlparse, parse_qs
+
+logger = logging.getLogger(__name__)
+
+# OAuth redirect URI registered with every authorization server via DCR. Loopback
+# is allowed for native/desktop clients (RFC 8252); remote users finish via the
+# paste-back flow. Deployments not reachable at http://localhost:7000 (custom
+# port, reverse proxy, or public domain) must set OAUTH_REDIRECT_BASE_URL (or
+# APP_PUBLIC_URL) to their externally reachable origin so the redirect lands back
+# on Odysseus. APP_PORT is intentionally not used: it is only the Docker host
+# port-map; the app always listens on 7000 inside the container.
+_REDIRECT_BASE = (
+ os.environ.get("OAUTH_REDIRECT_BASE_URL")
+ or os.environ.get("APP_PUBLIC_URL")
+ or "http://localhost:7000"
+).rstrip("/")
+REDIRECT_URI = f"{_REDIRECT_BASE}/api/mcp/oauth/callback"
+
+# How long the background connect waits for the user to authorize before giving up.
+AUTH_WAIT_SECONDS = 300
+
+_pending: Dict[str, asyncio.Future] = {} # state -> Future[(code, state)]
+_pending_ts: Dict[str, float] = {} # state -> monotonic timestamp, for pruning
+_auth_urls: Dict[str, str] = {} # server_id -> authorization URL
+
+
+def _prune_stale() -> None:
+ """Drop abandoned flows whose authorization window has elapsed so the
+ module-level registries don't grow unbounded (e.g. a user who never
+ finishes the browser step)."""
+ now = time.monotonic()
+ for state in [s for s, ts in _pending_ts.items() if now - ts > AUTH_WAIT_SECONDS]:
+ fut = _pending.pop(state, None)
+ _pending_ts.pop(state, None)
+ if fut is not None and not fut.done():
+ fut.cancel()
+
+
+def _discard_pending(state: Optional[str]) -> None:
+ if state is None:
+ return
+ _pending.pop(state, None)
+ _pending_ts.pop(state, None)
+
+
+def register_pending(state: str) -> asyncio.Future:
+ _prune_stale()
+ fut = asyncio.get_running_loop().create_future()
+ _pending[state] = fut
+ _pending_ts[state] = time.monotonic()
+ return fut
+
+
+def resolve_pending(state: str, code: str) -> bool:
+ fut = _pending.get(state)
+ if fut is not None and not fut.done():
+ fut.set_result((code, state))
+ return True
+ return False
+
+
+def pop_auth_url(server_id: str) -> Optional[str]:
+ return _auth_urls.get(server_id)
+
+
+def clear_auth_url(server_id: str) -> None:
+ _auth_urls.pop(server_id, None)
+
+
+class DbTokenStorage:
+ """SDK TokenStorage backed by the encrypted McpServer.oauth_tokens column."""
+
+ def __init__(self, server_id: str, session_factory=None):
+ self.server_id = server_id
+ if session_factory is None:
+ from core.database import SessionLocal
+ session_factory = SessionLocal
+ self._sf = session_factory
+
+ def _load(self) -> dict:
+ from core.database import McpServer
+ db = self._sf()
+ try:
+ srv = db.query(McpServer).filter(McpServer.id == self.server_id).first()
+ if srv and srv.oauth_tokens:
+ return json.loads(srv.oauth_tokens)
+ finally:
+ db.close()
+ return {}
+
+ def _update(self, key: str, value: dict) -> None:
+ """Load, set one key, and persist the oauth_tokens JSON in a single
+ session/commit (avoids the load+save double round-trip per write)."""
+ from core.database import McpServer
+ db = self._sf()
+ try:
+ srv = db.query(McpServer).filter(McpServer.id == self.server_id).first()
+ if srv is None:
+ return
+ data = json.loads(srv.oauth_tokens) if srv.oauth_tokens else {}
+ data[key] = value
+ srv.oauth_tokens = json.dumps(data)
+ db.commit()
+ finally:
+ db.close()
+
+ async def get_tokens(self):
+ from mcp.shared.auth import OAuthToken
+ data = self._load().get("tokens")
+ return OAuthToken.model_validate(data) if data else None
+
+ async def set_tokens(self, tokens) -> None:
+ self._update("tokens", json.loads(tokens.model_dump_json()))
+
+ async def get_client_info(self):
+ from mcp.shared.auth import OAuthClientInformationFull
+ data = self._load().get("client_info")
+ return OAuthClientInformationFull.model_validate(data) if data else None
+
+ async def set_client_info(self, client_info) -> None:
+ self._update("client_info", json.loads(client_info.model_dump_json()))
+
+
+def build_provider(server_id: str, url: str, on_redirect=None):
+ """Construct an OAuthClientProvider that drives the browser flow via the
+ Odysseus callback route.
+
+ on_redirect(authorization_url): optional sync callback invoked the moment
+ the authorization URL is known (after discovery + DCR). The manager uses it
+ to publish 'needs_auth' + auth_url to connection state regardless of how
+ long discovery/DCR took.
+ """
+ from mcp.client.auth import OAuthClientProvider
+ from mcp.shared.auth import OAuthClientMetadata
+
+ client_metadata = OAuthClientMetadata(
+ client_name="Odysseus",
+ redirect_uris=[REDIRECT_URI],
+ grant_types=["authorization_code", "refresh_token"],
+ response_types=["code"],
+ # Leave scope unset: the SDK applies the MCP scope-selection strategy and
+ # overwrites this from the server's WWW-Authenticate / protected-resource
+ # metadata before building the auth URL. Hardcoding an OIDC scope here
+ # would break the many MCP servers that are not OpenID providers.
+ scope=None,
+ token_endpoint_auth_method="none",
+ )
+
+ async def redirect_handler(authorization_url: str) -> None:
+ state = (parse_qs(urlparse(authorization_url).query).get("state") or [None])[0]
+ if state:
+ register_pending(state)
+ _auth_urls[server_id] = authorization_url
+ if on_redirect is not None:
+ try:
+ on_redirect(authorization_url)
+ except Exception as e:
+ logger.warning(f"MCP OAuth on_redirect callback failed: {e}")
+ logger.info(f"MCP OAuth: server {server_id} awaiting authorization (state={state})")
+
+ async def callback_handler() -> Tuple[str, Optional[str]]:
+ auth_url = _auth_urls.get(server_id)
+ state = (parse_qs(urlparse(auth_url).query).get("state") or [None])[0] if auth_url else None
+ fut = _pending.get(state)
+ if fut is None:
+ raise RuntimeError("No pending OAuth flow for this server")
+ try:
+ code, ret_state = await asyncio.wait_for(fut, timeout=AUTH_WAIT_SECONDS)
+ return code, ret_state
+ finally:
+ _discard_pending(state)
+ _auth_urls.pop(server_id, None)
+
+ return OAuthClientProvider(
+ server_url=url,
+ client_metadata=client_metadata,
+ storage=DbTokenStorage(server_id),
+ redirect_handler=redirect_handler,
+ callback_handler=callback_handler,
+ )
diff --git a/src/memory_provider.py b/src/memory_provider.py
new file mode 100644
index 0000000..925c591
--- /dev/null
+++ b/src/memory_provider.py
@@ -0,0 +1,320 @@
+"""Memory provider interfaces for native and external memory systems."""
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from typing import Any, Dict, Iterable, List, Optional
+
+
+@dataclass
+class MemoryRecord:
+ """Provider-neutral memory entry."""
+
+ id: str
+ text: str
+ timestamp: int = 0
+ category: str = "fact"
+ source: str = "unknown"
+ owner: Optional[str] = None
+ session_id: Optional[str] = None
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class MemorySearchHit:
+ """A memory returned by provider recall."""
+
+ memory: MemoryRecord
+ provider_id: str
+ score: Optional[float] = None
+
+
+class MemoryProvider(ABC):
+ """Base contract for Odysseus memory providers.
+
+ The native memory provider should always be available. External providers
+ can add recall/write behavior and their own tools without replacing the
+ built-in local memory baseline.
+ """
+
+ provider_id = "unknown"
+ display_name = "Unknown"
+ enabled = True
+
+ async def initialize(self) -> None:
+ """Prepare provider resources before use."""
+
+ async def shutdown(self) -> None:
+ """Release provider resources."""
+
+ @abstractmethod
+ async def remember(
+ self,
+ text: str,
+ *,
+ owner: Optional[str] = None,
+ session_id: Optional[str] = None,
+ category: str = "fact",
+ source: str = "user",
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> MemoryRecord:
+ """Store a memory and return the stored record."""
+
+ @abstractmethod
+ async def recall(
+ self,
+ query: str,
+ *,
+ owner: Optional[str] = None,
+ top_k: int = 5,
+ ) -> List[MemorySearchHit]:
+ """Return provider memories relevant to the query."""
+
+ @abstractmethod
+ async def list_memories(
+ self,
+ *,
+ owner: Optional[str] = None,
+ limit: int = 100,
+ ) -> List[MemoryRecord]:
+ """List memories visible to the owner."""
+
+ @abstractmethod
+ async def delete(self, memory_id: str, *, owner: Optional[str] = None) -> bool:
+ """Delete a memory by ID when allowed by the provider."""
+
+ def get_tool_schemas(self) -> List[Dict[str, Any]]:
+ """Return provider-defined tool schemas when this provider is enabled."""
+ return []
+
+ async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> Any:
+ """Handle a provider-defined tool call."""
+ raise KeyError(f"Provider {self.provider_id} does not expose tool {name}")
+
+
+class NativeMemoryProvider(MemoryProvider):
+ """Provider adapter for Odysseus' built-in memory manager and vector store."""
+
+ provider_id = "native"
+ display_name = "Odysseus native memory"
+
+ _CORE_FIELDS = {
+ "id",
+ "text",
+ "timestamp",
+ "source",
+ "category",
+ "uses",
+ "owner",
+ "session_id",
+ "metadata",
+ }
+
+ def __init__(self, memory_manager, memory_vector=None):
+ self.memory_manager = memory_manager
+ self.memory_vector = memory_vector
+
+ def _to_record(self, entry: Dict[str, Any]) -> MemoryRecord:
+ metadata = {
+ key: value
+ for key, value in entry.items()
+ if key not in self._CORE_FIELDS
+ }
+ stored_metadata = entry.get("metadata")
+ if isinstance(stored_metadata, dict):
+ metadata.update(stored_metadata)
+
+ return MemoryRecord(
+ id=entry.get("id", ""),
+ text=entry.get("text", ""),
+ timestamp=entry.get("timestamp", 0),
+ category=entry.get("category", "fact"),
+ source=entry.get("source", "unknown"),
+ owner=entry.get("owner"),
+ session_id=entry.get("session_id"),
+ metadata=metadata,
+ )
+
+ async def remember(
+ self,
+ text: str,
+ *,
+ owner: Optional[str] = None,
+ session_id: Optional[str] = None,
+ category: str = "fact",
+ source: str = "user",
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> MemoryRecord:
+ entry = self.memory_manager.add_entry(
+ text,
+ source=source,
+ category=category,
+ owner=owner,
+ )
+ if session_id:
+ entry["session_id"] = session_id
+ if metadata:
+ entry["metadata"] = dict(metadata)
+
+ memories = self.memory_manager.load_all()
+ memories.append(entry)
+ self.memory_manager.save(memories)
+
+ if self._vector_available():
+ self.memory_vector.add(entry["id"], entry["text"])
+
+ return self._to_record(entry)
+
+ async def recall(
+ self,
+ query: str,
+ *,
+ owner: Optional[str] = None,
+ top_k: int = 5,
+ ) -> List[MemorySearchHit]:
+ memories = self.memory_manager.load(owner=owner)
+ by_id = {m.get("id"): m for m in memories}
+
+ if self._vector_available():
+ hits: List[MemorySearchHit] = []
+ for result in self.memory_vector.search(query, k=top_k):
+ if not isinstance(result, dict):
+ continue
+ memory_id = result.get("memory_id")
+ entry = by_id.get(memory_id) if memory_id else result
+ if not entry:
+ continue
+ if owner is not None and entry.get("owner") != owner:
+ continue
+ hits.append(
+ MemorySearchHit(
+ memory=self._to_record(entry),
+ provider_id=self.provider_id,
+ score=result.get("score"),
+ )
+ )
+ if hits:
+ return hits
+
+ fallback = self.memory_manager.get_relevant_memories(
+ query,
+ memories,
+ max_items=top_k,
+ )
+ return [
+ MemorySearchHit(
+ memory=self._to_record(entry),
+ provider_id=self.provider_id,
+ score=None,
+ )
+ for entry in fallback
+ ]
+
+ async def list_memories(
+ self,
+ *,
+ owner: Optional[str] = None,
+ limit: int = 100,
+ ) -> List[MemoryRecord]:
+ return [
+ self._to_record(entry)
+ for entry in self.memory_manager.load(owner=owner)[:limit]
+ ]
+
+ async def delete(self, memory_id: str, *, owner: Optional[str] = None) -> bool:
+ memories = self.memory_manager.load_all()
+ remaining = []
+ deleted_id = None
+
+ for entry in memories:
+ if entry.get("id") != memory_id:
+ remaining.append(entry)
+ continue
+ if owner is not None and entry.get("owner") != owner:
+ remaining.append(entry)
+ continue
+ deleted_id = entry.get("id")
+
+ if deleted_id is None:
+ return False
+
+ self.memory_manager.save(remaining)
+ if self._vector_available():
+ self.memory_vector.remove(deleted_id)
+ return True
+
+ def _vector_available(self) -> bool:
+ return bool(self.memory_vector and getattr(self.memory_vector, "healthy", True))
+
+
+class MemoryProviderRegistry:
+ """Container for native and optional external memory providers."""
+
+ def __init__(self, providers: Optional[Iterable[MemoryProvider]] = None):
+ self._providers: Dict[str, MemoryProvider] = {}
+ for provider in providers or []:
+ self.register(provider)
+
+ def register(self, provider: MemoryProvider) -> None:
+ if provider.provider_id in self._providers:
+ raise ValueError(f"Memory provider already registered: {provider.provider_id}")
+ self._providers[provider.provider_id] = provider
+
+ def get(self, provider_id: str) -> MemoryProvider:
+ return self._providers[provider_id]
+
+ def all(self) -> List[MemoryProvider]:
+ return list(self._providers.values())
+
+ def active(self) -> List[MemoryProvider]:
+ return [provider for provider in self._providers.values() if provider.enabled]
+
+ def get_tool_schemas(self) -> List[Dict[str, Any]]:
+ schemas: List[Dict[str, Any]] = []
+ seen: Dict[str, str] = {}
+
+ for provider in self.active():
+ for schema in provider.get_tool_schemas():
+ name = self._tool_name(schema)
+ if name in seen:
+ raise ValueError(
+ f"Memory tool name conflict: {name} from "
+ f"{provider.provider_id} already exposed by {seen[name]}"
+ )
+ seen[name] = provider.provider_id
+ schemas.append(schema)
+
+ return schemas
+
+ async def handle_tool_call(self, name: str, arguments: Dict[str, Any]) -> Any:
+ provider_by_tool: Dict[str, MemoryProvider] = {}
+ for provider in self.active():
+ for schema in provider.get_tool_schemas():
+ tool_name = self._tool_name(schema)
+ if tool_name in provider_by_tool:
+ raise ValueError(
+ f"Memory tool name conflict: {tool_name} from "
+ f"{provider.provider_id} already exposed by "
+ f"{provider_by_tool[tool_name].provider_id}"
+ )
+ provider_by_tool[tool_name] = provider
+
+ provider = provider_by_tool.get(name)
+ if provider:
+ return await provider.handle_tool_call(name, arguments)
+ raise KeyError(f"No active memory provider exposes tool {name}")
+
+ @staticmethod
+ def _tool_name(schema: Dict[str, Any]) -> str:
+ if not isinstance(schema, dict):
+ raise ValueError("Memory provider tool schema must be a dict")
+ name = schema.get("name")
+ if isinstance(name, str) and name:
+ return name
+ function = schema.get("function")
+ if isinstance(function, dict):
+ function_name = function.get("name")
+ if isinstance(function_name, str) and function_name:
+ return function_name
+ raise ValueError("Memory provider tool schema is missing a tool name")
diff --git a/src/model_context.py b/src/model_context.py
index c985d3d..3a445fe 100644
--- a/src/model_context.py
+++ b/src/model_context.py
@@ -7,7 +7,7 @@ Provides token estimation for context usage tracking.
import logging
import sys
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse
@@ -208,27 +208,32 @@ KNOWN_CONTEXT_WINDOWS = {
# ---------------------------------------------------------------------------
# Cache
# ---------------------------------------------------------------------------
-_context_cache: Dict[str, int] = {}
+_context_cache: Dict[Tuple[str, str], int] = {}
def get_context_length(endpoint_url: str, model: str) -> int:
"""Get the context window size for a model.
Queries /v1/models on the endpoint and looks for context_length
- or context_window fields. Caches result per model ID.
+ or context_window fields. Caches result per (endpoint, model).
Falls back to DEFAULT_CONTEXT if unavailable.
"""
configured_kind = _configured_endpoint_kind(endpoint_url)
is_local = _is_local_endpoint(endpoint_url)
- if not is_local and model in _context_cache:
- return _context_cache[model]
+ # Key on (endpoint_url, model): the same model id can be served by two
+ # different remote endpoints with different real context windows (e.g. a
+ # capped proxy vs. the full provider), so caching by model id alone would
+ # serve one endpoint's window for the other (issue #2603).
+ cache_key = (endpoint_url, model)
+ if not is_local and cache_key in _context_cache:
+ return _context_cache[cache_key]
ctx = _query_context_length(endpoint_url, model)
# Only cache non-default values to allow retry on next request.
# Local endpoints can restart with a different --max-model-len while keeping
# the same model id, so always re-query them instead of serving stale cache.
if not is_local and (ctx != DEFAULT_CONTEXT or configured_kind in ("api", "proxy")):
- _context_cache[model] = ctx
+ _context_cache[cache_key] = ctx
logger.info(f"Context length for {model}: {ctx}")
return ctx
@@ -282,6 +287,16 @@ def _query_context_length(endpoint_url: str, model: str) -> int:
except Exception:
pass
+ # GitHub Copilot's /models requires auth + X-GitHub-Api-Version headers that
+ # aren't available here; an unauthenticated probe just 400s. All Copilot
+ # picker models are major API models covered by the known-context table, so
+ # rely on that instead of a doomed network call.
+ from src.copilot import is_copilot_base
+ if is_copilot_base(endpoint_url):
+ if known:
+ logger.info(f"Using known context window for {model}: {known}")
+ return known or DEFAULT_CONTEXT
+
models_url = endpoint_url.replace("/chat/completions", "/models")
try:
r = httpx.get(models_url, timeout=REQUEST_TIMEOUT)
diff --git a/src/research_handler.py b/src/research_handler.py
index f5d7f83..bec9695 100644
--- a/src/research_handler.py
+++ b/src/research_handler.py
@@ -722,6 +722,18 @@ class ResearchHandler:
minimum=1,
maximum=12,
)
+ _planning_timeout = _bounded_int(
+ get_setting("research_planning_timeout_seconds", _extraction_timeout),
+ default=_extraction_timeout,
+ minimum=15,
+ maximum=3600,
+ )
+ _query_timeout = _bounded_int(
+ get_setting("research_query_timeout_seconds", _extraction_timeout),
+ default=_extraction_timeout,
+ minimum=15,
+ maximum=3600,
+ )
researcher = DeepResearcher(
llm_endpoint=llm_endpoint,
@@ -732,6 +744,8 @@ class ResearchHandler:
max_time=max_time,
max_report_tokens=_max_report_tokens,
extraction_timeout=_extraction_timeout,
+ planning_timeout=_planning_timeout,
+ query_timeout=_query_timeout,
extraction_concurrency=_extraction_concurrency,
progress_callback=progress_callback,
search_provider=search_provider,
diff --git a/src/search/analytics.py b/src/search/analytics.py
index 58aa1b0..93b8114 100644
--- a/src/search/analytics.py
+++ b/src/search/analytics.py
@@ -1,141 +1,12 @@
-"""Search analytics, metrics tracking, and exception hierarchy."""
+"""Compatibility re-export shim for the live analytics module.
-import json
-import logging
-from collections import Counter
-from pathlib import Path
-from typing import Dict, Any
+The real implementation lives in :mod:`services.search.analytics`, which is
+what the search runtime imports. Alias this module to that implementation so
+mutable module state such as ``ANALYTICS_FILE`` cannot drift out of sync.
+"""
-from .cache import cache_metrics
+import sys
-logger = logging.getLogger(__name__)
+from services.search import analytics as _analytics
-# Dedicated error logger with file handler
-_error_log_path = Path(__file__).resolve().parent.parent / "search_engine_error.log"
-_error_handler = logging.FileHandler(_error_log_path, encoding="utf-8")
-_error_handler.setLevel(logging.WARNING)
-_error_handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s"))
-error_logger = logging.getLogger("search_engine_error")
-error_logger.addHandler(_error_handler)
-error_logger.propagate = False
-
-# Analytics file
-ANALYTICS_FILE = Path(__file__).resolve().parent.parent / "search_analytics.json"
-
-
-# ----------------------------------------------------------------------
-# Custom exception hierarchy
-# ----------------------------------------------------------------------
-class SearchEngineError(Exception):
- """Base class for all search-engine related errors."""
-
-
-class NetworkError(SearchEngineError):
- """Raised when a network request fails (e.g., timeout, DNS error)."""
-
-
-class ParseError(SearchEngineError):
- """Raised when HTML or other content cannot be parsed."""
-
-
-class RateLimitError(SearchEngineError):
- """Raised when the remote service returns a rate-limit (HTTP 429)."""
-
-
-# ----------------------------------------------------------------------
-# Analytics helpers
-# ----------------------------------------------------------------------
-def _default_analytics() -> Dict[str, Any]:
- """A fresh analytics document with every counter present."""
- return {
- "total_queries": 0,
- "successful_queries": 0,
- "failed_queries": 0,
- "cache_hits": 0,
- "cache_misses": 0,
- "query_patterns": {},
- }
-
-
-def _load_analytics() -> Dict[str, Any]:
- """Load analytics data from the JSON file, creating defaults if missing."""
- if not ANALYTICS_FILE.exists():
- default = _default_analytics()
- _save_analytics(default)
- return default
- try:
- with open(ANALYTICS_FILE, "r", encoding="utf-8") as f:
- data = json.load(f)
- # Merge over defaults so a file written by an older schema (or a
- # partial write) still has every counter — _record_query indexes
- # these keys directly and would otherwise raise KeyError.
- merged = _default_analytics()
- if isinstance(data, dict):
- merged.update(data)
- return merged
- except Exception as e:
- logger.warning(f"Failed to load analytics file: {e}")
- return _default_analytics()
-
-
-def _save_analytics(data: Dict[str, Any]) -> None:
- """Persist analytics data to the JSON file."""
- try:
- with open(ANALYTICS_FILE, "w", encoding="utf-8") as f:
- json.dump(data, f, indent=2)
- except Exception as e:
- logger.warning(f"Failed to write analytics file: {e}")
-
-
-def _record_query(query: str, success: bool, cache_hit: bool) -> None:
- """Update analytics for a single query execution."""
- analytics = _load_analytics()
- analytics["total_queries"] += 1
- if success:
- analytics["successful_queries"] += 1
- else:
- analytics["failed_queries"] += 1
-
- if cache_hit:
- analytics["cache_hits"] += 1
- cache_metrics["hits"] += 1
- else:
- analytics["cache_misses"] += 1
- cache_metrics["misses"] += 1
-
- patterns = analytics["query_patterns"]
- entry = patterns.get(query, {"count": 0, "successes": 0})
- entry["count"] += 1
- if success:
- entry["successes"] += 1
- patterns[query] = entry
-
- _save_analytics(analytics)
-
-
-def get_search_stats() -> Dict[str, Any]:
- """Return aggregated search analytics."""
- analytics = _load_analytics()
- total = analytics.get("total_queries", 0) or 1
- success_rate = analytics.get("successful_queries", 0) / total
- cache_total = analytics.get("cache_hits", 0) + analytics.get("cache_misses", 0) or 1
- cache_hit_rate = analytics.get("cache_hits", 0) / cache_total
-
- pattern_counter = Counter({
- q: data["count"] for q, data in analytics.get("query_patterns", {}).items()
- })
- most_common = [q for q, _ in pattern_counter.most_common(5)]
-
- return {
- "most_common_queries": most_common,
- "success_rate": success_rate,
- "cache_hit_rate": cache_hit_rate,
- "total_queries": analytics.get("total_queries", 0),
- "successful_queries": analytics.get("successful_queries", 0),
- "failed_queries": analytics.get("failed_queries", 0),
- "cache_hits": analytics.get("cache_hits", 0),
- "cache_misses": analytics.get("cache_misses", 0),
- "cache_evictions": cache_metrics["evictions"],
- "runtime_cache_hits": cache_metrics["hits"],
- "runtime_cache_misses": cache_metrics["misses"],
- }
+sys.modules[__name__] = _analytics
diff --git a/src/search/cache.py b/src/search/cache.py
index 11fe722..e66aaff 100644
--- a/src/search/cache.py
+++ b/src/search/cache.py
@@ -1,57 +1,11 @@
-"""Search and content caching with LRU eviction."""
+"""Compatibility wrapper for the canonical services.search.cache module.
-import hashlib
-import logging
-from datetime import datetime, timedelta
-from pathlib import Path
-from typing import Dict
+``src.search.cache`` stays importable for older agent/deep-research code, but the
+implementation now lives in ``services.search.cache`` so the two cannot drift.
+"""
-logger = logging.getLogger(__name__)
+import sys
-# Cache directories
-CACHE_DIR = Path(__file__).resolve().parent.parent / "cache"
-SEARCH_CACHE_DIR = CACHE_DIR / "search"
-CONTENT_CACHE_DIR = CACHE_DIR / "content"
-CACHE_MAX_ENTRIES = 1000
+from services.search import cache as _cache
-# Create cache directories
-SEARCH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
-CONTENT_CACHE_DIR.mkdir(parents=True, exist_ok=True)
-
-# Track cache size for LRU eviction
-search_cache_index: Dict[str, datetime] = {}
-content_cache_index: Dict[str, datetime] = {}
-
-# Cache metrics (shared across modules)
-cache_metrics = {"hits": 0, "misses": 0, "evictions": 0}
-
-
-def generate_cache_key(data: str) -> str:
- """Generate a unique cache key using SHA-256 hash."""
- return hashlib.sha256(data.encode("utf-8")).hexdigest()
-
-
-def cleanup_cache(cache_dir: Path, cache_index: Dict[str, datetime], max_age: timedelta):
- """Remove expired cache entries and enforce LRU policy."""
- current_time = datetime.now()
- files_in_dir = {f.name.split(".")[0]: f for f in cache_dir.glob("*.cache")}
-
- to_remove = []
- for key, timestamp in list(cache_index.items()):
- if current_time - timestamp > max_age or key not in files_in_dir:
- to_remove.append(key)
- if key in files_in_dir:
- files_in_dir[key].unlink(missing_ok=True)
-
- for key in to_remove:
- cache_index.pop(key, None)
- cache_metrics["evictions"] += 1
-
- if len(cache_index) > CACHE_MAX_ENTRIES:
- sorted_items = sorted(cache_index.items(), key=lambda x: x[1])
- excess_count = len(cache_index) - CACHE_MAX_ENTRIES
- for key, _ in sorted_items[:excess_count]:
- cache_index.pop(key, None)
- cache_file = cache_dir / f"{key}.cache"
- cache_file.unlink(missing_ok=True)
- cache_metrics["evictions"] += 1
+sys.modules[__name__] = _cache
diff --git a/src/search/content.py b/src/search/content.py
index 42f8e34..971d4c2 100644
--- a/src/search/content.py
+++ b/src/search/content.py
@@ -1,419 +1,11 @@
-"""Webpage content fetching with caching, PDF extraction, and summarization helpers."""
+"""Compatibility wrapper for the canonical services.search.content module.
-import copy
-import io
-import ipaddress
-import json
-import os
-import re
-import logging
-import socket
-from datetime import datetime, timedelta
-from typing import List
-from urllib.parse import urljoin, urlparse
+``src.search.content`` stays importable for older agent/deep-research code, but the
+implementation now lives in ``services.search.content`` so the two cannot drift.
+"""
-import httpx
-from bs4 import BeautifulSoup
+import sys
-from .analytics import RateLimitError, error_logger
-from .cache import (
- CONTENT_CACHE_DIR,
- content_cache_index,
- generate_cache_key,
- cleanup_cache,
-)
+from services.search import content as _content
-logger = logging.getLogger(__name__)
-
-_PRIVATE_NETWORKS = (
- ipaddress.ip_network("0.0.0.0/8"),
- ipaddress.ip_network("10.0.0.0/8"),
- ipaddress.ip_network("127.0.0.0/8"),
- ipaddress.ip_network("169.254.0.0/16"),
- ipaddress.ip_network("172.16.0.0/12"),
- ipaddress.ip_network("192.168.0.0/16"),
- ipaddress.ip_network("::1/128"),
- ipaddress.ip_network("fc00::/7"),
- ipaddress.ip_network("fe80::/10"),
-)
-
-
-def _is_private_address(addr: ipaddress._BaseAddress) -> bool:
- if isinstance(addr, ipaddress.IPv6Address) and addr.ipv4_mapped is not None:
- addr = addr.ipv4_mapped
- return (
- addr.is_private
- or addr.is_loopback
- or addr.is_link_local
- or addr.is_reserved
- or addr.is_multicast
- or addr.is_unspecified
- or any(addr in net for net in _PRIVATE_NETWORKS)
- )
-
-
-def _resolve_hostname_ips(hostname: str) -> List[ipaddress._BaseAddress]:
- ips = []
- for family, _, _, _, sockaddr in socket.getaddrinfo(hostname, None):
- if family in (socket.AF_INET, socket.AF_INET6):
- ips.append(ipaddress.ip_address(sockaddr[0]))
- return ips
-
-
-def _public_http_url(url: str) -> bool:
- parsed = urlparse(url)
- if parsed.scheme not in ("http", "https") or not parsed.hostname:
- return False
- host = parsed.hostname.strip().lower()
- if host in ("localhost", "metadata.google.internal", "metadata"):
- return False
- if host.endswith((".local", ".localhost", ".internal", ".lan", ".intranet")):
- return False
- try:
- return not _is_private_address(ipaddress.ip_address(host))
- except ValueError:
- pass
- try:
- ips = _resolve_hostname_ips(host)
- except OSError:
- return False
- # Fail closed: a hostname that resolves to nothing is treated as
- # non-public (an empty all(...) would otherwise return True).
- return bool(ips) and all(not _is_private_address(ip) for ip in ips)
-
-
-def _get_public_url(url: str, *, headers: dict, timeout: int) -> httpx.Response:
- if not _public_http_url(url):
- raise httpx.RequestError(f"Blocked non-public URL: {url}")
-
- current = url
- with httpx.Client(headers=headers, timeout=timeout, follow_redirects=False) as client:
- for _ in range(8):
- response = client.get(current)
- if response.status_code not in (301, 302, 303, 307, 308):
- return response
- location = response.headers.get("location")
- if not location:
- return response
- current = urljoin(current, location)
- if not _public_http_url(current):
- raise httpx.RequestError(f"Blocked redirect to non-public URL: {current}")
- raise httpx.RequestError("Too many redirects")
-
-# PDF extraction (optional dependency)
-try:
- from pdfminer.high_level import extract_text as pdf_extract_text
-except ImportError:
- pdf_extract_text = None # type: ignore
-
-
-# ----------------------------------------------------------------------
-# HTML extraction helpers
-# ----------------------------------------------------------------------
-def _extract_meta(soup: BeautifulSoup) -> dict:
- """Pull meta description and keywords if present."""
- description = ""
- keywords = ""
- desc_tag = soup.find("meta", attrs={"name": re.compile("description", re.I)})
- if desc_tag and desc_tag.get("content"):
- description = desc_tag["content"].strip()
- kw_tag = soup.find("meta", attrs={"name": re.compile("keywords", re.I)})
- if kw_tag and kw_tag.get("content"):
- keywords = kw_tag["content"].strip()
- return {"description": description, "keywords": keywords}
-
-
-def _extract_og_image(soup: BeautifulSoup) -> str:
- """Extract the best representative image URL from meta tags.
-
- Only returns absolute http(s) URLs — skips relative paths and data URIs.
- """
- candidates = []
- # Open Graph image (most reliable)
- for prop in ("og:image", "og:image:url", "og:image:secure_url"):
- tag = soup.find("meta", attrs={"property": prop})
- if tag and tag.get("content", "").strip():
- candidates.append(tag["content"].strip())
- # Twitter card image
- tag = soup.find("meta", attrs={"name": "twitter:image"})
- if tag and tag.get("content", "").strip():
- candidates.append(tag["content"].strip())
- # Thumbnail meta
- tag = soup.find("meta", attrs={"name": "thumbnail"})
- if tag and tag.get("content", "").strip():
- candidates.append(tag["content"].strip())
- # Return first absolute http(s) URL
- for url in candidates:
- if url.startswith(("https://", "http://")) and not url.endswith((".svg", ".ico")):
- return url
- return ""
-
-
-def _extract_lists(soup: BeautifulSoup) -> List[List[str]]:
- """Return a list of lists, each inner list representing a
/."""
- all_lists = []
- for lst in soup.find_all(["ul", "ol"]):
- items = [li.get_text(separator=" ", strip=True) for li in lst.find_all("li")]
- if items:
- all_lists.append(items)
- return all_lists
-
-
-def _extract_tables(soup: BeautifulSoup) -> List[List[List[str]]]:
- """Return a list of tables, each table is a list of rows, each row a list of cell texts."""
- tables_data = []
- for table in soup.find_all("table"):
- rows = []
- for tr in table.find_all("tr"):
- cells = [td.get_text(separator=" ", strip=True) for td in tr.find_all(["td", "th"])]
- if cells:
- rows.append(cells)
- if rows:
- tables_data.append(rows)
- return tables_data
-
-
-def _extract_code_blocks(soup: BeautifulSoup) -> List[str]:
- """Collect text from
and blocks."""
- blocks = []
- for tag in soup.find_all(["pre", "code"]):
- txt = tag.get_text(separator=" ", strip=True)
- if txt:
- blocks.append(txt)
- return blocks
-
-
-def _detect_js_frameworks(soup: BeautifulSoup) -> bool:
- """Very naive detection of common JS frameworks."""
- js_indicators = [
- "react", "angular", "vue", "svelte", "next", "nuxt",
- "ember", "backbone", "jquery", "polymer", "mithril",
- ]
- for script in soup.find_all("script"):
- src = script.get("src", "").lower()
- if any(fr in src for fr in js_indicators):
- return True
- if script.string:
- content = script.string.lower()
- if any(fr in content for fr in js_indicators):
- return True
- if soup.find(attrs={"data-reactroot": True}) or soup.find(attrs={"ng-app": True}):
- return True
- return False
-
-
-def _empty_result(url: str, error: str = "") -> dict:
- """Build a standard failure result dict."""
- return {
- "url": url,
- "title": "",
- "content": "",
- "lists": [],
- "tables": [],
- "code_blocks": [],
- "meta_description": "",
- "meta_keywords": "",
- "js_rendered": False,
- "js_message": "",
- "success": False,
- "error": error,
- }
-
-
-# ----------------------------------------------------------------------
-# Main content fetcher
-# ----------------------------------------------------------------------
-def fetch_webpage_content(url: str, timeout: int = 5, retry_attempt: int = 0) -> dict:
- """Fetch and extract meaningful content from a webpage with caching."""
- cache_key = generate_cache_key(url)
- cache_file = CONTENT_CACHE_DIR / f"{cache_key}.cache"
-
- # Check cache
- if cache_file.exists():
- try:
- with open(cache_file, "r", encoding="utf-8") as f:
- cached_data = json.load(f)
- timestamp = datetime.fromisoformat(cached_data["timestamp"])
- if datetime.now() - timestamp < timedelta(hours=2):
- logger.debug(f"Content cache hit for URL: {url}")
- return cached_data["data"]
- else:
- cache_file.unlink(missing_ok=True)
- content_cache_index.pop(cache_key, None)
- except Exception as e:
- logger.warning(f"Failed to read content cache for {url}: {e}")
- cache_file.unlink(missing_ok=True)
- content_cache_index.pop(cache_key, None)
-
- # Fetch
- try:
- headers = {
- "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36",
- "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
- "Accept-Language": "en-US,en;q=0.5",
- "Accept-Encoding": "gzip, deflate",
- "Connection": "keep-alive",
- }
- response = _get_public_url(url, headers=headers, timeout=timeout)
-
- if response.status_code == 429:
- raise RateLimitError(f"Rate limit hit for {url} (attempt {retry_attempt})")
-
- response.raise_for_status()
- except httpx.RequestError as e:
- error_logger.error(f"NetworkError fetching {url} (attempt {retry_attempt}): {e}")
- return _empty_result(url, f"NetworkError: {e}")
- except RateLimitError as e:
- error_logger.error(str(e))
- return _empty_result(url, str(e))
-
- # PDF handling
- content_type = response.headers.get("Content-Type", "").lower()
- if "application/pdf" in content_type or url.lower().endswith(".pdf"):
- if pdf_extract_text is None:
- logger.error("pdfminer.six is not installed; cannot extract PDF text.")
- pdf_text = ""
- else:
- try:
- pdf_bytes = io.BytesIO(response.content)
- pdf_text = pdf_extract_text(pdf_bytes)
- except Exception as e:
- logger.warning(f"PDF extraction failed for {url}: {e}")
- pdf_text = ""
- result = {
- "url": url,
- "title": os.path.basename(url),
- "content": pdf_text,
- "lists": [],
- "tables": [],
- "code_blocks": [],
- "meta_description": "",
- "meta_keywords": "",
- "js_rendered": False,
- "js_message": "",
- "success": bool(pdf_text),
- "error": "" if pdf_text else "Failed to extract PDF text",
- }
- _cache_result(cache_file, cache_key, result, url)
- return result
-
- # HTML handling
- try:
- soup = BeautifulSoup(response.text, "html.parser")
- except Exception as e:
- error_logger.error(f"ParseError parsing HTML from {url} (attempt {retry_attempt}): {e}")
- result = _empty_result(url, f"ParseError: {e}")
- _cache_result(cache_file, cache_key, result, url)
- return result
-
- title_tag = soup.find("title")
- title_text = title_tag.get_text(strip=True) if title_tag else ""
- meta_info = _extract_meta(soup)
- og_image = _extract_og_image(soup)
- js_rendered = _detect_js_frameworks(soup)
- js_message = "Page appears to be rendered by a JavaScript framework; content may be incomplete." if js_rendered else ""
-
- # Main textual content (heuristic): prefer semantic / "content"-classed
- # containers to skip nav/footer/boilerplate; tuned for article pages.
- main_content = ""
- content_areas = soup.find_all(
- ["main", "article", "section", "div"],
- class_=re.compile("content|main|body|article|post|entry|text", re.I),
- )
- if content_areas:
- for area in content_areas[:3]:
- main_content += area.get_text(separator=" ", strip=True) + " "
- main_content = re.sub(r"\s+", " ", main_content).strip()
-
- # The class heuristic can latch onto a small wrapper and miss the real
- # content (app/landing pages, or SSR sites whose body isn't in a
- # "content"-classed div, so these came back nearly empty before). When the
- # heuristic returns nothing OR suspiciously little, fall back to the full
- # , stripping scripts/styles (so JSON/JS doesn't leak into the text)
- # plus nav/header/footer/aside (boilerplate), and keep whichever yields
- # more readable text.
- THIN_CONTENT_CHARS = 600 # below this the heuristic likely missed the page
- if len(main_content) < THIN_CONTENT_CHARS:
- body = soup.find("body")
- if body:
- # Strip from a copy so the later list/table/code extractors still
- # see the original soup unmodified.
- body_copy = copy.copy(body)
- for _noise in body_copy.find_all(
- ["script", "style", "noscript", "template", "nav", "header", "footer", "aside"]
- ):
- _noise.extract()
- body_text = re.sub(r"\s+", " ", body_copy.get_text(separator=" ", strip=True)).strip()
- if len(body_text) > len(main_content):
- main_content = body_text
-
- result = {
- "url": url,
- "title": title_text,
- "content": main_content,
- "lists": _extract_lists(soup),
- "tables": _extract_tables(soup),
- "code_blocks": _extract_code_blocks(soup),
- "meta_description": meta_info.get("description", ""),
- "meta_keywords": meta_info.get("keywords", ""),
- "og_image": og_image,
- "js_rendered": js_rendered,
- "js_message": js_message,
- "success": True,
- "error": "",
- }
- _cache_result(cache_file, cache_key, result, url)
- return result
-
-
-def _cache_result(cache_file, cache_key: str, result: dict, url: str):
- """Write a result to the content cache."""
- try:
- cache_data = {"timestamp": datetime.now().isoformat(), "data": result}
- with open(cache_file, "w", encoding="utf-8") as f:
- json.dump(cache_data, f)
- content_cache_index[cache_key] = datetime.now()
- cleanup_cache(CONTENT_CACHE_DIR, content_cache_index, timedelta(hours=2))
- except Exception as e:
- logger.warning(f"Failed to write content cache for {url}: {e}")
-
-
-# ----------------------------------------------------------------------
-# Content summarization helpers
-# ----------------------------------------------------------------------
-def extract_key_points(text: str) -> List[str]:
- """Pull out bullet-style key points from a block of text."""
- points: List[str] = []
- bullet_pat = re.compile(r"^\s*[-*•]\s+(.*)")
- numbered_pat = re.compile(r"^\s*\d+[\.\)]\s+(.*)")
- for line in text.splitlines():
- m = bullet_pat.match(line) or numbered_pat.match(line)
- if m:
- points.append(m.group(1).strip())
- return points
-
-
-def get_tldr(text: str, max_sentences: int = 3) -> str:
- """Produce a very short TL;DR by taking the first few sentences."""
- sentences = re.split(r"(?<=[.!?])\s+", text)
- selected = [s.strip() for s in sentences if s][:max_sentences]
- return " ".join(selected)
-
-
-def extract_quotes(text: str) -> List[str]:
- """Return quoted excerpts that are at least 15 characters long."""
- # Backreference the opening quote so the closing quote must match it —
- # otherwise `"text'` (open double, close single) is treated as a quote.
- return [m.group(2).strip() for m in re.finditer(r'(["\'])([^"\']{15,}?)\1', text)]
-
-
-def extract_statistics(text: str) -> List[str]:
- """Find numbers, percentages, dates and simple measurements."""
- # Match a comma-grouped number (1,000,000) OR a plain digit run (50000) —
- # the old `\d{1,3}(?:,\d{3})*` matched only the first 3 digits of a
- # comma-less number, and the trailing `\b` dropped a closing `%`.
- pattern = re.compile(
- r"\b(?:\d{1,3}(?:,\d{3})+|\d+)(?:\.\d+)?\s*(%|percent|‰|per cent|[a-zA-Z]+)?",
- re.IGNORECASE,
- )
- return [m.group(0).strip() for m in pattern.finditer(text)]
+sys.modules[__name__] = _content
diff --git a/src/search/query.py b/src/search/query.py
index 844d1b8..dc5299d 100644
--- a/src/search/query.py
+++ b/src/search/query.py
@@ -1,141 +1,11 @@
-"""Query enhancement, entity extraction, and cache duration helpers."""
+"""Compatibility wrapper for the canonical services.search.query module.
-import re
-import logging
-from datetime import timedelta
-from typing import Dict, List, Optional, Tuple
+``src.search.query`` stays importable for older agent/deep-research code, but the
+implementation now lives in ``services.search.query`` so the two cannot drift.
+"""
-logger = logging.getLogger(__name__)
+import sys
+from services.search import query as _query
-# ----------------------------------------------------------------------
-# Query processing helpers
-# ----------------------------------------------------------------------
-def _detect_question_type(query: str) -> Optional[str]:
- """Return the leading question word if present (who, what, when, where, why, how)."""
- if not isinstance(query, str):
- return None
- q = query.strip().lower()
- for word in ("who", "what", "when", "where", "why", "how"):
- # Require a whole-word match: a bare prefix mis-flags ordinary queries
- # like "whatsapp pricing" (-> what) or "however ..." (-> how), which
- # then get spurious boost terms OR-appended in enhance_query.
- if q == word or q.startswith(word + " "):
- return word
- return None
-
-
-def _extract_entities(query: str) -> Dict[str, List[str]]:
- """Lightweight entity extraction: capitalized words and date patterns."""
- entities: Dict[str, List[str]] = {"names": [], "dates": []}
- qtype = _detect_question_type(query)
- cleaned = query
- if qtype:
- cleaned = re.sub(rf"^{qtype}\b", "", cleaned, flags=re.I).strip()
- for token in re.findall(r"\b[A-Z][a-zA-Z]+\b", cleaned):
- entities["names"].append(token)
- for year in re.findall(r"\b(?:19|20)\d{2}\b", cleaned):
- entities["dates"].append(year)
- month_day_year = re.findall(
- r"\b(?:Jan|January|Feb|February|Mar|March|Apr|April|May|Jun|June|Jul|July|Aug|August|Sep|Sept|September|Oct|October|Nov|November|Dec|December)\s+\d{1,2},?\s*\d{4}\b",
- cleaned,
- flags=re.I,
- )
- entities["dates"].extend(month_day_year)
- return entities
-
-
-def _split_multi_part(query: str) -> List[str]:
- """Split a query into sub-queries on common conjunctions."""
- if not isinstance(query, str):
- return []
- parts = re.split(r"\s+and\s+|\s+or\s+|;", query, flags=re.I)
- return [p.strip() for p in parts if p.strip()]
-
-
-def _extract_site_filter(query: str) -> Tuple[str, Optional[str]]:
- """Detect a 'site:example.com' token. Returns (query_without_token, site_or_None)."""
- if not isinstance(query, str):
- return "", None
- match = re.search(r"\bsite:([^\s]+)", query, flags=re.I)
- if match:
- site = match.group(1)
- new_query = re.sub(r"\bsite:[^\s]+", "", query, flags=re.I).strip()
- return new_query, site
- return query, None
-
-
-def _boost_entities_in_query(base_query: str, entities: Dict[str, List[str]]) -> str:
- """Append extracted entities to the query using OR to increase relevance."""
- parts = [base_query]
- if entities.get("names"):
- parts.append(" OR ".join(f'"{n}"' for n in entities["names"]))
- if entities.get("dates"):
- parts.append(" OR ".join(f'"{d}"' for d in entities["dates"]))
- return " ".join(parts)
-
-
-def enhance_query(original_query: str) -> Tuple[str, Optional[str]]:
- """Process the original query: site filter, question type boosts, entity extraction."""
- if not isinstance(original_query, str):
- original_query = ""
- query_without_site, site = _extract_site_filter(original_query)
- sub_queries = _split_multi_part(query_without_site)
-
- enhanced_subs: List[str] = []
- for sub in sub_queries:
- qtype = _detect_question_type(sub)
- boost_keywords = []
- if qtype == "who":
- boost_keywords.append("person")
- elif qtype == "when":
- boost_keywords.append("date")
- elif qtype == "where":
- boost_keywords.append("location")
- elif qtype == "why":
- boost_keywords.append("reason")
- elif qtype == "how":
- boost_keywords.append("method")
- entities = _extract_entities(sub)
- boosted = _boost_entities_in_query(sub, entities)
- if boost_keywords:
- boosted = f'({boosted}) OR ({" OR ".join(boost_keywords)})'
- enhanced_subs.append(boosted)
-
- final_query = " AND ".join(f"({s})" for s in enhanced_subs)
- if site:
- final_query = f"{final_query} site:{site}"
- return final_query, site
-
-
-def build_enhanced_query(query: str, time_filter: str = None) -> str:
- """Build an enhanced search query with optional time filtering."""
- enhanced_query, _ = enhance_query(query)
-
- if time_filter:
- time_map = {"day": "d", "week": "w", "month": "m", "year": "y"}
- if time_filter in time_map:
- enhanced_query = f"{enhanced_query} after:{time_map[time_filter]}"
- logger.info(f"Added time filter '{time_filter}' to query")
-
- logger.info(f"Enhanced query: '{query}' -> '{enhanced_query}'")
- return enhanced_query
-
-
-# ----------------------------------------------------------------------
-# Cache duration helpers
-# ----------------------------------------------------------------------
-def _is_news_query(query: str) -> bool:
- """Lightweight heuristic to decide if a query is news-oriented."""
- if not isinstance(query, str):
- return False
- news_terms = {"news", "latest", "breaking", "today", "today's", "current", "updates", "happening"}
- tokens = set(re.findall(r"\b\w+\b", query.lower()))
- return bool(tokens & news_terms)
-
-
-def _cache_duration_for_query(query: str) -> timedelta:
- """News queries -> 30 minutes, reference queries -> 24 hours."""
- if _is_news_query(query):
- return timedelta(minutes=30)
- return timedelta(hours=24)
+sys.modules[__name__] = _query
diff --git a/src/settings.py b/src/settings.py
index 09a53c9..5bce0fc 100644
--- a/src/settings.py
+++ b/src/settings.py
@@ -85,6 +85,11 @@ DEFAULT_SETTINGS = {
"research_search_provider": "",
"research_max_tokens": 16384,
"research_extraction_timeout_seconds": 90,
+ # Lightweight planning/query LLM calls happen before any search starts.
+ # Keep them separately tunable so slow local backends are not capped by
+ # the old 30s/60s per-call defaults.
+ "research_planning_timeout_seconds": 90,
+ "research_query_timeout_seconds": 90,
"research_extraction_concurrency": 3,
# Hard wall-clock cap on a single deep-research run. The previous 600s
# (10 min) default cut off slow local / edge LLMs mid-synthesis; 1800s
@@ -95,6 +100,7 @@ DEFAULT_SETTINGS = {
# Tune via Settings or by editing data/settings.json.
"research_run_timeout_seconds": 1800,
"agent_max_tool_calls": 0,
+ "agent_max_rounds": 20, # per-message agent step cap (clamped 1..200)
"agent_input_token_budget": 6000,
# Ceiling on the *auto-derived* input budget that #1230 introduced. Has
# no effect when `agent_input_token_budget` is explicitly set (the user's
diff --git a/src/text_helpers.py b/src/text_helpers.py
index 90d66a9..733ced0 100644
--- a/src/text_helpers.py
+++ b/src/text_helpers.py
@@ -15,18 +15,33 @@ from __future__ import annotations
import re
+_THINK_TAG_NAME = r"(?:think(?:ing)?|thought)"
+
# Closed reasoning blocks. Multi-pass loop in `strip_think` handles nested
# `...` patterns some models emit.
-_THINK_CLOSED_RE = re.compile(r"[\s\S]*?\s*", re.IGNORECASE)
+_THINK_CLOSED_RE = re.compile(rf"<{_THINK_TAG_NAME}(?:\s+[^>]*)?>[\s\S]*?{_THINK_TAG_NAME}>\s*", re.IGNORECASE)
# Orphan opening or closing tags that survive after the closed-pass.
-_THINK_TAG_RE = re.compile(r"?think(?:ing)?[^>]*>\s*", re.IGNORECASE)
+_THINK_TAG_RE = re.compile(rf"?{_THINK_TAG_NAME}[^>]*>\s*", re.IGNORECASE)
# Dangling opener anywhere in the response with no closer — strip everything
# from `` to the end of string.
-_THINK_OPEN_RE = re.compile(r"[\s\S]*$", re.IGNORECASE)
+_THINK_OPEN_RE = re.compile(rf"<{_THINK_TAG_NAME}(?:\s+[^>]*)?>[\s\S]*$", re.IGNORECASE)
# Streaming models occasionally emit ``-style attributes.
# Normalize to a plain `` so the regexes above catch them.
-_THINK_ATTR_RE = re.compile(r"]*>", re.IGNORECASE)
-_THINK_ATTR_CLOSE_RE = re.compile(r"]*>", re.IGNORECASE)
+_THINK_ATTR_RE = re.compile(rf"<{_THINK_TAG_NAME}\s+[^>]*>", re.IGNORECASE)
+_THINK_ATTR_CLOSE_RE = re.compile(rf"{_THINK_TAG_NAME}\s+[^>]*>", re.IGNORECASE)
+_GEMMA_THOUGHT_OPEN_RE = re.compile(r"<\|channel>thought\s*\n?[\s\S]*$", re.IGNORECASE)
+_GEMMA_RESPONSE_CHANNEL_RE = re.compile(
+ r"<\|channel>response\s*\n?([\s\S]*?)",
+ re.IGNORECASE,
+)
+_GEMMA_RESPONSE_OPEN_RE = re.compile(r"<\|channel>response\s*\n?", re.IGNORECASE)
+_GEMMA_CHANNEL_CLOSE_RE = re.compile(r"", re.IGNORECASE)
+_THOUGHT_TAG_OPEN_RE = re.compile(r"]*)?>", re.IGNORECASE)
+_THOUGHT_TAG_CLOSE_RE = re.compile(r"", re.IGNORECASE)
+_GEMMA_THOUGHT_CHANNEL_CAPTURE_RE = re.compile(
+ r"<\|channel>thought\s*\n?([\s\S]*?)\s*",
+ re.IGNORECASE,
+)
# Qwen and a few other models prefix the response with a "Thinking Process:"
# block before the real answer.
_QWEN_THINKING_RE = re.compile(
@@ -78,6 +93,30 @@ def _strip_reasoning_prose(text: str) -> str:
return "\n\n".join(keep).strip() if keep else text
+def normalize_thinking_markup(text: str) -> str:
+ """Canonicalize supported thinking wrappers to `` markup.
+
+ The chat UI and persistence layer already understand `...`.
+ Gemma 4 may instead emit `<|channel>thought\n...`, and some
+ gateways/models emit `...`. Normalize those shapes into
+ the existing representation and strip empty thought channels.
+ """
+ if not text:
+ return text
+ out = _THOUGHT_TAG_OPEN_RE.sub(lambda m: "", text)
+ out = _THOUGHT_TAG_CLOSE_RE.sub("", out)
+
+ def _replace_gemma_thought(match: re.Match) -> str:
+ thought = match.group(1).strip()
+ return f"{thought}\n" if thought else ""
+
+ out = _GEMMA_THOUGHT_CHANNEL_CAPTURE_RE.sub(_replace_gemma_thought, out)
+ out = _GEMMA_RESPONSE_CHANNEL_RE.sub(lambda m: m.group(1), out)
+ out = _GEMMA_RESPONSE_OPEN_RE.sub("", out)
+ out = _GEMMA_CHANNEL_CLOSE_RE.sub("", out)
+ return out
+
+
def strip_think(text: str, *, prose: bool = False, prompt_echo: bool = True) -> str:
"""Strip `` blocks from model output.
@@ -92,13 +131,21 @@ def strip_think(text: str, *, prose: bool = False, prompt_echo: bool = True) ->
"The user asks:" / "We need to" leaked prompt echoes.
Robust to:
- * closed `...` (any depth, both `` and ``)
- * dangling unclosed `...`
+ * closed `...` (any depth, plus ``/``)
+ * dangling unclosed `...` / `...`
* stray opener/closer tags
* ``-style attributes
+ * Gemma 4 `<|channel>thought...` wrappers
"""
if not text:
return ""
+ # Gemma 4 thinking-capable models use channel control tokens rather than
+ # XML tags when the runtime does not split reasoning into a separate field.
+ # The thought channel can be empty in non-thinking mode; either way it is
+ # not user-facing content. A response channel, when present, is only a
+ # wrapper around the final answer.
+ text = normalize_thinking_markup(text)
+ text = _GEMMA_THOUGHT_OPEN_RE.sub("", text)
# Normalize attributes so the closed/open regexes can catch them.
text = _THINK_ATTR_RE.sub("", text)
text = _THINK_ATTR_CLOSE_RE.sub("", text)
diff --git a/src/tool_execution.py b/src/tool_execution.py
index c43fca9..e84a414 100644
--- a/src/tool_execution.py
+++ b/src/tool_execution.py
@@ -12,14 +12,127 @@ import collections
import json
import logging
import os
+import pathlib
import sys
import time
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple
from src.tool_security import is_public_blocked_tool, owner_is_admin_or_single_user
+# Persistent working directory for agent subprocesses.
+# Resolves to /data, which is the bind-mounted volume in Docker
+# (/app/data) and the local data directory for manual installs.
+# Using this as cwd and HOME prevents the agent from silently creating files
+# in ephemeral container layers that are lost on the next rebuild.
+_AGENT_WORKDIR = str(pathlib.Path(__file__).parent.parent / "data")
+
MAX_OUTPUT_CHARS = 10_000
MAX_READ_CHARS = 20_000
+MAX_DIFF_LINES = 400 # cap unified-diff size returned to the UI
+
+
+def _unified_diff(old: str, new: str, path: str) -> Optional[Dict[str, Any]]:
+ """Build a unified diff of a file write for display in the chat.
+
+ Returns {"text": , "added": N, "removed": M, "new_file": bool}
+ or None when there's no textual change. Truncates very large diffs.
+ """
+ if old == new:
+ return None
+ import difflib
+
+ old_lines = old.splitlines()
+ new_lines = new.splitlines()
+ label = path or "file"
+ diff_lines = list(difflib.unified_diff(
+ old_lines, new_lines,
+ fromfile=f"a/{label}", tofile=f"b/{label}",
+ lineterm="",
+ ))
+ added = sum(1 for l in diff_lines if l.startswith("+") and not l.startswith("+++"))
+ removed = sum(1 for l in diff_lines if l.startswith("-") and not l.startswith("---"))
+ truncated = False
+ if len(diff_lines) > MAX_DIFF_LINES:
+ diff_lines = diff_lines[:MAX_DIFF_LINES]
+ truncated = True
+ text = "\n".join(diff_lines)
+ if truncated:
+ text += f"\n… diff truncated at {MAX_DIFF_LINES} lines"
+ return {
+ "text": text,
+ "added": added,
+ "removed": removed,
+ "new_file": old == "",
+ "file": os.path.basename(path) or (path or "file"),
+ }
+
+
+async def _do_edit_file(content: str, workspace: Optional[str] = None) -> Dict[str, Any]:
+ """Exact string-replacement edit of an on-disk file.
+
+ content is JSON: {"path", "old_string", "new_string", "replace_all"?}.
+ Fails if old_string is missing or non-unique (unless replace_all) so the
+ model can't silently edit the wrong place. Returns a unified diff for the UI.
+ Confined to the workspace when one is set (same policy as write_file).
+ """
+ try:
+ args = json.loads(content) if content.strip().startswith("{") else {}
+ except (json.JSONDecodeError, TypeError):
+ args = {}
+ raw_path = (args.get("path") or "").strip()
+ old = args.get("old_string", "")
+ new = args.get("new_string", "")
+ replace_all = bool(args.get("replace_all", False))
+ if not raw_path:
+ return {"error": "edit_file: path required", "exit_code": 1}
+ # Confine to the workspace when set, else the same allowlist + sensitive-file
+ # policy as read/write_file.
+ try:
+ path = (_resolve_tool_path_in_workspace(workspace, raw_path)
+ if workspace else _resolve_tool_path(raw_path))
+ except ValueError as e:
+ return {"error": f"edit_file: {e}", "exit_code": 1}
+ if old == "":
+ return {"error": "edit_file: old_string required (use write_file to create a file)", "exit_code": 1}
+ if old == new:
+ return {"error": "edit_file: old_string and new_string are identical", "exit_code": 1}
+
+ def _apply():
+ with open(path, "r", encoding="utf-8") as f:
+ original = f.read()
+ count = original.count(old)
+ if count == 0:
+ return original, None, "not_found"
+ if count > 1 and not replace_all:
+ return original, None, f"not_unique:{count}"
+ updated = original.replace(old, new) if replace_all else original.replace(old, new, 1)
+ with open(path, "w", encoding="utf-8") as f:
+ f.write(updated)
+ return original, updated, "ok"
+
+ try:
+ original, updated, status = await asyncio.to_thread(_apply)
+ except FileNotFoundError:
+ return {"error": f"edit_file: {path}: not found (use write_file to create it)", "exit_code": 1}
+ except (IsADirectoryError, UnicodeDecodeError):
+ return {"error": f"edit_file: {path}: not an editable text file", "exit_code": 1}
+ except PermissionError:
+ return {"error": f"edit_file: {path}: permission denied", "exit_code": 1}
+ except OSError as e:
+ return {"error": f"edit_file: {path}: {e}", "exit_code": 1}
+
+ if status == "not_found":
+ return {"error": f"edit_file: old_string not found in {path}. Read the file and match it exactly.", "exit_code": 1}
+ if status.startswith("not_unique"):
+ n = status.split(":", 1)[1]
+ return {"error": f"edit_file: old_string is not unique in {path} ({n} matches). Add surrounding context or set replace_all=true.", "exit_code": 1}
+
+ n = original.count(old)
+ result = {"output": f"Edited {path} ({n} replacement{'s' if n != 1 else ''})", "exit_code": 0}
+ diff = _unified_diff(original, updated, path)
+ if diff:
+ result["diff"] = diff
+ return result
# ---------------------------------------------------------------------------
# Path confinement for read_file / write_file
@@ -158,6 +271,40 @@ def _resolve_tool_path(raw_path: str) -> str:
f"path '{raw_path}' is outside the allowed roots"
)
+
+def _resolve_tool_path_in_workspace(workspace: str, raw_path: str) -> str:
+ """Confine a model-supplied path to the active workspace.
+
+ Layered on top of upstream's path policy: the workspace is the allowed
+ root (relative paths resolve under it; paths that escape it are rejected),
+ and the sensitive-file deny list (.ssh, .gnupg, id_rsa, …) still applies
+ inside it. When no workspace is set, callers use _resolve_tool_path (the
+ default data/tmp allowlist) instead.
+ """
+ if raw_path is None or not str(raw_path).strip():
+ raise ValueError("path is required")
+ base = os.path.realpath(workspace)
+ expanded = os.path.expanduser(str(raw_path).strip())
+ candidate = expanded if os.path.isabs(expanded) else os.path.join(base, expanded)
+ resolved = os.path.realpath(candidate)
+ if _is_sensitive_path(resolved):
+ raise ValueError(
+ f"path '{raw_path}' is inside a sensitive directory "
+ f"(e.g. .ssh, .gnupg) or matches a sensitive filename"
+ )
+ if resolved != base:
+ # normcase so containment holds on case-insensitive filesystems
+ # (Windows, default macOS): it lowercases on Windows and is a no-op on
+ # POSIX. commonpath raises ValueError across Windows drives (C: vs D:)
+ # or mixed abs/rel — both mean "outside", so the except rejects them.
+ nbase = os.path.normcase(base)
+ try:
+ if os.path.commonpath([os.path.normcase(resolved), nbase]) != nbase:
+ raise ValueError
+ except ValueError:
+ raise ValueError(f"path '{raw_path}' is outside the workspace ({workspace})")
+ return resolved
+
# Bash + python tools used to share a single 60s timeout. That's
# enough for one-shot commands but starves real workloads (pip
# install, ffmpeg conversions, etc.) — and worse, the agent saw the
@@ -186,6 +333,39 @@ def get_mcp_manager():
return agent_tools.get_mcp_manager()
+# Directories ignored by the code-nav tools' Python fallbacks so results aren't
+# polluted by VCS internals / dependency trees / build caches. ripgrep already
+# honours .gitignore; this is the parity floor for the no-rg path (and the
+# explicit excludes passed to rg so it skips them even without a .gitignore).
+_CODENAV_SKIP_DIRS = frozenset({
+ ".git", ".hg", ".svn", "node_modules", "venv", ".venv", "__pycache__",
+ ".mypy_cache", ".pytest_cache", ".ruff_cache", "dist", "build",
+ ".next", ".cache", "site-packages", ".idea", ".tox",
+})
+# Per-tool result caps (keep tool output cheap + model-friendly).
+_CODENAV_MAX_HITS = 200
+_CODENAV_MAX_LINE = 400
+
+
+def _resolve_search_root(raw_path: str, workspace: Optional[str] = None) -> str:
+ """Resolve + confine a code-nav path (grep/glob/ls).
+
+ With a workspace set, the workspace folder is the root and supplied paths are
+ confined inside it (same policy as read_file). Without one, an empty path
+ defaults to the agent's primary root (project data dir) and a supplied path
+ is confined by the global allowlist + sensitive-file policy.
+ """
+ raw = (raw_path or "").strip()
+ if workspace:
+ if not raw:
+ return os.path.realpath(workspace)
+ return _resolve_tool_path_in_workspace(workspace, raw)
+ if not raw:
+ roots = _tool_path_roots()
+ return roots[0] if roots else os.path.realpath(".")
+ return _resolve_tool_path(raw)
+
+
def _truncate(text: str, limit: int = MAX_OUTPUT_CHARS) -> str:
if len(text) > limit:
return text[:limit] + f"\n... (truncated, {len(text)} chars total)"
@@ -396,11 +576,12 @@ async def _call_mcp_tool(
tool: str,
content: str,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
+ workspace: Optional[str] = None,
) -> Dict:
"""Route a legacy tool call through the MCP manager, with direct fallbacks."""
mcp = get_mcp_manager()
if not mcp:
- return await _direct_fallback(tool, content, progress_cb=progress_cb) or {"error": f"MCP manager not available for tool '{tool}'", "exit_code": 1}
+ return await _direct_fallback(tool, content, progress_cb=progress_cb, workspace=workspace) or {"error": f"MCP manager not available for tool '{tool}'", "exit_code": 1}
server_id, tool_name = _MCP_TOOL_MAP[tool]
qualified = f"mcp__{server_id}__{tool_name}"
@@ -409,7 +590,7 @@ async def _call_mcp_tool(
# If MCP server not connected, try direct fallback
if isinstance(result, dict) and result.get("exit_code") == 1 and "not connected" in result.get("error", ""):
- fallback = await _direct_fallback(tool, content, progress_cb=progress_cb)
+ fallback = await _direct_fallback(tool, content, progress_cb=progress_cb, workspace=workspace)
if fallback:
return fallback
@@ -436,6 +617,7 @@ async def _direct_fallback(
tool: str,
content: str,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
+ workspace: Optional[str] = None,
) -> Optional[Dict]:
"""In-process execution path for the eight tools that used to live as
stdio MCP servers under mcp_servers/. Those servers were deleted in
@@ -461,6 +643,7 @@ async def _direct_fallback(
"TERM": "xterm-256color",
"COLUMNS": "120",
"LINES": "40",
+ "HOME": _AGENT_WORKDIR,
}
try:
@@ -470,6 +653,7 @@ async def _direct_fallback(
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_subproc_env,
+ cwd=workspace or _AGENT_WORKDIR,
)
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
proc,
@@ -496,6 +680,7 @@ async def _direct_fallback(
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=_subproc_env,
+ cwd=workspace or _AGENT_WORKDIR,
)
stdout, stderr, rc, timed_out = await _run_subprocess_streaming(
proc,
@@ -512,14 +697,43 @@ async def _direct_fallback(
return {"output": output or "(no output)", "exit_code": rc or 0}
if tool == "read_file":
- raw_path = content.split("\n", 1)[0].strip()
+ # Args: plain path on line 1 (back-compat) OR JSON
+ # {path, offset?, limit?} where offset/limit are a 1-based line range.
+ raw_path, offset, limit = content.split("\n", 1)[0].strip(), 0, 0
+ _stripped = content.strip()
+ if _stripped.startswith("{"):
+ try:
+ _a = _json.loads(_stripped)
+ raw_path = str(_a.get("path", "")).strip()
+ offset = int(_a.get("offset") or 0)
+ limit = int(_a.get("limit") or 0)
+ except (_json.JSONDecodeError, TypeError, ValueError):
+ pass
try:
- path = _resolve_tool_path(raw_path)
+ path = (_resolve_tool_path_in_workspace(workspace, raw_path)
+ if workspace else _resolve_tool_path(raw_path))
except ValueError as e:
return {"error": f"read_file: {e}", "exit_code": 1}
try:
- # Run blocking read in a thread to keep the loop responsive
+ # Run blocking read in a thread to keep the loop responsive.
def _read():
+ if offset > 0 or limit > 0:
+ # Line-range read: slice [offset, offset+limit).
+ start = max(offset, 1)
+ out, n, budget = [], 0, MAX_READ_CHARS
+ with open(path, "r", encoding="utf-8", errors="replace") as f:
+ for i, line in enumerate(f, 1):
+ if i < start:
+ continue
+ if limit > 0 and n >= limit:
+ break
+ out.append(line)
+ n += 1
+ budget -= len(line)
+ if budget <= 0:
+ out.append(f"\n... [truncated at {MAX_READ_CHARS} chars]")
+ break
+ return "".join(out)
with open(path, "r", encoding="utf-8", errors="replace") as f:
return f.read(MAX_READ_CHARS + 1)
data = await asyncio.to_thread(_read)
@@ -527,10 +741,11 @@ async def _direct_fallback(
return {"error": f"read_file: {path}: not found", "exit_code": 1}
except PermissionError:
return {"error": f"read_file: {path}: permission denied", "exit_code": 1}
+ except IsADirectoryError:
+ return {"error": f"read_file: {path}: is a directory (use ls)", "exit_code": 1}
except OSError as e:
return {"error": f"read_file: {path}: {e}", "exit_code": 1}
- truncated = len(data) > MAX_READ_CHARS
- if truncated:
+ if not (offset > 0 or limit > 0) and len(data) > MAX_READ_CHARS:
data = data[:MAX_READ_CHARS] + f"\n... [truncated at {MAX_READ_CHARS} chars]"
return {"output": data, "exit_code": 0}
@@ -539,23 +754,226 @@ async def _direct_fallback(
raw_path = lines[0].strip()
body = lines[1] if len(lines) > 1 else ""
try:
- path = _resolve_tool_path(raw_path)
+ path = (_resolve_tool_path_in_workspace(workspace, raw_path)
+ if workspace else _resolve_tool_path(raw_path))
except ValueError as e:
return {"error": f"write_file: {e}", "exit_code": 1}
try:
def _write():
+ # Capture prior content (best-effort, text) so we can show a
+ # before/after diff. Missing/binary file → treat as empty.
+ old = ""
+ try:
+ with open(path, "r", encoding="utf-8") as f:
+ old = f.read()
+ except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError):
+ old = ""
d = os.path.dirname(path)
if d:
os.makedirs(d, exist_ok=True)
with open(path, "w", encoding="utf-8") as f:
f.write(body)
- return len(body)
- size = await asyncio.to_thread(_write)
+ return old, len(body)
+ old_content, size = await asyncio.to_thread(_write)
except PermissionError:
return {"error": f"write_file: {path}: permission denied", "exit_code": 1}
except OSError as e:
return {"error": f"write_file: {path}: {e}", "exit_code": 1}
- return {"output": f"Wrote {size} bytes to {path}", "exit_code": 0}
+ diff = _unified_diff(old_content, body, path)
+ result = {"output": f"Wrote {size} bytes to {path}", "exit_code": 0}
+ if diff:
+ result["diff"] = diff
+ return result
+
+ if tool == "grep":
+ # Args (JSON): {pattern, path?, glob?, ignore_case?, max_results?}.
+ # Bare string → treated as the pattern.
+ args: Dict[str, Any] = {}
+ _s = (content or "").strip()
+ if _s.startswith("{"):
+ try:
+ args = _json.loads(_s)
+ except _json.JSONDecodeError:
+ args = {}
+ else:
+ args = {"pattern": _s}
+ pattern = str(args.get("pattern", "")).strip()
+ if not pattern:
+ return {"error": "grep: pattern is required", "exit_code": 1}
+ ignore_case = bool(args.get("ignore_case"))
+ glob_pat = str(args.get("glob", "") or "").strip()
+ try:
+ max_hits = int(args.get("max_results") or _CODENAV_MAX_HITS)
+ except (TypeError, ValueError):
+ max_hits = _CODENAV_MAX_HITS
+ max_hits = max(1, min(max_hits, _CODENAV_MAX_HITS))
+ try:
+ root = _resolve_search_root(str(args.get("path", "")), workspace)
+ except ValueError as e:
+ return {"error": f"grep: {e}", "exit_code": 1}
+
+ def _grep():
+ import re as _re
+ import shutil
+ rg = shutil.which("rg")
+ if rg:
+ cmd = [rg, "--line-number", "--no-heading", "--color=never",
+ "--max-count", str(max_hits)]
+ if ignore_case:
+ cmd.append("--ignore-case")
+ if glob_pat:
+ cmd += ["--glob", glob_pat]
+ # Exclude junk dirs even when the tree has no .gitignore, so
+ # results match the Python fallback's skip set.
+ for _d in _CODENAV_SKIP_DIRS:
+ cmd += ["--glob", f"!**/{_d}/**"]
+ cmd += ["--regexp", pattern, root]
+ try:
+ import subprocess
+ p = subprocess.run(cmd, capture_output=True, text=True, timeout=20)
+ lines = [ln for ln in (p.stdout or "").splitlines() if ln][:max_hits]
+ return lines, None
+ except subprocess.TimeoutExpired:
+ return None, "grep: timed out"
+ except Exception as _e:
+ return None, f"grep: {_e}"
+ # Python fallback (no ripgrep): walk + regex.
+ try:
+ rx = _re.compile(pattern, _re.IGNORECASE if ignore_case else 0)
+ except _re.error as _e:
+ return None, f"grep: bad pattern: {_e}"
+ import fnmatch
+ hits = []
+ if os.path.isfile(root):
+ file_iter = [root]
+ else:
+ file_iter = []
+ for dp, dns, fns in os.walk(root):
+ dns[:] = [d for d in dns if d not in _CODENAV_SKIP_DIRS]
+ for fn in fns:
+ if glob_pat and not fnmatch.fnmatch(fn, glob_pat):
+ continue
+ file_iter.append(os.path.join(dp, fn))
+ for fp in file_iter:
+ if len(hits) >= max_hits:
+ break
+ try:
+ with open(fp, "r", encoding="utf-8", errors="strict") as f:
+ for i, line in enumerate(f, 1):
+ if rx.search(line):
+ hits.append(f"{fp}:{i}:{line.rstrip()[:_CODENAV_MAX_LINE]}")
+ if len(hits) >= max_hits:
+ break
+ except (UnicodeDecodeError, OSError):
+ continue # skip binary / unreadable
+ return hits, None
+
+ lines, err = await asyncio.to_thread(_grep)
+ if err:
+ return {"error": err, "exit_code": 1}
+ if not lines:
+ return {"output": f"No matches for {pattern!r} under {root}", "exit_code": 0}
+ out = "\n".join(ln[:_CODENAV_MAX_LINE] for ln in lines)
+ if len(lines) >= max_hits:
+ out += f"\n... [capped at {max_hits} matches]"
+ return {"output": _truncate(out), "exit_code": 0}
+
+ if tool == "glob":
+ args = {}
+ _s = (content or "").strip()
+ if _s.startswith("{"):
+ try:
+ args = _json.loads(_s)
+ except _json.JSONDecodeError:
+ args = {}
+ else:
+ args = {"pattern": _s}
+ pattern = str(args.get("pattern", "")).strip()
+ if not pattern:
+ return {"error": "glob: pattern is required", "exit_code": 1}
+ try:
+ root = _resolve_search_root(str(args.get("path", "")), workspace)
+ except ValueError as e:
+ return {"error": f"glob: {e}", "exit_code": 1}
+
+ def _glob():
+ from pathlib import Path
+ base = Path(root)
+ if not base.is_dir():
+ return None, f"glob: {root}: not a directory"
+ matched = []
+ try:
+ for p in base.rglob(pattern):
+ if set(p.relative_to(base).parts) & _CODENAV_SKIP_DIRS:
+ continue
+ try:
+ mtime = p.stat().st_mtime
+ except OSError:
+ mtime = 0
+ matched.append((mtime, str(p)))
+ if len(matched) > _CODENAV_MAX_HITS * 5:
+ break
+ except (OSError, ValueError) as _e:
+ return None, f"glob: {_e}"
+ matched.sort(key=lambda t: t[0], reverse=True) # newest first
+ return [pth for _, pth in matched[:_CODENAV_MAX_HITS]], None
+
+ paths, err = await asyncio.to_thread(_glob)
+ if err:
+ return {"error": err, "exit_code": 1}
+ if not paths:
+ return {"output": f"No files matching {pattern!r} under {root}", "exit_code": 0}
+ out = "\n".join(paths)
+ if len(paths) >= _CODENAV_MAX_HITS:
+ out += f"\n... [capped at {_CODENAV_MAX_HITS} files]"
+ return {"output": _truncate(out), "exit_code": 0}
+
+ if tool == "ls":
+ raw_path = ""
+ _s = (content or "").strip()
+ if _s.startswith("{"):
+ try:
+ raw_path = str(_json.loads(_s).get("path", "")).strip()
+ except _json.JSONDecodeError:
+ raw_path = ""
+ else:
+ raw_path = _s.split("\n", 1)[0].strip()
+ try:
+ root = _resolve_search_root(raw_path, workspace)
+ except ValueError as e:
+ return {"error": f"ls: {e}", "exit_code": 1}
+
+ def _ls():
+ if not os.path.isdir(root):
+ return None, f"ls: {root}: not a directory"
+ rows = []
+ try:
+ with os.scandir(root) as it:
+ for entry in it:
+ if entry.name.startswith("."):
+ continue
+ try:
+ is_dir = entry.is_dir(follow_symlinks=False)
+ size = entry.stat(follow_symlinks=False).st_size if not is_dir else 0
+ except OSError:
+ continue
+ rows.append((is_dir, entry.name, size))
+ except (PermissionError, OSError) as _e:
+ return None, f"ls: {_e}"
+ rows.sort(key=lambda r: (not r[0], r[1].lower())) # dirs first, then name
+ lines = [f"{root}:"]
+ for is_dir, name, size in rows[:_CODENAV_MAX_HITS]:
+ lines.append(f" {name}/" if is_dir else f" {name} ({size} B)")
+ if len(rows) > _CODENAV_MAX_HITS:
+ lines.append(f" ... [{len(rows) - _CODENAV_MAX_HITS} more]")
+ if not rows:
+ lines.append(" (empty)")
+ return "\n".join(lines), None
+
+ out, err = await asyncio.to_thread(_ls)
+ if err:
+ return {"error": err, "exit_code": 1}
+ return {"output": _truncate(out), "exit_code": 0}
if tool == "web_search":
from src.search import comprehensive_web_search
@@ -685,6 +1103,7 @@ async def execute_tool_block(
disabled_tools: Optional[set] = None,
owner: Optional[str] = None,
progress_cb: Optional[Callable[[Dict], Awaitable[None]]] = None,
+ workspace: Optional[str] = None,
) -> Tuple[str, Dict]:
"""Execute a single tool block. Returns (description, result_dict).
@@ -773,7 +1192,7 @@ async def execute_tool_block(
_is_bg, _bg_cmd = _split_bg_marker(content)
if _is_bg and _bg_cmd:
from src import bg_jobs
- rec = bg_jobs.launch(_bg_cmd, session_id=session_id)
+ rec = bg_jobs.launch(_bg_cmd, session_id=session_id, cwd=workspace or _AGENT_WORKDIR)
short = _bg_cmd.strip().split(chr(10))[0][:80]
desc = f"bash (background): {short}"
result = {
@@ -795,7 +1214,14 @@ async def execute_tool_block(
if tool in _MCP_TOOL_MAP:
first_line = content.split(chr(10))[0][:80]
desc = f"{tool}: {first_line}"
- result = await _call_mcp_tool(tool, content, progress_cb=progress_cb)
+ result = await _call_mcp_tool(tool, content, progress_cb=progress_cb, workspace=workspace)
+ elif tool in ("grep", "glob", "ls"):
+ # Code-navigation tools — no MCP server; run the direct implementation.
+ # Confined to the workspace when one is set (same policy as read_file).
+ first_line = content.split(chr(10))[0][:80]
+ desc = f"{tool}: {first_line}"
+ result = await _direct_fallback(tool, content, progress_cb=progress_cb, workspace=workspace) \
+ or {"error": f"{tool}: execution failed", "exit_code": 1}
elif tool == "create_document":
title = content.split("\n")[0].strip()[:60]
desc = f"create_document: {title}"
@@ -898,6 +1324,9 @@ async def execute_tool_block(
elif tool == "edit_image":
desc = "edit_image"
result = await do_edit_image(content, owner=owner)
+ elif tool == "edit_file":
+ result = await _do_edit_file(content, workspace=workspace)
+ desc = result.get("output") or result.get("error") or "edit_file"
elif tool == "trigger_research":
desc = "trigger_research"
result = await do_trigger_research(content, owner=owner)
diff --git a/src/tool_index.py b/src/tool_index.py
index c648715..6d5f457 100644
--- a/src/tool_index.py
+++ b/src/tool_index.py
@@ -22,7 +22,13 @@ logger = logging.getLogger(__name__)
# Tools that are ALWAYS included regardless of retrieval results.
# These are the most commonly needed and should never be missing.
ALWAYS_AVAILABLE = frozenset({
- "bash", "python", "web_search", "web_fetch", "read_file",
+ "bash", "python", "web_search", "web_fetch",
+ # File tools: read AND write/edit. An agent with disk access should always
+ # be able to change files, not just read them — otherwise a bare "edit X"
+ # request can miss write_file/edit_file (RAG-only) and the model wrongly
+ # falls back to edit_document (editor panel). All admin-gated by tool_security.
+ "read_file", "write_file", "edit_file",
+ "grep", "glob", "ls", # code-navigation tools (admin-gated by tool_security)
"api_call", # For configured integrations (Miniflux, Gitea, Linkding, etc.)
# The two genuinely AMBIENT cookbook tools — "what's running" and
# "kill it" can be asked any time without prior cookbook context,
@@ -71,8 +77,12 @@ BUILTIN_TOOL_DESCRIPTIONS: Dict[str, str] = {
"python": "Execute Python code for computation, data processing, math, scripting, parsing, API calls. Not for writing code for the user.",
"web_search": "Quick single web lookup for a fact, current event, or doc mid-task. NOT for 'research X' / 'do research on X' requests — those are deep-research jobs (use trigger_research). web_search = one query; trigger_research = a full researched report in the sidebar.",
"web_fetch": "Fetch and read the text content of a specific URL/website the user names (e.g. 'check example.com', 'open this link'). Use when you have a concrete URL; for open-ended lookups use web_search instead.",
- "read_file": "Read a file from disk and return its contents. View source code, config files, logs.",
- "write_file": "Write content to a file on disk. Create new files, save output, update configs.",
+ "read_file": "Read a file from disk and return its contents. View source code, config files, logs. Supports an optional line range (offset/limit) for large files.",
+ "grep": "Search file CONTENTS for a regex across a directory tree (ripgrep-backed, honours .gitignore). Returns file:line:match. Use to find where code/symbols/strings live — prefer over bash grep.",
+ "glob": "Find FILES by glob pattern (e.g. '**/*.py'), newest first. Use to locate files by name/extension — prefer over bash find/ls.",
+ "ls": "List a directory's entries (folders then files with sizes). Use to see what's in a folder — prefer over bash ls.",
+ "write_file": "Write/create or fully rewrite a file ON DISK (source code, configs, project files). Use for new files or full rewrites — NOT create_document (editor panel) and NOT a bash heredoc.",
+ "edit_file": "Edit an existing file ON DISK by exact string replacement (fix a bug, change a function). Shows a diff. The tool for changing files on disk — NOT edit_document (editor panel) and NOT bash sed/heredoc.",
"create_document": "Create a new document in the editor panel. For code, articles, text content longer than 15 lines, unless an already-open document/email draft is the obvious target. If an email compose draft is open, edit that draft instead of creating another document.",
"edit_document": "Preferred tool for editing an existing document — targeted find-and-replace. Use for any small change: add a function, fix a bug, tweak a section, rename things.",
"update_document": "Replace the entire active document content. ONLY for full rewrites (>50% changed). Do not use for small edits — use edit_document instead.",
diff --git a/src/tool_schemas.py b/src/tool_schemas.py
index dd8eb74..e45415d 100644
--- a/src/tool_schemas.py
+++ b/src/tool_schemas.py
@@ -82,16 +82,65 @@ FUNCTION_TOOL_SCHEMAS = [
"type": "function",
"function": {
"name": "read_file",
- "description": "Read a file from disk",
+ "description": "Read a file from disk. Optionally read a line range with offset/limit for large files.",
"parameters": {
"type": "object",
"properties": {
- "path": {"type": "string", "description": "File path to read"}
+ "path": {"type": "string", "description": "File path to read"},
+ "offset": {"type": "integer", "description": "1-based line to start reading from (optional)"},
+ "limit": {"type": "integer", "description": "Max number of lines to read from offset (optional)"}
},
"required": ["path"]
}
}
},
+ {
+ "type": "function",
+ "function": {
+ "name": "grep",
+ "description": "Search file contents for a regular expression across a directory tree (uses ripgrep when available, respecting .gitignore). Returns file:line:match. PREFER this over `bash grep/rg` for code search — confined to the allowed roots, structured output.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "pattern": {"type": "string", "description": "Regular expression to search for"},
+ "path": {"type": "string", "description": "Directory or file to search (optional; defaults to the project root)"},
+ "glob": {"type": "string", "description": "Only search files matching this glob, e.g. '*.py' (optional)"},
+ "ignore_case": {"type": "boolean", "description": "Case-insensitive match (optional)"},
+ "max_results": {"type": "integer", "description": "Max matches to return (optional)"}
+ },
+ "required": ["pattern"]
+ }
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "glob",
+ "description": "Find files by glob pattern (recursive), newest first. e.g. '**/*.py'. PREFER this over `bash find/ls` for locating files — confined to the allowed roots.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "pattern": {"type": "string", "description": "Glob pattern, e.g. '**/*.ts' or 'src/**/test_*.py'"},
+ "path": {"type": "string", "description": "Base directory (optional; defaults to the project root)"}
+ },
+ "required": ["pattern"]
+ }
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "ls",
+ "description": "List the entries of a directory (folders first, then files with sizes). PREFER this over `bash ls` — confined to the allowed roots.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "path": {"type": "string", "description": "Directory to list (optional; defaults to the project root)"}
+ },
+ "required": []
+ }
+ }
+ },
{
"type": "function",
"function": {
@@ -107,6 +156,23 @@ FUNCTION_TOOL_SCHEMAS = [
}
}
},
+ {
+ "type": "function",
+ "function": {
+ "name": "edit_file",
+ "description": "Edit a file ON DISK by exact string replacement (home folder, project files, any real path like ~/sweden.txt or /path/to/file). This is the right tool for files on disk — NOT edit_document (that's for editor-panel documents). PREFER this over bash (sed/echo) — it shows a diff. old_string must match the file exactly and be unique (or set replace_all). Use write_file to create a new file.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "path": {"type": "string", "description": "File path to edit"},
+ "old_string": {"type": "string", "description": "Exact text to replace (must match the file, including indentation)"},
+ "new_string": {"type": "string", "description": "Replacement text"},
+ "replace_all": {"type": "boolean", "description": "Replace all occurrences instead of requiring a unique match"}
+ },
+ "required": ["path", "old_string", "new_string"]
+ }
+ }
+ },
{
"type": "function",
"function": {
@@ -127,7 +193,7 @@ FUNCTION_TOOL_SCHEMAS = [
"type": "function",
"function": {
"name": "edit_document",
- "description": "PREFERRED way to change an existing document. Targeted find-and-replace with multiple FIND/REPLACE pairs per call. Use this for any edit smaller than a full rewrite: adding a function, fixing a bug, tweaking a section, renaming things. Do NOT send the whole file back via update_document for small edits — it wastes tokens and is hard to review.",
+ "description": "Edit a document OPEN IN THE EDITOR PANEL (created via create_document) — NOT a file on disk. For files on disk (home folder, project files, anything with a path like ~/x.txt or /path/to/file) use edit_file instead. Targeted find-and-replace with multiple FIND/REPLACE pairs per call; use for any edit smaller than a full rewrite. Do NOT send the whole file back via update_document for small edits.",
"parameters": {
"type": "object",
"properties": {
@@ -1126,9 +1192,17 @@ def function_call_to_tool_block(name: str, arguments: str) -> Optional[ToolBlock
else:
content = args.get("query", "")
elif tool_type == "read_file":
- content = args.get("path", "")
+ # Plain path (back-compat) unless a line range is requested → JSON.
+ if args.get("offset") or args.get("limit"):
+ content = json.dumps(args)
+ else:
+ content = args.get("path", "")
+ elif tool_type in ("grep", "glob", "ls"):
+ content = json.dumps(args) if args else "{}"
elif tool_type == "write_file":
content = args.get("path", "") + "\n" + args.get("content", "")
+ elif tool_type == "edit_file":
+ content = json.dumps(args)
elif tool_type == "create_document":
parts = [args.get("title", "Untitled")]
if args.get("language"):
diff --git a/src/tool_security.py b/src/tool_security.py
index c4094b9..8ffa50f 100644
--- a/src/tool_security.py
+++ b/src/tool_security.py
@@ -16,6 +16,10 @@ NON_ADMIN_BLOCKED_TOOLS = {
"python",
"read_file",
"write_file",
+ "edit_file",
+ "grep",
+ "glob",
+ "ls",
"search_chats",
"manage_memory",
"manage_skills",
diff --git a/static/app.js b/static/app.js
index 683e0e5..08ab121 100644
--- a/static/app.js
+++ b/static/app.js
@@ -4,6 +4,7 @@
// ============================================
import Storage from './js/storage.js';
import uiModule from './js/ui.js';
+import workspaceModule from './js/workspace.js';
import fileHandlerModule from './js/fileHandler.js';
import modelsModule from './js/models.js';
import ragModule from './js/rag.js';
@@ -13,6 +14,7 @@ import chatModule from './js/chat.js';
import compareModule from './js/compare/index.js';
import documentModule from './js/document.js';
import searchChatModule from './js/search-chat.js';
+import { makeWindowDraggable } from './js/windowDrag.js';
import markdownModule from './js/markdown.js';
import chatRenderer from './js/chatRenderer.js';
import sessionModule from './js/sessions.js';
@@ -1686,6 +1688,7 @@ function initializeEventListeners() {
}
setupToggle('web-toggle-btn', 'web-toggle', 'web');
setupToggle('bash-toggle-btn', 'bash-toggle', 'bash');
+ try { workspaceModule.initWorkspace(); } catch (_) {}
// Document editor toggle (special: uses module panel, not a checkbox)
const overflowDocBtn = el('overflow-doc-btn');
@@ -2683,82 +2686,38 @@ function initializeEventListeners() {
// Apply saved visibility on load
applyUIVis(loadUIVis());
- // Generic draggable for all .modal elements
- const _sharedDragModalIds = new Set(['settings-modal']);
- try { document.querySelectorAll('.modal').forEach(m => {
- if (_sharedDragModalIds.has(m.id)) return;
- const content = m.querySelector('.modal-content');
- const header = m.querySelector('.modal-header');
- if (!content || !header) return;
- let dragX, dragY, startLeft, startTop, dragging = false;
-
- // Reset to flex-centered position each time modal opens
- new MutationObserver(() => {
- if (!m.classList.contains('hidden')) {
- content.style.position = '';
- content.style.left = '';
- content.style.top = '';
- content.style.right = '';
- content.style.bottom = '';
- content.style.margin = '';
- }
- }).observe(m, { attributes: true, attributeFilter: ['class'] });
-
- function startDrag(clientX, clientY) {
- dragging = true;
- const rect = content.getBoundingClientRect();
- dragX = clientX; dragY = clientY;
- startLeft = rect.left; startTop = rect.top;
- // Switch to fixed so it can be freely positioned
- content.style.position = 'fixed';
- content.style.left = startLeft + 'px';
- content.style.top = startTop + 'px';
- content.style.margin = '0';
- }
-
- header.addEventListener('mousedown', (e) => {
- if (e.target.closest('.close-btn')) return;
- e.preventDefault();
- startDrag(e.clientX, e.clientY);
- document.addEventListener('mousemove', onDrag);
- document.addEventListener('mouseup', stopDrag);
- });
- function onDrag(e) {
- if (!dragging) return;
- content.style.left = (startLeft + e.clientX - dragX) + 'px';
- content.style.top = (startTop + e.clientY - dragY) + 'px';
- }
- function stopDrag() {
- dragging = false;
- document.removeEventListener('mousemove', onDrag);
- document.removeEventListener('mouseup', stopDrag);
- }
-
- // Touch drag is desktop-only — on mobile, modals are bottom sheets and
- // ui.js handles swipe-down-to-dismiss. Attaching this listener fights
- // the swipe-dismiss gesture.
- if (window.innerWidth > 768) {
- header.addEventListener('touchstart', (e) => {
- if (e.target.closest('.close-btn')) return;
- const t = e.touches[0];
- startDrag(t.clientX, t.clientY);
- document.addEventListener('touchmove', onTouchDrag, { passive: false });
- document.addEventListener('touchend', stopTouchDrag);
+ // The only two modals without a per-module makeWindowDraggable call. Wire
+ // them onto the shared helper, drag-only, to match their old behavior.
+ try {
+ ['custom-preset-modal', 'rename-session-modal'].forEach((id) => {
+ const m = document.getElementById(id);
+ if (!m) return;
+ const content = m.querySelector('.modal-content');
+ const header = m.querySelector('.modal-header');
+ if (!content || !header) return;
+ makeWindowDraggable(m, {
+ content, header,
+ skipSelector: '.close-btn',
+ enableDock: false,
+ enableResize: false,
});
- }
- function onTouchDrag(e) {
- if (!dragging) return;
- e.preventDefault();
- const t = e.touches[0];
- content.style.left = (startLeft + t.clientX - dragX) + 'px';
- content.style.top = (startTop + t.clientY - dragY) + 'px';
- }
- function stopTouchDrag() {
- dragging = false;
- document.removeEventListener('touchmove', onTouchDrag);
- document.removeEventListener('touchend', stopTouchDrag);
- }
- }); } catch(e) { console.error('Modal drag init error:', e); }
+ // Re-center on open (these persist in the DOM). Guard on the
+ // hidden→visible edge so it never fires mid-drag.
+ let wasHidden = m.classList.contains('hidden');
+ new MutationObserver(() => {
+ const isHidden = m.classList.contains('hidden');
+ if (wasHidden && !isHidden) {
+ content.style.position = '';
+ content.style.left = '';
+ content.style.top = '';
+ content.style.right = '';
+ content.style.bottom = '';
+ content.style.margin = '';
+ }
+ wasHidden = isHidden;
+ }).observe(m, { attributes: true, attributeFilter: ['class'] });
+ });
+ } catch (e) { console.error('Dialog drag init error:', e); }
})();
// ── Modal minimize → dock ──
diff --git a/static/index.html b/static/index.html
index 72544de..c5f3828 100644
--- a/static/index.html
+++ b/static/index.html
@@ -1031,6 +1031,13 @@
RAG
+
+