Skip to content

Commit

Permalink
Extract CreateVector to testsupport
Browse files Browse the repository at this point in the history
  • Loading branch information
tygern committed May 2, 2024
1 parent cd81d79 commit d22576c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
4 changes: 2 additions & 2 deletions internal/analyzer/analyze_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

func TestAnalyzer_Analyze(t *testing.T) {
vector := createVector(0)
vector := testsupport.CreateVector(0)
endpoint, server := testsupport.StartTestServer(t, func(mux *http.ServeMux) {
testsupport.Handle(mux, "/embeddings", fmt.Sprintf(`{
"data": [
Expand All @@ -38,7 +38,7 @@ func TestAnalyzer_Analyze(t *testing.T) {
err := a.Analyze(context.Background())
assert.NoError(t, err)

chunk1, err := embeddingsGateway.FindSimilar(createVector(0))
chunk1, err := embeddingsGateway.FindSimilar(testsupport.CreateVector(0))
assert.NoError(t, err)
assert.Equal(t, analyzer.CitedChunkRecord{Content: "chunk1", Source: "https://example.com"}, chunk1)
}
Expand Down
16 changes: 5 additions & 11 deletions internal/analyzer/embeddings_gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestEmbeddingsGateway_UnprocessedIds(t *testing.T) {
testDb.Execute("insert into data (id, source, content) values ('aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16', 'https://example.com', 'some content')")
testDb.Execute("insert into chunks (id, data_id, content) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')")
testDb.Execute("insert into chunks (id, data_id, content) values ('cccccccc-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')")
vector := createVector(0)
vector := testsupport.CreateVector(0)
testDb.Execute("insert into embeddings (chunk_id, embedding) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', $1)", pgvector.NewVector(vector))

ids, err := gateway.UnprocessedIds()
Expand All @@ -39,7 +39,7 @@ func TestEmbeddingsGateway_Save(t *testing.T) {
testDb.Execute("insert into data (id, source, content) values ('aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16', 'https://example.com', 'some content')")
testDb.Execute("insert into chunks (id, data_id, content) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')")

err := gateway.Save("bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16", createVector(0))
err := gateway.Save("bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16", testsupport.CreateVector(0))
assert.NoError(t, err)

chunkId, err := dbsupport.QueryOne(testDb.DB, "select chunk_id from embeddings", func(row *sql.Row, chunkId *string) error {
Expand All @@ -60,19 +60,13 @@ func TestEmbeddingsGateway_FindSimilar(t *testing.T) {
testDb.Execute("insert into chunks (id, data_id, content) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')")
testDb.Execute("insert into chunks (id, data_id, content) values ('cccccccc-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','another chunk')")

err := gateway.Save("bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16", createVector(0))
err := gateway.Save("bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16", testsupport.CreateVector(0))
assert.NoError(t, err)
err = gateway.Save("cccccccc-2f3f-4bc9-8dba-ba397156cc16", createVector(1))
err = gateway.Save("cccccccc-2f3f-4bc9-8dba-ba397156cc16", testsupport.CreateVector(1))
assert.NoError(t, err)

similar, err := gateway.FindSimilar(createVector(1))
similar, err := gateway.FindSimilar(testsupport.CreateVector(1))
assert.NoError(t, err)

assert.Equal(t, analyzer.CitedChunkRecord{Content: "another chunk", Source: "https://example.com"}, similar)
}

func createVector(oneIndex int) []float32 {
embedding := make([]float32, 3072)
embedding[oneIndex] = 1
return embedding
}
7 changes: 7 additions & 0 deletions pkg/testsupport/vector_support.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package testsupport

func CreateVector(oneIndex int) []float32 {
embedding := make([]float32, 3072)
embedding[oneIndex] = 1
return embedding
}

0 comments on commit d22576c

Please sign in to comment.