Skip to content

Commit

Permalink
Add gateway tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tygern committed May 1, 2024
1 parent b8c9dbd commit 63bbf3d
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 20 deletions.
2 changes: 1 addition & 1 deletion internal/analyzer/analyze.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (a *Analyzer) Analyze(ctx context.Context) error {
slog.Info("Starting to analyze data")
defer slog.Info("Finished analyzing data")

ids, listErr := a.chunksGateway.UnprocessedIds()
ids, listErr := a.embeddingsGateway.UnprocessedIds()
if listErr != nil {
return fmt.Errorf("unable to list ids: %w", listErr)
}
Expand Down
9 changes: 9 additions & 0 deletions internal/analyzer/embeddings_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ func NewEmbeddingsGateway(db *sql.DB) *EmbeddingsGateway {
return &EmbeddingsGateway{db: db}
}

func (g *EmbeddingsGateway) UnprocessedIds() ([]string, error) {
return dbsupport.Query(
g.db,
`select chunks.id from chunks
left join public.embeddings e on chunks.id = e.chunk_id
where e.id is null`,
func(rows *sql.Rows, id *string) error { return rows.Scan(id) })
}

func (g *EmbeddingsGateway) Save(chunkId string, vector []float32) error {
_, err := g.db.Exec("insert into embeddings (chunk_id, embedding) values ($1, $2)", chunkId, pgvector.NewVector(vector))
return err
Expand Down
78 changes: 78 additions & 0 deletions internal/analyzer/embeddings_gateway_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package analyzer_test

import (
"database/sql"
"github.com/initialcapacity/ai-starter/internal/analyzer"
"github.com/initialcapacity/ai-starter/pkg/dbsupport"
"github.com/initialcapacity/ai-starter/pkg/testsupport"
"github.com/pgvector/pgvector-go"
"github.com/stretchr/testify/assert"
"testing"
)

func TestEmbeddingsGateway_UnprocessedIds(t *testing.T) {
testDb := testsupport.NewTestDb(t)
defer testDb.Close()
testDb.ClearTables()

gateway := analyzer.NewEmbeddingsGateway(testDb.DB)

testDb.Execute("insert into data (id, source, content) values ('aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16', 'https://example.com', 'some content')")
testDb.Execute("insert into chunks (id, data_id, content) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')")
testDb.Execute("insert into chunks (id, data_id, content) values ('cccccccc-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')")
vector := createVector(0)
testDb.Execute("insert into embeddings (chunk_id, embedding) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', $1)", pgvector.NewVector(vector))

ids, err := gateway.UnprocessedIds()
assert.NoError(t, err)

assert.Equal(t, []string{"cccccccc-2f3f-4bc9-8dba-ba397156cc16"}, ids)
}

func TestEmbeddingsGateway_Save(t *testing.T) {
testDb := testsupport.NewTestDb(t)
defer testDb.Close()
testDb.ClearTables()

gateway := analyzer.NewEmbeddingsGateway(testDb.DB)

testDb.Execute("insert into data (id, source, content) values ('aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16', 'https://example.com', 'some content')")
testDb.Execute("insert into chunks (id, data_id, content) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')")

err := gateway.Save("bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16", createVector(0))
assert.NoError(t, err)

chunkId, err := dbsupport.QueryOne(testDb.DB, "select chunk_id from embeddings", func(row *sql.Row, chunkId *string) error {
return row.Scan(chunkId)
})
assert.NoError(t, err)
assert.Equal(t, "bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16", chunkId)
}

func TestEmbeddingsGateway_FindSimilar(t *testing.T) {
testDb := testsupport.NewTestDb(t)
defer testDb.Close()
testDb.ClearTables()

gateway := analyzer.NewEmbeddingsGateway(testDb.DB)

testDb.Execute("insert into data (id, source, content) values ('aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16', 'https://example.com', 'some content')")
testDb.Execute("insert into chunks (id, data_id, content) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')")
testDb.Execute("insert into chunks (id, data_id, content) values ('cccccccc-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','another chunk')")

err := gateway.Save("bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16", createVector(0))
assert.NoError(t, err)
err = gateway.Save("cccccccc-2f3f-4bc9-8dba-ba397156cc16", createVector(1))
assert.NoError(t, err)

similar, err := gateway.FindSimilar(createVector(1))
assert.NoError(t, err)

assert.Equal(t, analyzer.CitedChunkRecord{Content: "another chunk", Source: "https://example.com"}, similar)
}

func createVector(oneIndex int) []float32 {
embedding := make([]float32, 3072)
embedding[oneIndex] = 1
return embedding
}
9 changes: 0 additions & 9 deletions internal/collector/chunks_gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,6 @@ func NewChunksGateway(db *sql.DB) *ChunksGateway {
return &ChunksGateway{db: db}
}

func (g *ChunksGateway) UnprocessedIds() ([]string, error) {
return dbsupport.Query(
g.db,
`select chunks.id from chunks
left join public.embeddings e on chunks.id = e.chunk_id
where e.id is null`,
func(rows *sql.Rows, id *string) error { return rows.Scan(id) })
}

func (g *ChunksGateway) Get(id string) (ChunkRecord, error) {
return dbsupport.QueryOne(
g.db,
Expand Down
46 changes: 46 additions & 0 deletions internal/collector/chunks_gateway_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package collector_test

import (
"database/sql"
"github.com/initialcapacity/ai-starter/internal/collector"
"github.com/initialcapacity/ai-starter/pkg/dbsupport"
"github.com/initialcapacity/ai-starter/pkg/testsupport"
"github.com/stretchr/testify/assert"
"testing"
)

func TestChunksGateway_Save(t *testing.T) {
testDb := testsupport.NewTestDb(t)
defer testDb.Close()
testDb.ClearTables()

gateway := collector.NewChunksGateway(testDb.DB)

testDb.Execute("insert into data (id, source, content) values ('41345dc1-2f3f-4bc9-8dba-ba397156cc16', 'https://example.com', 'some content')")
err := gateway.Save("41345dc1-2f3f-4bc9-8dba-ba397156cc16", "a chunk")
assert.NoError(t, err)

content, err := dbsupport.QueryOne(testDb.DB, "select content from chunks", func(row *sql.Row, content *string) error {
return row.Scan(content)
})
assert.NoError(t, err)
assert.Equal(t, "a chunk", content)
}

func TestChunksGateway_Get(t *testing.T) {
testDb := testsupport.NewTestDb(t)
defer testDb.Close()
testDb.ClearTables()

gateway := collector.NewChunksGateway(testDb.DB)

testDb.Execute("insert into data (id, source, content) values ('aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16', 'https://example.com', 'some content')")
testDb.Execute("insert into chunks (id, data_id, content) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')")

record, err := gateway.Get("bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16")
assert.NoError(t, err)
assert.Equal(t, collector.ChunkRecord{
DataId: "aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16",
Content: "a chunk",
}, record)
}
15 changes: 5 additions & 10 deletions internal/collector/chunks_service_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package collector_test

import (
"database/sql"
"github.com/initialcapacity/ai-starter/internal/collector"
"github.com/initialcapacity/ai-starter/pkg/dbsupport"
"github.com/initialcapacity/ai-starter/pkg/testsupport"
"github.com/stretchr/testify/assert"
"testing"
Expand All @@ -19,17 +21,10 @@ func TestChunksService_SaveChunks(t *testing.T) {
err := chunksService.SaveChunks("41345dc1-2f3f-4bc9-8dba-ba397156cc16", "some content")
assert.NoError(t, err)

ids, err := chunksGateway.UnprocessedIds()
content, err := dbsupport.Query(testDb.DB, "select content from chunks", func(rows *sql.Rows, content *string) error {
return rows.Scan(content)
})
assert.NoError(t, err)

var content []string
for _, id := range ids {
chunk, getErr := chunksGateway.Get(id)
assert.NoError(t, getErr)

content = append(content, chunk.Content)
}

testsupport.AssertContainsExactly(t, []string{"some c", "ontent"}, content)
}

Expand Down

0 comments on commit 63bbf3d

Please sign in to comment.