Skip to content

Commit

Permalink
fix unexpected output with deployment (#74)
Browse files Browse the repository at this point in the history
* fix futures

* format imports
  • Loading branch information
PythonFZ authored Mar 28, 2023
1 parent be4a986 commit ee5e46e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 30 deletions.
45 changes: 34 additions & 11 deletions tests/test_deployment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import dataclasses

import numpy as np

import znflow


Expand Down Expand Up @@ -30,8 +32,7 @@ def test_single_nodify():
depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()

node1 = depl.get_results(node1)
assert node1.result == 6
assert depl.get_results(node1) == 6


def test_single_Node():
Expand All @@ -54,12 +55,9 @@ def test_multiple_nodify():
depl = znflow.deployment.Deployment(graph=graph)
depl.submit_graph()

node1 = depl.get_results(node1)
node2 = depl.get_results(node2)
node3 = depl.get_results(node3)
assert node1.result == 6
assert node2.result == 15
assert node3.result == 21
assert depl.get_results(node1) == 6
assert depl.get_results(node2) == 15
assert depl.get_results(node3) == 21


def test_multiple_Node():
Expand Down Expand Up @@ -92,8 +90,33 @@ def test_multiple_nodify_and_Node():

results = depl.get_results(graph.nodes)

assert results[node1.uuid].result == 6
assert results[node1.uuid] == 6
assert results[node2.uuid].outputs == 15
assert results[node3.uuid].result == 21
assert results[node3.uuid] == 21
assert results[node4.uuid].outputs == 42
assert results[node5.uuid].result == 43
assert results[node5.uuid] == 43


@znflow.nodify
def get_forces():
return np.random.normal(size=(100, 3))


@znflow.nodify
def concatenate(forces):
return np.concatenate(forces)


def test_concatenate():
with znflow.DiGraph() as graph:
forces = [get_forces() for _ in range(10)]
forces = concatenate(forces)

deployment = znflow.deployment.Deployment(
graph=graph,
)
deployment.submit_graph()
results = deployment.get_results(forces)

assert isinstance(results, np.ndarray)
assert results.shape == (1000, 3)
36 changes: 17 additions & 19 deletions znflow/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from dask.distributed import Client, Future
from networkx.classes.reportviews import NodeView

from znflow.base import Connection, NodeBaseMixin
from znflow.base import CombinedConnections, Connection, FunctionFuture
from znflow.graph import DiGraph
from znflow.node import Node
from znflow.utils import IterableHandler


Expand All @@ -20,23 +21,20 @@ def default(self, value, **kwargs):
Parameters
----------
value: NodeBaseMixin|any
If a NodeBaseMixin, the node will be loaded and returned.
value: any
the value to be loaded from the results dict
kwargs: dict
results: results dictionary of {uuid: node} shape.
Returns
-------
any:
If a NodeBaseMixin, the node will be loaded and returned.
Otherwise, the input value is returned.
results: results dictionary of {uuid: Future} shape.
"""
results = kwargs["results"]
if isinstance(value, NodeBaseMixin):
return results[value.uuid].result()

return value
if isinstance(value, Node):
# results: dict[uuid, DaskFuture]
return results[value.uuid].result()
elif isinstance(value, (FunctionFuture, CombinedConnections, Connection)):
return results[value.uuid].result().result
else:
return value


class _UpdateConnections(IterableHandler):
Expand Down Expand Up @@ -66,19 +64,19 @@ def default(self, value, **kwargs):
return value


def node_submit(node: NodeBaseMixin, **kwargs) -> NodeBaseMixin:
def node_submit(node, **kwargs):
"""Submit script for Dask worker.
Parameters
----------
node: NodeBaseMixin
node: any
the Node class
kwargs: dict
predecessors: dict of {uuid: Connection} shape
Returns
-------
NodeBaseMixin:
any:
the Node class with updated state (after calling "Node.run").
"""
Expand Down Expand Up @@ -146,12 +144,12 @@ def submit_graph(self):
pure=False,
)

def get_results(self, obj: typing.Union[NodeBaseMixin, list, dict, NodeView], /):
def get_results(self, obj: typing.Union[Node, list, dict, NodeView], /):
"""Get the results from Dask based on the original object.
Parameters
----------
obj: NodeBaseMixin|list|dict|NodeView
obj: any
either a single Node or multiple Nodes from the submitted graph.
Returns
Expand Down

0 comments on commit ee5e46e

Please sign in to comment.