Skip to content

Commit

Permalink
Remove oneshot scanners.
Browse files Browse the repository at this point in the history
Every operations gets at least two scans before loading, and acero nodes will further minimize memory usage.
  • Loading branch information
coady committed Sep 29, 2024
1 parent 5be87e6 commit 574c0b7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 39 deletions.
7 changes: 3 additions & 4 deletions graphique/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,8 @@ def find(self, *values) -> Iterator[slice]:
class Table(pa.Table):
"""Table interface as a namespace of functions."""

def map_batch(scanner: ds.Scanner, func: Callable, *rargs, **kwargs) -> pa.Table:
# TODO(apache/arrow#31612): replace with user defined function for multiple kernels
batches = [func(batch, *rargs, **kwargs) for batch in scanner.to_batches() if batch]
return pa.Table.from_batches(batches, None if batches else scanner.projected_schema)
def map_batch(self, func: Callable, *args, **kwargs) -> pa.Table:
return pa.Table.from_batches(func(batch, *args, **kwargs) for batch in self.to_batches())

def columns(self) -> dict:
"""Return columns as a dictionary."""
Expand Down Expand Up @@ -639,6 +637,7 @@ class Nodes(ac.Declaration):
'order_by': ac.OrderByNodeOptions,
'hashjoin': ac.HashJoinNodeOptions,
}
to_batches = ac.Declaration.to_reader # source compatibility

def __init__(self, name, *args, inputs=None, **options):
super().__init__(name, self.option_map[name](*args, **options), inputs)
Expand Down
29 changes: 8 additions & 21 deletions graphique/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .inputs import CountAggregate, Cumulative, Diff, Expression, Field, Filter
from .inputs import HashAggregates, ListFunction, Pairwise, Projection, Rank
from .inputs import ScalarAggregate, TDigestAggregate, VarianceAggregate, links, provisional
from .models import Column, doc_field, selections
from .models import Column, doc_field
from .scalars import Long

Source = Union[ds.Dataset, Nodes, ds.Scanner, pa.Table]
Expand Down Expand Up @@ -416,7 +416,7 @@ def apply(
Applied functions load arrays into memory as needed. See `scan` for scalar functions,
which do not require loading.
"""
table = T.map_batch(self.scanner(info), self.apply_list, list_)
table = T.map_batch(self.select(info), self.apply_list, list_)
self.add_metric(info, table, mode='batch')
columns = {}
funcs = pc.cumulative_max, pc.cumulative_mean, pc.cumulative_min, pc.cumulative_prod
Expand All @@ -436,18 +436,16 @@ def flatten(self, info: Info, indices: str = '') -> Self:
At least one list column must be referenced, and all list columns must have the same lengths.
"""
batches = T.flatten(self.scanner(info), indices)
batch = next(batches)
scanner = ds.Scanner.from_batches(itertools.chain([batch], batches), schema=batch.schema)
return type(self)(self.oneshot(info, scanner))
table = pa.Table.from_batches(T.flatten(self.select(info), indices))
return type(self)(self.add_metric(info, table, mode='batch'))

@doc_field
def tables(self, info: Info) -> list[Optional[Self]]: # type: ignore
"""Return a list of tables by splitting list columns.
At least one list column must be referenced, and all list columns must have the same lengths.
"""
for batch in self.scanner(info).to_batches():
for batch in self.select(info).to_batches():
for row in T.split(batch):
yield None if row is None else type(self)(pa.Table.from_batches([row]))

Expand Down Expand Up @@ -502,15 +500,6 @@ def project(self, info: Info, columns: list[Projection]) -> dict:
raise ValueError(f"projected columns need a name or alias: {projection['']}")
return projection

@classmethod
def oneshot(cls, info: Info, scanner: ds.Scanner) -> Union[ds.Scanner, pa.Table]:
"""Load oneshot scanner if needed."""
selected = selections(*info.selected_fields)
selected['type'] = selected['schema'] = 0
if sum(selected.values()) > 1:
return cls.add_metric(info, scanner.to_table(), mode='oneshot')
return scanner

@doc_field(filter="selected rows", columns="projected columns")
def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] = []) -> Self: # type: ignore
"""Select rows and project columns without memory usage."""
Expand All @@ -519,7 +508,7 @@ def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] =
return type(self)(self.source.filter(expr))
scanner = self.scanner(info, filter=expr, columns=self.project(info, columns))
if isinstance(self.source, ds.Scanner):
scanner = self.oneshot(info, scanner)
scanner = self.add_metric(info, scanner.to_table(), mode='batch')
return type(self)(scanner)

@doc_field(
Expand Down Expand Up @@ -571,7 +560,5 @@ def drop_null(self, info: Info) -> Self:
"""Remove missing values from referenced columns in the table."""
if isinstance(self.source, pa.Table):
return type(self)(pc.drop_null(self.to_table(info)))
scanner = self.scanner(info)
batches = map(pc.drop_null, scanner.to_batches())
scanner = ds.Scanner.from_batches(batches, schema=scanner.projected_schema)
return type(self)(self.oneshot(info, scanner))
table = T.map_batch(self.select(info), pc.drop_null)
return type(self)(self.add_metric(info, table, mode='batch'))
18 changes: 4 additions & 14 deletions graphique/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
GraphQL output types and resolvers.
"""

import collections
import functools
import inspect
import itertools
from collections.abc import Callable
from datetime import date, datetime, time, timedelta
from decimal import Decimal
Expand All @@ -24,17 +22,9 @@
T = TypeVar('T')


def _selections(field):
for selection in field.selections:
if hasattr(selection, 'name'):
yield selection.name
else:
yield from _selections(selection)


def selections(*fields) -> dict:
"""Return counts of field name selections from strawberry `SelectedField`."""
return collections.Counter(itertools.chain(*map(_selections, fields)))
def selections(*fields) -> set:
"""Return field name selections from strawberry `SelectedField`."""
return {selection.name for field in fields for selection in field.selections}


def doc_field(func: Optional[Callable] = None, **kwargs: str) -> StrawberryField:
Expand Down Expand Up @@ -260,7 +250,7 @@ def take_from(
) -> Optional[Annotated['Dataset', strawberry.lazy('.interface')]]:
"""Select indices from a table on the root Query type."""
root = getattr(info.root_value, field)
return type(root)(root.scanner(info).take(self.array.combine_chunks()))
return type(root)(root.select(info).take(self.array.combine_chunks()))


@Column.register(list)
Expand Down

0 comments on commit 574c0b7

Please sign in to comment.