Skip to content

Commit

Permalink
Add graph.group(<name>) (#89)
Browse files Browse the repository at this point in the history
* add graph.group()

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tests + docstrings

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* expose empty_graph

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bugfix

* prepare for release 0.1.14

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Sep 21, 2023
1 parent e4d9551 commit 60e0f58
Show file tree
Hide file tree
Showing 9 changed files with 343 additions and 28 deletions.
26 changes: 24 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ print(n3.results)
# >>> 7.5
```

## Dask Support
### Dask Support

ZnFlow comes with support for [Dask](https://www.dask.org/) to run your graph:

Expand Down Expand Up @@ -232,7 +232,29 @@ Instead, you can also use the `znflow.disable_graph` decorator / context manager
to disable the graph for a specific block of code or the `znflow.Property` as a
drop-in replacement for `property`.

# Supported Frameworks
### Groups

It is possible to create groups of `znflow.nodify` or `znflow.Nodes` independent
from the graph structure. To create a group you can use
`with graph.group(<name>)`. To access the group members, use
`graph.get_group(<name>) -> list`.

```python
import znflow

@znflow.nodify
def compute_mean(x, y):
return (x + y) / 2

graph = znflow.DiGraph()

with graph.group("grp1"):
n1 = compute_mean(2, 4)

assert n1.uuid in graph.get_group("grp1")
```

## Supported Frameworks

ZnFlow includes tests to ensure compatibility with:

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "znflow"
version = "0.1.13"
version = "0.1.14"
description = "A general purpose framework for building and running computational graphs."
authors = ["zincwarecode <zincwarecode@gmail.com>"]
license = "Apache-2.0"
Expand Down
7 changes: 3 additions & 4 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import zninit

import znflow
from znflow.base import empty


class PlainNode(znflow.Node):
Expand Down Expand Up @@ -76,7 +75,7 @@ def test_changed_graph():
with pytest.raises(ValueError):
with znflow.DiGraph():
znflow.base.NodeBaseMixin._graph_ = znflow.DiGraph()
znflow.base.NodeBaseMixin._graph_ = empty # reset after test
znflow.base.NodeBaseMixin._graph_ = znflow.empty_graph # reset after test


def test_add_others():
Expand Down Expand Up @@ -159,9 +158,9 @@ def test_disable_graph():
node1 = DataclassNode(value=42)
assert node1._graph_ is graph
with znflow.base.disable_graph():
assert node1._graph_ is empty
assert node1._graph_ is znflow.empty_graph
assert node1._graph_ is graph
assert node1._graph_ is empty
assert node1._graph_ is znflow.empty_graph


def test_get_attribute():
Expand Down
221 changes: 221 additions & 0 deletions tests/test_graph_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import pytest

import znflow


class PlainNode(znflow.Node):
def __init__(self, value):
self.value = value

def run(self):
self.value += 1


def test_empty_grp_name():
graph = znflow.DiGraph()

with pytest.raises(TypeError):
with graph.group(): # name required
pass


def test_grp():
graph = znflow.DiGraph()

assert graph.active_group is None

with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name

node = PlainNode(1)

assert graph.active_group is None
graph.run()

assert grp_name == "my_grp"
assert node.value == 2
assert node.uuid in graph.nodes
assert grp_name in graph._groups
assert graph.get_group(grp_name) == [node.uuid]

assert len(graph._groups) == 1
assert len(graph) == 1


def test_muliple_grps():
graph = znflow.DiGraph()

assert graph.active_group is None

with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name

node = PlainNode(1)

assert graph.active_group is None

with graph.group("my_grp2") as grp_name2:
assert graph.active_group == grp_name2

node2 = PlainNode(2)

assert graph.active_group is None

graph.run()

assert grp_name == "my_grp"
assert grp_name2 == "my_grp2"

assert node.value == 2
assert node2.value == 3

assert node.uuid in graph.nodes
assert node2.uuid in graph.nodes

assert grp_name in graph._groups
assert grp_name2 in graph._groups

assert graph.get_group(grp_name) == [node.uuid]
assert graph.get_group(grp_name2) == [node2.uuid]

assert len(graph._groups) == 2
assert len(graph) == 2


def test_nested_grps():
graph = znflow.DiGraph()

with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name
with pytest.raises(TypeError):
with graph.group("my_grp2"):
pass


def test_grp_with_existing_nodes():
with znflow.DiGraph() as graph:
node = PlainNode(1)

with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name

node2 = PlainNode(2)

assert graph.active_group is None

graph.run()

assert grp_name == "my_grp"

assert node.value == 2
assert node2.value == 3

assert node.uuid in graph.nodes
assert node2.uuid in graph.nodes

assert grp_name in graph._groups

assert graph.get_group(grp_name) == [node2.uuid]

assert len(graph._groups) == 1
assert len(graph) == 2


def test_grp_with_multiple_nodes():
with znflow.DiGraph() as graph:
node = PlainNode(1)
node2 = PlainNode(2)

with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name

node3 = PlainNode(3)
node4 = PlainNode(4)

assert graph.active_group is None

graph.run()

assert grp_name == "my_grp"

assert node.value == 2
assert node2.value == 3
assert node3.value == 4
assert node4.value == 5

assert node.uuid in graph.nodes
assert node2.uuid in graph.nodes
assert node3.uuid in graph.nodes
assert node4.uuid in graph.nodes

assert grp_name in graph._groups

assert graph.get_group(grp_name) == [node3.uuid, node4.uuid]

assert len(graph._groups) == 1
assert len(graph) == 4


def test_reopen_grps():
with znflow.DiGraph() as graph:
with graph.group("my_grp") as grp_name:
assert graph.active_group == grp_name

node = PlainNode(1)

with graph.group("my_grp") as grp_name2:
assert graph.active_group == grp_name2

node2 = PlainNode(2)

assert graph.active_group is None

graph.run()

assert grp_name == "my_grp"
assert grp_name2 == grp_name

assert node.value == 2
assert node2.value == 3

assert node.uuid in graph.nodes
assert node2.uuid in graph.nodes

assert grp_name in graph._groups

assert graph.get_group(grp_name) == [node.uuid, node2.uuid]

assert len(graph._groups) == 1
assert len(graph) == 2


def test_tuple_grp_names():
graph = znflow.DiGraph()

assert graph.active_group is None
with graph.group(("grp", "1")) as grp_name:
assert graph.active_group == grp_name

node = PlainNode(1)

assert graph.active_group is None
graph.run()

assert grp_name == ("grp", "1")
assert node.value == 2
assert node.uuid in graph.nodes
assert grp_name in graph._groups
assert graph.get_group(grp_name) == [node.uuid]


def test_grp_nodify():
@znflow.nodify
def compute_mean(x, y):
return (x + y) / 2

graph = znflow.DiGraph()

with graph.group("grp1"):
n1 = compute_mean(2, 4)

assert n1.uuid in graph.get_group("grp1")
2 changes: 1 addition & 1 deletion tests/test_znflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

def test_version():
"""Test the version."""
assert znflow.__version__ == "0.1.13"
assert znflow.__version__ == "0.1.14"
4 changes: 4 additions & 0 deletions znflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
FunctionFuture,
Property,
disable_graph,
empty_graph,
get_attribute,
get_graph,
)
from znflow.combine import combine
from znflow.graph import DiGraph
Expand All @@ -33,6 +35,8 @@
"CombinedConnections",
"combine",
"exceptions",
"get_graph",
"empty_graph",
]

with contextlib.suppress(ImportError):
Expand Down
37 changes: 25 additions & 12 deletions znflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def disable_graph(*args, **kwargs):
This can be useful, if you e.g. want to use 'get_attribute'.
"""
graph = get_graph()
set_graph(empty)
set_graph(empty_graph)
try:
yield
finally:
Expand Down Expand Up @@ -77,11 +77,11 @@ def deleter(self, fdel):


@dataclasses.dataclass(frozen=True)
class Empty:
class EmptyGraph:
"""An empty class used as a default value for _graph_."""


empty = Empty()
empty_graph = EmptyGraph()


class NodeBaseMixin:
Expand All @@ -91,11 +91,18 @@ class NodeBaseMixin:
Attributes
----------
_graph_ : DiGraph
uuid : UUID
_graph_ : DiGraph
The graph this node belongs to.
This is only available within the graph context.
uuid : UUID
The unique identifier of this node.
_external_ : bool
If true, the node is allowed to be created outside of a graph context.
In this case connections can be created to this node, otherwise
an exception is raised.
"""

_graph_ = empty
_graph_ = empty_graph
_external_ = False
_uuid: UUID = None

Expand Down Expand Up @@ -141,15 +148,21 @@ def get_attribute(obj, name, default=_get_attribute_none):
@dataclasses.dataclass(frozen=True)
class Connection:
"""A Connector for Nodes.
instance: either a Node or FunctionFuture
attribute:
Node.attribute
or FunctionFuture.result
or None if the class is passed and not an attribute
Attributes
----------
instance: Node|FunctionFuture
the object this connection points to
attribute: str
Node.attribute
or FunctionFuture.result
or None if the class is passed and not an attribute
item: any
any slice or list index to be applied to the result
"""

instance: any
attribute: any
attribute: str
item: any = None

def __post_init__(self):
Expand Down
Loading

0 comments on commit 60e0f58

Please sign in to comment.