Skip to content

Commit

Permalink
Grouping simplified.
Browse files Browse the repository at this point in the history
Fragment optimization is no longer faster than parallelization when columns are read.
  • Loading branch information
coady committed Oct 6, 2024
1 parent 70fb2b4 commit 5ffe5a5
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 205 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

## Unreleased
### Changed
* Acero engine used for scanning
* Grouping defaults to parallelized but unordered
* Partitioning supports facet counts and arbitrary functions
* Partitioning supports arbitrary functions
* `group` optimized for dictionary arrays
* `rank` optimized for out-of-core

Expand Down
103 changes: 29 additions & 74 deletions graphique/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import itertools
import operator
import json
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from collections.abc import Callable, Iterable, Iterator, Mapping
from dataclasses import dataclass
from typing import Optional, Union, get_type_hints
import numpy as np
Expand Down Expand Up @@ -57,26 +57,14 @@ class Agg:
associatives = {'all', 'any', 'first', 'last', 'max', 'min', 'one', 'product', 'sum'}
associatives |= {'count'} # transformed to be associative
ordered = {'first', 'last'}
count_all: tuple = [], 'hash_count_all', None

def __init__(self, name: str, alias: str = '', **options):
self.name = name
self.alias = alias or name
self.options = options

def astuple(self, func: str) -> tuple:
options = self.option_map[func.rpartition('hash_')[-1]](**self.options)
return self.name, func, options, self.alias

one = operator.itemgetter(0)
list = staticmethod(lambda array: array)

@staticmethod
def distinct(array: Array, mode: str = 'only_valid') -> pa.Array:
values = pc.unique(array)
if not values.null_count or mode == 'all':
return values
return values.drop_null() if mode == 'only_valid' else pa.array([None], array.type)
def func_options(self, func: str) -> pc.FunctionOptions:
return self.option_map[func.removeprefix('hash_')](**self.options)


@dataclass(frozen=True)
Expand Down Expand Up @@ -405,44 +393,6 @@ def runs(self, *names: str, **predicates: tuple) -> tuple:
table = Table.union(scalars, Table.from_offsets(lists, offsets))
return table, Column.diff(offsets)

def group(
self, *names: str, counts: str = '', ordered: bool = False, **funcs: Sequence[Agg]
) -> pa.Table:
"""Group by and aggregate.
Args:
*names: columns to group by
counts: alias for optional row counts
ordered: do not use threads
**funcs: aggregate funcs with columns options
"""
prefix = 'hash_' if names else ''
aggs = {}
for func in funcs:
for agg in funcs[func]:
*value, name = agg.astuple(prefix + func)
aggs[name] = tuple(value)
if counts:
aggs[counts] = Agg.count_all
if isinstance(self, pa.Table):
self = ds.dataset(self)
use_threads = not ordered and Agg.ordered.isdisjoint(funcs)
return Nodes.group(self, *names, **aggs).to_table(use_threads)

def aggregate(self, counts: str = '', **funcs: Sequence[Agg]) -> dict:
"""Return aggregated scalars as a row of data."""
row = {counts: len(self)} if counts else {}
for key in ('one', 'list', 'distinct'): # hash only functions
func, aggs = getattr(Agg, key), funcs.pop(key, [])
row |= {agg.alias: func(self[agg.name], **agg.options) for agg in aggs}
if funcs:
table = Table.group(self, **funcs) # type: ignore
row |= {name: table[name][0] for name in table.schema.names}
for name, value in row.items():
if isinstance(value, pa.ChunkedArray):
row[name] = value.combine_chunks()
return row

def list_fields(self) -> set:
return {field.name for field in self.schema if Column.is_list_type(field)}

Expand Down Expand Up @@ -543,19 +493,26 @@ def _(self: pa.RecordBatch, k, *names):
return Table.min_max(self, *names)
return next(Table.rank(ds.dataset(self), k, *names).to_batches())

def get_fragments(self) -> Iterator[ds.Fragment]:
"""Support filtered datasets if it only references partition keys."""
expr = self._scan_options.get('filter')
if expr is not None: # raise ValueError if filter references other fields
ds.dataset([], schema=self.partitioning.schema).scanner(filter=expr)
return self._get_fragments(expr)

def fragment_keys(self) -> list:
"""Filtered partitioned datasets may not have fragments."""
with contextlib.suppress(AttributeError, ValueError):
Table.get_fragments(self)
return self.partitioning.schema.names
return []
def fragments(self, *names, counts: str = '') -> pa.Table:
"""Return selected fragment keys in a table."""
try:
expr = self._scan_options.get('filter')
if expr is not None: # raise ValueError if filter references other fields
ds.dataset([], schema=self.partitioning.schema).scanner(filter=expr)
except (AttributeError, ValueError):
return pa.table({})
fragments = self._get_fragments(expr)
parts = [ds.get_partition_keys(frag.partition_expression) for frag in fragments]
names, table = set(names), pa.Table.from_pylist(parts) # type: ignore
keys = [name for name in table.schema.names if name in names]
table = table.group_by(keys, use_threads=False).aggregate([])
if not counts:
return table
if not table.schema:
return table.append_column(counts, pa.array([self.count_rows()]))
exprs = [bit_all(pc.field(key) == row[key] for key in row) for row in table.to_pylist()]
column = [self.filter(expr).count_rows() for expr in exprs]
return table.append_column(counts, pa.array(column))

def rank_keys(self, k: int, *names: str, dense: bool = True) -> tuple:
"""Return expression and unmatched fields for partitioned dataset which filters by rank.
Expand All @@ -565,21 +522,19 @@ def rank_keys(self, k: int, *names: str, dense: bool = True) -> tuple:
*names: columns to rank by
dense: use dense rank; false indicates sorting
"""
schema = set(Table.fragment_keys(self))
keys = dict(itertools.takewhile(lambda key: key[0] in schema, map(sort_key, names)))
keys = dict(map(sort_key, names))
table = Table.fragments(self, *keys, counts='' if dense else '_')
keys = {name: keys[name] for name in table.schema.names if name in keys}
if not keys:
return None, names
parts = [ds.get_partition_keys(frag.partition_expression) for frag in self.get_fragments()]
table = pa.Table.from_pylist(parts).group_by(keys).aggregate([])
if dense:
table = table.take(pc.select_k_unstable(table, k, keys.items()))
else:
table = table.sort_by(keys.items())
totals = itertools.accumulate(table['_'].to_pylist())
counts = (count for count, total in enumerate(totals, 1) if total >= k)
table = table[: next(counts, None)].remove_column(len(table) - 1)
exprs = [bit_all(pc.field(key) == row[key] for key in row) for row in table.to_pylist()]
totals = itertools.accumulate(self.filter(expr).count_rows() for expr in exprs)
counts = (count for count, total in enumerate(totals, 1) if total >= k)
if not dense:
table = table[: next(counts, None)]
remaining = names[len(keys) :]
if remaining or not dense: # fields with a single value are no longer needed
selectors = [len(table[key].unique()) > 1 for key in keys]
Expand Down
73 changes: 19 additions & 54 deletions graphique/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from strawberry import Info
from strawberry.extensions.utils import get_path_from_info
from typing_extensions import Self
from .core import Batch, Column as C, ListChunk, Nodes, Table as T
from .core import Agg, Batch, Column as C, ListChunk, Nodes, Table as T
from .inputs import CountAggregate, Cumulative, Diff, Expression, Field, Filter
from .inputs import HashAggregates, ListFunction, Pairwise, Projection, Rank
from .inputs import ScalarAggregate, TDigestAggregate, VarianceAggregate, links, provisional
Expand Down Expand Up @@ -247,58 +247,23 @@ def group(
ordered: bool = False,
aggregate: HashAggregates = {}, # type: ignore
) -> Self:
"""Return table grouped by columns.
See `column` for accessing any column which has changed type. See `tables` to split on any
aggregated list columns.
"""
source, aggs = self.source, dict(aggregate)
refs = {agg.name for values in aggs.values() for agg in values}
fragments = set(T.fragment_keys(self.source))
if isinstance(source, ds.Scanner):
source = self.to_table(info)
if fragments and set(by) <= fragments:
if set(by) == fragments:
return type(self)(self.fragments(info, counts, aggregate))
if fragments.isdisjoint(refs) and set(aggs) <= Field.associatives:
source = self.fragments(info, counts, aggregate)
aggs.setdefault('sum', []).extend(Field(agg.alias) for agg in aggs.pop('count', []))
if counts:
aggs['sum'].append(Field(counts))
counts = ''
for agg in itertools.chain(*aggs.values()):
agg.name = agg.alias
loaded = isinstance(source, pa.Table)
table = T.group(source, *by, counts=counts, ordered=ordered, **aggs)
return type(self)(table if loaded else self.add_metric(info, table, mode='group'))

def fragments(self, info: Info, counts: str = '', aggregate: HashAggregates = {}) -> pa.Table: # type: ignore
"""Return table from scanning fragments and grouping by partitions.
Requires a partitioned dataset. Faster and less memory intensive than `group`.
"""
schema = self.source.partitioning.schema # requires a Dataset
aggs = dict(aggregate)
names = self.references(info, level=1)
names.update(agg.name for value in aggs.values() for agg in value)
projection = {name: pc.field(name) for name in names - set(schema.names)}
columns = collections.defaultdict(list)
for fragment in T.get_fragments(self.source):
row = ds.get_partition_keys(fragment.partition_expression)
if projection:
table = fragment.to_table(columns=projection)
row |= T.aggregate(table, counts=counts, **aggs)
elif counts:
row[counts] = fragment.count_rows()
arrays = {name: value for name, value in row.items() if isinstance(value, pa.Array)}
row |= T.columns(pa.RecordBatch.from_pydict(arrays))
for name in row:
columns[name].append(row[name])
for name, values in columns.items():
if isinstance(values[0], pa.Array):
columns[name] = ListChunk.from_scalars(values)
columns |= {field.name: pa.array(columns[field.name], field.type) for field in schema}
return self.add_metric(info, pa.table(columns), mode='fragment')
if not any(aggregate.keys()):
fragments = T.fragments(self.source, *by, counts=counts)
if set(fragments.schema.names) >= set(by):
return type(self)(fragments)
prefix = 'hash_' if by else ''
aggs: dict = {counts: ([], prefix + 'count_all', None)} if counts else {}
for func, values in dict(aggregate).items():
ordered = ordered or func in Agg.ordered
for agg in values:
aggs[agg.alias] = (agg.name, prefix + func, agg.func_options(func))
source = self.to_table(info) if isinstance(self.source, ds.Scanner) else self.source
if isinstance(source, pa.Table):
source = ds.dataset(source)
source = Nodes.group(source, *by, **aggs)
if ordered:
source = self.add_metric(info, source.to_table(use_threads=False), mode='group')
return type(self)(source)

@doc_field(
by="column names",
Expand Down Expand Up @@ -476,7 +441,7 @@ def aggregate(
else:
columns[agg.alias] = func(table[agg.name], **agg.options)
for name, aggs in agg_fields.items():
funcs = {key: agg.astuple(key)[2] for key, agg in aggs.items()}
funcs = {key: agg.func_options(key) for key, agg in aggs.items()}
batch = ListChunk.aggregate(table[name], **funcs)
columns.update(zip([agg.alias for agg in aggs.values()], batch))
return type(self)(pa.table(columns))
Expand Down
4 changes: 2 additions & 2 deletions tests/test_bench.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pytest
from graphique.core import Table as T
from graphique.core import Nodes, Table as T


@pytest.mark.benchmark
def test_group(table):
T.group(table, 'state', 'county', 'city')
Nodes('table_source', table).group('state', 'county', 'city')
T.runs(table, 'state', 'county', 'city')


Expand Down
73 changes: 4 additions & 69 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pyarrow.compute as pc
import pyarrow.dataset as ds
import pytest
from graphique.core import Agg, ListChunk, Nodes, Column as C, Table as T
from graphique.core import ListChunk, Nodes, Column as C, Table as T
from graphique.scalars import parse_duration, duration_isoformat


Expand Down Expand Up @@ -41,19 +41,11 @@ def test_dictionary(table):


def test_chunks():
array = pa.chunked_array([list('aba'), list('bcb')])
table = pa.table({'col': array})
groups = T.group(table, 'col', counts='counts')
assert dict(zip(*groups.to_pydict().values())) == {'a': 2, 'b': 3, 'c': 1}
array = pa.chunked_array([pa.array(list(chunk)).dictionary_encode() for chunk in ('aba', 'ca')])
assert C.index(array, 'a') == 0
assert C.index(array, 'c') == 3
assert C.index(array, 'a', start=3) == 4
assert C.index(array, 'b', start=2) == -1
table = pa.table({'col': array})
tbl = T.group(table, 'col', ordered=True, count_distinct=[Agg('col', 'count')])
assert tbl['col'].to_pylist() == list('abc')
assert tbl['count'].to_pylist() == [1] * 3


def test_lists():
Expand Down Expand Up @@ -81,6 +73,8 @@ def test_lists():
assert not T.from_offsets(pa.table({}), pa.array([0]))
array = ListChunk.from_counts(pa.array([3, None, 2]), list('abcde'))
assert array.to_pylist() == [list('abc'), None, list('de')]
with pytest.raises(ValueError):
T.list_value_length(pa.table({'x': array, 'y': array[::-1]}))


def test_membership():
Expand All @@ -95,7 +89,7 @@ def test_nodes(table):
assert Nodes.scan(dataset).to_table()['state'].unique().to_pylist() == ['CA']
(column,) = Nodes.scan(dataset, columns={'_': pc.field('state')}).to_table()
assert column.unique().to_pylist() == ['CA']
table = Nodes.group(dataset, 'county', 'city', counts=Agg.count_all).to_table()
table = Nodes.group(dataset, 'county', 'city', counts=([], 'hash_count_all', None)).to_table()
assert len(table) == 1241
assert pc.sum(table['counts']).as_py() == 2647
scanner = Nodes.scan(dataset, columns=['state'])
Expand All @@ -106,65 +100,6 @@ def test_nodes(table):
assert scanner.take([0, 2]) == pa.table({'state': ['CA'] * 2})


def test_group(table):
groups = T.group(table, 'state', list=[Agg('county'), Agg('city')])
assert len(groups) == 52
assert groups['state'][0].as_py() == 'NY'
assert len(pa.Table.from_batches(T.flatten(groups))) == len(table)
table = T.filter_list(groups, pc.field('county') == pc.field('city'))
assert len(pc.list_flatten(table['city'])) == 2805
groups = T.map_list(groups, T.sort, 'county')
assert groups['county'][0].values[0].as_py() == 'Albany'
groups = T.map_list(groups, T.sort, '-county', '-city', length=1, null_placement='at_start')
assert groups['county'][0].values.to_pylist() == ['Yates']
assert groups['city'][0].values.to_pylist() == ['Rushville']
groups = groups.append_column('other', pa.array([[0]] * len(groups)))
with pytest.raises(ValueError):
T.map_list(groups, T.sort, 'county')
groups = T.group(table, first=[Agg('state')])
assert groups['state'].to_pylist() == ['NY']


def test_aggregate(table):
tbl = T.union(table, table.select([0]).rename_columns(['test']))
assert tbl.schema.names == table.schema.names + ['test']
groups = T.group(table, 'state', 'county')
assert len(groups) == 3216
assert groups.schema.names == ['state', 'county']
groups = T.group(table, 'state', counts='counts', first=[Agg('county')])
assert len(groups) == 52
assert groups['state'][0].as_py() == 'NY'
assert groups['counts'][0].as_py() == 2205
assert groups['county'][0].as_py() == 'Suffolk'
groups = T.group(table, 'state', last=[Agg('city', 'last')], min=[Agg('zipcode')])
assert groups['last'][0].as_py() == 'Elmira'
assert groups['zipcode'][0].as_py() == 501
groups = T.group(table, 'state', max=[Agg('zipcode', 'max', skip_nulls=False)])
assert groups['max'][0].as_py() == 14925
groups = T.group(table, 'state', min_max=[Agg('zipcode')])
assert groups['zipcode'][0].as_py() == {'min': 501, 'max': 14925}
groups = T.group(
table, 'state', approximate_median=[Agg('longitude')], tdigest=[Agg('latitude')]
)
assert groups['longitude'][0].as_py() == pytest.approx(-74.25370)
assert groups['latitude'][0].as_py() == [pytest.approx(42.34672)]
row = T.aggregate(table, min=[Agg('state')], list=[Agg('zipcode')])
assert row['state'].as_py() == 'AK'
assert row['zipcode'] == table['zipcode'].combine_chunks()
row = T.aggregate(table, counts='counts')
assert row['counts'] == 41700
row = T.aggregate(table, first=[Agg('zipcode')])
assert row['zipcode'].as_py() == 501
row = T.aggregate(table, last=[Agg('zipcode')])
assert row['zipcode'].as_py() == 99950
nulls = pa.table({'': [0, None, 0]})
row = T.aggregate(nulls, list=[Agg('')])
assert row[''].to_pylist() == [0, None, 0]
row = T.aggregate(nulls, distinct=[Agg('', 'd1'), Agg('', 'd2', mode='all')])
assert row['d1'].to_pylist() == [0]
assert row['d2'].to_pylist() == [0, None]


def test_runs(table):
groups, counts = T.runs(table, 'state')
assert len(groups) == len(counts) == 66
Expand Down
Loading

0 comments on commit 5ffe5a5

Please sign in to comment.