Skip to content

Commit

Permalink
Ensure subgraphs are releasing eagerly
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Nov 11, 2024
1 parent 5da4a7b commit 7ea87ed
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 1 deletion.
26 changes: 26 additions & 0 deletions dask/_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,32 @@ def is_coro(self):
self._is_coro = False
return self._is_coro

@staticmethod
def fuse(*tasks: Task, key: KeyType | None = None) -> Task:
leafs = set()
all_keys = set()
all_deps: set[KeyType] = set()
for t in tasks:
if t.key not in all_deps:
leafs.add(t.key)
all_deps.update(t.dependencies)
all_keys.add(t.key)
leafs -= t.dependencies
external_deps = all_deps - set(all_keys)
if len(leafs) > 1:
raise ValueError("Cannot fuse tasks with multiple outputs")

outkey = leafs.pop()

return Task(
key or outkey,
_execute_subgraph,
DataNode(None, {t.key: t for t in tasks}),
outkey,
{k: Alias(k) for k in external_deps},
{},
)


class DependenciesMapping(MutableMapping):
def __init__(self, dsk):
Expand Down
91 changes: 90 additions & 1 deletion dask/tests/test_task_spec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
import pickle
import sys
from collections import namedtuple
Expand Down Expand Up @@ -817,7 +818,95 @@ def test_dependencies_mapping_doesnt_mutate_task():
assert t2.dependencies == {"key"}


def test_subgraph_dont_hold_in_memory_too_long():
def test_fuse_tasks_key():
a = Task("key-1", func, "a", "b")
b = Task("key-2", func2, a.ref(), "d")
for t1, t2 in itertools.permutations((a, b)):
fused = Task.fuse(t2, t1)
assert fused.key == b.key

fused = Task.fuse(t2, t1, key="new-key")
assert fused.key == "new-key"


def test_fuse_tasks():
a = Task("key-1", func, "a", "b")
b = Task("key-2", func2, a.ref(), "d")
for t1, t2 in itertools.permutations((a, b)):
fused = Task.fuse(t2, t1)

assert fused() == func2(func("a", "b"), "d")

t1 = Task("key-1", func, TaskRef("dependency"), "b")
t2 = Task("key-2", func2, t1.ref(), "d")

fused = Task.fuse(t2, t1)
assert fused.dependencies == {"dependency"}

assert fused({"dependency": "dep"}) == func2(func("dep", "b"), "d")


def test_fuse_reject_multiple_outputs():
a = Task("key-1", func, "a", "b")
b = Task("key-2", func2, "a", "d")
for t1, t2 in itertools.permutations((a, b)):
with pytest.raises(ValueError, match="multiple outputs"):
Task.fuse(t1, t2)


def test_fused_ensure_only_executed_once():
counter = []

def counter_func(a, b):
counter.append(None)
return func(a, b)

a = Task("key-1", counter_func, "a", "a")
b = Task("key-2", func2, a.ref(), "b")
c = Task("key-3", func2, a.ref(), "c")
d = Task("key-4", func, b.ref(), c.ref())
for perm in itertools.permutations([a, b, c, d]):
fused = Task.fuse(*perm)
counter.clear()
assert fused() == func(func2(func("a", "a"), "b"), func2(func("a", "a"), "c"))
assert len(counter) == 1


def test_fused_dont_hold_in_memory_too_long():
tasks = []
prev = None

# If we execute a fused task we want to release objects as quickly as
# possible. If every task generates this object, we must at most hold two of
# them in memory
class OnlyTwice:
counter = 0
total = 0

def __init__(self):
OnlyTwice.counter += 1
OnlyTwice.total += 1
if OnlyTwice.counter > 2:
raise ValueError("Didn't release as expected")

def __del__(self):
OnlyTwice.counter -= 1

def generate_object(arg):
return OnlyTwice()

prev = None
for ix in range(10):
prev = t = Task(
f"key-{ix}", generate_object, prev.ref() if prev is not None else ix
)
tasks.append(t)
fuse = Task.fuse(*tasks)
assert fuse()
assert OnlyTwice.total == 10


def test_subgraph_dont_hold_in_memory_too_long_legacy():
prev = None

# If we execute a fused task we want to release objects as quickly as
Expand Down

0 comments on commit 7ea87ed

Please sign in to comment.