Skip to content

Commit 15e772f

Browse files
committed
Updated mistral
1 parent c6feba6 commit 15e772f

File tree

8 files changed

+246
-94
lines changed

8 files changed

+246
-94
lines changed

Diff for: pkg/mistral/chat_completion.go

+8-9
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,8 @@ func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context,
7878
}
7979

8080
// Always append the first message of each completion
81-
for _, completion := range context.(*session).seq {
82-
if completion.Num() == 0 {
83-
continue
84-
}
85-
messages = append(messages, completion.Message(0))
81+
for _, message := range context.(*session).seq {
82+
messages = append(messages, message)
8683
}
8784

8885
// Request
@@ -179,17 +176,19 @@ func appendCompletion(response *Response, c *Completion) {
179176
response.Completions = append(response.Completions, Completion{
180177
Index: c.Index,
181178
Message: &Message{
182-
Role: c.Delta.Role,
183-
Content: "",
179+
RoleContent: RoleContent{
180+
Role: c.Delta.Role(),
181+
Content: "",
182+
},
184183
},
185184
})
186185
}
187186
// Add the completion delta
188187
if c.Reason != "" {
189188
response.Completions[c.Index].Reason = c.Reason
190189
}
191-
if c.Delta.Role != "" {
192-
response.Completions[c.Index].Message.Role = c.Delta.Role
190+
if role := c.Delta.Role(); role != "" {
191+
response.Completions[c.Index].Message.RoleContent.Role = role
193192
}
194193

195194
// TODO: We only allow deltas which are strings at the moment...

Diff for: pkg/mistral/chat_completion_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package mistral_test
22

33
import (
44
"context"
5-
"errors"
5+
"fmt"
66
"os"
77
"strings"
88
"testing"
@@ -243,6 +243,6 @@ func (weather) Description() string {
243243
return "Get the weather for a city"
244244
}
245245

246-
func (weather) Run(ctx context.Context) (any, error) {
247-
return nil, errors.New("I couldn't retrieve the weather for that city")
246+
func (w weather) Run(ctx context.Context) (any, error) {
247+
return fmt.Sprintf("The weather in %q is sunny and warm", w.City), nil
248248
}

Diff for: pkg/mistral/message.go

+75-40
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package mistral
22

33
import (
4+
"encoding/json"
5+
46
"github.com/mutablelogic/go-llm"
7+
"github.com/mutablelogic/go-llm/pkg/tool"
58
)
69

710
///////////////////////////////////////////////////////////////////////////////
@@ -12,6 +15,21 @@ type Completions []Completion
1215

1316
var _ llm.Completion = Completions{}
1417

18+
// Message with text or object content
19+
type Message struct {
20+
RoleContent
21+
ToolCallArray `json:"tool_calls,omitempty"`
22+
}
23+
24+
type RoleContent struct {
25+
Role string `json:"role,omitempty"` // assistant, user, tool, system
26+
Id string `json:"tool_call_id,omitempty"` // tool call - when role is tool
27+
Name string `json:"name,omitempty"` // function name - when role is tool
28+
Content any `json:"content,omitempty"` // string or array of text, reference, image_url
29+
}
30+
31+
var _ llm.Completion = (*Message)(nil)
32+
1533
// Completion Variation
1634
type Completion struct {
1735
Index uint64 `json:"index"`
@@ -20,23 +38,15 @@ type Completion struct {
2038
Reason string `json:"finish_reason,omitempty"`
2139
}
2240

23-
// Message with text or object content
24-
type Message struct {
25-
Role string `json:"role,omitempty"` // assistant, user, tool, system
26-
Prefix bool `json:"prefix,omitempty"`
27-
Content any `json:"content,omitempty"`
28-
ToolCalls `json:"tool_calls,omitempty"`
29-
}
30-
3141
type Content struct {
32-
Type string `json:"type"` // text, reference, image_url
42+
Type string `json:"type,omitempty"` // text, reference, image_url
3343
*Text `json:"text,omitempty"` // text content
3444
*Prediction `json:"content,omitempty"` // prediction
3545
*Image `json:"image_url,omitempty"` // image_url
3646
}
3747

3848
// A set of tool calls
39-
type ToolCalls []ToolCall
49+
type ToolCallArray []ToolCall
4050

4151
// text content
4252
type Text string
@@ -78,62 +88,87 @@ func NewImageAttachment(a *llm.Attachment) *Content {
7888
}
7989

8090
///////////////////////////////////////////////////////////////////////////////
81-
// PUBLIC METHODS
91+
// PUBLIC METHODS - MESSAGE
8292

83-
// Return the number of completions
84-
func (c Completions) Num() int {
85-
return len(c)
93+
func (m Message) Num() int {
94+
return 1
8695
}
8796

88-
// Return the role of the completion
89-
func (c Completions) Role() string {
90-
// The role should be the same for all completions, let's use the first one
91-
if len(c) == 0 {
92-
return ""
93-
}
94-
return c[0].Message.Role
97+
func (m Message) Role() string {
98+
return m.RoleContent.Role
9599
}
96100

97-
// Return the text content for a specific completion
98-
func (c Completions) Text(index int) string {
99-
if index < 0 || index >= len(c) {
101+
func (m Message) Text(index int) string {
102+
if index != 0 {
100103
return ""
101104
}
102-
completion := c[index].Message
103-
if text, ok := completion.Content.(string); ok {
105+
// If content is text, return it
106+
if text, ok := m.Content.(string); ok {
104107
return text
105108
}
106-
// TODO: Will the text be in other forms?
109+
// For other kinds, return empty string for the moment
107110
return ""
108111
}
109112

110-
// Return the current session tool calls given the completion index.
111-
// Will return nil if no tool calls were returned.
112-
func (c Completions) ToolCalls(index int) []llm.ToolCall {
113-
if index < 0 || index >= len(c) {
114-
return nil
115-
}
116-
117-
// Get the completion
118-
completion := c[index].Message
119-
if completion == nil {
113+
func (m Message) ToolCalls(index int) []llm.ToolCall {
114+
if index != 0 {
120115
return nil
121116
}
122117

123118
// Make the tool calls
124-
calls := make([]llm.ToolCall, 0, len(completion.ToolCalls))
125-
for _, call := range completion.ToolCalls {
126-
calls = append(calls, &toolcall{call})
119+
calls := make([]llm.ToolCall, 0, len(m.ToolCallArray))
120+
for _, call := range m.ToolCallArray {
121+
var args map[string]any
122+
if call.Function.Arguments != "" {
123+
if err := json.Unmarshal([]byte(call.Function.Arguments), &args); err != nil {
124+
return nil
125+
}
126+
}
127+
calls = append(calls, tool.NewCall(call.Id, call.Function.Name, args))
127128
}
128129

129130
// Return success
130131
return calls
131132
}
132133

134+
///////////////////////////////////////////////////////////////////////////////
135+
// PUBLIC METHODS - COMPLETIONS
136+
137+
// Return the number of completions
138+
func (c Completions) Num() int {
139+
return len(c)
140+
}
141+
133142
// Return message for a specific completion
134143
func (c Completions) Message(index int) *Message {
135144
if index < 0 || index >= len(c) {
136145
return nil
137146
}
138147
return c[index].Message
139148
}
149+
150+
// Return the role of the completion
151+
func (c Completions) Role() string {
152+
// The role should be the same for all completions, let's use the first one
153+
if len(c) == 0 {
154+
return ""
155+
}
156+
return c[0].Message.Role()
157+
}
158+
159+
// Return the text content for a specific completion
160+
func (c Completions) Text(index int) string {
161+
if index < 0 || index >= len(c) {
162+
return ""
163+
}
164+
return c[index].Message.Text(0)
165+
}
166+
167+
// Return the current session tool calls given the completion index.
168+
// Will return nil if no tool calls were returned.
169+
func (c Completions) ToolCalls(index int) []llm.ToolCall {
170+
if index < 0 || index >= len(c) {
171+
return nil
172+
}
173+
return c[index].Message.ToolCalls(0)
174+
}

Diff for: pkg/mistral/model.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
// TYPES
1313

1414
type model struct {
15+
*Client
1516
meta Model
1617
}
1718

@@ -65,7 +66,7 @@ func (c *Client) ListModels(ctx context.Context) ([]llm.Model, error) {
6566
// Make models
6667
result := make([]llm.Model, 0, len(response.Data))
6768
for _, meta := range response.Data {
68-
result = append(result, &model{meta: meta})
69+
result = append(result, &model{c, meta})
6970
}
7071

7172
// Return models

0 commit comments

Comments
 (0)