From 5df8490116e9deaeaa217784c45bb6862550ffad Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Fri, 1 Mar 2024 12:17:33 -0600 Subject: [PATCH 1/7] multifill and typetracing optimizations --- src/dask_histogram/boost.py | 153 ++++++++++++++++-------------- src/dask_histogram/core.py | 130 ++++++++++++++++++------- src/dask_histogram/histogram.yaml | 2 +- tests/test_boost.py | 12 ++- tests/test_core.py | 32 +++---- 5 files changed, 202 insertions(+), 127 deletions(-) diff --git a/src/dask_histogram/boost.py b/src/dask_histogram/boost.py index e36cc98..e36939c 100644 --- a/src/dask_histogram/boost.py +++ b/src/dask_histogram/boost.py @@ -10,7 +10,6 @@ import boost_histogram.storage as storage import dask import dask.array as da -from dask.bag.core import empty_safe_aggregate, partition_all from dask.base import DaskMethodsMixin, dont_optimize, is_dask_collection, tokenize from dask.context import globalmethod from dask.delayed import Delayed, delayed @@ -20,7 +19,14 @@ from tlz import first from dask_histogram.bins import normalize_bins_range -from dask_histogram.core import AggHistogram, _get_optimization_function, factory +from dask_histogram.core import ( + AggHistogram, + _get_optimization_function, + _partitioned_histogram_multifill, + _reduction, + factory, + is_dask_awkward_like, +) if TYPE_CHECKING: from dask_histogram.typing import ( @@ -36,55 +42,6 @@ __all__ = ("Histogram", "histogram", "histogram2d", "histogramdd") -def _build_staged_tree_reduce( - stages: list[AggHistogram], split_every: int | bool -) -> HighLevelGraph: - if not split_every: - split_every = len(stages) - - reducer = sum - - token = tokenize(stages, reducer, split_every) - - k = len(stages) - b = "" - fmt = f"staged-fill-aggregate-{token}" - depth = 0 - - dsk = {} - - if k > 1: - while k > split_every: - c = fmt + str(depth) - for i, inds in enumerate(partition_all(split_every, range(k))): - dsk[(c, i)] = ( - empty_safe_aggregate, - reducer, - [ - (stages[j].name if depth == 0 else b, 0 if depth == 0 else j) - for j in inds - ], - False, - ) - - k = i + 1 - b = c - depth += 1 - - dsk[(fmt, 0)] = ( - empty_safe_aggregate, - reducer, - [ - (stages[j].name if depth == 0 else b, 0 if depth == 0 else j) - for j in range(k) - ], - True, - ) - return fmt, HighLevelGraph.from_collections(fmt, dsk, dependencies=stages) - - return stages[0].name, stages[0].dask - - class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram): """Histogram object capable of lazy computation. @@ -97,9 +54,6 @@ class Histogram(bh.Histogram, DaskMethodsMixin, family=dask_histogram): type is :py:class:`boost_histogram.storage.Double`. metadata : Any Data that is passed along if a new histogram is created. - split_every : int | bool | None, default None - Width of aggregation layers for staged fills. - If False, all staged fills are added in one layer (memory intensive!). See Also -------- @@ -139,7 +93,7 @@ def __init__( ) -> None: """Construct a Histogram object.""" super().__init__(*axes, storage=storage, metadata=metadata) - self._staged: list[AggHistogram] | None = None + self._staged: AggHistogram | None = None self._dask_name: str | None = ( f"empty-histogram-{tokenize(*axes, storage, metadata)}" ) @@ -148,14 +102,12 @@ def __init__( {}, ) self._split_every = split_every - if self._split_every is None: - self._split_every = dask.config.get("histogram.aggregation.split_every", 8) @property def _histref(self): return ( tuple(self.axes), - self.storage_type, + self.storage_type(), self.metadata, ) @@ -164,12 +116,6 @@ def __iadd__(self, other): self._staged += other._staged elif not self.staged_fills() and other.staged_fills(): self._staged = other._staged - if self.staged_fills(): - new_name, new_graph = _build_staged_tree_reduce( - self._staged, self._split_every - ) - self._dask = new_graph - self._dask_name = new_name return self def __add__(self, other): @@ -234,6 +180,8 @@ def _in_memory_type(self) -> type[bh.Histogram]: @property def dask_name(self) -> str: + if self._dask_name == "__not_yet_calculated__" and self._dask is None: + self._build_taskgraph() if self._dask_name is None: raise RuntimeError( "The dask name should never be None when it's requested." @@ -242,12 +190,73 @@ def dask_name(self) -> str: @property def dask(self) -> HighLevelGraph: + if self._dask_name == "__not_yet_calculated__" and self._dask is None: + self._build_taskgraph() if self._dask is None: raise RuntimeError( "The dask graph should never be None when it's requested." ) return self._dask + def _build_taskgraph(self): + data_list = [] + weights = [] + samples = [] + + dask_data = tuple( + datum + for datum in ( + self._staged[0]["args"] + tuple(self._staged[0]["kwargs"].values()) + ) + if is_dask_collection(datum) + ) + + if is_dask_awkward_like(dask_data[0]): + + for afill in self._staged: + data_list.append(afill["args"]) + weights.append(afill["kwargs"]["weight"]) + samples.append(afill["kwargs"]["sample"]) + + if all(weight is None for weight in weights): + weights = None + + if not all(sample is None for sample in samples): + samples = None + + split_every = self._split_every + if split_every is None: + split_every = dask.config.get("histogram.aggregation.split-every", 8) + + fills = _partitioned_histogram_multifill( + data_list, self._histref, weights, samples + ) + + output_hist = _reduction(fills, split_every) + else: + + first_fill = self._staged.pop() + + output_hist = factory( + *first_fill["args"], + histref=self._histref, + weights=first_fill["kwargs"]["weight"], + sample=first_fill["kwargs"]["sample"], + ) + + for afill in self._staged: + output_hist += factory( + *afill["args"], + histref=self._histref, + weights=afill["kwargs"]["weight"], + sample=afill["kwargs"]["sample"], + ) + + self._staged = None + self._staged_result = output_hist + self._dask = output_hist.dask + self._dask_name = output_hist.name + def fill( # type: ignore self, *args: DaskCollection, @@ -318,14 +327,14 @@ def fill( # type: ignore else: raise ValueError(f"Cannot interpret input data: {args}") - new_fill = factory(*args, histref=self._histref, weights=weight, sample=sample) + # new_fill = partitioned_factory(*args, histref=self._histref, weights=weight, sample=sample) + new_fill = {"args": args, "kwargs": {"weight": weight, "sample": sample}} if self._staged is None: self._staged = [new_fill] else: - self._staged += [new_fill] - new_name, new_graph = _build_staged_tree_reduce(self._staged, self._split_every) - self._dask = new_graph - self._dask_name = new_name + self._staged.append(new_fill) + self._dask = None # self._staged.__dask_graph__() + self._dask_name = "__not_yet_calculated__" return self @@ -383,7 +392,8 @@ def to_delayed(self) -> Delayed: """ if self._staged is not None: - return sum(self._staged[1:], start=self._staged[0]).to_delayed() + self._build_taskgraph() + return self._staged_result.to_delayed() return delayed(bh.Histogram(self)) def __repr__(self) -> str: @@ -449,7 +459,8 @@ def to_dask_array(self, flow: bool = False, dd: bool = True) -> Any: """ if self._staged is not None: - return sum(self._staged).to_dask_array(flow=flow, dd=dd) + self._build_taskgraph() + return self._staged_result.to_dask_array(flow=flow, dd=dd) else: counts, edges = self.to_numpy(flow=flow, dd=True, view=False) counts = da.from_array(counts) diff --git a/src/dask_histogram/core.py b/src/dask_histogram/core.py index 629527b..89b3163 100644 --- a/src/dask_histogram/core.py +++ b/src/dask_histogram/core.py @@ -14,6 +14,7 @@ from dask.core import flatten from dask.delayed import Delayed from dask.highlevelgraph import HighLevelGraph +from dask.local import identity from dask.threaded import get as tget from dask.utils import is_dataframe_like, key_split @@ -32,6 +33,11 @@ ) +def hist_safe_sum(items): + safe_items = [item for item in items if not isinstance(item, tuple)] + return sum(safe_items) + + def clone(histref: bh.Histogram | None = None) -> bh.Histogram: """Create a Histogram object based on another. @@ -63,7 +69,7 @@ def _blocked_sa( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) if data.ndim == 1: return thehist.fill(data) @@ -83,7 +89,7 @@ def _blocked_sa_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) if data.ndim == 1: return thehist.fill(data, sample=sample) @@ -103,7 +109,7 @@ def _blocked_sa_w( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) if data.ndim == 1: return thehist.fill(data, weight=weights) @@ -124,7 +130,7 @@ def _blocked_sa_w_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) if data.ndim == 1: return thehist.fill(data, weight=weights, sample=sample) @@ -142,7 +148,7 @@ def _blocked_ma( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*data) @@ -157,7 +163,7 @@ def _blocked_ma_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*data, sample=sample) @@ -172,7 +178,7 @@ def _blocked_ma_w( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*data, weight=weights) @@ -188,7 +194,7 @@ def _blocked_ma_w_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*data, weight=weights, sample=sample) @@ -201,7 +207,7 @@ def _blocked_df( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*(data[c] for c in data.columns), weight=None) @@ -215,7 +221,7 @@ def _blocked_df_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*(data[c] for c in data.columns), sample=sample) @@ -230,7 +236,7 @@ def _blocked_df_w( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*(data[c] for c in data.columns), weight=weights) @@ -246,7 +252,7 @@ def _blocked_df_w_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*(data[c] for c in data.columns), weight=weights, sample=sample) @@ -279,7 +285,7 @@ def _blocked_dak( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(thedata, weight=theweights, sample=thesample) @@ -302,7 +308,7 @@ def _blocked_dak_ma( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*tuple(thedata)) @@ -330,9 +336,13 @@ def _blocked_dak_ma_w( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) - return thehist.fill(*tuple(thedata), weight=theweights) + + if ak.backend(*data) != "typetracer": + thehist.fill(*tuple(thedata), weight=theweights) + + return thehist def _blocked_dak_ma_s( @@ -358,7 +368,7 @@ def _blocked_dak_ma_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*tuple(thedata), sample=thesample) @@ -391,11 +401,45 @@ def _blocked_dak_ma_w_s( thehist = ( clone(histref) if not isinstance(histref, tuple) - else bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) return thehist.fill(*tuple(thedata), weight=theweights, sample=thesample) +def _blocked_multi_dak( + data_list: tuple[tuple[Any]], + weights: tuple[Any] | None, + samples: tuple[Any] | None, + histref: tuple | bh.Histogram | None = None, +) -> bh.Histogram: + import awkward as ak + + thehist = ( + clone(histref) + if not isinstance(histref, tuple) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + ) + + backend = ak.backend(*data_list[0]) + + for idata, data in enumerate(data_list): + weight = None if weights is None else weights[idata] + sample = None if samples is None else samples[idata] + + if backend != "typetracer": + thehist.fill(*data, weight=weight, sample=sample) + else: + for datum in data: + if isinstance(datum, ak.highlevel.Array): + ak.typetracer.touch_data(datum) + if isinstance(weight, ak.highlevel.Array): + ak.typetracer.touch_data(weight) + if isinstance(sample, ak.highlevel.Array): + ak.typetracer.touch_data(sample) + + return thehist + + def optimize( dsk: Mapping, keys: Hashable | list[Hashable] | set[Hashable], @@ -516,15 +560,11 @@ def histref(self): @property def _storage_type(self) -> type[bh.storage.Storage]: """Storage type of the histogram.""" - if isinstance(self.histref, tuple): - return self.histref[1] return self.histref.storage_type @property def ndim(self) -> int: """Total number of dimensions.""" - if isinstance(self.histref, tuple): - return len(self.histref[0]) return self.histref.ndim @property @@ -751,17 +791,12 @@ def to_delayed(self, optimize_graph: bool = True) -> list[Delayed]: return [Delayed(k, graph, layer=layer) for k in keys] -def _hist_safe_sum(items): - safe_items = [item for item in items if not isinstance(item, tuple)] - return sum(safe_items) - - def _reduction( ph: PartitionedHistogram, split_every: int | None = None, ) -> AggHistogram: if split_every is None: - split_every = dask.config.get("histogram.aggregation.split_every", 8) + split_every = dask.config.get("histogram.aggregation.split-every", 8) if split_every is False: split_every = ph.npartitions @@ -776,9 +811,9 @@ def _reduction( name=name_agg, name_input=ph.name, npartitions_input=ph.npartitions, - concat_func=_hist_safe_sum, - tree_node_func=lambda x: x, - finalize_func=lambda x: x, + concat_func=hist_safe_sum, + tree_node_func=identity, + finalize_func=identity, split_every=split_every, tree_node_name=name_comb, ) @@ -878,6 +913,35 @@ def _partitionwise(func, layer_name, *args, **kwargs): ) +class PackedMultifill: + def __init__(self, repacker): + self.repacker = repacker + + def __call__(self, *args): + return _blocked_multi_dak(*self.repacker(args)) + + +def _partitioned_histogram_multifill( + data: tuple[DaskCollection | tuple], + histref: bh.Histogram | tuple, + weights: tuple[DaskCollection] | None = None, + samples: tuple[DaskCollection] | None = None, +): + name = f"hist-on-block-{tokenize(data, histref, weights, samples)}" + + from dask.base import unpack_collections + from dask_awkward.lib.core import partitionwise_layer as dak_pwl + + flattened_deps, repacker = unpack_collections(data, weights, samples, histref) + + graph = dak_pwl(PackedMultifill(repacker), name, *flattened_deps) + + hlg = HighLevelGraph.from_collections(name, graph, dependencies=flattened_deps) + return PartitionedHistogram( + hlg, name, flattened_deps[0].npartitions, histref=histref + ) + + def _partitioned_histogram( *data: DaskCollection, histref: bh.Histogram | tuple, @@ -1006,9 +1070,7 @@ def to_dask_array(agghist: AggHistogram, flow: bool = False, dd: bool = False) - thehist = agghist.histref if isinstance(thehist, tuple): thehist = bh.Histogram( - *agghist.histref[0], - storage=agghist.histref[1](), - metadata=agghist.histref[2], + *agghist.histref[0], storage=agghist.histref[1], metadata=agghist.histref[2] ) zeros = (0,) * thehist.ndim dsk = {(name, *zeros): (lambda x, f: x.to_numpy(flow=f)[0], agghist.key, flow)} diff --git a/src/dask_histogram/histogram.yaml b/src/dask_histogram/histogram.yaml index 293b71a..43ea6e4 100644 --- a/src/dask_histogram/histogram.yaml +++ b/src/dask_histogram/histogram.yaml @@ -7,4 +7,4 @@ histogram: # aggregated histogram this parameter controls how the tree # reduction is handled; this number of nodes will be combined at a # time as a new dask task. - split_every: 8 + split-every: 8 diff --git a/tests/test_boost.py b/tests/test_boost.py index d16e0a1..3268e70 100644 --- a/tests/test_boost.py +++ b/tests/test_boost.py @@ -159,10 +159,12 @@ def test_obj_5D_strcat_intcat_rectangular_dak(use_weights): x = dak.from_dask_array(da.random.standard_normal(size=2000, chunks=400)) y = dak.from_dask_array(da.random.standard_normal(size=2000, chunks=400)) z = dak.from_dask_array(da.random.standard_normal(size=2000, chunks=400)) + weights = [] if use_weights: - weights = dak.from_dask_array( - da.random.uniform(0.5, 0.75, size=2000, chunks=400) - ) + for i in range(25): + weights.append( + dak.from_dask_array(da.random.uniform(0.5, 0.75, size=2000, chunks=400)) + ) storage = dhb.storage.Weight() else: weights = None @@ -181,7 +183,7 @@ def test_obj_5D_strcat_intcat_rectangular_dak(use_weights): assert h.__dask_optimize__ == dak.lib.optimize.all_optimizations for i in range(25): - h.fill(f"testcat{i+1}", i + 1, x, y, z, weight=weights) + h.fill(f"testcat{i+1}", i + 1, x, y, z, weight=weights[i] if weights else None) h = h.compute() control = bh.Histogram(*h.axes, storage=h.storage_type()) @@ -189,7 +191,7 @@ def test_obj_5D_strcat_intcat_rectangular_dak(use_weights): if use_weights: for i in range(25): control.fill( - f"testcat{i+1}", i + 1, x_c, y_c, z_c, weight=weights.compute() + f"testcat{i+1}", i + 1, x_c, y_c, z_c, weight=weights[i].compute() ) else: for i in range(25): diff --git a/tests/test_core.py b/tests/test_core.py index 60459af..e76061c 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -12,13 +12,13 @@ def _gen_storage(weights, sample): if weights is not None and sample is not None: - store = bh.storage.WeightedMean + store = bh.storage.WeightedMean() elif weights is None and sample is not None: - store = bh.storage.Mean + store = bh.storage.Mean() elif weights is not None and sample is None: - store = bh.storage.Weight + store = bh.storage.Weight() else: - store = bh.storage.Double + store = bh.storage.Double() return store @@ -31,7 +31,7 @@ def test_1d_array(weights, sample): sample = da.random.uniform(2, 8, size=(2000,), chunks=(250,)) store = _gen_storage(weights, sample) histref = ((bh.axis.Regular(10, -3, 3),), store, None) - h = bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + h = bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) x = da.random.standard_normal(size=(2000,), chunks=(250,)) dh = dhc.factory(x, histref=histref, weights=weights, split_every=4, sample=sample) h.fill( @@ -59,7 +59,7 @@ def test_array_input(weights, shape, sample): sample = da.random.uniform(3, 9, size=(2000,), chunks=(200,)) store = _gen_storage(weights, sample) histref = (axes, store, None) - h = bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + h = bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) dh = dhc.factory(x, histref=histref, weights=weights, split_every=4, sample=sample) h.fill( *xc, @@ -76,12 +76,12 @@ def test_multi_array(weights): bh.axis.Regular(10, -3, 3), bh.axis.Regular(10, -3, 3), ), - bh.storage.Weight, + bh.storage.Weight(), None, ) h = bh.Histogram( *histref[0], - storage=histref[1](), + storage=histref[1], metadata=histref[2], ) if weights is not None: @@ -105,12 +105,12 @@ def test_nd_array(weights): bh.axis.Regular(10, 0, 1), bh.axis.Regular(10, 0, 1), ), - bh.storage.Weight, + bh.storage.Weight(), None, ) h = bh.Histogram( *histref[0], - storage=histref[1](), + storage=histref[1], metadata=histref[2], ) if weights is not None: @@ -134,10 +134,10 @@ def test_df_input(weights): bh.axis.Regular(12, 0, 1), bh.axis.Regular(12, 0, 1), ), - bh.storage.Weight, + bh.storage.Weight(), None, ) - h = bh.Histogram(*histref[0], storage=histref[1](), metadata=histref[2]) + h = bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) df = dds.timeseries(freq="600s", partition_freq="2d") dfc = df.compute() if weights is not None: @@ -166,7 +166,7 @@ def test_to_dask_array(weights, shape): ) h = bh.Histogram(*axes, storage=bh.storage.Weight()) dh = dhc.factory( - x, histref=(axes, bh.storage.Weight, None), weights=weights, split_every=4 + x, histref=(axes, bh.storage.Weight(), None), weights=weights, split_every=4 ) h.fill(*xc, weight=weights.compute() if weights is not None else None) c, _ = dh.to_dask_array(flow=False, dd=True) @@ -181,7 +181,7 @@ def gen_hist_1D( ) -> dhc.AggHistogram: histref = ( (bh.axis.Regular(bins, range[0], range[1]),), - bh.storage.Weight, + bh.storage.Weight(), None, ) x = da.random.standard_normal(size=size, chunks=chunks) @@ -319,12 +319,12 @@ def test_agghist_to_delayed(weights): bh.axis.Regular(10, 0, 1), bh.axis.Regular(10, 0, 1), ), - bh.storage.Weight, + bh.storage.Weight(), None, ) h = bh.Histogram( *histref[0], - storage=histref[1](), + storage=histref[1], metadata=histref[2], ) if weights is not None: From f7c499e185bd615b998416c81a0133aeccbf7f2e Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Mon, 4 Mar 2024 08:29:43 -0600 Subject: [PATCH 2/7] reorg _blocked_multifill_dak and use a partial --- src/dask_histogram/core.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/dask_histogram/core.py b/src/dask_histogram/core.py index 89b3163..f0adba3 100644 --- a/src/dask_histogram/core.py +++ b/src/dask_histogram/core.py @@ -3,6 +3,7 @@ from __future__ import annotations import operator +from functools import partial from typing import TYPE_CHECKING, Any, Callable, Hashable, Literal, Mapping, Sequence import boost_histogram as bh @@ -407,13 +408,13 @@ def _blocked_dak_ma_w_s( def _blocked_multi_dak( - data_list: tuple[tuple[Any]], - weights: tuple[Any] | None, - samples: tuple[Any] | None, - histref: tuple | bh.Histogram | None = None, + repacker: Callable, + *flattened_inputs: tuple[Any], ) -> bh.Histogram: import awkward as ak + data_list, weights, samples, histref = repacker(flattened_inputs) + thehist = ( clone(histref) if not isinstance(histref, tuple) @@ -913,14 +914,6 @@ def _partitionwise(func, layer_name, *args, **kwargs): ) -class PackedMultifill: - def __init__(self, repacker): - self.repacker = repacker - - def __call__(self, *args): - return _blocked_multi_dak(*self.repacker(args)) - - def _partitioned_histogram_multifill( data: tuple[DaskCollection | tuple], histref: bh.Histogram | tuple, @@ -934,7 +927,9 @@ def _partitioned_histogram_multifill( flattened_deps, repacker = unpack_collections(data, weights, samples, histref) - graph = dak_pwl(PackedMultifill(repacker), name, *flattened_deps) + unpacked_multifill = partial(_blocked_multi_dak, repacker) + + graph = dak_pwl(unpacked_multifill, name, *flattened_deps) hlg = HighLevelGraph.from_collections(name, graph, dependencies=flattened_deps) return PartitionedHistogram( From a70c7793d9505d216992091bb2e322141702e7bb Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Mon, 4 Mar 2024 08:33:14 -0600 Subject: [PATCH 3/7] do not create intermediate list in hist_safe_sum --- src/dask_histogram/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dask_histogram/core.py b/src/dask_histogram/core.py index f0adba3..0d6d9b9 100644 --- a/src/dask_histogram/core.py +++ b/src/dask_histogram/core.py @@ -35,8 +35,7 @@ def hist_safe_sum(items): - safe_items = [item for item in items if not isinstance(item, tuple)] - return sum(safe_items) + return sum(item for item in items if not isinstance(item, tuple)) def clone(histref: bh.Histogram | None = None) -> bh.Histogram: From eda5a43f340d61ea0b1ce4d115db25ccd9bc668c Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Mon, 4 Mar 2024 10:56:28 -0600 Subject: [PATCH 4/7] use multifill over all input types to dask-boost-histogram --- src/dask_histogram/boost.py | 60 +++++++------------------- src/dask_histogram/core.py | 84 +++++++++++++++++++++++++++++++++---- 2 files changed, 91 insertions(+), 53 deletions(-) diff --git a/src/dask_histogram/boost.py b/src/dask_histogram/boost.py index e36939c..28dfc13 100644 --- a/src/dask_histogram/boost.py +++ b/src/dask_histogram/boost.py @@ -24,8 +24,6 @@ _get_optimization_function, _partitioned_histogram_multifill, _reduction, - factory, - is_dask_awkward_like, ) if TYPE_CHECKING: @@ -203,54 +201,26 @@ def _build_taskgraph(self): weights = [] samples = [] - dask_data = tuple( - datum - for datum in ( - self._staged[0]["args"] + tuple(self._staged[0]["kwargs"].values()) - ) - if is_dask_collection(datum) - ) - - if is_dask_awkward_like(dask_data[0]): - - for afill in self._staged: - data_list.append(afill["args"]) - weights.append(afill["kwargs"]["weight"]) - samples.append(afill["kwargs"]["sample"]) + for afill in self._staged: + data_list.append(afill["args"]) + weights.append(afill["kwargs"]["weight"]) + samples.append(afill["kwargs"]["sample"]) - if all(weight is None for weight in weights): - weights = None + if all(weight is None for weight in weights): + weights = None - if not all(sample is None for sample in samples): - samples = None - - split_every = self._split_every - if split_every is None: - split_every = dask.config.get("histogram.aggregation.split-every", 8) - - fills = _partitioned_histogram_multifill( - data_list, self._histref, weights, samples - ) + if not all(sample is None for sample in samples): + samples = None - output_hist = _reduction(fills, split_every) - else: - - first_fill = self._staged.pop() + split_every = self._split_every or dask.config.get( + "histogram.aggregation.split-every", 8 + ) - output_hist = factory( - *first_fill["args"], - histref=self._histref, - weights=first_fill["kwargs"]["weight"], - sample=first_fill["kwargs"]["sample"], - ) + fills = _partitioned_histogram_multifill( + data_list, self._histref, weights, samples + ) - for afill in self._staged: - output_hist += factory( - *afill["args"], - histref=self._histref, - weights=afill["kwargs"]["weight"], - sample=afill["kwargs"]["sample"], - ) + output_hist = _reduction(fills, split_every) self._staged = None self._staged_result = output_hist diff --git a/src/dask_histogram/core.py b/src/dask_histogram/core.py index 0d6d9b9..f46a0bd 100644 --- a/src/dask_histogram/core.py +++ b/src/dask_histogram/core.py @@ -406,6 +406,64 @@ def _blocked_dak_ma_w_s( return thehist.fill(*tuple(thedata), weight=theweights, sample=thesample) +def _blocked_multi( + repacker: Callable, + *flattened_inputs: tuple[Any], +) -> bh.Histogram: + + data_list, weights, samples, histref = repacker(flattened_inputs) + + weights = weights or (None for _ in range(len(data_list))) + samples = samples or (None for _ in range(len(data_list))) + + thehist = ( + clone(histref) + if not isinstance(histref, tuple) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + ) + + for ( + datatuple, + weight, + sample, + ) in zip(data_list, weights, samples): + data = datatuple + if len(data) == 1 and data[0].ndim == 2: + data = data[0].T + thehist.fill(*data, weight=weight, sample=sample) + + return thehist + + +def _blocked_multi_df( + repacker: Callable, + *flattened_inputs: tuple[Any], +) -> bh.Histogram: + + data_list, weights, samples, histref = repacker(flattened_inputs) + + weights = weights or (None for _ in range(len(data_list))) + samples = samples or (None for _ in range(len(data_list))) + + thehist = ( + clone(histref) + if not isinstance(histref, tuple) + else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) + ) + + for ( + datatuple, + weight, + sample, + ) in zip(data_list, weights, samples): + data = datatuple + if len(datatuple) == 1: + data = data[0] + thehist.fill(*(data[c] for c in data.columns), weight=weight, sample=sample) + + return thehist + + def _blocked_multi_dak( repacker: Callable, *flattened_inputs: tuple[Any], @@ -414,18 +472,22 @@ def _blocked_multi_dak( data_list, weights, samples, histref = repacker(flattened_inputs) + weights = weights or (None for _ in range(len(data_list))) + samples = samples or (None for _ in range(len(data_list))) + thehist = ( clone(histref) if not isinstance(histref, tuple) else bh.Histogram(*histref[0], storage=histref[1], metadata=histref[2]) ) - backend = ak.backend(*data_list[0]) - - for idata, data in enumerate(data_list): - weight = None if weights is None else weights[idata] - sample = None if samples is None else samples[idata] + backend = ak.backend(*flattened_inputs) + for ( + data, + weight, + sample, + ) in zip(data_list, weights, samples): if backend != "typetracer": thehist.fill(*data, weight=weight, sample=sample) else: @@ -926,9 +988,15 @@ def _partitioned_histogram_multifill( flattened_deps, repacker = unpack_collections(data, weights, samples, histref) - unpacked_multifill = partial(_blocked_multi_dak, repacker) - - graph = dak_pwl(unpacked_multifill, name, *flattened_deps) + if is_dask_awkward_like(flattened_deps[0]): + unpacked_multifill = partial(_blocked_multi_dak, repacker) + graph = dak_pwl(unpacked_multifill, name, *flattened_deps) + elif is_dataframe_like(flattened_deps[0]): + unpacked_multifill = partial(_blocked_multi_df, repacker) + graph = _partitionwise(unpacked_multifill, name, *flattened_deps) + else: + unpacked_multifill = partial(_blocked_multi, repacker) + graph = _partitionwise(unpacked_multifill, name, *flattened_deps) hlg = HighLevelGraph.from_collections(name, graph, dependencies=flattened_deps) return PartitionedHistogram( From 08c756b4dddbaef5f6a93fc7753bcb33c0888533 Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Mon, 4 Mar 2024 10:59:37 -0600 Subject: [PATCH 5/7] remove commented dead code --- src/dask_histogram/boost.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dask_histogram/boost.py b/src/dask_histogram/boost.py index 28dfc13..b947dd3 100644 --- a/src/dask_histogram/boost.py +++ b/src/dask_histogram/boost.py @@ -297,13 +297,12 @@ def fill( # type: ignore else: raise ValueError(f"Cannot interpret input data: {args}") - # new_fill = partitioned_factory(*args, histref=self._histref, weights=weight, sample=sample) new_fill = {"args": args, "kwargs": {"weight": weight, "sample": sample}} if self._staged is None: self._staged = [new_fill] else: self._staged.append(new_fill) - self._dask = None # self._staged.__dask_graph__() + self._dask = None self._dask_name = "__not_yet_calculated__" return self From eaf1ebed0adb0a9a052fb9a119ac5dd7fc084377 Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Mon, 4 Mar 2024 15:43:29 -0600 Subject: [PATCH 6/7] do not allow adding of dask-boost-histograms --- src/dask_histogram/boost.py | 8 +++----- tests/test_boost.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/dask_histogram/boost.py b/src/dask_histogram/boost.py index b947dd3..3f79269 100644 --- a/src/dask_histogram/boost.py +++ b/src/dask_histogram/boost.py @@ -110,11 +110,9 @@ def _histref(self): ) def __iadd__(self, other): - if self.staged_fills() and other.staged_fills(): - self._staged += other._staged - elif not self.staged_fills() and other.staged_fills(): - self._staged = other._staged - return self + raise NotImplementedError( + "dask-boost-histograms are not addable, please sum them after computation!" + ) def __add__(self, other): return self.__iadd__(other) diff --git a/tests/test_boost.py b/tests/test_boost.py index 3268e70..0dd68f5 100644 --- a/tests/test_boost.py +++ b/tests/test_boost.py @@ -485,14 +485,19 @@ def test_add(use_weights): h2 = dhb.Histogram(dhb.axis.Regular(12, -3, 3), storage=store()) h2.fill(y, weight=yweights) - h3 = h1 + h2 + with pytest.raises(NotImplementedError): + h3 = h1 + h2 - h3 = h3.compute() + h3 = h1.compute() + h2.compute() h4 = dhb.Histogram(dhb.axis.Regular(12, -3, 3), storage=store()) h4.fill(x, weight=xweights) - h4 += h2 + + with pytest.raises(NotImplementedError): + h4 += h2 + h4 = h4.compute() + h4 += h2.compute() controlx = bh.Histogram(*h1.axes, storage=h1.storage_type()) controly = bh.Histogram(*h2.axes, storage=h2.storage_type()) From 941e0b538dcedbfc31332c14b1cdd688e0fd4b21 Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Mon, 4 Mar 2024 15:46:36 -0600 Subject: [PATCH 7/7] hide dask-awkward partitionwise import --- src/dask_histogram/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dask_histogram/core.py b/src/dask_histogram/core.py index f46a0bd..4f3cec7 100644 --- a/src/dask_histogram/core.py +++ b/src/dask_histogram/core.py @@ -984,11 +984,12 @@ def _partitioned_histogram_multifill( name = f"hist-on-block-{tokenize(data, histref, weights, samples)}" from dask.base import unpack_collections - from dask_awkward.lib.core import partitionwise_layer as dak_pwl flattened_deps, repacker = unpack_collections(data, weights, samples, histref) if is_dask_awkward_like(flattened_deps[0]): + from dask_awkward.lib.core import partitionwise_layer as dak_pwl + unpacked_multifill = partial(_blocked_multi_dak, repacker) graph = dak_pwl(unpacked_multifill, name, *flattened_deps) elif is_dataframe_like(flattened_deps[0]):