Skip to content

Commit a3a04d8

Browse files
committed
fix gemini provider
1 parent 792e2b1 commit a3a04d8

File tree

1 file changed

+38
-71
lines changed

1 file changed

+38
-71
lines changed

internal/llm/provider/gemini.go

+38-71
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,6 @@ func newGeminiClient(opts providerClientOptions) GeminiClient {
5454

5555
func (g *geminiClient) convertMessages(messages []message.Message) []*genai.Content {
5656
var history []*genai.Content
57-
58-
// Add system message first
59-
history = append(history, &genai.Content{
60-
Parts: []genai.Part{genai.Text(g.providerOptions.systemMessage)},
61-
Role: "user",
62-
})
63-
64-
// Add a system response to acknowledge the system message
65-
history = append(history, &genai.Content{
66-
Parts: []genai.Part{genai.Text("I'll help you with that.")},
67-
Role: "model",
68-
})
69-
7057
for _, msg := range messages {
7158
switch msg.Role {
7259
case message.User:
@@ -154,14 +141,11 @@ func (g *geminiClient) convertTools(tools []tools.BaseTool) []*genai.Tool {
154141
}
155142

156143
func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishReason {
157-
reasonStr := reason.String()
158144
switch {
159-
case reasonStr == "STOP":
145+
case reason == genai.FinishReasonStop:
160146
return message.FinishReasonEndTurn
161-
case reasonStr == "MAX_TOKENS":
147+
case reason == genai.FinishReasonMaxTokens:
162148
return message.FinishReasonMaxTokens
163-
case strings.Contains(reasonStr, "FUNCTION") || strings.Contains(reasonStr, "TOOL"):
164-
return message.FinishReasonToolUse
165149
default:
166150
return message.FinishReasonUnknown
167151
}
@@ -170,7 +154,11 @@ func (g *geminiClient) finishReason(reason genai.FinishReason) message.FinishRea
170154
func (g *geminiClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
171155
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
172156
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
173-
157+
model.SystemInstruction = &genai.Content{
158+
Parts: []genai.Part{
159+
genai.Text(g.providerOptions.systemMessage),
160+
},
161+
}
174162
// Convert tools
175163
if len(tools) > 0 {
176164
model.Tools = g.convertTools(tools)
@@ -188,19 +176,13 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
188176
attempts := 0
189177
for {
190178
attempts++
179+
var toolCalls []message.ToolCall
191180
chat := model.StartChat()
192181
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
193182

194183
lastMsg := geminiMessages[len(geminiMessages)-1]
195-
var lastText string
196-
for _, part := range lastMsg.Parts {
197-
if text, ok := part.(genai.Text); ok {
198-
lastText = string(text)
199-
break
200-
}
201-
}
202184

203-
resp, err := chat.SendMessage(ctx, genai.Text(lastText))
185+
resp, err := chat.SendMessage(ctx, lastMsg.Parts...)
204186
// If there is an error we are going to see if we can retry the call
205187
if err != nil {
206188
retry, after, retryErr := g.shouldRetry(attempts, err)
@@ -220,7 +202,6 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
220202
}
221203

222204
content := ""
223-
var toolCalls []message.ToolCall
224205

225206
if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
226207
for _, part := range resp.Candidates[0].Content.Parts {
@@ -231,28 +212,37 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
231212
id := "call_" + uuid.New().String()
232213
args, _ := json.Marshal(p.Args)
233214
toolCalls = append(toolCalls, message.ToolCall{
234-
ID: id,
235-
Name: p.Name,
236-
Input: string(args),
237-
Type: "function",
215+
ID: id,
216+
Name: p.Name,
217+
Input: string(args),
218+
Type: "function",
219+
Finished: true,
238220
})
239221
}
240222
}
241223
}
224+
finishReason := g.finishReason(resp.Candidates[0].FinishReason)
225+
if len(toolCalls) > 0 {
226+
finishReason = message.FinishReasonToolUse
227+
}
242228

243229
return &ProviderResponse{
244230
Content: content,
245231
ToolCalls: toolCalls,
246232
Usage: g.usage(resp),
247-
FinishReason: g.finishReason(resp.Candidates[0].FinishReason),
233+
FinishReason: finishReason,
248234
}, nil
249235
}
250236
}
251237

252238
func (g *geminiClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
253239
model := g.client.GenerativeModel(g.providerOptions.model.APIModel)
254240
model.SetMaxOutputTokens(int32(g.providerOptions.maxTokens))
255-
241+
model.SystemInstruction = &genai.Content{
242+
Parts: []genai.Part{
243+
genai.Text(g.providerOptions.systemMessage),
244+
},
245+
}
256246
// Convert tools
257247
if len(tools) > 0 {
258248
model.Tools = g.convertTools(tools)
@@ -276,18 +266,10 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
276266
for {
277267
attempts++
278268
chat := model.StartChat()
279-
chat.History = geminiMessages[:len(geminiMessages)-1] // All but last message
280-
269+
chat.History = geminiMessages[:len(geminiMessages)-1]
281270
lastMsg := geminiMessages[len(geminiMessages)-1]
282-
var lastText string
283-
for _, part := range lastMsg.Parts {
284-
if text, ok := part.(genai.Text); ok {
285-
lastText = string(text)
286-
break
287-
}
288-
}
289271

290-
iter := chat.SendMessageStream(ctx, genai.Text(lastText))
272+
iter := chat.SendMessageStream(ctx, lastMsg.Parts...)
291273

292274
currentContent := ""
293275
toolCalls := []message.ToolCall{}
@@ -330,23 +312,23 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
330312
for _, part := range resp.Candidates[0].Content.Parts {
331313
switch p := part.(type) {
332314
case genai.Text:
333-
newText := string(p)
334-
delta := newText[len(currentContent):]
315+
delta := string(p)
335316
if delta != "" {
336317
eventChan <- ProviderEvent{
337318
Type: EventContentDelta,
338319
Content: delta,
339320
}
340-
currentContent = newText
321+
currentContent += delta
341322
}
342323
case genai.FunctionCall:
343324
id := "call_" + uuid.New().String()
344325
args, _ := json.Marshal(p.Args)
345326
newCall := message.ToolCall{
346-
ID: id,
347-
Name: p.Name,
348-
Input: string(args),
349-
Type: "function",
327+
ID: id,
328+
Name: p.Name,
329+
Input: string(args),
330+
Type: "function",
331+
Finished: true,
350332
}
351333

352334
isNew := true
@@ -368,37 +350,22 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
368350
eventChan <- ProviderEvent{Type: EventContentStop}
369351

370352
if finalResp != nil {
353+
finishReason := g.finishReason(finalResp.Candidates[0].FinishReason)
354+
if len(toolCalls) > 0 {
355+
finishReason = message.FinishReasonToolUse
356+
}
371357
eventChan <- ProviderEvent{
372358
Type: EventComplete,
373359
Response: &ProviderResponse{
374360
Content: currentContent,
375361
ToolCalls: toolCalls,
376362
Usage: g.usage(finalResp),
377-
FinishReason: g.finishReason(finalResp.Candidates[0].FinishReason),
363+
FinishReason: finishReason,
378364
},
379365
}
380366
return
381367
}
382368

383-
// If we get here, we need to retry
384-
if attempts > maxRetries {
385-
eventChan <- ProviderEvent{
386-
Type: EventError,
387-
Error: fmt.Errorf("maximum retry attempts reached: %d retries", maxRetries),
388-
}
389-
return
390-
}
391-
392-
// Wait before retrying
393-
select {
394-
case <-ctx.Done():
395-
if ctx.Err() != nil {
396-
eventChan <- ProviderEvent{Type: EventError, Error: ctx.Err()}
397-
}
398-
return
399-
case <-time.After(time.Duration(2000*(1<<(attempts-1))) * time.Millisecond):
400-
continue
401-
}
402369
}
403370
}()
404371

0 commit comments

Comments
 (0)