diff --git a/datasketch/lsh.py b/datasketch/lsh.py index 709271b5..fc7da432 100644 --- a/datasketch/lsh.py +++ b/datasketch/lsh.py @@ -18,7 +18,14 @@ def _integration(f, a, b): # For when no scipy installed integrate = _integration - +def _ensure_bytestring(bytes_or_str): + if isinstance(bytes_or_str, str): + return bytes_or_str.encode('utf-8') + elif isinstance(bytes_or_str, bytes): + return bytes_or_str + else: + raise ValueError("basename must be either bytes or string type") + def _false_positive_probability(threshold, b, r): _probability = lambda s : 1 - (1 - s**float(r))**float(b) a, err = integrate(_probability, 0.0, threshold) @@ -115,7 +122,7 @@ def __init__(self, threshold=0.9, num_perm=128, weights=(0.5, 0.5), self.prepickle = storage_config['type'] == 'redis' if prepickle is None else prepickle - basename = storage_config.get('basename', _random_name(11)) + basename = _ensure_bytestring(storage_config.get('basename', _random_name(11))) self.hashtables = [ unordered_storage(storage_config, name=b''.join([basename, b'_bucket_', bytes([i])])) for i in range(self.b)] diff --git a/datasketch/storage.py b/datasketch/storage.py index 6453408a..38c44ade 100644 --- a/datasketch/storage.py +++ b/datasketch/storage.py @@ -9,6 +9,16 @@ except ImportError: redis = None +try: + from pynamodb.models import Model + from pynamodb.models import MetaModel + from pynamodb.connection.util import pythonic + from pynamodb.attributes import UnicodeAttribute, BinaryAttribute + import ulid + ddb = True +except ImportError: + ddb = None + def ordered_storage(config, name=None): '''Return ordered storage system based on the specified config. @@ -44,6 +54,8 @@ def ordered_storage(config, name=None): return DictListStorage(config) if tp == 'redis': return RedisListStorage(config, name=name) + if tp == 'ddb': + return DDBListStorage(config, name=name) def unordered_storage(config, name=None): @@ -79,6 +91,8 @@ def unordered_storage(config, name=None): return DictSetStorage(config) if tp == 'redis': return RedisSetStorage(config, name=name) + if tp == 'ddb': + return DDBSetStorage(config, name=name) class Storage(ABC): @@ -178,7 +192,7 @@ def get(self, key): def remove(self, *keys): for key in keys: del self._dict[key] - + def remove_val(self, key, val): self._dict[key].remove(val) @@ -198,7 +212,6 @@ def itemcounts(self, **kwargs): def has_key(self, key): return key in self._dict - class DictSetStorage(UnorderedStorage, DictListStorage): '''This is a wrapper class around ``defaultdict(set)`` enabling it to support an API consistent with `Storage` @@ -434,3 +447,98 @@ def _random_name(length): # For use with Redis, we return bytes return ''.join(random.choice(string.ascii_lowercase) for _ in range(length)).encode('utf8') + +if ddb is not None: + class ModelMeta(MetaModel): + def __new__(cls, name, bases, d, **kwargs): + d['Meta'] = type('Meta', (), { + 'table_name': make_safe_table_name(kwargs['table_name']), + 'region': kwargs['config']['region'], + 'read_capacity_units': kwargs['config']['read_capacity'], + 'write_capacity_units': kwargs['config']['write_capacity'] + }) + return MetaModel.__new__(cls, name, bases, d) + def __init__(self, *args, **kwargs): + del kwargs['table_name'] + del kwargs['config'] + return super().__init__(*args, **kwargs) + + class DDBSet(Model): + set_name = BinaryAttribute(hash_key = True) + value = UnicodeAttribute(range_key = True) + + class DDBList(Model): + set_name = UnicodeAttribute(hash_key = True) + insert_order = UnicodeAttribute(range_key = True) + value = BinaryAttribute() + + class DDBListStorage(OrderedStorage): + def __init__(self, config, name = None): + self.name = name + class ListModel(DDBList, metaclass = ModelMeta, table_name = name, config = config): + pass + self.model_class = ListModel + if not self.model_class.exists(): + self.model_class.create_table(wait = True, read_capacity_units=10, write_capacity_units=10) + + def keys(self): + res = [ item.set_name for item in self.model_class.scan() ] + return res + + def get(self, key): + res = [ item.value for item in self.model_class.query(key, scan_index_forward = True) ] + return res + + def remove(self, *keys): + for key in keys: + count = 0 + for item in self.model_class.query(key): + count += 1 + item.delete() + + def remove_val(self, key, val): + count = 0 + for item in self.model_class.query(key, self.model_class.value == val): + count += 1 + item.delete() + + def insert(self, key, *vals, **kwargs): + with self.model_class.batch_write() as batch: + for val in vals: + ulid_str = ulid.new().str + batch.save(self.model_class(key.decode('utf-8'), ulid_str, value = val)) + + def size(self): + return self.model_class.count() + + def itemcounts(self, **kwargs): + dict = {} + for item in self.model_class.scan(): + dict[item.set_name] = self.model_class.count(set_name) + return dict + + def has_key(self, key): + result = self.get(key) + if result == []: + return False + return True + + def make_safe_table_name(name): + return str(name)[2:-1].replace('\\', '') + + class DDBSetStorage(UnorderedStorage, DDBListStorage): + def __init__(self, config, name = None): + self.name = name + class SetModel(DDBSet, metaclass = ModelMeta, table_name = name, config = config): + pass + self.model_class = SetModel + if not self.model_class.exists(): + self.model_class.create_table(wait = True, read_capacity_units=10, write_capacity_units=10) + + def get(self, key): + res = [ item.value for item in self.model_class.query(key) ] + return set(res) + def insert(self, key, *vals, **kwargs): + with self.model_class.batch_write() as batch: + for val in vals: + batch.save(self.model_class(key, val.decode('utf-8')))