Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat committed Feb 29, 2024
1 parent 12439de commit 9bd7dca
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
6 changes: 3 additions & 3 deletions lilac/embeddings/vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
import os
import pickle
from typing import Iterable, Optional, Sequence, Type, cast
from typing import Iterable, Iterator, Optional, Sequence, Type, cast

import numpy as np

Expand Down Expand Up @@ -50,7 +50,7 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
pass

@abc.abstractmethod
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
"""Return the embeddings for given keys.
Args:
Expand Down Expand Up @@ -160,7 +160,7 @@ def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]:
all_vector_keys.append([(*path_key, i) for i in range(len(spans))])

flat_vector_keys = [key for vector_keys in all_vector_keys for key in (vector_keys or [])]
all_vectors = iter(self._vector_store.get(flat_vector_keys))
all_vectors = self._vector_store.get(flat_vector_keys)
for spans in all_spans:
yield [{'span': span, 'vector': next(all_vectors)} for span in spans]

Expand Down
6 changes: 4 additions & 2 deletions lilac/signals/semantic_similarity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None:
pass

@override
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray:
def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]:
keys = keys or []
return np.array([EMBEDDINGS[tuple(path_key)][cast(int, index)] for *path_key, index in keys])
yield from [
np.array(EMBEDDINGS[tuple(path_key)][cast(int, index)]) for *path_key, index in keys
]

@override
def delete(self, base_path: str) -> None:
Expand Down

0 comments on commit 9bd7dca

Please sign in to comment.