diff --git a/poetry.lock b/poetry.lock index bacb0c4..eb5e8f3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2655,4 +2655,4 @@ docs = [] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "a9f53032be43cf9b5e1421a9922d41707795017727db143eadf3c850358b80ec" +content-hash = "dcf4b0cf69a661dbaea422465fda6c7f1e81d7dc067891956f199e8bb19a17e6" diff --git a/pyproject.toml b/pyproject.toml index b6274a7..9cfc8ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ paramiko = "^3.4.1" openai = "^1.42.0" flask = "^3.0.3" anthropic = "^0.34.2" +colorama = "^0.4.6" [tool.poetry.dev-dependencies] pytest = "^8.3.2" diff --git a/src/planai/graph.py b/src/planai/graph.py index 691b43f..7e5a4c7 100644 --- a/src/planai/graph.py +++ b/src/planai/graph.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import shutil +import time from concurrent.futures import ThreadPoolExecutor -from threading import Thread +from threading import Event, Thread from typing import Dict, List, Optional, Set, Tuple, Type +from colorama import Fore, Style, init from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from .dispatcher import Dispatcher from .joined_task import InitialTaskWorker from .task import Task, TaskType, TaskWorker +# Initialize colorama for Windows compatibility +init(autoreset=True) + class Graph(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -192,6 +198,7 @@ def run( self, initial_tasks: List[Tuple[TaskWorker, Task]], run_dashboard: bool = False, + display_terminal: bool = True, ) -> None: """ Execute the Graph by initiating source tasks and managing the workflow. @@ -236,6 +243,11 @@ def run( dispatcher.start_web_interface() self._dispatcher = dispatcher + if display_terminal: + self._start_terminal_display() + terminal_thread = Thread(target=self._terminal_display_thread) + terminal_thread.start() + # Apply the max parallel tasks settings for worker_class, max_parallel_tasks in self._max_parallel_tasks.items(): dispatcher.set_max_parallel_tasks(worker_class, max_parallel_tasks) @@ -260,11 +272,91 @@ def run( dispatcher.wait_for_completion(wait_for_quit=run_dashboard) dispatcher.stop() dispatch_thread.join() + + if display_terminal: + self._stop_terminal_display_event.set() + terminal_thread.join() + self._thread_pool.shutdown(wait=True) for worker in self.workers: worker.completed() + def _start_terminal_display(self): + self._stop_terminal_display_event = Event() + self._stop_terminal_display_event.clear() + self._log_lines = [] + + def _terminal_display_thread(self): + try: + while not self._stop_terminal_display_event.is_set(): + self.display_terminal_status() + time.sleep(1) # Update interval + finally: + self._clear_terminal() + self._print_log() + + def _clear_terminal(self): + # Clear the terminal when the thread is terminating + print("\033[H\033[J") + + def _print_log(self): + print("\nLog:") + for line in self._log_lines[-10:]: + print(line) + + def display_terminal_status(self): + data = { + "queued": self._dispatcher.get_queued_tasks(), + "active": self._dispatcher.get_active_tasks(), + "completed": self._dispatcher.get_completed_tasks(), + "failed": self._dispatcher.get_failed_tasks(), + } + terminal_size = shutil.get_terminal_size((80, 20)) + terminal_width = terminal_size.columns + + print("\033[H\033[J") # Clear terminal + + for worker in sorted( + set(t["worker"] for tasks in data.values() for t in tasks) + ): + completed = sum(1 for t in data["completed"] if t["worker"] == worker) + active = sum(1 for t in data["active"] if t["worker"] == worker) + queued = sum(1 for t in data["queued"] if t["worker"] == worker) + failed = sum(1 for t in data["failed"] if t["worker"] == worker) + + total_tasks = completed + active + queued + failed + # Including space for worker name and separators + bar_length = (terminal_width - 48) // 2 + + if total_tasks > 0: + completed_ratio = completed / total_tasks + active_ratio = active / total_tasks + queued_ratio = queued / total_tasks + else: + completed_ratio = active_ratio = queued_ratio = 0 + + # Create bars based on ratios + completed_bar = Fore.GREEN + "🟩" * int(bar_length * completed_ratio) + active_bar = Fore.BLUE + "🔵" * int(bar_length * active_ratio) + queued_bar = Fore.LIGHTYELLOW_EX + "🟠" * int(bar_length * queued_ratio) + failed_bar = ( + Fore.RED + "❌" * failed + ) # Using a cross mark emoji for failed tasks + + print( + f"{worker:20} | {completed_bar}{active_bar}{queued_bar}{Style.RESET_ALL} {failed_bar}" + ) + + self._print_log() + + # Reset the cursor to the top + print("\033[H") + + def print(self, *args): + message = " ".join(str(arg) for arg in args) + self._log_lines.append(message) + def __str__(self) -> str: return f"Graph: {self.name} with {len(self.workers)} tasks" @@ -288,7 +380,8 @@ class Task1Worker(TaskWorker): output_types: List[Type[Task]] = [Task2WorkItem] def consume_work(self, task: Task1WorkItem): - print(f"Task1 consuming: {task.data}") + self.print(f"Task1 consuming: {task.data}") + time.sleep(1) processed = f"Processed: {task.data.upper()}" self.publish_work(Task2WorkItem(processed_data=processed), input_task=task) @@ -296,7 +389,8 @@ class Task2Worker(TaskWorker): output_types: List[Type[Task]] = [Task3WorkItem] def consume_work(self, task: Task2WorkItem): - print(f"Task2 consuming: {task.processed_data}") + self.print(f"Task2 consuming: {task.processed_data}") + time.sleep(1) final = f"Final: {task.processed_data}!" self.publish_work(Task3WorkItem(final_result=final), input_task=task) @@ -304,8 +398,9 @@ class Task3Worker(TaskWorker): output_types: Set[Type[Task]] = set() def consume_work(self, task: Task3WorkItem): - print(f"Task3 consuming: {task.final_result}") - print("Workflow complete!") + self.print(f"Task3 consuming: {task.final_result}") + time.sleep(1) + self.print("Workflow complete!") # Create Graph graph = Graph(name="Simple Workflow") diff --git a/src/planai/task.py b/src/planai/task.py index 2e9af65..02054c3 100644 --- a/src/planai/task.py +++ b/src/planai/task.py @@ -336,6 +336,15 @@ def unwatch(self, prefix: "ProvenanceChain") -> bool: raise ValueError("Prefix must be a tuple") return self._graph._dispatcher.unwatch(prefix, self) + def print(self, *args): + """ + Prints a message to the console. + + Parameters: + *args: The message to print. + """ + self._graph.print(*args) + def _pre_consume_work(self, task: Task): with self._state_lock: self._last_input_task = task