-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding dask support * improve dask support * add IterableHandler.updated * support NodeView, remove support for graph * add TODO * install all extras * fix + add tests for deployment * update README.md * move bokeh to extras * update documentation * update documentation * poetry update * isort * typo * bump version
- Loading branch information
Showing
9 changed files
with
659 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "znflow" | ||
version = "0.1.9" | ||
version = "0.1.10" | ||
description = "A general purpose framework for building and running computational graphs." | ||
authors = ["zincwarecode <[email protected]>"] | ||
license = "Apache-2.0" | ||
|
@@ -11,6 +11,11 @@ python = "^3.8" | |
networkx = "^3.0" | ||
matplotlib = "^3.6.3" | ||
|
||
dask = { version = "^2022.12.1", optional = true } | ||
distributed = { version = "^2022.12.1", optional = true } | ||
dask-jobqueue = { version = "^0.8.1", optional = true } | ||
bokeh = { version = "^2.4.2", optional = true } | ||
|
||
[tool.poetry.group.lint.dependencies] | ||
black = "^22.10.0" | ||
isort = "^5.10.1" | ||
|
@@ -25,6 +30,10 @@ attrs = "^22.2.0" | |
[tool.poetry.group.notebook.dependencies] | ||
jupyterlab = "^3.5.1" | ||
|
||
[tool.poetry.extras] | ||
dask = ["dask", "distributed", "dask-jobqueue", "bokeh"] | ||
|
||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import dataclasses | ||
|
||
import znflow | ||
|
||
|
||
@znflow.nodify | ||
def compute_sum(*args): | ||
return sum(args) | ||
|
||
|
||
@dataclasses.dataclass | ||
class ComputeSum(znflow.Node): | ||
inputs: list | ||
outputs: float = None | ||
|
||
def run(self): | ||
# this will just call the function compute_sum and won't construct a graph! | ||
self.outputs = compute_sum(*self.inputs) | ||
|
||
|
||
@znflow.nodify | ||
def add_to_ComputeSum(instance: ComputeSum): | ||
return instance.outputs + 1 | ||
|
||
|
||
def test_single_nodify(): | ||
with znflow.DiGraph() as graph: | ||
node1 = compute_sum(1, 2, 3) | ||
|
||
depl = znflow.deployment.Deployment(graph=graph) | ||
depl.submit_graph() | ||
|
||
node1 = depl.get_results(node1) | ||
assert node1.result == 6 | ||
|
||
|
||
def test_single_Node(): | ||
with znflow.DiGraph() as graph: | ||
node1 = ComputeSum(inputs=[1, 2, 3]) | ||
|
||
depl = znflow.deployment.Deployment(graph=graph) | ||
depl.submit_graph() | ||
|
||
node1 = depl.get_results(node1) | ||
assert node1.outputs == 6 | ||
|
||
|
||
def test_multiple_nodify(): | ||
with znflow.DiGraph() as graph: | ||
node1 = compute_sum(1, 2, 3) | ||
node2 = compute_sum(4, 5, 6) | ||
node3 = compute_sum(node1, node2) | ||
|
||
depl = znflow.deployment.Deployment(graph=graph) | ||
depl.submit_graph() | ||
|
||
node1 = depl.get_results(node1) | ||
node2 = depl.get_results(node2) | ||
node3 = depl.get_results(node3) | ||
assert node1.result == 6 | ||
assert node2.result == 15 | ||
assert node3.result == 21 | ||
|
||
|
||
def test_multiple_Node(): | ||
with znflow.DiGraph() as graph: | ||
node1 = ComputeSum(inputs=[1, 2, 3]) | ||
node2 = ComputeSum(inputs=[4, 5, 6]) | ||
node3 = ComputeSum(inputs=[node1.outputs, node2.outputs]) | ||
|
||
depl = znflow.deployment.Deployment(graph=graph) | ||
depl.submit_graph() | ||
|
||
node1 = depl.get_results(node1) | ||
node2 = depl.get_results(node2) | ||
node3 = depl.get_results(node3) | ||
assert node1.outputs == 6 | ||
assert node2.outputs == 15 | ||
assert node3.outputs == 21 | ||
|
||
|
||
def test_multiple_nodify_and_Node(): | ||
with znflow.DiGraph() as graph: | ||
node1 = compute_sum(1, 2, 3) | ||
node2 = ComputeSum(inputs=[4, 5, 6]) | ||
node3 = compute_sum(node1, node2.outputs) | ||
node4 = ComputeSum(inputs=[node1, node2.outputs, node3]) | ||
node5 = add_to_ComputeSum(node4) | ||
|
||
depl = znflow.deployment.Deployment(graph=graph) | ||
depl.submit_graph() | ||
|
||
results = depl.get_results(graph.nodes) | ||
|
||
assert results[node1.uuid].result == 6 | ||
assert results[node2.uuid].outputs == 15 | ||
assert results[node3.uuid].result == 21 | ||
assert results[node4.uuid].outputs == 42 | ||
assert results[node5.uuid].result == 43 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
"""ZnFlow deployment using Dask.""" | ||
|
||
import dataclasses | ||
import typing | ||
import uuid | ||
|
||
from dask.distributed import Client, Future | ||
from networkx.classes.reportviews import NodeView | ||
|
||
from znflow.base import Connection, NodeBaseMixin | ||
from znflow.graph import DiGraph | ||
from znflow.utils import IterableHandler | ||
|
||
|
||
class _LoadNode(IterableHandler): | ||
"""Iterable handler for loading nodes.""" | ||
|
||
def default(self, value, **kwargs): | ||
"""Default handler for loading nodes. | ||
Parameters | ||
---------- | ||
value: NodeBaseMixin|any | ||
If a NodeBaseMixin, the node will be loaded and returned. | ||
kwargs: dict | ||
results: results dictionary of {uuid: node} shape. | ||
Returns | ||
------- | ||
any: | ||
If a NodeBaseMixin, the node will be loaded and returned. | ||
Otherwise, the input value is returned. | ||
""" | ||
results = kwargs["results"] | ||
if isinstance(value, NodeBaseMixin): | ||
return results[value.uuid].result() | ||
|
||
return value | ||
|
||
|
||
class _UpdateConnections(IterableHandler): | ||
"""Iterable handler for replacing connections.""" | ||
|
||
def default(self, value, **kwargs): | ||
"""Replace connections by its values. | ||
Parameters | ||
---------- | ||
value: Connection|any | ||
If a Connection, the connection will be replaced by its result. | ||
kwargs: dict | ||
predecessors: dict of {uuid: Connection} shape. | ||
Returns | ||
------- | ||
any: | ||
If a Connection, the connection will be replaced by its result. | ||
Otherwise, the input value is returned. | ||
""" | ||
predecessors = kwargs["predecessors"] | ||
if isinstance(value, Connection): | ||
# We don't actually need the connection, we need the results. | ||
return dataclasses.replace(value, instance=predecessors[value.uuid]).result | ||
return value | ||
|
||
|
||
def node_submit(node: NodeBaseMixin, **kwargs) -> NodeBaseMixin: | ||
"""Submit script for Dask worker. | ||
Parameters | ||
---------- | ||
node: NodeBaseMixin | ||
the Node class | ||
kwargs: dict | ||
predecessors: dict of {uuid: Connection} shape | ||
Returns | ||
------- | ||
NodeBaseMixin: | ||
the Node class with updated state (after calling "Node.run"). | ||
""" | ||
predecessors = kwargs.get("predecessors", {}) | ||
for item in dir(node): | ||
# TODO this information is available in the graph, | ||
# no need to expensively iterate over all attributes | ||
if item.startswith("_"): | ||
continue | ||
updater = _UpdateConnections() | ||
value = updater(getattr(node, item), predecessors=predecessors) | ||
if updater.updated: | ||
setattr(node, item, value) | ||
|
||
node.run() | ||
return node | ||
|
||
|
||
@dataclasses.dataclass | ||
class Deployment: | ||
"""ZnFlow deployment using Dask. | ||
Attributes | ||
---------- | ||
graph: DiGraph | ||
the znflow graph containing the nodes. | ||
client: Client, optional | ||
the Dask client. | ||
results: Dict[uuid, Future] | ||
a dictionary of {uuid: Future} shape that is filled after the graph is submitted. | ||
""" | ||
|
||
graph: DiGraph | ||
client: Client = dataclasses.field(default_factory=Client) | ||
results: typing.Dict[uuid.UUID, Future] = dataclasses.field( | ||
default_factory=dict, init=False | ||
) | ||
|
||
def submit_graph(self): | ||
"""Submit the graph to Dask. | ||
When submitting to Dask, a Node is serialized, processed and a | ||
copy can be returned. | ||
This requires: | ||
- the connections to be updated to the respective Nodes coming from Dask futures. | ||
- the Node to be returned from the workers and passed to all successors. | ||
""" | ||
for node_uuid in self.graph.reverse(): | ||
node = self.graph.nodes[node_uuid]["value"] | ||
predecessors = list(self.graph.predecessors(node.uuid)) | ||
|
||
if len(predecessors) == 0: | ||
self.results[node.uuid] = self.client.submit( # TODO how to name | ||
node_submit, node=node, pure=False | ||
) | ||
else: | ||
self.results[node.uuid] = self.client.submit( | ||
node_submit, | ||
node=node, | ||
predecessors={ | ||
x: self.results[x] for x in self.results if x in predecessors | ||
}, | ||
pure=False, | ||
) | ||
|
||
def get_results(self, obj: typing.Union[NodeBaseMixin, list, dict, NodeView], /): | ||
"""Get the results from Dask based on the original object. | ||
Parameters | ||
---------- | ||
obj: NodeBaseMixin|list|dict|NodeView | ||
either a single Node or multiple Nodes from the submitted graph. | ||
Returns | ||
------- | ||
any: | ||
Returns an instance of obj which is updated with the results from Dask. | ||
""" | ||
if isinstance(obj, NodeView): | ||
data = _LoadNode()(dict(obj), results=self.results) | ||
return {x: v["value"] for x, v in data.items()} | ||
elif isinstance(obj, DiGraph): | ||
raise NotImplementedError | ||
return _LoadNode()(obj, results=self.results) |