diff --git a/fsspec/implementations/reference.py b/fsspec/implementations/reference.py index 6904d7b60..d5d3f2968 100644 --- a/fsspec/implementations/reference.py +++ b/fsspec/implementations/reference.py @@ -7,7 +7,7 @@ import os from itertools import chain from functools import lru_cache -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import fsspec.core @@ -104,7 +104,13 @@ def pd(self): return pd def __init__( - self, root, fs=None, out_root=None, cache_size=128, categorical_threshold=10 + self, + root, + fs=None, + out_root=None, + cache_size=128, + categorical_threshold=10, + engine: Literal["fastparquet", "pyarrow"] = "fastparquet", ): """ @@ -126,16 +132,25 @@ def __init__( Encode urls as pandas.Categorical to reduce memory footprint if the ratio of the number of unique urls to total number of refs for each variable is greater than or equal to this number. (default 10) + engine: Literal["fastparquet","pyarrow"] + Engine choice for reading parquet files. (default is "fastparquet") """ + self.root = root self.chunk_sizes = {} self.out_root = out_root or self.root self.cat_thresh = categorical_threshold + self.engine = engine self.cache_size = cache_size self.url = self.root + "/{field}/refs.{record}.parq" # TODO: derive fs from `root` self.fs = fsspec.filesystem("file") if fs is None else fs + from importlib.util import find_spec + + if self.engine == "pyarrow" and find_spec("pyarrow") is None: + raise ImportError("engine choice `pyarrow` is not installed.") + def __getattr__(self, item): if item in ("_items", "record_size", "zmetadata"): self.setup() @@ -158,7 +173,7 @@ def open_refs(field, record): """cached parquet file loader""" path = self.url.format(field=field, record=record) data = io.BytesIO(self.fs.cat_file(path)) - df = self.pd.read_parquet(data, engine="fastparquet") + df = self.pd.read_parquet(data, engine=self.engine) refs = {c: df[c].to_numpy() for c in df.columns} return refs @@ -463,18 +478,28 @@ def write(self, field, record, base_url=None, storage_options=None): fn = f"{base_url or self.out_root}/{field}/refs.{record}.parq" self.fs.mkdirs(f"{base_url or self.out_root}/{field}", exist_ok=True) + + if self.engine == "pyarrow": + df_backend_kwargs = {"write_statistics": False} + elif self.engine == "fastparquet": + df_backend_kwargs = { + "stats": False, + "object_encoding": object_encoding, + "has_nulls": has_nulls, + } + else: + raise NotImplementedError(f"{self.engine} not supported") + df.to_parquet( fn, - engine="fastparquet", + engine=self.engine, storage_options=storage_options or getattr(self.fs, "storage_options", None), compression="zstd", index=False, - stats=False, - object_encoding=object_encoding, - has_nulls=has_nulls, - # **kwargs, + **df_backend_kwargs, ) + partition.clear() self._items.pop((field, record)) @@ -486,6 +511,7 @@ def flush(self, base_url=None, storage_options=None): base_url: str Location of the output """ + # write what we have so far and clear sub chunks for thing in list(self._items): if isinstance(thing, tuple): diff --git a/fsspec/implementations/tests/test_reference.py b/fsspec/implementations/tests/test_reference.py index 7bfd7d08e..620ffe715 100644 --- a/fsspec/implementations/tests/test_reference.py +++ b/fsspec/implementations/tests/test_reference.py @@ -761,18 +761,25 @@ def test_append_parquet(lazy_refs, m): assert lazy2["data/1"] == b"Adata" -def test_deep_parq(m): +@pytest.mark.parametrize("engine", ["fastparquet", "pyarrow"]) +def test_deep_parq(m, engine): pytest.importorskip("kerchunk") zarr = pytest.importorskip("zarr") + lz = fsspec.implementations.reference.LazyReferenceMapper.create( - "memory://out.parq", fs=m + "memory://out.parq", + fs=m, + engine=engine, ) g = zarr.open_group(lz, mode="w") + g2 = g.create_group("instant") g2.create_dataset(name="one", data=[1, 2, 3]) lz.flush() - lz = fsspec.implementations.reference.LazyReferenceMapper("memory://out.parq", fs=m) + lz = fsspec.implementations.reference.LazyReferenceMapper( + "memory://out.parq", fs=m, engine=engine + ) g = zarr.open_group(lz) assert g.instant.one[:].tolist() == [1, 2, 3] assert sorted(_["name"] for _ in lz.ls("")) == [".zgroup", ".zmetadata", "instant"] @@ -780,6 +787,7 @@ def test_deep_parq(m): "instant/.zgroup", "instant/one", ] + assert sorted(_["name"] for _ in lz.ls("instant/one")) == [ "instant/one/.zarray", "instant/one/0",