diff --git a/src/planai/dispatcher.py b/src/planai/dispatcher.py index dc572f5..a62c4a4 100644 --- a/src/planai/dispatcher.py +++ b/src/planai/dispatcher.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import sys import threading import time from collections import defaultdict, deque @@ -41,11 +42,16 @@ ProvenanceChain = Tuple[Tuple[TaskName, TaskID], ...] +class NotificationTask(Task): + pass + + class Dispatcher: def __init__(self, graph: "Graph", web_port=5000): self.graph = graph self.work_queue = Queue() self.provenance: DefaultDict[ProvenanceChain, int] = defaultdict(int) + self.provenance_trace: Dict[ProvenanceChain, dict] = {} self.notifiers: DefaultDict[ProvenanceChain, List[TaskWorker]] = defaultdict( list ) @@ -76,18 +82,63 @@ def _add_provenance(self, task: Task): for prefix in self._generate_prefixes(task): with self.provenance_lock: self.provenance[prefix] = self.provenance.get(prefix, 0) + 1 + if prefix in self.provenance_trace: + trace_entry = { + "worker": task._provenance[-1][0], + "action": "adding", + "task": task.name, + "count": self.provenance[prefix], + "status": "", + } + self.provenance_trace[prefix].append(trace_entry) def _remove_provenance(self, task: Task): to_notify = set() for prefix in self._generate_prefixes(task): with self.provenance_lock: self.provenance[prefix] -= 1 - if self.provenance[prefix] == 0: + effective_count = self.provenance[prefix] + + if effective_count < 0: + error_message = f"FATAL ERROR: Provenance count for prefix {prefix} became negative ({effective_count}). This indicates a serious bug in the provenance tracking system." + logging.critical(error_message) + print(error_message, file=sys.stderr) + sys.exit(1) + + if effective_count == 0: del self.provenance[prefix] to_notify.add(prefix) + if prefix in self.provenance_trace: + # Get the list of notifiers for this prefix + with self.notifiers_lock: + notifiers = [n.name for n in self.notifiers.get(prefix, [])] + + status = ( + "will notify watchers" + if effective_count == 0 + else "still waiting for other tasks" + ) + if effective_count == 0 and notifiers: + status += f" (Notifying: {', '.join(notifiers)})" + + trace_entry = { + "worker": task._provenance[-1][0], + "action": "removing", + "task": task.name, + "count": effective_count, + "status": status, + } + self.provenance_trace[prefix].append(trace_entry) + for prefix in to_notify: - self._notify_task_completion(prefix, task) + self._notify_task_completion(prefix) + + def trace(self, prefix: ProvenanceChain): + logging.debug(f"Starting trace for {prefix}") + with self.provenance_lock: + if prefix not in self.provenance_trace: + self.provenance_trace[prefix] = [] def watch( self, prefix: ProvenanceChain, notifier: TaskWorker, task: Optional[Task] = None @@ -155,7 +206,7 @@ def unwatch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool: return True return False - def _notify_task_completion(self, prefix: tuple, task: Task): + def _notify_task_completion(self, prefix: tuple): to_notify = [] with self.notifiers_lock: for notifier in self.notifiers[prefix]: @@ -166,8 +217,8 @@ def _notify_task_completion(self, prefix: tuple, task: Task): self.active_tasks += 1 # Use a named function instead of a lambda to avoid closure issues - def task_completed_callback(future, worker=notifier, task=task): - self._task_completed(worker, task, future) + def task_completed_callback(future, worker=notifier): + self._task_completed(worker, None, future) future = self.graph._thread_pool.submit(notifier.notify, prefix) future.add_done_callback(task_completed_callback) @@ -235,6 +286,10 @@ def _task_to_dict(self, worker: TaskWorker, task: Task, error: str = "") -> Dict data["error"] = error return data + def get_traces(self) -> Dict: + with self.provenance_lock: + return self.provenance_trace + def get_queued_tasks(self) -> List[Dict]: return [ self._task_to_dict(worker, task) for worker, task in self.work_queue.queue @@ -269,7 +324,7 @@ def _get_task_id(self, task: Task) -> str: # Fallback in case _provenance is empty return f"unknown_{id(task)}" - def _task_completed(self, worker: TaskWorker, task: Task, future): + def _task_completed(self, worker: TaskWorker, task: Optional[Task], future): success: bool = False error_message: str = "" try: @@ -277,45 +332,58 @@ def _task_completed(self, worker: TaskWorker, task: Task, future): _ = future.result() # Handle successful task completion - logging.info(f"Task {task.name} completed successfully") + if task: + logging.info(f"Task {task.name} completed successfully") + else: + logging.info( + f"Notification for worker {worker.name} completed successfully" + ) success = True except Exception as e: # Handle task failure error_message = str(e) - logging.exception(f"Task {task.name} failed with exception: {str(e)}") + if task: + logging.exception(f"Task {task.name} failed with exception: {str(e)}") + else: + logging.exception( + f"Notification for worker {worker.name} failed with exception: {str(e)}" + ) # Anything else that needs to be done when a task fails? finally: # This code will run whether the task succeeded or failed - - # Determine whether we should retry the task - if not success: - if worker.num_retries > 0: - if task.retry_count < worker.num_retries: - task.increment_retry_count() - with self.task_lock: - self.active_tasks -= 1 - self.work_queue.put((worker, task)) - logging.info( - f"Retrying task {task.name} for the {task.retry_count} time" - ) - return - - with self.task_lock: - self.failed_tasks.appendleft((worker, task, error_message)) - self.total_failed_tasks += 1 - logging.error( - f"Task {task.name} failed after {task.retry_count} retries" - ) - # we'll fall through and do the clean up - else: + if task is not None: # Only handle retry and provenance for actual tasks + # Determine whether we should retry the task + if not success: + if worker.num_retries > 0: + if task.retry_count < worker.num_retries: + task.increment_retry_count() + with self.task_lock: + self.active_tasks -= 1 + self.work_queue.put((worker, task)) + logging.info( + f"Retrying task {task.name} for the {task.retry_count} time" + ) + return + + with self.task_lock: + self.failed_tasks.appendleft((worker, task, error_message)) + self.total_failed_tasks += 1 + logging.error( + f"Task {task.name} failed after {task.retry_count} retries" + ) + # we'll fall through and do the clean up + self._remove_provenance(task) + + if success: with self.task_lock: - self.completed_tasks.appendleft((worker, task)) + self.completed_tasks.appendleft( + (worker, task if task else NotificationTask()) + ) self.total_completed_tasks += 1 - self._remove_provenance(task) with self.task_lock: self.active_tasks -= 1 if self.active_tasks == 0 and self.work_queue.empty(): diff --git a/src/planai/task.py b/src/planai/task.py index b3a3f42..d34302f 100644 --- a/src/planai/task.py +++ b/src/planai/task.py @@ -225,6 +225,21 @@ def next(self, downstream: "TaskWorker"): self._graph.set_dependency(self, downstream) return downstream + def trace(self, prefix: "ProvenanceChain"): + """ + Traces the provenance chain for a given prefix in the graph. + + This method sets up a trace on a given prefix in the provenance chain. It will be visible + in the dispatcher dashboard. + + Parameters: + ----------- + prefix : ProvenanceChain + The prefix to trace. Must be a tuple representing a part of a task's provenance chain. + This is the sequence of task identifiers leading up to (but not including) the current task. + """ + self._graph._dispatcher.trace(prefix) + def watch(self, prefix: "ProvenanceChain", task: Optional[Task] = None) -> bool: """ Watches for the completion of a specific provenance chain prefix in the task graph. diff --git a/src/planai/templates/index.html b/src/planai/templates/index.html index 9b9e9a8..eb709d8 100644 --- a/src/planai/templates/index.html +++ b/src/planai/templates/index.html @@ -102,6 +102,30 @@ .nested-object { margin-left: 20px; } + + .trace-list { + background-color: var(--task-list-bg); + border: 1px solid var(--border-color); + border-radius: 4px; + padding: 10px; + margin-bottom: 20px; + } + + .trace-prefix { + background-color: var(--task-item-bg); + border: 1px solid var(--border-color); + border-radius: 4px; + padding: 10px; + margin-bottom: 10px; + } + + .trace-prefix h4 { + margin-top: 0; + } + + .trace-prefix ul { + padding-left: 20px; + } @@ -111,6 +135,9 @@

PlanAI Dispatcher Dashboard

+

Provenance Trace

+
+

Queued Tasks

@@ -201,7 +228,6 @@

Input Provenance:

}); } - function renderObject(obj, indent = '') { return Object.entries(obj).map(([key, value]) => { if (value === null) return `${indent}${key}: null`; @@ -224,7 +250,65 @@

Input Provenance:

} } + // Trace visualization + let traceEventSource; + function setupTraceEventSource() { + traceEventSource = new EventSource("/trace_stream"); + + traceEventSource.onmessage = function (event) { + const traceData = JSON.parse(event.data); + updateTraceVisualization(traceData); + }; + + traceEventSource.onerror = function (error) { + traceEventSource.close(); + setTimeout(setupTraceEventSource, 1000); + }; + } + + setupTraceEventSource(); + + function updateTraceVisualization(traceData) { + const traceElement = document.getElementById('trace-visualization'); + traceElement.innerHTML = ''; + + for (const [prefixStr, entries] of Object.entries(traceData)) { + const prefixElement = document.createElement('div'); + prefixElement.className = 'trace-prefix'; + + const prefix = prefixStr.split('_').join(', '); + + prefixElement.innerHTML = `

Prefix: (${prefix})

`; + + const table = document.createElement('table'); + table.className = 'trace-table'; + + // Create table header + const headerRow = table.insertRow(); + ['Worker', 'Action', 'Task', 'Count', 'Status'].forEach(header => { + const th = document.createElement('th'); + th.textContent = header; + headerRow.appendChild(th); + }); + + // Populate table with entry data + entries.forEach(entry => { + const row = table.insertRow(); + row.className = entry.action + (entry.count === 0 ? ' zero-count' : ''); + ['worker', 'action', 'task', 'count', 'status'].forEach(key => { + const cell = row.insertCell(); + cell.textContent = entry[key]; + }); + }); + + prefixElement.appendChild(table); + traceElement.appendChild(prefixElement); + } + } + + + // Quit button functionality document.getElementById('quit-button').addEventListener('click', function () { fetch('/quit', { method: 'POST' }) .then(response => response.json()) @@ -234,6 +318,7 @@

Input Provenance:

} }); }); + // Theme toggle functionality const themeToggle = document.getElementById('theme-toggle'); const prefersDarkScheme = window.matchMedia("(prefers-color-scheme: dark)"); diff --git a/src/planai/web_interface.py b/src/planai/web_interface.py index 033b2bf..67075e4 100644 --- a/src/planai/web_interface.py +++ b/src/planai/web_interface.py @@ -41,7 +41,7 @@ def event_stream(): if current_data != last_data: yield f"data: {json.dumps(current_data)}\n\n" last_data = current_data - time.sleep(1) + time.sleep(0.2) def get_current_data(): queued_tasks = dispatcher.get_queued_tasks() @@ -61,6 +61,24 @@ def get_current_data(): return Response(event_stream(), mimetype="text/event-stream") +@app.route("/trace_stream") +def trace_stream(): + def event_stream(): + last_trace = None + while True: + current_trace = dispatcher.get_traces() + if current_trace != last_trace: + # Convert tuple keys to strings + serializable_trace = { + "_".join(map(str, k)): v for k, v in current_trace.items() + } + yield f"data: {json.dumps(serializable_trace)}\n\n" + last_trace = current_trace + time.sleep(0.2) + + return Response(event_stream(), mimetype="text/event-stream") + + @app.route("/quit", methods=["POST"]) def quit(): global quit_event diff --git a/tests/planai/test_dispatcher.py b/tests/planai/test_dispatcher.py index ee54b9c..8d12f82 100644 --- a/tests/planai/test_dispatcher.py +++ b/tests/planai/test_dispatcher.py @@ -127,14 +127,13 @@ def test_remove_provenance(self): with patch.object(self.dispatcher, "_notify_task_completion") as mock_notify: self.dispatcher._remove_provenance(task) self.assertEqual(self.dispatcher.provenance, {}) - mock_notify.assert_any_call((("Task1", 1),), task) - mock_notify.assert_any_call((("Task1", 1), ("Task2", 2)), task) + mock_notify.assert_any_call((("Task1", 1),)) + mock_notify.assert_any_call((("Task1", 1), ("Task2", 2))) def test_notify_task_completion(self): notifier = Mock(spec=TaskWorker) - task = DummyTask(data="test") self.dispatcher.notifiers = {(("Task1", 1),): [notifier]} - self.dispatcher._notify_task_completion((("Task1", 1),), task) + self.dispatcher._notify_task_completion((("Task1", 1),)) self.assertEqual(self.dispatcher.active_tasks, 1) notifier.notify.assert_called_once_with((("Task1", 1),))