Skip to content

Commit

Permalink
fix: provenance corruption
Browse files Browse the repository at this point in the history
Due to Python's passing of reference, it was possible for provenance to get corrupted. Now we provide a copy of the task to consumers
  • Loading branch information
provos committed Sep 4, 2024
1 parent f8dca3e commit e3bd7bf
Showing 1 changed file with 42 additions and 22 deletions.
64 changes: 42 additions & 22 deletions src/planai/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def _add_provenance(self, task: Task):
for prefix in self._generate_prefixes(task):
with self.provenance_lock:
self.provenance[prefix] = self.provenance.get(prefix, 0) + 1
logging.debug(
"+Provenance for %s is now %s", prefix, self.provenance[prefix]
)
if prefix in self.provenance_trace:
trace_entry = {
"worker": task._provenance[-1][0],
Expand All @@ -91,20 +94,20 @@ def _add_provenance(self, task: Task):
"status": "",
}
self.provenance_trace[prefix].append(trace_entry)
logging.info(
"Tracing: add provenance for %s: %s", prefix, trace_entry
)

def _remove_provenance(self, task: Task):
to_notify = set()
for prefix in self._generate_prefixes(task):
with self.provenance_lock:
self.provenance[prefix] -= 1
effective_count = self.provenance[prefix]

if effective_count < 0:
error_message = f"FATAL ERROR: Provenance count for prefix {prefix} became negative ({effective_count}). This indicates a serious bug in the provenance tracking system."
logging.critical(error_message)
print(error_message, file=sys.stderr)
sys.exit(1)
logging.debug(
"-Provenance for %s is now %s", prefix, self.provenance[prefix]
)

effective_count = self.provenance[prefix]
if effective_count == 0:
del self.provenance[prefix]
to_notify.add(prefix)
Expand All @@ -130,12 +133,21 @@ def _remove_provenance(self, task: Task):
"status": status,
}
self.provenance_trace[prefix].append(trace_entry)
logging.info(
"Tracing: remove provenance for %s: %s", prefix, trace_entry
)

if effective_count < 0:
error_message = f"FATAL ERROR: Provenance count for prefix {prefix} became negative ({effective_count}). This indicates a serious bug in the provenance tracking system."
logging.critical(error_message)
print(error_message, file=sys.stderr)
sys.exit(1)

for prefix in to_notify:
self._notify_task_completion(prefix)

def trace(self, prefix: ProvenanceChain):
logging.debug(f"Starting trace for {prefix}")
logging.info(f"Starting trace for {prefix}")
with self.provenance_lock:
if prefix not in self.provenance_trace:
self.provenance_trace[prefix] = []
Expand Down Expand Up @@ -191,7 +203,7 @@ def watch(
should_notify = True

if should_notify:
self._notify_task_completion(prefix, task)
self._notify_task_completion(prefix)

return added

Expand All @@ -213,6 +225,7 @@ def _notify_task_completion(self, prefix: tuple):
to_notify.append((notifier, prefix))

for notifier, prefix in to_notify:
logging.info(f"Notifying {notifier.name} that prefix {prefix} is complete")
with self.task_lock:
self.active_tasks += 1

Expand Down Expand Up @@ -258,7 +271,9 @@ def _execute_task(self, worker: TaskWorker, task: Task):
self.debug_active_tasks[task_id] = (worker, task)

try:
worker._pre_consume_work(task)
# since we are storing a lot of references to the task, we need to make sure
# that we are not storing the same task object in multiple places
worker._pre_consume_work(task.copy())
except Exception:
raise # Re-raise the caught exception
finally:
Expand Down Expand Up @@ -354,9 +369,9 @@ def _task_completed(self, worker: TaskWorker, task: Optional[Task], future):

finally:
# This code will run whether the task succeeded or failed
if task is not None: # Only handle retry and provenance for actual tasks
# Determine whether we should retry the task
if not success:
if not success:
if task is not None:
# Determine whether we should retry the task
if worker.num_retries > 0:
if task.retry_count < worker.num_retries:
task.increment_retry_count()
Expand All @@ -368,30 +383,35 @@ def _task_completed(self, worker: TaskWorker, task: Optional[Task], future):
)
return

with self.task_lock:
self.failed_tasks.appendleft((worker, task, error_message))
self.total_failed_tasks += 1
with self.task_lock:
self.failed_tasks.appendleft(
(worker, task if task else NotificationTask(), error_message)
)
self.total_failed_tasks += 1

if task:
logging.error(
f"Task {task.name} failed after {task.retry_count} retries"
)
# we'll fall through and do the clean up
self._remove_provenance(task)

if success:
else:
with self.task_lock:
self.completed_tasks.appendleft(
(worker, task if task else NotificationTask())
)
self.total_completed_tasks += 1

if task:
self._remove_provenance(task)

with self.task_lock:
self.active_tasks -= 1
if self.active_tasks == 0 and self.work_queue.empty():
self.task_completion_event.set()

def add_work(self, worker: TaskWorker, task: Task):
self._add_provenance(task)
self.work_queue.put((worker, task))
task_copy = task.copy()
self._add_provenance(task_copy)
self.work_queue.put((worker, task_copy))

def add_multiple_work(self, work_items: List[Tuple[TaskWorker, Task]]):
# the ordering of adding provenance first is important for join tasks to
Expand Down

0 comments on commit e3bd7bf

Please sign in to comment.