Skip to content

Commit

Permalink
refactor: move provenance tracking into the graph
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Oct 2, 2024
1 parent f6db2f8 commit e372425
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 37 deletions.
20 changes: 3 additions & 17 deletions src/planai/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
from threading import Event, Lock
from typing import TYPE_CHECKING, Any, Deque, Dict, List, Optional, Tuple, Type

from .provenance import ProvenanceChain, ProvenanceTracker
from .stats import WorkerStat
from .task import Task, TaskWorker
from .user_input import UserInputRequest
Expand Down Expand Up @@ -111,8 +110,6 @@ def __init__(
# we are using the work_available Event to signal the dispatcher that there might be work
self.work_available = threading.Event()

self._provenance_tracker = ProvenanceTracker()

self.stop_event = Event()
self._active_tasks = 0
self.task_completion_event = Event()
Expand All @@ -134,17 +131,6 @@ def __init__(
self.user_input_requests = Queue()
self.user_pending_requests: Dict[str, UserInputRequest] = {}

def trace(self, prefix: ProvenanceChain):
self._provenance_tracker.trace(prefix)

def watch(
self, prefix: ProvenanceChain, notifier: TaskWorker, task: Optional[Task] = None
) -> bool:
return self._provenance_tracker.watch(prefix, notifier, task)

def unwatch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
return self._provenance_tracker.unwatch(prefix, notifier)

@property
def active_tasks(self):
with self.task_lock:
Expand Down Expand Up @@ -479,14 +465,14 @@ def _task_completed(self, worker: TaskWorker, task: Optional[Task], future):
self.worker_stats[worker.name].increment_completed()

if task:
self._provenance_tracker._remove_provenance(task, worker)
worker._graph._provenance_tracker._remove_provenance(task, worker)

if self.decrement_active_tasks(worker):
self.task_completion_event.set()

def add_work(self, worker: TaskWorker, task: Task):
task_copy = task.model_copy()
self._provenance_tracker._add_provenance(task_copy)
worker._graph._provenance_tracker._add_provenance(task_copy)
self._add_to_queue(worker, task_copy)

def add_multiple_work(self, work_items: List[Tuple[TaskWorker, Task]]):
Expand All @@ -495,7 +481,7 @@ def add_multiple_work(self, work_items: List[Tuple[TaskWorker, Task]]):
# before all the provenance is added.
work_items = [(worker, task.model_copy()) for worker, task in work_items]
for worker, task in work_items:
self._provenance_tracker._add_provenance(task)
worker._graph._provenance_tracker._add_provenance(task)
for worker, task in work_items:
self._add_to_queue(worker, task)

Expand Down
30 changes: 30 additions & 0 deletions src/planai/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from .dispatcher import Dispatcher
from .joined_task import InitialTaskWorker
from .provenance import ProvenanceChain, ProvenanceTracker
from .task import Task, TaskType, TaskWorker

# Initialize colorama for Windows compatibility
Expand All @@ -37,13 +38,28 @@ class Graph(BaseModel):
dependencies: Dict[TaskWorker, List[TaskWorker]] = Field(default_factory=dict)

_dispatcher: Optional[Dispatcher] = PrivateAttr(default=None)
_provenance_tracker: ProvenanceTracker = PrivateAttr(
default_factory=ProvenanceTracker
)

_max_parallel_tasks: Dict[Type[TaskWorker], int] = PrivateAttr(default_factory=dict)
_sink_tasks: List[TaskType] = PrivateAttr(default_factory=list)
_sink_worker: Optional[TaskWorker] = PrivateAttr(default=None)

_worker_distances: Dict[str, Dict[str, int]] = PrivateAttr(default_factory=dict)
_has_terminal: bool = PrivateAttr(default=False)

def trace(self, prefix: ProvenanceChain):
self._provenance_tracker.trace(prefix)

def watch(
self, prefix: ProvenanceChain, notifier: TaskWorker, task: Optional[Task] = None
) -> bool:
return self._provenance_tracker.watch(prefix, notifier, task)

def unwatch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
return self._provenance_tracker.unwatch(prefix, notifier)

def add_worker(self, task: TaskWorker) -> "Graph":
"""Add a task to the Graph."""
if task in self.workers:
Expand Down Expand Up @@ -301,6 +317,20 @@ def run(
worker.completed()

def inject_initial_task_worker(self, initial_tasks: List[Tuple[TaskWorker, Task]]):
"""
Injects an initial task worker and sets up dependencies for the given initial tasks.
This method creates an `InitialTaskWorker` instance and adds it to the worker list.
It then sets up dependencies between the initial task worker and each worker in the
provided `initial_tasks` list without performing any checks.
Args:
initial_tasks (List[Tuple[TaskWorker, Task]]): A list of tuples where each tuple
contains a `TaskWorker` and a `Task`.
Returns:
None
"""
initial_worker = InitialTaskWorker()
self.add_worker(initial_worker)
for worker, _ in initial_tasks:
Expand Down
6 changes: 3 additions & 3 deletions src/planai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def trace(self, prefix: "ProvenanceChain"):
The prefix to trace. Must be a tuple representing a part of a task's provenance chain.
This is the sequence of task identifiers leading up to (but not including) the current task.
"""
self._graph._dispatcher.trace(prefix)
self._graph.trace(prefix)

def watch(self, prefix: "ProvenanceChain", task: Optional[Task] = None) -> bool:
"""
Expand Down Expand Up @@ -338,7 +338,7 @@ def watch(self, prefix: "ProvenanceChain", task: Optional[Task] = None) -> bool:
"""
if not isinstance(prefix, tuple):
raise ValueError("Prefix must be a tuple")
return self._graph._dispatcher.watch(prefix, self, task)
return self._graph.watch(prefix, self, task)

def unwatch(self, prefix: "ProvenanceChain") -> bool:
"""
Expand All @@ -352,7 +352,7 @@ def unwatch(self, prefix: "ProvenanceChain") -> bool:
"""
if not isinstance(prefix, tuple):
raise ValueError("Prefix must be a tuple")
return self._graph._dispatcher.unwatch(prefix, self)
return self._graph.unwatch(prefix, self)

def print(self, *args):
"""
Expand Down
38 changes: 30 additions & 8 deletions tests/planai/test_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from planai.dispatcher import Dispatcher
from planai.graph import Graph
from planai.provenance import ProvenanceTracker
from planai.task import Task, TaskWorker


Expand Down Expand Up @@ -110,13 +111,15 @@ def shutdown(self, wait=True):
class TestDispatcher(unittest.TestCase):
def setUp(self):
self.graph = Mock(spec=Graph)
self.graph._provenance_tracker = ProvenanceTracker()
self.dispatcher = Dispatcher(self.graph, start_thread_pool=False)
self.dispatcher.work_queue = Queue()
self.dispatcher.stop_event = Event()
self.dispatcher._thread_pool = SingleThreadedExecutor()

def test_dispatch(self):
worker = Mock(spec=TaskWorker)
worker._graph = self.graph
task = DummyTask(data="test")
self.dispatcher.work_queue.put((worker, task))

Expand All @@ -139,6 +142,7 @@ def test_dispatch(self):

def test_execute_task(self):
worker = Mock(spec=TaskWorker)
worker._graph = self.graph
future = Mock()
task = DummyTask(data="test")
self.dispatcher._execute_task(worker, task)
Expand All @@ -149,6 +153,7 @@ def test_execute_task(self):

def test_task_to_dict(self):
worker = DummyTaskWorkerSimple()
worker._graph = self.graph
task = DummyTask(data="test")
task._provenance = [("Task1", 1)]
task._input_provenance = [DummyTask(data="input")]
Expand All @@ -160,6 +165,7 @@ def test_task_to_dict(self):

def test_get_queued_tasks(self):
worker = DummyTaskWorkerSimple()
worker._graph = self.graph
task = DummyTask(data="test")
self.dispatcher.work_queue.put((worker, task))
result = self.dispatcher.get_queued_tasks()
Expand All @@ -169,6 +175,7 @@ def test_get_queued_tasks(self):

def test_get_active_tasks(self):
worker = DummyTaskWorkerSimple()
worker._graph = self.graph
task = DummyTask(data="test")
self.dispatcher.debug_active_tasks = {1: (worker, task)}
result = self.dispatcher.get_active_tasks()
Expand All @@ -178,6 +185,7 @@ def test_get_active_tasks(self):

def test_get_completed_tasks(self):
worker = DummyTaskWorkerSimple()
worker._graph = self.graph
task = DummyTask(data="test")
self.dispatcher.completed_tasks = deque([(worker, task)])
result = self.dispatcher.get_completed_tasks()
Expand All @@ -188,6 +196,7 @@ def test_get_completed_tasks(self):
def test_notify_completed(self):
future = Mock()
worker = DummyTaskWorkerSimple()
worker._graph = self.graph
task = DummyTask(data="test")
self.dispatcher.increment_active_tasks(worker)
self.dispatcher._task_completed(worker, task, future)
Expand All @@ -196,6 +205,7 @@ def test_notify_completed(self):

def test_task_completed(self):
worker = Mock(spec=TaskWorker)
worker._graph = self.graph
task = DummyTask(data="test")
future = Mock()
future.result.return_value = None
Expand All @@ -205,7 +215,7 @@ def test_task_completed(self):
self.dispatcher.work_queue = Queue() # Ensure the queue is empty

with patch.object(
self.dispatcher._provenance_tracker, "_remove_provenance"
self.graph._provenance_tracker, "_remove_provenance"
) as mock_remove:
self.dispatcher._task_completed(worker, task, future)
mock_remove.assert_called_once_with(task, worker)
Expand All @@ -215,6 +225,7 @@ def test_task_completed(self):

def test_task_completed_with_remaining_tasks(self):
worker = Mock(spec=TaskWorker)
worker._graph = self.graph
task = DummyTask(data="test")
future = Mock()
future.result.return_value = None
Expand All @@ -225,7 +236,7 @@ def test_task_completed_with_remaining_tasks(self):
self.dispatcher.work_queue = Queue() # Ensure the queue is empty

with patch.object(
self.dispatcher._provenance_tracker, "_remove_provenance"
self.graph._provenance_tracker, "_remove_provenance"
) as mock_remove:
self.dispatcher._task_completed(worker, task, future)
mock_remove.assert_called_once_with(task, worker)
Expand All @@ -235,9 +246,10 @@ def test_task_completed_with_remaining_tasks(self):

def test_add_work(self):
worker = Mock(spec=TaskWorker)
worker._graph = self.graph
task = DummyTask(data="test")
with patch.object(
self.dispatcher._provenance_tracker, "_add_provenance"
self.graph._provenance_tracker, "_add_provenance"
) as mock_add:
self.dispatcher.add_work(worker, task)
mock_add.assert_called_once_with(task)
Expand All @@ -262,7 +274,7 @@ def test_start_web_interface(self, mock_run_web_interface):

class TestDispatcherThreading(unittest.TestCase):
def setUp(self):
self.graph = Mock(spec=Graph)
self.graph = Graph(name="Test Graph")
self.dispatcher = Dispatcher(self.graph)

def tearDown(self):
Expand All @@ -275,6 +287,7 @@ def test_concurrent_add_work(self):
def add_work():
for _ in range(num_tasks_per_thread):
worker = Mock(spec=TaskWorker)
worker._graph = self.graph
task = DummyTask(data="test")
self.dispatcher.add_work(worker, task)

Expand All @@ -294,13 +307,14 @@ def test_race_condition_provenance(self):
num_threads = 10
num_operations = 1000
worker = DummyTaskWorkerSimple()
worker._graph = self.graph

def modify_provenance():
for _ in range(num_operations):
task = DummyTask(data="test")
task._provenance = [("Task1", 1)]
self.dispatcher._provenance_tracker._add_provenance(task)
self.dispatcher._provenance_tracker._remove_provenance(task, worker)
self.graph._provenance_tracker._add_provenance(task)
self.graph._provenance_tracker._remove_provenance(task, worker)

threads = [
threading.Thread(target=modify_provenance) for _ in range(num_threads)
Expand All @@ -313,7 +327,7 @@ def modify_provenance():
thread.join()

# All operations should cancel out, leaving the provenance empty or with zero counts
for value in self.dispatcher._provenance_tracker.provenance.values():
for value in self.graph._provenance_tracker.provenance.values():
self.assertEqual(value, 0, "Provenance count should be 0 for all tasks")

def test_stress_dispatcher(self):
Expand All @@ -322,6 +336,8 @@ def test_stress_dispatcher(self):
num_tasks_per_worker = 1000

workers = [Mock(spec=TaskWorker) for _ in range(num_workers)]
for worker in workers:
worker._graph = self.graph

def worker_task(worker):
for i in range(num_tasks_per_worker):
Expand Down Expand Up @@ -374,6 +390,7 @@ def test_exception_logging(self, mock_logging_exception):

# Create a worker that raises an exception
worker = ExceptionRaisingTaskWorker()
worker._graph = self.graph

# Add tasks to the dispatcher
for i in range(num_tasks):
Expand Down Expand Up @@ -415,6 +432,8 @@ def test_exception_logging(self, mock_logging_exception):
@patch("planai.dispatcher.logging.exception")
def test_task_retry(self, mock_log_exception, mock_log_error, mock_log_info):
worker = RetryTaskWorker(num_retries=2, fail_attempts=2)
worker._graph = self.graph

task = DummyTask(data="test-retry")

self.dispatcher.increment_active_tasks(worker)
Expand Down Expand Up @@ -456,6 +475,8 @@ def test_task_retry_exhausted(
self, mock_log_exception, mock_log_error, mock_log_info
):
worker = RetryTaskWorker(num_retries=2, fail_attempts=3)
worker._graph = self.graph

task = DummyTask(data="test-retry-exhausted")

self.dispatcher.increment_active_tasks(worker)
Expand Down Expand Up @@ -537,7 +558,7 @@ def test_exception_handling_end_to_end(self):

# Check that the provenance was properly removed
self.assertEqual(
len(dispatcher._provenance_tracker.provenance),
len(self.graph._provenance_tracker.provenance),
0,
"Provenance should be empty",
)
Expand Down Expand Up @@ -566,6 +587,7 @@ def test_max_parallel_tasks(self):

# Create a custom worker
worker = LimitedParallelTaskWorker()
worker._graph = self.graph

# Set up the dispatcher
self.dispatcher.set_max_parallel_tasks(LimitedParallelTaskWorker, max_parallel)
Expand Down
2 changes: 1 addition & 1 deletion tests/planai/test_joined_multiple_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_complex_joined_task_workflow(self):
dispatch_thread = threading.Thread(target=self.dispatcher.dispatch)
dispatch_thread.start()

self.dispatcher.trace((("InitialTaskWorker", 1),))
self.graph.trace((("InitialTaskWorker", 1),))

# Add initial work
self.dispatcher.add_multiple_work(initial_work)
Expand Down
Loading

0 comments on commit e372425

Please sign in to comment.