From ee5e46e355bc5ff6975194068c8938f75f402b77 Mon Sep 17 00:00:00 2001 From: Fabian Zills <46721498+PythonFZ@users.noreply.github.com> Date: Tue, 28 Mar 2023 20:58:14 +0100 Subject: [PATCH] fix unexpected output with deployment (#74) * fix futures * format imports --- tests/test_deployment.py | 45 ++++++++++++++++++++++++++++++---------- znflow/deployment.py | 36 +++++++++++++++----------------- 2 files changed, 51 insertions(+), 30 deletions(-) diff --git a/tests/test_deployment.py b/tests/test_deployment.py index c77786f..21eb21b 100644 --- a/tests/test_deployment.py +++ b/tests/test_deployment.py @@ -1,5 +1,7 @@ import dataclasses +import numpy as np + import znflow @@ -30,8 +32,7 @@ def test_single_nodify(): depl = znflow.deployment.Deployment(graph=graph) depl.submit_graph() - node1 = depl.get_results(node1) - assert node1.result == 6 + assert depl.get_results(node1) == 6 def test_single_Node(): @@ -54,12 +55,9 @@ def test_multiple_nodify(): depl = znflow.deployment.Deployment(graph=graph) depl.submit_graph() - node1 = depl.get_results(node1) - node2 = depl.get_results(node2) - node3 = depl.get_results(node3) - assert node1.result == 6 - assert node2.result == 15 - assert node3.result == 21 + assert depl.get_results(node1) == 6 + assert depl.get_results(node2) == 15 + assert depl.get_results(node3) == 21 def test_multiple_Node(): @@ -92,8 +90,33 @@ def test_multiple_nodify_and_Node(): results = depl.get_results(graph.nodes) - assert results[node1.uuid].result == 6 + assert results[node1.uuid] == 6 assert results[node2.uuid].outputs == 15 - assert results[node3.uuid].result == 21 + assert results[node3.uuid] == 21 assert results[node4.uuid].outputs == 42 - assert results[node5.uuid].result == 43 + assert results[node5.uuid] == 43 + + +@znflow.nodify +def get_forces(): + return np.random.normal(size=(100, 3)) + + +@znflow.nodify +def concatenate(forces): + return np.concatenate(forces) + + +def test_concatenate(): + with znflow.DiGraph() as graph: + forces = [get_forces() for _ in range(10)] + forces = concatenate(forces) + + deployment = znflow.deployment.Deployment( + graph=graph, + ) + deployment.submit_graph() + results = deployment.get_results(forces) + + assert isinstance(results, np.ndarray) + assert results.shape == (1000, 3) diff --git a/znflow/deployment.py b/znflow/deployment.py index c236b3e..042a38d 100644 --- a/znflow/deployment.py +++ b/znflow/deployment.py @@ -7,8 +7,9 @@ from dask.distributed import Client, Future from networkx.classes.reportviews import NodeView -from znflow.base import Connection, NodeBaseMixin +from znflow.base import CombinedConnections, Connection, FunctionFuture from znflow.graph import DiGraph +from znflow.node import Node from znflow.utils import IterableHandler @@ -20,23 +21,20 @@ def default(self, value, **kwargs): Parameters ---------- - value: NodeBaseMixin|any - If a NodeBaseMixin, the node will be loaded and returned. + value: any + the value to be loaded from the results dict kwargs: dict - results: results dictionary of {uuid: node} shape. - - Returns - ------- - any: - If a NodeBaseMixin, the node will be loaded and returned. - Otherwise, the input value is returned. - + results: results dictionary of {uuid: Future} shape. """ results = kwargs["results"] - if isinstance(value, NodeBaseMixin): - return results[value.uuid].result() - return value + if isinstance(value, Node): + # results: dict[uuid, DaskFuture] + return results[value.uuid].result() + elif isinstance(value, (FunctionFuture, CombinedConnections, Connection)): + return results[value.uuid].result().result + else: + return value class _UpdateConnections(IterableHandler): @@ -66,19 +64,19 @@ def default(self, value, **kwargs): return value -def node_submit(node: NodeBaseMixin, **kwargs) -> NodeBaseMixin: +def node_submit(node, **kwargs): """Submit script for Dask worker. Parameters ---------- - node: NodeBaseMixin + node: any the Node class kwargs: dict predecessors: dict of {uuid: Connection} shape Returns ------- - NodeBaseMixin: + any: the Node class with updated state (after calling "Node.run"). """ @@ -146,12 +144,12 @@ def submit_graph(self): pure=False, ) - def get_results(self, obj: typing.Union[NodeBaseMixin, list, dict, NodeView], /): + def get_results(self, obj: typing.Union[Node, list, dict, NodeView], /): """Get the results from Dask based on the original object. Parameters ---------- - obj: NodeBaseMixin|list|dict|NodeView + obj: any either a single Node or multiple Nodes from the submitted graph. Returns