Skip to content

Commit

Permalink
add hash check flag and improve logging of cache
Browse files Browse the repository at this point in the history
  • Loading branch information
OnnoEbbens committed Dec 19, 2024
1 parent 608d741 commit 3959ade
Showing 1 changed file with 67 additions and 54 deletions.
121 changes: 67 additions & 54 deletions nlmod/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def clear_cache(cachedir):

# remove pklz file
os.remove(os.path.join(cachedir, fname))
logger.info(f"removed {fname}")
msg = f"removed {fname}"
logger.info(msg)

# remove netcdf file
fpath_nc = os.path.join(cachedir, fname_nc)
Expand All @@ -51,7 +52,8 @@ def clear_cache(cachedir):
cached_ds = xr.open_dataset(fpath_nc)
cached_ds.close()
os.remove(fpath_nc)
logger.info(f"removed {fname_nc}")
msg = f"removed {fname_nc}"
logger.info(msg)


def cache_netcdf(
Expand All @@ -62,11 +64,10 @@ def cache_netcdf(
datavars=None,
coords=None,
attrs=None,
nc_hash=True,
):
"""Decorator to read/write the result of a function from/to a file to speed
up function calls with the same arguments. Should only be applied to
functions that:
"""Decorator to read/write the result of a function from/to a file to speed up
function calls with the same arguments. Should only be applied to functions that:
- return an Xarray Dataset
- have no more than one xarray dataset as function argument
- have functions arguments of types that can be checked using the
Expand Down Expand Up @@ -111,6 +112,9 @@ def cache_netcdf(
List of coordinates to check for. The default is an empty list.
attrs : list, optional
List of attributes to check for. The default is an empty list.
nc_hash: bool, optional
check if the pickled function arguments belong to the cached netcdf file.
Default is True.
"""

def decorator(func):
Expand Down Expand Up @@ -200,17 +204,19 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
modification_check = time_mod_cache > time_mod_func

if not modification_check:
logger.info(
f"module of function {func.__name__} recently modified, not using cache"
)
msg = f"module of function {func.__name__} recently modified, not using cache"
logger.info(msg)

with xr.open_dataset(fname_cache) as cached_ds:
cached_ds.load()

if pickle_check:
# Ensure that the pickle pairs with the netcdf, see #66.
cache_bytes = open(fname_cache, 'rb').read()
func_args_dic["_nc_hash"] = hashlib.sha256(cache_bytes).hexdigest()
if nc_hash:
cache_bytes = open(fname_cache, "rb").read()
func_args_dic["_nc_hash"] = hashlib.sha256(
cache_bytes
).hexdigest()

if dataset is not None:
# Check the coords of the dataset argument
Expand All @@ -231,12 +237,14 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):

cached_ds = _check_for_data_array(cached_ds)
if modification_check and argument_check and pickle_check:
logger.info(f"using cached data -> {cachename}")
msg = f"using cached data -> {cachename}"
logger.info(msg)
return cached_ds

# create cache
result = func(*args_adj, **kwargs_adj)
logger.info(f"caching data -> {cachename}")
msg = f"caching data -> {cachename}"
logger.info(msg)

if isinstance(result, xr.DataArray):
# set the DataArray as a variable in a new Dataset
Expand All @@ -263,10 +271,11 @@ def wrapper(*args, cachedir=None, cachename=None, **kwargs):
result.to_netcdf(fname_cache)

# add netcdf hash to function arguments dic, see #66
cache_bytes = open(fname_cache, 'rb').read()
func_args_dic["_nc_hash"] = hashlib.sha256(cache_bytes).hexdigest()
if nc_hash:
cache_bytes = open(fname_cache, "rb").read()
func_args_dic["_nc_hash"] = hashlib.sha256(cache_bytes).hexdigest()

# Add dataset argument hash to pickle
# Add dataset argument hash to function arguments dic
if dataset is not None:
func_args_dic["_dataset_coords_hash"] = dask.base.tokenize(
dict(dataset.coords)
Expand Down Expand Up @@ -351,9 +360,8 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):
modification_check = time_mod_cache > time_mod_func

if not modification_check:
logger.info(
f"module of function {func.__name__} recently modified, not using cache"
)
msg = f"module of function {func.__name__} recently modified, not using cache"
logger.info(msg)

# check if you can read the cached pickle, there are
# several reasons why a pickle can not be read.
Expand All @@ -375,12 +383,14 @@ def decorator(*args, cachedir=None, cachename=None, **kwargs):
)

if modification_check and argument_check and pickle_check:
logger.info(f"using cached data -> {cachename}")
msg = f"using cached data -> {cachename}"
logger.info(msg)
return cached_pklz

# create cache
result = func(*args, **kwargs)
logger.info(f"caching data -> {cachename}")
msg = f"caching data -> {cachename}"
logger.info(msg)

if isinstance(result, pd.DataFrame):
# write pklz cache
Expand Down Expand Up @@ -428,20 +438,23 @@ def _same_function_arguments(func_args_dic, func_args_dic_cache):
data was created using the same function arguments as the requested
data.
Notes
-----
Keys that end with '_hash' are assumed to be hashes and not function arguments. They
are checked equally.
"""
for key, item in func_args_dic.items():
# check if cache and function call have same argument names
if key not in func_args_dic_cache:
logger.info(
"cache was created using different function arguments, do not use cached data"
)
msg = f"cache was created using different function argument {key}, do not use cached data"
logger.info(msg)
return False

# check if cache and function call have same argument types
if not isinstance(item, type(func_args_dic_cache[key])):
logger.info(
"cache was created using different function argument types, do not use cached data"
)
msg = f"cache was created using different function argument type: {key}: {type(func_args_dic_cache[key])}, do not use cached data"
logger.info(msg)
return False

# check if cache and function call have same argument values
Expand All @@ -450,34 +463,33 @@ def _same_function_arguments(func_args_dic, func_args_dic_cache):
pass
elif isinstance(item, (numbers.Number, bool, str, bytes, list, tuple)):
if item != func_args_dic_cache[key]:
logger.info(
"cache was created using different function argument values, do not use cached data"
)
if key.endswith("_hash") and isinstance(item, str):
msg = f"cached hashes do not match: {key}, do not use cached data"
logger.info(msg)
else:
msg = f"cache was created using different function argument: {key}, do not use cached data"
logger.info(msg)
return False
elif isinstance(item, np.ndarray):
if not np.allclose(item, func_args_dic_cache[key]):
logger.info(
"cache was created using different numpy array values, do not use cached data"
)
msg = f"cache was created using different numpy array for: {key}, do not use cached data"
logger.info(msg)
return False
elif isinstance(item, (pd.DataFrame, pd.Series, xr.DataArray)):
if not item.equals(func_args_dic_cache[key]):
logger.info(
"cache was created using different DataFrame/Series/DataArray, do not use cached data"
)
msg = f"cache was created using different DataFrame/Series/DataArray for: {key}, do not use cached data"
logger.info(msg)
return False
elif isinstance(item, dict):
# recursive checking
if not _same_function_arguments(item, func_args_dic_cache[key]):
logger.info(
"cache was created using different dictionaries, do not use cached data"
)
msg = f"cache was created using a different dictionary for: {key}, do not use cached data"
logger.info(msg)
return False
elif isinstance(item, (flopy.mf6.ModflowGwf, flopy.modflow.mf.Modflow)):
if str(item) != str(func_args_dic_cache[key]):
logger.info(
"cache was created using different groundwater flow model, do not use cached data"
)
msg = f"cache was created using different groundwater flow model for: {key}, do not use cached data"
logger.info(msg)
return False

elif isinstance(item, flopy.utils.gridintersect.GridIntersect):
Expand All @@ -498,24 +510,25 @@ def _same_function_arguments(func_args_dic, func_args_dic_cache):
or mfgrid1.keys() != mfgrid2.keys()
or not is_same_length_props
):
logger.info(
"cache was created using different gridintersect, do not use cached data"
)
msg = f"cache was created using different gridintersect object: {key}, do not use cached data"
logger.info(msg)
return False

is_other_props_equal = all(
np.all(v == mfgrid2[k]) for k, v in mfgrid1.items()
)

if not is_other_props_equal:
logger.info(
"cache was created using different gridintersect, do not use cached data"
)
msg = f"cache was created using different gridintersect object: {key}, do not use cached data"
logger.info(msg)
return False

else:
logger.info("cannot check if cache is valid, assuming invalid cache")
logger.info(f"function argument of type {type(item)}")
msg = (
f"cannot check if cache argument {key} is valid, assuming invalid cache,"
f"function argument of type {type(item)}"
)
logger.info(msg)
return False

return True
Expand Down Expand Up @@ -585,12 +598,12 @@ def _update_docstring_and_signature(func):
# add cachedir and cachename to docstring
original_doc = func.__doc__
if original_doc is None:
logger.warning(f'Function "{func.__name__}" has no docstring')
msg = f'Function "{func.__name__}" has no docstring'
logger.warning(msg)
return
if "Returns" not in original_doc:
logger.warning(
f'Function "{func.__name__}" has no "Returns" header in docstring'
)
msg = f'Function "{func.__name__}" has no "Returns" header in docstring'
logger.warning(msg)
return
before, after = original_doc.split("Returns")
mod_before = (
Expand Down

0 comments on commit 3959ade

Please sign in to comment.