diff --git a/src/task_scheduler.py b/src/task_scheduler.py index bf2dfe3..297a17c 100644 --- a/src/task_scheduler.py +++ b/src/task_scheduler.py @@ -241,6 +241,35 @@ class TaskScheduler: except Exception: logger.debug("Task progress update failed", exc_info=True) + def _mark_run_aborted(self, task_id: str, run_id: str | None = None, message: str = "Stopped by user") -> bool: + """Mark an active run as aborted. Used by stop/cancel paths.""" + try: + from core.database import SessionLocal, TaskRun + db = SessionLocal() + try: + q = db.query(TaskRun) + if run_id: + q = q.filter(TaskRun.id == run_id) + else: + q = q.filter( + TaskRun.task_id == task_id, + TaskRun.status.in_(("queued", "running")), + ).order_by(TaskRun.started_at.desc()) + run = q.first() + if not run or run.status not in ("queued", "running"): + return False + run.status = "aborted" + run.error = message + run.result = run.result or message + run.finished_at = datetime.utcnow() + db.commit() + return True + finally: + db.close() + except Exception: + logger.debug("Task abort marker failed for %s", task_id, exc_info=True) + return False + def add_notification(self, task_name: str, status: str, task_id: str = None, owner: str = None, body: str = None): """Store a notification about a completed task run. Tagged with the task's owner so `pop_notifications` can return only that user's @@ -581,12 +610,25 @@ class TaskScheduler: finally: _q_db.close() - if bypass_model_slot or not self._task_needs_model_slot(task_id): - await self._execute_task_locked(task_id, run_id, release_executing=release_executing) - return + try: + if bypass_model_slot or not self._task_needs_model_slot(task_id): + await self._execute_task_locked(task_id, run_id, release_executing=release_executing) + return - async with self._run_semaphore: - await self._execute_task_locked(task_id, run_id, release_executing=release_executing) + async with self._run_semaphore: + await self._execute_task_locked(task_id, run_id, release_executing=release_executing) + except asyncio.CancelledError: + # If cancellation happens while queued behind the semaphore, + # _execute_task_locked never runs and cannot update the Activity row. + self._mark_run_aborted(task_id, run_id) + raise + finally: + handle = self._task_handles.get(task_id) + if handle is current: + self._task_handles.pop(task_id, None) + if release_executing: + async with self._executing_lock: + self._executing.discard(task_id) async def _execute_task_locked(self, task_id: str, run_id: str, *, release_executing: bool = True): from core.database import SessionLocal, ScheduledTask, TaskRun @@ -1839,24 +1881,7 @@ class TaskScheduler: self._executing.discard(task_id) stopped = True - from core.database import SessionLocal, TaskRun - db = SessionLocal() - try: - run = ( - db.query(TaskRun) - .filter(TaskRun.task_id == task_id, TaskRun.status.in_(("queued", "running"))) - .order_by(TaskRun.started_at.desc()) - .first() - ) - if run: - run.status = "aborted" - run.error = "Stopped by user" - run.result = run.result or "Stopped by user" - run.finished_at = datetime.utcnow() - db.commit() - stopped = True - finally: - db.close() + stopped = self._mark_run_aborted(task_id) or stopped return stopped async def ensure_defaults(self, owner: str): diff --git a/tests/test_task_scheduler_cancel.py b/tests/test_task_scheduler_cancel.py new file mode 100644 index 0000000..3d399f1 --- /dev/null +++ b/tests/test_task_scheduler_cancel.py @@ -0,0 +1,105 @@ +import asyncio + +from sqlalchemy import Column, DateTime, String, Text, create_engine +from sqlalchemy.orm import declarative_base, sessionmaker + + +def _setup_db(tmp_path, monkeypatch): + import core.database as cd + + base = declarative_base() + + class ScheduledTask(base): + __tablename__ = "scheduled_tasks" + + id = Column(String, primary_key=True) + owner = Column(String) + name = Column(String) + task_type = Column(String, default="llm") + action = Column(String) + status = Column(String, default="active") + + class TaskRun(base): + __tablename__ = "task_runs" + + id = Column(String, primary_key=True) + task_id = Column(String) + started_at = Column(DateTime) + finished_at = Column(DateTime) + status = Column(String) + result = Column(Text) + error = Column(Text) + model = Column(String) + + engine = create_engine(f"sqlite:///{tmp_path / 'tasks.db'}") + base.metadata.create_all(engine) + session_local = sessionmaker(bind=engine, autocommit=False, autoflush=False) + monkeypatch.setattr(cd, "SessionLocal", session_local) + monkeypatch.setattr(cd, "ScheduledTask", ScheduledTask) + monkeypatch.setattr(cd, "TaskRun", TaskRun) + return session_local, ScheduledTask, TaskRun + + +def test_stop_task_cleans_up_queued_handle_and_run(tmp_path, monkeypatch): + session_local, ScheduledTask, TaskRun = _setup_db(tmp_path, monkeypatch) + + db = session_local() + db.add(ScheduledTask( + id="queued-task", + owner="alice", + name="Queued Task", + task_type="llm", + status="active", + )) + db.commit() + db.close() + + from src.task_scheduler import TaskScheduler + + async def drive(): + scheduler = TaskScheduler.__new__(TaskScheduler) + scheduler._executing = {"queued-task"} + scheduler._executing_lock = asyncio.Lock() + scheduler._run_semaphore = asyncio.Semaphore(1) + scheduler._task_handles = {} + scheduler._concurrency_cap = 1 + scheduler._task_defer_counts = {} + await scheduler._run_semaphore.acquire() + + task = asyncio.create_task(scheduler._execute_task("queued-task")) + try: + for _ in range(50): + if "queued-task" in scheduler._task_handles: + db2 = session_local() + try: + run = db2.query(TaskRun).filter(TaskRun.task_id == "queued-task").first() + if run: + break + finally: + db2.close() + await asyncio.sleep(0.01) + else: + raise AssertionError("queued run was not created") + + assert await scheduler.stop_task("queued-task") is True + try: + await task + except asyncio.CancelledError: + pass + finally: + scheduler._run_semaphore.release() + + assert "queued-task" not in scheduler._task_handles + assert "queued-task" not in scheduler._executing + + asyncio.run(drive()) + + db = session_local() + try: + run = db.query(TaskRun).filter(TaskRun.task_id == "queued-task").first() + assert run.status == "aborted" + assert run.error == "Stopped by user" + assert run.finished_at is not None + assert run.finished_at >= run.started_at + finally: + db.close()