Skip to content

Commit

Permalink
fix: use task_completed callback for notifies as well
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Sep 2, 2024
1 parent 9df986e commit c8a9bfa
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 14 deletions.
27 changes: 17 additions & 10 deletions src/planai/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _remove_provenance(self, task: Task):
to_notify.add(prefix)

for prefix in to_notify:
self._notify_task_completion(prefix)
self._notify_task_completion(prefix, task)

def watch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
if not isinstance(prefix, tuple):
Expand All @@ -97,7 +97,7 @@ def unwatch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
return True
return False

def _notify_task_completion(self, prefix: tuple):
def _notify_task_completion(self, prefix: tuple, task: Task):
to_notify = []
with self.notifiers_lock:
for notifier in self.notifiers[prefix]:
Expand All @@ -106,8 +106,13 @@ def _notify_task_completion(self, prefix: tuple):
for notifier, prefix in to_notify:
with self.task_lock:
self.active_tasks += 1

# Use a named function instead of a lambda to avoid closure issues
def task_completed_callback(future, worker=notifier, task=task):
self._task_completed(worker, task, future)

future = self.graph._thread_pool.submit(notifier.notify, prefix)
future.add_done_callback(self._notify_completed)
future.add_done_callback(task_completed_callback)

def _dispatch_once(self) -> bool:
try:
Expand All @@ -127,7 +132,15 @@ def task_completed_callback(future, worker=worker, task=task):
return False

def dispatch(self):
while not self.stop_event.is_set() or not self.work_queue.empty():
while True:
# making sure that we can access active_tasks in a thread-safe way
with self.task_lock:
if (
self.stop_event.is_set()
and self.work_queue.empty()
and self.active_tasks == 0
):
break
self._dispatch_once()

def _execute_task(self, worker: TaskWorker, task: Task):
Expand Down Expand Up @@ -197,12 +210,6 @@ def _get_task_id(self, task: Task) -> str:
# Fallback in case _provenance is empty
return f"unknown_{id(task)}"

def _notify_completed(self, future):
with self.task_lock:
self.active_tasks -= 1
if self.active_tasks == 0 and self.work_queue.empty():
self.task_completion_event.set()

def _task_completed(self, worker: TaskWorker, task: Task, future):
success: bool = False
error_message: str = ""
Expand Down
11 changes: 7 additions & 4 deletions tests/planai/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,14 @@ def test_remove_provenance(self):
with patch.object(self.dispatcher, "_notify_task_completion") as mock_notify:
self.dispatcher._remove_provenance(task)
self.assertEqual(self.dispatcher.provenance, {})
mock_notify.assert_any_call((("Task1", 1),))
mock_notify.assert_any_call((("Task1", 1), ("Task2", 2)))
mock_notify.assert_any_call((("Task1", 1),), task)
mock_notify.assert_any_call((("Task1", 1), ("Task2", 2)), task)

def test_notify_task_completion(self):
notifier = Mock(spec=TaskWorker)
task = DummyTask(data="test")
self.dispatcher.notifiers = {(("Task1", 1),): [notifier]}
self.dispatcher._notify_task_completion((("Task1", 1),))
self.dispatcher._notify_task_completion((("Task1", 1),), task)
self.assertEqual(self.dispatcher.active_tasks, 1)
notifier.notify.assert_called_once_with((("Task1", 1),))

Expand Down Expand Up @@ -222,8 +223,10 @@ def test_get_completed_tasks(self):

def test_notify_completed(self):
future = Mock()
worker = DummyTaskWorkerSimple()
task = DummyTask(data="test")
self.dispatcher.active_tasks = 1
self.dispatcher._notify_completed(future)
self.dispatcher._task_completed(worker, task, future)
self.assertEqual(self.dispatcher.active_tasks, 0)
self.assertTrue(self.dispatcher.task_completion_event.is_set())

Expand Down

0 comments on commit c8a9bfa

Please sign in to comment.