Skip to content

Commit b79b51a

Browse files
author
localhost
committed
Adding saving to irregular cache
1 parent 4094cb5 commit b79b51a

File tree

1 file changed

+22
-5
lines changed

1 file changed

+22
-5
lines changed

adlframework/caches/nparr_cache.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pdb
33
import pickle
4+
import os
45
from adlframework.cache import Cache
56
from adlframework.utils import get_logger
67

@@ -86,8 +87,6 @@ def load(self):
8687
Currently, only saves data.
8788
To-Do: save labels too.
8889
'''
89-
import pickle
90-
import os
9190
dtf = self.cache_file+'_data'
9291
lf = self.cache_file+'_label'
9392
df = self.cache_file+'_dict'
@@ -113,9 +112,6 @@ def double_arr_size(self):
113112
self.labels = self.new_labels
114113

115114
class IrregularNPArrCache(Cache):
116-
import tables
117-
import string
118-
import random
119115
'''
120116
TO-DO: Written for 1-d. Generalize to N-D.
121117
Reference: https://kastnerkyle.github.io/posts/using-pytables-for-larger-than-ram-data-processing/
@@ -124,6 +120,7 @@ def __init__(self, cache_file=None, compress=True):
124120
self.data = []
125121
self.labels = []
126122
self.id_to_index = {}
123+
self.cache_file = cache_file
127124

128125

129126
''' Necessary classes '''
@@ -147,3 +144,23 @@ def retrieve(self, id_):
147144
idx = self.id_to_index[id_]
148145
return self.data[idx], self.labels[idx]
149146

147+
def load(self):
148+
'''
149+
Reads data, labels, and id_to_index as tuple from pickle
150+
'''
151+
if self.cache_file != None:
152+
if os.path.exists(self.cache_file):
153+
with open(self.cache_file, "wb") as f:
154+
self.data, self.labels, self.id_to_index = pickle.load(f)
155+
else:
156+
logger.warn('Cache file specified doesn\'t exist. Will continue...')
157+
158+
def save(self):
159+
'''
160+
Save data, labels, and id_to_index as tuple in pickle
161+
'''
162+
if self.cache_file != None:
163+
with open(self.cache_file, "wb") as f:
164+
pickle.dump((self.data, self.labels, self.id_to_index), f)
165+
else:
166+
logger.warn('No cache file specified. Will lose cache on exit.')

0 commit comments

Comments
 (0)