From d81d972f9b8726461210369f60552ef61c5fd6dd Mon Sep 17 00:00:00 2001 From: Eli Bendersky Date: Tue, 9 Jan 2024 06:39:02 -0800 Subject: [PATCH] googleai: implement streaming and some generation options (#504) Re #410 --- llms/googleai/googleai_llm.go | 107 ++++++++++++++++++++++------- llms/googleai/googleai_llm_test.go | 37 ++++++++++ llms/googleai/googleai_option.go | 5 ++ 3 files changed, 124 insertions(+), 25 deletions(-) diff --git a/llms/googleai/googleai_llm.go b/llms/googleai/googleai_llm.go index 92243583a..63278802c 100644 --- a/llms/googleai/googleai_llm.go +++ b/llms/googleai/googleai_llm.go @@ -1,3 +1,7 @@ +// package googleai implements a langchaingo provider for Google AI LLMs. +// See https://ai.google.dev/ for more details and documetnation. +// +//nolint:goerr113, lll package googleai import ( @@ -5,12 +9,14 @@ import ( "errors" "fmt" "io" + "log" "net/http" "strings" "github.com/google/generative-ai-go/genai" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" + "google.golang.org/api/iterator" "google.golang.org/api/option" ) @@ -26,7 +32,7 @@ var ( ErrNoContentInResponse = errors.New("no content in generation response") ErrUnknownPartInResponse = errors.New("unknown part type in generation response") ErrInvalidMimeType = errors.New("invalid mime type on content") - ErrSystemRoleNotSupported = errors.New("system roles isn't supporeted yet") + ErrSystemRoleNotSupported = errors.New("system role isn't supporeted yet") ) const ( @@ -57,26 +63,28 @@ func NewGoogleAI(ctx context.Context, opts ...Option) (*GoogleAI, error) { } // GenerateContent calls the LLM with the provided parts. -// -//nolint:goerr113 -func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint:lll +func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { opts := llms.CallOptions{ - Model: g.opts.defaultModel, + Model: g.opts.defaultModel, + MaxTokens: int(g.opts.defaultMaxTokens), + Temperature: float64(g.opts.defaultTemperature), } for _, opt := range options { opt(&opts) } model := g.client.GenerativeModel(opts.Model) + model.SetMaxOutputTokens(int32(opts.MaxTokens)) + model.SetTemperature(float32(opts.Temperature)) if len(messages) == 1 { theMessage := messages[0] if theMessage.Role != schema.ChatMessageTypeHuman { return nil, fmt.Errorf("got %v message role, want human", theMessage.Role) } - return generateFromSingleMessage(ctx, model, theMessage.Parts) + return generateFromSingleMessage(ctx, model, theMessage.Parts, &opts) } - return generateFromMessages(ctx, model, messages) + return generateFromMessages(ctx, model, messages, &opts) } // downloadImageData downloads the content from the given URL and returns it as @@ -181,8 +189,6 @@ func convertParts(parts []llms.ContentPart) ([]genai.Part, error) { } // convertContent converts between a langchain MessageContent and genai content. -// -//nolint:goerr113 func convertContent(content llms.MessageContent) (*genai.Content, error) { parts, err := convertParts(content.Parts) if err != nil { @@ -213,25 +219,30 @@ func convertContent(content llms.MessageContent) (*genai.Content, error) { // generateFromSingleMessage generates content from the parts of a single // message. -func generateFromSingleMessage(ctx context.Context, model *genai.GenerativeModel, parts []llms.ContentPart) (*llms.ContentResponse, error) { //nolint:lll +func generateFromSingleMessage(ctx context.Context, model *genai.GenerativeModel, parts []llms.ContentPart, opts *llms.CallOptions) (*llms.ContentResponse, error) { convertedParts, err := convertParts(parts) if err != nil { return nil, err } - resp, err := model.GenerateContent(ctx, convertedParts...) - if err != nil { - return nil, err - } + if opts.StreamingFunc == nil { + // When no streaming is requested, just call GenerateContent and return + // the complete response with a list of candidates. + resp, err := model.GenerateContent(ctx, convertedParts...) + if err != nil { + return nil, err + } - if len(resp.Candidates) == 0 { - return nil, ErrNoContentInResponse + if len(resp.Candidates) == 0 { + return nil, ErrNoContentInResponse + } + return convertCandidates(resp.Candidates) } - return convertCandidates(resp.Candidates) + iter := model.GenerateContentStream(ctx, convertedParts...) + return convertAndStreamFromIterator(ctx, iter, opts) } -//nolint:goerr113 -func generateFromMessages(ctx context.Context, model *genai.GenerativeModel, messages []llms.MessageContent) (*llms.ContentResponse, error) { //nolint:lll +func generateFromMessages(ctx context.Context, model *genai.GenerativeModel, messages []llms.MessageContent, opts *llms.CallOptions) (*llms.ContentResponse, error) { history := make([]*genai.Content, 0, len(messages)) for _, mc := range messages { content, err := convertContent(mc) @@ -254,13 +265,59 @@ func generateFromMessages(ctx context.Context, model *genai.GenerativeModel, mes session := model.StartChat() session.History = history - resp, err := session.SendMessage(ctx, reqContent.Parts...) - if err != nil { - return nil, err + if opts.StreamingFunc == nil { + resp, err := session.SendMessage(ctx, reqContent.Parts...) + if err != nil { + return nil, err + } + + if len(resp.Candidates) == 0 { + return nil, ErrNoContentInResponse + } + return convertCandidates(resp.Candidates) } + iter := session.SendMessageStream(ctx, reqContent.Parts...) + return convertAndStreamFromIterator(ctx, iter, opts) +} + +// convertAndStreamFromIterator takes an iterator of GenerateContentResponse +// and produces a llms.ContentResponse reply from it, while streaming the +// resulting text into the opts-provided streaming function. +// Note that this is tricky in the face of multiple +// candidates, so this code assumes only a single candidate for now. +func convertAndStreamFromIterator(ctx context.Context, iter *genai.GenerateContentResponseIterator, opts *llms.CallOptions) (*llms.ContentResponse, error) { + candidate := &genai.Candidate{ + Content: &genai.Content{}, + } +DoStream: + for { + resp, err := iter.Next() + if errors.Is(err, iterator.Done) { + break + } + if err != nil { + log.Fatal(err) + } - if len(resp.Candidates) == 0 { - return nil, ErrNoContentInResponse + if len(resp.Candidates) != 1 { + return nil, fmt.Errorf("expect single candidate in stream mode; got %v", len(resp.Candidates)) + } + respCandidate := resp.Candidates[0] + candidate.Content.Parts = append(candidate.Content.Parts, respCandidate.Content.Parts...) + candidate.Content.Role = respCandidate.Content.Role + candidate.FinishReason = respCandidate.FinishReason + candidate.SafetyRatings = respCandidate.SafetyRatings + candidate.CitationMetadata = respCandidate.CitationMetadata + candidate.TokenCount += respCandidate.TokenCount + + for _, part := range respCandidate.Content.Parts { + if text, ok := part.(genai.Text); ok { + if opts.StreamingFunc(ctx, []byte(text)) != nil { + break DoStream + } + } + } } - return convertCandidates(resp.Candidates) + + return convertCandidates([]*genai.Candidate{candidate}) } diff --git a/llms/googleai/googleai_llm_test.go b/llms/googleai/googleai_llm_test.go index 1b46d12a6..2c4739bce 100644 --- a/llms/googleai/googleai_llm_test.go +++ b/llms/googleai/googleai_llm_test.go @@ -48,6 +48,43 @@ func TestMultiContentText(t *testing.T) { assert.Regexp(t, "dog|canid|canine", strings.ToLower(c1.Content)) } +func TestMultiContentTextStream(t *testing.T) { + t.Parallel() + llm := newClient(t) + + parts := []llms.ContentPart{ + llms.TextContent{Text: "I'm a pomeranian"}, + llms.TextContent{Text: "Tell me more about my taxonomy"}, + } + content := []llms.MessageContent{ + { + Role: schema.ChatMessageTypeHuman, + Parts: parts, + }, + } + + var chunks [][]byte + var sb strings.Builder + rsp, err := llm.GenerateContent(context.Background(), content, + llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error { + chunks = append(chunks, chunk) + sb.Write(chunk) + return nil + })) + + require.NoError(t, err) + + assert.NotEmpty(t, rsp.Choices) + // Check that the combined response contains what we expect + c1 := rsp.Choices[0] + assert.Regexp(t, "dog|canid|canine", strings.ToLower(c1.Content)) + + // Check that multiple chunks were received and they also have words + // we expect. + assert.GreaterOrEqual(t, len(chunks), 2) + assert.Regexp(t, "dog|canid|canine", sb.String()) +} + func TestMultiContentTextChatSequence(t *testing.T) { t.Parallel() llm := newClient(t) diff --git a/llms/googleai/googleai_option.go b/llms/googleai/googleai_option.go index 6f799b0be..38e591db4 100644 --- a/llms/googleai/googleai_option.go +++ b/llms/googleai/googleai_option.go @@ -1,3 +1,4 @@ +//nolint:gomnd package googleai // options is a set of options for GoogleAI clients. @@ -5,6 +6,8 @@ type options struct { apiKey string defaultModel string defaultEmbeddingModel string + defaultMaxTokens int32 + defaultTemperature float32 } func defaultOptions() options { @@ -12,6 +15,8 @@ func defaultOptions() options { apiKey: "", defaultModel: "gemini-pro", defaultEmbeddingModel: "embedding-001", + defaultMaxTokens: 256, + defaultTemperature: 0.5, } }