1
1
2
2
3
- from abc import ABC , abstractmethod , abstractclassmethod
3
+ import pickle
4
4
import typing
5
+ from abc import ABC , abstractclassmethod , abstractmethod
5
6
from pathlib import Path
6
- from diskcache import Cache
7
-
7
+ import yaml
8
+ import xarray as xr
9
+ from langbrainscore .utils .cache import get_cache_root
10
+ from langbrainscore .utils .logging import log
8
11
9
12
T = typing .TypeVar ('T' )
10
13
@@ -13,24 +16,57 @@ class _Cacheable(ABC):
13
16
A class used to define a common interface for Object caching in LangBrainscore
14
17
'''
15
18
16
- @abstractclassmethod
17
- def _get_netcdf_cacheable_objects (cls ) -> typing .Iterable [str ]:
18
- return ()
19
+ # @abstractclassmethod
20
+ # @classmethod
21
+ def _get_xarray_objects (self ) -> typing .Iterable [str ]:
22
+ '''
23
+ returns all (visible) attributes of self that are instances of xarray
24
+ '''
25
+ keys = []
26
+ for key , ob in vars (self ).items ():
27
+ if isinstance (ob , xr .DataArray ):
28
+ keys += [key ]
29
+ return keys
19
30
20
- @abstractclassmethod
21
- def _get_meta_attributes (cls ) -> typing .Iterable [str ]:
22
- return ()
31
+ # @abstractclassmethod
32
+ # def _get_meta_attributes(cls) -> typing.Iterable[str]:
33
+ # return ()
23
34
24
- def to_cache (self , filename ) -> None :
35
+ def to_cache (self , identifier_string : str , overwrite = True ,
36
+ xarray_serialization_backend = 'to_zarr' , cache_dir = None ) -> None :
25
37
'''
26
38
dump this object to cache. this method implementation will serve
27
39
as the default implementation. it is recommended that this be left
28
40
as-is for compatibility with caching across the library.
29
41
'''
30
- NotImplemented
42
+ root = Path (cache_dir or get_cache_root ()).expanduser ().resolve ()
43
+ root /= identifier_string
44
+ root /= self .__class__ .__name__
45
+ root .mkdir (exist_ok = True , parents = True )
46
+ log (f'caching { self } to { root } ' )
47
+
48
+ with (root / 'xarray_object_names.yml' ).open ('w' ) as f :
49
+ yaml .dump (self ._get_xarray_objects (), f , yaml .SafeDumper )
50
+
51
+ kwargs = {}
52
+ if overwrite and 'zarr' in xarray_serialization_backend :
53
+ kwargs .update ({'mode' :'w' })
54
+ for ob_name in self ._get_xarray_objects ():
55
+ ob = getattr (self , ob_name )
56
+ tgt_dir = root / (ob_name + '.xr' )
57
+
58
+ dump_object = getattr (ob .to_dataset (name = 'data' ), xarray_serialization_backend )
59
+ dump_object (tgt_dir , ** kwargs )
60
+
61
+ with (root / f'{ self .__class__ .__name__ } .pkl' ).open ('wb' ) as f :
62
+ pickle .dump (self , f )
63
+
31
64
32
65
# NB comment from Guido: https://github.com/python/typing/issues/58#issuecomment-194569410
33
- def from_cache (cls : T , filename ) -> typing .Callable [..., T ]:
66
+ @classmethod
67
+ def from_cache (cls , identifier_string : str ,
68
+ xarray_deserialization_backend = 'open_zarr' ,
69
+ cache_dir = None ) -> T :
34
70
'''
35
71
construct an object from cache. subclasses must start with the
36
72
object returned by a call to this method like so:
@@ -41,12 +77,22 @@ def from_cache(cls: T, filename) -> typing.Callable[..., T]:
41
77
return ob
42
78
43
79
'''
44
- return NotImplemented
80
+ root = Path (cache_dir or get_cache_root ()).expanduser ().resolve ()
81
+ root /= identifier_string
82
+ root /= cls .__name__
83
+ root .mkdir (exist_ok = True , parents = True )
84
+ log (f'loading cache for { cls } from { root } { identifier_string } ' )
85
+
86
+ with (root / f'{ cls .__name__ } .pkl' ).open ('rb' ) as f :
87
+ ob = pickle .load (f )
88
+
89
+ with (root / 'xarray_object_names.yml' ).open ('r' ) as f :
90
+ xarray_object_names = yaml .load (f , yaml .SafeLoader )
45
91
46
- C = Cache ()
47
- ob = cls ( )
48
- for attr in cls . _get_meta_attributes ():
49
- thing = None # retrieve from cache
50
- setattr (ob , attr , thing )
92
+ for attr in xarray_object_names :
93
+ tgt_dir = root / ( attr + '.xr' )
94
+ load_object = getattr ( xr , xarray_deserialization_backend )
95
+ xarray_instance = load_object ( tgt_dir )
96
+ setattr (ob , attr , xarray_instance . data )
51
97
52
- return ob
98
+ return ob
0 commit comments