diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 956668b8..739456a1 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -10,7 +10,7 @@ import numpy as np import pandas as pd -from dask._task_spec import Alias, DataNode, Task, TaskRef +from dask._task_spec import Alias, DataNode, Task, TaskRef, execute_graph from dask.array import Array from dask.core import flatten from dask.dataframe import methods @@ -3766,19 +3766,38 @@ def _broadcast_dep(self, dep: Expr): return dep.npartitions == 1 def _task(self, name: Key, index: int) -> Task: - subgraphs = {} + internal_tasks = [] + seen_keys = set() + external_deps = set() for _expr in self.exprs: if self._broadcast_dep(_expr): subname = (_expr._name, 0) else: subname = (_expr._name, index) - subgraphs[subname] = _expr._task(subname, subname[1]) - - for i, dep in enumerate(self.dependencies()): - subgraphs[self._blockwise_arg(dep, index)] = "_" + str(i) + t = _expr._task(subname, subname[1]) + assert t.key == subname + internal_tasks.append(t) + seen_keys.add(subname) + external_deps.update(t.dependencies) + external_deps -= seen_keys + dependencies = {dep: TaskRef(dep) for dep in external_deps} + t = Task( + name, + Fused._execute_internal_graph, + # Wrap the actual subgraph as a data node such that the tasks are + # not erroneously parsed. The external task would otherwise carry + # the internal keys as dependencies which is not satisfiable + DataNode(None, internal_tasks), + dependencies, + (self.exprs[0]._name, index), + ) + return t - result = subgraphs.pop((self.exprs[0]._name, index)) - return result.inline(subgraphs) + @staticmethod + def _execute_internal_graph(internal_tasks, dependencies, outkey): + cache = dict(dependencies) + res = execute_graph(internal_tasks, cache=cache, keys=[outkey]) + return res[outkey] # Used for sorting with None diff --git a/dask_expr/_indexing.py b/dask_expr/_indexing.py index caa80de7..2bb12a23 100644 --- a/dask_expr/_indexing.py +++ b/dask_expr/_indexing.py @@ -173,7 +173,12 @@ def _layer_cache(self): return convert_legacy_graph(self._layer()) def _task(self, name: Key, index: int) -> Task: - return self._layer_cache[(self._name, index)] + t = self._layer_cache[(self._name, index)] + if isinstance(t, Alias): + return Alias(name, t.target) + elif t.key != name: + return Task(name, lambda x: x, t) + return t class LocUnknown(Blockwise): diff --git a/dask_expr/io/csv.py b/dask_expr/io/csv.py index 75684108..af79dee1 100644 --- a/dask_expr/io/csv.py +++ b/dask_expr/io/csv.py @@ -127,7 +127,10 @@ def _tasks(self): def _filtered_task(self, name: Key, index: int) -> Task: if self._series: return Task(name, operator.getitem, self._tasks[index], self.columns[0]) - return self._tasks[index] + t = self._tasks[index] + if t.key != name: + return Task(name, lambda x: x, t) + return t class ReadTable(ReadCSV): diff --git a/dask_expr/tests/test_fusion.py b/dask_expr/tests/test_fusion.py index 6546c435..1a82bb14 100644 --- a/dask_expr/tests/test_fusion.py +++ b/dask_expr/tests/test_fusion.py @@ -1,3 +1,4 @@ +import dask.dataframe as dd import pytest from dask_expr import from_pandas, optimize @@ -128,3 +129,17 @@ def test_name(df): assert "getitem" in str(fused.expr) assert "sub" in str(fused.expr) assert str(fused.expr) == str(fused.expr).lower() + + +def test_fusion_executes_only_once(): + times_called = [] + import pandas as pd + + def test(i): + times_called.append(i) + return pd.DataFrame({"a": [1, 2, 3], "b": 1}) + + df = dd.from_map(test, [1], meta=[("a", "i8"), ("b", "i8")]) + df = df[df.a > 1] + df.sum().compute() + assert len(times_called) == 1