Skip to content

Commit

Permalink
fix: delete provenance when moving a task to a sub-graph
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Oct 6, 2024
1 parent dbfce66 commit 1c07ecc
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/planai/graph_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def init(self):
def consume_work(self, task: Task):
# save the task provenance
new_task = task.model_copy()
new_task._provenance = []
new_task._input_provenance = []
new_task.add_private_state(PRIVATE_STATE_KEY, task)

# and dispatch it to the sub-graph. this also sets the task provenance to InitialTaskWorker
Expand Down
2 changes: 1 addition & 1 deletion src/planai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class TaskWorker(BaseModel, ABC):
description="The types of work this task can output",
)
num_retries: int = Field(
0, description="The number of retries allowed for this task"
default=0, description="The number of retries allowed for this task"
)

_state_lock: threading.RLock = PrivateAttr(default_factory=threading.RLock)
Expand Down
122 changes: 117 additions & 5 deletions tests/planai/test_graph_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from planai.graph import Graph
from planai.graph_task import GraphTask
from planai.joined_task import JoinedTaskWorker
from planai.task import Task, TaskWorker


Expand All @@ -13,7 +14,7 @@ class InputTask(Task):

class SubGraphTask(Task):
data: str
_intermediate: bool = False # Indicate if this task is from the subgraph
intermediate: bool = False # Indicate if this task is from the subgraph


class FinalTask(Task):
Expand All @@ -35,7 +36,7 @@ class SubGraphWorker(TaskWorker):

def consume_work(self, task: SubGraphTask):
# Mark task as processed in subgraph
task._intermediate = True
task.intermediate = True
# Simulate processing in the subgraph
time.sleep(0.1)
self.publish_work(task, input_task=task)
Expand All @@ -47,10 +48,8 @@ class FinalWorker(TaskWorker):
def consume_work(self, task: SubGraphTask):
# Simulate final processing in the main graph
time.sleep(0.1)
print(task._provenance)
assert task._intermediate, "Task should have been processed by subgraph"
assert task.intermediate, "Task should have been processed by subgraph"
assert len(task._provenance) == 3, "Provenance length should be 3"
# The provenance should include: MainWorker, GraphTask, FinalWorker
# The subgraph's provenance should not appear here
expected_provenance = [
("InitialTaskWorker", 1),
Expand Down Expand Up @@ -100,6 +99,7 @@ def test_graph_task_provenance(self):
# Ensure that the dispatcher has completed all tasks
dispatcher = graph._dispatcher
self.assertIsNotNone(dispatcher)
assert dispatcher is not None
self.assertEqual(dispatcher.work_queue.qsize(), 0)
self.assertEqual(dispatcher.active_tasks, 0)
self.assertEqual(
Expand All @@ -112,6 +112,118 @@ def test_graph_task_provenance(self):

# Optionally, we can check logs or other side effects if needed

def test_graph_task_with_joined_task_worker(self):
# Create subgraph with a JoinedTaskWorker
subgraph = Graph(name="SubGraphWithJoin")

class SubInputTask(Task):
data: str

class SubOutputTask(Task):
data: str

class SubInitWorker(TaskWorker):
output_types: List[Type[Task]] = [SubGraphTask]

def consume_work(self, task: SubGraphTask):
assert len(task._provenance) == 1, "Provenance length should be 3"
# The subgraph's provenance should not appear here
assert task._provenance == [
("InitialTaskWorker", 1)
], "Provenance mismatch"

sub_task = SubGraphTask(data=task.data)
self.publish_work(sub_task, input_task=task)

class SubWorker(TaskWorker):
output_types: List[Type[Task]] = [SubInputTask]

def consume_work(self, task: SubGraphTask):
# Simulate processing in the subgraph
time.sleep(0.1)
# Publish multiple SubInputTasks for joining
for i in range(3):
sub_task = SubInputTask(data=f"{task.data}-{i}")
self.publish_work(sub_task, input_task=task)

class SubJoinWorker(JoinedTaskWorker):
join_type: Type[TaskWorker] = SubInitWorker
output_types: List[Type[Task]] = [SubGraphTask]

def consume_work(self, task: SubInputTask):
super().consume_work(task) # Handle joining

def consume_work_joined(self, tasks: List[SubInputTask]):
# Simulate processing after joining tasks
time.sleep(0.1)
# Concatenate data from joined tasks
joined_data = ";".join([t.data for t in tasks])
output_task = SubGraphTask(data=joined_data, intermediate=True)
self.publish_work(output_task, input_task=tasks[0])

sub_init_worker = SubInitWorker()
sub_worker = SubWorker()
sub_join_worker = SubJoinWorker()
subgraph_entry = sub_init_worker
subgraph_exit = sub_join_worker

subgraph.add_workers(sub_init_worker, sub_worker, sub_join_worker)
subgraph.set_dependency(sub_init_worker, sub_worker).next(sub_join_worker)

# Create GraphTask
graph_task = GraphTask(
graph=subgraph, entry_worker=subgraph_entry, exit_worker=subgraph_exit
)

# Create main graph
graph = Graph(name="MainGraphWithJoin")
main_worker = MainWorker()
final_worker = FinalWorker()

graph.add_workers(main_worker, graph_task, final_worker)
graph.set_dependency(main_worker, graph_task).next(final_worker)

# Prepare initial work item
initial_task = InputTask(data="TestData")
initial_work = [(main_worker, initial_task)]

graph.run(
initial_tasks=initial_work, run_dashboard=False, display_terminal=False
)

# Ensure that the dispatcher has completed all tasks
dispatcher = graph._dispatcher
self.assertIsNotNone(dispatcher)
assert dispatcher is not None
self.assertEqual(dispatcher.work_queue.qsize(), 0)
self.assertEqual(dispatcher.active_tasks, 0)

# The total completed tasks should include:
# - MainWorker: 1
# - GraphTask: 1
# - FinalWorker: 1
# - SubWorker: 1 (publishes 3 SubInputTasks)
# - SubJoinWorker: handles the 3 SubInputTasks and calls consume_work_joined once
# - AdapterSinkWorker: 1
total_main_graph_tasks = 3 # main_worker, graph_task, final_worker
total_subgraph_tasks = (
1 # sub_init_worker
+ 1 # sub_worker consumes SubGraphTask
+ 3 # sub_join_worker consumes 3 SubInputTasks
+ 1 # sub_join_worker consumes joined tasks
+ 1 # AdapterSinkWorker
)
expected_total_tasks = total_main_graph_tasks + total_subgraph_tasks

self.assertEqual(dispatcher.total_completed_tasks, expected_total_tasks)

# Since FinalWorker has no output, check that it processed the task correctly
final_tasks = graph.get_output_tasks()
self.assertEqual(len(final_tasks), 0)

# Optionally, you can verify that the joined data is correct
# But since FinalWorker does not output anything, and we didn't sink any tasks, we can't retrieve outputs here


if __name__ == "__main__":
unittest.main()

0 comments on commit 1c07ecc

Please sign in to comment.