Skip to content

Commit

Permalink
Extract query service
Browse files Browse the repository at this point in the history
  • Loading branch information
tygern committed Jun 3, 2024
1 parent f5166b6 commit cc67e6f
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 28 deletions.
3 changes: 2 additions & 1 deletion internal/app/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ import (

func Handlers(aiClient ai.Client, db *sql.DB) func(mux *http.ServeMux) {
embeddingsGateway := analyzer.NewEmbeddingsGateway(db)
queryService := NewQueryService(embeddingsGateway, aiClient)

return func(mux *http.ServeMux) {
mux.HandleFunc("GET /", Index())
mux.HandleFunc("POST /", Query(aiClient, embeddingsGateway))
mux.HandleFunc("POST /", Query(queryService))
mux.HandleFunc("GET /health", Health)

static, _ := fs.Sub(Resources, "resources/static")
Expand Down
31 changes: 4 additions & 27 deletions internal/app/index.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package app

import (
"fmt"
"github.com/initialcapacity/ai-starter/internal/analyzer"
"github.com/initialcapacity/ai-starter/pkg/ai"
"github.com/initialcapacity/ai-starter/pkg/deferrable"
"github.com/initialcapacity/ai-starter/pkg/websupport"
"log/slog"
Expand All @@ -22,7 +19,7 @@ type model struct {
Source string
}

func Query(aiClient ai.Client, embeddingsGateway *analyzer.EmbeddingsGateway) http.HandlerFunc {
func Query(queryService *QueryService) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := r.ParseForm()
if err != nil {
Expand All @@ -32,36 +29,16 @@ func Query(aiClient ai.Client, embeddingsGateway *analyzer.EmbeddingsGateway) ht
}

query := r.Form.Get("query")
embedding, err := aiClient.CreateEmbedding(r.Context(), query)
result, err := queryService.FetchResponse(r.Context(), query)
if err != nil {
slog.Error("unable to create embedding", err)
w.WriteHeader(http.StatusInternalServerError)
return
}

record, err := embeddingsGateway.FindSimilar(embedding)
if err != nil {
slog.Error("unable to find similar embedding", err)
w.WriteHeader(http.StatusInternalServerError)
return
}

responseChannel, err := aiClient.GetChatCompletion(r.Context(), []ai.ChatMessage{
{Role: ai.System, Content: "You are a reporter for a major world newspaper."},
{Role: ai.System, Content: "Write your response as if you were writing a short, high-quality news article for your paper. Limit your response to one paragraph."},
{Role: ai.System, Content: fmt.Sprintf("Use the following article for context: %s", record.Content)},
{Role: ai.User, Content: query},
})
if err != nil {
slog.Error("unable fetch chat completion", err)
w.WriteHeader(http.StatusInternalServerError)
return
}

_ = websupport.Render(w, Resources, "response", model{
Query: query,
Response: deferrable.New(w, responseChannel),
Source: record.Source,
Response: deferrable.New(w, result.Response),
Source: result.Source,
})
}
}
58 changes: 58 additions & 0 deletions internal/app/query_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package app

import (
"context"
"fmt"
"github.com/initialcapacity/ai-starter/internal/analyzer"
"github.com/initialcapacity/ai-starter/pkg/ai"
"log/slog"
)

type QueryService struct {
embeddingsGateway *analyzer.EmbeddingsGateway
aiClient aiClient
}

func NewQueryService(embeddingsGateway *analyzer.EmbeddingsGateway, aiClient aiClient) *QueryService {
return &QueryService{embeddingsGateway: embeddingsGateway, aiClient: aiClient}
}

func (q *QueryService) FetchResponse(ctx context.Context, query string) (QueryResult, error) {
embedding, err := q.aiClient.CreateEmbedding(ctx, query)
if err != nil {
slog.Error("unable to create embedding", err)
return QueryResult{}, err
}

record, err := q.embeddingsGateway.FindSimilar(embedding)
if err != nil {
slog.Error("unable to find similar embedding", err)
return QueryResult{}, err
}

response, err := q.aiClient.GetChatCompletion(ctx, []ai.ChatMessage{
{Role: ai.System, Content: "You are a reporter for a major world newspaper."},
{Role: ai.System, Content: "Write your response as if you were writing a short, high-quality news article for your paper. Limit your response to one paragraph."},
{Role: ai.System, Content: fmt.Sprintf("Use the following article for context: %s", record.Content)},
{Role: ai.User, Content: query},
})
if err != nil {
slog.Error("unable fetch chat completion", err)
return QueryResult{}, err
}

return QueryResult{
Source: record.Source,
Response: response,
}, nil
}

type QueryResult struct {
Response chan string
Source string
}

type aiClient interface {
GetChatCompletion(ctx context.Context, messages []ai.ChatMessage) (chan string, error)
CreateEmbedding(ctx context.Context, text string) ([]float32, error)
}
92 changes: 92 additions & 0 deletions internal/app/query_service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package app_test

import (
"context"
"errors"
"github.com/initialcapacity/ai-starter/internal/analyzer"
"github.com/initialcapacity/ai-starter/internal/app"
"github.com/initialcapacity/ai-starter/pkg/ai"
"github.com/initialcapacity/ai-starter/pkg/testsupport"
"github.com/pgvector/pgvector-go"
"github.com/stretchr/testify/assert"
"testing"
)

func TestQueryService_FetchResponse(t *testing.T) {
testDb := testsupport.NewTestDb(t)
defer testDb.Close()
insertData(testDb)

service := app.NewQueryService(analyzer.NewEmbeddingsGateway(testDb.DB), fakeAi{})

result, err := service.FetchResponse(context.Background(), "Does this sound good?")
assert.NoError(t, err)
message := <-result.Response

assert.Equal(t, "https://example.com", result.Source)
assert.Equal(t, "Sounds good", message)
}

func TestQueryService_FetchResponse_EmbeddingError(t *testing.T) {
testDb := testsupport.NewTestDb(t)
defer testDb.Close()
insertData(testDb)
service := app.NewQueryService(analyzer.NewEmbeddingsGateway(testDb.DB), fakeAi{embeddingError: errors.New("bad news")})

_, err := service.FetchResponse(context.Background(), "Does this sound good?")

assert.EqualError(t, err, "bad news")
}

func TestQueryService_FetchResponse_NoEmbeddings(t *testing.T) {
testDb := testsupport.NewTestDb(t)
defer testDb.Close()
service := app.NewQueryService(analyzer.NewEmbeddingsGateway(testDb.DB), fakeAi{})

_, err := service.FetchResponse(context.Background(), "Does this sound good?")

assert.EqualError(t, err, "sql: no rows in result set")
}

func TestQueryService_FetchResponse_CompletionError(t *testing.T) {
testDb := testsupport.NewTestDb(t)
defer testDb.Close()
insertData(testDb)
service := app.NewQueryService(analyzer.NewEmbeddingsGateway(testDb.DB), fakeAi{completionError: errors.New("bad news")})

_, err := service.FetchResponse(context.Background(), "Does this sound good?")

assert.EqualError(t, err, "bad news")
}

func insertData(testDb *testsupport.TestDb) {
testDb.Execute("insert into data (id, source, content) values ('aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16', 'https://example.com', 'some content')")
testDb.Execute("insert into chunks (id, data_id, content) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', 'aaaaaaaa-2f3f-4bc9-8dba-ba397156cc16','a chunk')")
testDb.Execute("insert into embeddings (chunk_id, embedding) values ('bbbbbbbb-2f3f-4bc9-8dba-ba397156cc16', $1)", pgvector.NewVector(testsupport.CreateVector(0)))

}

type fakeAi struct {
embeddingError error
completionError error
}

func (f fakeAi) GetChatCompletion(_ context.Context, _ []ai.ChatMessage) (chan string, error) {
if f.embeddingError != nil {
return nil, f.embeddingError
}

response := make(chan string)
go func() {
response <- "Sounds good"
}()
return response, nil
}

func (f fakeAi) CreateEmbedding(_ context.Context, _ string) ([]float32, error) {
if f.completionError != nil {
return nil, f.completionError
}

return testsupport.CreateVector(0), nil
}

0 comments on commit cc67e6f

Please sign in to comment.