Fix database stubs in regression tests (#301)
* Fix database stubs in regression tests * Keep regression tests independent of SQLAlchemy --------- Co-authored-by: red <red@red-MacBook-Air.local>
This commit is contained in:
@@ -26,6 +26,10 @@ class _FakeModelEndpoint:
|
|||||||
owner = _FakeColumn("owner")
|
owner = _FakeColumn("owner")
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeDbSession:
|
||||||
|
endpoint_url = _FakeColumn("endpoint_url")
|
||||||
|
|
||||||
|
|
||||||
class _FakeQuery:
|
class _FakeQuery:
|
||||||
def __init__(self, rows):
|
def __init__(self, rows):
|
||||||
self.rows = list(rows)
|
self.rows = list(rows)
|
||||||
@@ -68,6 +72,7 @@ def _install_model_route_import_stubs(monkeypatch):
|
|||||||
db_mod = types.ModuleType("core.database")
|
db_mod = types.ModuleType("core.database")
|
||||||
db_mod.SessionLocal = lambda: _FakeDb([])
|
db_mod.SessionLocal = lambda: _FakeDb([])
|
||||||
db_mod.ModelEndpoint = _FakeModelEndpoint
|
db_mod.ModelEndpoint = _FakeModelEndpoint
|
||||||
|
db_mod.Session = _FakeDbSession
|
||||||
middleware_mod = types.ModuleType("core.middleware")
|
middleware_mod = types.ModuleType("core.middleware")
|
||||||
middleware_mod.require_admin = lambda request: None
|
middleware_mod.require_admin = lambda request: None
|
||||||
multipart_mod = types.ModuleType("python_multipart")
|
multipart_mod = types.ModuleType("python_multipart")
|
||||||
@@ -80,6 +85,18 @@ def _install_model_route_import_stubs(monkeypatch):
|
|||||||
monkeypatch.setitem(sys.modules, "python_multipart", multipart_mod)
|
monkeypatch.setitem(sys.modules, "python_multipart", multipart_mod)
|
||||||
|
|
||||||
|
|
||||||
|
def _install_core_auth_stub(monkeypatch):
|
||||||
|
"""Install the narrow auth surface needed by tool-policy tests."""
|
||||||
|
core_mod = types.ModuleType("core")
|
||||||
|
core_mod.__path__ = []
|
||||||
|
auth_mod = types.ModuleType("core.auth")
|
||||||
|
auth_mod.AuthManager = MagicMock()
|
||||||
|
core_mod.auth = auth_mod
|
||||||
|
monkeypatch.setitem(sys.modules, "core", core_mod)
|
||||||
|
monkeypatch.setitem(sys.modules, "core.auth", auth_mod)
|
||||||
|
return auth_mod
|
||||||
|
|
||||||
|
|
||||||
def test_default_chat_does_not_auto_pick_shared_endpoint_for_fresh_user(monkeypatch):
|
def test_default_chat_does_not_auto_pick_shared_endpoint_for_fresh_user(monkeypatch):
|
||||||
_install_model_route_import_stubs(monkeypatch)
|
_install_model_route_import_stubs(monkeypatch)
|
||||||
import routes.model_routes as model_routes
|
import routes.model_routes as model_routes
|
||||||
@@ -335,8 +352,8 @@ async def test_build_chat_context_incognito_does_not_duplicate_current_user_mess
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_admin_agent_tools_require_admin(monkeypatch):
|
async def test_admin_agent_tools_require_admin(monkeypatch):
|
||||||
|
auth_mod = _install_core_auth_stub(monkeypatch)
|
||||||
from src.tool_execution import execute_tool_block
|
from src.tool_execution import execute_tool_block
|
||||||
import core.auth
|
|
||||||
|
|
||||||
class FakeAuth:
|
class FakeAuth:
|
||||||
is_configured = True
|
is_configured = True
|
||||||
@@ -344,7 +361,7 @@ async def test_admin_agent_tools_require_admin(monkeypatch):
|
|||||||
def is_admin(self, username):
|
def is_admin(self, username):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
monkeypatch.setattr(core.auth, "AuthManager", lambda: FakeAuth())
|
monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
|
||||||
|
|
||||||
desc, result = await execute_tool_block(
|
desc, result = await execute_tool_block(
|
||||||
SimpleNamespace(tool_type="manage_tokens", content='{"action":"create","name":"bad"}'),
|
SimpleNamespace(tool_type="manage_tokens", content='{"action":"create","name":"bad"}'),
|
||||||
@@ -358,8 +375,8 @@ async def test_admin_agent_tools_require_admin(monkeypatch):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_public_agent_policy_blocks_sensitive_tools(monkeypatch):
|
async def test_public_agent_policy_blocks_sensitive_tools(monkeypatch):
|
||||||
|
auth_mod = _install_core_auth_stub(monkeypatch)
|
||||||
from src.tool_execution import execute_tool_block
|
from src.tool_execution import execute_tool_block
|
||||||
import core.auth
|
|
||||||
|
|
||||||
class FakeAuth:
|
class FakeAuth:
|
||||||
is_configured = True
|
is_configured = True
|
||||||
@@ -367,7 +384,7 @@ async def test_public_agent_policy_blocks_sensitive_tools(monkeypatch):
|
|||||||
def is_admin(self, username):
|
def is_admin(self, username):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
monkeypatch.setattr(core.auth, "AuthManager", lambda: FakeAuth())
|
monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
|
||||||
|
|
||||||
for tool_name in ("send_email", "read_file", "app_api", "mcp__email__send_email"):
|
for tool_name in ("send_email", "read_file", "app_api", "mcp__email__send_email"):
|
||||||
desc, result = await execute_tool_block(
|
desc, result = await execute_tool_block(
|
||||||
@@ -380,7 +397,7 @@ async def test_public_agent_policy_blocks_sensitive_tools(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
def test_public_agent_policy_hides_sensitive_tools(monkeypatch):
|
def test_public_agent_policy_hides_sensitive_tools(monkeypatch):
|
||||||
import core.auth
|
auth_mod = _install_core_auth_stub(monkeypatch)
|
||||||
from src.tool_security import blocked_tools_for_owner
|
from src.tool_security import blocked_tools_for_owner
|
||||||
|
|
||||||
class FakeAuth:
|
class FakeAuth:
|
||||||
@@ -389,7 +406,7 @@ def test_public_agent_policy_hides_sensitive_tools(monkeypatch):
|
|||||||
def is_admin(self, username):
|
def is_admin(self, username):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
monkeypatch.setattr(core.auth, "AuthManager", lambda: FakeAuth())
|
monkeypatch.setattr(auth_mod, "AuthManager", lambda: FakeAuth())
|
||||||
|
|
||||||
blocked = blocked_tools_for_owner("regular-user")
|
blocked = blocked_tools_for_owner("regular-user")
|
||||||
|
|
||||||
|
|||||||
@@ -14,23 +14,44 @@ that a mid-operation DB error neither raises out of the helper nor leaks the
|
|||||||
connection. The error-path cases fail against the old close()-inside-try
|
connection. The error-path cases fail against the old close()-inside-try
|
||||||
pattern.
|
pattern.
|
||||||
"""
|
"""
|
||||||
import os
|
import ast
|
||||||
os.environ.setdefault("DATABASE_URL", "sqlite:///:memory:")
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Generator
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from core import database as db
|
|
||||||
|
def _load_db_helpers():
|
||||||
|
"""Load only the helper bodies under test, without importing SQLAlchemy."""
|
||||||
|
db_path = Path(__file__).parents[1] / "core" / "database.py"
|
||||||
|
tree = ast.parse(db_path.read_text(encoding="utf-8"), filename=str(db_path))
|
||||||
|
wanted = {"get_db_session", "get_session_mode", "set_session_mode"}
|
||||||
|
helper_nodes = [
|
||||||
|
node for node in tree.body
|
||||||
|
if isinstance(node, ast.FunctionDef) and node.name in wanted
|
||||||
|
]
|
||||||
|
namespace = {
|
||||||
|
"contextmanager": contextmanager,
|
||||||
|
"Generator": Generator,
|
||||||
|
"Session": MagicMock(),
|
||||||
|
"SessionLocal": MagicMock(),
|
||||||
|
"logger": MagicMock(),
|
||||||
|
}
|
||||||
|
exec(compile(ast.Module(helper_nodes, type_ignores=[]), str(db_path), "exec"), namespace)
|
||||||
|
return SimpleNamespace(**namespace, _namespace=namespace)
|
||||||
|
|
||||||
|
|
||||||
def _mock_session(monkeypatch):
|
def _mock_session(monkeypatch):
|
||||||
"""Make get_db_session() hand out a MagicMock session (no real DB)."""
|
"""Make get_db_session() hand out a MagicMock session (no real DB)."""
|
||||||
|
db = _load_db_helpers()
|
||||||
sess = MagicMock()
|
sess = MagicMock()
|
||||||
monkeypatch.setattr(db, "SessionLocal", lambda: sess)
|
monkeypatch.setitem(db._namespace, "SessionLocal", lambda: sess)
|
||||||
return sess
|
return db, sess
|
||||||
|
|
||||||
|
|
||||||
def test_set_session_mode_commits_and_closes_on_success(monkeypatch):
|
def test_set_session_mode_commits_and_closes_on_success(monkeypatch):
|
||||||
sess = _mock_session(monkeypatch)
|
db, sess = _mock_session(monkeypatch)
|
||||||
assert db.set_session_mode("s1", "agent") is True
|
assert db.set_session_mode("s1", "agent") is True
|
||||||
sess.query.return_value.filter.return_value.update.assert_called_once_with({"mode": "agent"})
|
sess.query.return_value.filter.return_value.update.assert_called_once_with({"mode": "agent"})
|
||||||
sess.commit.assert_called_once()
|
sess.commit.assert_called_once()
|
||||||
@@ -38,7 +59,7 @@ def test_set_session_mode_commits_and_closes_on_success(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
def test_set_session_mode_does_not_leak_on_error(monkeypatch):
|
def test_set_session_mode_does_not_leak_on_error(monkeypatch):
|
||||||
sess = _mock_session(monkeypatch)
|
db, sess = _mock_session(monkeypatch)
|
||||||
sess.query.return_value.filter.return_value.update.side_effect = RuntimeError("database is locked")
|
sess.query.return_value.filter.return_value.update.side_effect = RuntimeError("database is locked")
|
||||||
# Best-effort: the error is swallowed and False returned...
|
# Best-effort: the error is swallowed and False returned...
|
||||||
assert db.set_session_mode("s1", "agent") is False
|
assert db.set_session_mode("s1", "agent") is False
|
||||||
@@ -48,14 +69,14 @@ def test_set_session_mode_does_not_leak_on_error(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
def test_get_session_mode_reads_and_closes(monkeypatch):
|
def test_get_session_mode_reads_and_closes(monkeypatch):
|
||||||
sess = _mock_session(monkeypatch)
|
db, sess = _mock_session(monkeypatch)
|
||||||
sess.query.return_value.filter.return_value.scalar.return_value = "research_pending"
|
sess.query.return_value.filter.return_value.scalar.return_value = "research_pending"
|
||||||
assert db.get_session_mode("s1") == "research_pending"
|
assert db.get_session_mode("s1") == "research_pending"
|
||||||
sess.close.assert_called_once()
|
sess.close.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_get_session_mode_does_not_leak_on_error(monkeypatch):
|
def test_get_session_mode_does_not_leak_on_error(monkeypatch):
|
||||||
sess = _mock_session(monkeypatch)
|
db, sess = _mock_session(monkeypatch)
|
||||||
sess.query.return_value.filter.return_value.scalar.side_effect = RuntimeError("database is locked")
|
sess.query.return_value.filter.return_value.scalar.side_effect = RuntimeError("database is locked")
|
||||||
assert db.get_session_mode("s1") is None
|
assert db.get_session_mode("s1") is None
|
||||||
sess.close.assert_called_once()
|
sess.close.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user