diff --git a/README.md b/README.md index e170e56..787bf34 100644 --- a/README.md +++ b/README.md @@ -127,26 +127,23 @@ class ComputeMean(znflow.Node): def run(self): self.results = (self.x + self.y) / 2 -with znflow.DiGraph() as graph: + +client = Client() +deployment = znflow.deployment.DaskDeployment(client=client) + + +with znflow.DiGraph(deployment=deployment) as graph: n1 = ComputeMean(2, 8) n2 = compute_mean(13, 7) # connecting classes and functions to a Node n3 = ComputeMean(n1.results, n2) -client = Client() -deployment = znflow.deployment.Deployment(graph=graph, client=client) -deployment.submit_graph() +graph.run() -n3 = deployment.get_results(n3) print(n3) # >>> ComputeMean(x=5.0, y=10.0, results=7.5) ``` -We need to get the updated instance from the Dask worker via -`Deployment.get_results`. Due to the way Dask works, an inplace update is not -possible. To retrieve the full graph, you can use -`Deployment.get_results(graph.nodes)` instead. - ### Working with lists ZnFlow supports some special features for working with lists. In the following diff --git a/pyproject.toml b/pyproject.toml index bc4f86e..e398d5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "znflow" -version = "0.1.15" +version = "0.2.0a0" description = "A general purpose framework for building and running computational graphs." authors = ["zincwarecode "] license = "Apache-2.0" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..99e4f3e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,20 @@ +import pytest +from distributed.utils_test import ( # noqa: F401 + cleanup, + client, + cluster_fixture, + loop, + loop_in_thread, +) + +import znflow + + +@pytest.fixture +def vanilla_deployment(): + return znflow.deployment.VanillaDeployment() + + +@pytest.fixture +def dask_deployment(client): # noqa: F811 + return znflow.deployment.DaskDeployment(client=client) diff --git a/tests/examples/test_ips_lotf.py b/tests/examples/test_ips_lotf.py new file mode 100644 index 0000000..808d4d4 --- /dev/null +++ b/tests/examples/test_ips_lotf.py @@ -0,0 +1,81 @@ +"""Mock version of IPS LotF workflow for testing purposes.""" + +import dataclasses +import random + +import pytest + +import znflow + + +@dataclasses.dataclass +class AddData(znflow.Node): + file: str + + def run(self): + if self.file is None: + raise ValueError("File is None") + print(f"Adding data from {self.file}") + + @property + def atoms(self): + return "Atoms" + + +@dataclasses.dataclass +class TrainModel(znflow.Node): + data: str + model: str = None + + def run(self): + if self.data is None: + raise ValueError("Data is None") + self.model = "Model" + print(f"Model: {self.model}") + + +@dataclasses.dataclass +class MD(znflow.Node): + model: str + atoms: str = None + + def run(self): + if self.model is None: + raise ValueError("Model is None") + self.atoms = "Atoms" + print(f"Atoms: {self.atoms}") + + +@dataclasses.dataclass +class EvaluateModel(znflow.Node): + model: str + seed: int + metrics: float = None + + def run(self): + random.seed(self.seed) + if self.model is None: + raise ValueError("Model is None") + self.metrics = random.random() + print(f"Metrics: {self.metrics}") + + +@pytest.mark.parametrize("deployment", ["vanilla_deployment", "dask_deployment"]) +def test_lotf(deployment, request): + deployment = request.getfixturevalue(deployment) + + graph = znflow.DiGraph(deployment=deployment) + with graph: + data = AddData(file="data.xyz") + model = TrainModel(data=data.atoms) + md = MD(model=model.model) + metrics = EvaluateModel(model=model.model, seed=0) + for idx in range(10): + model = TrainModel(data=md.atoms) + md = MD(model=model.model) + metrics = EvaluateModel(model=model.model, seed=idx) + if znflow.resolve(metrics.metrics) == pytest.approx(0.623, 1e-3): + # break loop after 6th iteration + break + + assert len(graph) == 22 diff --git a/tests/test_deployment.py b/tests/test_deployment.py index 21eb21b..47b7021 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,6 +1,7 @@ import dataclasses import numpy as np +import pytest import znflow @@ -25,76 +26,94 @@ def add_to_ComputeSum(instance: ComputeSum): return instance.outputs + 1 -def test_single_nodify(): - with znflow.DiGraph() as graph: +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_single_nodify(request, deployment): + deployment = request.getfixturevalue(deployment) + + with znflow.DiGraph(deployment=deployment) as graph: node1 = compute_sum(1, 2, 3) - depl = znflow.deployment.Deployment(graph=graph) - depl.submit_graph() + graph.run() - assert depl.get_results(node1) == 6 + assert node1.result == 6 -def test_single_Node(): - with znflow.DiGraph() as graph: - node1 = ComputeSum(inputs=[1, 2, 3]) +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_single_Node(request, deployment): + deployment = request.getfixturevalue(deployment) - depl = znflow.deployment.Deployment(graph=graph) - depl.submit_graph() + with znflow.DiGraph(deployment=deployment) as graph: + node1 = ComputeSum(inputs=[1, 2, 3]) - node1 = depl.get_results(node1) + graph.run() assert node1.outputs == 6 -def test_multiple_nodify(): - with znflow.DiGraph() as graph: +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_multiple_nodify(request, deployment): + deployment = request.getfixturevalue(deployment) + + with znflow.DiGraph(deployment=deployment) as graph: node1 = compute_sum(1, 2, 3) node2 = compute_sum(4, 5, 6) node3 = compute_sum(node1, node2) - depl = znflow.deployment.Deployment(graph=graph) - depl.submit_graph() + graph.run() + + assert node1.result == 6 + assert node2.result == 15 + assert node3.result == 21 - assert depl.get_results(node1) == 6 - assert depl.get_results(node2) == 15 - assert depl.get_results(node3) == 21 +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_multiple_Node(request, deployment): + deployment = request.getfixturevalue(deployment) -def test_multiple_Node(): - with znflow.DiGraph() as graph: + with znflow.DiGraph(deployment=deployment) as graph: node1 = ComputeSum(inputs=[1, 2, 3]) node2 = ComputeSum(inputs=[4, 5, 6]) node3 = ComputeSum(inputs=[node1.outputs, node2.outputs]) - depl = znflow.deployment.Deployment(graph=graph) - depl.submit_graph() + graph.run() - node1 = depl.get_results(node1) - node2 = depl.get_results(node2) - node3 = depl.get_results(node3) assert node1.outputs == 6 assert node2.outputs == 15 assert node3.outputs == 21 -def test_multiple_nodify_and_Node(): - with znflow.DiGraph() as graph: +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_multiple_nodify_and_Node(request, deployment): + deployment = request.getfixturevalue(deployment) + + with znflow.DiGraph(deployment=deployment) as graph: node1 = compute_sum(1, 2, 3) node2 = ComputeSum(inputs=[4, 5, 6]) node3 = compute_sum(node1, node2.outputs) node4 = ComputeSum(inputs=[node1, node2.outputs, node3]) node5 = add_to_ComputeSum(node4) - depl = znflow.deployment.Deployment(graph=graph) - depl.submit_graph() - - results = depl.get_results(graph.nodes) + graph.run() - assert results[node1.uuid] == 6 - assert results[node2.uuid].outputs == 15 - assert results[node3.uuid] == 21 - assert results[node4.uuid].outputs == 42 - assert results[node5.uuid] == 43 + assert node1.result == 6 + assert node2.outputs == 15 + assert node3.result == 21 + assert node4.outputs == 42 + assert node5.result == 43 @znflow.nodify @@ -107,16 +126,18 @@ def concatenate(forces): return np.concatenate(forces) -def test_concatenate(): - with znflow.DiGraph() as graph: +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_concatenate(request, deployment): + deployment = request.getfixturevalue(deployment) + + with znflow.DiGraph(deployment=deployment) as graph: forces = [get_forces() for _ in range(10)] forces = concatenate(forces) - deployment = znflow.deployment.Deployment( - graph=graph, - ) - deployment.submit_graph() - results = deployment.get_results(forces) + graph.run() - assert isinstance(results, np.ndarray) - assert results.shape == (1000, 3) + assert isinstance(forces.result, np.ndarray) + assert forces.result.shape == (1000, 3) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index db15991..cb5ffa2 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -1,5 +1,7 @@ import dataclasses +import pytest + import znflow @@ -14,9 +16,15 @@ def run(self): self.outputs = self.inputs + 1 -def test_break_loop(): +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_break_loop(request, deployment): """Test loop breaking when output exceeds 5.""" - graph = znflow.DiGraph() + deployment = request.getfixturevalue(deployment) + + graph = znflow.DiGraph(deployment=deployment) with graph: node1 = AddOne(inputs=1) for _ in range(10): @@ -33,9 +41,14 @@ def test_break_loop(): assert node1.outputs == 6 -def test_break_loop_multiple(): +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_break_loop_multiple(request, deployment): """Test loop breaking with multiple nodes and different conditions.""" - graph = znflow.DiGraph() + deployment = request.getfixturevalue(deployment) + graph = znflow.DiGraph(deployment=deployment) with graph: node1 = AddOne(inputs=1) node2 = AddOne(inputs=node1.outputs) # Add another node in the loop @@ -67,10 +80,15 @@ def test_break_loop_multiple(): ) -def test_resolvce_only_run_relevant_nodes(): +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_resolvce_only_run_relevant_nodes(request, deployment): """Test that when using resolve only nodes that are direct predecessors are run.""" # Check by asserting None to the output of the second node - graph = znflow.DiGraph() + deployment = request.getfixturevalue(deployment) + graph = znflow.DiGraph(deployment=deployment) with graph: node1 = AddOne(inputs=1) node2 = AddOne(inputs=1234) @@ -90,8 +108,13 @@ def test_resolvce_only_run_relevant_nodes(): assert node1.outputs == 6 -def test_connections_remain(): - graph = znflow.DiGraph() +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_connections_remain(request, deployment): + deployment = request.getfixturevalue(deployment) + graph = znflow.DiGraph(deployment=deployment) with graph: node1 = AddOne(inputs=1) result = znflow.resolve(node1.outputs) @@ -99,8 +122,13 @@ def test_connections_remain(): assert isinstance(node1.outputs, znflow.Connection) -def test_loop_over_results(): - graph = znflow.DiGraph() +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_loop_over_results(request, deployment): + deployment = request.getfixturevalue(deployment) + graph = znflow.DiGraph(deployment=deployment) with graph: node1 = AddOne(inputs=5) nodes = [] diff --git a/tests/test_external_node.py b/tests/test_external_node.py index f24dff7..1c3b4ab 100644 --- a/tests/test_external_node.py +++ b/tests/test_external_node.py @@ -5,6 +5,8 @@ import dataclasses +import pytest + import znflow @@ -18,8 +20,13 @@ def run(self): self.value = 42 -def test_external_node_run(): - with znflow.DiGraph() as graph: +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment"], +) +def test_external_node_run(deployment, request): + deployment = request.getfixturevalue(deployment) + with znflow.DiGraph(deployment=deployment) as graph: node = NodeWithExternal() graph.run() @@ -71,10 +78,15 @@ def run(self) -> None: self.result = sum(self.inputs) -def test_external_node(): +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment"], +) +def test_external_node(deployment, request): + deployment = request.getfixturevalue(deployment) node = ExternalNode() - with znflow.DiGraph() as graph: + with znflow.DiGraph(deployment=deployment) as graph: add_number = AddNumber(shift=1, input=node.number) graph.run() @@ -83,10 +95,15 @@ def test_external_node(): assert add_number.result == 43 -def test_external_node_from_node(): +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment"], +) +def test_external_node_from_node(deployment, request): + deployment = request.getfixturevalue(deployment) node = ExternalNode() - with znflow.DiGraph() as graph: + with znflow.DiGraph(deployment=deployment) as graph: add_number = AddNumberFromNodes(shift=1, input=node) graph.run() @@ -95,11 +112,16 @@ def test_external_node_from_node(): assert add_number.result == 43 -def test_external_node_lists(): +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment"], +) +def test_external_node_lists(deployment, request): + deployment = request.getfixturevalue(deployment) node1 = ExternalNode() node2 = ExternalNode() - with znflow.DiGraph() as graph: + with znflow.DiGraph(deployment=deployment) as graph: sum_numbers = SumNumbers(inputs=[node1.number, node2.number]) graph.run() diff --git a/tests/test_node.py b/tests/test_node.py index f9c5236..fac0245 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -291,8 +291,13 @@ def run(self): return sum(self.nodes) -def test_DictionaryConnection(): - with znflow.DiGraph() as graph: +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_DictionaryConnection(deployment, request): + deployment = request.getfixturevalue(deployment) + with znflow.DiGraph(deployment=deployment) as graph: node1 = PlainNode(value=42) node2 = PlainNode(value=42) node3 = DictionaryConnection(nodes={"node1": node1.value, "node2": node2.value}) @@ -328,8 +333,13 @@ def test_DictionaryConnection(): assert edge2[0]["v_attr"] == "nodes" -def test_ListConnection(): - with znflow.DiGraph() as graph: +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_ListConnection(deployment, request): + deployment = request.getfixturevalue(deployment) + with znflow.DiGraph(deployment=deployment) as graph: node1 = PlainNode(value=42) node2 = PlainNode(value=42) node3 = ListConnection(nodes=[node1.value, node2.value]) diff --git a/tests/test_node_func.py b/tests/test_node_func.py index 3b8108e..ea46156 100644 --- a/tests/test_node_func.py +++ b/tests/test_node_func.py @@ -1,6 +1,8 @@ import dataclasses import random +import pytest + import znflow @@ -30,8 +32,13 @@ def test_eager(): assert n3 == 0.2903973544626711 -def test_graph(): - with znflow.DiGraph() as graph: +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) +def test_graph(deployment, request): + deployment = request.getfixturevalue(deployment) + with znflow.DiGraph(deployment=deployment) as graph: n1 = random_number(5) n2 = random_number(10) compute_sum = ComputeSum(inputs=[n1, n2]) diff --git a/tests/test_node_postinit.py b/tests/test_node_postinit.py index 936d055..ca5003d 100644 --- a/tests/test_node_postinit.py +++ b/tests/test_node_postinit.py @@ -59,11 +59,16 @@ def test_ConvertInputs(cls): assert node3.result == 3.0 +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment", "dask_deployment"], +) @pytest.mark.parametrize( "cls", [ConvertInputsPlain, ConverInputsZnInit, ConvertInputsDataclass] ) -def test_ConvertInputsNoAttribute(cls): - with znflow.DiGraph() as graph: +def test_ConvertInputsNoAttribute(cls, deployment, request): + deployment = request.getfixturevalue(deployment) + with znflow.DiGraph(deployment=deployment) as graph: node1 = cls(inputs="1") node2 = cls(inputs="2") node3 = compute_sum_inputs(node1, node2) diff --git a/tests/test_recursion.py b/tests/test_recursion.py index 3e6d1eb..3e0bc51 100644 --- a/tests/test_recursion.py +++ b/tests/test_recursion.py @@ -26,10 +26,16 @@ def setrecursionlimit(limit: int): sys.setrecursionlimit(_limit) +@pytest.mark.parametrize( + "deployment", + ["vanilla_deployment"], + # "dask_deployment" struggles with recursion limit +) @pytest.mark.parametrize("depth", [1, 10, 100, 1000]) -def test_AddOneLoop(depth): +def test_AddOneLoop(depth, deployment, request): + deployment = request.getfixturevalue(deployment) with setrecursionlimit(100): - with znflow.DiGraph() as graph: + with znflow.DiGraph(deployment=deployment) as graph: start = AddOne(0) for _ in range(depth): start = AddOne(start.x) diff --git a/tests/test_znflow.py b/tests/test_znflow.py index 618ef94..530eb44 100644 --- a/tests/test_znflow.py +++ b/tests/test_znflow.py @@ -5,4 +5,4 @@ def test_version(): """Test the version.""" - assert znflow.__version__ == "0.1.15" + assert znflow.__version__ == "0.2.0a0" diff --git a/znflow/__init__.py b/znflow/__init__.py index cd119e1..27bff02 100644 --- a/znflow/__init__.py +++ b/znflow/__init__.py @@ -1,11 +1,10 @@ """The 'ZnFlow' package.""" -import contextlib import importlib.metadata import logging import sys -from znflow import exceptions +from znflow import deployment, exceptions from znflow.base import ( CombinedConnections, Connection, @@ -41,13 +40,9 @@ "empty_graph", "resolve", "Group", + "deployment", ] -with contextlib.suppress(ImportError): - from znflow import deployment - - __all__ += ["deployment"] - logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) diff --git a/znflow/deployment.py b/znflow/deployment.py deleted file mode 100644 index ed712e3..0000000 --- a/znflow/deployment.py +++ /dev/null @@ -1,117 +0,0 @@ -"""ZnFlow deployment using Dask.""" - -import dataclasses -import typing -import uuid - -from dask.distributed import Client, Future -from networkx.classes.reportviews import NodeView - -from znflow.graph import DiGraph -from znflow.handler import ( - LoadNodeFromDeploymentResults, - UpdateConnectionsWithPredecessor, -) -from znflow.node import Node - - -def node_submit(node, **kwargs): - """Submit script for Dask worker. - - Parameters - ---------- - node: any - the Node class - kwargs: dict - predecessors: dict of {uuid: Connection} shape - - Returns - ------- - any: - the Node class with updated state (after calling "Node.run"). - - """ - predecessors = kwargs.get("predecessors", {}) - updater = UpdateConnectionsWithPredecessor() - for item in dir(node): - # TODO this information is available in the graph, - # no need to expensively iterate over all attributes - if item.startswith("_"): - continue - value = updater(getattr(node, item), predecessors=predecessors) - if updater.updated: - setattr(node, item, value) - - node.run() - return node - - -@dataclasses.dataclass -class Deployment: - """ZnFlow deployment using Dask. - - Attributes - ---------- - graph: DiGraph - the znflow graph containing the nodes. - client: Client, optional - the Dask client. - results: Dict[uuid, Future] - a dictionary of {uuid: Future} shape that is filled after the graph is submitted. - - """ - - graph: DiGraph - client: Client = dataclasses.field(default_factory=Client) - results: typing.Dict[uuid.UUID, Future] = dataclasses.field( - default_factory=dict, init=False - ) - - def submit_graph(self): - """Submit the graph to Dask. - - When submitting to Dask, a Node is serialized, processed and a - copy can be returned. - - This requires: - - the connections to be updated to the respective Nodes coming from Dask futures. - - the Node to be returned from the workers and passed to all successors. - """ - for node_uuid in self.graph.reverse(): - node = self.graph.nodes[node_uuid]["value"] - predecessors = list(self.graph.predecessors(node.uuid)) - - if len(predecessors) == 0: - self.results[node.uuid] = self.client.submit( # TODO how to name - node_submit, node=node, pure=False - ) - else: - self.results[node.uuid] = self.client.submit( - node_submit, - node=node, - predecessors={ - x: self.results[x] for x in self.results if x in predecessors - }, - pure=False, - ) - - def get_results(self, obj: typing.Union[Node, list, dict, NodeView], /): - """Get the results from Dask based on the original object. - - Parameters - ---------- - obj: any - either a single Node or multiple Nodes from the submitted graph. - - Returns - ------- - any: - Returns an instance of obj which is updated with the results from Dask. - - """ - if isinstance(obj, NodeView): - data = LoadNodeFromDeploymentResults()(dict(obj), results=self.results) - return {x: v["value"] for x, v in data.items()} - elif isinstance(obj, DiGraph): - raise NotImplementedError - return LoadNodeFromDeploymentResults()(obj, results=self.results) diff --git a/znflow/deployment/__init__.py b/znflow/deployment/__init__.py new file mode 100644 index 0000000..e7368f6 --- /dev/null +++ b/znflow/deployment/__init__.py @@ -0,0 +1,10 @@ +import contextlib + +from .vanilla import VanillaDeployment + +__all__ = ["VanillaDeployment"] + +with contextlib.suppress(ImportError): + from .dask_depl import DaskDeployment + + __all__ += ["DaskDeployment"] diff --git a/znflow/deployment/base.py b/znflow/deployment/base.py new file mode 100644 index 0000000..180ec5a --- /dev/null +++ b/znflow/deployment/base.py @@ -0,0 +1,26 @@ +import abc +import typing as t + +if t.TYPE_CHECKING: + from znflow.graph import DiGraph + + +class DeploymentBase(abc.ABC): + graph: "DiGraph" + + def run(self, nodes: t.Optional[t.List] = None): + if nodes is None: + nodes = self.graph.get_sorted_nodes() + else: + # convert nodes to UUIDs + nodes = [node.uuid for node in nodes] + + for node_uuid in nodes: + self._run_node(node_uuid) + + def set_graph(self, graph: "DiGraph"): + self.graph = graph + + @abc.abstractmethod + def _run_node(self, node_uuid): + pass diff --git a/znflow/deployment/dask_depl.py b/znflow/deployment/dask_depl.py new file mode 100644 index 0000000..a0c16d3 --- /dev/null +++ b/znflow/deployment/dask_depl.py @@ -0,0 +1,101 @@ +"""ZnFlow deployment using Dask.""" + +import dataclasses +import typing +import typing as t +import uuid + +from dask.distributed import Client, Future + +from znflow import handler +from znflow.handler import UpdateConnectionsWithPredecessor +from znflow.node import Node + +from .base import DeploymentBase + +if typing.TYPE_CHECKING: + pass + + +def node_submit(node, **kwargs): + """Submit script for Dask worker. + + Parameters + ---------- + node: any + the Node class + kwargs: dict + predecessors: dict of {uuid: Connection} shape + + Returns + ------- + any: + the Node class with updated state (after calling "Node.run"). + + """ + predecessors = kwargs.get("predecessors", {}) + updater = UpdateConnectionsWithPredecessor() + for item in dir(node): + # TODO this information is available in the graph, + # no need to expensively iterate over all attributes + if item.startswith("_"): + continue + value = updater(getattr(node, item), predecessors=predecessors) + if updater.updated: + setattr(node, item, value) + + node.run() + return node + + +# TODO: release the future objects +@dataclasses.dataclass +class DaskDeployment(DeploymentBase): + client: Client = dataclasses.field(default_factory=Client) + results: typing.Dict[uuid.UUID, Future] = dataclasses.field( + default_factory=dict, init=False + ) + + def run(self, nodes: t.Optional[list] = None): + super().run(nodes) + self._load_results() + + def _run_node(self, node_uuid): + node = self.graph.nodes[node_uuid]["value"] + predecessors = list(self.graph.predecessors(node_uuid)) + for predecessor in predecessors: + predecessor_available = self.graph.nodes[predecessor].get("available", False) + if self.graph.immutable_nodes and predecessor_available: + continue + self._run_node(predecessor) + + node_available = self.graph.nodes[node_uuid].get("available", False) + if self.graph.immutable_nodes and node_available: + return + if node._external_: + raise NotImplementedError( + "External nodes are not supported in Dask deployment" + ) + + self.results[node_uuid] = self.client.submit( + node_submit, + node=node, + predecessors={x: self.results[x] for x in self.results if x in predecessors}, + pure=False, + key=f"{node.__class__.__name__}-{node_uuid}", + ) + self.graph.nodes[node_uuid]["available"] = True + + def _load_results(self): + # TODO: only load nodes that have actually changed + for node_uuid in self.graph.reverse(): + node = self.graph.nodes[node_uuid]["value"] + try: + result = self.results[node.uuid].result() + if isinstance(node, Node): + node.__dict__.update(result.__dict__) + self.graph._update_node_attributes(node, handler.UpdateConnectors()) + else: + node.result = result.result + except KeyError: + pass diff --git a/znflow/deployment/vanilla.py b/znflow/deployment/vanilla.py new file mode 100644 index 0000000..4b3f714 --- /dev/null +++ b/znflow/deployment/vanilla.py @@ -0,0 +1,28 @@ +import dataclasses + +from znflow import handler + +from .base import DeploymentBase + + +@dataclasses.dataclass +class VanillaDeployment(DeploymentBase): + + def _run_node(self, node_uuid): + node = self.graph.nodes[node_uuid]["value"] + predecessors = list(self.graph.predecessors(node_uuid)) + for predecessor in predecessors: + predecessor_available = self.graph.nodes[predecessor].get("available", False) + if self.graph.immutable_nodes and predecessor_available: + continue + self._run_node(predecessor) + + node_available = self.graph.nodes[node_uuid].get("available", False) + if self.graph.immutable_nodes and node_available: + return + if node._external_: + return + + self.graph._update_node_attributes(node, handler.UpdateConnectors()) + node.run() + self.graph.nodes[node_uuid]["available"] = True diff --git a/znflow/graph.py b/znflow/graph.py index e432fb4..af7de92 100644 --- a/znflow/graph.py +++ b/znflow/graph.py @@ -16,6 +16,7 @@ get_graph, set_graph, ) +from znflow.deployment import VanillaDeployment from znflow.node import Node log = logging.getLogger(__name__) @@ -45,7 +46,9 @@ def nodes(self) -> typing.List[NodeBaseMixin]: class DiGraph(nx.MultiDiGraph): - def __init__(self, *args, disable=False, immutable_nodes=True, **kwargs): + def __init__( + self, *args, disable=False, immutable_nodes=True, deployment=None, **kwargs + ): """ Attributes ---------- @@ -58,6 +61,8 @@ def __init__(self, *args, disable=False, immutable_nodes=True, **kwargs): self.immutable_nodes = immutable_nodes self._groups = {} self.active_group: typing.Union[Group, None] = None + self.deployment = deployment or VanillaDeployment() + self.deployment.set_graph(self) super().__init__(*args, **kwargs) @@ -190,40 +195,7 @@ def run( nodes : list[Node] The nodes to run. If None, all nodes are run. """ - if nodes is not None: - for node_uuid in self.reverse(): - if self.immutable_nodes and self.nodes[node_uuid].get("available", False): - continue - node = self.nodes[node_uuid]["value"] - if node in nodes: - predecessors = list(self.predecessors(node.uuid)) - for predecessor in predecessors: - predecessor_node = self.nodes[predecessor]["value"] - if self.immutable_nodes and self.nodes[predecessor].get( - "available", False - ): - continue - self._update_node_attributes( - predecessor_node, handler.UpdateConnectors() - ) - predecessor_node.run() - if self.immutable_nodes: - self.nodes[predecessor]["available"] = True - self._update_node_attributes(node, handler.UpdateConnectors()) - node.run() - if self.immutable_nodes: - self.nodes[node_uuid]["available"] = True - else: - for node_uuid in self.get_sorted_nodes(): - if self.immutable_nodes and self.nodes[node_uuid].get("available", False): - continue - node = self.nodes[node_uuid]["value"] - if not node._external_: - # update connectors - self._update_node_attributes(node, handler.UpdateConnectors()) - node.run() - if self.immutable_nodes: - self.nodes[node_uuid]["available"] = True + self.deployment.run(nodes) def write_graph(self, *args): for node in args: