From 7181f1cdc3cc3ef769abd8ea986e7071e7dda01b Mon Sep 17 00:00:00 2001 From: Aric Coady Date: Mon, 11 Dec 2023 21:15:51 -0800 Subject: [PATCH] Dict union operator. Test fix for variance in partition grouping. --- graphique/core.py | 8 ++++---- graphique/interface.py | 8 ++++---- graphique/shell.py | 8 ++------ tests/test_dataset.py | 5 +++-- 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/graphique/core.py b/graphique/core.py index bc001b2..f7b2f95 100644 --- a/graphique/core.py +++ b/graphique/core.py @@ -347,7 +347,7 @@ def union(*tables: Batch) -> Batch: """Return table with union of columns.""" columns: dict = {} for table in tables: - columns.update(Table.columns(table)) + columns |= Table.columns(table) return type(tables[0]).from_pydict(columns) def range(self, name: str, lower=None, upper=None, **includes) -> pa.Table: @@ -428,10 +428,10 @@ def aggregate(self, counts: str = '', **funcs: Sequence[Agg]) -> dict: row = {counts: len(self)} if counts else {} for key in ('one', 'list', 'distinct'): # hash only functions func, aggs = getattr(Agg, key), funcs.pop(key, []) - row.update({agg.alias: func(self[agg.name], **agg.options) for agg in aggs}) + row |= {agg.alias: func(self[agg.name], **agg.options) for agg in aggs} if funcs: table = Table.group(self, **funcs) # type: ignore - row.update({name: table[name][0] for name in table.schema.names}) + row |= {name: table[name][0] for name in table.schema.names} for name, value in row.items(): if isinstance(value, pa.ChunkedArray): row[name] = value.combine_chunks() @@ -614,7 +614,7 @@ def split(self) -> Iterator[Optional[pa.RecordBatch]]: yield None else: row = {name: pa.repeat(self[name][index], count) for name in scalars} - row.update({name: self[name][index].values for name in lists}) + row |= {name: self[name][index].values for name in lists} yield pa.RecordBatch.from_pydict(row) def size(self) -> str: diff --git a/graphique/interface.py b/graphique/interface.py index 278529f..871beed 100644 --- a/graphique/interface.py +++ b/graphique/interface.py @@ -277,11 +277,11 @@ def fragments(self, info: Info, counts: str = '', aggregate: HashAggregates = {} row = ds.get_partition_keys(fragment.partition_expression) if projection: table = fragment.to_table(columns=projection) - row.update(T.aggregate(table, counts=counts, **aggs)) + row |= T.aggregate(table, counts=counts, **aggs) elif counts: row[counts] = fragment.count_rows() arrays = {name: value for name, value in row.items() if isinstance(value, pa.Array)} - row.update(T.columns(pa.RecordBatch.from_pydict(arrays))) + row |= T.columns(pa.RecordBatch.from_pydict(arrays)) for name in row: columns[name].append(row[name]) for name, values in columns.items(): @@ -289,7 +289,7 @@ def fragments(self, info: Info, counts: str = '', aggregate: HashAggregates = {} columns[name] = C.from_scalars(values) elif isinstance(values[0], pa.Array): columns[name] = ListChunk.from_scalars(values) - columns.update({field.name: pa.array(columns[field.name], field.type) for field in schema}) + columns |= {field.name: pa.array(columns[field.name], field.type) for field in schema} return self.add_metric(info, pa.table(columns), mode='fragment') @doc_field( @@ -482,7 +482,7 @@ def aggregate( def project(self, info: Info, columns: list[Projection]) -> dict: """Return projected columns, including all references from below fields.""" projection = {name: pc.field(name) for name in self.references(info, level=1)} - projection.update({col.alias or '.'.join(col.name): col.to_arrow() for col in columns}) + projection |= {col.alias or '.'.join(col.name): col.to_arrow() for col in columns} if '' in projection: raise ValueError(f"projected columns need a name or alias: {projection['']}") return projection diff --git a/graphique/shell.py b/graphique/shell.py index a694dba..3b8c12f 100644 --- a/graphique/shell.py +++ b/graphique/shell.py @@ -17,12 +17,8 @@ def write_batches(scanner: ds.Scanner, base_dir: str, *partitioning: str, **options): """Partition dataset by batches.""" - options.update( - format=options.get('format', 'parquet'), - partitioning=partitioning, - partitioning_flavor=options.get('partitioning_flavor', 'hive'), - existing_data_behavior='overwrite_or_ignore', - ) + options.update(format='parquet', partitioning=partitioning) + options.update(partitioning_flavor='hive', existing_data_behavior='overwrite_or_ignore') with tqdm(total=scanner.count_rows(), desc="Batches") as pbar: for index, batch in enumerate(scanner.to_batches()): options['basename_template'] = f'part-{index}-{{i}}.parquet' diff --git a/tests/test_dataset.py b/tests/test_dataset.py index e8a8743..43e62f6 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -121,8 +121,9 @@ def test_list(partclient): '''{ group(by: "north", aggregate: {distinct: {name: "west"}}) { tables { row { north } columns { west { length } } } } }''' ) - table = data['group']['tables'][0] - assert table == {'row': {'north': 0}, 'columns': {'west': {'length': 2}}} + tables = data['group']['tables'] + assert {table['row']['north'] for table in tables} == {0, 1} + assert [table['columns'] for table in tables] == [{'west': {'length': 2}}] * 2 def test_fragments(partclient):