Skip to content

Commit

Permalink
Merge pull request #247 from ArtesiaWater/add_caching_data_array
Browse files Browse the repository at this point in the history
Add caching for DataArrays as well
  • Loading branch information
rubencalje authored Aug 25, 2023
2 parents 9161091 + 46ccfbb commit f3fa3ee
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
41 changes: 40 additions & 1 deletion nlmod/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):
func_args_dic, func_args_dic_cache
)

cached_ds = _check_for_data_array(cached_ds)
if modification_check and argument_check and pickle_check:
if dataset is None:
logger.info(f"using cached data -> {cachename}")
Expand All @@ -162,6 +163,10 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):
result = func(*args, **kwargs)
logger.info(f"caching data -> {cachename}")

if isinstance(result, xr.DataArray):
# set the DataArray as a variable in a new Dataset
result = xr.Dataset({"__xarray_dataarray_variable__": result})

if isinstance(result, xr.Dataset):
# close cached netcdf (otherwise it is impossible to overwrite)
if os.path.exists(fname_cache):
Expand Down Expand Up @@ -192,7 +197,7 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):
pickle.dump(func_args_dic, fpklz)
else:
raise TypeError(f"expected xarray Dataset, got {type(result)} instead")

result = _check_for_data_array(result)
return result

return decorator
Expand Down Expand Up @@ -398,3 +403,37 @@ def _update_docstring_and_signature(func):
new_doc = "".join((mod_before, after))
func.__doc__ = new_doc
return


def _check_for_data_array(ds):
"""
Check if the saved NetCDF-file represents a DataArray or a Dataset, and return this
data-variable.
The file contains a DataArray when a variable called "__xarray_dataarray_variable__"
is present in the Dataset. If so, return a DataArray, otherwise return the Dataset.
By saving the DataArray, the coordinate "spatial_ref" was saved as a separate
variable. Therefore, add this variable as a coordinate to the DataArray again.
Parameters
----------
ds : xr.Dataset
Dataset with dimensions and coordinates.
Returns
-------
ds : xr.Dataset or xr.DataArray
A Dataset or DataArray containing the cached data.
"""
if "__xarray_dataarray_variable__" in ds:
if "spatial_ref" in ds:
spatial_ref = ds.spatial_ref
else:
spatial_ref = None
# the method returns a DataArray, so we return only this DataArray
ds = ds["__xarray_dataarray_variable__"]
if spatial_ref is not None:
ds = ds.assign_coords({"spatial_ref": spatial_ref})
return ds
4 changes: 4 additions & 0 deletions nlmod/read/ahn.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def get_ahn4_tiles(extent=None):
return gdf


@cache.cache_netcdf
def get_ahn1(extent, identifier="ahn1_5m", as_data_array=True):
"""Download AHN1.
Expand All @@ -267,6 +268,7 @@ def get_ahn1(extent, identifier="ahn1_5m", as_data_array=True):
return da


@cache.cache_netcdf
def get_ahn2(extent, identifier="ahn2_5m", as_data_array=True):
"""Download AHN2.
Expand All @@ -290,6 +292,7 @@ def get_ahn2(extent, identifier="ahn2_5m", as_data_array=True):
return _download_and_combine_tiles(tiles, identifier, extent, as_data_array)


@cache.cache_netcdf
def get_ahn3(extent, identifier="AHN3_5m_DTM", as_data_array=True):
"""Download AHN3.
Expand All @@ -312,6 +315,7 @@ def get_ahn3(extent, identifier="AHN3_5m_DTM", as_data_array=True):
return _download_and_combine_tiles(tiles, identifier, extent, as_data_array)


@cache.cache_netcdf
def get_ahn4(extent, identifier="AHN4_DTM_5m", as_data_array=True):
"""Download AHN4.
Expand Down
7 changes: 7 additions & 0 deletions tests/test_006_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def test_ds_check_time_attributes_false():
assert not check


def test_cache_data_array():
extent = [119_900, 120_000, 441_900, 442_000]
ahn_org = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename="ahn4.nc")
ahn = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename="ahn4.nc")
# assert ahn.equals(ahn_org)


@pytest.mark.slow
def test_ds_check_grid_false(tmpdir):
# two models with a different grid and same time dicretisation
Expand Down

0 comments on commit f3fa3ee

Please sign in to comment.