Skip to content

Commit

Permalink
Fix fusion calling things multiple times (#1161)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Nov 11, 2024
1 parent 2205ad8 commit ea970f1
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 10 deletions.
35 changes: 27 additions & 8 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import pandas as pd
from dask._task_spec import Alias, DataNode, Task, TaskRef
from dask._task_spec import Alias, DataNode, Task, TaskRef, execute_graph
from dask.array import Array
from dask.core import flatten
from dask.dataframe import methods
Expand Down Expand Up @@ -3766,19 +3766,38 @@ def _broadcast_dep(self, dep: Expr):
return dep.npartitions == 1

def _task(self, name: Key, index: int) -> Task:
subgraphs = {}
internal_tasks = []
seen_keys = set()
external_deps = set()
for _expr in self.exprs:
if self._broadcast_dep(_expr):
subname = (_expr._name, 0)
else:
subname = (_expr._name, index)
subgraphs[subname] = _expr._task(subname, subname[1])

for i, dep in enumerate(self.dependencies()):
subgraphs[self._blockwise_arg(dep, index)] = "_" + str(i)
t = _expr._task(subname, subname[1])
assert t.key == subname
internal_tasks.append(t)
seen_keys.add(subname)
external_deps.update(t.dependencies)
external_deps -= seen_keys
dependencies = {dep: TaskRef(dep) for dep in external_deps}
t = Task(
name,
Fused._execute_internal_graph,
# Wrap the actual subgraph as a data node such that the tasks are
# not erroneously parsed. The external task would otherwise carry
# the internal keys as dependencies which is not satisfiable
DataNode(None, internal_tasks),
dependencies,
(self.exprs[0]._name, index),
)
return t

result = subgraphs.pop((self.exprs[0]._name, index))
return result.inline(subgraphs)
@staticmethod
def _execute_internal_graph(internal_tasks, dependencies, outkey):
cache = dict(dependencies)
res = execute_graph(internal_tasks, cache=cache, keys=[outkey])
return res[outkey]


# Used for sorting with None
Expand Down
7 changes: 6 additions & 1 deletion dask_expr/_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,12 @@ def _layer_cache(self):
return convert_legacy_graph(self._layer())

def _task(self, name: Key, index: int) -> Task:
return self._layer_cache[(self._name, index)]
t = self._layer_cache[(self._name, index)]
if isinstance(t, Alias):
return Alias(name, t.target)
elif t.key != name:
return Task(name, lambda x: x, t)
return t


class LocUnknown(Blockwise):
Expand Down
5 changes: 4 additions & 1 deletion dask_expr/io/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ def _tasks(self):
def _filtered_task(self, name: Key, index: int) -> Task:
if self._series:
return Task(name, operator.getitem, self._tasks[index], self.columns[0])
return self._tasks[index]
t = self._tasks[index]
if t.key != name:
return Task(name, lambda x: x, t)
return t


class ReadTable(ReadCSV):
Expand Down
15 changes: 15 additions & 0 deletions dask_expr/tests/test_fusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dask.dataframe as dd
import pytest

from dask_expr import from_pandas, optimize
Expand Down Expand Up @@ -128,3 +129,17 @@ def test_name(df):
assert "getitem" in str(fused.expr)
assert "sub" in str(fused.expr)
assert str(fused.expr) == str(fused.expr).lower()


def test_fusion_executes_only_once():
times_called = []
import pandas as pd

def test(i):
times_called.append(i)
return pd.DataFrame({"a": [1, 2, 3], "b": 1})

df = dd.from_map(test, [1], meta=[("a", "i8"), ("b", "i8")])
df = df[df.a > 1]
df.sum().compute()
assert len(times_called) == 1

0 comments on commit ea970f1

Please sign in to comment.