diff --git a/nlmod/cache.py b/nlmod/cache.py index 93575d78..ad193d06 100644 --- a/nlmod/cache.py +++ b/nlmod/cache.py @@ -53,7 +53,7 @@ def clear_cache(cachedir): logger.info(f"removed {fname_nc}") -def cache_netcdf(func): +def cache_netcdf(coords_2d=False, coords_3d=False, coords_time=False, datavars=None, coords=None, attrs=None): """decorator to read/write the result of a function from/to a file to speed up function calls with the same arguments. Should only be applied to functions that: @@ -81,125 +81,159 @@ def cache_netcdf(func): to the decorated function. This assumes that the decorated function has a docstring with a "Returns" heading. If this is not the case an error is raised when trying to decorate the function. - """ - # add cachedir and cachename to docstring - _update_docstring_and_signature(func) + If all kwargs are left to their defaults, the function caches the full dataset. - @functools.wraps(func) - def decorator(*args, cachedir=None, cachename=None, **kwargs): - # 1 check if cachedir and name are provided - if cachedir is None or cachename is None: - return func(*args, **kwargs) - - if not cachename.endswith(".nc"): - cachename += ".nc" - - fname_cache = os.path.join(cachedir, cachename) # netcdf file - fname_pickle_cache = fname_cache.replace(".nc", ".pklz") - - # create dictionary with function arguments - func_args_dic = {f"arg{i}": args[i] for i in range(len(args))} - func_args_dic.update(kwargs) + Parameters + ---------- + ds : xr.Dataset + Dataset with dimensions and coordinates. + coords_2d : bool, optional + Shorthand for adding 2D coordinates. The default is False. + coords_3d : bool, optional + Shorthand for adding 3D coordinates. The default is False. + coords_time : bool, optional + Shorthand for adding time coordinates. The default is False. + datavars : list, optional + List of data variables to check for. The default is an empty list. + coords : list, optional + List of coordinates to check for. The default is an empty list. + attrs : list, optional + List of attributes to check for. The default is an empty list. + """ - # remove xarray dataset from function arguments - dataset = None - for key in list(func_args_dic.keys()): - if isinstance(func_args_dic[key], xr.Dataset): - if dataset is not None: - raise TypeError( - "function was called with multiple xarray dataset arguments" + def decorator(func): + # add cachedir and cachename to docstring + _update_docstring_and_signature(func) + + @functools.wraps(func) + def wrapper(*args, cachedir=None, cachename=None, **kwargs): + # 1 check if cachedir and name are provided + if cachedir is None or cachename is None: + return func(*args, **kwargs) + + if not cachename.endswith(".nc"): + cachename += ".nc" + + fname_cache = os.path.join(cachedir, cachename) # netcdf file + fname_pickle_cache = fname_cache.replace(".nc", ".pklz") + + # create dictionary with function arguments + func_args_dic = {f"arg{i}": args[i] for i in range(len(args))} + func_args_dic.update(kwargs) + + # remove xarray dataset from function arguments + dataset = None + for key in list(func_args_dic.keys()): + if isinstance(func_args_dic[key], xr.Dataset): + if dataset is not None: + raise TypeError( + "Function was called with multiple xarray dataset arguments. Currently unsupported." + ) + dataset_received = func_args_dic.pop(key) + dataset = ds_contains( + dataset_received, + coords_2d=coords_2d, + coords_3d=coords_3d, + coords_time=coords_time, + datavars=datavars, + coords=coords, + attrs=attrs) + + # only use cache if the cache file and the pickled function arguments exist + if os.path.exists(fname_cache) and os.path.exists(fname_pickle_cache): + # check if you can read the pickle, there are several reasons why a + # pickle can not be read. + try: + with open(fname_pickle_cache, "rb") as f: + func_args_dic_cache = pickle.load(f) + pickle_check = True + except (pickle.UnpicklingError, ModuleNotFoundError): + logger.info("could not read pickle, not using cache") + pickle_check = False + argument_check = False + + # check if the module where the function is defined was changed + # after the cache was created + time_mod_func = _get_modification_time(func) + time_mod_cache = os.path.getmtime(fname_cache) + modification_check = time_mod_cache > time_mod_func + + if not modification_check: + logger.info( + f"module of function {func.__name__} recently modified, not using cache" ) - dataset = func_args_dic.pop(key) - # only use cache if the cache file and the pickled function arguments exist - if os.path.exists(fname_cache) and os.path.exists(fname_pickle_cache): - # check if you can read the pickle, there are several reasons why a - # pickle can not be read. - try: - with open(fname_pickle_cache, "rb") as f: - func_args_dic_cache = pickle.load(f) - pickle_check = True - except (pickle.UnpicklingError, ModuleNotFoundError): - logger.info("could not read pickle, not using cache") - pickle_check = False - argument_check = False + with xr.open_dataset(fname_cache) as cached_ds: + cached_ds.load() - # check if the module where the function is defined was changed - # after the cache was created - time_mod_func = _get_modification_time(func) - time_mod_cache = os.path.getmtime(fname_cache) - modification_check = time_mod_cache > time_mod_func + if pickle_check: + # Ensure that the pickle pairs with the netcdf, see #66. + func_args_dic["_nc_hash"] = dask.base.tokenize(cached_ds) - if not modification_check: - logger.info( - f"module of function {func.__name__} recently modified, not using cache" - ) + if dataset is not None: + # Check the coords of the dataset argument + func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(dict(dataset.coords)) - cached_ds = xr.open_dataset(fname_cache) + # Check the data_vars of the dataset argument + func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(dict(dataset.data_vars)) - if pickle_check: - # add netcdf hash to function arguments dic, see #66 - func_args_dic["_nc_hash"] = dask.base.tokenize(cached_ds) - - # check if cache was created with same function arguments as - # function call - argument_check = _same_function_arguments( - func_args_dic, func_args_dic_cache - ) - - cached_ds = _check_for_data_array(cached_ds) - if modification_check and argument_check and pickle_check: - if dataset is None: - logger.info(f"using cached data -> {cachename}") - return cached_ds + # check if cache was created with same function arguments as + # function call + argument_check = _same_function_arguments( + func_args_dic, func_args_dic_cache + ) - # check if cached dataset has same dimension and coordinates - # as current dataset - if _check_ds(dataset, cached_ds): + cached_ds = _check_for_data_array(cached_ds) + if modification_check and argument_check and pickle_check: logger.info(f"using cached data -> {cachename}") return cached_ds - # create cache - result = func(*args, **kwargs) - logger.info(f"caching data -> {cachename}") + # create cache + result = func(*args, **kwargs) + logger.info(f"caching data -> {cachename}") + + if isinstance(result, xr.DataArray): + # set the DataArray as a variable in a new Dataset + result = xr.Dataset({"__xarray_dataarray_variable__": result}) + + if isinstance(result, xr.Dataset): + # close cached netcdf (otherwise it is impossible to overwrite) + if os.path.exists(fname_cache): + with xr.open_dataset(fname_cache) as cached_ds: + cached_ds.load() + + # write netcdf cache + # check if dataset is chunked for writing with dask.delayed + first_data_var = list(result.data_vars.keys())[0] + if result[first_data_var].chunks: + delayed = result.to_netcdf(fname_cache, compute=False) + with ProgressBar(): + delayed.compute() + # close and reopen dataset to ensure data is read from + # disk, and not from opendap + result.close() + result = xr.open_dataset(fname_cache, chunks="auto") + else: + result.to_netcdf(fname_cache) - if isinstance(result, xr.DataArray): - # set the DataArray as a variable in a new Dataset - result = xr.Dataset({"__xarray_dataarray_variable__": result}) + # add netcdf hash to function arguments dic, see #66 + with xr.open_dataset(fname_cache) as temp: + func_args_dic["_nc_hash"] = dask.base.tokenize(temp) - if isinstance(result, xr.Dataset): - # close cached netcdf (otherwise it is impossible to overwrite) - if os.path.exists(fname_cache): - cached_ds = xr.open_dataset(fname_cache) - cached_ds.close() + # Add dataset argument hash to pickle + if dataset is not None: + func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(dict(dataset.coords)) + func_args_dic["_dataset_data_vars_hash"] = dask.base.tokenize(dict(dataset.data_vars)) - # write netcdf cache - # check if dataset is chunked for writing with dask.delayed - first_data_var = list(result.data_vars.keys())[0] - if result[first_data_var].chunks: - delayed = result.to_netcdf(fname_cache, compute=False) - with ProgressBar(): - delayed.compute() - # close and reopen dataset to ensure data is read from - # disk, and not from opendap - result.close() - result = xr.open_dataset(fname_cache, chunks="auto") + # pickle function arguments + with open(fname_pickle_cache, "wb") as fpklz: + pickle.dump(func_args_dic, fpklz) else: - result.to_netcdf(fname_cache) - - # add netcdf hash to function arguments dic, see #66 - temp = xr.open_dataset(fname_cache) - func_args_dic["_nc_hash"] = dask.base.tokenize(temp) - temp.close() - - # pickle function arguments - with open(fname_pickle_cache, "wb") as fpklz: - pickle.dump(func_args_dic, fpklz) - else: - raise TypeError(f"expected xarray Dataset, got {type(result)} instead") - result = _check_for_data_array(result) - return result + raise TypeError(f"expected xarray Dataset, got {type(result)} instead") + result = _check_for_data_array(result) + return result + return wrapper return decorator @@ -318,39 +352,6 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs): return decorator -def _check_ds(ds, ds2): - """Check if two datasets have the same dimensions and coordinates. - - Parameters - ---------- - ds : xr.Dataset - dataset with dimensions and coordinates - ds2 : xr.Dataset - dataset with dimensions and coordinates. This is typically - a cached dataset. - - Returns - ------- - bool - True if the two datasets have the same grid and time discretization. - """ - - for coord in ds2.coords: - if coord in ds.coords: - try: - xr.testing.assert_identical(ds[coord], ds2[coord]) - except AssertionError: - logger.info( - f"coordinate {coord} has different values in cached dataset, not using cache" - ) - return False - else: - logger.info(f"dimension {coord} only present in cache, not using cache") - return False - - return True - - def _same_function_arguments(func_args_dic, func_args_dic_cache): """checks if two dictionaries with function arguments are identical by checking: @@ -577,3 +578,98 @@ def _check_for_data_array(ds): if spatial_ref is not None: ds = ds.assign_coords({"spatial_ref": spatial_ref}) return ds + + +def ds_contains(ds, coords_2d=False, coords_3d=False, coords_time=False, datavars=None, coords=None, attrs=None): + """ + Returns a Dataset containing only the required data. + + If all kwargs are left to their defaults, the function returns the full dataset. + + Parameters + ---------- + ds : xr.Dataset + Dataset with dimensions and coordinates. + coords_2d : bool, optional + Shorthand for adding 2D coordinates. The default is False. + coords_3d : bool, optional + Shorthand for adding 3D coordinates. The default is False. + coords_time : bool, optional + Shorthand for adding time coordinates. The default is False. + datavars : list, optional + List of data variables to check for. The default is an empty list. + coords : list, optional + List of coordinates to check for. The default is an empty list. + attrs : list, optional + List of attributes to check for. The default is an empty list. + + Returns + ------- + ds : xr.Dataset + A Dataset containing only the required data. + + """ + # Return the full dataset if not configured + if ds is None: + raise ValueError("No dataset provided") + elif not coords_2d and not coords_3d and not datavars and not coords and not attrs: + return ds + else: + # Initialize lists + if datavars is None: + datavars = [] + if coords is None: + coords = [] + if attrs is None: + attrs = [] + + # Add coords, datavars and attrs via shorthands + if coords_2d or coords_3d: + coords.append("x") + coords.append("y") + attrs.append("extent") + + if "gridtype" in ds.attrs: + attrs.append("gridtype") + + if "angrot" in ds.attrs: + attrs.append("angrot") + + if coords_3d: + coords.append("layer") + datavars.append("top") + datavars.append("botm") + + if coords_time: + coords.append("time") + datavars.append("steady") + datavars.append("nstp") + datavars.append("tsmult") + attrs.append("start") + attrs.append("time_units") + + # User-friendly error messages + if "northsea" in datavars and "northsea" not in ds.datavars: + raise ValueError("Northsea not in dataset. Run nlmod.read.rws.add_northsea() first.") + + if "time" in coords and "time" not in ds.coords: + raise ValueError("time not in dataset. Run nlmod.time.set_ds_time() first.") + + # User-unfriendly error messages + for datavar in datavars: + if datavar not in ds.datavars: + raise ValueError(f"{datavar} not in dataset.datavars") + + for coord in coords: + if coord not in ds.coords: + raise ValueError(f"{coord} not in dataset.coords") + + for attr in attrs: + if attr not in ds.attrs: + raise ValueError(f"{attr} not in dataset.attrs") + + # Return only the required data + return xr.Dataset( + data_vars={k: ds.data_vars[k] for k in datavars}, + coords={k: ds.coords[k] for k in coords}, + attrs={k: ds.attrs[k] for k in attrs}) diff --git a/nlmod/dims/grid.py b/nlmod/dims/grid.py index 8893ac05..a5ecdaf1 100644 --- a/nlmod/dims/grid.py +++ b/nlmod/dims/grid.py @@ -1852,7 +1852,7 @@ def get_vertices(ds, vert_per_cid=4, epsilon=0, rotated=False): return vertices_da -@cache.cache_netcdf +@cache.cache_netcdf(coords_2d=True) def mask_model_edge(ds, idomain=None): """get data array which is 1 for every active cell (defined by idomain) at the boundaries of the model (xmin, xmax, ymin, ymax). Other cells are 0. diff --git a/nlmod/read/ahn.py b/nlmod/read/ahn.py index ce5484bb..1bd2fda8 100644 --- a/nlmod/read/ahn.py +++ b/nlmod/read/ahn.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -@cache.cache_netcdf +@cache.cache_netcdf(coords_2d=True) def get_ahn(ds=None, identifier="AHN4_DTM_5m", method="average", extent=None): """Get a model dataset with ahn variable. @@ -193,7 +193,7 @@ def get_ahn_along_line(line, ahn=None, dx=None, num=None, method="linear", plot= return z -@cache.cache_netcdf +@cache.cache_netcdf() def get_latest_ahn_from_wcs( extent=None, identifier="dsm_05m", @@ -309,7 +309,7 @@ def get_ahn4_tiles(extent=None): return gdf -@cache.cache_netcdf +@cache.cache_netcdf() def get_ahn1(extent, identifier="ahn1_5m", as_data_array=True): """Download AHN1. @@ -336,7 +336,7 @@ def get_ahn1(extent, identifier="ahn1_5m", as_data_array=True): return da -@cache.cache_netcdf +@cache.cache_netcdf() def get_ahn2(extent, identifier="ahn2_5m", as_data_array=True): """Download AHN2. @@ -360,7 +360,7 @@ def get_ahn2(extent, identifier="ahn2_5m", as_data_array=True): return _download_and_combine_tiles(tiles, identifier, extent, as_data_array) -@cache.cache_netcdf +@cache.cache_netcdf() def get_ahn3(extent, identifier="AHN3_5m_DTM", as_data_array=True): """Download AHN3. @@ -383,7 +383,7 @@ def get_ahn3(extent, identifier="AHN3_5m_DTM", as_data_array=True): return _download_and_combine_tiles(tiles, identifier, extent, as_data_array) -@cache.cache_netcdf +@cache.cache_netcdf() def get_ahn4(extent, identifier="AHN4_DTM_5m", as_data_array=True): """Download AHN4. diff --git a/nlmod/read/geotop.py b/nlmod/read/geotop.py index 5ded82f0..b610f005 100644 --- a/nlmod/read/geotop.py +++ b/nlmod/read/geotop.py @@ -59,7 +59,7 @@ def get_kh_kv_table(kind="Brabant"): return df -@cache.cache_netcdf +@cache.cache_netcdf() def to_model_layers( geotop_ds, strat_props=None, @@ -233,7 +233,7 @@ def to_model_layers( return ds -@cache.cache_netcdf +@cache.cache_netcdf() def get_geotop(extent, url=GEOTOP_URL, probabilities=False): """Get a slice of the geotop netcdf url within the extent, set the x and y coordinates to match the cell centers and keep only the strat and lithok diff --git a/nlmod/read/jarkus.py b/nlmod/read/jarkus.py index e0672a5f..6654352d 100644 --- a/nlmod/read/jarkus.py +++ b/nlmod/read/jarkus.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) -@cache.cache_netcdf +@cache.cache_netcdf() def get_bathymetry(ds, northsea, kind="jarkus", method="average"): """get bathymetry of the Northsea from the jarkus dataset. @@ -92,7 +92,7 @@ def get_bathymetry(ds, northsea, kind="jarkus", method="average"): return ds_out -@cache.cache_netcdf +@cache.cache_netcdf() def get_dataset_jarkus(extent, kind="jarkus", return_tiles=False, time=-1): """Get bathymetry from Jarkus within a certain extent. If return_tiles is False, the following actions are performed: diff --git a/nlmod/read/knmi.py b/nlmod/read/knmi.py index 512f34a1..f6e0b27e 100644 --- a/nlmod/read/knmi.py +++ b/nlmod/read/knmi.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -@cache.cache_netcdf +@cache.cache_netcdf(coords_2d=True, coords_time=True) def get_recharge(ds, method="linear", most_common_station=False): """add multiple recharge packages to the groundwater flow model with knmi data by following these steps: diff --git a/nlmod/read/regis.py b/nlmod/read/regis.py index 729d7b44..90787dcc 100644 --- a/nlmod/read/regis.py +++ b/nlmod/read/regis.py @@ -16,7 +16,7 @@ # REGIS_URL = 'https://www.dinodata.nl/opendap/hyrax/REGIS/REGIS.nc' -@cache.cache_netcdf +@cache.cache_netcdf() def get_combined_layer_models( extent, regis_botm_layer="AKc", @@ -93,7 +93,7 @@ def get_combined_layer_models( return combined_ds -@cache.cache_netcdf +@cache.cache_netcdf() def get_regis( extent, botm_layer="AKc", diff --git a/nlmod/read/rws.py b/nlmod/read/rws.py index 7af2a991..01d5b4e2 100644 --- a/nlmod/read/rws.py +++ b/nlmod/read/rws.py @@ -37,7 +37,7 @@ def get_gdf_surface_water(ds): return gdf_swater -@cache.cache_netcdf +@cache.cache_netcdf(coords_2d=True) def get_surface_water(ds, da_basename): """create 3 data-arrays from the shapefile with surface water: @@ -91,7 +91,7 @@ def get_surface_water(ds, da_basename): return ds_out -@cache.cache_netcdf +@cache.cache_netcdf(coords_2d=True) def get_northsea(ds, da_name="northsea"): """Get Dataset which is 1 at the northsea and 0 everywhere else. Sea is defined by rws surface water shapefile. diff --git a/tests/test_006_caching.py b/tests/test_006_caching.py index 741c1ffd..5bdfb3e0 100644 --- a/tests/test_006_caching.py +++ b/tests/test_006_caching.py @@ -1,96 +1,86 @@ +import os import tempfile -import pytest -import test_001_model - import nlmod -tmpdir = tempfile.gettempdir() - - -def test_ds_check_true(): - # two models with the same grid and time dicretisation - ds = test_001_model.get_ds_from_cache("small_model") - ds2 = ds.copy() - - check = nlmod.cache._check_ds(ds, ds2) - - assert check - - -def test_ds_check_time_false(): - # two models with a different time discretisation - ds = test_001_model.get_ds_from_cache("small_model") - ds2 = test_001_model.get_ds_time_steady(tmpdir) - - check = nlmod.cache._check_ds(ds, ds2) - - assert not check - - -def test_ds_check_time_attributes_false(): - # two models with a different time discretisation - ds = test_001_model.get_ds_from_cache("small_model") - ds2 = ds.copy() - - ds2.time.attrs["time_units"] = "MONTHS" - - check = nlmod.cache._check_ds(ds, ds2) - - assert not check - -def test_cache_data_array(): +def test_cache_ahn_data_array(): + """Test caching of AHN data array. Does not have dataset as argument.""" extent = [119_900, 120_000, 441_900, 442_000] - ahn_no_cache = nlmod.read.ahn.get_ahn4(extent) - ahn_cached = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename="ahn4.nc") - ahn_cache = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename="ahn4.nc") - assert ahn_cached.equals(ahn_no_cache) - assert ahn_cache.equals(ahn_no_cache) - - -@pytest.mark.slow -def test_ds_check_grid_false(tmpdir): - # two models with a different grid and same time dicretisation - ds = test_001_model.get_ds_from_cache("small_model") - ds2 = test_001_model.get_ds_time_transient(tmpdir) - extent = [99100.0, 99400.0, 489100.0, 489400.0] - regis_ds = nlmod.read.regis.get_combined_layer_models( - extent, - use_regis=True, - use_geotop=False, - cachedir=tmpdir, - cachename="comb.nc", + cache_name = "ahn4.nc" + + with tempfile.TemporaryDirectory() as tmpdir: + assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet1" + ahn_no_cache = nlmod.read.ahn.get_ahn4(extent) + assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet2" + + ahn_cached = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename=cache_name) + assert os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should have existed by now" + assert ahn_cached.equals(ahn_no_cache) + modification_time1 = os.path.getmtime(os.path.join(tmpdir, cache_name)) + + # Check if the cache is used. If not, cache is rewritten and modification time changes + ahn_cache = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename=cache_name) + assert ahn_cache.equals(ahn_no_cache) + modification_time2 = os.path.getmtime(os.path.join(tmpdir, cache_name)) + assert modification_time1 == modification_time2, "Cache should not be rewritten" + + # Different extent should not lead to using the cache + extent = [119_800, 120_000, 441_900, 442_000] + ahn_cache = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename=cache_name) + modification_time3 = os.path.getmtime(os.path.join(tmpdir, cache_name)) + assert modification_time1 != modification_time3, "Cache should have been rewritten" + + +def test_cache_northsea_data_array(): + """Test caching of AHN data array. Does have dataset as argument.""" + from nlmod.read.rws import get_northsea + ds1 = nlmod.get_ds( + [119_700, 120_000, 441_900, 442_000], + delr=100., + delc=100., + top=0., + botm=[-1., -2.], + kh=10., + kv=1., ) - ds2 = nlmod.base.to_model_ds(regis_ds, delr=50.0, delc=50.0) - - check = nlmod.cache._check_ds(ds, ds2) - - assert not check - - -@pytest.mark.skip("too slow") -def test_use_cached_regis(tmpdir): - extent = [98700.0, 99000.0, 489500.0, 489700.0] - regis_ds1 = nlmod.read.regis.get_regis(extent, cachedir=tmpdir, cachename="reg.nc") - - regis_ds2 = nlmod.read.regis.get_regis(extent, cachedir=tmpdir, cachename="reg.nc") - - assert regis_ds1.equals(regis_ds2) - - -@pytest.mark.skip("too slow") -def test_do_not_use_cached_regis(tmpdir): - # cache regis - extent = [98700.0, 99000.0, 489500.0, 489700.0] - regis_ds1 = nlmod.read.regis.get_regis( - extent, cachedir=tmpdir, cachename="regis.nc" - ) - - # do not use cache because extent is different - extent = [99100.0, 99400.0, 489100.0, 489400.0] - regis_ds2 = nlmod.read.regis.get_regis( - extent, cachedir=tmpdir, cachename="regis.nc" + ds2 = nlmod.get_ds( + [119_800, 120_000, 441_900, 444_000], + delr=100., + delc=100., + top=0., + botm=[-1., -3.], + kh=10., + kv=1., ) - assert not regis_ds1.equals(regis_ds2) + cache_name = "northsea.nc" + + with tempfile.TemporaryDirectory() as tmpdir: + assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet1" + out1_no_cache = get_northsea(ds1) + assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet2" + + out1_cached = get_northsea(ds1, cachedir=tmpdir, cachename=cache_name) + assert os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should exist by now" + assert out1_cached.equals(out1_no_cache) + modification_time1 = os.path.getmtime(os.path.join(tmpdir, cache_name)) + + # Check if the cache is used. If not, cache is rewritten and modification time changes + out1_cache = get_northsea(ds1, cachedir=tmpdir, cachename=cache_name) + assert out1_cache.equals(out1_no_cache) + modification_time2 = os.path.getmtime(os.path.join(tmpdir, cache_name)) + assert modification_time1 == modification_time2, "Cache should not be rewritten" + + # Only properties of `coords_2d` determine if the cache is used. Cache should still be used. + ds1["toppertje"] = ds1.top + 1 + out1_cache = get_northsea(ds1, cachedir=tmpdir, cachename=cache_name) + assert out1_cache.equals(out1_no_cache) + modification_time2 = os.path.getmtime(os.path.join(tmpdir, cache_name)) + assert modification_time1 == modification_time2, "Cache should not be rewritten" + + # Different extent should not lead to using the cache + out2_cache = get_northsea(ds2, cachedir=tmpdir, cachename=cache_name) + modification_time3 = os.path.getmtime(os.path.join(tmpdir, cache_name)) + assert modification_time1 != modification_time3, "Cache should have been rewritten" + assert not out2_cache.equals(out1_no_cache)