From 0218b6bb900a32ca9032de6459c266a2f3364f6d Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Wed, 13 Nov 2024 19:45:22 +0100 Subject: [PATCH] Introduce more caching when walking the expression --- dask_expr/_core.py | 21 +++++++++++++++++---- dask_expr/_expr.py | 2 +- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/dask_expr/_core.py b/dask_expr/_core.py index 04e043df..ed9b8c41 100644 --- a/dask_expr/_core.py +++ b/dask_expr/_core.py @@ -160,17 +160,26 @@ def __reduce__(self): raise RuntimeError(f"Serializing a {type(self)} object") return type(self), tuple(self.operands) - def _depth(self): + def _depth(self, cache=None): """Depth of the expression tree Returns ------- depth: int """ + if cache is None: + cache = {} if not self.dependencies(): return 1 else: - return max(expr._depth() for expr in self.dependencies()) + 1 + result = [] + for expr in self.dependencies(): + if expr._name in cache: + result.append(cache[expr._name]) + else: + result.append(expr._depth(cache) + 1) + cache[expr._name] = result[-1] + return max(result) def operand(self, key): # Access an operand unambiguously @@ -242,7 +251,7 @@ def _layer(self) -> dict: for i in range(self.npartitions) } - def rewrite(self, kind: str): + def rewrite(self, kind: str, rewritten): """Rewrite an expression This leverages the ``._{kind}_down`` and ``._{kind}_up`` @@ -255,6 +264,9 @@ def rewrite(self, kind: str): changed: whether or not any change occured """ + if self._name in rewritten: + return rewritten[self._name] + expr = self down_name = f"_{kind}_down" up_name = f"_{kind}_up" @@ -291,7 +303,8 @@ def rewrite(self, kind: str): changed = False for operand in expr.operands: if isinstance(operand, Expr): - new = operand.rewrite(kind=kind) + new = operand.rewrite(kind=kind, rewritten=rewritten) + rewritten[operand._name] = new if new._name != operand._name: changed = True else: diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 9f518911..61ff5c72 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -3053,7 +3053,7 @@ def optimize_until(expr: Expr, stage: core.OptimizerStage) -> Expr: return expr # Manipulate Expression to make it more efficient - expr = expr.rewrite(kind="tune") + expr = expr.rewrite(kind="tune", rewritten={}) if stage == "tuned-logical": return expr