Skip to content

Commit

Permalink
fix: notification tasks would remove provenance and cause joiners to …
Browse files Browse the repository at this point in the history
…be called prematurely

feat: introduce a tracing mechanism for provenance in the dashboard
  • Loading branch information
provos committed Sep 3, 2024
1 parent 9d9ef9c commit f85a189
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 38 deletions.
132 changes: 100 additions & 32 deletions src/planai/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import sys
import threading
import time
from collections import defaultdict, deque
Expand Down Expand Up @@ -41,11 +42,16 @@
ProvenanceChain = Tuple[Tuple[TaskName, TaskID], ...]


class NotificationTask(Task):
pass


class Dispatcher:
def __init__(self, graph: "Graph", web_port=5000):
self.graph = graph
self.work_queue = Queue()
self.provenance: DefaultDict[ProvenanceChain, int] = defaultdict(int)
self.provenance_trace: Dict[ProvenanceChain, dict] = {}
self.notifiers: DefaultDict[ProvenanceChain, List[TaskWorker]] = defaultdict(
list
)
Expand Down Expand Up @@ -76,18 +82,63 @@ def _add_provenance(self, task: Task):
for prefix in self._generate_prefixes(task):
with self.provenance_lock:
self.provenance[prefix] = self.provenance.get(prefix, 0) + 1
if prefix in self.provenance_trace:
trace_entry = {
"worker": task._provenance[-1][0],
"action": "adding",
"task": task.name,
"count": self.provenance[prefix],
"status": "",
}
self.provenance_trace[prefix].append(trace_entry)

def _remove_provenance(self, task: Task):
to_notify = set()
for prefix in self._generate_prefixes(task):
with self.provenance_lock:
self.provenance[prefix] -= 1
if self.provenance[prefix] == 0:
effective_count = self.provenance[prefix]

if effective_count < 0:
error_message = f"FATAL ERROR: Provenance count for prefix {prefix} became negative ({effective_count}). This indicates a serious bug in the provenance tracking system."
logging.critical(error_message)
print(error_message, file=sys.stderr)
sys.exit(1)

if effective_count == 0:
del self.provenance[prefix]
to_notify.add(prefix)

if prefix in self.provenance_trace:
# Get the list of notifiers for this prefix
with self.notifiers_lock:
notifiers = [n.name for n in self.notifiers.get(prefix, [])]

status = (
"will notify watchers"
if effective_count == 0
else "still waiting for other tasks"
)
if effective_count == 0 and notifiers:
status += f" (Notifying: {', '.join(notifiers)})"

trace_entry = {
"worker": task._provenance[-1][0],
"action": "removing",
"task": task.name,
"count": effective_count,
"status": status,
}
self.provenance_trace[prefix].append(trace_entry)

for prefix in to_notify:
self._notify_task_completion(prefix, task)
self._notify_task_completion(prefix)

def trace(self, prefix: ProvenanceChain):
logging.debug(f"Starting trace for {prefix}")
with self.provenance_lock:
if prefix not in self.provenance_trace:
self.provenance_trace[prefix] = []

def watch(
self, prefix: ProvenanceChain, notifier: TaskWorker, task: Optional[Task] = None
Expand Down Expand Up @@ -155,7 +206,7 @@ def unwatch(self, prefix: ProvenanceChain, notifier: TaskWorker) -> bool:
return True
return False

def _notify_task_completion(self, prefix: tuple, task: Task):
def _notify_task_completion(self, prefix: tuple):
to_notify = []
with self.notifiers_lock:
for notifier in self.notifiers[prefix]:
Expand All @@ -166,8 +217,8 @@ def _notify_task_completion(self, prefix: tuple, task: Task):
self.active_tasks += 1

# Use a named function instead of a lambda to avoid closure issues
def task_completed_callback(future, worker=notifier, task=task):
self._task_completed(worker, task, future)
def task_completed_callback(future, worker=notifier):
self._task_completed(worker, None, future)

future = self.graph._thread_pool.submit(notifier.notify, prefix)
future.add_done_callback(task_completed_callback)
Expand Down Expand Up @@ -235,6 +286,10 @@ def _task_to_dict(self, worker: TaskWorker, task: Task, error: str = "") -> Dict
data["error"] = error
return data

def get_traces(self) -> Dict:
with self.provenance_lock:
return self.provenance_trace

def get_queued_tasks(self) -> List[Dict]:
return [
self._task_to_dict(worker, task) for worker, task in self.work_queue.queue
Expand Down Expand Up @@ -269,53 +324,66 @@ def _get_task_id(self, task: Task) -> str:
# Fallback in case _provenance is empty
return f"unknown_{id(task)}"

def _task_completed(self, worker: TaskWorker, task: Task, future):
def _task_completed(self, worker: TaskWorker, task: Optional[Task], future):
success: bool = False
error_message: str = ""
try:
# This will raise an exception if the task failed
_ = future.result()

# Handle successful task completion
logging.info(f"Task {task.name} completed successfully")
if task:
logging.info(f"Task {task.name} completed successfully")
else:
logging.info(
f"Notification for worker {worker.name} completed successfully"
)
success = True

except Exception as e:
# Handle task failure
error_message = str(e)
logging.exception(f"Task {task.name} failed with exception: {str(e)}")
if task:
logging.exception(f"Task {task.name} failed with exception: {str(e)}")
else:
logging.exception(
f"Notification for worker {worker.name} failed with exception: {str(e)}"
)

# Anything else that needs to be done when a task fails?

finally:
# This code will run whether the task succeeded or failed

# Determine whether we should retry the task
if not success:
if worker.num_retries > 0:
if task.retry_count < worker.num_retries:
task.increment_retry_count()
with self.task_lock:
self.active_tasks -= 1
self.work_queue.put((worker, task))
logging.info(
f"Retrying task {task.name} for the {task.retry_count} time"
)
return

with self.task_lock:
self.failed_tasks.appendleft((worker, task, error_message))
self.total_failed_tasks += 1
logging.error(
f"Task {task.name} failed after {task.retry_count} retries"
)
# we'll fall through and do the clean up
else:
if task is not None: # Only handle retry and provenance for actual tasks
# Determine whether we should retry the task
if not success:
if worker.num_retries > 0:
if task.retry_count < worker.num_retries:
task.increment_retry_count()
with self.task_lock:
self.active_tasks -= 1
self.work_queue.put((worker, task))
logging.info(
f"Retrying task {task.name} for the {task.retry_count} time"
)
return

with self.task_lock:
self.failed_tasks.appendleft((worker, task, error_message))
self.total_failed_tasks += 1
logging.error(
f"Task {task.name} failed after {task.retry_count} retries"
)
# we'll fall through and do the clean up
self._remove_provenance(task)

if success:
with self.task_lock:
self.completed_tasks.appendleft((worker, task))
self.completed_tasks.appendleft(
(worker, task if task else NotificationTask())
)
self.total_completed_tasks += 1

self._remove_provenance(task)
with self.task_lock:
self.active_tasks -= 1
if self.active_tasks == 0 and self.work_queue.empty():
Expand Down
15 changes: 15 additions & 0 deletions src/planai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,21 @@ def next(self, downstream: "TaskWorker"):
self._graph.set_dependency(self, downstream)
return downstream

def trace(self, prefix: "ProvenanceChain"):
"""
Traces the provenance chain for a given prefix in the graph.
This method sets up a trace on a given prefix in the provenance chain. It will be visible
in the dispatcher dashboard.
Parameters:
-----------
prefix : ProvenanceChain
The prefix to trace. Must be a tuple representing a part of a task's provenance chain.
This is the sequence of task identifiers leading up to (but not including) the current task.
"""
self._graph._dispatcher.trace(prefix)

def watch(self, prefix: "ProvenanceChain", task: Optional[Task] = None) -> bool:
"""
Watches for the completion of a specific provenance chain prefix in the task graph.
Expand Down
87 changes: 86 additions & 1 deletion src/planai/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,30 @@
.nested-object {
margin-left: 20px;
}

.trace-list {
background-color: var(--task-list-bg);
border: 1px solid var(--border-color);
border-radius: 4px;
padding: 10px;
margin-bottom: 20px;
}

.trace-prefix {
background-color: var(--task-item-bg);
border: 1px solid var(--border-color);
border-radius: 4px;
padding: 10px;
margin-bottom: 10px;
}

.trace-prefix h4 {
margin-top: 0;
}

.trace-prefix ul {
padding-left: 20px;
}
</style>
</head>

Expand All @@ -111,6 +135,9 @@
<h1>PlanAI Dispatcher Dashboard</h1>
<button id="quit-button">Quit</button>

<h2>Provenance Trace</h2>
<div id="trace-visualization" class="trace-list"></div>

<h2>Queued Tasks</h2>
<div id="queued-tasks" class="task-list"></div>

Expand Down Expand Up @@ -201,7 +228,6 @@ <h4>Input Provenance:</h4>
});
}


function renderObject(obj, indent = '') {
return Object.entries(obj).map(([key, value]) => {
if (value === null) return `${indent}<strong>${key}:</strong> null`;
Expand All @@ -224,7 +250,65 @@ <h4>Input Provenance:</h4>
}
}

// Trace visualization
let traceEventSource;

function setupTraceEventSource() {
traceEventSource = new EventSource("/trace_stream");

traceEventSource.onmessage = function (event) {
const traceData = JSON.parse(event.data);
updateTraceVisualization(traceData);
};

traceEventSource.onerror = function (error) {
traceEventSource.close();
setTimeout(setupTraceEventSource, 1000);
};
}

setupTraceEventSource();

function updateTraceVisualization(traceData) {
const traceElement = document.getElementById('trace-visualization');
traceElement.innerHTML = '';

for (const [prefixStr, entries] of Object.entries(traceData)) {
const prefixElement = document.createElement('div');
prefixElement.className = 'trace-prefix';

const prefix = prefixStr.split('_').join(', ');

prefixElement.innerHTML = `<h4>Prefix: (${prefix})</h4>`;

const table = document.createElement('table');
table.className = 'trace-table';

// Create table header
const headerRow = table.insertRow();
['Worker', 'Action', 'Task', 'Count', 'Status'].forEach(header => {
const th = document.createElement('th');
th.textContent = header;
headerRow.appendChild(th);
});

// Populate table with entry data
entries.forEach(entry => {
const row = table.insertRow();
row.className = entry.action + (entry.count === 0 ? ' zero-count' : '');
['worker', 'action', 'task', 'count', 'status'].forEach(key => {
const cell = row.insertCell();
cell.textContent = entry[key];
});
});

prefixElement.appendChild(table);
traceElement.appendChild(prefixElement);
}
}


// Quit button functionality
document.getElementById('quit-button').addEventListener('click', function () {
fetch('/quit', { method: 'POST' })
.then(response => response.json())
Expand All @@ -234,6 +318,7 @@ <h4>Input Provenance:</h4>
}
});
});

// Theme toggle functionality
const themeToggle = document.getElementById('theme-toggle');
const prefersDarkScheme = window.matchMedia("(prefers-color-scheme: dark)");
Expand Down
20 changes: 19 additions & 1 deletion src/planai/web_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def event_stream():
if current_data != last_data:
yield f"data: {json.dumps(current_data)}\n\n"
last_data = current_data
time.sleep(1)
time.sleep(0.2)

def get_current_data():
queued_tasks = dispatcher.get_queued_tasks()
Expand All @@ -61,6 +61,24 @@ def get_current_data():
return Response(event_stream(), mimetype="text/event-stream")


@app.route("/trace_stream")
def trace_stream():
def event_stream():
last_trace = None
while True:
current_trace = dispatcher.get_traces()
if current_trace != last_trace:
# Convert tuple keys to strings
serializable_trace = {
"_".join(map(str, k)): v for k, v in current_trace.items()
}
yield f"data: {json.dumps(serializable_trace)}\n\n"
last_trace = current_trace
time.sleep(0.2)

return Response(event_stream(), mimetype="text/event-stream")


@app.route("/quit", methods=["POST"])
def quit():
global quit_event
Expand Down
Loading

0 comments on commit f85a189

Please sign in to comment.