Skip to content

Commit

Permalink
refactor: move the thread pool executor into the dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Oct 1, 2024
1 parent 1180bf8 commit f6db2f8
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 37 deletions.
23 changes: 10 additions & 13 deletions src/planai/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,10 @@
import threading
import time
from collections import defaultdict, deque
from concurrent.futures import ThreadPoolExecutor
from queue import Empty, Queue
from threading import Event, Lock
from typing import (
TYPE_CHECKING,
Any,
Deque,
Dict,
List,
Optional,
Tuple,
Type,
)
from typing import TYPE_CHECKING, Any, Deque, Dict, List, Optional, Tuple, Type

from .provenance import ProvenanceChain, ProvenanceTracker
from .stats import WorkerStat
Expand Down Expand Up @@ -99,8 +91,11 @@ def get_inheritance_chain(cls: Type[TaskWorker]) -> List[str]:


class Dispatcher:
def __init__(self, graph: "Graph", web_port=5000):
self.graph = graph
def __init__(
self, graph: "Graph", web_port: int = 5000, start_thread_pool: bool = True
):
self._thread_pool = ThreadPoolExecutor() if start_thread_pool else None

self.web_port = web_port

self.task_lock = Lock()
Expand Down Expand Up @@ -245,7 +240,7 @@ def submit_work(
task_completed_callback: callable,
):
self.increment_active_tasks(worker)
future = self.graph._thread_pool.submit(*arguments)
future = self._thread_pool.submit(*arguments)
future.add_done_callback(task_completed_callback)

def _dispatch_user_requests(self):
Expand Down Expand Up @@ -527,6 +522,8 @@ def wait_for_completion(self, wait_for_quit=False):
# Sleep for a short time to avoid busy waiting
time.sleep(0.1)

self._thread_pool.shutdown(wait=True)

def start_web_interface(self):
web_thread = threading.Thread(
target=run_web_interface, args=(self, self.web_port)
Expand Down
9 changes: 0 additions & 9 deletions src/planai/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import shutil
import time
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from threading import Event, Thread
from typing import Dict, List, Optional, Set, Tuple, Type

Expand All @@ -38,19 +37,13 @@ class Graph(BaseModel):
dependencies: Dict[TaskWorker, List[TaskWorker]] = Field(default_factory=dict)

_dispatcher: Optional[Dispatcher] = PrivateAttr(default=None)
_thread_pool: Optional[ThreadPoolExecutor] = PrivateAttr(default=None)
_max_parallel_tasks: Dict[Type[TaskWorker], int] = PrivateAttr(default_factory=dict)
_sink_tasks: List[TaskType] = PrivateAttr(default_factory=list)
_sink_worker: Optional[TaskWorker] = PrivateAttr(default=None)

_worker_distances: Dict[str, Dict[str, int]] = PrivateAttr(default_factory=dict)
_has_terminal: bool = PrivateAttr(default=False)

def __init__(self, **data):
super().__init__(**data)
if self._thread_pool is None:
self._thread_pool = ThreadPoolExecutor()

def add_worker(self, task: TaskWorker) -> "Graph":
"""Add a task to the Graph."""
if task in self.workers:
Expand Down Expand Up @@ -304,8 +297,6 @@ def run(
terminal_thread.join()
logging.info("Terminal display stopped")

self._thread_pool.shutdown(wait=True)

for worker in self.workers:
worker.completed()

Expand Down
15 changes: 6 additions & 9 deletions tests/planai/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def shutdown(self, wait=True):
class TestDispatcher(unittest.TestCase):
def setUp(self):
self.graph = Mock(spec=Graph)
self.graph._thread_pool = SingleThreadedExecutor()
self.dispatcher = Dispatcher(self.graph)
self.dispatcher = Dispatcher(self.graph, start_thread_pool=False)
self.dispatcher.work_queue = Queue()
self.dispatcher.stop_event = Event()
self.dispatcher._thread_pool = SingleThreadedExecutor()

def test_dispatch(self):
worker = Mock(spec=TaskWorker)
Expand All @@ -129,7 +129,7 @@ def test_dispatch(self):
self.assertEqual(self.dispatcher.active_tasks, 1)

# Simulate task completion
future = self.graph._thread_pool.tasks[0]
future = self.dispatcher._thread_pool.tasks[0]
future.add_done_callback.assert_called_once()
callback = future.add_done_callback.call_args[0][0]
callback(future)
Expand Down Expand Up @@ -263,11 +263,10 @@ def test_start_web_interface(self, mock_run_web_interface):
class TestDispatcherThreading(unittest.TestCase):
def setUp(self):
self.graph = Mock(spec=Graph)
self.graph._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=4)
self.dispatcher = Dispatcher(self.graph)

def tearDown(self):
self.graph._thread_pool.shutdown(wait=True)
self.dispatcher._thread_pool.shutdown(wait=True)

def test_concurrent_add_work(self):
num_threads = 10
Expand Down Expand Up @@ -561,8 +560,6 @@ def test_exception_handling_end_to_end(self):
"Task DummyTask failed with exception: Test exception", cm.output[0]
)

self.graph._thread_pool.shutdown(wait=True)

def test_max_parallel_tasks(self):
num_tasks = 10
max_parallel = 2
Expand Down Expand Up @@ -671,7 +668,7 @@ def add_initial_work():
dispatcher.total_completed_tasks == total_processed
), f"Completed tasks {dispatcher.total_completed_tasks} should match total processed ({total_processed})"

graph._thread_pool.shutdown(wait=False)
dispatcher._thread_pool.shutdown(wait=False)

def test_concurrent_success_and_failures(self):
graph = Graph(name="Test Graph")
Expand Down Expand Up @@ -757,7 +754,7 @@ def add_work_for_worker(worker):
self.assertEqual(dispatcher.work_queue.qsize(), 0, "Work queue should be empty")
self.assertEqual(dispatcher.active_tasks, 0, "No active tasks should remain")

graph._thread_pool.shutdown(wait=True)
dispatcher._thread_pool.shutdown(wait=True)


if __name__ == "__main__":
Expand Down
2 changes: 0 additions & 2 deletions tests/planai/test_joined_multiple_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ def test_complex_joined_task_workflow(self):
self.assertEqual(len(outputs), 1)
self.assertEqual(outputs[0].result, 5)

self.graph._thread_pool.shutdown(wait=True)


if __name__ == "__main__":
unittest.main()
4 changes: 0 additions & 4 deletions tests/planai/test_joined_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,6 @@ def test_joined_task_worker(self):
len(self.worker3._joined_results), 0
) # All joined results should have been processed

self.graph._thread_pool.shutdown(wait=True)


class InitialTask(Task):
data: str
Expand Down Expand Up @@ -233,8 +231,6 @@ def add_initial_work():
self.assertEqual(self.dispatcher.work_queue.qsize(), 0)
self.assertEqual(self.dispatcher._active_tasks, 0)

self.graph._thread_pool.shutdown(wait=True)


if __name__ == "__main__":
unittest.main()

0 comments on commit f6db2f8

Please sign in to comment.