Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge from master #38

Open
wants to merge 27 commits into
base: redis
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
84d590b
change from list to dictionary
KempinskiB Sep 12, 2022
6c6a6fd
reverting back
KempinskiB Sep 12, 2022
e6c558e
Merge pull request #36 from argmaxml/redis
urigoren Sep 14, 2022
94f2d63
Merge pull request #37 from argmaxml/redis
urigoren Sep 17, 2022
8b12d62
Added testing for sklearn engine
gadm Sep 19, 2022
73d7817
Merge pull request #39 from argmaxml/sklearn_test
urigoren Sep 20, 2022
4e5a5f7
Merge pull request #41 from argmaxml/redis
urigoren Sep 21, 2022
3728b6c
fixing requirements
KempinskiB Sep 22, 2022
91521a2
Merge pull request #42 from argmaxml/requirements_fix
urigoren Sep 22, 2022
5856e21
Index bug
urigoren Sep 22, 2022
c7be0cf
Merge pull request #43 from argmaxml/hnswmock
urigoren Sep 22, 2022
74d2c95
fix named ids in sklearn searcher
urigoren Nov 13, 2022
ede2739
redis similarity works
urigoren Nov 19, 2022
c8bee59
init fix
urigoren Nov 19, 2022
31046b0
returning pruned keys
urigoren Nov 19, 2022
d34af6b
default overwrite is True
urigoren Nov 19, 2022
3eb3971
propagating redis credentials
urigoren Nov 19, 2022
b50f54b
drop if exists
urigoren Nov 19, 2022
01e41cf
sklearn does not take **kwargs
urigoren Nov 19, 2022
255968f
better documentation
urigoren Nov 20, 2022
492bd80
index_factory bugfix
urigoren Nov 20, 2022
c87cf2a
strategy edge case
urigoren Nov 22, 2022
f2e3f5b
simplify list comprehension
urigoren Nov 22, 2022
11fbd89
get and set redis vectors
urigoren Nov 24, 2022
4711f60
bugfix
urigoren Nov 24, 2022
8bf9327
bugfix
urigoren Nov 24, 2022
e97b839
bugfix
urigoren Nov 24, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion recsplain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__version__="0.0.83"
__version__="0.0.104"
from .similarity_helpers import SciKitNearestNeighbors, RedisIndex
from .strategies import BaseStrategy, AvgUserStrategy, RedisStrategy
from .encoders import PartitionSchema
from .endpoint import run_server
1 change: 0 additions & 1 deletion recsplain/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,4 +522,3 @@ def json_encode(self, value):
def encode(self, value):
val = self.get_feature(value)
return self.json_encode(val)

141 changes: 123 additions & 18 deletions recsplain/similarity_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys
from typing import Dict
import numpy as np
import collections
from sklearn.neighbors import NearestNeighbors
Expand All @@ -9,11 +10,14 @@
except ModuleNotFoundError:
print ("hnswlib not found")
HNSWMock = collections.namedtuple("HNSWMock", ("Index", "max_elements"))
hnswlib = HNSWMock(None,0)
class MockHnsw:
def __init__(self, *args, **kwargs) -> None:
pass
hnswlib = HNSWMock(MockHnsw(),0)
try:
import faiss
available_engines.add("faiss")
except ModuleNotFoundError:
except Exception:
print ("faiss not found")
faiss = None
try:
Expand Down Expand Up @@ -45,7 +49,7 @@ def parse_server_name(sname):


class FaissIndexFactory:
def __init__(self, space, dim, index_factory, **kwargs):
def __init__(self, space:str, dim:int, index_factory:str, **kwargs):
if index_factory == '':
index_factory = 'Flat'
if space in ['ip', 'cosine']:
Expand Down Expand Up @@ -90,7 +94,7 @@ def load_index(self, fname):
self.index = faiss.read_index(fname)

class LazyHnsw(hnswlib.Index):
def __init__(self, space, dim, index_factory=None,max_elements=1024, ef_construction=200, M=16):
def __init__(self, space:str, dim:int, max_elements=1024, ef_construction=200, M=16,**kwargs):
super().__init__(space, dim)
self.init_max_elements = max_elements
self.init_ef_construction = ef_construction
Expand Down Expand Up @@ -151,7 +155,7 @@ def get_current_count(self):


class SciKitNearestNeighbors:
def __init__(self, space, dim, index_factory=None, **kwargs):
def __init__(self, space:str, dim:int, **kwargs):
if space=="ip":
self.space = "cosine"
sys.stderr.write("Warning: ip is not supported by sklearn, falling back to cosine")
Expand All @@ -161,7 +165,7 @@ def __init__(self, space, dim, index_factory=None, **kwargs):
self.items = []
self.ids = []
self.fitted = False
self.index = NearestNeighbors(metric=self.space,n_jobs=-1,n_neighbors=10, **kwargs)
self.index = NearestNeighbors(metric=self.space,n_jobs=-1,n_neighbors=10)

def __len__(self):
return len(self.items)
Expand All @@ -175,6 +179,8 @@ def init(self, **kwargs):

def add_items(self, data, ids=None, num_threads=-1):
self.items.extend(data)
if ids is None:
ids = list(range(len(self.items),len(self.items)+len(data)))
self.ids.extend(ids)
self.fitted = False

Expand All @@ -186,7 +192,8 @@ def search(self, data, k=1):
self.index.fit(self.items)
self.fitted = True
scores, idx = self.index.kneighbors(data ,k, return_distance=True)
return (scores, idx)
names = [[self.ids[i] for i in ids] for ids in idx]
return scores, names

def get_max_elements(self):
return -1
Expand All @@ -196,17 +203,31 @@ def get_current_count(self):


class RedisIndex:
def __init__(self, space, dim, index_factory=None,redis_credentials=None,max_elements=1024, ef_construction=200, M=16):
def __init__(self, space:str, dim:int, redis_credentials=None,max_elements=1024, ef_construction=200, M=16, overwrite=True,**kwargs):
self.space = space
self.dim = dim
self.max_elements = max_elements
self.ef_construction = ef_construction
self.M = M
if kwargs.get("index_name") is None:
self.index_name = "idx"
else:
self.index_name = kwargs.get("index_name")
if redis_credentials is None:
raise Exception("Redis credentials must be provided")
self.redis = Redis(**redis_credentials)
self.pipe = None
if overwrite:
try:
self.redis.ft(self.index_name).info()
index_exists = True
except:
index_exists = False
if index_exists:
self.redis.ft(self.index_name).dropindex(delete_documents=True)
self.init_hnsw()
# applicable only for user events
self.user_keys=[]

def __enter__(self):
self.pipe = self.redis.pipeline()
Expand All @@ -221,24 +242,37 @@ def __len__(self):
def __itemgetter__(self, item):
return super().get_items([item])[0]

def user_keys(self):
"""Get all user keys"""
return [s.decode()[5:] for s in self.redis.keys("user:*")]

def item_keys(self):
"""Get all item keys"""
return [s.decode()[5:] for s in self.redis.keys("item:*")]

def vector_keys(self):
"""Get all vector keys"""
return [s.decode()[4:] for s in self.redis.keys("vec:*")]

def search(self, data, k=1,partition=None):
"""Search the nearest neighbors of the given vectors, and a given partition."""
query_vector = np.array(data).astype(np.float32).tobytes()

#prepare the query
p = "(@partition:{"+partition+"})" if partition is not None else "*"
q = Query(f'{p}=>[KNN {k} embedding $vec_param AS vector_score]').sort_by('vector_score').paging(0,k).return_fields('vector_score','item_id').dialect(2)
q = Query(f'{p}=>[KNN {k} @embedding $vec_param AS vector_score]').sort_by('vector_score').paging(0,k).return_fields('vector_score','item_id').dialect(2)
params_dict = {"vec_param": query_vector}
results = self.redis.ft().search(q, query_params = params_dict)
results = self.redis.ft(self.index_name).search(q, query_params = params_dict)
scores, ids = [], []
for item in results.docs:
scores.append(item.vector_score)
ids.append(item.item_id)
return (scores, ids)
return scores, ids

def add_items(self, data, ids=None, partition=None):
"""Add items and ids to the index, if a partition is not defined it defaults to NONE"""
self.pipe = self.redis.pipeline(transaction=False)
if partition is None:
partition=""
partition="NONE"
for datum, id in zip(data, ids):
key='item:'+ str(id)
emb = np.array(datum).astype(np.float32).tobytes()
Expand All @@ -248,20 +282,91 @@ def add_items(self, data, ids=None, partition=None):
self.pipe = None

def get_items(self, ids=None):
"""Get items by id"""
ret = []
for id in ids:
ret.append(self.redis.hget("item:"+str(id), "embedding"))
return ret
ret.append(np.frombuffer(self.redis.hget("item:"+str(id), "embedding"), dtype=np.float32))
return np.vstack(ret)

def add_user_event(self, user_id: str, data: Dict[str, str],ttl: int = 60*60*24):
"""
Adds a user event to the index. The event is stored in a hash with the key user:{user_id} and the fields
fields are defined by the `user_keys` property
"""
if not any(self.user_keys):
raise Exception("User keys must be set before adding user events")
vals = []
for key in self.user_keys:
v = data.get(key,"")
# ad hoc int trimming
try:
if v==int(v):
v=int(v)
except:
pass
vals.append(v)
val = '|'.join(map(str, vals))
if self.pipe:
self.pipe.rpush("user:"+str(user_id), val)
if ttl:
self.pipe.expire("user:"+str(user_id), ttl)
else:
self.redis.rpush("user:"+str(user_id), val)
if ttl:
self.redis.expire("user:"+str(user_id), ttl)
return self
def del_user(self, user_id):
"""Delete a user key from Redis"""
if self.pipe:
self.pipe.delete("user:"+str(user_id))
else:
self.redis.delete("user:"+str(user_id))

def get_user_events(self, user_id: str):
"""Gets a list of user events by key"""
if not any(self.user_keys):
raise Exception("User keys must be set before getting user events")
ret = self.redis.lrange("user:"+str(user_id), 0, -1)
return [dict(zip(self.user_keys,x.decode().split('|'))) for x in ret]

def set_vector(self, key, arr, prefix="vec:"):
"""Sets a numpy array as a vector in redis"""
emb = np.array(arr).astype(np.float32).tobytes()
self.redis.set(prefix+str(key), emb)
return self

def get_vector(self, key, prefix="vec:"):
"""Gets a numpy array from redis"""
return np.frombuffer(self.redis.get(prefix+str(key)), dtype=np.float32)


def init_hnsw(self, **kwargs):
self.redis.ft().create_index([
self.redis.ft(self.index_name).create_index([
VectorField("embedding", "HNSW", {"TYPE": "FLOAT32", "DIM": self.dim, "DISTANCE_METRIC": self.space, "INITIAL_CAP": self.max_elements, "M": self.M, "EF_CONSTRUCTION": self.ef_construction}),
TextField("item_id"),
TagField("partition")
])

def get_current_count(self):
raise NotImplementedError("RedisIndex is not implemented yet")
"""Get number of items in index"""
return int(self.redis.ft(self.index_name).info()["num_docs"])

def get_max_elements(self):
return self.max_elements
"""Get max elements in index"""
return self.max_elements

def info(self):
"""Get Redis info as dict"""
return self.redis.ft(self.index_name).info()

if __name__=="__main__":
# docker run -p 6379:6379 redislabs/redisearch:2.4.5
sim = RedisIndex(space='cosine',dim=32,redis_credentials={"host":"127.0.0.1", "port": 6379}, overwrite=True)
data=np.random.random((100,32))
aids=["a"+str(1+i) for i in range(100)]
bids=["b"+str(101+i) for i in range(100)]
sim.add_items(data,aids,partition="a")
sim.add_items(data,bids,partition="b")
# print(sim.search(data[0],k=10,partition=None))
# print(sim.get_items(aids[:10]))
print (sim.item_keys())
41 changes: 26 additions & 15 deletions recsplain/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, model_dir=None, similarity_engine=None ,engine_params={}):
def init_schema(self, **kwargs):
self.schema = PartitionSchema(**kwargs)
self.partitions[self.schema.base_strategy_id()] = [self.IndexEngine(self.schema.metric, self.schema.dim,
self.schema.index_factory,
index_factory=self.schema.index_factory,
**self.engine_params)
for _ in self.schema.partitions]
enc_sizes = {k: len(v) for k, v in self.schema.encoders[self.schema.base_strategy_id()].items()}
Expand All @@ -47,7 +47,7 @@ def init_schema(self, **kwargs):
def add_variant(self, variant):
variant = self.schema.add_variant(variant)
self.partitions[variant['id']] = [self.IndexEngine(self.schema.metric, self.schema.dim,
self.schema.index_factory, **self.engine_params)
index_factory=self.schema.index_factory, **self.engine_params)
for _ in self.schema.partitions]
# enc_sizes = {k: len(v) for k, v in self.schema.encoders[self.schema.base_strategy_id()].items()}
return variant#, enc_sizes
Expand Down Expand Up @@ -131,8 +131,11 @@ def query_by_partition_and_vector(self, partition_num, strategy_id, vec, k, expl
try:
vec = vec.reshape(1, -1).astype('float32') # for faiss
distances, num_ids = self.partitions[strategy_id][partition_num].search(vec, k=k)
indices = np.where(num_ids != -1)
distances, num_ids = distances[indices], num_ids[indices]
if hasattr(distances[0],"__iter__"):
distances=distances[0]
num_ids=num_ids[0]
distances = [d for d,i in zip(distances,num_ids) if i>=0]
num_ids = [i for i in num_ids if i>=0]
except Exception as e:
raise Exception("Error in querying: " + str(e))
if len(num_ids) == 0:
Expand Down Expand Up @@ -228,7 +231,7 @@ def load_model(self, model_name):
with (model_dir/"schema.json").open('r') as f:
schema_dict=json.load(f)
self.schema = PartitionSchema(**schema_dict)
self.partitions = {strategy['id']: [self.IndexEngine(self.schema.metric, self.schema.dim, self.schema.index_factory,
self.partitions = {strategy['id']: [self.IndexEngine(self.schema.metric, self.schema.dim, index_factory=self.schema.index_factory,
**self.engine_params) for _ in self.schema.partitions] for strategy in self.schema.strategies}
model_dir.mkdir(parents=True, exist_ok=True)
with (model_dir/"index_labels.json").open('r') as f:
Expand Down Expand Up @@ -344,16 +347,17 @@ def user_query(self, user_data, item_history, k, strategy_id=None, user_coldstar


class RedisStrategy(BaseStrategy):
def __init__(self, model_dir=None, similarity_engine=None, engine_params={}, redis_credentials=None, user_prefix="user:", value_sep="|", user_keys=[],event_key="event",item_key="item",event_weights={}):
super().__init__(model_dir, similarity_engine, engine_params)
def __init__(self, model_dir=None, similarity_engine=None, engine_params={}, redis_credentials=None, user_prefix="user:",vector_prefix="vec:", value_sep="|", user_keys=[],event_key="event",item_key="item",event_weights={}):
super().__init__(model_dir, similarity_engine, dict(engine_params,redis_credentials=redis_credentials))
assert Redis is not None, "RedisStrategy requires redis-py to be installed"
assert redis_credentials is not None, "RedisStrategy requires redis credentials"
assert len(user_keys)>0, "user_keys not supplied"
assert event_key in user_keys, "event_key not in user_keys"
assert item_key in user_keys, "item_key not in user_keys"
self.redis = Redis(**redis_credentials)
self.sep = value_sep
self. user_prefix = user_prefix
self.user_prefix = user_prefix
self.vector_prefix = vector_prefix
self.event_key=event_key
self.user_keys=user_keys
self.item_key=item_key
Expand All @@ -366,6 +370,14 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.pipe.execute()
self.pipe = None
def set_vector(self, key, arr):
"""Sets a numpy array as a vector in redis"""
emb = np.array(arr).astype(np.float32).tobytes()
self.redis.set(self.vector_prefix+str(key), emb)
return self
def get_vector(self, key):
"""Gets a numpy array from redis"""
return np.frombuffer(self.redis.get(self.vector_prefix+str(key)), dtype=np.float32)
def del_user(self, user_id):
if self.pipe:
self.pipe.delete(self.user_prefix+str(user_id))
Expand Down Expand Up @@ -408,16 +420,15 @@ def user_partition_num(self, user_data):
def user_query(self, user_data, k, strategy_id=None, user_coldstart_item=None, user_coldstart_weight=1,user_id=None):
if not strategy_id:
strategy_id = self.schema.base_strategy_id()
if user_coldstart_item is None:
n = 0
vec = np.zeros(self.schema.dim)
else:
vec = np.zeros(self.schema.dim)
n = 0
if user_coldstart_item is not None:
n = user_coldstart_weight
if hasattr(user_coldstart_item, "__call__"):
item = user_coldstart_item(user_data)
if type(user_coldstart_item) == str:
vec = self.get_vector(user_coldstart_item)
elif type(user_coldstart_item) == dict:
item = user_coldstart_item
vec = self.schema.encode(item, strategy_id)
vec = self.schema.encode(item, strategy_id)
user_partition_num = self.user_partition_num(user_data)
col_mapping = self.schema.component_breakdown()
labels, distances = [], []
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"joblib>=0.17.0",
"tqdm>=4.62.3",
"pandas>=1.3.0",
"scikit-learning>=0.19.0",
"scikit-learn>=0.19.0",
],
long_description="https://github.com/argmaxml/recsplain/blob/master/README.md",
long_description_content_type="text/markdown",
Expand Down
Loading