Skip to content

Commit

Permalink
googleai: implement streaming and some generation options (tmc#504)
Browse files Browse the repository at this point in the history
  • Loading branch information
eliben authored Jan 9, 2024
1 parent d098fea commit d81d972
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 25 deletions.
107 changes: 82 additions & 25 deletions llms/googleai/googleai_llm.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
// package googleai implements a langchaingo provider for Google AI LLMs.
// See https://ai.google.dev/ for more details and documetnation.
//
//nolint:goerr113, lll
package googleai

import (
"context"
"errors"
"fmt"
"io"
"log"
"net/http"
"strings"

"github.com/google/generative-ai-go/genai"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"google.golang.org/api/iterator"
"google.golang.org/api/option"
)

Expand All @@ -26,7 +32,7 @@ var (
ErrNoContentInResponse = errors.New("no content in generation response")
ErrUnknownPartInResponse = errors.New("unknown part type in generation response")
ErrInvalidMimeType = errors.New("invalid mime type on content")
ErrSystemRoleNotSupported = errors.New("system roles isn't supporeted yet")
ErrSystemRoleNotSupported = errors.New("system role isn't supporeted yet")
)

const (
Expand Down Expand Up @@ -57,26 +63,28 @@ func NewGoogleAI(ctx context.Context, opts ...Option) (*GoogleAI, error) {
}

// GenerateContent calls the LLM with the provided parts.
//
//nolint:goerr113
func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint:lll
func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) {
opts := llms.CallOptions{
Model: g.opts.defaultModel,
Model: g.opts.defaultModel,
MaxTokens: int(g.opts.defaultMaxTokens),
Temperature: float64(g.opts.defaultTemperature),
}
for _, opt := range options {
opt(&opts)
}

model := g.client.GenerativeModel(opts.Model)
model.SetMaxOutputTokens(int32(opts.MaxTokens))
model.SetTemperature(float32(opts.Temperature))

if len(messages) == 1 {
theMessage := messages[0]
if theMessage.Role != schema.ChatMessageTypeHuman {
return nil, fmt.Errorf("got %v message role, want human", theMessage.Role)
}
return generateFromSingleMessage(ctx, model, theMessage.Parts)
return generateFromSingleMessage(ctx, model, theMessage.Parts, &opts)
}
return generateFromMessages(ctx, model, messages)
return generateFromMessages(ctx, model, messages, &opts)
}

// downloadImageData downloads the content from the given URL and returns it as
Expand Down Expand Up @@ -181,8 +189,6 @@ func convertParts(parts []llms.ContentPart) ([]genai.Part, error) {
}

// convertContent converts between a langchain MessageContent and genai content.
//
//nolint:goerr113
func convertContent(content llms.MessageContent) (*genai.Content, error) {
parts, err := convertParts(content.Parts)
if err != nil {
Expand Down Expand Up @@ -213,25 +219,30 @@ func convertContent(content llms.MessageContent) (*genai.Content, error) {

// generateFromSingleMessage generates content from the parts of a single
// message.
func generateFromSingleMessage(ctx context.Context, model *genai.GenerativeModel, parts []llms.ContentPart) (*llms.ContentResponse, error) { //nolint:lll
func generateFromSingleMessage(ctx context.Context, model *genai.GenerativeModel, parts []llms.ContentPart, opts *llms.CallOptions) (*llms.ContentResponse, error) {
convertedParts, err := convertParts(parts)
if err != nil {
return nil, err
}

resp, err := model.GenerateContent(ctx, convertedParts...)
if err != nil {
return nil, err
}
if opts.StreamingFunc == nil {
// When no streaming is requested, just call GenerateContent and return
// the complete response with a list of candidates.
resp, err := model.GenerateContent(ctx, convertedParts...)
if err != nil {
return nil, err
}

if len(resp.Candidates) == 0 {
return nil, ErrNoContentInResponse
if len(resp.Candidates) == 0 {
return nil, ErrNoContentInResponse
}
return convertCandidates(resp.Candidates)
}
return convertCandidates(resp.Candidates)
iter := model.GenerateContentStream(ctx, convertedParts...)
return convertAndStreamFromIterator(ctx, iter, opts)
}

//nolint:goerr113
func generateFromMessages(ctx context.Context, model *genai.GenerativeModel, messages []llms.MessageContent) (*llms.ContentResponse, error) { //nolint:lll
func generateFromMessages(ctx context.Context, model *genai.GenerativeModel, messages []llms.MessageContent, opts *llms.CallOptions) (*llms.ContentResponse, error) {
history := make([]*genai.Content, 0, len(messages))
for _, mc := range messages {
content, err := convertContent(mc)
Expand All @@ -254,13 +265,59 @@ func generateFromMessages(ctx context.Context, model *genai.GenerativeModel, mes
session := model.StartChat()
session.History = history

resp, err := session.SendMessage(ctx, reqContent.Parts...)
if err != nil {
return nil, err
if opts.StreamingFunc == nil {
resp, err := session.SendMessage(ctx, reqContent.Parts...)
if err != nil {
return nil, err
}

if len(resp.Candidates) == 0 {
return nil, ErrNoContentInResponse
}
return convertCandidates(resp.Candidates)
}
iter := session.SendMessageStream(ctx, reqContent.Parts...)
return convertAndStreamFromIterator(ctx, iter, opts)
}

// convertAndStreamFromIterator takes an iterator of GenerateContentResponse
// and produces a llms.ContentResponse reply from it, while streaming the
// resulting text into the opts-provided streaming function.
// Note that this is tricky in the face of multiple
// candidates, so this code assumes only a single candidate for now.
func convertAndStreamFromIterator(ctx context.Context, iter *genai.GenerateContentResponseIterator, opts *llms.CallOptions) (*llms.ContentResponse, error) {
candidate := &genai.Candidate{
Content: &genai.Content{},
}
DoStream:
for {
resp, err := iter.Next()
if errors.Is(err, iterator.Done) {
break
}
if err != nil {
log.Fatal(err)
}

if len(resp.Candidates) == 0 {
return nil, ErrNoContentInResponse
if len(resp.Candidates) != 1 {
return nil, fmt.Errorf("expect single candidate in stream mode; got %v", len(resp.Candidates))
}
respCandidate := resp.Candidates[0]
candidate.Content.Parts = append(candidate.Content.Parts, respCandidate.Content.Parts...)
candidate.Content.Role = respCandidate.Content.Role
candidate.FinishReason = respCandidate.FinishReason
candidate.SafetyRatings = respCandidate.SafetyRatings
candidate.CitationMetadata = respCandidate.CitationMetadata
candidate.TokenCount += respCandidate.TokenCount

for _, part := range respCandidate.Content.Parts {
if text, ok := part.(genai.Text); ok {
if opts.StreamingFunc(ctx, []byte(text)) != nil {
break DoStream
}
}
}
}
return convertCandidates(resp.Candidates)

return convertCandidates([]*genai.Candidate{candidate})
}
37 changes: 37 additions & 0 deletions llms/googleai/googleai_llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,43 @@ func TestMultiContentText(t *testing.T) {
assert.Regexp(t, "dog|canid|canine", strings.ToLower(c1.Content))
}

func TestMultiContentTextStream(t *testing.T) {
t.Parallel()
llm := newClient(t)

parts := []llms.ContentPart{
llms.TextContent{Text: "I'm a pomeranian"},
llms.TextContent{Text: "Tell me more about my taxonomy"},
}
content := []llms.MessageContent{
{
Role: schema.ChatMessageTypeHuman,
Parts: parts,
},
}

var chunks [][]byte
var sb strings.Builder
rsp, err := llm.GenerateContent(context.Background(), content,
llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {
chunks = append(chunks, chunk)
sb.Write(chunk)
return nil
}))

require.NoError(t, err)

assert.NotEmpty(t, rsp.Choices)
// Check that the combined response contains what we expect
c1 := rsp.Choices[0]
assert.Regexp(t, "dog|canid|canine", strings.ToLower(c1.Content))

// Check that multiple chunks were received and they also have words
// we expect.
assert.GreaterOrEqual(t, len(chunks), 2)
assert.Regexp(t, "dog|canid|canine", sb.String())
}

func TestMultiContentTextChatSequence(t *testing.T) {
t.Parallel()
llm := newClient(t)
Expand Down
5 changes: 5 additions & 0 deletions llms/googleai/googleai_option.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
//nolint:gomnd
package googleai

// options is a set of options for GoogleAI clients.
type options struct {
apiKey string
defaultModel string
defaultEmbeddingModel string
defaultMaxTokens int32
defaultTemperature float32
}

func defaultOptions() options {
return options{
apiKey: "",
defaultModel: "gemini-pro",
defaultEmbeddingModel: "embedding-001",
defaultMaxTokens: 256,
defaultTemperature: 0.5,
}
}

Expand Down

0 comments on commit d81d972

Please sign in to comment.