Harden session endpoint owner scope (#1308)
This commit is contained in:
@@ -148,8 +148,9 @@ async def auto_name_session(session_manager, sess):
|
|||||||
if not first_msg:
|
if not first_msg:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
owner = getattr(sess, "owner", None)
|
||||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||||
sess.endpoint_url, sess.model, sess.headers,
|
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||||
)
|
)
|
||||||
if not t_model:
|
if not t_model:
|
||||||
logger.debug("[auto-name] No model provided, skipping")
|
logger.debug("[auto-name] No model provided, skipping")
|
||||||
@@ -311,7 +312,24 @@ def fire_message_event(request, webhook_manager, session_id: str, sess, message:
|
|||||||
fire_event("message_sent", user)
|
fire_event("message_sent", user)
|
||||||
|
|
||||||
|
|
||||||
def resolve_session_auth(sess, session_id: str):
|
def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
||||||
|
if not session_url or not endpoint_base:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
from src.endpoint_resolver import build_chat_url, normalize_base
|
||||||
|
|
||||||
|
sess_url = session_url.rstrip("/")
|
||||||
|
base = normalize_base(endpoint_base).rstrip("/")
|
||||||
|
return sess_url in {
|
||||||
|
base,
|
||||||
|
base + "/chat/completions",
|
||||||
|
build_chat_url(base).rstrip("/"),
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_session_auth(sess, session_id: str, owner: Optional[str] = None):
|
||||||
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
"""Ensure session has auth headers — resolve from endpoint DB if missing."""
|
||||||
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
has_auth = sess.headers and isinstance(sess.headers, dict) and any(
|
||||||
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
k.lower() in ('authorization', 'x-api-key') for k in sess.headers
|
||||||
@@ -320,19 +338,33 @@ def resolve_session_auth(sess, session_id: str):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.endpoint_resolver import build_headers
|
from src.endpoint_resolver import build_headers, normalize_base
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
domain = sess.endpoint_url.split("//")[1].split("/")[0] if "//" in sess.endpoint_url else ""
|
target_url = getattr(sess, "endpoint_url", "") or ""
|
||||||
if domain:
|
if not target_url:
|
||||||
ep = db.query(ModelEndpoint).filter(ModelEndpoint.base_url.contains(domain)).first()
|
return
|
||||||
if ep and ep.api_key:
|
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||||
sess.headers = build_headers(ep.api_key, ep.base_url)
|
if owner:
|
||||||
db.query(DBSession).filter(DBSession.id == session_id).update(
|
# Missing headers usually means "recover from the saved endpoint".
|
||||||
{"headers": json.dumps(sess.headers)}
|
# Scope that lookup to the session owner, otherwise two users
|
||||||
)
|
# with similar endpoint URLs can borrow each other's API key.
|
||||||
db.commit()
|
from src.auth_helpers import owner_filter
|
||||||
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
for ep in q.all():
|
||||||
|
if not _session_url_matches_endpoint(target_url, ep.base_url or ""):
|
||||||
|
continue
|
||||||
|
if not ep.api_key:
|
||||||
|
return
|
||||||
|
base = normalize_base(ep.base_url or "")
|
||||||
|
sess.headers = build_headers(ep.api_key, base)
|
||||||
|
update_q = db.query(DBSession).filter(DBSession.id == session_id)
|
||||||
|
if owner:
|
||||||
|
update_q = update_q.filter(DBSession.owner == owner)
|
||||||
|
update_q.update({"headers": sess.headers})
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"Resolved and persisted auth headers for session {session_id} from endpoint {ep.name}")
|
||||||
|
return
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -806,7 +838,7 @@ def run_post_response_tasks(
|
|||||||
from services.memory.memory_extractor import extract_and_store
|
from services.memory.memory_extractor import extract_and_store
|
||||||
from src.task_endpoint import resolve_task_endpoint
|
from src.task_endpoint import resolve_task_endpoint
|
||||||
t_url, t_model, t_headers = resolve_task_endpoint(
|
t_url, t_model, t_headers = resolve_task_endpoint(
|
||||||
sess.endpoint_url, sess.model, sess.headers,
|
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||||
)
|
)
|
||||||
asyncio.create_task(extract_and_store(
|
asyncio.create_task(extract_and_store(
|
||||||
sess, memory_manager, memory_vector,
|
sess, memory_manager, memory_vector,
|
||||||
@@ -843,7 +875,7 @@ def run_post_response_tasks(
|
|||||||
from services.memory.skill_extractor import maybe_extract_skill
|
from services.memory.skill_extractor import maybe_extract_skill
|
||||||
from src.task_endpoint import resolve_task_endpoint
|
from src.task_endpoint import resolve_task_endpoint
|
||||||
s_url, s_model, s_headers = resolve_task_endpoint(
|
s_url, s_model, s_headers = resolve_task_endpoint(
|
||||||
sess.endpoint_url, sess.model, sess.headers,
|
sess.endpoint_url, sess.model, sess.headers, owner=owner,
|
||||||
)
|
)
|
||||||
logger.debug("[skill-extract] dispatching extractor (model=%s)", s_model)
|
logger.debug("[skill-extract] dispatching extractor (model=%s)", s_model)
|
||||||
asyncio.create_task(maybe_extract_skill(
|
asyncio.create_task(maybe_extract_skill(
|
||||||
|
|||||||
@@ -72,13 +72,17 @@ def _session_url_matches_endpoint(session_url: str, endpoint_base: str) -> bool:
|
|||||||
return sess in variants or sess.startswith(base + "/")
|
return sess in variants or sess.startswith(base + "/")
|
||||||
|
|
||||||
|
|
||||||
def _clear_orphaned_session_endpoint(sess) -> bool:
|
def _clear_orphaned_session_endpoint(sess, owner: str | None = None) -> bool:
|
||||||
"""Clear a session model if its endpoint was deleted from ModelEndpoint."""
|
"""Clear a session model if its endpoint was deleted from ModelEndpoint."""
|
||||||
if not getattr(sess, "endpoint_url", ""):
|
if not getattr(sess, "endpoint_url", ""):
|
||||||
return False
|
return False
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
endpoints = q.all()
|
||||||
for ep in endpoints:
|
for ep in endpoints:
|
||||||
if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""):
|
if _session_url_matches_endpoint(sess.endpoint_url or "", ep.base_url or ""):
|
||||||
return False
|
return False
|
||||||
@@ -118,7 +122,7 @@ def _endpoint_cache_contains_model(endpoint, model: str) -> bool:
|
|||||||
return wanted in {str(item).strip() for item in models}
|
return wanted in {str(item).strip() for item in models}
|
||||||
|
|
||||||
|
|
||||||
def _is_image_generation_session(sess) -> bool:
|
def _is_image_generation_session(sess, owner: str | None = None) -> bool:
|
||||||
"""Whether this chat session should bypass text chat and generate images.
|
"""Whether this chat session should bypass text chat and generate images.
|
||||||
|
|
||||||
Model-name prefixes are explicit image models. Endpoint type is only used
|
Model-name prefixes are explicit image models. Endpoint type is only used
|
||||||
@@ -137,7 +141,11 @@ def _is_image_generation_session(sess) -> bool:
|
|||||||
|
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
endpoints = q.all()
|
||||||
for endpoint in endpoints:
|
for endpoint in endpoints:
|
||||||
if (getattr(endpoint, "model_type", None) or "llm") != "image":
|
if (getattr(endpoint, "model_type", None) or "llm") != "image":
|
||||||
continue
|
continue
|
||||||
@@ -152,7 +160,7 @@ def _is_image_generation_session(sess) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _recover_empty_session_model(sess, session_id: str) -> bool:
|
def _recover_empty_session_model(sess, session_id: str, owner: str | None = None) -> bool:
|
||||||
"""Re-populate sess.model from the matching endpoint's cached models.
|
"""Re-populate sess.model from the matching endpoint's cached models.
|
||||||
|
|
||||||
Covers the window between endpoint setup and the first chat send: the
|
Covers the window between endpoint setup and the first chat send: the
|
||||||
@@ -172,7 +180,11 @@ def _recover_empty_session_model(sess, session_id: str) -> bool:
|
|||||||
# cached model is the most defensible default.
|
# cached model is the most defensible default.
|
||||||
ep = None
|
ep = None
|
||||||
if getattr(sess, "endpoint_url", ""):
|
if getattr(sess, "endpoint_url", ""):
|
||||||
endpoints = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True).all()
|
q = db.query(ModelEndpoint).filter(ModelEndpoint.is_enabled == True)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
endpoints = q.all()
|
||||||
for cand in endpoints:
|
for cand in endpoints:
|
||||||
if _session_url_matches_endpoint(sess.endpoint_url or "", cand.base_url or ""):
|
if _session_url_matches_endpoint(sess.endpoint_url or "", cand.base_url or ""):
|
||||||
ep = cand
|
ep = cand
|
||||||
@@ -251,13 +263,14 @@ def setup_chat_routes(
|
|||||||
sess = session_manager.get_session(session)
|
sess = session_manager.get_session(session)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise HTTPException(404, f"Session '{session}' not found")
|
raise HTTPException(404, f"Session '{session}' not found")
|
||||||
if _clear_orphaned_session_endpoint(sess):
|
owner = get_current_user(request)
|
||||||
|
if _clear_orphaned_session_endpoint(sess, owner=owner):
|
||||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||||
|
|
||||||
# Empty model + live endpoint = setup race (Issue #587). Repair from
|
# Empty model + live endpoint = setup race (Issue #587). Repair from
|
||||||
# the endpoint's cached model list before privilege checks, which
|
# the endpoint's cached model list before privilege checks, which
|
||||||
# otherwise see "" and behave inconsistently with the allowlist.
|
# otherwise see "" and behave inconsistently with the allowlist.
|
||||||
_recover_empty_session_model(sess, session)
|
_recover_empty_session_model(sess, session, owner=owner)
|
||||||
if not getattr(sess, "model", "").strip():
|
if not getattr(sess, "model", "").strip():
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
400,
|
400,
|
||||||
@@ -401,7 +414,8 @@ def setup_chat_routes(
|
|||||||
# but BEFORE loading. Prevents cross-user session hijack.
|
# but BEFORE loading. Prevents cross-user session hijack.
|
||||||
_verify_session_owner(request, session)
|
_verify_session_owner(request, session)
|
||||||
sess = session_manager.get_session(session)
|
sess = session_manager.get_session(session)
|
||||||
if _clear_orphaned_session_endpoint(sess):
|
owner = get_current_user(request)
|
||||||
|
if _clear_orphaned_session_endpoint(sess, owner=owner):
|
||||||
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
raise HTTPException(400, "Selected model endpoint was removed. Pick another model in Settings.")
|
||||||
# Issue #587: picker shows a model from the endpoint cache but
|
# Issue #587: picker shows a model from the endpoint cache but
|
||||||
# s.model never made it onto the DB row (first-send race after
|
# s.model never made it onto the DB row (first-send race after
|
||||||
@@ -409,7 +423,7 @@ def setup_chat_routes(
|
|||||||
# the first cached model off the matching endpoint so the
|
# the first cached model off the matching endpoint so the
|
||||||
# upstream isn't called with model="" (which surfaces as a
|
# upstream isn't called with model="" (which surfaces as a
|
||||||
# generic 401/503).
|
# generic 401/503).
|
||||||
_recover_empty_session_model(sess, session)
|
_recover_empty_session_model(sess, session, owner=owner)
|
||||||
if not getattr(sess, "model", "").strip():
|
if not getattr(sess, "model", "").strip():
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
400,
|
400,
|
||||||
@@ -431,7 +445,7 @@ def setup_chat_routes(
|
|||||||
_enforce_chat_privileges(request, sess)
|
_enforce_chat_privileges(request, sess)
|
||||||
|
|
||||||
# Ensure session has auth headers
|
# Ensure session has auth headers
|
||||||
resolve_session_auth(sess, session)
|
resolve_session_auth(sess, session, owner=get_current_user(request))
|
||||||
|
|
||||||
# Check for research_pending BEFORE mode persist overwrites it
|
# Check for research_pending BEFORE mode persist overwrites it
|
||||||
do_research = str(use_research).lower() == "true"
|
do_research = str(use_research).lower() == "true"
|
||||||
@@ -768,7 +782,7 @@ def setup_chat_routes(
|
|||||||
# output. Resolved once per request.
|
# output. Resolved once per request.
|
||||||
try:
|
try:
|
||||||
from src.endpoint_resolver import resolve_chat_fallback_candidates
|
from src.endpoint_resolver import resolve_chat_fallback_candidates
|
||||||
_fallback_candidates = resolve_chat_fallback_candidates()
|
_fallback_candidates = resolve_chat_fallback_candidates(owner=_user)
|
||||||
except Exception:
|
except Exception:
|
||||||
_fallback_candidates = []
|
_fallback_candidates = []
|
||||||
|
|
||||||
@@ -781,7 +795,7 @@ def setup_chat_routes(
|
|||||||
_model_info["character_name"] = ctx.preset.character_name
|
_model_info["character_name"] = ctx.preset.character_name
|
||||||
yield f'data: {json.dumps(_model_info)}\n\n'
|
yield f'data: {json.dumps(_model_info)}\n\n'
|
||||||
|
|
||||||
if _is_image_generation_session(sess):
|
if _is_image_generation_session(sess, owner=_user):
|
||||||
from src.settings import get_setting
|
from src.settings import get_setting
|
||||||
if not get_setting("image_gen_enabled", True):
|
if not get_setting("image_gen_enabled", True):
|
||||||
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
yield f'data: {json.dumps({"delta": "Image generation is disabled by the administrator."})}\n\n'
|
||||||
@@ -792,7 +806,7 @@ def setup_chat_routes(
|
|||||||
_user_msg = message or ""
|
_user_msg = message or ""
|
||||||
yield f'data: {json.dumps({"type": "tool_start", "tool": "generate_image", "command": _user_msg[:100]})}\n\n'
|
yield f'data: {json.dumps({"type": "tool_start", "tool": "generate_image", "command": _user_msg[:100]})}\n\n'
|
||||||
yield ": heartbeat\n\n"
|
yield ": heartbeat\n\n"
|
||||||
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session)
|
_img_result = await do_generate_image(f"{_user_msg}\n{sess.model}", session, owner=_user)
|
||||||
_img_output = _img_result.get("results", _img_result.get("error", ""))
|
_img_output = _img_result.get("results", _img_result.get("error", ""))
|
||||||
_img_tool_data = {"type": "tool_output", "tool": "generate_image", "command": _user_msg[:100], "output": _img_output, "exit_code": 0 if "error" not in _img_result else 1}
|
_img_tool_data = {"type": "tool_output", "tool": "generate_image", "command": _user_msg[:100], "output": _img_output, "exit_code": 0 if "error" not in _img_result else 1}
|
||||||
for _k in ("image_url", "image_id", "image_prompt", "image_model", "image_size", "image_quality"):
|
for _k in ("image_url", "image_id", "image_prompt", "image_model", "image_size", "image_quality"):
|
||||||
|
|||||||
@@ -58,23 +58,71 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["sessions"])
|
router = APIRouter(prefix="/api", tags=["sessions"])
|
||||||
|
|
||||||
def _pick_endpoint_for_sort():
|
|
||||||
|
def _current_user_is_admin(request: Request, user: str | None) -> bool:
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
auth_mgr = getattr(request.app.state, "auth_manager", None)
|
||||||
|
is_admin = getattr(auth_mgr, "is_admin", None)
|
||||||
|
if not callable(is_admin):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return bool(is_admin(user))
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _reject_raw_endpoint_url_for_non_admin(
|
||||||
|
request: Request,
|
||||||
|
user: str | None,
|
||||||
|
endpoint_id: str | None,
|
||||||
|
endpoint_url: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Require registered endpoints for signed-in non-admin session changes."""
|
||||||
|
if endpoint_id and endpoint_id.strip():
|
||||||
|
return
|
||||||
|
if not endpoint_url:
|
||||||
|
return
|
||||||
|
# Raw URLs make the server dial whatever host the request supplies. For
|
||||||
|
# non-admin users, require a saved endpoint row so normal owner scoping and
|
||||||
|
# endpoint validation have already happened.
|
||||||
|
if user and not _current_user_is_admin(request, user):
|
||||||
|
raise HTTPException(403, "Choose a registered model endpoint")
|
||||||
|
|
||||||
|
|
||||||
|
def _persist_session_headers(session_id: str, headers: dict | None) -> None:
|
||||||
|
"""Persist endpoint auth headers for DB-backed session metadata."""
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
db_session = db.query(DbSession).filter(DbSession.id == session_id).first()
|
||||||
|
if db_session:
|
||||||
|
db_session.headers = headers or {}
|
||||||
|
db_session.updated_at = datetime.utcnow()
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_endpoint_for_sort(owner=None):
|
||||||
"""Pick model endpoint for auto-sort LLM call — uses utility endpoint setting, falls back to default."""
|
"""Pick model endpoint for auto-sort LLM call — uses utility endpoint setting, falls back to default."""
|
||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
# Try utility endpoint first (what the user configured for background tasks)
|
# Try utility endpoint first (what the user configured for background tasks)
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=owner)
|
||||||
if url and model:
|
if url and model:
|
||||||
return url, model, headers
|
return url, model, headers
|
||||||
# Fall back to task endpoint
|
# Fall back to task endpoint
|
||||||
try:
|
try:
|
||||||
from src.task_endpoint import resolve_task_endpoint
|
from src.task_endpoint import resolve_task_endpoint
|
||||||
url, model, headers = resolve_task_endpoint()
|
url, model, headers = resolve_task_endpoint(owner=owner)
|
||||||
if url and model:
|
if url and model:
|
||||||
return url, model, headers
|
return url, model, headers
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# Fall back to default
|
# Fall back to default
|
||||||
url, model, headers = resolve_endpoint("default")
|
url, model, headers = resolve_endpoint("default", owner=owner)
|
||||||
if url and model:
|
if url and model:
|
||||||
return url, model, headers
|
return url, model, headers
|
||||||
return None, None, None
|
return None, None, None
|
||||||
@@ -197,11 +245,41 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
endpoint_id: str = Form(""),
|
endpoint_id: str = Form(""),
|
||||||
):
|
):
|
||||||
skip_val = str(skip_validation).lower() == "true"
|
skip_val = str(skip_validation).lower() == "true"
|
||||||
|
user = get_current_user(request)
|
||||||
|
endpoint_api_key = ""
|
||||||
|
endpoint_base_url = ""
|
||||||
|
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
|
||||||
|
if endpoint_id and endpoint_id.strip():
|
||||||
|
from core.database import ModelEndpoint
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
from src.endpoint_resolver import build_chat_url, normalize_base
|
||||||
|
_db = SessionLocal()
|
||||||
|
try:
|
||||||
|
q = _db.query(ModelEndpoint).filter(
|
||||||
|
ModelEndpoint.id == endpoint_id.strip(),
|
||||||
|
ModelEndpoint.is_enabled == True,
|
||||||
|
)
|
||||||
|
if user:
|
||||||
|
q = owner_filter(q, ModelEndpoint, user)
|
||||||
|
endpoint_row = q.first()
|
||||||
|
if not endpoint_row:
|
||||||
|
raise HTTPException(400, "Model endpoint no longer exists")
|
||||||
|
endpoint_base_url = endpoint_row.base_url or ""
|
||||||
|
endpoint_api_key = endpoint_row.api_key or ""
|
||||||
|
endpoint_url = build_chat_url(normalize_base(endpoint_base_url))
|
||||||
|
finally:
|
||||||
|
_db.close()
|
||||||
|
|
||||||
if not endpoint_url and not skip_val:
|
if not endpoint_url and not skip_val:
|
||||||
raise HTTPException(400, "endpoint_url is required (choose from /api/models)")
|
raise HTTPException(400, "endpoint_url is required (choose from /api/models)")
|
||||||
|
|
||||||
model_to_use = model
|
model_to_use = model
|
||||||
|
request_api_key = api_key.strip() if api_key else ""
|
||||||
|
effective_api_key = request_api_key or endpoint_api_key
|
||||||
|
validation_headers = None
|
||||||
|
if effective_api_key:
|
||||||
|
from src.endpoint_resolver import build_headers
|
||||||
|
validation_headers = build_headers(effective_api_key, endpoint_base_url or endpoint_url)
|
||||||
|
|
||||||
if skip_val:
|
if skip_val:
|
||||||
# skip_validation = trust the caller and do NOT probe /v1/models.
|
# skip_validation = trust the caller and do NOT probe /v1/models.
|
||||||
@@ -212,7 +290,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
elif not model_to_use:
|
elif not model_to_use:
|
||||||
from src.llm_core import list_model_ids
|
from src.llm_core import list_model_ids
|
||||||
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
ids = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||||
headers={"Authorization": f"Bearer {api_key}"} if api_key.strip() else None)
|
headers=validation_headers)
|
||||||
if not ids:
|
if not ids:
|
||||||
raise HTTPException(400, "Cannot reach /v1/models")
|
raise HTTPException(400, "Cannot reach /v1/models")
|
||||||
# Default to the first CHAT model — endpoints often list embedding/
|
# Default to the first CHAT model — endpoints often list embedding/
|
||||||
@@ -227,7 +305,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
import os as _os
|
import os as _os
|
||||||
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
req_base = _os.path.basename(model_to_use.rstrip("/"))
|
||||||
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
avail = list_model_ids(endpoint_url, timeout=REQUEST_TIMEOUT,
|
||||||
headers={"Authorization": f"Bearer {api_key}"} if api_key.strip() else None)
|
headers=validation_headers)
|
||||||
if not avail:
|
if not avail:
|
||||||
raise HTTPException(400, "Cannot reach /v1/models")
|
raise HTTPException(400, "Cannot reach /v1/models")
|
||||||
if model_to_use not in avail:
|
if model_to_use not in avail:
|
||||||
@@ -252,22 +330,15 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
owner=user,
|
owner=user,
|
||||||
)
|
)
|
||||||
# Set auth headers for custom API-key endpoints
|
# Set auth headers for custom API-key endpoints
|
||||||
resolved_key = api_key.strip() if api_key else ""
|
resolved_key = request_api_key
|
||||||
resolved_base = endpoint_url
|
resolved_base = endpoint_url
|
||||||
if not resolved_key and endpoint_id and endpoint_id.strip():
|
if not resolved_key and endpoint_api_key:
|
||||||
from core.database import ModelEndpoint
|
resolved_key = endpoint_api_key
|
||||||
_db = SessionLocal()
|
resolved_base = endpoint_base_url
|
||||||
try:
|
|
||||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id.strip()).first()
|
|
||||||
if ep and ep.api_key:
|
|
||||||
resolved_key = ep.api_key
|
|
||||||
resolved_base = ep.base_url
|
|
||||||
finally:
|
|
||||||
_db.close()
|
|
||||||
if resolved_key:
|
if resolved_key:
|
||||||
from src.endpoint_resolver import build_headers
|
from src.endpoint_resolver import build_headers
|
||||||
session.headers = build_headers(resolved_key, resolved_base)
|
session.headers = build_headers(resolved_key, resolved_base)
|
||||||
session_manager.save_sessions()
|
_persist_session_headers(sid, session.headers)
|
||||||
# Fire webhook (sync-safe)
|
# Fire webhook (sync-safe)
|
||||||
if webhook_manager:
|
if webhook_manager:
|
||||||
webhook_manager.fire_and_forget("session.created", {
|
webhook_manager.fire_and_forget("session.created", {
|
||||||
@@ -313,27 +384,38 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
db.close()
|
db.close()
|
||||||
# Switch model/endpoint mid-session
|
# Switch model/endpoint mid-session
|
||||||
if model is not None and endpoint_url is not None:
|
if model is not None and endpoint_url is not None:
|
||||||
|
user = get_current_user(request)
|
||||||
|
_reject_raw_endpoint_url_for_non_admin(request, user, endpoint_id, endpoint_url)
|
||||||
|
endpoint_api_key = ""
|
||||||
|
endpoint_base_url = ""
|
||||||
if endpoint_id:
|
if endpoint_id:
|
||||||
from core.database import ModelEndpoint
|
from core.database import ModelEndpoint
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
from src.endpoint_resolver import build_chat_url, normalize_base
|
||||||
_db = SessionLocal()
|
_db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
q = _db.query(ModelEndpoint).filter(
|
||||||
|
ModelEndpoint.id == endpoint_id,
|
||||||
|
ModelEndpoint.is_enabled == True,
|
||||||
|
)
|
||||||
|
if user:
|
||||||
|
q = owner_filter(q, ModelEndpoint, user)
|
||||||
|
ep = q.first()
|
||||||
if not ep:
|
if not ep:
|
||||||
raise HTTPException(400, "Model endpoint no longer exists")
|
raise HTTPException(400, "Model endpoint no longer exists")
|
||||||
|
endpoint_base_url = ep.base_url or ""
|
||||||
|
endpoint_api_key = ep.api_key or ""
|
||||||
|
endpoint_url = build_chat_url(normalize_base(endpoint_base_url))
|
||||||
finally:
|
finally:
|
||||||
_db.close()
|
_db.close()
|
||||||
session.model = model
|
session.model = model
|
||||||
session.endpoint_url = endpoint_url
|
session.endpoint_url = endpoint_url
|
||||||
# Update auth headers from the endpoint's stored API key
|
# Update auth headers from the endpoint's stored API key
|
||||||
if endpoint_id:
|
if endpoint_api_key:
|
||||||
_db = SessionLocal()
|
from src.endpoint_resolver import build_headers
|
||||||
try:
|
session.headers = build_headers(endpoint_api_key, endpoint_base_url)
|
||||||
ep = _db.query(ModelEndpoint).filter(ModelEndpoint.id == endpoint_id).first()
|
else:
|
||||||
if ep and ep.api_key:
|
session.headers = {}
|
||||||
from src.endpoint_resolver import build_headers
|
|
||||||
session.headers = build_headers(ep.api_key, ep.base_url)
|
|
||||||
finally:
|
|
||||||
_db.close()
|
|
||||||
# Persist to DB
|
# Persist to DB
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
@@ -341,6 +423,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
if db_session:
|
if db_session:
|
||||||
db_session.model = model
|
db_session.model = model
|
||||||
db_session.endpoint_url = endpoint_url
|
db_session.endpoint_url = endpoint_url
|
||||||
|
db_session.headers = session.headers or {}
|
||||||
db_session.updated_at = datetime.utcnow()
|
db_session.updated_at = datetime.utcnow()
|
||||||
db.commit()
|
db.commit()
|
||||||
finally:
|
finally:
|
||||||
@@ -754,7 +837,7 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
from src.llm_core import llm_call_async
|
from src.llm_core import llm_call_async
|
||||||
|
|
||||||
url, model, headers = resolve_endpoint("utility")
|
url, model, headers = resolve_endpoint("utility", owner=get_current_user(request))
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
url, model, headers = session.endpoint_url, session.model, session.headers
|
url, model, headers = session.endpoint_url, session.model, session.headers
|
||||||
if not url or not model:
|
if not url or not model:
|
||||||
@@ -954,9 +1037,9 @@ def setup_session_routes(session_manager: SessionManager, config: dict, webhook_
|
|||||||
|
|
||||||
# Pick an endpoint — prefer admin-configured task endpoint
|
# Pick an endpoint — prefer admin-configured task endpoint
|
||||||
from src.task_endpoint import resolve_task_endpoint
|
from src.task_endpoint import resolve_task_endpoint
|
||||||
url, model, headers = resolve_task_endpoint()
|
url, model, headers = resolve_task_endpoint(owner=user)
|
||||||
if not url:
|
if not url:
|
||||||
url, model, headers = _pick_endpoint_for_sort()
|
url, model, headers = _pick_endpoint_for_sort(owner=user)
|
||||||
if not url:
|
if not url:
|
||||||
raise HTTPException(503, "No available model endpoint for auto-sort")
|
raise HTTPException(503, "No available model endpoint for auto-sort")
|
||||||
|
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ def resolve_endpoint(
|
|||||||
ep_id = _stg(f"{setting_prefix}_endpoint_id")
|
ep_id = _stg(f"{setting_prefix}_endpoint_id")
|
||||||
model = _stg(f"{setting_prefix}_model")
|
model = _stg(f"{setting_prefix}_model")
|
||||||
|
|
||||||
# If the specific endpoint is not configured, but the caller provided a
|
# If the specific endpoint is not configured, but the caller provided a
|
||||||
# valid fallback (e.g. the active session model), use that immediately.
|
# valid fallback (e.g. the active session model), use that immediately.
|
||||||
# This prevents background tasks from jumping to the global default_model
|
# This prevents background tasks from jumping to the global default_model
|
||||||
# when the user is mid-conversation with a different model.
|
# when the user is mid-conversation with a different model.
|
||||||
@@ -295,7 +295,7 @@ def resolve_endpoint(
|
|||||||
|
|
||||||
|
|
||||||
def resolve_endpoint_by_id(
|
def resolve_endpoint_by_id(
|
||||||
ep_id: str, model: Optional[str] = None
|
ep_id: str, model: Optional[str] = None, owner: Optional[str] = None
|
||||||
) -> Optional[Tuple[str, str, Dict]]:
|
) -> Optional[Tuple[str, str, Dict]]:
|
||||||
"""Resolve a specific endpoint id (+ optional model) to (chat_url, model, headers).
|
"""Resolve a specific endpoint id (+ optional model) to (chat_url, model, headers).
|
||||||
|
|
||||||
@@ -306,10 +306,14 @@ def resolve_endpoint_by_id(
|
|||||||
return None
|
return None
|
||||||
db = SessionLocal()
|
db = SessionLocal()
|
||||||
try:
|
try:
|
||||||
ep = db.query(ModelEndpoint).filter(
|
q = db.query(ModelEndpoint).filter(
|
||||||
ModelEndpoint.id == ep_id,
|
ModelEndpoint.id == ep_id,
|
||||||
ModelEndpoint.is_enabled == True,
|
ModelEndpoint.is_enabled == True,
|
||||||
).first()
|
)
|
||||||
|
if owner:
|
||||||
|
from src.auth_helpers import owner_filter
|
||||||
|
q = owner_filter(q, ModelEndpoint, owner)
|
||||||
|
ep = q.first()
|
||||||
if not ep:
|
if not ep:
|
||||||
return None
|
return None
|
||||||
base = normalize_base(ep.base_url)
|
base = normalize_base(ep.base_url)
|
||||||
@@ -332,14 +336,14 @@ def resolve_endpoint_by_id(
|
|||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
def resolve_chat_fallback_candidates() -> list:
|
def resolve_chat_fallback_candidates(owner: Optional[str] = None) -> list:
|
||||||
"""Build the configured default-chat fallback chain as a list of
|
"""Build the configured default-chat fallback chain as a list of
|
||||||
(chat_url, model, headers) tuples, skipping any that can't resolve.
|
(chat_url, model, headers) tuples, skipping any that can't resolve.
|
||||||
|
|
||||||
The primary model is NOT included — callers prepend their session's
|
The primary model is NOT included — callers prepend their session's
|
||||||
current (url, model, headers) so per-session model overrides are honored.
|
current (url, model, headers) so per-session model overrides are honored.
|
||||||
"""
|
"""
|
||||||
return _resolve_fallback_candidates("default_model_fallbacks")
|
return _resolve_fallback_candidates("default_model_fallbacks", owner=owner)
|
||||||
|
|
||||||
|
|
||||||
def resolve_utility_fallback_candidates(owner: Optional[str] = None) -> list:
|
def resolve_utility_fallback_candidates(owner: Optional[str] = None) -> list:
|
||||||
@@ -355,9 +359,9 @@ def resolve_utility_fallback_candidates(owner: Optional[str] = None) -> list:
|
|||||||
return _resolve_fallback_candidates("utility_model_fallbacks", owner=owner)
|
return _resolve_fallback_candidates("utility_model_fallbacks", owner=owner)
|
||||||
|
|
||||||
|
|
||||||
def resolve_vision_fallback_candidates() -> list:
|
def resolve_vision_fallback_candidates(owner: Optional[str] = None) -> list:
|
||||||
"""Configured fallback chain for the Vision model (`vision_model_fallbacks`)."""
|
"""Configured fallback chain for the Vision model (`vision_model_fallbacks`)."""
|
||||||
return _resolve_fallback_candidates("vision_model_fallbacks")
|
return _resolve_fallback_candidates("vision_model_fallbacks", owner=owner)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_fallback_candidates(setting_key: str, owner: Optional[str] = None) -> list:
|
def _resolve_fallback_candidates(setting_key: str, owner: Optional[str] = None) -> list:
|
||||||
@@ -371,7 +375,7 @@ def _resolve_fallback_candidates(setting_key: str, owner: Optional[str] = None)
|
|||||||
for entry in chain:
|
for entry in chain:
|
||||||
if not isinstance(entry, dict):
|
if not isinstance(entry, dict):
|
||||||
continue
|
continue
|
||||||
resolved = resolve_endpoint_by_id(entry.get("endpoint_id", ""), entry.get("model", ""))
|
resolved = resolve_endpoint_by_id(entry.get("endpoint_id", ""), entry.get("model", ""), owner=owner)
|
||||||
if resolved:
|
if resolved:
|
||||||
out.append(resolved)
|
out.append(resolved)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -3,11 +3,11 @@
|
|||||||
from src.endpoint_resolver import resolve_endpoint
|
from src.endpoint_resolver import resolve_endpoint
|
||||||
|
|
||||||
|
|
||||||
def resolve_task_endpoint(fallback_url=None, fallback_model=None, fallback_headers=None):
|
def resolve_task_endpoint(fallback_url=None, fallback_model=None, fallback_headers=None, owner=None):
|
||||||
"""Return (endpoint_url, model, headers) for background tasks.
|
"""Return (endpoint_url, model, headers) for background tasks.
|
||||||
|
|
||||||
Reads task_endpoint_id / task_model from admin settings.
|
Reads task_endpoint_id / task_model from admin settings.
|
||||||
Falls back to the provided values when the setting is empty or the
|
Falls back to the provided values when the setting is empty or the
|
||||||
endpoint cannot be resolved.
|
endpoint cannot be resolved.
|
||||||
"""
|
"""
|
||||||
return resolve_endpoint("task", fallback_url, fallback_model, fallback_headers)
|
return resolve_endpoint("task", fallback_url, fallback_model, fallback_headers, owner=owner)
|
||||||
|
|||||||
57
tests/test_session_endpoint_owner_scope.py
Normal file
57
tests/test_session_endpoint_owner_scope.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
# Import the route helper during collection so sibling session tests that use
|
||||||
|
# partial import stubs do not become the first loader of core.session_manager.
|
||||||
|
from routes.session_routes import _reject_raw_endpoint_url_for_non_admin
|
||||||
|
|
||||||
|
|
||||||
|
def _request(user, *, admin=False):
|
||||||
|
auth_manager = SimpleNamespace(is_admin=lambda username: bool(admin))
|
||||||
|
return SimpleNamespace(
|
||||||
|
state=SimpleNamespace(current_user=user),
|
||||||
|
app=SimpleNamespace(state=SimpleNamespace(auth_manager=auth_manager)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_admin_session_create_rejects_raw_endpoint_url_without_endpoint_id():
|
||||||
|
with pytest.raises(HTTPException) as exc:
|
||||||
|
_reject_raw_endpoint_url_for_non_admin(
|
||||||
|
_request("alice", admin=False),
|
||||||
|
"alice",
|
||||||
|
"",
|
||||||
|
"http://169.254.169.254/latest/meta-data",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_and_registered_endpoint_can_use_endpoint_url():
|
||||||
|
_reject_raw_endpoint_url_for_non_admin(
|
||||||
|
_request("alice", admin=False),
|
||||||
|
"alice",
|
||||||
|
"endpoint-id",
|
||||||
|
"http://127.0.0.1:8000/v1/chat/completions",
|
||||||
|
)
|
||||||
|
_reject_raw_endpoint_url_for_non_admin(
|
||||||
|
_request("admin", admin=True),
|
||||||
|
"admin",
|
||||||
|
"",
|
||||||
|
"http://127.0.0.1:8000/v1/chat/completions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_endpoint_recovery_paths_are_owner_scoped():
|
||||||
|
root = Path(__file__).resolve().parents[1]
|
||||||
|
chat_routes = (root / "routes" / "chat_routes.py").read_text(encoding="utf-8")
|
||||||
|
chat_helpers = (root / "routes" / "chat_helpers.py").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert "def _clear_orphaned_session_endpoint(sess, owner:" in chat_routes
|
||||||
|
assert "def _recover_empty_session_model(sess, session_id: str, owner:" in chat_routes
|
||||||
|
assert "q = owner_filter(q, ModelEndpoint, owner)" in chat_routes
|
||||||
|
assert "resolve_session_auth(sess, session, owner=get_current_user(request))" in chat_routes
|
||||||
|
assert "def resolve_session_auth(sess, session_id: str, owner:" in chat_helpers
|
||||||
|
assert "update_q = update_q.filter(DBSession.owner == owner)" in chat_helpers
|
||||||
Reference in New Issue
Block a user