Odysseus v1.0
This commit is contained in:
0
routes/__init__.py
Normal file
0
routes/__init__.py
Normal file
174
routes/admin_wipe_routes.py
Normal file
174
routes/admin_wipe_routes.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Admin Danger Zone — per-category wipes.
|
||||
|
||||
Each endpoint is admin-only and truncates exactly one domain so the
|
||||
user can selectively reset memory / skills / notes / etc. without
|
||||
nuking everything. The catch-all `chats` endpoint mirrors the
|
||||
existing /api/sessions/all so the Danger Zone speaks one URL pattern.
|
||||
|
||||
URL shape: DELETE /api/admin/wipe/{kind}
|
||||
Kinds: chats, memory, skills, notes, tasks, documents, gallery, calendar.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
|
||||
from core.middleware import require_admin
|
||||
from core.database import (
|
||||
SessionLocal,
|
||||
Session as DbSession,
|
||||
ChatMessage as DbChatMessage,
|
||||
Memory,
|
||||
Note,
|
||||
ScheduledTask,
|
||||
TaskRun,
|
||||
Document,
|
||||
DocumentVersion,
|
||||
GalleryImage,
|
||||
CalendarEvent,
|
||||
CalendarCal,
|
||||
)
|
||||
from src.constants import DATA_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _wipe_memory_files():
|
||||
"""Blank memory.json + drop the per-owner tidy-state sidecar so the
|
||||
next audit doesn't try to diff against gone memories."""
|
||||
for name in ("memory.json", "memory_tidy_state.json"):
|
||||
p = os.path.join(DATA_DIR, name)
|
||||
if not os.path.exists(p):
|
||||
continue
|
||||
try:
|
||||
if name == "memory.json":
|
||||
with open(p, "w") as f:
|
||||
json.dump([], f)
|
||||
else:
|
||||
os.remove(p)
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not reset {name}: {e}")
|
||||
|
||||
|
||||
def _rmtree_quiet(path: str):
|
||||
"""rmtree that doesn't crash if the path doesn't exist."""
|
||||
if os.path.isdir(path):
|
||||
try:
|
||||
shutil.rmtree(path)
|
||||
except OSError as e:
|
||||
logger.warning(f"Could not remove {path}: {e}")
|
||||
|
||||
|
||||
def setup_admin_wipe_routes(session_manager):
|
||||
"""The session_manager is passed in so we can also clear its
|
||||
in-memory cache when wiping chats — without it the DB is empty
|
||||
but the next /api/sessions returns stale entries."""
|
||||
router = APIRouter(prefix="/api/admin")
|
||||
|
||||
@router.delete("/wipe/{kind}")
|
||||
def wipe(kind: str, request: Request):
|
||||
require_admin(request)
|
||||
kind = (kind or "").strip().lower()
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if kind == "chats":
|
||||
count = db.query(DbSession).count()
|
||||
db.query(DbChatMessage).delete()
|
||||
db.query(DbSession).delete()
|
||||
db.commit()
|
||||
try:
|
||||
session_manager.sessions.clear()
|
||||
except Exception:
|
||||
pass
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
if kind == "memory":
|
||||
count = db.query(Memory).count()
|
||||
db.query(Memory).delete()
|
||||
db.commit()
|
||||
_wipe_memory_files()
|
||||
# Drop the vector store too so semantic search doesn't
|
||||
# return ghosts. Lazy import — chromadb may not be
|
||||
# initialised in every deployment.
|
||||
try:
|
||||
from src.memory_vector import get_memory_vector_store
|
||||
mv = get_memory_vector_store()
|
||||
if mv and hasattr(mv, "clear"):
|
||||
mv.clear()
|
||||
except Exception as e:
|
||||
logger.info(f"Memory vector clear skipped: {e}")
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
if kind == "skills":
|
||||
# Skills live as SKILL.md files under data/skills/. Drop
|
||||
# the entire directory; the SkillsManager re-creates the
|
||||
# tree on next write.
|
||||
skills_dir = os.path.join(DATA_DIR, "skills")
|
||||
count = 0
|
||||
if os.path.isdir(skills_dir):
|
||||
# Count SKILL.md files for the response — quick walk.
|
||||
for _, _, files in os.walk(skills_dir):
|
||||
count += sum(1 for f in files if f == "SKILL.md")
|
||||
_rmtree_quiet(skills_dir)
|
||||
# Legacy fallback file
|
||||
legacy = os.path.join(DATA_DIR, "skills.json")
|
||||
if os.path.exists(legacy):
|
||||
try:
|
||||
os.remove(legacy)
|
||||
except OSError:
|
||||
pass
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
if kind == "notes":
|
||||
count = db.query(Note).count()
|
||||
db.query(Note).delete()
|
||||
db.commit()
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
if kind == "tasks":
|
||||
# TaskRun rows reference tasks via FK — clear them first.
|
||||
db.query(TaskRun).delete()
|
||||
count = db.query(ScheduledTask).count()
|
||||
db.query(ScheduledTask).delete()
|
||||
db.commit()
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
if kind == "documents":
|
||||
# DocumentVersion FKs Document — clear children first.
|
||||
db.query(DocumentVersion).delete()
|
||||
count = db.query(Document).count()
|
||||
db.query(Document).delete()
|
||||
db.commit()
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
if kind == "gallery":
|
||||
count = db.query(GalleryImage).count()
|
||||
db.query(GalleryImage).delete()
|
||||
db.commit()
|
||||
# Also drop the upload dir so disk doesn't keep orphans.
|
||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery"))
|
||||
_rmtree_quiet(os.path.join(DATA_DIR, "gallery_uploads"))
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
if kind == "calendar":
|
||||
# Events FK calendars — clear children first, then both.
|
||||
db.query(CalendarEvent).delete()
|
||||
count = db.query(CalendarCal).count()
|
||||
db.query(CalendarCal).delete()
|
||||
db.commit()
|
||||
return {"status": "deleted", "kind": kind, "count": count}
|
||||
|
||||
raise HTTPException(400, f"Unknown wipe kind: {kind!r}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.exception(f"Wipe {kind} failed")
|
||||
raise HTTPException(500, f"Wipe {kind} failed: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return router
|
||||
91
routes/api_token_routes.py
Normal file
91
routes/api_token_routes.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""API Token management routes — /api/tokens/*."""
|
||||
|
||||
import secrets
|
||||
import uuid
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, HTTPException, Request, Form
|
||||
|
||||
from core.database import get_db_session, ApiToken
|
||||
from core.middleware import require_admin
|
||||
from src.auth_helpers import get_current_user
|
||||
|
||||
MAX_NAME_LEN = 100
|
||||
DEFAULT_SCOPES = "chat"
|
||||
|
||||
|
||||
def setup_api_token_routes() -> APIRouter:
|
||||
router = APIRouter(prefix="/api", tags=["api_tokens"])
|
||||
|
||||
@router.get("/tokens")
|
||||
def list_tokens(request: Request):
|
||||
require_admin(request)
|
||||
with get_db_session() as db:
|
||||
tokens = db.query(ApiToken).all()
|
||||
return [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"owner": getattr(t, "owner", None),
|
||||
"token_prefix": t.token_prefix,
|
||||
"scopes": [s.strip() for s in (getattr(t, "scopes", "") or DEFAULT_SCOPES).split(",") if s.strip()],
|
||||
"is_active": t.is_active,
|
||||
"last_used_at": t.last_used_at.isoformat() if t.last_used_at else None,
|
||||
"created_at": t.created_at.isoformat() if t.created_at else None,
|
||||
}
|
||||
for t in tokens
|
||||
]
|
||||
|
||||
def _invalidate_cache(request: Request):
|
||||
"""Tell the auth middleware its cached token map is stale."""
|
||||
try:
|
||||
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
|
||||
if invalidator:
|
||||
invalidator()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@router.post("/tokens")
|
||||
def create_token(request: Request, name: str = Form("")):
|
||||
require_admin(request)
|
||||
name = name.strip()[:MAX_NAME_LEN]
|
||||
if not name:
|
||||
raise HTTPException(400, "Token name is required")
|
||||
owner = get_current_user(request)
|
||||
|
||||
raw_token = "ody_" + secrets.token_urlsafe(32)
|
||||
token_hash = bcrypt.hashpw(raw_token.encode(), bcrypt.gensalt()).decode()
|
||||
token_id = str(uuid.uuid4())[:8]
|
||||
|
||||
with get_db_session() as db:
|
||||
db.add(ApiToken(
|
||||
id=token_id,
|
||||
owner=owner,
|
||||
name=name,
|
||||
token_hash=token_hash,
|
||||
token_prefix=raw_token[:8],
|
||||
scopes=DEFAULT_SCOPES,
|
||||
is_active=True,
|
||||
))
|
||||
_invalidate_cache(request)
|
||||
|
||||
return {
|
||||
"id": token_id,
|
||||
"name": name,
|
||||
"owner": owner,
|
||||
"token": raw_token,
|
||||
"token_prefix": raw_token[:8],
|
||||
"scopes": DEFAULT_SCOPES.split(","),
|
||||
}
|
||||
|
||||
@router.delete("/tokens/{token_id}")
|
||||
def delete_token(request: Request, token_id: str):
|
||||
require_admin(request)
|
||||
with get_db_session() as db:
|
||||
deleted = db.query(ApiToken).filter(ApiToken.id == token_id).delete()
|
||||
if not deleted:
|
||||
raise HTTPException(404, "Token not found")
|
||||
_invalidate_cache(request)
|
||||
return {"status": "deleted"}
|
||||
|
||||
return router
|
||||
325
routes/assistant_routes.py
Normal file
325
routes/assistant_routes.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""Personal assistant routes — resolve the per-user singleton, read/write
|
||||
its settings, and list its scheduled check-in tasks.
|
||||
|
||||
The personal assistant is just a specially-flagged CrewMember that owns one
|
||||
pinned Session and three daily ScheduledTasks ("Morning/Midday/Evening
|
||||
check-in"). Everything about it is user-editable: name, personality, model,
|
||||
enabled tools, timezone, and the three check-in times/prompts/enabled flags.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import SessionLocal, CrewMember, ScheduledTask
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.task_scheduler import compute_next_run
|
||||
|
||||
|
||||
class CheckInUpdate(BaseModel):
|
||||
id: str # ScheduledTask.id
|
||||
name: Optional[str] = None
|
||||
scheduled_time: Optional[str] = None # "HH:MM"
|
||||
prompt: Optional[str] = None
|
||||
enabled: Optional[bool] = None # maps to status "active"/"paused"
|
||||
|
||||
|
||||
class AssistantSettingsUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
avatar: Optional[str] = None
|
||||
personality: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
endpoint_url: Optional[str] = None
|
||||
enabled_tools: Optional[list[str]] = None
|
||||
allow_autonomous_email: Optional[bool] = None # convenience toggle
|
||||
timezone: Optional[str] = None
|
||||
check_ins: Optional[list[CheckInUpdate]] = None
|
||||
|
||||
|
||||
_EMAIL_TOOLS = {"send_email", "reply_to_email"}
|
||||
|
||||
|
||||
def _crew_to_dict(c: CrewMember) -> dict:
|
||||
try:
|
||||
tools = json.loads(c.enabled_tools) if c.enabled_tools else []
|
||||
except Exception:
|
||||
tools = []
|
||||
return {
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"avatar": c.avatar,
|
||||
"personality": c.personality,
|
||||
"model": c.model,
|
||||
"endpoint_url": c.endpoint_url,
|
||||
"greeting": c.greeting,
|
||||
"enabled_tools": tools,
|
||||
"session_id": c.session_id,
|
||||
"is_default_assistant": bool(c.is_default_assistant),
|
||||
"timezone": c.timezone,
|
||||
"allow_autonomous_email": any(t in _EMAIL_TOOLS for t in tools),
|
||||
}
|
||||
|
||||
|
||||
def _task_to_checkin_dict(t: ScheduledTask) -> dict:
|
||||
return {
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"scheduled_time": t.scheduled_time,
|
||||
"prompt": t.prompt,
|
||||
"enabled": (t.status or "active") == "active",
|
||||
"next_run": t.next_run.isoformat() + "Z" if t.next_run else None,
|
||||
"last_run": t.last_run.isoformat() + "Z" if t.last_run else None,
|
||||
"run_count": t.run_count or 0,
|
||||
}
|
||||
|
||||
|
||||
def setup_assistant_routes(task_scheduler) -> APIRouter:
|
||||
router = APIRouter(prefix="/api/assistant", tags=["assistant"])
|
||||
|
||||
def _owner(request: Request) -> str:
|
||||
owner = get_current_user(request)
|
||||
if not owner:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return owner
|
||||
|
||||
# Synthetic / non-human owners that should NEVER get an assistant +
|
||||
# check-in tasks seeded. Hitting any /assistant route under one of these
|
||||
# used to seed a full CrewMember + Morning/Midday/Evening tasks under that
|
||||
# owner, which then double-fired alongside the real user's check-ins.
|
||||
_SYNTHETIC_OWNERS = frozenset({"internal-tool", "api", "demo", "system", ""})
|
||||
|
||||
async def _get_or_create(owner: str) -> CrewMember:
|
||||
"""Return the per-owner assistant CrewMember, creating it on demand."""
|
||||
if not owner or owner in _SYNTHETIC_OWNERS:
|
||||
raise HTTPException(status_code=400, detail=f"Cannot seed assistant for {owner!r}")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
crew = db.query(CrewMember).filter(
|
||||
CrewMember.owner == owner,
|
||||
CrewMember.is_default_assistant == True, # noqa: E712
|
||||
).first()
|
||||
if crew:
|
||||
return crew
|
||||
finally:
|
||||
db.close()
|
||||
# Seed lazily. This is the same code the startup hook runs for each
|
||||
# user — safe to call again, it's idempotent.
|
||||
await task_scheduler.ensure_assistant_defaults(owner)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
crew = db.query(CrewMember).filter(
|
||||
CrewMember.owner == owner,
|
||||
CrewMember.is_default_assistant == True, # noqa: E712
|
||||
).first()
|
||||
return crew
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/session")
|
||||
async def get_assistant_session(request: Request):
|
||||
"""Resolve (or lazily create) the pinned Assistant session for this user."""
|
||||
owner = _owner(request)
|
||||
crew = await _get_or_create(owner)
|
||||
if not crew or not crew.session_id:
|
||||
raise HTTPException(status_code=500, detail="Assistant session could not be resolved")
|
||||
return {
|
||||
"session_id": crew.session_id,
|
||||
"crew_member_id": crew.id,
|
||||
"name": crew.name,
|
||||
}
|
||||
|
||||
@router.get("/settings")
|
||||
async def get_assistant_settings(request: Request):
|
||||
"""Return CrewMember fields + the three check-in task rows + task IDs for logs."""
|
||||
owner = _owner(request)
|
||||
crew = await _get_or_create(owner)
|
||||
if not crew:
|
||||
raise HTTPException(status_code=500, detail="Assistant not available")
|
||||
db = SessionLocal()
|
||||
try:
|
||||
tasks = db.query(ScheduledTask).filter(
|
||||
ScheduledTask.owner == owner,
|
||||
ScheduledTask.crew_member_id == crew.id,
|
||||
).order_by(ScheduledTask.scheduled_time.asc()).all()
|
||||
return {
|
||||
"crew": _crew_to_dict(crew),
|
||||
"check_ins": [_task_to_checkin_dict(t) for t in tasks],
|
||||
"task_ids": [t.id for t in tasks],
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.patch("/settings")
|
||||
async def update_assistant_settings(payload: AssistantSettingsUpdate, request: Request):
|
||||
"""Update CrewMember fields and/or check-in tasks in one call."""
|
||||
owner = _owner(request)
|
||||
crew = await _get_or_create(owner)
|
||||
if not crew:
|
||||
raise HTTPException(status_code=500, detail="Assistant not available")
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
crew_db = db.query(CrewMember).filter(CrewMember.id == crew.id).first()
|
||||
if not crew_db:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
# Update CrewMember fields.
|
||||
if payload.name is not None:
|
||||
crew_db.name = payload.name.strip() or crew_db.name
|
||||
if payload.avatar is not None:
|
||||
crew_db.avatar = payload.avatar
|
||||
if payload.personality is not None:
|
||||
crew_db.personality = payload.personality
|
||||
if payload.model is not None:
|
||||
crew_db.model = payload.model or None
|
||||
if payload.endpoint_url is not None:
|
||||
crew_db.endpoint_url = payload.endpoint_url or None
|
||||
if payload.timezone is not None:
|
||||
crew_db.timezone = payload.timezone or None
|
||||
|
||||
# Tool list: either explicit list, or implicit toggle.
|
||||
if payload.enabled_tools is not None:
|
||||
crew_db.enabled_tools = json.dumps(payload.enabled_tools)
|
||||
if payload.allow_autonomous_email is not None:
|
||||
try:
|
||||
existing = json.loads(crew_db.enabled_tools) if crew_db.enabled_tools else []
|
||||
except Exception:
|
||||
existing = []
|
||||
if payload.allow_autonomous_email:
|
||||
for t in ("send_email", "reply_to_email"):
|
||||
if t not in existing:
|
||||
existing.append(t)
|
||||
else:
|
||||
existing = [t for t in existing if t not in _EMAIL_TOOLS]
|
||||
crew_db.enabled_tools = json.dumps(existing)
|
||||
|
||||
crew_db.updated_at = datetime.utcnow()
|
||||
|
||||
# Update check-in tasks.
|
||||
if payload.check_ins:
|
||||
now_utc = datetime.utcnow()
|
||||
tz_name = crew_db.timezone or None
|
||||
for ci in payload.check_ins:
|
||||
task = db.query(ScheduledTask).filter(
|
||||
ScheduledTask.id == ci.id,
|
||||
ScheduledTask.owner == owner,
|
||||
ScheduledTask.crew_member_id == crew_db.id,
|
||||
).first()
|
||||
if not task:
|
||||
continue
|
||||
if ci.name is not None:
|
||||
task.name = ci.name.strip() or task.name
|
||||
time_changed = False
|
||||
if ci.scheduled_time is not None and ci.scheduled_time != task.scheduled_time:
|
||||
task.scheduled_time = ci.scheduled_time
|
||||
time_changed = True
|
||||
if ci.prompt is not None:
|
||||
task.prompt = ci.prompt
|
||||
if ci.enabled is not None:
|
||||
task.status = "active" if ci.enabled else "paused"
|
||||
if time_changed or ci.enabled is True:
|
||||
task.next_run = compute_next_run(
|
||||
task.schedule or "daily",
|
||||
task.scheduled_time,
|
||||
task.scheduled_day,
|
||||
task.scheduled_date,
|
||||
after=now_utc,
|
||||
cron_expression=task.cron_expression,
|
||||
tz_name=tz_name,
|
||||
)
|
||||
task.updated_at = datetime.utcnow()
|
||||
|
||||
# Timezone change also shifts the NEXT run of all check-ins even if
|
||||
# the user didn't touch the time fields.
|
||||
if payload.timezone is not None:
|
||||
now_utc = datetime.utcnow()
|
||||
tz_name = crew_db.timezone or None
|
||||
tasks = db.query(ScheduledTask).filter(
|
||||
ScheduledTask.owner == owner,
|
||||
ScheduledTask.crew_member_id == crew_db.id,
|
||||
).all()
|
||||
for t in tasks:
|
||||
if t.schedule and t.scheduled_time:
|
||||
t.next_run = compute_next_run(
|
||||
t.schedule, t.scheduled_time, t.scheduled_day, t.scheduled_date,
|
||||
after=now_utc, cron_expression=t.cron_expression, tz_name=tz_name,
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# Re-read crew_db + tasks to return the fresh state.
|
||||
crew_out = db.query(CrewMember).filter(CrewMember.id == crew.id).first()
|
||||
tasks_out = db.query(ScheduledTask).filter(
|
||||
ScheduledTask.owner == owner,
|
||||
ScheduledTask.crew_member_id == crew.id,
|
||||
).order_by(ScheduledTask.scheduled_time.asc()).all()
|
||||
return {
|
||||
"crew": _crew_to_dict(crew_out),
|
||||
"check_ins": [_task_to_checkin_dict(t) for t in tasks_out],
|
||||
"task_ids": [t.id for t in tasks_out],
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/run/{task_id}")
|
||||
async def run_check_in_now(task_id: str, request: Request):
|
||||
"""Trigger one of the assistant's check-ins immediately (manual test)."""
|
||||
owner = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(
|
||||
ScheduledTask.id == task_id,
|
||||
ScheduledTask.owner == owner,
|
||||
).first()
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
crew = db.query(CrewMember).filter(
|
||||
CrewMember.id == task.crew_member_id,
|
||||
CrewMember.is_default_assistant == True, # noqa: E712
|
||||
).first()
|
||||
if not crew:
|
||||
raise HTTPException(status_code=400, detail="Not an assistant task")
|
||||
finally:
|
||||
db.close()
|
||||
started = await task_scheduler.run_task_now(task_id)
|
||||
return {"started": bool(started)}
|
||||
|
||||
@router.get("/run-status/{task_id}")
|
||||
async def run_status(task_id: str, request: Request):
|
||||
"""Check whether the most recent run of a task has finished."""
|
||||
from core.database import TaskRun, ScheduledTask
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# SECURITY: 404 if the task doesn't belong to this user — without
|
||||
# this any authenticated user could poll the status of any task_id.
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Task not found")
|
||||
if user and task.owner != user:
|
||||
raise HTTPException(404, "Task not found")
|
||||
run = db.query(TaskRun).filter(
|
||||
TaskRun.task_id == task_id,
|
||||
).order_by(TaskRun.started_at.desc()).first()
|
||||
if not run:
|
||||
return {"status": "unknown"}
|
||||
if run.status == "running":
|
||||
return {"status": "running"}
|
||||
return {"status": "done", "result_status": run.status}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/available-timezones")
|
||||
async def list_timezones():
|
||||
"""Return the IANA tz name list used to populate the settings dropdown."""
|
||||
try:
|
||||
from zoneinfo import available_timezones
|
||||
zones = sorted(available_timezones())
|
||||
except Exception:
|
||||
zones = ["UTC"]
|
||||
return {"timezones": zones}
|
||||
|
||||
return router
|
||||
502
routes/auth_routes.py
Normal file
502
routes/auth_routes.py
Normal file
@@ -0,0 +1,502 @@
|
||||
"""Authentication routes — login, logout, signup, status, user management."""
|
||||
|
||||
from fastapi import APIRouter, Request, Response, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import logging
|
||||
import os
|
||||
|
||||
from core.auth import AuthManager
|
||||
from src.rate_limiter import RateLimiter
|
||||
from src.settings import (
|
||||
load_settings as _load_settings,
|
||||
save_settings as _save_settings,
|
||||
load_features as _load_features,
|
||||
save_features as _save_features,
|
||||
DEFAULT_SETTINGS,
|
||||
)
|
||||
from src.integrations import (
|
||||
load_integrations,
|
||||
add_integration,
|
||||
update_integration,
|
||||
delete_integration,
|
||||
get_integration,
|
||||
execute_api_call,
|
||||
INTEGRATION_PRESETS,
|
||||
migrate_from_settings,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
remember: bool = True
|
||||
totp_code: Optional[str] = None
|
||||
|
||||
|
||||
class SetupRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class SignupRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class CreateUserRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
is_admin: bool = False
|
||||
|
||||
|
||||
class DeleteUserRequest(BaseModel):
|
||||
username: str
|
||||
|
||||
|
||||
SESSION_COOKIE = "odysseus_session"
|
||||
|
||||
|
||||
def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
_login_limiter = RateLimiter(max_requests=15, window_seconds=60)
|
||||
_signup_limiter = RateLimiter(max_requests=3, window_seconds=300)
|
||||
_setup_limiter = RateLimiter(max_requests=3, window_seconds=300)
|
||||
|
||||
def _get_current_user(request: Request) -> Optional[str]:
|
||||
token = request.cookies.get(SESSION_COOKIE)
|
||||
return auth_manager.get_username_for_token(token)
|
||||
|
||||
@router.post("/setup")
|
||||
async def first_run_setup(body: SetupRequest, request: Request):
|
||||
"""Create initial admin account. Only works if no accounts exist."""
|
||||
if not _setup_limiter.check(request.client.host):
|
||||
raise HTTPException(429, "Too many requests — try again later")
|
||||
if auth_manager.is_configured:
|
||||
raise HTTPException(400, "Already configured")
|
||||
if len(body.password) < 8:
|
||||
raise HTTPException(400, "Password must be at least 8 characters")
|
||||
ok = auth_manager.setup(body.username, body.password)
|
||||
if not ok:
|
||||
raise HTTPException(500, "Setup failed")
|
||||
return {"ok": True, "message": "Admin account created"}
|
||||
|
||||
@router.post("/signup")
|
||||
async def signup(body: SignupRequest, request: Request):
|
||||
"""Create a new user account. Only works if signup is enabled by admin."""
|
||||
if not _signup_limiter.check(request.client.host):
|
||||
raise HTTPException(429, "Too many requests — try again later")
|
||||
if not auth_manager.is_configured:
|
||||
raise HTTPException(400, "Run setup first")
|
||||
if not auth_manager.signup_enabled:
|
||||
raise HTTPException(403, "Registration is disabled. Ask an admin for an account.")
|
||||
if len(body.password) < 8:
|
||||
raise HTTPException(400, "Password must be at least 8 characters")
|
||||
if len(body.username.strip()) < 1:
|
||||
raise HTTPException(400, "Username is required")
|
||||
ok = auth_manager.create_user(body.username, body.password, is_admin=False)
|
||||
if not ok:
|
||||
raise HTTPException(409, "Username already taken")
|
||||
return {"ok": True, "message": "Account created"}
|
||||
|
||||
@router.post("/login")
|
||||
async def login(body: LoginRequest, request: Request, response: Response):
|
||||
if not _login_limiter.check(request.client.host):
|
||||
raise HTTPException(429, "Too many requests — try again later")
|
||||
# Verify password first
|
||||
username = body.username.strip().lower()
|
||||
if not auth_manager.verify_password(username, body.password):
|
||||
raise HTTPException(401, "Invalid credentials")
|
||||
# Check 2FA if enabled
|
||||
if auth_manager.totp_enabled(username):
|
||||
if not body.totp_code:
|
||||
# Password OK but need TOTP — tell client to show code input
|
||||
return {"ok": False, "requires_totp": True, "username": username}
|
||||
if not auth_manager.totp_verify(username, body.totp_code):
|
||||
raise HTTPException(401, "Invalid 2FA code")
|
||||
# All checks passed — create session
|
||||
token = auth_manager.create_session(username, body.password)
|
||||
if not token:
|
||||
raise HTTPException(401, "Invalid credentials")
|
||||
cookie_kwargs = dict(
|
||||
key=SESSION_COOKIE,
|
||||
value=token,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
secure=os.getenv("SECURE_COOKIES", "false").lower() == "true",
|
||||
path="/",
|
||||
)
|
||||
if body.remember:
|
||||
cookie_kwargs["max_age"] = 60 * 60 * 24 * 7 # 7 days
|
||||
response.set_cookie(**cookie_kwargs)
|
||||
return {"ok": True, "username": username}
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(request: Request, response: Response):
|
||||
token = request.cookies.get(SESSION_COOKIE)
|
||||
if token:
|
||||
auth_manager.revoke_token(token)
|
||||
response.delete_cookie(SESSION_COOKIE, path="/")
|
||||
return {"ok": True}
|
||||
|
||||
@router.get("/status")
|
||||
async def auth_status(request: Request):
|
||||
token = request.cookies.get(SESSION_COOKIE)
|
||||
result = auth_manager.status(token)
|
||||
result["signup_enabled"] = auth_manager.signup_enabled
|
||||
# Include the caller's effective privileges so the frontend can
|
||||
# hide / dim UI controls the user isn't allowed to use. Admins get
|
||||
# ADMIN_PRIVILEGES (everything on), regular users get their stored
|
||||
# set merged with DEFAULT_PRIVILEGES.
|
||||
try:
|
||||
u = result.get("username")
|
||||
if u:
|
||||
result["privileges"] = auth_manager.get_privileges(u)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
@router.post("/change-password")
|
||||
async def change_password(body: ChangePasswordRequest, request: Request):
|
||||
user = _get_current_user(request)
|
||||
if not user:
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
if len(body.new_password) < 8:
|
||||
raise HTTPException(400, "Password must be at least 8 characters")
|
||||
ok = auth_manager.change_password(user, body.current_password, body.new_password)
|
||||
if not ok:
|
||||
raise HTTPException(400, "Current password is incorrect")
|
||||
return {"ok": True}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Two-factor authentication
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@router.post("/2fa/setup")
|
||||
async def totp_setup(request: Request):
|
||||
"""Generate a TOTP secret and return the QR code URI."""
|
||||
user = _get_current_user(request)
|
||||
if not user:
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
if auth_manager.totp_enabled(user):
|
||||
raise HTTPException(400, "2FA is already enabled")
|
||||
secret = auth_manager.totp_generate_secret(user)
|
||||
if not secret:
|
||||
raise HTTPException(500, "Failed to generate secret")
|
||||
uri = auth_manager.totp_get_provisioning_uri(user, secret)
|
||||
# Generate QR code as base64 PNG
|
||||
import qrcode, io, base64
|
||||
qr = qrcode.make(uri, box_size=6, border=2)
|
||||
buf = io.BytesIO()
|
||||
qr.save(buf, format="PNG")
|
||||
qr_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
||||
return {"secret": secret, "uri": uri, "qr_code": f"data:image/png;base64,{qr_b64}"}
|
||||
|
||||
class TotpVerifyRequest(BaseModel):
|
||||
code: str
|
||||
|
||||
@router.post("/2fa/confirm")
|
||||
async def totp_confirm(body: TotpVerifyRequest, request: Request):
|
||||
"""Verify a TOTP code to confirm 2FA setup. Returns backup codes."""
|
||||
user = _get_current_user(request)
|
||||
if not user:
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
if not auth_manager.totp_confirm_enable(user, body.code):
|
||||
raise HTTPException(400, "Invalid code — try again")
|
||||
backup = auth_manager.users.get(user, {}).get("totp_backup_codes", [])
|
||||
return {"ok": True, "backup_codes": backup}
|
||||
|
||||
class TotpDisableRequest(BaseModel):
|
||||
password: str
|
||||
|
||||
@router.post("/2fa/disable")
|
||||
async def totp_disable(body: TotpDisableRequest, request: Request):
|
||||
"""Disable 2FA. Requires password confirmation."""
|
||||
user = _get_current_user(request)
|
||||
if not user:
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
if not auth_manager.totp_disable(user, body.password):
|
||||
raise HTTPException(400, "Invalid password")
|
||||
return {"ok": True}
|
||||
|
||||
@router.get("/2fa/status")
|
||||
async def totp_status(request: Request):
|
||||
"""Check if 2FA is enabled for the current user."""
|
||||
user = _get_current_user(request)
|
||||
if not user:
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
return {"enabled": auth_manager.totp_enabled(user)}
|
||||
|
||||
# Admin-only routes
|
||||
@router.get("/users")
|
||||
async def list_users(request: Request):
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
return {"users": auth_manager.list_users()}
|
||||
|
||||
@router.post("/users")
|
||||
async def admin_create_user(body: CreateUserRequest, request: Request):
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
if len(body.password) < 8:
|
||||
raise HTTPException(400, "Password must be at least 8 characters")
|
||||
ok = auth_manager.create_user(body.username, body.password, body.is_admin)
|
||||
if not ok:
|
||||
raise HTTPException(409, "Username already taken")
|
||||
return {"ok": True}
|
||||
|
||||
@router.put("/users/{username}/privileges")
|
||||
async def update_user_privileges(username: str, request: Request):
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
body = await request.json()
|
||||
ok = auth_manager.set_privileges(username, body)
|
||||
if not ok:
|
||||
raise HTTPException(404, "User not found or is admin")
|
||||
return {"ok": True, "privileges": auth_manager.get_privileges(username)}
|
||||
|
||||
@router.post("/signup-toggle")
|
||||
async def toggle_signup(request: Request):
|
||||
"""Toggle open registration on/off. Admin only."""
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
auth_manager.signup_enabled = not auth_manager.signup_enabled
|
||||
return {"ok": True, "signup_enabled": auth_manager.signup_enabled}
|
||||
|
||||
@router.delete("/users")
|
||||
async def admin_delete_user(body: DeleteUserRequest, request: Request):
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
ok = auth_manager.delete_user(body.username, user)
|
||||
if not ok:
|
||||
raise HTTPException(400, "Cannot delete user")
|
||||
return {"ok": True}
|
||||
|
||||
# ---- Feature visibility (admin-managed) ----
|
||||
|
||||
@router.get("/features")
|
||||
async def get_features():
|
||||
"""Public: returns which UI features are enabled."""
|
||||
return _load_features()
|
||||
|
||||
@router.post("/features")
|
||||
async def set_features(request: Request):
|
||||
"""Admin only: update feature toggles."""
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
body = await request.json()
|
||||
current = _load_features()
|
||||
for key in current:
|
||||
if key in body and isinstance(body[key], bool):
|
||||
current[key] = body[key]
|
||||
_save_features(current)
|
||||
return current
|
||||
|
||||
# ---- App settings (admin-managed) ----
|
||||
|
||||
_SECRET_KEY_PATTERNS = ("_api_key", "_password", "_secret", "_token", "_key")
|
||||
|
||||
def _is_secret_key(name: str) -> bool:
|
||||
n = (name or "").lower()
|
||||
if n in ("google_pse_cx",): # public identifier, not a secret
|
||||
return False
|
||||
return any(n.endswith(p) or n == p.lstrip("_") for p in _SECRET_KEY_PATTERNS)
|
||||
|
||||
def _scrub_settings(settings: dict) -> dict:
|
||||
"""Return a copy of settings with secret-shaped values masked.
|
||||
|
||||
Frontend reads /settings without auth for things like keybinds + TTS
|
||||
prefs. Secrets (search-provider keys, IMAP/SMTP passwords) must NOT
|
||||
be exposed to non-admin callers.
|
||||
"""
|
||||
scrubbed = {}
|
||||
for k, v in (settings or {}).items():
|
||||
if _is_secret_key(k) and isinstance(v, str) and v:
|
||||
scrubbed[k] = "" # presence preserved, value blanked
|
||||
else:
|
||||
scrubbed[k] = v
|
||||
return scrubbed
|
||||
|
||||
@router.get("/settings")
|
||||
async def get_settings(request: Request):
|
||||
"""Returns app settings. Admins get the full set; non-admins get
|
||||
a scrubbed copy with secret keys blanked. The frontend uses this
|
||||
for keybinds + TTS prefs, so it stays callable without admin."""
|
||||
user = _get_current_user(request)
|
||||
settings = _load_settings()
|
||||
if user and auth_manager.is_admin(user):
|
||||
return settings
|
||||
return _scrub_settings(settings)
|
||||
|
||||
@router.post("/settings")
|
||||
async def set_settings(request: Request):
|
||||
"""Admin only: update app settings."""
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
body = await request.json()
|
||||
current = _load_settings()
|
||||
for key in DEFAULT_SETTINGS:
|
||||
if key in body:
|
||||
current[key] = body[key]
|
||||
_save_settings(current)
|
||||
return current
|
||||
|
||||
# ---- Integrations CRUD ----
|
||||
|
||||
# Run migration on startup
|
||||
migrate_from_settings()
|
||||
|
||||
@router.get("/integrations")
|
||||
async def list_integrations_route(request: Request):
|
||||
"""List all integrations (admin only, keys masked)."""
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
items = load_integrations()
|
||||
# Mask API keys for frontend display
|
||||
safe = []
|
||||
for item in items:
|
||||
copy = dict(item)
|
||||
if copy.get("api_key"):
|
||||
copy["api_key"] = copy["api_key"][:4] + "****"
|
||||
safe.append(copy)
|
||||
return {"integrations": safe}
|
||||
|
||||
@router.get("/integrations/presets")
|
||||
async def list_presets():
|
||||
"""List available integration presets."""
|
||||
return {"presets": {k: {kk: vv for kk, vv in v.items() if kk != "api_key"} for k, v in INTEGRATION_PRESETS.items()}}
|
||||
|
||||
@router.post("/integrations")
|
||||
async def create_integration(request: Request):
|
||||
"""Create a new integration (admin only)."""
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
body = await request.json()
|
||||
item = add_integration(body)
|
||||
return {"ok": True, "integration": item}
|
||||
|
||||
@router.put("/integrations/{integration_id}")
|
||||
async def update_integration_route(integration_id: str, request: Request):
|
||||
"""Update an existing integration (admin only)."""
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
body = await request.json()
|
||||
item = update_integration(integration_id, body)
|
||||
if not item:
|
||||
raise HTTPException(404, "Integration not found")
|
||||
return {"ok": True, "integration": item}
|
||||
|
||||
@router.delete("/integrations/{integration_id}")
|
||||
async def delete_integration_route(integration_id: str, request: Request):
|
||||
"""Delete an integration (admin only)."""
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
ok = delete_integration(integration_id)
|
||||
if not ok:
|
||||
raise HTTPException(404, "Integration not found")
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/integrations/{integration_id}/test")
|
||||
async def test_integration_route(integration_id: str, request: Request):
|
||||
"""Test connectivity to an integration (admin only)."""
|
||||
user = _get_current_user(request)
|
||||
if not user or not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
integ = get_integration(integration_id)
|
||||
if not integ:
|
||||
raise HTTPException(404, "Integration not found")
|
||||
preset = (integ.get("preset") or integ.get("name", "")).lower()
|
||||
|
||||
# ntfy is special: a GET / proves the server is reachable but
|
||||
# publishes nothing, so the user has no way to know whether
|
||||
# subscribers will actually receive notifications. Instead, do
|
||||
# the real thing — POST a one-line "connectivity test" message
|
||||
# to the topic the Reminders panel is configured to use. If the
|
||||
# subscriber app is wired up correctly, this is what the green
|
||||
# checkmark + a phone ping confirms together.
|
||||
if preset == "ntfy":
|
||||
import httpx
|
||||
from urllib.parse import urlparse
|
||||
# Strip any path/query the user accidentally pasted in the
|
||||
# base URL (e.g. `http://host:8091/odysseus`) — otherwise
|
||||
# the topic gets appended after the path and we publish to
|
||||
# `/odysseus/odysseus` (which ntfy 404s on). ntfy itself
|
||||
# only ever serves from the root.
|
||||
raw_base = (integ.get("base_url") or "").strip()
|
||||
parsed = urlparse(raw_base)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else raw_base.rstrip("/")
|
||||
settings = _load_settings()
|
||||
topic = (settings.get("reminder_ntfy_topic") or "reminders").strip() or "reminders"
|
||||
full_url = f"{base}/{topic}"
|
||||
api_key = integ.get("api_key", "")
|
||||
auth_type = (integ.get("auth_type") or "none").lower()
|
||||
headers = {
|
||||
"Title": "Odysseus connectivity test",
|
||||
"Tags": "white_check_mark",
|
||||
"Priority": "default",
|
||||
}
|
||||
if api_key:
|
||||
if auth_type == "bearer":
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
elif auth_type == "header":
|
||||
headers[integ.get("auth_header") or "Authorization"] = api_key
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||
r = await client.post(
|
||||
full_url,
|
||||
content="Connectivity test from Odysseus. If you see this on your phone, ntfy is wired up correctly.",
|
||||
headers=headers,
|
||||
)
|
||||
if r.is_success:
|
||||
# Tell the user EXACTLY where it went and what to
|
||||
# subscribe to on their phone, so they can match
|
||||
# without guesswork. The doubled-topic / wrong-host
|
||||
# mistakes are easier to spot when the actual URL
|
||||
# is right there in the success line.
|
||||
return {
|
||||
"ok": True,
|
||||
"message": (
|
||||
f"Sent to {full_url} — on your ntfy app, "
|
||||
f"subscribe to topic \"{topic}\" with server "
|
||||
f"\"{base}\" (or paste the full URL: {full_url})."
|
||||
),
|
||||
}
|
||||
return {"ok": False, "message": f"ntfy returned HTTP {r.status_code} from {full_url}: {r.text[:200]}"}
|
||||
except Exception as e:
|
||||
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}"[:300]}
|
||||
|
||||
# All other presets: GET against a known health endpoint.
|
||||
# Fall back to detecting from name if preset is missing.
|
||||
health_paths = {
|
||||
"miniflux": "/v1/me",
|
||||
"gitea": "/api/v1/version",
|
||||
"linkding": "/api/tags/",
|
||||
"homeassistant": "/api/",
|
||||
"home assistant": "/api/",
|
||||
}
|
||||
path = health_paths.get(preset, "/")
|
||||
result = await execute_api_call(integration_id, "GET", path)
|
||||
if result.get("exit_code", 1) == 0:
|
||||
return {"ok": True, "message": "Connection successful"}
|
||||
return {"ok": False, "message": (result.get("error") or "Connection failed")[:300]}
|
||||
|
||||
return router
|
||||
157
routes/backup_routes.py
Normal file
157
routes/backup_routes.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Backup routes — export/import user data (memories, presets, settings, skills, preferences)."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from core.middleware import require_admin
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.settings import load_settings, save_settings, load_features, save_features
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_backup_routes(memory_manager, preset_manager, skills_manager) -> APIRouter:
|
||||
router = APIRouter(tags=["backup"])
|
||||
|
||||
@router.get("/api/export")
|
||||
async def export_data(request: Request):
|
||||
"""Export all user data as a downloadable JSON file."""
|
||||
require_admin(request)
|
||||
user = get_current_user(request)
|
||||
|
||||
# Memories (filtered by owner when auth is enabled)
|
||||
memories = memory_manager.load(owner=user)
|
||||
|
||||
# Presets (shared across users — export all)
|
||||
presets = preset_manager.get_all()
|
||||
|
||||
# Skills (filtered by owner when auth is enabled)
|
||||
skills = skills_manager.load(owner=user)
|
||||
|
||||
# Settings
|
||||
settings = load_settings()
|
||||
|
||||
# Feature flags
|
||||
features = load_features()
|
||||
|
||||
# User preferences
|
||||
from routes.prefs_routes import _load_for_user
|
||||
preferences = _load_for_user(user)
|
||||
|
||||
export_data = {
|
||||
"version": 1,
|
||||
"exported_at": datetime.now().isoformat(),
|
||||
"exported_by": user,
|
||||
"memories": memories,
|
||||
"presets": presets,
|
||||
"skills": skills,
|
||||
"settings": settings,
|
||||
"features": features,
|
||||
"preferences": preferences,
|
||||
}
|
||||
|
||||
filename = f"odysseus_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
return Response(
|
||||
content=json.dumps(export_data, indent=2, ensure_ascii=False),
|
||||
media_type="application/json",
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"},
|
||||
)
|
||||
|
||||
@router.post("/api/import")
|
||||
async def import_data(request: Request):
|
||||
"""Import user data from a previously exported JSON file. Merges with existing data."""
|
||||
require_admin(request)
|
||||
user = get_current_user(request)
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
raise HTTPException(400, "Invalid JSON")
|
||||
|
||||
if not isinstance(body, dict):
|
||||
raise HTTPException(400, "Expected a JSON object")
|
||||
|
||||
imported = []
|
||||
|
||||
# ── Memories ──
|
||||
if "memories" in body and isinstance(body["memories"], list):
|
||||
existing = memory_manager.load_all()
|
||||
existing_texts = {e.get("text", "").strip().lower() for e in existing}
|
||||
added = 0
|
||||
for mem in body["memories"]:
|
||||
if not isinstance(mem, dict) or not mem.get("text"):
|
||||
continue
|
||||
if mem["text"].strip().lower() in existing_texts:
|
||||
continue # skip duplicates
|
||||
# Assign owner when auth is enabled
|
||||
if user and not mem.get("owner"):
|
||||
mem["owner"] = user
|
||||
existing.append(mem)
|
||||
existing_texts.add(mem["text"].strip().lower())
|
||||
added += 1
|
||||
memory_manager.save(existing)
|
||||
imported.append(f"{added} memories")
|
||||
|
||||
# ── Skills ──
|
||||
if "skills" in body and isinstance(body["skills"], list):
|
||||
existing = skills_manager.load_all()
|
||||
existing_ids = {s.get("id") for s in existing}
|
||||
existing_titles = {s.get("title", "").strip().lower() for s in existing}
|
||||
added = 0
|
||||
for skill in body["skills"]:
|
||||
if not isinstance(skill, dict) or not skill.get("title"):
|
||||
continue
|
||||
# Skip if same id or same title already exists
|
||||
if skill.get("id") in existing_ids:
|
||||
continue
|
||||
if skill["title"].strip().lower() in existing_titles:
|
||||
continue
|
||||
if user and not skill.get("owner"):
|
||||
skill["owner"] = user
|
||||
existing.append(skill)
|
||||
existing_ids.add(skill.get("id"))
|
||||
existing_titles.add(skill["title"].strip().lower())
|
||||
added += 1
|
||||
skills_manager.save(existing)
|
||||
imported.append(f"{added} skills")
|
||||
|
||||
# ── Presets ──
|
||||
if "presets" in body and isinstance(body["presets"], dict):
|
||||
current = preset_manager.get_all()
|
||||
for key, value in body["presets"].items():
|
||||
if isinstance(value, dict):
|
||||
current[key] = value
|
||||
elif isinstance(value, list):
|
||||
current[key] = value
|
||||
preset_manager.save(current)
|
||||
imported.append("presets")
|
||||
|
||||
# ── Settings ──
|
||||
if "settings" in body and isinstance(body["settings"], dict):
|
||||
current = load_settings()
|
||||
current.update(body["settings"])
|
||||
save_settings(current)
|
||||
imported.append("settings")
|
||||
|
||||
# ── Features ──
|
||||
if "features" in body and isinstance(body["features"], dict):
|
||||
current = load_features()
|
||||
current.update(body["features"])
|
||||
save_features(current)
|
||||
imported.append("features")
|
||||
|
||||
# ── Preferences ──
|
||||
if "preferences" in body and isinstance(body["preferences"], dict):
|
||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||
current = _load_for_user(user)
|
||||
current.update(body["preferences"])
|
||||
_save_for_user(user, current)
|
||||
imported.append("preferences")
|
||||
|
||||
if not imported:
|
||||
return {"ok": False, "message": "No recognized data found in the file"}
|
||||
|
||||
return {"ok": True, "imported": imported, "message": f"Imported: {', '.join(imported)}"}
|
||||
|
||||
return router
|
||||
1071
routes/calendar_routes.py
Normal file
1071
routes/calendar_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
802
routes/chat_helpers.py
Normal file
802
routes/chat_helpers.py
Normal file
@@ -0,0 +1,802 @@
|
||||
"""Shared helpers for chat routes — context building, post-response tasks, auth resolution."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.models import ChatMessage
|
||||
from core.database import SessionLocal
|
||||
from core.database import Session as DBSession, ModelEndpoint
|
||||
from src.llm_core import normalize_model_id
|
||||
from src.context_compactor import maybe_compact, trim_for_context
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.prompt_security import untrusted_context_message
|
||||
from routes.prefs_routes import _load_for_user as load_prefs_for_user
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Data containers ────────────────────────────────────────────────────── #
|
||||
|
||||
@dataclass
|
||||
class PresetInfo:
|
||||
"""Extracted preset parameters."""
|
||||
temperature: Optional[float]
|
||||
max_tokens: Optional[int]
|
||||
system_prompt: Optional[str]
|
||||
character_name: Optional[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessedMessage:
|
||||
"""Result of chat_handler.preprocess_message."""
|
||||
enhanced_message: str
|
||||
user_content: Any # str or list (multimodal)
|
||||
text_for_context: str
|
||||
youtube_transcripts: list
|
||||
attachment_meta: list
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatContext:
|
||||
"""Everything needed to call the LLM after context-building."""
|
||||
preface: list
|
||||
rag_sources: list
|
||||
web_sources: list
|
||||
used_memories: list
|
||||
messages: list
|
||||
context_length: int
|
||||
was_compacted: bool
|
||||
user: Optional[str]
|
||||
uprefs: dict
|
||||
preset: PresetInfo
|
||||
preprocessed: PreprocessedMessage
|
||||
# Documents auto-created server-side during preprocess (e.g. when an
|
||||
# attached fillable PDF gets rendered into a markdown editor doc).
|
||||
# The chat route emits a doc_update SSE event for each before streaming
|
||||
# begins, so the editor pane switches to the new doc immediately.
|
||||
auto_opened_docs: list = field(default_factory=list)
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────── #
|
||||
|
||||
def _enforce_chat_privileges(request, sess) -> None:
|
||||
"""Apply the per-user privilege gates (allowed_models + max_messages_per_day)
|
||||
that both /api/chat and /api/chat_stream must enforce BEFORE any LLM work.
|
||||
|
||||
Raises HTTPException(403) if the session's model is not in the user's
|
||||
allowlist, or HTTPException(429) if the user has hit their daily message
|
||||
cap. No-op for unauthenticated callers or when auth_manager is absent
|
||||
(single-user mode). Admins receive ADMIN_PRIVILEGES from get_privileges,
|
||||
which means empty allowed_models / zero cap → no-op for them.
|
||||
"""
|
||||
try:
|
||||
user = get_current_user(request)
|
||||
except Exception:
|
||||
user = None
|
||||
if not user:
|
||||
return
|
||||
auth_manager = getattr(getattr(request.app, "state", None), "auth_manager", None)
|
||||
if not auth_manager:
|
||||
return
|
||||
|
||||
privs = auth_manager.get_privileges(user) or {}
|
||||
allowed = privs.get("allowed_models") or []
|
||||
if allowed and sess.model and sess.model not in allowed:
|
||||
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
|
||||
|
||||
cap = int(privs.get("max_messages_per_day") or 0)
|
||||
if cap <= 0:
|
||||
return
|
||||
|
||||
from datetime import datetime as _dt, timedelta as _td
|
||||
from core.database import Session as _DbSess, ChatMessage as _Cm
|
||||
db = SessionLocal()
|
||||
try:
|
||||
count = (
|
||||
db.query(_Cm)
|
||||
.join(_DbSess, _Cm.session_id == _DbSess.id)
|
||||
.filter(_DbSess.owner == user,
|
||||
_Cm.role == "user",
|
||||
_Cm.timestamp >= _dt.utcnow() - _td(days=1))
|
||||
.count()
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
if count >= cap:
|
||||
raise HTTPException(429, f"Daily message limit reached ({cap}). Try again in 24 hours.")
|
||||
|
||||
|
||||
def needs_auto_name(name: str) -> bool:
|
||||
"""Check if a session still has its default/placeholder name."""
|
||||
if not name:
|
||||
return True
|
||||
if name.startswith("Chat:") or name == "Chat":
|
||||
return True
|
||||
# Default frontend name: "modelname HH:MM:SS AM/PM"
|
||||
if re.match(r'^.+ \d{1,2}:\d{2}:\d{2}\s*(AM|PM)$', name):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def auto_name_session(session_manager, sess):
|
||||
"""Generate a short title for a session from its first user message."""
|
||||
try:
|
||||
from src.llm_core import llm_call_async
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
|
||||
# Find first user message
|
||||
first_msg = ""
|
||||
for msg in sess.history:
|
||||
if msg.role == "user":
|
||||
content = msg.content
|
||||
if isinstance(content, list):
|
||||
content = next(
|
||||
(i.get("text", "") for i in content if isinstance(i, dict) and i.get("type") == "text"),
|
||||
"",
|
||||
)
|
||||
first_msg = str(content)[:500]
|
||||
break
|
||||
|
||||
if not first_msg:
|
||||
return
|
||||
|
||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||
sess.endpoint_url, sess.model, sess.headers,
|
||||
)
|
||||
|
||||
# max_tokens big enough that reasoning models (Minimax M2,
|
||||
# DeepSeek R1, QwQ, etc.) have headroom for <think>…</think>
|
||||
# plus the actual title — 200 used to clip them mid-reasoning
|
||||
# so strip_think left an empty string and no rename happened.
|
||||
# Timeout matches: 60s gives slow local reasoners room to finish.
|
||||
title = await llm_call_async(
|
||||
t_url,
|
||||
t_model,
|
||||
[
|
||||
{"role": "system", "content": "Generate a short title (3-6 words, no quotes) for a conversation that starts with this message. Reply with ONLY the title, nothing else. Do NOT include any thinking, reasoning, or explanation — just the title."},
|
||||
{"role": "user", "content": first_msg},
|
||||
],
|
||||
temperature=0.3,
|
||||
max_tokens=4096,
|
||||
headers=t_headers,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
title = title.strip().strip('"\'').strip()
|
||||
# Strip <think>/<thinking> blocks (closed, dangling, or stray tags)
|
||||
# via the central helper.
|
||||
from src.text_helpers import strip_think
|
||||
title = strip_think(title, prose=False, prompt_echo=False)
|
||||
if title and len(title) < 80:
|
||||
session_manager.update_session_name(sess.id, title)
|
||||
logger.info(f"Auto-named session {sess.id}: {title}")
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Auto-name failed for {sess.id}: {e}\n{traceback.format_exc()}")
|
||||
|
||||
|
||||
def try_fallback_endpoint(sess, session_id: str) -> dict | None:
|
||||
"""Find an alternative working endpoint when the current one fails.
|
||||
|
||||
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
|
||||
"""
|
||||
import requests as _req
|
||||
from src.endpoint_resolver import build_chat_url, build_headers, normalize_base
|
||||
|
||||
current_url = sess.endpoint_url or ""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
endpoints = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True
|
||||
).all()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
for ep in endpoints:
|
||||
base = normalize_base(ep.base_url)
|
||||
# Skip current endpoint
|
||||
if current_url and base in current_url:
|
||||
continue
|
||||
# Quick ping
|
||||
ping_url = base + "/models"
|
||||
headers = {}
|
||||
if ep.api_key:
|
||||
headers["Authorization"] = f"Bearer {ep.api_key}"
|
||||
try:
|
||||
r = _req.get(ping_url, headers=headers, timeout=5)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
|
||||
if not models:
|
||||
continue
|
||||
# Found a working endpoint — update session
|
||||
new_model = models[0]
|
||||
chat_url = build_chat_url(base)
|
||||
new_headers = build_headers(ep.api_key, base)
|
||||
|
||||
sess.model = new_model
|
||||
sess.endpoint_url = chat_url
|
||||
sess.headers = new_headers
|
||||
|
||||
# Persist
|
||||
_db = SessionLocal()
|
||||
try:
|
||||
_db.query(DBSession).filter(DBSession.id == session_id).update({
|
||||
"model": new_model,
|
||||
"endpoint_url": chat_url,
|
||||
"headers": json.dumps(new_headers),
|
||||
})
|
||||
_db.commit()
|
||||
finally:
|
||||
_db.close()
|
||||
|
||||
logger.info(f"Fallback: switched session {session_id} from {current_url} to {ep.name} ({new_model})")
|
||||
return {
|
||||
"model": new_model,
|
||||
"endpoint_url": chat_url,
|
||||
"endpoint_name": ep.name,
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_preset(chat_handler, preset_id) -> PresetInfo:
|
||||
"""Extract preset parameters via chat_handler."""
|
||||
temperature, max_tokens, system_prompt, char_name = (
|
||||
chat_handler.validate_and_extract_preset(preset_id)
|
||||
)
|
||||
return PresetInfo(
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_prompt,
|
||||
character_name=char_name,
|
||||
)
|
||||
|
||||
|
||||
async def preprocess(
|
||||
chat_handler, message, att_ids, sess,
|
||||
auto_opened_docs: Optional[list] = None,
|
||||
) -> PreprocessedMessage:
|
||||
"""Run chat_handler.preprocess_message and wrap the result."""
|
||||
enhanced, user_content, text_ctx, yt_transcripts, att_meta = (
|
||||
await chat_handler.preprocess_message(
|
||||
message, att_ids, sess, auto_opened_docs=auto_opened_docs
|
||||
)
|
||||
)
|
||||
return PreprocessedMessage(
|
||||
enhanced_message=enhanced,
|
||||
user_content=user_content,
|
||||
text_for_context=text_ctx,
|
||||
youtube_transcripts=yt_transcripts,
|
||||
attachment_meta=att_meta,
|
||||
)
|
||||
|
||||
|
||||
def add_user_message(sess, chat_handler, preprocessed: PreprocessedMessage, incognito: bool = False):
|
||||
"""Add user message to session history and update session name.
|
||||
In incognito mode, still add to in-memory history (for conversation context)
|
||||
but skip session name update (which would persist)."""
|
||||
user_meta = {"attachments": preprocessed.attachment_meta} if preprocessed.attachment_meta else None
|
||||
sess.add_message(ChatMessage("user", preprocessed.user_content, metadata=user_meta))
|
||||
if not incognito:
|
||||
chat_handler.update_session_name_if_needed(sess, preprocessed.text_for_context)
|
||||
|
||||
|
||||
def fire_message_event(request, webhook_manager, session_id: str, sess, message: str, compare_mode: bool = False):
|
||||
"""Fire webhook and event_bus events for a new user message."""
|
||||
if webhook_manager and not compare_mode:
|
||||
asyncio.create_task(webhook_manager.fire("chat.message", {
|
||||
"session_id": session_id, "model": sess.model, "message": message[:2000],
|
||||
}))
|
||||
from src.event_bus import fire_event
|
||||
user = get_current_user(request)
|
||||
fire_event("message_sent", user)
|
||||
|
||||
|
||||
def resolve_session_auth(sess, session_id: str):
|
||||
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
||||
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
||||
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
||||
)
|
||||
if has_auth:
|
||||
return
|
||||
|
||||
try:
|
||||
from src.endpoint_resolver import build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
domain = sess.endpoint_url.split("//")[1].split("/")[0] if "//" in sess.endpoint_url else ""
|
||||
if domain:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.base_url.contains(domain)).first()
|
||||
if ep and ep.api_key:
|
||||
sess.headers = build_headers(ep.api_key, ep.base_url)
|
||||
db.query(DBSession).filter(DBSession.id == session_id).update(
|
||||
{"headers": json.dumps(sess.headers)}
|
||||
)
|
||||
db.commit()
|
||||
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve session headers: {e}")
|
||||
|
||||
|
||||
async def build_chat_context(
|
||||
sess,
|
||||
request,
|
||||
chat_handler,
|
||||
chat_processor,
|
||||
message: str,
|
||||
session_id: str,
|
||||
preset_id=None,
|
||||
att_ids: list = None,
|
||||
use_web=None,
|
||||
use_rag=None,
|
||||
use_research=None,
|
||||
time_filter=None,
|
||||
incognito: bool = False,
|
||||
no_memory: bool = False,
|
||||
search_context: str = None,
|
||||
compare_mode: bool = False,
|
||||
webhook_manager=None,
|
||||
use_enhanced_message: bool = False,
|
||||
agent_mode: bool = False,
|
||||
) -> ChatContext:
|
||||
"""Build the full context (preface + messages) for an LLM call.
|
||||
|
||||
This is the shared logic between /chat and /chat_stream — preset extraction,
|
||||
message preprocessing, memory/RAG/web injection, compaction, normalization.
|
||||
"""
|
||||
# Preset
|
||||
preset = extract_preset(chat_handler, preset_id)
|
||||
|
||||
# Preprocess message (CoT, YouTube, VL images, build content). The
|
||||
# auto_opened_docs collector captures any docs created server-side
|
||||
# (e.g. fillable PDF → markdown editor doc) so the chat route can
|
||||
# announce them to the frontend before streaming.
|
||||
auto_opened_docs: list = []
|
||||
preprocessed = await preprocess(
|
||||
chat_handler, message, att_ids or [], sess,
|
||||
auto_opened_docs=auto_opened_docs,
|
||||
)
|
||||
|
||||
# Add user message to history
|
||||
add_user_message(sess, chat_handler, preprocessed, incognito=incognito)
|
||||
|
||||
# Fire events
|
||||
if not incognito:
|
||||
fire_message_event(request, webhook_manager, session_id, sess, message, compare_mode)
|
||||
|
||||
# Resolve user prefs
|
||||
user = get_current_user(request)
|
||||
uprefs = load_prefs_for_user(user)
|
||||
|
||||
# Memory enabled?
|
||||
mem_enabled = not incognito and not no_memory and uprefs.get("memory_enabled", True)
|
||||
# Skills injection respects its own enable toggle (mirrors memory_enabled).
|
||||
# When off, the "Available skills" index is not added to the prompt.
|
||||
skills_enabled = not incognito and uprefs.get("skills_enabled", True)
|
||||
logger.debug(
|
||||
"Memory enabled=%s for user=%s (incognito=%s, no_memory=%s, pref=%s)",
|
||||
mem_enabled, user, incognito, no_memory, uprefs.get("memory_enabled", "NOT_SET"),
|
||||
)
|
||||
|
||||
# Use RAG?
|
||||
use_rag_val = (str(use_rag).lower() != "false") if use_rag is not None else True
|
||||
if incognito:
|
||||
use_rag_val = False
|
||||
|
||||
# If pre-fetched search context was provided (compare mode), skip live web search
|
||||
skip_web = bool(search_context)
|
||||
|
||||
# Build context preface
|
||||
# The stream path uses enhanced_message (with CoT/preprocessing applied),
|
||||
# the sync path uses text_for_context.
|
||||
_ctx_msg = preprocessed.enhanced_message if use_enhanced_message else preprocessed.text_for_context
|
||||
_preface_kwargs = dict(
|
||||
message=_ctx_msg,
|
||||
session=sess,
|
||||
use_web=use_web and not skip_web,
|
||||
use_memory=mem_enabled,
|
||||
time_filter=time_filter,
|
||||
preset_system_prompt=preset.system_prompt,
|
||||
owner=user,
|
||||
character_name=preset.character_name,
|
||||
agent_mode=agent_mode,
|
||||
incognito=incognito,
|
||||
use_skills=skills_enabled,
|
||||
)
|
||||
if use_rag is not None:
|
||||
_preface_kwargs["use_rag"] = use_rag_val
|
||||
preface, rag_sources, web_sources = chat_processor.build_context_preface(**_preface_kwargs)
|
||||
|
||||
# Capture used memories immediately
|
||||
used_memories = getattr(chat_processor, '_last_used_memories', [])
|
||||
|
||||
# Inject pre-fetched search context (compare mode)
|
||||
if search_context:
|
||||
preface.append(untrusted_context_message("prefetched search context", search_context))
|
||||
|
||||
# YouTube transcripts
|
||||
for transcript in preprocessed.youtube_transcripts:
|
||||
preface.append(untrusted_context_message("youtube transcript", transcript))
|
||||
|
||||
# Normalize model ID
|
||||
norm = normalize_model_id(sess.endpoint_url, sess.model)
|
||||
if norm:
|
||||
sess.model = norm
|
||||
|
||||
# Build messages
|
||||
messages = preface + sess.get_context_messages()
|
||||
|
||||
# Auto-compact
|
||||
messages, context_length, was_compacted = await maybe_compact(
|
||||
sess, sess.endpoint_url, sess.model, messages, sess.headers,
|
||||
)
|
||||
messages = trim_for_context(messages, context_length)
|
||||
|
||||
return ChatContext(
|
||||
preface=preface,
|
||||
rag_sources=rag_sources,
|
||||
web_sources=web_sources,
|
||||
used_memories=used_memories,
|
||||
messages=messages,
|
||||
context_length=context_length,
|
||||
was_compacted=was_compacted,
|
||||
user=user,
|
||||
uprefs=uprefs,
|
||||
preset=preset,
|
||||
preprocessed=preprocessed,
|
||||
auto_opened_docs=auto_opened_docs,
|
||||
)
|
||||
|
||||
|
||||
def accumulate_token_usage(session_id: str, metrics: dict):
|
||||
"""Add input/output token counts to the session's running totals."""
|
||||
in_t = metrics.get("input_tokens", 0)
|
||||
out_t = metrics.get("output_tokens", 0)
|
||||
if not (in_t or out_t):
|
||||
return
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_s = db.query(DBSession).filter(DBSession.id == session_id).first()
|
||||
if db_s:
|
||||
db_s.total_input_tokens = (db_s.total_input_tokens or 0) + in_t
|
||||
db_s.total_output_tokens = (db_s.total_output_tokens or 0) + out_t
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _normalize_thinking(text: str) -> str:
|
||||
"""Wrap inline thinking patterns in <think> tags so they persist on reload.
|
||||
|
||||
Handles:
|
||||
- "Thinking Process:" (Qwen3.5)
|
||||
- Gemma-style inline reasoning ("The user said/asked...", "I should/need to...")
|
||||
- Garbled <think> tags (reasoning before the tag, unclosed tags)
|
||||
"""
|
||||
import re
|
||||
if not text:
|
||||
return text
|
||||
reasoning_prefix_re = re.compile(
|
||||
r'^\s*(?:thinking(?:\s+process)?\s*:|the user |i need |i should |i will |they are |the question |i can )',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
thinking_prefix_re = re.compile(r'^thinking(?:\s+process)?\s*:\s*', re.IGNORECASE)
|
||||
|
||||
# Handle garbled <think> tags: reasoning text followed by <think> as separator
|
||||
# e.g. "The user said...I should respond.\n<think>Hey! What's up?"
|
||||
garbled = re.match(
|
||||
r'^([\s\S]+?)\n*<think(?:ing)?>\s*([\s\S]*?)(?:</think(?:ing)?>)?\s*$',
|
||||
text, re.IGNORECASE
|
||||
)
|
||||
if garbled:
|
||||
before = garbled.group(1).strip()
|
||||
after = garbled.group(2).strip()
|
||||
# Only treat as garbled if the part before <think> looks like reasoning
|
||||
reasoning_starts = (
|
||||
'The user ', 'I need ', 'I should ', 'I will ',
|
||||
'They are ', 'The question ', 'I can ',
|
||||
'Thinking Process', 'Thinking:',
|
||||
)
|
||||
stripped_before = before.lstrip()
|
||||
if any(stripped_before.startswith(p) for p in reasoning_starts) or reasoning_prefix_re.match(stripped_before):
|
||||
# Strip "Thinking:" prefix from the thinking content
|
||||
stripped_before = thinking_prefix_re.sub('', stripped_before)
|
||||
return '<think>' + stripped_before + '</think>\n' + after
|
||||
|
||||
if '<think' in text.lower():
|
||||
return text # already has proper think tags
|
||||
|
||||
# Qwen3.5: "Thinking Process:" or "Thinking:" prefix
|
||||
if thinking_prefix_re.match(text.lstrip()):
|
||||
# Try clean boundary first
|
||||
m = re.match(
|
||||
r'^(Thinking(?:\s+Process)?:[\s\S]*?)(\n\n(?=[A-Z]|Hey|Yo|Hi|Sure|I |What|Here|Let|The |This |OK|Ok|Yes|No |So |Well |Thank|Alright|Of course|Absolutely|Great|Hello|As ))',
|
||||
text, re.IGNORECASE | re.MULTILINE
|
||||
)
|
||||
if m:
|
||||
think = thinking_prefix_re.sub('', m.group(1)).strip()
|
||||
return '<think>' + think + '</think>' + text[m.end()-2:]
|
||||
# Fallback: find last non-indented paragraph as reply
|
||||
parts = text.split('\n\n')
|
||||
for i in range(len(parts) - 1, 0, -1):
|
||||
line = parts[i].strip()
|
||||
if line and not re.match(r'^[\d*\-\s(]', line) and len(line) > 5:
|
||||
think = thinking_prefix_re.sub('', '\n\n'.join(parts[:i])).strip()
|
||||
reply = '\n\n'.join(parts[i:])
|
||||
return '<think>' + think + '</think>\n\n' + reply
|
||||
# Last resort: look for a quoted final response inside the thinking
|
||||
# Qwen often drafts the reply as "Option: ..." or * "reply text"
|
||||
last_quote = re.findall(r'["\u201c]([^"\u201d]{10,})["\u201d]', text)
|
||||
if last_quote:
|
||||
reply = last_quote[-1].strip()
|
||||
think = thinking_prefix_re.sub('', text).strip()
|
||||
return '<think>' + think + '</think>\n\n' + reply
|
||||
# Truly no reply found
|
||||
think = thinking_prefix_re.sub('', text).strip()
|
||||
return '<think>' + think + '</think>'
|
||||
|
||||
# Gemma-style: starts with reasoning ("The user", "I need", "I should", etc.)
|
||||
stripped_text = text.lstrip()
|
||||
first_line = stripped_text.split('\n')[0].strip()
|
||||
reasoning_starts = (
|
||||
'The user ', 'I need ', 'I should ', 'I will ',
|
||||
'They are ', 'The question ', 'I can ',
|
||||
)
|
||||
reply_starts = (
|
||||
'Hey', 'Hi ', 'Hi!', 'Hello', 'Sure', 'Yes', 'No ', 'No,', 'Yo', 'OK',
|
||||
'Here', 'Absolutely', 'Of course', 'Great', 'Alright',
|
||||
'Thanks', 'Welcome', 'Good ', "I'm happy", "I'd be",
|
||||
)
|
||||
if any(first_line.startswith(p) for p in reasoning_starts):
|
||||
# Try line-by-line split first
|
||||
lines = stripped_text.split('\n')
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
if i > 0 and any(stripped.startswith(p) for p in reply_starts):
|
||||
think = '\n'.join(lines[:i])
|
||||
reply = '\n'.join(lines[i:])
|
||||
return '<think>' + think + '</think>\n' + reply
|
||||
|
||||
# Try within-line split — model mashed thinking + reply on one line
|
||||
# Look for reply pattern after a period or sentence end
|
||||
for p in reply_starts:
|
||||
# Match: "...reasoning text.Reply text" or "...reasoning text. Reply text"
|
||||
pattern = r'([.!?])\s*(' + re.escape(p) + r')'
|
||||
m = re.search(pattern, stripped_text)
|
||||
if m and m.start() > 20: # at least 20 chars of reasoning before
|
||||
think = stripped_text[:m.start() + 1] # include the period
|
||||
reply = stripped_text[m.start() + 1:].lstrip()
|
||||
return '<think>' + think + '</think>\n' + reply
|
||||
|
||||
# Last resort: find last non-reasoning line
|
||||
for i in range(len(lines) - 1, 0, -1):
|
||||
stripped = lines[i].strip()
|
||||
if stripped and not any(stripped.startswith(p) for p in reasoning_starts) and not stripped.startswith('*') and len(stripped) > 3:
|
||||
think = '\n'.join(lines[:i])
|
||||
reply = '\n'.join(lines[i:])
|
||||
return '<think>' + think + '</think>\n' + reply
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def _extract_thinking_meta(text: str) -> dict | None:
|
||||
"""Extract thinking content into metadata, return {thinking, reply, time} or None."""
|
||||
import re
|
||||
if not text:
|
||||
return None
|
||||
|
||||
# Check for <think> tags (native or injected)
|
||||
time_match = re.search(r'<think(?:ing)?\s+time="([\d.]+)"', text)
|
||||
think_time = time_match.group(1) if time_match else None
|
||||
# Strip time attr for parsing
|
||||
clean = re.sub(r'<think(?:ing)?\s+time="[\d.]+"', '<think', text)
|
||||
|
||||
think_match = re.match(r'^[\s]*<think(?:ing)?>([\s\S]*?)</think(?:ing)?>\s*([\s\S]*)', clean, re.IGNORECASE)
|
||||
if think_match:
|
||||
thinking = think_match.group(1).strip()
|
||||
reply = think_match.group(2).strip()
|
||||
# Only strip the thinking out into metadata when there's an actual reply
|
||||
# left over. If reply is empty (model hit max_tokens inside <think>, or
|
||||
# the turn was reasoning-only), keep the raw text as content — otherwise
|
||||
# the saved message has empty content and the bubble looks blank on
|
||||
# reload. The renderer's processWithThinking still extracts the <think>
|
||||
# block visually at display time, so nothing changes for the normal case.
|
||||
if thinking and reply:
|
||||
return {"thinking": thinking, "reply": reply, "time": think_time}
|
||||
|
||||
# Detect Thinking Process: or Gemma-style reasoning
|
||||
normalized = _normalize_thinking(text)
|
||||
if '<think>' in normalized:
|
||||
think_match2 = re.match(r'^[\s]*<think(?:ing)?>([\s\S]*?)</think(?:ing)?>\s*([\s\S]*)', normalized, re.IGNORECASE)
|
||||
if think_match2:
|
||||
thinking = think_match2.group(1).strip()
|
||||
reply = think_match2.group(2).strip()
|
||||
if thinking and reply:
|
||||
return {"thinking": thinking, "reply": reply, "time": think_time}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def clean_thinking_for_save(content: str, metadata: dict | None = None) -> tuple[str, dict]:
|
||||
"""Extract thinking from content into metadata. Use for save paths that bypass save_assistant_response."""
|
||||
md = dict(metadata) if metadata else {}
|
||||
info = _extract_thinking_meta(content)
|
||||
if info:
|
||||
md["thinking"] = info["thinking"]
|
||||
if info.get("time"):
|
||||
md["thinking_time"] = info["time"]
|
||||
return info["reply"], md
|
||||
return content, md
|
||||
|
||||
|
||||
def save_assistant_response(
|
||||
sess,
|
||||
session_manager,
|
||||
session_id: str,
|
||||
full_response: str,
|
||||
last_metrics: dict | None,
|
||||
*,
|
||||
character_name: str = None,
|
||||
web_sources: list = None,
|
||||
rag_sources: list = None,
|
||||
research_sources: list = None,
|
||||
used_memories: list = None,
|
||||
do_research: bool = False,
|
||||
tool_events: list = None,
|
||||
incognito: bool = False,
|
||||
):
|
||||
"""Add assistant response to session history. In incognito mode, keeps in-memory context but skips DB persistence."""
|
||||
md = dict(last_metrics) if last_metrics else {}
|
||||
md["model"] = sess.model
|
||||
if character_name:
|
||||
md["character_name"] = character_name
|
||||
if web_sources:
|
||||
md["web_sources"] = web_sources
|
||||
if rag_sources:
|
||||
md["rag_sources"] = rag_sources
|
||||
if research_sources:
|
||||
md["research_sources"] = research_sources
|
||||
if used_memories:
|
||||
md["memories_used"] = used_memories
|
||||
if do_research and not research_sources:
|
||||
md["research_clarification"] = True
|
||||
if tool_events:
|
||||
md["tool_events"] = tool_events
|
||||
|
||||
# Extract thinking into metadata (don't pollute message content with <think> tags)
|
||||
_think_info = _extract_thinking_meta(full_response)
|
||||
if _think_info:
|
||||
md["thinking"] = _think_info["thinking"]
|
||||
md["thinking_time"] = _think_info.get("time")
|
||||
_content = _think_info["reply"]
|
||||
else:
|
||||
_content = full_response
|
||||
sess.add_message(ChatMessage("assistant", _content, metadata=md))
|
||||
|
||||
if not incognito:
|
||||
from core.database import update_session_last_accessed
|
||||
update_session_last_accessed(session_id)
|
||||
session_manager.save_sessions()
|
||||
|
||||
# Return the persisted message's DB id so the stream can wire it onto the
|
||||
# freshly-rendered bubble — lets the user edit/delete a just-streamed reply
|
||||
# without reloading. Incognito returns None: those messages are ephemeral,
|
||||
# so we don't hand out an edit/delete handle for them.
|
||||
if incognito:
|
||||
return None
|
||||
try:
|
||||
_last = sess.history[-1]
|
||||
_meta = getattr(_last, "metadata", None)
|
||||
if isinstance(_meta, dict):
|
||||
return _meta.get("_db_id")
|
||||
except (IndexError, AttributeError):
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def run_post_response_tasks(
|
||||
sess,
|
||||
session_manager,
|
||||
session_id: str,
|
||||
message: str,
|
||||
full_response: str,
|
||||
last_metrics: dict | None,
|
||||
uprefs: dict,
|
||||
memory_manager,
|
||||
memory_vector,
|
||||
webhook_manager,
|
||||
*,
|
||||
incognito: bool = False,
|
||||
compare_mode: bool = False,
|
||||
character_name: str = None,
|
||||
agent_rounds: int = 0,
|
||||
agent_tool_calls: int = 0,
|
||||
skills_manager=None,
|
||||
owner: str = None,
|
||||
extract_skills: bool = True,
|
||||
):
|
||||
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction."""
|
||||
# Memory extraction — only every 4th message pair to avoid excess LLM calls
|
||||
_msg_count = len(sess.history) if hasattr(sess, 'history') else 0
|
||||
_should_extract = (_msg_count >= 4) and (_msg_count % 4 == 0)
|
||||
if not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
|
||||
from services.memory.memory_extractor import extract_and_store
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||
sess.endpoint_url, sess.model, sess.headers,
|
||||
)
|
||||
asyncio.create_task(extract_and_store(
|
||||
sess, memory_manager, memory_vector,
|
||||
t_url, t_model, t_headers,
|
||||
))
|
||||
|
||||
# Skill extraction from complex agent runs. Only when the user actually
|
||||
# chose agent mode — not a chat we auto-escalated for a notes/calendar
|
||||
# intent, and never in incognito/compare.
|
||||
auto_skills_enabled = bool(uprefs.get("auto_skills", True))
|
||||
# Quiet by default — full gate/dispatch/start trace runs at DEBUG so
|
||||
# users can re-enable diagnostics with LOG_LEVEL=DEBUG when something
|
||||
# silently breaks. INFO-level only shows the outcome inside
|
||||
# maybe_extract_skill (Auto-extracted / dropped / failed).
|
||||
logger.debug(
|
||||
"[skill-extract] gate: extract_skills=%s auto_skills=%s incognito=%s "
|
||||
"compare=%s rounds=%d tools=%d skills_manager=%s",
|
||||
extract_skills, auto_skills_enabled, incognito, compare_mode,
|
||||
agent_rounds, agent_tool_calls, "set" if skills_manager else "MISSING",
|
||||
)
|
||||
if (
|
||||
extract_skills
|
||||
and auto_skills_enabled
|
||||
and not incognito
|
||||
and not compare_mode
|
||||
and (agent_rounds >= 2 or agent_tool_calls >= 2)
|
||||
):
|
||||
if skills_manager is None:
|
||||
logger.warning(
|
||||
"[skill-extract] gate PASSED but skills_manager is None — "
|
||||
"extraction skipped. (Bug: caller didn't pass skills_manager.)"
|
||||
)
|
||||
else:
|
||||
from services.memory.skill_extractor import maybe_extract_skill
|
||||
from src.task_endpoint import resolve_task_endpoint
|
||||
s_url, s_model, s_headers = resolve_task_endpoint(
|
||||
sess.endpoint_url, sess.model, sess.headers,
|
||||
)
|
||||
logger.debug("[skill-extract] dispatching extractor (model=%s)", s_model)
|
||||
asyncio.create_task(maybe_extract_skill(
|
||||
sess, skills_manager,
|
||||
s_url, s_model, s_headers,
|
||||
agent_rounds, agent_tool_calls,
|
||||
owner=owner,
|
||||
))
|
||||
|
||||
# Token accumulation
|
||||
if last_metrics:
|
||||
accumulate_token_usage(session_id, last_metrics)
|
||||
|
||||
# Webhook
|
||||
if webhook_manager and not compare_mode:
|
||||
asyncio.create_task(webhook_manager.fire("chat.completed", {
|
||||
"session_id": session_id, "model": sess.model,
|
||||
"user_message": message, "response": full_response[:2000],
|
||||
}))
|
||||
|
||||
# Auto-name
|
||||
if needs_auto_name(sess.name):
|
||||
asyncio.create_task(auto_name_session(session_manager, sess))
|
||||
1114
routes/chat_routes.py
Normal file
1114
routes/chat_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
60
routes/cleanup_routes.py
Normal file
60
routes/cleanup_routes.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# routes/cleanup_routes.py
|
||||
"""Routes for cleanup operations."""
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from src.cleanup_service import get_cleanup_preview, cleanup_sessions
|
||||
from src.auth_helpers import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def setup_cleanup_routes(session_manager):
|
||||
"""
|
||||
Setup cleanup-related routes.
|
||||
|
||||
Args:
|
||||
session_manager: SessionManager instance
|
||||
|
||||
Returns:
|
||||
APIRouter instance with cleanup routes
|
||||
"""
|
||||
router = APIRouter(prefix="/api/cleanup")
|
||||
|
||||
@router.get("/preview")
|
||||
async def cleanup_preview(request: Request):
|
||||
"""
|
||||
Preview what would be cleaned up without making any changes.
|
||||
|
||||
Returns:
|
||||
JSON response with lists of sessions that would be archived/deleted and estimated space savings
|
||||
"""
|
||||
user = get_current_user(request)
|
||||
try:
|
||||
preview = await get_cleanup_preview(owner=user)
|
||||
return preview
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup preview failed: {e}")
|
||||
raise HTTPException(500, "Cleanup preview generation failed")
|
||||
|
||||
@router.post("")
|
||||
async def cleanup_endpoint(request: Request):
|
||||
"""
|
||||
Perform cleanup operations:
|
||||
1. Archive inactive sessions (not accessed for 7 days)
|
||||
2. Delete old sessions (archived, not important, not accessed for 14+ days, with fewer than 10 messages)
|
||||
|
||||
Returns:
|
||||
JSON response with counts of deleted and archived sessions, and space freed
|
||||
"""
|
||||
user = get_current_user(request)
|
||||
try:
|
||||
archived_count, deleted_count, space_freed_mb = await cleanup_sessions(session_manager, owner=user)
|
||||
return {
|
||||
"archived_count": archived_count,
|
||||
"deleted_count": deleted_count,
|
||||
"space_freed_mb": round(space_freed_mb, 2)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Cleanup failed: {e}")
|
||||
raise HTTPException(500, "Cleanup operation failed")
|
||||
|
||||
return router
|
||||
246
routes/compare_routes.py
Normal file
246
routes/compare_routes.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# routes/compare_routes.py
|
||||
"""Model A/B comparison routes."""
|
||||
import json
|
||||
import uuid
|
||||
import random
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from typing import List
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
from core.database import Comparison, SessionLocal
|
||||
from core.session_manager import SessionManager
|
||||
from src.auth_helpers import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/compare", tags=["compare"])
|
||||
|
||||
|
||||
class RecordVoteRequest(BaseModel):
|
||||
prompt: str
|
||||
models: List[str]
|
||||
winner: str # model name or "tie"
|
||||
is_blind: bool = True
|
||||
|
||||
|
||||
def setup_compare_routes(session_manager: SessionManager):
|
||||
"""Setup comparison routes."""
|
||||
|
||||
@router.post("/start")
|
||||
def start_comparison(
|
||||
request: Request,
|
||||
prompt: str = Form(...),
|
||||
model_a: str = Form(...),
|
||||
model_b: str = Form(...),
|
||||
endpoint_a: str = Form(...),
|
||||
endpoint_b: str = Form(...),
|
||||
is_blind: str = Form("true"),
|
||||
):
|
||||
"""Create two ephemeral sessions and a comparison record.
|
||||
|
||||
Returns the comparison ID and the two session IDs so the client
|
||||
can fire two independent SSE streams to /api/chat_stream.
|
||||
"""
|
||||
comp_id = str(uuid.uuid4())
|
||||
sid_a = str(uuid.uuid4())
|
||||
sid_b = str(uuid.uuid4())
|
||||
|
||||
# Create ephemeral sessions (prefixed [CMP])
|
||||
for sid, model, endpoint in [(sid_a, model_a, endpoint_a), (sid_b, model_b, endpoint_b)]:
|
||||
user = getattr(request.state, 'current_user', None)
|
||||
session_manager.create_session(
|
||||
session_id=sid,
|
||||
name=f"[CMP] {model.split('/')[-1]}",
|
||||
endpoint_url=endpoint,
|
||||
model=model,
|
||||
rag=False,
|
||||
owner=user,
|
||||
)
|
||||
# Copy API key from endpoint config
|
||||
db = SessionLocal()
|
||||
try:
|
||||
from core.database import ModelEndpoint
|
||||
# Find matching endpoint by URL
|
||||
ep = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.base_url == endpoint.replace('/chat/completions', '')
|
||||
).first()
|
||||
if ep and ep.api_key:
|
||||
s = session_manager.sessions.get(sid)
|
||||
if s:
|
||||
s.headers = {"Authorization": f"Bearer {ep.api_key}"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Blind mapping: randomly assign left/right
|
||||
blind = str(is_blind).lower() == "true"
|
||||
if blind:
|
||||
mapping = {"left": "a", "right": "b"}
|
||||
if random.random() > 0.5:
|
||||
mapping = {"left": "b", "right": "a"}
|
||||
else:
|
||||
mapping = {"left": "a", "right": "b"}
|
||||
|
||||
# Store comparison record
|
||||
db = SessionLocal()
|
||||
try:
|
||||
comp = Comparison(
|
||||
id=comp_id,
|
||||
prompt=prompt,
|
||||
model_a=model_a,
|
||||
model_b=model_b,
|
||||
endpoint_a=endpoint_a,
|
||||
endpoint_b=endpoint_b,
|
||||
is_blind=blind,
|
||||
blind_mapping=json.dumps(mapping),
|
||||
owner=user,
|
||||
)
|
||||
db.add(comp)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Map session IDs to left/right based on blind mapping
|
||||
session_left = sid_a if mapping["left"] == "a" else sid_b
|
||||
session_right = sid_a if mapping["right"] == "a" else sid_b
|
||||
|
||||
return {
|
||||
"id": comp_id,
|
||||
"session_left": session_left,
|
||||
"session_right": session_right,
|
||||
"model_left": model_a if mapping["left"] == "a" else model_b,
|
||||
"model_right": model_a if mapping["right"] == "a" else model_b,
|
||||
"is_blind": blind,
|
||||
"mapping": mapping,
|
||||
}
|
||||
|
||||
@router.post("/{comp_id}/vote")
|
||||
def vote_comparison(
|
||||
request: Request,
|
||||
comp_id: str,
|
||||
winner: str = Form(...), # "left", "right", or "tie"
|
||||
):
|
||||
"""Record the user's vote and reveal model names if blind."""
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
comp = db.query(Comparison).filter(Comparison.id == comp_id).first()
|
||||
if not comp:
|
||||
raise HTTPException(404, "Comparison not found")
|
||||
# SECURITY: strict ownership — null-owner Comparisons were
|
||||
# accessible to every user.
|
||||
if user and comp.owner != user:
|
||||
raise HTTPException(404, "Comparison not found")
|
||||
if comp.winner:
|
||||
raise HTTPException(400, "Already voted")
|
||||
|
||||
mapping = json.loads(comp.blind_mapping) if comp.blind_mapping else {"left": "a", "right": "b"}
|
||||
|
||||
if winner == "tie":
|
||||
comp.winner = "tie"
|
||||
elif winner == "left":
|
||||
comp.winner = mapping["left"]
|
||||
elif winner == "right":
|
||||
comp.winner = mapping["right"]
|
||||
else:
|
||||
raise HTTPException(400, "winner must be 'left', 'right', or 'tie'")
|
||||
|
||||
comp.voted_at = datetime.utcnow()
|
||||
db.commit()
|
||||
|
||||
return {
|
||||
"winner": comp.winner,
|
||||
"model_a": comp.model_a,
|
||||
"model_b": comp.model_b,
|
||||
"revealed": {
|
||||
"left": comp.model_a if mapping["left"] == "a" else comp.model_b,
|
||||
"right": comp.model_a if mapping["right"] == "a" else comp.model_b,
|
||||
},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/record")
|
||||
def record_comparison(request: Request, body: RecordVoteRequest):
|
||||
"""Lightweight endpoint to record a comparison vote from the frontend."""
|
||||
user = get_current_user(request)
|
||||
comp_id = str(uuid.uuid4())
|
||||
|
||||
model_a = body.models[0] if len(body.models) > 0 else ""
|
||||
model_b = body.models[1] if len(body.models) > 1 else ""
|
||||
|
||||
# For N>2 models, store the full list as JSON in blind_mapping
|
||||
if len(body.models) > 2:
|
||||
blind_mapping = json.dumps({"models": body.models})
|
||||
else:
|
||||
blind_mapping = None
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
comp = Comparison(
|
||||
id=comp_id,
|
||||
prompt=body.prompt[:500],
|
||||
model_a=model_a,
|
||||
model_b=model_b,
|
||||
endpoint_a="",
|
||||
endpoint_b="",
|
||||
winner=body.winner,
|
||||
is_blind=body.is_blind,
|
||||
blind_mapping=blind_mapping,
|
||||
voted_at=datetime.utcnow(),
|
||||
owner=user,
|
||||
)
|
||||
db.add(comp)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return {"status": "ok", "id": comp_id}
|
||||
|
||||
@router.get("/history")
|
||||
def list_comparisons(request: Request):
|
||||
"""List past comparisons."""
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(Comparison)
|
||||
if user:
|
||||
q = q.filter(Comparison.owner == user)
|
||||
comps = q.order_by(Comparison.created_at.desc()).limit(50).all()
|
||||
return [
|
||||
{
|
||||
"id": c.id,
|
||||
"prompt": c.prompt[:100],
|
||||
"model_a": c.model_a,
|
||||
"model_b": c.model_b,
|
||||
"winner": c.winner,
|
||||
"is_blind": c.is_blind,
|
||||
"voted_at": c.voted_at.isoformat() if c.voted_at else None,
|
||||
"created_at": c.created_at.isoformat() if c.created_at else None,
|
||||
}
|
||||
for c in comps
|
||||
]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.delete("/{comp_id}")
|
||||
def delete_comparison(request: Request, comp_id: str):
|
||||
"""Delete a comparison and its ephemeral sessions."""
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
comp = db.query(Comparison).filter(Comparison.id == comp_id).first()
|
||||
if not comp:
|
||||
raise HTTPException(404, "Comparison not found")
|
||||
# SECURITY: strict ownership — null-owner Comparisons were
|
||||
# accessible to every user.
|
||||
if user and comp.owner != user:
|
||||
raise HTTPException(404, "Comparison not found")
|
||||
db.delete(comp)
|
||||
db.commit()
|
||||
return {"status": "deleted"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return router
|
||||
783
routes/contacts_routes.py
Normal file
783
routes/contacts_routes.py
Normal file
@@ -0,0 +1,783 @@
|
||||
"""
|
||||
contacts_routes.py
|
||||
|
||||
CardDAV contacts integration. Reads from local Radicale, supports
|
||||
search and adding new contacts.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
import uuid
|
||||
import json
|
||||
import csv
|
||||
import io
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Query, Depends, Response
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from src.auth_helpers import require_user
|
||||
from core.middleware import require_admin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
||||
SETTINGS_FILE = DATA_DIR / "settings.json"
|
||||
LOCAL_CONTACTS_FILE = DATA_DIR / "contacts.json"
|
||||
|
||||
|
||||
def _load_settings():
|
||||
if SETTINGS_FILE.exists():
|
||||
return json.loads(SETTINGS_FILE.read_text())
|
||||
return {}
|
||||
|
||||
|
||||
def _save_settings(settings):
|
||||
from core.atomic_io import atomic_write_json
|
||||
atomic_write_json(str(SETTINGS_FILE), settings, indent=2)
|
||||
|
||||
|
||||
def _get_carddav_config():
|
||||
import os
|
||||
settings = _load_settings()
|
||||
return {
|
||||
"url": settings.get("carddav_url", os.environ.get("CARDDAV_URL", "")),
|
||||
"username": settings.get("carddav_username", os.environ.get("CARDDAV_USERNAME", "")),
|
||||
"password": settings.get("carddav_password", os.environ.get("CARDDAV_PASSWORD", "")),
|
||||
}
|
||||
|
||||
|
||||
def _carddav_configured(cfg: Optional[Dict] = None) -> bool:
|
||||
cfg = cfg or _get_carddav_config()
|
||||
return bool((cfg.get("url") or "").strip())
|
||||
|
||||
|
||||
def _normalize_contact(contact: Dict) -> Dict:
|
||||
emails = []
|
||||
for e in contact.get("emails") or ([] if not contact.get("email") else [contact.get("email")]):
|
||||
e = str(e or "").strip()
|
||||
if e and e not in emails:
|
||||
emails.append(e)
|
||||
phones = []
|
||||
for p in contact.get("phones") or ([] if not contact.get("phone") else [contact.get("phone")]):
|
||||
p = str(p or "").strip()
|
||||
if p and p not in phones:
|
||||
phones.append(p)
|
||||
name = str(contact.get("name") or "").strip()
|
||||
if not name and emails:
|
||||
name = emails[0].split("@")[0]
|
||||
return {
|
||||
"uid": str(contact.get("uid") or uuid.uuid4()),
|
||||
"name": name,
|
||||
"emails": emails,
|
||||
"phones": phones,
|
||||
}
|
||||
|
||||
|
||||
def _load_local_contacts() -> List[Dict]:
|
||||
try:
|
||||
if not LOCAL_CONTACTS_FILE.exists():
|
||||
return []
|
||||
data = json.loads(LOCAL_CONTACTS_FILE.read_text())
|
||||
rows = data.get("contacts", data) if isinstance(data, dict) else data
|
||||
return [_normalize_contact(c) for c in (rows or []) if isinstance(c, dict)]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load local contacts: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _save_local_contacts(contacts: List[Dict]) -> None:
|
||||
from core.atomic_io import atomic_write_json
|
||||
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||
atomic_write_json(str(LOCAL_CONTACTS_FILE), {"contacts": [_normalize_contact(c) for c in contacts]}, indent=2)
|
||||
_contact_cache["contacts"] = [_normalize_contact(c) for c in contacts]
|
||||
_contact_cache["fetched_at"] = datetime.utcnow()
|
||||
|
||||
|
||||
# ── vCard parsing ──
|
||||
|
||||
def _vunesc(value: str) -> str:
|
||||
"""Reverse _vesc() — turn escaped vCard text back into the raw value.
|
||||
Order matters: handle \\n/\\, /\\; first, backslash-unescape last."""
|
||||
if not value:
|
||||
return value
|
||||
out = []
|
||||
i = 0
|
||||
while i < len(value):
|
||||
ch = value[i]
|
||||
if ch == "\\" and i + 1 < len(value):
|
||||
nxt = value[i + 1]
|
||||
if nxt in ("n", "N"):
|
||||
out.append("\n")
|
||||
elif nxt in (",", ";", "\\"):
|
||||
out.append(nxt)
|
||||
else:
|
||||
out.append(nxt)
|
||||
i += 2
|
||||
else:
|
||||
out.append(ch)
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _parse_vcards(text: str) -> List[Dict]:
|
||||
"""Parse a stream of vCards into dicts with name, email, phone."""
|
||||
contacts = []
|
||||
for block in re.split(r"BEGIN:VCARD", text):
|
||||
if not block.strip():
|
||||
continue
|
||||
contact = {"name": "", "emails": [], "phones": [], "uid": ""}
|
||||
for line in block.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("FN:") or line.startswith("FN;"):
|
||||
contact["name"] = _vunesc(line.split(":", 1)[1]) if ":" in line else ""
|
||||
elif line.startswith("EMAIL"):
|
||||
# Handle EMAIL:foo@bar OR EMAIL;TYPE=...:foo@bar OR EMAIL;PREF=1:foo@bar
|
||||
if ":" in line:
|
||||
email_addr = _vunesc(line.split(":", 1)[1])
|
||||
if email_addr and email_addr not in contact["emails"]:
|
||||
contact["emails"].append(email_addr)
|
||||
elif line.startswith("TEL"):
|
||||
if ":" in line:
|
||||
phone = _vunesc(line.split(":", 1)[1])
|
||||
if phone and phone not in contact["phones"]:
|
||||
contact["phones"].append(phone)
|
||||
elif line.startswith("UID:"):
|
||||
contact["uid"] = _vunesc(line[4:])
|
||||
if contact["name"] or contact["emails"]:
|
||||
contacts.append(contact)
|
||||
return contacts
|
||||
|
||||
|
||||
def _vesc(value: str) -> str:
|
||||
"""Escape a vCard property VALUE per RFC 6350 §3.4: backslash, comma,
|
||||
semicolon, and newlines. Without this, a name like 'Sekisui House,Ltd'
|
||||
or any value containing a newline produces a malformed vCard (broken
|
||||
N/FN fields) or could inject arbitrary properties."""
|
||||
return (
|
||||
(value or "")
|
||||
.replace("\\", "\\\\")
|
||||
.replace("\n", "\\n")
|
||||
.replace("\r", "")
|
||||
.replace(",", "\\,")
|
||||
.replace(";", "\\;")
|
||||
)
|
||||
|
||||
|
||||
def _build_vcard(name: str, email: str, uid: Optional[str] = None,
|
||||
emails: Optional[List[str]] = None,
|
||||
phones: Optional[List[str]] = None) -> str:
|
||||
"""Build a vCard. Accepts either a single `email` (legacy callers) or
|
||||
full `emails`/`phones` lists (edit path). The first email is marked
|
||||
PREF=1. All values are RFC-6350-escaped."""
|
||||
if not uid:
|
||||
uid = str(uuid.uuid4())
|
||||
# Normalize email lists — `email` arg is a convenience for single-email
|
||||
# creation; `emails` (if given) is authoritative.
|
||||
email_list = [e.strip() for e in (emails if emails is not None else ([email] if email else [])) if e and e.strip()]
|
||||
phone_list = [p.strip() for p in (phones or []) if p and p.strip()]
|
||||
# Try to split name into first/last
|
||||
parts = name.strip().split()
|
||||
if len(parts) >= 2:
|
||||
first = parts[0]
|
||||
last = " ".join(parts[1:])
|
||||
else:
|
||||
first = name
|
||||
last = ""
|
||||
# N field is structured (5 components separated by ';') — escape each
|
||||
# component individually so a comma in the name doesn't split it.
|
||||
n_field = f"{_vesc(last)};{_vesc(first)};;;"
|
||||
lines = [
|
||||
"BEGIN:VCARD",
|
||||
"VERSION:4.0",
|
||||
f"UID:{_vesc(uid)}",
|
||||
f"FN:{_vesc(name)}",
|
||||
f"N:{n_field}",
|
||||
]
|
||||
for i, em in enumerate(email_list):
|
||||
# First email is the preferred one.
|
||||
lines.append(f"EMAIL;PREF=1:{_vesc(em)}" if i == 0 else f"EMAIL:{_vesc(em)}")
|
||||
for ph in phone_list:
|
||||
lines.append(f"TEL:{_vesc(ph)}")
|
||||
lines.append("END:VCARD")
|
||||
return "\r\n".join(lines) + "\r\n"
|
||||
|
||||
|
||||
# ── In-memory cache ──
|
||||
|
||||
_contact_cache = {"contacts": [], "fetched_at": None}
|
||||
|
||||
|
||||
def _abs_url(href: str) -> str:
|
||||
"""Combine a multistatus <href> (an absolute path like
|
||||
/user/contacts/x.vcf) with the configured CardDAV server origin so we
|
||||
get a fully-qualified URL to PUT/DELETE. If href is already absolute
|
||||
(http...), return it as-is."""
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
if href.startswith("http://") or href.startswith("https://"):
|
||||
return href
|
||||
cfg = _get_carddav_config()
|
||||
p = urlparse(cfg["url"])
|
||||
return urlunparse((p.scheme, p.netloc, href, "", "", ""))
|
||||
|
||||
|
||||
# CardDAV REPORT body — pull every card's etag + raw vCard in ONE request,
|
||||
# alongside the resource href. Lets us map each contact's UID to the real
|
||||
# server resource path (which is NOT always <uid>.vcf for contacts created
|
||||
# by other clients).
|
||||
_ADDRESSBOOK_QUERY = (
|
||||
'<?xml version="1.0" encoding="utf-8"?>'
|
||||
'<C:addressbook-query xmlns:D="DAV:" xmlns:C="urn:ietf:params:xml:ns:carddav">'
|
||||
'<D:prop><D:getetag/><C:address-data/></D:prop>'
|
||||
'<C:filter/>'
|
||||
'</C:addressbook-query>'
|
||||
)
|
||||
|
||||
|
||||
def _fetch_via_report(cfg, auth):
|
||||
"""Try a CardDAV REPORT addressbook-query — returns contacts WITH an
|
||||
`href` field, or None if the server doesn't support it / errors."""
|
||||
from defusedxml import ElementTree as ET
|
||||
try:
|
||||
r = httpx.request(
|
||||
"REPORT", cfg["url"],
|
||||
content=_ADDRESSBOOK_QUERY.encode("utf-8"),
|
||||
headers={"Content-Type": "application/xml; charset=utf-8", "Depth": "1"},
|
||||
auth=auth, timeout=10,
|
||||
)
|
||||
if r.status_code not in (207, 200):
|
||||
return None
|
||||
root = ET.fromstring(r.text)
|
||||
ns = {"D": "DAV:", "C": "urn:ietf:params:xml:ns:carddav"}
|
||||
out = []
|
||||
for resp in root.findall("D:response", ns):
|
||||
href_el = resp.find("D:href", ns)
|
||||
data_el = resp.find(".//C:address-data", ns)
|
||||
if href_el is None or data_el is None or not (data_el.text or "").strip():
|
||||
continue
|
||||
parsed = _parse_vcards(data_el.text)
|
||||
if not parsed:
|
||||
continue
|
||||
c = parsed[0]
|
||||
c["href"] = href_el.text.strip()
|
||||
out.append(c)
|
||||
# If the REPORT parsed to ZERO contacts, don't trust it — some
|
||||
# CardDAV servers treat an empty <filter/> as "match nothing" and
|
||||
# return a valid-but-empty 207. Return None so the caller falls
|
||||
# back to the plain GET (which lists everything). A genuinely empty
|
||||
# address book just costs one extra GET that also returns nothing.
|
||||
if not out:
|
||||
return None
|
||||
return out
|
||||
except Exception as e:
|
||||
logger.warning(f"CardDAV REPORT failed, falling back to GET: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_contacts(force=False):
|
||||
"""Fetch all contacts. Uses CardDAV when configured, otherwise local JSON."""
|
||||
if not force and _contact_cache["fetched_at"]:
|
||||
age = (datetime.utcnow() - _contact_cache["fetched_at"]).total_seconds()
|
||||
if age < 60:
|
||||
return _contact_cache["contacts"]
|
||||
|
||||
cfg = _get_carddav_config()
|
||||
if not _carddav_configured(cfg):
|
||||
contacts = _load_local_contacts()
|
||||
_contact_cache["contacts"] = contacts
|
||||
_contact_cache["fetched_at"] = datetime.utcnow()
|
||||
return contacts
|
||||
|
||||
try:
|
||||
auth = None
|
||||
if cfg["username"]:
|
||||
auth = (cfg["username"], cfg["password"])
|
||||
# Preferred path: REPORT gives us hrefs for reliable edit/delete.
|
||||
contacts = _fetch_via_report(cfg, auth)
|
||||
if contacts is None:
|
||||
# Fallback: plain GET, concatenated vCards, no hrefs.
|
||||
r = httpx.get(cfg["url"], auth=auth, timeout=10)
|
||||
if r.status_code != 200:
|
||||
logger.warning(f"CardDAV returned {r.status_code}")
|
||||
return _contact_cache["contacts"]
|
||||
contacts = _parse_vcards(r.text)
|
||||
_contact_cache["contacts"] = contacts
|
||||
_contact_cache["fetched_at"] = datetime.utcnow()
|
||||
return contacts
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch contacts: {e}")
|
||||
return _contact_cache["contacts"]
|
||||
|
||||
|
||||
def _resolve_resource_url(uid: str) -> str:
|
||||
"""Map a contact UID to its real CardDAV resource URL. Uses the href
|
||||
captured during fetch when available (handles contacts whose filename
|
||||
!= UID); falls back to the <uid>.vcf guess for app-created contacts or
|
||||
when no href is known."""
|
||||
def _lookup():
|
||||
for c in _contact_cache.get("contacts", []):
|
||||
if c.get("uid") == uid and c.get("href"):
|
||||
return _abs_url(c["href"])
|
||||
return None
|
||||
found = _lookup()
|
||||
if found:
|
||||
return found
|
||||
# Not in cache (or no href) — refresh once and retry before guessing.
|
||||
try:
|
||||
_fetch_contacts(force=True)
|
||||
except Exception:
|
||||
pass
|
||||
return _lookup() or _vcard_url(uid)
|
||||
|
||||
|
||||
def _create_contact(name: str, email: str) -> bool:
|
||||
"""Add a new contact via CardDAV or local contacts."""
|
||||
cfg = _get_carddav_config()
|
||||
if not _carddav_configured(cfg):
|
||||
contacts = _load_local_contacts()
|
||||
email_l = (email or "").strip().lower()
|
||||
for c in contacts:
|
||||
if email_l and email_l in [e.lower() for e in c.get("emails", [])]:
|
||||
return True
|
||||
contacts.append(_normalize_contact({"name": name, "emails": [email]}))
|
||||
_save_local_contacts(contacts)
|
||||
return True
|
||||
|
||||
contact_uid = str(uuid.uuid4())
|
||||
vcard = _build_vcard(name, email, contact_uid)
|
||||
url = cfg["url"].rstrip("/") + "/" + contact_uid + ".vcf"
|
||||
try:
|
||||
auth = None
|
||||
if cfg["username"]:
|
||||
auth = (cfg["username"], cfg["password"])
|
||||
r = httpx.put(
|
||||
url,
|
||||
data=vcard.encode("utf-8"),
|
||||
headers={"Content-Type": "text/vcard; charset=utf-8"},
|
||||
auth=auth,
|
||||
timeout=10,
|
||||
)
|
||||
if r.status_code in (200, 201, 204):
|
||||
# Invalidate cache
|
||||
_contact_cache["fetched_at"] = None
|
||||
return True
|
||||
logger.warning(f"CardDAV PUT returned {r.status_code}: {r.text[:200]}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create contact: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _vcard_url(uid: str) -> str:
|
||||
"""The CardDAV resource URL for a given contact UID. The uid is URL-
|
||||
encoded so a value containing '/', '..' or other path chars can't
|
||||
escape the collection and target an arbitrary CardDAV resource."""
|
||||
from urllib.parse import quote
|
||||
cfg = _get_carddav_config()
|
||||
return cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
|
||||
|
||||
|
||||
def _import_vcards(text: str) -> Dict:
|
||||
"""Import a (possibly multi-card) .vcf blob. Each card is PUT to the
|
||||
CardDAV server PRESERVING its full original content (ADR/ORG/photo/
|
||||
etc.) — we don't rebuild it, just ensure it has VERSION + UID and
|
||||
normalize line endings. Returns {imported, failed, total}."""
|
||||
from urllib.parse import quote
|
||||
cfg = _get_carddav_config()
|
||||
if not cfg.get("url"):
|
||||
parsed = _parse_vcards(text)
|
||||
contacts = _load_local_contacts()
|
||||
existing = {
|
||||
e.lower()
|
||||
for c in contacts
|
||||
for e in (c.get("emails") or [])
|
||||
if e
|
||||
}
|
||||
imported = 0
|
||||
for c in parsed:
|
||||
emails = [e for e in (c.get("emails") or []) if e]
|
||||
if emails and any(e.lower() in existing for e in emails):
|
||||
continue
|
||||
contacts.append(_normalize_contact(c))
|
||||
for e in emails:
|
||||
existing.add(e.lower())
|
||||
imported += 1
|
||||
if imported:
|
||||
_save_local_contacts(contacts)
|
||||
return {"imported": imported, "failed": 0, "total": len(parsed)}
|
||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||
# Split into individual cards. re.split drops the BEGIN line, so we
|
||||
# re-add it. Normalize CRLF.
|
||||
raw = (text or "").replace("\r\n", "\n").replace("\r", "\n")
|
||||
blocks = []
|
||||
for chunk in raw.split("BEGIN:VCARD"):
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
# Trim anything after END:VCARD (defensive).
|
||||
end = chunk.upper().find("END:VCARD")
|
||||
body = chunk[: end + len("END:VCARD")] if end != -1 else chunk
|
||||
blocks.append("BEGIN:VCARD\n" + body)
|
||||
imported = 0
|
||||
failed = 0
|
||||
for block in blocks:
|
||||
# Extract or assign a UID.
|
||||
m = re.search(r"^UID:(.+)$", block, re.MULTILINE)
|
||||
uid = (m.group(1).strip() if m else "") or str(uuid.uuid4())
|
||||
if not m:
|
||||
# Inject a UID right after the VERSION line (or after BEGIN).
|
||||
if re.search(r"^VERSION:", block, re.MULTILINE):
|
||||
block = re.sub(r"(^VERSION:.*$)", r"\1\nUID:" + uid, block, count=1, flags=re.MULTILINE)
|
||||
else:
|
||||
block = block.replace("BEGIN:VCARD", f"BEGIN:VCARD\nVERSION:4.0\nUID:{uid}", 1)
|
||||
elif not re.search(r"^VERSION:", block, re.MULTILINE):
|
||||
block = block.replace("BEGIN:VCARD", "BEGIN:VCARD\nVERSION:4.0", 1)
|
||||
vcard = block.replace("\n", "\r\n") + "\r\n"
|
||||
url = cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
|
||||
try:
|
||||
r = httpx.put(
|
||||
url, data=vcard.encode("utf-8"),
|
||||
headers={"Content-Type": "text/vcard; charset=utf-8"},
|
||||
auth=auth, timeout=15,
|
||||
)
|
||||
if r.status_code in (200, 201, 204):
|
||||
imported += 1
|
||||
else:
|
||||
failed += 1
|
||||
logger.warning(f"Import PUT {uid} returned {r.status_code}: {r.text[:120]}")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"Import PUT {uid} failed: {e}")
|
||||
if imported:
|
||||
_contact_cache["fetched_at"] = None
|
||||
return {"imported": imported, "failed": failed, "total": len(blocks)}
|
||||
|
||||
|
||||
def _import_csv_contacts(text: str) -> Dict:
|
||||
"""Import contacts from CSV. Supports common headers:
|
||||
name/full_name/display_name, email/email_address/e-mail, phone/tel.
|
||||
Falls back to first columns as name,email,phone when no headers exist."""
|
||||
raw = (text or "").strip()
|
||||
if not raw:
|
||||
return {"imported": 0, "failed": 0, "total": 0, "error": "No CSV data found"}
|
||||
|
||||
try:
|
||||
sample = raw[:2048]
|
||||
dialect = csv.Sniffer().sniff(sample)
|
||||
except Exception:
|
||||
dialect = csv.excel
|
||||
|
||||
stream = io.StringIO(raw)
|
||||
try:
|
||||
has_header = csv.Sniffer().has_header(raw[:2048])
|
||||
except Exception:
|
||||
has_header = True
|
||||
|
||||
rows = []
|
||||
if has_header:
|
||||
reader = csv.DictReader(stream, dialect=dialect)
|
||||
for row in reader:
|
||||
lowered = {str(k or "").strip().lower(): (v or "").strip() for k, v in row.items()}
|
||||
name = (
|
||||
lowered.get("name") or lowered.get("full name") or lowered.get("full_name")
|
||||
or lowered.get("display name") or lowered.get("display_name")
|
||||
or lowered.get("fn") or ""
|
||||
)
|
||||
email = (
|
||||
lowered.get("email") or lowered.get("email address")
|
||||
or lowered.get("email_address") or lowered.get("e-mail")
|
||||
or lowered.get("mail") or ""
|
||||
)
|
||||
phone = lowered.get("phone") or lowered.get("telephone") or lowered.get("tel") or ""
|
||||
rows.append((name, email, phone))
|
||||
else:
|
||||
stream.seek(0)
|
||||
reader = csv.reader(stream, dialect=dialect)
|
||||
for row in reader:
|
||||
cols = [(c or "").strip() for c in row]
|
||||
if not any(cols):
|
||||
continue
|
||||
rows.append((
|
||||
cols[0] if len(cols) > 0 else "",
|
||||
cols[1] if len(cols) > 1 else "",
|
||||
cols[2] if len(cols) > 2 else "",
|
||||
))
|
||||
|
||||
imported = 0
|
||||
failed = 0
|
||||
total = 0
|
||||
existing_emails = {
|
||||
e.lower()
|
||||
for c in _fetch_contacts()
|
||||
for e in (c.get("emails") or [])
|
||||
if e
|
||||
}
|
||||
for name, email, phone in rows:
|
||||
email = (email or "").strip()
|
||||
name = (name or "").strip() or (email.split("@")[0] if email else "")
|
||||
if not email:
|
||||
continue
|
||||
total += 1
|
||||
if email.lower() in existing_emails:
|
||||
continue
|
||||
ok = _create_contact(name, email)
|
||||
if ok:
|
||||
imported += 1
|
||||
existing_emails.add(email.lower())
|
||||
# If the CSV had a phone number, rewrite the just-created row
|
||||
# through the richer update path so phone lands in CardDAV too.
|
||||
if phone:
|
||||
try:
|
||||
contacts = _fetch_contacts(force=True)
|
||||
created = next((c for c in contacts if email.lower() in [e.lower() for e in c.get("emails", [])]), None)
|
||||
if created and created.get("uid"):
|
||||
_update_contact(created["uid"], name, [email], [phone])
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
if imported:
|
||||
_contact_cache["fetched_at"] = None
|
||||
return {"imported": imported, "failed": failed, "total": total}
|
||||
|
||||
|
||||
def _contacts_to_vcf(contacts: List[Dict]) -> str:
|
||||
return "".join(
|
||||
_build_vcard(
|
||||
c.get("name") or ((c.get("emails") or [""])[0].split("@")[0] if c.get("emails") else "Contact"),
|
||||
"",
|
||||
uid=c.get("uid") or str(uuid.uuid4()),
|
||||
emails=c.get("emails") or [],
|
||||
phones=c.get("phones") or [],
|
||||
)
|
||||
for c in contacts
|
||||
)
|
||||
|
||||
|
||||
def _contacts_to_csv(contacts: List[Dict]) -> str:
|
||||
out = io.StringIO()
|
||||
writer = csv.writer(out)
|
||||
writer.writerow(["name", "email", "phone"])
|
||||
for c in contacts:
|
||||
emails = c.get("emails") or [""]
|
||||
phones = c.get("phones") or [""]
|
||||
max_len = max(len(emails), len(phones), 1)
|
||||
for i in range(max_len):
|
||||
writer.writerow([
|
||||
c.get("name") or "",
|
||||
emails[i] if i < len(emails) else "",
|
||||
phones[i] if i < len(phones) else "",
|
||||
])
|
||||
return out.getvalue()
|
||||
|
||||
|
||||
def _update_contact(uid: str, name: str, emails: List[str], phones: List[str]) -> bool:
|
||||
"""Rewrite an existing contact via CardDAV or local contacts."""
|
||||
cfg = _get_carddav_config()
|
||||
if not _carddav_configured(cfg):
|
||||
contacts = _load_local_contacts()
|
||||
found = False
|
||||
out = []
|
||||
for c in contacts:
|
||||
if c.get("uid") == uid:
|
||||
out.append(_normalize_contact({"uid": uid, "name": name, "emails": emails, "phones": phones}))
|
||||
found = True
|
||||
else:
|
||||
out.append(c)
|
||||
if not found:
|
||||
out.append(_normalize_contact({"uid": uid, "name": name, "emails": emails, "phones": phones}))
|
||||
_save_local_contacts(out)
|
||||
return True
|
||||
|
||||
vcard = _build_vcard(name, "", uid=uid, emails=emails, phones=phones)
|
||||
# Use the real resource href (handles externally-created contacts whose
|
||||
# filename != UID); falls back to the <uid>.vcf guess.
|
||||
url = _resolve_resource_url(uid)
|
||||
try:
|
||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||
r = httpx.put(
|
||||
url,
|
||||
data=vcard.encode("utf-8"),
|
||||
headers={"Content-Type": "text/vcard; charset=utf-8"},
|
||||
auth=auth,
|
||||
timeout=10,
|
||||
)
|
||||
if r.status_code in (200, 201, 204):
|
||||
_contact_cache["fetched_at"] = None
|
||||
return True
|
||||
logger.warning(f"CardDAV update PUT returned {r.status_code}: {r.text[:200]}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update contact: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _delete_contact(uid: str) -> bool:
|
||||
"""Delete a contact via CardDAV or local contacts."""
|
||||
cfg = _get_carddav_config()
|
||||
if not _carddav_configured(cfg):
|
||||
contacts = _load_local_contacts()
|
||||
remaining = [c for c in contacts if c.get("uid") != uid]
|
||||
_save_local_contacts(remaining)
|
||||
return True
|
||||
|
||||
url = _resolve_resource_url(uid)
|
||||
try:
|
||||
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
|
||||
r = httpx.delete(url, auth=auth, timeout=10)
|
||||
if r.status_code in (200, 204):
|
||||
_contact_cache["fetched_at"] = None
|
||||
return True
|
||||
if r.status_code == 404:
|
||||
# Resource not found at the resolved URL. With href resolution
|
||||
# this should be rare (genuinely already deleted). Invalidate
|
||||
# the cache and report success so the UI doesn't keep a ghost.
|
||||
logger.info(f"CardDAV DELETE 404 for {uid} — treating as already gone")
|
||||
_contact_cache["fetched_at"] = None
|
||||
return True
|
||||
logger.warning(f"CardDAV DELETE returned {r.status_code}: {r.text[:200]}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete contact: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ── Routes ──
|
||||
|
||||
def setup_contacts_routes():
|
||||
router = APIRouter(prefix="/api/contacts", tags=["contacts"])
|
||||
|
||||
@router.get("/list")
|
||||
async def list_contacts(_admin: str = Depends(require_admin)):
|
||||
"""List all contacts."""
|
||||
contacts = _fetch_contacts()
|
||||
return {"contacts": contacts, "count": len(contacts)}
|
||||
|
||||
@router.get("/search")
|
||||
async def search_contacts(q: str = Query(""), _admin: str = Depends(require_admin)):
|
||||
"""Search contacts by name or email. Returns up to 10 matches."""
|
||||
contacts = _fetch_contacts()
|
||||
if not q:
|
||||
return {"results": []}
|
||||
q_lower = q.lower()
|
||||
results = []
|
||||
for c in contacts:
|
||||
if q_lower in c["name"].lower():
|
||||
results.append(c)
|
||||
continue
|
||||
for em in c["emails"]:
|
||||
if q_lower in em.lower():
|
||||
results.append(c)
|
||||
break
|
||||
return {"results": results[:10]}
|
||||
|
||||
@router.post("/add")
|
||||
async def add_contact(data: dict, _admin: str = Depends(require_admin)):
|
||||
"""Add a new contact."""
|
||||
name = data.get("name", "").strip()
|
||||
email = data.get("email", "").strip()
|
||||
if not email:
|
||||
return {"success": False, "error": "Email required"}
|
||||
# Check if already exists
|
||||
contacts = _fetch_contacts()
|
||||
for c in contacts:
|
||||
if email.lower() in [e.lower() for e in c["emails"]]:
|
||||
return {"success": True, "message": "Already exists", "contact": c}
|
||||
if not name:
|
||||
name = email.split("@")[0]
|
||||
ok = _create_contact(name, email)
|
||||
return {"success": ok}
|
||||
|
||||
@router.post("/import")
|
||||
async def import_vcf(data: dict, _admin: str = Depends(require_admin)):
|
||||
"""Import contacts from .vcf or CSV. Body: {"vcf": "..."} or {"csv": "..."}."""
|
||||
text = data.get("vcf") or data.get("text") or ""
|
||||
csv_text = data.get("csv") or ""
|
||||
if text.strip():
|
||||
if "BEGIN:VCARD" not in text.upper():
|
||||
return {"success": False, "error": "No vCard data found"}
|
||||
result = _import_vcards(text)
|
||||
elif csv_text.strip():
|
||||
result = _import_csv_contacts(csv_text)
|
||||
else:
|
||||
return {"success": False, "error": "No contact data found"}
|
||||
result["success"] = result.get("imported", 0) > 0
|
||||
return result
|
||||
|
||||
@router.get("/export")
|
||||
async def export_contacts(
|
||||
format: str = Query("vcf", pattern="^(vcf|csv)$"),
|
||||
_admin: str = Depends(require_admin),
|
||||
):
|
||||
"""Export all contacts as vCard or CSV."""
|
||||
contacts = _fetch_contacts(force=True)
|
||||
if format == "csv":
|
||||
content = _contacts_to_csv(contacts)
|
||||
media_type = "text/csv; charset=utf-8"
|
||||
filename = "odysseus-contacts.csv"
|
||||
else:
|
||||
content = _contacts_to_vcf(contacts)
|
||||
media_type = "text/vcard; charset=utf-8"
|
||||
filename = "odysseus-contacts.vcf"
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=media_type,
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config(_admin: str = Depends(require_admin)):
|
||||
cfg = _get_carddav_config()
|
||||
# Mask password
|
||||
if cfg["password"]:
|
||||
cfg["password"] = "***"
|
||||
return cfg
|
||||
|
||||
@router.put("/config")
|
||||
async def update_config(data: dict, _admin: str = Depends(require_admin)):
|
||||
settings = _load_settings()
|
||||
for key in ("carddav_url", "carddav_username", "carddav_password"):
|
||||
if key in data:
|
||||
settings[key] = data[key]
|
||||
_save_settings(settings)
|
||||
# Force re-fetch
|
||||
_contact_cache["fetched_at"] = None
|
||||
return {"success": True}
|
||||
|
||||
@router.delete("/clear")
|
||||
async def clear_contacts(_admin: str = Depends(require_admin)):
|
||||
"""Clear all local contacts. If CardDAV is configured, only clears the local fallback cache."""
|
||||
_save_local_contacts([])
|
||||
return {"success": True}
|
||||
|
||||
# NOTE: the /{uid} routes are declared LAST so the literal paths above
|
||||
# (/list, /search, /add, /config) win — otherwise PUT /config would
|
||||
# match PUT /{uid} with uid="config".
|
||||
@router.put("/{uid}")
|
||||
async def edit_contact(uid: str, data: dict, _admin: str = Depends(require_admin)):
|
||||
"""Edit an existing contact — name / emails / phones."""
|
||||
name = (data.get("name") or "").strip()
|
||||
emails = data.get("emails")
|
||||
phones = data.get("phones")
|
||||
if emails is None and data.get("email"):
|
||||
emails = [data["email"]]
|
||||
emails = [e.strip() for e in (emails or []) if e and e.strip()]
|
||||
phones = [p.strip() for p in (phones or []) if p and p.strip()]
|
||||
if not name and not emails:
|
||||
return {"success": False, "error": "Name or email required"}
|
||||
if not name and emails:
|
||||
name = emails[0].split("@")[0]
|
||||
ok = _update_contact(uid, name, emails, phones)
|
||||
return {"success": ok}
|
||||
|
||||
@router.delete("/{uid}")
|
||||
async def delete_contact(uid: str, _admin: str = Depends(require_admin)):
|
||||
"""Delete a contact by UID."""
|
||||
if not uid:
|
||||
return {"success": False, "error": "UID required"}
|
||||
ok = _delete_contact(uid)
|
||||
return {"success": ok}
|
||||
|
||||
return router
|
||||
340
routes/cookbook_helpers.py
Normal file
340
routes/cookbook_helpers.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""cookbook_helpers.py — validators + small helpers shared by the cookbook routes.
|
||||
Extracted from cookbook_routes.py; the routes module imports the symbols it needs."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# HuggingFace repo IDs are <org>/<name>, both alphanumerics plus ._-
|
||||
# Rejecting anything else up front closes off shell-interpolation vectors.
|
||||
_REPO_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*/[A-Za-z0-9][A-Za-z0-9._-]*$")
|
||||
# Include pattern is a glob: allow typical safe glyphs only.
|
||||
_INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$")
|
||||
# Remote host: user@host (optionally with :port-free hostname parts).
|
||||
_REMOTE_HOST_RE = re.compile(r"^[A-Za-z0-9._-]+@[A-Za-z0-9._-]+$")
|
||||
# HF tokens and API tokens are url-safe base64-like.
|
||||
_TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$")
|
||||
# Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef".
|
||||
# Anything beyond plain alphanumerics + dash + underscore could break out
|
||||
# of the shell/PowerShell contexts the value lands in.
|
||||
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
|
||||
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
|
||||
_GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$")
|
||||
# A download target directory. Absolute or ~-relative path; safe path glyphs
|
||||
# only (no quotes, shell metacharacters, or spaces) since it lands in a shell
|
||||
# command. A leading ~ is expanded to $HOME at command-build time.
|
||||
_LOCAL_DIR_RE = re.compile(r"^~?/[A-Za-z0-9._/-]*$|^~$")
|
||||
|
||||
|
||||
def _validate_repo_id(v: str | None) -> str:
|
||||
if not v or not _REPO_ID_RE.match(v):
|
||||
raise HTTPException(400, "Invalid repo_id — must be <org>/<name> using [A-Za-z0-9._-]")
|
||||
return v
|
||||
|
||||
|
||||
def _validate_include(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _INCLUDE_RE.match(v):
|
||||
raise HTTPException(400, "Invalid include pattern")
|
||||
return v
|
||||
|
||||
|
||||
def _validate_remote_host(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _REMOTE_HOST_RE.match(v):
|
||||
raise HTTPException(400, "Invalid remote_host — must be user@host, no SSH option syntax")
|
||||
return v
|
||||
|
||||
|
||||
def _validate_token(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _TOKEN_RE.match(v):
|
||||
raise HTTPException(400, "Invalid token characters")
|
||||
return v
|
||||
|
||||
|
||||
def _validate_local_dir(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
v = v.rstrip("/") or "/"
|
||||
if not _LOCAL_DIR_RE.match(v):
|
||||
raise HTTPException(400, "Invalid local_dir — must be an absolute or ~ path with no spaces or shell metacharacters")
|
||||
return v
|
||||
|
||||
|
||||
def _validate_ssh_port(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _SSH_PORT_RE.fullmatch(str(v)):
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
port = int(v)
|
||||
if port < 1 or port > 65535:
|
||||
raise HTTPException(400, "Invalid ssh_port")
|
||||
return str(port)
|
||||
|
||||
|
||||
def _validate_gpus(v: str | None) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not _GPU_LIST_RE.fullmatch(str(v)):
|
||||
raise HTTPException(400, "Invalid gpus — expected comma-separated GPU indexes")
|
||||
return str(v)
|
||||
|
||||
|
||||
def _shell_path(p: str) -> str:
|
||||
"""Render a validated path for a double-quoted shell context, expanding a
|
||||
leading ~ to $HOME (single quotes wouldn't expand it). Safe because
|
||||
_validate_local_dir already restricts the charset."""
|
||||
if p == "~":
|
||||
return '"$HOME"'
|
||||
if p.startswith("~/"):
|
||||
return '"$HOME/' + p[2:] + '"'
|
||||
return '"' + p + '"'
|
||||
|
||||
|
||||
def _ps_squote(v: str) -> str:
|
||||
"""Escape a value for PowerShell single-quoted string interpolation.
|
||||
Belt-and-suspenders on top of _validate_token's regex — if the regex
|
||||
is ever loosened, this still keeps the heredoc shell-safe."""
|
||||
return v.replace("'", "''")
|
||||
|
||||
|
||||
def _bash_squote(v: str) -> str:
|
||||
"""Escape a value for bash/sh single-quoted string interpolation."""
|
||||
return v.replace("'", "'\\''")
|
||||
|
||||
|
||||
# Allow-list of binaries permitted as the leading token of `req.cmd` for /api/model/serve.
|
||||
# Anything else is rejected before the cmd is interpolated into a tmux/PowerShell wrapper.
|
||||
_SERVE_CMD_ALLOWLIST = {
|
||||
"vllm", "llama-server", "llama_server", "llama.cpp", "ollama",
|
||||
"python", "python3",
|
||||
"sglang", "lmdeploy",
|
||||
"node", "npx",
|
||||
}
|
||||
|
||||
|
||||
# The llama.cpp GGUF launcher (static/js/cookbook.js) emits a fixed-shape
|
||||
# prelude that resolves the cached .gguf on the target host before serving:
|
||||
# MODEL_FILE=$( { find …; find …; } | head -1 ) && { [ -n "$MODEL_FILE" ] && \
|
||||
# [ -f "$MODEL_FILE" ]; } || { echo "ERROR…"; exit 1; } && <serve> || <serve>
|
||||
# That legitimately needs $(...)/&&/||, so we recognise this exact shape and
|
||||
# validate the serve binaries it guards rather than rejecting it wholesale.
|
||||
_GGUF_PRELUDE_RE = re.compile(
|
||||
r'^MODEL_FILE=\$\([^\n]*?\)\s*&&\s*\{[^{}]*\}\s*\|\|\s*\{[^{}]*\}\s*&&\s*'
|
||||
)
|
||||
|
||||
|
||||
def _check_serve_binary(seg: str) -> None:
|
||||
"""Validate that a single command segment starts with an allowlisted binary
|
||||
(after skipping leading env-var assignments like `CUDA_VISIBLE_DEVICES=0`)."""
|
||||
try:
|
||||
tokens = shlex.split(seg) if seg.strip() else []
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Invalid cmd — could not parse")
|
||||
if not tokens:
|
||||
return
|
||||
env_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*=")
|
||||
first = next((t for t in tokens if not env_re.match(t)), "")
|
||||
base = os.path.basename(first)
|
||||
if base not in _SERVE_CMD_ALLOWLIST:
|
||||
raise HTTPException(
|
||||
400,
|
||||
f"cmd binary '{base or '(empty)'}' is not allowed. Must start with one of: "
|
||||
f"{', '.join(sorted(_SERVE_CMD_ALLOWLIST))}",
|
||||
)
|
||||
|
||||
|
||||
def _validate_serve_cmd(v: str | None) -> str | None:
|
||||
"""Reject serve commands that aren't in the allowlist or contain shell metachars.
|
||||
|
||||
`req.cmd` is dropped verbatim into a bash/PowerShell wrapper script and
|
||||
executed in a tmux session. Without this gate, an admin (or anyone in the
|
||||
pre-fix world) could pass arbitrary shell payloads.
|
||||
|
||||
Leading env-var assignments (e.g. `CUDA_VISIBLE_DEVICES=0 python3 ...`)
|
||||
are stripped before checking the binary — several of our cmd builders
|
||||
prepend them, and they shouldn't trip the allowlist.
|
||||
"""
|
||||
if v is None or v == "":
|
||||
return None
|
||||
# Collapse backslash-newline line continuations into single spaces. Serve
|
||||
# commands (vLLM especially) are routinely pasted multi-line with trailing
|
||||
# `\` — that's a safe shell/shlex continuation, so the command stays ONE
|
||||
# logical invocation and the leading-token allowlist below still governs.
|
||||
v = re.sub(r"\\[ \t]*\r?\n[ \t]*", " ", v).strip()
|
||||
# Backticks and raw newlines are never legitimate here.
|
||||
if any(c in v for c in ("`", "\n", "\r")):
|
||||
raise HTTPException(400, "Invalid characters in cmd")
|
||||
# Known GGUF launcher prelude → validate the serve invocation(s) it guards.
|
||||
m = _GGUF_PRELUDE_RE.match(v)
|
||||
if m:
|
||||
rest = v[m.end():]
|
||||
# rest is `[ENV=…] python3 -m llama_cpp.server … || [ENV=…] llama-server …`
|
||||
for part in rest.split("||"):
|
||||
_check_serve_binary(part.strip())
|
||||
return v
|
||||
# Otherwise: a single invocation — no shell metacharacters allowed.
|
||||
# (`$(` was the original intent; bare `$` is fine for shell-safe paths.)
|
||||
if any(c in v for c in (";", "&&", "||", "$(")):
|
||||
raise HTTPException(400, "Invalid characters in cmd")
|
||||
_check_serve_binary(v)
|
||||
return v
|
||||
|
||||
|
||||
class ModelDownloadRequest(BaseModel):
|
||||
repo_id: str
|
||||
include: str | None = None # glob pattern e.g. "*Q4_K_M*"
|
||||
hf_token: str | None = None
|
||||
env_prefix: str | None = None # e.g. "source ~/venv/bin/activate"
|
||||
remote_host: str | None = None # e.g. "gpu-box" — run download on this host via SSH
|
||||
ssh_port: str | None = None # e.g. "8022" for Termux
|
||||
platform: str | None = None # "linux", "termux", or "windows"
|
||||
local_dir: str | None = None # base dir to download into (a per-model subfolder is created under it); None = default HF cache
|
||||
disable_hf_transfer: bool = False # skip the Rust hf_transfer downloader — slower but far more reliable on large files (used by retries)
|
||||
|
||||
|
||||
class ServeRequest(BaseModel):
|
||||
repo_id: str
|
||||
cmd: str
|
||||
remote_host: str | None = None
|
||||
ssh_port: str | None = None
|
||||
env_prefix: str | None = None
|
||||
hf_token: str | None = None
|
||||
gpus: str | None = None
|
||||
platform: str | None = None # "linux", "termux", or "windows"
|
||||
|
||||
|
||||
def _parse_serve_phase(snapshot: str, task_type: str = "serve") -> dict:
|
||||
"""Parse a tmux snapshot of a serve task into structured phase info.
|
||||
|
||||
Single source of truth for serve task status detection. Returns:
|
||||
{ "phase": str, "status": "ready"|"running"|"", "tps": float|None,
|
||||
"reqs": int|None, "pct": int|None }
|
||||
"""
|
||||
import re
|
||||
if task_type != "serve" or not snapshot:
|
||||
return {}
|
||||
# Strip newlines so tmux line-wrapping doesn't break regex matching
|
||||
flat = re.sub(r'\s+', ' ', snapshot)
|
||||
|
||||
load_matches = re.findall(r'Loading safetensors.*?(\d+)%', flat)
|
||||
# Prefer "Downloading (incomplete total...)" (real aggregate bytes) over
|
||||
# "Fetching N files" (whole-file count, lags with hf_transfer's chunked pulls).
|
||||
downloading_matches = re.findall(r'Downloading.*?(\d+)%', flat)
|
||||
fetching_matches = re.findall(r'Fetching.*?(\d+)%', flat)
|
||||
dl_matches = downloading_matches if downloading_matches else fetching_matches
|
||||
# Match "Avg generation throughput: X tokens/s, Running: N reqs" (with line-wrap tolerance)
|
||||
tps_matches = re.findall(
|
||||
r'(?:Avg )?generation throughput:\s*([\d.]+)\s*tokens/s.*?Running:\s*(\d+)\s*reqs',
|
||||
flat,
|
||||
)
|
||||
|
||||
# Check throughput FIRST — the throughput log line contains "GPU KV cache usage"
|
||||
# which would otherwise false-match the warmup check
|
||||
if tps_matches:
|
||||
tps_str, reqs_str = tps_matches[-1]
|
||||
tps = float(tps_str)
|
||||
reqs = int(reqs_str)
|
||||
return {
|
||||
"phase": f"{tps_str} tok/s" if reqs > 0 else "idle",
|
||||
"status": "ready",
|
||||
"tps": tps,
|
||||
"reqs": reqs,
|
||||
}
|
||||
if "Application startup complete" in flat:
|
||||
return {"phase": "ready", "status": "ready"}
|
||||
# HTTP access logs (e.g. GET /v1/models 200 OK) mean the server is up and serving
|
||||
if re.search(r'(?:GET|POST)\s+/[^\s]*\s+HTTP/[\d.]+"\s*\d{3}', flat):
|
||||
return {"phase": "idle", "status": "ready"}
|
||||
if "Loading weights took" in flat:
|
||||
return {"phase": "initializing", "status": "running"}
|
||||
# "GPU KV cache" alone (during allocation) — not "GPU KV cache usage" (runtime log)
|
||||
if "GPU KV cache" in flat and "GPU KV cache usage" not in flat:
|
||||
return {"phase": "warming up", "status": "running"}
|
||||
if load_matches:
|
||||
pct = int(load_matches[-1])
|
||||
return {"phase": f"loading {pct}%", "status": "running", "pct": pct}
|
||||
if dl_matches:
|
||||
pct = int(dl_matches[-1])
|
||||
return {"phase": f"downloading {pct}%", "status": "running", "pct": pct}
|
||||
return {}
|
||||
|
||||
|
||||
def _ssh(host, cmd, port=None):
|
||||
"""Build SSH command string with optional port."""
|
||||
pf = f"-p {port} " if port and port != "22" else ""
|
||||
return f"ssh {pf}{host} '{cmd}'"
|
||||
|
||||
|
||||
def _safe_env_prefix(ep: str | None) -> str | None:
|
||||
"""Rewrite a `source <path>` env_prefix so it no-ops if the path is missing.
|
||||
Prevents `line N: <path>: No such file or directory` errors when a serve
|
||||
task is launched against a host that doesn't have the expected venv.
|
||||
|
||||
Also rewrites leading `~/` → `$HOME/` so the path expands inside double
|
||||
quotes (bash only tilde-expands unquoted tokens at word start)."""
|
||||
if not ep:
|
||||
return ep
|
||||
import shlex
|
||||
try:
|
||||
parts = shlex.split(ep, posix=True)
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Invalid env_prefix")
|
||||
if len(parts) != 2 or parts[0] not in {"source", "."}:
|
||||
# Bash conda activation emitted by the frontend:
|
||||
# eval "$(conda shell.bash hook)" && conda activate ENV
|
||||
m = re.fullmatch(r'eval "\$\(conda shell\.bash hook\)" && conda activate (.+)', ep)
|
||||
if m:
|
||||
env = m.group(1).strip()
|
||||
try:
|
||||
env_parts = shlex.split(env, posix=True)
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Invalid env_prefix")
|
||||
if len(env_parts) != 1:
|
||||
raise HTTPException(400, "Invalid env_prefix")
|
||||
return 'eval "$(conda shell.bash hook)" && conda activate ' + shlex.quote(env_parts[0])
|
||||
|
||||
# Plain conda activation, used by Windows/PowerShell and some manual callers.
|
||||
if len(parts) == 3 and parts[0] == "conda" and parts[1] == "activate":
|
||||
return "conda activate " + shlex.quote(parts[2])
|
||||
|
||||
# PowerShell venv activation emitted by the frontend:
|
||||
# & 'C:\path\Scripts\Activate.ps1'
|
||||
if len(parts) == 2 and parts[0] == "&":
|
||||
path = parts[1]
|
||||
if any(c in path for c in "\r\n;&|`$<>"):
|
||||
raise HTTPException(400, "Invalid env_prefix")
|
||||
return "& '" + path.replace("'", "''") + "'"
|
||||
|
||||
raise HTTPException(400, "Invalid env_prefix")
|
||||
path = parts[1]
|
||||
if any(c in path for c in "\r\n;&|`$<>"):
|
||||
raise HTTPException(400, "Invalid env_prefix")
|
||||
# Replace a leading "~/" with "$HOME/" so it survives quoting
|
||||
if path.startswith("~/"):
|
||||
path = "$HOME/" + path[2:]
|
||||
elif path == "~":
|
||||
path = "$HOME"
|
||||
path = path.replace('"', '\\"')
|
||||
return f'[ -f "{path}" ] && source "{path}" || true'
|
||||
|
||||
|
||||
def _ssh_ps(host, script_path, port=None):
|
||||
"""Build SSH command to run a PowerShell script on a Windows remote."""
|
||||
pf = f"-p {port} " if port and port != "22" else ""
|
||||
return f'ssh {pf}{host} "powershell -ExecutionPolicy Bypass -File {script_path}"'
|
||||
|
||||
|
||||
# Windows session dir — stored in user's temp on the remote
|
||||
WIN_SESSION_DIR = "$env:TEMP\\\\odysseus-sessions"
|
||||
1728
routes/cookbook_routes.py
Normal file
1728
routes/cookbook_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
71
routes/diagnostics_routes.py
Normal file
71
routes/diagnostics_routes.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Diagnostics routes — /api/db/stats, /api/rag/stats, /api/test/youtube, /api/test-research."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Form
|
||||
|
||||
from services.youtube.youtube_handler import extract_youtube_id, extract_transcript_async
|
||||
from core.constants import DEFAULT_HOST
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_diagnostics_routes(
|
||||
rag_manager,
|
||||
rag_available: bool,
|
||||
research_handler,
|
||||
) -> APIRouter:
|
||||
router = APIRouter(tags=["diagnostics"])
|
||||
|
||||
@router.get("/api/db/stats")
|
||||
async def get_database_stats() -> Dict[str, Any]:
|
||||
try:
|
||||
from core.database import get_detailed_stats
|
||||
return get_detailed_stats()
|
||||
except Exception as e:
|
||||
logger.error(f"DB stats error: {e}")
|
||||
raise HTTPException(500, "Failed to retrieve database statistics")
|
||||
|
||||
@router.get("/api/rag/stats")
|
||||
async def get_rag_stats() -> Dict[str, Any]:
|
||||
if rag_available and rag_manager:
|
||||
return rag_manager.get_stats()
|
||||
return {"error": "RAG system not available"}
|
||||
|
||||
@router.get("/api/test/youtube")
|
||||
async def test_youtube(url: str) -> Dict[str, Any]:
|
||||
try:
|
||||
video_id = extract_youtube_id(url)
|
||||
if not video_id:
|
||||
return {"error": "Invalid YouTube URL"}
|
||||
|
||||
data = await extract_transcript_async(url, video_id)
|
||||
return {
|
||||
"video_id": video_id,
|
||||
"transcript_success": data.get("success", False),
|
||||
"transcript_length": len(data.get("transcript", "")) if data.get("success") else 0,
|
||||
"transcript_preview": (data.get("transcript", "")[:500] + "...")
|
||||
if data.get("success") and len(data.get("transcript", "")) > 500
|
||||
else data.get("transcript", ""),
|
||||
"error": data.get("error") if not data.get("success") else None,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
@router.post("/api/test-research")
|
||||
async def test_research(query: str = Form("What is machine learning?")) -> Dict[str, Any]:
|
||||
try:
|
||||
endpoint = f"http://{DEFAULT_HOST}:8000/v1/chat/completions"
|
||||
model = "gpt-oss-120b"
|
||||
result = await research_handler.call_research_service(query, endpoint, model)
|
||||
return {
|
||||
"status": "success",
|
||||
"query": query,
|
||||
"result_preview": result[:200] + "..." if len(result) > 200 else result,
|
||||
"result_length": len(result),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e), "query": query}
|
||||
|
||||
return router
|
||||
198
routes/document_helpers.py
Normal file
198
routes/document_helpers.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""document_helpers.py — Pydantic models, doc serializers, owner gating, file-locator helpers shared with document_routes.py."""
|
||||
|
||||
"""Document routes — CRUD for living documents with version history."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import Document, DocumentVersion
|
||||
from core.database import Session as DbSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---- Request schemas ----
|
||||
|
||||
class DocumentCreate(BaseModel):
|
||||
session_id: Optional[str] = None
|
||||
title: str = "Untitled"
|
||||
language: Optional[str] = None
|
||||
content: str = ""
|
||||
|
||||
class DocumentUpdate(BaseModel):
|
||||
content: str
|
||||
summary: Optional[str] = None
|
||||
|
||||
class DocumentPatch(BaseModel):
|
||||
title: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
session_id: Optional[str] = None # link/unlink document to a session
|
||||
|
||||
|
||||
# ---- Helpers ----
|
||||
|
||||
def _doc_to_dict(doc: Document) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": doc.id,
|
||||
"session_id": doc.session_id,
|
||||
"title": doc.title,
|
||||
"language": doc.language,
|
||||
"current_content": doc.current_content,
|
||||
"version_count": doc.version_count,
|
||||
"is_active": doc.is_active,
|
||||
"archived": bool(getattr(doc, "archived", False)),
|
||||
"created_at": (doc.created_at.isoformat() + "Z") if doc.created_at else None,
|
||||
"updated_at": (doc.updated_at.isoformat() + "Z") if doc.updated_at else None,
|
||||
# Source-email provenance (set when doc was created from an email
|
||||
# attachment) — drives the "Send signed reply" menu item.
|
||||
"source_email_uid": getattr(doc, "source_email_uid", None),
|
||||
"source_email_folder": getattr(doc, "source_email_folder", None),
|
||||
"source_email_account_id": getattr(doc, "source_email_account_id", None),
|
||||
"source_email_message_id": getattr(doc, "source_email_message_id", None),
|
||||
}
|
||||
|
||||
def _version_to_dict(v: DocumentVersion) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": v.id,
|
||||
"document_id": v.document_id,
|
||||
"version_number": v.version_number,
|
||||
"content": v.content,
|
||||
"summary": v.summary,
|
||||
"source": v.source,
|
||||
"created_at": v.created_at.isoformat() if v.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _verify_doc_owner(db, doc: Document, user: str):
|
||||
"""Verify `user` owns this document. Raise 404 if not.
|
||||
|
||||
Documents now carry their own `owner` column, so a doc whose session
|
||||
was deleted (session_id → NULL) can still prove ownership and stay
|
||||
openable / cloneable. We trust that column first and only fall back to
|
||||
the session join for any not-yet-backfilled legacy row.
|
||||
"""
|
||||
if user is None:
|
||||
raise HTTPException(403, "Authentication required")
|
||||
if doc.owner is not None:
|
||||
if doc.owner != user:
|
||||
raise HTTPException(404, "Document not found")
|
||||
return
|
||||
# Legacy fallback: derive ownership from the linked session.
|
||||
if not doc.session_id:
|
||||
raise HTTPException(404, "Document not found")
|
||||
session = db.query(DbSession).filter(DbSession.id == doc.session_id).first()
|
||||
if not session or session.owner != user:
|
||||
raise HTTPException(404, "Document not found")
|
||||
|
||||
|
||||
def _owner_session_filter(q, user):
|
||||
"""Restrict a documents query to those owned by `user`.
|
||||
|
||||
Documents now carry their own `owner` column (backfilled at boot from
|
||||
the linked session, or assigned to the admin user for legacy/orphaned
|
||||
docs). We filter on that directly rather than on a session join, so a
|
||||
document whose session was deleted (session_id → NULL) still shows up
|
||||
for its owner instead of silently vanishing from the Library + search.
|
||||
|
||||
The owner backfill runs in init_db before the app serves requests, so
|
||||
by the time this filter is live there are no NULL-owner rows to leak;
|
||||
we therefore match the owner strictly."""
|
||||
if user is None:
|
||||
return q.filter(False)
|
||||
return q.filter(Document.owner == user)
|
||||
|
||||
|
||||
|
||||
def _slug(name: str) -> str:
|
||||
"""Filesystem-friendly version of a document title.
|
||||
|
||||
Whitespace becomes underscores; other unsafe punctuation is dropped.
|
||||
Preserves letters, digits, dot, hyphen, underscore. Idempotent.
|
||||
"""
|
||||
import re as _re
|
||||
s = (name or "").strip()
|
||||
# Drop the trailing extension if the title happens to include one
|
||||
s = _re.sub(r'\.pdf$', '', s, flags=_re.IGNORECASE)
|
||||
s = _re.sub(r'\s+', '_', s)
|
||||
s = _re.sub(r'[^A-Za-z0-9._-]', '', s)
|
||||
s = _re.sub(r'_+', '_', s).strip('_')
|
||||
return s or "form"
|
||||
|
||||
|
||||
# DPI scale for the interactive PDF view. ~150 DPI (2x of 72 PDF user-units).
|
||||
_PDF_RENDER_SCALE = 2.0
|
||||
|
||||
|
||||
def _locate_upload(upload_dir: str, file_id: str):
|
||||
"""Find an upload by its filename ID.
|
||||
|
||||
Lookup order:
|
||||
1. Direct hit at `upload_dir/file_id` (very small deployments).
|
||||
2. The `uploads.json` index that `UploadHandler.save_upload` maintains —
|
||||
maps file_hash → metadata containing the full path. O(1) once loaded.
|
||||
3. Fallback: `os.walk` the date-bucketed tree. Slow on large stores;
|
||||
only triggers for legacy uploads recorded before the index existed.
|
||||
|
||||
`followlinks=False` keeps a stray symlink loop in `data/uploads/` from
|
||||
spinning the walker into infinite recursion.
|
||||
"""
|
||||
import os
|
||||
import json as _json
|
||||
direct = os.path.join(upload_dir, file_id)
|
||||
if os.path.exists(direct):
|
||||
return direct
|
||||
# O(1) via uploads.json
|
||||
try:
|
||||
idx_path = os.path.join(upload_dir, "uploads.json")
|
||||
if os.path.exists(idx_path):
|
||||
with open(idx_path, "r") as f:
|
||||
idx = _json.load(f)
|
||||
for meta in (idx.values() if isinstance(idx, dict) else []):
|
||||
if meta.get("id") == file_id:
|
||||
p = meta.get("path")
|
||||
if p and os.path.exists(p):
|
||||
return p
|
||||
except Exception:
|
||||
pass
|
||||
for root, _dirs, files in os.walk(upload_dir, followlinks=False):
|
||||
if file_id in files:
|
||||
return os.path.join(root, file_id)
|
||||
return None
|
||||
|
||||
|
||||
def _derive_title(content: str) -> str:
|
||||
"""Derive a title from document content."""
|
||||
import re
|
||||
text = content.strip()
|
||||
if not text:
|
||||
return "Untitled"
|
||||
|
||||
# Markdown header
|
||||
md = re.match(r'^#{1,3}\s+(.+)', text, re.MULTILINE)
|
||||
if md:
|
||||
title = md.group(1).strip()
|
||||
if len(title) > 50:
|
||||
title = title[:48] + "…"
|
||||
return title
|
||||
|
||||
# HTML heading
|
||||
html = re.search(r'<h[1-3][^>]*>([^<]+)</h[1-3]>', text, re.IGNORECASE)
|
||||
if html:
|
||||
title = html.group(1).strip()
|
||||
if len(title) > 50:
|
||||
title = title[:48] + "…"
|
||||
return title
|
||||
|
||||
# First non-empty line (if short enough)
|
||||
for line in text.split('\n'):
|
||||
line = line.strip()
|
||||
if line and 2 <= len(line) <= 60:
|
||||
title = re.sub(r'[:#*`]+$', '', line).strip()
|
||||
if title and len(title) > 50:
|
||||
title = title[:48] + "…"
|
||||
return title or "Untitled"
|
||||
|
||||
return "Untitled"
|
||||
1643
routes/document_routes.py
Normal file
1643
routes/document_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
184
routes/editor_draft_routes.py
Normal file
184
routes/editor_draft_routes.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Editor draft routes — persisted in-progress gallery-editor sessions.
|
||||
|
||||
The gallery editor (image canvas) lets users layer edits on top of a
|
||||
photo (or a blank canvas). Persisting those layered sessions to the
|
||||
server makes them survive cache clears and roams across devices —
|
||||
unlike the legacy per-image localStorage drafts.
|
||||
|
||||
Each draft carries:
|
||||
- id — opaque uuid (the client never sees gallery-image ids
|
||||
as draft ids, so blank-canvas drafts work too)
|
||||
- source_image_id (nullable) — back-pointer for "this draft started as
|
||||
an edit of GalleryImage X"
|
||||
- payload — full JSON snapshot (layers as base64 PNG dataURLs,
|
||||
offsets, opacities, etc.) the editor knows how to
|
||||
rehydrate
|
||||
- thumbnail — small data URL for the landing-list grid
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import EditorDraft, SessionLocal
|
||||
from src.auth_helpers import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DraftCreate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
source_image_id: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
payload: Dict[str, Any]
|
||||
thumbnail: Optional[str] = None
|
||||
|
||||
|
||||
class DraftUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
payload: Optional[Dict[str, Any]] = None
|
||||
thumbnail: Optional[str] = None
|
||||
|
||||
|
||||
def _owns(d: EditorDraft, user: Optional[str]) -> bool:
|
||||
if user is None:
|
||||
return True
|
||||
return (d.owner or None) == user
|
||||
|
||||
|
||||
def _summary(d: EditorDraft) -> Dict[str, Any]:
|
||||
"""List-view representation — omits the bulky payload."""
|
||||
return {
|
||||
"id": d.id,
|
||||
"name": d.name or "Untitled",
|
||||
"source_image_id": d.source_image_id,
|
||||
"width": d.width,
|
||||
"height": d.height,
|
||||
"thumbnail": d.thumbnail,
|
||||
"created_at": d.created_at.isoformat() if d.created_at else None,
|
||||
"updated_at": d.updated_at.isoformat() if d.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def setup_editor_draft_routes() -> APIRouter:
|
||||
router = APIRouter(tags=["editor-drafts"])
|
||||
|
||||
@router.get("/api/editor-drafts")
|
||||
async def list_drafts(request: Request) -> Dict[str, List[Dict[str, Any]]]:
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(EditorDraft).filter(EditorDraft.is_active == True)
|
||||
if user is not None:
|
||||
q = q.filter(EditorDraft.owner == user)
|
||||
rows = q.order_by(EditorDraft.updated_at.desc()).limit(200).all()
|
||||
return {"drafts": [_summary(d) for d in rows]}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/api/editor-drafts/{draft_id}")
|
||||
async def get_draft(request: Request, draft_id: str) -> Dict[str, Any]:
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
d = db.query(EditorDraft).filter(
|
||||
EditorDraft.id == draft_id, EditorDraft.is_active == True
|
||||
).first()
|
||||
if not d or not _owns(d, user):
|
||||
raise HTTPException(404, "Draft not found")
|
||||
try:
|
||||
payload = json.loads(d.payload) if d.payload else {}
|
||||
except Exception:
|
||||
payload = {}
|
||||
return {
|
||||
**_summary(d),
|
||||
"payload": payload,
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/api/editor-drafts")
|
||||
async def create_draft(request: Request, body: DraftCreate) -> Dict[str, Any]:
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
d = EditorDraft(
|
||||
id=str(uuid.uuid4()),
|
||||
owner=user,
|
||||
name=(body.name or "Untitled")[:200],
|
||||
source_image_id=body.source_image_id,
|
||||
width=body.width,
|
||||
height=body.height,
|
||||
payload=json.dumps(body.payload or {}),
|
||||
thumbnail=body.thumbnail,
|
||||
)
|
||||
db.add(d)
|
||||
db.commit()
|
||||
db.refresh(d)
|
||||
return _summary(d)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.warning(f"editor-draft create failed: {e}")
|
||||
raise HTTPException(500, "Could not save draft")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.put("/api/editor-drafts/{draft_id}")
|
||||
async def update_draft(request: Request, draft_id: str, body: DraftUpdate) -> Dict[str, Any]:
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
d = db.query(EditorDraft).filter(
|
||||
EditorDraft.id == draft_id, EditorDraft.is_active == True
|
||||
).first()
|
||||
if not d or not _owns(d, user):
|
||||
raise HTTPException(404, "Draft not found")
|
||||
if body.name is not None:
|
||||
d.name = body.name[:200]
|
||||
if body.width is not None:
|
||||
d.width = body.width
|
||||
if body.height is not None:
|
||||
d.height = body.height
|
||||
if body.payload is not None:
|
||||
d.payload = json.dumps(body.payload)
|
||||
if body.thumbnail is not None:
|
||||
d.thumbnail = body.thumbnail
|
||||
db.commit()
|
||||
db.refresh(d)
|
||||
return _summary(d)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.warning(f"editor-draft update failed: {e}")
|
||||
raise HTTPException(500, "Could not update draft")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.delete("/api/editor-drafts/{draft_id}")
|
||||
async def delete_draft(request: Request, draft_id: str) -> Dict[str, str]:
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
d = db.query(EditorDraft).filter(EditorDraft.id == draft_id).first()
|
||||
if not d or not _owns(d, user):
|
||||
raise HTTPException(404, "Draft not found")
|
||||
d.is_active = False
|
||||
db.commit()
|
||||
return {"status": "deleted", "id": draft_id}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(500, str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return router
|
||||
1274
routes/email_helpers.py
Normal file
1274
routes/email_helpers.py
Normal file
File diff suppressed because it is too large
Load Diff
1006
routes/email_pollers.py
Normal file
1006
routes/email_pollers.py
Normal file
File diff suppressed because it is too large
Load Diff
3038
routes/email_routes.py
Normal file
3038
routes/email_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
318
routes/embedding_routes.py
Normal file
318
routes/embedding_routes.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# routes/embedding_routes.py
|
||||
"""Routes for managing local fastembed embedding models and custom endpoints."""
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import logging
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, HTTPException, Form
|
||||
from core.constants import BASE_DIR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ENDPOINT_FILE = os.path.join(BASE_DIR, "data", "embedding_endpoint.json")
|
||||
|
||||
# Track in-progress downloads
|
||||
_downloading: dict = {}
|
||||
|
||||
# Curated recommendations — good coverage of size/quality tiers
|
||||
RECOMMENDED_MODELS = {
|
||||
"sentence-transformers/all-MiniLM-L6-v2", # 384d, 90MB — fast & tiny, good default
|
||||
"BAAI/bge-small-en-v1.5", # 384d, 67MB — smallest, solid quality
|
||||
"nomic-ai/nomic-embed-text-v1.5-Q", # 768d, 130MB — quantized, great bang/buck
|
||||
"BAAI/bge-base-en-v1.5", # 768d, 210MB — balanced mid-range
|
||||
"snowflake/snowflake-arctic-embed-m", # 768d, 430MB — strong performer
|
||||
"BAAI/bge-large-en-v1.5", # 1024d, 1.2GB — highest quality
|
||||
}
|
||||
|
||||
|
||||
def _cache_dir() -> str:
|
||||
"""Get the fastembed cache directory.
|
||||
|
||||
Defaults to a persistent path under the repo's data/ dir. The old
|
||||
default lived in /tmp, which many systems wipe on reboot — forcing a
|
||||
full re-download of the embedding model after every restart.
|
||||
"""
|
||||
env = os.environ.get("FASTEMBED_CACHE_PATH")
|
||||
if env:
|
||||
return env
|
||||
return os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"data", "fastembed_cache",
|
||||
)
|
||||
|
||||
|
||||
def _model_cache_name(hf_source: str) -> str:
|
||||
"""Convert HF source like 'qdrant/all-MiniLM-L6-v2-onnx' to cache dir name."""
|
||||
return "models--" + hf_source.replace("/", "--")
|
||||
|
||||
|
||||
def _is_downloaded(hf_source: str) -> bool:
|
||||
"""Check if a model is already cached."""
|
||||
cache = _cache_dir()
|
||||
model_dir = os.path.join(cache, _model_cache_name(hf_source))
|
||||
if not os.path.isdir(model_dir):
|
||||
return False
|
||||
# Check for actual model files (not just empty dir)
|
||||
snapshots = os.path.join(model_dir, "snapshots")
|
||||
if os.path.isdir(snapshots):
|
||||
return any(os.listdir(snapshots))
|
||||
# Also check for blobs (older cache format)
|
||||
blobs = os.path.join(model_dir, "blobs")
|
||||
return os.path.isdir(blobs) and any(os.listdir(blobs))
|
||||
|
||||
|
||||
def _active_model() -> str:
|
||||
"""Get the currently configured fastembed model name."""
|
||||
return os.environ.get("FASTEMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
|
||||
|
||||
|
||||
def _dir_size_mb(path: str) -> float:
|
||||
"""Get directory size in MB."""
|
||||
total = 0
|
||||
for dirpath, _, filenames in os.walk(path):
|
||||
for f in filenames:
|
||||
fp = os.path.join(dirpath, f)
|
||||
try:
|
||||
total += os.path.getsize(fp)
|
||||
except OSError:
|
||||
pass
|
||||
return round(total / (1024 * 1024), 1)
|
||||
|
||||
|
||||
def _load_custom_endpoint() -> dict:
|
||||
"""Load the saved custom embedding endpoint, if any."""
|
||||
try:
|
||||
if os.path.exists(_ENDPOINT_FILE):
|
||||
return json.loads(Path(_ENDPOINT_FILE).read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_custom_endpoint(data: dict):
|
||||
Path(_ENDPOINT_FILE).parent.mkdir(parents=True, exist_ok=True)
|
||||
Path(_ENDPOINT_FILE).write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
def setup_embedding_routes():
|
||||
router = APIRouter(prefix="/api/embeddings")
|
||||
|
||||
@router.get("/models")
|
||||
def list_models():
|
||||
"""List all available fastembed models with download status."""
|
||||
try:
|
||||
from fastembed import TextEmbedding
|
||||
except ImportError:
|
||||
raise HTTPException(503, "fastembed is not installed")
|
||||
|
||||
active = _active_model()
|
||||
catalog = TextEmbedding.list_supported_models()
|
||||
result = []
|
||||
|
||||
for m in catalog:
|
||||
hf_src = m.get("sources", {}).get("hf", "")
|
||||
downloaded = _is_downloaded(hf_src) if hf_src else False
|
||||
|
||||
cached_size = None
|
||||
if downloaded and hf_src:
|
||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
||||
cached_size = _dir_size_mb(model_path)
|
||||
|
||||
result.append({
|
||||
"model": m["model"],
|
||||
"dim": m.get("dim"),
|
||||
"size_gb": m.get("size_in_GB", 0),
|
||||
"description": m.get("description", ""),
|
||||
"downloaded": downloaded,
|
||||
"downloading": m["model"] in _downloading,
|
||||
"active": m["model"] == active,
|
||||
"recommended": m["model"] in RECOMMENDED_MODELS,
|
||||
"cached_size_mb": cached_size,
|
||||
})
|
||||
|
||||
# Sort: active first, then downloaded, then by size
|
||||
result.sort(key=lambda x: (not x["active"], not x["downloaded"], x["size_gb"]))
|
||||
return result
|
||||
|
||||
@router.post("/models/{model_name:path}/download")
|
||||
async def download_model(model_name: str):
|
||||
"""Download a fastembed model. Returns when complete."""
|
||||
try:
|
||||
from fastembed import TextEmbedding
|
||||
except ImportError:
|
||||
raise HTTPException(503, "fastembed is not installed")
|
||||
|
||||
# Validate model exists
|
||||
catalog = {m["model"]: m for m in TextEmbedding.list_supported_models()}
|
||||
if model_name not in catalog:
|
||||
raise HTTPException(404, f"Unknown model: {model_name}")
|
||||
|
||||
hf_src = catalog[model_name].get("sources", {}).get("hf", "")
|
||||
if hf_src and _is_downloaded(hf_src):
|
||||
return {"status": "already_downloaded", "model": model_name}
|
||||
|
||||
if model_name in _downloading:
|
||||
return {"status": "already_downloading", "model": model_name}
|
||||
|
||||
_downloading[model_name] = True
|
||||
try:
|
||||
# Run in thread to not block the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
cache = _cache_dir()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: TextEmbedding(model_name=model_name, cache_dir=cache),
|
||||
)
|
||||
return {"status": "downloaded", "model": model_name}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download {model_name}: {e}")
|
||||
raise HTTPException(500, f"Download failed: {str(e)}")
|
||||
finally:
|
||||
_downloading.pop(model_name, None)
|
||||
|
||||
@router.get("/models/{model_name:path}/status")
|
||||
def download_status(model_name: str):
|
||||
"""Check download status of a model."""
|
||||
try:
|
||||
from fastembed import TextEmbedding
|
||||
except ImportError:
|
||||
raise HTTPException(503, "fastembed is not installed")
|
||||
|
||||
catalog = {m["model"]: m for m in TextEmbedding.list_supported_models()}
|
||||
if model_name not in catalog:
|
||||
raise HTTPException(404, f"Unknown model: {model_name}")
|
||||
|
||||
hf_src = catalog[model_name].get("sources", {}).get("hf", "")
|
||||
downloaded = _is_downloaded(hf_src) if hf_src else False
|
||||
|
||||
return {
|
||||
"model": model_name,
|
||||
"downloaded": downloaded,
|
||||
"downloading": model_name in _downloading,
|
||||
}
|
||||
|
||||
@router.delete("/models/{model_name:path}")
|
||||
def delete_model(model_name: str):
|
||||
"""Delete a cached model."""
|
||||
if model_name == _active_model():
|
||||
raise HTTPException(400, "Cannot delete the active embedding model")
|
||||
|
||||
if model_name in _downloading:
|
||||
raise HTTPException(400, "Model is currently downloading")
|
||||
|
||||
try:
|
||||
from fastembed import TextEmbedding
|
||||
except ImportError:
|
||||
raise HTTPException(503, "fastembed is not installed")
|
||||
|
||||
catalog = {m["model"]: m for m in TextEmbedding.list_supported_models()}
|
||||
if model_name not in catalog:
|
||||
raise HTTPException(404, f"Unknown model: {model_name}")
|
||||
|
||||
hf_src = catalog[model_name].get("sources", {}).get("hf", "")
|
||||
if not hf_src:
|
||||
raise HTTPException(400, "No cache source for this model")
|
||||
|
||||
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
|
||||
if not os.path.isdir(model_path):
|
||||
return {"deleted": False, "message": "Model not cached"}
|
||||
|
||||
shutil.rmtree(model_path)
|
||||
logger.info(f"Deleted cached model: {model_name} ({model_path})")
|
||||
return {"deleted": True, "model": model_name}
|
||||
|
||||
@router.get("/endpoint")
|
||||
def get_endpoint():
|
||||
"""Get the current custom embedding endpoint config."""
|
||||
saved = _load_custom_endpoint()
|
||||
current_url = os.environ.get("EMBEDDING_URL", "")
|
||||
return {
|
||||
"url": saved.get("url", current_url),
|
||||
"model": saved.get("model", os.environ.get("EMBEDDING_MODEL", "")),
|
||||
"active": bool(saved.get("url") or current_url),
|
||||
}
|
||||
|
||||
@router.post("/endpoint")
|
||||
def set_endpoint(url: str = Form(...), model: str = Form("")):
|
||||
"""Save a custom embedding endpoint URL."""
|
||||
url = url.strip()
|
||||
if not url:
|
||||
raise HTTPException(400, "URL is required")
|
||||
|
||||
# Quick health check
|
||||
try:
|
||||
import httpx
|
||||
resp = httpx.post(
|
||||
url,
|
||||
json={"input": ["test"], "model": model or "test"},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except Exception as e:
|
||||
raise HTTPException(400, f"Endpoint unreachable: {e}")
|
||||
|
||||
# Persist and set in environment for immediate use
|
||||
data = {"url": url}
|
||||
if model:
|
||||
data["model"] = model
|
||||
_save_custom_endpoint(data)
|
||||
os.environ["EMBEDDING_URL"] = url
|
||||
if model:
|
||||
os.environ["EMBEDDING_MODEL"] = model
|
||||
|
||||
# Reset the RAG singleton so it picks up the new endpoint
|
||||
import src.rag_singleton as _rs
|
||||
_rs.rag_instance = None
|
||||
_rs._last_attempt = 0
|
||||
|
||||
# Clear the HTTP-embedding "down" latch so the new endpoint is re-probed
|
||||
# instead of staying on the FastEmbed fallback for the process lifetime.
|
||||
try:
|
||||
from src.embeddings import reset_http_embed_state
|
||||
reset_http_embed_state()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reset ChromaDB client (collections will be recreated with new embeddings)
|
||||
try:
|
||||
from src.chroma_client import reset_client
|
||||
reset_client()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"Custom embedding endpoint set: {url}")
|
||||
return {"success": True, "url": url, "model": model}
|
||||
|
||||
@router.delete("/endpoint")
|
||||
def clear_endpoint():
|
||||
"""Clear the custom endpoint and revert to local fastembed."""
|
||||
if os.path.exists(_ENDPOINT_FILE):
|
||||
os.remove(_ENDPOINT_FILE)
|
||||
|
||||
# Remove from environment
|
||||
os.environ.pop("EMBEDDING_URL", None)
|
||||
os.environ.pop("EMBEDDING_MODEL", None)
|
||||
|
||||
# Reset the RAG singleton so it falls back to fastembed
|
||||
import src.rag_singleton as _rs
|
||||
_rs.rag_instance = None
|
||||
_rs._last_attempt = 0
|
||||
try:
|
||||
from src.embeddings import reset_http_embed_state
|
||||
reset_http_embed_state()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reset ChromaDB client
|
||||
try:
|
||||
from src.chroma_client import reset_client
|
||||
reset_client()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("Custom embedding endpoint cleared, reverting to local fastembed")
|
||||
return {"success": True}
|
||||
|
||||
return router
|
||||
70
routes/emoji_routes.py
Normal file
70
routes/emoji_routes.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# routes/emoji_routes.py
|
||||
# Same-origin emoji SVG proxy. The frontend rewrites emoji in chat to a
|
||||
# <span class="emoji" style="--em:url('/api/emoji/<codepoints>.svg')">
|
||||
# which uses the returned SVG as a CSS mask tinted to the text color, so emoji
|
||||
# render as monochrome line icons (project rule: never colorful emoji). The
|
||||
# black line-art SVGs are lazily fetched from the OpenMoji CDN on first use and
|
||||
# cached on disk, so:
|
||||
# - the client only ever talks to our own origin (no CDN dep, no CSP change),
|
||||
# - the repo isn't bloated with thousands of SVG files,
|
||||
# - it works offline once an emoji has been seen once.
|
||||
# Unknown/unreachable codepoints return a transparent SVG (not 404), so the CSS
|
||||
# mask shows nothing rather than a solid currentColor box.
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import FileResponse, Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_CACHE_DIR = Path(__file__).resolve().parent.parent / "data" / "emoji_cache"
|
||||
# OpenMoji "black" set = monochrome line-art SVGs. Filenames are the codepoints
|
||||
# in UPPERCASE (FE0F dropped, same as we compute), '-' joined.
|
||||
_OPENMOJI_BASE = "https://cdn.jsdelivr.net/npm/openmoji@15.0.0/black/svg"
|
||||
# codepoints like "1f600" or "1f468-200d-1f469-200d-1f467" (lowercase hex, '-' joined)
|
||||
_CODE_RE = re.compile(r"^[0-9a-f]{2,6}(?:-[0-9a-f]{2,6})*$")
|
||||
_SVG_HEADERS = {"Cache-Control": "public, max-age=31536000, immutable"}
|
||||
# Returned when a codepoint is unknown/unreachable: an empty (transparent) SVG,
|
||||
# so the CSS mask renders nothing instead of a solid box. Not cached, so a later
|
||||
# request can still pick up the real glyph once the CDN is reachable.
|
||||
_BLANK_SVG = b'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1"></svg>'
|
||||
_BLANK_HEADERS = {"Cache-Control": "no-store"}
|
||||
|
||||
|
||||
def setup_emoji_routes() -> APIRouter:
|
||||
router = APIRouter(prefix="/api/emoji", tags=["emoji"])
|
||||
|
||||
def _blank() -> Response:
|
||||
return Response(_BLANK_SVG, media_type="image/svg+xml", headers=_BLANK_HEADERS)
|
||||
|
||||
@router.get("/{code}.svg")
|
||||
async def emoji_svg(code: str):
|
||||
code = code.lower()
|
||||
if not _CODE_RE.match(code):
|
||||
return _blank()
|
||||
|
||||
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
fp = _CACHE_DIR / f"{code}.svg"
|
||||
if fp.exists():
|
||||
return FileResponse(fp, media_type="image/svg+xml", headers=_SVG_HEADERS)
|
||||
|
||||
# First time we've seen this emoji — fetch the OpenMoji black SVG + cache
|
||||
# it. OpenMoji filenames are the codepoints uppercased.
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||
r = await client.get(f"{_OPENMOJI_BASE}/{code.upper()}.svg")
|
||||
if r.status_code == 200 and b"<svg" in r.content[:256]:
|
||||
try:
|
||||
fp.write_bytes(r.content)
|
||||
except Exception:
|
||||
pass # cache write is best-effort
|
||||
return Response(r.content, media_type="image/svg+xml", headers=_SVG_HEADERS)
|
||||
except Exception as e:
|
||||
logger.warning("emoji fetch %s failed: %s", code, e)
|
||||
|
||||
return _blank()
|
||||
|
||||
return router
|
||||
47
routes/font_routes.py
Normal file
47
routes/font_routes.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Custom font discovery — lists user-supplied font files in static/fonts/custom/."""
|
||||
import os
|
||||
import re
|
||||
from fastapi import APIRouter
|
||||
|
||||
CUSTOM_FONTS_DIR = os.path.join("static", "fonts", "custom")
|
||||
FONT_EXTENSIONS = {".ttf", ".otf", ".woff", ".woff2"}
|
||||
|
||||
|
||||
def _derive_family(filename):
|
||||
"""Derive a font-family name from a filename like 'JetBrainsMono-Regular.woff2' → 'JetBrains Mono'."""
|
||||
name = os.path.splitext(filename)[0]
|
||||
# Strip common weight/style suffixes
|
||||
name = re.sub(
|
||||
r'[-_ ]?(Thin|ExtraLight|UltraLight|Light|Regular|Medium|SemiBold|DemiBold|Bold|ExtraBold|UltraBold|Black|Heavy|Italic|Oblique|Variable|VF)$',
|
||||
'', name, flags=re.IGNORECASE
|
||||
)
|
||||
# Insert spaces before uppercase runs: "JetBrainsMono" → "Jet Brains Mono"
|
||||
name = re.sub(r'(?<=[a-z])(?=[A-Z])', ' ', name)
|
||||
# Replace dashes/underscores with spaces
|
||||
name = re.sub(r'[-_]+', ' ', name).strip()
|
||||
return name or filename
|
||||
|
||||
|
||||
def setup_font_routes():
|
||||
router = APIRouter(prefix="/api/fonts", tags=["fonts"])
|
||||
|
||||
@router.get("/custom")
|
||||
async def list_custom_fonts():
|
||||
"""Return available custom fonts grouped by derived family name."""
|
||||
os.makedirs(CUSTOM_FONTS_DIR, exist_ok=True)
|
||||
families = {}
|
||||
for f in sorted(os.listdir(CUSTOM_FONTS_DIR)):
|
||||
ext = os.path.splitext(f)[1].lower()
|
||||
if ext not in FONT_EXTENSIONS:
|
||||
continue
|
||||
family = _derive_family(f)
|
||||
if family not in families:
|
||||
families[family] = []
|
||||
families[family].append({
|
||||
"file": f,
|
||||
"url": f"/static/fonts/custom/{f}",
|
||||
"format": ext.lstrip('.'),
|
||||
})
|
||||
return {"fonts": families}
|
||||
|
||||
return router
|
||||
125
routes/gallery_helpers.py
Normal file
125
routes/gallery_helpers.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""gallery_helpers.py — extracted helpers, models, and small utilities.
|
||||
|
||||
Imported by gallery_routes.py."""
|
||||
|
||||
"""Gallery routes — browsable library for photos and AI-generated images."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import GalleryImage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---- Request schemas ----
|
||||
|
||||
class GalleryPatch(BaseModel):
|
||||
tags: Optional[str] = None
|
||||
favorite: Optional[bool] = None
|
||||
album_id: Optional[str] = None
|
||||
|
||||
|
||||
# ---- EXIF extraction ----
|
||||
|
||||
def _extract_exif(content: bytes) -> dict:
|
||||
"""Extract EXIF metadata from image bytes. Returns dict of fields."""
|
||||
result = {"width": None, "height": None}
|
||||
try:
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
img = Image.open(BytesIO(content))
|
||||
result["width"] = img.width
|
||||
result["height"] = img.height
|
||||
|
||||
exif = img._getexif() if hasattr(img, '_getexif') else None
|
||||
if not exif:
|
||||
return result
|
||||
|
||||
# EXIF tag IDs
|
||||
# 271=Make, 272=Model, 306=DateTime, 36867=DateTimeOriginal
|
||||
# 34853=GPSInfo
|
||||
result["camera_make"] = str(exif.get(271, "")).strip() or None
|
||||
result["camera_model"] = str(exif.get(272, "")).strip() or None
|
||||
|
||||
# Date taken
|
||||
for tag_id in (36867, 36868, 306): # DateTimeOriginal, DateTimeDigitized, DateTime
|
||||
raw = exif.get(tag_id)
|
||||
if raw:
|
||||
try:
|
||||
result["taken_at"] = datetime.strptime(str(raw).strip(), "%Y:%m:%d %H:%M:%S")
|
||||
break
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# GPS
|
||||
gps_info = exif.get(34853)
|
||||
if gps_info and isinstance(gps_info, dict):
|
||||
try:
|
||||
def _to_deg(vals):
|
||||
d, m, s = [float(v) for v in vals]
|
||||
return d + m / 60 + s / 3600
|
||||
if 2 in gps_info and 4 in gps_info:
|
||||
lat = _to_deg(gps_info[2])
|
||||
lng = _to_deg(gps_info[4])
|
||||
if gps_info.get(1) == 'S': lat = -lat
|
||||
if gps_info.get(3) == 'W': lng = -lng
|
||||
result["gps_lat"] = f"{lat:.6f}"
|
||||
result["gps_lng"] = f"{lng:.6f}"
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
# User-visible failure (photo loses metadata): surface at WARNING
|
||||
# and record on the result so the upload endpoint can pass it back.
|
||||
logger.warning(f"EXIF extraction failed: {e}")
|
||||
result["exif_error"] = str(e)
|
||||
return result
|
||||
|
||||
|
||||
# ---- Helpers ----
|
||||
|
||||
def _image_to_dict(img: GalleryImage, session_name: str = None) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": img.id,
|
||||
"filename": img.filename,
|
||||
"url": f"/api/generated-image/{img.filename}",
|
||||
"prompt": img.prompt,
|
||||
"model": img.model,
|
||||
"size": img.size,
|
||||
"quality": img.quality,
|
||||
"tags": img.tags or "",
|
||||
"ai_tags": img.ai_tags or "",
|
||||
"user_tags": img.tags or "",
|
||||
"session_id": img.session_id,
|
||||
"session_name": session_name,
|
||||
"album_id": img.album_id,
|
||||
"is_active": img.is_active,
|
||||
"favorite": img.favorite or False,
|
||||
"taken_at": img.taken_at.isoformat() if img.taken_at else None,
|
||||
"camera": f"{img.camera_make or ''} {img.camera_model or ''}".strip() or None,
|
||||
"gps": {"lat": img.gps_lat, "lng": img.gps_lng} if img.gps_lat else None,
|
||||
"width": img.width,
|
||||
"height": img.height,
|
||||
"file_size": img.file_size,
|
||||
"created_at": img.created_at.isoformat() if img.created_at else None,
|
||||
"updated_at": img.updated_at.isoformat() if img.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _owner_filter(q, user):
|
||||
"""Apply owner filtering to a gallery query."""
|
||||
if user is None:
|
||||
return q.filter(False)
|
||||
return q.filter(GalleryImage.owner == user)
|
||||
|
||||
|
||||
|
||||
def _human_size(nbytes):
|
||||
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
||||
if abs(nbytes) < 1024:
|
||||
return f"{nbytes:.1f} {unit}"
|
||||
nbytes /= 1024
|
||||
return f"{nbytes:.1f} PB"
|
||||
1763
routes/gallery_routes.py
Normal file
1763
routes/gallery_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
619
routes/history_routes.py
Normal file
619
routes/history_routes.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""History routes — session history, truncation, fork, conversation topics."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
|
||||
from core.models import ChatMessage
|
||||
from core.database import SessionLocal, ChatMessage as DbChatMessage, Session as DbSession
|
||||
from src.topic_analyzer import analyze_topics
|
||||
from routes.session_routes import _verify_session_owner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_history_routes(session_manager) -> APIRouter:
|
||||
router = APIRouter(tags=["history"])
|
||||
|
||||
@router.get("/api/history/{session_id}")
|
||||
async def get_session_history(request: Request, session_id: str) -> Dict[str, Any]:
|
||||
_verify_session_owner(request, session_id)
|
||||
try:
|
||||
session = session_manager.get_session(session_id)
|
||||
except KeyError:
|
||||
raise HTTPException(404, f"Session '{session_id}' not found")
|
||||
|
||||
history_dict = []
|
||||
for msg in session.history:
|
||||
if isinstance(msg, ChatMessage):
|
||||
# Skip hidden messages (e.g. compaction summaries for AI context)
|
||||
if msg.metadata and msg.metadata.get("hidden"):
|
||||
continue
|
||||
entry = {"role": msg.role, "content": msg.content}
|
||||
if msg.metadata:
|
||||
entry["metadata"] = msg.metadata
|
||||
history_dict.append(entry)
|
||||
elif isinstance(msg, dict):
|
||||
if msg.get("metadata", {}).get("hidden"):
|
||||
continue
|
||||
entry = {
|
||||
"role": msg.get("role", ""),
|
||||
"content": msg.get("content", ""),
|
||||
}
|
||||
if msg.get("metadata"):
|
||||
entry["metadata"] = msg["metadata"]
|
||||
history_dict.append(entry)
|
||||
|
||||
# Fallback: load from DB if in-memory is empty
|
||||
if not history_dict:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_messages = (
|
||||
db.query(DbChatMessage)
|
||||
.filter(DbChatMessage.session_id == session_id)
|
||||
.order_by(DbChatMessage.timestamp)
|
||||
.all()
|
||||
)
|
||||
import json as _json
|
||||
history_dict = []
|
||||
for m in db_messages:
|
||||
entry = {"role": m.role, "content": m.content}
|
||||
meta = {}
|
||||
if m.meta_data:
|
||||
try:
|
||||
meta = _json.loads(m.meta_data) or {}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
meta = {}
|
||||
if m.timestamp and "timestamp" not in meta:
|
||||
meta["timestamp"] = m.timestamp.isoformat() + "Z"
|
||||
if meta:
|
||||
entry["metadata"] = meta
|
||||
history_dict.append(entry)
|
||||
if history_dict:
|
||||
session.history = [
|
||||
ChatMessage(role=m["role"], content=m["content"], metadata=m.get("metadata"))
|
||||
for m in history_dict
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"DB fallback failed for {session_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return {
|
||||
"history": history_dict,
|
||||
"model": session.model,
|
||||
"endpoint_url": session.endpoint_url,
|
||||
"name": session.name,
|
||||
}
|
||||
|
||||
@router.post("/api/session/{session_id}/truncate")
|
||||
async def truncate_session(request: Request, session_id: str):
|
||||
_verify_session_owner(request, session_id)
|
||||
try:
|
||||
body = await request.json()
|
||||
keep_count = body.get("keep_count", 0)
|
||||
result = session_manager.truncate_messages(session_id, keep_count)
|
||||
return {"status": "ok", "kept": keep_count, "truncated": result}
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Truncate error {session_id}: {e}")
|
||||
raise HTTPException(500, str(e))
|
||||
|
||||
@router.post("/api/session/{session_id}/message")
|
||||
async def add_message(request: Request, session_id: str):
|
||||
"""Add a message to a session (for slash command persistence)."""
|
||||
_verify_session_owner(request, session_id)
|
||||
try:
|
||||
body = await request.json()
|
||||
role = body.get("role", "assistant")
|
||||
content = body.get("content", "")
|
||||
if not content:
|
||||
raise HTTPException(400, "content is required")
|
||||
msg = ChatMessage(role=role, content=content, metadata=body.get("metadata"))
|
||||
session_manager.add_message(session_id, msg)
|
||||
return {"status": "ok"}
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
|
||||
@router.post("/api/session/{session_id}/delete-messages")
|
||||
async def delete_messages(request: Request, session_id: str):
|
||||
"""Delete specific messages by DB ID (or legacy index)."""
|
||||
_verify_session_owner(request, session_id)
|
||||
try:
|
||||
body = await request.json()
|
||||
msg_ids = body.get("msg_ids", [])
|
||||
indices = body.get("indices") # legacy fallback
|
||||
|
||||
session = session_manager.get_session(session_id)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
if msg_ids:
|
||||
# New ID-based delete
|
||||
deleted = 0
|
||||
for mid in msg_ids:
|
||||
db_msg = db.query(DbChatMessage).filter(
|
||||
DbChatMessage.id == mid,
|
||||
DbChatMessage.session_id == session_id,
|
||||
).first()
|
||||
if db_msg:
|
||||
db.delete(db_msg)
|
||||
deleted += 1
|
||||
|
||||
# Remove from in-memory history by matching _db_id
|
||||
def _get_db_id(m):
|
||||
meta = m.metadata if isinstance(m, ChatMessage) else (m.get('metadata') if isinstance(m, dict) else None)
|
||||
return meta.get('_db_id') if isinstance(meta, dict) else None
|
||||
session.history = [m for m in session.history if _get_db_id(m) not in msg_ids]
|
||||
elif indices:
|
||||
# Legacy index-based delete
|
||||
indices = sorted(indices, reverse=True)
|
||||
db_messages = db.query(DbChatMessage).filter(
|
||||
DbChatMessage.session_id == session_id
|
||||
).order_by(DbChatMessage.timestamp).all()
|
||||
|
||||
deleted = 0
|
||||
for idx in indices:
|
||||
if 0 <= idx < len(db_messages):
|
||||
db.delete(db_messages[idx])
|
||||
deleted += 1
|
||||
if 0 <= idx < len(session.history):
|
||||
session.history.pop(idx)
|
||||
else:
|
||||
return {"status": "ok", "deleted": 0}
|
||||
|
||||
session.message_count = len(session.history)
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.message_count = len(session.history)
|
||||
from datetime import datetime, timezone
|
||||
db_session.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
db.commit()
|
||||
return {"status": "ok", "deleted": deleted}
|
||||
finally:
|
||||
db.close()
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Delete messages error {session_id}: {e}")
|
||||
raise HTTPException(500, str(e))
|
||||
|
||||
@router.post("/api/session/{session_id}/edit-message")
|
||||
async def edit_message(request: Request, session_id: str):
|
||||
"""Edit the content of a message by its database ID."""
|
||||
_verify_session_owner(request, session_id)
|
||||
try:
|
||||
body = await request.json()
|
||||
msg_id = body.get("msg_id")
|
||||
content = body.get("content")
|
||||
if not msg_id or content is None:
|
||||
raise HTTPException(400, "msg_id and content are required")
|
||||
|
||||
session = session_manager.get_session(session_id)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_msg = db.query(DbChatMessage).filter(
|
||||
DbChatMessage.id == msg_id,
|
||||
DbChatMessage.session_id == session_id,
|
||||
).first()
|
||||
if not db_msg:
|
||||
raise HTTPException(404, "Message not found")
|
||||
|
||||
db_msg.content = content
|
||||
meta = {}
|
||||
if db_msg.meta_data:
|
||||
try: meta = json.loads(db_msg.meta_data)
|
||||
except (json.JSONDecodeError, ValueError): pass
|
||||
meta['edited'] = True
|
||||
db_msg.meta_data = json.dumps(meta)
|
||||
|
||||
# Update in-memory history by matching _db_id
|
||||
for hmsg in session.history:
|
||||
hmeta = hmsg.metadata if isinstance(hmsg, ChatMessage) else hmsg.get('metadata')
|
||||
if isinstance(hmeta, dict) and hmeta.get('_db_id') == msg_id:
|
||||
if isinstance(hmsg, ChatMessage):
|
||||
hmsg.content = content
|
||||
hmsg.metadata['edited'] = True
|
||||
elif isinstance(hmsg, dict):
|
||||
hmsg['content'] = content
|
||||
hmsg['metadata']['edited'] = True
|
||||
break
|
||||
|
||||
db.commit()
|
||||
return {"status": "ok"}
|
||||
finally:
|
||||
db.close()
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Edit message error {session_id}: {e}")
|
||||
raise HTTPException(500, str(e))
|
||||
|
||||
@router.post("/api/session/{session_id}/mark-stopped")
|
||||
async def mark_stopped(request: Request, session_id: str):
|
||||
"""Mark the last assistant message as stopped by user."""
|
||||
_verify_session_owner(request, session_id)
|
||||
try:
|
||||
session = session_manager.get_session(session_id)
|
||||
# Find last assistant message and add stopped metadata
|
||||
for msg in reversed(session.history):
|
||||
if (isinstance(msg, ChatMessage) and msg.role == 'assistant') or \
|
||||
(isinstance(msg, dict) and msg.get('role') == 'assistant'):
|
||||
if isinstance(msg, ChatMessage):
|
||||
if not msg.metadata:
|
||||
msg.metadata = {}
|
||||
msg.metadata['stopped'] = True
|
||||
if not msg.metadata.get('model'):
|
||||
msg.metadata['model'] = session.model
|
||||
else:
|
||||
if 'metadata' not in msg:
|
||||
msg['metadata'] = {}
|
||||
msg['metadata']['stopped'] = True
|
||||
if not msg['metadata'].get('model'):
|
||||
msg['metadata']['model'] = session.model
|
||||
break
|
||||
# Also update in DB
|
||||
db = SessionLocal()
|
||||
try:
|
||||
import json as _json
|
||||
db_messages = (
|
||||
db.query(DbChatMessage)
|
||||
.filter(DbChatMessage.session_id == session_id, DbChatMessage.role == 'assistant')
|
||||
.order_by(DbChatMessage.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if db_messages:
|
||||
meta = {}
|
||||
if db_messages.meta_data:
|
||||
try:
|
||||
meta = _json.loads(db_messages.meta_data)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
meta['stopped'] = True
|
||||
if not meta.get('model'):
|
||||
meta['model'] = session.model
|
||||
db_messages.meta_data = _json.dumps(meta)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
session_manager.save_sessions()
|
||||
return {"status": "ok"}
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Mark stopped error {session_id}: {e}")
|
||||
raise HTTPException(500, str(e))
|
||||
|
||||
@router.post("/api/session/{session_id}/update-last-meta")
|
||||
async def update_last_meta(request: Request, session_id: str):
|
||||
"""Merge metadata into the last assistant message (e.g. save variants)."""
|
||||
_verify_session_owner(request, session_id)
|
||||
try:
|
||||
body = await request.json()
|
||||
meta_update = body.get("metadata", {})
|
||||
session = session_manager.get_session(session_id)
|
||||
|
||||
# Update in-memory
|
||||
for msg in reversed(session.history):
|
||||
if (isinstance(msg, ChatMessage) and msg.role == 'assistant') or \
|
||||
(isinstance(msg, dict) and msg.get('role') == 'assistant'):
|
||||
if isinstance(msg, ChatMessage):
|
||||
if not msg.metadata:
|
||||
msg.metadata = {}
|
||||
msg.metadata.update(meta_update)
|
||||
else:
|
||||
if 'metadata' not in msg:
|
||||
msg['metadata'] = {}
|
||||
msg['metadata'].update(meta_update)
|
||||
break
|
||||
|
||||
# Update in DB
|
||||
db = SessionLocal()
|
||||
try:
|
||||
import json as _json
|
||||
db_msg = (
|
||||
db.query(DbChatMessage)
|
||||
.filter(DbChatMessage.session_id == session_id, DbChatMessage.role == 'assistant')
|
||||
.order_by(DbChatMessage.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if db_msg:
|
||||
meta = {}
|
||||
if db_msg.meta_data:
|
||||
try: meta = _json.loads(db_msg.meta_data)
|
||||
except (json.JSONDecodeError, ValueError): pass
|
||||
meta.update(meta_update)
|
||||
db_msg.meta_data = _json.dumps(meta)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
session_manager.save_sessions()
|
||||
return {"status": "ok"}
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Update last meta error {session_id}: {e}")
|
||||
raise HTTPException(500, str(e))
|
||||
|
||||
@router.post("/api/session/{session_id}/merge-last-assistant")
|
||||
async def merge_last_assistant(request: Request, session_id: str):
|
||||
"""Merge the last two assistant messages into one (for continue)."""
|
||||
_verify_session_owner(request, session_id)
|
||||
try:
|
||||
body = await request.json()
|
||||
separator = body.get("separator", "\n\n")
|
||||
session = session_manager.get_session(session_id)
|
||||
|
||||
# Find last two assistant messages in-memory
|
||||
ai_indices = []
|
||||
for i, msg in enumerate(session.history):
|
||||
role = msg.role if isinstance(msg, ChatMessage) else msg.get('role', '')
|
||||
if role == 'assistant':
|
||||
ai_indices.append(i)
|
||||
|
||||
if len(ai_indices) < 2:
|
||||
return {"status": "ok", "merged": False}
|
||||
|
||||
idx1, idx2 = ai_indices[-2], ai_indices[-1]
|
||||
msg1, msg2 = session.history[idx1], session.history[idx2]
|
||||
|
||||
content1 = msg1.content if isinstance(msg1, ChatMessage) else msg1.get('content', '')
|
||||
content2 = msg2.content if isinstance(msg2, ChatMessage) else msg2.get('content', '')
|
||||
merged_content = content1 + separator + content2
|
||||
|
||||
# Merge metadata
|
||||
meta1 = (msg1.metadata if isinstance(msg1, ChatMessage) else msg1.get('metadata')) or {}
|
||||
meta2 = (msg2.metadata if isinstance(msg2, ChatMessage) else msg2.get('metadata')) or {}
|
||||
merged_meta = {**meta1, **meta2}
|
||||
merged_meta.pop('stopped', None) # no longer stopped after continue
|
||||
|
||||
# Update first message, remove second
|
||||
if isinstance(msg1, ChatMessage):
|
||||
msg1.content = merged_content
|
||||
msg1.metadata = merged_meta
|
||||
else:
|
||||
msg1['content'] = merged_content
|
||||
msg1['metadata'] = merged_meta
|
||||
|
||||
# Also remove the hidden "continue" user message between them if present
|
||||
# It's the message at idx2-1 if it's a user message with continue text
|
||||
remove_indices = [idx2]
|
||||
if idx2 - 1 > idx1:
|
||||
between = session.history[idx2 - 1]
|
||||
between_role = between.role if isinstance(between, ChatMessage) else between.get('role', '')
|
||||
between_content = between.content if isinstance(between, ChatMessage) else between.get('content', '')
|
||||
if between_role == 'user' and 'previous response was interrupted' in between_content:
|
||||
remove_indices.insert(0, idx2 - 1)
|
||||
|
||||
for ri in sorted(remove_indices, reverse=True):
|
||||
session.history.pop(ri)
|
||||
|
||||
# Update DB
|
||||
db = SessionLocal()
|
||||
try:
|
||||
import json as _json
|
||||
db_messages = (
|
||||
db.query(DbChatMessage)
|
||||
.filter(DbChatMessage.session_id == session_id)
|
||||
.order_by(DbChatMessage.created_at)
|
||||
.all()
|
||||
)
|
||||
# Find last two assistant messages in DB
|
||||
ai_db = [(i, m) for i, m in enumerate(db_messages) if m.role == 'assistant']
|
||||
if len(ai_db) >= 2:
|
||||
(_, db1), (_, db2) = ai_db[-2], ai_db[-1]
|
||||
db1.content = merged_content
|
||||
db1.meta_data = _json.dumps(merged_meta)
|
||||
|
||||
# Remove the continue user message if between them
|
||||
db_idx2 = db_messages.index(db2)
|
||||
db_idx1 = db_messages.index(db1)
|
||||
for di in range(db_idx2, db_idx1, -1):
|
||||
db.delete(db_messages[di])
|
||||
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
session_manager.save_sessions()
|
||||
return {"status": "ok", "merged": True}
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Merge assistant error {session_id}: {e}")
|
||||
raise HTTPException(500, str(e))
|
||||
|
||||
@router.post("/api/session/{session_id}/fork")
|
||||
async def fork_session(request: Request, session_id: str):
|
||||
"""Create a new session with messages copied up to keep_count."""
|
||||
_verify_session_owner(request, session_id)
|
||||
try:
|
||||
body = await request.json()
|
||||
keep_count = body.get("keep_count", 0)
|
||||
|
||||
# Get the source session
|
||||
source = session_manager.sessions.get(session_id)
|
||||
if not source:
|
||||
raise HTTPException(404, "Session not found")
|
||||
|
||||
# Create new session
|
||||
new_id = str(uuid.uuid4())
|
||||
fork_name = f"\u2ADD {source.name}"
|
||||
new_session = session_manager.create_session(
|
||||
session_id=new_id,
|
||||
name=fork_name,
|
||||
endpoint_url=source.endpoint_url,
|
||||
model=source.model,
|
||||
rag=False,
|
||||
owner=getattr(source, 'owner', None),
|
||||
)
|
||||
|
||||
# Copy messages up to keep_count
|
||||
msgs_to_copy = source.history[:keep_count]
|
||||
for msg in msgs_to_copy:
|
||||
new_session.add_message(ChatMessage(msg.role, msg.content, msg.metadata))
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("session_created", getattr(source, 'owner', None))
|
||||
except Exception:
|
||||
logger.debug("session_created event dispatch failed", exc_info=True)
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"id": new_id,
|
||||
"name": fork_name,
|
||||
"kept": len(msgs_to_copy),
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Fork error {session_id}: {e}")
|
||||
raise HTTPException(500, str(e))
|
||||
|
||||
@router.get("/api/conversations/topics")
|
||||
async def get_conversation_topics(request: Request) -> Dict[str, Any]:
|
||||
from src.auth_helpers import get_current_user
|
||||
user = get_current_user(request)
|
||||
try:
|
||||
return analyze_topics(session_manager, owner=user)
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"Topic analysis failed: {e}")
|
||||
|
||||
@router.post("/api/session/{session_id}/compact")
|
||||
async def compact_session(request: Request, session_id: str):
|
||||
"""Manually trigger context compaction for a session."""
|
||||
_verify_session_owner(request, session_id)
|
||||
try:
|
||||
session = session_manager.get_session(session_id)
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
|
||||
try:
|
||||
from src.model_context import estimate_tokens, get_context_length
|
||||
from src.llm_core import llm_call_async
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
|
||||
if len(session.history) < 6:
|
||||
return {"status": "ok", "message": "Not enough messages to compact"}
|
||||
|
||||
ctx_len = get_context_length(session.endpoint_url, session.model)
|
||||
messages_before = session.get_context_messages()
|
||||
used_before = estimate_tokens(messages_before)
|
||||
pct_before = round((used_before / ctx_len) * 100, 1) if ctx_len else 0
|
||||
msg_count_before = len(session.history)
|
||||
|
||||
# Keep only last 4 messages, summarize the rest
|
||||
keep_count = 4
|
||||
older = session.history[:-keep_count]
|
||||
recent = session.history[-keep_count:]
|
||||
|
||||
# Build text to summarize
|
||||
convo_text = "\n".join(
|
||||
f"{(m.role if isinstance(m, ChatMessage) else m.get('role', '')).upper()}: "
|
||||
f"{(m.content if isinstance(m, ChatMessage) else m.get('content', ''))[:2000]}"
|
||||
for m in older
|
||||
)
|
||||
|
||||
# Use utility model if available
|
||||
util_url, util_model, util_headers = resolve_endpoint("utility")
|
||||
compact_url = util_url or session.endpoint_url
|
||||
compact_model = util_model or session.model
|
||||
compact_headers = util_headers if util_url else session.headers
|
||||
|
||||
from src.context_compactor import SELF_SUMMARY_SYSTEM_PROMPT
|
||||
compaction_count = sum(1 for m in session.history if isinstance(m, ChatMessage) and "[Conversation summary" in (m.content or ""))
|
||||
sys_prompt = SELF_SUMMARY_SYSTEM_PROMPT.replace("{count}", str(len(older))).replace("{n}", str(compaction_count + 1))
|
||||
summary = await llm_call_async(
|
||||
compact_url, compact_model,
|
||||
[
|
||||
{"role": "system", "content": sys_prompt},
|
||||
{"role": "user", "content": convo_text},
|
||||
],
|
||||
temperature=0.2, max_tokens=1024,
|
||||
headers=compact_headers, timeout=30,
|
||||
)
|
||||
|
||||
# Replace session history: summary as system message + recent messages
|
||||
# System message holds the full summary for AI context
|
||||
system_summary = ChatMessage(
|
||||
role="system",
|
||||
content=f"[Conversation summary — {len(older)} earlier messages were compacted]\n\n{summary}",
|
||||
metadata={"compacted": True, "hidden": True},
|
||||
)
|
||||
# Visible assistant message just shows stats
|
||||
summary_msg = ChatMessage(
|
||||
role="assistant",
|
||||
content=f"**Conversation compacted** — {len(older)} messages summarized, {len(recent)} kept.",
|
||||
metadata={"compacted": True, "messages_removed": len(older)},
|
||||
)
|
||||
new_history = [system_summary, summary_msg] + list(recent)
|
||||
session.history = new_history
|
||||
session.message_count = len(session.history)
|
||||
logger.info(f"Compact: session {session_id} history now has {len(session.history)} messages (was {msg_count_before})")
|
||||
|
||||
# Update DB: delete old messages, insert summary
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_msgs = db.query(DbChatMessage).filter(
|
||||
DbChatMessage.session_id == session_id
|
||||
).order_by(DbChatMessage.timestamp).all()
|
||||
|
||||
# Delete all but the last keep_count
|
||||
for m in db_msgs[:-keep_count]:
|
||||
db.delete(m)
|
||||
|
||||
# Insert system summary (hidden, for AI context) and visible summary
|
||||
import json as _json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
now = datetime.now(timezone.utc)
|
||||
db_sys_summary = DbChatMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
session_id=session_id,
|
||||
role="system",
|
||||
content=system_summary.content,
|
||||
meta_data=_json.dumps(system_summary.metadata),
|
||||
timestamp=now,
|
||||
)
|
||||
db.add(db_sys_summary)
|
||||
db_summary = DbChatMessage(
|
||||
id=str(uuid.uuid4()),
|
||||
session_id=session_id,
|
||||
role="assistant",
|
||||
content=summary_msg.content,
|
||||
meta_data=_json.dumps(summary_msg.metadata),
|
||||
timestamp=now,
|
||||
)
|
||||
db.add(db_summary)
|
||||
|
||||
# Update session record
|
||||
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||
if db_session:
|
||||
db_session.message_count = len(session.history)
|
||||
db_session.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
session_manager.save_sessions()
|
||||
|
||||
used_after = estimate_tokens(session.get_context_messages())
|
||||
pct_after = round((used_after / ctx_len) * 100, 1) if ctx_len else 0
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"message": f"Compacted: {msg_count_before} msgs → {len(session.history)} msgs ({pct_before}% → {pct_after}%)",
|
||||
"before": pct_before,
|
||||
"after": pct_after,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Manual compact error {session_id}: {e}")
|
||||
raise HTTPException(500, str(e))
|
||||
|
||||
return router
|
||||
204
routes/hwfit_routes.py
Normal file
204
routes/hwfit_routes.py
Normal file
@@ -0,0 +1,204 @@
|
||||
from copy import deepcopy
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
|
||||
def setup_hwfit_routes():
|
||||
router = APIRouter(prefix="/api/hwfit", tags=["hwfit"])
|
||||
|
||||
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
|
||||
"""Manual hardware is a "what if I had this setup" simulator —
|
||||
REPLACES the detected hardware entirely instead of adding to it.
|
||||
|
||||
The previous additive behavior averaged the manual VRAM across
|
||||
all GPUs (base + manual), which meant adding "1× 400 GB" on top
|
||||
of "2× 70 GB" only nudged the per-GPU cap from 70 to 180 GB
|
||||
(= 540 / 3), so GGUF models bigger than that still didn't surface
|
||||
— exactly the "cap stuck at detected level" bug the user hit.
|
||||
"""
|
||||
manual_mode = (manual_mode or "").lower()
|
||||
if manual_mode not in {"gpu", "ram"}:
|
||||
return system
|
||||
|
||||
try:
|
||||
override_ram_gb = float(manual_ram_gb) if manual_ram_gb else 0
|
||||
except ValueError:
|
||||
override_ram_gb = 0
|
||||
override_ram_gb = max(0.0, override_ram_gb)
|
||||
if override_ram_gb:
|
||||
# Replace RAM, don't add. The number in the field is the
|
||||
# TOTAL system memory the user wants to simulate.
|
||||
system["available_ram_gb"] = round(override_ram_gb, 1)
|
||||
system["total_ram_gb"] = round(override_ram_gb, 1)
|
||||
system["manual_hardware"] = True
|
||||
|
||||
if manual_mode == "ram":
|
||||
# RAM-only simulation — wipe GPU entirely so the ranker uses
|
||||
# CPU/RAM paths.
|
||||
system["has_gpu"] = False
|
||||
system["gpu_name"] = None
|
||||
system["gpu_vram_gb"] = 0
|
||||
system["gpu_count"] = 0
|
||||
system["gpus"] = []
|
||||
system["gpu_groups"] = []
|
||||
system["backend"] = "cpu_x86"
|
||||
return system
|
||||
|
||||
try:
|
||||
count = int(manual_gpu_count) if manual_gpu_count else 1
|
||||
except ValueError:
|
||||
count = 1
|
||||
try:
|
||||
vram_each = float(manual_vram_gb) if manual_vram_gb else 8.0
|
||||
except ValueError:
|
||||
vram_each = 8.0
|
||||
count = max(1, min(count, 16))
|
||||
vram_each = max(1.0, vram_each)
|
||||
backend = (manual_backend or system.get("backend") or "cuda").lower()
|
||||
if backend not in {"cuda", "rocm", "cpu_x86", "cpu_arm"}:
|
||||
backend = "cuda"
|
||||
total_vram = round(vram_each * count, 1)
|
||||
gpu_name = f"Simulated {backend.upper()} GPU" + (f" × {count}" if count > 1 else "")
|
||||
system["has_gpu"] = True
|
||||
system["gpu_name"] = gpu_name
|
||||
system["gpu_vram_gb"] = total_vram
|
||||
system["gpu_count"] = count
|
||||
system["gpus"] = [
|
||||
{"index": i, "name": gpu_name, "vram_gb": vram_each}
|
||||
for i in range(count)
|
||||
]
|
||||
# Single homogeneous pool — vram_each here is the ACTUAL per-GPU
|
||||
# VRAM the user entered, not an average. That's the whole point:
|
||||
# raising vram_each lifts the per-GPU cap (GGUF, tensor-parallel
|
||||
# math) all the way up, not just by a small fraction.
|
||||
system["gpu_groups"] = [{
|
||||
"name": gpu_name,
|
||||
"vram_each": vram_each,
|
||||
"count": count,
|
||||
"indices": list(range(count)),
|
||||
"vram_total": total_vram,
|
||||
}]
|
||||
system["homogeneous"] = True
|
||||
system["backend"] = backend
|
||||
return system
|
||||
|
||||
@router.get("/system")
|
||||
def get_system(host: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False):
|
||||
"""Detect and return current system hardware info. Pass host=user@server for remote.
|
||||
fresh=true bypasses the per-host cache (the Rescan button)."""
|
||||
from services.hwfit.hardware import detect_system
|
||||
return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
|
||||
|
||||
@router.get("/models")
|
||||
def get_models(use_case: str = "", sort: str = "score", limit: int = 50, search: str = "", host: str = "", quant: str = "", gpu_count: str = "", gpu_group: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False, manual_mode: str = "", manual_gpu_count: str = "", manual_vram_gb: str = "", manual_ram_gb: str = "", manual_backend: str = "", ignore_detected_gpu: bool = False, ignore_detected_ram: bool = False):
|
||||
"""Rank LLM models against detected hardware and return scored results.
|
||||
gpu_count: override GPU count (0 = CPU only, 1-N = simulate N GPUs of the
|
||||
active group). gpu_group: index into system.gpu_groups (the homogeneous
|
||||
pools) to target — empty/auto = the largest pool. vLLM can only
|
||||
tensor-parallel across identical GPUs, so we never mix pools.
|
||||
fresh=true bypasses the hardware-detection cache."""
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.fit import rank_models
|
||||
from services.hwfit.models import get_models, model_catalog_path
|
||||
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
||||
if system.get("error"):
|
||||
return {"system": system, "models": [], "error": system["error"]}
|
||||
if not get_models():
|
||||
return {
|
||||
"system": system,
|
||||
"models": [],
|
||||
"error": f"Model catalog missing or empty: {model_catalog_path()}",
|
||||
}
|
||||
|
||||
if ignore_detected_gpu:
|
||||
system["has_gpu"] = False
|
||||
system["gpu_name"] = None
|
||||
system["gpu_vram_gb"] = 0
|
||||
system["gpu_count"] = 0
|
||||
system["gpus"] = []
|
||||
system["gpu_groups"] = []
|
||||
if ignore_detected_ram:
|
||||
system["available_ram_gb"] = 0
|
||||
system["total_ram_gb"] = 0
|
||||
|
||||
system = _apply_manual_hardware(system, manual_mode, manual_gpu_count, manual_vram_gb, manual_ram_gb, manual_backend)
|
||||
|
||||
# Keep the raw detection around so the UI can still show the box's full
|
||||
# GPU complement even while we rank against one homogeneous pool.
|
||||
system["detected_gpu_vram_gb"] = system.get("gpu_vram_gb")
|
||||
system["detected_gpu_count"] = system.get("gpu_count")
|
||||
|
||||
groups = system.get("gpu_groups") or []
|
||||
# Resolve the target homogeneous pool. Default (auto) = the largest pool,
|
||||
# which for a uniform box is simply "all the GPUs" — no behaviour change.
|
||||
grp = None
|
||||
if groups:
|
||||
try:
|
||||
gidx = int(gpu_group) if gpu_group != "" else 0
|
||||
except ValueError:
|
||||
gidx = 0
|
||||
if 0 <= gidx < len(groups):
|
||||
grp = groups[gidx]
|
||||
|
||||
def _apply_group(g, n):
|
||||
n = max(1, min(n, g["count"]))
|
||||
system["gpu_count"] = n
|
||||
system["gpu_vram_gb"] = round(g["vram_each"] * n, 1)
|
||||
system["gpu_name"] = g["name"]
|
||||
system["active_group"] = {**g, "use_count": n}
|
||||
|
||||
if gpu_count != "":
|
||||
n = int(gpu_count)
|
||||
if n == 0:
|
||||
# RAM-only mode: rank against system memory, offload allowed.
|
||||
system["has_gpu"] = False
|
||||
system["gpu_vram_gb"] = 0
|
||||
system["gpu_count"] = 0
|
||||
system["gpu_only"] = False
|
||||
system.pop("active_group", None)
|
||||
elif grp:
|
||||
_apply_group(grp, n)
|
||||
system["gpu_only"] = True
|
||||
else:
|
||||
# No per-GPU detail (older detection) — assume uniform split.
|
||||
single_vram = (system.get("gpu_vram_gb") or 0) / (system.get("gpu_count") or 1)
|
||||
system["gpu_count"] = max(1, n)
|
||||
system["gpu_vram_gb"] = round(single_vram * max(1, n), 1)
|
||||
system["gpu_only"] = True
|
||||
elif grp:
|
||||
# No explicit count, but we still pin to one pool so heterogeneous
|
||||
# boxes rank against a real mixable group, not a fictional VRAM sum.
|
||||
# gpu_only stays off here so the default view still surfaces offload.
|
||||
_apply_group(grp, grp["count"])
|
||||
|
||||
results = rank_models(system, use_case=use_case or None, limit=limit, search=search or None, sort=sort, quant=quant or None)
|
||||
return {"system": system, "models": results}
|
||||
|
||||
@router.get("/image-models")
|
||||
def get_image_models(sort: str = "fit", search: str = "", host: str = "", gpu_count: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False, manual_mode: str = "", manual_gpu_count: str = "", manual_vram_gb: str = "", manual_ram_gb: str = "", manual_backend: str = "", ignore_detected_gpu: bool = False, ignore_detected_ram: bool = False):
|
||||
"""Rank image generation models against detected hardware."""
|
||||
from services.hwfit.hardware import detect_system
|
||||
from services.hwfit.image_models import rank_image_models
|
||||
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
|
||||
if system.get("error"):
|
||||
return {"system": system, "models": [], "error": system["error"]}
|
||||
if ignore_detected_gpu:
|
||||
system["has_gpu"] = False
|
||||
system["gpu_name"] = None
|
||||
system["gpu_vram_gb"] = 0
|
||||
system["gpu_count"] = 0
|
||||
system["gpus"] = []
|
||||
system["gpu_groups"] = []
|
||||
if ignore_detected_ram:
|
||||
system["available_ram_gb"] = 0
|
||||
system["total_ram_gb"] = 0
|
||||
system = _apply_manual_hardware(system, manual_mode, manual_gpu_count, manual_vram_gb, manual_ram_gb, manual_backend)
|
||||
# Image models use a single GPU — always use per-GPU VRAM
|
||||
gpu_vrams = [float(g.get("vram_gb") or 0) for g in (system.get("gpus") or []) if isinstance(g, dict)]
|
||||
single_vram = max(gpu_vrams) if gpu_vrams else ((system.get("gpu_vram_gb") or 0) / max(system.get("gpu_count") or 1, 1))
|
||||
system["gpu_vram_gb"] = single_vram
|
||||
system["gpu_count"] = 1 if single_vram > 0 else 0
|
||||
results = rank_image_models(system, search=search or None, sort=sort)
|
||||
return {"system": system, "models": results}
|
||||
|
||||
return router
|
||||
574
routes/mcp_routes.py
Normal file
574
routes/mcp_routes.py
Normal file
@@ -0,0 +1,574 @@
|
||||
# routes/mcp_routes.py
|
||||
"""MCP (Model Context Protocol) server management routes."""
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import html
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import RedirectResponse, HTMLResponse
|
||||
import logging
|
||||
import httpx
|
||||
|
||||
from core.database import McpServer, SessionLocal
|
||||
from core.middleware import require_admin
|
||||
from src.mcp_manager import McpManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/mcp", tags=["mcp"])
|
||||
|
||||
|
||||
def _load_disabled_map():
|
||||
"""Load per-server disabled tool sets from DB."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
disabled_map = {}
|
||||
for srv in db.query(McpServer).all():
|
||||
if srv.disabled_tools:
|
||||
try:
|
||||
names = json.loads(srv.disabled_tools)
|
||||
if names:
|
||||
disabled_map[srv.id] = set(names)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return disabled_map
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def setup_mcp_routes(mcp_manager: McpManager):
|
||||
"""Setup MCP routes with the provided manager."""
|
||||
|
||||
@router.get("/servers")
|
||||
def list_servers(request: Request):
|
||||
"""List all configured MCP servers with connection status."""
|
||||
require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
servers = db.query(McpServer).all()
|
||||
result = []
|
||||
for srv in servers:
|
||||
status = mcp_manager.get_server_status(srv.id)
|
||||
oauth_cfg = json.loads(srv.oauth_config) if srv.oauth_config else None
|
||||
needs_oauth = False
|
||||
if oauth_cfg:
|
||||
token_file = os.path.expanduser(oauth_cfg.get("token_file", ""))
|
||||
needs_oauth = token_file and not os.path.exists(token_file)
|
||||
disabled_list = json.loads(srv.disabled_tools) if srv.disabled_tools else []
|
||||
total_tools = status.get("tool_count", 0)
|
||||
result.append({
|
||||
"id": srv.id,
|
||||
"name": srv.name,
|
||||
"transport": srv.transport,
|
||||
"command": srv.command,
|
||||
"args": json.loads(srv.args) if srv.args else [],
|
||||
"env": json.loads(srv.env) if srv.env else {},
|
||||
"url": srv.url,
|
||||
"is_enabled": srv.is_enabled,
|
||||
"status": status.get("status", "disconnected"),
|
||||
"tool_count": total_tools,
|
||||
"disabled_tool_count": len(disabled_list),
|
||||
"enabled_tool_count": max(0, total_tools - len(disabled_list)),
|
||||
"error": status.get("error"),
|
||||
"has_oauth": oauth_cfg is not None,
|
||||
"needs_oauth": needs_oauth,
|
||||
})
|
||||
return result
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/servers")
|
||||
async def add_server(
|
||||
request: Request,
|
||||
name: str = Form(...),
|
||||
transport: str = Form("stdio"),
|
||||
command: str = Form(None),
|
||||
args: str = Form("[]"),
|
||||
env: str = Form("{}"),
|
||||
url: str = Form(None),
|
||||
oauth_file: str = Form(None),
|
||||
oauth_config: str = Form(None),
|
||||
):
|
||||
"""Add a new MCP server config and attempt connection. Admin-only:
|
||||
registering a stdio server is equivalent to executing arbitrary
|
||||
binaries on the host."""
|
||||
require_admin(request)
|
||||
server_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Validate
|
||||
if transport == "stdio" and not command:
|
||||
raise HTTPException(400, "command is required for stdio transport")
|
||||
if transport == "sse" and not url:
|
||||
raise HTTPException(400, "url is required for SSE transport")
|
||||
|
||||
# Parse JSON fields
|
||||
try:
|
||||
parsed_args = json.loads(args) if args else []
|
||||
except json.JSONDecodeError:
|
||||
parsed_args = []
|
||||
try:
|
||||
parsed_env = json.loads(env) if env else {}
|
||||
except json.JSONDecodeError:
|
||||
parsed_env = {}
|
||||
|
||||
# Parse OAuth config
|
||||
parsed_oauth_config = None
|
||||
if oauth_config:
|
||||
try:
|
||||
parsed_oauth_config = json.loads(oauth_config)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Write OAuth credentials file if provided (for Google MCP servers)
|
||||
logger.info(f"MCP add_server: oauth_file={oauth_file!r}")
|
||||
if oauth_file:
|
||||
try:
|
||||
oauth_data = json.loads(oauth_file)
|
||||
oauth_dir = os.path.expanduser(oauth_data.get("dir", ""))
|
||||
oauth_filename = oauth_data.get("filename", "")
|
||||
client_id = oauth_data.get("client_id", "")
|
||||
client_secret = oauth_data.get("client_secret", "")
|
||||
if oauth_dir and oauth_filename and client_id and client_secret:
|
||||
os.makedirs(oauth_dir, exist_ok=True)
|
||||
creds = {
|
||||
"installed": {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"redirect_uris": ["http://localhost"],
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://accounts.google.com/o/oauth2/token",
|
||||
}
|
||||
}
|
||||
filepath = os.path.join(oauth_dir, oauth_filename)
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(creds, f, indent=2)
|
||||
logger.info(f"Wrote OAuth credentials to {filepath}")
|
||||
parsed_env.pop("GOOGLE_CLIENT_ID", None)
|
||||
parsed_env.pop("GOOGLE_CLIENT_SECRET", None)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning(f"Failed to write OAuth file: {e}")
|
||||
|
||||
# Save to DB
|
||||
db = SessionLocal()
|
||||
try:
|
||||
srv = McpServer(
|
||||
id=server_id,
|
||||
name=name,
|
||||
transport=transport,
|
||||
command=command,
|
||||
args=json.dumps(parsed_args),
|
||||
env=json.dumps(parsed_env),
|
||||
url=url,
|
||||
is_enabled=True,
|
||||
oauth_config=json.dumps(parsed_oauth_config) if parsed_oauth_config else None,
|
||||
)
|
||||
db.add(srv)
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Check if OAuth token already exists — skip connection attempt if not
|
||||
needs_oauth = False
|
||||
if parsed_oauth_config:
|
||||
token_file = os.path.expanduser(parsed_oauth_config.get("token_file", ""))
|
||||
if token_file and not os.path.exists(token_file):
|
||||
needs_oauth = True
|
||||
|
||||
connected = False
|
||||
if not needs_oauth:
|
||||
connected = await mcp_manager.connect_server(
|
||||
server_id=server_id,
|
||||
name=name,
|
||||
transport=transport,
|
||||
command=command,
|
||||
args=parsed_args,
|
||||
env=parsed_env,
|
||||
url=url,
|
||||
)
|
||||
|
||||
status = mcp_manager.get_server_status(server_id)
|
||||
return {
|
||||
"id": server_id,
|
||||
"name": name,
|
||||
"connected": connected,
|
||||
"status": "needs_oauth" if needs_oauth else status.get("status", "disconnected"),
|
||||
"tool_count": status.get("tool_count", 0),
|
||||
"error": "OAuth authorization required" if needs_oauth else status.get("error"),
|
||||
"needs_oauth": needs_oauth,
|
||||
}
|
||||
|
||||
@router.post("/servers/{server_id}/reconnect")
|
||||
async def reconnect_server(server_id: str, request: Request):
|
||||
"""Reconnect to an MCP server."""
|
||||
require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
|
||||
if not srv:
|
||||
raise HTTPException(404, "Server not found")
|
||||
|
||||
await mcp_manager.disconnect_server(server_id)
|
||||
|
||||
args = json.loads(srv.args) if srv.args else []
|
||||
env = json.loads(srv.env) if srv.env else {}
|
||||
connected = await mcp_manager.connect_server(
|
||||
server_id=server_id,
|
||||
name=srv.name,
|
||||
transport=srv.transport,
|
||||
command=srv.command,
|
||||
args=args,
|
||||
env=env,
|
||||
url=srv.url,
|
||||
)
|
||||
|
||||
status = mcp_manager.get_server_status(server_id)
|
||||
return {
|
||||
"connected": connected,
|
||||
"status": status.get("status", "disconnected"),
|
||||
"tool_count": status.get("tool_count", 0),
|
||||
"error": status.get("error"),
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.patch("/servers/{server_id}")
|
||||
async def toggle_server(server_id: str, request: Request, is_enabled: str = Form(...)):
|
||||
"""Enable or disable an MCP server."""
|
||||
require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
|
||||
if not srv:
|
||||
raise HTTPException(404, "Server not found")
|
||||
|
||||
enabled = str(is_enabled).lower() == "true"
|
||||
srv.is_enabled = enabled
|
||||
db.commit()
|
||||
|
||||
if enabled:
|
||||
args = json.loads(srv.args) if srv.args else []
|
||||
env = json.loads(srv.env) if srv.env else {}
|
||||
await mcp_manager.connect_server(
|
||||
server_id=server_id,
|
||||
name=srv.name,
|
||||
transport=srv.transport,
|
||||
command=srv.command,
|
||||
args=args,
|
||||
env=env,
|
||||
url=srv.url,
|
||||
)
|
||||
else:
|
||||
await mcp_manager.disconnect_server(server_id)
|
||||
|
||||
return {"id": server_id, "is_enabled": enabled}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.delete("/servers/{server_id}")
|
||||
async def delete_server(server_id: str, request: Request):
|
||||
"""Remove an MCP server."""
|
||||
require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
|
||||
if not srv:
|
||||
raise HTTPException(404, "Server not found")
|
||||
|
||||
await mcp_manager.disconnect_server(server_id)
|
||||
|
||||
db.delete(srv)
|
||||
db.commit()
|
||||
return {"status": "deleted"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/tools")
|
||||
def list_tools(request: Request):
|
||||
"""List all discovered MCP tools across all connected servers."""
|
||||
require_admin(request)
|
||||
disabled_map = _load_disabled_map()
|
||||
return mcp_manager.get_all_tools(disabled_map)
|
||||
|
||||
@router.get("/servers/{server_id}/tools")
|
||||
def list_server_tools(server_id: str, request: Request):
|
||||
"""List all tools for a specific MCP server with enabled/disabled state."""
|
||||
require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
|
||||
if not srv:
|
||||
raise HTTPException(404, "Server not found")
|
||||
disabled_list = json.loads(srv.disabled_tools) if srv.disabled_tools else []
|
||||
disabled_set = set(disabled_list)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
all_tools = mcp_manager.get_all_tools()
|
||||
server_tools = [t for t in all_tools if t["server_id"] == server_id]
|
||||
for t in server_tools:
|
||||
t["is_disabled"] = t["name"] in disabled_set
|
||||
return server_tools
|
||||
|
||||
@router.patch("/servers/{server_id}/tools")
|
||||
async def update_disabled_tools(server_id: str, request: Request):
|
||||
"""Bulk update disabled tools list for a server.
|
||||
|
||||
Expects JSON body: {"disabled": ["tool_name_1", "tool_name_2"]}
|
||||
"""
|
||||
require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
|
||||
if not srv:
|
||||
raise HTTPException(404, "Server not found")
|
||||
|
||||
body = await request.json()
|
||||
disabled = body.get("disabled", [])
|
||||
if not isinstance(disabled, list):
|
||||
raise HTTPException(400, "disabled must be a list of tool names")
|
||||
|
||||
srv.disabled_tools = json.dumps(disabled) if disabled else None
|
||||
db.commit()
|
||||
|
||||
return {"id": server_id, "disabled_count": len(disabled)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# ── OAuth flow for Google MCP servers ──────────────────────────
|
||||
|
||||
@router.get("/oauth/authorize/{server_id}")
|
||||
def oauth_authorize(server_id: str, request: Request):
|
||||
"""Show OAuth authorization page with Google sign-in link."""
|
||||
require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
|
||||
if not srv:
|
||||
raise HTTPException(404, "Server not found")
|
||||
if not srv.oauth_config:
|
||||
raise HTTPException(400, "Server has no OAuth config")
|
||||
|
||||
oauth_cfg = json.loads(srv.oauth_config)
|
||||
keys_file = os.path.expanduser(oauth_cfg.get("keys_file", ""))
|
||||
if not keys_file or not os.path.exists(keys_file):
|
||||
raise HTTPException(400, "OAuth keys file not found")
|
||||
|
||||
with open(keys_file) as f:
|
||||
keys_data = json.load(f)
|
||||
keys = keys_data.get("installed") or keys_data.get("web")
|
||||
if not keys:
|
||||
raise HTTPException(400, "Invalid OAuth keys file format")
|
||||
|
||||
client_id = keys["client_id"]
|
||||
scopes = oauth_cfg.get("scopes", [])
|
||||
|
||||
# For Desktop App creds, redirect to localhost — the user will
|
||||
# paste the resulting URL back if they're on a different device.
|
||||
redirect_uri = "http://localhost:7000/api/mcp/oauth/callback"
|
||||
|
||||
params = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(scopes),
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"state": server_id,
|
||||
}
|
||||
auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urllib.parse.urlencode(params)
|
||||
|
||||
# Determine if user is accessing from the same machine
|
||||
host = request.headers.get("host", "")
|
||||
is_local = host.startswith("localhost") or host.startswith("127.0.0.1")
|
||||
|
||||
if is_local:
|
||||
# Same machine — just redirect, callback will work directly
|
||||
return RedirectResponse(auth_url)
|
||||
else:
|
||||
# Remote device — show paste-back page
|
||||
return HTMLResponse(_oauth_authorize_page(auth_url, server_id, host))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/oauth/callback")
|
||||
async def oauth_callback(code: str, state: str, request: Request):
|
||||
"""Handle OAuth callback from Google — exchange code for tokens."""
|
||||
require_admin(request)
|
||||
server_id = state
|
||||
return await _exchange_and_connect(server_id, code, request)
|
||||
|
||||
@router.post("/oauth/exchange/{server_id}")
|
||||
async def oauth_exchange(server_id: str, request: Request, callback_url: str = Form(...)):
|
||||
"""Manual code exchange — user pastes the callback URL from their browser."""
|
||||
require_admin(request)
|
||||
try:
|
||||
parsed = urllib.parse.urlparse(callback_url)
|
||||
params = urllib.parse.parse_qs(parsed.query)
|
||||
code = params.get("code", [None])[0]
|
||||
if not code:
|
||||
return HTMLResponse(_oauth_result_page("Error", "No authorization code found in the URL. Make sure you copied the full URL from your browser."), status_code=400)
|
||||
except Exception:
|
||||
return HTMLResponse(_oauth_result_page("Error", "Invalid URL format."), status_code=400)
|
||||
|
||||
return await _exchange_and_connect(server_id, code, request)
|
||||
|
||||
async def _exchange_and_connect(server_id: str, code: str, request: Request):
|
||||
"""Exchange auth code for tokens and connect the MCP server."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
|
||||
if not srv:
|
||||
return HTMLResponse(_oauth_result_page("Error", "Server not found."), status_code=404)
|
||||
if not srv.oauth_config:
|
||||
return HTMLResponse(_oauth_result_page("Error", "No OAuth config."), status_code=400)
|
||||
|
||||
oauth_cfg = json.loads(srv.oauth_config)
|
||||
keys_file = os.path.expanduser(oauth_cfg.get("keys_file", ""))
|
||||
token_file = os.path.expanduser(oauth_cfg.get("token_file", ""))
|
||||
|
||||
with open(keys_file) as f:
|
||||
keys_data = json.load(f)
|
||||
keys = keys_data.get("installed") or keys_data.get("web")
|
||||
client_id = keys["client_id"]
|
||||
client_secret = keys["client_secret"]
|
||||
|
||||
redirect_uri = "http://localhost:7000/api/mcp/oauth/callback"
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(
|
||||
"https://oauth2.googleapis.com/token",
|
||||
data={
|
||||
"code": code,
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
err = resp.text
|
||||
logger.error(f"OAuth token exchange failed: {err}")
|
||||
return HTMLResponse(_oauth_result_page("Authorization Failed", f"Google returned an error: {err}"), status_code=400)
|
||||
|
||||
tokens = resp.json()
|
||||
logger.info(f"OAuth tokens received for server {server_id}")
|
||||
|
||||
# Save tokens to the file the MCP package expects
|
||||
os.makedirs(os.path.dirname(token_file), exist_ok=True)
|
||||
with open(token_file, "w") as f:
|
||||
json.dump(tokens, f, indent=2)
|
||||
logger.info(f"Saved OAuth tokens to {token_file}")
|
||||
|
||||
# Attempt to connect the MCP server now
|
||||
args = json.loads(srv.args) if srv.args else []
|
||||
env = json.loads(srv.env) if srv.env else {}
|
||||
connected = await mcp_manager.connect_server(
|
||||
server_id=server_id,
|
||||
name=srv.name,
|
||||
transport=srv.transport,
|
||||
command=srv.command,
|
||||
args=args,
|
||||
env=env,
|
||||
url=srv.url,
|
||||
)
|
||||
|
||||
if connected:
|
||||
status = mcp_manager.get_server_status(server_id)
|
||||
tool_count = status.get("tool_count", 0)
|
||||
return HTMLResponse(_oauth_result_page(
|
||||
"Authorization Successful",
|
||||
f"{srv.name} connected with {tool_count} tools. You can close this window.",
|
||||
success=True,
|
||||
))
|
||||
else:
|
||||
status = mcp_manager.get_server_status(server_id)
|
||||
return HTMLResponse(_oauth_result_page(
|
||||
"Authorized but Connection Failed",
|
||||
f"Tokens saved, but the server failed to connect: {status.get('error', 'unknown error')}. Try reconnecting from Settings.",
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception(f"OAuth callback error: {e}")
|
||||
return HTMLResponse(_oauth_result_page("Error", str(e)), status_code=500)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def _oauth_authorize_page(auth_url: str, server_id: str, host: str) -> str:
|
||||
"""Page with Google sign-in link and URL paste-back form for remote access."""
|
||||
return f"""<!DOCTYPE html>
|
||||
<html><head>
|
||||
<meta charset="UTF-8"><title>Authorize — Odysseus</title>
|
||||
<style>
|
||||
body {{ font-family: 'Fira Code', monospace; background: #0f0f0f; color: #e0e0e0;
|
||||
display: flex; justify-content: center; align-items: center; min-height: 100vh; }}
|
||||
.card {{ background: #1a1a1a; border: 1px solid #333; border-radius: 12px;
|
||||
padding: 2rem; max-width: 480px; text-align: center; }}
|
||||
h2 {{ color: #e06c75; margin-bottom: 0.5rem; font-size: 1.1rem; }}
|
||||
p {{ color: #aaa; font-size: 0.82rem; line-height: 1.6; margin: 0.8rem 0; }}
|
||||
.step {{ text-align: left; color: #ccc; font-size: 0.82rem; line-height: 1.7; margin: 1rem 0; }}
|
||||
.step b {{ color: #e06c75; }}
|
||||
a.auth-link {{
|
||||
display: inline-block; margin: 1rem 0; padding: 0.6rem 1.5rem;
|
||||
background: #e06c75; color: #fff; text-decoration: none; border-radius: 6px;
|
||||
font-weight: 600; font-size: 0.9rem;
|
||||
}}
|
||||
a.auth-link:hover {{ background: #c55; }}
|
||||
input[type=text] {{
|
||||
width: 100%; padding: 0.5rem; margin: 0.5rem 0;
|
||||
background: #0f0f0f; border: 1px solid #333; border-radius: 6px;
|
||||
color: #e0e0e0; font-family: 'Fira Code', monospace; font-size: 0.8rem;
|
||||
}}
|
||||
input:focus {{ outline: none; border-color: #e06c75; }}
|
||||
button {{
|
||||
padding: 0.5rem 1.5rem; border: none; border-radius: 6px;
|
||||
background: #e06c75; color: #fff; font-weight: 600; cursor: pointer;
|
||||
font-family: 'Fira Code', monospace; font-size: 0.85rem; margin-top: 0.3rem;
|
||||
}}
|
||||
button:hover {{ background: #c55; }}
|
||||
.divider {{ border-top: 1px solid #333; margin: 1.2rem 0; }}
|
||||
</style></head>
|
||||
<body><div class="card">
|
||||
<h2>Authorize Google Account</h2>
|
||||
<div class="step">
|
||||
<b>1.</b> Click the button below to sign in with Google<br>
|
||||
<b>2.</b> After approving, your browser will show an error page — that's normal<br>
|
||||
<b>3.</b> Copy the full URL from your browser's address bar<br>
|
||||
<b>4.</b> Paste it below and click Connect
|
||||
</div>
|
||||
<a class="auth-link" href="{auth_url}" target="_blank" rel="noopener">Sign in with Google</a>
|
||||
<div class="divider"></div>
|
||||
<form method="POST" action="http://{host}/api/mcp/oauth/exchange/{server_id}">
|
||||
<p>Paste the URL from your browser after signing in:</p>
|
||||
<input type="text" name="callback_url" placeholder="http://localhost:7000/api/mcp/oauth/callback?code=..." required>
|
||||
<br><button type="submit">Connect</button>
|
||||
</form>
|
||||
</div></body></html>"""
|
||||
|
||||
|
||||
def _oauth_result_page(title: str, message: str, success: bool = False) -> str:
|
||||
"""Generate a simple HTML page for the OAuth result."""
|
||||
safe_title = html.escape(title)
|
||||
safe_message = html.escape(message)
|
||||
color = "#00661a" if success else "#e06c75"
|
||||
icon = "✓" if success else "✗"
|
||||
return f"""<!DOCTYPE html>
|
||||
<html><head>
|
||||
<meta charset="UTF-8"><title>{safe_title}</title>
|
||||
<style>
|
||||
body {{ font-family: 'Fira Code', monospace; background: #0f0f0f; color: #e0e0e0;
|
||||
display: flex; justify-content: center; align-items: center; min-height: 100vh; }}
|
||||
.card {{ background: #1a1a1a; border: 1px solid #333; border-radius: 12px;
|
||||
padding: 2rem; max-width: 420px; text-align: center; }}
|
||||
.icon {{ font-size: 3rem; color: {color}; margin-bottom: 1rem; }}
|
||||
h2 {{ color: {color}; margin-bottom: 0.5rem; font-size: 1.1rem; }}
|
||||
p {{ color: #aaa; font-size: 0.85rem; line-height: 1.5; }}
|
||||
</style></head>
|
||||
<body><div class="card">
|
||||
<div class="icon">{icon}</div>
|
||||
<h2>{safe_title}</h2>
|
||||
<p>{safe_message}</p>
|
||||
</div></body></html>"""
|
||||
517
routes/memory_routes.py
Normal file
517
routes/memory_routes.py
Normal file
@@ -0,0 +1,517 @@
|
||||
# routes/memory_routes.py
|
||||
from fastapi import APIRouter, Form, HTTPException, Request, UploadFile, File
|
||||
from typing import Dict, Any, Optional, List
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
# Leading list-marker like "1.", "12)", or "3:" plus surrounding whitespace.
|
||||
# Strips one prefix per call so import-from-LLM-output doesn't leave the
|
||||
# numbering inside the saved memory text. Bullet markers (-, *, •) are
|
||||
# also peeled here for the same reason.
|
||||
_LIST_PREFIX_RE = re.compile(r"^\s*(?:\d{1,3}[.):]\s+|[-*•]\s+)")
|
||||
|
||||
|
||||
def _strip_list_prefix(text: str) -> str:
|
||||
if not text:
|
||||
return text
|
||||
return _LIST_PREFIX_RE.sub("", text, count=1).strip()
|
||||
|
||||
from services.memory import MemoryManager
|
||||
from core.session_manager import SessionManager
|
||||
from src.request_models import MemoryAddRequest
|
||||
from core.database import SessionLocal
|
||||
from src.llm_core import llm_call_async
|
||||
from services.memory.memory_extractor import audit_memories
|
||||
from src.auth_helpers import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionManager, memory_vector=None):
|
||||
"""Set up memory-related routes."""
|
||||
router = APIRouter(prefix="/api/memory", tags=["memory"])
|
||||
|
||||
def _owner(request: Request) -> Optional[str]:
|
||||
return get_current_user(request)
|
||||
|
||||
def _verify_memory_owner(memory: dict, user: Optional[str]):
|
||||
"""Raise 404 if user doesn't own this memory.
|
||||
|
||||
SECURITY: strict ownership — previously `mem_owner and mem_owner != user`
|
||||
allowed any user to read/edit/delete memories with an empty/null owner
|
||||
field, which leaked legacy data across the multi-user deploy.
|
||||
"""
|
||||
if user is None:
|
||||
return # Auth disabled
|
||||
if memory.get("owner") != user:
|
||||
raise HTTPException(404, "Memory not found")
|
||||
|
||||
@router.post("/debug")
|
||||
def debug_memory_relevance(request: Request, query: str = Form(...)):
|
||||
"""Debug which memories would be triggered for a query"""
|
||||
user = _owner(request)
|
||||
memories = memory_manager.load(owner=user)
|
||||
relevant = memory_manager.get_relevant_memories(query, memories, threshold=0.05)
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"total_memories": len(memories),
|
||||
"relevant_count": len(relevant),
|
||||
"relevant_memories": [{"text": m["text"], "category": m.get("category", "unknown")}
|
||||
for m in relevant]
|
||||
}
|
||||
|
||||
@router.post("/add", response_model=Dict[str, Any])
|
||||
async def api_add_memory(
|
||||
request: Request,
|
||||
memory_data: Optional[MemoryAddRequest] = None
|
||||
):
|
||||
"""Add a new memory entry with optional category, source, and session reference."""
|
||||
from src.auth_helpers import require_privilege
|
||||
require_privilege(request, "can_manage_memory")
|
||||
if memory_data is None:
|
||||
form = await request.form()
|
||||
memory_data = MemoryAddRequest(
|
||||
text=form.get("text"),
|
||||
category=form.get("category", "fact"),
|
||||
source=form.get("source", "user"),
|
||||
session_id=form.get("session_id")
|
||||
)
|
||||
|
||||
user = _owner(request)
|
||||
text = (memory_data.text or "").strip()
|
||||
if not text:
|
||||
raise HTTPException(400, "empty memory")
|
||||
user_mem = memory_manager.load(owner=user)
|
||||
if memory_manager.find_duplicates(text, user_mem):
|
||||
return {"ok": True, "count": len(user_mem), "message": "Memory already exists"}
|
||||
|
||||
new_entry = memory_manager.add_entry(text, memory_data.source, memory_data.category, owner=user)
|
||||
if memory_data.session_id:
|
||||
new_entry["session_id"] = memory_data.session_id
|
||||
all_mem = memory_manager.load_all()
|
||||
all_mem.append(new_entry)
|
||||
memory_manager.save(all_mem)
|
||||
# Sync vector index
|
||||
if memory_vector and memory_vector.healthy:
|
||||
memory_vector.add(new_entry["id"], text)
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("memory_added", user)
|
||||
except Exception:
|
||||
logger.debug("memory_added event dispatch failed", exc_info=True)
|
||||
return {"ok": True, "count": len([m for m in all_mem if m.get("owner") == user])}
|
||||
|
||||
@router.get("")
|
||||
def api_get_memory(request: Request):
|
||||
"""Return all memory entries with their metadata."""
|
||||
user = _owner(request)
|
||||
return {"memory": memory_manager.load(owner=user)}
|
||||
|
||||
@router.post("/search")
|
||||
def search_memories(request: Request, query: str = Form(...), session_id: str = Form(None), category: str = Form(None)):
|
||||
"""Search across all memories with optional filters."""
|
||||
user = _owner(request)
|
||||
memories = memory_manager.load(owner=user)
|
||||
|
||||
if session_id:
|
||||
memories = [m for m in memories if m.get("session_id") == session_id]
|
||||
|
||||
if category:
|
||||
memories = [m for m in memories if category in m.get("categories", [m.get("category", "")])]
|
||||
|
||||
relevant = memory_manager.get_relevant_memories(query, memories, threshold=0.05, max_items=20)
|
||||
|
||||
return {"memories": relevant, "total": len(relevant), "query": query}
|
||||
|
||||
@router.get("/timeline")
|
||||
def memory_timeline(request: Request):
|
||||
"""Get memories in chronological order with source session information."""
|
||||
user = _owner(request)
|
||||
memories = memory_manager.load(owner=user)
|
||||
sorted_memories = sorted(memories, key=lambda x: x.get("timestamp", 0), reverse=True)
|
||||
|
||||
results = []
|
||||
for memory in sorted_memories:
|
||||
if "timestamp" in memory:
|
||||
try:
|
||||
dt = datetime.fromtimestamp(memory["timestamp"])
|
||||
memory["timestamp_str"] = dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
except (ValueError, OSError, OverflowError):
|
||||
memory["timestamp_str"] = "Unknown"
|
||||
else:
|
||||
memory["timestamp_str"] = "Unknown"
|
||||
|
||||
session_id = memory.get("session_id")
|
||||
if session_id and session_id in session_manager.sessions:
|
||||
session = session_manager.get_session(session_id)
|
||||
memory["session_name"] = session.name if session else f"Session {session_id[:6]}"
|
||||
else:
|
||||
memory["session_name"] = "Unknown"
|
||||
|
||||
results.append(memory)
|
||||
|
||||
return {"timeline": results, "total": len(results)}
|
||||
|
||||
@router.get("/by-session/{session_id}")
|
||||
def get_memory_by_session(request: Request, session_id: str):
|
||||
"""Get all memories associated with a specific session."""
|
||||
try:
|
||||
session_manager.get_session(session_id)
|
||||
except KeyError:
|
||||
raise HTTPException(404, f"Session {session_id} not found")
|
||||
|
||||
user = _owner(request)
|
||||
memories = memory_manager.load(owner=user)
|
||||
session_memories = [m for m in memories if m.get("session_id") == session_id]
|
||||
|
||||
session_memories.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
|
||||
|
||||
try:
|
||||
session = session_manager.get_session(session_id)
|
||||
session_name = session.name if session else f"Session {session_id[:6]}"
|
||||
except KeyError:
|
||||
session_name = f"Session {session_id[:6]}"
|
||||
|
||||
for memory in session_memories:
|
||||
memory["session_name"] = session_name
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"session_name": session_name,
|
||||
"memory_count": len(session_memories),
|
||||
"memories": session_memories
|
||||
}
|
||||
|
||||
@router.post("/extract")
|
||||
async def extract_memory(request: Request, session: str = Form(...)) -> Dict[str, List[str]]:
|
||||
"""Analyze a session's chat history and return memory suggestions."""
|
||||
if not get_current_user(request):
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
try:
|
||||
sess = session_manager.get_session(session)
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found")
|
||||
|
||||
system_msg = {
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a helpful assistant. Analyze the entire conversation history provided and extract any "
|
||||
"useful factual statements, contacts, addresses, phone numbers, or other information that the user "
|
||||
"might want to remember for future interactions. Return each piece of information as a JSON object "
|
||||
"with a 'text' field. For example: [{'text': 'Alice lives at 123 Main St'}, {'text': 'Bob works at Acme Corp'}]. "
|
||||
"Only include information that is specific and likely to be useful later."
|
||||
),
|
||||
}
|
||||
messages = [system_msg] + sess.get_context_messages()
|
||||
|
||||
try:
|
||||
suggestion_text = await llm_call_async(
|
||||
sess.endpoint_url,
|
||||
sess.model,
|
||||
messages,
|
||||
temperature=0.2,
|
||||
max_tokens=500,
|
||||
headers=sess.headers,
|
||||
)
|
||||
try:
|
||||
suggestions = json.loads(suggestion_text)
|
||||
if isinstance(suggestions, list):
|
||||
suggestions = [s if isinstance(s, str) else s.get("text", "") for s in suggestions]
|
||||
else:
|
||||
suggestions = []
|
||||
except json.JSONDecodeError:
|
||||
suggestions = [line.strip() for line in suggestion_text.splitlines() if line.strip()]
|
||||
|
||||
return {"suggestions": [s for s in suggestions if s]}
|
||||
except Exception as e:
|
||||
logger.error(f"LLM memory extraction failed (session {session}): {e}")
|
||||
fallback = memory_manager.extract_memory_from_chat(sess.history, session)
|
||||
return {"suggestions": [item["text"] for item in fallback]}
|
||||
|
||||
@router.post("/audit")
|
||||
async def api_audit_memories(request: Request, session: str = Form(None)):
|
||||
"""Deduplicate and consolidate memories via LLM.
|
||||
|
||||
Uses the default model from settings, or falls back to a session's model.
|
||||
Returns before and after memory counts.
|
||||
"""
|
||||
from routes.model_routes import _load_settings, _normalize_base, build_chat_url
|
||||
from core.database import ModelEndpoint
|
||||
import json as _json
|
||||
|
||||
endpoint_url = model = None
|
||||
headers = {}
|
||||
|
||||
# Try default model from settings first
|
||||
settings = _load_settings()
|
||||
ep_id = settings.get("default_endpoint_id", "")
|
||||
default_model = settings.get("default_model", "")
|
||||
if ep_id:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.id == ep_id, ModelEndpoint.is_enabled == True
|
||||
).first()
|
||||
if ep:
|
||||
base = _normalize_base(ep.base_url)
|
||||
endpoint_url = build_chat_url(base)
|
||||
model = default_model
|
||||
if not model and ep.models:
|
||||
try:
|
||||
models = _json.loads(ep.models) if isinstance(ep.models, str) else ep.models
|
||||
if models:
|
||||
model = models[0]
|
||||
except Exception:
|
||||
pass
|
||||
if ep.api_key:
|
||||
headers = {"Authorization": f"Bearer {ep.api_key}"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Fall back to session model if no default configured
|
||||
if not endpoint_url and session:
|
||||
try:
|
||||
sess = session_manager.get_session(session)
|
||||
endpoint_url = sess.endpoint_url
|
||||
model = sess.model
|
||||
headers = sess.headers
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if not endpoint_url or not model:
|
||||
raise HTTPException(400, "No default model configured — set one in Settings")
|
||||
|
||||
user = _owner(request)
|
||||
result = await audit_memories(
|
||||
memory_manager,
|
||||
memory_vector,
|
||||
endpoint_url,
|
||||
model,
|
||||
headers,
|
||||
owner=user,
|
||||
)
|
||||
|
||||
if "error" in result and "before" not in result:
|
||||
raise HTTPException(502, f"Audit failed: {result['error']}")
|
||||
|
||||
return {
|
||||
"ok": "error" not in result,
|
||||
"before": result.get("before", 0),
|
||||
"after": result.get("after", 0),
|
||||
"removed": result.get("before", 0) - result.get("after", 0),
|
||||
# True when the audit skipped the LLM because nothing changed
|
||||
# since the last tidy. Frontend already says "Already clean"
|
||||
# for removed==0, so this is here for future use / debugging.
|
||||
"already_tidy": bool(result.get("already_tidy")),
|
||||
}
|
||||
|
||||
@router.post("/import")
|
||||
async def import_memories_from_file(
|
||||
request: Request,
|
||||
session: str = Form(...),
|
||||
file: UploadFile = File(...)
|
||||
):
|
||||
"""Extract memory suggestions from an uploaded file (PDF, TXT, MD, etc.)."""
|
||||
from src.auth_helpers import require_privilege
|
||||
require_privilege(request, "can_manage_memory")
|
||||
try:
|
||||
sess = session_manager.get_session(session)
|
||||
except KeyError:
|
||||
raise HTTPException(404, "Session not found — needed for LLM config")
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
filename = file.filename or "upload"
|
||||
_, ext = os.path.splitext(filename.lower())
|
||||
|
||||
allowed = {".txt", ".md", ".pdf", ".csv", ".log", ".json", ".py", ".js", ".html"}
|
||||
if ext not in allowed:
|
||||
raise HTTPException(400, f"Unsupported file type: {ext}")
|
||||
|
||||
# Extract text based on file type
|
||||
if ext == ".pdf":
|
||||
from src.document_processor import _process_pdf
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
text = _process_pdf(tmp_path)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
else:
|
||||
try:
|
||||
text = content.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
from charset_normalizer import detect
|
||||
encoding = (detect(content) or {}).get("encoding") or "utf-8"
|
||||
text = content.decode(encoding, errors="replace")
|
||||
|
||||
if not text.strip():
|
||||
return {"suggestions": [], "message": "No readable content found"}
|
||||
|
||||
# Fast path: a .json upload that already looks like a memories export
|
||||
# (list of {text, category, ...} dicts, or list of strings) round-trips
|
||||
# directly without spending an LLM call to re-extract its own output.
|
||||
# Without this, re-importing a memories.json from another account
|
||||
# ran the file through the extractor, which often re-emitted the
|
||||
# entries as a numbered list (and the numbering leaked into the
|
||||
# `text` field).
|
||||
if ext == ".json":
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
if isinstance(parsed, list) and parsed:
|
||||
direct = []
|
||||
for item in parsed:
|
||||
if isinstance(item, dict) and item.get("text"):
|
||||
direct.append({
|
||||
"text": _strip_list_prefix(str(item["text"])),
|
||||
"category": item.get("category") or "fact",
|
||||
})
|
||||
elif isinstance(item, str) and item.strip():
|
||||
direct.append({
|
||||
"text": _strip_list_prefix(item.strip()),
|
||||
"category": "fact",
|
||||
})
|
||||
if direct:
|
||||
return {"suggestions": direct, "filename": filename}
|
||||
|
||||
# Truncate very long documents
|
||||
if len(text) > 15000:
|
||||
text = text[:15000] + "\n[Truncated]"
|
||||
|
||||
# Send to LLM for memory extraction
|
||||
import_prompt = (
|
||||
"You are a memory extraction assistant. The user uploaded a document. "
|
||||
"Analyze the text below and extract specific, useful facts — things like "
|
||||
"names, preferences, jobs, locations, relationships, opinions, projects, "
|
||||
"goals, contacts, or any other personal details worth remembering.\n\n"
|
||||
"Rules:\n"
|
||||
"- Each fact should be a short, self-contained statement\n"
|
||||
"- Do NOT extract generic knowledge\n"
|
||||
"- Focus on personal, memorable information\n"
|
||||
"- If there are no useful facts, return an empty array\n\n"
|
||||
"Return a JSON array of objects with 'text' and 'category' fields.\n"
|
||||
"Categories: 'identity', 'preference', 'fact', 'contact', 'project', 'goal'\n\n"
|
||||
"Return ONLY valid JSON, no markdown fences."
|
||||
)
|
||||
|
||||
try:
|
||||
raw = await llm_call_async(
|
||||
sess.endpoint_url,
|
||||
sess.model,
|
||||
[
|
||||
{"role": "system", "content": import_prompt},
|
||||
{"role": "user", "content": f"Document: {filename}\n\n{text}"},
|
||||
],
|
||||
temperature=0.2,
|
||||
max_tokens=2000,
|
||||
headers=sess.headers,
|
||||
)
|
||||
|
||||
# Parse JSON
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
raw = raw.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
|
||||
|
||||
suggestions = json.loads(raw)
|
||||
if isinstance(suggestions, list):
|
||||
normalized = []
|
||||
for s in suggestions:
|
||||
if not s:
|
||||
continue
|
||||
if isinstance(s, dict):
|
||||
s = dict(s)
|
||||
if s.get("text"):
|
||||
s["text"] = _strip_list_prefix(str(s["text"]))
|
||||
normalized.append(s)
|
||||
else:
|
||||
normalized.append({"text": _strip_list_prefix(str(s)), "category": "fact"})
|
||||
suggestions = normalized
|
||||
else:
|
||||
suggestions = []
|
||||
|
||||
return {"suggestions": suggestions, "filename": filename}
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# Fallback: split by lines, stripping any "1.", "2)" markdown-list
|
||||
# numbering the model added so saved memories don't keep the prefix.
|
||||
lines = [_strip_list_prefix(l.strip()) for l in raw.splitlines() if l.strip() and len(l.strip()) > 5]
|
||||
return {"suggestions": [{"text": l, "category": "fact"} for l in lines[:20]], "filename": filename}
|
||||
except Exception as e:
|
||||
logger.error(f"Memory import extraction failed: {e}")
|
||||
raise HTTPException(502, f"LLM extraction failed: {str(e)}")
|
||||
|
||||
@router.post("/{memory_id}/pin")
|
||||
def pin_memory(request: Request, memory_id: str, pinned: bool = Form(True)):
|
||||
"""Pin or unpin a memory. Pinned memories are always included in context."""
|
||||
user = _owner(request)
|
||||
all_mem = memory_manager.load_all()
|
||||
for i, memory in enumerate(all_mem):
|
||||
if memory["id"] == memory_id:
|
||||
_verify_memory_owner(memory, user)
|
||||
all_mem[i]["pinned"] = pinned
|
||||
memory_manager.save(all_mem)
|
||||
return {"ok": True, "pinned": pinned}
|
||||
raise HTTPException(404, f"Memory item {memory_id} not found")
|
||||
|
||||
# Wildcard routes MUST come last — otherwise they swallow /import, /search, etc.
|
||||
@router.get("/{memory_id}")
|
||||
def get_memory_item(request: Request, memory_id: str):
|
||||
"""Get a specific memory item by ID."""
|
||||
user = _owner(request)
|
||||
memories = memory_manager.load(owner=user)
|
||||
for memory in memories:
|
||||
if memory["id"] == memory_id:
|
||||
return {"memory": memory}
|
||||
|
||||
raise HTTPException(404, "Memory not found")
|
||||
|
||||
@router.put("/{memory_id}")
|
||||
def update_memory(request: Request, memory_id: str, text: str = Form(...), category: str = Form(None)):
|
||||
"""Update an existing memory item with new text and optional category."""
|
||||
user = _owner(request)
|
||||
all_mem = memory_manager.load_all()
|
||||
for i, memory in enumerate(all_mem):
|
||||
if memory["id"] == memory_id:
|
||||
_verify_memory_owner(memory, user)
|
||||
all_mem[i]["text"] = text.strip()
|
||||
if category:
|
||||
all_mem[i]["category"] = category
|
||||
all_mem[i]["timestamp"] = int(time.time())
|
||||
|
||||
memory_manager.save(all_mem)
|
||||
# Sync vector index (remove old, add updated)
|
||||
if memory_vector and memory_vector.healthy:
|
||||
memory_vector.remove(memory_id)
|
||||
memory_vector.add(memory_id, text.strip())
|
||||
return {"ok": True, "message": "Memory updated successfully"}
|
||||
|
||||
raise HTTPException(404, f"Memory item {memory_id} not found")
|
||||
|
||||
@router.delete("/{memory_id}")
|
||||
def delete_memory(request: Request, memory_id: str):
|
||||
"""Delete a memory item by its ID."""
|
||||
user = _owner(request)
|
||||
all_mem = memory_manager.load_all()
|
||||
|
||||
# Find and verify ownership before deleting
|
||||
target = next((m for m in all_mem if m["id"] == memory_id), None)
|
||||
if not target:
|
||||
raise HTTPException(404, f"Memory item {memory_id} not found")
|
||||
_verify_memory_owner(target, user)
|
||||
|
||||
all_mem = [m for m in all_mem if m["id"] != memory_id]
|
||||
memory_manager.save(all_mem)
|
||||
# Sync vector index
|
||||
if memory_vector and memory_vector.healthy:
|
||||
memory_vector.remove(memory_id)
|
||||
return {"ok": True, "message": "Memory deleted successfully"}
|
||||
|
||||
return router
|
||||
1226
routes/model_routes.py
Normal file
1226
routes/model_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
741
routes/note_routes.py
Normal file
741
routes/note_routes.py
Normal file
@@ -0,0 +1,741 @@
|
||||
# routes/note_routes.py
|
||||
"""Google Keep-style notes / checklists API."""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import SessionLocal, Note
|
||||
from src.auth_helpers import get_current_user
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class NoteCreate(BaseModel):
|
||||
title: str = ""
|
||||
content: Optional[str] = None
|
||||
items: Optional[list] = None
|
||||
note_type: str = "note"
|
||||
color: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
pinned: bool = False
|
||||
due_date: Optional[str] = None
|
||||
source: str = "user"
|
||||
session_id: Optional[str] = None
|
||||
image_url: Optional[str] = None
|
||||
repeat: Optional[str] = "none"
|
||||
sort_order: Optional[int] = None
|
||||
|
||||
|
||||
class NoteUpdate(BaseModel):
|
||||
title: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
items: Optional[list] = None
|
||||
note_type: Optional[str] = None
|
||||
color: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
pinned: Optional[bool] = None
|
||||
archived: Optional[bool] = None
|
||||
due_date: Optional[str] = None
|
||||
image_url: Optional[str] = None
|
||||
repeat: Optional[str] = None
|
||||
sort_order: Optional[int] = None
|
||||
agent_session_id: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _note_to_dict(note: Note) -> Dict[str, Any]:
|
||||
items = None
|
||||
if note.items:
|
||||
try:
|
||||
items = json.loads(note.items)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
items = None
|
||||
ai_cls = None
|
||||
raw_ai = getattr(note, "ai_classification", None)
|
||||
if raw_ai:
|
||||
try:
|
||||
ai_cls = json.loads(raw_ai)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
ai_cls = None
|
||||
return {
|
||||
"id": note.id,
|
||||
"owner": note.owner,
|
||||
"title": note.title,
|
||||
"content": note.content,
|
||||
"items": items,
|
||||
"note_type": note.note_type,
|
||||
"color": note.color,
|
||||
"label": note.label,
|
||||
"pinned": note.pinned,
|
||||
"archived": note.archived,
|
||||
"due_date": note.due_date,
|
||||
"source": note.source,
|
||||
"session_id": note.session_id,
|
||||
"sort_order": note.sort_order or 0,
|
||||
"image_url": note.image_url,
|
||||
"repeat": note.repeat or "none",
|
||||
"ai_classification": ai_cls,
|
||||
"ai_content_hash": getattr(note, "ai_content_hash", None),
|
||||
"agent_session_id": getattr(note, "agent_session_id", None),
|
||||
"created_at": note.created_at.isoformat() if note.created_at else None,
|
||||
"updated_at": note.updated_at.isoformat() if note.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reminder dispatch — module-level so background tasks (built-in actions)
|
||||
# can call it directly without an HTTP roundtrip + auth cookie. The route
|
||||
# version below is a thin wrapper that pulls `owner` from the request.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Scheduler reference — set by setup_note_routes() so dispatch_reminder can
|
||||
# push a parallel in-app notification (frontend polls the scheduler's queue
|
||||
# and fires real browser Notification(...) popups). Optional; works without it.
|
||||
_scheduler_ref = None
|
||||
|
||||
|
||||
async def dispatch_reminder(
|
||||
title: str,
|
||||
note_body: str,
|
||||
note_id: str,
|
||||
owner: str = "",
|
||||
queue_browser: bool = True,
|
||||
) -> dict:
|
||||
"""Fire a reminder via the configured channel (browser/email/ntfy).
|
||||
|
||||
Args:
|
||||
title: short headline shown to the user
|
||||
note_body: longer body text
|
||||
note_id: stable id (used as tag/dedupe in browser notifications)
|
||||
owner: the user this reminder belongs to — scopes SMTP config to
|
||||
their account so we don't cross-leak credentials
|
||||
|
||||
Returns: {synthesis, email_sent, ntfy_sent}. Browser channel is wired via
|
||||
the in-memory notification queue picked up by the frontend poller, so
|
||||
nothing is "sent" synchronously for it — the channel just routes there.
|
||||
"""
|
||||
from src.settings import load_settings
|
||||
settings = load_settings()
|
||||
channel = settings.get("reminder_channel", "browser")
|
||||
llm_on = bool(settings.get("reminder_llm_synthesis", False))
|
||||
title = (title or "").strip()
|
||||
note_body = (note_body or "").strip()
|
||||
cache_key = str(note_id) if note_id else ""
|
||||
cache = {}
|
||||
cache_path = None
|
||||
if cache_key:
|
||||
try:
|
||||
import json as _json
|
||||
from datetime import datetime as _dt, timezone as _tz, timedelta as _td
|
||||
from pathlib import Path as _P
|
||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||
cache_path = _P(f"data/note_pings_{_slug}.json")
|
||||
if cache_path.exists():
|
||||
cache = _json.loads(cache_path.read_text())
|
||||
last = cache.get(cache_key)
|
||||
if last:
|
||||
last_channel = None
|
||||
if isinstance(last, dict):
|
||||
last_channel = last.get("channel")
|
||||
last = last.get("at")
|
||||
last_dt = _dt.fromisoformat(str(last))
|
||||
if last_dt.tzinfo is None:
|
||||
last_dt = last_dt.replace(tzinfo=_tz.utc)
|
||||
# Legacy cache values were plain timestamps and could be
|
||||
# written by the frontend even when the email/ntfy send failed.
|
||||
# Treat those as browser-only dedupe so email reminders can be
|
||||
# retried by the backend scanner after a failed frontend path.
|
||||
should_skip = last_dt >= _dt.now(_tz.utc) - _td(minutes=25)
|
||||
if should_skip and channel in ("email", "ntfy"):
|
||||
should_skip = last_channel == channel
|
||||
if should_skip:
|
||||
return {
|
||||
"synthesis": None,
|
||||
"email_sent": False,
|
||||
"ntfy_sent": False,
|
||||
"browser_sent": True,
|
||||
"skipped": True,
|
||||
}
|
||||
except Exception as _e:
|
||||
logger.debug(f"dispatch_reminder: cache read failed: {_e}")
|
||||
|
||||
synthesis = None
|
||||
_SYNTH_FAILED_TAG = "[utility model unavailable — no summary generated]"
|
||||
if llm_on:
|
||||
try:
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
if not url:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
if url and model:
|
||||
raw = await llm_call_async(
|
||||
url=url, model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a reminder assistant. Write a single short, warm, motivating sentence (max 25 words) reminding the user about the note below. Do not add greetings, preamble, or hashtags. Output only the sentence."},
|
||||
{"role": "user", "content": f"Title: {title}\n\n{note_body}".strip()},
|
||||
],
|
||||
temperature=0.7, max_tokens=200, headers=headers, timeout=30,
|
||||
)
|
||||
from src.text_helpers import strip_think as _strip_think
|
||||
# prose=True strips untagged "The user wants me to…" chain-of-thought.
|
||||
# prompt_echo=True strips Qwen-style "Thinking Process:" / leaked
|
||||
# prompt prefixes. Both are safe here because this is a
|
||||
# one-sentence LLM-only output, not user-pasted content.
|
||||
synthesis = _strip_think(raw or "", prose=True, prompt_echo=True)
|
||||
# Reminder synthesis is supposed to be ONE sentence. Strip-think's
|
||||
# paragraph-based heuristic misses cases where the model puts
|
||||
# reasoning + answer on consecutive lines inside one paragraph
|
||||
# (e.g. "I should write... [\n] You have one task waiting...").
|
||||
# Walk lines, drop reasoning/prompt-echo lines, then keep the
|
||||
# last surviving line — that's the actual warm sentence.
|
||||
if synthesis:
|
||||
import re as _re
|
||||
# Tightened: target ACTUAL self-talk (model narrating what
|
||||
# it'll do) rather than any first-person sentence. The old
|
||||
# pattern killed legit warm sentences like "I'll see you
|
||||
# tomorrow" or "I should be done by then". New rules:
|
||||
# • "I (need|should|have|'ll|will) (write|draft|reply|…)"
|
||||
# only matches when followed by a TASK verb taking an
|
||||
# OBJECT (so first-person + intransitive verb passes).
|
||||
# • Self-instructional patterns the model emits verbatim:
|
||||
# "I should write something that reminds them…",
|
||||
# "I need to draft…", "Let me think…".
|
||||
# • Explicit instructions echoed back from the prompt:
|
||||
# "Keep it under 25 words", "No greetings".
|
||||
_reasoning = _re.compile(
|
||||
r"^\s*(?:"
|
||||
# "I should write/draft/compose…" with a task-object follow
|
||||
r"i (?:need|should|have|'ll|will|am going|am)\s+to\s+"
|
||||
r"(?:write|draft|compose|craft|generate|produce|create|"
|
||||
r"summarize|answer|provide|note|address|remind|output)"
|
||||
r"\s+(?:a |an |the |something|this|that|here|them|him|her|"
|
||||
r"you|user|reply|response|sentence|message|line|warm)|"
|
||||
# The model literally narrating about the user
|
||||
r"the user (?:wants|is asking|asks|needs|wrote|said|requested) (?:me )?(?:to|for|that|about|something)|"
|
||||
# "Let me [think/write/draft/…] (about/for/the …)"
|
||||
r"let me (?:think|write|draft|consider|note|see|check)\b\s+(?:about|for|the|this|that|if|whether)|"
|
||||
# "Looking at the/this/that …"
|
||||
r"looking at (?:the|this|that)\b|"
|
||||
# "Based on the/this/what …"
|
||||
r"based on (?:the|this|what|context|that)\b|"
|
||||
# Prompt-echo of length / style instructions
|
||||
r"keep it under \d+ words\b|"
|
||||
r"(?:no greetings|no preamble|no hashtags|just output the)\b"
|
||||
r").*",
|
||||
_re.IGNORECASE,
|
||||
)
|
||||
# Echo of the prompt's "Pending:" / "<N> pending" tail.
|
||||
_echo = _re.compile(
|
||||
r"^\s*(?:pending\s*[:.]|(?:\d+|one|two|three|four|five)\s+pending\b)",
|
||||
_re.IGNORECASE,
|
||||
)
|
||||
lines = [ln for ln in synthesis.splitlines() if ln.strip()]
|
||||
cleaned = [ln for ln in lines if not _reasoning.match(ln) and not _echo.match(ln)]
|
||||
if cleaned:
|
||||
# The model's actual answer is normally the LAST surviving
|
||||
# line — reasoning leads, answer trails.
|
||||
synthesis = cleaned[-1].strip()
|
||||
else:
|
||||
synthesis = _SYNTH_FAILED_TAG
|
||||
except Exception as e:
|
||||
logger.warning(f"Reminder LLM synthesis failed: {e}")
|
||||
synthesis = _SYNTH_FAILED_TAG
|
||||
if synthesis:
|
||||
_s = synthesis.strip(); _low = _s.lower()
|
||||
if (not _s or _low.startswith("error:") or _low.startswith("[error")
|
||||
or "operation failed" in _low
|
||||
or ("upstream" in _low and "failed" in _low)) and synthesis != _SYNTH_FAILED_TAG:
|
||||
logger.warning(f"Reminder synthesis looked like an error, replacing: {_s[:120]!r}")
|
||||
synthesis = _SYNTH_FAILED_TAG
|
||||
|
||||
email_sent = False
|
||||
email_error = ""
|
||||
if channel == "email":
|
||||
try:
|
||||
from routes.email_routes import _get_email_config
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from datetime import datetime as _dt
|
||||
# `reminder_email_account_id` lets the user pick WHICH email
|
||||
# account to send reminders from (when they have several
|
||||
# configured in Integrations). Falls back to the default
|
||||
# account when no explicit choice is saved.
|
||||
_acc_id = (settings.get("reminder_email_account_id") or "").strip() or None
|
||||
cfg = _get_email_config(account_id=_acc_id, owner=owner or "")
|
||||
if not (cfg.get("smtp_host") and cfg.get("smtp_user") and cfg.get("smtp_password")):
|
||||
try:
|
||||
from core.database import SessionLocal as _SL, EmailAccount as _EA
|
||||
from sqlalchemy import and_, or_
|
||||
db = _SL()
|
||||
try:
|
||||
q = db.query(_EA).filter(_EA.enabled == True) # noqa: E712
|
||||
if owner:
|
||||
unowned = or_(_EA.owner == None, _EA.owner == "") # noqa: E711
|
||||
same_mailbox = or_(_EA.imap_user == owner, _EA.from_address == owner)
|
||||
q = q.filter(or_(_EA.owner == owner, and_(unowned, same_mailbox)))
|
||||
for row in q.order_by(_EA.is_default.desc(), _EA.created_at.asc()).all():
|
||||
trial = _get_email_config(account_id=row.id, owner=owner or "")
|
||||
if trial.get("smtp_host") and trial.get("smtp_user") and trial.get("smtp_password"):
|
||||
cfg = trial
|
||||
break
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as _fallback_error:
|
||||
logger.debug(f"Reminder SMTP fallback lookup failed: {_fallback_error}")
|
||||
from_addr = (cfg.get("from_address") or cfg.get("smtp_user") or "").strip()
|
||||
recipient = (settings.get("reminder_email_to") or "").strip() or from_addr
|
||||
# Loud diagnostic so we can see WHY a reminder didn't send (the
|
||||
# previous "silently no-op when cfg has no smtp_host" was invisible).
|
||||
logger.info(
|
||||
f"dispatch_reminder[email] note_id={note_id} owner={owner!r} "
|
||||
f"smtp_host={cfg.get('smtp_host')!r} smtp_user={cfg.get('smtp_user')!r} "
|
||||
f"from={from_addr!r} recipient={recipient!r} "
|
||||
f"account_name={cfg.get('account_name')!r}"
|
||||
)
|
||||
missing = []
|
||||
if not cfg.get("smtp_host"):
|
||||
missing.append("SMTP host")
|
||||
if not cfg.get("smtp_user"):
|
||||
missing.append("SMTP user")
|
||||
if not cfg.get("smtp_password"):
|
||||
missing.append("SMTP password")
|
||||
if not from_addr:
|
||||
missing.append("from address")
|
||||
if not recipient:
|
||||
missing.append("recipient")
|
||||
if missing:
|
||||
email_error = "Missing " + ", ".join(missing)
|
||||
logger.warning(
|
||||
"Reminder email not sent for note_id=%s account=%r: %s",
|
||||
note_id, cfg.get("account_name"), email_error,
|
||||
)
|
||||
else:
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = recipient
|
||||
_t = title or 'Note'
|
||||
_t = _t[len('Reminder:'):].strip() if _t.lower().startswith('reminder:') else _t
|
||||
msg["Subject"] = f"Reminder (Odysseus): {_t}"
|
||||
msg["Date"] = _dt.utcnow().strftime("%a, %d %b %Y %H:%M:%S +0000")
|
||||
msg["X-Odysseus-Origin"] = "odysseus-ui"
|
||||
msg["X-Odysseus-Kind"] = "reminder"
|
||||
msg["X-Odysseus-Ref"] = str(note_id)
|
||||
# Body shape: synthesis (warm sentence) → blank line → bold
|
||||
# title header → note details. The title was previously only
|
||||
# in the subject line, so the email read like a faceless
|
||||
# to-do list with no anchor to which note triggered it.
|
||||
_body_chunks = []
|
||||
if synthesis:
|
||||
_body_chunks.append(synthesis)
|
||||
if _t:
|
||||
_body_chunks.append(_t)
|
||||
if note_body:
|
||||
_body_chunks.append(note_body)
|
||||
plain = "\n\n".join(_body_chunks) if _body_chunks else title
|
||||
msg.attach(MIMEText(plain, "plain", "utf-8"))
|
||||
|
||||
def _smtp_send():
|
||||
from routes.email_helpers import _send_smtp_message
|
||||
_send_smtp_message(cfg, from_addr, [recipient], msg.as_string())
|
||||
|
||||
import asyncio as _aio
|
||||
await _aio.to_thread(_smtp_send)
|
||||
email_sent = True
|
||||
except Exception as e:
|
||||
email_error = str(e) or e.__class__.__name__
|
||||
logger.warning(f"Reminder email send failed: {e}")
|
||||
|
||||
ntfy_sent = False
|
||||
ntfy_error = ""
|
||||
if channel == "ntfy":
|
||||
try:
|
||||
from src.integrations import load_integrations
|
||||
import httpx
|
||||
intg = next(
|
||||
(i for i in load_integrations()
|
||||
if i.get("preset") == "ntfy" and i.get("enabled", True) and i.get("base_url")),
|
||||
None,
|
||||
)
|
||||
if intg:
|
||||
base = intg["base_url"].rstrip("/")
|
||||
topic = settings.get("reminder_ntfy_topic") or "reminders"
|
||||
ntfy_body = synthesis or note_body or title
|
||||
hdrs = {"Title": title or "Reminder", "Priority": "high", "Tags": "bell"}
|
||||
api_key = intg.get("api_key", "")
|
||||
if api_key:
|
||||
hdrs["Authorization"] = f"Bearer {api_key}"
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(f"{base}/{topic}", content=ntfy_body, headers=hdrs)
|
||||
ntfy_sent = resp.is_success
|
||||
if not ntfy_sent:
|
||||
ntfy_error = f"ntfy returned HTTP {resp.status_code}"
|
||||
else:
|
||||
ntfy_error = "No enabled ntfy integration"
|
||||
except Exception as e:
|
||||
ntfy_error = str(e) or e.__class__.__name__
|
||||
logger.warning(f"Reminder ntfy send failed: {e}")
|
||||
|
||||
# In-app browser notification ALWAYS fires (regardless of channel). The
|
||||
# frontend polls `/api/tasks/notifications` and turns any entry with a
|
||||
# `body` into a real `Notification(...)` — same surface as task-success
|
||||
# popups. Lets the user see reminders inside the app even when the
|
||||
# primary channel is email/ntfy and the tab is open.
|
||||
browser_sent = False
|
||||
local_browser_sent = (not queue_browser and channel == "browser")
|
||||
if queue_browser and _scheduler_ref is not None:
|
||||
try:
|
||||
_scheduler_ref.add_notification(
|
||||
task_name=title or "Reminder",
|
||||
status="success",
|
||||
task_id=f"reminder-{note_id}",
|
||||
owner=owner or None,
|
||||
body=(synthesis or note_body or title or "").strip()[:500] or "Reminder",
|
||||
)
|
||||
browser_sent = True
|
||||
except Exception as _e:
|
||||
logger.debug(f"dispatch_reminder: in-app notif push failed: {_e}")
|
||||
|
||||
# Dedupe across paths: write to the same cache file `action_ping_notes`
|
||||
# reads, so the background scanner's REPING_MIN window suppresses a
|
||||
# second send for the same note within 25 min. Without this, a note
|
||||
# whose due_date fires while the user has the app open got TWO emails
|
||||
# (frontend-fired here + background-fired by ping_notes 0–5 min later).
|
||||
if (email_sent or ntfy_sent or browser_sent or local_browser_sent) and note_id:
|
||||
try:
|
||||
import json as _json
|
||||
from datetime import datetime as _dt, timezone as _tz
|
||||
from pathlib import Path as _P
|
||||
# Per-owner cache so the scanner's prune step on user A's run
|
||||
# doesn't drop user B's just-fired entry (review C4).
|
||||
_STATE = cache_path
|
||||
if _STATE is None:
|
||||
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
|
||||
_STATE = _P(f"data/note_pings_{_slug}.json")
|
||||
_STATE.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
_cache = cache or (_json.loads(_STATE.read_text()) if _STATE.exists() else {})
|
||||
except Exception:
|
||||
_cache = {}
|
||||
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "browser"
|
||||
_cache[cache_key or str(note_id)] = {
|
||||
"at": _dt.now(_tz.utc).isoformat(),
|
||||
"channel": sent_channel,
|
||||
}
|
||||
_STATE.write_text(_json.dumps(_cache))
|
||||
except Exception as _e:
|
||||
logger.debug(f"dispatch_reminder: cache write failed: {_e}")
|
||||
|
||||
return {
|
||||
"synthesis": synthesis,
|
||||
"email_sent": email_sent,
|
||||
"email_error": email_error,
|
||||
"ntfy_sent": ntfy_sent,
|
||||
"ntfy_error": ntfy_error,
|
||||
"browser_sent": browser_sent or local_browser_sent,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def setup_note_routes(task_scheduler=None):
|
||||
# Expose the scheduler to module-level `dispatch_reminder` so reminders
|
||||
# can also push to the in-app notification queue (the polling system
|
||||
# turns each entry into a real browser Notification + the existing
|
||||
# tasks-tab badge / dot system).
|
||||
global _scheduler_ref
|
||||
_scheduler_ref = task_scheduler
|
||||
|
||||
router = APIRouter(prefix="/api/notes", tags=["notes"])
|
||||
|
||||
def _owner(request: Request) -> Optional[str]:
|
||||
return get_current_user(request)
|
||||
|
||||
# --- LIST ---
|
||||
@router.get("")
|
||||
def list_notes(
|
||||
request: Request,
|
||||
archived: Optional[bool] = None,
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(Note)
|
||||
if user is not None:
|
||||
q = q.filter(Note.owner == user)
|
||||
if archived is not None:
|
||||
q = q.filter(Note.archived == archived)
|
||||
else:
|
||||
q = q.filter(Note.archived == False)
|
||||
if label:
|
||||
q = q.filter(Note.label == label)
|
||||
# Archived view: most recently archived first. Active view: pin + manual order.
|
||||
if archived is True:
|
||||
notes = q.order_by(Note.updated_at.desc()).all()
|
||||
else:
|
||||
notes = q.order_by(Note.pinned.desc(), Note.sort_order.asc(), Note.updated_at.desc()).all()
|
||||
return {"notes": [_note_to_dict(n) for n in notes]}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# --- CREATE ---
|
||||
@router.post("")
|
||||
def create_note(request: Request, body: NoteCreate):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
note = Note(
|
||||
id=str(uuid.uuid4()),
|
||||
owner=user,
|
||||
title=body.title,
|
||||
content=body.content,
|
||||
items=json.dumps(body.items) if body.items is not None else None,
|
||||
note_type=body.note_type,
|
||||
color=body.color,
|
||||
label=body.label,
|
||||
pinned=body.pinned,
|
||||
due_date=body.due_date,
|
||||
source=body.source,
|
||||
session_id=body.session_id,
|
||||
image_url=body.image_url,
|
||||
repeat=body.repeat or "none",
|
||||
sort_order=body.sort_order if body.sort_order is not None else 0,
|
||||
)
|
||||
db.add(note)
|
||||
db.commit()
|
||||
db.refresh(note)
|
||||
return _note_to_dict(note)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# --- GET ONE ---
|
||||
@router.get("/{note_id}")
|
||||
def get_note(request: Request, note_id: str):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
note = db.query(Note).filter(Note.id == note_id).first()
|
||||
if not note:
|
||||
raise HTTPException(404, "Note not found")
|
||||
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
|
||||
# let any user touch a row whose owner field was null/empty.
|
||||
if user is not None and note.owner != user:
|
||||
raise HTTPException(404, "Note not found")
|
||||
return _note_to_dict(note)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# --- UPDATE ---
|
||||
@router.put("/{note_id}")
|
||||
def update_note(request: Request, note_id: str, body: NoteUpdate):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
note = db.query(Note).filter(Note.id == note_id).first()
|
||||
if not note:
|
||||
raise HTTPException(404, "Note not found")
|
||||
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
|
||||
# let any user touch a row whose owner field was null/empty.
|
||||
if user is not None and note.owner != user:
|
||||
raise HTTPException(404, "Note not found")
|
||||
|
||||
if body.title is not None:
|
||||
note.title = body.title
|
||||
if body.content is not None:
|
||||
note.content = body.content
|
||||
if body.items is not None:
|
||||
note.items = json.dumps(body.items)
|
||||
flag_modified(note, "items")
|
||||
if body.note_type is not None:
|
||||
note.note_type = body.note_type
|
||||
if body.color is not None:
|
||||
note.color = body.color
|
||||
if body.label is not None:
|
||||
note.label = body.label
|
||||
if body.pinned is not None:
|
||||
note.pinned = body.pinned
|
||||
if body.archived is not None:
|
||||
note.archived = body.archived
|
||||
if body.due_date is not None:
|
||||
note.due_date = body.due_date
|
||||
if body.image_url is not None:
|
||||
note.image_url = body.image_url
|
||||
if body.repeat is not None:
|
||||
note.repeat = body.repeat
|
||||
if body.sort_order is not None:
|
||||
note.sort_order = body.sort_order
|
||||
if body.agent_session_id is not None:
|
||||
note.agent_session_id = body.agent_session_id
|
||||
|
||||
db.commit()
|
||||
db.refresh(note)
|
||||
return _note_to_dict(note)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# --- DELETE ---
|
||||
@router.delete("/{note_id}")
|
||||
def delete_note(request: Request, note_id: str):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
note = db.query(Note).filter(Note.id == note_id).first()
|
||||
if not note:
|
||||
raise HTTPException(404, "Note not found")
|
||||
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
|
||||
# let any user touch a row whose owner field was null/empty.
|
||||
if user is not None and note.owner != user:
|
||||
raise HTTPException(404, "Note not found")
|
||||
db.delete(note)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# --- TOGGLE PIN ---
|
||||
@router.post("/{note_id}/pin")
|
||||
def toggle_pin(request: Request, note_id: str):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
note = db.query(Note).filter(Note.id == note_id).first()
|
||||
if not note:
|
||||
raise HTTPException(404, "Note not found")
|
||||
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
|
||||
# let any user touch a row whose owner field was null/empty.
|
||||
if user is not None and note.owner != user:
|
||||
raise HTTPException(404, "Note not found")
|
||||
note.pinned = not note.pinned
|
||||
db.commit()
|
||||
return {"ok": True, "pinned": note.pinned}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# --- TOGGLE ARCHIVE ---
|
||||
@router.post("/{note_id}/archive")
|
||||
def toggle_archive(request: Request, note_id: str):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
note = db.query(Note).filter(Note.id == note_id).first()
|
||||
if not note:
|
||||
raise HTTPException(404, "Note not found")
|
||||
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
|
||||
# let any user touch a row whose owner field was null/empty.
|
||||
if user is not None and note.owner != user:
|
||||
raise HTTPException(404, "Note not found")
|
||||
note.archived = not note.archived
|
||||
db.commit()
|
||||
return {"ok": True, "archived": note.archived}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# --- TOGGLE CHECKLIST ITEM ---
|
||||
@router.post("/{note_id}/items/{index}/toggle")
|
||||
def toggle_item(request: Request, note_id: str, index: int):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
note = db.query(Note).filter(Note.id == note_id).first()
|
||||
if not note:
|
||||
raise HTTPException(404, "Note not found")
|
||||
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
|
||||
# let any user touch a row whose owner field was null/empty.
|
||||
if user is not None and note.owner != user:
|
||||
raise HTTPException(404, "Note not found")
|
||||
if not note.items:
|
||||
raise HTTPException(400, "Note has no checklist items")
|
||||
items = json.loads(note.items)
|
||||
if index < 0 or index >= len(items):
|
||||
raise HTTPException(400, f"Item index {index} out of range")
|
||||
items[index]["done"] = not items[index].get("done", False)
|
||||
note.items = json.dumps(items)
|
||||
flag_modified(note, "items")
|
||||
db.commit()
|
||||
return {"ok": True, "items": items}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# --- FIRE REMINDER ---
|
||||
@router.post("/fire-reminder")
|
||||
async def fire_reminder(request: Request):
|
||||
"""Dispatch a reminder according to user settings.
|
||||
|
||||
Called by the frontend when a reminder fires. Optionally generates an
|
||||
LLM synthesis line and/or sends an email through configured SMTP.
|
||||
Returns {synthesis, email_sent}.
|
||||
"""
|
||||
# Gate against anonymous callers — LLM synthesis can burn tokens.
|
||||
from src.auth_helpers import get_current_user as _gcu
|
||||
if not _gcu(request):
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
body = await request.json()
|
||||
note_id = body.get("note_id")
|
||||
title = (body.get("title") or "").strip()
|
||||
note_body = (body.get("body") or "").strip()
|
||||
if not note_id:
|
||||
raise HTTPException(400, "note_id required")
|
||||
|
||||
# Delegate to the module-level helper so background tasks can reuse
|
||||
# the same dispatch without an HTTP roundtrip + auth cookie.
|
||||
return await dispatch_reminder(
|
||||
title=title, note_body=note_body, note_id=note_id,
|
||||
owner=_gcu(request) or "",
|
||||
queue_browser=False,
|
||||
)
|
||||
|
||||
# --- REORDER NOTES ---
|
||||
@router.post("/reorder")
|
||||
async def reorder_notes(request: Request):
|
||||
"""Update sort_order for a list of note IDs in the order provided."""
|
||||
user = _owner(request)
|
||||
body = await request.json()
|
||||
ids = body.get("ids", [])
|
||||
if not isinstance(ids, list):
|
||||
raise HTTPException(400, "ids must be a list")
|
||||
# v2 review HIGH-12: drop the legacy `(owner == user) | (owner ==
|
||||
# None)` OR which let an authenticated user silently reorder
|
||||
# every legacy-null-owner note belonging to other accounts. In
|
||||
# an unconfigured (single-user) auth deploy the OR is still safe
|
||||
# because there's no second user to attack; we keep that branch
|
||||
# explicit and gated on AuthManager.is_configured.
|
||||
try:
|
||||
from core.auth import AuthManager
|
||||
_allow_null = not AuthManager().is_configured
|
||||
except Exception:
|
||||
_allow_null = False
|
||||
db = SessionLocal()
|
||||
try:
|
||||
for i, nid in enumerate(ids):
|
||||
q = db.query(Note).filter(Note.id == nid)
|
||||
if user is not None:
|
||||
if _allow_null:
|
||||
q = q.filter((Note.owner == user) | (Note.owner == None)) # noqa: E711
|
||||
else:
|
||||
q = q.filter(Note.owner == user)
|
||||
note = q.first()
|
||||
if note:
|
||||
note.sort_order = i
|
||||
db.commit()
|
||||
return {"ok": True, "count": len(ids)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return router
|
||||
276
routes/personal_routes.py
Normal file
276
routes/personal_routes.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# routes/personal_routes.py
|
||||
"""Routes for personal documents management."""
|
||||
import os
|
||||
import logging
|
||||
from typing import List
|
||||
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
|
||||
from src.request_models import DirectoryRequest
|
||||
from core.constants import BASE_DIR, PERSONAL_DIR
|
||||
from src.rag_singleton import get_rag_manager
|
||||
from src.auth_helpers import get_current_user, require_user
|
||||
from core.middleware import require_admin
|
||||
from src.upload_handler import secure_filename
|
||||
|
||||
UPLOADS_DIR = os.path.join(BASE_DIR, "data", "personal_uploads")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
|
||||
"""
|
||||
Setup personal documents related routes.
|
||||
|
||||
Args:
|
||||
personal_docs_manager: PersonalDocsManager instance
|
||||
rag_manager: RAG manager instance (may be None)
|
||||
rag_available: Boolean indicating if RAG is available
|
||||
|
||||
Returns:
|
||||
APIRouter instance with personal docs routes
|
||||
"""
|
||||
router = APIRouter(prefix="/api/personal")
|
||||
|
||||
def _rag():
|
||||
"""Get the current RAG manager, retrying init if needed."""
|
||||
return get_rag_manager()
|
||||
|
||||
def _resolve_allowed_personal_dir(directory: str) -> str:
|
||||
"""Resolve a user-supplied personal-docs path under the allowed root."""
|
||||
if not directory:
|
||||
raise HTTPException(400, "Directory path is required")
|
||||
|
||||
base_abs = os.path.abspath(PERSONAL_DIR)
|
||||
candidate = directory if os.path.isabs(directory) else os.path.join(base_abs, directory)
|
||||
resolved = os.path.abspath(candidate)
|
||||
try:
|
||||
in_base = os.path.commonpath([resolved, base_abs]) == base_abs
|
||||
except ValueError:
|
||||
in_base = False
|
||||
if not in_base:
|
||||
raise HTTPException(403, "Directory must be inside personal documents")
|
||||
return resolved
|
||||
|
||||
@router.get("")
|
||||
def api_personal_list(owner: str = Depends(require_user), _admin: None = Depends(require_admin)):
|
||||
"""Enhanced version that includes directories"""
|
||||
files = [{"name": f["name"], "size": f["size"], "path": f.get("path", "")} for f in personal_docs_manager.index]
|
||||
directories = personal_docs_manager.get_indexed_directories() if hasattr(personal_docs_manager, "get_indexed_directories") else []
|
||||
return {"files": files, "directories": directories}
|
||||
|
||||
@router.post("/reload")
|
||||
def api_personal_reload(owner: str = Depends(require_user), _admin: None = Depends(require_admin)):
|
||||
personal_docs_manager.refresh_index()
|
||||
return {"ok": True, "count": len(personal_docs_manager.index)}
|
||||
|
||||
@router.post("/add_directory")
|
||||
async def add_directory_to_rag(
|
||||
request: Request,
|
||||
directory_request: DirectoryRequest,
|
||||
owner: str = Depends(require_user), _admin: None = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
Add a directory and all its subdirectories/files to the RAG index.
|
||||
|
||||
Args:
|
||||
directory_request: Directory request model containing the directory path
|
||||
|
||||
Returns:
|
||||
JSON response with indexing results
|
||||
"""
|
||||
directory = directory_request.directory
|
||||
try:
|
||||
directory = _resolve_allowed_personal_dir(directory)
|
||||
|
||||
# Security check - ensure directory exists and is accessible
|
||||
if not os.path.exists(directory):
|
||||
raise HTTPException(404, f"Directory not found: {directory}")
|
||||
|
||||
if not os.path.isdir(directory):
|
||||
raise HTTPException(400, f"Path is not a directory: {directory}")
|
||||
|
||||
logger.info(f"Adding directory to RAG: {directory}")
|
||||
|
||||
# Use the RAGManager to index the directory
|
||||
rag = _rag()
|
||||
if rag:
|
||||
result = rag.index_personal_documents(directory, owner=owner)
|
||||
|
||||
if result["success"]:
|
||||
# Also update the personal_docs_manager to track this directory
|
||||
personal_docs_manager.add_directory(directory, index=False)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully indexed {result['indexed_count']} chunks from {directory}",
|
||||
"indexed_count": result["indexed_count"],
|
||||
"failed_count": result.get("failed_count", 0),
|
||||
"directory": directory
|
||||
}
|
||||
else:
|
||||
raise HTTPException(500, result.get("message", "Failed to index directory"))
|
||||
else:
|
||||
raise HTTPException(503, "RAG system is not available")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding directory to RAG: {e}")
|
||||
raise HTTPException(500, f"Failed to add directory: {str(e)}")
|
||||
|
||||
@router.delete("/remove_directory")
|
||||
async def remove_directory_from_rag(directory: str = Query(...), owner: str = Depends(require_user), _admin: None = Depends(require_admin)):
|
||||
"""
|
||||
Remove a directory from the RAG index.
|
||||
|
||||
Args:
|
||||
directory: Path to the directory to remove
|
||||
|
||||
Returns:
|
||||
JSON response confirming removal
|
||||
"""
|
||||
try:
|
||||
if not directory:
|
||||
raise HTTPException(400, "Directory path is required")
|
||||
|
||||
logger.info(f"Removing directory from RAG: {directory}")
|
||||
|
||||
# Always remove from personal_docs_manager tracking
|
||||
if hasattr(personal_docs_manager, 'remove_directory'):
|
||||
personal_docs_manager.remove_directory(directory)
|
||||
|
||||
# Remove from RAG vector store (best-effort)
|
||||
rag = _rag()
|
||||
if rag:
|
||||
try:
|
||||
rag.remove_directory(directory)
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG removal failed for directory {directory}: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Successfully removed {directory} from RAG index",
|
||||
"directory": directory
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing directory from RAG: {e}")
|
||||
raise HTTPException(500, f"Failed to remove directory: {str(e)}")
|
||||
|
||||
@router.post("/upload")
|
||||
async def upload_files_to_rag(request: Request, files: List[UploadFile] = File(...)):
|
||||
"""Upload files directly into RAG. Supports text and PDF."""
|
||||
user = get_current_user(request)
|
||||
rag = _rag()
|
||||
if not rag:
|
||||
raise HTTPException(503, "RAG system is not available — is the embedding service running?")
|
||||
|
||||
os.makedirs(UPLOADS_DIR, exist_ok=True)
|
||||
|
||||
total_indexed = 0
|
||||
total_failed = 0
|
||||
uploaded_files = []
|
||||
|
||||
for upload in files:
|
||||
try:
|
||||
# Sanitize filename — strip directory components and unsafe chars
|
||||
safe_name = secure_filename(os.path.basename(upload.filename or "upload"))
|
||||
if not safe_name or safe_name.startswith("."):
|
||||
safe_name = f"upload_{total_indexed + total_failed}"
|
||||
file_path = os.path.join(UPLOADS_DIR, safe_name)
|
||||
# Defense-in-depth: ensure resolved path stays under UPLOADS_DIR
|
||||
base_abs = os.path.abspath(UPLOADS_DIR)
|
||||
if os.path.commonpath([os.path.abspath(file_path), base_abs]) != base_abs:
|
||||
logger.warning(f"Rejected unsafe upload path: {upload.filename!r}")
|
||||
total_failed += 1
|
||||
continue
|
||||
content_bytes = await upload.read()
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content_bytes)
|
||||
|
||||
ext = os.path.splitext(safe_name)[1].lower()
|
||||
if ext == ".pdf":
|
||||
from src.personal_docs import extract_pdf_text
|
||||
text = extract_pdf_text(file_path)
|
||||
else:
|
||||
text = content_bytes.decode("utf-8", errors="replace")
|
||||
|
||||
if not text or not text.strip():
|
||||
total_failed += 1
|
||||
continue
|
||||
|
||||
# Chunk and index
|
||||
chunks = rag._split_into_chunks(text, chunk_size=500)
|
||||
for i, chunk in enumerate(chunks):
|
||||
metadata = {
|
||||
"source": file_path,
|
||||
"filename": safe_name,
|
||||
"directory": UPLOADS_DIR,
|
||||
"type": ext,
|
||||
"chunk_id": i,
|
||||
}
|
||||
if user:
|
||||
metadata["owner"] = user
|
||||
if rag.add_document(chunk, metadata):
|
||||
total_indexed += 1
|
||||
else:
|
||||
total_failed += 1
|
||||
|
||||
uploaded_files.append(safe_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload/index {upload.filename}: {e}")
|
||||
total_failed += 1
|
||||
|
||||
# Track uploads directory
|
||||
if uploaded_files and hasattr(personal_docs_manager, "add_directory"):
|
||||
personal_docs_manager.add_directory(UPLOADS_DIR, index=False)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"uploaded": uploaded_files,
|
||||
"indexed_count": total_indexed,
|
||||
"failed_count": total_failed,
|
||||
}
|
||||
|
||||
@router.delete("/file")
|
||||
async def delete_file_from_rag(filepath: str = Query(...), owner: str = Depends(require_user), _admin: None = Depends(require_admin)):
|
||||
"""Delete a specific file from RAG index and optionally from disk."""
|
||||
try:
|
||||
# Remove chunks from RAG vector store (best-effort)
|
||||
removed = 0
|
||||
rag = _rag()
|
||||
if rag:
|
||||
try:
|
||||
removed = rag.delete_by_source(filepath)
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG removal failed for {filepath}: {e}")
|
||||
|
||||
# Delete file from disk if it's in uploads dir
|
||||
deleted_from_disk = False
|
||||
try:
|
||||
abs_target = os.path.abspath(filepath)
|
||||
base_abs = os.path.abspath(UPLOADS_DIR)
|
||||
in_uploads = (
|
||||
abs_target == base_abs
|
||||
or os.path.commonpath([abs_target, base_abs]) == base_abs
|
||||
)
|
||||
except ValueError:
|
||||
# commonpath raises on mixed drives / non-comparable paths
|
||||
in_uploads = False
|
||||
if in_uploads and abs_target != base_abs and os.path.exists(abs_target):
|
||||
os.remove(abs_target)
|
||||
deleted_from_disk = True
|
||||
|
||||
# Exclude the file from the listing (persists across restarts)
|
||||
personal_docs_manager.exclude_file(filepath)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"removed_chunks": removed,
|
||||
"deleted_from_disk": deleted_from_disk,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file {filepath}: {e}")
|
||||
raise HTTPException(500, f"Failed to delete file: {str(e)}")
|
||||
|
||||
return router
|
||||
74
routes/prefs_routes.py
Normal file
74
routes/prefs_routes.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""User preferences API — per-user key/value store backed by a JSON file."""
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Request
|
||||
from src.auth_helpers import get_current_user
|
||||
|
||||
PREFS_FILE = os.path.join("data", "user_prefs.json")
|
||||
|
||||
|
||||
def _load():
|
||||
"""Load the raw prefs file (internal use only)."""
|
||||
try:
|
||||
with open(PREFS_FILE, "r") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
return {}
|
||||
|
||||
|
||||
def _save(prefs):
|
||||
os.makedirs(os.path.dirname(PREFS_FILE), exist_ok=True)
|
||||
with open(PREFS_FILE, "w") as f:
|
||||
json.dump(prefs, f, indent=2)
|
||||
|
||||
|
||||
def _load_for_user(user: Optional[str] = None) -> dict:
|
||||
"""Load preferences for a specific user."""
|
||||
all_prefs = _load()
|
||||
if "_users" in all_prefs:
|
||||
if user is None:
|
||||
# Auth disabled — return first user's prefs for backward compat
|
||||
users = all_prefs["_users"]
|
||||
return dict(next(iter(users.values()), {}))
|
||||
return dict(all_prefs["_users"].get(user, {}))
|
||||
# Legacy flat format — return as-is
|
||||
return dict(all_prefs)
|
||||
|
||||
|
||||
def _save_for_user(user: Optional[str], prefs: dict):
|
||||
"""Save preferences for a specific user."""
|
||||
all_prefs = _load()
|
||||
if user is None:
|
||||
# Auth disabled — save flat
|
||||
_save(prefs)
|
||||
return
|
||||
if "_users" not in all_prefs:
|
||||
all_prefs = {"_users": {}}
|
||||
all_prefs["_users"][user] = prefs
|
||||
_save(all_prefs)
|
||||
|
||||
|
||||
def setup_prefs_routes():
|
||||
router = APIRouter(prefix="/api/prefs", tags=["preferences"])
|
||||
|
||||
@router.get("")
|
||||
async def get_all_prefs(request: Request):
|
||||
user = get_current_user(request)
|
||||
return _load_for_user(user)
|
||||
|
||||
@router.get("/{key}")
|
||||
async def get_pref(request: Request, key: str):
|
||||
user = get_current_user(request)
|
||||
prefs = _load_for_user(user)
|
||||
return {"key": key, "value": prefs.get(key)}
|
||||
|
||||
@router.put("/{key}")
|
||||
async def set_pref(request: Request, key: str, body: dict):
|
||||
user = get_current_user(request)
|
||||
prefs = _load_for_user(user)
|
||||
prefs[key] = body.get("value")
|
||||
_save_for_user(user, prefs)
|
||||
return {"key": key, "value": prefs[key]}
|
||||
|
||||
return router
|
||||
123
routes/preset_routes.py
Normal file
123
routes/preset_routes.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Preset routes — /api/presets GET, /api/presets/custom POST, user templates CRUD."""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.request_models import PresetUpdateRequest
|
||||
from core.middleware import require_admin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserTemplateRequest(BaseModel):
|
||||
id: str = ""
|
||||
name: str = Field(..., min_length=1, max_length=100)
|
||||
system_prompt: str = Field("", max_length=10000)
|
||||
temperature: float = Field(1.0, ge=0.0, le=2.0)
|
||||
max_tokens: int = Field(0, ge=0, le=65536)
|
||||
|
||||
|
||||
def setup_preset_routes(preset_manager) -> APIRouter:
|
||||
router = APIRouter(tags=["presets"])
|
||||
|
||||
@router.get("/api/presets")
|
||||
async def get_presets() -> Dict[str, Any]:
|
||||
return preset_manager.presets
|
||||
|
||||
@router.post("/api/presets/custom")
|
||||
async def update_custom_preset(preset_update: PresetUpdateRequest, _admin: None = Depends(require_admin)) -> Dict[str, Any]:
|
||||
try:
|
||||
success = preset_manager.update_custom(
|
||||
preset_update.temperature,
|
||||
preset_update.max_tokens,
|
||||
preset_update.system_prompt,
|
||||
preset_update.name,
|
||||
preset_update.enabled,
|
||||
preset_update.inject_prefix,
|
||||
preset_update.inject_suffix,
|
||||
)
|
||||
if success:
|
||||
return {"success": True, "message": "Custom preset updated"}
|
||||
return {"success": False, "message": "Failed to save preset"}
|
||||
except Exception as e:
|
||||
logger.error(f"Preset update error: {e}")
|
||||
raise HTTPException(500, "Failed to update custom preset")
|
||||
|
||||
@router.get("/api/presets/templates")
|
||||
async def get_user_templates() -> List[Dict]:
|
||||
return preset_manager.get_user_templates()
|
||||
|
||||
@router.post("/api/presets/templates")
|
||||
async def save_user_template(req: UserTemplateRequest, _admin: None = Depends(require_admin)) -> Dict[str, Any]:
|
||||
template = req.model_dump()
|
||||
if not template["id"]:
|
||||
template["id"] = f"user-{uuid.uuid4().hex[:8]}"
|
||||
success = preset_manager.save_user_template(template)
|
||||
if success:
|
||||
return {"success": True, "template": template}
|
||||
return {"success": False, "message": "Failed to save template"}
|
||||
|
||||
@router.delete("/api/presets/templates/{template_id}")
|
||||
async def delete_user_template(template_id: str, _admin: None = Depends(require_admin)) -> Dict[str, Any]:
|
||||
success = preset_manager.delete_user_template(template_id)
|
||||
if success:
|
||||
return {"success": True}
|
||||
return {"success": False, "message": "Failed to delete template"}
|
||||
|
||||
@router.post("/api/presets/expand")
|
||||
async def expand_character_prompt(request: Request) -> Dict[str, Any]:
|
||||
"""Use AI to expand a rough character description into a full system prompt."""
|
||||
from src.ai_interaction import _resolve_model
|
||||
from src.llm_core import llm_call_async
|
||||
|
||||
data = await request.json()
|
||||
draft = (data.get("prompt") or "").strip()
|
||||
name = (data.get("name") or "").strip()
|
||||
|
||||
if not draft and not name:
|
||||
return {"success": False, "message": "Nothing to expand"}
|
||||
|
||||
user_input = ""
|
||||
if name:
|
||||
user_input += f"Character name: {name}\n"
|
||||
if draft:
|
||||
user_input += f"Notes: {draft}\n"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are an expert at writing character system prompts for AI assistants. "
|
||||
"The user will give you a character name and/or rough notes. "
|
||||
"Write a concise, effective system prompt (3-6 sentences) that captures the character's personality, "
|
||||
"speaking style, knowledge areas, and behavioral guidelines. "
|
||||
"Output ONLY the system prompt text — no quotes, no preamble, no explanation."
|
||||
)},
|
||||
{"role": "user", "content": user_input},
|
||||
]
|
||||
|
||||
try:
|
||||
model_spec = data.get("model") or ""
|
||||
url, model, headers = _resolve_model(model_spec)
|
||||
result = await llm_call_async(url, model, messages, temperature=0.8, max_tokens=500, headers=headers)
|
||||
return {"success": True, "prompt": result.strip()}
|
||||
except Exception as e:
|
||||
logger.error(f"Expand prompt failed: {e}")
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
# ── Group presets ──
|
||||
@router.get("/api/presets/groups")
|
||||
async def get_group_presets():
|
||||
"""Get saved group chat presets."""
|
||||
return {"groups": preset_manager.get_group_presets()}
|
||||
|
||||
@router.post("/api/presets/groups")
|
||||
async def save_group_presets(request: Request, _admin: None = Depends(require_admin)):
|
||||
"""Save group chat presets."""
|
||||
data = await request.json()
|
||||
preset_manager.save_group_presets(data.get("groups", []))
|
||||
return {"ok": True}
|
||||
|
||||
return router
|
||||
607
routes/research_routes.py
Normal file
607
routes/research_routes.py
Normal file
@@ -0,0 +1,607 @@
|
||||
"""Research background task routes — /api/research/*."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import HTMLResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.auth_helpers import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Model-name substrings that are NOT chat/generation models — research must
|
||||
# never pick these as its model. An OpenAI-style endpoint often lists
|
||||
# `text-embedding-ada-002` etc. first in its model list, which is why research
|
||||
# was failing with "Cannot reach model 'text-embedding-ada-002'".
|
||||
_NON_CHAT_MODEL = (
|
||||
"text-embedding", "embedding", "tts-", "whisper", "dall-e",
|
||||
"moderation", "rerank", "reranker", "clip", "stable-diffusion",
|
||||
)
|
||||
|
||||
|
||||
def _first_chat_model(models) -> str:
|
||||
"""First model that isn't an embedding/tts/etc. — falls back to models[0]."""
|
||||
for m in (models or []):
|
||||
if not any(p in str(m).lower() for p in _NON_CHAT_MODEL):
|
||||
return m
|
||||
return (models[0] if models else "")
|
||||
|
||||
|
||||
def _resolve_research_endpoint(sess) -> tuple:
|
||||
"""Return (endpoint_url, model, headers) for Deep Research, checking admin overrides."""
|
||||
url, model, headers = resolve_endpoint(
|
||||
"research",
|
||||
fallback_url=sess.endpoint_url,
|
||||
fallback_model=sess.model,
|
||||
fallback_headers=sess.headers,
|
||||
)
|
||||
return url, model, headers
|
||||
|
||||
|
||||
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
|
||||
router = APIRouter(tags=["research"])
|
||||
|
||||
def _require_user(request: Request) -> str:
|
||||
"""All research endpoints require an authenticated user. Research
|
||||
data isn't owner-scoped in the on-disk JSON yet, so we at least
|
||||
block anonymous access. Multi-tenant deploys should additionally
|
||||
verify the session belongs to this user."""
|
||||
user = get_current_user(request)
|
||||
if not user:
|
||||
raise HTTPException(401, "Not authenticated")
|
||||
return user
|
||||
|
||||
def _owns_in_memory(session_id: str, user: str) -> bool:
|
||||
"""Ownership check for an in-flight (in-memory) research task.
|
||||
Falls back to the on-disk JSON if the task has already finished."""
|
||||
entry = research_handler._active_tasks.get(session_id)
|
||||
if entry is not None:
|
||||
return entry.get("owner", "") == user
|
||||
# Task no longer in memory — check the persisted JSON.
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
return False
|
||||
try:
|
||||
return json.loads(path.read_text()).get("owner") == user
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@router.get("/api/research/active")
|
||||
async def research_active(request: Request):
|
||||
"""List all currently active (running) research tasks."""
|
||||
user = _require_user(request)
|
||||
active = []
|
||||
for sid, entry in research_handler._active_tasks.items():
|
||||
# SECURITY: only show this user's running tasks.
|
||||
if entry.get("owner", "") != user:
|
||||
continue
|
||||
if entry.get("status") == "running":
|
||||
active.append({
|
||||
"session_id": sid,
|
||||
"query": entry.get("query", ""),
|
||||
"status": "running",
|
||||
"progress": entry.get("progress", {}),
|
||||
"started_at": entry.get("started_at", 0),
|
||||
})
|
||||
return {"active": active}
|
||||
|
||||
@router.get("/api/research/status/{session_id}")
|
||||
async def research_status(session_id: str, request: Request):
|
||||
user = _require_user(request)
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
status = research_handler.get_status(session_id)
|
||||
if status is None:
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
return status
|
||||
|
||||
@router.post("/api/research/cancel/{session_id}")
|
||||
async def research_cancel(session_id: str, request: Request):
|
||||
user = _require_user(request)
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
cancelled = research_handler.cancel_research(session_id)
|
||||
return {"cancelled": cancelled}
|
||||
|
||||
@router.post("/api/research/result/{session_id}")
|
||||
async def research_result(session_id: str, request: Request):
|
||||
user = _require_user(request)
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research result available")
|
||||
result = research_handler.get_result(session_id)
|
||||
if result is None:
|
||||
raise HTTPException(404, "No research result available")
|
||||
sources = research_handler.get_sources(session_id) or []
|
||||
raw_findings = research_handler.get_raw_findings(session_id) or []
|
||||
research_handler.clear_result(session_id)
|
||||
return {"result": result, "sources": sources, "raw_findings": raw_findings}
|
||||
|
||||
def _assert_owns_research(session_id: str, user: str) -> None:
|
||||
"""404-not-403 ownership gate for a research session's on-disk JSON.
|
||||
Use BEFORE returning any data or mutating the file."""
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
try:
|
||||
owner = json.loads(path.read_text()).get("owner")
|
||||
except Exception:
|
||||
raise HTTPException(404, "Research not found")
|
||||
if owner != user:
|
||||
raise HTTPException(404, "Research not found")
|
||||
|
||||
@router.get("/api/research/report/{session_id}")
|
||||
async def research_report(session_id: str, request: Request):
|
||||
"""Serve the visual HTML report for a completed research session."""
|
||||
user = _require_user(request)
|
||||
_assert_owns_research(session_id, user)
|
||||
logger.info(f"Visual report requested for session {session_id}")
|
||||
try:
|
||||
html_content = research_handler.get_report_html(session_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Visual report generation error: {e}", exc_info=True)
|
||||
raise HTTPException(500, f"Report generation failed: {e}")
|
||||
if html_content is None:
|
||||
logger.warning(f"No report data found for session {session_id}")
|
||||
raise HTTPException(404, "No visual report available for this session")
|
||||
return HTMLResponse(content=html_content)
|
||||
|
||||
class HideImageRequest(BaseModel):
|
||||
url: str
|
||||
|
||||
@router.post("/api/research/{session_id}/hide-image")
|
||||
async def research_hide_image(session_id: str, body: HideImageRequest, request: Request):
|
||||
"""Mark an image URL as hidden for this research's visual report.
|
||||
Persisted to the research JSON so subsequent /report renders skip it."""
|
||||
user = _require_user(request)
|
||||
_assert_owns_research(session_id, user)
|
||||
ok = research_handler.hide_image(session_id, body.url)
|
||||
if not ok:
|
||||
raise HTTPException(404, "Research not found")
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/api/research/{session_id}/unhide-images")
|
||||
async def research_unhide_images(session_id: str, request: Request):
|
||||
"""Clear the hidden-images list for a research session."""
|
||||
user = _require_user(request)
|
||||
_assert_owns_research(session_id, user)
|
||||
ok = research_handler.unhide_all_images(session_id)
|
||||
if not ok:
|
||||
raise HTTPException(404, "Research not found")
|
||||
return {"ok": True}
|
||||
|
||||
@router.get("/api/research/library")
|
||||
async def research_library(
|
||||
request: Request,
|
||||
search: Optional[str] = Query(None),
|
||||
sort: str = Query("recent"),
|
||||
limit: int = Query(50),
|
||||
archived: bool = Query(False),
|
||||
):
|
||||
user = _require_user(request)
|
||||
"""List all completed research for the Library panel."""
|
||||
data_dir = Path("data/deep_research")
|
||||
items = []
|
||||
for p in data_dir.glob("*.json"):
|
||||
try:
|
||||
d = json.loads(p.read_text())
|
||||
# SECURITY: only show research belonging to this user. Legacy
|
||||
# JSONs without an `owner` field are hidden — auth was the only
|
||||
# gate before, so every user saw every other user's reports.
|
||||
if d.get("owner") != user:
|
||||
continue
|
||||
# Archived view shows ONLY archived reports; default hides them.
|
||||
if bool(d.get("archived")) != archived:
|
||||
continue
|
||||
query = d.get("query", "")
|
||||
if search and search.lower() not in query.lower():
|
||||
continue
|
||||
sources = d.get("sources", [])
|
||||
items.append({
|
||||
"id": p.stem,
|
||||
"query": query,
|
||||
"category": d.get("category") or "",
|
||||
"source_count": len(sources),
|
||||
"status": d.get("status", "done"),
|
||||
"duration": d.get("stats", {}).get("Duration", ""),
|
||||
"rounds": d.get("stats", {}).get("Rounds", ""),
|
||||
"started_at": d.get("started_at", 0),
|
||||
"completed_at": d.get("completed_at", 0),
|
||||
"archived": bool(d.get("archived")),
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Sort
|
||||
if sort == "recent":
|
||||
items.sort(key=lambda x: x["completed_at"] or 0, reverse=True)
|
||||
elif sort == "oldest":
|
||||
items.sort(key=lambda x: x["completed_at"] or 0)
|
||||
elif sort == "most-messages":
|
||||
items.sort(key=lambda x: x["source_count"], reverse=True)
|
||||
elif sort == "alpha":
|
||||
items.sort(key=lambda x: x["query"].lower())
|
||||
|
||||
return {"research": items[:limit], "total": len(items)}
|
||||
|
||||
@router.get("/api/research/detail/{session_id}")
|
||||
async def research_detail(session_id: str, request: Request):
|
||||
"""Return the full JSON for a single research result — sources,
|
||||
summary, stats — used by the Library preview panel."""
|
||||
user = _require_user(request)
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
try:
|
||||
data = json.loads(path.read_text())
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"Failed to read research: {e}")
|
||||
# SECURITY: 404 (not 403) so we don't leak that the report exists.
|
||||
if data.get("owner") != user:
|
||||
raise HTTPException(404, "Research not found")
|
||||
return data
|
||||
|
||||
@router.post("/api/research/{session_id}/archive")
|
||||
async def research_archive(session_id: str, request: Request, archived: bool = Query(True)):
|
||||
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
|
||||
user = _require_user(request)
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
if not path.exists():
|
||||
raise HTTPException(404, "Research not found")
|
||||
try:
|
||||
data = json.loads(path.read_text())
|
||||
if data.get("owner") != user:
|
||||
raise HTTPException(404, "Research not found")
|
||||
data["archived"] = bool(archived)
|
||||
path.write_text(json.dumps(data))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"Failed to update research: {e}")
|
||||
return {"ok": True, "id": session_id, "archived": bool(archived)}
|
||||
|
||||
@router.delete("/api/research/{session_id}")
|
||||
async def research_delete(session_id: str, request: Request):
|
||||
"""Delete a research result from disk."""
|
||||
user = _require_user(request)
|
||||
data_dir = Path("data/deep_research")
|
||||
json_path = data_dir / f"{session_id}.json"
|
||||
deleted = False
|
||||
if json_path.exists():
|
||||
# SECURITY: verify ownership before letting the caller delete it.
|
||||
try:
|
||||
data = json.loads(json_path.read_text())
|
||||
if data.get("owner") != user:
|
||||
raise HTTPException(404, "Research not found")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
raise HTTPException(404, "Research not found")
|
||||
json_path.unlink()
|
||||
deleted = True
|
||||
return {"deleted": deleted}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Panel endpoints — launch research without a chat session
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class ResearchStartRequest(BaseModel):
|
||||
query: str
|
||||
# max_rounds=0 means "Auto" — let the AI decide when to stop, capped at 20.
|
||||
max_rounds: int = Field(default=0, ge=0, le=20)
|
||||
search_provider: Optional[str] = None
|
||||
endpoint_id: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
max_time: int = Field(default=300, ge=60, le=1800)
|
||||
category: Optional[str] = None
|
||||
|
||||
@router.post("/api/research/start")
|
||||
async def research_start(body: ResearchStartRequest, request: Request):
|
||||
"""Launch a research job from the dedicated panel."""
|
||||
from src.auth_helpers import require_privilege
|
||||
user = require_privilege(request, "can_use_research")
|
||||
if user == "internal-tool":
|
||||
tool_owner = (request.headers.get("X-Odysseus-Owner") or "").strip()
|
||||
if tool_owner and tool_owner not in {"internal-tool", "api", "demo", "system"}:
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
if auth_mgr is not None and getattr(auth_mgr, "is_configured", False):
|
||||
try:
|
||||
privs = auth_mgr.get_privileges(tool_owner) or {}
|
||||
if not privs.get("can_use_research", True):
|
||||
raise HTTPException(403, f"Your account is not allowed to can use research.")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
user = tool_owner
|
||||
session_id = f"rp-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
if body.endpoint_id:
|
||||
from src.database import SessionLocal
|
||||
from src.database import ModelEndpoint
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.id == body.endpoint_id,
|
||||
ModelEndpoint.is_enabled == True,
|
||||
).first()
|
||||
if not ep:
|
||||
raise HTTPException(404, "Endpoint not found or disabled")
|
||||
base = normalize_base(ep.base_url)
|
||||
ep_url = build_chat_url(base)
|
||||
ep_headers = build_headers(ep.api_key, base)
|
||||
ep_model = body.model or ""
|
||||
if not ep_model:
|
||||
try:
|
||||
import json as _json
|
||||
models = _json.loads(ep.cached_models) if ep.cached_models else []
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("research")
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("utility")
|
||||
# When neither research nor utility is configured, use the user's
|
||||
# configured DEFAULT model (default_endpoint_id/default_model) rather
|
||||
# than arbitrarily grabbing the first enabled endpoint's first model
|
||||
# (which surfaced gpt-3.5). "Default" should mean the default model.
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("default")
|
||||
if not ep_url:
|
||||
ep_url, ep_model, ep_headers = resolve_endpoint("chat")
|
||||
if not ep_url:
|
||||
from src.database import SessionLocal
|
||||
from src.database import ModelEndpoint
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(
|
||||
ModelEndpoint.is_enabled == True,
|
||||
).first()
|
||||
if ep:
|
||||
base = normalize_base(ep.base_url)
|
||||
ep_url = build_chat_url(base)
|
||||
ep_headers = build_headers(ep.api_key, base)
|
||||
ep_model = ""
|
||||
if ep.cached_models:
|
||||
try:
|
||||
import json as _json
|
||||
models = _json.loads(ep.cached_models)
|
||||
if models:
|
||||
ep_model = _first_chat_model(models)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.close()
|
||||
if not ep_url:
|
||||
raise HTTPException(400, "No endpoints configured. Add one in Settings first.")
|
||||
if body.model:
|
||||
ep_model = body.model
|
||||
|
||||
# max_rounds=0 → "Auto", let AI decide; pass 20 as the safety cap.
|
||||
effective_max_rounds = body.max_rounds if body.max_rounds > 0 else 20
|
||||
research_handler.start_research(
|
||||
session_id=session_id,
|
||||
query=body.query,
|
||||
llm_endpoint=ep_url,
|
||||
llm_model=ep_model,
|
||||
max_time=body.max_time,
|
||||
llm_headers=ep_headers,
|
||||
max_rounds=effective_max_rounds,
|
||||
search_provider=body.search_provider or None,
|
||||
category=body.category or None,
|
||||
owner=user,
|
||||
)
|
||||
return {"session_id": session_id, "status": "running", "query": body.query}
|
||||
|
||||
@router.get("/api/research/stream/{session_id}")
|
||||
async def research_stream(session_id: str, request: Request):
|
||||
"""SSE stream of research progress events."""
|
||||
user = _require_user(request)
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
async def _generate():
|
||||
last_progress = None
|
||||
while True:
|
||||
status = research_handler.get_status(session_id)
|
||||
if status is None:
|
||||
yield f"data: {json.dumps({'status': 'not_found'})}\n\n"
|
||||
return
|
||||
st = status.get("status", "")
|
||||
progress = status.get("progress", {})
|
||||
if progress != last_progress:
|
||||
last_progress = progress
|
||||
yield f"data: {json.dumps({**progress, 'status': st})}\n\n"
|
||||
if st != "running":
|
||||
final = {'status': st, 'final': True}
|
||||
task = research_handler._active_tasks.get(session_id, {})
|
||||
if st == "error" and task.get("result"):
|
||||
final['error'] = str(task["result"])[:500]
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
return
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
return StreamingResponse(
|
||||
_generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
@router.post("/api/research/result-peek/{session_id}")
|
||||
async def research_result_peek(session_id: str, request: Request):
|
||||
"""Get research result without clearing it (for panel use)."""
|
||||
user = _require_user(request)
|
||||
if not _owns_in_memory(session_id, user):
|
||||
raise HTTPException(404, "No research found for this session")
|
||||
result = research_handler.get_result(session_id)
|
||||
if result is None:
|
||||
p = Path("data/deep_research") / f"{session_id}.json"
|
||||
if p.exists():
|
||||
d = json.loads(p.read_text())
|
||||
return {
|
||||
"result": d.get("result", ""),
|
||||
"sources": d.get("sources", []),
|
||||
"raw_findings": d.get("raw_findings", []),
|
||||
"category": d.get("category") or "",
|
||||
}
|
||||
raise HTTPException(404, "No research result available")
|
||||
sources = research_handler.get_sources(session_id) or []
|
||||
raw_findings = research_handler.get_raw_findings(session_id) or []
|
||||
return {"result": result, "sources": sources, "raw_findings": raw_findings, "category": ""}
|
||||
|
||||
@router.post("/api/research/spinoff/{session_id}")
|
||||
async def research_spinoff(session_id: str, request: Request):
|
||||
"""Create a new chat session pre-seeded with this research as context.
|
||||
|
||||
Reads the persisted research result + sources for `session_id`, creates
|
||||
a fresh session (inheriting endpoint/model/headers from the source
|
||||
session if available, otherwise from the resolved chat endpoint), and
|
||||
injects a single system message containing the report and sources so
|
||||
the user can ask follow-up questions in a clean conversation.
|
||||
"""
|
||||
_require_user(request)
|
||||
if session_manager is None:
|
||||
raise HTTPException(500, "session_manager not configured")
|
||||
|
||||
# Load research data — prefer in-memory result, fall back to disk
|
||||
result = research_handler.get_result(session_id)
|
||||
sources = research_handler.get_sources(session_id) or []
|
||||
query = ""
|
||||
|
||||
path = Path("data/deep_research") / f"{session_id}.json"
|
||||
if path.exists():
|
||||
try:
|
||||
disk = json.loads(path.read_text())
|
||||
if not result:
|
||||
result = disk.get("result")
|
||||
if not sources:
|
||||
sources = disk.get("sources", []) or []
|
||||
query = disk.get("query", "") or ""
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read research JSON for spinoff: {e}")
|
||||
|
||||
if not result:
|
||||
raise HTTPException(404, "No research result available for this session")
|
||||
|
||||
# Inherit endpoint/model/headers from the source session when possible.
|
||||
# For panel-launched research (rp-* IDs), there is no chat session, so
|
||||
# fall back through the same chain as /api/research/start: research →
|
||||
# utility → first enabled endpoint in the DB.
|
||||
ep_url, ep_model, ep_headers = "", "", {}
|
||||
try:
|
||||
src_sess = session_manager.get_session(session_id)
|
||||
ep_url = src_sess.endpoint_url or ""
|
||||
ep_model = src_sess.model or ""
|
||||
ep_headers = dict(src_sess.headers or {})
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def _merge(r_url, r_model, r_headers):
|
||||
nonlocal ep_url, ep_model, ep_headers
|
||||
if not ep_url and r_url:
|
||||
ep_url = r_url
|
||||
if not ep_model and r_model:
|
||||
ep_model = r_model
|
||||
if not ep_headers and r_headers:
|
||||
ep_headers = dict(r_headers)
|
||||
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("chat"))
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("research"))
|
||||
if not ep_url or not ep_model:
|
||||
_merge(*resolve_endpoint("utility"))
|
||||
if not ep_url or not ep_model:
|
||||
# Last resort: any enabled endpoint
|
||||
from src.database import SessionLocal
|
||||
from src.database import ModelEndpoint
|
||||
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
|
||||
if ep:
|
||||
base = normalize_base(ep.base_url)
|
||||
fallback_url = build_chat_url(base)
|
||||
fallback_headers = build_headers(ep.api_key, base)
|
||||
fallback_model = ""
|
||||
if ep.cached_models:
|
||||
try:
|
||||
models = json.loads(ep.cached_models)
|
||||
if models:
|
||||
fallback_model = models[0]
|
||||
except Exception:
|
||||
pass
|
||||
_merge(fallback_url, fallback_model, fallback_headers)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if not ep_url or not ep_model:
|
||||
raise HTTPException(400, "No endpoint configured — add one in Settings first")
|
||||
|
||||
# Create new session
|
||||
new_sid = str(uuid.uuid4())
|
||||
user = get_current_user(request)
|
||||
|
||||
title_query = (query or "research").strip()
|
||||
if len(title_query) > 60:
|
||||
title_query = title_query[:57] + "…"
|
||||
new_name = f"Follow-up: {title_query}"
|
||||
|
||||
new_sess = session_manager.create_session(
|
||||
session_id=new_sid,
|
||||
name=new_name,
|
||||
endpoint_url=ep_url,
|
||||
model=ep_model,
|
||||
rag=False,
|
||||
owner=user,
|
||||
)
|
||||
if ep_headers:
|
||||
new_sess.headers = ep_headers
|
||||
session_manager.save_sessions()
|
||||
try:
|
||||
from src.event_bus import fire_event
|
||||
fire_event("session_created", user)
|
||||
except Exception:
|
||||
logger.debug("session_created event dispatch failed", exc_info=True)
|
||||
|
||||
# Build the priming system message — report only, no sources injected.
|
||||
# The user can open the visual report for source details; keeping sources
|
||||
# out of the chat context saves tokens and avoids the AI fabricating
|
||||
# citations.
|
||||
date_str = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
primer = (
|
||||
f"[Research context — {date_str}]\n\n"
|
||||
f"The user previously ran a deep research investigation. Use the "
|
||||
f"report below as your primary knowledge base when answering "
|
||||
f"follow-up questions. If the user asks something not covered, "
|
||||
f"say so plainly rather than guessing.\n\n"
|
||||
f"=== ORIGINAL QUERY ===\n{query or '(not recorded)'}\n\n"
|
||||
f"=== REPORT ===\n{result}"
|
||||
)
|
||||
|
||||
from core.models import ChatMessage
|
||||
new_sess.add_message(ChatMessage(
|
||||
role="system",
|
||||
content=primer,
|
||||
metadata={"research_spinoff_from": session_id},
|
||||
))
|
||||
session_manager.save_sessions()
|
||||
|
||||
return {
|
||||
"session_id": new_sid,
|
||||
"name": new_name,
|
||||
"source_count": len(sources),
|
||||
}
|
||||
|
||||
return router
|
||||
111
routes/search_routes.py
Normal file
111
routes/search_routes.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Search routes — /api/search/config GET, /api/search POST."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
import time
|
||||
|
||||
from services.search import get_search_config, comprehensive_web_search, PROVIDER_INFO
|
||||
from services.search.core import _call_provider
|
||||
from services.search.providers import _get_provider_key, _get_search_instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _request_values(request: Request) -> Dict[str, Any]:
|
||||
"""Accept JSON, form data, or query params for search endpoints.
|
||||
|
||||
The browser UI posts FormData, while the agent's generic app_api tool
|
||||
posts JSON. FastAPI Form(...) rejects JSON with a 422 before our handler
|
||||
runs, which made the model think SearXNG was broken.
|
||||
"""
|
||||
values: Dict[str, Any] = dict(request.query_params)
|
||||
content_type = (request.headers.get("content-type") or "").lower()
|
||||
try:
|
||||
if "application/json" in content_type:
|
||||
body = await request.json()
|
||||
if isinstance(body, dict):
|
||||
values.update(body)
|
||||
else:
|
||||
form = await request.form()
|
||||
values.update(dict(form))
|
||||
except Exception:
|
||||
pass
|
||||
return values
|
||||
|
||||
|
||||
def setup_search_routes(config) -> APIRouter:
|
||||
router = APIRouter(tags=["search"])
|
||||
|
||||
@router.get("/api/search/config")
|
||||
async def get_search_settings() -> Dict[str, Any]:
|
||||
return get_search_config()
|
||||
|
||||
@router.post("/api/search")
|
||||
async def do_web_search(request: Request) -> Dict[str, Any]:
|
||||
"""Standalone web search — returns context string + source list.
|
||||
|
||||
Used by Compare mode to pre-search once and share results across panes.
|
||||
"""
|
||||
values = await _request_values(request)
|
||||
query = str(values.get("query") or values.get("q") or "").strip()
|
||||
if not query:
|
||||
return {"context": "", "sources": [], "error": "query is required"}
|
||||
time_filter = values.get("time_filter") or values.get("freshness")
|
||||
if time_filter is not None:
|
||||
time_filter = str(time_filter).strip() or None
|
||||
try:
|
||||
context, sources = comprehensive_web_search(
|
||||
query, return_sources=True, time_filter=time_filter,
|
||||
)
|
||||
return {"context": context, "sources": sources}
|
||||
except Exception as e:
|
||||
logger.error(f"Standalone web search failed: {e}")
|
||||
return {"context": "", "sources": [], "error": str(e)}
|
||||
|
||||
@router.get("/api/search/providers")
|
||||
async def list_search_providers():
|
||||
"""Return available search providers with config status."""
|
||||
providers = []
|
||||
for pid, (label, needs_key, needs_url) in PROVIDER_INFO.items():
|
||||
if pid == "disabled":
|
||||
continue
|
||||
available = True
|
||||
if needs_key and not _get_provider_key(pid):
|
||||
available = False
|
||||
if needs_url and pid == "searxng" and not _get_search_instance():
|
||||
available = False
|
||||
providers.append({
|
||||
"id": pid,
|
||||
"label": label,
|
||||
"available": available,
|
||||
})
|
||||
return providers
|
||||
|
||||
@router.post("/api/search/query")
|
||||
async def search_with_provider(request: Request) -> Dict[str, Any]:
|
||||
"""Search using a specific provider. Used by compare search mode."""
|
||||
values = await _request_values(request)
|
||||
query = str(values.get("query") or values.get("q") or "").strip()
|
||||
provider = str(values.get("provider") or "").strip()
|
||||
try:
|
||||
count = int(values.get("count") or values.get("limit") or 10)
|
||||
except Exception:
|
||||
count = 10
|
||||
if not query:
|
||||
return {"results": [], "provider": provider, "error": "query is required"}
|
||||
if provider not in PROVIDER_INFO or provider == "disabled":
|
||||
return {"results": [], "provider": provider, "error": "Unknown provider"}
|
||||
t0 = time.time()
|
||||
try:
|
||||
results = _call_provider(provider, query, min(count, 20))
|
||||
elapsed = round(time.time() - t0, 2)
|
||||
return {"results": results, "provider": provider, "time": elapsed}
|
||||
except Exception as e:
|
||||
elapsed = round(time.time() - t0, 2)
|
||||
logger.error(f"Search provider {provider} failed: {e}")
|
||||
return {"results": [], "provider": provider, "time": elapsed, "error": str(e)}
|
||||
|
||||
return router
|
||||
1059
routes/session_routes.py
Normal file
1059
routes/session_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
608
routes/shell_routes.py
Normal file
608
routes/shell_routes.py
Normal file
@@ -0,0 +1,608 @@
|
||||
"""Shell routes — user-facing command execution endpoint."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pty
|
||||
import fcntl
|
||||
import shlex
|
||||
import shutil
|
||||
import uuid
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def _require_admin(request: Request):
|
||||
"""Reject non-admin callers. Shell exec is admin-only — never expose to
|
||||
regular users; that's RCE-after-signup."""
|
||||
auth_manager = getattr(request.app.state, "auth_manager", None)
|
||||
if not auth_manager:
|
||||
# No auth at all — only safe in fully-trusted localhost dev mode
|
||||
return
|
||||
user = getattr(request.state, "current_user", None)
|
||||
# In-process tool loopback. The AuthMiddleware already validated the
|
||||
# internal token + loopback client before setting this marker, so
|
||||
# honour it here as admin-equivalent.
|
||||
if user == "internal-tool":
|
||||
return
|
||||
if not user or user == "api":
|
||||
raise HTTPException(403, "Admin only")
|
||||
if not auth_manager.is_admin(user):
|
||||
raise HTTPException(403, "Admin only")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _find_line_break(buf):
|
||||
"""Find next line terminator in buffer. Returns (index, separator_length) or (-1, 0)."""
|
||||
ni = buf.find(b"\n")
|
||||
ri = buf.find(b"\r")
|
||||
if ni == -1 and ri == -1:
|
||||
return -1, 0
|
||||
if ni == -1:
|
||||
return ri, 1
|
||||
if ri == -1:
|
||||
return ni, 1
|
||||
if ri < ni:
|
||||
return ri, (2 if ri + 1 == ni else 1)
|
||||
return ni, 1
|
||||
|
||||
|
||||
EXEC_TIMEOUT = 30 # seconds — shorter than agent's 60s
|
||||
STREAM_TIMEOUT = 120 # default for short commands
|
||||
MAX_OUTPUT = 200_000 # truncate limit
|
||||
TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux"
|
||||
|
||||
|
||||
class ShellExecRequest(BaseModel):
|
||||
command: str
|
||||
timeout: int | None = None # optional override; 0 = no timeout (run until client disconnects)
|
||||
use_pty: bool = False # use pseudo-TTY (for progress bars)
|
||||
use_tmux: bool = False # run in tmux session (survives browser disconnect)
|
||||
|
||||
|
||||
async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, Any]:
|
||||
"""Run a shell command and return stdout/stderr/exit_code."""
|
||||
proc = None
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=str(Path.home()),
|
||||
)
|
||||
stdout_b, stderr_b = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=timeout
|
||||
)
|
||||
stdout = stdout_b.decode(errors="replace")[:MAX_OUTPUT]
|
||||
stderr = stderr_b.decode(errors="replace")[:MAX_OUTPUT]
|
||||
return {"stdout": stdout, "stderr": stderr, "exit_code": proc.returncode}
|
||||
except asyncio.TimeoutError:
|
||||
if proc:
|
||||
try:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
return {"stdout": "", "stderr": f"Command timed out after {timeout}s", "exit_code": -1}
|
||||
except Exception as e:
|
||||
return {"stdout": "", "stderr": str(e), "exit_code": -1}
|
||||
|
||||
|
||||
async def _generate_pty(cmd: str, timeout: int, request: Request):
|
||||
"""Run command in a pseudo-TTY so tqdm/progress bars work natively."""
|
||||
loop = asyncio.get_event_loop()
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
|
||||
# Set master to non-blocking
|
||||
flags = fcntl.fcntl(master_fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
cmd,
|
||||
stdin=slave_fd,
|
||||
stdout=slave_fd,
|
||||
stderr=slave_fd,
|
||||
cwd=str(Path.home()),
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
os.close(slave_fd) # parent doesn't need the slave side
|
||||
|
||||
deadline = (loop.time() + timeout) if timeout else None
|
||||
buf = b""
|
||||
process_done = asyncio.Event()
|
||||
|
||||
async def _wait_proc():
|
||||
await proc.wait()
|
||||
process_done.set()
|
||||
|
||||
wait_task = asyncio.create_task(_wait_proc())
|
||||
|
||||
try:
|
||||
while not process_done.is_set():
|
||||
if deadline and loop.time() > deadline:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': f'Command timed out after {timeout}s'})}\n\n"
|
||||
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
|
||||
return
|
||||
|
||||
# Check client disconnect
|
||||
if await request.is_disconnected():
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
return
|
||||
|
||||
# Read available data from PTY
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, _pty_read, master_fd),
|
||||
timeout=2.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except OSError:
|
||||
break
|
||||
|
||||
if chunk is None:
|
||||
# No data yet, keep waiting
|
||||
continue
|
||||
if chunk == b"":
|
||||
# EOF — process closed the PTY
|
||||
break
|
||||
|
||||
buf += chunk
|
||||
# Split on \r or \n
|
||||
while True:
|
||||
idx, sep_len = _find_line_break(buf)
|
||||
if idx == -1:
|
||||
break
|
||||
line = buf[:idx].decode(errors="replace")
|
||||
buf = buf[idx + sep_len:]
|
||||
if line:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
|
||||
# Drain any remaining PTY output after process exits
|
||||
try:
|
||||
while True:
|
||||
rest = _pty_read(master_fd)
|
||||
if rest is None or rest == b"":
|
||||
break
|
||||
buf += rest
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Flush remaining buffer
|
||||
if buf:
|
||||
# Split remaining buffer same as above
|
||||
while True:
|
||||
idx, sep_len = _find_line_break(buf)
|
||||
if idx == -1:
|
||||
break
|
||||
line = buf[:idx].decode(errors="replace")
|
||||
buf = buf[idx + sep_len:]
|
||||
if line:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
if buf:
|
||||
text = buf.decode(errors="replace").strip()
|
||||
if text:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': text})}\n\n"
|
||||
|
||||
await wait_task
|
||||
yield f"data: {json.dumps({'exit_code': proc.returncode})}\n\n"
|
||||
|
||||
except Exception as e:
|
||||
try:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': str(e)})}\n\n"
|
||||
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
|
||||
finally:
|
||||
wait_task.cancel()
|
||||
try:
|
||||
os.close(master_fd)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _pty_read(fd: int) -> bytes | None:
|
||||
"""Blocking read from PTY fd. Called via run_in_executor.
|
||||
Returns bytes on data, None on timeout (no data yet)."""
|
||||
import select
|
||||
r, _, _ = select.select([fd], [], [], 1.0)
|
||||
if r:
|
||||
try:
|
||||
data = os.read(fd, 4096)
|
||||
return data if data else b"" # empty = EOF
|
||||
except OSError:
|
||||
return b"" # fd closed = EOF
|
||||
return None # timeout, no data yet
|
||||
|
||||
|
||||
async def _generate_tmux(cmd: str, request: Request):
|
||||
"""Run command in a tmux session. Streams output via a log file.
|
||||
The tmux session survives browser disconnect — user can reconnect or
|
||||
`tmux attach -t <name>` to see it live."""
|
||||
TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
session_id = f"cookbook-{uuid.uuid4().hex[:8]}"
|
||||
log_path = TMUX_LOG_DIR / f"{session_id}.log"
|
||||
|
||||
# Write a wrapper script that runs the command, tees output, and records exit code.
|
||||
# Using a script avoids shell quoting issues with the tmux command.
|
||||
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
|
||||
script_path.write_text(
|
||||
f"#!/bin/bash\n"
|
||||
f"ODYSSEUS_USER_SHELL=\"${{SHELL:-}}\"\n"
|
||||
f"if [ -n \"$ODYSSEUS_USER_SHELL\" ] && [ -x \"$ODYSSEUS_USER_SHELL\" ]; then\n"
|
||||
f" ODYSSEUS_USER_PATH=\"$(\"$ODYSSEUS_USER_SHELL\" -ic 'printf \"__ODYSSEUS_PATH__%s\\n\" \"$PATH\"' 2>/dev/null | sed -n 's/^__ODYSSEUS_PATH__//p' | tail -n 1 || true)\"\n"
|
||||
f" if [ -n \"$ODYSSEUS_USER_PATH\" ]; then export PATH=\"$ODYSSEUS_USER_PATH:$PATH\"; fi\n"
|
||||
f"fi\n"
|
||||
f"{cmd} 2>&1 | tee '{log_path}'\n"
|
||||
f"EC=${{PIPESTATUS[0]}}\n"
|
||||
f"echo ':::EXIT_CODE:::'$EC >> '{log_path}'\n"
|
||||
f"rm -f '{script_path}'\n"
|
||||
f"exit $EC\n"
|
||||
)
|
||||
script_path.chmod(0o755)
|
||||
logger.info("tmux wrapper script created: session=%s path=%s", session_id, script_path)
|
||||
|
||||
tmux_cmd = f"tmux new-session -d -s {session_id} {shlex.quote(str(script_path))}"
|
||||
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
tmux_cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
await proc.wait()
|
||||
if proc.returncode != 0:
|
||||
stderr = (await proc.stderr.read()).decode(errors="replace")
|
||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': f'Failed to start tmux: {stderr}'})}\n\n"
|
||||
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
|
||||
return
|
||||
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': f'Started tmux session: {session_id}'})}\n\n"
|
||||
|
||||
# Tail the log file, streaming new lines as SSE
|
||||
lines_sent = 0
|
||||
exit_code = None
|
||||
|
||||
while True:
|
||||
# Check client disconnect
|
||||
if await request.is_disconnected():
|
||||
# tmux keeps running — that's the whole point
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': f'Disconnected. tmux session {session_id} continues in background.'})}\n\n"
|
||||
return
|
||||
|
||||
# Read new lines from log
|
||||
try:
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(errors="replace").splitlines()
|
||||
new_lines = lines[lines_sent:]
|
||||
for line in new_lines:
|
||||
if line.startswith(":::EXIT_CODE:::"):
|
||||
try:
|
||||
exit_code = int(line.split(":::")[-1])
|
||||
except ValueError:
|
||||
exit_code = -1
|
||||
else:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
lines_sent = len(lines)
|
||||
except Exception as e:
|
||||
logger.debug(f"tmux log read error: {e}")
|
||||
|
||||
if exit_code is not None:
|
||||
break
|
||||
|
||||
# Check if tmux session is still alive
|
||||
check = await asyncio.create_subprocess_shell(
|
||||
f"tmux has-session -t {session_id} 2>/dev/null",
|
||||
stdout=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.DEVNULL,
|
||||
)
|
||||
await check.wait()
|
||||
if check.returncode != 0:
|
||||
# Session ended — do one final read
|
||||
await asyncio.sleep(0.5)
|
||||
if log_path.exists():
|
||||
lines = log_path.read_text(errors="replace").splitlines()
|
||||
for line in lines[lines_sent:]:
|
||||
if line.startswith(":::EXIT_CODE:::"):
|
||||
try:
|
||||
exit_code = int(line.split(":::")[-1])
|
||||
except ValueError:
|
||||
exit_code = -1
|
||||
else:
|
||||
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
|
||||
if exit_code is None:
|
||||
exit_code = 0
|
||||
break
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
yield f"data: {json.dumps({'exit_code': exit_code})}\n\n"
|
||||
|
||||
# Clean up log file
|
||||
try:
|
||||
log_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def setup_shell_routes() -> APIRouter:
|
||||
router = APIRouter(tags=["shell"])
|
||||
|
||||
@router.post("/api/shell/exec")
|
||||
async def shell_exec(request: Request, req: ShellExecRequest) -> Dict[str, Any]:
|
||||
"""Execute a shell command and return output. Admin only."""
|
||||
_require_admin(request)
|
||||
cmd = req.command.strip()
|
||||
if not cmd:
|
||||
return {"stdout": "", "stderr": "No command provided", "exit_code": 1}
|
||||
|
||||
logger.info("User shell exec requested: length=%d", len(cmd))
|
||||
result = await _exec_shell(cmd, timeout=EXEC_TIMEOUT)
|
||||
return result
|
||||
|
||||
@router.post("/api/shell/stream")
|
||||
async def shell_stream(request: Request, req: ShellExecRequest):
|
||||
"""Execute a shell command and stream output line-by-line via SSE. Admin only."""
|
||||
_require_admin(request)
|
||||
cmd = req.command.strip()
|
||||
if not cmd:
|
||||
async def empty():
|
||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'No command provided'})}\n\n"
|
||||
yield f"data: {json.dumps({'exit_code': 1})}\n\n"
|
||||
return StreamingResponse(empty(), media_type="text/event-stream")
|
||||
|
||||
timeout = req.timeout if req.timeout is not None else STREAM_TIMEOUT
|
||||
use_pty = req.use_pty
|
||||
use_tmux = req.use_tmux
|
||||
logger.info(
|
||||
"User shell stream requested: timeout=%s pty=%s tmux=%s length=%d",
|
||||
"none" if timeout == 0 else f"{timeout}s",
|
||||
use_pty,
|
||||
use_tmux,
|
||||
len(cmd),
|
||||
)
|
||||
|
||||
if use_tmux:
|
||||
return StreamingResponse(
|
||||
_generate_tmux(cmd, request),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
if use_pty:
|
||||
return StreamingResponse(
|
||||
_generate_pty(cmd, timeout, request),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
async def generate():
|
||||
proc = None
|
||||
reader_tasks = []
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=str(Path.home()),
|
||||
)
|
||||
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def _reader(stream, name):
|
||||
"""Read chunks, split on \\n or \\r for progress bar support."""
|
||||
try:
|
||||
buf = b""
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
if buf:
|
||||
await q.put((name, buf.decode(errors="replace").rstrip("\r\n")))
|
||||
break
|
||||
buf += chunk
|
||||
while True:
|
||||
idx, sep_len = _find_line_break(buf)
|
||||
if idx == -1:
|
||||
break
|
||||
line = buf[:idx].decode(errors="replace")
|
||||
buf = buf[idx + sep_len:]
|
||||
if line:
|
||||
await q.put((name, line))
|
||||
finally:
|
||||
await q.put((name, None))
|
||||
|
||||
reader_tasks = [
|
||||
asyncio.create_task(_reader(proc.stdout, "stdout")),
|
||||
asyncio.create_task(_reader(proc.stderr, "stderr")),
|
||||
]
|
||||
|
||||
finished = 0
|
||||
deadline = (asyncio.get_event_loop().time() + timeout) if timeout else None
|
||||
while finished < 2:
|
||||
if deadline:
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
if remaining <= 0:
|
||||
raise asyncio.TimeoutError()
|
||||
wait = min(remaining, 2.0)
|
||||
else:
|
||||
wait = 2.0
|
||||
|
||||
try:
|
||||
name, text = await asyncio.wait_for(q.get(), timeout=wait)
|
||||
except asyncio.TimeoutError:
|
||||
if await request.is_disconnected():
|
||||
if proc:
|
||||
proc.kill()
|
||||
return
|
||||
continue
|
||||
|
||||
if text is None:
|
||||
finished += 1
|
||||
continue
|
||||
yield f"data: {json.dumps({'stream': name, 'data': text})}\n\n"
|
||||
|
||||
await proc.wait()
|
||||
yield f"data: {json.dumps({'exit_code': proc.returncode})}\n\n"
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if proc:
|
||||
try:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': f'Command timed out after {timeout}s'})}\n\n"
|
||||
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
|
||||
except Exception as e:
|
||||
yield f"data: {json.dumps({'stream': 'stderr', 'data': str(e)})}\n\n"
|
||||
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
|
||||
finally:
|
||||
for t in reader_tasks:
|
||||
t.cancel()
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream")
|
||||
|
||||
@router.get("/api/cookbook/packages")
|
||||
async def list_packages(host: str | None = None, ssh_port: str | None = None, venv: str | None = None):
|
||||
"""Check which optional packages are installed.
|
||||
|
||||
Local-target packages are checked in-process. Remote-target packages
|
||||
(vllm, sglang, llama_cpp, diffusers, hf_transfer) are checked on the SELECTED
|
||||
server over SSH, inside its venv — otherwise installing on a remote box
|
||||
never reflected because the check only ever looked at the local host.
|
||||
"""
|
||||
import importlib, shlex, json as _json
|
||||
packages = [
|
||||
# ── System ── OS binaries, not pip packages
|
||||
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
|
||||
{"name": "docker", "pip": "", "desc": "Required only for Docker-backed launch commands", "category": "System", "target": "remote", "kind": "system", "install_hint": "Install Docker on the selected server and allow this user to run docker."},
|
||||
# ── LLM ── installs on GPU servers for model serving/downloading
|
||||
{"name": "hf_transfer", "pip": "hf_transfer", "desc": "Fast model downloads from HuggingFace", "category": "LLM", "target": "remote"},
|
||||
{"name": "llama_cpp", "pip": "llama-cpp-python[server]", "desc": "Serve GGUF models via llama.cpp", "category": "LLM", "target": "remote"},
|
||||
{"name": "sglang", "pip": "sglang[all]", "desc": "Serve HF safetensors models via SGLang", "category": "LLM", "target": "remote"},
|
||||
{"name": "vllm", "pip": "vllm", "desc": "High-throughput LLM serving engine", "category": "LLM", "target": "remote"},
|
||||
# ── Image ── editor + diffusion model serving
|
||||
{"name": "diffusers", "pip": "diffusers", "desc": "Image generation pipelines (SD, Flux)", "category": "Image", "target": "remote"},
|
||||
{"name": "rembg", "pip": "rembg[gpu]", "desc": "AI background removal for image editor", "category": "Image", "target": "local"},
|
||||
{"name": "realesrgan", "pip": "realesrgan", "desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.", "category": "Image", "target": "local"},
|
||||
# ── Tools ──
|
||||
{"name": "playwright", "pip": "playwright", "desc": "Browser automation for web tools", "category": "Tools", "target": "local"},
|
||||
]
|
||||
# Remote check: for remote-target packages, probe the selected server's
|
||||
# venv over SSH so a remote `pip install` actually reflects here.
|
||||
remote_status: dict = {}
|
||||
remote_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") != "system"]
|
||||
remote_system_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") == "system"]
|
||||
if host and remote_names:
|
||||
try:
|
||||
names_lit = ",".join(repr(n) for n in remote_names)
|
||||
py = (
|
||||
"import importlib.util,json,shutil;"
|
||||
f"names=[{names_lit}];"
|
||||
"status={n:(importlib.util.find_spec(n) is not None) for n in names};"
|
||||
"status['llama_cpp']=status.get('llama_cpp',False) or shutil.which('llama-server') is not None;"
|
||||
"print(json.dumps(status))"
|
||||
)
|
||||
src = ""
|
||||
if venv:
|
||||
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
|
||||
# NOT shlex.quoted: a leading ~ must stay shell-expandable on
|
||||
# the remote (quoting it breaks `~/venv` → activation fails →
|
||||
# the && short-circuits and every package reads as missing).
|
||||
src = f". {act} && "
|
||||
inner = f"{src}python3 -c {shlex.quote(py)}"
|
||||
pf = f"-p {ssh_port} " if ssh_port and ssh_port not in ("", "22") else ""
|
||||
ssh_cmd = (
|
||||
f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {pf}"
|
||||
f"{shlex.quote(host)} {shlex.quote(inner)}"
|
||||
)
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||
txt = out.decode("utf-8", errors="replace").strip()
|
||||
# The activate script can emit noise — take the last JSON line.
|
||||
for line in reversed(txt.splitlines()):
|
||||
line = line.strip()
|
||||
if line.startswith("{"):
|
||||
remote_status = _json.loads(line)
|
||||
break
|
||||
except Exception:
|
||||
remote_status = {}
|
||||
if host and remote_system_names:
|
||||
try:
|
||||
checks = []
|
||||
for name in remote_system_names:
|
||||
qn = shlex.quote(name)
|
||||
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
|
||||
inner = " ; ".join(checks)
|
||||
pf = f"-p {ssh_port} " if ssh_port and ssh_port not in ("", "22") else ""
|
||||
ssh_cmd = (
|
||||
f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {pf}"
|
||||
f"{shlex.quote(host)} {shlex.quote(inner)}"
|
||||
)
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
|
||||
txt = out.decode("utf-8", errors="replace").strip()
|
||||
for line in txt.splitlines():
|
||||
name, sep, value = line.strip().partition("=")
|
||||
if sep and name in remote_system_names:
|
||||
remote_status[name] = value == "1"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for pkg in packages:
|
||||
if host and pkg.get("target") == "remote":
|
||||
pkg["installed"] = bool(remote_status.get(pkg["name"], False))
|
||||
continue
|
||||
if pkg.get("kind") == "system":
|
||||
pkg["installed"] = shutil.which(pkg["name"]) is not None
|
||||
continue
|
||||
try:
|
||||
if pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
|
||||
pkg["installed"] = True
|
||||
continue
|
||||
importlib.import_module(pkg["name"])
|
||||
pkg["installed"] = True
|
||||
except ImportError:
|
||||
pkg["installed"] = False
|
||||
return {"packages": packages}
|
||||
|
||||
@router.post("/api/cookbook/packages/install")
|
||||
async def install_package(request: Request):
|
||||
"""Install a package via pip. Admin only — pip install is effectively code exec."""
|
||||
_require_admin(request)
|
||||
import sys as _sys
|
||||
body = await request.json()
|
||||
pip_name = body.get("pip")
|
||||
if not pip_name:
|
||||
return {"ok": False, "error": "No package specified"}
|
||||
# Validate against known packages to prevent arbitrary pip install
|
||||
known = {
|
||||
"rembg[gpu]", "hf_transfer", "llama-cpp-python[server]", "sglang[all]", "diffusers",
|
||||
"TTS", "bark", "faster-whisper", "playwright", "realesrgan", "gfpgan",
|
||||
"insightface", "onnxruntime-gpu", "onnxruntime", "hdbscan",
|
||||
}
|
||||
if pip_name not in known:
|
||||
return {"ok": False, "error": f"Unknown package: {pip_name}"}
|
||||
cmd = [_sys.executable, "-m", "pip", "install", pip_name]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
if proc.returncode == 0:
|
||||
return {"ok": True, "output": stdout.decode()[-200:]}
|
||||
return {"ok": False, "error": stderr.decode()[-300:]}
|
||||
|
||||
return router
|
||||
123
routes/signature_routes.py
Normal file
123
routes/signature_routes.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Signature routes — CRUD for the user's saved visual signatures.
|
||||
|
||||
Signatures are reusable image stamps (drawn once, applied to many things):
|
||||
PDF form fields, email composition, document insertion. Each signature is
|
||||
stored as a base64 PNG so it can be embedded inline anywhere without a
|
||||
separate fetch.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import SessionLocal, Signature
|
||||
from src.auth_helpers import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DATA_URL_RE = re.compile(
|
||||
r'^data:image/(?P<fmt>png|jpeg|jpg);base64,(?P<data>.+)$',
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
class SignatureCreate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
data: str # base64 PNG, with or without `data:image/png;base64,` prefix
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
svg: Optional[str] = None
|
||||
|
||||
|
||||
def _to_dict(s: Signature) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"data_url": f"data:image/png;base64,{s.data_png}",
|
||||
"width": s.width,
|
||||
"height": s.height,
|
||||
"created_at": (s.created_at.isoformat() + "Z") if s.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def setup_signature_routes() -> APIRouter:
|
||||
router = APIRouter(tags=["signatures"])
|
||||
|
||||
@router.get("/api/signatures")
|
||||
async def list_signatures(request: Request) -> Dict[str, Any]:
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(Signature)
|
||||
if user is not None:
|
||||
# SECURITY: strict ownership — the previous OR predicate
|
||||
# returned every null-owner signature to every user.
|
||||
q = q.filter(Signature.owner == user)
|
||||
sigs = q.order_by(Signature.created_at.desc()).all()
|
||||
return {"signatures": [_to_dict(s) for s in sigs]}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/api/signatures")
|
||||
async def create_signature(request: Request, req: SignatureCreate) -> Dict[str, Any]:
|
||||
user = get_current_user(request)
|
||||
raw = (req.data or "").strip()
|
||||
m = _DATA_URL_RE.match(raw)
|
||||
b64 = m.group("data") if m else raw
|
||||
try:
|
||||
payload = base64.b64decode(b64, validate=True)
|
||||
if not payload:
|
||||
raise ValueError("empty payload")
|
||||
except Exception:
|
||||
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
|
||||
|
||||
sig = Signature(
|
||||
id=str(uuid.uuid4()),
|
||||
owner=user,
|
||||
name=(req.name or "Signature").strip() or "Signature",
|
||||
data_png=b64,
|
||||
width=req.width,
|
||||
height=req.height,
|
||||
svg=req.svg,
|
||||
)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db.add(sig)
|
||||
db.commit()
|
||||
db.refresh(sig)
|
||||
return _to_dict(sig)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Failed to save signature: {e}")
|
||||
raise HTTPException(500, f"Failed to save signature: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.delete("/api/signatures/{sig_id}")
|
||||
async def delete_signature(sig_id: str, request: Request) -> Dict[str, Any]:
|
||||
user = get_current_user(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
sig = db.query(Signature).filter(Signature.id == sig_id).first()
|
||||
if not sig:
|
||||
raise HTTPException(404, "Signature not found")
|
||||
if user and sig.owner != user:
|
||||
raise HTTPException(403, "Not your signature")
|
||||
db.delete(sig)
|
||||
db.commit()
|
||||
return {"deleted": sig_id}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(500, f"Failed to delete signature: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return router
|
||||
1530
routes/skills_routes.py
Normal file
1530
routes/skills_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
55
routes/stt_routes.py
Normal file
55
routes/stt_routes.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# routes/stt_routes.py
|
||||
"""STT API routes — multi-provider (local Whisper, API endpoint, browser)."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_stt_routes(stt_service):
|
||||
"""Setup STT routes with the provided STT service"""
|
||||
router = APIRouter(prefix="/api/stt", tags=["stt"])
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stt_stats():
|
||||
"""Get STT service statistics"""
|
||||
try:
|
||||
return stt_service.get_stats()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get STT stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/transcribe")
|
||||
async def transcribe_audio(file: UploadFile = File(...)):
|
||||
"""Transcribe uploaded audio file to text"""
|
||||
try:
|
||||
if not stt_service.available:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={"message": "STT service not available or set to browser mode"}
|
||||
)
|
||||
|
||||
audio_bytes = await file.read()
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail={"message": "Empty audio file"})
|
||||
|
||||
text = stt_service.transcribe(audio_bytes)
|
||||
if text is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": "Transcription failed"}
|
||||
)
|
||||
|
||||
return {"text": text}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription error: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": f"Transcription failed: {str(e)}"}
|
||||
)
|
||||
|
||||
return router
|
||||
910
routes/task_routes.py
Normal file
910
routes/task_routes.py
Normal file
@@ -0,0 +1,910 @@
|
||||
"""CRUD routes for scheduled tasks."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.database import SessionLocal, ScheduledTask, TaskRun
|
||||
from src.auth_helpers import get_current_user
|
||||
from src.task_scheduler import compute_next_run, HOUSEKEEPING_DEFAULTS
|
||||
from routes.prefs_routes import _load_for_user, _save_for_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskCreate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
task_type: str = "llm" # "llm" | "action" | "research"
|
||||
action: Optional[str] = None # builtin action name
|
||||
schedule: Optional[str] = None # "once" | "daily" | "weekly" | "monthly" | "cron"
|
||||
scheduled_time: str = "09:00" # HH:MM
|
||||
scheduled_day: Optional[int] = None # day-of-week (0=Mon) or day-of-month
|
||||
scheduled_date: Optional[str] = None # ISO datetime for "once"
|
||||
cron_expression: Optional[str] = None # cron string e.g. "*/5 * * * *"
|
||||
trigger_type: str = "schedule" # "schedule" | "event" | "webhook"
|
||||
trigger_event: Optional[str] = None # e.g. "session_created"
|
||||
trigger_count: Optional[int] = None # fire every N events
|
||||
output_target: str = "session"
|
||||
model: Optional[str] = None
|
||||
endpoint_url: Optional[str] = None
|
||||
then_task_id: Optional[str] = None # chain: run this task after success
|
||||
notifications_enabled: Optional[bool] = None # None lets action-specific defaults apply
|
||||
|
||||
|
||||
class TaskUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
prompt: Optional[str] = None
|
||||
task_type: Optional[str] = None
|
||||
action: Optional[str] = None
|
||||
schedule: Optional[str] = None
|
||||
scheduled_time: Optional[str] = None
|
||||
scheduled_day: Optional[int] = None
|
||||
scheduled_date: Optional[str] = None
|
||||
cron_expression: Optional[str] = None
|
||||
trigger_type: Optional[str] = None
|
||||
trigger_event: Optional[str] = None
|
||||
trigger_count: Optional[int] = None
|
||||
output_target: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
endpoint_url: Optional[str] = None
|
||||
then_task_id: Optional[str] = None
|
||||
notifications_enabled: Optional[bool] = None
|
||||
|
||||
|
||||
def _display_task_name(t: ScheduledTask) -> str:
|
||||
defs = HOUSEKEEPING_DEFAULTS.get(t.action) if t.action else None
|
||||
if defs and (t.name or "") in set(defs.get("legacy_names") or []):
|
||||
return defs["name"]
|
||||
return t.name
|
||||
|
||||
|
||||
def _task_to_dict(t: ScheduledTask, include_last_run_result: bool = False) -> dict:
|
||||
defs = HOUSEKEEPING_DEFAULTS.get(t.action) if t.action else None
|
||||
d = {
|
||||
"id": t.id,
|
||||
"name": _display_task_name(t),
|
||||
"prompt": t.prompt,
|
||||
"task_type": t.task_type or "llm",
|
||||
"action": t.action,
|
||||
"schedule": t.schedule,
|
||||
"scheduled_time": t.scheduled_time,
|
||||
"scheduled_day": t.scheduled_day,
|
||||
"scheduled_date": t.scheduled_date.isoformat() + "Z" if t.scheduled_date else None,
|
||||
"cron_expression": t.cron_expression,
|
||||
"trigger_type": t.trigger_type or "schedule",
|
||||
"trigger_event": t.trigger_event,
|
||||
"trigger_count": t.trigger_count,
|
||||
"trigger_counter": t.trigger_counter or 0,
|
||||
"next_run": t.next_run.isoformat() + "Z" if t.next_run else None,
|
||||
"last_run": t.last_run.isoformat() + "Z" if t.last_run else None,
|
||||
"status": t.status,
|
||||
"output_target": t.output_target,
|
||||
"session_id": t.session_id,
|
||||
"crew_member_id": getattr(t, "crew_member_id", None),
|
||||
"model": t.model,
|
||||
"endpoint_url": t.endpoint_url,
|
||||
"run_count": t.run_count or 0,
|
||||
"then_task_id": t.then_task_id,
|
||||
"notifications_enabled": bool(getattr(t, "notifications_enabled", True)),
|
||||
"webhook_token": t.webhook_token if (t.trigger_type or "schedule") == "webhook" else None,
|
||||
"created_at": t.created_at.isoformat() + "Z" if t.created_at else None,
|
||||
"updated_at": t.updated_at.isoformat() + "Z" if t.updated_at else None,
|
||||
}
|
||||
# Built-in housekeeping tasks (identified by their action) are flagged so
|
||||
# the UI can mark them and offer "revert to default" once altered.
|
||||
d["is_builtin"] = defs is not None
|
||||
if defs:
|
||||
default_names = {defs["name"], *set(defs.get("legacy_names") or [])}
|
||||
d["is_modified"] = (
|
||||
(t.name or "") not in default_names
|
||||
or (t.schedule or "") != (defs["schedule"] or "")
|
||||
or (t.scheduled_time or "") != (defs["scheduled_time"] or "")
|
||||
or (t.cron_expression or "") != (defs["cron_expression"] or "")
|
||||
)
|
||||
else:
|
||||
d["is_modified"] = False
|
||||
if include_last_run_result and t.runs:
|
||||
last = t.runs[0] # ordered desc by started_at
|
||||
d["last_run_status"] = last.status
|
||||
d["last_run_result"] = (last.result or last.error or "")[:500]
|
||||
return d
|
||||
|
||||
|
||||
def _run_to_dict(r: TaskRun) -> dict:
|
||||
return {
|
||||
"id": r.id,
|
||||
"task_id": r.task_id,
|
||||
"started_at": r.started_at.isoformat() + "Z" if r.started_at else None,
|
||||
"finished_at": r.finished_at.isoformat() + "Z" if r.finished_at else None,
|
||||
"status": r.status,
|
||||
"result": r.result,
|
||||
"error": r.error,
|
||||
"tokens_used": r.tokens_used,
|
||||
"model": r.model,
|
||||
}
|
||||
|
||||
|
||||
def _run_research_id(task: ScheduledTask) -> str:
|
||||
if (task.task_type or "llm") == "research" and task.session_id:
|
||||
return task.session_id
|
||||
return ""
|
||||
|
||||
|
||||
def _resolve_run_endpoint(db, task: ScheduledTask, run: TaskRun) -> str:
|
||||
"""Best-effort endpoint URL for reopening a task run in chat."""
|
||||
if getattr(task, "endpoint_url", None):
|
||||
return task.endpoint_url or ""
|
||||
|
||||
try:
|
||||
if getattr(task, "session_id", None):
|
||||
from core.database import Session as DbSession
|
||||
sess = db.query(DbSession).filter(DbSession.id == task.session_id).first()
|
||||
if sess and sess.endpoint_url:
|
||||
return sess.endpoint_url or ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model = (getattr(run, "model", None) or getattr(task, "model", None) or "").strip()
|
||||
if not model:
|
||||
return ""
|
||||
|
||||
try:
|
||||
from core.database import ModelEndpoint
|
||||
eps = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
||||
for ep in eps:
|
||||
cached = []
|
||||
if ep.cached_models:
|
||||
try:
|
||||
cached = json.loads(ep.cached_models) or []
|
||||
except Exception:
|
||||
cached = []
|
||||
if model in cached:
|
||||
return ep.base_url or ""
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def setup_task_routes(task_scheduler) -> APIRouter:
|
||||
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
|
||||
|
||||
def _owner(request: Request):
|
||||
return get_current_user(request)
|
||||
|
||||
async def _generate_task_name(prompt: str) -> str:
|
||||
"""Use LLM to generate a short task name from the prompt."""
|
||||
try:
|
||||
from src.llm_core import llm_call_async
|
||||
from core.database import Session as DbSession
|
||||
db = SessionLocal()
|
||||
try:
|
||||
recent = db.query(DbSession).filter(
|
||||
DbSession.endpoint_url.isnot(None),
|
||||
DbSession.model.isnot(None),
|
||||
).order_by(DbSession.created_at.desc()).first()
|
||||
if not recent:
|
||||
return prompt[:50].strip()
|
||||
url, model = recent.endpoint_url, recent.model
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
result = await llm_call_async(
|
||||
url=url, model=model,
|
||||
messages=[
|
||||
{"role": "system", "content": "Generate a short title (3-5 words, no quotes) for this scheduled task. Reply with ONLY the title, nothing else."},
|
||||
{"role": "user", "content": prompt[:500]},
|
||||
],
|
||||
max_tokens=20,
|
||||
timeout=15,
|
||||
)
|
||||
title = result.strip().strip('"\'').strip()
|
||||
return title[:60] if title else prompt[:50].strip()
|
||||
except Exception:
|
||||
first = prompt.split('\n')[0].split('.')[0].strip()
|
||||
return first[:50] if first else "Untitled Task"
|
||||
|
||||
@router.get("")
|
||||
async def list_tasks(request: Request, status: Optional[str] = None,
|
||||
include_last_run: bool = False):
|
||||
user = _owner(request)
|
||||
if user:
|
||||
await task_scheduler.ensure_defaults(user)
|
||||
else:
|
||||
db_seed = SessionLocal()
|
||||
try:
|
||||
owners = {
|
||||
row[0] for row in db_seed.query(ScheduledTask.owner)
|
||||
.filter(ScheduledTask.task_type == "action")
|
||||
.filter(ScheduledTask.action.in_(list(HOUSEKEEPING_DEFAULTS.keys())))
|
||||
.all()
|
||||
if row[0]
|
||||
}
|
||||
finally:
|
||||
db_seed.close()
|
||||
for owner in owners:
|
||||
await task_scheduler.ensure_defaults(owner)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(ScheduledTask)
|
||||
if user:
|
||||
q = q.filter(ScheduledTask.owner == user)
|
||||
if status:
|
||||
q = q.filter(ScheduledTask.status == status)
|
||||
tasks = q.order_by(ScheduledTask.created_at.desc()).all()
|
||||
return {"tasks": [_task_to_dict(t, include_last_run_result=include_last_run) for t in tasks]}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/onboarding")
|
||||
async def get_tasks_onboarding(request: Request):
|
||||
user = _owner(request)
|
||||
prefs = _load_for_user(user) or {}
|
||||
return {
|
||||
"opened": bool(prefs.get("tasks_opened")),
|
||||
"enabled": bool(prefs.get("tasks_enabled")),
|
||||
}
|
||||
|
||||
@router.post("/onboarding")
|
||||
async def update_tasks_onboarding(request: Request, body: dict):
|
||||
user = _owner(request)
|
||||
prefs = _load_for_user(user) or {}
|
||||
prefs["tasks_opened"] = True
|
||||
enable = bool(body.get("enabled"))
|
||||
if enable:
|
||||
prefs["tasks_enabled"] = True
|
||||
_save_for_user(user, prefs)
|
||||
if user:
|
||||
await task_scheduler.ensure_defaults(user)
|
||||
|
||||
resumed = 0
|
||||
if enable:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
tasks = db.query(ScheduledTask).filter(
|
||||
ScheduledTask.owner == user,
|
||||
ScheduledTask.task_type == "action",
|
||||
ScheduledTask.action.in_(list(HOUSEKEEPING_DEFAULTS.keys())),
|
||||
).all()
|
||||
for task in tasks:
|
||||
defs = HOUSEKEEPING_DEFAULTS.get(task.action or "")
|
||||
if defs and defs.get("ship_paused"):
|
||||
continue
|
||||
if task.status == "active":
|
||||
continue
|
||||
task.status = "active"
|
||||
if (task.trigger_type or "schedule") == "schedule":
|
||||
task.next_run = compute_next_run(
|
||||
task.schedule,
|
||||
task.scheduled_time,
|
||||
task.scheduled_day,
|
||||
task.scheduled_date,
|
||||
cron_expression=task.cron_expression,
|
||||
)
|
||||
resumed += 1
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
return {"ok": True, "opened": True, "enabled": bool(prefs.get("tasks_enabled")), "resumed": resumed}
|
||||
|
||||
# Actions that execute shell/SSH commands — restricted to admins.
|
||||
# Non-admin users cannot create tasks with these action types via the
|
||||
# API. See review CRIT-C.
|
||||
_ADMIN_ONLY_ACTIONS = {"run_local", "run_script", "ssh_command"}
|
||||
|
||||
def _is_admin(user: str | None) -> bool:
|
||||
if not user:
|
||||
return False
|
||||
# In-process tool-loopback marker — AuthMiddleware validated
|
||||
# the internal token + loopback client before stamping this,
|
||||
# so treat as admin-equivalent.
|
||||
if user == "internal-tool":
|
||||
return True
|
||||
try:
|
||||
from core.auth import AuthManager
|
||||
auth = AuthManager()
|
||||
if not auth.is_configured:
|
||||
# Unconfigured single-user deploy: trust the local owner.
|
||||
return True
|
||||
return bool(auth.is_admin(user))
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@router.post("")
|
||||
async def create_task(request: Request, req: TaskCreate):
|
||||
user = _owner(request)
|
||||
|
||||
# Validate
|
||||
if req.task_type in ("llm", "research") and not req.prompt:
|
||||
raise HTTPException(400, "Prompt is required for LLM/research tasks")
|
||||
if req.task_type == "action" and not req.action:
|
||||
raise HTTPException(400, "Action name is required for action tasks")
|
||||
# Block shell-executing action types for non-admins. action_run_local
|
||||
# uses subprocess.run(shell=True) and ssh_command / run_script run
|
||||
# arbitrary commands.
|
||||
if req.task_type == "action" and req.action in _ADMIN_ONLY_ACTIONS and not _is_admin(user):
|
||||
raise HTTPException(403, f"Action '{req.action}' requires admin privileges")
|
||||
if req.trigger_type == "schedule" and not req.schedule:
|
||||
raise HTTPException(400, "Schedule is required for schedule-triggered tasks")
|
||||
if req.trigger_type == "schedule" and req.schedule == "cron" and not req.cron_expression:
|
||||
raise HTTPException(400, "Cron expression is required for cron schedule")
|
||||
if req.trigger_type == "schedule" and req.schedule == "cron" and req.cron_expression:
|
||||
try:
|
||||
from croniter import croniter
|
||||
croniter(req.cron_expression)
|
||||
except Exception:
|
||||
raise HTTPException(400, "Invalid cron expression")
|
||||
if req.trigger_type == "event" and not req.trigger_event:
|
||||
raise HTTPException(400, "Event name is required for event-triggered tasks")
|
||||
if req.trigger_type == "event" and not req.trigger_count:
|
||||
raise HTTPException(400, "Trigger count is required for event-triggered tasks")
|
||||
|
||||
# Auto-generate name
|
||||
name = req.name
|
||||
if not name:
|
||||
if req.task_type == "action":
|
||||
from src.builtin_actions import BUILTIN_ACTION_INFO
|
||||
name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task")
|
||||
elif req.prompt:
|
||||
name = await _generate_task_name(req.prompt)
|
||||
else:
|
||||
name = "Untitled Task"
|
||||
|
||||
# Compute next_run for schedule-triggered tasks
|
||||
next_run = None
|
||||
sched_date = None
|
||||
if req.trigger_type == "schedule":
|
||||
if req.schedule == "once" and req.scheduled_date:
|
||||
try:
|
||||
sched_date = datetime.fromisoformat(req.scheduled_date.replace("Z", "+00:00")).replace(tzinfo=None)
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Invalid scheduled_date format")
|
||||
next_run = compute_next_run(
|
||||
req.schedule, req.scheduled_time,
|
||||
req.scheduled_day, sched_date,
|
||||
cron_expression=req.cron_expression,
|
||||
)
|
||||
|
||||
# Generate webhook token if needed
|
||||
webhook_token = None
|
||||
if req.trigger_type == "webhook":
|
||||
webhook_token = secrets.token_urlsafe(32)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
db = SessionLocal()
|
||||
try:
|
||||
notifications_enabled = (
|
||||
False if req.task_type == "action" and req.notifications_enabled is None
|
||||
else bool(req.notifications_enabled) if req.notifications_enabled is not None
|
||||
else True
|
||||
)
|
||||
task = ScheduledTask(
|
||||
id=task_id,
|
||||
owner=user,
|
||||
name=name,
|
||||
prompt=req.prompt,
|
||||
task_type=req.task_type,
|
||||
action=req.action,
|
||||
schedule=req.schedule,
|
||||
scheduled_time=req.scheduled_time,
|
||||
scheduled_day=req.scheduled_day,
|
||||
scheduled_date=sched_date,
|
||||
cron_expression=req.cron_expression,
|
||||
trigger_type=req.trigger_type,
|
||||
trigger_event=req.trigger_event,
|
||||
trigger_count=req.trigger_count,
|
||||
trigger_counter=0,
|
||||
next_run=next_run,
|
||||
status="active" if (req.trigger_type in ("event", "webhook") or next_run) else "completed",
|
||||
output_target=req.output_target,
|
||||
model=req.model or None,
|
||||
endpoint_url=req.endpoint_url or None,
|
||||
then_task_id=req.then_task_id or None,
|
||||
webhook_token=webhook_token,
|
||||
notifications_enabled=notifications_enabled,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return _task_to_dict(task)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/notifications")
|
||||
async def get_notifications(request: Request):
|
||||
"""Return and clear pending task-run notifications for the
|
||||
current user. Anonymous callers get nothing (prevents
|
||||
cross-tenant drain — see review CRIT-B)."""
|
||||
user = _owner(request)
|
||||
if not user:
|
||||
return {"notifications": []}
|
||||
notes = task_scheduler.pop_notifications(owner=user)
|
||||
return {"notifications": notes}
|
||||
|
||||
@router.get("/{task_id}")
|
||||
async def get_task(request: Request, task_id: str):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Task not found")
|
||||
if user and task.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
return _task_to_dict(task)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.put("/{task_id}")
|
||||
async def update_task(request: Request, task_id: str, req: TaskUpdate):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Task not found")
|
||||
if user and task.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
|
||||
if req.name is not None:
|
||||
task.name = req.name
|
||||
if req.prompt is not None:
|
||||
task.prompt = req.prompt
|
||||
if req.task_type is not None:
|
||||
task.task_type = req.task_type
|
||||
if req.action is not None:
|
||||
# Same admin-only gate as create — see CRIT-C.
|
||||
if req.action in _ADMIN_ONLY_ACTIONS and not _is_admin(user):
|
||||
raise HTTPException(403, f"Action '{req.action}' requires admin privileges")
|
||||
task.action = req.action
|
||||
if req.output_target is not None:
|
||||
task.output_target = req.output_target
|
||||
if req.model is not None:
|
||||
task.model = req.model or None
|
||||
if req.endpoint_url is not None:
|
||||
task.endpoint_url = req.endpoint_url or None
|
||||
if req.trigger_type is not None:
|
||||
# Generate webhook token when switching to webhook trigger
|
||||
if req.trigger_type == "webhook" and not task.webhook_token:
|
||||
task.webhook_token = secrets.token_urlsafe(32)
|
||||
task.trigger_type = req.trigger_type
|
||||
if req.trigger_event is not None:
|
||||
task.trigger_event = req.trigger_event
|
||||
if req.trigger_count is not None:
|
||||
task.trigger_count = req.trigger_count
|
||||
if req.then_task_id is not None:
|
||||
task.then_task_id = req.then_task_id or None
|
||||
if req.notifications_enabled is not None:
|
||||
task.notifications_enabled = bool(req.notifications_enabled)
|
||||
if req.cron_expression is not None:
|
||||
if req.cron_expression:
|
||||
try:
|
||||
from croniter import croniter
|
||||
croniter(req.cron_expression)
|
||||
except Exception:
|
||||
raise HTTPException(400, "Invalid cron expression")
|
||||
task.cron_expression = req.cron_expression or None
|
||||
|
||||
# Recompute next_run if schedule changed
|
||||
schedule_changed = False
|
||||
if req.schedule is not None:
|
||||
task.schedule = req.schedule
|
||||
schedule_changed = True
|
||||
if req.scheduled_time is not None:
|
||||
task.scheduled_time = req.scheduled_time
|
||||
schedule_changed = True
|
||||
if req.scheduled_day is not None:
|
||||
task.scheduled_day = req.scheduled_day
|
||||
schedule_changed = True
|
||||
if req.scheduled_date is not None:
|
||||
try:
|
||||
task.scheduled_date = datetime.fromisoformat(
|
||||
req.scheduled_date.replace("Z", "+00:00")
|
||||
).replace(tzinfo=None)
|
||||
except ValueError:
|
||||
raise HTTPException(400, "Invalid scheduled_date format")
|
||||
schedule_changed = True
|
||||
|
||||
if req.cron_expression is not None:
|
||||
schedule_changed = True
|
||||
|
||||
if schedule_changed and task.status == "active" and (task.trigger_type or "schedule") == "schedule":
|
||||
task.next_run = compute_next_run(
|
||||
task.schedule, task.scheduled_time,
|
||||
task.scheduled_day, task.scheduled_date,
|
||||
cron_expression=task.cron_expression,
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return _task_to_dict(task)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.delete("/{task_id}")
|
||||
async def delete_task(request: Request, task_id: str):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Task not found")
|
||||
if user and task.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
db.delete(task)
|
||||
db.commit()
|
||||
return {"ok": True}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/{task_id}/pause")
|
||||
async def pause_task(request: Request, task_id: str):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Task not found")
|
||||
if user and task.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
task.status = "paused"
|
||||
db.commit()
|
||||
return {"ok": True, "status": "paused"}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/{task_id}/resume")
|
||||
async def resume_task(request: Request, task_id: str):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Task not found")
|
||||
if user and task.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
task.status = "active"
|
||||
if (task.trigger_type or "schedule") == "schedule":
|
||||
task.next_run = compute_next_run(
|
||||
task.schedule, task.scheduled_time,
|
||||
task.scheduled_day, task.scheduled_date,
|
||||
cron_expression=task.cron_expression,
|
||||
)
|
||||
db.commit()
|
||||
return {"ok": True, "status": "active", "next_run": task.next_run.isoformat() + "Z" if task.next_run else None}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/{task_id}/revert")
|
||||
async def revert_task(request: Request, task_id: str):
|
||||
"""Reset a built-in (housekeeping) task to its default config."""
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Task not found")
|
||||
if user and task.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
defs = HOUSEKEEPING_DEFAULTS.get(task.action) if task.action else None
|
||||
if not defs:
|
||||
raise HTTPException(400, "Not a built-in task")
|
||||
task.name = defs["name"]
|
||||
task.schedule = defs["schedule"]
|
||||
task.scheduled_time = defs["scheduled_time"]
|
||||
task.scheduled_day = None
|
||||
task.scheduled_date = None
|
||||
task.cron_expression = defs["cron_expression"]
|
||||
task.trigger_type = defs.get("trigger_type", "schedule")
|
||||
task.trigger_event = defs.get("trigger_event")
|
||||
task.trigger_count = defs.get("trigger_count")
|
||||
task.trigger_counter = 0
|
||||
task.prompt = None
|
||||
task.model = None
|
||||
task.endpoint_url = None
|
||||
task.status = "paused" if defs.get("ship_paused") else "active"
|
||||
task.next_run = None
|
||||
if task.trigger_type == "schedule":
|
||||
task.next_run = compute_next_run(
|
||||
defs["schedule"], defs["scheduled_time"], None, None,
|
||||
cron_expression=defs["cron_expression"],
|
||||
)
|
||||
db.commit()
|
||||
db.refresh(task)
|
||||
return {"ok": True, "task": _task_to_dict(task)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/{task_id}/run")
|
||||
async def run_task_now(request: Request, task_id: str, force: bool = False):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Task not found")
|
||||
if user and task.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
finally:
|
||||
db.close()
|
||||
started = await task_scheduler.run_task_now(task_id, force=force)
|
||||
if not started:
|
||||
raise HTTPException(409, "Task is already running")
|
||||
return {"ok": True, "message": "Task triggered" + (" in parallel" if force else "")}
|
||||
|
||||
@router.get("/runs/recent")
|
||||
async def list_recent_runs(request: Request, limit: int = 50):
|
||||
"""Recent task runs across ALL tasks for this owner. Drives the Activity view."""
|
||||
user = _owner(request)
|
||||
limit = max(1, min(limit, 200))
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(TaskRun, ScheduledTask).join(
|
||||
ScheduledTask, TaskRun.task_id == ScheduledTask.id
|
||||
)
|
||||
if user:
|
||||
# Strict owner scope — was previously OR'ing in `owner IS NULL`
|
||||
# rows for "legacy single-user" back-compat, but that leaks any
|
||||
# legacy/migrated task's full result text to every authenticated
|
||||
# user. _migrate_assign_legacy_owner runs on startup to claim
|
||||
# legacy rows for the admin, so the OR-NULL path is no longer
|
||||
# needed for any sane deploy.
|
||||
q = q.filter(ScheduledTask.owner == user)
|
||||
# Pull a little extra before de-duping. When auth is bypassed on a
|
||||
# local browser session, legacy/default tasks from multiple owners
|
||||
# can be visible together; the built-in urgent-email scanner then
|
||||
# produces several identical "no email accounts configured" rows in
|
||||
# the same minute. Keep the task records intact, but collapse those
|
||||
# duplicate Activity rows for display.
|
||||
rows = q.order_by(TaskRun.started_at.desc()).limit(limit * 3).all()
|
||||
deduped = []
|
||||
seen_urgency_rows = set()
|
||||
for r, t in rows:
|
||||
if (t.action or "") == "check_email_urgency":
|
||||
ts = r.started_at.replace(second=0, microsecond=0) if r.started_at else None
|
||||
text = (r.result or r.error or "").strip()
|
||||
key = (ts, r.status or "", text)
|
||||
if key in seen_urgency_rows:
|
||||
continue
|
||||
seen_urgency_rows.add(key)
|
||||
deduped.append((r, t))
|
||||
if len(deduped) >= limit:
|
||||
break
|
||||
return {
|
||||
"runs": [
|
||||
{
|
||||
**_run_to_dict(r),
|
||||
"task_name": _display_task_name(t),
|
||||
"task_type": t.task_type or "llm",
|
||||
"action": t.action,
|
||||
# Model + endpoint the task ran on, so the Activity
|
||||
# view's "Open in chat" can reuse the same model.
|
||||
"model": r.model or t.model or "",
|
||||
"endpoint_url": _resolve_run_endpoint(db, t, r),
|
||||
"session_id": t.session_id or "",
|
||||
"research_id": _run_research_id(t),
|
||||
# Where the task delivered its result — the Activity tab
|
||||
# uses this to filter notification rows in/out.
|
||||
"output_target": t.output_target or "session",
|
||||
}
|
||||
for r, t in deduped
|
||||
]
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/{task_id}/runs")
|
||||
async def list_runs(request: Request, task_id: str, limit: int = 20, offset: int = 0):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Task not found")
|
||||
if user and task.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
runs = db.query(TaskRun).filter(TaskRun.task_id == task_id)\
|
||||
.order_by(TaskRun.started_at.desc())\
|
||||
.offset(offset).limit(limit).all()
|
||||
total = db.query(TaskRun).filter(TaskRun.task_id == task_id).count()
|
||||
return {"runs": [_run_to_dict(r) for r in runs], "total": total}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.get("/meta/output-targets")
|
||||
async def list_output_targets(request: Request):
|
||||
"""List available output targets — only delivery/send tools, not all MCP tools."""
|
||||
_owner(request)
|
||||
targets = [
|
||||
{"value": "session", "label": "Session", "description": "Save result to a chat session"},
|
||||
{"value": "notification", "label": "Notification", "description": "Push a browser notification with the result (also saved to the session for history)"},
|
||||
{"value": "email", "label": "Email me", "description": "Send result through your configured SMTP account"},
|
||||
]
|
||||
# Only include tools whose NAME clearly indicates an outbound delivery
|
||||
# action — match by verb in the tool name, not by any mention of "email"
|
||||
# in the description (which falsely picked up search_email, list_email,
|
||||
# etc.). Also exclude read/search/list tools whose names happen to start
|
||||
# with a delivery verb.
|
||||
_DELIVERY_VERBS = ("send", "notify", "post", "publish", "draft", "dispatch", "deliver")
|
||||
_NON_DELIVERY = (
|
||||
"search", "list", "get", "find", "read", "fetch", "view",
|
||||
"tag", "label", "move", "archive", "delete", "mark", "schedule",
|
||||
)
|
||||
try:
|
||||
from src.agent_tools import get_mcp_manager
|
||||
mcp = get_mcp_manager()
|
||||
if mcp:
|
||||
for tool in mcp.get_all_tools():
|
||||
name_lower = tool.get("name", "").lower()
|
||||
if any(x in name_lower for x in _NON_DELIVERY):
|
||||
continue
|
||||
if not any(v in name_lower for v in _DELIVERY_VERBS):
|
||||
continue
|
||||
targets.append({
|
||||
"value": tool["qualified_name"],
|
||||
"label": f"{tool['server_name']} → {tool['name']}",
|
||||
"description": tool.get("description", ""),
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
return {"targets": targets}
|
||||
|
||||
@router.get("/meta/actions")
|
||||
async def list_actions(request: Request):
|
||||
"""List available built-in actions."""
|
||||
user = _owner(request)
|
||||
from src.builtin_actions import BUILTIN_ACTION_INFO
|
||||
return {"actions": [
|
||||
{"name": name, "description": desc}
|
||||
for name, desc in BUILTIN_ACTION_INFO.items()
|
||||
if name not in _ADMIN_ONLY_ACTIONS or _is_admin(user)
|
||||
]}
|
||||
|
||||
@router.get("/meta/events")
|
||||
async def list_events(request: Request):
|
||||
"""List available event triggers."""
|
||||
_owner(request)
|
||||
return {"events": [
|
||||
{"name": "session_created", "description": "Fires when a new chat session is created"},
|
||||
{"name": "message_sent", "description": "Fires when a user sends a message"},
|
||||
{"name": "document_created", "description": "Fires when a document is created"},
|
||||
{"name": "memory_added", "description": "Fires when a memory is added"},
|
||||
{"name": "research_completed", "description": "Fires when a research report completes"},
|
||||
{"name": "email_received", "description": "Fires when new inbox mail is observed"},
|
||||
{"name": "skill_added", "description": "Fires when a new skill is created"},
|
||||
]}
|
||||
|
||||
@router.post("/{task_id}/webhook/{token}")
|
||||
async def webhook_trigger(task_id: str, token: str):
|
||||
"""Unauthenticated endpoint — the token IS the auth."""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(
|
||||
ScheduledTask.id == task_id,
|
||||
ScheduledTask.webhook_token == token,
|
||||
ScheduledTask.status == "active",
|
||||
).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Not found")
|
||||
finally:
|
||||
db.close()
|
||||
started = await task_scheduler.run_task_now(task_id)
|
||||
if not started:
|
||||
raise HTTPException(409, "Task is already running")
|
||||
return {"ok": True, "message": "Task triggered via webhook"}
|
||||
|
||||
@router.post("/{task_id}/webhook-regenerate")
|
||||
async def regenerate_webhook(request: Request, task_id: str):
|
||||
user = _owner(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
|
||||
if not task:
|
||||
raise HTTPException(404, "Task not found")
|
||||
if user and task.owner != user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
task.webhook_token = secrets.token_urlsafe(32)
|
||||
db.commit()
|
||||
return {"ok": True, "webhook_token": task.webhook_token}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# --- PARSE NATURAL LANGUAGE → TASK DRAFT (AI) ---
|
||||
@router.post("/parse")
|
||||
async def parse_task(request: Request) -> Dict[str, Any]:
|
||||
"""Turn a free-form description ("every weekday at 7am research the top
|
||||
AI news and summarize it") into a structured task draft the frontend
|
||||
can pre-fill the form with. Returns a draft only — the user reviews and
|
||||
saves it, so a misread schedule never goes live unreviewed."""
|
||||
from src.endpoint_resolver import resolve_endpoint
|
||||
from src.llm_core import llm_call_async
|
||||
from src.text_helpers import strip_think as _strip_think
|
||||
import json as _json, re as _re
|
||||
from datetime import datetime as _dt
|
||||
|
||||
body = await request.json()
|
||||
desc = (body.get("description") or "").strip()
|
||||
if not desc:
|
||||
return {"success": False, "message": "Nothing to parse"}
|
||||
|
||||
now = _dt.now()
|
||||
# Give the model the current date/time + weekday so relative phrasing
|
||||
# ("tomorrow", "every Monday", "in an hour") resolves correctly.
|
||||
ctx = now.strftime("%Y-%m-%d %H:%M (%A)")
|
||||
sys = (
|
||||
"You convert a user's description of a recurring or one-off task into "
|
||||
"STRICT JSON for a task scheduler. The current local date/time is "
|
||||
f"{ctx}. Output ONLY a JSON object, no prose, no markdown fences.\n\n"
|
||||
"Schema (omit fields you can't infer):\n"
|
||||
"{\n"
|
||||
' "task_type": "llm" | "research", // "research" if it asks to research/investigate/find out; else "llm"\n'
|
||||
' "name": "short 3-6 word title",\n'
|
||||
' "prompt": "the instruction the AI should run on schedule (or the research question)",\n'
|
||||
' "schedule": "daily" | "weekly" | "monthly" | "once" | "cron",\n'
|
||||
' "scheduled_time": "HH:MM", // 24h LOCAL time\n'
|
||||
' "scheduled_day": 0, // weekly: 0=Mon..6=Sun; monthly: 1..31\n'
|
||||
' "scheduled_date": "YYYY-MM-DDTHH:MM", // only for "once"\n'
|
||||
' "cron_expression": "m h dom mon dow", // only if schedule is "cron"\n'
|
||||
' "output_target": "session" | "email" | "notification" // use email when the user asks to email the result\n'
|
||||
"}\n\n"
|
||||
"Rules: default schedule to 'daily' if a time is given without a frequency. "
|
||||
"Default scheduled_time to '09:00' if none is stated. For 'every weekday' "
|
||||
"use cron '0 H * * 1-5'. Keep the prompt actionable and self-contained."
|
||||
)
|
||||
try:
|
||||
url, model, headers = resolve_endpoint("utility")
|
||||
if not url:
|
||||
url, model, headers = resolve_endpoint("default")
|
||||
if not (url and model):
|
||||
return {"success": False, "message": "No model endpoint configured"}
|
||||
raw = await llm_call_async(
|
||||
url=url, model=model,
|
||||
messages=[{"role": "system", "content": sys},
|
||||
{"role": "user", "content": desc[:1000]}],
|
||||
temperature=0.2, max_tokens=400, headers=headers, timeout=45,
|
||||
)
|
||||
text = _strip_think(raw or "", prose=False, prompt_echo=False).strip()
|
||||
if text.startswith("```"):
|
||||
text = text.strip("`")
|
||||
if text.lower().startswith("json"):
|
||||
text = text[4:].lstrip()
|
||||
# Pull the first {...} block in case the model added stray text.
|
||||
m = _re.search(r"\{.*\}", text, _re.S)
|
||||
draft = _json.loads(m.group(0) if m else text)
|
||||
if not isinstance(draft, dict):
|
||||
raise ValueError("not an object")
|
||||
# Whitelist + light validation so the frontend gets clean fields.
|
||||
out: Dict[str, Any] = {}
|
||||
if draft.get("task_type") in ("llm", "research"):
|
||||
out["task_type"] = draft["task_type"]
|
||||
else:
|
||||
out["task_type"] = "llm"
|
||||
for k in ("name", "prompt", "cron_expression", "scheduled_date"):
|
||||
if isinstance(draft.get(k), str) and draft[k].strip():
|
||||
out[k] = draft[k].strip()
|
||||
if draft.get("schedule") in ("daily", "weekly", "monthly", "once", "cron"):
|
||||
out["schedule"] = draft["schedule"]
|
||||
else:
|
||||
out["schedule"] = "daily"
|
||||
st = draft.get("scheduled_time")
|
||||
if isinstance(st, str) and _re.match(r"^\d{1,2}:\d{2}$", st.strip()):
|
||||
out["scheduled_time"] = st.strip()
|
||||
if isinstance(draft.get("scheduled_day"), int):
|
||||
out["scheduled_day"] = draft["scheduled_day"]
|
||||
if draft.get("output_target") in ("session", "email", "notification"):
|
||||
out["output_target"] = draft["output_target"]
|
||||
out["trigger_type"] = "schedule"
|
||||
if not out.get("prompt"):
|
||||
return {"success": False, "message": "Could not extract a task instruction"}
|
||||
return {"success": True, "draft": out}
|
||||
except Exception as e:
|
||||
logger.error(f"parse_task failed: {e}")
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
return router
|
||||
87
routes/tts_routes.py
Normal file
87
routes/tts_routes.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# routes/tts_routes.py
|
||||
"""
|
||||
TTS API routes — multi-provider (local Kokoro, API endpoint, browser).
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
format: str = "audio" # "audio" or "base64"
|
||||
|
||||
def setup_tts_routes(tts_service):
|
||||
"""Setup TTS routes with the provided TTS service"""
|
||||
router = APIRouter(prefix="/api/tts", tags=["tts"])
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_tts_stats():
|
||||
"""Get TTS service statistics"""
|
||||
try:
|
||||
return tts_service.get_stats()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get TTS stats: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.post("/synthesize")
|
||||
async def synthesize_speech(request: TTSRequest):
|
||||
"""Synthesize speech from text"""
|
||||
try:
|
||||
if not tts_service.available:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail={"message": "TTS service not available"}
|
||||
)
|
||||
|
||||
if request.format == "base64":
|
||||
audio_b64 = tts_service.synthesize_to_base64(request.text)
|
||||
if not audio_b64:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": "Synthesis failed"}
|
||||
)
|
||||
return {"audio": audio_b64}
|
||||
|
||||
else: # audio format
|
||||
audio_data = tts_service.synthesize(request.text)
|
||||
if not audio_data:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": "Synthesis failed"}
|
||||
)
|
||||
|
||||
# Detect format from magic bytes (MP3: ID3 tag or sync word ff e0+)
|
||||
is_mp3 = audio_data[:3] == b'ID3' or (len(audio_data) >= 2 and audio_data[0] == 0xff and (audio_data[1] & 0xe0) == 0xe0)
|
||||
mime = "audio/mpeg" if is_mp3 else "audio/wav"
|
||||
return Response(
|
||||
content=audio_data,
|
||||
media_type=mime,
|
||||
headers={
|
||||
"Content-Disposition": "inline; filename=speech.mp3" if "mpeg" in mime else "inline; filename=speech.wav"
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Synthesis error: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"message": f"Synthesis failed: {str(e)}"}
|
||||
)
|
||||
|
||||
@router.post("/clear-cache")
|
||||
async def clear_tts_cache():
|
||||
"""Clear TTS cache"""
|
||||
try:
|
||||
tts_service.clear_cache()
|
||||
return {"success": True, "message": "Cache cleared"}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear cache: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return router
|
||||
251
routes/upload_routes.py
Normal file
251
routes/upload_routes.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# routes/upload_routes.py
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Request, File, UploadFile, HTTPException
|
||||
from typing import List
|
||||
import logging
|
||||
from core.middleware import require_admin
|
||||
from src.auth_helpers import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/upload", tags=["upload"])
|
||||
|
||||
def setup_upload_routes(upload_handler):
|
||||
"""Setup upload routes with the provided handler"""
|
||||
|
||||
@router.post("")
|
||||
async def api_upload(request: Request, files: List[UploadFile] = File(...)):
|
||||
"""Upload files with enhanced security and organization."""
|
||||
if not files:
|
||||
raise HTTPException(400, "No files uploaded")
|
||||
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
out = []
|
||||
|
||||
# Limit concurrent uploads per IP
|
||||
ip_upload_count = sum(
|
||||
1 for f in files
|
||||
if client_ip in upload_handler.upload_rate_log and
|
||||
any(now > time.time() - 10 for now in upload_handler.upload_rate_log[client_ip][-len(files):])
|
||||
)
|
||||
|
||||
if ip_upload_count >= upload_handler.max_concurrent_uploads:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Maximum concurrent uploads ({upload_handler.max_concurrent_uploads}) exceeded"
|
||||
)
|
||||
|
||||
for u in files:
|
||||
try:
|
||||
meta = upload_handler.save_upload(u, client_ip, owner=get_current_user(request))
|
||||
out.append({
|
||||
"id": meta["id"],
|
||||
"name": meta["name"],
|
||||
"mime": meta["mime"],
|
||||
"size": meta["size"],
|
||||
"hash": meta["hash"],
|
||||
"uploaded_at": meta["uploaded_at"],
|
||||
"width": meta.get("width"),
|
||||
"height": meta.get("height"),
|
||||
"is_duplicate": meta.get("is_duplicate", False)
|
||||
})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process upload {u.filename}: {str(e)}")
|
||||
continue
|
||||
|
||||
if not out:
|
||||
raise HTTPException(500, "All file uploads failed")
|
||||
|
||||
return {"files": out}
|
||||
|
||||
@router.post("/cleanup")
|
||||
async def manual_cleanup(request: Request):
|
||||
"""Manually trigger cleanup of old uploads."""
|
||||
require_admin(request)
|
||||
cleaned_count = upload_handler.cleanup_old_uploads()
|
||||
return {"status": "success", "files_cleaned": cleaned_count}
|
||||
|
||||
@router.get("/stats")
|
||||
async def upload_stats(request: Request):
|
||||
"""Get statistics about uploaded files."""
|
||||
require_admin(request)
|
||||
try:
|
||||
return upload_handler.get_upload_stats()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get upload stats: {e}")
|
||||
raise HTTPException(500, "Failed to get upload statistics")
|
||||
|
||||
@router.get("/{file_id}")
|
||||
async def download_file(request: Request, file_id: str, thumb: int = 0):
|
||||
"""Serve an uploaded file by its ID. `?thumb=1` returns a small cached
|
||||
JPEG thumbnail for images (used by chat attachment previews) so the
|
||||
client isn't downloading the full-resolution photo just to show it tiny."""
|
||||
if not upload_handler.validate_upload_id(file_id):
|
||||
raise HTTPException(400, "Invalid file ID")
|
||||
# Search upload directories for the file
|
||||
from src.constants import UPLOAD_DIR
|
||||
import mimetypes as _mt
|
||||
path = os.path.join(UPLOAD_DIR, file_id)
|
||||
if not os.path.exists(path):
|
||||
for root, dirs, files in os.walk(UPLOAD_DIR):
|
||||
if file_id in files:
|
||||
path = os.path.join(root, file_id)
|
||||
break
|
||||
else:
|
||||
raise HTTPException(404, "File not found")
|
||||
if not upload_handler.inside_base_dir(path):
|
||||
raise HTTPException(403, "Access denied")
|
||||
# Look up original filename and owner from uploads.json
|
||||
original_name = file_id
|
||||
info = None
|
||||
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
|
||||
if os.path.exists(uploads_db):
|
||||
with open(uploads_db) as f:
|
||||
db = json.load(f)
|
||||
info = next((fi for fi in db.values() if fi["id"] == file_id), None)
|
||||
if info:
|
||||
original_name = info.get("name", file_id)
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
||||
current_user = get_current_user(request)
|
||||
file_owner = info.get("owner") if info else None
|
||||
if auth_configured:
|
||||
if not current_user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||
raise HTTPException(404, "File not found")
|
||||
mime = _mt.guess_type(path)[0] or "application/octet-stream"
|
||||
from fastapi.responses import FileResponse
|
||||
# Downscaled thumbnail for image previews — generated once and cached.
|
||||
if thumb and mime.startswith("image/"):
|
||||
try:
|
||||
from PIL import Image, ImageOps
|
||||
thumb_dir = os.path.join(UPLOAD_DIR, ".thumbs")
|
||||
os.makedirs(thumb_dir, exist_ok=True)
|
||||
thumb_path = os.path.join(thumb_dir, file_id + ".jpg")
|
||||
if (not os.path.exists(thumb_path)
|
||||
or os.path.getmtime(thumb_path) < os.path.getmtime(path)):
|
||||
im = Image.open(path)
|
||||
# iPhone / camera JPEGs encode rotation in EXIF rather than
|
||||
# the pixel data. Browsers honour that on the original via
|
||||
# image-orientation:from-image, but PIL strips EXIF when it
|
||||
# saves the JPEG thumb, leaving the pixels sideways. Bake
|
||||
# the rotation into the pixels before thumbnailing.
|
||||
im = ImageOps.exif_transpose(im)
|
||||
im.thumbnail((320, 320))
|
||||
if im.mode not in ("RGB", "L"):
|
||||
im = im.convert("RGB")
|
||||
im.save(thumb_path, "JPEG", quality=80)
|
||||
return FileResponse(thumb_path, media_type="image/jpeg")
|
||||
except Exception as e:
|
||||
logger.warning(f"Thumbnail generation failed for {file_id}: {e}")
|
||||
# Fall through to the full image.
|
||||
return FileResponse(path, media_type=mime, filename=original_name)
|
||||
|
||||
def _load_upload_info(file_id: str):
|
||||
"""Look up the uploads.json record for a file_id, with owner/auth checks."""
|
||||
from src.constants import UPLOAD_DIR
|
||||
info = None
|
||||
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
|
||||
if os.path.exists(uploads_db):
|
||||
with open(uploads_db) as f:
|
||||
db = json.load(f)
|
||||
info = next((fi for fi in db.values() if fi["id"] == file_id), None)
|
||||
return info
|
||||
|
||||
def _vision_cache_path(file_id: str) -> str:
|
||||
from src.constants import UPLOAD_DIR
|
||||
cache_dir = os.path.join(UPLOAD_DIR, ".vision")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
return os.path.join(cache_dir, file_id + ".txt")
|
||||
|
||||
@router.get("/{file_id}/vision")
|
||||
async def get_vision_text(request: Request, file_id: str, force: int = 0):
|
||||
"""Return the vision-model OCR/description for an uploaded image.
|
||||
Cached under UPLOAD_DIR/.vision/{file_id}.txt — first call computes,
|
||||
subsequent loads are instant. Pass force=1 to recompute."""
|
||||
if not upload_handler.validate_upload_id(file_id):
|
||||
raise HTTPException(400, "Invalid file ID")
|
||||
from src.constants import UPLOAD_DIR
|
||||
path = os.path.join(UPLOAD_DIR, file_id)
|
||||
if not os.path.exists(path):
|
||||
for root, dirs, files in os.walk(UPLOAD_DIR):
|
||||
if file_id in files:
|
||||
path = os.path.join(root, file_id)
|
||||
break
|
||||
else:
|
||||
raise HTTPException(404, "File not found")
|
||||
if not upload_handler.inside_base_dir(path):
|
||||
raise HTTPException(403, "Access denied")
|
||||
info = _load_upload_info(file_id)
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
||||
current_user = get_current_user(request)
|
||||
file_owner = info.get("owner") if info else None
|
||||
if auth_configured:
|
||||
if not current_user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||
raise HTTPException(404, "File not found")
|
||||
import mimetypes as _mt
|
||||
mime = _mt.guess_type(path)[0] or ""
|
||||
if not mime.startswith("image/"):
|
||||
raise HTTPException(400, "Not an image")
|
||||
cache_path = _vision_cache_path(file_id)
|
||||
if not force and os.path.exists(cache_path):
|
||||
try:
|
||||
with open(cache_path) as f:
|
||||
return {"text": f.read(), "cached": True}
|
||||
except Exception as e:
|
||||
logger.warning(f"Vision cache read failed for {file_id}: {e}")
|
||||
from src.document_processor import analyze_image_with_vl
|
||||
try:
|
||||
text = analyze_image_with_vl(path) or ""
|
||||
except Exception as e:
|
||||
logger.error(f"Vision analysis failed for {file_id}: {e}")
|
||||
raise HTTPException(500, f"Vision analysis failed: {e}")
|
||||
try:
|
||||
with open(cache_path, "w") as f:
|
||||
f.write(text)
|
||||
except Exception as e:
|
||||
logger.warning(f"Vision cache write failed for {file_id}: {e}")
|
||||
return {"text": text, "cached": False}
|
||||
|
||||
@router.put("/{file_id}/vision")
|
||||
async def put_vision_text(request: Request, file_id: str):
|
||||
"""Persist a user-edited vision/OCR text for an attachment. Stored in
|
||||
the same cache file so the chat send picks it up as the override."""
|
||||
if not upload_handler.validate_upload_id(file_id):
|
||||
raise HTTPException(400, "Invalid file ID")
|
||||
info = _load_upload_info(file_id)
|
||||
if not info:
|
||||
raise HTTPException(404, "File not found")
|
||||
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
|
||||
current_user = get_current_user(request)
|
||||
file_owner = info.get("owner")
|
||||
if auth_configured:
|
||||
if not current_user:
|
||||
raise HTTPException(403, "Access denied")
|
||||
if file_owner != current_user and not auth_mgr.is_admin(current_user):
|
||||
raise HTTPException(404, "File not found")
|
||||
body = await request.json()
|
||||
text = (body or {}).get("text", "")
|
||||
if not isinstance(text, str):
|
||||
raise HTTPException(400, "text must be a string")
|
||||
with open(_vision_cache_path(file_id), "w") as f:
|
||||
f.write(text)
|
||||
return {"ok": True}
|
||||
|
||||
async def periodic_rate_limit_cleanup():
|
||||
"""Background task to run cleanup every hour"""
|
||||
while True:
|
||||
await asyncio.sleep(3600)
|
||||
upload_handler.cleanup_rate_limits()
|
||||
|
||||
return router, periodic_rate_limit_cleanup
|
||||
216
routes/vault_routes.py
Normal file
216
routes/vault_routes.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
vault_routes.py
|
||||
|
||||
Vaultwarden / Bitwarden CLI integration — config and unlock endpoints.
|
||||
Stores the BW_SESSION key in data/vault.json with restrictive permissions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.middleware import require_admin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VAULT_FILE = Path("data/vault.json")
|
||||
|
||||
|
||||
def _find_bw() -> str:
|
||||
"""Locate the bw binary, checking PATH and common npm-global locations."""
|
||||
p = shutil.which("bw")
|
||||
if p:
|
||||
return p
|
||||
home = os.path.expanduser("~")
|
||||
for candidate in (
|
||||
f"{home}/.npm-global/bin/bw",
|
||||
f"{home}/.nvm/versions/node/*/bin/bw",
|
||||
"/usr/local/bin/bw",
|
||||
"/opt/homebrew/bin/bw",
|
||||
):
|
||||
if "*" in candidate:
|
||||
import glob
|
||||
for m in glob.glob(candidate):
|
||||
if os.path.isfile(m) and os.access(m, os.X_OK):
|
||||
return m
|
||||
elif os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
||||
return candidate
|
||||
return "bw" # fall back to PATH lookup (will FileNotFoundError, handled below)
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
if VAULT_FILE.exists():
|
||||
try:
|
||||
return json.loads(VAULT_FILE.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_config(cfg: dict):
|
||||
VAULT_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
VAULT_FILE.write_text(json.dumps(cfg, indent=2))
|
||||
try:
|
||||
os.chmod(str(VAULT_FILE), 0o600)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _run_bw(args: list, session: str = None, input_text: str = None) -> tuple:
|
||||
env = {}
|
||||
env.update(os.environ)
|
||||
if session:
|
||||
env["BW_SESSION"] = session
|
||||
bw_path = _find_bw()
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
bw_path, *args,
|
||||
stdin=asyncio.subprocess.PIPE if input_text else None,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return "", "bw CLI not installed (install `nodejs-bitwarden-cli` or `bitwarden-cli`)", 127
|
||||
except Exception as e:
|
||||
return "", f"Failed to launch bw: {e}", 1
|
||||
try:
|
||||
stdout, stderr = await proc.communicate(input=input_text.encode() if input_text else None)
|
||||
except Exception as e:
|
||||
return "", f"bw subprocess error: {e}", 1
|
||||
return stdout.decode(errors="replace").strip(), stderr.decode(errors="replace").strip(), proc.returncode
|
||||
|
||||
|
||||
class VaultConfig(BaseModel):
|
||||
server_url: str = ""
|
||||
email: str = ""
|
||||
|
||||
|
||||
class VaultUnlockRequest(BaseModel):
|
||||
master_password: str
|
||||
|
||||
|
||||
class VaultLoginRequest(BaseModel):
|
||||
email: str
|
||||
master_password: str
|
||||
|
||||
|
||||
def setup_vault_routes():
|
||||
router = APIRouter(prefix="/api/vault", tags=["vault"])
|
||||
|
||||
@router.get("/config")
|
||||
async def get_config(request: Request):
|
||||
"""Return vault config (no sensitive fields)."""
|
||||
require_admin(request)
|
||||
cfg = _load_config()
|
||||
return {
|
||||
"server_url": cfg.get("server_url", ""),
|
||||
"email": cfg.get("email", ""),
|
||||
"unlocked": bool(cfg.get("session")),
|
||||
"unlocked_at": cfg.get("unlocked_at", ""),
|
||||
"bw_installed": await _check_bw_installed(),
|
||||
}
|
||||
|
||||
@router.post("/config")
|
||||
async def save_config(req: VaultConfig, request: Request):
|
||||
"""Save vault URL + email. Runs 'bw config server' to point at Vaultwarden."""
|
||||
require_admin(request)
|
||||
cfg = _load_config()
|
||||
cfg["server_url"] = req.server_url.strip().rstrip("/")
|
||||
cfg["email"] = req.email.strip()
|
||||
|
||||
if cfg["server_url"]:
|
||||
_, stderr, rc = await _run_bw(["config", "server", cfg["server_url"]])
|
||||
if rc != 0:
|
||||
return {"ok": False, "error": f"bw config failed: {stderr[:300]}"}
|
||||
|
||||
_save_config(cfg)
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/login")
|
||||
async def login(req: VaultLoginRequest, request: Request):
|
||||
"""Log in to Vaultwarden (required once per account)."""
|
||||
require_admin(request)
|
||||
cfg = _load_config()
|
||||
# Update email
|
||||
cfg["email"] = req.email
|
||||
_save_config(cfg)
|
||||
|
||||
stdout, stderr, rc = await _run_bw(
|
||||
["login", req.email, "--raw"],
|
||||
input_text=req.master_password + "\n",
|
||||
)
|
||||
if rc != 0:
|
||||
# Already logged in is OK
|
||||
if "already logged in" in stderr.lower():
|
||||
return {"ok": True, "already": True}
|
||||
return {"ok": False, "error": f"Login failed: {stderr[:300]}"}
|
||||
# bw login --raw prints session key on success (when 2FA disabled)
|
||||
if stdout:
|
||||
cfg["session"] = stdout
|
||||
cfg["unlocked_at"] = datetime.utcnow().isoformat()
|
||||
_save_config(cfg)
|
||||
return {"ok": True}
|
||||
|
||||
@router.post("/unlock")
|
||||
async def unlock(req: VaultUnlockRequest, request: Request):
|
||||
"""Unlock the vault and save the session key."""
|
||||
require_admin(request)
|
||||
stdout, stderr, rc = await _run_bw(
|
||||
["unlock", req.master_password, "--raw"],
|
||||
)
|
||||
if rc != 0:
|
||||
return {"ok": False, "error": f"Unlock failed: {stderr[:300]}"}
|
||||
session = stdout.strip()
|
||||
if not session:
|
||||
return {"ok": False, "error": "bw returned empty session"}
|
||||
cfg = _load_config()
|
||||
cfg["session"] = session
|
||||
cfg["unlocked_at"] = datetime.utcnow().isoformat()
|
||||
_save_config(cfg)
|
||||
return {"ok": True, "message": "Vault unlocked"}
|
||||
|
||||
@router.post("/lock")
|
||||
async def lock(request: Request):
|
||||
"""Lock the vault (clear session from config)."""
|
||||
require_admin(request)
|
||||
cfg = _load_config()
|
||||
cfg.pop("session", None)
|
||||
cfg.pop("unlocked_at", None)
|
||||
_save_config(cfg)
|
||||
# Also tell bw to lock
|
||||
await _run_bw(["lock"])
|
||||
return {"ok": True, "message": "Vault locked"}
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(request: Request):
|
||||
"""Log out of the Bitwarden CLI completely."""
|
||||
require_admin(request)
|
||||
await _run_bw(["logout"])
|
||||
cfg = _load_config()
|
||||
cfg.pop("session", None)
|
||||
cfg.pop("email", None)
|
||||
cfg.pop("unlocked_at", None)
|
||||
_save_config(cfg)
|
||||
return {"ok": True}
|
||||
|
||||
return router
|
||||
|
||||
|
||||
async def _check_bw_installed() -> bool:
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
_find_bw(), "--version",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
await proc.communicate()
|
||||
return proc.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
322
routes/webhook_routes.py
Normal file
322
routes/webhook_routes.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""Webhook, API Token, and sync chat routes."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Request, Form
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.database import SessionLocal, Webhook
|
||||
from src.webhook_manager import WebhookManager, validate_webhook_url, validate_events
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["webhooks"])
|
||||
|
||||
# Input limits
|
||||
MAX_NAME_LEN = 100
|
||||
MAX_URL_LEN = 2048
|
||||
MAX_SECRET_LEN = 256
|
||||
MAX_MESSAGE_LEN = 32_000
|
||||
|
||||
|
||||
from core.middleware import require_admin as _require_admin
|
||||
|
||||
|
||||
def setup_webhook_routes(
|
||||
webhook_manager: WebhookManager,
|
||||
auth_manager,
|
||||
session_manager=None,
|
||||
api_key_manager=None,
|
||||
) -> APIRouter:
|
||||
|
||||
@router.get("/webhooks")
|
||||
def list_webhooks(request: Request):
|
||||
_require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
hooks = db.query(Webhook).all()
|
||||
return [
|
||||
{
|
||||
"id": w.id,
|
||||
"name": w.name,
|
||||
"url": w.url,
|
||||
"has_secret": bool(w.secret),
|
||||
"events": w.events.split(",") if w.events else [],
|
||||
"is_active": w.is_active,
|
||||
"last_triggered_at": w.last_triggered_at.isoformat() if w.last_triggered_at else None,
|
||||
"last_status_code": w.last_status_code,
|
||||
"last_error": w.last_error,
|
||||
"created_at": w.created_at.isoformat() if w.created_at else None,
|
||||
}
|
||||
for w in hooks
|
||||
]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.post("/webhooks")
|
||||
def create_webhook(
|
||||
request: Request,
|
||||
name: str = Form(""),
|
||||
url: str = Form(""),
|
||||
secret: str = Form(""),
|
||||
events: str = Form(""),
|
||||
):
|
||||
_require_admin(request)
|
||||
name = name.strip()[:MAX_NAME_LEN]
|
||||
if not name:
|
||||
raise HTTPException(400, "Webhook name is required")
|
||||
try:
|
||||
url = validate_webhook_url(url)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
try:
|
||||
events = validate_events(events)
|
||||
except ValueError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
||||
secret_val = secret.strip()[:MAX_SECRET_LEN] or None
|
||||
# Encrypt the secret at rest using the same Fernet key as API keys
|
||||
encrypted_secret = None
|
||||
if secret_val and api_key_manager:
|
||||
encrypted_secret = api_key_manager.encrypt_api_key(secret_val)
|
||||
elif secret_val:
|
||||
encrypted_secret = secret_val # Fallback if no encryption available
|
||||
|
||||
webhook_id = str(uuid.uuid4())[:8]
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db.add(Webhook(
|
||||
id=webhook_id,
|
||||
name=name,
|
||||
url=url,
|
||||
secret=encrypted_secret,
|
||||
events=events,
|
||||
is_active=True,
|
||||
))
|
||||
db.commit()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return {"id": webhook_id, "name": name}
|
||||
|
||||
@router.post("/webhooks/{webhook_id}/test")
|
||||
async def test_webhook(request: Request, webhook_id: str):
|
||||
_require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
wh = db.query(Webhook).filter(Webhook.id == webhook_id).first()
|
||||
if not wh:
|
||||
raise HTTPException(404, "Webhook not found")
|
||||
url, secret = wh.url, wh.secret
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
await webhook_manager.deliver_test(webhook_id, url, secret)
|
||||
return {"status": "sent"}
|
||||
|
||||
@router.patch("/webhooks/{webhook_id}")
|
||||
def toggle_webhook(request: Request, webhook_id: str):
|
||||
_require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
wh = db.query(Webhook).filter(Webhook.id == webhook_id).first()
|
||||
if not wh:
|
||||
raise HTTPException(404, "Webhook not found")
|
||||
wh.is_active = not wh.is_active
|
||||
db.commit()
|
||||
return {"id": webhook_id, "is_active": wh.is_active}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@router.delete("/webhooks/{webhook_id}")
|
||||
def delete_webhook(request: Request, webhook_id: str):
|
||||
_require_admin(request)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
deleted = db.query(Webhook).filter(Webhook.id == webhook_id).delete()
|
||||
db.commit()
|
||||
if not deleted:
|
||||
raise HTTPException(404, "Webhook not found")
|
||||
finally:
|
||||
db.close()
|
||||
return {"status": "deleted"}
|
||||
|
||||
# ================================================================
|
||||
# Sync Chat Endpoint (for n8n / Make / Activepieces)
|
||||
# ================================================================
|
||||
|
||||
# Known provider base URLs — auto-resolved from api_key prefix or model name
|
||||
KNOWN_PROVIDERS = {
|
||||
"deepseek": "https://api.deepseek.com/v1",
|
||||
"openai": "https://api.openai.com/v1",
|
||||
"mistral": "https://api.mistral.ai/v1",
|
||||
"groq": "https://api.groq.com/openai/v1",
|
||||
"together": "https://api.together.xyz/v1",
|
||||
"openrouter": "https://openrouter.ai/api/v1",
|
||||
"fireworks": "https://api.fireworks.ai/inference/v1",
|
||||
}
|
||||
|
||||
# Model prefix → provider mapping for auto-detection
|
||||
MODEL_PROVIDER_MAP = {
|
||||
"deepseek": "deepseek",
|
||||
"gpt-": "openai",
|
||||
"o1": "openai",
|
||||
"o3": "openai",
|
||||
"o4": "openai",
|
||||
"mistral": "mistral",
|
||||
"llama": "groq",
|
||||
"mixtral": "groq",
|
||||
}
|
||||
|
||||
def _resolve_base_url(model: Optional[str], provider: Optional[str]) -> Optional[str]:
|
||||
"""Try to auto-resolve a base URL from provider name or model prefix."""
|
||||
if provider and provider.lower() in KNOWN_PROVIDERS:
|
||||
return KNOWN_PROVIDERS[provider.lower()]
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
for prefix, prov in MODEL_PROVIDER_MAP.items():
|
||||
if model_lower.startswith(prefix):
|
||||
return KNOWN_PROVIDERS[prov]
|
||||
return None
|
||||
|
||||
class SyncChatRequest(BaseModel):
|
||||
message: str = Field(..., max_length=MAX_MESSAGE_LEN)
|
||||
model: Optional[str] = Field(None, max_length=200)
|
||||
session: Optional[str] = Field(None, max_length=100)
|
||||
api_key: Optional[str] = Field(None, max_length=256)
|
||||
base_url: Optional[str] = Field(None, max_length=MAX_URL_LEN)
|
||||
provider: Optional[str] = Field(None, max_length=50)
|
||||
|
||||
@router.post("/v1/chat")
|
||||
async def sync_chat(request: Request, body: SyncChatRequest):
|
||||
if not getattr(request.state, "api_token", False):
|
||||
raise HTTPException(403, "This endpoint requires an API token")
|
||||
scopes = set(getattr(request.state, "api_token_scopes", []) or [])
|
||||
if "chat" not in scopes:
|
||||
raise HTTPException(403, "API token is not scoped for chat")
|
||||
token_owner = getattr(request.state, "api_token_owner", None)
|
||||
|
||||
from core.models import ChatMessage
|
||||
from src.llm_core import llm_call_async
|
||||
from core.database import ModelEndpoint
|
||||
|
||||
message = body.message.strip()
|
||||
if not message:
|
||||
raise HTTPException(400, "Message is required")
|
||||
|
||||
session_id = body.session
|
||||
sess = None
|
||||
|
||||
# --- Case 1: Resume an existing session ---
|
||||
if session_id and session_manager:
|
||||
try:
|
||||
sess = session_manager.get_session(session_id)
|
||||
except (KeyError, Exception):
|
||||
raise HTTPException(404, "Session not found")
|
||||
# SECURITY: verify the API-token's user owns this session — without
|
||||
# this any token holder could resume any user's chat by passing its
|
||||
# ID. The token's user is on request.state.user (set by API-token
|
||||
# middleware); fall back to require_user if not present.
|
||||
try:
|
||||
from src.auth_helpers import get_current_user as _gcu
|
||||
_tok_user = token_owner or getattr(request.state, "user", None) or _gcu(request)
|
||||
except Exception:
|
||||
_tok_user = None
|
||||
_sess_owner = getattr(sess, "owner", None)
|
||||
if _tok_user and _sess_owner and _sess_owner != _tok_user:
|
||||
raise HTTPException(404, "Session not found")
|
||||
|
||||
# --- Case 2: Direct API key + model (no pre-configured endpoint needed) ---
|
||||
if not sess and body.api_key:
|
||||
api_key = body.api_key.strip()
|
||||
model = body.model or "deepseek-chat"
|
||||
|
||||
# Resolve base_url: explicit > provider name > model prefix auto-detect
|
||||
base_url = body.base_url.strip().rstrip("/") if body.base_url else None
|
||||
if not base_url:
|
||||
base_url = _resolve_base_url(model, body.provider)
|
||||
if not base_url:
|
||||
raise HTTPException(400,
|
||||
"Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') "
|
||||
"or provider ('deepseek', 'openai', 'groq', etc.)")
|
||||
|
||||
endpoint_url = base_url + "/chat/completions"
|
||||
|
||||
if not session_manager:
|
||||
raise HTTPException(500, "Session manager not available")
|
||||
|
||||
sid = str(uuid.uuid4())
|
||||
sess = session_manager.create_session(
|
||||
session_id=sid, name="API Chat", endpoint_url=endpoint_url,
|
||||
model=model, owner=token_owner,
|
||||
)
|
||||
sess.headers = {"Authorization": f"Bearer {api_key}"}
|
||||
session_manager.save_sessions()
|
||||
session_id = sid
|
||||
|
||||
# --- Case 3: Fall back to first configured ModelEndpoint ---
|
||||
if not sess:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if not ep:
|
||||
raise HTTPException(400,
|
||||
"No session, api_key, or configured endpoints. "
|
||||
"Pass api_key + model, or configure an endpoint in Admin.")
|
||||
|
||||
endpoint_url = ep.base_url.rstrip("/") + "/chat/completions"
|
||||
model = body.model or "auto"
|
||||
api_key = ep.api_key
|
||||
|
||||
if model == "auto":
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as client:
|
||||
models_url = ep.base_url.rstrip("/") + "/models"
|
||||
hdrs = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
resp = await client.get(models_url, headers=hdrs)
|
||||
resp.raise_for_status()
|
||||
ids = [m.get("id") for m in (resp.json().get("data") or []) if m.get("id")]
|
||||
model = ids[0] if ids else "auto"
|
||||
except Exception:
|
||||
raise HTTPException(500, "Could not discover models from endpoint")
|
||||
|
||||
if not session_manager:
|
||||
raise HTTPException(500, "Session manager not available")
|
||||
|
||||
sid = str(uuid.uuid4())
|
||||
sess = session_manager.create_session(
|
||||
session_id=sid, name="API Chat", endpoint_url=endpoint_url,
|
||||
model=model, owner=token_owner,
|
||||
)
|
||||
if api_key:
|
||||
sess.headers = {"Authorization": f"Bearer {api_key}"}
|
||||
session_manager.save_sessions()
|
||||
session_id = sid
|
||||
|
||||
# --- Send message and get response ---
|
||||
sess.add_message(ChatMessage("user", message))
|
||||
|
||||
messages = [{"role": m.role, "content": m.content} for m in sess.history]
|
||||
|
||||
reply = await llm_call_async(
|
||||
sess.endpoint_url, sess.model, messages,
|
||||
headers=sess.headers, timeout=120,
|
||||
)
|
||||
sess.add_message(ChatMessage("assistant", reply))
|
||||
session_manager.save_sessions()
|
||||
|
||||
asyncio.create_task(webhook_manager.fire("chat.completed", {
|
||||
"session_id": session_id, "model": sess.model,
|
||||
"user_message": message[:2000], "response": reply[:2000],
|
||||
}))
|
||||
|
||||
return {"response": reply, "session_id": session_id, "model": sess.model}
|
||||
|
||||
return router
|
||||
Reference in New Issue
Block a user