diff --git a/src/planai/dispatcher.py b/src/planai/dispatcher.py index dc572f5..a62c4a4 100644 --- a/src/planai/dispatcher.py +++ b/src/planai/dispatcher.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import sys import threading import time from collections import defaultdict, deque @@ -41,11 +42,16 @@ ProvenanceChain = Tuple[Tuple[TaskName, TaskID], ...] +class NotificationTask(Task): + pass + + class Dispatcher: def __init__(self, graph: "Graph", web_port=5000): self.graph = graph self.work_queue = Queue() self.provenance: DefaultDict[ProvenanceChain, int] = defaultdict(int) + self.provenance_trace: Dict[ProvenanceChain, dict] = {} self.notifiers: DefaultDict[ProvenanceChain, List[TaskWorker]] = defaultdict( list ) @@ -76,18 +82,63 @@ def _add_provenance(self, task: Task): for prefix in self._generate_prefixes(task): with self.provenance_lock: self.provenance[prefix] = self.provenance.get(prefix, 0) + 1 + if prefix in self.provenance_trace: + trace_entry = { + "worker": task._provenance[-1][0], + "action": "adding", + "task": task.name, + "count": self.provenance[prefix], + "status": "", + } + self.provenance_trace[prefix].append(trace_entry) def _remove_provenance(self, task: Task): to_notify = set() for prefix in self._generate_prefixes(task): with self.provenance_lock: self.provenance[prefix] -= 1 - if self.provenance[prefix] == 0: + effective_count = self.provenance[prefix] + + if effective_count < 0: + error_message = f"FATAL ERROR: Provenance count for prefix {prefix} became negative ({effective_count}). This indicates a serious bug in the provenance tracking system." + logging.critical(error_message) + print(error_message, file=sys.stderr) + sys.exit(1) + + if effective_count == 0: del self.provenance[prefix] to_notify.add(prefix) + if prefix in self.provenance_trace: + # Get the list of notifiers for this prefix + with self.notifiers_lock: + notifiers = [n.name for n in self.notifiers.get(prefix, [])] + + status = ( + "will notify watchers" + if effective_count == 0 + else "still waiting for other tasks" + ) + if effective_count == 0 and notifiers: + status += f" (Notifying: {', '.join(notifiers)})" + + trace_entry = { + "worker": task._provenance[-1][0], + "action": "removing", + "task": task.name, + "count": effective_count, + "status": status, + } + self.provenance_trace[prefix].append(trace_entry) + for prefix in to_notify: - self._notify_task_completion(prefix, task) + self._notify_task_completion(prefix) + + def trace(self, prefix: ProvenanceChain): + logging.debug(f"Starting trace for {prefix}") + with self.provenance_lock: + if prefix not in self.provenance_trace: + self.provenance_trace[prefix] = [] def watch( self, prefix: ProvenanceChain, notifier: TaskWorker, task: Optional[Task] = None @@ -155,7 +206,7 @@ def unwatch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool: return True return False - def _notify_task_completion(self, prefix: tuple, task: Task): + def _notify_task_completion(self, prefix: tuple): to_notify = [] with self.notifiers_lock: for notifier in self.notifiers[prefix]: @@ -166,8 +217,8 @@ def _notify_task_completion(self, prefix: tuple, task: Task): self.active_tasks += 1 # Use a named function instead of a lambda to avoid closure issues - def task_completed_callback(future, worker=notifier, task=task): - self._task_completed(worker, task, future) + def task_completed_callback(future, worker=notifier): + self._task_completed(worker, None, future) future = self.graph._thread_pool.submit(notifier.notify, prefix) future.add_done_callback(task_completed_callback) @@ -235,6 +286,10 @@ def _task_to_dict(self, worker: TaskWorker, task: Task, error: str = "") -> Dict data["error"] = error return data + def get_traces(self) -> Dict: + with self.provenance_lock: + return self.provenance_trace + def get_queued_tasks(self) -> List[Dict]: return [ self._task_to_dict(worker, task) for worker, task in self.work_queue.queue @@ -269,7 +324,7 @@ def _get_task_id(self, task: Task) -> str: # Fallback in case _provenance is empty return f"unknown_{id(task)}" - def _task_completed(self, worker: TaskWorker, task: Task, future): + def _task_completed(self, worker: TaskWorker, task: Optional[Task], future): success: bool = False error_message: str = "" try: @@ -277,45 +332,58 @@ def _task_completed(self, worker: TaskWorker, task: Task, future): _ = future.result() # Handle successful task completion - logging.info(f"Task {task.name} completed successfully") + if task: + logging.info(f"Task {task.name} completed successfully") + else: + logging.info( + f"Notification for worker {worker.name} completed successfully" + ) success = True except Exception as e: # Handle task failure error_message = str(e) - logging.exception(f"Task {task.name} failed with exception: {str(e)}") + if task: + logging.exception(f"Task {task.name} failed with exception: {str(e)}") + else: + logging.exception( + f"Notification for worker {worker.name} failed with exception: {str(e)}" + ) # Anything else that needs to be done when a task fails? finally: # This code will run whether the task succeeded or failed - - # Determine whether we should retry the task - if not success: - if worker.num_retries > 0: - if task.retry_count < worker.num_retries: - task.increment_retry_count() - with self.task_lock: - self.active_tasks -= 1 - self.work_queue.put((worker, task)) - logging.info( - f"Retrying task {task.name} for the {task.retry_count} time" - ) - return - - with self.task_lock: - self.failed_tasks.appendleft((worker, task, error_message)) - self.total_failed_tasks += 1 - logging.error( - f"Task {task.name} failed after {task.retry_count} retries" - ) - # we'll fall through and do the clean up - else: + if task is not None: # Only handle retry and provenance for actual tasks + # Determine whether we should retry the task + if not success: + if worker.num_retries > 0: + if task.retry_count < worker.num_retries: + task.increment_retry_count() + with self.task_lock: + self.active_tasks -= 1 + self.work_queue.put((worker, task)) + logging.info( + f"Retrying task {task.name} for the {task.retry_count} time" + ) + return + + with self.task_lock: + self.failed_tasks.appendleft((worker, task, error_message)) + self.total_failed_tasks += 1 + logging.error( + f"Task {task.name} failed after {task.retry_count} retries" + ) + # we'll fall through and do the clean up + self._remove_provenance(task) + + if success: with self.task_lock: - self.completed_tasks.appendleft((worker, task)) + self.completed_tasks.appendleft( + (worker, task if task else NotificationTask()) + ) self.total_completed_tasks += 1 - self._remove_provenance(task) with self.task_lock: self.active_tasks -= 1 if self.active_tasks == 0 and self.work_queue.empty(): diff --git a/src/planai/task.py b/src/planai/task.py index b3a3f42..d34302f 100644 --- a/src/planai/task.py +++ b/src/planai/task.py @@ -225,6 +225,21 @@ def next(self, downstream: "TaskWorker"): self._graph.set_dependency(self, downstream) return downstream + def trace(self, prefix: "ProvenanceChain"): + """ + Traces the provenance chain for a given prefix in the graph. + + This method sets up a trace on a given prefix in the provenance chain. It will be visible + in the dispatcher dashboard. + + Parameters: + ----------- + prefix : ProvenanceChain + The prefix to trace. Must be a tuple representing a part of a task's provenance chain. + This is the sequence of task identifiers leading up to (but not including) the current task. + """ + self._graph._dispatcher.trace(prefix) + def watch(self, prefix: "ProvenanceChain", task: Optional[Task] = None) -> bool: """ Watches for the completion of a specific provenance chain prefix in the task graph. diff --git a/src/planai/templates/index.html b/src/planai/templates/index.html index 9b9e9a8..eb709d8 100644 --- a/src/planai/templates/index.html +++ b/src/planai/templates/index.html @@ -102,6 +102,30 @@ .nested-object { margin-left: 20px; } + + .trace-list { + background-color: var(--task-list-bg); + border: 1px solid var(--border-color); + border-radius: 4px; + padding: 10px; + margin-bottom: 20px; + } + + .trace-prefix { + background-color: var(--task-item-bg); + border: 1px solid var(--border-color); + border-radius: 4px; + padding: 10px; + margin-bottom: 10px; + } + + .trace-prefix h4 { + margin-top: 0; + } + + .trace-prefix ul { + padding-left: 20px; + } @@ -111,6 +135,9 @@