-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: init vector store simliartiy code
- Loading branch information
Showing
9 changed files
with
217 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
use crate::embedding::Embedding; | ||
use crate::similarity::Similarity; | ||
|
||
pub struct CosineSimilarity; | ||
|
||
impl Similarity for CosineSimilarity { | ||
fn similarity_score(&self, vector_a: &Embedding, vector_b: &Embedding) -> f32 { | ||
if vector_a.len() != vector_b.len() { | ||
panic!( | ||
"Length of vector a ({}) must be equal to the length of vector b ({})", | ||
vector_a.len(), | ||
vector_b.len() | ||
); | ||
} | ||
|
||
let dot_product: f32 = vector_a | ||
.iter() | ||
.zip(vector_b.iter()) // Use vector_b.iter() directly | ||
.map(|(a, b)| a * b) | ||
.sum(); | ||
|
||
let norm_a: f32 = vector_a.iter().map(|x| x * x).sum(); | ||
let norm_b: f32 = vector_b.iter().map(|x| x * x).sum(); | ||
|
||
dot_product / (f32::sqrt(norm_a) * f32::sqrt(norm_b)) | ||
} | ||
} | ||
|
||
impl CosineSimilarity { | ||
pub fn between(embedding: &Embedding, reference_embedding: &Embedding) -> f32 { | ||
CosineSimilarity.similarity_score(embedding, reference_embedding) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
use std::cmp::Ordering; | ||
use crate::embedding::Embedding; | ||
|
||
#[derive(Debug, Clone)] | ||
pub struct EmbeddingMatch<Embedded: Clone + Ord> { | ||
score: f32, | ||
embedding_id: String, | ||
embedding: Embedding, | ||
embedded: Embedded, | ||
} | ||
|
||
impl<Embedded: Clone + Ord> EmbeddingMatch<Embedded> { | ||
pub(crate) fn new(score: f32, embedding_id: String, embedding: Embedding, embedded: Embedded) -> Self { | ||
EmbeddingMatch { | ||
score, | ||
embedding_id, | ||
embedding, | ||
embedded, | ||
} | ||
} | ||
} | ||
|
||
impl<Embedded: Clone + Ord> PartialOrd for EmbeddingMatch<Embedded> { | ||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { | ||
self.score.partial_cmp(&other.score) | ||
} | ||
} | ||
|
||
impl<Embedded: Clone + Ord> PartialEq for EmbeddingMatch<Embedded> { | ||
fn eq(&self, other: &Self) -> bool { | ||
self.score == other.score | ||
} | ||
} | ||
|
||
impl<Embedded: Clone + Ord> Ord for EmbeddingMatch<Embedded> { | ||
fn cmp(&self, other: &Self) -> Ordering { | ||
self.score.partial_cmp(&other.score).unwrap_or(Ordering::Equal) | ||
} | ||
} | ||
|
||
impl<Embedded: Clone + Ord> Eq for EmbeddingMatch<Embedded> { | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
use crate::embedding::Embedding; | ||
use crate::embedding_match::EmbeddingMatch; | ||
|
||
trait EmbeddingStore<Embedded: Clone + Ord> { | ||
// Adds an embedding to the store and returns its unique identifier. | ||
fn add(&mut self, embedding: Embedding) -> String; | ||
|
||
// Adds an embedding to the store with a specified identifier. | ||
fn add_with_id(&mut self, id: String, embedding: Embedding); | ||
|
||
// Adds an embedding to the store and associates it with the provided embedded data. | ||
fn add_with_embedded(&mut self, embedding: Embedding, embedded: Embedded) -> String; | ||
|
||
// Adds a list of embeddings to the store and returns a list of unique identifiers. | ||
fn add_all(&mut self, embeddings: Vec<Embedding>) -> Vec<String>; | ||
|
||
// Adds a list of embeddings to the store and associates them with a list of embedded data. | ||
fn add_all_with_embedded(&mut self, embeddings: Vec<Embedding>, embedded: Vec<Embedded>) -> Vec<String>; | ||
|
||
// Find relevant embeddings in the store based on a reference embedding, with a maximum number of results. | ||
// An optional minimum score can be specified to filter results. | ||
fn find_relevant( | ||
&self, | ||
reference_embedding: Embedding, | ||
max_results: usize, | ||
min_score: f32, | ||
) -> Vec<EmbeddingMatch<Embedded>>; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
use std::collections::BinaryHeap; | ||
|
||
use crate::cosine_similarity::CosineSimilarity; | ||
use crate::embedding::Embedding; | ||
use crate::embedding_match::EmbeddingMatch; | ||
use crate::relevance_score::RelevanceScore; | ||
|
||
|
||
#[derive(Clone)] | ||
struct Entry<Embedded: Clone + Ord> { | ||
id: String, | ||
embedding: Embedding, | ||
embedded: Option<Embedded>, | ||
} | ||
|
||
impl<Embedded: Clone + Ord> Entry<Embedded> { | ||
fn new(id: String, embedding: Embedding, embedded: Option<Embedded>) -> Self { | ||
Entry { id, embedding, embedded } | ||
} | ||
} | ||
|
||
pub struct InMemoryEmbeddingStore<Embedded: Clone + Ord> { | ||
entries: Vec<Entry<Embedded>>, | ||
} | ||
|
||
// Implement methods for InMemoryEmbeddingStore | ||
impl<Embedded: Clone + Ord> InMemoryEmbeddingStore<Embedded> { | ||
fn new() -> Self { | ||
InMemoryEmbeddingStore { entries: Vec::new() } | ||
} | ||
|
||
fn add(&mut self, embedding: Embedding) -> String { | ||
let id = uuid::Uuid::new_v4().to_string(); | ||
self.add_with_id(id.clone(), embedding); | ||
id | ||
} | ||
|
||
fn add_with_id(&mut self, id: String, embedding: Embedding) { | ||
self.add_with_embedded(id, embedding, None); | ||
} | ||
|
||
fn add_with_embedded(&mut self, id: String, embedding: Embedding, embedded: Option<Embedded>) -> String { | ||
let entry = Entry::new(id.clone(), embedding, embedded); | ||
self.entries.push(entry); | ||
id | ||
} | ||
|
||
fn add_all(&mut self, embeddings: Vec<Embedding>) -> Vec<String> { | ||
embeddings | ||
.into_iter() | ||
.map(|embedding| self.add(embedding)) | ||
.collect() | ||
} | ||
|
||
fn add_all_with_embedded(&mut self, embeddings: Vec<Embedding>, embedded: Vec<Embedded>) -> Vec<String> { | ||
assert_eq!(embeddings.len(), embedded.len(), "The list of embeddings and embedded must have the same size"); | ||
|
||
embeddings | ||
.into_iter() | ||
.zip(embedded) | ||
.map(|(embedding, embedded)| self.add_with_embedded(uuid::Uuid::new_v4().to_string(), embedding, Some(embedded))) | ||
.collect() | ||
} | ||
|
||
fn find_relevant(&self, reference_embedding: Embedding, max_results: usize, min_score: f32) -> Vec<EmbeddingMatch<Embedded>> { | ||
let mut matches = BinaryHeap::new(); | ||
|
||
for entry in &self.entries { | ||
let cosine_similarity = CosineSimilarity::between(&entry.embedding, &reference_embedding); | ||
let score = RelevanceScore::from_cosine_similarity(cosine_similarity); | ||
|
||
if score >= min_score { | ||
matches.push(EmbeddingMatch::new(score, entry.id.clone(), entry.embedding.clone(), entry.embedded.clone().unwrap())); | ||
|
||
if matches.len() > max_results { | ||
matches.pop(); | ||
} | ||
} | ||
} | ||
|
||
let mut result: Vec<_> = matches.into_sorted_vec(); | ||
result.reverse(); | ||
result | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
pub struct RelevanceScore; | ||
|
||
impl RelevanceScore { | ||
pub fn from_cosine_similarity(cosine_similarity: f32) -> f32 { | ||
(cosine_similarity + 1.0) / 2.0 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
use crate::embedding::Embedding; | ||
|
||
pub trait Similarity { | ||
fn similarity_score(&self, set1: &Embedding, set2: &Embedding) -> f32; | ||
} |