Odysseus v1.0

This commit is contained in:
pewdiepie-archdaemon
2026-05-31 23:58:26 +09:00
commit e5c99a5eee
421 changed files with 271349 additions and 0 deletions

0
routes/__init__.py Normal file
View File

174
routes/admin_wipe_routes.py Normal file
View File

@@ -0,0 +1,174 @@
"""Admin Danger Zone — per-category wipes.
Each endpoint is admin-only and truncates exactly one domain so the
user can selectively reset memory / skills / notes / etc. without
nuking everything. The catch-all `chats` endpoint mirrors the
existing /api/sessions/all so the Danger Zone speaks one URL pattern.
URL shape: DELETE /api/admin/wipe/{kind}
Kinds: chats, memory, skills, notes, tasks, documents, gallery, calendar.
"""
import json
import logging
import os
import shutil
from fastapi import APIRouter, HTTPException, Request
from core.middleware import require_admin
from core.database import (
SessionLocal,
Session as DbSession,
ChatMessage as DbChatMessage,
Memory,
Note,
ScheduledTask,
TaskRun,
Document,
DocumentVersion,
GalleryImage,
CalendarEvent,
CalendarCal,
)
from src.constants import DATA_DIR
logger = logging.getLogger(__name__)
def _wipe_memory_files():
"""Blank memory.json + drop the per-owner tidy-state sidecar so the
next audit doesn't try to diff against gone memories."""
for name in ("memory.json", "memory_tidy_state.json"):
p = os.path.join(DATA_DIR, name)
if not os.path.exists(p):
continue
try:
if name == "memory.json":
with open(p, "w") as f:
json.dump([], f)
else:
os.remove(p)
except OSError as e:
logger.warning(f"Could not reset {name}: {e}")
def _rmtree_quiet(path: str):
"""rmtree that doesn't crash if the path doesn't exist."""
if os.path.isdir(path):
try:
shutil.rmtree(path)
except OSError as e:
logger.warning(f"Could not remove {path}: {e}")
def setup_admin_wipe_routes(session_manager):
"""The session_manager is passed in so we can also clear its
in-memory cache when wiping chats — without it the DB is empty
but the next /api/sessions returns stale entries."""
router = APIRouter(prefix="/api/admin")
@router.delete("/wipe/{kind}")
def wipe(kind: str, request: Request):
require_admin(request)
kind = (kind or "").strip().lower()
db = SessionLocal()
try:
if kind == "chats":
count = db.query(DbSession).count()
db.query(DbChatMessage).delete()
db.query(DbSession).delete()
db.commit()
try:
session_manager.sessions.clear()
except Exception:
pass
return {"status": "deleted", "kind": kind, "count": count}
if kind == "memory":
count = db.query(Memory).count()
db.query(Memory).delete()
db.commit()
_wipe_memory_files()
# Drop the vector store too so semantic search doesn't
# return ghosts. Lazy import — chromadb may not be
# initialised in every deployment.
try:
from src.memory_vector import get_memory_vector_store
mv = get_memory_vector_store()
if mv and hasattr(mv, "clear"):
mv.clear()
except Exception as e:
logger.info(f"Memory vector clear skipped: {e}")
return {"status": "deleted", "kind": kind, "count": count}
if kind == "skills":
# Skills live as SKILL.md files under data/skills/. Drop
# the entire directory; the SkillsManager re-creates the
# tree on next write.
skills_dir = os.path.join(DATA_DIR, "skills")
count = 0
if os.path.isdir(skills_dir):
# Count SKILL.md files for the response — quick walk.
for _, _, files in os.walk(skills_dir):
count += sum(1 for f in files if f == "SKILL.md")
_rmtree_quiet(skills_dir)
# Legacy fallback file
legacy = os.path.join(DATA_DIR, "skills.json")
if os.path.exists(legacy):
try:
os.remove(legacy)
except OSError:
pass
return {"status": "deleted", "kind": kind, "count": count}
if kind == "notes":
count = db.query(Note).count()
db.query(Note).delete()
db.commit()
return {"status": "deleted", "kind": kind, "count": count}
if kind == "tasks":
# TaskRun rows reference tasks via FK — clear them first.
db.query(TaskRun).delete()
count = db.query(ScheduledTask).count()
db.query(ScheduledTask).delete()
db.commit()
return {"status": "deleted", "kind": kind, "count": count}
if kind == "documents":
# DocumentVersion FKs Document — clear children first.
db.query(DocumentVersion).delete()
count = db.query(Document).count()
db.query(Document).delete()
db.commit()
return {"status": "deleted", "kind": kind, "count": count}
if kind == "gallery":
count = db.query(GalleryImage).count()
db.query(GalleryImage).delete()
db.commit()
# Also drop the upload dir so disk doesn't keep orphans.
_rmtree_quiet(os.path.join(DATA_DIR, "gallery"))
_rmtree_quiet(os.path.join(DATA_DIR, "gallery_uploads"))
return {"status": "deleted", "kind": kind, "count": count}
if kind == "calendar":
# Events FK calendars — clear children first, then both.
db.query(CalendarEvent).delete()
count = db.query(CalendarCal).count()
db.query(CalendarCal).delete()
db.commit()
return {"status": "deleted", "kind": kind, "count": count}
raise HTTPException(400, f"Unknown wipe kind: {kind!r}")
except HTTPException:
raise
except Exception as e:
db.rollback()
logger.exception(f"Wipe {kind} failed")
raise HTTPException(500, f"Wipe {kind} failed: {e}")
finally:
db.close()
return router

View File

@@ -0,0 +1,91 @@
"""API Token management routes — /api/tokens/*."""
import secrets
import uuid
import bcrypt
from fastapi import APIRouter, HTTPException, Request, Form
from core.database import get_db_session, ApiToken
from core.middleware import require_admin
from src.auth_helpers import get_current_user
MAX_NAME_LEN = 100
DEFAULT_SCOPES = "chat"
def setup_api_token_routes() -> APIRouter:
router = APIRouter(prefix="/api", tags=["api_tokens"])
@router.get("/tokens")
def list_tokens(request: Request):
require_admin(request)
with get_db_session() as db:
tokens = db.query(ApiToken).all()
return [
{
"id": t.id,
"name": t.name,
"owner": getattr(t, "owner", None),
"token_prefix": t.token_prefix,
"scopes": [s.strip() for s in (getattr(t, "scopes", "") or DEFAULT_SCOPES).split(",") if s.strip()],
"is_active": t.is_active,
"last_used_at": t.last_used_at.isoformat() if t.last_used_at else None,
"created_at": t.created_at.isoformat() if t.created_at else None,
}
for t in tokens
]
def _invalidate_cache(request: Request):
"""Tell the auth middleware its cached token map is stale."""
try:
invalidator = getattr(request.app.state, "invalidate_token_cache", None)
if invalidator:
invalidator()
except Exception:
pass
@router.post("/tokens")
def create_token(request: Request, name: str = Form("")):
require_admin(request)
name = name.strip()[:MAX_NAME_LEN]
if not name:
raise HTTPException(400, "Token name is required")
owner = get_current_user(request)
raw_token = "ody_" + secrets.token_urlsafe(32)
token_hash = bcrypt.hashpw(raw_token.encode(), bcrypt.gensalt()).decode()
token_id = str(uuid.uuid4())[:8]
with get_db_session() as db:
db.add(ApiToken(
id=token_id,
owner=owner,
name=name,
token_hash=token_hash,
token_prefix=raw_token[:8],
scopes=DEFAULT_SCOPES,
is_active=True,
))
_invalidate_cache(request)
return {
"id": token_id,
"name": name,
"owner": owner,
"token": raw_token,
"token_prefix": raw_token[:8],
"scopes": DEFAULT_SCOPES.split(","),
}
@router.delete("/tokens/{token_id}")
def delete_token(request: Request, token_id: str):
require_admin(request)
with get_db_session() as db:
deleted = db.query(ApiToken).filter(ApiToken.id == token_id).delete()
if not deleted:
raise HTTPException(404, "Token not found")
_invalidate_cache(request)
return {"status": "deleted"}
return router

325
routes/assistant_routes.py Normal file
View File

@@ -0,0 +1,325 @@
"""Personal assistant routes — resolve the per-user singleton, read/write
its settings, and list its scheduled check-in tasks.
The personal assistant is just a specially-flagged CrewMember that owns one
pinned Session and three daily ScheduledTasks ("Morning/Midday/Evening
check-in"). Everything about it is user-editable: name, personality, model,
enabled tools, timezone, and the three check-in times/prompts/enabled flags.
"""
import json
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from core.database import SessionLocal, CrewMember, ScheduledTask
from src.auth_helpers import get_current_user
from src.task_scheduler import compute_next_run
class CheckInUpdate(BaseModel):
id: str # ScheduledTask.id
name: Optional[str] = None
scheduled_time: Optional[str] = None # "HH:MM"
prompt: Optional[str] = None
enabled: Optional[bool] = None # maps to status "active"/"paused"
class AssistantSettingsUpdate(BaseModel):
name: Optional[str] = None
avatar: Optional[str] = None
personality: Optional[str] = None
model: Optional[str] = None
endpoint_url: Optional[str] = None
enabled_tools: Optional[list[str]] = None
allow_autonomous_email: Optional[bool] = None # convenience toggle
timezone: Optional[str] = None
check_ins: Optional[list[CheckInUpdate]] = None
_EMAIL_TOOLS = {"send_email", "reply_to_email"}
def _crew_to_dict(c: CrewMember) -> dict:
try:
tools = json.loads(c.enabled_tools) if c.enabled_tools else []
except Exception:
tools = []
return {
"id": c.id,
"name": c.name,
"avatar": c.avatar,
"personality": c.personality,
"model": c.model,
"endpoint_url": c.endpoint_url,
"greeting": c.greeting,
"enabled_tools": tools,
"session_id": c.session_id,
"is_default_assistant": bool(c.is_default_assistant),
"timezone": c.timezone,
"allow_autonomous_email": any(t in _EMAIL_TOOLS for t in tools),
}
def _task_to_checkin_dict(t: ScheduledTask) -> dict:
return {
"id": t.id,
"name": t.name,
"scheduled_time": t.scheduled_time,
"prompt": t.prompt,
"enabled": (t.status or "active") == "active",
"next_run": t.next_run.isoformat() + "Z" if t.next_run else None,
"last_run": t.last_run.isoformat() + "Z" if t.last_run else None,
"run_count": t.run_count or 0,
}
def setup_assistant_routes(task_scheduler) -> APIRouter:
router = APIRouter(prefix="/api/assistant", tags=["assistant"])
def _owner(request: Request) -> str:
owner = get_current_user(request)
if not owner:
raise HTTPException(status_code=401, detail="Not authenticated")
return owner
# Synthetic / non-human owners that should NEVER get an assistant +
# check-in tasks seeded. Hitting any /assistant route under one of these
# used to seed a full CrewMember + Morning/Midday/Evening tasks under that
# owner, which then double-fired alongside the real user's check-ins.
_SYNTHETIC_OWNERS = frozenset({"internal-tool", "api", "demo", "system", ""})
async def _get_or_create(owner: str) -> CrewMember:
"""Return the per-owner assistant CrewMember, creating it on demand."""
if not owner or owner in _SYNTHETIC_OWNERS:
raise HTTPException(status_code=400, detail=f"Cannot seed assistant for {owner!r}")
db = SessionLocal()
try:
crew = db.query(CrewMember).filter(
CrewMember.owner == owner,
CrewMember.is_default_assistant == True, # noqa: E712
).first()
if crew:
return crew
finally:
db.close()
# Seed lazily. This is the same code the startup hook runs for each
# user — safe to call again, it's idempotent.
await task_scheduler.ensure_assistant_defaults(owner)
db = SessionLocal()
try:
crew = db.query(CrewMember).filter(
CrewMember.owner == owner,
CrewMember.is_default_assistant == True, # noqa: E712
).first()
return crew
finally:
db.close()
@router.get("/session")
async def get_assistant_session(request: Request):
"""Resolve (or lazily create) the pinned Assistant session for this user."""
owner = _owner(request)
crew = await _get_or_create(owner)
if not crew or not crew.session_id:
raise HTTPException(status_code=500, detail="Assistant session could not be resolved")
return {
"session_id": crew.session_id,
"crew_member_id": crew.id,
"name": crew.name,
}
@router.get("/settings")
async def get_assistant_settings(request: Request):
"""Return CrewMember fields + the three check-in task rows + task IDs for logs."""
owner = _owner(request)
crew = await _get_or_create(owner)
if not crew:
raise HTTPException(status_code=500, detail="Assistant not available")
db = SessionLocal()
try:
tasks = db.query(ScheduledTask).filter(
ScheduledTask.owner == owner,
ScheduledTask.crew_member_id == crew.id,
).order_by(ScheduledTask.scheduled_time.asc()).all()
return {
"crew": _crew_to_dict(crew),
"check_ins": [_task_to_checkin_dict(t) for t in tasks],
"task_ids": [t.id for t in tasks],
}
finally:
db.close()
@router.patch("/settings")
async def update_assistant_settings(payload: AssistantSettingsUpdate, request: Request):
"""Update CrewMember fields and/or check-in tasks in one call."""
owner = _owner(request)
crew = await _get_or_create(owner)
if not crew:
raise HTTPException(status_code=500, detail="Assistant not available")
db = SessionLocal()
try:
crew_db = db.query(CrewMember).filter(CrewMember.id == crew.id).first()
if not crew_db:
raise HTTPException(status_code=404, detail="Assistant not found")
# Update CrewMember fields.
if payload.name is not None:
crew_db.name = payload.name.strip() or crew_db.name
if payload.avatar is not None:
crew_db.avatar = payload.avatar
if payload.personality is not None:
crew_db.personality = payload.personality
if payload.model is not None:
crew_db.model = payload.model or None
if payload.endpoint_url is not None:
crew_db.endpoint_url = payload.endpoint_url or None
if payload.timezone is not None:
crew_db.timezone = payload.timezone or None
# Tool list: either explicit list, or implicit toggle.
if payload.enabled_tools is not None:
crew_db.enabled_tools = json.dumps(payload.enabled_tools)
if payload.allow_autonomous_email is not None:
try:
existing = json.loads(crew_db.enabled_tools) if crew_db.enabled_tools else []
except Exception:
existing = []
if payload.allow_autonomous_email:
for t in ("send_email", "reply_to_email"):
if t not in existing:
existing.append(t)
else:
existing = [t for t in existing if t not in _EMAIL_TOOLS]
crew_db.enabled_tools = json.dumps(existing)
crew_db.updated_at = datetime.utcnow()
# Update check-in tasks.
if payload.check_ins:
now_utc = datetime.utcnow()
tz_name = crew_db.timezone or None
for ci in payload.check_ins:
task = db.query(ScheduledTask).filter(
ScheduledTask.id == ci.id,
ScheduledTask.owner == owner,
ScheduledTask.crew_member_id == crew_db.id,
).first()
if not task:
continue
if ci.name is not None:
task.name = ci.name.strip() or task.name
time_changed = False
if ci.scheduled_time is not None and ci.scheduled_time != task.scheduled_time:
task.scheduled_time = ci.scheduled_time
time_changed = True
if ci.prompt is not None:
task.prompt = ci.prompt
if ci.enabled is not None:
task.status = "active" if ci.enabled else "paused"
if time_changed or ci.enabled is True:
task.next_run = compute_next_run(
task.schedule or "daily",
task.scheduled_time,
task.scheduled_day,
task.scheduled_date,
after=now_utc,
cron_expression=task.cron_expression,
tz_name=tz_name,
)
task.updated_at = datetime.utcnow()
# Timezone change also shifts the NEXT run of all check-ins even if
# the user didn't touch the time fields.
if payload.timezone is not None:
now_utc = datetime.utcnow()
tz_name = crew_db.timezone or None
tasks = db.query(ScheduledTask).filter(
ScheduledTask.owner == owner,
ScheduledTask.crew_member_id == crew_db.id,
).all()
for t in tasks:
if t.schedule and t.scheduled_time:
t.next_run = compute_next_run(
t.schedule, t.scheduled_time, t.scheduled_day, t.scheduled_date,
after=now_utc, cron_expression=t.cron_expression, tz_name=tz_name,
)
db.commit()
# Re-read crew_db + tasks to return the fresh state.
crew_out = db.query(CrewMember).filter(CrewMember.id == crew.id).first()
tasks_out = db.query(ScheduledTask).filter(
ScheduledTask.owner == owner,
ScheduledTask.crew_member_id == crew.id,
).order_by(ScheduledTask.scheduled_time.asc()).all()
return {
"crew": _crew_to_dict(crew_out),
"check_ins": [_task_to_checkin_dict(t) for t in tasks_out],
"task_ids": [t.id for t in tasks_out],
}
finally:
db.close()
@router.post("/run/{task_id}")
async def run_check_in_now(task_id: str, request: Request):
"""Trigger one of the assistant's check-ins immediately (manual test)."""
owner = _owner(request)
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(
ScheduledTask.id == task_id,
ScheduledTask.owner == owner,
).first()
if not task:
raise HTTPException(status_code=404, detail="Task not found")
crew = db.query(CrewMember).filter(
CrewMember.id == task.crew_member_id,
CrewMember.is_default_assistant == True, # noqa: E712
).first()
if not crew:
raise HTTPException(status_code=400, detail="Not an assistant task")
finally:
db.close()
started = await task_scheduler.run_task_now(task_id)
return {"started": bool(started)}
@router.get("/run-status/{task_id}")
async def run_status(task_id: str, request: Request):
"""Check whether the most recent run of a task has finished."""
from core.database import TaskRun, ScheduledTask
user = _owner(request)
db = SessionLocal()
try:
# SECURITY: 404 if the task doesn't belong to this user — without
# this any authenticated user could poll the status of any task_id.
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
if not task:
raise HTTPException(404, "Task not found")
if user and task.owner != user:
raise HTTPException(404, "Task not found")
run = db.query(TaskRun).filter(
TaskRun.task_id == task_id,
).order_by(TaskRun.started_at.desc()).first()
if not run:
return {"status": "unknown"}
if run.status == "running":
return {"status": "running"}
return {"status": "done", "result_status": run.status}
finally:
db.close()
@router.get("/available-timezones")
async def list_timezones():
"""Return the IANA tz name list used to populate the settings dropdown."""
try:
from zoneinfo import available_timezones
zones = sorted(available_timezones())
except Exception:
zones = ["UTC"]
return {"timezones": zones}
return router

502
routes/auth_routes.py Normal file
View File

@@ -0,0 +1,502 @@
"""Authentication routes — login, logout, signup, status, user management."""
from fastapi import APIRouter, Request, Response, HTTPException
from pydantic import BaseModel
from typing import Optional
import logging
import os
from core.auth import AuthManager
from src.rate_limiter import RateLimiter
from src.settings import (
load_settings as _load_settings,
save_settings as _save_settings,
load_features as _load_features,
save_features as _save_features,
DEFAULT_SETTINGS,
)
from src.integrations import (
load_integrations,
add_integration,
update_integration,
delete_integration,
get_integration,
execute_api_call,
INTEGRATION_PRESETS,
migrate_from_settings,
)
logger = logging.getLogger(__name__)
class LoginRequest(BaseModel):
username: str
password: str
remember: bool = True
totp_code: Optional[str] = None
class SetupRequest(BaseModel):
username: str
password: str
class SignupRequest(BaseModel):
username: str
password: str
class ChangePasswordRequest(BaseModel):
current_password: str
new_password: str
class CreateUserRequest(BaseModel):
username: str
password: str
is_admin: bool = False
class DeleteUserRequest(BaseModel):
username: str
SESSION_COOKIE = "odysseus_session"
def setup_auth_routes(auth_manager: AuthManager) -> APIRouter:
router = APIRouter(prefix="/api/auth", tags=["auth"])
_login_limiter = RateLimiter(max_requests=15, window_seconds=60)
_signup_limiter = RateLimiter(max_requests=3, window_seconds=300)
_setup_limiter = RateLimiter(max_requests=3, window_seconds=300)
def _get_current_user(request: Request) -> Optional[str]:
token = request.cookies.get(SESSION_COOKIE)
return auth_manager.get_username_for_token(token)
@router.post("/setup")
async def first_run_setup(body: SetupRequest, request: Request):
"""Create initial admin account. Only works if no accounts exist."""
if not _setup_limiter.check(request.client.host):
raise HTTPException(429, "Too many requests — try again later")
if auth_manager.is_configured:
raise HTTPException(400, "Already configured")
if len(body.password) < 8:
raise HTTPException(400, "Password must be at least 8 characters")
ok = auth_manager.setup(body.username, body.password)
if not ok:
raise HTTPException(500, "Setup failed")
return {"ok": True, "message": "Admin account created"}
@router.post("/signup")
async def signup(body: SignupRequest, request: Request):
"""Create a new user account. Only works if signup is enabled by admin."""
if not _signup_limiter.check(request.client.host):
raise HTTPException(429, "Too many requests — try again later")
if not auth_manager.is_configured:
raise HTTPException(400, "Run setup first")
if not auth_manager.signup_enabled:
raise HTTPException(403, "Registration is disabled. Ask an admin for an account.")
if len(body.password) < 8:
raise HTTPException(400, "Password must be at least 8 characters")
if len(body.username.strip()) < 1:
raise HTTPException(400, "Username is required")
ok = auth_manager.create_user(body.username, body.password, is_admin=False)
if not ok:
raise HTTPException(409, "Username already taken")
return {"ok": True, "message": "Account created"}
@router.post("/login")
async def login(body: LoginRequest, request: Request, response: Response):
if not _login_limiter.check(request.client.host):
raise HTTPException(429, "Too many requests — try again later")
# Verify password first
username = body.username.strip().lower()
if not auth_manager.verify_password(username, body.password):
raise HTTPException(401, "Invalid credentials")
# Check 2FA if enabled
if auth_manager.totp_enabled(username):
if not body.totp_code:
# Password OK but need TOTP — tell client to show code input
return {"ok": False, "requires_totp": True, "username": username}
if not auth_manager.totp_verify(username, body.totp_code):
raise HTTPException(401, "Invalid 2FA code")
# All checks passed — create session
token = auth_manager.create_session(username, body.password)
if not token:
raise HTTPException(401, "Invalid credentials")
cookie_kwargs = dict(
key=SESSION_COOKIE,
value=token,
httponly=True,
samesite="lax",
secure=os.getenv("SECURE_COOKIES", "false").lower() == "true",
path="/",
)
if body.remember:
cookie_kwargs["max_age"] = 60 * 60 * 24 * 7 # 7 days
response.set_cookie(**cookie_kwargs)
return {"ok": True, "username": username}
@router.post("/logout")
async def logout(request: Request, response: Response):
token = request.cookies.get(SESSION_COOKIE)
if token:
auth_manager.revoke_token(token)
response.delete_cookie(SESSION_COOKIE, path="/")
return {"ok": True}
@router.get("/status")
async def auth_status(request: Request):
token = request.cookies.get(SESSION_COOKIE)
result = auth_manager.status(token)
result["signup_enabled"] = auth_manager.signup_enabled
# Include the caller's effective privileges so the frontend can
# hide / dim UI controls the user isn't allowed to use. Admins get
# ADMIN_PRIVILEGES (everything on), regular users get their stored
# set merged with DEFAULT_PRIVILEGES.
try:
u = result.get("username")
if u:
result["privileges"] = auth_manager.get_privileges(u)
except Exception:
pass
return result
@router.post("/change-password")
async def change_password(body: ChangePasswordRequest, request: Request):
user = _get_current_user(request)
if not user:
raise HTTPException(401, "Not authenticated")
if len(body.new_password) < 8:
raise HTTPException(400, "Password must be at least 8 characters")
ok = auth_manager.change_password(user, body.current_password, body.new_password)
if not ok:
raise HTTPException(400, "Current password is incorrect")
return {"ok": True}
# ------------------------------------------------------------------
# Two-factor authentication
# ------------------------------------------------------------------
@router.post("/2fa/setup")
async def totp_setup(request: Request):
"""Generate a TOTP secret and return the QR code URI."""
user = _get_current_user(request)
if not user:
raise HTTPException(401, "Not authenticated")
if auth_manager.totp_enabled(user):
raise HTTPException(400, "2FA is already enabled")
secret = auth_manager.totp_generate_secret(user)
if not secret:
raise HTTPException(500, "Failed to generate secret")
uri = auth_manager.totp_get_provisioning_uri(user, secret)
# Generate QR code as base64 PNG
import qrcode, io, base64
qr = qrcode.make(uri, box_size=6, border=2)
buf = io.BytesIO()
qr.save(buf, format="PNG")
qr_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
return {"secret": secret, "uri": uri, "qr_code": f"data:image/png;base64,{qr_b64}"}
class TotpVerifyRequest(BaseModel):
code: str
@router.post("/2fa/confirm")
async def totp_confirm(body: TotpVerifyRequest, request: Request):
"""Verify a TOTP code to confirm 2FA setup. Returns backup codes."""
user = _get_current_user(request)
if not user:
raise HTTPException(401, "Not authenticated")
if not auth_manager.totp_confirm_enable(user, body.code):
raise HTTPException(400, "Invalid code — try again")
backup = auth_manager.users.get(user, {}).get("totp_backup_codes", [])
return {"ok": True, "backup_codes": backup}
class TotpDisableRequest(BaseModel):
password: str
@router.post("/2fa/disable")
async def totp_disable(body: TotpDisableRequest, request: Request):
"""Disable 2FA. Requires password confirmation."""
user = _get_current_user(request)
if not user:
raise HTTPException(401, "Not authenticated")
if not auth_manager.totp_disable(user, body.password):
raise HTTPException(400, "Invalid password")
return {"ok": True}
@router.get("/2fa/status")
async def totp_status(request: Request):
"""Check if 2FA is enabled for the current user."""
user = _get_current_user(request)
if not user:
raise HTTPException(401, "Not authenticated")
return {"enabled": auth_manager.totp_enabled(user)}
# Admin-only routes
@router.get("/users")
async def list_users(request: Request):
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
return {"users": auth_manager.list_users()}
@router.post("/users")
async def admin_create_user(body: CreateUserRequest, request: Request):
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
if len(body.password) < 8:
raise HTTPException(400, "Password must be at least 8 characters")
ok = auth_manager.create_user(body.username, body.password, body.is_admin)
if not ok:
raise HTTPException(409, "Username already taken")
return {"ok": True}
@router.put("/users/{username}/privileges")
async def update_user_privileges(username: str, request: Request):
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
body = await request.json()
ok = auth_manager.set_privileges(username, body)
if not ok:
raise HTTPException(404, "User not found or is admin")
return {"ok": True, "privileges": auth_manager.get_privileges(username)}
@router.post("/signup-toggle")
async def toggle_signup(request: Request):
"""Toggle open registration on/off. Admin only."""
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
auth_manager.signup_enabled = not auth_manager.signup_enabled
return {"ok": True, "signup_enabled": auth_manager.signup_enabled}
@router.delete("/users")
async def admin_delete_user(body: DeleteUserRequest, request: Request):
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
ok = auth_manager.delete_user(body.username, user)
if not ok:
raise HTTPException(400, "Cannot delete user")
return {"ok": True}
# ---- Feature visibility (admin-managed) ----
@router.get("/features")
async def get_features():
"""Public: returns which UI features are enabled."""
return _load_features()
@router.post("/features")
async def set_features(request: Request):
"""Admin only: update feature toggles."""
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
body = await request.json()
current = _load_features()
for key in current:
if key in body and isinstance(body[key], bool):
current[key] = body[key]
_save_features(current)
return current
# ---- App settings (admin-managed) ----
_SECRET_KEY_PATTERNS = ("_api_key", "_password", "_secret", "_token", "_key")
def _is_secret_key(name: str) -> bool:
n = (name or "").lower()
if n in ("google_pse_cx",): # public identifier, not a secret
return False
return any(n.endswith(p) or n == p.lstrip("_") for p in _SECRET_KEY_PATTERNS)
def _scrub_settings(settings: dict) -> dict:
"""Return a copy of settings with secret-shaped values masked.
Frontend reads /settings without auth for things like keybinds + TTS
prefs. Secrets (search-provider keys, IMAP/SMTP passwords) must NOT
be exposed to non-admin callers.
"""
scrubbed = {}
for k, v in (settings or {}).items():
if _is_secret_key(k) and isinstance(v, str) and v:
scrubbed[k] = "" # presence preserved, value blanked
else:
scrubbed[k] = v
return scrubbed
@router.get("/settings")
async def get_settings(request: Request):
"""Returns app settings. Admins get the full set; non-admins get
a scrubbed copy with secret keys blanked. The frontend uses this
for keybinds + TTS prefs, so it stays callable without admin."""
user = _get_current_user(request)
settings = _load_settings()
if user and auth_manager.is_admin(user):
return settings
return _scrub_settings(settings)
@router.post("/settings")
async def set_settings(request: Request):
"""Admin only: update app settings."""
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
body = await request.json()
current = _load_settings()
for key in DEFAULT_SETTINGS:
if key in body:
current[key] = body[key]
_save_settings(current)
return current
# ---- Integrations CRUD ----
# Run migration on startup
migrate_from_settings()
@router.get("/integrations")
async def list_integrations_route(request: Request):
"""List all integrations (admin only, keys masked)."""
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
items = load_integrations()
# Mask API keys for frontend display
safe = []
for item in items:
copy = dict(item)
if copy.get("api_key"):
copy["api_key"] = copy["api_key"][:4] + "****"
safe.append(copy)
return {"integrations": safe}
@router.get("/integrations/presets")
async def list_presets():
"""List available integration presets."""
return {"presets": {k: {kk: vv for kk, vv in v.items() if kk != "api_key"} for k, v in INTEGRATION_PRESETS.items()}}
@router.post("/integrations")
async def create_integration(request: Request):
"""Create a new integration (admin only)."""
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
body = await request.json()
item = add_integration(body)
return {"ok": True, "integration": item}
@router.put("/integrations/{integration_id}")
async def update_integration_route(integration_id: str, request: Request):
"""Update an existing integration (admin only)."""
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
body = await request.json()
item = update_integration(integration_id, body)
if not item:
raise HTTPException(404, "Integration not found")
return {"ok": True, "integration": item}
@router.delete("/integrations/{integration_id}")
async def delete_integration_route(integration_id: str, request: Request):
"""Delete an integration (admin only)."""
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
ok = delete_integration(integration_id)
if not ok:
raise HTTPException(404, "Integration not found")
return {"ok": True}
@router.post("/integrations/{integration_id}/test")
async def test_integration_route(integration_id: str, request: Request):
"""Test connectivity to an integration (admin only)."""
user = _get_current_user(request)
if not user or not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
integ = get_integration(integration_id)
if not integ:
raise HTTPException(404, "Integration not found")
preset = (integ.get("preset") or integ.get("name", "")).lower()
# ntfy is special: a GET / proves the server is reachable but
# publishes nothing, so the user has no way to know whether
# subscribers will actually receive notifications. Instead, do
# the real thing — POST a one-line "connectivity test" message
# to the topic the Reminders panel is configured to use. If the
# subscriber app is wired up correctly, this is what the green
# checkmark + a phone ping confirms together.
if preset == "ntfy":
import httpx
from urllib.parse import urlparse
# Strip any path/query the user accidentally pasted in the
# base URL (e.g. `http://host:8091/odysseus`) — otherwise
# the topic gets appended after the path and we publish to
# `/odysseus/odysseus` (which ntfy 404s on). ntfy itself
# only ever serves from the root.
raw_base = (integ.get("base_url") or "").strip()
parsed = urlparse(raw_base)
base = f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else raw_base.rstrip("/")
settings = _load_settings()
topic = (settings.get("reminder_ntfy_topic") or "reminders").strip() or "reminders"
full_url = f"{base}/{topic}"
api_key = integ.get("api_key", "")
auth_type = (integ.get("auth_type") or "none").lower()
headers = {
"Title": "Odysseus connectivity test",
"Tags": "white_check_mark",
"Priority": "default",
}
if api_key:
if auth_type == "bearer":
headers["Authorization"] = f"Bearer {api_key}"
elif auth_type == "header":
headers[integ.get("auth_header") or "Authorization"] = api_key
try:
async with httpx.AsyncClient(timeout=8.0) as client:
r = await client.post(
full_url,
content="Connectivity test from Odysseus. If you see this on your phone, ntfy is wired up correctly.",
headers=headers,
)
if r.is_success:
# Tell the user EXACTLY where it went and what to
# subscribe to on their phone, so they can match
# without guesswork. The doubled-topic / wrong-host
# mistakes are easier to spot when the actual URL
# is right there in the success line.
return {
"ok": True,
"message": (
f"Sent to {full_url} — on your ntfy app, "
f"subscribe to topic \"{topic}\" with server "
f"\"{base}\" (or paste the full URL: {full_url})."
),
}
return {"ok": False, "message": f"ntfy returned HTTP {r.status_code} from {full_url}: {r.text[:200]}"}
except Exception as e:
return {"ok": False, "message": f"ntfy publish to {full_url} failed: {e}"[:300]}
# All other presets: GET against a known health endpoint.
# Fall back to detecting from name if preset is missing.
health_paths = {
"miniflux": "/v1/me",
"gitea": "/api/v1/version",
"linkding": "/api/tags/",
"homeassistant": "/api/",
"home assistant": "/api/",
}
path = health_paths.get(preset, "/")
result = await execute_api_call(integration_id, "GET", path)
if result.get("exit_code", 1) == 0:
return {"ok": True, "message": "Connection successful"}
return {"ok": False, "message": (result.get("error") or "Connection failed")[:300]}
return router

157
routes/backup_routes.py Normal file
View File

@@ -0,0 +1,157 @@
"""Backup routes — export/import user data (memories, presets, settings, skills, preferences)."""
import json
import logging
from datetime import datetime
from fastapi import APIRouter, HTTPException, Request, Response
from core.middleware import require_admin
from src.auth_helpers import get_current_user
from src.settings import load_settings, save_settings, load_features, save_features
logger = logging.getLogger(__name__)
def setup_backup_routes(memory_manager, preset_manager, skills_manager) -> APIRouter:
router = APIRouter(tags=["backup"])
@router.get("/api/export")
async def export_data(request: Request):
"""Export all user data as a downloadable JSON file."""
require_admin(request)
user = get_current_user(request)
# Memories (filtered by owner when auth is enabled)
memories = memory_manager.load(owner=user)
# Presets (shared across users — export all)
presets = preset_manager.get_all()
# Skills (filtered by owner when auth is enabled)
skills = skills_manager.load(owner=user)
# Settings
settings = load_settings()
# Feature flags
features = load_features()
# User preferences
from routes.prefs_routes import _load_for_user
preferences = _load_for_user(user)
export_data = {
"version": 1,
"exported_at": datetime.now().isoformat(),
"exported_by": user,
"memories": memories,
"presets": presets,
"skills": skills,
"settings": settings,
"features": features,
"preferences": preferences,
}
filename = f"odysseus_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
return Response(
content=json.dumps(export_data, indent=2, ensure_ascii=False),
media_type="application/json",
headers={"Content-Disposition": f"attachment; filename={filename}"},
)
@router.post("/api/import")
async def import_data(request: Request):
"""Import user data from a previously exported JSON file. Merges with existing data."""
require_admin(request)
user = get_current_user(request)
try:
body = await request.json()
except Exception:
raise HTTPException(400, "Invalid JSON")
if not isinstance(body, dict):
raise HTTPException(400, "Expected a JSON object")
imported = []
# ── Memories ──
if "memories" in body and isinstance(body["memories"], list):
existing = memory_manager.load_all()
existing_texts = {e.get("text", "").strip().lower() for e in existing}
added = 0
for mem in body["memories"]:
if not isinstance(mem, dict) or not mem.get("text"):
continue
if mem["text"].strip().lower() in existing_texts:
continue # skip duplicates
# Assign owner when auth is enabled
if user and not mem.get("owner"):
mem["owner"] = user
existing.append(mem)
existing_texts.add(mem["text"].strip().lower())
added += 1
memory_manager.save(existing)
imported.append(f"{added} memories")
# ── Skills ──
if "skills" in body and isinstance(body["skills"], list):
existing = skills_manager.load_all()
existing_ids = {s.get("id") for s in existing}
existing_titles = {s.get("title", "").strip().lower() for s in existing}
added = 0
for skill in body["skills"]:
if not isinstance(skill, dict) or not skill.get("title"):
continue
# Skip if same id or same title already exists
if skill.get("id") in existing_ids:
continue
if skill["title"].strip().lower() in existing_titles:
continue
if user and not skill.get("owner"):
skill["owner"] = user
existing.append(skill)
existing_ids.add(skill.get("id"))
existing_titles.add(skill["title"].strip().lower())
added += 1
skills_manager.save(existing)
imported.append(f"{added} skills")
# ── Presets ──
if "presets" in body and isinstance(body["presets"], dict):
current = preset_manager.get_all()
for key, value in body["presets"].items():
if isinstance(value, dict):
current[key] = value
elif isinstance(value, list):
current[key] = value
preset_manager.save(current)
imported.append("presets")
# ── Settings ──
if "settings" in body and isinstance(body["settings"], dict):
current = load_settings()
current.update(body["settings"])
save_settings(current)
imported.append("settings")
# ── Features ──
if "features" in body and isinstance(body["features"], dict):
current = load_features()
current.update(body["features"])
save_features(current)
imported.append("features")
# ── Preferences ──
if "preferences" in body and isinstance(body["preferences"], dict):
from routes.prefs_routes import _load_for_user, _save_for_user
current = _load_for_user(user)
current.update(body["preferences"])
_save_for_user(user, current)
imported.append("preferences")
if not imported:
return {"ok": False, "message": "No recognized data found in the file"}
return {"ok": True, "imported": imported, "message": f"Imported: {', '.join(imported)}"}
return router

1071
routes/calendar_routes.py Normal file

File diff suppressed because it is too large Load Diff

802
routes/chat_helpers.py Normal file
View File

@@ -0,0 +1,802 @@
"""Shared helpers for chat routes — context building, post-response tasks, auth resolution."""
import asyncio
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Optional
from core.models import ChatMessage
from core.database import SessionLocal
from core.database import Session as DBSession, ModelEndpoint
from src.llm_core import normalize_model_id
from src.context_compactor import maybe_compact, trim_for_context
from src.auth_helpers import get_current_user
from src.prompt_security import untrusted_context_message
from routes.prefs_routes import _load_for_user as load_prefs_for_user
from fastapi import HTTPException
logger = logging.getLogger(__name__)
# ── Data containers ────────────────────────────────────────────────────── #
@dataclass
class PresetInfo:
"""Extracted preset parameters."""
temperature: Optional[float]
max_tokens: Optional[int]
system_prompt: Optional[str]
character_name: Optional[str]
@dataclass
class PreprocessedMessage:
"""Result of chat_handler.preprocess_message."""
enhanced_message: str
user_content: Any # str or list (multimodal)
text_for_context: str
youtube_transcripts: list
attachment_meta: list
@dataclass
class ChatContext:
"""Everything needed to call the LLM after context-building."""
preface: list
rag_sources: list
web_sources: list
used_memories: list
messages: list
context_length: int
was_compacted: bool
user: Optional[str]
uprefs: dict
preset: PresetInfo
preprocessed: PreprocessedMessage
# Documents auto-created server-side during preprocess (e.g. when an
# attached fillable PDF gets rendered into a markdown editor doc).
# The chat route emits a doc_update SSE event for each before streaming
# begins, so the editor pane switches to the new doc immediately.
auto_opened_docs: list = field(default_factory=list)
# ── Helpers ────────────────────────────────────────────────────────────── #
def _enforce_chat_privileges(request, sess) -> None:
"""Apply the per-user privilege gates (allowed_models + max_messages_per_day)
that both /api/chat and /api/chat_stream must enforce BEFORE any LLM work.
Raises HTTPException(403) if the session's model is not in the user's
allowlist, or HTTPException(429) if the user has hit their daily message
cap. No-op for unauthenticated callers or when auth_manager is absent
(single-user mode). Admins receive ADMIN_PRIVILEGES from get_privileges,
which means empty allowed_models / zero cap → no-op for them.
"""
try:
user = get_current_user(request)
except Exception:
user = None
if not user:
return
auth_manager = getattr(getattr(request.app, "state", None), "auth_manager", None)
if not auth_manager:
return
privs = auth_manager.get_privileges(user) or {}
allowed = privs.get("allowed_models") or []
if allowed and sess.model and sess.model not in allowed:
raise HTTPException(403, f"Your account is not allowed to use model '{sess.model}'.")
cap = int(privs.get("max_messages_per_day") or 0)
if cap <= 0:
return
from datetime import datetime as _dt, timedelta as _td
from core.database import Session as _DbSess, ChatMessage as _Cm
db = SessionLocal()
try:
count = (
db.query(_Cm)
.join(_DbSess, _Cm.session_id == _DbSess.id)
.filter(_DbSess.owner == user,
_Cm.role == "user",
_Cm.timestamp >= _dt.utcnow() - _td(days=1))
.count()
)
finally:
db.close()
if count >= cap:
raise HTTPException(429, f"Daily message limit reached ({cap}). Try again in 24 hours.")
def needs_auto_name(name: str) -> bool:
"""Check if a session still has its default/placeholder name."""
if not name:
return True
if name.startswith("Chat:") or name == "Chat":
return True
# Default frontend name: "modelname HH:MM:SS AM/PM"
if re.match(r'^.+ \d{1,2}:\d{2}:\d{2}\s*(AM|PM)$', name):
return True
return False
async def auto_name_session(session_manager, sess):
"""Generate a short title for a session from its first user message."""
try:
from src.llm_core import llm_call_async
from src.task_endpoint import resolve_task_endpoint
# Find first user message
first_msg = ""
for msg in sess.history:
if msg.role == "user":
content = msg.content
if isinstance(content, list):
content = next(
(i.get("text", "") for i in content if isinstance(i, dict) and i.get("type") == "text"),
"",
)
first_msg = str(content)[:500]
break
if not first_msg:
return
t_url, t_model, t_headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers,
)
# max_tokens big enough that reasoning models (Minimax M2,
# DeepSeek R1, QwQ, etc.) have headroom for <think>…</think>
# plus the actual title — 200 used to clip them mid-reasoning
# so strip_think left an empty string and no rename happened.
# Timeout matches: 60s gives slow local reasoners room to finish.
title = await llm_call_async(
t_url,
t_model,
[
{"role": "system", "content": "Generate a short title (3-6 words, no quotes) for a conversation that starts with this message. Reply with ONLY the title, nothing else. Do NOT include any thinking, reasoning, or explanation — just the title."},
{"role": "user", "content": first_msg},
],
temperature=0.3,
max_tokens=4096,
headers=t_headers,
timeout=60,
)
title = title.strip().strip('"\'').strip()
# Strip <think>/<thinking> blocks (closed, dangling, or stray tags)
# via the central helper.
from src.text_helpers import strip_think
title = strip_think(title, prose=False, prompt_echo=False)
if title and len(title) < 80:
session_manager.update_session_name(sess.id, title)
logger.info(f"Auto-named session {sess.id}: {title}")
except Exception as e:
import traceback
logger.error(f"Auto-name failed for {sess.id}: {e}\n{traceback.format_exc()}")
def try_fallback_endpoint(sess, session_id: str) -> dict | None:
"""Find an alternative working endpoint when the current one fails.
Returns {"model": ..., "endpoint_url": ..., "endpoint_name": ...} or None.
"""
import requests as _req
from src.endpoint_resolver import build_chat_url, build_headers, normalize_base
current_url = sess.endpoint_url or ""
db = SessionLocal()
try:
endpoints = db.query(ModelEndpoint).filter(
ModelEndpoint.is_enabled == True
).all()
finally:
db.close()
for ep in endpoints:
base = normalize_base(ep.base_url)
# Skip current endpoint
if current_url and base in current_url:
continue
# Quick ping
ping_url = base + "/models"
headers = {}
if ep.api_key:
headers["Authorization"] = f"Bearer {ep.api_key}"
try:
r = _req.get(ping_url, headers=headers, timeout=5)
r.raise_for_status()
data = r.json()
models = [m.get("id") for m in (data.get("data") or []) if m.get("id")]
if not models:
continue
# Found a working endpoint — update session
new_model = models[0]
chat_url = build_chat_url(base)
new_headers = build_headers(ep.api_key, base)
sess.model = new_model
sess.endpoint_url = chat_url
sess.headers = new_headers
# Persist
_db = SessionLocal()
try:
_db.query(DBSession).filter(DBSession.id == session_id).update({
"model": new_model,
"endpoint_url": chat_url,
"headers": json.dumps(new_headers),
})
_db.commit()
finally:
_db.close()
logger.info(f"Fallback: switched session {session_id} from {current_url} to {ep.name} ({new_model})")
return {
"model": new_model,
"endpoint_url": chat_url,
"endpoint_name": ep.name,
}
except Exception:
continue
return None
def extract_preset(chat_handler, preset_id) -> PresetInfo:
"""Extract preset parameters via chat_handler."""
temperature, max_tokens, system_prompt, char_name = (
chat_handler.validate_and_extract_preset(preset_id)
)
return PresetInfo(
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_prompt,
character_name=char_name,
)
async def preprocess(
chat_handler, message, att_ids, sess,
auto_opened_docs: Optional[list] = None,
) -> PreprocessedMessage:
"""Run chat_handler.preprocess_message and wrap the result."""
enhanced, user_content, text_ctx, yt_transcripts, att_meta = (
await chat_handler.preprocess_message(
message, att_ids, sess, auto_opened_docs=auto_opened_docs
)
)
return PreprocessedMessage(
enhanced_message=enhanced,
user_content=user_content,
text_for_context=text_ctx,
youtube_transcripts=yt_transcripts,
attachment_meta=att_meta,
)
def add_user_message(sess, chat_handler, preprocessed: PreprocessedMessage, incognito: bool = False):
"""Add user message to session history and update session name.
In incognito mode, still add to in-memory history (for conversation context)
but skip session name update (which would persist)."""
user_meta = {"attachments": preprocessed.attachment_meta} if preprocessed.attachment_meta else None
sess.add_message(ChatMessage("user", preprocessed.user_content, metadata=user_meta))
if not incognito:
chat_handler.update_session_name_if_needed(sess, preprocessed.text_for_context)
def fire_message_event(request, webhook_manager, session_id: str, sess, message: str, compare_mode: bool = False):
"""Fire webhook and event_bus events for a new user message."""
if webhook_manager and not compare_mode:
asyncio.create_task(webhook_manager.fire("chat.message", {
"session_id": session_id, "model": sess.model, "message": message[:2000],
}))
from src.event_bus import fire_event
user = get_current_user(request)
fire_event("message_sent", user)
def resolve_session_auth(sess, session_id: str):
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
)
if has_auth:
return
try:
from src.endpoint_resolver import build_headers
db = SessionLocal()
try:
domain = sess.endpoint_url.split("//")[1].split("/")[0] if "//" in sess.endpoint_url else ""
if domain:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.base_url.contains(domain)).first()
if ep and ep.api_key:
sess.headers = build_headers(ep.api_key, ep.base_url)
db.query(DBSession).filter(DBSession.id == session_id).update(
{"headers": json.dumps(sess.headers)}
)
db.commit()
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to resolve session headers: {e}")
async def build_chat_context(
sess,
request,
chat_handler,
chat_processor,
message: str,
session_id: str,
preset_id=None,
att_ids: list = None,
use_web=None,
use_rag=None,
use_research=None,
time_filter=None,
incognito: bool = False,
no_memory: bool = False,
search_context: str = None,
compare_mode: bool = False,
webhook_manager=None,
use_enhanced_message: bool = False,
agent_mode: bool = False,
) -> ChatContext:
"""Build the full context (preface + messages) for an LLM call.
This is the shared logic between /chat and /chat_stream — preset extraction,
message preprocessing, memory/RAG/web injection, compaction, normalization.
"""
# Preset
preset = extract_preset(chat_handler, preset_id)
# Preprocess message (CoT, YouTube, VL images, build content). The
# auto_opened_docs collector captures any docs created server-side
# (e.g. fillable PDF → markdown editor doc) so the chat route can
# announce them to the frontend before streaming.
auto_opened_docs: list = []
preprocessed = await preprocess(
chat_handler, message, att_ids or [], sess,
auto_opened_docs=auto_opened_docs,
)
# Add user message to history
add_user_message(sess, chat_handler, preprocessed, incognito=incognito)
# Fire events
if not incognito:
fire_message_event(request, webhook_manager, session_id, sess, message, compare_mode)
# Resolve user prefs
user = get_current_user(request)
uprefs = load_prefs_for_user(user)
# Memory enabled?
mem_enabled = not incognito and not no_memory and uprefs.get("memory_enabled", True)
# Skills injection respects its own enable toggle (mirrors memory_enabled).
# When off, the "Available skills" index is not added to the prompt.
skills_enabled = not incognito and uprefs.get("skills_enabled", True)
logger.debug(
"Memory enabled=%s for user=%s (incognito=%s, no_memory=%s, pref=%s)",
mem_enabled, user, incognito, no_memory, uprefs.get("memory_enabled", "NOT_SET"),
)
# Use RAG?
use_rag_val = (str(use_rag).lower() != "false") if use_rag is not None else True
if incognito:
use_rag_val = False
# If pre-fetched search context was provided (compare mode), skip live web search
skip_web = bool(search_context)
# Build context preface
# The stream path uses enhanced_message (with CoT/preprocessing applied),
# the sync path uses text_for_context.
_ctx_msg = preprocessed.enhanced_message if use_enhanced_message else preprocessed.text_for_context
_preface_kwargs = dict(
message=_ctx_msg,
session=sess,
use_web=use_web and not skip_web,
use_memory=mem_enabled,
time_filter=time_filter,
preset_system_prompt=preset.system_prompt,
owner=user,
character_name=preset.character_name,
agent_mode=agent_mode,
incognito=incognito,
use_skills=skills_enabled,
)
if use_rag is not None:
_preface_kwargs["use_rag"] = use_rag_val
preface, rag_sources, web_sources = chat_processor.build_context_preface(**_preface_kwargs)
# Capture used memories immediately
used_memories = getattr(chat_processor, '_last_used_memories', [])
# Inject pre-fetched search context (compare mode)
if search_context:
preface.append(untrusted_context_message("prefetched search context", search_context))
# YouTube transcripts
for transcript in preprocessed.youtube_transcripts:
preface.append(untrusted_context_message("youtube transcript", transcript))
# Normalize model ID
norm = normalize_model_id(sess.endpoint_url, sess.model)
if norm:
sess.model = norm
# Build messages
messages = preface + sess.get_context_messages()
# Auto-compact
messages, context_length, was_compacted = await maybe_compact(
sess, sess.endpoint_url, sess.model, messages, sess.headers,
)
messages = trim_for_context(messages, context_length)
return ChatContext(
preface=preface,
rag_sources=rag_sources,
web_sources=web_sources,
used_memories=used_memories,
messages=messages,
context_length=context_length,
was_compacted=was_compacted,
user=user,
uprefs=uprefs,
preset=preset,
preprocessed=preprocessed,
auto_opened_docs=auto_opened_docs,
)
def accumulate_token_usage(session_id: str, metrics: dict):
"""Add input/output token counts to the session's running totals."""
in_t = metrics.get("input_tokens", 0)
out_t = metrics.get("output_tokens", 0)
if not (in_t or out_t):
return
db = SessionLocal()
try:
db_s = db.query(DBSession).filter(DBSession.id == session_id).first()
if db_s:
db_s.total_input_tokens = (db_s.total_input_tokens or 0) + in_t
db_s.total_output_tokens = (db_s.total_output_tokens or 0) + out_t
db.commit()
except Exception:
db.rollback()
finally:
db.close()
def _normalize_thinking(text: str) -> str:
"""Wrap inline thinking patterns in <think> tags so they persist on reload.
Handles:
- "Thinking Process:" (Qwen3.5)
- Gemma-style inline reasoning ("The user said/asked...", "I should/need to...")
- Garbled <think> tags (reasoning before the tag, unclosed tags)
"""
import re
if not text:
return text
reasoning_prefix_re = re.compile(
r'^\s*(?:thinking(?:\s+process)?\s*:|the user |i need |i should |i will |they are |the question |i can )',
re.IGNORECASE,
)
thinking_prefix_re = re.compile(r'^thinking(?:\s+process)?\s*:\s*', re.IGNORECASE)
# Handle garbled <think> tags: reasoning text followed by <think> as separator
# e.g. "The user said...I should respond.\n<think>Hey! What's up?"
garbled = re.match(
r'^([\s\S]+?)\n*<think(?:ing)?>\s*([\s\S]*?)(?:</think(?:ing)?>)?\s*$',
text, re.IGNORECASE
)
if garbled:
before = garbled.group(1).strip()
after = garbled.group(2).strip()
# Only treat as garbled if the part before <think> looks like reasoning
reasoning_starts = (
'The user ', 'I need ', 'I should ', 'I will ',
'They are ', 'The question ', 'I can ',
'Thinking Process', 'Thinking:',
)
stripped_before = before.lstrip()
if any(stripped_before.startswith(p) for p in reasoning_starts) or reasoning_prefix_re.match(stripped_before):
# Strip "Thinking:" prefix from the thinking content
stripped_before = thinking_prefix_re.sub('', stripped_before)
return '<think>' + stripped_before + '</think>\n' + after
if '<think' in text.lower():
return text # already has proper think tags
# Qwen3.5: "Thinking Process:" or "Thinking:" prefix
if thinking_prefix_re.match(text.lstrip()):
# Try clean boundary first
m = re.match(
r'^(Thinking(?:\s+Process)?:[\s\S]*?)(\n\n(?=[A-Z]|Hey|Yo|Hi|Sure|I |What|Here|Let|The |This |OK|Ok|Yes|No |So |Well |Thank|Alright|Of course|Absolutely|Great|Hello|As ))',
text, re.IGNORECASE | re.MULTILINE
)
if m:
think = thinking_prefix_re.sub('', m.group(1)).strip()
return '<think>' + think + '</think>' + text[m.end()-2:]
# Fallback: find last non-indented paragraph as reply
parts = text.split('\n\n')
for i in range(len(parts) - 1, 0, -1):
line = parts[i].strip()
if line and not re.match(r'^[\d*\-\s(]', line) and len(line) > 5:
think = thinking_prefix_re.sub('', '\n\n'.join(parts[:i])).strip()
reply = '\n\n'.join(parts[i:])
return '<think>' + think + '</think>\n\n' + reply
# Last resort: look for a quoted final response inside the thinking
# Qwen often drafts the reply as "Option: ..." or * "reply text"
last_quote = re.findall(r'["\u201c]([^"\u201d]{10,})["\u201d]', text)
if last_quote:
reply = last_quote[-1].strip()
think = thinking_prefix_re.sub('', text).strip()
return '<think>' + think + '</think>\n\n' + reply
# Truly no reply found
think = thinking_prefix_re.sub('', text).strip()
return '<think>' + think + '</think>'
# Gemma-style: starts with reasoning ("The user", "I need", "I should", etc.)
stripped_text = text.lstrip()
first_line = stripped_text.split('\n')[0].strip()
reasoning_starts = (
'The user ', 'I need ', 'I should ', 'I will ',
'They are ', 'The question ', 'I can ',
)
reply_starts = (
'Hey', 'Hi ', 'Hi!', 'Hello', 'Sure', 'Yes', 'No ', 'No,', 'Yo', 'OK',
'Here', 'Absolutely', 'Of course', 'Great', 'Alright',
'Thanks', 'Welcome', 'Good ', "I'm happy", "I'd be",
)
if any(first_line.startswith(p) for p in reasoning_starts):
# Try line-by-line split first
lines = stripped_text.split('\n')
for i, line in enumerate(lines):
stripped = line.strip()
if not stripped:
continue
if i > 0 and any(stripped.startswith(p) for p in reply_starts):
think = '\n'.join(lines[:i])
reply = '\n'.join(lines[i:])
return '<think>' + think + '</think>\n' + reply
# Try within-line split — model mashed thinking + reply on one line
# Look for reply pattern after a period or sentence end
for p in reply_starts:
# Match: "...reasoning text.Reply text" or "...reasoning text. Reply text"
pattern = r'([.!?])\s*(' + re.escape(p) + r')'
m = re.search(pattern, stripped_text)
if m and m.start() > 20: # at least 20 chars of reasoning before
think = stripped_text[:m.start() + 1] # include the period
reply = stripped_text[m.start() + 1:].lstrip()
return '<think>' + think + '</think>\n' + reply
# Last resort: find last non-reasoning line
for i in range(len(lines) - 1, 0, -1):
stripped = lines[i].strip()
if stripped and not any(stripped.startswith(p) for p in reasoning_starts) and not stripped.startswith('*') and len(stripped) > 3:
think = '\n'.join(lines[:i])
reply = '\n'.join(lines[i:])
return '<think>' + think + '</think>\n' + reply
return text
def _extract_thinking_meta(text: str) -> dict | None:
"""Extract thinking content into metadata, return {thinking, reply, time} or None."""
import re
if not text:
return None
# Check for <think> tags (native or injected)
time_match = re.search(r'<think(?:ing)?\s+time="([\d.]+)"', text)
think_time = time_match.group(1) if time_match else None
# Strip time attr for parsing
clean = re.sub(r'<think(?:ing)?\s+time="[\d.]+"', '<think', text)
think_match = re.match(r'^[\s]*<think(?:ing)?>([\s\S]*?)</think(?:ing)?>\s*([\s\S]*)', clean, re.IGNORECASE)
if think_match:
thinking = think_match.group(1).strip()
reply = think_match.group(2).strip()
# Only strip the thinking out into metadata when there's an actual reply
# left over. If reply is empty (model hit max_tokens inside <think>, or
# the turn was reasoning-only), keep the raw text as content — otherwise
# the saved message has empty content and the bubble looks blank on
# reload. The renderer's processWithThinking still extracts the <think>
# block visually at display time, so nothing changes for the normal case.
if thinking and reply:
return {"thinking": thinking, "reply": reply, "time": think_time}
# Detect Thinking Process: or Gemma-style reasoning
normalized = _normalize_thinking(text)
if '<think>' in normalized:
think_match2 = re.match(r'^[\s]*<think(?:ing)?>([\s\S]*?)</think(?:ing)?>\s*([\s\S]*)', normalized, re.IGNORECASE)
if think_match2:
thinking = think_match2.group(1).strip()
reply = think_match2.group(2).strip()
if thinking and reply:
return {"thinking": thinking, "reply": reply, "time": think_time}
return None
def clean_thinking_for_save(content: str, metadata: dict | None = None) -> tuple[str, dict]:
"""Extract thinking from content into metadata. Use for save paths that bypass save_assistant_response."""
md = dict(metadata) if metadata else {}
info = _extract_thinking_meta(content)
if info:
md["thinking"] = info["thinking"]
if info.get("time"):
md["thinking_time"] = info["time"]
return info["reply"], md
return content, md
def save_assistant_response(
sess,
session_manager,
session_id: str,
full_response: str,
last_metrics: dict | None,
*,
character_name: str = None,
web_sources: list = None,
rag_sources: list = None,
research_sources: list = None,
used_memories: list = None,
do_research: bool = False,
tool_events: list = None,
incognito: bool = False,
):
"""Add assistant response to session history. In incognito mode, keeps in-memory context but skips DB persistence."""
md = dict(last_metrics) if last_metrics else {}
md["model"] = sess.model
if character_name:
md["character_name"] = character_name
if web_sources:
md["web_sources"] = web_sources
if rag_sources:
md["rag_sources"] = rag_sources
if research_sources:
md["research_sources"] = research_sources
if used_memories:
md["memories_used"] = used_memories
if do_research and not research_sources:
md["research_clarification"] = True
if tool_events:
md["tool_events"] = tool_events
# Extract thinking into metadata (don't pollute message content with <think> tags)
_think_info = _extract_thinking_meta(full_response)
if _think_info:
md["thinking"] = _think_info["thinking"]
md["thinking_time"] = _think_info.get("time")
_content = _think_info["reply"]
else:
_content = full_response
sess.add_message(ChatMessage("assistant", _content, metadata=md))
if not incognito:
from core.database import update_session_last_accessed
update_session_last_accessed(session_id)
session_manager.save_sessions()
# Return the persisted message's DB id so the stream can wire it onto the
# freshly-rendered bubble — lets the user edit/delete a just-streamed reply
# without reloading. Incognito returns None: those messages are ephemeral,
# so we don't hand out an edit/delete handle for them.
if incognito:
return None
try:
_last = sess.history[-1]
_meta = getattr(_last, "metadata", None)
if isinstance(_meta, dict):
return _meta.get("_db_id")
except (IndexError, AttributeError):
pass
return None
def run_post_response_tasks(
sess,
session_manager,
session_id: str,
message: str,
full_response: str,
last_metrics: dict | None,
uprefs: dict,
memory_manager,
memory_vector,
webhook_manager,
*,
incognito: bool = False,
compare_mode: bool = False,
character_name: str = None,
agent_rounds: int = 0,
agent_tool_calls: int = 0,
skills_manager=None,
owner: str = None,
extract_skills: bool = True,
):
"""Fire background tasks after a completed response: memory extraction, webhooks, auto-name, skill extraction."""
# Memory extraction — only every 4th message pair to avoid excess LLM calls
_msg_count = len(sess.history) if hasattr(sess, 'history') else 0
_should_extract = (_msg_count >= 4) and (_msg_count % 4 == 0)
if not incognito and not compare_mode and _should_extract and uprefs.get("auto_memory", True):
from services.memory.memory_extractor import extract_and_store
from src.task_endpoint import resolve_task_endpoint
t_url, t_model, t_headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers,
)
asyncio.create_task(extract_and_store(
sess, memory_manager, memory_vector,
t_url, t_model, t_headers,
))
# Skill extraction from complex agent runs. Only when the user actually
# chose agent mode — not a chat we auto-escalated for a notes/calendar
# intent, and never in incognito/compare.
auto_skills_enabled = bool(uprefs.get("auto_skills", True))
# Quiet by default — full gate/dispatch/start trace runs at DEBUG so
# users can re-enable diagnostics with LOG_LEVEL=DEBUG when something
# silently breaks. INFO-level only shows the outcome inside
# maybe_extract_skill (Auto-extracted / dropped / failed).
logger.debug(
"[skill-extract] gate: extract_skills=%s auto_skills=%s incognito=%s "
"compare=%s rounds=%d tools=%d skills_manager=%s",
extract_skills, auto_skills_enabled, incognito, compare_mode,
agent_rounds, agent_tool_calls, "set" if skills_manager else "MISSING",
)
if (
extract_skills
and auto_skills_enabled
and not incognito
and not compare_mode
and (agent_rounds >= 2 or agent_tool_calls >= 2)
):
if skills_manager is None:
logger.warning(
"[skill-extract] gate PASSED but skills_manager is None — "
"extraction skipped. (Bug: caller didn't pass skills_manager.)"
)
else:
from services.memory.skill_extractor import maybe_extract_skill
from src.task_endpoint import resolve_task_endpoint
s_url, s_model, s_headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers,
)
logger.debug("[skill-extract] dispatching extractor (model=%s)", s_model)
asyncio.create_task(maybe_extract_skill(
sess, skills_manager,
s_url, s_model, s_headers,
agent_rounds, agent_tool_calls,
owner=owner,
))
# Token accumulation
if last_metrics:
accumulate_token_usage(session_id, last_metrics)
# Webhook
if webhook_manager and not compare_mode:
asyncio.create_task(webhook_manager.fire("chat.completed", {
"session_id": session_id, "model": sess.model,
"user_message": message, "response": full_response[:2000],
}))
# Auto-name
if needs_auto_name(sess.name):
asyncio.create_task(auto_name_session(session_manager, sess))

1114
routes/chat_routes.py Normal file

File diff suppressed because it is too large Load Diff

60
routes/cleanup_routes.py Normal file
View File

@@ -0,0 +1,60 @@
# routes/cleanup_routes.py
"""Routes for cleanup operations."""
import logging
from fastapi import APIRouter, HTTPException, Request
from src.cleanup_service import get_cleanup_preview, cleanup_sessions
from src.auth_helpers import get_current_user
logger = logging.getLogger(__name__)
def setup_cleanup_routes(session_manager):
"""
Setup cleanup-related routes.
Args:
session_manager: SessionManager instance
Returns:
APIRouter instance with cleanup routes
"""
router = APIRouter(prefix="/api/cleanup")
@router.get("/preview")
async def cleanup_preview(request: Request):
"""
Preview what would be cleaned up without making any changes.
Returns:
JSON response with lists of sessions that would be archived/deleted and estimated space savings
"""
user = get_current_user(request)
try:
preview = await get_cleanup_preview(owner=user)
return preview
except Exception as e:
logger.error(f"Cleanup preview failed: {e}")
raise HTTPException(500, "Cleanup preview generation failed")
@router.post("")
async def cleanup_endpoint(request: Request):
"""
Perform cleanup operations:
1. Archive inactive sessions (not accessed for 7 days)
2. Delete old sessions (archived, not important, not accessed for 14+ days, with fewer than 10 messages)
Returns:
JSON response with counts of deleted and archived sessions, and space freed
"""
user = get_current_user(request)
try:
archived_count, deleted_count, space_freed_mb = await cleanup_sessions(session_manager, owner=user)
return {
"archived_count": archived_count,
"deleted_count": deleted_count,
"space_freed_mb": round(space_freed_mb, 2)
}
except Exception as e:
logger.error(f"Cleanup failed: {e}")
raise HTTPException(500, "Cleanup operation failed")
return router

246
routes/compare_routes.py Normal file
View File

@@ -0,0 +1,246 @@
# routes/compare_routes.py
"""Model A/B comparison routes."""
import json
import uuid
import random
from datetime import datetime
from fastapi import APIRouter, Form, HTTPException, Request
from typing import List
from pydantic import BaseModel
import logging
from core.database import Comparison, SessionLocal
from core.session_manager import SessionManager
from src.auth_helpers import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/compare", tags=["compare"])
class RecordVoteRequest(BaseModel):
prompt: str
models: List[str]
winner: str # model name or "tie"
is_blind: bool = True
def setup_compare_routes(session_manager: SessionManager):
"""Setup comparison routes."""
@router.post("/start")
def start_comparison(
request: Request,
prompt: str = Form(...),
model_a: str = Form(...),
model_b: str = Form(...),
endpoint_a: str = Form(...),
endpoint_b: str = Form(...),
is_blind: str = Form("true"),
):
"""Create two ephemeral sessions and a comparison record.
Returns the comparison ID and the two session IDs so the client
can fire two independent SSE streams to /api/chat_stream.
"""
comp_id = str(uuid.uuid4())
sid_a = str(uuid.uuid4())
sid_b = str(uuid.uuid4())
# Create ephemeral sessions (prefixed [CMP])
for sid, model, endpoint in [(sid_a, model_a, endpoint_a), (sid_b, model_b, endpoint_b)]:
user = getattr(request.state, 'current_user', None)
session_manager.create_session(
session_id=sid,
name=f"[CMP] {model.split('/')[-1]}",
endpoint_url=endpoint,
model=model,
rag=False,
owner=user,
)
# Copy API key from endpoint config
db = SessionLocal()
try:
from core.database import ModelEndpoint
# Find matching endpoint by URL
ep = db.query(ModelEndpoint).filter(
ModelEndpoint.base_url == endpoint.replace('/chat/completions', '')
).first()
if ep and ep.api_key:
s = session_manager.sessions.get(sid)
if s:
s.headers = {"Authorization": f"Bearer {ep.api_key}"}
finally:
db.close()
# Blind mapping: randomly assign left/right
blind = str(is_blind).lower() == "true"
if blind:
mapping = {"left": "a", "right": "b"}
if random.random() > 0.5:
mapping = {"left": "b", "right": "a"}
else:
mapping = {"left": "a", "right": "b"}
# Store comparison record
db = SessionLocal()
try:
comp = Comparison(
id=comp_id,
prompt=prompt,
model_a=model_a,
model_b=model_b,
endpoint_a=endpoint_a,
endpoint_b=endpoint_b,
is_blind=blind,
blind_mapping=json.dumps(mapping),
owner=user,
)
db.add(comp)
db.commit()
finally:
db.close()
# Map session IDs to left/right based on blind mapping
session_left = sid_a if mapping["left"] == "a" else sid_b
session_right = sid_a if mapping["right"] == "a" else sid_b
return {
"id": comp_id,
"session_left": session_left,
"session_right": session_right,
"model_left": model_a if mapping["left"] == "a" else model_b,
"model_right": model_a if mapping["right"] == "a" else model_b,
"is_blind": blind,
"mapping": mapping,
}
@router.post("/{comp_id}/vote")
def vote_comparison(
request: Request,
comp_id: str,
winner: str = Form(...), # "left", "right", or "tie"
):
"""Record the user's vote and reveal model names if blind."""
user = get_current_user(request)
db = SessionLocal()
try:
comp = db.query(Comparison).filter(Comparison.id == comp_id).first()
if not comp:
raise HTTPException(404, "Comparison not found")
# SECURITY: strict ownership — null-owner Comparisons were
# accessible to every user.
if user and comp.owner != user:
raise HTTPException(404, "Comparison not found")
if comp.winner:
raise HTTPException(400, "Already voted")
mapping = json.loads(comp.blind_mapping) if comp.blind_mapping else {"left": "a", "right": "b"}
if winner == "tie":
comp.winner = "tie"
elif winner == "left":
comp.winner = mapping["left"]
elif winner == "right":
comp.winner = mapping["right"]
else:
raise HTTPException(400, "winner must be 'left', 'right', or 'tie'")
comp.voted_at = datetime.utcnow()
db.commit()
return {
"winner": comp.winner,
"model_a": comp.model_a,
"model_b": comp.model_b,
"revealed": {
"left": comp.model_a if mapping["left"] == "a" else comp.model_b,
"right": comp.model_a if mapping["right"] == "a" else comp.model_b,
},
}
finally:
db.close()
@router.post("/record")
def record_comparison(request: Request, body: RecordVoteRequest):
"""Lightweight endpoint to record a comparison vote from the frontend."""
user = get_current_user(request)
comp_id = str(uuid.uuid4())
model_a = body.models[0] if len(body.models) > 0 else ""
model_b = body.models[1] if len(body.models) > 1 else ""
# For N>2 models, store the full list as JSON in blind_mapping
if len(body.models) > 2:
blind_mapping = json.dumps({"models": body.models})
else:
blind_mapping = None
db = SessionLocal()
try:
comp = Comparison(
id=comp_id,
prompt=body.prompt[:500],
model_a=model_a,
model_b=model_b,
endpoint_a="",
endpoint_b="",
winner=body.winner,
is_blind=body.is_blind,
blind_mapping=blind_mapping,
voted_at=datetime.utcnow(),
owner=user,
)
db.add(comp)
db.commit()
finally:
db.close()
return {"status": "ok", "id": comp_id}
@router.get("/history")
def list_comparisons(request: Request):
"""List past comparisons."""
user = get_current_user(request)
db = SessionLocal()
try:
q = db.query(Comparison)
if user:
q = q.filter(Comparison.owner == user)
comps = q.order_by(Comparison.created_at.desc()).limit(50).all()
return [
{
"id": c.id,
"prompt": c.prompt[:100],
"model_a": c.model_a,
"model_b": c.model_b,
"winner": c.winner,
"is_blind": c.is_blind,
"voted_at": c.voted_at.isoformat() if c.voted_at else None,
"created_at": c.created_at.isoformat() if c.created_at else None,
}
for c in comps
]
finally:
db.close()
@router.delete("/{comp_id}")
def delete_comparison(request: Request, comp_id: str):
"""Delete a comparison and its ephemeral sessions."""
user = get_current_user(request)
db = SessionLocal()
try:
comp = db.query(Comparison).filter(Comparison.id == comp_id).first()
if not comp:
raise HTTPException(404, "Comparison not found")
# SECURITY: strict ownership — null-owner Comparisons were
# accessible to every user.
if user and comp.owner != user:
raise HTTPException(404, "Comparison not found")
db.delete(comp)
db.commit()
return {"status": "deleted"}
finally:
db.close()
return router

783
routes/contacts_routes.py Normal file
View File

@@ -0,0 +1,783 @@
"""
contacts_routes.py
CardDAV contacts integration. Reads from local Radicale, supports
search and adding new contacts.
"""
import re
import logging
import uuid
import json
import csv
import io
import httpx
from pathlib import Path
from datetime import datetime
from fastapi import APIRouter, Query, Depends, Response
from typing import List, Dict, Optional
from src.auth_helpers import require_user
from core.middleware import require_admin
logger = logging.getLogger(__name__)
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
SETTINGS_FILE = DATA_DIR / "settings.json"
LOCAL_CONTACTS_FILE = DATA_DIR / "contacts.json"
def _load_settings():
if SETTINGS_FILE.exists():
return json.loads(SETTINGS_FILE.read_text())
return {}
def _save_settings(settings):
from core.atomic_io import atomic_write_json
atomic_write_json(str(SETTINGS_FILE), settings, indent=2)
def _get_carddav_config():
import os
settings = _load_settings()
return {
"url": settings.get("carddav_url", os.environ.get("CARDDAV_URL", "")),
"username": settings.get("carddav_username", os.environ.get("CARDDAV_USERNAME", "")),
"password": settings.get("carddav_password", os.environ.get("CARDDAV_PASSWORD", "")),
}
def _carddav_configured(cfg: Optional[Dict] = None) -> bool:
cfg = cfg or _get_carddav_config()
return bool((cfg.get("url") or "").strip())
def _normalize_contact(contact: Dict) -> Dict:
emails = []
for e in contact.get("emails") or ([] if not contact.get("email") else [contact.get("email")]):
e = str(e or "").strip()
if e and e not in emails:
emails.append(e)
phones = []
for p in contact.get("phones") or ([] if not contact.get("phone") else [contact.get("phone")]):
p = str(p or "").strip()
if p and p not in phones:
phones.append(p)
name = str(contact.get("name") or "").strip()
if not name and emails:
name = emails[0].split("@")[0]
return {
"uid": str(contact.get("uid") or uuid.uuid4()),
"name": name,
"emails": emails,
"phones": phones,
}
def _load_local_contacts() -> List[Dict]:
try:
if not LOCAL_CONTACTS_FILE.exists():
return []
data = json.loads(LOCAL_CONTACTS_FILE.read_text())
rows = data.get("contacts", data) if isinstance(data, dict) else data
return [_normalize_contact(c) for c in (rows or []) if isinstance(c, dict)]
except Exception as e:
logger.error(f"Failed to load local contacts: {e}")
return []
def _save_local_contacts(contacts: List[Dict]) -> None:
from core.atomic_io import atomic_write_json
DATA_DIR.mkdir(parents=True, exist_ok=True)
atomic_write_json(str(LOCAL_CONTACTS_FILE), {"contacts": [_normalize_contact(c) for c in contacts]}, indent=2)
_contact_cache["contacts"] = [_normalize_contact(c) for c in contacts]
_contact_cache["fetched_at"] = datetime.utcnow()
# ── vCard parsing ──
def _vunesc(value: str) -> str:
"""Reverse _vesc() — turn escaped vCard text back into the raw value.
Order matters: handle \\n/\\, /\\; first, backslash-unescape last."""
if not value:
return value
out = []
i = 0
while i < len(value):
ch = value[i]
if ch == "\\" and i + 1 < len(value):
nxt = value[i + 1]
if nxt in ("n", "N"):
out.append("\n")
elif nxt in (",", ";", "\\"):
out.append(nxt)
else:
out.append(nxt)
i += 2
else:
out.append(ch)
i += 1
return "".join(out)
def _parse_vcards(text: str) -> List[Dict]:
"""Parse a stream of vCards into dicts with name, email, phone."""
contacts = []
for block in re.split(r"BEGIN:VCARD", text):
if not block.strip():
continue
contact = {"name": "", "emails": [], "phones": [], "uid": ""}
for line in block.split("\n"):
line = line.strip()
if line.startswith("FN:") or line.startswith("FN;"):
contact["name"] = _vunesc(line.split(":", 1)[1]) if ":" in line else ""
elif line.startswith("EMAIL"):
# Handle EMAIL:foo@bar OR EMAIL;TYPE=...:foo@bar OR EMAIL;PREF=1:foo@bar
if ":" in line:
email_addr = _vunesc(line.split(":", 1)[1])
if email_addr and email_addr not in contact["emails"]:
contact["emails"].append(email_addr)
elif line.startswith("TEL"):
if ":" in line:
phone = _vunesc(line.split(":", 1)[1])
if phone and phone not in contact["phones"]:
contact["phones"].append(phone)
elif line.startswith("UID:"):
contact["uid"] = _vunesc(line[4:])
if contact["name"] or contact["emails"]:
contacts.append(contact)
return contacts
def _vesc(value: str) -> str:
"""Escape a vCard property VALUE per RFC 6350 §3.4: backslash, comma,
semicolon, and newlines. Without this, a name like 'Sekisui House,Ltd'
or any value containing a newline produces a malformed vCard (broken
N/FN fields) or could inject arbitrary properties."""
return (
(value or "")
.replace("\\", "\\\\")
.replace("\n", "\\n")
.replace("\r", "")
.replace(",", "\\,")
.replace(";", "\\;")
)
def _build_vcard(name: str, email: str, uid: Optional[str] = None,
emails: Optional[List[str]] = None,
phones: Optional[List[str]] = None) -> str:
"""Build a vCard. Accepts either a single `email` (legacy callers) or
full `emails`/`phones` lists (edit path). The first email is marked
PREF=1. All values are RFC-6350-escaped."""
if not uid:
uid = str(uuid.uuid4())
# Normalize email lists — `email` arg is a convenience for single-email
# creation; `emails` (if given) is authoritative.
email_list = [e.strip() for e in (emails if emails is not None else ([email] if email else [])) if e and e.strip()]
phone_list = [p.strip() for p in (phones or []) if p and p.strip()]
# Try to split name into first/last
parts = name.strip().split()
if len(parts) >= 2:
first = parts[0]
last = " ".join(parts[1:])
else:
first = name
last = ""
# N field is structured (5 components separated by ';') — escape each
# component individually so a comma in the name doesn't split it.
n_field = f"{_vesc(last)};{_vesc(first)};;;"
lines = [
"BEGIN:VCARD",
"VERSION:4.0",
f"UID:{_vesc(uid)}",
f"FN:{_vesc(name)}",
f"N:{n_field}",
]
for i, em in enumerate(email_list):
# First email is the preferred one.
lines.append(f"EMAIL;PREF=1:{_vesc(em)}" if i == 0 else f"EMAIL:{_vesc(em)}")
for ph in phone_list:
lines.append(f"TEL:{_vesc(ph)}")
lines.append("END:VCARD")
return "\r\n".join(lines) + "\r\n"
# ── In-memory cache ──
_contact_cache = {"contacts": [], "fetched_at": None}
def _abs_url(href: str) -> str:
"""Combine a multistatus <href> (an absolute path like
/user/contacts/x.vcf) with the configured CardDAV server origin so we
get a fully-qualified URL to PUT/DELETE. If href is already absolute
(http...), return it as-is."""
from urllib.parse import urlparse, urlunparse
if href.startswith("http://") or href.startswith("https://"):
return href
cfg = _get_carddav_config()
p = urlparse(cfg["url"])
return urlunparse((p.scheme, p.netloc, href, "", "", ""))
# CardDAV REPORT body — pull every card's etag + raw vCard in ONE request,
# alongside the resource href. Lets us map each contact's UID to the real
# server resource path (which is NOT always <uid>.vcf for contacts created
# by other clients).
_ADDRESSBOOK_QUERY = (
'<?xml version="1.0" encoding="utf-8"?>'
'<C:addressbook-query xmlns:D="DAV:" xmlns:C="urn:ietf:params:xml:ns:carddav">'
'<D:prop><D:getetag/><C:address-data/></D:prop>'
'<C:filter/>'
'</C:addressbook-query>'
)
def _fetch_via_report(cfg, auth):
"""Try a CardDAV REPORT addressbook-query — returns contacts WITH an
`href` field, or None if the server doesn't support it / errors."""
from defusedxml import ElementTree as ET
try:
r = httpx.request(
"REPORT", cfg["url"],
content=_ADDRESSBOOK_QUERY.encode("utf-8"),
headers={"Content-Type": "application/xml; charset=utf-8", "Depth": "1"},
auth=auth, timeout=10,
)
if r.status_code not in (207, 200):
return None
root = ET.fromstring(r.text)
ns = {"D": "DAV:", "C": "urn:ietf:params:xml:ns:carddav"}
out = []
for resp in root.findall("D:response", ns):
href_el = resp.find("D:href", ns)
data_el = resp.find(".//C:address-data", ns)
if href_el is None or data_el is None or not (data_el.text or "").strip():
continue
parsed = _parse_vcards(data_el.text)
if not parsed:
continue
c = parsed[0]
c["href"] = href_el.text.strip()
out.append(c)
# If the REPORT parsed to ZERO contacts, don't trust it — some
# CardDAV servers treat an empty <filter/> as "match nothing" and
# return a valid-but-empty 207. Return None so the caller falls
# back to the plain GET (which lists everything). A genuinely empty
# address book just costs one extra GET that also returns nothing.
if not out:
return None
return out
except Exception as e:
logger.warning(f"CardDAV REPORT failed, falling back to GET: {e}")
return None
def _fetch_contacts(force=False):
"""Fetch all contacts. Uses CardDAV when configured, otherwise local JSON."""
if not force and _contact_cache["fetched_at"]:
age = (datetime.utcnow() - _contact_cache["fetched_at"]).total_seconds()
if age < 60:
return _contact_cache["contacts"]
cfg = _get_carddav_config()
if not _carddav_configured(cfg):
contacts = _load_local_contacts()
_contact_cache["contacts"] = contacts
_contact_cache["fetched_at"] = datetime.utcnow()
return contacts
try:
auth = None
if cfg["username"]:
auth = (cfg["username"], cfg["password"])
# Preferred path: REPORT gives us hrefs for reliable edit/delete.
contacts = _fetch_via_report(cfg, auth)
if contacts is None:
# Fallback: plain GET, concatenated vCards, no hrefs.
r = httpx.get(cfg["url"], auth=auth, timeout=10)
if r.status_code != 200:
logger.warning(f"CardDAV returned {r.status_code}")
return _contact_cache["contacts"]
contacts = _parse_vcards(r.text)
_contact_cache["contacts"] = contacts
_contact_cache["fetched_at"] = datetime.utcnow()
return contacts
except Exception as e:
logger.error(f"Failed to fetch contacts: {e}")
return _contact_cache["contacts"]
def _resolve_resource_url(uid: str) -> str:
"""Map a contact UID to its real CardDAV resource URL. Uses the href
captured during fetch when available (handles contacts whose filename
!= UID); falls back to the <uid>.vcf guess for app-created contacts or
when no href is known."""
def _lookup():
for c in _contact_cache.get("contacts", []):
if c.get("uid") == uid and c.get("href"):
return _abs_url(c["href"])
return None
found = _lookup()
if found:
return found
# Not in cache (or no href) — refresh once and retry before guessing.
try:
_fetch_contacts(force=True)
except Exception:
pass
return _lookup() or _vcard_url(uid)
def _create_contact(name: str, email: str) -> bool:
"""Add a new contact via CardDAV or local contacts."""
cfg = _get_carddav_config()
if not _carddav_configured(cfg):
contacts = _load_local_contacts()
email_l = (email or "").strip().lower()
for c in contacts:
if email_l and email_l in [e.lower() for e in c.get("emails", [])]:
return True
contacts.append(_normalize_contact({"name": name, "emails": [email]}))
_save_local_contacts(contacts)
return True
contact_uid = str(uuid.uuid4())
vcard = _build_vcard(name, email, contact_uid)
url = cfg["url"].rstrip("/") + "/" + contact_uid + ".vcf"
try:
auth = None
if cfg["username"]:
auth = (cfg["username"], cfg["password"])
r = httpx.put(
url,
data=vcard.encode("utf-8"),
headers={"Content-Type": "text/vcard; charset=utf-8"},
auth=auth,
timeout=10,
)
if r.status_code in (200, 201, 204):
# Invalidate cache
_contact_cache["fetched_at"] = None
return True
logger.warning(f"CardDAV PUT returned {r.status_code}: {r.text[:200]}")
return False
except Exception as e:
logger.error(f"Failed to create contact: {e}")
return False
def _vcard_url(uid: str) -> str:
"""The CardDAV resource URL for a given contact UID. The uid is URL-
encoded so a value containing '/', '..' or other path chars can't
escape the collection and target an arbitrary CardDAV resource."""
from urllib.parse import quote
cfg = _get_carddav_config()
return cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
def _import_vcards(text: str) -> Dict:
"""Import a (possibly multi-card) .vcf blob. Each card is PUT to the
CardDAV server PRESERVING its full original content (ADR/ORG/photo/
etc.) — we don't rebuild it, just ensure it has VERSION + UID and
normalize line endings. Returns {imported, failed, total}."""
from urllib.parse import quote
cfg = _get_carddav_config()
if not cfg.get("url"):
parsed = _parse_vcards(text)
contacts = _load_local_contacts()
existing = {
e.lower()
for c in contacts
for e in (c.get("emails") or [])
if e
}
imported = 0
for c in parsed:
emails = [e for e in (c.get("emails") or []) if e]
if emails and any(e.lower() in existing for e in emails):
continue
contacts.append(_normalize_contact(c))
for e in emails:
existing.add(e.lower())
imported += 1
if imported:
_save_local_contacts(contacts)
return {"imported": imported, "failed": 0, "total": len(parsed)}
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
# Split into individual cards. re.split drops the BEGIN line, so we
# re-add it. Normalize CRLF.
raw = (text or "").replace("\r\n", "\n").replace("\r", "\n")
blocks = []
for chunk in raw.split("BEGIN:VCARD"):
chunk = chunk.strip()
if not chunk:
continue
# Trim anything after END:VCARD (defensive).
end = chunk.upper().find("END:VCARD")
body = chunk[: end + len("END:VCARD")] if end != -1 else chunk
blocks.append("BEGIN:VCARD\n" + body)
imported = 0
failed = 0
for block in blocks:
# Extract or assign a UID.
m = re.search(r"^UID:(.+)$", block, re.MULTILINE)
uid = (m.group(1).strip() if m else "") or str(uuid.uuid4())
if not m:
# Inject a UID right after the VERSION line (or after BEGIN).
if re.search(r"^VERSION:", block, re.MULTILINE):
block = re.sub(r"(^VERSION:.*$)", r"\1\nUID:" + uid, block, count=1, flags=re.MULTILINE)
else:
block = block.replace("BEGIN:VCARD", f"BEGIN:VCARD\nVERSION:4.0\nUID:{uid}", 1)
elif not re.search(r"^VERSION:", block, re.MULTILINE):
block = block.replace("BEGIN:VCARD", "BEGIN:VCARD\nVERSION:4.0", 1)
vcard = block.replace("\n", "\r\n") + "\r\n"
url = cfg["url"].rstrip("/") + "/" + quote(uid, safe="") + ".vcf"
try:
r = httpx.put(
url, data=vcard.encode("utf-8"),
headers={"Content-Type": "text/vcard; charset=utf-8"},
auth=auth, timeout=15,
)
if r.status_code in (200, 201, 204):
imported += 1
else:
failed += 1
logger.warning(f"Import PUT {uid} returned {r.status_code}: {r.text[:120]}")
except Exception as e:
failed += 1
logger.error(f"Import PUT {uid} failed: {e}")
if imported:
_contact_cache["fetched_at"] = None
return {"imported": imported, "failed": failed, "total": len(blocks)}
def _import_csv_contacts(text: str) -> Dict:
"""Import contacts from CSV. Supports common headers:
name/full_name/display_name, email/email_address/e-mail, phone/tel.
Falls back to first columns as name,email,phone when no headers exist."""
raw = (text or "").strip()
if not raw:
return {"imported": 0, "failed": 0, "total": 0, "error": "No CSV data found"}
try:
sample = raw[:2048]
dialect = csv.Sniffer().sniff(sample)
except Exception:
dialect = csv.excel
stream = io.StringIO(raw)
try:
has_header = csv.Sniffer().has_header(raw[:2048])
except Exception:
has_header = True
rows = []
if has_header:
reader = csv.DictReader(stream, dialect=dialect)
for row in reader:
lowered = {str(k or "").strip().lower(): (v or "").strip() for k, v in row.items()}
name = (
lowered.get("name") or lowered.get("full name") or lowered.get("full_name")
or lowered.get("display name") or lowered.get("display_name")
or lowered.get("fn") or ""
)
email = (
lowered.get("email") or lowered.get("email address")
or lowered.get("email_address") or lowered.get("e-mail")
or lowered.get("mail") or ""
)
phone = lowered.get("phone") or lowered.get("telephone") or lowered.get("tel") or ""
rows.append((name, email, phone))
else:
stream.seek(0)
reader = csv.reader(stream, dialect=dialect)
for row in reader:
cols = [(c or "").strip() for c in row]
if not any(cols):
continue
rows.append((
cols[0] if len(cols) > 0 else "",
cols[1] if len(cols) > 1 else "",
cols[2] if len(cols) > 2 else "",
))
imported = 0
failed = 0
total = 0
existing_emails = {
e.lower()
for c in _fetch_contacts()
for e in (c.get("emails") or [])
if e
}
for name, email, phone in rows:
email = (email or "").strip()
name = (name or "").strip() or (email.split("@")[0] if email else "")
if not email:
continue
total += 1
if email.lower() in existing_emails:
continue
ok = _create_contact(name, email)
if ok:
imported += 1
existing_emails.add(email.lower())
# If the CSV had a phone number, rewrite the just-created row
# through the richer update path so phone lands in CardDAV too.
if phone:
try:
contacts = _fetch_contacts(force=True)
created = next((c for c in contacts if email.lower() in [e.lower() for e in c.get("emails", [])]), None)
if created and created.get("uid"):
_update_contact(created["uid"], name, [email], [phone])
except Exception:
pass
else:
failed += 1
if imported:
_contact_cache["fetched_at"] = None
return {"imported": imported, "failed": failed, "total": total}
def _contacts_to_vcf(contacts: List[Dict]) -> str:
return "".join(
_build_vcard(
c.get("name") or ((c.get("emails") or [""])[0].split("@")[0] if c.get("emails") else "Contact"),
"",
uid=c.get("uid") or str(uuid.uuid4()),
emails=c.get("emails") or [],
phones=c.get("phones") or [],
)
for c in contacts
)
def _contacts_to_csv(contacts: List[Dict]) -> str:
out = io.StringIO()
writer = csv.writer(out)
writer.writerow(["name", "email", "phone"])
for c in contacts:
emails = c.get("emails") or [""]
phones = c.get("phones") or [""]
max_len = max(len(emails), len(phones), 1)
for i in range(max_len):
writer.writerow([
c.get("name") or "",
emails[i] if i < len(emails) else "",
phones[i] if i < len(phones) else "",
])
return out.getvalue()
def _update_contact(uid: str, name: str, emails: List[str], phones: List[str]) -> bool:
"""Rewrite an existing contact via CardDAV or local contacts."""
cfg = _get_carddav_config()
if not _carddav_configured(cfg):
contacts = _load_local_contacts()
found = False
out = []
for c in contacts:
if c.get("uid") == uid:
out.append(_normalize_contact({"uid": uid, "name": name, "emails": emails, "phones": phones}))
found = True
else:
out.append(c)
if not found:
out.append(_normalize_contact({"uid": uid, "name": name, "emails": emails, "phones": phones}))
_save_local_contacts(out)
return True
vcard = _build_vcard(name, "", uid=uid, emails=emails, phones=phones)
# Use the real resource href (handles externally-created contacts whose
# filename != UID); falls back to the <uid>.vcf guess.
url = _resolve_resource_url(uid)
try:
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
r = httpx.put(
url,
data=vcard.encode("utf-8"),
headers={"Content-Type": "text/vcard; charset=utf-8"},
auth=auth,
timeout=10,
)
if r.status_code in (200, 201, 204):
_contact_cache["fetched_at"] = None
return True
logger.warning(f"CardDAV update PUT returned {r.status_code}: {r.text[:200]}")
return False
except Exception as e:
logger.error(f"Failed to update contact: {e}")
return False
def _delete_contact(uid: str) -> bool:
"""Delete a contact via CardDAV or local contacts."""
cfg = _get_carddav_config()
if not _carddav_configured(cfg):
contacts = _load_local_contacts()
remaining = [c for c in contacts if c.get("uid") != uid]
_save_local_contacts(remaining)
return True
url = _resolve_resource_url(uid)
try:
auth = (cfg["username"], cfg["password"]) if cfg["username"] else None
r = httpx.delete(url, auth=auth, timeout=10)
if r.status_code in (200, 204):
_contact_cache["fetched_at"] = None
return True
if r.status_code == 404:
# Resource not found at the resolved URL. With href resolution
# this should be rare (genuinely already deleted). Invalidate
# the cache and report success so the UI doesn't keep a ghost.
logger.info(f"CardDAV DELETE 404 for {uid} — treating as already gone")
_contact_cache["fetched_at"] = None
return True
logger.warning(f"CardDAV DELETE returned {r.status_code}: {r.text[:200]}")
return False
except Exception as e:
logger.error(f"Failed to delete contact: {e}")
return False
# ── Routes ──
def setup_contacts_routes():
router = APIRouter(prefix="/api/contacts", tags=["contacts"])
@router.get("/list")
async def list_contacts(_admin: str = Depends(require_admin)):
"""List all contacts."""
contacts = _fetch_contacts()
return {"contacts": contacts, "count": len(contacts)}
@router.get("/search")
async def search_contacts(q: str = Query(""), _admin: str = Depends(require_admin)):
"""Search contacts by name or email. Returns up to 10 matches."""
contacts = _fetch_contacts()
if not q:
return {"results": []}
q_lower = q.lower()
results = []
for c in contacts:
if q_lower in c["name"].lower():
results.append(c)
continue
for em in c["emails"]:
if q_lower in em.lower():
results.append(c)
break
return {"results": results[:10]}
@router.post("/add")
async def add_contact(data: dict, _admin: str = Depends(require_admin)):
"""Add a new contact."""
name = data.get("name", "").strip()
email = data.get("email", "").strip()
if not email:
return {"success": False, "error": "Email required"}
# Check if already exists
contacts = _fetch_contacts()
for c in contacts:
if email.lower() in [e.lower() for e in c["emails"]]:
return {"success": True, "message": "Already exists", "contact": c}
if not name:
name = email.split("@")[0]
ok = _create_contact(name, email)
return {"success": ok}
@router.post("/import")
async def import_vcf(data: dict, _admin: str = Depends(require_admin)):
"""Import contacts from .vcf or CSV. Body: {"vcf": "..."} or {"csv": "..."}."""
text = data.get("vcf") or data.get("text") or ""
csv_text = data.get("csv") or ""
if text.strip():
if "BEGIN:VCARD" not in text.upper():
return {"success": False, "error": "No vCard data found"}
result = _import_vcards(text)
elif csv_text.strip():
result = _import_csv_contacts(csv_text)
else:
return {"success": False, "error": "No contact data found"}
result["success"] = result.get("imported", 0) > 0
return result
@router.get("/export")
async def export_contacts(
format: str = Query("vcf", pattern="^(vcf|csv)$"),
_admin: str = Depends(require_admin),
):
"""Export all contacts as vCard or CSV."""
contacts = _fetch_contacts(force=True)
if format == "csv":
content = _contacts_to_csv(contacts)
media_type = "text/csv; charset=utf-8"
filename = "odysseus-contacts.csv"
else:
content = _contacts_to_vcf(contacts)
media_type = "text/vcard; charset=utf-8"
filename = "odysseus-contacts.vcf"
return Response(
content=content,
media_type=media_type,
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
@router.get("/config")
async def get_config(_admin: str = Depends(require_admin)):
cfg = _get_carddav_config()
# Mask password
if cfg["password"]:
cfg["password"] = "***"
return cfg
@router.put("/config")
async def update_config(data: dict, _admin: str = Depends(require_admin)):
settings = _load_settings()
for key in ("carddav_url", "carddav_username", "carddav_password"):
if key in data:
settings[key] = data[key]
_save_settings(settings)
# Force re-fetch
_contact_cache["fetched_at"] = None
return {"success": True}
@router.delete("/clear")
async def clear_contacts(_admin: str = Depends(require_admin)):
"""Clear all local contacts. If CardDAV is configured, only clears the local fallback cache."""
_save_local_contacts([])
return {"success": True}
# NOTE: the /{uid} routes are declared LAST so the literal paths above
# (/list, /search, /add, /config) win — otherwise PUT /config would
# match PUT /{uid} with uid="config".
@router.put("/{uid}")
async def edit_contact(uid: str, data: dict, _admin: str = Depends(require_admin)):
"""Edit an existing contact — name / emails / phones."""
name = (data.get("name") or "").strip()
emails = data.get("emails")
phones = data.get("phones")
if emails is None and data.get("email"):
emails = [data["email"]]
emails = [e.strip() for e in (emails or []) if e and e.strip()]
phones = [p.strip() for p in (phones or []) if p and p.strip()]
if not name and not emails:
return {"success": False, "error": "Name or email required"}
if not name and emails:
name = emails[0].split("@")[0]
ok = _update_contact(uid, name, emails, phones)
return {"success": ok}
@router.delete("/{uid}")
async def delete_contact(uid: str, _admin: str = Depends(require_admin)):
"""Delete a contact by UID."""
if not uid:
return {"success": False, "error": "UID required"}
ok = _delete_contact(uid)
return {"success": ok}
return router

340
routes/cookbook_helpers.py Normal file
View File

@@ -0,0 +1,340 @@
"""cookbook_helpers.py — validators + small helpers shared by the cookbook routes.
Extracted from cookbook_routes.py; the routes module imports the symbols it needs."""
import logging
import os
import re
import shlex
from fastapi import HTTPException
from pydantic import BaseModel
logger = logging.getLogger(__name__)
# HuggingFace repo IDs are <org>/<name>, both alphanumerics plus ._-
# Rejecting anything else up front closes off shell-interpolation vectors.
_REPO_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*/[A-Za-z0-9][A-Za-z0-9._-]*$")
# Include pattern is a glob: allow typical safe glyphs only.
_INCLUDE_RE = re.compile(r"^[A-Za-z0-9._\-*?/\[\]]+$")
# Remote host: user@host (optionally with :port-free hostname parts).
_REMOTE_HOST_RE = re.compile(r"^[A-Za-z0-9._-]+@[A-Za-z0-9._-]+$")
# HF tokens and API tokens are url-safe base64-like.
_TOKEN_RE = re.compile(r"^[A-Za-z0-9._~+/=-]+$")
# Session IDs we mint look like "cookbook-deadbeef" or "serve-deadbeef".
# Anything beyond plain alphanumerics + dash + underscore could break out
# of the shell/PowerShell contexts the value lands in.
_SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
_SSH_PORT_RE = re.compile(r"^\d{1,5}$")
_GPU_LIST_RE = re.compile(r"^\d+(?:,\d+)*$")
# A download target directory. Absolute or ~-relative path; safe path glyphs
# only (no quotes, shell metacharacters, or spaces) since it lands in a shell
# command. A leading ~ is expanded to $HOME at command-build time.
_LOCAL_DIR_RE = re.compile(r"^~?/[A-Za-z0-9._/-]*$|^~$")
def _validate_repo_id(v: str | None) -> str:
if not v or not _REPO_ID_RE.match(v):
raise HTTPException(400, "Invalid repo_id — must be <org>/<name> using [A-Za-z0-9._-]")
return v
def _validate_include(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _INCLUDE_RE.match(v):
raise HTTPException(400, "Invalid include pattern")
return v
def _validate_remote_host(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _REMOTE_HOST_RE.match(v):
raise HTTPException(400, "Invalid remote_host — must be user@host, no SSH option syntax")
return v
def _validate_token(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _TOKEN_RE.match(v):
raise HTTPException(400, "Invalid token characters")
return v
def _validate_local_dir(v: str | None) -> str | None:
if v is None or v == "":
return None
v = v.rstrip("/") or "/"
if not _LOCAL_DIR_RE.match(v):
raise HTTPException(400, "Invalid local_dir — must be an absolute or ~ path with no spaces or shell metacharacters")
return v
def _validate_ssh_port(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _SSH_PORT_RE.fullmatch(str(v)):
raise HTTPException(400, "Invalid ssh_port")
port = int(v)
if port < 1 or port > 65535:
raise HTTPException(400, "Invalid ssh_port")
return str(port)
def _validate_gpus(v: str | None) -> str | None:
if v is None or v == "":
return None
if not _GPU_LIST_RE.fullmatch(str(v)):
raise HTTPException(400, "Invalid gpus — expected comma-separated GPU indexes")
return str(v)
def _shell_path(p: str) -> str:
"""Render a validated path for a double-quoted shell context, expanding a
leading ~ to $HOME (single quotes wouldn't expand it). Safe because
_validate_local_dir already restricts the charset."""
if p == "~":
return '"$HOME"'
if p.startswith("~/"):
return '"$HOME/' + p[2:] + '"'
return '"' + p + '"'
def _ps_squote(v: str) -> str:
"""Escape a value for PowerShell single-quoted string interpolation.
Belt-and-suspenders on top of _validate_token's regex — if the regex
is ever loosened, this still keeps the heredoc shell-safe."""
return v.replace("'", "''")
def _bash_squote(v: str) -> str:
"""Escape a value for bash/sh single-quoted string interpolation."""
return v.replace("'", "'\\''")
# Allow-list of binaries permitted as the leading token of `req.cmd` for /api/model/serve.
# Anything else is rejected before the cmd is interpolated into a tmux/PowerShell wrapper.
_SERVE_CMD_ALLOWLIST = {
"vllm", "llama-server", "llama_server", "llama.cpp", "ollama",
"python", "python3",
"sglang", "lmdeploy",
"node", "npx",
}
# The llama.cpp GGUF launcher (static/js/cookbook.js) emits a fixed-shape
# prelude that resolves the cached .gguf on the target host before serving:
# MODEL_FILE=$( { find …; find …; } | head -1 ) && { [ -n "$MODEL_FILE" ] && \
# [ -f "$MODEL_FILE" ]; } || { echo "ERROR…"; exit 1; } && <serve> || <serve>
# That legitimately needs $(...)/&&/||, so we recognise this exact shape and
# validate the serve binaries it guards rather than rejecting it wholesale.
_GGUF_PRELUDE_RE = re.compile(
r'^MODEL_FILE=\$\([^\n]*?\)\s*&&\s*\{[^{}]*\}\s*\|\|\s*\{[^{}]*\}\s*&&\s*'
)
def _check_serve_binary(seg: str) -> None:
"""Validate that a single command segment starts with an allowlisted binary
(after skipping leading env-var assignments like `CUDA_VISIBLE_DEVICES=0`)."""
try:
tokens = shlex.split(seg) if seg.strip() else []
except ValueError:
raise HTTPException(400, "Invalid cmd — could not parse")
if not tokens:
return
env_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*=")
first = next((t for t in tokens if not env_re.match(t)), "")
base = os.path.basename(first)
if base not in _SERVE_CMD_ALLOWLIST:
raise HTTPException(
400,
f"cmd binary '{base or '(empty)'}' is not allowed. Must start with one of: "
f"{', '.join(sorted(_SERVE_CMD_ALLOWLIST))}",
)
def _validate_serve_cmd(v: str | None) -> str | None:
"""Reject serve commands that aren't in the allowlist or contain shell metachars.
`req.cmd` is dropped verbatim into a bash/PowerShell wrapper script and
executed in a tmux session. Without this gate, an admin (or anyone in the
pre-fix world) could pass arbitrary shell payloads.
Leading env-var assignments (e.g. `CUDA_VISIBLE_DEVICES=0 python3 ...`)
are stripped before checking the binary — several of our cmd builders
prepend them, and they shouldn't trip the allowlist.
"""
if v is None or v == "":
return None
# Collapse backslash-newline line continuations into single spaces. Serve
# commands (vLLM especially) are routinely pasted multi-line with trailing
# `\` — that's a safe shell/shlex continuation, so the command stays ONE
# logical invocation and the leading-token allowlist below still governs.
v = re.sub(r"\\[ \t]*\r?\n[ \t]*", " ", v).strip()
# Backticks and raw newlines are never legitimate here.
if any(c in v for c in ("`", "\n", "\r")):
raise HTTPException(400, "Invalid characters in cmd")
# Known GGUF launcher prelude → validate the serve invocation(s) it guards.
m = _GGUF_PRELUDE_RE.match(v)
if m:
rest = v[m.end():]
# rest is `[ENV=…] python3 -m llama_cpp.server … || [ENV=…] llama-server …`
for part in rest.split("||"):
_check_serve_binary(part.strip())
return v
# Otherwise: a single invocation — no shell metacharacters allowed.
# (`$(` was the original intent; bare `$` is fine for shell-safe paths.)
if any(c in v for c in (";", "&&", "||", "$(")):
raise HTTPException(400, "Invalid characters in cmd")
_check_serve_binary(v)
return v
class ModelDownloadRequest(BaseModel):
repo_id: str
include: str | None = None # glob pattern e.g. "*Q4_K_M*"
hf_token: str | None = None
env_prefix: str | None = None # e.g. "source ~/venv/bin/activate"
remote_host: str | None = None # e.g. "gpu-box" — run download on this host via SSH
ssh_port: str | None = None # e.g. "8022" for Termux
platform: str | None = None # "linux", "termux", or "windows"
local_dir: str | None = None # base dir to download into (a per-model subfolder is created under it); None = default HF cache
disable_hf_transfer: bool = False # skip the Rust hf_transfer downloader — slower but far more reliable on large files (used by retries)
class ServeRequest(BaseModel):
repo_id: str
cmd: str
remote_host: str | None = None
ssh_port: str | None = None
env_prefix: str | None = None
hf_token: str | None = None
gpus: str | None = None
platform: str | None = None # "linux", "termux", or "windows"
def _parse_serve_phase(snapshot: str, task_type: str = "serve") -> dict:
"""Parse a tmux snapshot of a serve task into structured phase info.
Single source of truth for serve task status detection. Returns:
{ "phase": str, "status": "ready"|"running"|"", "tps": float|None,
"reqs": int|None, "pct": int|None }
"""
import re
if task_type != "serve" or not snapshot:
return {}
# Strip newlines so tmux line-wrapping doesn't break regex matching
flat = re.sub(r'\s+', ' ', snapshot)
load_matches = re.findall(r'Loading safetensors.*?(\d+)%', flat)
# Prefer "Downloading (incomplete total...)" (real aggregate bytes) over
# "Fetching N files" (whole-file count, lags with hf_transfer's chunked pulls).
downloading_matches = re.findall(r'Downloading.*?(\d+)%', flat)
fetching_matches = re.findall(r'Fetching.*?(\d+)%', flat)
dl_matches = downloading_matches if downloading_matches else fetching_matches
# Match "Avg generation throughput: X tokens/s, Running: N reqs" (with line-wrap tolerance)
tps_matches = re.findall(
r'(?:Avg )?generation throughput:\s*([\d.]+)\s*tokens/s.*?Running:\s*(\d+)\s*reqs',
flat,
)
# Check throughput FIRST — the throughput log line contains "GPU KV cache usage"
# which would otherwise false-match the warmup check
if tps_matches:
tps_str, reqs_str = tps_matches[-1]
tps = float(tps_str)
reqs = int(reqs_str)
return {
"phase": f"{tps_str} tok/s" if reqs > 0 else "idle",
"status": "ready",
"tps": tps,
"reqs": reqs,
}
if "Application startup complete" in flat:
return {"phase": "ready", "status": "ready"}
# HTTP access logs (e.g. GET /v1/models 200 OK) mean the server is up and serving
if re.search(r'(?:GET|POST)\s+/[^\s]*\s+HTTP/[\d.]+"\s*\d{3}', flat):
return {"phase": "idle", "status": "ready"}
if "Loading weights took" in flat:
return {"phase": "initializing", "status": "running"}
# "GPU KV cache" alone (during allocation) — not "GPU KV cache usage" (runtime log)
if "GPU KV cache" in flat and "GPU KV cache usage" not in flat:
return {"phase": "warming up", "status": "running"}
if load_matches:
pct = int(load_matches[-1])
return {"phase": f"loading {pct}%", "status": "running", "pct": pct}
if dl_matches:
pct = int(dl_matches[-1])
return {"phase": f"downloading {pct}%", "status": "running", "pct": pct}
return {}
def _ssh(host, cmd, port=None):
"""Build SSH command string with optional port."""
pf = f"-p {port} " if port and port != "22" else ""
return f"ssh {pf}{host} '{cmd}'"
def _safe_env_prefix(ep: str | None) -> str | None:
"""Rewrite a `source <path>` env_prefix so it no-ops if the path is missing.
Prevents `line N: <path>: No such file or directory` errors when a serve
task is launched against a host that doesn't have the expected venv.
Also rewrites leading `~/` → `$HOME/` so the path expands inside double
quotes (bash only tilde-expands unquoted tokens at word start)."""
if not ep:
return ep
import shlex
try:
parts = shlex.split(ep, posix=True)
except ValueError:
raise HTTPException(400, "Invalid env_prefix")
if len(parts) != 2 or parts[0] not in {"source", "."}:
# Bash conda activation emitted by the frontend:
# eval "$(conda shell.bash hook)" && conda activate ENV
m = re.fullmatch(r'eval "\$\(conda shell\.bash hook\)" && conda activate (.+)', ep)
if m:
env = m.group(1).strip()
try:
env_parts = shlex.split(env, posix=True)
except ValueError:
raise HTTPException(400, "Invalid env_prefix")
if len(env_parts) != 1:
raise HTTPException(400, "Invalid env_prefix")
return 'eval "$(conda shell.bash hook)" && conda activate ' + shlex.quote(env_parts[0])
# Plain conda activation, used by Windows/PowerShell and some manual callers.
if len(parts) == 3 and parts[0] == "conda" and parts[1] == "activate":
return "conda activate " + shlex.quote(parts[2])
# PowerShell venv activation emitted by the frontend:
# & 'C:\path\Scripts\Activate.ps1'
if len(parts) == 2 and parts[0] == "&":
path = parts[1]
if any(c in path for c in "\r\n;&|`$<>"):
raise HTTPException(400, "Invalid env_prefix")
return "& '" + path.replace("'", "''") + "'"
raise HTTPException(400, "Invalid env_prefix")
path = parts[1]
if any(c in path for c in "\r\n;&|`$<>"):
raise HTTPException(400, "Invalid env_prefix")
# Replace a leading "~/" with "$HOME/" so it survives quoting
if path.startswith("~/"):
path = "$HOME/" + path[2:]
elif path == "~":
path = "$HOME"
path = path.replace('"', '\\"')
return f'[ -f "{path}" ] && source "{path}" || true'
def _ssh_ps(host, script_path, port=None):
"""Build SSH command to run a PowerShell script on a Windows remote."""
pf = f"-p {port} " if port and port != "22" else ""
return f'ssh {pf}{host} "powershell -ExecutionPolicy Bypass -File {script_path}"'
# Windows session dir — stored in user's temp on the remote
WIN_SESSION_DIR = "$env:TEMP\\\\odysseus-sessions"

1728
routes/cookbook_routes.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,71 @@
"""Diagnostics routes — /api/db/stats, /api/rag/stats, /api/test/youtube, /api/test-research."""
import logging
from typing import Dict, Any
from fastapi import APIRouter, HTTPException, Form
from services.youtube.youtube_handler import extract_youtube_id, extract_transcript_async
from core.constants import DEFAULT_HOST
logger = logging.getLogger(__name__)
def setup_diagnostics_routes(
rag_manager,
rag_available: bool,
research_handler,
) -> APIRouter:
router = APIRouter(tags=["diagnostics"])
@router.get("/api/db/stats")
async def get_database_stats() -> Dict[str, Any]:
try:
from core.database import get_detailed_stats
return get_detailed_stats()
except Exception as e:
logger.error(f"DB stats error: {e}")
raise HTTPException(500, "Failed to retrieve database statistics")
@router.get("/api/rag/stats")
async def get_rag_stats() -> Dict[str, Any]:
if rag_available and rag_manager:
return rag_manager.get_stats()
return {"error": "RAG system not available"}
@router.get("/api/test/youtube")
async def test_youtube(url: str) -> Dict[str, Any]:
try:
video_id = extract_youtube_id(url)
if not video_id:
return {"error": "Invalid YouTube URL"}
data = await extract_transcript_async(url, video_id)
return {
"video_id": video_id,
"transcript_success": data.get("success", False),
"transcript_length": len(data.get("transcript", "")) if data.get("success") else 0,
"transcript_preview": (data.get("transcript", "")[:500] + "...")
if data.get("success") and len(data.get("transcript", "")) > 500
else data.get("transcript", ""),
"error": data.get("error") if not data.get("success") else None,
}
except Exception as e:
return {"error": str(e)}
@router.post("/api/test-research")
async def test_research(query: str = Form("What is machine learning?")) -> Dict[str, Any]:
try:
endpoint = f"http://{DEFAULT_HOST}:8000/v1/chat/completions"
model = "gpt-oss-120b"
result = await research_handler.call_research_service(query, endpoint, model)
return {
"status": "success",
"query": query,
"result_preview": result[:200] + "..." if len(result) > 200 else result,
"result_length": len(result),
}
except Exception as e:
return {"status": "error", "error": str(e), "query": query}
return router

198
routes/document_helpers.py Normal file
View File

@@ -0,0 +1,198 @@
"""document_helpers.py — Pydantic models, doc serializers, owner gating, file-locator helpers shared with document_routes.py."""
"""Document routes — CRUD for living documents with version history."""
import logging
from typing import Dict, Any, Optional
from fastapi import HTTPException
from pydantic import BaseModel
from core.database import Document, DocumentVersion
from core.database import Session as DbSession
logger = logging.getLogger(__name__)
# ---- Request schemas ----
class DocumentCreate(BaseModel):
session_id: Optional[str] = None
title: str = "Untitled"
language: Optional[str] = None
content: str = ""
class DocumentUpdate(BaseModel):
content: str
summary: Optional[str] = None
class DocumentPatch(BaseModel):
title: Optional[str] = None
language: Optional[str] = None
session_id: Optional[str] = None # link/unlink document to a session
# ---- Helpers ----
def _doc_to_dict(doc: Document) -> Dict[str, Any]:
return {
"id": doc.id,
"session_id": doc.session_id,
"title": doc.title,
"language": doc.language,
"current_content": doc.current_content,
"version_count": doc.version_count,
"is_active": doc.is_active,
"archived": bool(getattr(doc, "archived", False)),
"created_at": (doc.created_at.isoformat() + "Z") if doc.created_at else None,
"updated_at": (doc.updated_at.isoformat() + "Z") if doc.updated_at else None,
# Source-email provenance (set when doc was created from an email
# attachment) — drives the "Send signed reply" menu item.
"source_email_uid": getattr(doc, "source_email_uid", None),
"source_email_folder": getattr(doc, "source_email_folder", None),
"source_email_account_id": getattr(doc, "source_email_account_id", None),
"source_email_message_id": getattr(doc, "source_email_message_id", None),
}
def _version_to_dict(v: DocumentVersion) -> Dict[str, Any]:
return {
"id": v.id,
"document_id": v.document_id,
"version_number": v.version_number,
"content": v.content,
"summary": v.summary,
"source": v.source,
"created_at": v.created_at.isoformat() if v.created_at else None,
}
def _verify_doc_owner(db, doc: Document, user: str):
"""Verify `user` owns this document. Raise 404 if not.
Documents now carry their own `owner` column, so a doc whose session
was deleted (session_id → NULL) can still prove ownership and stay
openable / cloneable. We trust that column first and only fall back to
the session join for any not-yet-backfilled legacy row.
"""
if user is None:
raise HTTPException(403, "Authentication required")
if doc.owner is not None:
if doc.owner != user:
raise HTTPException(404, "Document not found")
return
# Legacy fallback: derive ownership from the linked session.
if not doc.session_id:
raise HTTPException(404, "Document not found")
session = db.query(DbSession).filter(DbSession.id == doc.session_id).first()
if not session or session.owner != user:
raise HTTPException(404, "Document not found")
def _owner_session_filter(q, user):
"""Restrict a documents query to those owned by `user`.
Documents now carry their own `owner` column (backfilled at boot from
the linked session, or assigned to the admin user for legacy/orphaned
docs). We filter on that directly rather than on a session join, so a
document whose session was deleted (session_id → NULL) still shows up
for its owner instead of silently vanishing from the Library + search.
The owner backfill runs in init_db before the app serves requests, so
by the time this filter is live there are no NULL-owner rows to leak;
we therefore match the owner strictly."""
if user is None:
return q.filter(False)
return q.filter(Document.owner == user)
def _slug(name: str) -> str:
"""Filesystem-friendly version of a document title.
Whitespace becomes underscores; other unsafe punctuation is dropped.
Preserves letters, digits, dot, hyphen, underscore. Idempotent.
"""
import re as _re
s = (name or "").strip()
# Drop the trailing extension if the title happens to include one
s = _re.sub(r'\.pdf$', '', s, flags=_re.IGNORECASE)
s = _re.sub(r'\s+', '_', s)
s = _re.sub(r'[^A-Za-z0-9._-]', '', s)
s = _re.sub(r'_+', '_', s).strip('_')
return s or "form"
# DPI scale for the interactive PDF view. ~150 DPI (2x of 72 PDF user-units).
_PDF_RENDER_SCALE = 2.0
def _locate_upload(upload_dir: str, file_id: str):
"""Find an upload by its filename ID.
Lookup order:
1. Direct hit at `upload_dir/file_id` (very small deployments).
2. The `uploads.json` index that `UploadHandler.save_upload` maintains —
maps file_hash → metadata containing the full path. O(1) once loaded.
3. Fallback: `os.walk` the date-bucketed tree. Slow on large stores;
only triggers for legacy uploads recorded before the index existed.
`followlinks=False` keeps a stray symlink loop in `data/uploads/` from
spinning the walker into infinite recursion.
"""
import os
import json as _json
direct = os.path.join(upload_dir, file_id)
if os.path.exists(direct):
return direct
# O(1) via uploads.json
try:
idx_path = os.path.join(upload_dir, "uploads.json")
if os.path.exists(idx_path):
with open(idx_path, "r") as f:
idx = _json.load(f)
for meta in (idx.values() if isinstance(idx, dict) else []):
if meta.get("id") == file_id:
p = meta.get("path")
if p and os.path.exists(p):
return p
except Exception:
pass
for root, _dirs, files in os.walk(upload_dir, followlinks=False):
if file_id in files:
return os.path.join(root, file_id)
return None
def _derive_title(content: str) -> str:
"""Derive a title from document content."""
import re
text = content.strip()
if not text:
return "Untitled"
# Markdown header
md = re.match(r'^#{1,3}\s+(.+)', text, re.MULTILINE)
if md:
title = md.group(1).strip()
if len(title) > 50:
title = title[:48] + ""
return title
# HTML heading
html = re.search(r'<h[1-3][^>]*>([^<]+)</h[1-3]>', text, re.IGNORECASE)
if html:
title = html.group(1).strip()
if len(title) > 50:
title = title[:48] + ""
return title
# First non-empty line (if short enough)
for line in text.split('\n'):
line = line.strip()
if line and 2 <= len(line) <= 60:
title = re.sub(r'[:#*`]+$', '', line).strip()
if title and len(title) > 50:
title = title[:48] + ""
return title or "Untitled"
return "Untitled"

1643
routes/document_routes.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,184 @@
"""Editor draft routes — persisted in-progress gallery-editor sessions.
The gallery editor (image canvas) lets users layer edits on top of a
photo (or a blank canvas). Persisting those layered sessions to the
server makes them survive cache clears and roams across devices —
unlike the legacy per-image localStorage drafts.
Each draft carries:
- id — opaque uuid (the client never sees gallery-image ids
as draft ids, so blank-canvas drafts work too)
- source_image_id (nullable) — back-pointer for "this draft started as
an edit of GalleryImage X"
- payload — full JSON snapshot (layers as base64 PNG dataURLs,
offsets, opacities, etc.) the editor knows how to
rehydrate
- thumbnail — small data URL for the landing-list grid
"""
import json
import logging
import uuid
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from core.database import EditorDraft, SessionLocal
from src.auth_helpers import get_current_user
logger = logging.getLogger(__name__)
class DraftCreate(BaseModel):
name: Optional[str] = None
source_image_id: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
payload: Dict[str, Any]
thumbnail: Optional[str] = None
class DraftUpdate(BaseModel):
name: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
payload: Optional[Dict[str, Any]] = None
thumbnail: Optional[str] = None
def _owns(d: EditorDraft, user: Optional[str]) -> bool:
if user is None:
return True
return (d.owner or None) == user
def _summary(d: EditorDraft) -> Dict[str, Any]:
"""List-view representation — omits the bulky payload."""
return {
"id": d.id,
"name": d.name or "Untitled",
"source_image_id": d.source_image_id,
"width": d.width,
"height": d.height,
"thumbnail": d.thumbnail,
"created_at": d.created_at.isoformat() if d.created_at else None,
"updated_at": d.updated_at.isoformat() if d.updated_at else None,
}
def setup_editor_draft_routes() -> APIRouter:
router = APIRouter(tags=["editor-drafts"])
@router.get("/api/editor-drafts")
async def list_drafts(request: Request) -> Dict[str, List[Dict[str, Any]]]:
user = get_current_user(request)
db = SessionLocal()
try:
q = db.query(EditorDraft).filter(EditorDraft.is_active == True)
if user is not None:
q = q.filter(EditorDraft.owner == user)
rows = q.order_by(EditorDraft.updated_at.desc()).limit(200).all()
return {"drafts": [_summary(d) for d in rows]}
finally:
db.close()
@router.get("/api/editor-drafts/{draft_id}")
async def get_draft(request: Request, draft_id: str) -> Dict[str, Any]:
user = get_current_user(request)
db = SessionLocal()
try:
d = db.query(EditorDraft).filter(
EditorDraft.id == draft_id, EditorDraft.is_active == True
).first()
if not d or not _owns(d, user):
raise HTTPException(404, "Draft not found")
try:
payload = json.loads(d.payload) if d.payload else {}
except Exception:
payload = {}
return {
**_summary(d),
"payload": payload,
}
finally:
db.close()
@router.post("/api/editor-drafts")
async def create_draft(request: Request, body: DraftCreate) -> Dict[str, Any]:
user = get_current_user(request)
db = SessionLocal()
try:
d = EditorDraft(
id=str(uuid.uuid4()),
owner=user,
name=(body.name or "Untitled")[:200],
source_image_id=body.source_image_id,
width=body.width,
height=body.height,
payload=json.dumps(body.payload or {}),
thumbnail=body.thumbnail,
)
db.add(d)
db.commit()
db.refresh(d)
return _summary(d)
except Exception as e:
db.rollback()
logger.warning(f"editor-draft create failed: {e}")
raise HTTPException(500, "Could not save draft")
finally:
db.close()
@router.put("/api/editor-drafts/{draft_id}")
async def update_draft(request: Request, draft_id: str, body: DraftUpdate) -> Dict[str, Any]:
user = get_current_user(request)
db = SessionLocal()
try:
d = db.query(EditorDraft).filter(
EditorDraft.id == draft_id, EditorDraft.is_active == True
).first()
if not d or not _owns(d, user):
raise HTTPException(404, "Draft not found")
if body.name is not None:
d.name = body.name[:200]
if body.width is not None:
d.width = body.width
if body.height is not None:
d.height = body.height
if body.payload is not None:
d.payload = json.dumps(body.payload)
if body.thumbnail is not None:
d.thumbnail = body.thumbnail
db.commit()
db.refresh(d)
return _summary(d)
except HTTPException:
raise
except Exception as e:
db.rollback()
logger.warning(f"editor-draft update failed: {e}")
raise HTTPException(500, "Could not update draft")
finally:
db.close()
@router.delete("/api/editor-drafts/{draft_id}")
async def delete_draft(request: Request, draft_id: str) -> Dict[str, str]:
user = get_current_user(request)
db = SessionLocal()
try:
d = db.query(EditorDraft).filter(EditorDraft.id == draft_id).first()
if not d or not _owns(d, user):
raise HTTPException(404, "Draft not found")
d.is_active = False
db.commit()
return {"status": "deleted", "id": draft_id}
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(500, str(e))
finally:
db.close()
return router

1274
routes/email_helpers.py Normal file

File diff suppressed because it is too large Load Diff

1006
routes/email_pollers.py Normal file

File diff suppressed because it is too large Load Diff

3038
routes/email_routes.py Normal file

File diff suppressed because it is too large Load Diff

318
routes/embedding_routes.py Normal file
View File

@@ -0,0 +1,318 @@
# routes/embedding_routes.py
"""Routes for managing local fastembed embedding models and custom endpoints."""
import os
import json
import shutil
import logging
import asyncio
from pathlib import Path
from fastapi import APIRouter, HTTPException, Form
from core.constants import BASE_DIR
logger = logging.getLogger(__name__)
_ENDPOINT_FILE = os.path.join(BASE_DIR, "data", "embedding_endpoint.json")
# Track in-progress downloads
_downloading: dict = {}
# Curated recommendations — good coverage of size/quality tiers
RECOMMENDED_MODELS = {
"sentence-transformers/all-MiniLM-L6-v2", # 384d, 90MB — fast & tiny, good default
"BAAI/bge-small-en-v1.5", # 384d, 67MB — smallest, solid quality
"nomic-ai/nomic-embed-text-v1.5-Q", # 768d, 130MB — quantized, great bang/buck
"BAAI/bge-base-en-v1.5", # 768d, 210MB — balanced mid-range
"snowflake/snowflake-arctic-embed-m", # 768d, 430MB — strong performer
"BAAI/bge-large-en-v1.5", # 1024d, 1.2GB — highest quality
}
def _cache_dir() -> str:
"""Get the fastembed cache directory.
Defaults to a persistent path under the repo's data/ dir. The old
default lived in /tmp, which many systems wipe on reboot — forcing a
full re-download of the embedding model after every restart.
"""
env = os.environ.get("FASTEMBED_CACHE_PATH")
if env:
return env
return os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"data", "fastembed_cache",
)
def _model_cache_name(hf_source: str) -> str:
"""Convert HF source like 'qdrant/all-MiniLM-L6-v2-onnx' to cache dir name."""
return "models--" + hf_source.replace("/", "--")
def _is_downloaded(hf_source: str) -> bool:
"""Check if a model is already cached."""
cache = _cache_dir()
model_dir = os.path.join(cache, _model_cache_name(hf_source))
if not os.path.isdir(model_dir):
return False
# Check for actual model files (not just empty dir)
snapshots = os.path.join(model_dir, "snapshots")
if os.path.isdir(snapshots):
return any(os.listdir(snapshots))
# Also check for blobs (older cache format)
blobs = os.path.join(model_dir, "blobs")
return os.path.isdir(blobs) and any(os.listdir(blobs))
def _active_model() -> str:
"""Get the currently configured fastembed model name."""
return os.environ.get("FASTEMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
def _dir_size_mb(path: str) -> float:
"""Get directory size in MB."""
total = 0
for dirpath, _, filenames in os.walk(path):
for f in filenames:
fp = os.path.join(dirpath, f)
try:
total += os.path.getsize(fp)
except OSError:
pass
return round(total / (1024 * 1024), 1)
def _load_custom_endpoint() -> dict:
"""Load the saved custom embedding endpoint, if any."""
try:
if os.path.exists(_ENDPOINT_FILE):
return json.loads(Path(_ENDPOINT_FILE).read_text())
except Exception:
pass
return {}
def _save_custom_endpoint(data: dict):
Path(_ENDPOINT_FILE).parent.mkdir(parents=True, exist_ok=True)
Path(_ENDPOINT_FILE).write_text(json.dumps(data, indent=2))
def setup_embedding_routes():
router = APIRouter(prefix="/api/embeddings")
@router.get("/models")
def list_models():
"""List all available fastembed models with download status."""
try:
from fastembed import TextEmbedding
except ImportError:
raise HTTPException(503, "fastembed is not installed")
active = _active_model()
catalog = TextEmbedding.list_supported_models()
result = []
for m in catalog:
hf_src = m.get("sources", {}).get("hf", "")
downloaded = _is_downloaded(hf_src) if hf_src else False
cached_size = None
if downloaded and hf_src:
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
cached_size = _dir_size_mb(model_path)
result.append({
"model": m["model"],
"dim": m.get("dim"),
"size_gb": m.get("size_in_GB", 0),
"description": m.get("description", ""),
"downloaded": downloaded,
"downloading": m["model"] in _downloading,
"active": m["model"] == active,
"recommended": m["model"] in RECOMMENDED_MODELS,
"cached_size_mb": cached_size,
})
# Sort: active first, then downloaded, then by size
result.sort(key=lambda x: (not x["active"], not x["downloaded"], x["size_gb"]))
return result
@router.post("/models/{model_name:path}/download")
async def download_model(model_name: str):
"""Download a fastembed model. Returns when complete."""
try:
from fastembed import TextEmbedding
except ImportError:
raise HTTPException(503, "fastembed is not installed")
# Validate model exists
catalog = {m["model"]: m for m in TextEmbedding.list_supported_models()}
if model_name not in catalog:
raise HTTPException(404, f"Unknown model: {model_name}")
hf_src = catalog[model_name].get("sources", {}).get("hf", "")
if hf_src and _is_downloaded(hf_src):
return {"status": "already_downloaded", "model": model_name}
if model_name in _downloading:
return {"status": "already_downloading", "model": model_name}
_downloading[model_name] = True
try:
# Run in thread to not block the event loop
loop = asyncio.get_event_loop()
cache = _cache_dir()
await loop.run_in_executor(
None,
lambda: TextEmbedding(model_name=model_name, cache_dir=cache),
)
return {"status": "downloaded", "model": model_name}
except Exception as e:
logger.error(f"Failed to download {model_name}: {e}")
raise HTTPException(500, f"Download failed: {str(e)}")
finally:
_downloading.pop(model_name, None)
@router.get("/models/{model_name:path}/status")
def download_status(model_name: str):
"""Check download status of a model."""
try:
from fastembed import TextEmbedding
except ImportError:
raise HTTPException(503, "fastembed is not installed")
catalog = {m["model"]: m for m in TextEmbedding.list_supported_models()}
if model_name not in catalog:
raise HTTPException(404, f"Unknown model: {model_name}")
hf_src = catalog[model_name].get("sources", {}).get("hf", "")
downloaded = _is_downloaded(hf_src) if hf_src else False
return {
"model": model_name,
"downloaded": downloaded,
"downloading": model_name in _downloading,
}
@router.delete("/models/{model_name:path}")
def delete_model(model_name: str):
"""Delete a cached model."""
if model_name == _active_model():
raise HTTPException(400, "Cannot delete the active embedding model")
if model_name in _downloading:
raise HTTPException(400, "Model is currently downloading")
try:
from fastembed import TextEmbedding
except ImportError:
raise HTTPException(503, "fastembed is not installed")
catalog = {m["model"]: m for m in TextEmbedding.list_supported_models()}
if model_name not in catalog:
raise HTTPException(404, f"Unknown model: {model_name}")
hf_src = catalog[model_name].get("sources", {}).get("hf", "")
if not hf_src:
raise HTTPException(400, "No cache source for this model")
model_path = os.path.join(_cache_dir(), _model_cache_name(hf_src))
if not os.path.isdir(model_path):
return {"deleted": False, "message": "Model not cached"}
shutil.rmtree(model_path)
logger.info(f"Deleted cached model: {model_name} ({model_path})")
return {"deleted": True, "model": model_name}
@router.get("/endpoint")
def get_endpoint():
"""Get the current custom embedding endpoint config."""
saved = _load_custom_endpoint()
current_url = os.environ.get("EMBEDDING_URL", "")
return {
"url": saved.get("url", current_url),
"model": saved.get("model", os.environ.get("EMBEDDING_MODEL", "")),
"active": bool(saved.get("url") or current_url),
}
@router.post("/endpoint")
def set_endpoint(url: str = Form(...), model: str = Form("")):
"""Save a custom embedding endpoint URL."""
url = url.strip()
if not url:
raise HTTPException(400, "URL is required")
# Quick health check
try:
import httpx
resp = httpx.post(
url,
json={"input": ["test"], "model": model or "test"},
timeout=10,
)
resp.raise_for_status()
except Exception as e:
raise HTTPException(400, f"Endpoint unreachable: {e}")
# Persist and set in environment for immediate use
data = {"url": url}
if model:
data["model"] = model
_save_custom_endpoint(data)
os.environ["EMBEDDING_URL"] = url
if model:
os.environ["EMBEDDING_MODEL"] = model
# Reset the RAG singleton so it picks up the new endpoint
import src.rag_singleton as _rs
_rs.rag_instance = None
_rs._last_attempt = 0
# Clear the HTTP-embedding "down" latch so the new endpoint is re-probed
# instead of staying on the FastEmbed fallback for the process lifetime.
try:
from src.embeddings import reset_http_embed_state
reset_http_embed_state()
except Exception:
pass
# Reset ChromaDB client (collections will be recreated with new embeddings)
try:
from src.chroma_client import reset_client
reset_client()
except Exception:
pass
logger.info(f"Custom embedding endpoint set: {url}")
return {"success": True, "url": url, "model": model}
@router.delete("/endpoint")
def clear_endpoint():
"""Clear the custom endpoint and revert to local fastembed."""
if os.path.exists(_ENDPOINT_FILE):
os.remove(_ENDPOINT_FILE)
# Remove from environment
os.environ.pop("EMBEDDING_URL", None)
os.environ.pop("EMBEDDING_MODEL", None)
# Reset the RAG singleton so it falls back to fastembed
import src.rag_singleton as _rs
_rs.rag_instance = None
_rs._last_attempt = 0
try:
from src.embeddings import reset_http_embed_state
reset_http_embed_state()
except Exception:
pass
# Reset ChromaDB client
try:
from src.chroma_client import reset_client
reset_client()
except Exception:
pass
logger.info("Custom embedding endpoint cleared, reverting to local fastembed")
return {"success": True}
return router

70
routes/emoji_routes.py Normal file
View File

@@ -0,0 +1,70 @@
# routes/emoji_routes.py
# Same-origin emoji SVG proxy. The frontend rewrites emoji in chat to a
# <span class="emoji" style="--em:url('/api/emoji/<codepoints>.svg')">
# which uses the returned SVG as a CSS mask tinted to the text color, so emoji
# render as monochrome line icons (project rule: never colorful emoji). The
# black line-art SVGs are lazily fetched from the OpenMoji CDN on first use and
# cached on disk, so:
# - the client only ever talks to our own origin (no CDN dep, no CSP change),
# - the repo isn't bloated with thousands of SVG files,
# - it works offline once an emoji has been seen once.
# Unknown/unreachable codepoints return a transparent SVG (not 404), so the CSS
# mask shows nothing rather than a solid currentColor box.
import logging
import re
from pathlib import Path
import httpx
from fastapi import APIRouter
from fastapi.responses import FileResponse, Response
logger = logging.getLogger(__name__)
_CACHE_DIR = Path(__file__).resolve().parent.parent / "data" / "emoji_cache"
# OpenMoji "black" set = monochrome line-art SVGs. Filenames are the codepoints
# in UPPERCASE (FE0F dropped, same as we compute), '-' joined.
_OPENMOJI_BASE = "https://cdn.jsdelivr.net/npm/openmoji@15.0.0/black/svg"
# codepoints like "1f600" or "1f468-200d-1f469-200d-1f467" (lowercase hex, '-' joined)
_CODE_RE = re.compile(r"^[0-9a-f]{2,6}(?:-[0-9a-f]{2,6})*$")
_SVG_HEADERS = {"Cache-Control": "public, max-age=31536000, immutable"}
# Returned when a codepoint is unknown/unreachable: an empty (transparent) SVG,
# so the CSS mask renders nothing instead of a solid box. Not cached, so a later
# request can still pick up the real glyph once the CDN is reachable.
_BLANK_SVG = b'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1 1"></svg>'
_BLANK_HEADERS = {"Cache-Control": "no-store"}
def setup_emoji_routes() -> APIRouter:
router = APIRouter(prefix="/api/emoji", tags=["emoji"])
def _blank() -> Response:
return Response(_BLANK_SVG, media_type="image/svg+xml", headers=_BLANK_HEADERS)
@router.get("/{code}.svg")
async def emoji_svg(code: str):
code = code.lower()
if not _CODE_RE.match(code):
return _blank()
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
fp = _CACHE_DIR / f"{code}.svg"
if fp.exists():
return FileResponse(fp, media_type="image/svg+xml", headers=_SVG_HEADERS)
# First time we've seen this emoji — fetch the OpenMoji black SVG + cache
# it. OpenMoji filenames are the codepoints uppercased.
try:
async with httpx.AsyncClient(timeout=8.0) as client:
r = await client.get(f"{_OPENMOJI_BASE}/{code.upper()}.svg")
if r.status_code == 200 and b"<svg" in r.content[:256]:
try:
fp.write_bytes(r.content)
except Exception:
pass # cache write is best-effort
return Response(r.content, media_type="image/svg+xml", headers=_SVG_HEADERS)
except Exception as e:
logger.warning("emoji fetch %s failed: %s", code, e)
return _blank()
return router

47
routes/font_routes.py Normal file
View File

@@ -0,0 +1,47 @@
"""Custom font discovery — lists user-supplied font files in static/fonts/custom/."""
import os
import re
from fastapi import APIRouter
CUSTOM_FONTS_DIR = os.path.join("static", "fonts", "custom")
FONT_EXTENSIONS = {".ttf", ".otf", ".woff", ".woff2"}
def _derive_family(filename):
"""Derive a font-family name from a filename like 'JetBrainsMono-Regular.woff2''JetBrains Mono'."""
name = os.path.splitext(filename)[0]
# Strip common weight/style suffixes
name = re.sub(
r'[-_ ]?(Thin|ExtraLight|UltraLight|Light|Regular|Medium|SemiBold|DemiBold|Bold|ExtraBold|UltraBold|Black|Heavy|Italic|Oblique|Variable|VF)$',
'', name, flags=re.IGNORECASE
)
# Insert spaces before uppercase runs: "JetBrainsMono" → "Jet Brains Mono"
name = re.sub(r'(?<=[a-z])(?=[A-Z])', ' ', name)
# Replace dashes/underscores with spaces
name = re.sub(r'[-_]+', ' ', name).strip()
return name or filename
def setup_font_routes():
router = APIRouter(prefix="/api/fonts", tags=["fonts"])
@router.get("/custom")
async def list_custom_fonts():
"""Return available custom fonts grouped by derived family name."""
os.makedirs(CUSTOM_FONTS_DIR, exist_ok=True)
families = {}
for f in sorted(os.listdir(CUSTOM_FONTS_DIR)):
ext = os.path.splitext(f)[1].lower()
if ext not in FONT_EXTENSIONS:
continue
family = _derive_family(f)
if family not in families:
families[family] = []
families[family].append({
"file": f,
"url": f"/static/fonts/custom/{f}",
"format": ext.lstrip('.'),
})
return {"fonts": families}
return router

125
routes/gallery_helpers.py Normal file
View File

@@ -0,0 +1,125 @@
"""gallery_helpers.py — extracted helpers, models, and small utilities.
Imported by gallery_routes.py."""
"""Gallery routes — browsable library for photos and AI-generated images."""
import logging
from datetime import datetime
from typing import Dict, Any, Optional
from pydantic import BaseModel
from core.database import GalleryImage
logger = logging.getLogger(__name__)
# ---- Request schemas ----
class GalleryPatch(BaseModel):
tags: Optional[str] = None
favorite: Optional[bool] = None
album_id: Optional[str] = None
# ---- EXIF extraction ----
def _extract_exif(content: bytes) -> dict:
"""Extract EXIF metadata from image bytes. Returns dict of fields."""
result = {"width": None, "height": None}
try:
from PIL import Image
from io import BytesIO
img = Image.open(BytesIO(content))
result["width"] = img.width
result["height"] = img.height
exif = img._getexif() if hasattr(img, '_getexif') else None
if not exif:
return result
# EXIF tag IDs
# 271=Make, 272=Model, 306=DateTime, 36867=DateTimeOriginal
# 34853=GPSInfo
result["camera_make"] = str(exif.get(271, "")).strip() or None
result["camera_model"] = str(exif.get(272, "")).strip() or None
# Date taken
for tag_id in (36867, 36868, 306): # DateTimeOriginal, DateTimeDigitized, DateTime
raw = exif.get(tag_id)
if raw:
try:
result["taken_at"] = datetime.strptime(str(raw).strip(), "%Y:%m:%d %H:%M:%S")
break
except (ValueError, TypeError):
pass
# GPS
gps_info = exif.get(34853)
if gps_info and isinstance(gps_info, dict):
try:
def _to_deg(vals):
d, m, s = [float(v) for v in vals]
return d + m / 60 + s / 3600
if 2 in gps_info and 4 in gps_info:
lat = _to_deg(gps_info[2])
lng = _to_deg(gps_info[4])
if gps_info.get(1) == 'S': lat = -lat
if gps_info.get(3) == 'W': lng = -lng
result["gps_lat"] = f"{lat:.6f}"
result["gps_lng"] = f"{lng:.6f}"
except Exception:
pass
except Exception as e:
# User-visible failure (photo loses metadata): surface at WARNING
# and record on the result so the upload endpoint can pass it back.
logger.warning(f"EXIF extraction failed: {e}")
result["exif_error"] = str(e)
return result
# ---- Helpers ----
def _image_to_dict(img: GalleryImage, session_name: str = None) -> Dict[str, Any]:
return {
"id": img.id,
"filename": img.filename,
"url": f"/api/generated-image/{img.filename}",
"prompt": img.prompt,
"model": img.model,
"size": img.size,
"quality": img.quality,
"tags": img.tags or "",
"ai_tags": img.ai_tags or "",
"user_tags": img.tags or "",
"session_id": img.session_id,
"session_name": session_name,
"album_id": img.album_id,
"is_active": img.is_active,
"favorite": img.favorite or False,
"taken_at": img.taken_at.isoformat() if img.taken_at else None,
"camera": f"{img.camera_make or ''} {img.camera_model or ''}".strip() or None,
"gps": {"lat": img.gps_lat, "lng": img.gps_lng} if img.gps_lat else None,
"width": img.width,
"height": img.height,
"file_size": img.file_size,
"created_at": img.created_at.isoformat() if img.created_at else None,
"updated_at": img.updated_at.isoformat() if img.updated_at else None,
}
def _owner_filter(q, user):
"""Apply owner filtering to a gallery query."""
if user is None:
return q.filter(False)
return q.filter(GalleryImage.owner == user)
def _human_size(nbytes):
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if abs(nbytes) < 1024:
return f"{nbytes:.1f} {unit}"
nbytes /= 1024
return f"{nbytes:.1f} PB"

1763
routes/gallery_routes.py Normal file

File diff suppressed because it is too large Load Diff

619
routes/history_routes.py Normal file
View File

@@ -0,0 +1,619 @@
"""History routes — session history, truncation, fork, conversation topics."""
import json
import uuid
import logging
from typing import Dict, Any
from fastapi import APIRouter, Request, HTTPException
from core.models import ChatMessage
from core.database import SessionLocal, ChatMessage as DbChatMessage, Session as DbSession
from src.topic_analyzer import analyze_topics
from routes.session_routes import _verify_session_owner
logger = logging.getLogger(__name__)
def setup_history_routes(session_manager) -> APIRouter:
router = APIRouter(tags=["history"])
@router.get("/api/history/{session_id}")
async def get_session_history(request: Request, session_id: str) -> Dict[str, Any]:
_verify_session_owner(request, session_id)
try:
session = session_manager.get_session(session_id)
except KeyError:
raise HTTPException(404, f"Session '{session_id}' not found")
history_dict = []
for msg in session.history:
if isinstance(msg, ChatMessage):
# Skip hidden messages (e.g. compaction summaries for AI context)
if msg.metadata and msg.metadata.get("hidden"):
continue
entry = {"role": msg.role, "content": msg.content}
if msg.metadata:
entry["metadata"] = msg.metadata
history_dict.append(entry)
elif isinstance(msg, dict):
if msg.get("metadata", {}).get("hidden"):
continue
entry = {
"role": msg.get("role", ""),
"content": msg.get("content", ""),
}
if msg.get("metadata"):
entry["metadata"] = msg["metadata"]
history_dict.append(entry)
# Fallback: load from DB if in-memory is empty
if not history_dict:
db = SessionLocal()
try:
db_messages = (
db.query(DbChatMessage)
.filter(DbChatMessage.session_id == session_id)
.order_by(DbChatMessage.timestamp)
.all()
)
import json as _json
history_dict = []
for m in db_messages:
entry = {"role": m.role, "content": m.content}
meta = {}
if m.meta_data:
try:
meta = _json.loads(m.meta_data) or {}
except (json.JSONDecodeError, ValueError):
meta = {}
if m.timestamp and "timestamp" not in meta:
meta["timestamp"] = m.timestamp.isoformat() + "Z"
if meta:
entry["metadata"] = meta
history_dict.append(entry)
if history_dict:
session.history = [
ChatMessage(role=m["role"], content=m["content"], metadata=m.get("metadata"))
for m in history_dict
]
except Exception as e:
logger.error(f"DB fallback failed for {session_id}: {e}")
finally:
db.close()
return {
"history": history_dict,
"model": session.model,
"endpoint_url": session.endpoint_url,
"name": session.name,
}
@router.post("/api/session/{session_id}/truncate")
async def truncate_session(request: Request, session_id: str):
_verify_session_owner(request, session_id)
try:
body = await request.json()
keep_count = body.get("keep_count", 0)
result = session_manager.truncate_messages(session_id, keep_count)
return {"status": "ok", "kept": keep_count, "truncated": result}
except KeyError:
raise HTTPException(404, "Session not found")
except Exception as e:
logger.error(f"Truncate error {session_id}: {e}")
raise HTTPException(500, str(e))
@router.post("/api/session/{session_id}/message")
async def add_message(request: Request, session_id: str):
"""Add a message to a session (for slash command persistence)."""
_verify_session_owner(request, session_id)
try:
body = await request.json()
role = body.get("role", "assistant")
content = body.get("content", "")
if not content:
raise HTTPException(400, "content is required")
msg = ChatMessage(role=role, content=content, metadata=body.get("metadata"))
session_manager.add_message(session_id, msg)
return {"status": "ok"}
except KeyError:
raise HTTPException(404, "Session not found")
@router.post("/api/session/{session_id}/delete-messages")
async def delete_messages(request: Request, session_id: str):
"""Delete specific messages by DB ID (or legacy index)."""
_verify_session_owner(request, session_id)
try:
body = await request.json()
msg_ids = body.get("msg_ids", [])
indices = body.get("indices") # legacy fallback
session = session_manager.get_session(session_id)
db = SessionLocal()
try:
if msg_ids:
# New ID-based delete
deleted = 0
for mid in msg_ids:
db_msg = db.query(DbChatMessage).filter(
DbChatMessage.id == mid,
DbChatMessage.session_id == session_id,
).first()
if db_msg:
db.delete(db_msg)
deleted += 1
# Remove from in-memory history by matching _db_id
def _get_db_id(m):
meta = m.metadata if isinstance(m, ChatMessage) else (m.get('metadata') if isinstance(m, dict) else None)
return meta.get('_db_id') if isinstance(meta, dict) else None
session.history = [m for m in session.history if _get_db_id(m) not in msg_ids]
elif indices:
# Legacy index-based delete
indices = sorted(indices, reverse=True)
db_messages = db.query(DbChatMessage).filter(
DbChatMessage.session_id == session_id
).order_by(DbChatMessage.timestamp).all()
deleted = 0
for idx in indices:
if 0 <= idx < len(db_messages):
db.delete(db_messages[idx])
deleted += 1
if 0 <= idx < len(session.history):
session.history.pop(idx)
else:
return {"status": "ok", "deleted": 0}
session.message_count = len(session.history)
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
if db_session:
db_session.message_count = len(session.history)
from datetime import datetime, timezone
db_session.updated_at = datetime.now(timezone.utc)
db.commit()
return {"status": "ok", "deleted": deleted}
finally:
db.close()
except KeyError:
raise HTTPException(404, "Session not found")
except Exception as e:
logger.error(f"Delete messages error {session_id}: {e}")
raise HTTPException(500, str(e))
@router.post("/api/session/{session_id}/edit-message")
async def edit_message(request: Request, session_id: str):
"""Edit the content of a message by its database ID."""
_verify_session_owner(request, session_id)
try:
body = await request.json()
msg_id = body.get("msg_id")
content = body.get("content")
if not msg_id or content is None:
raise HTTPException(400, "msg_id and content are required")
session = session_manager.get_session(session_id)
db = SessionLocal()
try:
db_msg = db.query(DbChatMessage).filter(
DbChatMessage.id == msg_id,
DbChatMessage.session_id == session_id,
).first()
if not db_msg:
raise HTTPException(404, "Message not found")
db_msg.content = content
meta = {}
if db_msg.meta_data:
try: meta = json.loads(db_msg.meta_data)
except (json.JSONDecodeError, ValueError): pass
meta['edited'] = True
db_msg.meta_data = json.dumps(meta)
# Update in-memory history by matching _db_id
for hmsg in session.history:
hmeta = hmsg.metadata if isinstance(hmsg, ChatMessage) else hmsg.get('metadata')
if isinstance(hmeta, dict) and hmeta.get('_db_id') == msg_id:
if isinstance(hmsg, ChatMessage):
hmsg.content = content
hmsg.metadata['edited'] = True
elif isinstance(hmsg, dict):
hmsg['content'] = content
hmsg['metadata']['edited'] = True
break
db.commit()
return {"status": "ok"}
finally:
db.close()
except KeyError:
raise HTTPException(404, "Session not found")
except HTTPException:
raise
except Exception as e:
logger.error(f"Edit message error {session_id}: {e}")
raise HTTPException(500, str(e))
@router.post("/api/session/{session_id}/mark-stopped")
async def mark_stopped(request: Request, session_id: str):
"""Mark the last assistant message as stopped by user."""
_verify_session_owner(request, session_id)
try:
session = session_manager.get_session(session_id)
# Find last assistant message and add stopped metadata
for msg in reversed(session.history):
if (isinstance(msg, ChatMessage) and msg.role == 'assistant') or \
(isinstance(msg, dict) and msg.get('role') == 'assistant'):
if isinstance(msg, ChatMessage):
if not msg.metadata:
msg.metadata = {}
msg.metadata['stopped'] = True
if not msg.metadata.get('model'):
msg.metadata['model'] = session.model
else:
if 'metadata' not in msg:
msg['metadata'] = {}
msg['metadata']['stopped'] = True
if not msg['metadata'].get('model'):
msg['metadata']['model'] = session.model
break
# Also update in DB
db = SessionLocal()
try:
import json as _json
db_messages = (
db.query(DbChatMessage)
.filter(DbChatMessage.session_id == session_id, DbChatMessage.role == 'assistant')
.order_by(DbChatMessage.created_at.desc())
.first()
)
if db_messages:
meta = {}
if db_messages.meta_data:
try:
meta = _json.loads(db_messages.meta_data)
except (json.JSONDecodeError, ValueError):
pass
meta['stopped'] = True
if not meta.get('model'):
meta['model'] = session.model
db_messages.meta_data = _json.dumps(meta)
db.commit()
finally:
db.close()
session_manager.save_sessions()
return {"status": "ok"}
except KeyError:
raise HTTPException(404, "Session not found")
except Exception as e:
logger.error(f"Mark stopped error {session_id}: {e}")
raise HTTPException(500, str(e))
@router.post("/api/session/{session_id}/update-last-meta")
async def update_last_meta(request: Request, session_id: str):
"""Merge metadata into the last assistant message (e.g. save variants)."""
_verify_session_owner(request, session_id)
try:
body = await request.json()
meta_update = body.get("metadata", {})
session = session_manager.get_session(session_id)
# Update in-memory
for msg in reversed(session.history):
if (isinstance(msg, ChatMessage) and msg.role == 'assistant') or \
(isinstance(msg, dict) and msg.get('role') == 'assistant'):
if isinstance(msg, ChatMessage):
if not msg.metadata:
msg.metadata = {}
msg.metadata.update(meta_update)
else:
if 'metadata' not in msg:
msg['metadata'] = {}
msg['metadata'].update(meta_update)
break
# Update in DB
db = SessionLocal()
try:
import json as _json
db_msg = (
db.query(DbChatMessage)
.filter(DbChatMessage.session_id == session_id, DbChatMessage.role == 'assistant')
.order_by(DbChatMessage.created_at.desc())
.first()
)
if db_msg:
meta = {}
if db_msg.meta_data:
try: meta = _json.loads(db_msg.meta_data)
except (json.JSONDecodeError, ValueError): pass
meta.update(meta_update)
db_msg.meta_data = _json.dumps(meta)
db.commit()
finally:
db.close()
session_manager.save_sessions()
return {"status": "ok"}
except KeyError:
raise HTTPException(404, "Session not found")
except Exception as e:
logger.error(f"Update last meta error {session_id}: {e}")
raise HTTPException(500, str(e))
@router.post("/api/session/{session_id}/merge-last-assistant")
async def merge_last_assistant(request: Request, session_id: str):
"""Merge the last two assistant messages into one (for continue)."""
_verify_session_owner(request, session_id)
try:
body = await request.json()
separator = body.get("separator", "\n\n")
session = session_manager.get_session(session_id)
# Find last two assistant messages in-memory
ai_indices = []
for i, msg in enumerate(session.history):
role = msg.role if isinstance(msg, ChatMessage) else msg.get('role', '')
if role == 'assistant':
ai_indices.append(i)
if len(ai_indices) < 2:
return {"status": "ok", "merged": False}
idx1, idx2 = ai_indices[-2], ai_indices[-1]
msg1, msg2 = session.history[idx1], session.history[idx2]
content1 = msg1.content if isinstance(msg1, ChatMessage) else msg1.get('content', '')
content2 = msg2.content if isinstance(msg2, ChatMessage) else msg2.get('content', '')
merged_content = content1 + separator + content2
# Merge metadata
meta1 = (msg1.metadata if isinstance(msg1, ChatMessage) else msg1.get('metadata')) or {}
meta2 = (msg2.metadata if isinstance(msg2, ChatMessage) else msg2.get('metadata')) or {}
merged_meta = {**meta1, **meta2}
merged_meta.pop('stopped', None) # no longer stopped after continue
# Update first message, remove second
if isinstance(msg1, ChatMessage):
msg1.content = merged_content
msg1.metadata = merged_meta
else:
msg1['content'] = merged_content
msg1['metadata'] = merged_meta
# Also remove the hidden "continue" user message between them if present
# It's the message at idx2-1 if it's a user message with continue text
remove_indices = [idx2]
if idx2 - 1 > idx1:
between = session.history[idx2 - 1]
between_role = between.role if isinstance(between, ChatMessage) else between.get('role', '')
between_content = between.content if isinstance(between, ChatMessage) else between.get('content', '')
if between_role == 'user' and 'previous response was interrupted' in between_content:
remove_indices.insert(0, idx2 - 1)
for ri in sorted(remove_indices, reverse=True):
session.history.pop(ri)
# Update DB
db = SessionLocal()
try:
import json as _json
db_messages = (
db.query(DbChatMessage)
.filter(DbChatMessage.session_id == session_id)
.order_by(DbChatMessage.created_at)
.all()
)
# Find last two assistant messages in DB
ai_db = [(i, m) for i, m in enumerate(db_messages) if m.role == 'assistant']
if len(ai_db) >= 2:
(_, db1), (_, db2) = ai_db[-2], ai_db[-1]
db1.content = merged_content
db1.meta_data = _json.dumps(merged_meta)
# Remove the continue user message if between them
db_idx2 = db_messages.index(db2)
db_idx1 = db_messages.index(db1)
for di in range(db_idx2, db_idx1, -1):
db.delete(db_messages[di])
db.commit()
finally:
db.close()
session_manager.save_sessions()
return {"status": "ok", "merged": True}
except KeyError:
raise HTTPException(404, "Session not found")
except Exception as e:
logger.error(f"Merge assistant error {session_id}: {e}")
raise HTTPException(500, str(e))
@router.post("/api/session/{session_id}/fork")
async def fork_session(request: Request, session_id: str):
"""Create a new session with messages copied up to keep_count."""
_verify_session_owner(request, session_id)
try:
body = await request.json()
keep_count = body.get("keep_count", 0)
# Get the source session
source = session_manager.sessions.get(session_id)
if not source:
raise HTTPException(404, "Session not found")
# Create new session
new_id = str(uuid.uuid4())
fork_name = f"\u2ADD {source.name}"
new_session = session_manager.create_session(
session_id=new_id,
name=fork_name,
endpoint_url=source.endpoint_url,
model=source.model,
rag=False,
owner=getattr(source, 'owner', None),
)
# Copy messages up to keep_count
msgs_to_copy = source.history[:keep_count]
for msg in msgs_to_copy:
new_session.add_message(ChatMessage(msg.role, msg.content, msg.metadata))
try:
from src.event_bus import fire_event
fire_event("session_created", getattr(source, 'owner', None))
except Exception:
logger.debug("session_created event dispatch failed", exc_info=True)
return {
"status": "ok",
"id": new_id,
"name": fork_name,
"kept": len(msgs_to_copy),
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Fork error {session_id}: {e}")
raise HTTPException(500, str(e))
@router.get("/api/conversations/topics")
async def get_conversation_topics(request: Request) -> Dict[str, Any]:
from src.auth_helpers import get_current_user
user = get_current_user(request)
try:
return analyze_topics(session_manager, owner=user)
except Exception as e:
raise HTTPException(500, f"Topic analysis failed: {e}")
@router.post("/api/session/{session_id}/compact")
async def compact_session(request: Request, session_id: str):
"""Manually trigger context compaction for a session."""
_verify_session_owner(request, session_id)
try:
session = session_manager.get_session(session_id)
except KeyError:
raise HTTPException(404, "Session not found")
try:
from src.model_context import estimate_tokens, get_context_length
from src.llm_core import llm_call_async
from src.endpoint_resolver import resolve_endpoint
if len(session.history) < 6:
return {"status": "ok", "message": "Not enough messages to compact"}
ctx_len = get_context_length(session.endpoint_url, session.model)
messages_before = session.get_context_messages()
used_before = estimate_tokens(messages_before)
pct_before = round((used_before / ctx_len) * 100, 1) if ctx_len else 0
msg_count_before = len(session.history)
# Keep only last 4 messages, summarize the rest
keep_count = 4
older = session.history[:-keep_count]
recent = session.history[-keep_count:]
# Build text to summarize
convo_text = "\n".join(
f"{(m.role if isinstance(m, ChatMessage) else m.get('role', '')).upper()}: "
f"{(m.content if isinstance(m, ChatMessage) else m.get('content', ''))[:2000]}"
for m in older
)
# Use utility model if available
util_url, util_model, util_headers = resolve_endpoint("utility")
compact_url = util_url or session.endpoint_url
compact_model = util_model or session.model
compact_headers = util_headers if util_url else session.headers
from src.context_compactor import SELF_SUMMARY_SYSTEM_PROMPT
compaction_count = sum(1 for m in session.history if isinstance(m, ChatMessage) and "[Conversation summary" in (m.content or ""))
sys_prompt = SELF_SUMMARY_SYSTEM_PROMPT.replace("{count}", str(len(older))).replace("{n}", str(compaction_count + 1))
summary = await llm_call_async(
compact_url, compact_model,
[
{"role": "system", "content": sys_prompt},
{"role": "user", "content": convo_text},
],
temperature=0.2, max_tokens=1024,
headers=compact_headers, timeout=30,
)
# Replace session history: summary as system message + recent messages
# System message holds the full summary for AI context
system_summary = ChatMessage(
role="system",
content=f"[Conversation summary — {len(older)} earlier messages were compacted]\n\n{summary}",
metadata={"compacted": True, "hidden": True},
)
# Visible assistant message just shows stats
summary_msg = ChatMessage(
role="assistant",
content=f"**Conversation compacted** — {len(older)} messages summarized, {len(recent)} kept.",
metadata={"compacted": True, "messages_removed": len(older)},
)
new_history = [system_summary, summary_msg] + list(recent)
session.history = new_history
session.message_count = len(session.history)
logger.info(f"Compact: session {session_id} history now has {len(session.history)} messages (was {msg_count_before})")
# Update DB: delete old messages, insert summary
db = SessionLocal()
try:
db_msgs = db.query(DbChatMessage).filter(
DbChatMessage.session_id == session_id
).order_by(DbChatMessage.timestamp).all()
# Delete all but the last keep_count
for m in db_msgs[:-keep_count]:
db.delete(m)
# Insert system summary (hidden, for AI context) and visible summary
import json as _json
import uuid
from datetime import datetime, timezone
now = datetime.now(timezone.utc)
db_sys_summary = DbChatMessage(
id=str(uuid.uuid4()),
session_id=session_id,
role="system",
content=system_summary.content,
meta_data=_json.dumps(system_summary.metadata),
timestamp=now,
)
db.add(db_sys_summary)
db_summary = DbChatMessage(
id=str(uuid.uuid4()),
session_id=session_id,
role="assistant",
content=summary_msg.content,
meta_data=_json.dumps(summary_msg.metadata),
timestamp=now,
)
db.add(db_summary)
# Update session record
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
if db_session:
db_session.message_count = len(session.history)
db_session.updated_at = datetime.now(timezone.utc)
db.commit()
finally:
db.close()
session_manager.save_sessions()
used_after = estimate_tokens(session.get_context_messages())
pct_after = round((used_after / ctx_len) * 100, 1) if ctx_len else 0
return {
"status": "ok",
"message": f"Compacted: {msg_count_before} msgs → {len(session.history)} msgs ({pct_before}% → {pct_after}%)",
"before": pct_before,
"after": pct_after,
}
except Exception as e:
logger.error(f"Manual compact error {session_id}: {e}")
raise HTTPException(500, str(e))
return router

204
routes/hwfit_routes.py Normal file
View File

@@ -0,0 +1,204 @@
from copy import deepcopy
from fastapi import APIRouter
def setup_hwfit_routes():
router = APIRouter(prefix="/api/hwfit", tags=["hwfit"])
def _apply_manual_hardware(system, manual_mode="", manual_gpu_count="", manual_vram_gb="", manual_ram_gb="", manual_backend=""):
"""Manual hardware is a "what if I had this setup" simulator —
REPLACES the detected hardware entirely instead of adding to it.
The previous additive behavior averaged the manual VRAM across
all GPUs (base + manual), which meant adding "1× 400 GB" on top
of "2× 70 GB" only nudged the per-GPU cap from 70 to 180 GB
(= 540 / 3), so GGUF models bigger than that still didn't surface
— exactly the "cap stuck at detected level" bug the user hit.
"""
manual_mode = (manual_mode or "").lower()
if manual_mode not in {"gpu", "ram"}:
return system
try:
override_ram_gb = float(manual_ram_gb) if manual_ram_gb else 0
except ValueError:
override_ram_gb = 0
override_ram_gb = max(0.0, override_ram_gb)
if override_ram_gb:
# Replace RAM, don't add. The number in the field is the
# TOTAL system memory the user wants to simulate.
system["available_ram_gb"] = round(override_ram_gb, 1)
system["total_ram_gb"] = round(override_ram_gb, 1)
system["manual_hardware"] = True
if manual_mode == "ram":
# RAM-only simulation — wipe GPU entirely so the ranker uses
# CPU/RAM paths.
system["has_gpu"] = False
system["gpu_name"] = None
system["gpu_vram_gb"] = 0
system["gpu_count"] = 0
system["gpus"] = []
system["gpu_groups"] = []
system["backend"] = "cpu_x86"
return system
try:
count = int(manual_gpu_count) if manual_gpu_count else 1
except ValueError:
count = 1
try:
vram_each = float(manual_vram_gb) if manual_vram_gb else 8.0
except ValueError:
vram_each = 8.0
count = max(1, min(count, 16))
vram_each = max(1.0, vram_each)
backend = (manual_backend or system.get("backend") or "cuda").lower()
if backend not in {"cuda", "rocm", "cpu_x86", "cpu_arm"}:
backend = "cuda"
total_vram = round(vram_each * count, 1)
gpu_name = f"Simulated {backend.upper()} GPU" + (f" × {count}" if count > 1 else "")
system["has_gpu"] = True
system["gpu_name"] = gpu_name
system["gpu_vram_gb"] = total_vram
system["gpu_count"] = count
system["gpus"] = [
{"index": i, "name": gpu_name, "vram_gb": vram_each}
for i in range(count)
]
# Single homogeneous pool — vram_each here is the ACTUAL per-GPU
# VRAM the user entered, not an average. That's the whole point:
# raising vram_each lifts the per-GPU cap (GGUF, tensor-parallel
# math) all the way up, not just by a small fraction.
system["gpu_groups"] = [{
"name": gpu_name,
"vram_each": vram_each,
"count": count,
"indices": list(range(count)),
"vram_total": total_vram,
}]
system["homogeneous"] = True
system["backend"] = backend
return system
@router.get("/system")
def get_system(host: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False):
"""Detect and return current system hardware info. Pass host=user@server for remote.
fresh=true bypasses the per-host cache (the Rescan button)."""
from services.hwfit.hardware import detect_system
return detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh)
@router.get("/models")
def get_models(use_case: str = "", sort: str = "score", limit: int = 50, search: str = "", host: str = "", quant: str = "", gpu_count: str = "", gpu_group: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False, manual_mode: str = "", manual_gpu_count: str = "", manual_vram_gb: str = "", manual_ram_gb: str = "", manual_backend: str = "", ignore_detected_gpu: bool = False, ignore_detected_ram: bool = False):
"""Rank LLM models against detected hardware and return scored results.
gpu_count: override GPU count (0 = CPU only, 1-N = simulate N GPUs of the
active group). gpu_group: index into system.gpu_groups (the homogeneous
pools) to target — empty/auto = the largest pool. vLLM can only
tensor-parallel across identical GPUs, so we never mix pools.
fresh=true bypasses the hardware-detection cache."""
from services.hwfit.hardware import detect_system
from services.hwfit.fit import rank_models
from services.hwfit.models import get_models, model_catalog_path
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
if system.get("error"):
return {"system": system, "models": [], "error": system["error"]}
if not get_models():
return {
"system": system,
"models": [],
"error": f"Model catalog missing or empty: {model_catalog_path()}",
}
if ignore_detected_gpu:
system["has_gpu"] = False
system["gpu_name"] = None
system["gpu_vram_gb"] = 0
system["gpu_count"] = 0
system["gpus"] = []
system["gpu_groups"] = []
if ignore_detected_ram:
system["available_ram_gb"] = 0
system["total_ram_gb"] = 0
system = _apply_manual_hardware(system, manual_mode, manual_gpu_count, manual_vram_gb, manual_ram_gb, manual_backend)
# Keep the raw detection around so the UI can still show the box's full
# GPU complement even while we rank against one homogeneous pool.
system["detected_gpu_vram_gb"] = system.get("gpu_vram_gb")
system["detected_gpu_count"] = system.get("gpu_count")
groups = system.get("gpu_groups") or []
# Resolve the target homogeneous pool. Default (auto) = the largest pool,
# which for a uniform box is simply "all the GPUs" — no behaviour change.
grp = None
if groups:
try:
gidx = int(gpu_group) if gpu_group != "" else 0
except ValueError:
gidx = 0
if 0 <= gidx < len(groups):
grp = groups[gidx]
def _apply_group(g, n):
n = max(1, min(n, g["count"]))
system["gpu_count"] = n
system["gpu_vram_gb"] = round(g["vram_each"] * n, 1)
system["gpu_name"] = g["name"]
system["active_group"] = {**g, "use_count": n}
if gpu_count != "":
n = int(gpu_count)
if n == 0:
# RAM-only mode: rank against system memory, offload allowed.
system["has_gpu"] = False
system["gpu_vram_gb"] = 0
system["gpu_count"] = 0
system["gpu_only"] = False
system.pop("active_group", None)
elif grp:
_apply_group(grp, n)
system["gpu_only"] = True
else:
# No per-GPU detail (older detection) — assume uniform split.
single_vram = (system.get("gpu_vram_gb") or 0) / (system.get("gpu_count") or 1)
system["gpu_count"] = max(1, n)
system["gpu_vram_gb"] = round(single_vram * max(1, n), 1)
system["gpu_only"] = True
elif grp:
# No explicit count, but we still pin to one pool so heterogeneous
# boxes rank against a real mixable group, not a fictional VRAM sum.
# gpu_only stays off here so the default view still surfaces offload.
_apply_group(grp, grp["count"])
results = rank_models(system, use_case=use_case or None, limit=limit, search=search or None, sort=sort, quant=quant or None)
return {"system": system, "models": results}
@router.get("/image-models")
def get_image_models(sort: str = "fit", search: str = "", host: str = "", gpu_count: str = "", ssh_port: str = "", platform: str = "", fresh: bool = False, manual_mode: str = "", manual_gpu_count: str = "", manual_vram_gb: str = "", manual_ram_gb: str = "", manual_backend: str = "", ignore_detected_gpu: bool = False, ignore_detected_ram: bool = False):
"""Rank image generation models against detected hardware."""
from services.hwfit.hardware import detect_system
from services.hwfit.image_models import rank_image_models
system = deepcopy(detect_system(host=host, ssh_port=ssh_port, platform=platform, fresh=fresh))
if system.get("error"):
return {"system": system, "models": [], "error": system["error"]}
if ignore_detected_gpu:
system["has_gpu"] = False
system["gpu_name"] = None
system["gpu_vram_gb"] = 0
system["gpu_count"] = 0
system["gpus"] = []
system["gpu_groups"] = []
if ignore_detected_ram:
system["available_ram_gb"] = 0
system["total_ram_gb"] = 0
system = _apply_manual_hardware(system, manual_mode, manual_gpu_count, manual_vram_gb, manual_ram_gb, manual_backend)
# Image models use a single GPU — always use per-GPU VRAM
gpu_vrams = [float(g.get("vram_gb") or 0) for g in (system.get("gpus") or []) if isinstance(g, dict)]
single_vram = max(gpu_vrams) if gpu_vrams else ((system.get("gpu_vram_gb") or 0) / max(system.get("gpu_count") or 1, 1))
system["gpu_vram_gb"] = single_vram
system["gpu_count"] = 1 if single_vram > 0 else 0
results = rank_image_models(system, search=search or None, sort=sort)
return {"system": system, "models": results}
return router

574
routes/mcp_routes.py Normal file
View File

@@ -0,0 +1,574 @@
# routes/mcp_routes.py
"""MCP (Model Context Protocol) server management routes."""
import json
import os
import uuid
import urllib.parse
import html
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import RedirectResponse, HTMLResponse
import logging
import httpx
from core.database import McpServer, SessionLocal
from core.middleware import require_admin
from src.mcp_manager import McpManager
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/mcp", tags=["mcp"])
def _load_disabled_map():
"""Load per-server disabled tool sets from DB."""
db = SessionLocal()
try:
disabled_map = {}
for srv in db.query(McpServer).all():
if srv.disabled_tools:
try:
names = json.loads(srv.disabled_tools)
if names:
disabled_map[srv.id] = set(names)
except (json.JSONDecodeError, TypeError):
pass
return disabled_map
finally:
db.close()
def setup_mcp_routes(mcp_manager: McpManager):
"""Setup MCP routes with the provided manager."""
@router.get("/servers")
def list_servers(request: Request):
"""List all configured MCP servers with connection status."""
require_admin(request)
db = SessionLocal()
try:
servers = db.query(McpServer).all()
result = []
for srv in servers:
status = mcp_manager.get_server_status(srv.id)
oauth_cfg = json.loads(srv.oauth_config) if srv.oauth_config else None
needs_oauth = False
if oauth_cfg:
token_file = os.path.expanduser(oauth_cfg.get("token_file", ""))
needs_oauth = token_file and not os.path.exists(token_file)
disabled_list = json.loads(srv.disabled_tools) if srv.disabled_tools else []
total_tools = status.get("tool_count", 0)
result.append({
"id": srv.id,
"name": srv.name,
"transport": srv.transport,
"command": srv.command,
"args": json.loads(srv.args) if srv.args else [],
"env": json.loads(srv.env) if srv.env else {},
"url": srv.url,
"is_enabled": srv.is_enabled,
"status": status.get("status", "disconnected"),
"tool_count": total_tools,
"disabled_tool_count": len(disabled_list),
"enabled_tool_count": max(0, total_tools - len(disabled_list)),
"error": status.get("error"),
"has_oauth": oauth_cfg is not None,
"needs_oauth": needs_oauth,
})
return result
finally:
db.close()
@router.post("/servers")
async def add_server(
request: Request,
name: str = Form(...),
transport: str = Form("stdio"),
command: str = Form(None),
args: str = Form("[]"),
env: str = Form("{}"),
url: str = Form(None),
oauth_file: str = Form(None),
oauth_config: str = Form(None),
):
"""Add a new MCP server config and attempt connection. Admin-only:
registering a stdio server is equivalent to executing arbitrary
binaries on the host."""
require_admin(request)
server_id = str(uuid.uuid4())[:8]
# Validate
if transport == "stdio" and not command:
raise HTTPException(400, "command is required for stdio transport")
if transport == "sse" and not url:
raise HTTPException(400, "url is required for SSE transport")
# Parse JSON fields
try:
parsed_args = json.loads(args) if args else []
except json.JSONDecodeError:
parsed_args = []
try:
parsed_env = json.loads(env) if env else {}
except json.JSONDecodeError:
parsed_env = {}
# Parse OAuth config
parsed_oauth_config = None
if oauth_config:
try:
parsed_oauth_config = json.loads(oauth_config)
except json.JSONDecodeError:
pass
# Write OAuth credentials file if provided (for Google MCP servers)
logger.info(f"MCP add_server: oauth_file={oauth_file!r}")
if oauth_file:
try:
oauth_data = json.loads(oauth_file)
oauth_dir = os.path.expanduser(oauth_data.get("dir", ""))
oauth_filename = oauth_data.get("filename", "")
client_id = oauth_data.get("client_id", "")
client_secret = oauth_data.get("client_secret", "")
if oauth_dir and oauth_filename and client_id and client_secret:
os.makedirs(oauth_dir, exist_ok=True)
creds = {
"installed": {
"client_id": client_id,
"client_secret": client_secret,
"redirect_uris": ["http://localhost"],
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://accounts.google.com/o/oauth2/token",
}
}
filepath = os.path.join(oauth_dir, oauth_filename)
with open(filepath, "w") as f:
json.dump(creds, f, indent=2)
logger.info(f"Wrote OAuth credentials to {filepath}")
parsed_env.pop("GOOGLE_CLIENT_ID", None)
parsed_env.pop("GOOGLE_CLIENT_SECRET", None)
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"Failed to write OAuth file: {e}")
# Save to DB
db = SessionLocal()
try:
srv = McpServer(
id=server_id,
name=name,
transport=transport,
command=command,
args=json.dumps(parsed_args),
env=json.dumps(parsed_env),
url=url,
is_enabled=True,
oauth_config=json.dumps(parsed_oauth_config) if parsed_oauth_config else None,
)
db.add(srv)
db.commit()
finally:
db.close()
# Check if OAuth token already exists — skip connection attempt if not
needs_oauth = False
if parsed_oauth_config:
token_file = os.path.expanduser(parsed_oauth_config.get("token_file", ""))
if token_file and not os.path.exists(token_file):
needs_oauth = True
connected = False
if not needs_oauth:
connected = await mcp_manager.connect_server(
server_id=server_id,
name=name,
transport=transport,
command=command,
args=parsed_args,
env=parsed_env,
url=url,
)
status = mcp_manager.get_server_status(server_id)
return {
"id": server_id,
"name": name,
"connected": connected,
"status": "needs_oauth" if needs_oauth else status.get("status", "disconnected"),
"tool_count": status.get("tool_count", 0),
"error": "OAuth authorization required" if needs_oauth else status.get("error"),
"needs_oauth": needs_oauth,
}
@router.post("/servers/{server_id}/reconnect")
async def reconnect_server(server_id: str, request: Request):
"""Reconnect to an MCP server."""
require_admin(request)
db = SessionLocal()
try:
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
if not srv:
raise HTTPException(404, "Server not found")
await mcp_manager.disconnect_server(server_id)
args = json.loads(srv.args) if srv.args else []
env = json.loads(srv.env) if srv.env else {}
connected = await mcp_manager.connect_server(
server_id=server_id,
name=srv.name,
transport=srv.transport,
command=srv.command,
args=args,
env=env,
url=srv.url,
)
status = mcp_manager.get_server_status(server_id)
return {
"connected": connected,
"status": status.get("status", "disconnected"),
"tool_count": status.get("tool_count", 0),
"error": status.get("error"),
}
finally:
db.close()
@router.patch("/servers/{server_id}")
async def toggle_server(server_id: str, request: Request, is_enabled: str = Form(...)):
"""Enable or disable an MCP server."""
require_admin(request)
db = SessionLocal()
try:
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
if not srv:
raise HTTPException(404, "Server not found")
enabled = str(is_enabled).lower() == "true"
srv.is_enabled = enabled
db.commit()
if enabled:
args = json.loads(srv.args) if srv.args else []
env = json.loads(srv.env) if srv.env else {}
await mcp_manager.connect_server(
server_id=server_id,
name=srv.name,
transport=srv.transport,
command=srv.command,
args=args,
env=env,
url=srv.url,
)
else:
await mcp_manager.disconnect_server(server_id)
return {"id": server_id, "is_enabled": enabled}
finally:
db.close()
@router.delete("/servers/{server_id}")
async def delete_server(server_id: str, request: Request):
"""Remove an MCP server."""
require_admin(request)
db = SessionLocal()
try:
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
if not srv:
raise HTTPException(404, "Server not found")
await mcp_manager.disconnect_server(server_id)
db.delete(srv)
db.commit()
return {"status": "deleted"}
finally:
db.close()
@router.get("/tools")
def list_tools(request: Request):
"""List all discovered MCP tools across all connected servers."""
require_admin(request)
disabled_map = _load_disabled_map()
return mcp_manager.get_all_tools(disabled_map)
@router.get("/servers/{server_id}/tools")
def list_server_tools(server_id: str, request: Request):
"""List all tools for a specific MCP server with enabled/disabled state."""
require_admin(request)
db = SessionLocal()
try:
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
if not srv:
raise HTTPException(404, "Server not found")
disabled_list = json.loads(srv.disabled_tools) if srv.disabled_tools else []
disabled_set = set(disabled_list)
finally:
db.close()
all_tools = mcp_manager.get_all_tools()
server_tools = [t for t in all_tools if t["server_id"] == server_id]
for t in server_tools:
t["is_disabled"] = t["name"] in disabled_set
return server_tools
@router.patch("/servers/{server_id}/tools")
async def update_disabled_tools(server_id: str, request: Request):
"""Bulk update disabled tools list for a server.
Expects JSON body: {"disabled": ["tool_name_1", "tool_name_2"]}
"""
require_admin(request)
db = SessionLocal()
try:
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
if not srv:
raise HTTPException(404, "Server not found")
body = await request.json()
disabled = body.get("disabled", [])
if not isinstance(disabled, list):
raise HTTPException(400, "disabled must be a list of tool names")
srv.disabled_tools = json.dumps(disabled) if disabled else None
db.commit()
return {"id": server_id, "disabled_count": len(disabled)}
finally:
db.close()
# ── OAuth flow for Google MCP servers ──────────────────────────
@router.get("/oauth/authorize/{server_id}")
def oauth_authorize(server_id: str, request: Request):
"""Show OAuth authorization page with Google sign-in link."""
require_admin(request)
db = SessionLocal()
try:
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
if not srv:
raise HTTPException(404, "Server not found")
if not srv.oauth_config:
raise HTTPException(400, "Server has no OAuth config")
oauth_cfg = json.loads(srv.oauth_config)
keys_file = os.path.expanduser(oauth_cfg.get("keys_file", ""))
if not keys_file or not os.path.exists(keys_file):
raise HTTPException(400, "OAuth keys file not found")
with open(keys_file) as f:
keys_data = json.load(f)
keys = keys_data.get("installed") or keys_data.get("web")
if not keys:
raise HTTPException(400, "Invalid OAuth keys file format")
client_id = keys["client_id"]
scopes = oauth_cfg.get("scopes", [])
# For Desktop App creds, redirect to localhost — the user will
# paste the resulting URL back if they're on a different device.
redirect_uri = "http://localhost:7000/api/mcp/oauth/callback"
params = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": " ".join(scopes),
"access_type": "offline",
"prompt": "consent",
"state": server_id,
}
auth_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urllib.parse.urlencode(params)
# Determine if user is accessing from the same machine
host = request.headers.get("host", "")
is_local = host.startswith("localhost") or host.startswith("127.0.0.1")
if is_local:
# Same machine — just redirect, callback will work directly
return RedirectResponse(auth_url)
else:
# Remote device — show paste-back page
return HTMLResponse(_oauth_authorize_page(auth_url, server_id, host))
finally:
db.close()
@router.get("/oauth/callback")
async def oauth_callback(code: str, state: str, request: Request):
"""Handle OAuth callback from Google — exchange code for tokens."""
require_admin(request)
server_id = state
return await _exchange_and_connect(server_id, code, request)
@router.post("/oauth/exchange/{server_id}")
async def oauth_exchange(server_id: str, request: Request, callback_url: str = Form(...)):
"""Manual code exchange — user pastes the callback URL from their browser."""
require_admin(request)
try:
parsed = urllib.parse.urlparse(callback_url)
params = urllib.parse.parse_qs(parsed.query)
code = params.get("code", [None])[0]
if not code:
return HTMLResponse(_oauth_result_page("Error", "No authorization code found in the URL. Make sure you copied the full URL from your browser."), status_code=400)
except Exception:
return HTMLResponse(_oauth_result_page("Error", "Invalid URL format."), status_code=400)
return await _exchange_and_connect(server_id, code, request)
async def _exchange_and_connect(server_id: str, code: str, request: Request):
"""Exchange auth code for tokens and connect the MCP server."""
db = SessionLocal()
try:
srv = db.query(McpServer).filter(McpServer.id == server_id).first()
if not srv:
return HTMLResponse(_oauth_result_page("Error", "Server not found."), status_code=404)
if not srv.oauth_config:
return HTMLResponse(_oauth_result_page("Error", "No OAuth config."), status_code=400)
oauth_cfg = json.loads(srv.oauth_config)
keys_file = os.path.expanduser(oauth_cfg.get("keys_file", ""))
token_file = os.path.expanduser(oauth_cfg.get("token_file", ""))
with open(keys_file) as f:
keys_data = json.load(f)
keys = keys_data.get("installed") or keys_data.get("web")
client_id = keys["client_id"]
client_secret = keys["client_secret"]
redirect_uri = "http://localhost:7000/api/mcp/oauth/callback"
async with httpx.AsyncClient() as client:
resp = await client.post(
"https://oauth2.googleapis.com/token",
data={
"code": code,
"client_id": client_id,
"client_secret": client_secret,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
if resp.status_code != 200:
err = resp.text
logger.error(f"OAuth token exchange failed: {err}")
return HTMLResponse(_oauth_result_page("Authorization Failed", f"Google returned an error: {err}"), status_code=400)
tokens = resp.json()
logger.info(f"OAuth tokens received for server {server_id}")
# Save tokens to the file the MCP package expects
os.makedirs(os.path.dirname(token_file), exist_ok=True)
with open(token_file, "w") as f:
json.dump(tokens, f, indent=2)
logger.info(f"Saved OAuth tokens to {token_file}")
# Attempt to connect the MCP server now
args = json.loads(srv.args) if srv.args else []
env = json.loads(srv.env) if srv.env else {}
connected = await mcp_manager.connect_server(
server_id=server_id,
name=srv.name,
transport=srv.transport,
command=srv.command,
args=args,
env=env,
url=srv.url,
)
if connected:
status = mcp_manager.get_server_status(server_id)
tool_count = status.get("tool_count", 0)
return HTMLResponse(_oauth_result_page(
"Authorization Successful",
f"{srv.name} connected with {tool_count} tools. You can close this window.",
success=True,
))
else:
status = mcp_manager.get_server_status(server_id)
return HTMLResponse(_oauth_result_page(
"Authorized but Connection Failed",
f"Tokens saved, but the server failed to connect: {status.get('error', 'unknown error')}. Try reconnecting from Settings.",
))
except Exception as e:
logger.exception(f"OAuth callback error: {e}")
return HTMLResponse(_oauth_result_page("Error", str(e)), status_code=500)
finally:
db.close()
return router
def _oauth_authorize_page(auth_url: str, server_id: str, host: str) -> str:
"""Page with Google sign-in link and URL paste-back form for remote access."""
return f"""<!DOCTYPE html>
<html><head>
<meta charset="UTF-8"><title>Authorize — Odysseus</title>
<style>
body {{ font-family: 'Fira Code', monospace; background: #0f0f0f; color: #e0e0e0;
display: flex; justify-content: center; align-items: center; min-height: 100vh; }}
.card {{ background: #1a1a1a; border: 1px solid #333; border-radius: 12px;
padding: 2rem; max-width: 480px; text-align: center; }}
h2 {{ color: #e06c75; margin-bottom: 0.5rem; font-size: 1.1rem; }}
p {{ color: #aaa; font-size: 0.82rem; line-height: 1.6; margin: 0.8rem 0; }}
.step {{ text-align: left; color: #ccc; font-size: 0.82rem; line-height: 1.7; margin: 1rem 0; }}
.step b {{ color: #e06c75; }}
a.auth-link {{
display: inline-block; margin: 1rem 0; padding: 0.6rem 1.5rem;
background: #e06c75; color: #fff; text-decoration: none; border-radius: 6px;
font-weight: 600; font-size: 0.9rem;
}}
a.auth-link:hover {{ background: #c55; }}
input[type=text] {{
width: 100%; padding: 0.5rem; margin: 0.5rem 0;
background: #0f0f0f; border: 1px solid #333; border-radius: 6px;
color: #e0e0e0; font-family: 'Fira Code', monospace; font-size: 0.8rem;
}}
input:focus {{ outline: none; border-color: #e06c75; }}
button {{
padding: 0.5rem 1.5rem; border: none; border-radius: 6px;
background: #e06c75; color: #fff; font-weight: 600; cursor: pointer;
font-family: 'Fira Code', monospace; font-size: 0.85rem; margin-top: 0.3rem;
}}
button:hover {{ background: #c55; }}
.divider {{ border-top: 1px solid #333; margin: 1.2rem 0; }}
</style></head>
<body><div class="card">
<h2>Authorize Google Account</h2>
<div class="step">
<b>1.</b> Click the button below to sign in with Google<br>
<b>2.</b> After approving, your browser will show an error page — that's normal<br>
<b>3.</b> Copy the full URL from your browser's address bar<br>
<b>4.</b> Paste it below and click Connect
</div>
<a class="auth-link" href="{auth_url}" target="_blank" rel="noopener">Sign in with Google</a>
<div class="divider"></div>
<form method="POST" action="http://{host}/api/mcp/oauth/exchange/{server_id}">
<p>Paste the URL from your browser after signing in:</p>
<input type="text" name="callback_url" placeholder="http://localhost:7000/api/mcp/oauth/callback?code=..." required>
<br><button type="submit">Connect</button>
</form>
</div></body></html>"""
def _oauth_result_page(title: str, message: str, success: bool = False) -> str:
"""Generate a simple HTML page for the OAuth result."""
safe_title = html.escape(title)
safe_message = html.escape(message)
color = "#00661a" if success else "#e06c75"
icon = "&#10003;" if success else "&#10007;"
return f"""<!DOCTYPE html>
<html><head>
<meta charset="UTF-8"><title>{safe_title}</title>
<style>
body {{ font-family: 'Fira Code', monospace; background: #0f0f0f; color: #e0e0e0;
display: flex; justify-content: center; align-items: center; min-height: 100vh; }}
.card {{ background: #1a1a1a; border: 1px solid #333; border-radius: 12px;
padding: 2rem; max-width: 420px; text-align: center; }}
.icon {{ font-size: 3rem; color: {color}; margin-bottom: 1rem; }}
h2 {{ color: {color}; margin-bottom: 0.5rem; font-size: 1.1rem; }}
p {{ color: #aaa; font-size: 0.85rem; line-height: 1.5; }}
</style></head>
<body><div class="card">
<div class="icon">{icon}</div>
<h2>{safe_title}</h2>
<p>{safe_message}</p>
</div></body></html>"""

517
routes/memory_routes.py Normal file
View File

@@ -0,0 +1,517 @@
# routes/memory_routes.py
from fastapi import APIRouter, Form, HTTPException, Request, UploadFile, File
from typing import Dict, Any, Optional, List
import json
import os
import re
import tempfile
import time
from datetime import datetime
import logging
# Leading list-marker like "1.", "12)", or "3:" plus surrounding whitespace.
# Strips one prefix per call so import-from-LLM-output doesn't leave the
# numbering inside the saved memory text. Bullet markers (-, *, •) are
# also peeled here for the same reason.
_LIST_PREFIX_RE = re.compile(r"^\s*(?:\d{1,3}[.):]\s+|[-*•]\s+)")
def _strip_list_prefix(text: str) -> str:
if not text:
return text
return _LIST_PREFIX_RE.sub("", text, count=1).strip()
from services.memory import MemoryManager
from core.session_manager import SessionManager
from src.request_models import MemoryAddRequest
from core.database import SessionLocal
from src.llm_core import llm_call_async
from services.memory.memory_extractor import audit_memories
from src.auth_helpers import get_current_user
logger = logging.getLogger(__name__)
def setup_memory_routes(memory_manager: MemoryManager, session_manager: SessionManager, memory_vector=None):
"""Set up memory-related routes."""
router = APIRouter(prefix="/api/memory", tags=["memory"])
def _owner(request: Request) -> Optional[str]:
return get_current_user(request)
def _verify_memory_owner(memory: dict, user: Optional[str]):
"""Raise 404 if user doesn't own this memory.
SECURITY: strict ownership — previously `mem_owner and mem_owner != user`
allowed any user to read/edit/delete memories with an empty/null owner
field, which leaked legacy data across the multi-user deploy.
"""
if user is None:
return # Auth disabled
if memory.get("owner") != user:
raise HTTPException(404, "Memory not found")
@router.post("/debug")
def debug_memory_relevance(request: Request, query: str = Form(...)):
"""Debug which memories would be triggered for a query"""
user = _owner(request)
memories = memory_manager.load(owner=user)
relevant = memory_manager.get_relevant_memories(query, memories, threshold=0.05)
return {
"query": query,
"total_memories": len(memories),
"relevant_count": len(relevant),
"relevant_memories": [{"text": m["text"], "category": m.get("category", "unknown")}
for m in relevant]
}
@router.post("/add", response_model=Dict[str, Any])
async def api_add_memory(
request: Request,
memory_data: Optional[MemoryAddRequest] = None
):
"""Add a new memory entry with optional category, source, and session reference."""
from src.auth_helpers import require_privilege
require_privilege(request, "can_manage_memory")
if memory_data is None:
form = await request.form()
memory_data = MemoryAddRequest(
text=form.get("text"),
category=form.get("category", "fact"),
source=form.get("source", "user"),
session_id=form.get("session_id")
)
user = _owner(request)
text = (memory_data.text or "").strip()
if not text:
raise HTTPException(400, "empty memory")
user_mem = memory_manager.load(owner=user)
if memory_manager.find_duplicates(text, user_mem):
return {"ok": True, "count": len(user_mem), "message": "Memory already exists"}
new_entry = memory_manager.add_entry(text, memory_data.source, memory_data.category, owner=user)
if memory_data.session_id:
new_entry["session_id"] = memory_data.session_id
all_mem = memory_manager.load_all()
all_mem.append(new_entry)
memory_manager.save(all_mem)
# Sync vector index
if memory_vector and memory_vector.healthy:
memory_vector.add(new_entry["id"], text)
try:
from src.event_bus import fire_event
fire_event("memory_added", user)
except Exception:
logger.debug("memory_added event dispatch failed", exc_info=True)
return {"ok": True, "count": len([m for m in all_mem if m.get("owner") == user])}
@router.get("")
def api_get_memory(request: Request):
"""Return all memory entries with their metadata."""
user = _owner(request)
return {"memory": memory_manager.load(owner=user)}
@router.post("/search")
def search_memories(request: Request, query: str = Form(...), session_id: str = Form(None), category: str = Form(None)):
"""Search across all memories with optional filters."""
user = _owner(request)
memories = memory_manager.load(owner=user)
if session_id:
memories = [m for m in memories if m.get("session_id") == session_id]
if category:
memories = [m for m in memories if category in m.get("categories", [m.get("category", "")])]
relevant = memory_manager.get_relevant_memories(query, memories, threshold=0.05, max_items=20)
return {"memories": relevant, "total": len(relevant), "query": query}
@router.get("/timeline")
def memory_timeline(request: Request):
"""Get memories in chronological order with source session information."""
user = _owner(request)
memories = memory_manager.load(owner=user)
sorted_memories = sorted(memories, key=lambda x: x.get("timestamp", 0), reverse=True)
results = []
for memory in sorted_memories:
if "timestamp" in memory:
try:
dt = datetime.fromtimestamp(memory["timestamp"])
memory["timestamp_str"] = dt.strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError, OverflowError):
memory["timestamp_str"] = "Unknown"
else:
memory["timestamp_str"] = "Unknown"
session_id = memory.get("session_id")
if session_id and session_id in session_manager.sessions:
session = session_manager.get_session(session_id)
memory["session_name"] = session.name if session else f"Session {session_id[:6]}"
else:
memory["session_name"] = "Unknown"
results.append(memory)
return {"timeline": results, "total": len(results)}
@router.get("/by-session/{session_id}")
def get_memory_by_session(request: Request, session_id: str):
"""Get all memories associated with a specific session."""
try:
session_manager.get_session(session_id)
except KeyError:
raise HTTPException(404, f"Session {session_id} not found")
user = _owner(request)
memories = memory_manager.load(owner=user)
session_memories = [m for m in memories if m.get("session_id") == session_id]
session_memories.sort(key=lambda x: x.get("timestamp", 0), reverse=True)
try:
session = session_manager.get_session(session_id)
session_name = session.name if session else f"Session {session_id[:6]}"
except KeyError:
session_name = f"Session {session_id[:6]}"
for memory in session_memories:
memory["session_name"] = session_name
return {
"session_id": session_id,
"session_name": session_name,
"memory_count": len(session_memories),
"memories": session_memories
}
@router.post("/extract")
async def extract_memory(request: Request, session: str = Form(...)) -> Dict[str, List[str]]:
"""Analyze a session's chat history and return memory suggestions."""
if not get_current_user(request):
raise HTTPException(401, "Not authenticated")
try:
sess = session_manager.get_session(session)
except KeyError:
raise HTTPException(404, "Session not found")
system_msg = {
"role": "system",
"content": (
"You are a helpful assistant. Analyze the entire conversation history provided and extract any "
"useful factual statements, contacts, addresses, phone numbers, or other information that the user "
"might want to remember for future interactions. Return each piece of information as a JSON object "
"with a 'text' field. For example: [{'text': 'Alice lives at 123 Main St'}, {'text': 'Bob works at Acme Corp'}]. "
"Only include information that is specific and likely to be useful later."
),
}
messages = [system_msg] + sess.get_context_messages()
try:
suggestion_text = await llm_call_async(
sess.endpoint_url,
sess.model,
messages,
temperature=0.2,
max_tokens=500,
headers=sess.headers,
)
try:
suggestions = json.loads(suggestion_text)
if isinstance(suggestions, list):
suggestions = [s if isinstance(s, str) else s.get("text", "") for s in suggestions]
else:
suggestions = []
except json.JSONDecodeError:
suggestions = [line.strip() for line in suggestion_text.splitlines() if line.strip()]
return {"suggestions": [s for s in suggestions if s]}
except Exception as e:
logger.error(f"LLM memory extraction failed (session {session}): {e}")
fallback = memory_manager.extract_memory_from_chat(sess.history, session)
return {"suggestions": [item["text"] for item in fallback]}
@router.post("/audit")
async def api_audit_memories(request: Request, session: str = Form(None)):
"""Deduplicate and consolidate memories via LLM.
Uses the default model from settings, or falls back to a session's model.
Returns before and after memory counts.
"""
from routes.model_routes import _load_settings, _normalize_base, build_chat_url
from core.database import ModelEndpoint
import json as _json
endpoint_url = model = None
headers = {}
# Try default model from settings first
settings = _load_settings()
ep_id = settings.get("default_endpoint_id", "")
default_model = settings.get("default_model", "")
if ep_id:
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(
ModelEndpoint.id == ep_id, ModelEndpoint.is_enabled == True
).first()
if ep:
base = _normalize_base(ep.base_url)
endpoint_url = build_chat_url(base)
model = default_model
if not model and ep.models:
try:
models = _json.loads(ep.models) if isinstance(ep.models, str) else ep.models
if models:
model = models[0]
except Exception:
pass
if ep.api_key:
headers = {"Authorization": f"Bearer {ep.api_key}"}
finally:
db.close()
# Fall back to session model if no default configured
if not endpoint_url and session:
try:
sess = session_manager.get_session(session)
endpoint_url = sess.endpoint_url
model = sess.model
headers = sess.headers
except KeyError:
pass
if not endpoint_url or not model:
raise HTTPException(400, "No default model configured — set one in Settings")
user = _owner(request)
result = await audit_memories(
memory_manager,
memory_vector,
endpoint_url,
model,
headers,
owner=user,
)
if "error" in result and "before" not in result:
raise HTTPException(502, f"Audit failed: {result['error']}")
return {
"ok": "error" not in result,
"before": result.get("before", 0),
"after": result.get("after", 0),
"removed": result.get("before", 0) - result.get("after", 0),
# True when the audit skipped the LLM because nothing changed
# since the last tidy. Frontend already says "Already clean"
# for removed==0, so this is here for future use / debugging.
"already_tidy": bool(result.get("already_tidy")),
}
@router.post("/import")
async def import_memories_from_file(
request: Request,
session: str = Form(...),
file: UploadFile = File(...)
):
"""Extract memory suggestions from an uploaded file (PDF, TXT, MD, etc.)."""
from src.auth_helpers import require_privilege
require_privilege(request, "can_manage_memory")
try:
sess = session_manager.get_session(session)
except KeyError:
raise HTTPException(404, "Session not found — needed for LLM config")
# Read file content
content = await file.read()
filename = file.filename or "upload"
_, ext = os.path.splitext(filename.lower())
allowed = {".txt", ".md", ".pdf", ".csv", ".log", ".json", ".py", ".js", ".html"}
if ext not in allowed:
raise HTTPException(400, f"Unsupported file type: {ext}")
# Extract text based on file type
if ext == ".pdf":
from src.document_processor import _process_pdf
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
tmp.write(content)
tmp_path = tmp.name
try:
text = _process_pdf(tmp_path)
finally:
os.unlink(tmp_path)
else:
try:
text = content.decode("utf-8")
except UnicodeDecodeError:
from charset_normalizer import detect
encoding = (detect(content) or {}).get("encoding") or "utf-8"
text = content.decode(encoding, errors="replace")
if not text.strip():
return {"suggestions": [], "message": "No readable content found"}
# Fast path: a .json upload that already looks like a memories export
# (list of {text, category, ...} dicts, or list of strings) round-trips
# directly without spending an LLM call to re-extract its own output.
# Without this, re-importing a memories.json from another account
# ran the file through the extractor, which often re-emitted the
# entries as a numbered list (and the numbering leaked into the
# `text` field).
if ext == ".json":
try:
parsed = json.loads(text)
except json.JSONDecodeError:
parsed = None
if isinstance(parsed, list) and parsed:
direct = []
for item in parsed:
if isinstance(item, dict) and item.get("text"):
direct.append({
"text": _strip_list_prefix(str(item["text"])),
"category": item.get("category") or "fact",
})
elif isinstance(item, str) and item.strip():
direct.append({
"text": _strip_list_prefix(item.strip()),
"category": "fact",
})
if direct:
return {"suggestions": direct, "filename": filename}
# Truncate very long documents
if len(text) > 15000:
text = text[:15000] + "\n[Truncated]"
# Send to LLM for memory extraction
import_prompt = (
"You are a memory extraction assistant. The user uploaded a document. "
"Analyze the text below and extract specific, useful facts — things like "
"names, preferences, jobs, locations, relationships, opinions, projects, "
"goals, contacts, or any other personal details worth remembering.\n\n"
"Rules:\n"
"- Each fact should be a short, self-contained statement\n"
"- Do NOT extract generic knowledge\n"
"- Focus on personal, memorable information\n"
"- If there are no useful facts, return an empty array\n\n"
"Return a JSON array of objects with 'text' and 'category' fields.\n"
"Categories: 'identity', 'preference', 'fact', 'contact', 'project', 'goal'\n\n"
"Return ONLY valid JSON, no markdown fences."
)
try:
raw = await llm_call_async(
sess.endpoint_url,
sess.model,
[
{"role": "system", "content": import_prompt},
{"role": "user", "content": f"Document: {filename}\n\n{text}"},
],
temperature=0.2,
max_tokens=2000,
headers=sess.headers,
)
# Parse JSON
raw = raw.strip()
if raw.startswith("```"):
raw = raw.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
suggestions = json.loads(raw)
if isinstance(suggestions, list):
normalized = []
for s in suggestions:
if not s:
continue
if isinstance(s, dict):
s = dict(s)
if s.get("text"):
s["text"] = _strip_list_prefix(str(s["text"]))
normalized.append(s)
else:
normalized.append({"text": _strip_list_prefix(str(s)), "category": "fact"})
suggestions = normalized
else:
suggestions = []
return {"suggestions": suggestions, "filename": filename}
except json.JSONDecodeError:
# Fallback: split by lines, stripping any "1.", "2)" markdown-list
# numbering the model added so saved memories don't keep the prefix.
lines = [_strip_list_prefix(l.strip()) for l in raw.splitlines() if l.strip() and len(l.strip()) > 5]
return {"suggestions": [{"text": l, "category": "fact"} for l in lines[:20]], "filename": filename}
except Exception as e:
logger.error(f"Memory import extraction failed: {e}")
raise HTTPException(502, f"LLM extraction failed: {str(e)}")
@router.post("/{memory_id}/pin")
def pin_memory(request: Request, memory_id: str, pinned: bool = Form(True)):
"""Pin or unpin a memory. Pinned memories are always included in context."""
user = _owner(request)
all_mem = memory_manager.load_all()
for i, memory in enumerate(all_mem):
if memory["id"] == memory_id:
_verify_memory_owner(memory, user)
all_mem[i]["pinned"] = pinned
memory_manager.save(all_mem)
return {"ok": True, "pinned": pinned}
raise HTTPException(404, f"Memory item {memory_id} not found")
# Wildcard routes MUST come last — otherwise they swallow /import, /search, etc.
@router.get("/{memory_id}")
def get_memory_item(request: Request, memory_id: str):
"""Get a specific memory item by ID."""
user = _owner(request)
memories = memory_manager.load(owner=user)
for memory in memories:
if memory["id"] == memory_id:
return {"memory": memory}
raise HTTPException(404, "Memory not found")
@router.put("/{memory_id}")
def update_memory(request: Request, memory_id: str, text: str = Form(...), category: str = Form(None)):
"""Update an existing memory item with new text and optional category."""
user = _owner(request)
all_mem = memory_manager.load_all()
for i, memory in enumerate(all_mem):
if memory["id"] == memory_id:
_verify_memory_owner(memory, user)
all_mem[i]["text"] = text.strip()
if category:
all_mem[i]["category"] = category
all_mem[i]["timestamp"] = int(time.time())
memory_manager.save(all_mem)
# Sync vector index (remove old, add updated)
if memory_vector and memory_vector.healthy:
memory_vector.remove(memory_id)
memory_vector.add(memory_id, text.strip())
return {"ok": True, "message": "Memory updated successfully"}
raise HTTPException(404, f"Memory item {memory_id} not found")
@router.delete("/{memory_id}")
def delete_memory(request: Request, memory_id: str):
"""Delete a memory item by its ID."""
user = _owner(request)
all_mem = memory_manager.load_all()
# Find and verify ownership before deleting
target = next((m for m in all_mem if m["id"] == memory_id), None)
if not target:
raise HTTPException(404, f"Memory item {memory_id} not found")
_verify_memory_owner(target, user)
all_mem = [m for m in all_mem if m["id"] != memory_id]
memory_manager.save(all_mem)
# Sync vector index
if memory_vector and memory_vector.healthy:
memory_vector.remove(memory_id)
return {"ok": True, "message": "Memory deleted successfully"}
return router

1226
routes/model_routes.py Normal file

File diff suppressed because it is too large Load Diff

741
routes/note_routes.py Normal file
View File

@@ -0,0 +1,741 @@
# routes/note_routes.py
"""Google Keep-style notes / checklists API."""
import json
import uuid
import logging
from typing import Dict, Any, Optional
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from core.database import SessionLocal, Note
from src.auth_helpers import get_current_user
from sqlalchemy.orm.attributes import flag_modified
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Request models
# ---------------------------------------------------------------------------
class NoteCreate(BaseModel):
title: str = ""
content: Optional[str] = None
items: Optional[list] = None
note_type: str = "note"
color: Optional[str] = None
label: Optional[str] = None
pinned: bool = False
due_date: Optional[str] = None
source: str = "user"
session_id: Optional[str] = None
image_url: Optional[str] = None
repeat: Optional[str] = "none"
sort_order: Optional[int] = None
class NoteUpdate(BaseModel):
title: Optional[str] = None
content: Optional[str] = None
items: Optional[list] = None
note_type: Optional[str] = None
color: Optional[str] = None
label: Optional[str] = None
pinned: Optional[bool] = None
archived: Optional[bool] = None
due_date: Optional[str] = None
image_url: Optional[str] = None
repeat: Optional[str] = None
sort_order: Optional[int] = None
agent_session_id: Optional[str] = None
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _note_to_dict(note: Note) -> Dict[str, Any]:
items = None
if note.items:
try:
items = json.loads(note.items)
except (json.JSONDecodeError, TypeError):
items = None
ai_cls = None
raw_ai = getattr(note, "ai_classification", None)
if raw_ai:
try:
ai_cls = json.loads(raw_ai)
except (json.JSONDecodeError, TypeError):
ai_cls = None
return {
"id": note.id,
"owner": note.owner,
"title": note.title,
"content": note.content,
"items": items,
"note_type": note.note_type,
"color": note.color,
"label": note.label,
"pinned": note.pinned,
"archived": note.archived,
"due_date": note.due_date,
"source": note.source,
"session_id": note.session_id,
"sort_order": note.sort_order or 0,
"image_url": note.image_url,
"repeat": note.repeat or "none",
"ai_classification": ai_cls,
"ai_content_hash": getattr(note, "ai_content_hash", None),
"agent_session_id": getattr(note, "agent_session_id", None),
"created_at": note.created_at.isoformat() if note.created_at else None,
"updated_at": note.updated_at.isoformat() if note.updated_at else None,
}
# ---------------------------------------------------------------------------
# Reminder dispatch — module-level so background tasks (built-in actions)
# can call it directly without an HTTP roundtrip + auth cookie. The route
# version below is a thin wrapper that pulls `owner` from the request.
# ---------------------------------------------------------------------------
# Scheduler reference — set by setup_note_routes() so dispatch_reminder can
# push a parallel in-app notification (frontend polls the scheduler's queue
# and fires real browser Notification(...) popups). Optional; works without it.
_scheduler_ref = None
async def dispatch_reminder(
title: str,
note_body: str,
note_id: str,
owner: str = "",
queue_browser: bool = True,
) -> dict:
"""Fire a reminder via the configured channel (browser/email/ntfy).
Args:
title: short headline shown to the user
note_body: longer body text
note_id: stable id (used as tag/dedupe in browser notifications)
owner: the user this reminder belongs to — scopes SMTP config to
their account so we don't cross-leak credentials
Returns: {synthesis, email_sent, ntfy_sent}. Browser channel is wired via
the in-memory notification queue picked up by the frontend poller, so
nothing is "sent" synchronously for it — the channel just routes there.
"""
from src.settings import load_settings
settings = load_settings()
channel = settings.get("reminder_channel", "browser")
llm_on = bool(settings.get("reminder_llm_synthesis", False))
title = (title or "").strip()
note_body = (note_body or "").strip()
cache_key = str(note_id) if note_id else ""
cache = {}
cache_path = None
if cache_key:
try:
import json as _json
from datetime import datetime as _dt, timezone as _tz, timedelta as _td
from pathlib import Path as _P
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
cache_path = _P(f"data/note_pings_{_slug}.json")
if cache_path.exists():
cache = _json.loads(cache_path.read_text())
last = cache.get(cache_key)
if last:
last_channel = None
if isinstance(last, dict):
last_channel = last.get("channel")
last = last.get("at")
last_dt = _dt.fromisoformat(str(last))
if last_dt.tzinfo is None:
last_dt = last_dt.replace(tzinfo=_tz.utc)
# Legacy cache values were plain timestamps and could be
# written by the frontend even when the email/ntfy send failed.
# Treat those as browser-only dedupe so email reminders can be
# retried by the backend scanner after a failed frontend path.
should_skip = last_dt >= _dt.now(_tz.utc) - _td(minutes=25)
if should_skip and channel in ("email", "ntfy"):
should_skip = last_channel == channel
if should_skip:
return {
"synthesis": None,
"email_sent": False,
"ntfy_sent": False,
"browser_sent": True,
"skipped": True,
}
except Exception as _e:
logger.debug(f"dispatch_reminder: cache read failed: {_e}")
synthesis = None
_SYNTH_FAILED_TAG = "[utility model unavailable — no summary generated]"
if llm_on:
try:
from src.endpoint_resolver import resolve_endpoint
from src.llm_core import llm_call_async
url, model, headers = resolve_endpoint("utility")
if not url:
url, model, headers = resolve_endpoint("default")
if url and model:
raw = await llm_call_async(
url=url, model=model,
messages=[
{"role": "system", "content": "You are a reminder assistant. Write a single short, warm, motivating sentence (max 25 words) reminding the user about the note below. Do not add greetings, preamble, or hashtags. Output only the sentence."},
{"role": "user", "content": f"Title: {title}\n\n{note_body}".strip()},
],
temperature=0.7, max_tokens=200, headers=headers, timeout=30,
)
from src.text_helpers import strip_think as _strip_think
# prose=True strips untagged "The user wants me to…" chain-of-thought.
# prompt_echo=True strips Qwen-style "Thinking Process:" / leaked
# prompt prefixes. Both are safe here because this is a
# one-sentence LLM-only output, not user-pasted content.
synthesis = _strip_think(raw or "", prose=True, prompt_echo=True)
# Reminder synthesis is supposed to be ONE sentence. Strip-think's
# paragraph-based heuristic misses cases where the model puts
# reasoning + answer on consecutive lines inside one paragraph
# (e.g. "I should write... [\n] You have one task waiting...").
# Walk lines, drop reasoning/prompt-echo lines, then keep the
# last surviving line — that's the actual warm sentence.
if synthesis:
import re as _re
# Tightened: target ACTUAL self-talk (model narrating what
# it'll do) rather than any first-person sentence. The old
# pattern killed legit warm sentences like "I'll see you
# tomorrow" or "I should be done by then". New rules:
# • "I (need|should|have|'ll|will) (write|draft|reply|…)"
# only matches when followed by a TASK verb taking an
# OBJECT (so first-person + intransitive verb passes).
# • Self-instructional patterns the model emits verbatim:
# "I should write something that reminds them…",
# "I need to draft…", "Let me think…".
# • Explicit instructions echoed back from the prompt:
# "Keep it under 25 words", "No greetings".
_reasoning = _re.compile(
r"^\s*(?:"
# "I should write/draft/compose…" with a task-object follow
r"i (?:need|should|have|'ll|will|am going|am)\s+to\s+"
r"(?:write|draft|compose|craft|generate|produce|create|"
r"summarize|answer|provide|note|address|remind|output)"
r"\s+(?:a |an |the |something|this|that|here|them|him|her|"
r"you|user|reply|response|sentence|message|line|warm)|"
# The model literally narrating about the user
r"the user (?:wants|is asking|asks|needs|wrote|said|requested) (?:me )?(?:to|for|that|about|something)|"
# "Let me [think/write/draft/…] (about/for/the …)"
r"let me (?:think|write|draft|consider|note|see|check)\b\s+(?:about|for|the|this|that|if|whether)|"
# "Looking at the/this/that …"
r"looking at (?:the|this|that)\b|"
# "Based on the/this/what …"
r"based on (?:the|this|what|context|that)\b|"
# Prompt-echo of length / style instructions
r"keep it under \d+ words\b|"
r"(?:no greetings|no preamble|no hashtags|just output the)\b"
r").*",
_re.IGNORECASE,
)
# Echo of the prompt's "Pending:" / "<N> pending" tail.
_echo = _re.compile(
r"^\s*(?:pending\s*[:.]|(?:\d+|one|two|three|four|five)\s+pending\b)",
_re.IGNORECASE,
)
lines = [ln for ln in synthesis.splitlines() if ln.strip()]
cleaned = [ln for ln in lines if not _reasoning.match(ln) and not _echo.match(ln)]
if cleaned:
# The model's actual answer is normally the LAST surviving
# line — reasoning leads, answer trails.
synthesis = cleaned[-1].strip()
else:
synthesis = _SYNTH_FAILED_TAG
except Exception as e:
logger.warning(f"Reminder LLM synthesis failed: {e}")
synthesis = _SYNTH_FAILED_TAG
if synthesis:
_s = synthesis.strip(); _low = _s.lower()
if (not _s or _low.startswith("error:") or _low.startswith("[error")
or "operation failed" in _low
or ("upstream" in _low and "failed" in _low)) and synthesis != _SYNTH_FAILED_TAG:
logger.warning(f"Reminder synthesis looked like an error, replacing: {_s[:120]!r}")
synthesis = _SYNTH_FAILED_TAG
email_sent = False
email_error = ""
if channel == "email":
try:
from routes.email_routes import _get_email_config
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from datetime import datetime as _dt
# `reminder_email_account_id` lets the user pick WHICH email
# account to send reminders from (when they have several
# configured in Integrations). Falls back to the default
# account when no explicit choice is saved.
_acc_id = (settings.get("reminder_email_account_id") or "").strip() or None
cfg = _get_email_config(account_id=_acc_id, owner=owner or "")
if not (cfg.get("smtp_host") and cfg.get("smtp_user") and cfg.get("smtp_password")):
try:
from core.database import SessionLocal as _SL, EmailAccount as _EA
from sqlalchemy import and_, or_
db = _SL()
try:
q = db.query(_EA).filter(_EA.enabled == True) # noqa: E712
if owner:
unowned = or_(_EA.owner == None, _EA.owner == "") # noqa: E711
same_mailbox = or_(_EA.imap_user == owner, _EA.from_address == owner)
q = q.filter(or_(_EA.owner == owner, and_(unowned, same_mailbox)))
for row in q.order_by(_EA.is_default.desc(), _EA.created_at.asc()).all():
trial = _get_email_config(account_id=row.id, owner=owner or "")
if trial.get("smtp_host") and trial.get("smtp_user") and trial.get("smtp_password"):
cfg = trial
break
finally:
db.close()
except Exception as _fallback_error:
logger.debug(f"Reminder SMTP fallback lookup failed: {_fallback_error}")
from_addr = (cfg.get("from_address") or cfg.get("smtp_user") or "").strip()
recipient = (settings.get("reminder_email_to") or "").strip() or from_addr
# Loud diagnostic so we can see WHY a reminder didn't send (the
# previous "silently no-op when cfg has no smtp_host" was invisible).
logger.info(
f"dispatch_reminder[email] note_id={note_id} owner={owner!r} "
f"smtp_host={cfg.get('smtp_host')!r} smtp_user={cfg.get('smtp_user')!r} "
f"from={from_addr!r} recipient={recipient!r} "
f"account_name={cfg.get('account_name')!r}"
)
missing = []
if not cfg.get("smtp_host"):
missing.append("SMTP host")
if not cfg.get("smtp_user"):
missing.append("SMTP user")
if not cfg.get("smtp_password"):
missing.append("SMTP password")
if not from_addr:
missing.append("from address")
if not recipient:
missing.append("recipient")
if missing:
email_error = "Missing " + ", ".join(missing)
logger.warning(
"Reminder email not sent for note_id=%s account=%r: %s",
note_id, cfg.get("account_name"), email_error,
)
else:
msg = MIMEMultipart("alternative")
msg["From"] = from_addr
msg["To"] = recipient
_t = title or 'Note'
_t = _t[len('Reminder:'):].strip() if _t.lower().startswith('reminder:') else _t
msg["Subject"] = f"Reminder (Odysseus): {_t}"
msg["Date"] = _dt.utcnow().strftime("%a, %d %b %Y %H:%M:%S +0000")
msg["X-Odysseus-Origin"] = "odysseus-ui"
msg["X-Odysseus-Kind"] = "reminder"
msg["X-Odysseus-Ref"] = str(note_id)
# Body shape: synthesis (warm sentence) → blank line → bold
# title header → note details. The title was previously only
# in the subject line, so the email read like a faceless
# to-do list with no anchor to which note triggered it.
_body_chunks = []
if synthesis:
_body_chunks.append(synthesis)
if _t:
_body_chunks.append(_t)
if note_body:
_body_chunks.append(note_body)
plain = "\n\n".join(_body_chunks) if _body_chunks else title
msg.attach(MIMEText(plain, "plain", "utf-8"))
def _smtp_send():
from routes.email_helpers import _send_smtp_message
_send_smtp_message(cfg, from_addr, [recipient], msg.as_string())
import asyncio as _aio
await _aio.to_thread(_smtp_send)
email_sent = True
except Exception as e:
email_error = str(e) or e.__class__.__name__
logger.warning(f"Reminder email send failed: {e}")
ntfy_sent = False
ntfy_error = ""
if channel == "ntfy":
try:
from src.integrations import load_integrations
import httpx
intg = next(
(i for i in load_integrations()
if i.get("preset") == "ntfy" and i.get("enabled", True) and i.get("base_url")),
None,
)
if intg:
base = intg["base_url"].rstrip("/")
topic = settings.get("reminder_ntfy_topic") or "reminders"
ntfy_body = synthesis or note_body or title
hdrs = {"Title": title or "Reminder", "Priority": "high", "Tags": "bell"}
api_key = intg.get("api_key", "")
if api_key:
hdrs["Authorization"] = f"Bearer {api_key}"
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(f"{base}/{topic}", content=ntfy_body, headers=hdrs)
ntfy_sent = resp.is_success
if not ntfy_sent:
ntfy_error = f"ntfy returned HTTP {resp.status_code}"
else:
ntfy_error = "No enabled ntfy integration"
except Exception as e:
ntfy_error = str(e) or e.__class__.__name__
logger.warning(f"Reminder ntfy send failed: {e}")
# In-app browser notification ALWAYS fires (regardless of channel). The
# frontend polls `/api/tasks/notifications` and turns any entry with a
# `body` into a real `Notification(...)` — same surface as task-success
# popups. Lets the user see reminders inside the app even when the
# primary channel is email/ntfy and the tab is open.
browser_sent = False
local_browser_sent = (not queue_browser and channel == "browser")
if queue_browser and _scheduler_ref is not None:
try:
_scheduler_ref.add_notification(
task_name=title or "Reminder",
status="success",
task_id=f"reminder-{note_id}",
owner=owner or None,
body=(synthesis or note_body or title or "").strip()[:500] or "Reminder",
)
browser_sent = True
except Exception as _e:
logger.debug(f"dispatch_reminder: in-app notif push failed: {_e}")
# Dedupe across paths: write to the same cache file `action_ping_notes`
# reads, so the background scanner's REPING_MIN window suppresses a
# second send for the same note within 25 min. Without this, a note
# whose due_date fires while the user has the app open got TWO emails
# (frontend-fired here + background-fired by ping_notes 05 min later).
if (email_sent or ntfy_sent or browser_sent or local_browser_sent) and note_id:
try:
import json as _json
from datetime import datetime as _dt, timezone as _tz
from pathlib import Path as _P
# Per-owner cache so the scanner's prune step on user A's run
# doesn't drop user B's just-fired entry (review C4).
_STATE = cache_path
if _STATE is None:
_slug = "".join(c if (c.isalnum() or c in "-_.@") else "_" for c in (owner or "default"))
_STATE = _P(f"data/note_pings_{_slug}.json")
_STATE.parent.mkdir(parents=True, exist_ok=True)
try:
_cache = cache or (_json.loads(_STATE.read_text()) if _STATE.exists() else {})
except Exception:
_cache = {}
sent_channel = "email" if email_sent else "ntfy" if ntfy_sent else "browser"
_cache[cache_key or str(note_id)] = {
"at": _dt.now(_tz.utc).isoformat(),
"channel": sent_channel,
}
_STATE.write_text(_json.dumps(_cache))
except Exception as _e:
logger.debug(f"dispatch_reminder: cache write failed: {_e}")
return {
"synthesis": synthesis,
"email_sent": email_sent,
"email_error": email_error,
"ntfy_sent": ntfy_sent,
"ntfy_error": ntfy_error,
"browser_sent": browser_sent or local_browser_sent,
}
# ---------------------------------------------------------------------------
# Router factory
# ---------------------------------------------------------------------------
def setup_note_routes(task_scheduler=None):
# Expose the scheduler to module-level `dispatch_reminder` so reminders
# can also push to the in-app notification queue (the polling system
# turns each entry into a real browser Notification + the existing
# tasks-tab badge / dot system).
global _scheduler_ref
_scheduler_ref = task_scheduler
router = APIRouter(prefix="/api/notes", tags=["notes"])
def _owner(request: Request) -> Optional[str]:
return get_current_user(request)
# --- LIST ---
@router.get("")
def list_notes(
request: Request,
archived: Optional[bool] = None,
label: Optional[str] = None,
):
user = _owner(request)
db = SessionLocal()
try:
q = db.query(Note)
if user is not None:
q = q.filter(Note.owner == user)
if archived is not None:
q = q.filter(Note.archived == archived)
else:
q = q.filter(Note.archived == False)
if label:
q = q.filter(Note.label == label)
# Archived view: most recently archived first. Active view: pin + manual order.
if archived is True:
notes = q.order_by(Note.updated_at.desc()).all()
else:
notes = q.order_by(Note.pinned.desc(), Note.sort_order.asc(), Note.updated_at.desc()).all()
return {"notes": [_note_to_dict(n) for n in notes]}
finally:
db.close()
# --- CREATE ---
@router.post("")
def create_note(request: Request, body: NoteCreate):
user = _owner(request)
db = SessionLocal()
try:
note = Note(
id=str(uuid.uuid4()),
owner=user,
title=body.title,
content=body.content,
items=json.dumps(body.items) if body.items is not None else None,
note_type=body.note_type,
color=body.color,
label=body.label,
pinned=body.pinned,
due_date=body.due_date,
source=body.source,
session_id=body.session_id,
image_url=body.image_url,
repeat=body.repeat or "none",
sort_order=body.sort_order if body.sort_order is not None else 0,
)
db.add(note)
db.commit()
db.refresh(note)
return _note_to_dict(note)
finally:
db.close()
# --- GET ONE ---
@router.get("/{note_id}")
def get_note(request: Request, note_id: str):
user = _owner(request)
db = SessionLocal()
try:
note = db.query(Note).filter(Note.id == note_id).first()
if not note:
raise HTTPException(404, "Note not found")
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
# let any user touch a row whose owner field was null/empty.
if user is not None and note.owner != user:
raise HTTPException(404, "Note not found")
return _note_to_dict(note)
finally:
db.close()
# --- UPDATE ---
@router.put("/{note_id}")
def update_note(request: Request, note_id: str, body: NoteUpdate):
user = _owner(request)
db = SessionLocal()
try:
note = db.query(Note).filter(Note.id == note_id).first()
if not note:
raise HTTPException(404, "Note not found")
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
# let any user touch a row whose owner field was null/empty.
if user is not None and note.owner != user:
raise HTTPException(404, "Note not found")
if body.title is not None:
note.title = body.title
if body.content is not None:
note.content = body.content
if body.items is not None:
note.items = json.dumps(body.items)
flag_modified(note, "items")
if body.note_type is not None:
note.note_type = body.note_type
if body.color is not None:
note.color = body.color
if body.label is not None:
note.label = body.label
if body.pinned is not None:
note.pinned = body.pinned
if body.archived is not None:
note.archived = body.archived
if body.due_date is not None:
note.due_date = body.due_date
if body.image_url is not None:
note.image_url = body.image_url
if body.repeat is not None:
note.repeat = body.repeat
if body.sort_order is not None:
note.sort_order = body.sort_order
if body.agent_session_id is not None:
note.agent_session_id = body.agent_session_id
db.commit()
db.refresh(note)
return _note_to_dict(note)
finally:
db.close()
# --- DELETE ---
@router.delete("/{note_id}")
def delete_note(request: Request, note_id: str):
user = _owner(request)
db = SessionLocal()
try:
note = db.query(Note).filter(Note.id == note_id).first()
if not note:
raise HTTPException(404, "Note not found")
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
# let any user touch a row whose owner field was null/empty.
if user is not None and note.owner != user:
raise HTTPException(404, "Note not found")
db.delete(note)
db.commit()
return {"ok": True}
finally:
db.close()
# --- TOGGLE PIN ---
@router.post("/{note_id}/pin")
def toggle_pin(request: Request, note_id: str):
user = _owner(request)
db = SessionLocal()
try:
note = db.query(Note).filter(Note.id == note_id).first()
if not note:
raise HTTPException(404, "Note not found")
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
# let any user touch a row whose owner field was null/empty.
if user is not None and note.owner != user:
raise HTTPException(404, "Note not found")
note.pinned = not note.pinned
db.commit()
return {"ok": True, "pinned": note.pinned}
finally:
db.close()
# --- TOGGLE ARCHIVE ---
@router.post("/{note_id}/archive")
def toggle_archive(request: Request, note_id: str):
user = _owner(request)
db = SessionLocal()
try:
note = db.query(Note).filter(Note.id == note_id).first()
if not note:
raise HTTPException(404, "Note not found")
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
# let any user touch a row whose owner field was null/empty.
if user is not None and note.owner != user:
raise HTTPException(404, "Note not found")
note.archived = not note.archived
db.commit()
return {"ok": True, "archived": note.archived}
finally:
db.close()
# --- TOGGLE CHECKLIST ITEM ---
@router.post("/{note_id}/items/{index}/toggle")
def toggle_item(request: Request, note_id: str, index: int):
user = _owner(request)
db = SessionLocal()
try:
note = db.query(Note).filter(Note.id == note_id).first()
if not note:
raise HTTPException(404, "Note not found")
# SECURITY: strict ownership — previously `note.owner and note.owner != user`
# let any user touch a row whose owner field was null/empty.
if user is not None and note.owner != user:
raise HTTPException(404, "Note not found")
if not note.items:
raise HTTPException(400, "Note has no checklist items")
items = json.loads(note.items)
if index < 0 or index >= len(items):
raise HTTPException(400, f"Item index {index} out of range")
items[index]["done"] = not items[index].get("done", False)
note.items = json.dumps(items)
flag_modified(note, "items")
db.commit()
return {"ok": True, "items": items}
finally:
db.close()
# --- FIRE REMINDER ---
@router.post("/fire-reminder")
async def fire_reminder(request: Request):
"""Dispatch a reminder according to user settings.
Called by the frontend when a reminder fires. Optionally generates an
LLM synthesis line and/or sends an email through configured SMTP.
Returns {synthesis, email_sent}.
"""
# Gate against anonymous callers — LLM synthesis can burn tokens.
from src.auth_helpers import get_current_user as _gcu
if not _gcu(request):
raise HTTPException(401, "Not authenticated")
body = await request.json()
note_id = body.get("note_id")
title = (body.get("title") or "").strip()
note_body = (body.get("body") or "").strip()
if not note_id:
raise HTTPException(400, "note_id required")
# Delegate to the module-level helper so background tasks can reuse
# the same dispatch without an HTTP roundtrip + auth cookie.
return await dispatch_reminder(
title=title, note_body=note_body, note_id=note_id,
owner=_gcu(request) or "",
queue_browser=False,
)
# --- REORDER NOTES ---
@router.post("/reorder")
async def reorder_notes(request: Request):
"""Update sort_order for a list of note IDs in the order provided."""
user = _owner(request)
body = await request.json()
ids = body.get("ids", [])
if not isinstance(ids, list):
raise HTTPException(400, "ids must be a list")
# v2 review HIGH-12: drop the legacy `(owner == user) | (owner ==
# None)` OR which let an authenticated user silently reorder
# every legacy-null-owner note belonging to other accounts. In
# an unconfigured (single-user) auth deploy the OR is still safe
# because there's no second user to attack; we keep that branch
# explicit and gated on AuthManager.is_configured.
try:
from core.auth import AuthManager
_allow_null = not AuthManager().is_configured
except Exception:
_allow_null = False
db = SessionLocal()
try:
for i, nid in enumerate(ids):
q = db.query(Note).filter(Note.id == nid)
if user is not None:
if _allow_null:
q = q.filter((Note.owner == user) | (Note.owner == None)) # noqa: E711
else:
q = q.filter(Note.owner == user)
note = q.first()
if note:
note.sort_order = i
db.commit()
return {"ok": True, "count": len(ids)}
finally:
db.close()
return router

276
routes/personal_routes.py Normal file
View File

@@ -0,0 +1,276 @@
# routes/personal_routes.py
"""Routes for personal documents management."""
import os
import logging
from typing import List
from fastapi import APIRouter, HTTPException, Query, Request, UploadFile, File, Depends
from src.request_models import DirectoryRequest
from core.constants import BASE_DIR, PERSONAL_DIR
from src.rag_singleton import get_rag_manager
from src.auth_helpers import get_current_user, require_user
from core.middleware import require_admin
from src.upload_handler import secure_filename
UPLOADS_DIR = os.path.join(BASE_DIR, "data", "personal_uploads")
logger = logging.getLogger(__name__)
def setup_personal_routes(personal_docs_manager, rag_manager, rag_available):
"""
Setup personal documents related routes.
Args:
personal_docs_manager: PersonalDocsManager instance
rag_manager: RAG manager instance (may be None)
rag_available: Boolean indicating if RAG is available
Returns:
APIRouter instance with personal docs routes
"""
router = APIRouter(prefix="/api/personal")
def _rag():
"""Get the current RAG manager, retrying init if needed."""
return get_rag_manager()
def _resolve_allowed_personal_dir(directory: str) -> str:
"""Resolve a user-supplied personal-docs path under the allowed root."""
if not directory:
raise HTTPException(400, "Directory path is required")
base_abs = os.path.abspath(PERSONAL_DIR)
candidate = directory if os.path.isabs(directory) else os.path.join(base_abs, directory)
resolved = os.path.abspath(candidate)
try:
in_base = os.path.commonpath([resolved, base_abs]) == base_abs
except ValueError:
in_base = False
if not in_base:
raise HTTPException(403, "Directory must be inside personal documents")
return resolved
@router.get("")
def api_personal_list(owner: str = Depends(require_user), _admin: None = Depends(require_admin)):
"""Enhanced version that includes directories"""
files = [{"name": f["name"], "size": f["size"], "path": f.get("path", "")} for f in personal_docs_manager.index]
directories = personal_docs_manager.get_indexed_directories() if hasattr(personal_docs_manager, "get_indexed_directories") else []
return {"files": files, "directories": directories}
@router.post("/reload")
def api_personal_reload(owner: str = Depends(require_user), _admin: None = Depends(require_admin)):
personal_docs_manager.refresh_index()
return {"ok": True, "count": len(personal_docs_manager.index)}
@router.post("/add_directory")
async def add_directory_to_rag(
request: Request,
directory_request: DirectoryRequest,
owner: str = Depends(require_user), _admin: None = Depends(require_admin),
):
"""
Add a directory and all its subdirectories/files to the RAG index.
Args:
directory_request: Directory request model containing the directory path
Returns:
JSON response with indexing results
"""
directory = directory_request.directory
try:
directory = _resolve_allowed_personal_dir(directory)
# Security check - ensure directory exists and is accessible
if not os.path.exists(directory):
raise HTTPException(404, f"Directory not found: {directory}")
if not os.path.isdir(directory):
raise HTTPException(400, f"Path is not a directory: {directory}")
logger.info(f"Adding directory to RAG: {directory}")
# Use the RAGManager to index the directory
rag = _rag()
if rag:
result = rag.index_personal_documents(directory, owner=owner)
if result["success"]:
# Also update the personal_docs_manager to track this directory
personal_docs_manager.add_directory(directory, index=False)
return {
"success": True,
"message": f"Successfully indexed {result['indexed_count']} chunks from {directory}",
"indexed_count": result["indexed_count"],
"failed_count": result.get("failed_count", 0),
"directory": directory
}
else:
raise HTTPException(500, result.get("message", "Failed to index directory"))
else:
raise HTTPException(503, "RAG system is not available")
except HTTPException:
raise
except Exception as e:
logger.error(f"Error adding directory to RAG: {e}")
raise HTTPException(500, f"Failed to add directory: {str(e)}")
@router.delete("/remove_directory")
async def remove_directory_from_rag(directory: str = Query(...), owner: str = Depends(require_user), _admin: None = Depends(require_admin)):
"""
Remove a directory from the RAG index.
Args:
directory: Path to the directory to remove
Returns:
JSON response confirming removal
"""
try:
if not directory:
raise HTTPException(400, "Directory path is required")
logger.info(f"Removing directory from RAG: {directory}")
# Always remove from personal_docs_manager tracking
if hasattr(personal_docs_manager, 'remove_directory'):
personal_docs_manager.remove_directory(directory)
# Remove from RAG vector store (best-effort)
rag = _rag()
if rag:
try:
rag.remove_directory(directory)
except Exception as e:
logger.warning(f"RAG removal failed for directory {directory}: {e}")
return {
"success": True,
"message": f"Successfully removed {directory} from RAG index",
"directory": directory
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error removing directory from RAG: {e}")
raise HTTPException(500, f"Failed to remove directory: {str(e)}")
@router.post("/upload")
async def upload_files_to_rag(request: Request, files: List[UploadFile] = File(...)):
"""Upload files directly into RAG. Supports text and PDF."""
user = get_current_user(request)
rag = _rag()
if not rag:
raise HTTPException(503, "RAG system is not available — is the embedding service running?")
os.makedirs(UPLOADS_DIR, exist_ok=True)
total_indexed = 0
total_failed = 0
uploaded_files = []
for upload in files:
try:
# Sanitize filename — strip directory components and unsafe chars
safe_name = secure_filename(os.path.basename(upload.filename or "upload"))
if not safe_name or safe_name.startswith("."):
safe_name = f"upload_{total_indexed + total_failed}"
file_path = os.path.join(UPLOADS_DIR, safe_name)
# Defense-in-depth: ensure resolved path stays under UPLOADS_DIR
base_abs = os.path.abspath(UPLOADS_DIR)
if os.path.commonpath([os.path.abspath(file_path), base_abs]) != base_abs:
logger.warning(f"Rejected unsafe upload path: {upload.filename!r}")
total_failed += 1
continue
content_bytes = await upload.read()
with open(file_path, "wb") as f:
f.write(content_bytes)
ext = os.path.splitext(safe_name)[1].lower()
if ext == ".pdf":
from src.personal_docs import extract_pdf_text
text = extract_pdf_text(file_path)
else:
text = content_bytes.decode("utf-8", errors="replace")
if not text or not text.strip():
total_failed += 1
continue
# Chunk and index
chunks = rag._split_into_chunks(text, chunk_size=500)
for i, chunk in enumerate(chunks):
metadata = {
"source": file_path,
"filename": safe_name,
"directory": UPLOADS_DIR,
"type": ext,
"chunk_id": i,
}
if user:
metadata["owner"] = user
if rag.add_document(chunk, metadata):
total_indexed += 1
else:
total_failed += 1
uploaded_files.append(safe_name)
except Exception as e:
logger.error(f"Failed to upload/index {upload.filename}: {e}")
total_failed += 1
# Track uploads directory
if uploaded_files and hasattr(personal_docs_manager, "add_directory"):
personal_docs_manager.add_directory(UPLOADS_DIR, index=False)
return {
"success": True,
"uploaded": uploaded_files,
"indexed_count": total_indexed,
"failed_count": total_failed,
}
@router.delete("/file")
async def delete_file_from_rag(filepath: str = Query(...), owner: str = Depends(require_user), _admin: None = Depends(require_admin)):
"""Delete a specific file from RAG index and optionally from disk."""
try:
# Remove chunks from RAG vector store (best-effort)
removed = 0
rag = _rag()
if rag:
try:
removed = rag.delete_by_source(filepath)
except Exception as e:
logger.warning(f"RAG removal failed for {filepath}: {e}")
# Delete file from disk if it's in uploads dir
deleted_from_disk = False
try:
abs_target = os.path.abspath(filepath)
base_abs = os.path.abspath(UPLOADS_DIR)
in_uploads = (
abs_target == base_abs
or os.path.commonpath([abs_target, base_abs]) == base_abs
)
except ValueError:
# commonpath raises on mixed drives / non-comparable paths
in_uploads = False
if in_uploads and abs_target != base_abs and os.path.exists(abs_target):
os.remove(abs_target)
deleted_from_disk = True
# Exclude the file from the listing (persists across restarts)
personal_docs_manager.exclude_file(filepath)
return {
"success": True,
"removed_chunks": removed,
"deleted_from_disk": deleted_from_disk,
}
except Exception as e:
logger.error(f"Failed to delete file {filepath}: {e}")
raise HTTPException(500, f"Failed to delete file: {str(e)}")
return router

74
routes/prefs_routes.py Normal file
View File

@@ -0,0 +1,74 @@
"""User preferences API — per-user key/value store backed by a JSON file."""
import json
import os
from typing import Optional
from fastapi import APIRouter, Request
from src.auth_helpers import get_current_user
PREFS_FILE = os.path.join("data", "user_prefs.json")
def _load():
"""Load the raw prefs file (internal use only)."""
try:
with open(PREFS_FILE, "r") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return {}
def _save(prefs):
os.makedirs(os.path.dirname(PREFS_FILE), exist_ok=True)
with open(PREFS_FILE, "w") as f:
json.dump(prefs, f, indent=2)
def _load_for_user(user: Optional[str] = None) -> dict:
"""Load preferences for a specific user."""
all_prefs = _load()
if "_users" in all_prefs:
if user is None:
# Auth disabled — return first user's prefs for backward compat
users = all_prefs["_users"]
return dict(next(iter(users.values()), {}))
return dict(all_prefs["_users"].get(user, {}))
# Legacy flat format — return as-is
return dict(all_prefs)
def _save_for_user(user: Optional[str], prefs: dict):
"""Save preferences for a specific user."""
all_prefs = _load()
if user is None:
# Auth disabled — save flat
_save(prefs)
return
if "_users" not in all_prefs:
all_prefs = {"_users": {}}
all_prefs["_users"][user] = prefs
_save(all_prefs)
def setup_prefs_routes():
router = APIRouter(prefix="/api/prefs", tags=["preferences"])
@router.get("")
async def get_all_prefs(request: Request):
user = get_current_user(request)
return _load_for_user(user)
@router.get("/{key}")
async def get_pref(request: Request, key: str):
user = get_current_user(request)
prefs = _load_for_user(user)
return {"key": key, "value": prefs.get(key)}
@router.put("/{key}")
async def set_pref(request: Request, key: str, body: dict):
user = get_current_user(request)
prefs = _load_for_user(user)
prefs[key] = body.get("value")
_save_for_user(user, prefs)
return {"key": key, "value": prefs[key]}
return router

123
routes/preset_routes.py Normal file
View File

@@ -0,0 +1,123 @@
"""Preset routes — /api/presets GET, /api/presets/custom POST, user templates CRUD."""
import logging
import uuid
from typing import Dict, Any, List
from fastapi import APIRouter, HTTPException, Request, Depends
from pydantic import BaseModel, Field
from src.request_models import PresetUpdateRequest
from core.middleware import require_admin
logger = logging.getLogger(__name__)
class UserTemplateRequest(BaseModel):
id: str = ""
name: str = Field(..., min_length=1, max_length=100)
system_prompt: str = Field("", max_length=10000)
temperature: float = Field(1.0, ge=0.0, le=2.0)
max_tokens: int = Field(0, ge=0, le=65536)
def setup_preset_routes(preset_manager) -> APIRouter:
router = APIRouter(tags=["presets"])
@router.get("/api/presets")
async def get_presets() -> Dict[str, Any]:
return preset_manager.presets
@router.post("/api/presets/custom")
async def update_custom_preset(preset_update: PresetUpdateRequest, _admin: None = Depends(require_admin)) -> Dict[str, Any]:
try:
success = preset_manager.update_custom(
preset_update.temperature,
preset_update.max_tokens,
preset_update.system_prompt,
preset_update.name,
preset_update.enabled,
preset_update.inject_prefix,
preset_update.inject_suffix,
)
if success:
return {"success": True, "message": "Custom preset updated"}
return {"success": False, "message": "Failed to save preset"}
except Exception as e:
logger.error(f"Preset update error: {e}")
raise HTTPException(500, "Failed to update custom preset")
@router.get("/api/presets/templates")
async def get_user_templates() -> List[Dict]:
return preset_manager.get_user_templates()
@router.post("/api/presets/templates")
async def save_user_template(req: UserTemplateRequest, _admin: None = Depends(require_admin)) -> Dict[str, Any]:
template = req.model_dump()
if not template["id"]:
template["id"] = f"user-{uuid.uuid4().hex[:8]}"
success = preset_manager.save_user_template(template)
if success:
return {"success": True, "template": template}
return {"success": False, "message": "Failed to save template"}
@router.delete("/api/presets/templates/{template_id}")
async def delete_user_template(template_id: str, _admin: None = Depends(require_admin)) -> Dict[str, Any]:
success = preset_manager.delete_user_template(template_id)
if success:
return {"success": True}
return {"success": False, "message": "Failed to delete template"}
@router.post("/api/presets/expand")
async def expand_character_prompt(request: Request) -> Dict[str, Any]:
"""Use AI to expand a rough character description into a full system prompt."""
from src.ai_interaction import _resolve_model
from src.llm_core import llm_call_async
data = await request.json()
draft = (data.get("prompt") or "").strip()
name = (data.get("name") or "").strip()
if not draft and not name:
return {"success": False, "message": "Nothing to expand"}
user_input = ""
if name:
user_input += f"Character name: {name}\n"
if draft:
user_input += f"Notes: {draft}\n"
messages = [
{"role": "system", "content": (
"You are an expert at writing character system prompts for AI assistants. "
"The user will give you a character name and/or rough notes. "
"Write a concise, effective system prompt (3-6 sentences) that captures the character's personality, "
"speaking style, knowledge areas, and behavioral guidelines. "
"Output ONLY the system prompt text — no quotes, no preamble, no explanation."
)},
{"role": "user", "content": user_input},
]
try:
model_spec = data.get("model") or ""
url, model, headers = _resolve_model(model_spec)
result = await llm_call_async(url, model, messages, temperature=0.8, max_tokens=500, headers=headers)
return {"success": True, "prompt": result.strip()}
except Exception as e:
logger.error(f"Expand prompt failed: {e}")
return {"success": False, "message": str(e)}
# ── Group presets ──
@router.get("/api/presets/groups")
async def get_group_presets():
"""Get saved group chat presets."""
return {"groups": preset_manager.get_group_presets()}
@router.post("/api/presets/groups")
async def save_group_presets(request: Request, _admin: None = Depends(require_admin)):
"""Save group chat presets."""
data = await request.json()
preset_manager.save_group_presets(data.get("groups", []))
return {"ok": True}
return router

607
routes/research_routes.py Normal file
View File

@@ -0,0 +1,607 @@
"""Research background task routes — /api/research/*."""
import asyncio
import json
import logging
import uuid
from datetime import datetime
from pathlib import Path
from typing import Optional
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, StreamingResponse
from pydantic import BaseModel, Field
from src.endpoint_resolver import resolve_endpoint
from src.auth_helpers import get_current_user
logger = logging.getLogger(__name__)
# Model-name substrings that are NOT chat/generation models — research must
# never pick these as its model. An OpenAI-style endpoint often lists
# `text-embedding-ada-002` etc. first in its model list, which is why research
# was failing with "Cannot reach model 'text-embedding-ada-002'".
_NON_CHAT_MODEL = (
"text-embedding", "embedding", "tts-", "whisper", "dall-e",
"moderation", "rerank", "reranker", "clip", "stable-diffusion",
)
def _first_chat_model(models) -> str:
"""First model that isn't an embedding/tts/etc. — falls back to models[0]."""
for m in (models or []):
if not any(p in str(m).lower() for p in _NON_CHAT_MODEL):
return m
return (models[0] if models else "")
def _resolve_research_endpoint(sess) -> tuple:
"""Return (endpoint_url, model, headers) for Deep Research, checking admin overrides."""
url, model, headers = resolve_endpoint(
"research",
fallback_url=sess.endpoint_url,
fallback_model=sess.model,
fallback_headers=sess.headers,
)
return url, model, headers
def setup_research_routes(research_handler, session_manager=None) -> APIRouter:
router = APIRouter(tags=["research"])
def _require_user(request: Request) -> str:
"""All research endpoints require an authenticated user. Research
data isn't owner-scoped in the on-disk JSON yet, so we at least
block anonymous access. Multi-tenant deploys should additionally
verify the session belongs to this user."""
user = get_current_user(request)
if not user:
raise HTTPException(401, "Not authenticated")
return user
def _owns_in_memory(session_id: str, user: str) -> bool:
"""Ownership check for an in-flight (in-memory) research task.
Falls back to the on-disk JSON if the task has already finished."""
entry = research_handler._active_tasks.get(session_id)
if entry is not None:
return entry.get("owner", "") == user
# Task no longer in memory — check the persisted JSON.
path = Path("data/deep_research") / f"{session_id}.json"
if not path.exists():
return False
try:
return json.loads(path.read_text()).get("owner") == user
except Exception:
return False
@router.get("/api/research/active")
async def research_active(request: Request):
"""List all currently active (running) research tasks."""
user = _require_user(request)
active = []
for sid, entry in research_handler._active_tasks.items():
# SECURITY: only show this user's running tasks.
if entry.get("owner", "") != user:
continue
if entry.get("status") == "running":
active.append({
"session_id": sid,
"query": entry.get("query", ""),
"status": "running",
"progress": entry.get("progress", {}),
"started_at": entry.get("started_at", 0),
})
return {"active": active}
@router.get("/api/research/status/{session_id}")
async def research_status(session_id: str, request: Request):
user = _require_user(request)
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research found for this session")
status = research_handler.get_status(session_id)
if status is None:
raise HTTPException(404, "No research found for this session")
return status
@router.post("/api/research/cancel/{session_id}")
async def research_cancel(session_id: str, request: Request):
user = _require_user(request)
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research found for this session")
cancelled = research_handler.cancel_research(session_id)
return {"cancelled": cancelled}
@router.post("/api/research/result/{session_id}")
async def research_result(session_id: str, request: Request):
user = _require_user(request)
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research result available")
result = research_handler.get_result(session_id)
if result is None:
raise HTTPException(404, "No research result available")
sources = research_handler.get_sources(session_id) or []
raw_findings = research_handler.get_raw_findings(session_id) or []
research_handler.clear_result(session_id)
return {"result": result, "sources": sources, "raw_findings": raw_findings}
def _assert_owns_research(session_id: str, user: str) -> None:
"""404-not-403 ownership gate for a research session's on-disk JSON.
Use BEFORE returning any data or mutating the file."""
path = Path("data/deep_research") / f"{session_id}.json"
if not path.exists():
raise HTTPException(404, "Research not found")
try:
owner = json.loads(path.read_text()).get("owner")
except Exception:
raise HTTPException(404, "Research not found")
if owner != user:
raise HTTPException(404, "Research not found")
@router.get("/api/research/report/{session_id}")
async def research_report(session_id: str, request: Request):
"""Serve the visual HTML report for a completed research session."""
user = _require_user(request)
_assert_owns_research(session_id, user)
logger.info(f"Visual report requested for session {session_id}")
try:
html_content = research_handler.get_report_html(session_id)
except Exception as e:
logger.error(f"Visual report generation error: {e}", exc_info=True)
raise HTTPException(500, f"Report generation failed: {e}")
if html_content is None:
logger.warning(f"No report data found for session {session_id}")
raise HTTPException(404, "No visual report available for this session")
return HTMLResponse(content=html_content)
class HideImageRequest(BaseModel):
url: str
@router.post("/api/research/{session_id}/hide-image")
async def research_hide_image(session_id: str, body: HideImageRequest, request: Request):
"""Mark an image URL as hidden for this research's visual report.
Persisted to the research JSON so subsequent /report renders skip it."""
user = _require_user(request)
_assert_owns_research(session_id, user)
ok = research_handler.hide_image(session_id, body.url)
if not ok:
raise HTTPException(404, "Research not found")
return {"ok": True}
@router.post("/api/research/{session_id}/unhide-images")
async def research_unhide_images(session_id: str, request: Request):
"""Clear the hidden-images list for a research session."""
user = _require_user(request)
_assert_owns_research(session_id, user)
ok = research_handler.unhide_all_images(session_id)
if not ok:
raise HTTPException(404, "Research not found")
return {"ok": True}
@router.get("/api/research/library")
async def research_library(
request: Request,
search: Optional[str] = Query(None),
sort: str = Query("recent"),
limit: int = Query(50),
archived: bool = Query(False),
):
user = _require_user(request)
"""List all completed research for the Library panel."""
data_dir = Path("data/deep_research")
items = []
for p in data_dir.glob("*.json"):
try:
d = json.loads(p.read_text())
# SECURITY: only show research belonging to this user. Legacy
# JSONs without an `owner` field are hidden — auth was the only
# gate before, so every user saw every other user's reports.
if d.get("owner") != user:
continue
# Archived view shows ONLY archived reports; default hides them.
if bool(d.get("archived")) != archived:
continue
query = d.get("query", "")
if search and search.lower() not in query.lower():
continue
sources = d.get("sources", [])
items.append({
"id": p.stem,
"query": query,
"category": d.get("category") or "",
"source_count": len(sources),
"status": d.get("status", "done"),
"duration": d.get("stats", {}).get("Duration", ""),
"rounds": d.get("stats", {}).get("Rounds", ""),
"started_at": d.get("started_at", 0),
"completed_at": d.get("completed_at", 0),
"archived": bool(d.get("archived")),
})
except Exception:
continue
# Sort
if sort == "recent":
items.sort(key=lambda x: x["completed_at"] or 0, reverse=True)
elif sort == "oldest":
items.sort(key=lambda x: x["completed_at"] or 0)
elif sort == "most-messages":
items.sort(key=lambda x: x["source_count"], reverse=True)
elif sort == "alpha":
items.sort(key=lambda x: x["query"].lower())
return {"research": items[:limit], "total": len(items)}
@router.get("/api/research/detail/{session_id}")
async def research_detail(session_id: str, request: Request):
"""Return the full JSON for a single research result — sources,
summary, stats — used by the Library preview panel."""
user = _require_user(request)
path = Path("data/deep_research") / f"{session_id}.json"
if not path.exists():
raise HTTPException(404, "Research not found")
try:
data = json.loads(path.read_text())
except Exception as e:
raise HTTPException(500, f"Failed to read research: {e}")
# SECURITY: 404 (not 403) so we don't leak that the report exists.
if data.get("owner") != user:
raise HTTPException(404, "Research not found")
return data
@router.post("/api/research/{session_id}/archive")
async def research_archive(session_id: str, request: Request, archived: bool = Query(True)):
"""Soft-archive / restore a research report (sets `archived` in its JSON)."""
user = _require_user(request)
path = Path("data/deep_research") / f"{session_id}.json"
if not path.exists():
raise HTTPException(404, "Research not found")
try:
data = json.loads(path.read_text())
if data.get("owner") != user:
raise HTTPException(404, "Research not found")
data["archived"] = bool(archived)
path.write_text(json.dumps(data))
except HTTPException:
raise
except Exception as e:
raise HTTPException(500, f"Failed to update research: {e}")
return {"ok": True, "id": session_id, "archived": bool(archived)}
@router.delete("/api/research/{session_id}")
async def research_delete(session_id: str, request: Request):
"""Delete a research result from disk."""
user = _require_user(request)
data_dir = Path("data/deep_research")
json_path = data_dir / f"{session_id}.json"
deleted = False
if json_path.exists():
# SECURITY: verify ownership before letting the caller delete it.
try:
data = json.loads(json_path.read_text())
if data.get("owner") != user:
raise HTTPException(404, "Research not found")
except HTTPException:
raise
except Exception:
raise HTTPException(404, "Research not found")
json_path.unlink()
deleted = True
return {"deleted": deleted}
# ------------------------------------------------------------------
# Panel endpoints — launch research without a chat session
# ------------------------------------------------------------------
class ResearchStartRequest(BaseModel):
query: str
# max_rounds=0 means "Auto" — let the AI decide when to stop, capped at 20.
max_rounds: int = Field(default=0, ge=0, le=20)
search_provider: Optional[str] = None
endpoint_id: Optional[str] = None
model: Optional[str] = None
max_time: int = Field(default=300, ge=60, le=1800)
category: Optional[str] = None
@router.post("/api/research/start")
async def research_start(body: ResearchStartRequest, request: Request):
"""Launch a research job from the dedicated panel."""
from src.auth_helpers import require_privilege
user = require_privilege(request, "can_use_research")
if user == "internal-tool":
tool_owner = (request.headers.get("X-Odysseus-Owner") or "").strip()
if tool_owner and tool_owner not in {"internal-tool", "api", "demo", "system"}:
auth_mgr = getattr(request.app.state, "auth_manager", None)
if auth_mgr is not None and getattr(auth_mgr, "is_configured", False):
try:
privs = auth_mgr.get_privileges(tool_owner) or {}
if not privs.get("can_use_research", True):
raise HTTPException(403, f"Your account is not allowed to can use research.")
except HTTPException:
raise
except Exception:
pass
user = tool_owner
session_id = f"rp-{uuid.uuid4().hex[:12]}"
if body.endpoint_id:
from src.database import SessionLocal
from src.database import ModelEndpoint
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(
ModelEndpoint.id == body.endpoint_id,
ModelEndpoint.is_enabled == True,
).first()
if not ep:
raise HTTPException(404, "Endpoint not found or disabled")
base = normalize_base(ep.base_url)
ep_url = build_chat_url(base)
ep_headers = build_headers(ep.api_key, base)
ep_model = body.model or ""
if not ep_model:
try:
import json as _json
models = _json.loads(ep.cached_models) if ep.cached_models else []
if models:
ep_model = _first_chat_model(models)
except Exception:
pass
finally:
db.close()
else:
ep_url, ep_model, ep_headers = resolve_endpoint("research")
if not ep_url:
ep_url, ep_model, ep_headers = resolve_endpoint("utility")
# When neither research nor utility is configured, use the user's
# configured DEFAULT model (default_endpoint_id/default_model) rather
# than arbitrarily grabbing the first enabled endpoint's first model
# (which surfaced gpt-3.5). "Default" should mean the default model.
if not ep_url:
ep_url, ep_model, ep_headers = resolve_endpoint("default")
if not ep_url:
ep_url, ep_model, ep_headers = resolve_endpoint("chat")
if not ep_url:
from src.database import SessionLocal
from src.database import ModelEndpoint
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(
ModelEndpoint.is_enabled == True,
).first()
if ep:
base = normalize_base(ep.base_url)
ep_url = build_chat_url(base)
ep_headers = build_headers(ep.api_key, base)
ep_model = ""
if ep.cached_models:
try:
import json as _json
models = _json.loads(ep.cached_models)
if models:
ep_model = _first_chat_model(models)
except Exception:
pass
finally:
db.close()
if not ep_url:
raise HTTPException(400, "No endpoints configured. Add one in Settings first.")
if body.model:
ep_model = body.model
# max_rounds=0 → "Auto", let AI decide; pass 20 as the safety cap.
effective_max_rounds = body.max_rounds if body.max_rounds > 0 else 20
research_handler.start_research(
session_id=session_id,
query=body.query,
llm_endpoint=ep_url,
llm_model=ep_model,
max_time=body.max_time,
llm_headers=ep_headers,
max_rounds=effective_max_rounds,
search_provider=body.search_provider or None,
category=body.category or None,
owner=user,
)
return {"session_id": session_id, "status": "running", "query": body.query}
@router.get("/api/research/stream/{session_id}")
async def research_stream(session_id: str, request: Request):
"""SSE stream of research progress events."""
user = _require_user(request)
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research found for this session")
async def _generate():
last_progress = None
while True:
status = research_handler.get_status(session_id)
if status is None:
yield f"data: {json.dumps({'status': 'not_found'})}\n\n"
return
st = status.get("status", "")
progress = status.get("progress", {})
if progress != last_progress:
last_progress = progress
yield f"data: {json.dumps({**progress, 'status': st})}\n\n"
if st != "running":
final = {'status': st, 'final': True}
task = research_handler._active_tasks.get(session_id, {})
if st == "error" and task.get("result"):
final['error'] = str(task["result"])[:500]
yield f"data: {json.dumps(final)}\n\n"
return
await asyncio.sleep(1.5)
return StreamingResponse(
_generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@router.post("/api/research/result-peek/{session_id}")
async def research_result_peek(session_id: str, request: Request):
"""Get research result without clearing it (for panel use)."""
user = _require_user(request)
if not _owns_in_memory(session_id, user):
raise HTTPException(404, "No research found for this session")
result = research_handler.get_result(session_id)
if result is None:
p = Path("data/deep_research") / f"{session_id}.json"
if p.exists():
d = json.loads(p.read_text())
return {
"result": d.get("result", ""),
"sources": d.get("sources", []),
"raw_findings": d.get("raw_findings", []),
"category": d.get("category") or "",
}
raise HTTPException(404, "No research result available")
sources = research_handler.get_sources(session_id) or []
raw_findings = research_handler.get_raw_findings(session_id) or []
return {"result": result, "sources": sources, "raw_findings": raw_findings, "category": ""}
@router.post("/api/research/spinoff/{session_id}")
async def research_spinoff(session_id: str, request: Request):
"""Create a new chat session pre-seeded with this research as context.
Reads the persisted research result + sources for `session_id`, creates
a fresh session (inheriting endpoint/model/headers from the source
session if available, otherwise from the resolved chat endpoint), and
injects a single system message containing the report and sources so
the user can ask follow-up questions in a clean conversation.
"""
_require_user(request)
if session_manager is None:
raise HTTPException(500, "session_manager not configured")
# Load research data — prefer in-memory result, fall back to disk
result = research_handler.get_result(session_id)
sources = research_handler.get_sources(session_id) or []
query = ""
path = Path("data/deep_research") / f"{session_id}.json"
if path.exists():
try:
disk = json.loads(path.read_text())
if not result:
result = disk.get("result")
if not sources:
sources = disk.get("sources", []) or []
query = disk.get("query", "") or ""
except Exception as e:
logger.warning(f"Could not read research JSON for spinoff: {e}")
if not result:
raise HTTPException(404, "No research result available for this session")
# Inherit endpoint/model/headers from the source session when possible.
# For panel-launched research (rp-* IDs), there is no chat session, so
# fall back through the same chain as /api/research/start: research →
# utility → first enabled endpoint in the DB.
ep_url, ep_model, ep_headers = "", "", {}
try:
src_sess = session_manager.get_session(session_id)
ep_url = src_sess.endpoint_url or ""
ep_model = src_sess.model or ""
ep_headers = dict(src_sess.headers or {})
except KeyError:
pass
def _merge(r_url, r_model, r_headers):
nonlocal ep_url, ep_model, ep_headers
if not ep_url and r_url:
ep_url = r_url
if not ep_model and r_model:
ep_model = r_model
if not ep_headers and r_headers:
ep_headers = dict(r_headers)
if not ep_url or not ep_model:
_merge(*resolve_endpoint("chat"))
if not ep_url or not ep_model:
_merge(*resolve_endpoint("research"))
if not ep_url or not ep_model:
_merge(*resolve_endpoint("utility"))
if not ep_url or not ep_model:
# Last resort: any enabled endpoint
from src.database import SessionLocal
from src.database import ModelEndpoint
from src.endpoint_resolver import normalize_base, build_chat_url, build_headers
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
if ep:
base = normalize_base(ep.base_url)
fallback_url = build_chat_url(base)
fallback_headers = build_headers(ep.api_key, base)
fallback_model = ""
if ep.cached_models:
try:
models = json.loads(ep.cached_models)
if models:
fallback_model = models[0]
except Exception:
pass
_merge(fallback_url, fallback_model, fallback_headers)
finally:
db.close()
if not ep_url or not ep_model:
raise HTTPException(400, "No endpoint configured — add one in Settings first")
# Create new session
new_sid = str(uuid.uuid4())
user = get_current_user(request)
title_query = (query or "research").strip()
if len(title_query) > 60:
title_query = title_query[:57] + ""
new_name = f"Follow-up: {title_query}"
new_sess = session_manager.create_session(
session_id=new_sid,
name=new_name,
endpoint_url=ep_url,
model=ep_model,
rag=False,
owner=user,
)
if ep_headers:
new_sess.headers = ep_headers
session_manager.save_sessions()
try:
from src.event_bus import fire_event
fire_event("session_created", user)
except Exception:
logger.debug("session_created event dispatch failed", exc_info=True)
# Build the priming system message — report only, no sources injected.
# The user can open the visual report for source details; keeping sources
# out of the chat context saves tokens and avoids the AI fabricating
# citations.
date_str = datetime.utcnow().strftime("%Y-%m-%d")
primer = (
f"[Research context — {date_str}]\n\n"
f"The user previously ran a deep research investigation. Use the "
f"report below as your primary knowledge base when answering "
f"follow-up questions. If the user asks something not covered, "
f"say so plainly rather than guessing.\n\n"
f"=== ORIGINAL QUERY ===\n{query or '(not recorded)'}\n\n"
f"=== REPORT ===\n{result}"
)
from core.models import ChatMessage
new_sess.add_message(ChatMessage(
role="system",
content=primer,
metadata={"research_spinoff_from": session_id},
))
session_manager.save_sessions()
return {
"session_id": new_sid,
"name": new_name,
"source_count": len(sources),
}
return router

111
routes/search_routes.py Normal file
View File

@@ -0,0 +1,111 @@
"""Search routes — /api/search/config GET, /api/search POST."""
import logging
from typing import Dict, Any
from fastapi import APIRouter, Request
import time
from services.search import get_search_config, comprehensive_web_search, PROVIDER_INFO
from services.search.core import _call_provider
from services.search.providers import _get_provider_key, _get_search_instance
logger = logging.getLogger(__name__)
async def _request_values(request: Request) -> Dict[str, Any]:
"""Accept JSON, form data, or query params for search endpoints.
The browser UI posts FormData, while the agent's generic app_api tool
posts JSON. FastAPI Form(...) rejects JSON with a 422 before our handler
runs, which made the model think SearXNG was broken.
"""
values: Dict[str, Any] = dict(request.query_params)
content_type = (request.headers.get("content-type") or "").lower()
try:
if "application/json" in content_type:
body = await request.json()
if isinstance(body, dict):
values.update(body)
else:
form = await request.form()
values.update(dict(form))
except Exception:
pass
return values
def setup_search_routes(config) -> APIRouter:
router = APIRouter(tags=["search"])
@router.get("/api/search/config")
async def get_search_settings() -> Dict[str, Any]:
return get_search_config()
@router.post("/api/search")
async def do_web_search(request: Request) -> Dict[str, Any]:
"""Standalone web search — returns context string + source list.
Used by Compare mode to pre-search once and share results across panes.
"""
values = await _request_values(request)
query = str(values.get("query") or values.get("q") or "").strip()
if not query:
return {"context": "", "sources": [], "error": "query is required"}
time_filter = values.get("time_filter") or values.get("freshness")
if time_filter is not None:
time_filter = str(time_filter).strip() or None
try:
context, sources = comprehensive_web_search(
query, return_sources=True, time_filter=time_filter,
)
return {"context": context, "sources": sources}
except Exception as e:
logger.error(f"Standalone web search failed: {e}")
return {"context": "", "sources": [], "error": str(e)}
@router.get("/api/search/providers")
async def list_search_providers():
"""Return available search providers with config status."""
providers = []
for pid, (label, needs_key, needs_url) in PROVIDER_INFO.items():
if pid == "disabled":
continue
available = True
if needs_key and not _get_provider_key(pid):
available = False
if needs_url and pid == "searxng" and not _get_search_instance():
available = False
providers.append({
"id": pid,
"label": label,
"available": available,
})
return providers
@router.post("/api/search/query")
async def search_with_provider(request: Request) -> Dict[str, Any]:
"""Search using a specific provider. Used by compare search mode."""
values = await _request_values(request)
query = str(values.get("query") or values.get("q") or "").strip()
provider = str(values.get("provider") or "").strip()
try:
count = int(values.get("count") or values.get("limit") or 10)
except Exception:
count = 10
if not query:
return {"results": [], "provider": provider, "error": "query is required"}
if provider not in PROVIDER_INFO or provider == "disabled":
return {"results": [], "provider": provider, "error": "Unknown provider"}
t0 = time.time()
try:
results = _call_provider(provider, query, min(count, 20))
elapsed = round(time.time() - t0, 2)
return {"results": results, "provider": provider, "time": elapsed}
except Exception as e:
elapsed = round(time.time() - t0, 2)
logger.error(f"Search provider {provider} failed: {e}")
return {"results": [], "provider": provider, "time": elapsed, "error": str(e)}
return router

1059
routes/session_routes.py Normal file

File diff suppressed because it is too large Load Diff

608
routes/shell_routes.py Normal file
View File

@@ -0,0 +1,608 @@
"""Shell routes — user-facing command execution endpoint."""
import asyncio
import json
import logging
import os
import pty
import fcntl
import shlex
import shutil
import uuid
import tempfile
from pathlib import Path
from typing import Dict, Any
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
def _require_admin(request: Request):
"""Reject non-admin callers. Shell exec is admin-only — never expose to
regular users; that's RCE-after-signup."""
auth_manager = getattr(request.app.state, "auth_manager", None)
if not auth_manager:
# No auth at all — only safe in fully-trusted localhost dev mode
return
user = getattr(request.state, "current_user", None)
# In-process tool loopback. The AuthMiddleware already validated the
# internal token + loopback client before setting this marker, so
# honour it here as admin-equivalent.
if user == "internal-tool":
return
if not user or user == "api":
raise HTTPException(403, "Admin only")
if not auth_manager.is_admin(user):
raise HTTPException(403, "Admin only")
logger = logging.getLogger(__name__)
def _find_line_break(buf):
"""Find next line terminator in buffer. Returns (index, separator_length) or (-1, 0)."""
ni = buf.find(b"\n")
ri = buf.find(b"\r")
if ni == -1 and ri == -1:
return -1, 0
if ni == -1:
return ri, 1
if ri == -1:
return ni, 1
if ri < ni:
return ri, (2 if ri + 1 == ni else 1)
return ni, 1
EXEC_TIMEOUT = 30 # seconds — shorter than agent's 60s
STREAM_TIMEOUT = 120 # default for short commands
MAX_OUTPUT = 200_000 # truncate limit
TMUX_LOG_DIR = Path(tempfile.gettempdir()) / "odysseus-tmux"
class ShellExecRequest(BaseModel):
command: str
timeout: int | None = None # optional override; 0 = no timeout (run until client disconnects)
use_pty: bool = False # use pseudo-TTY (for progress bars)
use_tmux: bool = False # run in tmux session (survives browser disconnect)
async def _exec_shell(command: str, timeout: int = EXEC_TIMEOUT) -> Dict[str, Any]:
"""Run a shell command and return stdout/stderr/exit_code."""
proc = None
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(Path.home()),
)
stdout_b, stderr_b = await asyncio.wait_for(
proc.communicate(), timeout=timeout
)
stdout = stdout_b.decode(errors="replace")[:MAX_OUTPUT]
stderr = stderr_b.decode(errors="replace")[:MAX_OUTPUT]
return {"stdout": stdout, "stderr": stderr, "exit_code": proc.returncode}
except asyncio.TimeoutError:
if proc:
try:
proc.kill()
await proc.wait()
except ProcessLookupError:
pass
return {"stdout": "", "stderr": f"Command timed out after {timeout}s", "exit_code": -1}
except Exception as e:
return {"stdout": "", "stderr": str(e), "exit_code": -1}
async def _generate_pty(cmd: str, timeout: int, request: Request):
"""Run command in a pseudo-TTY so tqdm/progress bars work natively."""
loop = asyncio.get_event_loop()
master_fd, slave_fd = pty.openpty()
# Set master to non-blocking
flags = fcntl.fcntl(master_fd, fcntl.F_GETFL)
fcntl.fcntl(master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
proc = await asyncio.create_subprocess_shell(
cmd,
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
cwd=str(Path.home()),
preexec_fn=os.setsid,
)
os.close(slave_fd) # parent doesn't need the slave side
deadline = (loop.time() + timeout) if timeout else None
buf = b""
process_done = asyncio.Event()
async def _wait_proc():
await proc.wait()
process_done.set()
wait_task = asyncio.create_task(_wait_proc())
try:
while not process_done.is_set():
if deadline and loop.time() > deadline:
proc.kill()
await proc.wait()
yield f"data: {json.dumps({'stream': 'stderr', 'data': f'Command timed out after {timeout}s'})}\n\n"
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
return
# Check client disconnect
if await request.is_disconnected():
proc.kill()
await proc.wait()
return
# Read available data from PTY
try:
chunk = await asyncio.wait_for(
loop.run_in_executor(None, _pty_read, master_fd),
timeout=2.0,
)
except asyncio.TimeoutError:
continue
except OSError:
break
if chunk is None:
# No data yet, keep waiting
continue
if chunk == b"":
# EOF — process closed the PTY
break
buf += chunk
# Split on \r or \n
while True:
idx, sep_len = _find_line_break(buf)
if idx == -1:
break
line = buf[:idx].decode(errors="replace")
buf = buf[idx + sep_len:]
if line:
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
# Drain any remaining PTY output after process exits
try:
while True:
rest = _pty_read(master_fd)
if rest is None or rest == b"":
break
buf += rest
except OSError:
pass
# Flush remaining buffer
if buf:
# Split remaining buffer same as above
while True:
idx, sep_len = _find_line_break(buf)
if idx == -1:
break
line = buf[:idx].decode(errors="replace")
buf = buf[idx + sep_len:]
if line:
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
if buf:
text = buf.decode(errors="replace").strip()
if text:
yield f"data: {json.dumps({'stream': 'stdout', 'data': text})}\n\n"
await wait_task
yield f"data: {json.dumps({'exit_code': proc.returncode})}\n\n"
except Exception as e:
try:
proc.kill()
await proc.wait()
except ProcessLookupError:
pass
yield f"data: {json.dumps({'stream': 'stderr', 'data': str(e)})}\n\n"
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
finally:
wait_task.cancel()
try:
os.close(master_fd)
except OSError:
pass
def _pty_read(fd: int) -> bytes | None:
"""Blocking read from PTY fd. Called via run_in_executor.
Returns bytes on data, None on timeout (no data yet)."""
import select
r, _, _ = select.select([fd], [], [], 1.0)
if r:
try:
data = os.read(fd, 4096)
return data if data else b"" # empty = EOF
except OSError:
return b"" # fd closed = EOF
return None # timeout, no data yet
async def _generate_tmux(cmd: str, request: Request):
"""Run command in a tmux session. Streams output via a log file.
The tmux session survives browser disconnect — user can reconnect or
`tmux attach -t <name>` to see it live."""
TMUX_LOG_DIR.mkdir(parents=True, exist_ok=True)
session_id = f"cookbook-{uuid.uuid4().hex[:8]}"
log_path = TMUX_LOG_DIR / f"{session_id}.log"
# Write a wrapper script that runs the command, tees output, and records exit code.
# Using a script avoids shell quoting issues with the tmux command.
script_path = TMUX_LOG_DIR / f"{session_id}.sh"
script_path.write_text(
f"#!/bin/bash\n"
f"ODYSSEUS_USER_SHELL=\"${{SHELL:-}}\"\n"
f"if [ -n \"$ODYSSEUS_USER_SHELL\" ] && [ -x \"$ODYSSEUS_USER_SHELL\" ]; then\n"
f" ODYSSEUS_USER_PATH=\"$(\"$ODYSSEUS_USER_SHELL\" -ic 'printf \"__ODYSSEUS_PATH__%s\\n\" \"$PATH\"' 2>/dev/null | sed -n 's/^__ODYSSEUS_PATH__//p' | tail -n 1 || true)\"\n"
f" if [ -n \"$ODYSSEUS_USER_PATH\" ]; then export PATH=\"$ODYSSEUS_USER_PATH:$PATH\"; fi\n"
f"fi\n"
f"{cmd} 2>&1 | tee '{log_path}'\n"
f"EC=${{PIPESTATUS[0]}}\n"
f"echo ':::EXIT_CODE:::'$EC >> '{log_path}'\n"
f"rm -f '{script_path}'\n"
f"exit $EC\n"
)
script_path.chmod(0o755)
logger.info("tmux wrapper script created: session=%s path=%s", session_id, script_path)
tmux_cmd = f"tmux new-session -d -s {session_id} {shlex.quote(str(script_path))}"
proc = await asyncio.create_subprocess_shell(
tmux_cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await proc.wait()
if proc.returncode != 0:
stderr = (await proc.stderr.read()).decode(errors="replace")
yield f"data: {json.dumps({'stream': 'stderr', 'data': f'Failed to start tmux: {stderr}'})}\n\n"
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
return
yield f"data: {json.dumps({'stream': 'stdout', 'data': f'Started tmux session: {session_id}'})}\n\n"
# Tail the log file, streaming new lines as SSE
lines_sent = 0
exit_code = None
while True:
# Check client disconnect
if await request.is_disconnected():
# tmux keeps running — that's the whole point
yield f"data: {json.dumps({'stream': 'stdout', 'data': f'Disconnected. tmux session {session_id} continues in background.'})}\n\n"
return
# Read new lines from log
try:
if log_path.exists():
lines = log_path.read_text(errors="replace").splitlines()
new_lines = lines[lines_sent:]
for line in new_lines:
if line.startswith(":::EXIT_CODE:::"):
try:
exit_code = int(line.split(":::")[-1])
except ValueError:
exit_code = -1
else:
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
lines_sent = len(lines)
except Exception as e:
logger.debug(f"tmux log read error: {e}")
if exit_code is not None:
break
# Check if tmux session is still alive
check = await asyncio.create_subprocess_shell(
f"tmux has-session -t {session_id} 2>/dev/null",
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await check.wait()
if check.returncode != 0:
# Session ended — do one final read
await asyncio.sleep(0.5)
if log_path.exists():
lines = log_path.read_text(errors="replace").splitlines()
for line in lines[lines_sent:]:
if line.startswith(":::EXIT_CODE:::"):
try:
exit_code = int(line.split(":::")[-1])
except ValueError:
exit_code = -1
else:
yield f"data: {json.dumps({'stream': 'stdout', 'data': line})}\n\n"
if exit_code is None:
exit_code = 0
break
await asyncio.sleep(1.0)
yield f"data: {json.dumps({'exit_code': exit_code})}\n\n"
# Clean up log file
try:
log_path.unlink(missing_ok=True)
except Exception:
pass
def setup_shell_routes() -> APIRouter:
router = APIRouter(tags=["shell"])
@router.post("/api/shell/exec")
async def shell_exec(request: Request, req: ShellExecRequest) -> Dict[str, Any]:
"""Execute a shell command and return output. Admin only."""
_require_admin(request)
cmd = req.command.strip()
if not cmd:
return {"stdout": "", "stderr": "No command provided", "exit_code": 1}
logger.info("User shell exec requested: length=%d", len(cmd))
result = await _exec_shell(cmd, timeout=EXEC_TIMEOUT)
return result
@router.post("/api/shell/stream")
async def shell_stream(request: Request, req: ShellExecRequest):
"""Execute a shell command and stream output line-by-line via SSE. Admin only."""
_require_admin(request)
cmd = req.command.strip()
if not cmd:
async def empty():
yield f"data: {json.dumps({'stream': 'stderr', 'data': 'No command provided'})}\n\n"
yield f"data: {json.dumps({'exit_code': 1})}\n\n"
return StreamingResponse(empty(), media_type="text/event-stream")
timeout = req.timeout if req.timeout is not None else STREAM_TIMEOUT
use_pty = req.use_pty
use_tmux = req.use_tmux
logger.info(
"User shell stream requested: timeout=%s pty=%s tmux=%s length=%d",
"none" if timeout == 0 else f"{timeout}s",
use_pty,
use_tmux,
len(cmd),
)
if use_tmux:
return StreamingResponse(
_generate_tmux(cmd, request),
media_type="text/event-stream",
)
if use_pty:
return StreamingResponse(
_generate_pty(cmd, timeout, request),
media_type="text/event-stream",
)
async def generate():
proc = None
reader_tasks = []
try:
proc = await asyncio.create_subprocess_shell(
cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=str(Path.home()),
)
q: asyncio.Queue = asyncio.Queue()
async def _reader(stream, name):
"""Read chunks, split on \\n or \\r for progress bar support."""
try:
buf = b""
while True:
chunk = await stream.read(4096)
if not chunk:
if buf:
await q.put((name, buf.decode(errors="replace").rstrip("\r\n")))
break
buf += chunk
while True:
idx, sep_len = _find_line_break(buf)
if idx == -1:
break
line = buf[:idx].decode(errors="replace")
buf = buf[idx + sep_len:]
if line:
await q.put((name, line))
finally:
await q.put((name, None))
reader_tasks = [
asyncio.create_task(_reader(proc.stdout, "stdout")),
asyncio.create_task(_reader(proc.stderr, "stderr")),
]
finished = 0
deadline = (asyncio.get_event_loop().time() + timeout) if timeout else None
while finished < 2:
if deadline:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
raise asyncio.TimeoutError()
wait = min(remaining, 2.0)
else:
wait = 2.0
try:
name, text = await asyncio.wait_for(q.get(), timeout=wait)
except asyncio.TimeoutError:
if await request.is_disconnected():
if proc:
proc.kill()
return
continue
if text is None:
finished += 1
continue
yield f"data: {json.dumps({'stream': name, 'data': text})}\n\n"
await proc.wait()
yield f"data: {json.dumps({'exit_code': proc.returncode})}\n\n"
except asyncio.TimeoutError:
if proc:
try:
proc.kill()
await proc.wait()
except ProcessLookupError:
pass
yield f"data: {json.dumps({'stream': 'stderr', 'data': f'Command timed out after {timeout}s'})}\n\n"
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'stream': 'stderr', 'data': str(e)})}\n\n"
yield f"data: {json.dumps({'exit_code': -1})}\n\n"
finally:
for t in reader_tasks:
t.cancel()
return StreamingResponse(generate(), media_type="text/event-stream")
@router.get("/api/cookbook/packages")
async def list_packages(host: str | None = None, ssh_port: str | None = None, venv: str | None = None):
"""Check which optional packages are installed.
Local-target packages are checked in-process. Remote-target packages
(vllm, sglang, llama_cpp, diffusers, hf_transfer) are checked on the SELECTED
server over SSH, inside its venv — otherwise installing on a remote box
never reflected because the check only ever looked at the local host.
"""
import importlib, shlex, json as _json
packages = [
# ── System ── OS binaries, not pip packages
{"name": "tmux", "pip": "", "desc": "Required for Linux/Termux Cookbook background downloads and serves", "category": "System", "target": "remote", "kind": "system", "install_hint": "Run Cookbook server setup, or install tmux with apt/pacman/dnf/apk/zypper."},
{"name": "docker", "pip": "", "desc": "Required only for Docker-backed launch commands", "category": "System", "target": "remote", "kind": "system", "install_hint": "Install Docker on the selected server and allow this user to run docker."},
# ── LLM ── installs on GPU servers for model serving/downloading
{"name": "hf_transfer", "pip": "hf_transfer", "desc": "Fast model downloads from HuggingFace", "category": "LLM", "target": "remote"},
{"name": "llama_cpp", "pip": "llama-cpp-python[server]", "desc": "Serve GGUF models via llama.cpp", "category": "LLM", "target": "remote"},
{"name": "sglang", "pip": "sglang[all]", "desc": "Serve HF safetensors models via SGLang", "category": "LLM", "target": "remote"},
{"name": "vllm", "pip": "vllm", "desc": "High-throughput LLM serving engine", "category": "LLM", "target": "remote"},
# ── Image ── editor + diffusion model serving
{"name": "diffusers", "pip": "diffusers", "desc": "Image generation pipelines (SD, Flux)", "category": "Image", "target": "remote"},
{"name": "rembg", "pip": "rembg[gpu]", "desc": "AI background removal for image editor", "category": "Image", "target": "local"},
{"name": "realesrgan", "pip": "realesrgan", "desc": "AI denoise + upscale (Real-ESRGAN). Used by editor's Denoise and Upscale tools.", "category": "Image", "target": "local"},
# ── Tools ──
{"name": "playwright", "pip": "playwright", "desc": "Browser automation for web tools", "category": "Tools", "target": "local"},
]
# Remote check: for remote-target packages, probe the selected server's
# venv over SSH so a remote `pip install` actually reflects here.
remote_status: dict = {}
remote_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") != "system"]
remote_system_names = [p["name"] for p in packages if p.get("target") == "remote" and p.get("kind") == "system"]
if host and remote_names:
try:
names_lit = ",".join(repr(n) for n in remote_names)
py = (
"import importlib.util,json,shutil;"
f"names=[{names_lit}];"
"status={n:(importlib.util.find_spec(n) is not None) for n in names};"
"status['llama_cpp']=status.get('llama_cpp',False) or shutil.which('llama-server') is not None;"
"print(json.dumps(status))"
)
src = ""
if venv:
act = venv if venv.endswith("/bin/activate") else venv.rstrip("/") + "/bin/activate"
# NOT shlex.quoted: a leading ~ must stay shell-expandable on
# the remote (quoting it breaks `~/venv` → activation fails →
# the && short-circuits and every package reads as missing).
src = f". {act} && "
inner = f"{src}python3 -c {shlex.quote(py)}"
pf = f"-p {ssh_port} " if ssh_port and ssh_port not in ("", "22") else ""
ssh_cmd = (
f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {pf}"
f"{shlex.quote(host)} {shlex.quote(inner)}"
)
proc = await asyncio.create_subprocess_shell(
ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
txt = out.decode("utf-8", errors="replace").strip()
# The activate script can emit noise — take the last JSON line.
for line in reversed(txt.splitlines()):
line = line.strip()
if line.startswith("{"):
remote_status = _json.loads(line)
break
except Exception:
remote_status = {}
if host and remote_system_names:
try:
checks = []
for name in remote_system_names:
qn = shlex.quote(name)
checks.append(f"if command -v {qn} >/dev/null 2>&1; then echo {qn}=1; else echo {qn}=0; fi")
inner = " ; ".join(checks)
pf = f"-p {ssh_port} " if ssh_port and ssh_port not in ("", "22") else ""
ssh_cmd = (
f"ssh -o ConnectTimeout=6 -o StrictHostKeyChecking=no {pf}"
f"{shlex.quote(host)} {shlex.quote(inner)}"
)
proc = await asyncio.create_subprocess_shell(
ssh_cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
out, _err = await asyncio.wait_for(proc.communicate(), timeout=12)
txt = out.decode("utf-8", errors="replace").strip()
for line in txt.splitlines():
name, sep, value = line.strip().partition("=")
if sep and name in remote_system_names:
remote_status[name] = value == "1"
except Exception:
pass
for pkg in packages:
if host and pkg.get("target") == "remote":
pkg["installed"] = bool(remote_status.get(pkg["name"], False))
continue
if pkg.get("kind") == "system":
pkg["installed"] = shutil.which(pkg["name"]) is not None
continue
try:
if pkg["name"] == "llama_cpp" and shutil.which("llama-server"):
pkg["installed"] = True
continue
importlib.import_module(pkg["name"])
pkg["installed"] = True
except ImportError:
pkg["installed"] = False
return {"packages": packages}
@router.post("/api/cookbook/packages/install")
async def install_package(request: Request):
"""Install a package via pip. Admin only — pip install is effectively code exec."""
_require_admin(request)
import sys as _sys
body = await request.json()
pip_name = body.get("pip")
if not pip_name:
return {"ok": False, "error": "No package specified"}
# Validate against known packages to prevent arbitrary pip install
known = {
"rembg[gpu]", "hf_transfer", "llama-cpp-python[server]", "sglang[all]", "diffusers",
"TTS", "bark", "faster-whisper", "playwright", "realesrgan", "gfpgan",
"insightface", "onnxruntime-gpu", "onnxruntime", "hdbscan",
}
if pip_name not in known:
return {"ok": False, "error": f"Unknown package: {pip_name}"}
cmd = [_sys.executable, "-m", "pip", "install", pip_name]
proc = await asyncio.create_subprocess_exec(
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await proc.communicate()
if proc.returncode == 0:
return {"ok": True, "output": stdout.decode()[-200:]}
return {"ok": False, "error": stderr.decode()[-300:]}
return router

123
routes/signature_routes.py Normal file
View File

@@ -0,0 +1,123 @@
"""Signature routes — CRUD for the user's saved visual signatures.
Signatures are reusable image stamps (drawn once, applied to many things):
PDF form fields, email composition, document insertion. Each signature is
stored as a base64 PNG so it can be embedded inline anywhere without a
separate fetch.
"""
import base64
import logging
import re
import uuid
from typing import Any, Dict, Optional
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from core.database import SessionLocal, Signature
from src.auth_helpers import get_current_user
logger = logging.getLogger(__name__)
_DATA_URL_RE = re.compile(
r'^data:image/(?P<fmt>png|jpeg|jpg);base64,(?P<data>.+)$',
re.IGNORECASE | re.DOTALL,
)
class SignatureCreate(BaseModel):
name: Optional[str] = None
data: str # base64 PNG, with or without `data:image/png;base64,` prefix
width: Optional[int] = None
height: Optional[int] = None
svg: Optional[str] = None
def _to_dict(s: Signature) -> Dict[str, Any]:
return {
"id": s.id,
"name": s.name,
"data_url": f"data:image/png;base64,{s.data_png}",
"width": s.width,
"height": s.height,
"created_at": (s.created_at.isoformat() + "Z") if s.created_at else None,
}
def setup_signature_routes() -> APIRouter:
router = APIRouter(tags=["signatures"])
@router.get("/api/signatures")
async def list_signatures(request: Request) -> Dict[str, Any]:
user = get_current_user(request)
db = SessionLocal()
try:
q = db.query(Signature)
if user is not None:
# SECURITY: strict ownership — the previous OR predicate
# returned every null-owner signature to every user.
q = q.filter(Signature.owner == user)
sigs = q.order_by(Signature.created_at.desc()).all()
return {"signatures": [_to_dict(s) for s in sigs]}
finally:
db.close()
@router.post("/api/signatures")
async def create_signature(request: Request, req: SignatureCreate) -> Dict[str, Any]:
user = get_current_user(request)
raw = (req.data or "").strip()
m = _DATA_URL_RE.match(raw)
b64 = m.group("data") if m else raw
try:
payload = base64.b64decode(b64, validate=True)
if not payload:
raise ValueError("empty payload")
except Exception:
raise HTTPException(400, "Signature data must be base64-encoded PNG bytes")
sig = Signature(
id=str(uuid.uuid4()),
owner=user,
name=(req.name or "Signature").strip() or "Signature",
data_png=b64,
width=req.width,
height=req.height,
svg=req.svg,
)
db = SessionLocal()
try:
db.add(sig)
db.commit()
db.refresh(sig)
return _to_dict(sig)
except Exception as e:
db.rollback()
logger.error(f"Failed to save signature: {e}")
raise HTTPException(500, f"Failed to save signature: {e}")
finally:
db.close()
@router.delete("/api/signatures/{sig_id}")
async def delete_signature(sig_id: str, request: Request) -> Dict[str, Any]:
user = get_current_user(request)
db = SessionLocal()
try:
sig = db.query(Signature).filter(Signature.id == sig_id).first()
if not sig:
raise HTTPException(404, "Signature not found")
if user and sig.owner != user:
raise HTTPException(403, "Not your signature")
db.delete(sig)
db.commit()
return {"deleted": sig_id}
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(500, f"Failed to delete signature: {e}")
finally:
db.close()
return router

1530
routes/skills_routes.py Normal file

File diff suppressed because it is too large Load Diff

55
routes/stt_routes.py Normal file
View File

@@ -0,0 +1,55 @@
# routes/stt_routes.py
"""STT API routes — multi-provider (local Whisper, API endpoint, browser)."""
from fastapi import APIRouter, HTTPException, UploadFile, File
import logging
logger = logging.getLogger(__name__)
def setup_stt_routes(stt_service):
"""Setup STT routes with the provided STT service"""
router = APIRouter(prefix="/api/stt", tags=["stt"])
@router.get("/stats")
async def get_stt_stats():
"""Get STT service statistics"""
try:
return stt_service.get_stats()
except Exception as e:
logger.error(f"Failed to get STT stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/transcribe")
async def transcribe_audio(file: UploadFile = File(...)):
"""Transcribe uploaded audio file to text"""
try:
if not stt_service.available:
raise HTTPException(
status_code=503,
detail={"message": "STT service not available or set to browser mode"}
)
audio_bytes = await file.read()
if not audio_bytes:
raise HTTPException(status_code=400, detail={"message": "Empty audio file"})
text = stt_service.transcribe(audio_bytes)
if text is None:
raise HTTPException(
status_code=500,
detail={"message": "Transcription failed"}
)
return {"text": text}
except HTTPException:
raise
except Exception as e:
logger.error(f"Transcription error: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={"message": f"Transcription failed: {str(e)}"}
)
return router

910
routes/task_routes.py Normal file
View File

@@ -0,0 +1,910 @@
"""CRUD routes for scheduled tasks."""
import json
import logging
import secrets
import uuid
from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from core.database import SessionLocal, ScheduledTask, TaskRun
from src.auth_helpers import get_current_user
from src.task_scheduler import compute_next_run, HOUSEKEEPING_DEFAULTS
from routes.prefs_routes import _load_for_user, _save_for_user
logger = logging.getLogger(__name__)
class TaskCreate(BaseModel):
name: Optional[str] = None
prompt: Optional[str] = None
task_type: str = "llm" # "llm" | "action" | "research"
action: Optional[str] = None # builtin action name
schedule: Optional[str] = None # "once" | "daily" | "weekly" | "monthly" | "cron"
scheduled_time: str = "09:00" # HH:MM
scheduled_day: Optional[int] = None # day-of-week (0=Mon) or day-of-month
scheduled_date: Optional[str] = None # ISO datetime for "once"
cron_expression: Optional[str] = None # cron string e.g. "*/5 * * * *"
trigger_type: str = "schedule" # "schedule" | "event" | "webhook"
trigger_event: Optional[str] = None # e.g. "session_created"
trigger_count: Optional[int] = None # fire every N events
output_target: str = "session"
model: Optional[str] = None
endpoint_url: Optional[str] = None
then_task_id: Optional[str] = None # chain: run this task after success
notifications_enabled: Optional[bool] = None # None lets action-specific defaults apply
class TaskUpdate(BaseModel):
name: Optional[str] = None
prompt: Optional[str] = None
task_type: Optional[str] = None
action: Optional[str] = None
schedule: Optional[str] = None
scheduled_time: Optional[str] = None
scheduled_day: Optional[int] = None
scheduled_date: Optional[str] = None
cron_expression: Optional[str] = None
trigger_type: Optional[str] = None
trigger_event: Optional[str] = None
trigger_count: Optional[int] = None
output_target: Optional[str] = None
model: Optional[str] = None
endpoint_url: Optional[str] = None
then_task_id: Optional[str] = None
notifications_enabled: Optional[bool] = None
def _display_task_name(t: ScheduledTask) -> str:
defs = HOUSEKEEPING_DEFAULTS.get(t.action) if t.action else None
if defs and (t.name or "") in set(defs.get("legacy_names") or []):
return defs["name"]
return t.name
def _task_to_dict(t: ScheduledTask, include_last_run_result: bool = False) -> dict:
defs = HOUSEKEEPING_DEFAULTS.get(t.action) if t.action else None
d = {
"id": t.id,
"name": _display_task_name(t),
"prompt": t.prompt,
"task_type": t.task_type or "llm",
"action": t.action,
"schedule": t.schedule,
"scheduled_time": t.scheduled_time,
"scheduled_day": t.scheduled_day,
"scheduled_date": t.scheduled_date.isoformat() + "Z" if t.scheduled_date else None,
"cron_expression": t.cron_expression,
"trigger_type": t.trigger_type or "schedule",
"trigger_event": t.trigger_event,
"trigger_count": t.trigger_count,
"trigger_counter": t.trigger_counter or 0,
"next_run": t.next_run.isoformat() + "Z" if t.next_run else None,
"last_run": t.last_run.isoformat() + "Z" if t.last_run else None,
"status": t.status,
"output_target": t.output_target,
"session_id": t.session_id,
"crew_member_id": getattr(t, "crew_member_id", None),
"model": t.model,
"endpoint_url": t.endpoint_url,
"run_count": t.run_count or 0,
"then_task_id": t.then_task_id,
"notifications_enabled": bool(getattr(t, "notifications_enabled", True)),
"webhook_token": t.webhook_token if (t.trigger_type or "schedule") == "webhook" else None,
"created_at": t.created_at.isoformat() + "Z" if t.created_at else None,
"updated_at": t.updated_at.isoformat() + "Z" if t.updated_at else None,
}
# Built-in housekeeping tasks (identified by their action) are flagged so
# the UI can mark them and offer "revert to default" once altered.
d["is_builtin"] = defs is not None
if defs:
default_names = {defs["name"], *set(defs.get("legacy_names") or [])}
d["is_modified"] = (
(t.name or "") not in default_names
or (t.schedule or "") != (defs["schedule"] or "")
or (t.scheduled_time or "") != (defs["scheduled_time"] or "")
or (t.cron_expression or "") != (defs["cron_expression"] or "")
)
else:
d["is_modified"] = False
if include_last_run_result and t.runs:
last = t.runs[0] # ordered desc by started_at
d["last_run_status"] = last.status
d["last_run_result"] = (last.result or last.error or "")[:500]
return d
def _run_to_dict(r: TaskRun) -> dict:
return {
"id": r.id,
"task_id": r.task_id,
"started_at": r.started_at.isoformat() + "Z" if r.started_at else None,
"finished_at": r.finished_at.isoformat() + "Z" if r.finished_at else None,
"status": r.status,
"result": r.result,
"error": r.error,
"tokens_used": r.tokens_used,
"model": r.model,
}
def _run_research_id(task: ScheduledTask) -> str:
if (task.task_type or "llm") == "research" and task.session_id:
return task.session_id
return ""
def _resolve_run_endpoint(db, task: ScheduledTask, run: TaskRun) -> str:
"""Best-effort endpoint URL for reopening a task run in chat."""
if getattr(task, "endpoint_url", None):
return task.endpoint_url or ""
try:
if getattr(task, "session_id", None):
from core.database import Session as DbSession
sess = db.query(DbSession).filter(DbSession.id == task.session_id).first()
if sess and sess.endpoint_url:
return sess.endpoint_url or ""
except Exception:
pass
model = (getattr(run, "model", None) or getattr(task, "model", None) or "").strip()
if not model:
return ""
try:
from core.database import ModelEndpoint
eps = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
for ep in eps:
cached = []
if ep.cached_models:
try:
cached = json.loads(ep.cached_models) or []
except Exception:
cached = []
if model in cached:
return ep.base_url or ""
except Exception:
pass
return ""
def setup_task_routes(task_scheduler) -> APIRouter:
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
def _owner(request: Request):
return get_current_user(request)
async def _generate_task_name(prompt: str) -> str:
"""Use LLM to generate a short task name from the prompt."""
try:
from src.llm_core import llm_call_async
from core.database import Session as DbSession
db = SessionLocal()
try:
recent = db.query(DbSession).filter(
DbSession.endpoint_url.isnot(None),
DbSession.model.isnot(None),
).order_by(DbSession.created_at.desc()).first()
if not recent:
return prompt[:50].strip()
url, model = recent.endpoint_url, recent.model
finally:
db.close()
result = await llm_call_async(
url=url, model=model,
messages=[
{"role": "system", "content": "Generate a short title (3-5 words, no quotes) for this scheduled task. Reply with ONLY the title, nothing else."},
{"role": "user", "content": prompt[:500]},
],
max_tokens=20,
timeout=15,
)
title = result.strip().strip('"\'').strip()
return title[:60] if title else prompt[:50].strip()
except Exception:
first = prompt.split('\n')[0].split('.')[0].strip()
return first[:50] if first else "Untitled Task"
@router.get("")
async def list_tasks(request: Request, status: Optional[str] = None,
include_last_run: bool = False):
user = _owner(request)
if user:
await task_scheduler.ensure_defaults(user)
else:
db_seed = SessionLocal()
try:
owners = {
row[0] for row in db_seed.query(ScheduledTask.owner)
.filter(ScheduledTask.task_type == "action")
.filter(ScheduledTask.action.in_(list(HOUSEKEEPING_DEFAULTS.keys())))
.all()
if row[0]
}
finally:
db_seed.close()
for owner in owners:
await task_scheduler.ensure_defaults(owner)
db = SessionLocal()
try:
q = db.query(ScheduledTask)
if user:
q = q.filter(ScheduledTask.owner == user)
if status:
q = q.filter(ScheduledTask.status == status)
tasks = q.order_by(ScheduledTask.created_at.desc()).all()
return {"tasks": [_task_to_dict(t, include_last_run_result=include_last_run) for t in tasks]}
finally:
db.close()
@router.get("/onboarding")
async def get_tasks_onboarding(request: Request):
user = _owner(request)
prefs = _load_for_user(user) or {}
return {
"opened": bool(prefs.get("tasks_opened")),
"enabled": bool(prefs.get("tasks_enabled")),
}
@router.post("/onboarding")
async def update_tasks_onboarding(request: Request, body: dict):
user = _owner(request)
prefs = _load_for_user(user) or {}
prefs["tasks_opened"] = True
enable = bool(body.get("enabled"))
if enable:
prefs["tasks_enabled"] = True
_save_for_user(user, prefs)
if user:
await task_scheduler.ensure_defaults(user)
resumed = 0
if enable:
db = SessionLocal()
try:
tasks = db.query(ScheduledTask).filter(
ScheduledTask.owner == user,
ScheduledTask.task_type == "action",
ScheduledTask.action.in_(list(HOUSEKEEPING_DEFAULTS.keys())),
).all()
for task in tasks:
defs = HOUSEKEEPING_DEFAULTS.get(task.action or "")
if defs and defs.get("ship_paused"):
continue
if task.status == "active":
continue
task.status = "active"
if (task.trigger_type or "schedule") == "schedule":
task.next_run = compute_next_run(
task.schedule,
task.scheduled_time,
task.scheduled_day,
task.scheduled_date,
cron_expression=task.cron_expression,
)
resumed += 1
db.commit()
finally:
db.close()
return {"ok": True, "opened": True, "enabled": bool(prefs.get("tasks_enabled")), "resumed": resumed}
# Actions that execute shell/SSH commands — restricted to admins.
# Non-admin users cannot create tasks with these action types via the
# API. See review CRIT-C.
_ADMIN_ONLY_ACTIONS = {"run_local", "run_script", "ssh_command"}
def _is_admin(user: str | None) -> bool:
if not user:
return False
# In-process tool-loopback marker — AuthMiddleware validated
# the internal token + loopback client before stamping this,
# so treat as admin-equivalent.
if user == "internal-tool":
return True
try:
from core.auth import AuthManager
auth = AuthManager()
if not auth.is_configured:
# Unconfigured single-user deploy: trust the local owner.
return True
return bool(auth.is_admin(user))
except Exception:
return False
@router.post("")
async def create_task(request: Request, req: TaskCreate):
user = _owner(request)
# Validate
if req.task_type in ("llm", "research") and not req.prompt:
raise HTTPException(400, "Prompt is required for LLM/research tasks")
if req.task_type == "action" and not req.action:
raise HTTPException(400, "Action name is required for action tasks")
# Block shell-executing action types for non-admins. action_run_local
# uses subprocess.run(shell=True) and ssh_command / run_script run
# arbitrary commands.
if req.task_type == "action" and req.action in _ADMIN_ONLY_ACTIONS and not _is_admin(user):
raise HTTPException(403, f"Action '{req.action}' requires admin privileges")
if req.trigger_type == "schedule" and not req.schedule:
raise HTTPException(400, "Schedule is required for schedule-triggered tasks")
if req.trigger_type == "schedule" and req.schedule == "cron" and not req.cron_expression:
raise HTTPException(400, "Cron expression is required for cron schedule")
if req.trigger_type == "schedule" and req.schedule == "cron" and req.cron_expression:
try:
from croniter import croniter
croniter(req.cron_expression)
except Exception:
raise HTTPException(400, "Invalid cron expression")
if req.trigger_type == "event" and not req.trigger_event:
raise HTTPException(400, "Event name is required for event-triggered tasks")
if req.trigger_type == "event" and not req.trigger_count:
raise HTTPException(400, "Trigger count is required for event-triggered tasks")
# Auto-generate name
name = req.name
if not name:
if req.task_type == "action":
from src.builtin_actions import BUILTIN_ACTION_INFO
name = BUILTIN_ACTION_INFO.get(req.action, req.action or "Action Task")
elif req.prompt:
name = await _generate_task_name(req.prompt)
else:
name = "Untitled Task"
# Compute next_run for schedule-triggered tasks
next_run = None
sched_date = None
if req.trigger_type == "schedule":
if req.schedule == "once" and req.scheduled_date:
try:
sched_date = datetime.fromisoformat(req.scheduled_date.replace("Z", "+00:00")).replace(tzinfo=None)
except ValueError:
raise HTTPException(400, "Invalid scheduled_date format")
next_run = compute_next_run(
req.schedule, req.scheduled_time,
req.scheduled_day, sched_date,
cron_expression=req.cron_expression,
)
# Generate webhook token if needed
webhook_token = None
if req.trigger_type == "webhook":
webhook_token = secrets.token_urlsafe(32)
task_id = str(uuid.uuid4())
db = SessionLocal()
try:
notifications_enabled = (
False if req.task_type == "action" and req.notifications_enabled is None
else bool(req.notifications_enabled) if req.notifications_enabled is not None
else True
)
task = ScheduledTask(
id=task_id,
owner=user,
name=name,
prompt=req.prompt,
task_type=req.task_type,
action=req.action,
schedule=req.schedule,
scheduled_time=req.scheduled_time,
scheduled_day=req.scheduled_day,
scheduled_date=sched_date,
cron_expression=req.cron_expression,
trigger_type=req.trigger_type,
trigger_event=req.trigger_event,
trigger_count=req.trigger_count,
trigger_counter=0,
next_run=next_run,
status="active" if (req.trigger_type in ("event", "webhook") or next_run) else "completed",
output_target=req.output_target,
model=req.model or None,
endpoint_url=req.endpoint_url or None,
then_task_id=req.then_task_id or None,
webhook_token=webhook_token,
notifications_enabled=notifications_enabled,
)
db.add(task)
db.commit()
db.refresh(task)
return _task_to_dict(task)
finally:
db.close()
@router.get("/notifications")
async def get_notifications(request: Request):
"""Return and clear pending task-run notifications for the
current user. Anonymous callers get nothing (prevents
cross-tenant drain — see review CRIT-B)."""
user = _owner(request)
if not user:
return {"notifications": []}
notes = task_scheduler.pop_notifications(owner=user)
return {"notifications": notes}
@router.get("/{task_id}")
async def get_task(request: Request, task_id: str):
user = _owner(request)
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
if not task:
raise HTTPException(404, "Task not found")
if user and task.owner != user:
raise HTTPException(403, "Access denied")
return _task_to_dict(task)
finally:
db.close()
@router.put("/{task_id}")
async def update_task(request: Request, task_id: str, req: TaskUpdate):
user = _owner(request)
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
if not task:
raise HTTPException(404, "Task not found")
if user and task.owner != user:
raise HTTPException(403, "Access denied")
if req.name is not None:
task.name = req.name
if req.prompt is not None:
task.prompt = req.prompt
if req.task_type is not None:
task.task_type = req.task_type
if req.action is not None:
# Same admin-only gate as create — see CRIT-C.
if req.action in _ADMIN_ONLY_ACTIONS and not _is_admin(user):
raise HTTPException(403, f"Action '{req.action}' requires admin privileges")
task.action = req.action
if req.output_target is not None:
task.output_target = req.output_target
if req.model is not None:
task.model = req.model or None
if req.endpoint_url is not None:
task.endpoint_url = req.endpoint_url or None
if req.trigger_type is not None:
# Generate webhook token when switching to webhook trigger
if req.trigger_type == "webhook" and not task.webhook_token:
task.webhook_token = secrets.token_urlsafe(32)
task.trigger_type = req.trigger_type
if req.trigger_event is not None:
task.trigger_event = req.trigger_event
if req.trigger_count is not None:
task.trigger_count = req.trigger_count
if req.then_task_id is not None:
task.then_task_id = req.then_task_id or None
if req.notifications_enabled is not None:
task.notifications_enabled = bool(req.notifications_enabled)
if req.cron_expression is not None:
if req.cron_expression:
try:
from croniter import croniter
croniter(req.cron_expression)
except Exception:
raise HTTPException(400, "Invalid cron expression")
task.cron_expression = req.cron_expression or None
# Recompute next_run if schedule changed
schedule_changed = False
if req.schedule is not None:
task.schedule = req.schedule
schedule_changed = True
if req.scheduled_time is not None:
task.scheduled_time = req.scheduled_time
schedule_changed = True
if req.scheduled_day is not None:
task.scheduled_day = req.scheduled_day
schedule_changed = True
if req.scheduled_date is not None:
try:
task.scheduled_date = datetime.fromisoformat(
req.scheduled_date.replace("Z", "+00:00")
).replace(tzinfo=None)
except ValueError:
raise HTTPException(400, "Invalid scheduled_date format")
schedule_changed = True
if req.cron_expression is not None:
schedule_changed = True
if schedule_changed and task.status == "active" and (task.trigger_type or "schedule") == "schedule":
task.next_run = compute_next_run(
task.schedule, task.scheduled_time,
task.scheduled_day, task.scheduled_date,
cron_expression=task.cron_expression,
)
db.commit()
db.refresh(task)
return _task_to_dict(task)
finally:
db.close()
@router.delete("/{task_id}")
async def delete_task(request: Request, task_id: str):
user = _owner(request)
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
if not task:
raise HTTPException(404, "Task not found")
if user and task.owner != user:
raise HTTPException(403, "Access denied")
db.delete(task)
db.commit()
return {"ok": True}
finally:
db.close()
@router.post("/{task_id}/pause")
async def pause_task(request: Request, task_id: str):
user = _owner(request)
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
if not task:
raise HTTPException(404, "Task not found")
if user and task.owner != user:
raise HTTPException(403, "Access denied")
task.status = "paused"
db.commit()
return {"ok": True, "status": "paused"}
finally:
db.close()
@router.post("/{task_id}/resume")
async def resume_task(request: Request, task_id: str):
user = _owner(request)
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
if not task:
raise HTTPException(404, "Task not found")
if user and task.owner != user:
raise HTTPException(403, "Access denied")
task.status = "active"
if (task.trigger_type or "schedule") == "schedule":
task.next_run = compute_next_run(
task.schedule, task.scheduled_time,
task.scheduled_day, task.scheduled_date,
cron_expression=task.cron_expression,
)
db.commit()
return {"ok": True, "status": "active", "next_run": task.next_run.isoformat() + "Z" if task.next_run else None}
finally:
db.close()
@router.post("/{task_id}/revert")
async def revert_task(request: Request, task_id: str):
"""Reset a built-in (housekeeping) task to its default config."""
user = _owner(request)
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
if not task:
raise HTTPException(404, "Task not found")
if user and task.owner != user:
raise HTTPException(403, "Access denied")
defs = HOUSEKEEPING_DEFAULTS.get(task.action) if task.action else None
if not defs:
raise HTTPException(400, "Not a built-in task")
task.name = defs["name"]
task.schedule = defs["schedule"]
task.scheduled_time = defs["scheduled_time"]
task.scheduled_day = None
task.scheduled_date = None
task.cron_expression = defs["cron_expression"]
task.trigger_type = defs.get("trigger_type", "schedule")
task.trigger_event = defs.get("trigger_event")
task.trigger_count = defs.get("trigger_count")
task.trigger_counter = 0
task.prompt = None
task.model = None
task.endpoint_url = None
task.status = "paused" if defs.get("ship_paused") else "active"
task.next_run = None
if task.trigger_type == "schedule":
task.next_run = compute_next_run(
defs["schedule"], defs["scheduled_time"], None, None,
cron_expression=defs["cron_expression"],
)
db.commit()
db.refresh(task)
return {"ok": True, "task": _task_to_dict(task)}
finally:
db.close()
@router.post("/{task_id}/run")
async def run_task_now(request: Request, task_id: str, force: bool = False):
user = _owner(request)
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
if not task:
raise HTTPException(404, "Task not found")
if user and task.owner != user:
raise HTTPException(403, "Access denied")
finally:
db.close()
started = await task_scheduler.run_task_now(task_id, force=force)
if not started:
raise HTTPException(409, "Task is already running")
return {"ok": True, "message": "Task triggered" + (" in parallel" if force else "")}
@router.get("/runs/recent")
async def list_recent_runs(request: Request, limit: int = 50):
"""Recent task runs across ALL tasks for this owner. Drives the Activity view."""
user = _owner(request)
limit = max(1, min(limit, 200))
db = SessionLocal()
try:
q = db.query(TaskRun, ScheduledTask).join(
ScheduledTask, TaskRun.task_id == ScheduledTask.id
)
if user:
# Strict owner scope — was previously OR'ing in `owner IS NULL`
# rows for "legacy single-user" back-compat, but that leaks any
# legacy/migrated task's full result text to every authenticated
# user. _migrate_assign_legacy_owner runs on startup to claim
# legacy rows for the admin, so the OR-NULL path is no longer
# needed for any sane deploy.
q = q.filter(ScheduledTask.owner == user)
# Pull a little extra before de-duping. When auth is bypassed on a
# local browser session, legacy/default tasks from multiple owners
# can be visible together; the built-in urgent-email scanner then
# produces several identical "no email accounts configured" rows in
# the same minute. Keep the task records intact, but collapse those
# duplicate Activity rows for display.
rows = q.order_by(TaskRun.started_at.desc()).limit(limit * 3).all()
deduped = []
seen_urgency_rows = set()
for r, t in rows:
if (t.action or "") == "check_email_urgency":
ts = r.started_at.replace(second=0, microsecond=0) if r.started_at else None
text = (r.result or r.error or "").strip()
key = (ts, r.status or "", text)
if key in seen_urgency_rows:
continue
seen_urgency_rows.add(key)
deduped.append((r, t))
if len(deduped) >= limit:
break
return {
"runs": [
{
**_run_to_dict(r),
"task_name": _display_task_name(t),
"task_type": t.task_type or "llm",
"action": t.action,
# Model + endpoint the task ran on, so the Activity
# view's "Open in chat" can reuse the same model.
"model": r.model or t.model or "",
"endpoint_url": _resolve_run_endpoint(db, t, r),
"session_id": t.session_id or "",
"research_id": _run_research_id(t),
# Where the task delivered its result — the Activity tab
# uses this to filter notification rows in/out.
"output_target": t.output_target or "session",
}
for r, t in deduped
]
}
finally:
db.close()
@router.get("/{task_id}/runs")
async def list_runs(request: Request, task_id: str, limit: int = 20, offset: int = 0):
user = _owner(request)
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
if not task:
raise HTTPException(404, "Task not found")
if user and task.owner != user:
raise HTTPException(403, "Access denied")
runs = db.query(TaskRun).filter(TaskRun.task_id == task_id)\
.order_by(TaskRun.started_at.desc())\
.offset(offset).limit(limit).all()
total = db.query(TaskRun).filter(TaskRun.task_id == task_id).count()
return {"runs": [_run_to_dict(r) for r in runs], "total": total}
finally:
db.close()
@router.get("/meta/output-targets")
async def list_output_targets(request: Request):
"""List available output targets — only delivery/send tools, not all MCP tools."""
_owner(request)
targets = [
{"value": "session", "label": "Session", "description": "Save result to a chat session"},
{"value": "notification", "label": "Notification", "description": "Push a browser notification with the result (also saved to the session for history)"},
{"value": "email", "label": "Email me", "description": "Send result through your configured SMTP account"},
]
# Only include tools whose NAME clearly indicates an outbound delivery
# action — match by verb in the tool name, not by any mention of "email"
# in the description (which falsely picked up search_email, list_email,
# etc.). Also exclude read/search/list tools whose names happen to start
# with a delivery verb.
_DELIVERY_VERBS = ("send", "notify", "post", "publish", "draft", "dispatch", "deliver")
_NON_DELIVERY = (
"search", "list", "get", "find", "read", "fetch", "view",
"tag", "label", "move", "archive", "delete", "mark", "schedule",
)
try:
from src.agent_tools import get_mcp_manager
mcp = get_mcp_manager()
if mcp:
for tool in mcp.get_all_tools():
name_lower = tool.get("name", "").lower()
if any(x in name_lower for x in _NON_DELIVERY):
continue
if not any(v in name_lower for v in _DELIVERY_VERBS):
continue
targets.append({
"value": tool["qualified_name"],
"label": f"{tool['server_name']}{tool['name']}",
"description": tool.get("description", ""),
})
except Exception:
pass
return {"targets": targets}
@router.get("/meta/actions")
async def list_actions(request: Request):
"""List available built-in actions."""
user = _owner(request)
from src.builtin_actions import BUILTIN_ACTION_INFO
return {"actions": [
{"name": name, "description": desc}
for name, desc in BUILTIN_ACTION_INFO.items()
if name not in _ADMIN_ONLY_ACTIONS or _is_admin(user)
]}
@router.get("/meta/events")
async def list_events(request: Request):
"""List available event triggers."""
_owner(request)
return {"events": [
{"name": "session_created", "description": "Fires when a new chat session is created"},
{"name": "message_sent", "description": "Fires when a user sends a message"},
{"name": "document_created", "description": "Fires when a document is created"},
{"name": "memory_added", "description": "Fires when a memory is added"},
{"name": "research_completed", "description": "Fires when a research report completes"},
{"name": "email_received", "description": "Fires when new inbox mail is observed"},
{"name": "skill_added", "description": "Fires when a new skill is created"},
]}
@router.post("/{task_id}/webhook/{token}")
async def webhook_trigger(task_id: str, token: str):
"""Unauthenticated endpoint — the token IS the auth."""
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(
ScheduledTask.id == task_id,
ScheduledTask.webhook_token == token,
ScheduledTask.status == "active",
).first()
if not task:
raise HTTPException(404, "Not found")
finally:
db.close()
started = await task_scheduler.run_task_now(task_id)
if not started:
raise HTTPException(409, "Task is already running")
return {"ok": True, "message": "Task triggered via webhook"}
@router.post("/{task_id}/webhook-regenerate")
async def regenerate_webhook(request: Request, task_id: str):
user = _owner(request)
db = SessionLocal()
try:
task = db.query(ScheduledTask).filter(ScheduledTask.id == task_id).first()
if not task:
raise HTTPException(404, "Task not found")
if user and task.owner != user:
raise HTTPException(403, "Access denied")
task.webhook_token = secrets.token_urlsafe(32)
db.commit()
return {"ok": True, "webhook_token": task.webhook_token}
finally:
db.close()
# --- PARSE NATURAL LANGUAGE → TASK DRAFT (AI) ---
@router.post("/parse")
async def parse_task(request: Request) -> Dict[str, Any]:
"""Turn a free-form description ("every weekday at 7am research the top
AI news and summarize it") into a structured task draft the frontend
can pre-fill the form with. Returns a draft only — the user reviews and
saves it, so a misread schedule never goes live unreviewed."""
from src.endpoint_resolver import resolve_endpoint
from src.llm_core import llm_call_async
from src.text_helpers import strip_think as _strip_think
import json as _json, re as _re
from datetime import datetime as _dt
body = await request.json()
desc = (body.get("description") or "").strip()
if not desc:
return {"success": False, "message": "Nothing to parse"}
now = _dt.now()
# Give the model the current date/time + weekday so relative phrasing
# ("tomorrow", "every Monday", "in an hour") resolves correctly.
ctx = now.strftime("%Y-%m-%d %H:%M (%A)")
sys = (
"You convert a user's description of a recurring or one-off task into "
"STRICT JSON for a task scheduler. The current local date/time is "
f"{ctx}. Output ONLY a JSON object, no prose, no markdown fences.\n\n"
"Schema (omit fields you can't infer):\n"
"{\n"
' "task_type": "llm" | "research", // "research" if it asks to research/investigate/find out; else "llm"\n'
' "name": "short 3-6 word title",\n'
' "prompt": "the instruction the AI should run on schedule (or the research question)",\n'
' "schedule": "daily" | "weekly" | "monthly" | "once" | "cron",\n'
' "scheduled_time": "HH:MM", // 24h LOCAL time\n'
' "scheduled_day": 0, // weekly: 0=Mon..6=Sun; monthly: 1..31\n'
' "scheduled_date": "YYYY-MM-DDTHH:MM", // only for "once"\n'
' "cron_expression": "m h dom mon dow", // only if schedule is "cron"\n'
' "output_target": "session" | "email" | "notification" // use email when the user asks to email the result\n'
"}\n\n"
"Rules: default schedule to 'daily' if a time is given without a frequency. "
"Default scheduled_time to '09:00' if none is stated. For 'every weekday' "
"use cron '0 H * * 1-5'. Keep the prompt actionable and self-contained."
)
try:
url, model, headers = resolve_endpoint("utility")
if not url:
url, model, headers = resolve_endpoint("default")
if not (url and model):
return {"success": False, "message": "No model endpoint configured"}
raw = await llm_call_async(
url=url, model=model,
messages=[{"role": "system", "content": sys},
{"role": "user", "content": desc[:1000]}],
temperature=0.2, max_tokens=400, headers=headers, timeout=45,
)
text = _strip_think(raw or "", prose=False, prompt_echo=False).strip()
if text.startswith("```"):
text = text.strip("`")
if text.lower().startswith("json"):
text = text[4:].lstrip()
# Pull the first {...} block in case the model added stray text.
m = _re.search(r"\{.*\}", text, _re.S)
draft = _json.loads(m.group(0) if m else text)
if not isinstance(draft, dict):
raise ValueError("not an object")
# Whitelist + light validation so the frontend gets clean fields.
out: Dict[str, Any] = {}
if draft.get("task_type") in ("llm", "research"):
out["task_type"] = draft["task_type"]
else:
out["task_type"] = "llm"
for k in ("name", "prompt", "cron_expression", "scheduled_date"):
if isinstance(draft.get(k), str) and draft[k].strip():
out[k] = draft[k].strip()
if draft.get("schedule") in ("daily", "weekly", "monthly", "once", "cron"):
out["schedule"] = draft["schedule"]
else:
out["schedule"] = "daily"
st = draft.get("scheduled_time")
if isinstance(st, str) and _re.match(r"^\d{1,2}:\d{2}$", st.strip()):
out["scheduled_time"] = st.strip()
if isinstance(draft.get("scheduled_day"), int):
out["scheduled_day"] = draft["scheduled_day"]
if draft.get("output_target") in ("session", "email", "notification"):
out["output_target"] = draft["output_target"]
out["trigger_type"] = "schedule"
if not out.get("prompt"):
return {"success": False, "message": "Could not extract a task instruction"}
return {"success": True, "draft": out}
except Exception as e:
logger.error(f"parse_task failed: {e}")
return {"success": False, "message": str(e)}
return router

87
routes/tts_routes.py Normal file
View File

@@ -0,0 +1,87 @@
# routes/tts_routes.py
"""
TTS API routes — multi-provider (local Kokoro, API endpoint, browser).
"""
from fastapi import APIRouter, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel
import logging
logger = logging.getLogger(__name__)
class TTSRequest(BaseModel):
text: str
format: str = "audio" # "audio" or "base64"
def setup_tts_routes(tts_service):
"""Setup TTS routes with the provided TTS service"""
router = APIRouter(prefix="/api/tts", tags=["tts"])
@router.get("/stats")
async def get_tts_stats():
"""Get TTS service statistics"""
try:
return tts_service.get_stats()
except Exception as e:
logger.error(f"Failed to get TTS stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/synthesize")
async def synthesize_speech(request: TTSRequest):
"""Synthesize speech from text"""
try:
if not tts_service.available:
raise HTTPException(
status_code=503,
detail={"message": "TTS service not available"}
)
if request.format == "base64":
audio_b64 = tts_service.synthesize_to_base64(request.text)
if not audio_b64:
raise HTTPException(
status_code=500,
detail={"message": "Synthesis failed"}
)
return {"audio": audio_b64}
else: # audio format
audio_data = tts_service.synthesize(request.text)
if not audio_data:
raise HTTPException(
status_code=500,
detail={"message": "Synthesis failed"}
)
# Detect format from magic bytes (MP3: ID3 tag or sync word ff e0+)
is_mp3 = audio_data[:3] == b'ID3' or (len(audio_data) >= 2 and audio_data[0] == 0xff and (audio_data[1] & 0xe0) == 0xe0)
mime = "audio/mpeg" if is_mp3 else "audio/wav"
return Response(
content=audio_data,
media_type=mime,
headers={
"Content-Disposition": "inline; filename=speech.mp3" if "mpeg" in mime else "inline; filename=speech.wav"
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Synthesis error: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={"message": f"Synthesis failed: {str(e)}"}
)
@router.post("/clear-cache")
async def clear_tts_cache():
"""Clear TTS cache"""
try:
tts_service.clear_cache()
return {"success": True, "message": "Cache cleared"}
except Exception as e:
logger.error(f"Failed to clear cache: {e}")
raise HTTPException(status_code=500, detail=str(e))
return router

251
routes/upload_routes.py Normal file
View File

@@ -0,0 +1,251 @@
# routes/upload_routes.py
import os
import time
import json
import asyncio
from fastapi import APIRouter, Request, File, UploadFile, HTTPException
from typing import List
import logging
from core.middleware import require_admin
from src.auth_helpers import get_current_user
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/upload", tags=["upload"])
def setup_upload_routes(upload_handler):
"""Setup upload routes with the provided handler"""
@router.post("")
async def api_upload(request: Request, files: List[UploadFile] = File(...)):
"""Upload files with enhanced security and organization."""
if not files:
raise HTTPException(400, "No files uploaded")
client_ip = request.client.host if request.client else "unknown"
out = []
# Limit concurrent uploads per IP
ip_upload_count = sum(
1 for f in files
if client_ip in upload_handler.upload_rate_log and
any(now > time.time() - 10 for now in upload_handler.upload_rate_log[client_ip][-len(files):])
)
if ip_upload_count >= upload_handler.max_concurrent_uploads:
raise HTTPException(
status_code=429,
detail=f"Maximum concurrent uploads ({upload_handler.max_concurrent_uploads}) exceeded"
)
for u in files:
try:
meta = upload_handler.save_upload(u, client_ip, owner=get_current_user(request))
out.append({
"id": meta["id"],
"name": meta["name"],
"mime": meta["mime"],
"size": meta["size"],
"hash": meta["hash"],
"uploaded_at": meta["uploaded_at"],
"width": meta.get("width"),
"height": meta.get("height"),
"is_duplicate": meta.get("is_duplicate", False)
})
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to process upload {u.filename}: {str(e)}")
continue
if not out:
raise HTTPException(500, "All file uploads failed")
return {"files": out}
@router.post("/cleanup")
async def manual_cleanup(request: Request):
"""Manually trigger cleanup of old uploads."""
require_admin(request)
cleaned_count = upload_handler.cleanup_old_uploads()
return {"status": "success", "files_cleaned": cleaned_count}
@router.get("/stats")
async def upload_stats(request: Request):
"""Get statistics about uploaded files."""
require_admin(request)
try:
return upload_handler.get_upload_stats()
except Exception as e:
logger.error(f"Failed to get upload stats: {e}")
raise HTTPException(500, "Failed to get upload statistics")
@router.get("/{file_id}")
async def download_file(request: Request, file_id: str, thumb: int = 0):
"""Serve an uploaded file by its ID. `?thumb=1` returns a small cached
JPEG thumbnail for images (used by chat attachment previews) so the
client isn't downloading the full-resolution photo just to show it tiny."""
if not upload_handler.validate_upload_id(file_id):
raise HTTPException(400, "Invalid file ID")
# Search upload directories for the file
from src.constants import UPLOAD_DIR
import mimetypes as _mt
path = os.path.join(UPLOAD_DIR, file_id)
if not os.path.exists(path):
for root, dirs, files in os.walk(UPLOAD_DIR):
if file_id in files:
path = os.path.join(root, file_id)
break
else:
raise HTTPException(404, "File not found")
if not upload_handler.inside_base_dir(path):
raise HTTPException(403, "Access denied")
# Look up original filename and owner from uploads.json
original_name = file_id
info = None
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
if os.path.exists(uploads_db):
with open(uploads_db) as f:
db = json.load(f)
info = next((fi for fi in db.values() if fi["id"] == file_id), None)
if info:
original_name = info.get("name", file_id)
auth_mgr = getattr(request.app.state, "auth_manager", None)
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
current_user = get_current_user(request)
file_owner = info.get("owner") if info else None
if auth_configured:
if not current_user:
raise HTTPException(403, "Access denied")
if file_owner != current_user and not auth_mgr.is_admin(current_user):
raise HTTPException(404, "File not found")
mime = _mt.guess_type(path)[0] or "application/octet-stream"
from fastapi.responses import FileResponse
# Downscaled thumbnail for image previews — generated once and cached.
if thumb and mime.startswith("image/"):
try:
from PIL import Image, ImageOps
thumb_dir = os.path.join(UPLOAD_DIR, ".thumbs")
os.makedirs(thumb_dir, exist_ok=True)
thumb_path = os.path.join(thumb_dir, file_id + ".jpg")
if (not os.path.exists(thumb_path)
or os.path.getmtime(thumb_path) < os.path.getmtime(path)):
im = Image.open(path)
# iPhone / camera JPEGs encode rotation in EXIF rather than
# the pixel data. Browsers honour that on the original via
# image-orientation:from-image, but PIL strips EXIF when it
# saves the JPEG thumb, leaving the pixels sideways. Bake
# the rotation into the pixels before thumbnailing.
im = ImageOps.exif_transpose(im)
im.thumbnail((320, 320))
if im.mode not in ("RGB", "L"):
im = im.convert("RGB")
im.save(thumb_path, "JPEG", quality=80)
return FileResponse(thumb_path, media_type="image/jpeg")
except Exception as e:
logger.warning(f"Thumbnail generation failed for {file_id}: {e}")
# Fall through to the full image.
return FileResponse(path, media_type=mime, filename=original_name)
def _load_upload_info(file_id: str):
"""Look up the uploads.json record for a file_id, with owner/auth checks."""
from src.constants import UPLOAD_DIR
info = None
uploads_db = os.path.join(UPLOAD_DIR, "uploads.json")
if os.path.exists(uploads_db):
with open(uploads_db) as f:
db = json.load(f)
info = next((fi for fi in db.values() if fi["id"] == file_id), None)
return info
def _vision_cache_path(file_id: str) -> str:
from src.constants import UPLOAD_DIR
cache_dir = os.path.join(UPLOAD_DIR, ".vision")
os.makedirs(cache_dir, exist_ok=True)
return os.path.join(cache_dir, file_id + ".txt")
@router.get("/{file_id}/vision")
async def get_vision_text(request: Request, file_id: str, force: int = 0):
"""Return the vision-model OCR/description for an uploaded image.
Cached under UPLOAD_DIR/.vision/{file_id}.txt — first call computes,
subsequent loads are instant. Pass force=1 to recompute."""
if not upload_handler.validate_upload_id(file_id):
raise HTTPException(400, "Invalid file ID")
from src.constants import UPLOAD_DIR
path = os.path.join(UPLOAD_DIR, file_id)
if not os.path.exists(path):
for root, dirs, files in os.walk(UPLOAD_DIR):
if file_id in files:
path = os.path.join(root, file_id)
break
else:
raise HTTPException(404, "File not found")
if not upload_handler.inside_base_dir(path):
raise HTTPException(403, "Access denied")
info = _load_upload_info(file_id)
auth_mgr = getattr(request.app.state, "auth_manager", None)
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
current_user = get_current_user(request)
file_owner = info.get("owner") if info else None
if auth_configured:
if not current_user:
raise HTTPException(403, "Access denied")
if file_owner != current_user and not auth_mgr.is_admin(current_user):
raise HTTPException(404, "File not found")
import mimetypes as _mt
mime = _mt.guess_type(path)[0] or ""
if not mime.startswith("image/"):
raise HTTPException(400, "Not an image")
cache_path = _vision_cache_path(file_id)
if not force and os.path.exists(cache_path):
try:
with open(cache_path) as f:
return {"text": f.read(), "cached": True}
except Exception as e:
logger.warning(f"Vision cache read failed for {file_id}: {e}")
from src.document_processor import analyze_image_with_vl
try:
text = analyze_image_with_vl(path) or ""
except Exception as e:
logger.error(f"Vision analysis failed for {file_id}: {e}")
raise HTTPException(500, f"Vision analysis failed: {e}")
try:
with open(cache_path, "w") as f:
f.write(text)
except Exception as e:
logger.warning(f"Vision cache write failed for {file_id}: {e}")
return {"text": text, "cached": False}
@router.put("/{file_id}/vision")
async def put_vision_text(request: Request, file_id: str):
"""Persist a user-edited vision/OCR text for an attachment. Stored in
the same cache file so the chat send picks it up as the override."""
if not upload_handler.validate_upload_id(file_id):
raise HTTPException(400, "Invalid file ID")
info = _load_upload_info(file_id)
if not info:
raise HTTPException(404, "File not found")
auth_mgr = getattr(request.app.state, "auth_manager", None)
auth_configured = bool(auth_mgr and auth_mgr.is_configured)
current_user = get_current_user(request)
file_owner = info.get("owner")
if auth_configured:
if not current_user:
raise HTTPException(403, "Access denied")
if file_owner != current_user and not auth_mgr.is_admin(current_user):
raise HTTPException(404, "File not found")
body = await request.json()
text = (body or {}).get("text", "")
if not isinstance(text, str):
raise HTTPException(400, "text must be a string")
with open(_vision_cache_path(file_id), "w") as f:
f.write(text)
return {"ok": True}
async def periodic_rate_limit_cleanup():
"""Background task to run cleanup every hour"""
while True:
await asyncio.sleep(3600)
upload_handler.cleanup_rate_limits()
return router, periodic_rate_limit_cleanup

216
routes/vault_routes.py Normal file
View File

@@ -0,0 +1,216 @@
"""
vault_routes.py
Vaultwarden / Bitwarden CLI integration — config and unlock endpoints.
Stores the BW_SESSION key in data/vault.json with restrictive permissions.
"""
import json
import logging
import os
import shutil
import asyncio
from pathlib import Path
from datetime import datetime
from fastapi import APIRouter, Request
from pydantic import BaseModel
from core.middleware import require_admin
logger = logging.getLogger(__name__)
VAULT_FILE = Path("data/vault.json")
def _find_bw() -> str:
"""Locate the bw binary, checking PATH and common npm-global locations."""
p = shutil.which("bw")
if p:
return p
home = os.path.expanduser("~")
for candidate in (
f"{home}/.npm-global/bin/bw",
f"{home}/.nvm/versions/node/*/bin/bw",
"/usr/local/bin/bw",
"/opt/homebrew/bin/bw",
):
if "*" in candidate:
import glob
for m in glob.glob(candidate):
if os.path.isfile(m) and os.access(m, os.X_OK):
return m
elif os.path.isfile(candidate) and os.access(candidate, os.X_OK):
return candidate
return "bw" # fall back to PATH lookup (will FileNotFoundError, handled below)
def _load_config() -> dict:
if VAULT_FILE.exists():
try:
return json.loads(VAULT_FILE.read_text())
except Exception:
pass
return {}
def _save_config(cfg: dict):
VAULT_FILE.parent.mkdir(parents=True, exist_ok=True)
VAULT_FILE.write_text(json.dumps(cfg, indent=2))
try:
os.chmod(str(VAULT_FILE), 0o600)
except Exception:
pass
async def _run_bw(args: list, session: str = None, input_text: str = None) -> tuple:
env = {}
env.update(os.environ)
if session:
env["BW_SESSION"] = session
bw_path = _find_bw()
try:
proc = await asyncio.create_subprocess_exec(
bw_path, *args,
stdin=asyncio.subprocess.PIPE if input_text else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
env=env,
)
except FileNotFoundError:
return "", "bw CLI not installed (install `nodejs-bitwarden-cli` or `bitwarden-cli`)", 127
except Exception as e:
return "", f"Failed to launch bw: {e}", 1
try:
stdout, stderr = await proc.communicate(input=input_text.encode() if input_text else None)
except Exception as e:
return "", f"bw subprocess error: {e}", 1
return stdout.decode(errors="replace").strip(), stderr.decode(errors="replace").strip(), proc.returncode
class VaultConfig(BaseModel):
server_url: str = ""
email: str = ""
class VaultUnlockRequest(BaseModel):
master_password: str
class VaultLoginRequest(BaseModel):
email: str
master_password: str
def setup_vault_routes():
router = APIRouter(prefix="/api/vault", tags=["vault"])
@router.get("/config")
async def get_config(request: Request):
"""Return vault config (no sensitive fields)."""
require_admin(request)
cfg = _load_config()
return {
"server_url": cfg.get("server_url", ""),
"email": cfg.get("email", ""),
"unlocked": bool(cfg.get("session")),
"unlocked_at": cfg.get("unlocked_at", ""),
"bw_installed": await _check_bw_installed(),
}
@router.post("/config")
async def save_config(req: VaultConfig, request: Request):
"""Save vault URL + email. Runs 'bw config server' to point at Vaultwarden."""
require_admin(request)
cfg = _load_config()
cfg["server_url"] = req.server_url.strip().rstrip("/")
cfg["email"] = req.email.strip()
if cfg["server_url"]:
_, stderr, rc = await _run_bw(["config", "server", cfg["server_url"]])
if rc != 0:
return {"ok": False, "error": f"bw config failed: {stderr[:300]}"}
_save_config(cfg)
return {"ok": True}
@router.post("/login")
async def login(req: VaultLoginRequest, request: Request):
"""Log in to Vaultwarden (required once per account)."""
require_admin(request)
cfg = _load_config()
# Update email
cfg["email"] = req.email
_save_config(cfg)
stdout, stderr, rc = await _run_bw(
["login", req.email, "--raw"],
input_text=req.master_password + "\n",
)
if rc != 0:
# Already logged in is OK
if "already logged in" in stderr.lower():
return {"ok": True, "already": True}
return {"ok": False, "error": f"Login failed: {stderr[:300]}"}
# bw login --raw prints session key on success (when 2FA disabled)
if stdout:
cfg["session"] = stdout
cfg["unlocked_at"] = datetime.utcnow().isoformat()
_save_config(cfg)
return {"ok": True}
@router.post("/unlock")
async def unlock(req: VaultUnlockRequest, request: Request):
"""Unlock the vault and save the session key."""
require_admin(request)
stdout, stderr, rc = await _run_bw(
["unlock", req.master_password, "--raw"],
)
if rc != 0:
return {"ok": False, "error": f"Unlock failed: {stderr[:300]}"}
session = stdout.strip()
if not session:
return {"ok": False, "error": "bw returned empty session"}
cfg = _load_config()
cfg["session"] = session
cfg["unlocked_at"] = datetime.utcnow().isoformat()
_save_config(cfg)
return {"ok": True, "message": "Vault unlocked"}
@router.post("/lock")
async def lock(request: Request):
"""Lock the vault (clear session from config)."""
require_admin(request)
cfg = _load_config()
cfg.pop("session", None)
cfg.pop("unlocked_at", None)
_save_config(cfg)
# Also tell bw to lock
await _run_bw(["lock"])
return {"ok": True, "message": "Vault locked"}
@router.post("/logout")
async def logout(request: Request):
"""Log out of the Bitwarden CLI completely."""
require_admin(request)
await _run_bw(["logout"])
cfg = _load_config()
cfg.pop("session", None)
cfg.pop("email", None)
cfg.pop("unlocked_at", None)
_save_config(cfg)
return {"ok": True}
return router
async def _check_bw_installed() -> bool:
try:
proc = await asyncio.create_subprocess_exec(
_find_bw(), "--version",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await proc.communicate()
return proc.returncode == 0
except Exception:
return False

322
routes/webhook_routes.py Normal file
View File

@@ -0,0 +1,322 @@
"""Webhook, API Token, and sync chat routes."""
import asyncio
import uuid
import logging
from typing import Optional
import httpx
from fastapi import APIRouter, HTTPException, Request, Form
from pydantic import BaseModel, Field
from core.database import SessionLocal, Webhook
from src.webhook_manager import WebhookManager, validate_webhook_url, validate_events
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["webhooks"])
# Input limits
MAX_NAME_LEN = 100
MAX_URL_LEN = 2048
MAX_SECRET_LEN = 256
MAX_MESSAGE_LEN = 32_000
from core.middleware import require_admin as _require_admin
def setup_webhook_routes(
webhook_manager: WebhookManager,
auth_manager,
session_manager=None,
api_key_manager=None,
) -> APIRouter:
@router.get("/webhooks")
def list_webhooks(request: Request):
_require_admin(request)
db = SessionLocal()
try:
hooks = db.query(Webhook).all()
return [
{
"id": w.id,
"name": w.name,
"url": w.url,
"has_secret": bool(w.secret),
"events": w.events.split(",") if w.events else [],
"is_active": w.is_active,
"last_triggered_at": w.last_triggered_at.isoformat() if w.last_triggered_at else None,
"last_status_code": w.last_status_code,
"last_error": w.last_error,
"created_at": w.created_at.isoformat() if w.created_at else None,
}
for w in hooks
]
finally:
db.close()
@router.post("/webhooks")
def create_webhook(
request: Request,
name: str = Form(""),
url: str = Form(""),
secret: str = Form(""),
events: str = Form(""),
):
_require_admin(request)
name = name.strip()[:MAX_NAME_LEN]
if not name:
raise HTTPException(400, "Webhook name is required")
try:
url = validate_webhook_url(url)
except ValueError as e:
raise HTTPException(400, str(e))
try:
events = validate_events(events)
except ValueError as e:
raise HTTPException(400, str(e))
secret_val = secret.strip()[:MAX_SECRET_LEN] or None
# Encrypt the secret at rest using the same Fernet key as API keys
encrypted_secret = None
if secret_val and api_key_manager:
encrypted_secret = api_key_manager.encrypt_api_key(secret_val)
elif secret_val:
encrypted_secret = secret_val # Fallback if no encryption available
webhook_id = str(uuid.uuid4())[:8]
db = SessionLocal()
try:
db.add(Webhook(
id=webhook_id,
name=name,
url=url,
secret=encrypted_secret,
events=events,
is_active=True,
))
db.commit()
finally:
db.close()
return {"id": webhook_id, "name": name}
@router.post("/webhooks/{webhook_id}/test")
async def test_webhook(request: Request, webhook_id: str):
_require_admin(request)
db = SessionLocal()
try:
wh = db.query(Webhook).filter(Webhook.id == webhook_id).first()
if not wh:
raise HTTPException(404, "Webhook not found")
url, secret = wh.url, wh.secret
finally:
db.close()
await webhook_manager.deliver_test(webhook_id, url, secret)
return {"status": "sent"}
@router.patch("/webhooks/{webhook_id}")
def toggle_webhook(request: Request, webhook_id: str):
_require_admin(request)
db = SessionLocal()
try:
wh = db.query(Webhook).filter(Webhook.id == webhook_id).first()
if not wh:
raise HTTPException(404, "Webhook not found")
wh.is_active = not wh.is_active
db.commit()
return {"id": webhook_id, "is_active": wh.is_active}
finally:
db.close()
@router.delete("/webhooks/{webhook_id}")
def delete_webhook(request: Request, webhook_id: str):
_require_admin(request)
db = SessionLocal()
try:
deleted = db.query(Webhook).filter(Webhook.id == webhook_id).delete()
db.commit()
if not deleted:
raise HTTPException(404, "Webhook not found")
finally:
db.close()
return {"status": "deleted"}
# ================================================================
# Sync Chat Endpoint (for n8n / Make / Activepieces)
# ================================================================
# Known provider base URLs — auto-resolved from api_key prefix or model name
KNOWN_PROVIDERS = {
"deepseek": "https://api.deepseek.com/v1",
"openai": "https://api.openai.com/v1",
"mistral": "https://api.mistral.ai/v1",
"groq": "https://api.groq.com/openai/v1",
"together": "https://api.together.xyz/v1",
"openrouter": "https://openrouter.ai/api/v1",
"fireworks": "https://api.fireworks.ai/inference/v1",
}
# Model prefix → provider mapping for auto-detection
MODEL_PROVIDER_MAP = {
"deepseek": "deepseek",
"gpt-": "openai",
"o1": "openai",
"o3": "openai",
"o4": "openai",
"mistral": "mistral",
"llama": "groq",
"mixtral": "groq",
}
def _resolve_base_url(model: Optional[str], provider: Optional[str]) -> Optional[str]:
"""Try to auto-resolve a base URL from provider name or model prefix."""
if provider and provider.lower() in KNOWN_PROVIDERS:
return KNOWN_PROVIDERS[provider.lower()]
if model:
model_lower = model.lower()
for prefix, prov in MODEL_PROVIDER_MAP.items():
if model_lower.startswith(prefix):
return KNOWN_PROVIDERS[prov]
return None
class SyncChatRequest(BaseModel):
message: str = Field(..., max_length=MAX_MESSAGE_LEN)
model: Optional[str] = Field(None, max_length=200)
session: Optional[str] = Field(None, max_length=100)
api_key: Optional[str] = Field(None, max_length=256)
base_url: Optional[str] = Field(None, max_length=MAX_URL_LEN)
provider: Optional[str] = Field(None, max_length=50)
@router.post("/v1/chat")
async def sync_chat(request: Request, body: SyncChatRequest):
if not getattr(request.state, "api_token", False):
raise HTTPException(403, "This endpoint requires an API token")
scopes = set(getattr(request.state, "api_token_scopes", []) or [])
if "chat" not in scopes:
raise HTTPException(403, "API token is not scoped for chat")
token_owner = getattr(request.state, "api_token_owner", None)
from core.models import ChatMessage
from src.llm_core import llm_call_async
from core.database import ModelEndpoint
message = body.message.strip()
if not message:
raise HTTPException(400, "Message is required")
session_id = body.session
sess = None
# --- Case 1: Resume an existing session ---
if session_id and session_manager:
try:
sess = session_manager.get_session(session_id)
except (KeyError, Exception):
raise HTTPException(404, "Session not found")
# SECURITY: verify the API-token's user owns this session — without
# this any token holder could resume any user's chat by passing its
# ID. The token's user is on request.state.user (set by API-token
# middleware); fall back to require_user if not present.
try:
from src.auth_helpers import get_current_user as _gcu
_tok_user = token_owner or getattr(request.state, "user", None) or _gcu(request)
except Exception:
_tok_user = None
_sess_owner = getattr(sess, "owner", None)
if _tok_user and _sess_owner and _sess_owner != _tok_user:
raise HTTPException(404, "Session not found")
# --- Case 2: Direct API key + model (no pre-configured endpoint needed) ---
if not sess and body.api_key:
api_key = body.api_key.strip()
model = body.model or "deepseek-chat"
# Resolve base_url: explicit > provider name > model prefix auto-detect
base_url = body.base_url.strip().rstrip("/") if body.base_url else None
if not base_url:
base_url = _resolve_base_url(model, body.provider)
if not base_url:
raise HTTPException(400,
"Could not auto-detect provider. Pass base_url (e.g. 'https://api.deepseek.com/v1') "
"or provider ('deepseek', 'openai', 'groq', etc.)")
endpoint_url = base_url + "/chat/completions"
if not session_manager:
raise HTTPException(500, "Session manager not available")
sid = str(uuid.uuid4())
sess = session_manager.create_session(
session_id=sid, name="API Chat", endpoint_url=endpoint_url,
model=model, owner=token_owner,
)
sess.headers = {"Authorization": f"Bearer {api_key}"}
session_manager.save_sessions()
session_id = sid
# --- Case 3: Fall back to first configured ModelEndpoint ---
if not sess:
db = SessionLocal()
try:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).first()
finally:
db.close()
if not ep:
raise HTTPException(400,
"No session, api_key, or configured endpoints. "
"Pass api_key + model, or configure an endpoint in Admin.")
endpoint_url = ep.base_url.rstrip("/") + "/chat/completions"
model = body.model or "auto"
api_key = ep.api_key
if model == "auto":
try:
async with httpx.AsyncClient(timeout=5) as client:
models_url = ep.base_url.rstrip("/") + "/models"
hdrs = {"Authorization": f"Bearer {api_key}"} if api_key else {}
resp = await client.get(models_url, headers=hdrs)
resp.raise_for_status()
ids = [m.get("id") for m in (resp.json().get("data") or []) if m.get("id")]
model = ids[0] if ids else "auto"
except Exception:
raise HTTPException(500, "Could not discover models from endpoint")
if not session_manager:
raise HTTPException(500, "Session manager not available")
sid = str(uuid.uuid4())
sess = session_manager.create_session(
session_id=sid, name="API Chat", endpoint_url=endpoint_url,
model=model, owner=token_owner,
)
if api_key:
sess.headers = {"Authorization": f"Bearer {api_key}"}
session_manager.save_sessions()
session_id = sid
# --- Send message and get response ---
sess.add_message(ChatMessage("user", message))
messages = [{"role": m.role, "content": m.content} for m in sess.history]
reply = await llm_call_async(
sess.endpoint_url, sess.model, messages,
headers=sess.headers, timeout=120,
)
sess.add_message(ChatMessage("assistant", reply))
session_manager.save_sessions()
asyncio.create_task(webhook_manager.fire("chat.completed", {
"session_id": session_id, "model": sess.model,
"user_message": message[:2000], "response": reply[:2000],
}))
return {"response": reply, "session_id": session_id, "model": sess.model}
return router