diff --git a/routes/gallery_routes.py b/routes/gallery_routes.py index fd791bd..db17bfe 100644 --- a/routes/gallery_routes.py +++ b/routes/gallery_routes.py @@ -9,7 +9,7 @@ from fastapi import APIRouter, HTTPException, Query, Request from core.database import SessionLocal, GalleryImage, GalleryAlbum, ModelEndpoint from core.database import Session as DbSession -from src.auth_helpers import get_current_user +from src.auth_helpers import get_current_user, require_privilege from routes.gallery_helpers import ( GalleryPatch, _extract_exif, _image_to_dict, _owner_filter, _human_size, @@ -233,6 +233,7 @@ def setup_gallery_routes() -> APIRouter: """AI upscale using img2img with the diffusion server.""" import base64, httpx + require_privilege(request, "can_generate_images") form = await request.form() file = form.get("image") if not file: raise HTTPException(400, "No image") @@ -275,6 +276,7 @@ def setup_gallery_routes() -> APIRouter: """Style transfer using img2img with the diffusion server.""" import base64, httpx + require_privilege(request, "can_generate_images") form = await request.form() file = form.get("image") prompt = form.get("prompt", "") @@ -906,6 +908,7 @@ def setup_gallery_routes() -> APIRouter: the request for /v1/images/edits (multipart, inverted mask). Otherwise proxy through to a self-hosted diffusion server's /v1/images/inpaint.""" import httpx + require_privilege(request, "can_generate_images") body = await request.json() # Use endpoint from request body (editor dropdown) or fall back to DB lookup base = (body.pop("_endpoint", "") or "").rstrip("/") @@ -1093,6 +1096,7 @@ def setup_gallery_routes() -> APIRouter: you get edge blending + lighting unification while keeping the composition recognisable.""" import httpx, base64 as _b64 + require_privilege(request, "can_generate_images") body = await request.json() image_b64 = body.get("image") @@ -1298,6 +1302,7 @@ def setup_gallery_routes() -> APIRouter: # error so the client can prompt the user to install via Cookbook. @router.post("/api/image/denoise") async def denoise_image(request: Request): + require_privilege(request, "can_generate_images") body = await request.json() image_b64 = body.get("image") if not image_b64: @@ -1347,6 +1352,7 @@ def setup_gallery_routes() -> APIRouter: # server required. Used by the editor's AI Upscale button. @router.post("/api/image/upscale-local") async def upscale_image_local(request: Request): + require_privilege(request, "can_generate_images") body = await request.json() image_b64 = body.get("image") if not image_b64: @@ -1403,6 +1409,7 @@ def setup_gallery_routes() -> APIRouter: outside the hint becomes transparent regardless of what the model thought was foreground. """ + require_privilege(request, "can_generate_images") body = await request.json() image_b64 = body.get("image") hint_b64 = body.get("hint_mask") @@ -1484,6 +1491,7 @@ def setup_gallery_routes() -> APIRouter: @router.post("/api/image/enhance-face") async def enhance_face(request: Request): """Face/portrait enhancement. Uses GFPGAN if available, falls back to PIL.""" + require_privilege(request, "can_generate_images") body = await request.json() image_b64 = body.get("image") if not image_b64: @@ -1760,4 +1768,3 @@ def setup_gallery_routes() -> APIRouter: return router - diff --git a/tests/test_gallery_image_privileges.py b/tests/test_gallery_image_privileges.py new file mode 100644 index 0000000..2fe21c3 --- /dev/null +++ b/tests/test_gallery_image_privileges.py @@ -0,0 +1,40 @@ +import ast +from pathlib import Path + + +GATED_IMAGE_FUNCTIONS = { + "gallery_ai_upscale", + "gallery_style_transfer", + "inpaint_proxy", + "harmonize_image", + "denoise_image", + "upscale_image_local", + "remove_background", + "enhance_face", +} + + +def _gallery_source(): + return Path("routes/gallery_routes.py").read_text(encoding="utf-8") + + +def _function_sources(source): + tree = ast.parse(source) + return { + node.name: ast.get_source_segment(source, node) or "" + for node in ast.walk(tree) + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + + +def test_image_generation_endpoints_require_image_privilege(): + source = _gallery_source() + functions = _function_sources(source) + + for name in GATED_IMAGE_FUNCTIONS: + assert name in functions + assert 'require_privilege(request, "can_generate_images")' in functions[name] + + +def test_gallery_routes_imports_privilege_helper(): + assert "from src.auth_helpers import get_current_user, require_privilege" in _gallery_source()