Skip to content

Commit

Permalink
Nodes supported as a data source.
Browse files Browse the repository at this point in the history
  • Loading branch information
coady committed Sep 28, 2024
1 parent a1ef5e3 commit 5be87e6
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 34 deletions.
2 changes: 2 additions & 0 deletions graphique/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,8 @@ def apply(self, name: str, *args, **options) -> Self:
"""Add a node by name."""
return type(self)(name, *args, inputs=[self], **options)

filter = functools.partialmethod(apply, 'filter')

def project(self, columns: Union[Mapping[str, pc.Expression], Iterable[str]]) -> Self:
"""Add `project` node from columns names with optional expressions."""
if isinstance(columns, Mapping):
Expand Down
81 changes: 48 additions & 33 deletions graphique/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import collections
import inspect
import itertools
from collections.abc import Callable, Iterable, Iterator, Mapping
from collections.abc import Callable, Iterable, Iterator, Mapping, Sized
from datetime import timedelta
from typing import Annotated, Optional, Union, no_type_check
import pyarrow as pa
Expand All @@ -18,14 +18,14 @@
from strawberry import Info
from strawberry.extensions.utils import get_path_from_info
from typing_extensions import Self
from .core import Batch, Column as C, ListChunk, Table as T
from .core import Batch, Column as C, ListChunk, Nodes, Table as T
from .inputs import CountAggregate, Cumulative, Diff, Expression, Field, Filter
from .inputs import HashAggregates, ListFunction, Pairwise, Projection, Rank
from .inputs import ScalarAggregate, TDigestAggregate, VarianceAggregate, links, provisional
from .models import Column, doc_field, selections
from .scalars import Long

Source = Union[ds.Dataset, ds.Scanner, pa.Table]
Source = Union[ds.Dataset, Nodes, ds.Scanner, pa.Table]


def references(field) -> Iterator:
Expand Down Expand Up @@ -71,23 +71,43 @@ def references(self, info: Info, level: int = 0) -> set:
fields = itertools.chain(*[field.selections for field in fields])
return set(itertools.chain(*map(references, fields))) & set(self.schema().names)

def select(self, info: Info) -> Source:
"""Return source with only the columns necessary to proceed."""
names = list(self.references(info))
if len(names) >= len(self.schema().names):
return self.source
if isinstance(self.source, ds.Dataset):
return Nodes.scan(self.source, names)
if isinstance(self.source, Nodes):
return self.source.project(names)
if isinstance(self.source, ds.Scanner):
schema = self.source.projected_schema
return ds.Scanner.from_batches(self.source.to_batches(), schema=schema, columns=names)
return self.source.select(names)

def scanner(self, info: Info, **options) -> ds.Scanner:
"""Return scanner with only the columns necessary to proceed."""
options.setdefault('columns', list(self.references(info)))
dataset = ds.dataset(self.source) if isinstance(self.source, pa.Table) else self.source
if isinstance(dataset, ds.Dataset):
return dataset.scanner(**options)
options['schema'] = dataset.projected_schema
return ds.Scanner.from_batches(dataset.to_batches(), **options)

def select(self, info: Info, length: Optional[int] = None) -> pa.Table:
source = ds.dataset(self.source) if isinstance(self.source, pa.Table) else self.source
if isinstance(source, ds.Dataset):
return source.scanner(**options)
if isinstance(source, Nodes):
if 'filter' in options: # pragma: no branch
source = source.filter(options['filter'])
if 'columns' in options: # pragma: no branch
source = source.project(options['columns'])
return source.scanner()
options['schema'] = source.projected_schema
return ds.Scanner.from_batches(source.to_batches(), **options)

def to_table(self, info: Info, length: Optional[int] = None) -> pa.Table:
"""Return table with only the rows and columns necessary to proceed."""
if isinstance(self.source, pa.Table):
return self.source.select(self.references(info))
scanner = self.scanner(info)
source = self.select(info)
if isinstance(source, pa.Table):
return source
if length is None:
return self.add_metric(info, scanner.to_table(), mode='read')
return self.add_metric(info, scanner.head(length), mode='head')
return self.add_metric(info, source.to_table(), mode='read')
return self.add_metric(info, source.head(length), mode='head')

@classmethod
@no_type_check
Expand All @@ -99,12 +119,12 @@ def resolve_reference(cls, info: Info, **keys) -> Self:

def columns(self, info: Info) -> dict:
"""fields for each column"""
table = self.select(info)
table = self.to_table(info)
return {name: Column.cast(table[name]) for name in table.schema.names}

def row(self, info: Info, index: int = 0) -> dict:
"""Return scalar values at index."""
table = self.select(info, index + 1 if index >= 0 else None)
table = self.to_table(info, index + 1 if index >= 0 else None)
row = {}
for name in table.schema.names:
scalar = table[name][index]
Expand Down Expand Up @@ -169,11 +189,6 @@ def optional(self) -> Optional[Self]:
"""
return self

@staticmethod
def add_context(info: Info, key: str, **data): # pragma: no cover
"""Add data to context with path info."""
info.context.setdefault(key, []).append(dict(data, path=get_path_from_info(info)))

@staticmethod
def add_metric(info: Info, table: pa.Table, **data):
"""Add memory usage and other metrics to context with path info."""
Expand All @@ -184,15 +199,15 @@ def add_metric(info: Info, table: pa.Table, **data):
@doc_field
def length(self) -> Long:
"""number of rows"""
return len(self.source) if hasattr(self.source, '__len__') else self.source.count_rows()
return len(self.source) if isinstance(self.source, Sized) else self.source.count_rows()

@doc_field
def any(self, info: Info, length: Long = 1) -> bool:
"""Return whether there are at least `length` rows.
May be significantly faster than `length` for out-of-core data.
"""
table = self.select(info, length)
table = self.to_table(info, length)
return len(table) >= length

@doc_field
Expand Down Expand Up @@ -225,7 +240,7 @@ def slice(
self, info: Info, offset: Long = 0, length: Optional[Long] = None, reverse: bool = False
) -> Self:
"""Return zero-copy slice of table."""
table = self.select(info, length and (offset + length if offset >= 0 else None))
table = self.to_table(info, length and (offset + length if offset >= 0 else None))
table = table[offset:][:length] # `slice` bug: ARROW-15412
return type(self)(table[::-1] if reverse else table)

Expand All @@ -252,7 +267,7 @@ def group(
refs = {agg.name for values in aggs.values() for agg in values}
fragments = set(T.fragment_keys(self.source))
if isinstance(source, ds.Scanner):
source = self.select(info)
source = self.to_table(info)
if fragments and set(by) <= fragments:
if set(by) == fragments:
return type(self)(self.fragments(info, counts, aggregate))
Expand Down Expand Up @@ -310,7 +325,7 @@ def runs(
Differs from `group` by relying on adjacency, and is typically faster. Other columns are
transformed into list columns. See `column` and `tables` to further access lists.
"""
table = self.select(info)
table = self.to_table(info)
predicates = {}
for diff in map(dict, split):
name = diff.pop('name')
Expand Down Expand Up @@ -339,7 +354,7 @@ def sort(
"""
kwargs = dict(length=length, null_placement=null_placement)
if isinstance(self.source, pa.Table) or length is None:
table = self.select(info)
table = self.to_table(info)
else:
expr, by = T.rank_keys(self.source, length, *by, dense=False)
scanner = self.scanner(info, filter=expr)
Expand All @@ -360,7 +375,7 @@ def rank(self, info: Info, by: list[str], max: int = 1) -> Self:
if not by:
return type(self)(source)
if not isinstance(source, ds.Dataset):
source = self.select(info)
source = self.to_table(info)
return type(self)(T.rank(source, max, *by))

@staticmethod
Expand Down Expand Up @@ -459,7 +474,7 @@ def aggregate(
variance: doc_argument(list[VarianceAggregate], func=pc.variance) = [],
) -> Self:
"""Return table with scalar aggregate functions applied to list columns."""
table = self.select(info)
table = self.to_table(info)
columns = T.columns(table)
agg_fields: dict = collections.defaultdict(dict)
keys: tuple = 'approximate_median', 'count', 'count_distinct', 'distinct', 'first', 'last'
Expand Down Expand Up @@ -529,7 +544,7 @@ def join(
) -> Self:
"""Provisional: [join](https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html#pyarrow.dataset.Dataset.join) this table with another table on the root Query type."""
left, right = (
root.source if isinstance(root.source, ds.Dataset) else root.select(info)
root.source if isinstance(root.source, ds.Dataset) else root.to_table(info)
for root in (self, getattr(info.root_value, right))
)
table = left.join(
Expand All @@ -548,14 +563,14 @@ def join(
@doc_field
def take(self, info: Info, indices: list[Long]) -> Self:
"""Select rows from indices."""
table = self.scanner(info).take(indices)
table = self.select(info).take(indices)
return type(self)(self.add_metric(info, table, mode='take'))

@doc_field
def drop_null(self, info: Info) -> Self:
"""Remove missing values from referenced columns in the table."""
if isinstance(self.source, pa.Table):
return type(self)(pc.drop_null(self.select(info)))
return type(self)(pc.drop_null(self.to_table(info)))
scanner = self.scanner(info)
batches = map(pc.drop_null, scanner.to_batches())
scanner = ds.Scanner.from_batches(batches, schema=scanner.projected_schema)
Expand Down
3 changes: 2 additions & 1 deletion graphique/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pyarrow as pa
import pyarrow.dataset as ds
from starlette.config import Config
from graphique.core import Nodes
from graphique.inputs import Expression
from graphique import GraphQL

Expand All @@ -32,7 +33,7 @@
if FILTERS is not None:
root = root.to_table(columns=COLUMNS, filter=Expression.from_query(**FILTERS).to_arrow())
elif COLUMNS:
root = root.scanner(columns=COLUMNS)
root = Nodes.scan(root, columns=COLUMNS)

if FEDERATED:
app = GraphQL.federated({FEDERATED: root}, debug=DEBUG)
Expand Down

0 comments on commit 5be87e6

Please sign in to comment.