Skip to content

Commit

Permalink
Adds the score to the Embedding record
Browse files Browse the repository at this point in the history
  • Loading branch information
javiertoledo committed Sep 6, 2023
1 parent a5dc39f commit 7000e61
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 26 deletions.
2 changes: 2 additions & 0 deletions .idea/gradle.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions core/build.gradle
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
plugins {
id 'java'
id "io.freefair.lombok" version "8.3"
}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
package com.theagilemonkeys.ellmental.core.schema;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.RequiredArgsConstructor;

import java.util.List;
import java.util.Map;
import java.util.UUID;

/**
* Embeddings represent a point in the embeddings space, representing the semantics of a given text.
*/
public record Embedding(
UUID id,
List<Double> vector,
Map<String, String> metadata
) {}
@Data
@RequiredArgsConstructor
@AllArgsConstructor
public class Embedding {
public final UUID id;
public final List<Double> vector;
public final Map<String, String> metadata;
public Double score;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ public static void main(String[] args) {
Embedding embedding = embeddingsSpaceComponent.save("Hello World!");
System.out.printf("Saved embedding: %s", embedding);

System.out.printf("Got embedding: %s", embeddingsSpaceComponent.get(embedding.id()));
System.out.printf("Got embedding: %s", embeddingsSpaceComponent.get(embedding.id));

embeddingsSpaceComponent.delete(embedding.id());
embeddingsSpaceComponent.delete(embedding.id);

System.out.printf("Got embedding after delete: %s", embeddingsSpaceComponent.get(embedding.id()));
System.out.printf("Got embedding after delete: %s", embeddingsSpaceComponent.get(embedding.id));
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,25 @@ public void testGenerateEmbedding() {
Embedding embedding = openAIEmbeddingsModel.generateEmbedding("The Agile Monkeys rule!");

// The id is not null and is a valid UUID
assertNotNull(embedding.id());
assertDoesNotThrow(() -> UUID.fromString(embedding.id().toString()));
assertNotNull(embedding.id);
assertDoesNotThrow(() -> UUID.fromString(embedding.id.toString()));

// The embedding is properly set
assertEquals(embedding.vector().size(), testValues.testGenerateEmbeddingExpectedValue.size());
assertArrayEquals(embedding.vector().toArray(), testValues.testGenerateEmbeddingExpectedValue.toArray());
assertEquals(embedding.vector.size(), testValues.testGenerateEmbeddingExpectedValue.size());
assertArrayEquals(embedding.vector.toArray(), testValues.testGenerateEmbeddingExpectedValue.toArray());

// The original input is retrievable from the metadata
String input = embedding.metadata().get("input");
String input = embedding.metadata.get("input");
assertEquals(input, "The Agile Monkeys rule!");

// The source and model are set to the right parameters
String source = embedding.metadata().get("source");
String source = embedding.metadata.get("source");
assertEquals(source, "OpenAI");
String model = embedding.metadata().get("model");
String model = embedding.metadata.get("model");
assertEquals(model, OpenAIEmbeddingsModel.embeddingOpenAiModel);

// createdAt is set and is a valid date
String createdAt = embedding.metadata().get("createdAt");
String createdAt = embedding.metadata.get("createdAt");
assertNotNull(createdAt);
assertDoesNotThrow(() -> java.time.LocalDateTime.parse(createdAt));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public List<Embedding> similaritySearch(Embedding reference, int limit) {
limit,
true,
true,
reference.vector(),
reference.vector,
this.namespace);
String requestBodyJson = new Gson().toJson(body);

Expand All @@ -118,7 +118,7 @@ public List<Embedding> similaritySearch(Embedding reference, int limit) {
.matches
.stream()
.sorted((a, b) -> Double.compare(b.score, a.score))
.map(match -> new Embedding(match.id, match.values, match.metadata))
.map(match -> new Embedding(match.id, match.values, match.metadata, match.score))
.collect(Collectors.toList());
} catch (IOException e) {
log.error("VectorStore error on upsert: {}", e.getMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class UpsertVectorSchema {

public UpsertVectorSchema(Embedding embedding, String namespace) {
this.vectors = new ArrayList<>();
this.vectors.add(new Vector(embedding.id(), embedding.vector(), embedding.metadata()));
this.vectors.add(new Vector(embedding.id, embedding.vector, embedding.metadata));
this.namespace = namespace;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,16 @@ public void testSimilaritySearch() throws IOException {
assertEquals(2, embeddings.size());

// Embeddings come in the right order and the UUIDs were parsed correctly
assertEquals(fakeUUID2, embeddings.get(0).id());
assertEquals(fakeUUID1, embeddings.get(1).id());
assertEquals(fakeUUID2, embeddings.get(0).id);
assertEquals(fakeUUID1, embeddings.get(1).id);

// Embeddings values were parsed correctly
assertEquals(List.of(0.4, 0.5, 0.6), embeddings.get(0).vector());
assertEquals(List.of(0.1, 0.2, 0.3), embeddings.get(1).vector());
assertEquals(List.of(0.4, 0.5, 0.6), embeddings.get(0).vector);
assertEquals(List.of(0.1, 0.2, 0.3), embeddings.get(1).vector);

// Embeddings metadata was parsed correctly
assertEquals(Map.of("key2", "value2"), embeddings.get(0).metadata());
assertEquals(Map.of("key1", "value1"), embeddings.get(1).metadata());
assertEquals(Map.of("key2", "value2"), embeddings.get(0).metadata);
assertEquals(Map.of("key1", "value1"), embeddings.get(1).metadata);
}

@Test
Expand Down Expand Up @@ -148,7 +148,7 @@ public void testGet() throws IOException {
Embedding embedding = pineconeEmbeddingsStore.get(uuid);

assertNotNull(embedding);
assertEquals(vector.getValues(), embedding.vector());
assertEquals(vector.getValues(), embedding.vector);
}

@Test
Expand Down

0 comments on commit 7000e61

Please sign in to comment.