Skip to content

Commit

Permalink
better support for properties (#42)
Browse files Browse the repository at this point in the history
* better support for properties

* resolve attribute access issue
  • Loading branch information
PythonFZ authored Mar 8, 2023
1 parent ffb261d commit a24ee7c
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 15 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ assert n1.results == 4.0

```

Instead, you can also use the ``znflow.disable_graph`` decorator / context manager to disable the graph for a specific block of code or the ``znflow.Property`` as a drop-in replacement for ``property``.


# Supported Frameworks
ZnFlow includes tests to ensure compatibility with:
- "Plain classes"
Expand Down
96 changes: 88 additions & 8 deletions tests/test_get_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import znflow


class POW2(znflow.Node):
x_factor: float = 0.5
class POW2Base(znflow.Node):
x_factor: float = 1.0
results: float = None
_x: float = None

Expand All @@ -22,6 +22,11 @@ def x_(self, value):
"""
self._x = value * self.x_factor

def run(self):
self.results = self.x**1


class POW2GetAttr(POW2Base):
@property
def x(self):
return self._x
Expand All @@ -30,21 +35,96 @@ def x(self):
def x(self, value):
self._x = value * znflow.get_attribute(self, "x_factor")

def run(self):
self.results = self.x**2

class POW2Decorate(POW2Base):
@property
def x(self):
return self._x

@znflow.disable_graph()
@x.setter
def x(self, value):
self._x = value * self.x_factor


class POW2Decorate2(POW2Base):
@znflow.Property
def x(self):
return self._x

@x.setter
def x(self, value):
self._x = value * self.x_factor


class POW2Context(POW2Base):
@property
def x(self):
return self._x

@x.setter
def x(self, value):
with znflow.disable_graph():
self._x = value * self.x_factor

def test_get_attribute():

@pytest.mark.parametrize("cls", [POW2GetAttr, POW2Decorate, POW2Context, POW2Decorate2])
def test_get_attribute(cls):
with znflow.DiGraph() as graph:
n1 = POW2()
n1 = cls()
n1.x = 4.0 # converted to 2.0

graph.run()
assert n1.x == 2.0
assert n1.x == 4.0
assert n1.results == 4.0

with znflow.DiGraph() as graph:
n1 = POW2()
n1 = cls()
with pytest.raises(TypeError):
# TypeError: unsupported operand type(s) for *: 'float' and 'Connection'
n1.x_ = 4.0


class InvalidAttribute(znflow.Node):
@property
def invalid_attribute(self):
raise ValueError("attribute not available")


def test_invalid_attribute():
node = InvalidAttribute()
with pytest.raises(ValueError):
node.invalid_attribute

with znflow.DiGraph() as graph:
node = InvalidAttribute()
invalid_attribute = node.invalid_attribute
assert isinstance(invalid_attribute, znflow.Connection)
assert invalid_attribute.instance == node
assert invalid_attribute.attribute == "invalid_attribute"
assert node.uuid in graph


class NodeWithInit(znflow.Node):
def __init__(self):
self.x = 1.0


def test_attribute_not_found():
"""Try to access an Attribute which does not exist."""
with pytest.raises(AttributeError):
node = InvalidAttribute()
node.this_does_not_exist

with znflow.DiGraph():
node = POW2GetAttr()
with pytest.raises(AttributeError):
node.this_does_not_exist

with znflow.DiGraph():
node = NodeWithInit()
with pytest.raises(AttributeError):
node.this_does_not_exist
outs = node.x

assert outs.result == 1.0
19 changes: 19 additions & 0 deletions tests/test_node_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ def run(self):
self.outputs = sum(self.inputs)


@dataclasses.dataclass
class SumNodesFromDict(znflow.Node):
inputs: dict
outputs: float = None

def run(self):
self.outputs = sum(self.inputs.values())


def test_eager():
node = Node(inputs=1)
node.run()
Expand Down Expand Up @@ -102,3 +111,13 @@ def test_graph_multi():
graph.run()

assert node7.outputs == 80


def test_SumNodesFromDict():
with znflow.DiGraph() as graph:
node1 = Node(inputs=5)
node2 = Node(inputs=10)
node3 = SumNodesFromDict(inputs={"a": node1.outputs, "b": node2.outputs})
graph.run()

assert node3.outputs == 30
10 changes: 9 additions & 1 deletion znflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import logging
import sys

from znflow.base import Connection, FunctionFuture, get_attribute
from znflow.base import (
Connection,
FunctionFuture,
Property,
disable_graph,
get_attribute,
)
from znflow.graph import DiGraph
from znflow.node import Node, nodify
from znflow.visualize import draw
Expand All @@ -18,6 +24,8 @@
"FunctionFuture",
"Connection",
"get_attribute",
"disable_graph",
"Property",
]

logger = logging.getLogger(__name__)
Expand Down
55 changes: 54 additions & 1 deletion znflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@contextlib.contextmanager
def disable_graph():
def disable_graph(*args, **kwargs):
"""Temporarily disable set the graph to None.
This can be useful, if you e.g. want to use 'get_attribute'.
Expand All @@ -20,6 +20,59 @@ def disable_graph():
set_graph(graph)


class Property:
"""Custom Property with disabled graph.
References
----------
Adapted from https://docs.python.org/3/howto/descriptor.html#properties
"""

def __init__(self, fget=None, fset=None, fdel=None, doc=None):
self.fget = disable_graph()(fget)
self.fset = disable_graph()(fset)
self.fdel = disable_graph()(fdel)
if doc is None and fget is not None:
doc = fget.__doc__
self.__doc__ = doc
self._name = ""

def __set_name__(self, owner, name):
self._name = name

def __get__(self, obj, objtype=None):
if obj is None:
return self
if self.fget is None:
raise AttributeError(f"property '{self._name}' has no getter")
return self.fget(obj)

def __set__(self, obj, value):
if self.fset is None:
raise AttributeError(f"property '{self._name}' has no setter")
self.fset(obj, value)

def __delete__(self, obj):
if self.fdel is None:
raise AttributeError(f"property '{self._name}' has no deleter")
self.fdel(obj)

def getter(self, fget):
prop = type(self)(fget, self.fset, self.fdel, self.__doc__)
prop._name = self._name
return prop

def setter(self, fset):
prop = type(self)(self.fget, fset, self.fdel, self.__doc__)
prop._name = self._name
return prop

def deleter(self, fdel):
prop = type(self)(self.fget, self.fset, fdel, self.__doc__)
prop._name = self._name
return prop


class NodeBaseMixin:
"""A Parent for all Nodes.
Expand Down
8 changes: 7 additions & 1 deletion znflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,13 @@ def _update_node_attributes(self, node_instance: Node, updater) -> None:
if attribute.startswith("_") or attribute in Node._protected_:
# We do not allow connections to private attributes.
continue
value = getattr(node_instance, attribute)
try:
value = getattr(node_instance, attribute)
except Exception:
# It might be, that the value is currently not available.
# For example, it could be a property that is not yet set.
# In this case we skip updating the attribute, no matter the exception.
continue
value = updater(value)
if updater.updated:
try:
Expand Down
19 changes: 15 additions & 4 deletions znflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import inspect
import uuid

from znflow.base import Connection, FunctionFuture, NodeBaseMixin, get_graph
from znflow.base import (
Connection,
FunctionFuture,
NodeBaseMixin,
disable_graph,
get_graph,
)


def _mark_init_in_construction(cls):
Expand Down Expand Up @@ -51,14 +57,19 @@ def __new__(cls, *args, **kwargs):
return instance

def __getattribute__(self, item):
value = super().__getattribute__(item)
if get_graph() is not None:
with disable_graph():
if item not in set(dir(self)):
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{item}'"
)

if item not in type(self)._protected_ and not item.startswith("_"):
if self._in_construction:
return value
return super().__getattribute__(item)
connector = Connection(instance=self, attribute=item)
return connector
return value
return super().__getattribute__(item)

def __setattr__(self, item, value) -> None:
super().__setattr__(item, value)
Expand Down

0 comments on commit a24ee7c

Please sign in to comment.