From 3843ceee2fb5ae515ea24bb2e2a4dd56a0eb05b3 Mon Sep 17 00:00:00 2001 From: Niels Provos Date: Sun, 6 Oct 2024 17:50:20 -0700 Subject: [PATCH] refactor: Rename GraphTask to SubGraphWorker in graph_task.py and test_graph_task.py --- src/planai/graph_task.py | 4 ++-- tests/planai/test_graph_task.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/planai/graph_task.py b/src/planai/graph_task.py index 7673b06..2d54724 100644 --- a/src/planai/graph_task.py +++ b/src/planai/graph_task.py @@ -21,7 +21,7 @@ PRIVATE_STATE_KEY = "_graph_task_private_state" -class GraphTask(TaskWorker): +class SubGraphWorker(TaskWorker): graph: Graph = Field( ..., description="The graph that will be run as part of this TaskWorker" ) @@ -164,7 +164,7 @@ def consume_work(self, task: Task3WorkItem): sub_graph.set_dependency(task1, task2) # Create the graph task - graph_task = GraphTask(graph=sub_graph, entry_worker=task1, exit_worker=task2) + graph_task = SubGraphWorker(graph=sub_graph, entry_worker=task1, exit_worker=task2) # Create the final consumer task3 = Task3Worker() diff --git a/tests/planai/test_graph_task.py b/tests/planai/test_graph_task.py index 104e7d6..a40df17 100644 --- a/tests/planai/test_graph_task.py +++ b/tests/planai/test_graph_task.py @@ -3,7 +3,7 @@ from typing import List, Type from planai.graph import Graph -from planai.graph_task import GraphTask +from planai.graph_task import SubGraphWorker from planai.joined_task import JoinedTaskWorker from planai.task import Task, TaskWorker @@ -31,7 +31,7 @@ def consume_work(self, task: InputTask): self.publish_work(sub_task, input_task=task) -class SubGraphWorker(TaskWorker): +class SubGraphHandler(TaskWorker): output_types: List[Type[Task]] = [SubGraphTask] def consume_work(self, task: SubGraphTask): @@ -54,7 +54,7 @@ def consume_work(self, task: SubGraphTask): expected_provenance = [ ("InitialTaskWorker", 1), ("MainWorker", 1), - ("GraphTask", 1), + ("SubGraphWorker", 1), ] self.verify_provenance(task, expected_provenance) @@ -70,13 +70,13 @@ class TestGraphTask(unittest.TestCase): def test_graph_task_provenance(self): # Create subgraph subgraph = Graph(name="SubGraph") - subgraph_worker = SubGraphWorker() + subgraph_worker = SubGraphHandler() subgraph_entry = subgraph_worker # Entry point subgraph_exit = subgraph_worker # Exit point subgraph.add_workers(subgraph_worker) # Create GraphTask - graph_task = GraphTask( + graph_task = SubGraphWorker( graph=subgraph, entry_worker=subgraph_entry, exit_worker=subgraph_exit ) @@ -168,7 +168,7 @@ def consume_work_joined(self, tasks: List[SubInputTask]): subgraph.set_dependency(sub_init_worker, sub_worker).next(sub_join_worker) # Create GraphTask - graph_task = GraphTask( + graph_task = SubGraphWorker( graph=subgraph, entry_worker=subgraph_entry, exit_worker=subgraph_exit )