Skip to content

Commit

Permalink
Creating new cache used full ds instead of reduced ds
Browse files Browse the repository at this point in the history
Required to ensure that the cached function does not use data_vars that not explicitly required
  • Loading branch information
bdestombe committed Apr 28, 2024
1 parent bef2118 commit 003c6e7
Showing 1 changed file with 78 additions and 56 deletions.
134 changes: 78 additions & 56 deletions nlmod/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def clear_cache(cachedir):


def cache_netcdf(coords_2d=False, coords_3d=False, coords_time=False, datavars=None, coords=None, attrs=None):
"""decorator to read/write the result of a function from/to a file to speed
"""Decorator to read/write the result of a function from/to a file to speed
up function calls with the same arguments. Should only be applied to
functions that:
Expand Down Expand Up @@ -118,27 +118,54 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
fname_cache = os.path.join(cachedir, cachename) # netcdf file
fname_pickle_cache = fname_cache.replace(".nc", ".pklz")

# create dictionary with function arguments
func_args_dic = {f"arg{i}": args[i] for i in range(len(args))}
func_args_dic.update(kwargs)
# adjust args and kwargs with minimal dataset
args_adj = []
kwargs_adj = {}

# remove xarray dataset from function arguments
dataset = None
for key in list(func_args_dic.keys()):
if isinstance(func_args_dic[key], xr.Dataset):
if dataset is not None:
raise TypeError(
"Function was called with multiple xarray dataset arguments. Currently unsupported."
)
dataset_received = func_args_dic.pop(key)
dataset = ds_contains(
dataset_received,
datasets = []
func_args_dic = {}

for i, arg in enumerate(args):
if isinstance(arg, xr.Dataset):
arg_adj = ds_contains(
arg,
coords_2d=coords_2d,
coords_3d=coords_3d,
coords_time=coords_time,
datavars=datavars,
coords=coords,
attrs=attrs)
args_adj.append(arg_adj)
datasets.append(arg_adj)
else:
args_adj.append(arg)
func_args_dic[f"arg{i}"] = arg

for key, arg in kwargs.items():
if isinstance(arg, xr.Dataset):
arg_adj = ds_contains(
arg,
coords_2d=coords_2d,
coords_3d=coords_3d,
coords_time=coords_time,
datavars=datavars,
coords=coords,
attrs=attrs)
kwargs_adj[key] = arg_adj
datasets.append(arg_adj)
else:
kwargs_adj[key] = arg
func_args_dic[key] = arg

if len(datasets) == 0:
dataset = None
elif len(datasets) == 1:
dataset = datasets[0]
else:
msg = "Function was called with multiple xarray dataset arguments. Currently unsupported."
raise NotImplementedError(
msg
)

# only use cache if the cache file and the pickled function arguments exist
if os.path.exists(fname_cache) and os.path.exists(fname_pickle_cache):
Expand Down Expand Up @@ -190,7 +217,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
return cached_ds

# create cache
result = func(*args, **kwargs)
result = func(*args_adj, **kwargs_adj)
logger.info(f"caching data -> {cachename}")

if isinstance(result, xr.DataArray):
Expand All @@ -205,7 +232,7 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):

# write netcdf cache
# check if dataset is chunked for writing with dask.delayed
first_data_var = list(result.data_vars.keys())[0]
first_data_var = next(iter(result.data_vars.keys()))
if result[first_data_var].chunks:
delayed = result.to_netcdf(fname_cache, compute=False)
with ProgressBar():
Expand All @@ -230,16 +257,16 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
with open(fname_pickle_cache, "wb") as fpklz:
pickle.dump(func_args_dic, fpklz)
else:
raise TypeError(f"expected xarray Dataset, got {type(result)} instead")
result = _check_for_data_array(result)
return result
msg = f"expected xarray Dataset, got {type(result)} instead"
raise TypeError(msg)
return _check_for_data_array(result)
return wrapper

return decorator


def cache_pickle(func):
"""decorator to read/write the result of a function from/to a file to speed
"""Decorator to read/write the result of a function from/to a file to speed
up function calls with the same arguments. Should only be applied to
functions that:
Expand All @@ -262,7 +289,6 @@ def cache_pickle(func):
docstring with a "Returns" heading. If this is not the case an error is
raised when trying to decorate the function.
"""

# add cachedir and cachename to docstring
_update_docstring_and_signature(func)

Expand Down Expand Up @@ -346,22 +372,23 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):
with open(fname_pickle_cache, "wb") as fpklz:
pickle.dump(func_args_dic, fpklz)
else:
raise TypeError(f"expected DataFrame, got {type(result)} instead")
msg = f"expected DataFrame, got {type(result)} instead"
raise TypeError(msg)
return result

return decorator


def _same_function_arguments(func_args_dic, func_args_dic_cache):
"""checks if two dictionaries with function arguments are identical by
"""Checks if two dictionaries with function arguments are identical by
checking:
1. if they have the same keys
2. if the items have the same type
3. if the items have the same values (only possible for the types: int,
float, bool, str, bytes, list,
tuple, dict, np.ndarray,
xr.DataArray,
flopy.mf6.ModflowGwf)
flopy.mf6.ModflowGwf).
Parameters
----------
Expand All @@ -381,7 +408,7 @@ def _same_function_arguments(func_args_dic, func_args_dic_cache):
"""
for key, item in func_args_dic.items():
# check if cache and function call have same argument names
if key not in func_args_dic_cache.keys():
if key not in func_args_dic_cache:
logger.info(
"cache was created using different function arguments, do not use cached data"
)
Expand Down Expand Up @@ -510,16 +537,9 @@ def _update_docstring_and_signature(func):
cur_param = cur_param[:-1]
else:
add_kwargs = None
new_param = cur_param + (
inspect.Parameter(
"cachedir", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None
),
inspect.Parameter(
"cachename", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None
),
)
new_param = (*cur_param, inspect.Parameter("cachedir", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None), inspect.Parameter("cachename", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None))
if add_kwargs is not None:
new_param = new_param + (add_kwargs,)
new_param = (*new_param, add_kwargs)
sig = sig.replace(parameters=new_param)
func.__signature__ = sig

Expand All @@ -541,7 +561,7 @@ def _update_docstring_and_signature(func):
" filename of netcdf cache. If None no cache is used."
" Default is None.\n\n Returns"
)
new_doc = "".join((mod_before, after))
new_doc = f"{mod_before}{after}"
func.__doc__ = new_doc
return

Expand Down Expand Up @@ -569,10 +589,7 @@ def _check_for_data_array(ds):
"""
if "__xarray_dataarray_variable__" in ds:
if "spatial_ref" in ds:
spatial_ref = ds.spatial_ref
else:
spatial_ref = None
spatial_ref = ds.spatial_ref if "spatial_ref" in ds else None
# the method returns a DataArray, so we return only this DataArray
ds = ds["__xarray_dataarray_variable__"]
if spatial_ref is not None:
Expand Down Expand Up @@ -611,25 +628,25 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar
"""
# Return the full dataset if not configured
if ds is None:
raise ValueError("No dataset provided")
elif not coords_2d and not coords_3d and not datavars and not coords and not attrs:
msg = "No dataset provided"
raise ValueError(msg)
if not coords_2d and not coords_3d and not datavars and not coords and not attrs:
return ds
else:
# Initialize lists
if datavars is None:
datavars = []
if coords is None:
coords = []
if attrs is None:
attrs = []
# Initialize lists
if datavars is None:
datavars = []
if coords is None:
coords = []
if attrs is None:
attrs = []

# Add coords, datavars and attrs via shorthands
if coords_2d or coords_3d:
coords.append("x")
coords.append("y")
datavars.append("area")
attrs.append("extent")

if "gridtype" in ds.attrs:
attrs.append("gridtype")

Expand All @@ -651,23 +668,28 @@ def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavar

# User-friendly error messages
if "northsea" in datavars and "northsea" not in ds.data_vars:
raise ValueError("Northsea not in dataset. Run nlmod.read.rws.add_northsea() first.")
msg = "Northsea not in dataset. Run nlmod.read.rws.add_northsea() first."
raise ValueError(msg)

if "time" in coords and "time" not in ds.coords:
raise ValueError("time not in dataset. Run nlmod.time.set_ds_time() first.")
msg = "time not in dataset. Run nlmod.time.set_ds_time() first."
raise ValueError(msg)

# User-unfriendly error messages
for datavar in datavars:
if datavar not in ds.data_vars:
raise ValueError(f"{datavar} not in dataset.data_vars")
msg = f"{datavar} not in dataset.data_vars"
raise ValueError(msg)

for coord in coords:
if coord not in ds.coords:
raise ValueError(f"{coord} not in dataset.coords")
msg = f"{coord} not in dataset.coords"
raise ValueError(msg)

for attr in attrs:
if attr not in ds.attrs:
raise ValueError(f"{attr} not in dataset.attrs")
msg = f"{attr} not in dataset.attrs"
raise ValueError(msg)

# Return only the required data
return xr.Dataset(
Expand Down

0 comments on commit 003c6e7

Please sign in to comment.