Skip to content

Commit

Permalink
refactor how "id" and "set" works
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Apr 7, 2024
1 parent e82e1f2 commit d653341
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 53 deletions.
40 changes: 25 additions & 15 deletions tests/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_add_text(self):
start_time = time.time()
text = "This is a test text. " * 100
metadata = {"source": "test", "author": "John Doe", "timestamp": "2023-06-08"}
self.vlite.add(text, metadata=metadata)
self.vlite.add(text, metadata=metadata, item_id="test_text_1")
end_time = time.time()
TestVLite.test_times["add_single_text"] = end_time - start_time
print(f"Count of texts in the collection: {self.vlite.count()}")
Expand All @@ -22,12 +22,12 @@ def test_add_texts(self):
start_time = time.time()
text_512tokens = "underreckoning fleckiness hairstane paradigmatic eligibility sublevate xviii achylia reremice flung outpurl questing gilia unosmotic unsuckled plecopterid excludable phenazine fricando unfledgedness spiritsome incircle desmogenous subclavate redbug semihoral district chrysocolla protocoled servius readings propolises javali dujan stickman attendee hambone obtusipennate tightropes monitorially signaletics diestrums preassigning spriggy yestermorning margaritic tankfuls aseptify linearity hilasmic twinning tokonoma seminormalness cerebrospinant refroid doghouse kochab dacryocystalgia saltbushes newcomer provoker berberid platycoria overpersuaded reoverflow constrainable headless forgivably syzygal purled reese polyglottonic decennary embronze pluripotent equivocally myoblasts thymelaeaceous confervae perverted preanticipate mammalogical desalinizing tackets misappearance subflexuose concludence effluviums runtish gras cuckolded hemostasia coatroom chelidon policizer trichinised frontstall impositions unta outrance scholium fibrochondritis furcates fleaweed housefront helipads hemachate snift appellativeness knobwood superinclination tsures haberdasheries unparliamented reexecution nontangential waddied desolated subdistinctively undiscernibleness swishiest dextral progs koprino bruisingly unloanably bardash uncuckoldedunderreckoning fleckiness hairstane paradigmatic eligibility sublevate xviii achylia reremice flung outpurl questing gilia unosmotic unsuckled plecopterid excludable phenazine fricando unfledgedness spiritsome incircle desmogenous subclavate redbug semihoral district chrysocolla spriggy yestermorning margaritic tankfuls aseptify linearity hilasmic twinning tokonoma seminormalness cerebrospinant refroequivocally myoblasts thymelaeaceous confervae perverted preantiest dextral progs koprino bruisingly unloanably bardash uncuckolded"
metadata = {"source": "test_512tokens", "category": "random"}
self.vlite.add(text_512tokens, metadata=metadata)
self.vlite.add(text_512tokens, metadata=metadata, item_id="test_text_2")

with open(os.path.join(os.path.dirname(__file__), "data/text-8192tokens.txt"), "r") as file:
text_8192tokens = file.read()
metadata = {"source": "test_8192tokens", "category": "long_text"}
self.vlite.add(text_8192tokens, metadata=metadata)
self.vlite.add(text_8192tokens, metadata=metadata, item_id="test_text_3")

end_time = time.time()
TestVLite.test_times["add_multiple_texts"] = end_time - start_time
Expand All @@ -36,7 +36,7 @@ def test_add_texts(self):
def test_add_pdf(self):
print(f"[test_add_pdf] Count of chunks in the collection: {self.vlite.count()}")
start_time = time.time()
self.vlite.add(process_pdf(os.path.join(os.path.dirname(__file__), 'data/attention.pdf')), need_chunks=False)
self.vlite.add(process_pdf(os.path.join(os.path.dirname(__file__), 'data/attention.pdf')), need_chunks=False, item_id="test_pdf_1")
end_time = time.time()
TestVLite.test_times["add_pdf"] = end_time - start_time
print(f"[test_add_pdf] after Count of chunks in the collection: {self.vlite.count()}")
Expand All @@ -62,7 +62,7 @@ def test_retrieve(self):
]
start_time = time.time()
for query in queries:
results = self.vlite.retrieve(query)
results = self.vlite.retrieve(query, top_k=3)
print(f"Query: {query}")
print(f"Top 3 results:")
for text, similarity, metadata in results[:3]:
Expand All @@ -74,36 +74,46 @@ def test_retrieve(self):
TestVLite.test_times["retrieve"] = end_time - start_time

def test_delete(self):
self.vlite.add("This is a test text.", metadata={"id": "test_text_1"})
self.vlite.add("Another test text.", metadata={"id": "test_text_2"})
self.vlite.add("This is a test text.", item_id="test_delete_1")
self.vlite.add("Another test text.", item_id="test_delete_2")
start_time = time.time()
self.vlite.delete(['test_text_1', 'test_text_2'])
self.vlite.delete(['test_delete_1', 'test_delete_2'])
end_time = time.time()
TestVLite.test_times["delete"] = end_time - start_time
print(f"Count of texts in the collection: {self.vlite.count()}")

def test_update(self):
self.vlite.add("This is a test text.", metadata={"id": "test_text_3"})
self.vlite.add("This is a test text.", item_id="test_update_1")
start_time = time.time()
self.vlite.update("test_text_3", text="This is an updated text.", metadata={"updated": True})
self.vlite.update("test_update_1", text="This is an updated text.", metadata={"updated": True})
end_time = time.time()
TestVLite.test_times["update"] = end_time - start_time
print(f"Count of texts in the collection: {self.vlite.count()}")

def test_get(self):
self.vlite.add("Text 1", metadata={"id": "text_1", "category": "A"})
self.vlite.add("Text 2", metadata={"id": "text_2", "category": "B"})
self.vlite.add("Text 3", metadata={"id": "text_3", "category": "A"})
self.vlite.add("Text 1", item_id="test_get_1", metadata={"category": "A"})
self.vlite.add("Text 2", item_id="test_get_2", metadata={"category": "B"})
self.vlite.add("Text 3", item_id="test_get_3", metadata={"category": "A"})

start_time = time.time()
items = self.vlite.get(ids=["text_1", "text_3"])
print(f"Items with IDs 'text_1' and 'text_3': {items}")
items = self.vlite.get(ids=["test_get_1", "test_get_3"])
print(f"Items with IDs 'test_get_1' and 'test_get_3': {items}")

items = self.vlite.get(where={"category": "A"})
print(f"Items with category 'A': {items}")
end_time = time.time()
TestVLite.test_times["get"] = end_time - start_time

def test_set(self):
self.vlite.add("Original text", item_id="test_set_1", metadata={"original": True})
start_time = time.time()
self.vlite.set("test_set_1", text="Updated text", metadata={"updated": True})
end_time = time.time()
TestVLite.test_times["set"] = end_time - start_time

items = self.vlite.get(ids=["test_set_1"])
print(f"Updated item: {items}")

def test_count(self):
start_time = time.time()
count = self.vlite.count()
Expand Down
106 changes: 68 additions & 38 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def __init__(self, collection=None, device='cpu', model_name='mixedbread-ai/mxba
except FileNotFoundError:
print(f"Collection file {self.collection} not found. Initializing empty attributes.")

def add(self, data, metadata=None, need_chunks=True, fast=True):
print("Adding text to the collection...", self.collection)
def add(self, data, metadata=None, item_id=None, need_chunks=True, fast=True):
data = [data] if not isinstance(data, list) else data
results = []
all_chunks = []
Expand All @@ -42,14 +41,14 @@ def add(self, data, metadata=None, need_chunks=True, fast=True):
if isinstance(item, dict):
text_content = item['text']
item_metadata = item.get('metadata', {})
item_id = item_metadata.get('id', str(uuid4()))
else:
text_content = item
item_metadata = {}

if item_id is None:
item_id = str(uuid4())

item_metadata.update(metadata or {})
item_metadata['id'] = item_id

if need_chunks:
chunks = chop_and_chunk(text_content, fast=fast)
Expand All @@ -63,17 +62,17 @@ def add(self, data, metadata=None, need_chunks=True, fast=True):

encoded_data = self.model.embed(all_chunks, device=self.device)
binary_encoded_data = self.model.quantize(encoded_data, precision="binary")
for idx, (chunk, binary_vector, metadata, item_id) in enumerate(zip(all_chunks, binary_encoded_data, all_metadata, all_ids)):

for idx, (chunk, binary_vector, metadata) in enumerate(zip(all_chunks, binary_encoded_data, all_metadata)):
chunk_id = f"{item_id}_{idx}"
self.index[chunk_id] = {
'text': chunk,
'metadata': metadata,
'binary_vector': binary_vector.tolist()
}

if item_id not in [result[0] for result in results]:
results.append((item_id, binary_encoded_data, metadata))
if item_id not in [result[0] for result in results]:
results.append((item_id, binary_encoded_data, metadata))

self.save()
print("Text added successfully.")
Expand Down Expand Up @@ -111,35 +110,19 @@ def search(self, query_binary_vector, top_k, metadata=None):
# Apply metadata filter on the retrieved top_k items
if metadata:
filtered_ids = []
for item_id in top_k_ids:
item_metadata = self.index[item_id]['metadata']
for chunk_id in top_k_ids:
item_id = chunk_id.split('_')[0]
item_metadata = self.index[chunk_id]['metadata']
if all(item_metadata.get(key) == value for key, value in metadata.items()):
filtered_ids.append(item_id)
filtered_ids.append(chunk_id)
top_k_ids = filtered_ids[:top_k]

# Get the similarity scores for the top_k items
top_k_scores = binary_similarities[top_k_indices]

return list(zip(top_k_ids, top_k_scores))

def delete(self, ids):
if isinstance(ids, str):
ids = [ids]

deleted_count = 0
for id in ids:
if id in self.index:
del self.index[id]
deleted_count += 1

if deleted_count > 0:
self.save()
print(f"Deleted {deleted_count} item(s) from the collection.")
else:
print("No items found with the specified IDs.")

return deleted_count

def update(self, id, text=None, metadata=None, vector=None):
if id in self.index:
if text is not None:
Expand All @@ -157,13 +140,57 @@ def update(self, id, text=None, metadata=None, vector=None):
else:
print(f"Item with ID '{id}' not found.")
return False

def delete(self, ids):
if isinstance(ids, str):
ids = [ids]

deleted_count = 0
for id in ids:
chunk_ids = [chunk_id for chunk_id in self.index if chunk_id.startswith(f"{id}_")]
for chunk_id in chunk_ids:
if chunk_id in self.index:
del self.index[chunk_id]
deleted_count += 1

if deleted_count > 0:
self.save()
print(f"Deleted {deleted_count} item(s) from the collection.")
else:
print("No items found with the specified IDs.")

return deleted_count



def get(self, ids=None, where=None):
if ids is not None:
id_set = set(ids)
items = [(self.index[id]['text'], self.index[id]['metadata']) for id in self.index if id in id_set]
if isinstance(ids, str):
ids = [ids]
items = []
for id in ids:
item_chunks = []
item_metadata = {}
for chunk_id, chunk_data in self.index.items():
if chunk_id.startswith(f"{id}_"):
item_chunks.append(chunk_data['text'])
item_metadata.update(chunk_data['metadata'])
if item_chunks:
item_text = ' '.join(item_chunks)
items.append((item_text, item_metadata))
else:
items = [(self.index[id]['text'], self.index[id]['metadata']) for id in self.index]
items = []
item_dict = {}
for chunk_id, chunk_data in self.index.items():
item_id = chunk_id.split('_')[0]
if item_id not in item_dict:
item_dict[item_id] = {'chunks': [], 'metadata': {}}
item_dict[item_id]['chunks'].append(chunk_data['text'])
item_dict[item_id]['metadata'].update(chunk_data['metadata'])
for item_id, item_data in item_dict.items():
item_text = ' '.join(item_data['chunks'])
item_metadata = item_data['metadata']
items.append((item_text, item_metadata))

if where is not None:
items = [item for item in items if all(item[1].get(key) == value for key, value in where.items())]
Expand All @@ -172,16 +199,19 @@ def get(self, ids=None, where=None):

def set(self, id, text=None, metadata=None, vector=None):
print(f"Setting attributes for item with ID: {id}")
if id in self.index:
if text is not None:
self.index[id]['text'] = text
if metadata is not None:
self.index[id]['metadata'].update(metadata)
if vector is not None:
self.index[id]['vector'] = vector
chunk_ids = [chunk_id for chunk_id in self.index if chunk_id.startswith(f"{id}_")]
if chunk_ids:
for chunk_id in chunk_ids:
if text is not None:
self.index[chunk_id]['text'] = text
if metadata is not None:
self.index[chunk_id]['metadata'].update(metadata)
if vector is not None:
self.index[chunk_id]['vector'] = vector
self.save()
else:
print(f"Item with ID {id} not found.")


def count(self):
return len(self.index)
Expand Down

0 comments on commit d653341

Please sign in to comment.