Skip to content

Commit

Permalink
fix: several bugs that were introduced with buffering publish_work items
Browse files Browse the repository at this point in the history
  • Loading branch information
provos committed Sep 2, 2024
1 parent 2d6edbb commit b17764d
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 20 deletions.
21 changes: 11 additions & 10 deletions src/planai/cached_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,20 @@ def __init__(self, **data):
self._cache = Cache(self.cache_dir, size_limit=self.cache_size_limit)

def _pre_consume_work(self, task: Task):
self.pre_consume_work(task)
with self._work_buffer_context:
self.pre_consume_work(task)

cache_key = self._get_cache_key(task)
cached_results = self._cache.get(cache_key)
cache_key = self._get_cache_key(task)
cached_results = self._cache.get(cache_key)

if cached_results is not None:
logging.info("Cache hit for %s with key: %s", self.name, cache_key)
self._publish_cached_results(cached_results, task)
else:
logging.info("Cache miss for %s with key: %s", self.name, cache_key)
self.consume_work(task)
if cached_results is not None:
logging.info("Cache hit for %s with key: %s", self.name, cache_key)
self._publish_cached_results(cached_results, task)
else:
logging.info("Cache miss for %s with key: %s", self.name, cache_key)
self.consume_work(task)

self.post_consume_work(task)
self.post_consume_work(task)

def pre_consume_work(self, task: Task):
"""
Expand Down
5 changes: 3 additions & 2 deletions src/planai/joined_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def notify(self, prefix: str):
sorted_tasks = sorted(
self._joined_results[prefix], key=attrgetter("_provenance")
)
self.consume_work_joined(sorted_tasks)
del self._joined_results[prefix]
with self._work_buffer_context:
self.consume_work_joined(sorted_tasks)
del self._joined_results[prefix]

def _validate_connection(self) -> None:
"""
Expand Down
6 changes: 3 additions & 3 deletions src/planai/llm_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _invoke_llm(self, task: Task) -> Task:
parser = PydanticOutputParser(pydantic_object=self._output_type())

# allow subclasses to customize the prompt based on the input task
prompt = self.format_prompt(task)
task_prompt = self.format_prompt(task)

# allow subclasses to pre-process the task and present it more clearly to the LLM
processed_task = self.pre_process(task)
Expand All @@ -71,7 +71,7 @@ def _invoke_llm(self, task: Task) -> Task:
output_schema=self._output_type(),
system="You are a helpful AI assistant. Please help the user with the following task and produce output in JSON.",
task=processed_task.model_dump_json(indent=2),
instructions=prompt,
instructions=task_prompt,
format_instructions=parser.get_format_instructions(),
)

Expand Down Expand Up @@ -117,7 +117,7 @@ def post_process(self, response: Optional[Task], input_task: Task):
else:
logging.error(
"LLM did not return a valid response for task %s with provenance %s",
input_task.__class__.__name__,
input_task.name,
input_task._provenance,
)

Expand Down
29 changes: 25 additions & 4 deletions src/planai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import threading
import uuid
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -100,6 +101,17 @@ def prefix_for_input_task(
return None


class WorkBufferContext:
def __init__(self, worker):
self.worker = worker

def __enter__(self):
self.worker._init_work_buffer()

def __exit__(self, exc_type, exc_value, traceback):
self.worker._flush_work_buffer()


class TaskWorker(BaseModel, ABC):
"""
Base class for all task workers.
Expand Down Expand Up @@ -133,6 +145,11 @@ class TaskWorker(BaseModel, ABC):
_last_input_task: Optional[Task] = PrivateAttr(default=None)
_instance_id: uuid.UUID = PrivateAttr(default_factory=uuid.uuid4)
_local: threading.local = PrivateAttr(default_factory=threading.local)
_work_buffer_context: Optional[WorkBufferContext] = PrivateAttr(default=None)

def __init__(self, **data):
super().__init__(**data)
self._work_buffer_context = WorkBufferContext(self)

def __hash__(self):
return hash(self._instance_id)
Expand Down Expand Up @@ -243,9 +260,8 @@ def unwatch(self, prefix: "ProvenanceChain") -> bool:
def _pre_consume_work(self, task: Task):
with self._state_lock:
self._last_input_task = task
self._init_work_buffer()
self.consume_work(task)
self.flush_work_buffer()
with self._work_buffer_context:
self.consume_work(task)

def init(self):
"""
Expand Down Expand Up @@ -315,9 +331,14 @@ def publish_work(self, task: Task, input_task: Optional[Task]):
self._init_work_buffer()
self._local.work_buffer.append((consumer, task))

def flush_work_buffer(self):
def _flush_work_buffer(self):
self._init_work_buffer()
if self._graph and self._graph._dispatcher:
logging.info(
"Worker %s flushing work buffer with %d items",
self.name,
len(self._local.work_buffer),
)
self._graph._dispatcher.add_multiple_work(self._local.work_buffer)
else:
for consumer, task in self._local.work_buffer:
Expand Down
2 changes: 1 addition & 1 deletion tests/planai/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_publish_work(self):
self.worker.register_consumer(DummyTask, self.worker)

self.worker.publish_work(task, input_task)
self.worker.flush_work_buffer()
self.worker._flush_work_buffer()

self.assertEqual(len(task._provenance), 1)
self.assertEqual(task._provenance[0][0], self.worker.name)
Expand Down

0 comments on commit b17764d

Please sign in to comment.