Harden session endpoint owner scope (#1308)

This commit is contained in:
Vykos
2026-06-02 19:40:22 +02:00
committed by GitHub
parent 80de69ebb0
commit 4771d80eb2
6 changed files with 261 additions and 71 deletions

View File

@@ -148,8 +148,9 @@ async def auto_name_session(session_manager, sess):
if not first_msg: if not first_msg:
return return
owner = getattr(sess, "owner", None)
t_url, t_model, t_headers = resolve_task_endpoint( t_url, t_model, t_headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers, sess.endpoint_url, sess.model, sess.headers, owner=owner,
) )
if not t_model: if not t_model:
logger.debug("[auto-name] No model provided, skipping") logger.debug("[auto-name] No model provided, skipping")
@@ -311,7 +312,24 @@ def fire_message_event(request, webhook_manager, session_id: str, sess, message:
fire_event("message_sent", user) fire_event("message_sent", user)
def resolve_session_auth(sess, session_id: str): def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
if not session_url or not endpoint_base:
return False
try:
from src.endpoint_resolver import build_chat_url, normalize_base
sess_url = session_url.rstrip("/")
base = normalize_base(endpoint_base).rstrip("/")
return sess_url in {
base,
base + "/chat/completions",
build_chat_url(base).rstrip("/"),
}
except Exception:
return False
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
"""Ensure session has auth headers — resolve from endpoint DB if missing.""" """Ensure session has auth headers — resolve from endpoint DB if missing."""
has_auth = sess.headers and isinstance(sess.headers, dict) and any( has_auth = sess.headers and isinstance(sess.headers, dict) and any(
k.lower() in ('authorization', 'x-api-key') for k in sess.headers k.lower() in ('authorization', 'x-api-key') for k in sess.headers
@@ -320,19 +338,33 @@ def resolve_session_auth(sess, session_id: str):
return return
try: try:
from src.endpoint_resolver import build_headers from src.endpoint_resolver import build_headers, normalize_base
db = SessionLocal() db = SessionLocal()
try: try:
domain = sess.endpoint_url.split("//")[1].split("/")[0] if "//" in sess.endpoint_url else "" target_url = getattr(sess, "endpoint_url", "") or ""
if domain: if not target_url:
ep = db.query(ModelEndpoint).filter(ModelEndpoint.base_url.contains(domain)).first() return
if ep and ep.api_key: q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
sess.headers = build_headers(ep.api_key, ep.base_url) if owner:
db.query(DBSession).filter(DBSession.id == session_id).update( # Missing headers usually means "recover from the saved endpoint".
{"headers": json.dumps(sess.headers)} # Scope that lookup to the session owner, otherwise two users
) # with similar endpoint URLs can borrow each other's API key.
db.commit() from src.auth_helpers import owner_filter
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}") q = owner_filter(q, ModelEndpoint, owner)
for ep in q.all():
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
continue
if not ep.api_key:
return
base = normalize_base(ep.base_url or "")
sess.headers = build_headers(ep.api_key, base)
update_q = db.query(DBSession).filter(DBSession.id == session_id)
if owner:
update_q = update_q.filter(DBSession.owner == owner)
update_q.update({"headers": sess.headers})
db.commit()
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
return
finally: finally:
db.close() db.close()
except Exception as e: except Exception as e:
@@ -806,7 +838,7 @@ def run_post_response_tasks(
from services.memory.memory_extractor import extract_and_store from services.memory.memory_extractor import extract_and_store
from src.task_endpoint import resolve_task_endpoint from src.task_endpoint import resolve_task_endpoint
t_url, t_model, t_headers = resolve_task_endpoint( t_url, t_model, t_headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers, sess.endpoint_url, sess.model, sess.headers, owner=owner,
) )
asyncio.create_task(extract_and_store( asyncio.create_task(extract_and_store(
sess, memory_manager, memory_vector, sess, memory_manager, memory_vector,
@@ -843,7 +875,7 @@ def run_post_response_tasks(
from services.memory.skill_extractor import maybe_extract_skill from services.memory.skill_extractor import maybe_extract_skill
from src.task_endpoint import resolve_task_endpoint from src.task_endpoint import resolve_task_endpoint
s_url, s_model, s_headers = resolve_task_endpoint( s_url, s_model, s_headers = resolve_task_endpoint(
sess.endpoint_url, sess.model, sess.headers, sess.endpoint_url, sess.model, sess.headers, owner=owner,
) )
logger.debug("[skill-extract] dispatching extractor (model=%s)", s_model) logger.debug("[skill-extract] dispatching extractor (model=%s)", s_model)
asyncio.create_task(maybe_extract_skill( asyncio.create_task(maybe_extract_skill(

View File

@@ -72,13 +72,17 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
return sess in variants or sess.startswith(base + "/") return sess in variants or sess.startswith(base + "/")
def _clear_orphaned_session_endpoint(sess) -> bool: def _clear_orphaned_session_endpoint(sess, owner: str | None = None) -> bool:
"""Clear a session model if its endpoint was deleted from ModelEndpoint.""" """Clear a session model if its endpoint was deleted from ModelEndpoint."""
if not getattr(sess, "endpoint_url", ""): if not getattr(sess, "endpoint_url", ""):
return False return False
db = SessionLocal() db = SessionLocal()
try: try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for ep in endpoints: for ep in endpoints:
if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""): if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""):
return False return False
@@ -118,7 +122,7 @@ def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
return wanted in {str(item).strip() for item in models} return wanted in {str(item).strip() for item in models}
def _is_image_generation_session(sess) -> bool: def _is_image_generation_session(sess, owner: str | None = None) -> bool:
"""Whether this chat session should bypass text chat and generate images. """Whether this chat session should bypass text chat and generate images.
Model-name prefixes are explicit image models. Endpoint type is only used Model-name prefixes are explicit image models. Endpoint type is only used
@@ -137,7 +141,11 @@ def _is_image_generation_session(sess) -> bool:
db = SessionLocal() db = SessionLocal()
try: try:
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for endpoint in endpoints: for endpoint in endpoints:
if (getattr(endpoint, "model_type", None) or "llm") != "image": if (getattr(endpoint, "model_type", None) or "llm") != "image":
continue continue
@@ -152,7 +160,7 @@ def _is_image_generation_session(sess) -> bool:
return False return False
def _recover_empty_session_model(sess, session_id: str) -> bool: def _recover_empty_session_model(sess, session_id: str, owner: str | None = None) -> bool:
"""Re-populate sess.model from the matching endpoint's cached models. """Re-populate sess.model from the matching endpoint's cached models.
Covers the window between endpoint setup and the first chat send: the Covers the window between endpoint setup and the first chat send: the
@@ -172,7 +180,11 @@ def _recover_empty_session_model(sess, session_id: str) -> bool:
# cached model is the most defensible default. # cached model is the most defensible default.
ep = None ep = None
if getattr(sess, "endpoint_url", ""): if getattr(sess, "endpoint_url", ""):
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all() q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
endpoints = q.all()
for cand in endpoints: for cand in endpoints:
if _session_url_matches_endpoint(sess.endpoint_url or "", cand.base_url or ""): if _session_url_matches_endpoint(sess.endpoint_url or "", cand.base_url or ""):
ep = cand ep = cand
@@ -251,13 +263,14 @@ def setup_chat_routes(
sess = session_manager.get_session(session) sess = session_manager.get_session(session)
except KeyError: except KeyError:
raise HTTPException(404, f"Session '{session}' not found") raise HTTPException(404, f"Session '{session}' not found")
if _clear_orphaned_session_endpoint(sess): owner = get_current_user(request)
if _clear_orphaned_session_endpoint(sess, owner=owner):
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.") raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
# Empty model + live endpoint = setup race (Issue #587). Repair from # Empty model + live endpoint = setup race (Issue #587). Repair from
# the endpoint's cached model list before privilege checks, which # the endpoint's cached model list before privilege checks, which
# otherwise see "" and behave inconsistently with the allowlist. # otherwise see "" and behave inconsistently with the allowlist.
_recover_empty_session_model(sess, session) _recover_empty_session_model(sess, session, owner=owner)
if not getattr(sess, "model", "").strip(): if not getattr(sess, "model", "").strip():
raise HTTPException( raise HTTPException(
400, 400,
@@ -401,7 +414,8 @@ def setup_chat_routes(
# but BEFORE loading. Prevents cross-user session hijack. # but BEFORE loading. Prevents cross-user session hijack.
_verify_session_owner(request, session) _verify_session_owner(request, session)
sess = session_manager.get_session(session) sess = session_manager.get_session(session)
if _clear_orphaned_session_endpoint(sess): owner = get_current_user(request)
if _clear_orphaned_session_endpoint(sess, owner=owner):
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.") raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
# Issue #587: picker shows a model from the endpoint cache but # Issue #587: picker shows a model from the endpoint cache but
# s.model never made it onto the DB row (first-send race after # s.model never made it onto the DB row (first-send race after
@@ -409,7 +423,7 @@ def setup_chat_routes(
# the first cached model off the matching endpoint so the # the first cached model off the matching endpoint so the
# upstream isn't called with model="" (which surfaces as a # upstream isn't called with model="" (which surfaces as a
# generic 401/503). # generic 401/503).
_recover_empty_session_model(sess, session) _recover_empty_session_model(sess, session, owner=owner)
if not getattr(sess, "model", "").strip(): if not getattr(sess, "model", "").strip():
raise HTTPException( raise HTTPException(
400, 400,
@@ -431,7 +445,7 @@ def setup_chat_routes(
_enforce_chat_privileges(request, sess) _enforce_chat_privileges(request, sess)
# Ensure session has auth headers # Ensure session has auth headers
resolve_session_auth(sess, session) resolve_session_auth(sess, session, owner=get_current_user(request))
# Check for research_pending BEFORE mode persist overwrites it # Check for research_pending BEFORE mode persist overwrites it
do_research = str(use_research).lower() == "true" do_research = str(use_research).lower() == "true"
@@ -768,7 +782,7 @@ def setup_chat_routes(
# output. Resolved once per request. # output. Resolved once per request.
try: try:
from src.endpoint_resolver import resolve_chat_fallback_candidates from src.endpoint_resolver import resolve_chat_fallback_candidates
_fallback_candidates = resolve_chat_fallback_candidates() _fallback_candidates = resolve_chat_fallback_candidates(owner=_user)
except Exception: except Exception:
_fallback_candidates = [] _fallback_candidates = []
@@ -781,7 +795,7 @@ def setup_chat_routes(
_model_info["character_name"] = ctx.preset.character_name _model_info["character_name"] = ctx.preset.character_name
yield f'data: {json.dumps(_model_info)}\n\n' yield f'data: {json.dumps(_model_info)}\n\n'
if _is_image_generation_session(sess): if _is_image_generation_session(sess, owner=_user):
from src.settings import get_setting from src.settings import get_setting
if not get_setting("image_gen_enabled", True): if not get_setting("image_gen_enabled", True):
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n' yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
@@ -792,7 +806,7 @@ def setup_chat_routes(
_user_msg = message or "" _user_msg = message or ""
yield f'data: {json.dumps({"type": "tool_start", "tool": "generate_image", "command": _user_msg[:100]})}\n\n' yield f'data: {json.dumps({"type": "tool_start", "tool": "generate_image", "command": _user_msg[:100]})}\n\n'
yield ": heartbeat\n\n" yield ": heartbeat\n\n"
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session) _img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session, owner=_user)
_img_output = _img_result.get("results", _img_result.get("error", "")) _img_output = _img_result.get("results", _img_result.get("error", ""))
_img_tool_data = {"type": "tool_output", "tool": "generate_image", "command": _user_msg[:100], "output": _img_output, "exit_code": 0 if "error" not in _img_result else 1} _img_tool_data = {"type": "tool_output", "tool": "generate_image", "command": _user_msg[:100], "output": _img_output, "exit_code": 0 if "error" not in _img_result else 1}
for _k in ("image_url", "image_id", "image_prompt", "image_model", "image_size", "image_quality"): for _k in ("image_url", "image_id", "image_prompt", "image_model", "image_size", "image_quality"):

View File

@@ -58,23 +58,71 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["sessions"]) router = APIRouter(prefix="/api", tags=["sessions"])
def _pick_endpoint_for_sort():
def _current_user_is_admin(request: Request, user: str | None) -> bool:
if not user:
return False
auth_mgr = getattr(request.app.state, "auth_manager", None)
is_admin = getattr(auth_mgr, "is_admin", None)
if not callable(is_admin):
return False
try:
return bool(is_admin(user))
except Exception:
return False
def _reject_raw_endpoint_url_for_non_admin(
request: Request,
user: str | None,
endpoint_id: str | None,
endpoint_url: str | None,
) -> None:
"""Require registered endpoints for signed-in non-admin session changes."""
if endpoint_id and endpoint_id.strip():
return
if not endpoint_url:
return
# Raw URLs make the server dial whatever host the request supplies. For
# non-admin users, require a saved endpoint row so normal owner scoping and
# endpoint validation have already happened.
if user and not _current_user_is_admin(request, user):
raise HTTPException(403, "Choose a registered model endpoint")
def _persist_session_headers(session_id: str, headers: dict | None) -> None:
"""Persist endpoint auth headers for DB-backed session metadata."""
db = SessionLocal()
try:
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
if db_session:
db_session.headers = headers or {}
db_session.updated_at = datetime.utcnow()
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()
def _pick_endpoint_for_sort(owner=None):
"""Pick model endpoint for auto-sort LLM call — uses utility endpoint setting, falls back to default.""" """Pick model endpoint for auto-sort LLM call — uses utility endpoint setting, falls back to default."""
from src.endpoint_resolver import resolve_endpoint from src.endpoint_resolver import resolve_endpoint
# Try utility endpoint first (what the user configured for background tasks) # Try utility endpoint first (what the user configured for background tasks)
url, model, headers = resolve_endpoint("utility") url, model, headers = resolve_endpoint("utility", owner=owner)
if url and model: if url and model:
return url, model, headers return url, model, headers
# Fall back to task endpoint # Fall back to task endpoint
try: try:
from src.task_endpoint import resolve_task_endpoint from src.task_endpoint import resolve_task_endpoint
url, model, headers = resolve_task_endpoint() url, model, headers = resolve_task_endpoint(owner=owner)
if url and model: if url and model:
return url, model, headers return url, model, headers
except Exception: except Exception:
pass pass
# Fall back to default # Fall back to default
url, model, headers = resolve_endpoint("default") url, model, headers = resolve_endpoint("default", owner=owner)
if url and model: if url and model:
return url, model, headers return url, model, headers
return None, None, None return None, None, None
@@ -197,11 +245,41 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
endpoint_id: str = Form(""), endpoint_id: str = Form(""),
): ):
skip_val = str(skip_validation).lower() == "true" skip_val = str(skip_validation).lower() == "true"
user = get_current_user(request)
endpoint_api_key = ""
endpoint_base_url = ""
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
if endpoint_id and endpoint_id.strip():
from core.database import ModelEndpoint
from src.auth_helpers import owner_filter
from src.endpoint_resolver import build_chat_url, normalize_base
_db = SessionLocal()
try:
q = _db.query(ModelEndpoint).filter(
ModelEndpoint.id == endpoint_id.strip(),
ModelEndpoint.is_enabled == True,
)
if user:
q = owner_filter(q, ModelEndpoint, user)
endpoint_row = q.first()
if not endpoint_row:
raise HTTPException(400, "Model endpoint no longer exists")
endpoint_base_url = endpoint_row.base_url or ""
endpoint_api_key = endpoint_row.api_key or ""
endpoint_url = build_chat_url(normalize_base(endpoint_base_url))
finally:
_db.close()
if not endpoint_url and not skip_val: if not endpoint_url and not skip_val:
raise HTTPException(400, "endpoint_url is required (choose from /api/models)") raise HTTPException(400, "endpoint_url is required (choose from /api/models)")
model_to_use = model model_to_use = model
request_api_key = api_key.strip() if api_key else ""
effective_api_key = request_api_key or endpoint_api_key
validation_headers = None
if effective_api_key:
from src.endpoint_resolver import build_headers
validation_headers = build_headers(effective_api_key, endpoint_base_url or endpoint_url)
if skip_val: if skip_val:
# skip_validation = trust the caller and do NOT probe /v1/models. # skip_validation = trust the caller and do NOT probe /v1/models.
@@ -212,7 +290,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
elif not model_to_use: elif not model_to_use:
from src.llm_core import list_model_ids from src.llm_core import list_model_ids
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT, ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
headers={"Authorization": f"Bearer {api_key}"} if api_key.strip() else None) headers=validation_headers)
if not ids: if not ids:
raise HTTPException(400, "Cannot reach /v1/models") raise HTTPException(400, "Cannot reach /v1/models")
# Default to the first CHAT model — endpoints often list embedding/ # Default to the first CHAT model — endpoints often list embedding/
@@ -227,7 +305,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
import os as _os import os as _os
req_base = _os.path.basename(model_to_use.rstrip("/")) req_base = _os.path.basename(model_to_use.rstrip("/"))
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT, avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
headers={"Authorization": f"Bearer {api_key}"} if api_key.strip() else None) headers=validation_headers)
if not avail: if not avail:
raise HTTPException(400, "Cannot reach /v1/models") raise HTTPException(400, "Cannot reach /v1/models")
if model_to_use not in avail: if model_to_use not in avail:
@@ -252,22 +330,15 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
owner=user, owner=user,
) )
# Set auth headers for custom API-key endpoints # Set auth headers for custom API-key endpoints
resolved_key = api_key.strip() if api_key else "" resolved_key = request_api_key
resolved_base = endpoint_url resolved_base = endpoint_url
if not resolved_key and endpoint_id and endpoint_id.strip(): if not resolved_key and endpoint_api_key:
from core.database import ModelEndpoint resolved_key = endpoint_api_key
_db = SessionLocal() resolved_base = endpoint_base_url
try:
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id.strip()).first()
if ep and ep.api_key:
resolved_key = ep.api_key
resolved_base = ep.base_url
finally:
_db.close()
if resolved_key: if resolved_key:
from src.endpoint_resolver import build_headers from src.endpoint_resolver import build_headers
session.headers = build_headers(resolved_key, resolved_base) session.headers = build_headers(resolved_key, resolved_base)
session_manager.save_sessions() _persist_session_headers(sid, session.headers)
# Fire webhook (sync-safe) # Fire webhook (sync-safe)
if webhook_manager: if webhook_manager:
webhook_manager.fire_and_forget("session.created", { webhook_manager.fire_and_forget("session.created", {
@@ -313,27 +384,38 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
db.close() db.close()
# Switch model/endpoint mid-session # Switch model/endpoint mid-session
if model is not None and endpoint_url is not None: if model is not None and endpoint_url is not None:
user = get_current_user(request)
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
endpoint_api_key = ""
endpoint_base_url = ""
if endpoint_id: if endpoint_id:
from core.database import ModelEndpoint from core.database import ModelEndpoint
from src.auth_helpers import owner_filter
from src.endpoint_resolver import build_chat_url, normalize_base
_db = SessionLocal() _db = SessionLocal()
try: try:
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first() q = _db.query(ModelEndpoint).filter(
ModelEndpoint.id == endpoint_id,
ModelEndpoint.is_enabled == True,
)
if user:
q = owner_filter(q, ModelEndpoint, user)
ep = q.first()
if not ep: if not ep:
raise HTTPException(400, "Model endpoint no longer exists") raise HTTPException(400, "Model endpoint no longer exists")
endpoint_base_url = ep.base_url or ""
endpoint_api_key = ep.api_key or ""
endpoint_url = build_chat_url(normalize_base(endpoint_base_url))
finally: finally:
_db.close() _db.close()
session.model = model session.model = model
session.endpoint_url = endpoint_url session.endpoint_url = endpoint_url
# Update auth headers from the endpoint's stored API key # Update auth headers from the endpoint's stored API key
if endpoint_id: if endpoint_api_key:
_db = SessionLocal() from src.endpoint_resolver import build_headers
try: session.headers = build_headers(endpoint_api_key, endpoint_base_url)
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first() else:
if ep and ep.api_key: session.headers = {}
from src.endpoint_resolver import build_headers
session.headers = build_headers(ep.api_key, ep.base_url)
finally:
_db.close()
# Persist to DB # Persist to DB
db = SessionLocal() db = SessionLocal()
try: try:
@@ -341,6 +423,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
if db_session: if db_session:
db_session.model = model db_session.model = model
db_session.endpoint_url = endpoint_url db_session.endpoint_url = endpoint_url
db_session.headers = session.headers or {}
db_session.updated_at = datetime.utcnow() db_session.updated_at = datetime.utcnow()
db.commit() db.commit()
finally: finally:
@@ -754,7 +837,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
from src.endpoint_resolver import resolve_endpoint from src.endpoint_resolver import resolve_endpoint
from src.llm_core import llm_call_async from src.llm_core import llm_call_async
url, model, headers = resolve_endpoint("utility") url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
if not url or not model: if not url or not model:
url, model, headers = session.endpoint_url, session.model, session.headers url, model, headers = session.endpoint_url, session.model, session.headers
if not url or not model: if not url or not model:
@@ -954,9 +1037,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
# Pick an endpoint — prefer admin-configured task endpoint # Pick an endpoint — prefer admin-configured task endpoint
from src.task_endpoint import resolve_task_endpoint from src.task_endpoint import resolve_task_endpoint
url, model, headers = resolve_task_endpoint() url, model, headers = resolve_task_endpoint(owner=user)
if not url: if not url:
url, model, headers = _pick_endpoint_for_sort() url, model, headers = _pick_endpoint_for_sort(owner=user)
if not url: if not url:
raise HTTPException(503, "No available model endpoint for auto-sort") raise HTTPException(503, "No available model endpoint for auto-sort")

View File

@@ -234,7 +234,7 @@ def resolve_endpoint(
ep_id = _stg(f"{setting_prefix}_endpoint_id") ep_id = _stg(f"{setting_prefix}_endpoint_id")
model = _stg(f"{setting_prefix}_model") model = _stg(f"{setting_prefix}_model")
# If the specific endpoint is not configured, but the caller provided a # If the specific endpoint is not configured, but the caller provided a
# valid fallback (e.g. the active session model), use that immediately. # valid fallback (e.g. the active session model), use that immediately.
# This prevents background tasks from jumping to the global default_model # This prevents background tasks from jumping to the global default_model
# when the user is mid-conversation with a different model. # when the user is mid-conversation with a different model.
@@ -295,7 +295,7 @@ def resolve_endpoint(
def resolve_endpoint_by_id( def resolve_endpoint_by_id(
ep_id: str, model: Optional[str] = None ep_id: str, model: Optional[str] = None, owner: Optional[str] = None
) -> Optional[Tuple[str, str, Dict]]: ) -> Optional[Tuple[str, str, Dict]]:
"""Resolve a specific endpoint id (+ optional model) to (chat_url, model, headers). """Resolve a specific endpoint id (+ optional model) to (chat_url, model, headers).
@@ -306,10 +306,14 @@ def resolve_endpoint_by_id(
return None return None
db = SessionLocal() db = SessionLocal()
try: try:
ep = db.query(ModelEndpoint).filter( q = db.query(ModelEndpoint).filter(
ModelEndpoint.id == ep_id, ModelEndpoint.id == ep_id,
ModelEndpoint.is_enabled == True, ModelEndpoint.is_enabled == True,
).first() )
if owner:
from src.auth_helpers import owner_filter
q = owner_filter(q, ModelEndpoint, owner)
ep = q.first()
if not ep: if not ep:
return None return None
base = normalize_base(ep.base_url) base = normalize_base(ep.base_url)
@@ -332,14 +336,14 @@ def resolve_endpoint_by_id(
db.close() db.close()
def resolve_chat_fallback_candidates() -> list: def resolve_chat_fallback_candidates(owner: Optional[str] = None) -> list:
"""Build the configured default-chat fallback chain as a list of """Build the configured default-chat fallback chain as a list of
(chat_url, model, headers) tuples, skipping any that can't resolve. (chat_url, model, headers) tuples, skipping any that can't resolve.
The primary model is NOT included — callers prepend their session's The primary model is NOT included — callers prepend their session's
current (url, model, headers) so per-session model overrides are honored. current (url, model, headers) so per-session model overrides are honored.
""" """
return _resolve_fallback_candidates("default_model_fallbacks") return _resolve_fallback_candidates("default_model_fallbacks", owner=owner)
def resolve_utility_fallback_candidates(owner: Optional[str] = None) -> list: def resolve_utility_fallback_candidates(owner: Optional[str] = None) -> list:
@@ -355,9 +359,9 @@ def resolve_utility_fallback_candidates(owner: Optional[str] = None) -> list:
return _resolve_fallback_candidates("utility_model_fallbacks", owner=owner) return _resolve_fallback_candidates("utility_model_fallbacks", owner=owner)
def resolve_vision_fallback_candidates() -> list: def resolve_vision_fallback_candidates(owner: Optional[str] = None) -> list:
"""Configured fallback chain for the Vision model (`vision_model_fallbacks`).""" """Configured fallback chain for the Vision model (`vision_model_fallbacks`)."""
return _resolve_fallback_candidates("vision_model_fallbacks") return _resolve_fallback_candidates("vision_model_fallbacks", owner=owner)
def _resolve_fallback_candidates(setting_key: str, owner: Optional[str] = None) -> list: def _resolve_fallback_candidates(setting_key: str, owner: Optional[str] = None) -> list:
@@ -371,7 +375,7 @@ def _resolve_fallback_candidates(setting_key: str, owner: Optional[str] = None)
for entry in chain: for entry in chain:
if not isinstance(entry, dict): if not isinstance(entry, dict):
continue continue
resolved = resolve_endpoint_by_id(entry.get("endpoint_id", ""), entry.get("model", "")) resolved = resolve_endpoint_by_id(entry.get("endpoint_id", ""), entry.get("model", ""), owner=owner)
if resolved: if resolved:
out.append(resolved) out.append(resolved)
return out return out

View File

@@ -3,11 +3,11 @@
from src.endpoint_resolver import resolve_endpoint from src.endpoint_resolver import resolve_endpoint
def resolve_task_endpoint(fallback_url=None, fallback_model=None, fallback_headers=None): def resolve_task_endpoint(fallback_url=None, fallback_model=None, fallback_headers=None, owner=None):
"""Return (endpoint_url, model, headers) for background tasks. """Return (endpoint_url, model, headers) for background tasks.
Reads task_endpoint_id / task_model from admin settings. Reads task_endpoint_id / task_model from admin settings.
Falls back to the provided values when the setting is empty or the Falls back to the provided values when the setting is empty or the
endpoint cannot be resolved. endpoint cannot be resolved.
""" """
return resolve_endpoint("task", fallback_url, fallback_model, fallback_headers) return resolve_endpoint("task", fallback_url, fallback_model, fallback_headers, owner=owner)

View File

@@ -0,0 +1,57 @@
from pathlib import Path
from types import SimpleNamespace
import pytest
from fastapi import HTTPException
# Import the route helper during collection so sibling session tests that use
# partial import stubs do not become the first loader of core.session_manager.
from routes.session_routes import _reject_raw_endpoint_url_for_non_admin
def _request(user, *, admin=False):
auth_manager = SimpleNamespace(is_admin=lambda username: bool(admin))
return SimpleNamespace(
state=SimpleNamespace(current_user=user),
app=SimpleNamespace(state=SimpleNamespace(auth_manager=auth_manager)),
)
def test_non_admin_session_create_rejects_raw_endpoint_url_without_endpoint_id():
with pytest.raises(HTTPException) as exc:
_reject_raw_endpoint_url_for_non_admin(
_request("alice", admin=False),
"alice",
"",
"http://169.254.169.254/latest/meta-data",
)
assert exc.value.status_code == 403
def test_admin_and_registered_endpoint_can_use_endpoint_url():
_reject_raw_endpoint_url_for_non_admin(
_request("alice", admin=False),
"alice",
"endpoint-id",
"http://127.0.0.1:8000/v1/chat/completions",
)
_reject_raw_endpoint_url_for_non_admin(
_request("admin", admin=True),
"admin",
"",
"http://127.0.0.1:8000/v1/chat/completions",
)
def test_chat_endpoint_recovery_paths_are_owner_scoped():
root = Path(__file__).resolve().parents[1]
chat_routes = (root / "routes" / "chat_routes.py").read_text(encoding="utf-8")
chat_helpers = (root / "routes" / "chat_helpers.py").read_text(encoding="utf-8")
assert "def _clear_orphaned_session_endpoint(sess, owner:" in chat_routes
assert "def _recover_empty_session_model(sess, session_id: str, owner:" in chat_routes
assert "q = owner_filter(q, ModelEndpoint, owner)" in chat_routes
assert "resolve_session_auth(sess, session, owner=get_current_user(request))" in chat_routes
assert "def resolve_session_auth(sess, session_id: str, owner:" in chat_helpers
assert "update_q = update_q.filter(DBSession.owner == owner)" in chat_helpers