Skip to content

Commit

Permalink
Generalized scanning and projection.
Browse files Browse the repository at this point in the history
  • Loading branch information
coady committed Oct 12, 2024
1 parent 6fe9f9e commit 5a3f440
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 44 deletions.
35 changes: 13 additions & 22 deletions graphique/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ class Agg:
'tdigest': pc.TDigestOptions,
'variance': pc.VarianceOptions,
}

associatives = {'all', 'any', 'first', 'last', 'max', 'min', 'one', 'product', 'sum'}
associatives |= {'count'} # transformed to be associative
ordered = {'first', 'last'}

def __init__(self, name: str, alias: str = '', **options):
Expand Down Expand Up @@ -597,14 +594,18 @@ class Nodes(ac.Declaration):
def __init__(self, name, *args, inputs=None, **options):
super().__init__(name, self.option_map[name](*args, **options), inputs)

@classmethod
def scan(cls, dataset: ds.Dataset, columns: Optional[Iterable] = None) -> Self:
"""Return source node from a dataset."""
self = cls('scan', dataset, columns=columns)
expr = dataset._scan_options.get('filter')
if expr is not None:
self = self.apply('filter', expr)
return self if columns is None else self.project(columns)
def scan(self, columns: Iterable[str]) -> Self:
"""Return projected source node, supporting datasets and tables."""
if isinstance(self, ds.Dataset):
expr = self._scan_options.get('filter')
self = Nodes('scan', self, columns=columns)
if expr is not None:
self = self.apply('filter', expr)
elif isinstance(self, pa.Table):
self = Nodes('table_source', self)
if isinstance(columns, Mapping):
return self.apply('project', columns.values(), columns)
return self.apply('project', map(pc.field, columns))

@property
def schema(self) -> pa.Schema:
Expand Down Expand Up @@ -633,12 +634,6 @@ def apply(self, name: str, *args, **options) -> Self:

filter = functools.partialmethod(apply, 'filter')

def project(self, columns: Union[Mapping[str, pc.Expression], Iterable[str]]) -> Self:
"""Add `project` node from columns names with optional expressions."""
if isinstance(columns, Mapping):
return self.apply('project', columns.values(), columns)
return self.apply('project', map(pc.field, columns))

def group(self, *names, **aggs: tuple) -> Self:
"""Add `aggregate` node with dictionary support.
Expand All @@ -653,8 +648,4 @@ def group(self, *names, **aggs: tuple) -> Self:
field = self.schema.field(name)
if pa.types.is_dictionary(field.type):
columns[name] = columns[name].cast(field.type.value_type)
if isinstance(self, ds.Dataset):
self = Nodes.scan(self, columns)
else:
self = self.project(columns)
return self.apply('aggregate', aggregates, names)
return Nodes.scan(self, columns).apply('aggregate', aggregates, names)
22 changes: 5 additions & 17 deletions graphique/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,12 @@ def select(self, info: Info) -> Source:
names = list(self.references(info))
if len(names) >= len(self.schema().names):
return self.source
if isinstance(self.source, ds.Dataset):
return Nodes.scan(self.source, names)
if isinstance(self.source, Nodes):
return self.source.project(names)
if isinstance(self.source, ds.Scanner):
schema = self.source.projected_schema
return ds.Scanner.from_batches(self.source.to_batches(), schema=schema, columns=names)
return self.source.select(names)
if isinstance(self.source, pa.Table):
return self.source.select(names)
return Nodes.scan(self.source, names)

def to_table(self, info: Info, length: Optional[int] = None) -> pa.Table:
"""Return table with only the rows and columns necessary to proceed."""
Expand Down Expand Up @@ -143,9 +141,7 @@ def filter(self, info: Info, **queries: Filter) -> Self:
source = T.range(source, name, lower, upper, **includes)
if len(query.pop('eq', [])) != 1 or query:
break
self = type(self)(source)
expr = Expression.from_query(**queries)
return self if expr.to_arrow() is None else self.scan(info, filter=expr)
return type(self)(source).scan(info, filter=Expression.from_query(**queries))

@doc_field
def type(self) -> str:
Expand Down Expand Up @@ -258,8 +254,6 @@ def group(
for agg in values:
aggs[agg.alias] = (agg.name, prefix + func, agg.func_options(func))
source = self.to_table(info) if isinstance(self.source, ds.Scanner) else self.source
if isinstance(source, pa.Table):
source = ds.dataset(source)
source = Nodes.group(source, *by, **aggs)
if ordered:
source = self.add_metric(info, source.to_table(use_threads=False), mode='group')
Expand Down Expand Up @@ -461,13 +455,7 @@ def scan(self, info: Info, filter: Expression = {}, columns: list[Projection] =
scanner = ds.Scanner.from_batches(self.source.to_batches(), **options)
return type(self)(self.add_metric(info, scanner.to_table(), mode='batch'))
source = self.source if expr is None else self.source.filter(expr)
if isinstance(source, ds.Dataset):
return type(self)(Nodes.scan(source, projection) if columns else source)
if isinstance(source, pa.Table):
if not columns:
return type(self)(source.select(list(projection)))
source = Nodes('table_source', source)
return type(self)(source.project(projection))
return type(self)(Nodes.scan(source, projection) if columns else source)

@doc_field(
right="name of right table; must be on root Query type",
Expand Down
1 change: 0 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def test_membership():

def test_nodes(table):
dataset = ds.dataset(table).filter(pc.field('state') == 'CA')
assert Nodes.scan(dataset).to_table()['state'].unique().to_pylist() == ['CA']
(column,) = Nodes.scan(dataset, columns={'_': pc.field('state')}).to_table()
assert column.unique().to_pylist() == ['CA']
table = Nodes.group(dataset, 'county', 'city', counts=([], 'hash_count_all', None)).to_table()
Expand Down
6 changes: 2 additions & 4 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,9 @@ def test_federation(fedclient):

data = fedclient.execute(
"""{ _entities(representations: {__typename: "ZipcodesTable", zipcode: 90001}) {
... on ZipcodesTable { length row { state } schema { names } } } }"""
... on ZipcodesTable { length type row { state } } } }"""
)
assert data == {
'_entities': [{'length': 1, 'row': {'state': 'CA'}, 'schema': {'names': ['state']}}]
}
assert data == {'_entities': [{'length': 1, 'type': 'Nodes', 'row': {'state': 'CA'}}]}
data = fedclient.execute("""{ states { filter(state: {eq: "CA"}) { columns { indices {
takeFrom(field: "zipcodes") { __typename column(name: "state") { length } } } } } } }""")
table = data['states']['filter']['columns']['indices']['takeFrom']
Expand Down

0 comments on commit 5a3f440

Please sign in to comment.