Skip to content

Commit

Permalink
fix RecursionError (#66)
Browse files Browse the repository at this point in the history
* max recursion test

* solve one recursion issue

* small performance increase

* fix RecursionError
  • Loading branch information
PythonFZ authored Mar 24, 2023
1 parent 53b225a commit 8293e23
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
51 changes: 51 additions & 0 deletions tests/test_recursion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import contextlib
import sys

import pytest

import znflow


class AddOne(znflow.Node):
def __init__(self, x):
super().__init__()
self.x = x

def run(self):
self.x += 1


@contextlib.contextmanager
def setrecursionlimit(limit: int):
"""Set the recursion limit for the duration of the context manager."""
_limit = sys.getrecursionlimit()
try:
sys.setrecursionlimit(limit)
yield
finally:
sys.setrecursionlimit(_limit)


@pytest.mark.parametrize("depth", [1, 10, 100, 1000])
def test_AddOneLoop(depth):
with setrecursionlimit(100):
with znflow.DiGraph() as graph:
start = AddOne(0)
for _ in range(depth):
start = AddOne(start.x)

graph.run()
assert len(graph.nodes) == depth + 1
assert start.x == depth + 1


@pytest.mark.parametrize("depth", [1, 10, 100, 1000])
def test_AddOneLoopNoGraph(depth):
with setrecursionlimit(100):
start = AddOne(0)
start.run()
for _ in range(depth):
start = AddOne(start.x)
start.run()

assert start.x == depth + 1
5 changes: 3 additions & 2 deletions znflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,9 @@ def add_connections(self, u_of_edge, v_of_edge, **attr):

def get_sorted_nodes(self):
all_pipelines = []
for stage in self.reverse():
all_pipelines += nx.dfs_postorder_nodes(self.reverse(), stage)
reverse = self.reverse(copy=False)
for stage in reverse:
all_pipelines += nx.dfs_postorder_nodes(reverse, stage)
return list(dict.fromkeys(all_pipelines)) # remove duplicates but keep order

def run(self):
Expand Down
11 changes: 9 additions & 2 deletions znflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ def _mark_init_in_construction(cls):
if "__init__" in dir(cls):

def wrap_init(func):
if getattr(func, "_already_wrapped", False):
# if the function is already wrapped, return it
# TODO this is solving the error but not the root cause
return func

@functools.wraps(func)
def wrapper(*args, **kwargs):
cls._in_construction = True
value = func(*args, **kwargs)
cls._in_construction = False
return value

wrapper._already_wrapped = True

return wrapper

cls.__init__ = wrap_init(cls.__init__)
Expand Down Expand Up @@ -58,7 +65,7 @@ def __new__(cls, *args, **kwargs):
return instance

def __getattribute__(self, item):
if item == "_graph_":
if item.startswith("_"):
return super().__getattribute__(item)
if self._graph_ not in [empty, None]:
with disable_graph():
Expand All @@ -67,7 +74,7 @@ def __getattribute__(self, item):
f"'{self.__class__.__name__}' object has no attribute '{item}'"
)

if item not in type(self)._protected_ and not item.startswith("_"):
if item not in type(self)._protected_:
if self._in_construction:
return super().__getattribute__(item)
return Connection(instance=self, attribute=item)
Expand Down

0 comments on commit 8293e23

Please sign in to comment.