Skip to content

Commit

Permalink
Add concatenate flag to .compute() (#1138)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Sep 26, 2024
1 parent c4cee18 commit a4d0590
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
8 changes: 6 additions & 2 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def persist(self, fuse=True, **kwargs):
out = self.optimize(fuse=fuse)
return DaskMethodsMixin.persist(out, **kwargs)

def compute(self, fuse=True, **kwargs):
def compute(self, fuse=True, concatenate=True, **kwargs):
"""Compute this DataFrame.
This turns a lazy Dask DataFrame into an in-memory pandas DataFrame.
Expand All @@ -463,6 +463,10 @@ def compute(self, fuse=True, **kwargs):
Whether to fuse the expression tree before computing. Fusing significantly
reduces the number of tasks and improves performance. It shouldn't be
disabled unless absolutely necessary.
concatenate : bool, default True
Whether to concatenate all partitions into a single one before computing.
Concatenating enables more powerful optimizations but it also incurs additional
data transfer cost. Generally, it should be enabled.
kwargs
Extra keywords to forward to the base compute function.
Expand All @@ -471,7 +475,7 @@ def compute(self, fuse=True, **kwargs):
dask.compute
"""
out = self
if not isinstance(out, Scalar):
if not isinstance(out, Scalar) and concatenate:
out = out.repartition(npartitions=1)
out = out.optimize(fuse=fuse)
return DaskMethodsMixin.compute(out, **kwargs)
Expand Down
41 changes: 39 additions & 2 deletions dask_expr/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

distributed = pytest.importorskip("distributed")

from distributed import Client, LocalCluster
from distributed import Client, LocalCluster, SchedulerPlugin
from distributed.shuffle._core import id_from_key
from distributed.utils_test import cleanup # noqa F401
from distributed.utils_test import client as c # noqa F401
from distributed.utils_test import gen_cluster
from distributed.utils_test import gen_cluster, loop, loop_in_thread # noqa F401

import dask_expr as dx

Expand Down Expand Up @@ -456,3 +457,39 @@ def test_respect_context_shuffle(df, pdf, func):
with dask.config.set({"dataframe.shuffle.method": "tasks"}):
result = q.optimize(fuse=False)
assert len([x for x in result.walk() if isinstance(x, P2PShuffle)]) > 0


@pytest.mark.parametrize("concatenate", [True, False])
def test_compute_concatenates(loop, concatenate):
pdf = pd.DataFrame({"a": np.random.randint(1, 100, (100,)), "b": 1})
df = from_pandas(pdf, npartitions=10)

class Plugin(SchedulerPlugin):
def start(self, *args, **kwargs):
self.repartition_in_tasks = False

def update_graph(
self,
scheduler,
*,
client,
keys,
tasks,
annotations,
priority,
dependencies,
**kwargs,
):
for key in dependencies:
if not isinstance(key, tuple):
continue
group = key[0]
if not isinstance(group, str):
continue
self.repartition_in_tasks |= group.startswith("repartitiontofewer")

with Client(loop=loop) as c:
c.register_plugin(Plugin(), name="tracker")
df.compute(fuse=False, concatenate=concatenate)
plugin = c.cluster.scheduler.plugins["tracker"]
assert plugin.repartition_in_tasks is concatenate

0 comments on commit a4d0590

Please sign in to comment.