From cc67e6f67b1ab8bd3b0763583ed73100b2cccc6a Mon Sep 17 00:00:00 2001 From: Tyson Gern Date: Mon, 3 Jun 2024 05:58:12 -0600 Subject: [PATCH] Extract query service --- internal/app/handlers.go | 3 +- internal/app/index.go | 31 ++-------- internal/app/query_service.go | 58 +++++++++++++++++++ internal/app/query_service_test.go | 92 ++++++++++++++++++++++++++++++ 4 files changed, 156 insertions(+), 28 deletions(-) create mode 100644 internal/app/query_service.go create mode 100644 internal/app/query_service_test.go diff --git a/internal/app/handlers.go b/internal/app/handlers.go index c6328c4..d93eb00 100644 --- a/internal/app/handlers.go +++ b/internal/app/handlers.go @@ -10,10 +10,11 @@ import ( func Handlers(aiClient ai.Client, db *sql.DB) func(mux *http.ServeMux) { embeddingsGateway := analyzer.NewEmbeddingsGateway(db) + queryService := NewQueryService(embeddingsGateway, aiClient) return func(mux *http.ServeMux) { mux.HandleFunc("GET /", Index()) - mux.HandleFunc("POST /", Query(aiClient, embeddingsGateway)) + mux.HandleFunc("POST /", Query(queryService)) mux.HandleFunc("GET /health", Health) static, _ := fs.Sub(Resources, "resources/static") diff --git a/internal/app/index.go b/internal/app/index.go index 2a15dac..9e28693 100644 --- a/internal/app/index.go +++ b/internal/app/index.go @@ -1,9 +1,6 @@ package app import ( - "fmt" - "github.com/initialcapacity/ai-starter/internal/analyzer" - "github.com/initialcapacity/ai-starter/pkg/ai" "github.com/initialcapacity/ai-starter/pkg/deferrable" "github.com/initialcapacity/ai-starter/pkg/websupport" "log/slog" @@ -22,7 +19,7 @@ type model struct { Source string } -func Query(aiClient ai.Client, embeddingsGateway *analyzer.EmbeddingsGateway) http.HandlerFunc { +func Query(queryService *QueryService) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { err := r.ParseForm() if err != nil { @@ -32,36 +29,16 @@ func Query(aiClient ai.Client, embeddingsGateway *analyzer.EmbeddingsGateway) ht } query := r.Form.Get("query") - embedding, err := aiClient.CreateEmbedding(r.Context(), query) + result, err := queryService.FetchResponse(r.Context(), query) if err != nil { - slog.Error("unable to create embedding", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - record, err := embeddingsGateway.FindSimilar(embedding) - if err != nil { - slog.Error("unable to find similar embedding", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - responseChannel, err := aiClient.GetChatCompletion(r.Context(), []ai.ChatMessage{ - {Role: ai.System, Content: "You are a reporter for a major world newspaper."}, - {Role: ai.System, Content: "Write your response as if you were writing a short, high-quality news article for your paper. Limit your response to one paragraph."}, - {Role: ai.System, Content: fmt.Sprintf("Use the following article for context: %s", record.Content)}, - {Role: ai.User, Content: query}, - }) - if err != nil { - slog.Error("unable fetch chat completion", err) w.WriteHeader(http.StatusInternalServerError) return } _ = websupport.Render(w, Resources, "response", model{ Query: query, - Response: deferrable.New(w, responseChannel), - Source: record.Source, + Response: deferrable.New(w, result.Response), + Source: result.Source, }) } } diff --git a/internal/app/query_service.go b/internal/app/query_service.go new file mode 100644 index 0000000..cb06f13 --- /dev/null +++ b/internal/app/query_service.go @@ -0,0 +1,58 @@ +package app + +import ( + "context" + "fmt" + "github.com/initialcapacity/ai-starter/internal/analyzer" + "github.com/initialcapacity/ai-starter/pkg/ai" + "log/slog" +) + +type QueryService struct { + embeddingsGateway *analyzer.EmbeddingsGateway + aiClient aiClient +} + +func NewQueryService(embeddingsGateway *analyzer.EmbeddingsGateway, aiClient aiClient) *QueryService { + return &QueryService{embeddingsGateway: embeddingsGateway, aiClient: aiClient} +} + +func (q *QueryService) FetchResponse(ctx context.Context, query string) (QueryResult, error) { + embedding, err := q.aiClient.CreateEmbedding(ctx, query) + if err != nil { + slog.Error("unable to create embedding", err) + return QueryResult{}, err + } + + record, err := q.embeddingsGateway.FindSimilar(embedding) + if err != nil { + slog.Error("unable to find similar embedding", err) + return QueryResult{}, err + } + + response, err := q.aiClient.GetChatCompletion(ctx, []ai.ChatMessage{ + {Role: ai.System, Content: "You are a reporter for a major world newspaper."}, + {Role: ai.System, Content: "Write your response as if you were writing a short, high-quality news article for your paper. Limit your response to one paragraph."}, + {Role: ai.System, Content: fmt.Sprintf("Use the following article for context: %s", record.Content)}, + {Role: ai.User, Content: query}, + }) + if err != nil { + slog.Error("unable fetch chat completion", err) + return QueryResult{}, err + } + + return QueryResult{ + Source: record.Source, + Response: response, + }, nil +} + +type QueryResult struct { + Response chan string + Source string +} + +type aiClient interface { + GetChatCompletion(ctx context.Context, messages []ai.ChatMessage) (chan string, error) + CreateEmbedding(ctx context.Context, text string) ([]float32, error) +} diff --git a/internal/app/query_service_test.go b/internal/app/query_service_test.go new file mode 100644 index 0000000..20c89d9 --- /dev/null +++ b/internal/app/query_service_test.go @@ -0,0 +1,92 @@ +package app_test + +import ( + "context" + "errors" + "github.com/initialcapacity/ai-starter/internal/analyzer" + "github.com/initialcapacity/ai-starter/internal/app" + "github.com/initialcapacity/ai-starter/pkg/ai" + "github.com/initialcapacity/ai-starter/pkg/testsupport" + "github.com/pgvector/pgvector-go" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestQueryService_FetchResponse(t *testing.T) { + testDb := testsupport.NewTestDb(t) + defer testDb.Close() + insertData(testDb) + + service := app.NewQueryService(analyzer.NewEmbeddingsGateway(testDb.DB), fakeAi{}) + + result, err := service.FetchResponse(context.Background(), "Does this sound good?") + assert.NoError(t, err) + message := <-result.Response + + assert.Equal(t, "https://example.com", result.Source) + assert.Equal(t, "Sounds good", message) +} + +func TestQueryService_FetchResponse_EmbeddingError(t *testing.T) { + testDb := testsupport.NewTestDb(t) + defer testDb.Close() + insertData(testDb) + service := app.NewQueryService(analyzer.NewEmbeddingsGateway(testDb.DB), fakeAi{embeddingError: errors.New("bad news")}) + + _, err := service.FetchResponse(context.Background(), "Does this sound good?") + + assert.EqualError(t, err, "bad news") +} + +func TestQueryService_FetchResponse_NoEmbeddings(t *testing.T) { + testDb := testsupport.NewTestDb(t) + defer testDb.Close() + service := app.NewQueryService(analyzer.NewEmbeddingsGateway(testDb.DB), fakeAi{}) + + _, err := service.FetchResponse(context.Background(), "Does this sound good?") + + assert.EqualError(t, err, "sql: no rows in result set") +} + +func TestQueryService_FetchResponse_CompletionError(t *testing.T) { + testDb := testsupport.NewTestDb(t) + defer testDb.Close() + insertData(testDb) + service := app.NewQueryService(analyzer.NewEmbeddingsGateway(testDb.DB), fakeAi{completionError: errors.New("bad news")}) + + _, err := service.FetchResponse(context.Background(), "Does this sound good?") + + assert.EqualError(t, err, "bad news") +} + +func insertData(testDb *testsupport.TestDb) { + testDb.Execute("insert into data (id, source, content) values ('aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16', 'https://example.com', 'some content')") + testDb.Execute("insert into chunks (id, data_id, content) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')") + testDb.Execute("insert into embeddings (chunk_id, embedding) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', $1)", pgvector.NewVector(testsupport.CreateVector(0))) + +} + +type fakeAi struct { + embeddingError error + completionError error +} + +func (f fakeAi) GetChatCompletion(_ context.Context, _ []ai.ChatMessage) (chan string, error) { + if f.embeddingError != nil { + return nil, f.embeddingError + } + + response := make(chan string) + go func() { + response <- "Sounds good" + }() + return response, nil +} + +func (f fakeAi) CreateEmbedding(_ context.Context, _ string) ([]float32, error) { + if f.completionError != nil { + return nil, f.completionError + } + + return testsupport.CreateVector(0), nil +}