Skip to content

Commit

Permalink
Partitioning utility supports sorting.
Browse files Browse the repository at this point in the history
  • Loading branch information
coady committed Jan 20, 2024
1 parent 87626e8 commit d8689d3
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions graphique/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@

import argparse
import shutil
from collections.abc import Iterable
from pathlib import Path
import pyarrow.dataset as ds
from tqdm import tqdm # type: ignore


def sort_key(name: str) -> tuple:
"""Parse sort order."""
return name.lstrip('-'), ('descending' if name.startswith('-') else 'ascending')


def write_batches(scanner: ds.Scanner, base_dir: str, *partitioning: str, **options):
"""Partition dataset by batches."""
options.update(format='parquet', partitioning=partitioning)
Expand All @@ -27,24 +33,31 @@ def write_batches(scanner: ds.Scanner, base_dir: str, *partitioning: str, **opti
pbar.update(len(batch))


def write_fragments(dataset: ds.Dataset, base_dir: str, **options):
def write_fragments(dataset: ds.Dataset, base_dir: str, sorting=(), **options):
"""Rewrite partition files by fragment to consolidate."""
options['format'] = 'parquet'
exprs = {Path(frag.path).parent: frag.partition_expression for frag in dataset.get_fragments()}
offset = len(dataset.partitioning.schema)
for path in tqdm(exprs, desc="Fragments"):
part_dir = Path(base_dir, *path.parts[-offset:])
ds.write_dataset(dataset.filter(exprs[path]), part_dir, **options)
part = dataset.filter(exprs[path])
ds.write_dataset(part.sort_by(sorting) if sorting else part, part_dir, **options)


def partition(
scanner: ds.Scanner, base_dir: str, *partitioning: str, fragments: bool = False, **options
scanner: ds.Scanner,
base_dir: str,
*partitioning: str,
fragments: bool = False,
sort: Iterable[str] = (),
**options,
):
"""Partition dataset by keys."""
temp = Path(base_dir) / 'temp'
write_batches(scanner, str(temp), *partitioning)
dataset = ds.dataset(temp, partitioning='hive')
if fragments:
write_fragments(dataset, base_dir, **options)
if fragments or sort:
write_fragments(dataset, base_dir, tuple(map(sort_key, sort)), **options)
else:
options.update(partitioning_flavor='hive', existing_data_behavior='overwrite_or_ignore')
with tqdm(desc="Partitions"):
Expand All @@ -57,8 +70,9 @@ def partition(
parser.add_argument('dest', help="destination path")
parser.add_argument('partitioning', nargs='+', help="partition keys")
parser.add_argument('--fragments', action='store_true', help="iterate over fragments")
parser.add_argument('--sort', nargs='*', default=(), help="sort keys; will load fragments")

if __name__ == '__main__':
args = parser.parse_args()
dataset = ds.dataset(args.src, partitioning='hive')
partition(dataset, args.dest, *args.partitioning, fragments=args.fragments)
partition(dataset, args.dest, *args.partitioning, fragments=args.fragments, sort=args.sort)

0 comments on commit d8689d3

Please sign in to comment.