From 1c07eccfecfd08ba5898fb4557d40cb128ace3df Mon Sep 17 00:00:00 2001 From: Niels Provos Date: Sun, 6 Oct 2024 14:23:26 -0700 Subject: [PATCH] fix: delete provenance when moving a task to a sub-graph --- src/planai/graph_task.py | 2 + src/planai/task.py | 2 +- tests/planai/test_graph_task.py | 122 ++++++++++++++++++++++++++++++-- 3 files changed, 120 insertions(+), 6 deletions(-) diff --git a/src/planai/graph_task.py b/src/planai/graph_task.py index 1f4f9cd..7673b06 100644 --- a/src/planai/graph_task.py +++ b/src/planai/graph_task.py @@ -76,6 +76,8 @@ def init(self): def consume_work(self, task: Task): # save the task provenance new_task = task.model_copy() + new_task._provenance = [] + new_task._input_provenance = [] new_task.add_private_state(PRIVATE_STATE_KEY, task) # and dispatch it to the sub-graph. this also sets the task provenance to InitialTaskWorker diff --git a/src/planai/task.py b/src/planai/task.py index e896e35..fb0ee2f 100644 --- a/src/planai/task.py +++ b/src/planai/task.py @@ -190,7 +190,7 @@ class TaskWorker(BaseModel, ABC): description="The types of work this task can output", ) num_retries: int = Field( - 0, description="The number of retries allowed for this task" + default=0, description="The number of retries allowed for this task" ) _state_lock: threading.RLock = PrivateAttr(default_factory=threading.RLock) diff --git a/tests/planai/test_graph_task.py b/tests/planai/test_graph_task.py index e969e37..7c57c3f 100644 --- a/tests/planai/test_graph_task.py +++ b/tests/planai/test_graph_task.py @@ -4,6 +4,7 @@ from planai.graph import Graph from planai.graph_task import GraphTask +from planai.joined_task import JoinedTaskWorker from planai.task import Task, TaskWorker @@ -13,7 +14,7 @@ class InputTask(Task): class SubGraphTask(Task): data: str - _intermediate: bool = False # Indicate if this task is from the subgraph + intermediate: bool = False # Indicate if this task is from the subgraph class FinalTask(Task): @@ -35,7 +36,7 @@ class SubGraphWorker(TaskWorker): def consume_work(self, task: SubGraphTask): # Mark task as processed in subgraph - task._intermediate = True + task.intermediate = True # Simulate processing in the subgraph time.sleep(0.1) self.publish_work(task, input_task=task) @@ -47,10 +48,8 @@ class FinalWorker(TaskWorker): def consume_work(self, task: SubGraphTask): # Simulate final processing in the main graph time.sleep(0.1) - print(task._provenance) - assert task._intermediate, "Task should have been processed by subgraph" + assert task.intermediate, "Task should have been processed by subgraph" assert len(task._provenance) == 3, "Provenance length should be 3" - # The provenance should include: MainWorker, GraphTask, FinalWorker # The subgraph's provenance should not appear here expected_provenance = [ ("InitialTaskWorker", 1), @@ -100,6 +99,7 @@ def test_graph_task_provenance(self): # Ensure that the dispatcher has completed all tasks dispatcher = graph._dispatcher self.assertIsNotNone(dispatcher) + assert dispatcher is not None self.assertEqual(dispatcher.work_queue.qsize(), 0) self.assertEqual(dispatcher.active_tasks, 0) self.assertEqual( @@ -112,6 +112,118 @@ def test_graph_task_provenance(self): # Optionally, we can check logs or other side effects if needed + def test_graph_task_with_joined_task_worker(self): + # Create subgraph with a JoinedTaskWorker + subgraph = Graph(name="SubGraphWithJoin") + + class SubInputTask(Task): + data: str + + class SubOutputTask(Task): + data: str + + class SubInitWorker(TaskWorker): + output_types: List[Type[Task]] = [SubGraphTask] + + def consume_work(self, task: SubGraphTask): + assert len(task._provenance) == 1, "Provenance length should be 3" + # The subgraph's provenance should not appear here + assert task._provenance == [ + ("InitialTaskWorker", 1) + ], "Provenance mismatch" + + sub_task = SubGraphTask(data=task.data) + self.publish_work(sub_task, input_task=task) + + class SubWorker(TaskWorker): + output_types: List[Type[Task]] = [SubInputTask] + + def consume_work(self, task: SubGraphTask): + # Simulate processing in the subgraph + time.sleep(0.1) + # Publish multiple SubInputTasks for joining + for i in range(3): + sub_task = SubInputTask(data=f"{task.data}-{i}") + self.publish_work(sub_task, input_task=task) + + class SubJoinWorker(JoinedTaskWorker): + join_type: Type[TaskWorker] = SubInitWorker + output_types: List[Type[Task]] = [SubGraphTask] + + def consume_work(self, task: SubInputTask): + super().consume_work(task) # Handle joining + + def consume_work_joined(self, tasks: List[SubInputTask]): + # Simulate processing after joining tasks + time.sleep(0.1) + # Concatenate data from joined tasks + joined_data = ";".join([t.data for t in tasks]) + output_task = SubGraphTask(data=joined_data, intermediate=True) + self.publish_work(output_task, input_task=tasks[0]) + + sub_init_worker = SubInitWorker() + sub_worker = SubWorker() + sub_join_worker = SubJoinWorker() + subgraph_entry = sub_init_worker + subgraph_exit = sub_join_worker + + subgraph.add_workers(sub_init_worker, sub_worker, sub_join_worker) + subgraph.set_dependency(sub_init_worker, sub_worker).next(sub_join_worker) + + # Create GraphTask + graph_task = GraphTask( + graph=subgraph, entry_worker=subgraph_entry, exit_worker=subgraph_exit + ) + + # Create main graph + graph = Graph(name="MainGraphWithJoin") + main_worker = MainWorker() + final_worker = FinalWorker() + + graph.add_workers(main_worker, graph_task, final_worker) + graph.set_dependency(main_worker, graph_task).next(final_worker) + + # Prepare initial work item + initial_task = InputTask(data="TestData") + initial_work = [(main_worker, initial_task)] + + graph.run( + initial_tasks=initial_work, run_dashboard=False, display_terminal=False + ) + + # Ensure that the dispatcher has completed all tasks + dispatcher = graph._dispatcher + self.assertIsNotNone(dispatcher) + assert dispatcher is not None + self.assertEqual(dispatcher.work_queue.qsize(), 0) + self.assertEqual(dispatcher.active_tasks, 0) + + # The total completed tasks should include: + # - MainWorker: 1 + # - GraphTask: 1 + # - FinalWorker: 1 + # - SubWorker: 1 (publishes 3 SubInputTasks) + # - SubJoinWorker: handles the 3 SubInputTasks and calls consume_work_joined once + # - AdapterSinkWorker: 1 + total_main_graph_tasks = 3 # main_worker, graph_task, final_worker + total_subgraph_tasks = ( + 1 # sub_init_worker + + 1 # sub_worker consumes SubGraphTask + + 3 # sub_join_worker consumes 3 SubInputTasks + + 1 # sub_join_worker consumes joined tasks + + 1 # AdapterSinkWorker + ) + expected_total_tasks = total_main_graph_tasks + total_subgraph_tasks + + self.assertEqual(dispatcher.total_completed_tasks, expected_total_tasks) + + # Since FinalWorker has no output, check that it processed the task correctly + final_tasks = graph.get_output_tasks() + self.assertEqual(len(final_tasks), 0) + + # Optionally, you can verify that the joined data is correct + # But since FinalWorker does not output anything, and we didn't sink any tasks, we can't retrieve outputs here + if __name__ == "__main__": unittest.main()