Tasks: clean up queued cancellation state
This commit is contained in:
@@ -241,6 +241,35 @@ class TaskScheduler:
|
||||
except Exception:
|
||||
logger.debug("Task progress update failed", exc_info=True)
|
||||
|
||||
def _mark_run_aborted(self, task_id: str, run_id: str | None = None, message: str = "Stopped by user") -> bool:
|
||||
"""Mark an active run as aborted. Used by stop/cancel paths."""
|
||||
try:
|
||||
from core.database import SessionLocal, TaskRun
|
||||
db = SessionLocal()
|
||||
try:
|
||||
q = db.query(TaskRun)
|
||||
if run_id:
|
||||
q = q.filter(TaskRun.id == run_id)
|
||||
else:
|
||||
q = q.filter(
|
||||
TaskRun.task_id == task_id,
|
||||
TaskRun.status.in_(("queued", "running")),
|
||||
).order_by(TaskRun.started_at.desc())
|
||||
run = q.first()
|
||||
if not run or run.status not in ("queued", "running"):
|
||||
return False
|
||||
run.status = "aborted"
|
||||
run.error = message
|
||||
run.result = run.result or message
|
||||
run.finished_at = datetime.utcnow()
|
||||
db.commit()
|
||||
return True
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
logger.debug("Task abort marker failed for %s", task_id, exc_info=True)
|
||||
return False
|
||||
|
||||
def add_notification(self, task_name: str, status: str, task_id: str = None, owner: str = None, body: str = None):
|
||||
"""Store a notification about a completed task run. Tagged with the
|
||||
task's owner so `pop_notifications` can return only that user's
|
||||
@@ -581,12 +610,25 @@ class TaskScheduler:
|
||||
finally:
|
||||
_q_db.close()
|
||||
|
||||
if bypass_model_slot or not self._task_needs_model_slot(task_id):
|
||||
await self._execute_task_locked(task_id, run_id, release_executing=release_executing)
|
||||
return
|
||||
try:
|
||||
if bypass_model_slot or not self._task_needs_model_slot(task_id):
|
||||
await self._execute_task_locked(task_id, run_id, release_executing=release_executing)
|
||||
return
|
||||
|
||||
async with self._run_semaphore:
|
||||
await self._execute_task_locked(task_id, run_id, release_executing=release_executing)
|
||||
async with self._run_semaphore:
|
||||
await self._execute_task_locked(task_id, run_id, release_executing=release_executing)
|
||||
except asyncio.CancelledError:
|
||||
# If cancellation happens while queued behind the semaphore,
|
||||
# _execute_task_locked never runs and cannot update the Activity row.
|
||||
self._mark_run_aborted(task_id, run_id)
|
||||
raise
|
||||
finally:
|
||||
handle = self._task_handles.get(task_id)
|
||||
if handle is current:
|
||||
self._task_handles.pop(task_id, None)
|
||||
if release_executing:
|
||||
async with self._executing_lock:
|
||||
self._executing.discard(task_id)
|
||||
|
||||
async def _execute_task_locked(self, task_id: str, run_id: str, *, release_executing: bool = True):
|
||||
from core.database import SessionLocal, ScheduledTask, TaskRun
|
||||
@@ -1839,24 +1881,7 @@ class TaskScheduler:
|
||||
self._executing.discard(task_id)
|
||||
stopped = True
|
||||
|
||||
from core.database import SessionLocal, TaskRun
|
||||
db = SessionLocal()
|
||||
try:
|
||||
run = (
|
||||
db.query(TaskRun)
|
||||
.filter(TaskRun.task_id == task_id, TaskRun.status.in_(("queued", "running")))
|
||||
.order_by(TaskRun.started_at.desc())
|
||||
.first()
|
||||
)
|
||||
if run:
|
||||
run.status = "aborted"
|
||||
run.error = "Stopped by user"
|
||||
run.result = run.result or "Stopped by user"
|
||||
run.finished_at = datetime.utcnow()
|
||||
db.commit()
|
||||
stopped = True
|
||||
finally:
|
||||
db.close()
|
||||
stopped = self._mark_run_aborted(task_id) or stopped
|
||||
return stopped
|
||||
|
||||
async def ensure_defaults(self, owner: str):
|
||||
|
||||
105
tests/test_task_scheduler_cancel.py
Normal file
105
tests/test_task_scheduler_cancel.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import Column, DateTime, String, Text, create_engine
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
|
||||
|
||||
def _setup_db(tmp_path, monkeypatch):
|
||||
import core.database as cd
|
||||
|
||||
base = declarative_base()
|
||||
|
||||
class ScheduledTask(base):
|
||||
__tablename__ = "scheduled_tasks"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
owner = Column(String)
|
||||
name = Column(String)
|
||||
task_type = Column(String, default="llm")
|
||||
action = Column(String)
|
||||
status = Column(String, default="active")
|
||||
|
||||
class TaskRun(base):
|
||||
__tablename__ = "task_runs"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
task_id = Column(String)
|
||||
started_at = Column(DateTime)
|
||||
finished_at = Column(DateTime)
|
||||
status = Column(String)
|
||||
result = Column(Text)
|
||||
error = Column(Text)
|
||||
model = Column(String)
|
||||
|
||||
engine = create_engine(f"sqlite:///{tmp_path / 'tasks.db'}")
|
||||
base.metadata.create_all(engine)
|
||||
session_local = sessionmaker(bind=engine, autocommit=False, autoflush=False)
|
||||
monkeypatch.setattr(cd, "SessionLocal", session_local)
|
||||
monkeypatch.setattr(cd, "ScheduledTask", ScheduledTask)
|
||||
monkeypatch.setattr(cd, "TaskRun", TaskRun)
|
||||
return session_local, ScheduledTask, TaskRun
|
||||
|
||||
|
||||
def test_stop_task_cleans_up_queued_handle_and_run(tmp_path, monkeypatch):
|
||||
session_local, ScheduledTask, TaskRun = _setup_db(tmp_path, monkeypatch)
|
||||
|
||||
db = session_local()
|
||||
db.add(ScheduledTask(
|
||||
id="queued-task",
|
||||
owner="alice",
|
||||
name="Queued Task",
|
||||
task_type="llm",
|
||||
status="active",
|
||||
))
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
from src.task_scheduler import TaskScheduler
|
||||
|
||||
async def drive():
|
||||
scheduler = TaskScheduler.__new__(TaskScheduler)
|
||||
scheduler._executing = {"queued-task"}
|
||||
scheduler._executing_lock = asyncio.Lock()
|
||||
scheduler._run_semaphore = asyncio.Semaphore(1)
|
||||
scheduler._task_handles = {}
|
||||
scheduler._concurrency_cap = 1
|
||||
scheduler._task_defer_counts = {}
|
||||
await scheduler._run_semaphore.acquire()
|
||||
|
||||
task = asyncio.create_task(scheduler._execute_task("queued-task"))
|
||||
try:
|
||||
for _ in range(50):
|
||||
if "queued-task" in scheduler._task_handles:
|
||||
db2 = session_local()
|
||||
try:
|
||||
run = db2.query(TaskRun).filter(TaskRun.task_id == "queued-task").first()
|
||||
if run:
|
||||
break
|
||||
finally:
|
||||
db2.close()
|
||||
await asyncio.sleep(0.01)
|
||||
else:
|
||||
raise AssertionError("queued run was not created")
|
||||
|
||||
assert await scheduler.stop_task("queued-task") is True
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
scheduler._run_semaphore.release()
|
||||
|
||||
assert "queued-task" not in scheduler._task_handles
|
||||
assert "queued-task" not in scheduler._executing
|
||||
|
||||
asyncio.run(drive())
|
||||
|
||||
db = session_local()
|
||||
try:
|
||||
run = db.query(TaskRun).filter(TaskRun.task_id == "queued-task").first()
|
||||
assert run.status == "aborted"
|
||||
assert run.error == "Stopped by user"
|
||||
assert run.finished_at is not None
|
||||
assert run.finished_at >= run.started_at
|
||||
finally:
|
||||
db.close()
|
||||
Reference in New Issue
Block a user