Skip to content

Commit

Permalink
Adds engine option to LazyReferenceMapper (#1692)
Browse files Browse the repository at this point in the history
  • Loading branch information
norlandrhagen authored Oct 9, 2024
1 parent f2c7717 commit 176efbe
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 11 deletions.
42 changes: 34 additions & 8 deletions fsspec/implementations/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
from itertools import chain
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import fsspec.core

Expand Down Expand Up @@ -104,7 +104,13 @@ def pd(self):
return pd

def __init__(
self, root, fs=None, out_root=None, cache_size=128, categorical_threshold=10
self,
root,
fs=None,
out_root=None,
cache_size=128,
categorical_threshold=10,
engine: Literal["fastparquet", "pyarrow"] = "fastparquet",
):
"""
Expand All @@ -126,16 +132,25 @@ def __init__(
Encode urls as pandas.Categorical to reduce memory footprint if the ratio
of the number of unique urls to total number of refs for each variable
is greater than or equal to this number. (default 10)
engine: Literal["fastparquet","pyarrow"]
Engine choice for reading parquet files. (default is "fastparquet")
"""

self.root = root
self.chunk_sizes = {}
self.out_root = out_root or self.root
self.cat_thresh = categorical_threshold
self.engine = engine
self.cache_size = cache_size
self.url = self.root + "/{field}/refs.{record}.parq"
# TODO: derive fs from `root`
self.fs = fsspec.filesystem("file") if fs is None else fs

from importlib.util import find_spec

if self.engine == "pyarrow" and find_spec("pyarrow") is None:
raise ImportError("engine choice `pyarrow` is not installed.")

def __getattr__(self, item):
if item in ("_items", "record_size", "zmetadata"):
self.setup()
Expand All @@ -158,7 +173,7 @@ def open_refs(field, record):
"""cached parquet file loader"""
path = self.url.format(field=field, record=record)
data = io.BytesIO(self.fs.cat_file(path))
df = self.pd.read_parquet(data, engine="fastparquet")
df = self.pd.read_parquet(data, engine=self.engine)
refs = {c: df[c].to_numpy() for c in df.columns}
return refs

Expand Down Expand Up @@ -463,18 +478,28 @@ def write(self, field, record, base_url=None, storage_options=None):

fn = f"{base_url or self.out_root}/{field}/refs.{record}.parq"
self.fs.mkdirs(f"{base_url or self.out_root}/{field}", exist_ok=True)

if self.engine == "pyarrow":
df_backend_kwargs = {"write_statistics": False}
elif self.engine == "fastparquet":
df_backend_kwargs = {
"stats": False,
"object_encoding": object_encoding,
"has_nulls": has_nulls,
}
else:
raise NotImplementedError(f"{self.engine} not supported")

df.to_parquet(
fn,
engine="fastparquet",
engine=self.engine,
storage_options=storage_options
or getattr(self.fs, "storage_options", None),
compression="zstd",
index=False,
stats=False,
object_encoding=object_encoding,
has_nulls=has_nulls,
# **kwargs,
**df_backend_kwargs,
)

partition.clear()
self._items.pop((field, record))

Expand All @@ -486,6 +511,7 @@ def flush(self, base_url=None, storage_options=None):
base_url: str
Location of the output
"""

# write what we have so far and clear sub chunks
for thing in list(self._items):
if isinstance(thing, tuple):
Expand Down
14 changes: 11 additions & 3 deletions fsspec/implementations/tests/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,25 +761,33 @@ def test_append_parquet(lazy_refs, m):
assert lazy2["data/1"] == b"Adata"


def test_deep_parq(m):
@pytest.mark.parametrize("engine", ["fastparquet", "pyarrow"])
def test_deep_parq(m, engine):
pytest.importorskip("kerchunk")
zarr = pytest.importorskip("zarr")

lz = fsspec.implementations.reference.LazyReferenceMapper.create(
"memory://out.parq", fs=m
"memory://out.parq",
fs=m,
engine=engine,
)
g = zarr.open_group(lz, mode="w")

g2 = g.create_group("instant")
g2.create_dataset(name="one", data=[1, 2, 3])
lz.flush()

lz = fsspec.implementations.reference.LazyReferenceMapper("memory://out.parq", fs=m)
lz = fsspec.implementations.reference.LazyReferenceMapper(
"memory://out.parq", fs=m, engine=engine
)
g = zarr.open_group(lz)
assert g.instant.one[:].tolist() == [1, 2, 3]
assert sorted(_["name"] for _ in lz.ls("")) == [".zgroup", ".zmetadata", "instant"]
assert sorted(_["name"] for _ in lz.ls("instant")) == [
"instant/.zgroup",
"instant/one",
]

assert sorted(_["name"] for _ in lz.ls("instant/one")) == [
"instant/one/.zarray",
"instant/one/0",
Expand Down

0 comments on commit 176efbe

Please sign in to comment.