diff --git a/src/planai/cached_task.py b/src/planai/cached_task.py index 2e59e30..b87299e 100644 --- a/src/planai/cached_task.py +++ b/src/planai/cached_task.py @@ -35,19 +35,20 @@ def __init__(self, **data): self._cache = Cache(self.cache_dir, size_limit=self.cache_size_limit) def _pre_consume_work(self, task: Task): - self.pre_consume_work(task) + with self._work_buffer_context: + self.pre_consume_work(task) - cache_key = self._get_cache_key(task) - cached_results = self._cache.get(cache_key) + cache_key = self._get_cache_key(task) + cached_results = self._cache.get(cache_key) - if cached_results is not None: - logging.info("Cache hit for %s with key: %s", self.name, cache_key) - self._publish_cached_results(cached_results, task) - else: - logging.info("Cache miss for %s with key: %s", self.name, cache_key) - self.consume_work(task) + if cached_results is not None: + logging.info("Cache hit for %s with key: %s", self.name, cache_key) + self._publish_cached_results(cached_results, task) + else: + logging.info("Cache miss for %s with key: %s", self.name, cache_key) + self.consume_work(task) - self.post_consume_work(task) + self.post_consume_work(task) def pre_consume_work(self, task: Task): """ diff --git a/src/planai/joined_task.py b/src/planai/joined_task.py index 9c7f744..fb1cc8c 100644 --- a/src/planai/joined_task.py +++ b/src/planai/joined_task.py @@ -72,8 +72,9 @@ def notify(self, prefix: str): sorted_tasks = sorted( self._joined_results[prefix], key=attrgetter("_provenance") ) - self.consume_work_joined(sorted_tasks) - del self._joined_results[prefix] + with self._work_buffer_context: + self.consume_work_joined(sorted_tasks) + del self._joined_results[prefix] def _validate_connection(self) -> None: """ diff --git a/src/planai/llm_task.py b/src/planai/llm_task.py index 067c193..0bb8e10 100644 --- a/src/planai/llm_task.py +++ b/src/planai/llm_task.py @@ -61,7 +61,7 @@ def _invoke_llm(self, task: Task) -> Task: parser = PydanticOutputParser(pydantic_object=self._output_type()) # allow subclasses to customize the prompt based on the input task - prompt = self.format_prompt(task) + task_prompt = self.format_prompt(task) # allow subclasses to pre-process the task and present it more clearly to the LLM processed_task = self.pre_process(task) @@ -71,7 +71,7 @@ def _invoke_llm(self, task: Task) -> Task: output_schema=self._output_type(), system="You are a helpful AI assistant. Please help the user with the following task and produce output in JSON.", task=processed_task.model_dump_json(indent=2), - instructions=prompt, + instructions=task_prompt, format_instructions=parser.get_format_instructions(), ) @@ -117,7 +117,7 @@ def post_process(self, response: Optional[Task], input_task: Task): else: logging.error( "LLM did not return a valid response for task %s with provenance %s", - input_task.__class__.__name__, + input_task.name, input_task._provenance, ) diff --git a/src/planai/task.py b/src/planai/task.py index 8575888..148b6ac 100644 --- a/src/planai/task.py +++ b/src/planai/task.py @@ -16,6 +16,7 @@ import threading import uuid from abc import ABC, abstractmethod +from contextlib import contextmanager from typing import ( TYPE_CHECKING, Any, @@ -100,6 +101,17 @@ def prefix_for_input_task( return None +class WorkBufferContext: + def __init__(self, worker): + self.worker = worker + + def __enter__(self): + self.worker._init_work_buffer() + + def __exit__(self, exc_type, exc_value, traceback): + self.worker._flush_work_buffer() + + class TaskWorker(BaseModel, ABC): """ Base class for all task workers. @@ -133,6 +145,11 @@ class TaskWorker(BaseModel, ABC): _last_input_task: Optional[Task] = PrivateAttr(default=None) _instance_id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4) _local: threading.local = PrivateAttr(default_factory=threading.local) + _work_buffer_context: Optional[WorkBufferContext] = PrivateAttr(default=None) + + def __init__(self, **data): + super().__init__(**data) + self._work_buffer_context = WorkBufferContext(self) def __hash__(self): return hash(self._instance_id) @@ -243,9 +260,8 @@ def unwatch(self, prefix: "ProvenanceChain") -> bool: def _pre_consume_work(self, task: Task): with self._state_lock: self._last_input_task = task - self._init_work_buffer() - self.consume_work(task) - self.flush_work_buffer() + with self._work_buffer_context: + self.consume_work(task) def init(self): """ @@ -315,9 +331,14 @@ def publish_work(self, task: Task, input_task: Optional[Task]): self._init_work_buffer() self._local.work_buffer.append((consumer, task)) - def flush_work_buffer(self): + def _flush_work_buffer(self): self._init_work_buffer() if self._graph and self._graph._dispatcher: + logging.info( + "Worker %s flushing work buffer with %d items", + self.name, + len(self._local.work_buffer), + ) self._graph._dispatcher.add_multiple_work(self._local.work_buffer) else: for consumer, task in self._local.work_buffer: diff --git a/tests/planai/test_task.py b/tests/planai/test_task.py index a9eaaf1..acf9a84 100644 --- a/tests/planai/test_task.py +++ b/tests/planai/test_task.py @@ -141,7 +141,7 @@ def test_publish_work(self): self.worker.register_consumer(DummyTask, self.worker) self.worker.publish_work(task, input_task) - self.worker.flush_work_buffer() + self.worker._flush_work_buffer() self.assertEqual(len(task._provenance), 1) self.assertEqual(task._provenance[0][0], self.worker.name)