Fix database stubs in regression tests (#301)

* Fix database stubs in regression tests

* Keep regression tests independent of SQLAlchemy

---------

Co-authored-by: red <red@red-MacBook-Air.local>
This commit is contained in:
red person
2026-06-01 00:55:09 -07:00
committed by GitHub
parent be260f43e8
commit c9c6b919ff
2 changed files with 54 additions and 16 deletions

View File

@@ -26,6 +26,10 @@ class _FakeModelEndpoint:
owner = _FakeColumn("owner") owner = _FakeColumn("owner")
class _FakeDbSession:
endpoint_url = _FakeColumn("endpoint_url")
class _FakeQuery: class _FakeQuery:
def __init__(self, rows): def __init__(self, rows):
self.rows = list(rows) self.rows = list(rows)
@@ -68,6 +72,7 @@ def _install_model_route_import_stubs(monkeypatch):
db_mod = types.ModuleType("core.database") db_mod = types.ModuleType("core.database")
db_mod.SessionLocal = lambda: _FakeDb([]) db_mod.SessionLocal = lambda: _FakeDb([])
db_mod.ModelEndpoint = _FakeModelEndpoint db_mod.ModelEndpoint = _FakeModelEndpoint
db_mod.Session = _FakeDbSession
middleware_mod = types.ModuleType("core.middleware") middleware_mod = types.ModuleType("core.middleware")
middleware_mod.require_admin = lambda request: None middleware_mod.require_admin = lambda request: None
multipart_mod = types.ModuleType("python_multipart") multipart_mod = types.ModuleType("python_multipart")
@@ -80,6 +85,18 @@ def _install_model_route_import_stubs(monkeypatch):
monkeypatch.setitem(sys.modules, "python_multipart", multipart_mod) monkeypatch.setitem(sys.modules, "python_multipart", multipart_mod)
def _install_core_auth_stub(monkeypatch):
"""Install the narrow auth surface needed by tool-policy tests."""
core_mod = types.ModuleType("core")
core_mod.__path__ = []
auth_mod = types.ModuleType("core.auth")
auth_mod.AuthManager = MagicMock()
core_mod.auth = auth_mod
monkeypatch.setitem(sys.modules, "core", core_mod)
monkeypatch.setitem(sys.modules, "core.auth", auth_mod)
return auth_mod
def test_default_chat_does_not_auto_pick_shared_endpoint_for_fresh_user(monkeypatch): def test_default_chat_does_not_auto_pick_shared_endpoint_for_fresh_user(monkeypatch):
_install_model_route_import_stubs(monkeypatch) _install_model_route_import_stubs(monkeypatch)
import routes.model_routes as model_routes import routes.model_routes as model_routes
@@ -335,8 +352,8 @@ async def test_build_chat_context_incognito_does_not_duplicate_current_user_mess
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_agent_tools_require_admin(monkeypatch): async def test_admin_agent_tools_require_admin(monkeypatch):
auth_mod = _install_core_auth_stub(monkeypatch)
from src.tool_execution import execute_tool_block from src.tool_execution import execute_tool_block
import core.auth
class FakeAuth: class FakeAuth:
is_configured = True is_configured = True
@@ -344,7 +361,7 @@ async def test_admin_agent_tools_require_admin(monkeypatch):
def is_admin(self, username): def is_admin(self, username):
return False return False
monkeypatch.setattr(core.auth, "AuthManager", lambda: FakeAuth()) monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
desc, result = await execute_tool_block( desc, result = await execute_tool_block(
SimpleNamespace(tool_type="manage_tokens", content='{"action":"create","name":"bad"}'), SimpleNamespace(tool_type="manage_tokens", content='{"action":"create","name":"bad"}'),
@@ -358,8 +375,8 @@ async def test_admin_agent_tools_require_admin(monkeypatch):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_public_agent_policy_blocks_sensitive_tools(monkeypatch): async def test_public_agent_policy_blocks_sensitive_tools(monkeypatch):
auth_mod = _install_core_auth_stub(monkeypatch)
from src.tool_execution import execute_tool_block from src.tool_execution import execute_tool_block
import core.auth
class FakeAuth: class FakeAuth:
is_configured = True is_configured = True
@@ -367,7 +384,7 @@ async def test_public_agent_policy_blocks_sensitive_tools(monkeypatch):
def is_admin(self, username): def is_admin(self, username):
return False return False
monkeypatch.setattr(core.auth, "AuthManager", lambda: FakeAuth()) monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
for tool_name in ("send_email", "read_file", "app_api", "mcp__email__send_email"): for tool_name in ("send_email", "read_file", "app_api", "mcp__email__send_email"):
desc, result = await execute_tool_block( desc, result = await execute_tool_block(
@@ -380,7 +397,7 @@ async def test_public_agent_policy_blocks_sensitive_tools(monkeypatch):
def test_public_agent_policy_hides_sensitive_tools(monkeypatch): def test_public_agent_policy_hides_sensitive_tools(monkeypatch):
import core.auth auth_mod = _install_core_auth_stub(monkeypatch)
from src.tool_security import blocked_tools_for_owner from src.tool_security import blocked_tools_for_owner
class FakeAuth: class FakeAuth:
@@ -389,7 +406,7 @@ def test_public_agent_policy_hides_sensitive_tools(monkeypatch):
def is_admin(self, username): def is_admin(self, username):
return False return False
monkeypatch.setattr(core.auth, "AuthManager", lambda: FakeAuth()) monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
blocked = blocked_tools_for_owner("regular-user") blocked = blocked_tools_for_owner("regular-user")

View File

@@ -14,23 +14,44 @@ that a mid-operation DB error neither raises out of the helper nor leaks the
connection. The error-path cases fail against the old close()-inside-try connection. The error-path cases fail against the old close()-inside-try
pattern. pattern.
""" """
import os import ast
os.environ.setdefault("DATABASE_URL", "sqlite:///:memory:") from contextlib import contextmanager
from pathlib import Path
from types import SimpleNamespace
from typing import Generator
from unittest.mock import MagicMock from unittest.mock import MagicMock
from core import database as db
def _load_db_helpers():
"""Load only the helper bodies under test, without importing SQLAlchemy."""
db_path = Path(__file__).parents[1] / "core" / "database.py"
tree = ast.parse(db_path.read_text(encoding="utf-8"), filename=str(db_path))
wanted = {"get_db_session", "get_session_mode", "set_session_mode"}
helper_nodes = [
node for node in tree.body
if isinstance(node, ast.FunctionDef) and node.name in wanted
]
namespace = {
"contextmanager": contextmanager,
"Generator": Generator,
"Session": MagicMock(),
"SessionLocal": MagicMock(),
"logger": MagicMock(),
}
exec(compile(ast.Module(helper_nodes, type_ignores=[]), str(db_path), "exec"), namespace)
return SimpleNamespace(**namespace, _namespace=namespace)
def _mock_session(monkeypatch): def _mock_session(monkeypatch):
"""Make get_db_session() hand out a MagicMock session (no real DB).""" """Make get_db_session() hand out a MagicMock session (no real DB)."""
db = _load_db_helpers()
sess = MagicMock() sess = MagicMock()
monkeypatch.setattr(db, "SessionLocal", lambda: sess) monkeypatch.setitem(db._namespace, "SessionLocal", lambda: sess)
return sess return db, sess
def test_set_session_mode_commits_and_closes_on_success(monkeypatch): def test_set_session_mode_commits_and_closes_on_success(monkeypatch):
sess = _mock_session(monkeypatch) db, sess = _mock_session(monkeypatch)
assert db.set_session_mode("s1", "agent") is True assert db.set_session_mode("s1", "agent") is True
sess.query.return_value.filter.return_value.update.assert_called_once_with({"mode": "agent"}) sess.query.return_value.filter.return_value.update.assert_called_once_with({"mode": "agent"})
sess.commit.assert_called_once() sess.commit.assert_called_once()
@@ -38,7 +59,7 @@ def test_set_session_mode_commits_and_closes_on_success(monkeypatch):
def test_set_session_mode_does_not_leak_on_error(monkeypatch): def test_set_session_mode_does_not_leak_on_error(monkeypatch):
sess = _mock_session(monkeypatch) db, sess = _mock_session(monkeypatch)
sess.query.return_value.filter.return_value.update.side_effect = RuntimeError("database is locked") sess.query.return_value.filter.return_value.update.side_effect = RuntimeError("database is locked")
# Best-effort: the error is swallowed and False returned... # Best-effort: the error is swallowed and False returned...
assert db.set_session_mode("s1", "agent") is False assert db.set_session_mode("s1", "agent") is False
@@ -48,14 +69,14 @@ def test_set_session_mode_does_not_leak_on_error(monkeypatch):
def test_get_session_mode_reads_and_closes(monkeypatch): def test_get_session_mode_reads_and_closes(monkeypatch):
sess = _mock_session(monkeypatch) db, sess = _mock_session(monkeypatch)
sess.query.return_value.filter.return_value.scalar.return_value = "research_pending" sess.query.return_value.filter.return_value.scalar.return_value = "research_pending"
assert db.get_session_mode("s1") == "research_pending" assert db.get_session_mode("s1") == "research_pending"
sess.close.assert_called_once() sess.close.assert_called_once()
def test_get_session_mode_does_not_leak_on_error(monkeypatch): def test_get_session_mode_does_not_leak_on_error(monkeypatch):
sess = _mock_session(monkeypatch) db, sess = _mock_session(monkeypatch)
sess.query.return_value.filter.return_value.scalar.side_effect = RuntimeError("database is locked") sess.query.return_value.filter.return_value.scalar.side_effect = RuntimeError("database is locked")
assert db.get_session_mode("s1") is None assert db.get_session_mode("s1") is None
sess.close.assert_called_once() sess.close.assert_called_once()