Skip to content

Commit

Permalink
Support incremental appending
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Mar 3, 2021
1 parent e2188f7 commit a468066
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 30 deletions.
73 changes: 45 additions & 28 deletions pangeo_forge/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _prepare_target():
chunk_key
).chunk() # make sure data are not in memory
init_dsets.append(chunk_ds)
# TODO: create csutomizable option for this step
# TODO: create customizable option for this step
# How to combine attrs is particularly important. It seems like
# xarray is missing a "minimal" option to only keep the attrs
# that are the same among all input variables.
Expand Down Expand Up @@ -234,36 +234,47 @@ def _prepare_target():
@property
def cache_input(self) -> Callable:
def cache_func(input_key: Hashable) -> None:
logger.info(f"Caching input {input_key}")
fname = self._inputs[input_key]
# TODO: add a check for whether the input is already cached?
with input_opener(fname, mode="rb", **self.fsspec_open_kwargs) as source:
with self.input_cache.open(fname, mode="wb") as target:
# TODO: make this configurable? Would we ever want to change it?
BLOCK_SIZE = 10_000_000 # 10 MB
while True:
data = source.read(BLOCK_SIZE)
if not data:
break
target.write(data)
if self._inputs[input_key]["processed"]:
logger.info(f"Dry-run: caching input {input_key}")
else:
logger.info(f"Caching input {input_key}")
fname = self._inputs[input_key]["url"]
# TODO: add a check for whether the input is already cached?
with input_opener(fname, mode="rb", **self.fsspec_open_kwargs) as source:
with self.input_cache.open(fname, mode="wb") as target:
# TODO: make this configurable? Would we ever want to change it?
BLOCK_SIZE = 10_000_000 # 10 MB
while True:
data = source.read(BLOCK_SIZE)
if not data:
break
target.write(data)

return cache_func

@property
def store_chunk(self) -> Callable:
def _store_chunk(chunk_key):
ds_chunk = self.open_chunk(chunk_key)
write_region = self.region_for_chunk(chunk_key)
if all(
[
self._inputs[input_key]["processed"]
for input_key in self.inputs_for_chunk(chunk_key)
]
):
logger.info(f"Dry-run: storing chunk '{chunk_key}' to Zarr region {write_region}")
else:
ds_chunk = self.open_chunk(chunk_key)

def drop_vars(ds):
# writing a region means that all the variables MUST have sequence_dim
to_drop = [v for v in ds.variables if self.sequence_dim not in ds[v].dims]
return ds.drop_vars(to_drop)
def drop_vars(ds):
# writing a region means that all the variables MUST have sequence_dim
to_drop = [v for v in ds.variables if self.sequence_dim not in ds[v].dims]
return ds.drop_vars(to_drop)

ds_chunk = drop_vars(ds_chunk)
target_mapper = self.target.get_mapper()
write_region = self.region_for_chunk(chunk_key)
logger.info(f"Storing chunk '{chunk_key}' to Zarr region {write_region}")
ds_chunk.to_zarr(target_mapper, region=write_region)
ds_chunk = drop_vars(ds_chunk)
target_mapper = self.target.get_mapper()
logger.info(f"Storing chunk '{chunk_key}' to Zarr region {write_region}")
ds_chunk.to_zarr(target_mapper, region=write_region)

return _store_chunk

Expand Down Expand Up @@ -297,7 +308,7 @@ def input_opener(self, fname: str):
yield f

def open_input(self, input_key: Hashable):
fname = self._inputs[input_key]
fname = self._inputs[input_key]["url"]
with self.input_opener(fname) as f:
logger.info(f"Opening input with Xarray {input_key}: '{fname}'")
ds = xr.open_dataset(f, **self.xarray_open_kwargs)
Expand Down Expand Up @@ -357,7 +368,7 @@ def expand_target_dim(self, dim, dimsize):
# now explicity write the sequence coordinate to avoid missing data
# when reopening
if dim in zgroup:
zgroup[dim][:] = 0
zgroup[dim][self.nitems_per_input * self.processed_input_nb :] = 0 # noqa: E203

def inputs_for_chunk(self, chunk_key):
return self._chunks_inputs[chunk_key]
Expand Down Expand Up @@ -392,15 +403,20 @@ class NetCDFtoZarrSequentialRecipe(NetCDFtoZarrRecipe):
"""There is only one sequence of input files. Each file can contain
many variables.
:param processed_input_urls: The inputs already used to generate the existing dataset.
:param input_urls: The inputs used to generate the dataset.
"""

processed_input_urls: Iterable[str] = field(repr=False, default_factory=list)
input_urls: Iterable[str] = field(repr=False, default_factory=list)

def __post_init__(self):
super().__post_init__()
input_pattern = ExplicitURLSequence(self.input_urls)
self._inputs = {k: v for k, v in input_pattern}
self.processed_input_nb = len(self.processed_input_urls)
input_pattern = ExplicitURLSequence(self.processed_input_urls + self.input_urls)
self._inputs = {
k: {"url": v, "processed": k < self.processed_input_nb} for k, v in input_pattern
}
self._chunks_inputs = {
k: v for k, v in enumerate(chunked_iterable(self._inputs, self.inputs_per_chunk))
}
Expand Down Expand Up @@ -428,7 +444,8 @@ class NetCDFtoZarrMultiVarSequentialRecipe(NetCDFtoZarrRecipe):
def __post_init__(self):
super().__post_init__()
self._variables = self.input_pattern.keys["variable"]
self._inputs = {k: v for k, v in self.input_pattern}
self.processed_input_nb = 0 # TODO
self._inputs = {k: {"url": v, "processed": False} for k, v in self.input_pattern}
# input keys are tuples like
# ("temp", 0)
# ("temp", 1)
Expand Down
37 changes: 35 additions & 2 deletions tests/test_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,41 @@ def _manually_execute_recipe(r):
r.finalize_target()


def test_NetCDFtoZarrSequentialRecipeIncremental(
daily_xarray_dataset, netcdf_local_paths, tmp_target, tmp_cache
):

paths, items_per_file = netcdf_local_paths
n = len(paths) // 2

paths1 = paths[:n]
r = recipe.NetCDFtoZarrSequentialRecipe(
input_urls=paths1,
sequence_dim="time",
inputs_per_chunk=1,
nitems_per_input=items_per_file,
target=tmp_target,
input_cache=tmp_cache,
)
_manually_execute_recipe(r)

paths2 = paths[n:]
r = recipe.NetCDFtoZarrSequentialRecipe(
processed_input_urls=paths1,
input_urls=paths2,
sequence_dim="time",
inputs_per_chunk=1,
nitems_per_input=items_per_file,
target=tmp_target,
input_cache=tmp_cache,
)
_manually_execute_recipe(r)

ds_target = xr.open_zarr(tmp_target.get_mapper(), consolidated=True).load()
ds_expected = daily_xarray_dataset.compute()
assert ds_target.identical(ds_expected)


@pytest.mark.parametrize(
"username, password", [("foo", "bar"), ("foo", "wrong"),], # noqa: E231
)
Expand Down Expand Up @@ -164,6 +199,4 @@ def test_NetCDFtoZarrMultiVarSequentialRecipe(
_manually_execute_recipe(r)

ds_target = xr.open_zarr(tmp_target.get_mapper(), consolidated=True).compute()
print(ds_target)
print(daily_xarray_dataset)
assert ds_target.identical(daily_xarray_dataset)

0 comments on commit a468066

Please sign in to comment.