Skip to content

Commit

Permalink
Dict union operator.
Browse files Browse the repository at this point in the history
Test fix for variance in partition grouping.
  • Loading branch information
coady committed Dec 12, 2023
1 parent bde7b94 commit 7181f1c
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 16 deletions.
8 changes: 4 additions & 4 deletions graphique/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def union(*tables: Batch) -> Batch:
"""Return table with union of columns."""
columns: dict = {}
for table in tables:
columns.update(Table.columns(table))
columns |= Table.columns(table)
return type(tables[0]).from_pydict(columns)

def range(self, name: str, lower=None, upper=None, **includes) -> pa.Table:
Expand Down Expand Up @@ -428,10 +428,10 @@ def aggregate(self, counts: str = '', **funcs: Sequence[Agg]) -> dict:
row = {counts: len(self)} if counts else {}
for key in ('one', 'list', 'distinct'): # hash only functions
func, aggs = getattr(Agg, key), funcs.pop(key, [])
row.update({agg.alias: func(self[agg.name], **agg.options) for agg in aggs})
row |= {agg.alias: func(self[agg.name], **agg.options) for agg in aggs}
if funcs:
table = Table.group(self, **funcs) # type: ignore
row.update({name: table[name][0] for name in table.schema.names})
row |= {name: table[name][0] for name in table.schema.names}
for name, value in row.items():
if isinstance(value, pa.ChunkedArray):
row[name] = value.combine_chunks()
Expand Down Expand Up @@ -614,7 +614,7 @@ def split(self) -> Iterator[Optional[pa.RecordBatch]]:
yield None
else:
row = {name: pa.repeat(self[name][index], count) for name in scalars}
row.update({name: self[name][index].values for name in lists})
row |= {name: self[name][index].values for name in lists}
yield pa.RecordBatch.from_pydict(row)

def size(self) -> str:
Expand Down
8 changes: 4 additions & 4 deletions graphique/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,19 +277,19 @@ def fragments(self, info: Info, counts: str = '', aggregate: HashAggregates = {}
row = ds.get_partition_keys(fragment.partition_expression)
if projection:
table = fragment.to_table(columns=projection)
row.update(T.aggregate(table, counts=counts, **aggs))
row |= T.aggregate(table, counts=counts, **aggs)
elif counts:
row[counts] = fragment.count_rows()
arrays = {name: value for name, value in row.items() if isinstance(value, pa.Array)}
row.update(T.columns(pa.RecordBatch.from_pydict(arrays)))
row |= T.columns(pa.RecordBatch.from_pydict(arrays))
for name in row:
columns[name].append(row[name])
for name, values in columns.items():
if isinstance(values[0], pa.Scalar):
columns[name] = C.from_scalars(values)
elif isinstance(values[0], pa.Array):
columns[name] = ListChunk.from_scalars(values)
columns.update({field.name: pa.array(columns[field.name], field.type) for field in schema})
columns |= {field.name: pa.array(columns[field.name], field.type) for field in schema}
return self.add_metric(info, pa.table(columns), mode='fragment')

@doc_field(
Expand Down Expand Up @@ -482,7 +482,7 @@ def aggregate(
def project(self, info: Info, columns: list[Projection]) -> dict:
"""Return projected columns, including all references from below fields."""
projection = {name: pc.field(name) for name in self.references(info, level=1)}
projection.update({col.alias or '.'.join(col.name): col.to_arrow() for col in columns})
projection |= {col.alias or '.'.join(col.name): col.to_arrow() for col in columns}
if '' in projection:
raise ValueError(f"projected columns need a name or alias: {projection['']}")
return projection
Expand Down
8 changes: 2 additions & 6 deletions graphique/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@

def write_batches(scanner: ds.Scanner, base_dir: str, *partitioning: str, **options):
"""Partition dataset by batches."""
options.update(
format=options.get('format', 'parquet'),
partitioning=partitioning,
partitioning_flavor=options.get('partitioning_flavor', 'hive'),
existing_data_behavior='overwrite_or_ignore',
)
options.update(format='parquet', partitioning=partitioning)
options.update(partitioning_flavor='hive', existing_data_behavior='overwrite_or_ignore')
with tqdm(total=scanner.count_rows(), desc="Batches") as pbar:
for index, batch in enumerate(scanner.to_batches()):
options['basename_template'] = f'part-{index}-{{i}}.parquet'
Expand Down
5 changes: 3 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ def test_list(partclient):
'''{ group(by: "north", aggregate: {distinct: {name: "west"}}) {
tables { row { north } columns { west { length } } } } }'''
)
table = data['group']['tables'][0]
assert table == {'row': {'north': 0}, 'columns': {'west': {'length': 2}}}
tables = data['group']['tables']
assert {table['row']['north'] for table in tables} == {0, 1}
assert [table['columns'] for table in tables] == [{'west': {'length': 2}}] * 2


def test_fragments(partclient):
Expand Down

0 comments on commit 7181f1c

Please sign in to comment.