Skip to content

Commit

Permalink
feat: terminal dashboard to show progress on the terminal
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Sep 12, 2024
1 parent a53e0d7 commit 3bfaef7
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 6 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ paramiko = "^3.4.1"
openai = "^1.42.0"
flask = "^3.0.3"
anthropic = "^0.34.2"
colorama = "^0.4.6"

[tool.poetry.dev-dependencies]
pytest = "^8.3.2"
Expand Down
105 changes: 100 additions & 5 deletions src/planai/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import shutil
import time
from concurrent.futures import ThreadPoolExecutor
from threading import Thread
from threading import Event, Thread
from typing import Dict, List, Optional, Set, Tuple, Type

from colorama import Fore, Style, init
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr

from .dispatcher import Dispatcher
from .joined_task import InitialTaskWorker
from .task import Task, TaskType, TaskWorker

# Initialize colorama for Windows compatibility
init(autoreset=True)


class Graph(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
Expand Down Expand Up @@ -192,6 +198,7 @@ def run(
self,
initial_tasks: List[Tuple[TaskWorker, Task]],
run_dashboard: bool = False,
display_terminal: bool = True,
) -> None:
"""
Execute the Graph by initiating source tasks and managing the workflow.
Expand Down Expand Up @@ -236,6 +243,11 @@ def run(
dispatcher.start_web_interface()
self._dispatcher = dispatcher

if display_terminal:
self._start_terminal_display()
terminal_thread = Thread(target=self._terminal_display_thread)
terminal_thread.start()

# Apply the max parallel tasks settings
for worker_class, max_parallel_tasks in self._max_parallel_tasks.items():
dispatcher.set_max_parallel_tasks(worker_class, max_parallel_tasks)
Expand All @@ -260,11 +272,91 @@ def run(
dispatcher.wait_for_completion(wait_for_quit=run_dashboard)
dispatcher.stop()
dispatch_thread.join()

if display_terminal:
self._stop_terminal_display_event.set()
terminal_thread.join()

self._thread_pool.shutdown(wait=True)

for worker in self.workers:
worker.completed()

def _start_terminal_display(self):
self._stop_terminal_display_event = Event()
self._stop_terminal_display_event.clear()
self._log_lines = []

def _terminal_display_thread(self):
try:
while not self._stop_terminal_display_event.is_set():
self.display_terminal_status()
time.sleep(1) # Update interval
finally:
self._clear_terminal()
self._print_log()

def _clear_terminal(self):
# Clear the terminal when the thread is terminating
print("\033[H\033[J")

def _print_log(self):
print("\nLog:")
for line in self._log_lines[-10:]:
print(line)

def display_terminal_status(self):
data = {
"queued": self._dispatcher.get_queued_tasks(),
"active": self._dispatcher.get_active_tasks(),
"completed": self._dispatcher.get_completed_tasks(),
"failed": self._dispatcher.get_failed_tasks(),
}
terminal_size = shutil.get_terminal_size((80, 20))
terminal_width = terminal_size.columns

print("\033[H\033[J") # Clear terminal

for worker in sorted(
set(t["worker"] for tasks in data.values() for t in tasks)
):
completed = sum(1 for t in data["completed"] if t["worker"] == worker)
active = sum(1 for t in data["active"] if t["worker"] == worker)
queued = sum(1 for t in data["queued"] if t["worker"] == worker)
failed = sum(1 for t in data["failed"] if t["worker"] == worker)

total_tasks = completed + active + queued + failed
# Including space for worker name and separators
bar_length = (terminal_width - 48) // 2

if total_tasks > 0:
completed_ratio = completed / total_tasks
active_ratio = active / total_tasks
queued_ratio = queued / total_tasks
else:
completed_ratio = active_ratio = queued_ratio = 0

# Create bars based on ratios
completed_bar = Fore.GREEN + "🟩" * int(bar_length * completed_ratio)
active_bar = Fore.BLUE + "🔵" * int(bar_length * active_ratio)
queued_bar = Fore.LIGHTYELLOW_EX + "🟠" * int(bar_length * queued_ratio)
failed_bar = (
Fore.RED + "❌" * failed
) # Using a cross mark emoji for failed tasks

print(
f"{worker:20} | {completed_bar}{active_bar}{queued_bar}{Style.RESET_ALL} {failed_bar}"
)

self._print_log()

# Reset the cursor to the top
print("\033[H")

def print(self, *args):
message = " ".join(str(arg) for arg in args)
self._log_lines.append(message)

def __str__(self) -> str:
return f"Graph: {self.name} with {len(self.workers)} tasks"

Expand All @@ -288,24 +380,27 @@ class Task1Worker(TaskWorker):
output_types: List[Type[Task]] = [Task2WorkItem]

def consume_work(self, task: Task1WorkItem):
print(f"Task1 consuming: {task.data}")
self.print(f"Task1 consuming: {task.data}")
time.sleep(1)
processed = f"Processed: {task.data.upper()}"
self.publish_work(Task2WorkItem(processed_data=processed), input_task=task)

class Task2Worker(TaskWorker):
output_types: List[Type[Task]] = [Task3WorkItem]

def consume_work(self, task: Task2WorkItem):
print(f"Task2 consuming: {task.processed_data}")
self.print(f"Task2 consuming: {task.processed_data}")
time.sleep(1)
final = f"Final: {task.processed_data}!"
self.publish_work(Task3WorkItem(final_result=final), input_task=task)

class Task3Worker(TaskWorker):
output_types: Set[Type[Task]] = set()

def consume_work(self, task: Task3WorkItem):
print(f"Task3 consuming: {task.final_result}")
print("Workflow complete!")
self.print(f"Task3 consuming: {task.final_result}")
time.sleep(1)
self.print("Workflow complete!")

# Create Graph
graph = Graph(name="Simple Workflow")
Expand Down
9 changes: 9 additions & 0 deletions src/planai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,15 @@ def unwatch(self, prefix: "ProvenanceChain") -> bool:
raise ValueError("Prefix must be a tuple")
return self._graph._dispatcher.unwatch(prefix, self)

def print(self, *args):
"""
Prints a message to the console.
Parameters:
*args: The message to print.
"""
self._graph.print(*args)

def _pre_consume_work(self, task: Task):
with self._state_lock:
self._last_input_task = task
Expand Down

0 comments on commit 3bfaef7

Please sign in to comment.