-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update cache.py The netcdf cache function validates the cache by comparing the ds argument and other function arguments to the pickled arguments. If they match, the cache can be used. Currently, just the coordinates of the argument ds and the output ds had to match, introducing two errors: - If the data_vars differ and are used the cache is falsely valid - The coordintates of the ds argument has to match the coordinates of the output ds. This limits the use of the cache function. The PR compares the hash of the coords and data_vars of the ds argument to those that were stored in the pickle together with the cached output ds. Ideally, the cache.cache_netcdf() accepts arguments that specify specifically which data_vars and coords need to be included in the validation check. Beyond the scope of this pr. - Included tests
- Loading branch information
Showing
9 changed files
with
325 additions
and
239 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,96 +1,86 @@ | ||
import os | ||
import tempfile | ||
|
||
import pytest | ||
import test_001_model | ||
|
||
import nlmod | ||
|
||
tmpdir = tempfile.gettempdir() | ||
|
||
|
||
def test_ds_check_true(): | ||
# two models with the same grid and time dicretisation | ||
ds = test_001_model.get_ds_from_cache("small_model") | ||
ds2 = ds.copy() | ||
|
||
check = nlmod.cache._check_ds(ds, ds2) | ||
|
||
assert check | ||
|
||
|
||
def test_ds_check_time_false(): | ||
# two models with a different time discretisation | ||
ds = test_001_model.get_ds_from_cache("small_model") | ||
ds2 = test_001_model.get_ds_time_steady(tmpdir) | ||
|
||
check = nlmod.cache._check_ds(ds, ds2) | ||
|
||
assert not check | ||
|
||
|
||
def test_ds_check_time_attributes_false(): | ||
# two models with a different time discretisation | ||
ds = test_001_model.get_ds_from_cache("small_model") | ||
ds2 = ds.copy() | ||
|
||
ds2.time.attrs["time_units"] = "MONTHS" | ||
|
||
check = nlmod.cache._check_ds(ds, ds2) | ||
|
||
assert not check | ||
|
||
|
||
def test_cache_data_array(): | ||
def test_cache_ahn_data_array(): | ||
"""Test caching of AHN data array. Does not have dataset as argument.""" | ||
extent = [119_900, 120_000, 441_900, 442_000] | ||
ahn_no_cache = nlmod.read.ahn.get_ahn4(extent) | ||
ahn_cached = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename="ahn4.nc") | ||
ahn_cache = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename="ahn4.nc") | ||
assert ahn_cached.equals(ahn_no_cache) | ||
assert ahn_cache.equals(ahn_no_cache) | ||
|
||
|
||
@pytest.mark.slow | ||
def test_ds_check_grid_false(tmpdir): | ||
# two models with a different grid and same time dicretisation | ||
ds = test_001_model.get_ds_from_cache("small_model") | ||
ds2 = test_001_model.get_ds_time_transient(tmpdir) | ||
extent = [99100.0, 99400.0, 489100.0, 489400.0] | ||
regis_ds = nlmod.read.regis.get_combined_layer_models( | ||
extent, | ||
use_regis=True, | ||
use_geotop=False, | ||
cachedir=tmpdir, | ||
cachename="comb.nc", | ||
cache_name = "ahn4.nc" | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet1" | ||
ahn_no_cache = nlmod.read.ahn.get_ahn4(extent) | ||
assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet2" | ||
|
||
ahn_cached = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename=cache_name) | ||
assert os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should have existed by now" | ||
assert ahn_cached.equals(ahn_no_cache) | ||
modification_time1 = os.path.getmtime(os.path.join(tmpdir, cache_name)) | ||
|
||
# Check if the cache is used. If not, cache is rewritten and modification time changes | ||
ahn_cache = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename=cache_name) | ||
assert ahn_cache.equals(ahn_no_cache) | ||
modification_time2 = os.path.getmtime(os.path.join(tmpdir, cache_name)) | ||
assert modification_time1 == modification_time2, "Cache should not be rewritten" | ||
|
||
# Different extent should not lead to using the cache | ||
extent = [119_800, 120_000, 441_900, 442_000] | ||
ahn_cache = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename=cache_name) | ||
modification_time3 = os.path.getmtime(os.path.join(tmpdir, cache_name)) | ||
assert modification_time1 != modification_time3, "Cache should have been rewritten" | ||
|
||
|
||
def test_cache_northsea_data_array(): | ||
"""Test caching of AHN data array. Does have dataset as argument.""" | ||
from nlmod.read.rws import get_northsea | ||
ds1 = nlmod.get_ds( | ||
[119_700, 120_000, 441_900, 442_000], | ||
delr=100., | ||
delc=100., | ||
top=0., | ||
botm=[-1., -2.], | ||
kh=10., | ||
kv=1., | ||
) | ||
ds2 = nlmod.base.to_model_ds(regis_ds, delr=50.0, delc=50.0) | ||
|
||
check = nlmod.cache._check_ds(ds, ds2) | ||
|
||
assert not check | ||
|
||
|
||
@pytest.mark.skip("too slow") | ||
def test_use_cached_regis(tmpdir): | ||
extent = [98700.0, 99000.0, 489500.0, 489700.0] | ||
regis_ds1 = nlmod.read.regis.get_regis(extent, cachedir=tmpdir, cachename="reg.nc") | ||
|
||
regis_ds2 = nlmod.read.regis.get_regis(extent, cachedir=tmpdir, cachename="reg.nc") | ||
|
||
assert regis_ds1.equals(regis_ds2) | ||
|
||
|
||
@pytest.mark.skip("too slow") | ||
def test_do_not_use_cached_regis(tmpdir): | ||
# cache regis | ||
extent = [98700.0, 99000.0, 489500.0, 489700.0] | ||
regis_ds1 = nlmod.read.regis.get_regis( | ||
extent, cachedir=tmpdir, cachename="regis.nc" | ||
) | ||
|
||
# do not use cache because extent is different | ||
extent = [99100.0, 99400.0, 489100.0, 489400.0] | ||
regis_ds2 = nlmod.read.regis.get_regis( | ||
extent, cachedir=tmpdir, cachename="regis.nc" | ||
ds2 = nlmod.get_ds( | ||
[119_800, 120_000, 441_900, 444_000], | ||
delr=100., | ||
delc=100., | ||
top=0., | ||
botm=[-1., -3.], | ||
kh=10., | ||
kv=1., | ||
) | ||
|
||
assert not regis_ds1.equals(regis_ds2) | ||
cache_name = "northsea.nc" | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet1" | ||
out1_no_cache = get_northsea(ds1) | ||
assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet2" | ||
|
||
out1_cached = get_northsea(ds1, cachedir=tmpdir, cachename=cache_name) | ||
assert os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should exist by now" | ||
assert out1_cached.equals(out1_no_cache) | ||
modification_time1 = os.path.getmtime(os.path.join(tmpdir, cache_name)) | ||
|
||
# Check if the cache is used. If not, cache is rewritten and modification time changes | ||
out1_cache = get_northsea(ds1, cachedir=tmpdir, cachename=cache_name) | ||
assert out1_cache.equals(out1_no_cache) | ||
modification_time2 = os.path.getmtime(os.path.join(tmpdir, cache_name)) | ||
assert modification_time1 == modification_time2, "Cache should not be rewritten" | ||
|
||
# Only properties of `coords_2d` determine if the cache is used. Cache should still be used. | ||
ds1["toppertje"] = ds1.top + 1 | ||
out1_cache = get_northsea(ds1, cachedir=tmpdir, cachename=cache_name) | ||
assert out1_cache.equals(out1_no_cache) | ||
modification_time2 = os.path.getmtime(os.path.join(tmpdir, cache_name)) | ||
assert modification_time1 == modification_time2, "Cache should not be rewritten" | ||
|
||
# Different extent should not lead to using the cache | ||
out2_cache = get_northsea(ds2, cachedir=tmpdir, cachename=cache_name) | ||
modification_time3 = os.path.getmtime(os.path.join(tmpdir, cache_name)) | ||
assert modification_time1 != modification_time3, "Cache should have been rewritten" | ||
assert not out2_cache.equals(out1_no_cache) |