Skip to content

Commit

Permalink
Update cache.py (#326)
Browse files Browse the repository at this point in the history
* Update cache.py

The netcdf cache function validates the cache by comparing the ds argument and other function arguments to the pickled arguments. If they match, the cache can be used.

Currently, just the coordinates of the argument ds and the output ds had to match, introducing two errors:
- If the data_vars differ and are used the cache is falsely valid
- The coordintates of the ds argument has to match the coordinates of the output ds. This limits the use of the cache function.

The PR compares the hash of the coords and data_vars of the ds argument to those that were stored in the pickle together with the cached output ds.

Ideally, the cache.cache_netcdf() accepts arguments that specify specifically which data_vars and coords need to be included in the validation check. Beyond the scope of this pr.

- Included tests
  • Loading branch information
bdestombe authored Apr 15, 2024
1 parent cde1359 commit 68ee919
Show file tree
Hide file tree
Showing 9 changed files with 325 additions and 239 deletions.
368 changes: 232 additions & 136 deletions nlmod/cache.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion nlmod/dims/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,7 @@ def get_vertices(ds, vert_per_cid=4, epsilon=0, rotated=False):
return vertices_da


@cache.cache_netcdf
@cache.cache_netcdf(coords_2d=True)
def mask_model_edge(ds, idomain=None):
"""get data array which is 1 for every active cell (defined by idomain) at
the boundaries of the model (xmin, xmax, ymin, ymax). Other cells are 0.
Expand Down
12 changes: 6 additions & 6 deletions nlmod/read/ahn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
logger = logging.getLogger(__name__)


@cache.cache_netcdf
@cache.cache_netcdf(coords_2d=True)
def get_ahn(ds=None, identifier="AHN4_DTM_5m", method="average", extent=None):
"""Get a model dataset with ahn variable.
Expand Down Expand Up @@ -193,7 +193,7 @@ def get_ahn_along_line(line, ahn=None, dx=None, num=None, method="linear", plot=
return z


@cache.cache_netcdf
@cache.cache_netcdf()
def get_latest_ahn_from_wcs(
extent=None,
identifier="dsm_05m",
Expand Down Expand Up @@ -309,7 +309,7 @@ def get_ahn4_tiles(extent=None):
return gdf


@cache.cache_netcdf
@cache.cache_netcdf()
def get_ahn1(extent, identifier="ahn1_5m", as_data_array=True):
"""Download AHN1.
Expand All @@ -336,7 +336,7 @@ def get_ahn1(extent, identifier="ahn1_5m", as_data_array=True):
return da


@cache.cache_netcdf
@cache.cache_netcdf()
def get_ahn2(extent, identifier="ahn2_5m", as_data_array=True):
"""Download AHN2.
Expand All @@ -360,7 +360,7 @@ def get_ahn2(extent, identifier="ahn2_5m", as_data_array=True):
return _download_and_combine_tiles(tiles, identifier, extent, as_data_array)


@cache.cache_netcdf
@cache.cache_netcdf()
def get_ahn3(extent, identifier="AHN3_5m_DTM", as_data_array=True):
"""Download AHN3.
Expand All @@ -383,7 +383,7 @@ def get_ahn3(extent, identifier="AHN3_5m_DTM", as_data_array=True):
return _download_and_combine_tiles(tiles, identifier, extent, as_data_array)


@cache.cache_netcdf
@cache.cache_netcdf()
def get_ahn4(extent, identifier="AHN4_DTM_5m", as_data_array=True):
"""Download AHN4.
Expand Down
4 changes: 2 additions & 2 deletions nlmod/read/geotop.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_kh_kv_table(kind="Brabant"):
return df


@cache.cache_netcdf
@cache.cache_netcdf()
def to_model_layers(
geotop_ds,
strat_props=None,
Expand Down Expand Up @@ -233,7 +233,7 @@ def to_model_layers(
return ds


@cache.cache_netcdf
@cache.cache_netcdf()
def get_geotop(extent, url=GEOTOP_URL, probabilities=False):
"""Get a slice of the geotop netcdf url within the extent, set the x and y
coordinates to match the cell centers and keep only the strat and lithok
Expand Down
4 changes: 2 additions & 2 deletions nlmod/read/jarkus.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)


@cache.cache_netcdf
@cache.cache_netcdf()
def get_bathymetry(ds, northsea, kind="jarkus", method="average"):
"""get bathymetry of the Northsea from the jarkus dataset.
Expand Down Expand Up @@ -92,7 +92,7 @@ def get_bathymetry(ds, northsea, kind="jarkus", method="average"):
return ds_out


@cache.cache_netcdf
@cache.cache_netcdf()
def get_dataset_jarkus(extent, kind="jarkus", return_tiles=False, time=-1):
"""Get bathymetry from Jarkus within a certain extent. If return_tiles is False, the
following actions are performed:
Expand Down
2 changes: 1 addition & 1 deletion nlmod/read/knmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
logger = logging.getLogger(__name__)


@cache.cache_netcdf
@cache.cache_netcdf(coords_2d=True, coords_time=True)
def get_recharge(ds, method="linear", most_common_station=False):
"""add multiple recharge packages to the groundwater flow model with knmi
data by following these steps:
Expand Down
4 changes: 2 additions & 2 deletions nlmod/read/regis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# REGIS_URL = 'https://www.dinodata.nl/opendap/hyrax/REGIS/REGIS.nc'


@cache.cache_netcdf
@cache.cache_netcdf()
def get_combined_layer_models(
extent,
regis_botm_layer="AKc",
Expand Down Expand Up @@ -93,7 +93,7 @@ def get_combined_layer_models(
return combined_ds


@cache.cache_netcdf
@cache.cache_netcdf()
def get_regis(
extent,
botm_layer="AKc",
Expand Down
4 changes: 2 additions & 2 deletions nlmod/read/rws.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_gdf_surface_water(ds):
return gdf_swater


@cache.cache_netcdf
@cache.cache_netcdf(coords_2d=True)
def get_surface_water(ds, da_basename):
"""create 3 data-arrays from the shapefile with surface water:
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_surface_water(ds, da_basename):
return ds_out


@cache.cache_netcdf
@cache.cache_netcdf(coords_2d=True)
def get_northsea(ds, da_name="northsea"):
"""Get Dataset which is 1 at the northsea and 0 everywhere else. Sea is
defined by rws surface water shapefile.
Expand Down
164 changes: 77 additions & 87 deletions tests/test_006_caching.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,86 @@
import os
import tempfile

import pytest
import test_001_model

import nlmod

tmpdir = tempfile.gettempdir()


def test_ds_check_true():
# two models with the same grid and time dicretisation
ds = test_001_model.get_ds_from_cache("small_model")
ds2 = ds.copy()

check = nlmod.cache._check_ds(ds, ds2)

assert check


def test_ds_check_time_false():
# two models with a different time discretisation
ds = test_001_model.get_ds_from_cache("small_model")
ds2 = test_001_model.get_ds_time_steady(tmpdir)

check = nlmod.cache._check_ds(ds, ds2)

assert not check


def test_ds_check_time_attributes_false():
# two models with a different time discretisation
ds = test_001_model.get_ds_from_cache("small_model")
ds2 = ds.copy()

ds2.time.attrs["time_units"] = "MONTHS"

check = nlmod.cache._check_ds(ds, ds2)

assert not check


def test_cache_data_array():
def test_cache_ahn_data_array():
"""Test caching of AHN data array. Does not have dataset as argument."""
extent = [119_900, 120_000, 441_900, 442_000]
ahn_no_cache = nlmod.read.ahn.get_ahn4(extent)
ahn_cached = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename="ahn4.nc")
ahn_cache = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename="ahn4.nc")
assert ahn_cached.equals(ahn_no_cache)
assert ahn_cache.equals(ahn_no_cache)


@pytest.mark.slow
def test_ds_check_grid_false(tmpdir):
# two models with a different grid and same time dicretisation
ds = test_001_model.get_ds_from_cache("small_model")
ds2 = test_001_model.get_ds_time_transient(tmpdir)
extent = [99100.0, 99400.0, 489100.0, 489400.0]
regis_ds = nlmod.read.regis.get_combined_layer_models(
extent,
use_regis=True,
use_geotop=False,
cachedir=tmpdir,
cachename="comb.nc",
cache_name = "ahn4.nc"

with tempfile.TemporaryDirectory() as tmpdir:
assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet1"
ahn_no_cache = nlmod.read.ahn.get_ahn4(extent)
assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet2"

ahn_cached = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename=cache_name)
assert os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should have existed by now"
assert ahn_cached.equals(ahn_no_cache)
modification_time1 = os.path.getmtime(os.path.join(tmpdir, cache_name))

# Check if the cache is used. If not, cache is rewritten and modification time changes
ahn_cache = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename=cache_name)
assert ahn_cache.equals(ahn_no_cache)
modification_time2 = os.path.getmtime(os.path.join(tmpdir, cache_name))
assert modification_time1 == modification_time2, "Cache should not be rewritten"

# Different extent should not lead to using the cache
extent = [119_800, 120_000, 441_900, 442_000]
ahn_cache = nlmod.read.ahn.get_ahn4(extent, cachedir=tmpdir, cachename=cache_name)
modification_time3 = os.path.getmtime(os.path.join(tmpdir, cache_name))
assert modification_time1 != modification_time3, "Cache should have been rewritten"


def test_cache_northsea_data_array():
"""Test caching of AHN data array. Does have dataset as argument."""
from nlmod.read.rws import get_northsea
ds1 = nlmod.get_ds(
[119_700, 120_000, 441_900, 442_000],
delr=100.,
delc=100.,
top=0.,
botm=[-1., -2.],
kh=10.,
kv=1.,
)
ds2 = nlmod.base.to_model_ds(regis_ds, delr=50.0, delc=50.0)

check = nlmod.cache._check_ds(ds, ds2)

assert not check


@pytest.mark.skip("too slow")
def test_use_cached_regis(tmpdir):
extent = [98700.0, 99000.0, 489500.0, 489700.0]
regis_ds1 = nlmod.read.regis.get_regis(extent, cachedir=tmpdir, cachename="reg.nc")

regis_ds2 = nlmod.read.regis.get_regis(extent, cachedir=tmpdir, cachename="reg.nc")

assert regis_ds1.equals(regis_ds2)


@pytest.mark.skip("too slow")
def test_do_not_use_cached_regis(tmpdir):
# cache regis
extent = [98700.0, 99000.0, 489500.0, 489700.0]
regis_ds1 = nlmod.read.regis.get_regis(
extent, cachedir=tmpdir, cachename="regis.nc"
)

# do not use cache because extent is different
extent = [99100.0, 99400.0, 489100.0, 489400.0]
regis_ds2 = nlmod.read.regis.get_regis(
extent, cachedir=tmpdir, cachename="regis.nc"
ds2 = nlmod.get_ds(
[119_800, 120_000, 441_900, 444_000],
delr=100.,
delc=100.,
top=0.,
botm=[-1., -3.],
kh=10.,
kv=1.,
)

assert not regis_ds1.equals(regis_ds2)
cache_name = "northsea.nc"

with tempfile.TemporaryDirectory() as tmpdir:
assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet1"
out1_no_cache = get_northsea(ds1)
assert not os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should not exist yet2"

out1_cached = get_northsea(ds1, cachedir=tmpdir, cachename=cache_name)
assert os.path.exists(os.path.join(tmpdir, cache_name)), "Cache should exist by now"
assert out1_cached.equals(out1_no_cache)
modification_time1 = os.path.getmtime(os.path.join(tmpdir, cache_name))

# Check if the cache is used. If not, cache is rewritten and modification time changes
out1_cache = get_northsea(ds1, cachedir=tmpdir, cachename=cache_name)
assert out1_cache.equals(out1_no_cache)
modification_time2 = os.path.getmtime(os.path.join(tmpdir, cache_name))
assert modification_time1 == modification_time2, "Cache should not be rewritten"

# Only properties of `coords_2d` determine if the cache is used. Cache should still be used.
ds1["toppertje"] = ds1.top + 1
out1_cache = get_northsea(ds1, cachedir=tmpdir, cachename=cache_name)
assert out1_cache.equals(out1_no_cache)
modification_time2 = os.path.getmtime(os.path.join(tmpdir, cache_name))
assert modification_time1 == modification_time2, "Cache should not be rewritten"

# Different extent should not lead to using the cache
out2_cache = get_northsea(ds2, cachedir=tmpdir, cachename=cache_name)
modification_time3 = os.path.getmtime(os.path.join(tmpdir, cache_name))
assert modification_time1 != modification_time3, "Cache should have been rewritten"
assert not out2_cache.equals(out1_no_cache)

0 comments on commit 68ee919

Please sign in to comment.