Skip to content

Commit

Permalink
feat: init vector store simliartiy code
Browse files Browse the repository at this point in the history
  • Loading branch information
phodal committed Nov 13, 2023
1 parent 4f061d6 commit 3b4e1f7
Show file tree
Hide file tree
Showing 9 changed files with 217 additions and 2 deletions.
1 change: 1 addition & 0 deletions inference_core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ tokenizers = { version = "0.14.1", default-features = false, features = ["progre
ndarray = "0.15.6"

uniffi = { version = "0.25" }
uuid = { version = "1.5.0", features = ["v4"] }

[lib]
crate-type = ["lib", "cdylib"]
Expand Down
33 changes: 33 additions & 0 deletions inference_core/src/cosine_similarity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use crate::embedding::Embedding;
use crate::similarity::Similarity;

pub struct CosineSimilarity;

impl Similarity for CosineSimilarity {
fn similarity_score(&self, vector_a: &Embedding, vector_b: &Embedding) -> f32 {
if vector_a.len() != vector_b.len() {
panic!(
"Length of vector a ({}) must be equal to the length of vector b ({})",
vector_a.len(),
vector_b.len()
);
}

let dot_product: f32 = vector_a
.iter()
.zip(vector_b.iter()) // Use vector_b.iter() directly
.map(|(a, b)| a * b)
.sum();

let norm_a: f32 = vector_a.iter().map(|x| x * x).sum();
let norm_b: f32 = vector_b.iter().map(|x| x * x).sum();

dot_product / (f32::sqrt(norm_a) * f32::sqrt(norm_b))
}
}

impl CosineSimilarity {
pub fn between(embedding: &Embedding, reference_embedding: &Embedding) -> f32 {
CosineSimilarity.similarity_score(embedding, reference_embedding)
}
}
13 changes: 11 additions & 2 deletions inference_core/src/document.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
use std::collections::HashMap;
use crate::embedding::Embedding;

struct Document {
pub struct Document {
id: String,
metadata: Metadata,
text: String,
vector: Embedding,
}

struct Metadata {
pub struct Metadata {
metadata: HashMap<String, String>,
}

impl Metadata {
pub fn new() -> Self {
Self {
metadata: HashMap::new(),
}
}
}


impl Document {
fn from(string_value: String) -> Self {
Self {
Expand Down
42 changes: 42 additions & 0 deletions inference_core/src/embedding_match.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use std::cmp::Ordering;
use crate::embedding::Embedding;

#[derive(Debug, Clone)]
pub struct EmbeddingMatch<Embedded: Clone + Ord> {
score: f32,
embedding_id: String,
embedding: Embedding,
embedded: Embedded,
}

impl<Embedded: Clone + Ord> EmbeddingMatch<Embedded> {
pub(crate) fn new(score: f32, embedding_id: String, embedding: Embedding, embedded: Embedded) -> Self {
EmbeddingMatch {
score,
embedding_id,
embedding,
embedded,
}
}
}

impl<Embedded: Clone + Ord> PartialOrd for EmbeddingMatch<Embedded> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.score.partial_cmp(&other.score)
}
}

impl<Embedded: Clone + Ord> PartialEq for EmbeddingMatch<Embedded> {
fn eq(&self, other: &Self) -> bool {
self.score == other.score
}
}

impl<Embedded: Clone + Ord> Ord for EmbeddingMatch<Embedded> {
fn cmp(&self, other: &Self) -> Ordering {
self.score.partial_cmp(&other.score).unwrap_or(Ordering::Equal)
}
}

impl<Embedded: Clone + Ord> Eq for EmbeddingMatch<Embedded> {
}
28 changes: 28 additions & 0 deletions inference_core/src/embedding_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use crate::embedding::Embedding;
use crate::embedding_match::EmbeddingMatch;

trait EmbeddingStore<Embedded: Clone + Ord> {
// Adds an embedding to the store and returns its unique identifier.
fn add(&mut self, embedding: Embedding) -> String;

// Adds an embedding to the store with a specified identifier.
fn add_with_id(&mut self, id: String, embedding: Embedding);

// Adds an embedding to the store and associates it with the provided embedded data.
fn add_with_embedded(&mut self, embedding: Embedding, embedded: Embedded) -> String;

// Adds a list of embeddings to the store and returns a list of unique identifiers.
fn add_all(&mut self, embeddings: Vec<Embedding>) -> Vec<String>;

// Adds a list of embeddings to the store and associates them with a list of embedded data.
fn add_all_with_embedded(&mut self, embeddings: Vec<Embedding>, embedded: Vec<Embedded>) -> Vec<String>;

// Find relevant embeddings in the store based on a reference embedding, with a maximum number of results.
// An optional minimum score can be specified to filter results.
fn find_relevant(
&self,
reference_embedding: Embedding,
max_results: usize,
min_score: f32,
) -> Vec<EmbeddingMatch<Embedded>>;
}
5 changes: 5 additions & 0 deletions inference_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ pub mod embed;
pub mod memory_store;
pub mod document;
pub mod embedding;
pub mod embedding_store;
mod embedding_match;
mod similarity;
mod relevance_score;
mod cosine_similarity;

fn hello_name(name: String) -> String {
format!("Hello {}", name)
Expand Down
85 changes: 85 additions & 0 deletions inference_core/src/memory_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
use std::collections::BinaryHeap;

use crate::cosine_similarity::CosineSimilarity;
use crate::embedding::Embedding;
use crate::embedding_match::EmbeddingMatch;
use crate::relevance_score::RelevanceScore;


#[derive(Clone)]
struct Entry<Embedded: Clone + Ord> {
id: String,
embedding: Embedding,
embedded: Option<Embedded>,
}

impl<Embedded: Clone + Ord> Entry<Embedded> {
fn new(id: String, embedding: Embedding, embedded: Option<Embedded>) -> Self {
Entry { id, embedding, embedded }
}
}

pub struct InMemoryEmbeddingStore<Embedded: Clone + Ord> {
entries: Vec<Entry<Embedded>>,
}

// Implement methods for InMemoryEmbeddingStore
impl<Embedded: Clone + Ord> InMemoryEmbeddingStore<Embedded> {
fn new() -> Self {
InMemoryEmbeddingStore { entries: Vec::new() }
}

fn add(&mut self, embedding: Embedding) -> String {
let id = uuid::Uuid::new_v4().to_string();
self.add_with_id(id.clone(), embedding);
id
}

fn add_with_id(&mut self, id: String, embedding: Embedding) {
self.add_with_embedded(id, embedding, None);
}

fn add_with_embedded(&mut self, id: String, embedding: Embedding, embedded: Option<Embedded>) -> String {
let entry = Entry::new(id.clone(), embedding, embedded);
self.entries.push(entry);
id
}

fn add_all(&mut self, embeddings: Vec<Embedding>) -> Vec<String> {
embeddings
.into_iter()
.map(|embedding| self.add(embedding))
.collect()
}

fn add_all_with_embedded(&mut self, embeddings: Vec<Embedding>, embedded: Vec<Embedded>) -> Vec<String> {
assert_eq!(embeddings.len(), embedded.len(), "The list of embeddings and embedded must have the same size");

embeddings
.into_iter()
.zip(embedded)
.map(|(embedding, embedded)| self.add_with_embedded(uuid::Uuid::new_v4().to_string(), embedding, Some(embedded)))
.collect()
}

fn find_relevant(&self, reference_embedding: Embedding, max_results: usize, min_score: f32) -> Vec<EmbeddingMatch<Embedded>> {
let mut matches = BinaryHeap::new();

for entry in &self.entries {
let cosine_similarity = CosineSimilarity::between(&entry.embedding, &reference_embedding);
let score = RelevanceScore::from_cosine_similarity(cosine_similarity);

if score >= min_score {
matches.push(EmbeddingMatch::new(score, entry.id.clone(), entry.embedding.clone(), entry.embedded.clone().unwrap()));

if matches.len() > max_results {
matches.pop();
}
}
}

let mut result: Vec<_> = matches.into_sorted_vec();
result.reverse();
result
}
}
7 changes: 7 additions & 0 deletions inference_core/src/relevance_score.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pub struct RelevanceScore;

impl RelevanceScore {
pub fn from_cosine_similarity(cosine_similarity: f32) -> f32 {
(cosine_similarity + 1.0) / 2.0
}
}
5 changes: 5 additions & 0 deletions inference_core/src/similarity.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use crate::embedding::Embedding;

pub trait Similarity {
fn similarity_score(&self, set1: &Embedding, set2: &Embedding) -> f32;
}

0 comments on commit 3b4e1f7

Please sign in to comment.