Skip to content

Commit 032f884

Browse files
committed
add a basic caching implementation in general and for Dataset object to start with
1 parent 601ecaf commit 032f884

File tree

3 files changed

+76
-26
lines changed

3 files changed

+76
-26
lines changed

langbrainscore/interface/cacheable.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11

22

3-
from abc import ABC, abstractmethod, abstractclassmethod
3+
import pickle
44
import typing
5+
from abc import ABC, abstractclassmethod, abstractmethod
56
from pathlib import Path
6-
from diskcache import Cache
7-
7+
import yaml
8+
import xarray as xr
9+
from langbrainscore.utils.cache import get_cache_root
10+
from langbrainscore.utils.logging import log
811

912
T = typing.TypeVar('T')
1013

@@ -13,24 +16,57 @@ class _Cacheable(ABC):
1316
A class used to define a common interface for Object caching in LangBrainscore
1417
'''
1518

16-
@abstractclassmethod
17-
def _get_netcdf_cacheable_objects(cls) -> typing.Iterable[str]:
18-
return ()
19+
# @abstractclassmethod
20+
# @classmethod
21+
def _get_xarray_objects(self) -> typing.Iterable[str]:
22+
'''
23+
returns all (visible) attributes of self that are instances of xarray
24+
'''
25+
keys = []
26+
for key, ob in vars(self).items():
27+
if isinstance(ob, xr.DataArray):
28+
keys += [key]
29+
return keys
1930

20-
@abstractclassmethod
21-
def _get_meta_attributes(cls) -> typing.Iterable[str]:
22-
return ()
31+
# @abstractclassmethod
32+
# def _get_meta_attributes(cls) -> typing.Iterable[str]:
33+
# return ()
2334

24-
def to_cache(self, filename) -> None:
35+
def to_cache(self, identifier_string: str, overwrite = True,
36+
xarray_serialization_backend='to_zarr', cache_dir = None) -> None:
2537
'''
2638
dump this object to cache. this method implementation will serve
2739
as the default implementation. it is recommended that this be left
2840
as-is for compatibility with caching across the library.
2941
'''
30-
NotImplemented
42+
root = Path(cache_dir or get_cache_root()).expanduser().resolve()
43+
root /= identifier_string
44+
root /= self.__class__.__name__
45+
root.mkdir(exist_ok=True, parents=True)
46+
log(f'caching {self} to {root}')
47+
48+
with (root / 'xarray_object_names.yml').open('w') as f:
49+
yaml.dump(self._get_xarray_objects(), f, yaml.SafeDumper)
50+
51+
kwargs = {}
52+
if overwrite and 'zarr' in xarray_serialization_backend:
53+
kwargs.update({'mode':'w'})
54+
for ob_name in self._get_xarray_objects():
55+
ob = getattr(self, ob_name)
56+
tgt_dir = root / (ob_name + '.xr')
57+
58+
dump_object = getattr(ob.to_dataset(name='data'), xarray_serialization_backend)
59+
dump_object(tgt_dir, **kwargs)
60+
61+
with (root / f'{self.__class__.__name__}.pkl').open('wb') as f:
62+
pickle.dump(self, f)
63+
3164

3265
# NB comment from Guido: https://github.com/python/typing/issues/58#issuecomment-194569410
33-
def from_cache(cls: T, filename) -> typing.Callable[..., T]:
66+
@classmethod
67+
def from_cache(cls, identifier_string: str,
68+
xarray_deserialization_backend='open_zarr',
69+
cache_dir = None) -> T:
3470
'''
3571
construct an object from cache. subclasses must start with the
3672
object returned by a call to this method like so:
@@ -41,12 +77,22 @@ def from_cache(cls: T, filename) -> typing.Callable[..., T]:
4177
return ob
4278
4379
'''
44-
return NotImplemented
80+
root = Path(cache_dir or get_cache_root()).expanduser().resolve()
81+
root /= identifier_string
82+
root /= cls.__name__
83+
root.mkdir(exist_ok=True, parents=True)
84+
log(f'loading cache for {cls} from {root} {identifier_string}')
85+
86+
with (root/f'{cls.__name__}.pkl').open('rb') as f:
87+
ob = pickle.load(f)
88+
89+
with (root / 'xarray_object_names.yml').open('r') as f:
90+
xarray_object_names = yaml.load(f, yaml.SafeLoader)
4591

46-
C = Cache()
47-
ob = cls()
48-
for attr in cls._get_meta_attributes():
49-
thing = None # retrieve from cache
50-
setattr(ob, attr, thing)
92+
for attr in xarray_object_names:
93+
tgt_dir = root / (attr + '.xr')
94+
load_object = getattr(xr, xarray_deserialization_backend)
95+
xarray_instance = load_object(tgt_dir)
96+
setattr(ob, attr, xarray_instance.data)
5197

52-
return ob
98+
return ob

langbrainscore/interface/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import xarray as xr
22
from langbrainscore.interface.cacheable import _Cacheable
3+
from abc import ABC
34

4-
5-
# TODO: class _Dataset(_Cacheable, ABC)
6-
class _Dataset:
5+
class _Dataset(_Cacheable, ABC):
76
"""
87
wrapper class for xarray DataArray that confirms format adheres to interface.
98
"""
109

10+
1111
def __init__(self, xr_obj: xr.DataArray) -> "_Dataset":
1212
"""
1313
accepts an xarray with the following core

langbrainscore/utils/cache.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,21 @@
55

66
import typing
77
from pathlib import Path
8+
import os
89

910
def get_cache_root(prefix: typing.Union[str, Path] = '~/.cache') -> Path:
1011
'''
1112
1213
'''
14+
if 'LBS_CACHE' in os.environ:
15+
prefix = os.environ['LBS_CACHE']
16+
1317
prefix = Path(prefix).expanduser().resolve()
1418
root = prefix / 'langbrainscore'
1519

16-
act = root / 'encoder_activations'
17-
data = root / 'datasets'
18-
results = root / 'results'
20+
# act = root / 'encoder_activations'
21+
# data = root / 'datasets'
22+
# results = root / 'results'
1923

2024
root.mkdir(parents=True, exist_ok=True)
21-
return prefix
25+
return root

0 commit comments

Comments
 (0)