From a4d05904205d2c22054b3efed8bbc5a7a4b9be0e Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 26 Sep 2024 13:52:40 +0200 Subject: [PATCH 1/2] Add concatenate flag to .compute() (#1138) --- dask_expr/_collection.py | 8 ++++-- dask_expr/tests/test_distributed.py | 41 +++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index 26772730..326f1f49 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -447,7 +447,7 @@ def persist(self, fuse=True, **kwargs): out = self.optimize(fuse=fuse) return DaskMethodsMixin.persist(out, **kwargs) - def compute(self, fuse=True, **kwargs): + def compute(self, fuse=True, concatenate=True, **kwargs): """Compute this DataFrame. This turns a lazy Dask DataFrame into an in-memory pandas DataFrame. @@ -463,6 +463,10 @@ def compute(self, fuse=True, **kwargs): Whether to fuse the expression tree before computing. Fusing significantly reduces the number of tasks and improves performance. It shouldn't be disabled unless absolutely necessary. + concatenate : bool, default True + Whether to concatenate all partitions into a single one before computing. + Concatenating enables more powerful optimizations but it also incurs additional + data transfer cost. Generally, it should be enabled. kwargs Extra keywords to forward to the base compute function. @@ -471,7 +475,7 @@ def compute(self, fuse=True, **kwargs): dask.compute """ out = self - if not isinstance(out, Scalar): + if not isinstance(out, Scalar) and concatenate: out = out.repartition(npartitions=1) out = out.optimize(fuse=fuse) return DaskMethodsMixin.compute(out, **kwargs) diff --git a/dask_expr/tests/test_distributed.py b/dask_expr/tests/test_distributed.py index dc791eb6..38741c39 100644 --- a/dask_expr/tests/test_distributed.py +++ b/dask_expr/tests/test_distributed.py @@ -11,10 +11,11 @@ distributed = pytest.importorskip("distributed") -from distributed import Client, LocalCluster +from distributed import Client, LocalCluster, SchedulerPlugin from distributed.shuffle._core import id_from_key +from distributed.utils_test import cleanup # noqa F401 from distributed.utils_test import client as c # noqa F401 -from distributed.utils_test import gen_cluster +from distributed.utils_test import gen_cluster, loop, loop_in_thread # noqa F401 import dask_expr as dx @@ -456,3 +457,39 @@ def test_respect_context_shuffle(df, pdf, func): with dask.config.set({"dataframe.shuffle.method": "tasks"}): result = q.optimize(fuse=False) assert len([x for x in result.walk() if isinstance(x, P2PShuffle)]) > 0 + + +@pytest.mark.parametrize("concatenate", [True, False]) +def test_compute_concatenates(loop, concatenate): + pdf = pd.DataFrame({"a": np.random.randint(1, 100, (100,)), "b": 1}) + df = from_pandas(pdf, npartitions=10) + + class Plugin(SchedulerPlugin): + def start(self, *args, **kwargs): + self.repartition_in_tasks = False + + def update_graph( + self, + scheduler, + *, + client, + keys, + tasks, + annotations, + priority, + dependencies, + **kwargs, + ): + for key in dependencies: + if not isinstance(key, tuple): + continue + group = key[0] + if not isinstance(group, str): + continue + self.repartition_in_tasks |= group.startswith("repartitiontofewer") + + with Client(loop=loop) as c: + c.register_plugin(Plugin(), name="tracker") + df.compute(fuse=False, concatenate=concatenate) + plugin = c.cluster.scheduler.plugins["tracker"] + assert plugin.repartition_in_tasks is concatenate From ca5db2215c94d011ea0f2a086195a96734477d3f Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Fri, 27 Sep 2024 19:57:05 -0500 Subject: [PATCH 2/2] Release for dask 2024.9.1 --- changes.md | 8 ++++++++ pyproject.toml | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/changes.md b/changes.md index fe389acb..c5601701 100644 --- a/changes.md +++ b/changes.md @@ -1,5 +1,13 @@ ## Dask-expr +# v1.1.15 + +- Add concatenate flag to .compute() (:pr:`1138`) `Hendrik Makait`_ + +# v1.1.14 + +- Import from tokenize (:pr:`1133`) `Patrick Hoefler`_ + # v1.1.14 - Import from tokenize (:pr:`1133`) `Patrick Hoefler`_ diff --git a/pyproject.toml b/pyproject.toml index 6ad2147f..4752a5f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ readme = "README.md" requires-python = ">=3.10" dependencies = [ - "dask == 2024.9.0", + "dask == 2024.9.1", "pyarrow>=14.0.1", "pandas >= 2", ]