-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
377 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from attrs import define | ||
from griptape.artifacts import BaseArtifact | ||
|
||
|
||
@define | ||
class ControlFlowArtifact(BaseArtifact): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from __future__ import annotations | ||
from attrs import define, field | ||
from typing import TYPE_CHECKING | ||
from griptape.artifacts import ControlFlowArtifact | ||
|
||
if TYPE_CHECKING: | ||
from griptape.tasks import BaseTask | ||
from griptape.artifacts import BaseArtifact | ||
|
||
|
||
@define | ||
class TaskArtifact(ControlFlowArtifact): | ||
value: BaseTask = field(metadata={"serializable": True}) | ||
|
||
@property | ||
def task_id(self) -> str: | ||
return self.value.id | ||
|
||
@property | ||
def task(self) -> BaseTask: | ||
return self.value | ||
|
||
def to_text(self) -> str: | ||
return self.value.id | ||
|
||
def __add__(self, other: BaseArtifact) -> BaseArtifact: | ||
raise NotImplementedError("TaskArtifact does not support addition") | ||
|
||
def __eq__(self, value: object) -> bool: | ||
return self.value is value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .base_meta_entry import BaseMetaEntry | ||
from .action_subtask_meta_entry import ActionSubtaskMetaEntry | ||
from .control_flow_meta_entry import ControlFlowMetaEntry | ||
from .meta_memory import MetaMemory | ||
|
||
__all__ = ["BaseMetaEntry", "MetaMemory", "ActionSubtaskMetaEntry"] | ||
__all__ = ["BaseMetaEntry", "MetaMemory", "ActionSubtaskMetaEntry", "ControlFlowMetaEntry"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from __future__ import annotations | ||
from typing import TYPE_CHECKING, Optional | ||
from attrs import field, define | ||
from griptape.memory.meta import BaseMetaEntry | ||
|
||
if TYPE_CHECKING: | ||
from griptape.artifacts import BaseArtifact | ||
|
||
|
||
@define | ||
class ControlFlowMetaEntry(BaseMetaEntry): | ||
type: str = field(default=__name__, kw_only=True, metadata={"serializable": False}) | ||
input_tasks: list[str] = field(factory=list, kw_only=True) | ||
output_tasks: list[str] = field(factory=list, kw_only=True) | ||
output: Optional[BaseArtifact] = field(default=None, kw_only=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from __future__ import annotations | ||
from abc import ABC | ||
from attrs import define | ||
from griptape.tasks import BaseTask | ||
from griptape.memory.meta import ControlFlowMetaEntry | ||
|
||
|
||
@define | ||
class BaseControlFlowTask(BaseTask, ABC): | ||
def before_run(self) -> None: | ||
super().before_run() | ||
|
||
self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}") | ||
|
||
def after_run(self) -> None: | ||
super().after_run() | ||
|
||
self.structure.meta_memory.add_entry( | ||
ControlFlowMetaEntry( | ||
input_tasks=[parent.id for parent in self.parents], | ||
output_tasks=[child.id for child in filter(lambda child: not child.is_finished(), self.children)], | ||
output=self.output, | ||
) | ||
) | ||
|
||
self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") | ||
|
||
def _cancel_children_rec(self, task: BaseTask, chosen_task: BaseTask) -> None: | ||
for child in filter(lambda child: child != chosen_task, task.children): | ||
if all(parent.is_complete() for parent in filter(lambda parent: parent != task, child.parents)): | ||
child.state = BaseTask.State.CANCELLED | ||
self._cancel_children_rec(child, chosen_task) | ||
|
||
def _get_task(self, task: str | BaseTask) -> BaseTask: | ||
return self.structure.find_task(task) if isinstance(task, str) else task |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from __future__ import annotations | ||
from typing import TYPE_CHECKING, Literal, Union | ||
from attrs import field | ||
|
||
from griptape.artifacts import BooleanArtifact | ||
from griptape.tasks import BaseControlFlowTask | ||
|
||
if TYPE_CHECKING: | ||
from griptape.tasks import BaseTask | ||
|
||
|
||
class BooleanControlFlowTask(BaseControlFlowTask): | ||
true_tasks: list[str | BaseTask] = field(factory=list, kw_only=True) | ||
false_tasks: list[str | BaseTask] = field(factory=list, kw_only=True) | ||
operator: Union[Literal["and"], Literal["or"], Literal["xor"]] = field(default="and", kw_only=True) | ||
coerce_inputs_to_bool: bool = field(default=False, kw_only=True) | ||
|
||
def run(self) -> BooleanArtifact: | ||
if not all( | ||
choice_task if isinstance(choice_task, str) else choice_task.id in self.child_ids | ||
for choice_task in [*self.true_tasks, *self.false_tasks] | ||
): | ||
raise ValueError(f"BooleanControlFlowTask {self.id} has invalid true_tasks or false_tasks") | ||
|
||
inputs = [task.output for task in self.parents] | ||
|
||
if self.coerce_inputs_to_bool: | ||
inputs = [BooleanArtifact(input) for input in inputs] | ||
else: | ||
if not all(isinstance(input, BooleanArtifact) for input in inputs): | ||
raise ValueError(f"BooleanControlFlowTask {self.id} has non-BooleanArtifact inputs") | ||
|
||
if self.operator == "and": | ||
self.output = BooleanArtifact(all(inputs)) | ||
elif self.operator == "or": | ||
self.output = BooleanArtifact(any(inputs)) | ||
elif self.operator == "xor": | ||
self.output = BooleanArtifact(sum([int(input.value) for input in inputs]) == 1) | ||
else: | ||
raise ValueError(f"BooleanControlFlowTask {self.id} has invalid operator {self.operator}") | ||
|
||
for task in self.true_tasks if self.output.value else self.false_tasks: | ||
task = self._get_task(task) | ||
self._cancel_children_rec(self, task) | ||
return self.output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from __future__ import annotations | ||
from typing import Callable | ||
from attrs import define, field | ||
|
||
from griptape.artifacts import BaseArtifact, ErrorArtifact, TaskArtifact, ListArtifact | ||
from griptape.tasks import BaseTask | ||
from griptape.tasks import BaseControlFlowTask | ||
|
||
|
||
@define | ||
class ChoiceControlFlowTask(BaseControlFlowTask): | ||
control_flow_fn: Callable[[list[BaseTask] | BaseTask], list[BaseTask | str] | BaseTask | str] = field( | ||
metadata={"serializable": False} | ||
) | ||
|
||
@property | ||
def input(self) -> BaseArtifact: | ||
if len(self.parents) == 1: | ||
return TaskArtifact(self.parents[0]) | ||
return ListArtifact([TaskArtifact(parent) for parent in self.parents]) | ||
|
||
def before_run(self) -> None: | ||
super().before_run() | ||
|
||
self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nInput: {self.input.to_text()}") | ||
|
||
def after_run(self) -> None: | ||
super().after_run() | ||
|
||
self.structure.logger.info(f"{self.__class__.__name__} {self.id}\nOutput: {self.output.to_text()}") | ||
|
||
def run(self) -> BaseArtifact: | ||
tasks = self.control_flow_fn( | ||
[artifact.value for artifact in self.input.value] | ||
if isinstance(self.input, ListArtifact) | ||
else self.input.value | ||
) | ||
|
||
if not isinstance(tasks, list): | ||
tasks = [tasks] | ||
|
||
if tasks is None: | ||
tasks = [] | ||
|
||
tasks = [self._get_task(task) for task in tasks] | ||
|
||
for task in tasks: | ||
if task.id not in self.child_ids: | ||
self.output = ErrorArtifact(f"ControlFlowTask {self.id} did not return a valid child task") | ||
return self.output | ||
|
||
self.output = ( | ||
ListArtifact( | ||
[ | ||
parent.value.output | ||
for parent in filter(lambda parent: parent.value.output is not None, self.input.value) | ||
] # pyright: ignore | ||
) | ||
if isinstance(self.input, ListArtifact) | ||
else self.input.value.output | ||
) | ||
self._cancel_children_rec(self, task) | ||
|
||
return self.output # pyright: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.