fix(scheduler): push next_run forward on startup to stop restart double-fire (#708)
TaskScheduler.start() aborts stale TaskRun rows but never advanced ScheduledTask.next_run. Across a restart the in-process _executing set is empty, so the first post-restart _check_due_tasks() call dispatches every task whose next_run is still in the past — and so does every subsequent poll, until the task's regular _execute_task path finally runs compute_next_run and pushes it forward. start() now queries active tasks with next_run < now and pushes each one to now + 60s. The first poll after restart sees them as not-yet-due, the task runs once normally, and compute_next_run puts the schedule back on its real cadence. Paused and not-yet-due tasks are left alone. The validator test was rewritten as a regression test asserting the opposite of the bug it originally demonstrated, plus two narrower cases to lock down the filter (only active+overdue is touched).
This commit is contained in:
@@ -312,6 +312,33 @@ class TaskScheduler:
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not clear stale task_runs on startup: {e}")
|
||||
|
||||
# Advance next_run for active tasks whose next_run is already in the
|
||||
# past. Without this, a restart hits _check_due_tasks() with an empty
|
||||
# in-process _executing set, and the same overdue task fires once per
|
||||
# poll until it completes.
|
||||
try:
|
||||
from core.database import SessionLocal as _SL, ScheduledTask as _ST
|
||||
db = _SL()
|
||||
try:
|
||||
now = datetime.utcnow()
|
||||
overdue = db.query(_ST).filter(
|
||||
_ST.status == "active",
|
||||
_ST.next_run.isnot(None),
|
||||
_ST.next_run < now,
|
||||
).all()
|
||||
if overdue:
|
||||
for t in overdue:
|
||||
t.next_run = now + timedelta(seconds=60)
|
||||
db.commit()
|
||||
logger.info(
|
||||
"Pushed next_run forward by 60s for %d overdue active tasks on startup",
|
||||
len(overdue),
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not advance overdue next_run on startup: {e}")
|
||||
|
||||
# Defense-in-depth dedupe sweep: for any owner with >1 rows where
|
||||
# is_default_assistant=True, keep the oldest and demote the rest +
|
||||
# delete their orphaned check-in tasks. This is the safety net for
|
||||
|
||||
190
tests/test_scheduler_restart_doublefire.py
Normal file
190
tests/test_scheduler_restart_doublefire.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Validator + regression test for FINDING 6.2 — restart double-fires overdue
|
||||
scheduled tasks.
|
||||
|
||||
Demonstrates the bug: TaskScheduler.start() aborts stale TaskRun rows but never
|
||||
advances ScheduledTask.next_run, so the in-memory _executing guard resets
|
||||
across a restart and _check_due_tasks will re-dispatch any task whose
|
||||
next_run is still in the past.
|
||||
|
||||
After the fix (start() advances overdue next_run to now + 60s), the regression
|
||||
test asserts the opposite: the task fires at most once across two consecutive
|
||||
polls.
|
||||
"""
|
||||
import sys, types, asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
from sqlalchemy import create_engine, Column, String, DateTime, Integer, Boolean, Text
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
|
||||
|
||||
def _stub_heavy():
|
||||
for name in [
|
||||
"src.builtin_actions", "src.ai_interaction", "src.endpoint_resolver",
|
||||
"src.agent_loop", "src.session_manager",
|
||||
]:
|
||||
sys.modules.setdefault(name, types.ModuleType(name))
|
||||
|
||||
|
||||
def _setup_isolated_db():
|
||||
import core.database as cd
|
||||
B = declarative_base()
|
||||
|
||||
class ScheduledTask(B):
|
||||
__tablename__ = "scheduled_tasks"
|
||||
id = Column(String, primary_key=True)
|
||||
owner = Column(String)
|
||||
name = Column(String, default="t")
|
||||
prompt = Column(Text)
|
||||
task_type = Column(String, default="llm")
|
||||
next_run = Column(DateTime, index=True)
|
||||
last_run = Column(DateTime)
|
||||
status = Column(String, default="active")
|
||||
run_count = Column(Integer, default=0)
|
||||
|
||||
class TaskRun(B):
|
||||
__tablename__ = "task_runs"
|
||||
id = Column(String, primary_key=True)
|
||||
task_id = Column(String)
|
||||
started_at = Column(DateTime)
|
||||
finished_at = Column(DateTime)
|
||||
status = Column(String, default="queued")
|
||||
error = Column(Text)
|
||||
|
||||
eng = create_engine("sqlite:///:memory:")
|
||||
B.metadata.create_all(eng)
|
||||
cd.engine = eng
|
||||
cd.SessionLocal = sessionmaker(bind=eng, autocommit=False, autoflush=False)
|
||||
cd.ScheduledTask = ScheduledTask
|
||||
cd.TaskRun = TaskRun
|
||||
return cd, ScheduledTask, TaskRun
|
||||
|
||||
|
||||
def _drive_scheduler(monkeypatch, pre_start_setup=None):
|
||||
"""Build a TaskScheduler bypassing __init__ and run start() + two polls."""
|
||||
_stub_heavy()
|
||||
cd, ScheduledTask, TaskRun = _setup_isolated_db()
|
||||
|
||||
from src.task_scheduler import TaskScheduler
|
||||
sch = TaskScheduler.__new__(TaskScheduler)
|
||||
sch._executing = set()
|
||||
sch._executing_lock = asyncio.Lock()
|
||||
sch._concurrency_cap = 1
|
||||
sch._run_semaphore = asyncio.Semaphore(1)
|
||||
sch._running = True
|
||||
sch._task = None
|
||||
sch._note_pings_task = None
|
||||
sch._known_task_owners = lambda: []
|
||||
sch._task_defer_counts = {}
|
||||
|
||||
if pre_start_setup:
|
||||
pre_start_setup(cd, ScheduledTask, TaskRun)
|
||||
|
||||
async def _never():
|
||||
await asyncio.sleep(3600)
|
||||
monkeypatch.setattr(sch, "_loop", _never)
|
||||
monkeypatch.setattr(sch, "_note_pings_loop", _never)
|
||||
|
||||
dispatched = []
|
||||
def _fake_create_task(coro):
|
||||
dispatched.append(coro)
|
||||
class _T:
|
||||
def cancel(self): pass
|
||||
return _T()
|
||||
monkeypatch.setattr("src.task_scheduler.asyncio.create_task", _fake_create_task)
|
||||
|
||||
async def _drive():
|
||||
await sch.start()
|
||||
await sch._check_due_tasks()
|
||||
await sch._check_due_tasks()
|
||||
return dispatched
|
||||
|
||||
all_dispatched = asyncio.run(_drive())
|
||||
# start() also fires the long-lived _loop and _note_pings_loop as tasks
|
||||
# (stubbed to _never here); filter those out so the test only counts
|
||||
# real per-poll task dispatches.
|
||||
real_dispatches = [c for c in all_dispatched if c.__name__ != "_never"]
|
||||
return cd, ScheduledTask, TaskRun, real_dispatches
|
||||
|
||||
|
||||
def test_restart_does_not_re_dispatch_overdue_task(monkeypatch):
|
||||
"""After restart, an overdue active task should fire at most once across
|
||||
two consecutive polls (the first poll re-fires it, but next_run is then
|
||||
advanced so the second poll does not)."""
|
||||
def _setup(cd, ScheduledTask, TaskRun):
|
||||
db = cd.SessionLocal()
|
||||
db.add(ScheduledTask(
|
||||
id="t_due_1", owner="alice", name="overdue",
|
||||
task_type="llm",
|
||||
next_run=datetime.utcnow() - timedelta(hours=1),
|
||||
status="active",
|
||||
))
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
cd, ScheduledTask, TaskRun, dispatched = _drive_scheduler(monkeypatch, _setup)
|
||||
|
||||
db = cd.SessionLocal()
|
||||
t = db.query(ScheduledTask).filter(ScheduledTask.id == "t_due_1").first()
|
||||
db.close()
|
||||
assert t.next_run >= datetime.utcnow() - timedelta(seconds=1), (
|
||||
f"After start(), next_run should have been pushed into the future; "
|
||||
f"got {t.next_run}"
|
||||
)
|
||||
assert len(dispatched) <= 1, (
|
||||
f"Expected at most 1 dispatch across two polls; got {len(dispatched)}. "
|
||||
"The startup next_run advance is not preventing the second poll from "
|
||||
"re-firing the same overdue task."
|
||||
)
|
||||
|
||||
|
||||
def test_startup_does_not_advance_fresh_tasks(monkeypatch):
|
||||
"""Tasks whose next_run is in the future must be untouched by the startup
|
||||
sweep — only overdue ones get pushed forward."""
|
||||
future = datetime.utcnow() + timedelta(hours=2)
|
||||
def _setup(cd, ScheduledTask, TaskRun):
|
||||
db = cd.SessionLocal()
|
||||
db.add(ScheduledTask(
|
||||
id="t_fresh", owner="alice", name="fresh",
|
||||
task_type="llm", next_run=future, status="active",
|
||||
))
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
cd, ScheduledTask, TaskRun, dispatched = _drive_scheduler(monkeypatch, _setup)
|
||||
|
||||
db = cd.SessionLocal()
|
||||
t = db.query(ScheduledTask).filter(ScheduledTask.id == "t_fresh").first()
|
||||
db.close()
|
||||
assert t.next_run == future, (
|
||||
f"Fresh task's next_run was modified: expected {future}, got {t.next_run}"
|
||||
)
|
||||
assert len(dispatched) == 0
|
||||
|
||||
|
||||
def test_startup_does_not_advance_paused_tasks(monkeypatch):
|
||||
"""A paused task with an old next_run is not overdue for execution —
|
||||
it should not be advanced by the startup sweep."""
|
||||
def _setup(cd, ScheduledTask, TaskRun):
|
||||
db = cd.SessionLocal()
|
||||
db.add(ScheduledTask(
|
||||
id="t_paused", owner="alice", name="paused",
|
||||
task_type="llm",
|
||||
next_run=datetime.utcnow() - timedelta(hours=1),
|
||||
status="paused",
|
||||
))
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
cd, ScheduledTask, TaskRun, dispatched = _drive_scheduler(monkeypatch, _setup)
|
||||
|
||||
db = cd.SessionLocal()
|
||||
t = db.query(ScheduledTask).filter(ScheduledTask.id == "t_paused").first()
|
||||
db.close()
|
||||
# The stored next_run should still be ~1h in the past (the startup sweep
|
||||
# only advances active overdue tasks; a paused task with an old next_run
|
||||
# is left alone). Allow a small delta to absorb the time the sweep took.
|
||||
one_hour_ago = datetime.utcnow() - timedelta(hours=1)
|
||||
assert abs((t.next_run - one_hour_ago).total_seconds()) < 5, (
|
||||
f"Paused task's next_run was modified: "
|
||||
f"expected ~{one_hour_ago}, got {t.next_run}"
|
||||
)
|
||||
Reference in New Issue
Block a user