Skip to content

Commit

Permalink
Add ControlFlowTask
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo committed Jul 5, 2024
1 parent 3e5bd03 commit 9b497da
Show file tree
Hide file tree
Showing 15 changed files with 377 additions and 13 deletions.
4 changes: 4 additions & 0 deletions griptape/artifacts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from .media_artifact import MediaArtifact
from .image_artifact import ImageArtifact
from .audio_artifact import AudioArtifact
from .control_flow_artifact import ControlFlowArtifact
from .task_artifact import TaskArtifact


__all__ = [
Expand All @@ -23,4 +25,6 @@
"MediaArtifact",
"ImageArtifact",
"AudioArtifact",
"ControlFlowArtifact",
"TaskArtifact",
]
6 changes: 4 additions & 2 deletions griptape/artifacts/boolean_artifact.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Any, Union
from attrs import define, field
from griptape.artifacts import BaseArtifact
from griptape.artifacts import TextArtifact, BaseArtifact


@define
Expand All @@ -10,11 +10,13 @@ class BooleanArtifact(BaseArtifact):
meta: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})

@classmethod
def parse_bool(cls, value: Union[str, bool]) -> BooleanArtifact:
def parse_bool(cls, value: Union[str, bool, TextArtifact]) -> BooleanArtifact:
"""
Convert a string literal or bool to a BooleanArtifact. The string must be either "true" or "false" with any casing.
"""
if value is not None:
if isinstance(value, TextArtifact):
value = str(value)
if isinstance(value, str):
if value.lower() == "true":
return BooleanArtifact(True)
Expand Down
7 changes: 7 additions & 0 deletions griptape/artifacts/control_flow_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from attrs import define
from griptape.artifacts import BaseArtifact


@define
class ControlFlowArtifact(BaseArtifact):
pass
30 changes: 30 additions & 0 deletions griptape/artifacts/task_artifact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations
from attrs import define, field
from typing import TYPE_CHECKING
from griptape.artifacts import ControlFlowArtifact

if TYPE_CHECKING:
from griptape.tasks import BaseTask
from griptape.artifacts import BaseArtifact


@define
class TaskArtifact(ControlFlowArtifact):
value: BaseTask = field(metadata={"serializable": True})

@property
def task_id(self) -> str:
return self.value.id

@property
def task(self) -> BaseTask:
return self.value

def to_text(self) -> str:
return self.value.id

def __add__(self, other: BaseArtifact) -> BaseArtifact:
raise NotImplementedError("TaskArtifact does not support addition")

def __eq__(self, value: object) -> bool:
return self.value is value
3 changes: 2 additions & 1 deletion griptape/memory/meta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_meta_entry import BaseMetaEntry
from .action_subtask_meta_entry import ActionSubtaskMetaEntry
from .control_flow_meta_entry import ControlFlowMetaEntry
from .meta_memory import MetaMemory

__all__ = ["BaseMetaEntry", "MetaMemory", "ActionSubtaskMetaEntry"]
__all__ = ["BaseMetaEntry", "MetaMemory", "ActionSubtaskMetaEntry", "ControlFlowMetaEntry"]
15 changes: 15 additions & 0 deletions griptape/memory/meta/control_flow_meta_entry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from attrs import field, define
from griptape.memory.meta import BaseMetaEntry

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact


@define
class ControlFlowMetaEntry(BaseMetaEntry):
type: str = field(default=__name__, kw_only=True, metadata={"serializable": False})
input_tasks: list[str] = field(factory=list, kw_only=True)
output_tasks: list[str] = field(factory=list, kw_only=True)
output: Optional[BaseArtifact] = field(default=None, kw_only=True)
3 changes: 3 additions & 0 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def default_task_memory(self) -> TaskMemory:
def is_finished(self) -> bool:
return all(s.is_finished() for s in self.tasks)

def is_complete(self) -> bool:
return all(s.is_complete() for s in self.tasks)

def is_executing(self) -> bool:
return any(s for s in self.tasks if s.is_executing())

Expand Down
16 changes: 9 additions & 7 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,22 @@ def insert_task(
def try_run(self, *args) -> Workflow:
exit_loop = False

while not self.is_finished() and not exit_loop:
while not self.is_complete() and not exit_loop:
futures_list = {}
ordered_tasks = self.order_tasks()
executable_tasks = [*filter(lambda task: task.can_execute(), self.order_tasks())]

for task in ordered_tasks:
if task.can_execute():
future = self.futures_executor_fn().submit(task.execute)
futures_list[future] = task
if not executable_tasks:
exit_loop = True
break

for task in executable_tasks:
future = self.futures_executor_fn().submit(task.execute)
futures_list[future] = task

# Wait for all tasks to complete
for future in futures.as_completed(futures_list):
if isinstance(future.result(), ErrorArtifact) and self.fail_fast:
exit_loop = True

break

if self.conversation_memory and self.output is not None:
Expand Down
4 changes: 4 additions & 0 deletions griptape/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from .text_to_speech_task import TextToSpeechTask
from .structure_run_task import StructureRunTask
from .audio_transcription_task import AudioTranscriptionTask
from .base_control_flow_task import BaseControlFlowTask
from .choice_control_flow_task import ChoiceControlFlowTask

__all__ = [
"BaseTask",
Expand All @@ -46,4 +48,6 @@
"TextToSpeechTask",
"StructureRunTask",
"AudioTranscriptionTask",
"BaseControlFlowTask",
"ChoiceControlFlowTask",
]
35 changes: 35 additions & 0 deletions griptape/tasks/base_control_flow_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations
from abc import ABC
from attrs import define
from griptape.tasks import BaseTask
from griptape.memory.meta import ControlFlowMetaEntry


@define
class BaseControlFlowTask(BaseTask, ABC):
def before_run(self) -> None:
super().before_run()

self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}")

def after_run(self) -> None:
super().after_run()

self.structure.meta_memory.add_entry(
ControlFlowMetaEntry(
input_tasks=[parent.id for parent in self.parents],
output_tasks=[child.id for child in filter(lambda child: not child.is_finished(), self.children)],
output=self.output,
)
)

self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}")

def _cancel_children_rec(self, task: BaseTask, chosen_task: BaseTask) -> None:
for child in filter(lambda child: child != chosen_task, task.children):
if all(parent.is_complete() for parent in filter(lambda parent: parent != task, child.parents)):
child.state = BaseTask.State.CANCELLED
self._cancel_children_rec(child, chosen_task)

def _get_task(self, task: str | BaseTask) -> BaseTask:
return self.structure.find_task(task) if isinstance(task, str) else task
15 changes: 14 additions & 1 deletion griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class State(Enum):
PENDING = 1
EXECUTING = 2
FINISHED = 3
CANCELLED = 4

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
state: State = field(default=State.PENDING, kw_only=True)
Expand Down Expand Up @@ -101,9 +102,15 @@ def is_pending(self) -> bool:
def is_finished(self) -> bool:
return self.state == BaseTask.State.FINISHED

def is_cancelled(self) -> bool:
return self.state == BaseTask.State.CANCELLED

def is_executing(self) -> bool:
return self.state == BaseTask.State.EXECUTING

def is_complete(self) -> bool:
return self.is_finished() or self.is_cancelled()

def before_run(self) -> None:
if self.structure:
self.structure.publish_event(
Expand Down Expand Up @@ -147,7 +154,13 @@ def execute(self) -> Optional[BaseArtifact]:
return self.output

def can_execute(self) -> bool:
return self.state == BaseTask.State.PENDING and all(parent.is_finished() for parent in self.parents)
return self.is_pending() and (
(
all(parent.is_complete() for parent in self.parents)
and any(parent.is_finished() for parent in self.parents)
)
or len(self.parents) == 0
)

def reset(self) -> BaseTask:
self.state = BaseTask.State.PENDING
Expand Down
45 changes: 45 additions & 0 deletions griptape/tasks/boolean_control_flow_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Union
from attrs import field

from griptape.artifacts import BooleanArtifact
from griptape.tasks import BaseControlFlowTask

if TYPE_CHECKING:
from griptape.tasks import BaseTask


class BooleanControlFlowTask(BaseControlFlowTask):
true_tasks: list[str | BaseTask] = field(factory=list, kw_only=True)
false_tasks: list[str | BaseTask] = field(factory=list, kw_only=True)
operator: Union[Literal["and"], Literal["or"], Literal["xor"]] = field(default="and", kw_only=True)
coerce_inputs_to_bool: bool = field(default=False, kw_only=True)

def run(self) -> BooleanArtifact:
if not all(
choice_task if isinstance(choice_task, str) else choice_task.id in self.child_ids
for choice_task in [*self.true_tasks, *self.false_tasks]
):
raise ValueError(f"BooleanControlFlowTask {self.id} has invalid true_tasks or false_tasks")

inputs = [task.output for task in self.parents]

if self.coerce_inputs_to_bool:
inputs = [BooleanArtifact(input) for input in inputs]
else:
if not all(isinstance(input, BooleanArtifact) for input in inputs):
raise ValueError(f"BooleanControlFlowTask {self.id} has non-BooleanArtifact inputs")

if self.operator == "and":
self.output = BooleanArtifact(all(inputs))
elif self.operator == "or":
self.output = BooleanArtifact(any(inputs))
elif self.operator == "xor":
self.output = BooleanArtifact(sum([int(input.value) for input in inputs]) == 1)
else:
raise ValueError(f"BooleanControlFlowTask {self.id} has invalid operator {self.operator}")

for task in self.true_tasks if self.output.value else self.false_tasks:
task = self._get_task(task)
self._cancel_children_rec(self, task)
return self.output
64 changes: 64 additions & 0 deletions griptape/tasks/choice_control_flow_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations
from typing import Callable
from attrs import define, field

from griptape.artifacts import BaseArtifact, ErrorArtifact, TaskArtifact, ListArtifact
from griptape.tasks import BaseTask
from griptape.tasks import BaseControlFlowTask


@define
class ChoiceControlFlowTask(BaseControlFlowTask):
control_flow_fn: Callable[[list[BaseTask] | BaseTask], list[BaseTask | str] | BaseTask | str] = field(
metadata={"serializable": False}
)

@property
def input(self) -> BaseArtifact:
if len(self.parents) == 1:
return TaskArtifact(self.parents[0])
return ListArtifact([TaskArtifact(parent) for parent in self.parents])

def before_run(self) -> None:
super().before_run()

self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}")

def after_run(self) -> None:
super().after_run()

self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}")

def run(self) -> BaseArtifact:
tasks = self.control_flow_fn(
[artifact.value for artifact in self.input.value]
if isinstance(self.input, ListArtifact)
else self.input.value
)

if not isinstance(tasks, list):
tasks = [tasks]

if tasks is None:
tasks = []

tasks = [self._get_task(task) for task in tasks]

for task in tasks:
if task.id not in self.child_ids:
self.output = ErrorArtifact(f"ControlFlowTask {self.id} did not return a valid child task")
return self.output

self.output = (
ListArtifact(
[
parent.value.output
for parent in filter(lambda parent: parent.value.output is not None, self.input.value)
] # pyright: ignore
)
if isinstance(self.input, ListArtifact)
else self.input.value.output
)
self._cancel_children_rec(self, task)

return self.output # pyright: ignore
7 changes: 6 additions & 1 deletion griptape/utils/structure_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,12 @@ def to_url(self) -> str:
def __render_task(self, task: BaseTask) -> str:
if task.children:
children = " & ".join([f"{self.__get_id(child.id)}({child.id})" for child in task.children])
return f"{self.__get_id(task.id)}({task.id})--> {children};"
from griptape.tasks import ChoiceControlFlowTask

if isinstance(task, ChoiceControlFlowTask):
return f"{self.__get_id(task.id)}{{{task.id}}}-.-> {children};"
else:
return f"{self.__get_id(task.id)}({task.id})--> {children};"
else:
return f"{self.__get_id(task.id)}({task.id});"

Expand Down
Loading

0 comments on commit 9b497da

Please sign in to comment.