diff --git a/src/planai/dispatcher.py b/src/planai/dispatcher.py index 0270476..dd0d580 100644 --- a/src/planai/dispatcher.py +++ b/src/planai/dispatcher.py @@ -60,18 +60,10 @@ import threading import time from collections import defaultdict, deque +from concurrent.futures import ThreadPoolExecutor from queue import Empty, Queue from threading import Event, Lock -from typing import ( - TYPE_CHECKING, - Any, - Deque, - Dict, - List, - Optional, - Tuple, - Type, -) +from typing import TYPE_CHECKING, Any, Deque, Dict, List, Optional, Tuple, Type from .provenance import ProvenanceChain, ProvenanceTracker from .stats import WorkerStat @@ -99,8 +91,11 @@ def get_inheritance_chain(cls: Type[TaskWorker]) -> List[str]: class Dispatcher: - def __init__(self, graph: "Graph", web_port=5000): - self.graph = graph + def __init__( + self, graph: "Graph", web_port: int = 5000, start_thread_pool: bool = True + ): + self._thread_pool = ThreadPoolExecutor() if start_thread_pool else None + self.web_port = web_port self.task_lock = Lock() @@ -245,7 +240,7 @@ def submit_work( task_completed_callback: callable, ): self.increment_active_tasks(worker) - future = self.graph._thread_pool.submit(*arguments) + future = self._thread_pool.submit(*arguments) future.add_done_callback(task_completed_callback) def _dispatch_user_requests(self): @@ -527,6 +522,8 @@ def wait_for_completion(self, wait_for_quit=False): # Sleep for a short time to avoid busy waiting time.sleep(0.1) + self._thread_pool.shutdown(wait=True) + def start_web_interface(self): web_thread = threading.Thread( target=run_web_interface, args=(self, self.web_port) diff --git a/src/planai/graph.py b/src/planai/graph.py index 3480e6d..1598f42 100644 --- a/src/planai/graph.py +++ b/src/planai/graph.py @@ -15,7 +15,6 @@ import shutil import time from collections import deque -from concurrent.futures import ThreadPoolExecutor from threading import Event, Thread from typing import Dict, List, Optional, Set, Tuple, Type @@ -38,7 +37,6 @@ class Graph(BaseModel): dependencies: Dict[TaskWorker, List[TaskWorker]] = Field(default_factory=dict) _dispatcher: Optional[Dispatcher] = PrivateAttr(default=None) - _thread_pool: Optional[ThreadPoolExecutor] = PrivateAttr(default=None) _max_parallel_tasks: Dict[Type[TaskWorker], int] = PrivateAttr(default_factory=dict) _sink_tasks: List[TaskType] = PrivateAttr(default_factory=list) _sink_worker: Optional[TaskWorker] = PrivateAttr(default=None) @@ -46,11 +44,6 @@ class Graph(BaseModel): _worker_distances: Dict[str, Dict[str, int]] = PrivateAttr(default_factory=dict) _has_terminal: bool = PrivateAttr(default=False) - def __init__(self, **data): - super().__init__(**data) - if self._thread_pool is None: - self._thread_pool = ThreadPoolExecutor() - def add_worker(self, task: TaskWorker) -> "Graph": """Add a task to the Graph.""" if task in self.workers: @@ -304,8 +297,6 @@ def run( terminal_thread.join() logging.info("Terminal display stopped") - self._thread_pool.shutdown(wait=True) - for worker in self.workers: worker.completed() diff --git a/tests/planai/test_dispatcher.py b/tests/planai/test_dispatcher.py index 1f67f52..accb2f9 100644 --- a/tests/planai/test_dispatcher.py +++ b/tests/planai/test_dispatcher.py @@ -110,10 +110,10 @@ def shutdown(self, wait=True): class TestDispatcher(unittest.TestCase): def setUp(self): self.graph = Mock(spec=Graph) - self.graph._thread_pool = SingleThreadedExecutor() - self.dispatcher = Dispatcher(self.graph) + self.dispatcher = Dispatcher(self.graph, start_thread_pool=False) self.dispatcher.work_queue = Queue() self.dispatcher.stop_event = Event() + self.dispatcher._thread_pool = SingleThreadedExecutor() def test_dispatch(self): worker = Mock(spec=TaskWorker) @@ -129,7 +129,7 @@ def test_dispatch(self): self.assertEqual(self.dispatcher.active_tasks, 1) # Simulate task completion - future = self.graph._thread_pool.tasks[0] + future = self.dispatcher._thread_pool.tasks[0] future.add_done_callback.assert_called_once() callback = future.add_done_callback.call_args[0][0] callback(future) @@ -263,11 +263,10 @@ def test_start_web_interface(self, mock_run_web_interface): class TestDispatcherThreading(unittest.TestCase): def setUp(self): self.graph = Mock(spec=Graph) - self.graph._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4) self.dispatcher = Dispatcher(self.graph) def tearDown(self): - self.graph._thread_pool.shutdown(wait=True) + self.dispatcher._thread_pool.shutdown(wait=True) def test_concurrent_add_work(self): num_threads = 10 @@ -561,8 +560,6 @@ def test_exception_handling_end_to_end(self): "Task DummyTask failed with exception: Test exception", cm.output[0] ) - self.graph._thread_pool.shutdown(wait=True) - def test_max_parallel_tasks(self): num_tasks = 10 max_parallel = 2 @@ -671,7 +668,7 @@ def add_initial_work(): dispatcher.total_completed_tasks == total_processed ), f"Completed tasks {dispatcher.total_completed_tasks} should match total processed ({total_processed})" - graph._thread_pool.shutdown(wait=False) + dispatcher._thread_pool.shutdown(wait=False) def test_concurrent_success_and_failures(self): graph = Graph(name="Test Graph") @@ -757,7 +754,7 @@ def add_work_for_worker(worker): self.assertEqual(dispatcher.work_queue.qsize(), 0, "Work queue should be empty") self.assertEqual(dispatcher.active_tasks, 0, "No active tasks should remain") - graph._thread_pool.shutdown(wait=True) + dispatcher._thread_pool.shutdown(wait=True) if __name__ == "__main__": diff --git a/tests/planai/test_joined_multiple_task.py b/tests/planai/test_joined_multiple_task.py index 6c5073c..34a4f52 100644 --- a/tests/planai/test_joined_multiple_task.py +++ b/tests/planai/test_joined_multiple_task.py @@ -177,8 +177,6 @@ def test_complex_joined_task_workflow(self): self.assertEqual(len(outputs), 1) self.assertEqual(outputs[0].result, 5) - self.graph._thread_pool.shutdown(wait=True) - if __name__ == "__main__": unittest.main() diff --git a/tests/planai/test_joined_task.py b/tests/planai/test_joined_task.py index 4c3e939..165adcc 100644 --- a/tests/planai/test_joined_task.py +++ b/tests/planai/test_joined_task.py @@ -121,8 +121,6 @@ def test_joined_task_worker(self): len(self.worker3._joined_results), 0 ) # All joined results should have been processed - self.graph._thread_pool.shutdown(wait=True) - class InitialTask(Task): data: str @@ -233,8 +231,6 @@ def add_initial_work(): self.assertEqual(self.dispatcher.work_queue.qsize(), 0) self.assertEqual(self.dispatcher._active_tasks, 0) - self.graph._thread_pool.shutdown(wait=True) - if __name__ == "__main__": unittest.main()