Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds engine option to LazyReferenceMapper #1692

Merged
merged 16 commits into from
Oct 9, 2024
42 changes: 34 additions & 8 deletions fsspec/implementations/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
from itertools import chain
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import fsspec.core

Expand Down Expand Up @@ -104,7 +104,13 @@ def pd(self):
return pd

def __init__(
self, root, fs=None, out_root=None, cache_size=128, categorical_threshold=10
self,
root,
fs=None,
out_root=None,
cache_size=128,
categorical_threshold=10,
engine: Literal["fastparquet", "pyarrow"] = "fastparquet",
):
"""

Expand All @@ -126,16 +132,25 @@ def __init__(
Encode urls as pandas.Categorical to reduce memory footprint if the ratio
of the number of unique urls to total number of refs for each variable
is greater than or equal to this number. (default 10)
engine: Literal["fastparquet","pyarrow"]
Engine choice for reading parquet files. (default is "fastparquet")
"""

self.root = root
self.chunk_sizes = {}
self.out_root = out_root or self.root
self.cat_thresh = categorical_threshold
self.engine = engine
self.cache_size = cache_size
self.url = self.root + "/{field}/refs.{record}.parq"
# TODO: derive fs from `root`
self.fs = fsspec.filesystem("file") if fs is None else fs

from importlib.util import find_spec

if self.engine == "pyarrow" and find_spec("pyarrow") is None:
raise ImportError("engine choice `pyarrow` is not installed.")

def __getattr__(self, item):
if item in ("_items", "record_size", "zmetadata"):
self.setup()
Expand All @@ -158,7 +173,7 @@ def open_refs(field, record):
"""cached parquet file loader"""
path = self.url.format(field=field, record=record)
data = io.BytesIO(self.fs.cat_file(path))
df = self.pd.read_parquet(data, engine="fastparquet")
df = self.pd.read_parquet(data, engine=self.engine)
refs = {c: df[c].to_numpy() for c in df.columns}
return refs

Expand Down Expand Up @@ -463,18 +478,28 @@ def write(self, field, record, base_url=None, storage_options=None):

fn = f"{base_url or self.out_root}/{field}/refs.{record}.parq"
self.fs.mkdirs(f"{base_url or self.out_root}/{field}", exist_ok=True)

if self.engine == "pyarrow":
df_backend_kwargs = {"write_statistics": False}
elif self.engine == "fastparquet":
df_backend_kwargs = {
"stats": False,
"object_encoding": object_encoding,
"has_nulls": has_nulls,
}
else:
raise NotImplementedError(f"{self.engine} not supported")

df.to_parquet(
fn,
engine="fastparquet",
engine=self.engine,
storage_options=storage_options
or getattr(self.fs, "storage_options", None),
compression="zstd",
index=False,
stats=False,
object_encoding=object_encoding,
has_nulls=has_nulls,
# **kwargs,
**df_backend_kwargs,
)

partition.clear()
self._items.pop((field, record))

Expand All @@ -486,6 +511,7 @@ def flush(self, base_url=None, storage_options=None):
base_url: str
Location of the output
"""

# write what we have so far and clear sub chunks
for thing in list(self._items):
if isinstance(thing, tuple):
Expand Down
14 changes: 11 additions & 3 deletions fsspec/implementations/tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,25 +761,33 @@ def test_append_parquet(lazy_refs, m):
assert lazy2["data/1"] == b"Adata"


def test_deep_parq(m):
@pytest.mark.parametrize("engine", ["fastparquet", "pyarrow"])
def test_deep_parq(m, engine):
pytest.importorskip("kerchunk")
zarr = pytest.importorskip("zarr")

lz = fsspec.implementations.reference.LazyReferenceMapper.create(
"memory://out.parq", fs=m
"memory://out.parq",
fs=m,
engine=engine,
)
g = zarr.open_group(lz, mode="w")

g2 = g.create_group("instant")
g2.create_dataset(name="one", data=[1, 2, 3])
lz.flush()

lz = fsspec.implementations.reference.LazyReferenceMapper("memory://out.parq", fs=m)
lz = fsspec.implementations.reference.LazyReferenceMapper(
"memory://out.parq", fs=m, engine=engine
)
g = zarr.open_group(lz)
assert g.instant.one[:].tolist() == [1, 2, 3]
assert sorted(_["name"] for _ in lz.ls("")) == [".zgroup", ".zmetadata", "instant"]
assert sorted(_["name"] for _ in lz.ls("instant")) == [
"instant/.zgroup",
"instant/one",
]

assert sorted(_["name"] for _ in lz.ls("instant/one")) == [
"instant/one/.zarray",
"instant/one/0",
Expand Down
Loading