Skip to content

Commit 26ea0f0

Browse files
committed
Updated mistral
1 parent 9247d07 commit 26ea0f0

File tree

13 files changed

+221
-351
lines changed

13 files changed

+221
-351
lines changed

cmd/llm/main.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,10 @@ func main() {
127127
if cli.AnthropicKey != "" {
128128
opts = append(opts, agent.WithAnthropic(cli.AnthropicKey, clientopts...))
129129
}
130-
if cli.MistralKey != "" {
131-
opts = append(opts, agent.WithMistral(cli.MistralKey, clientopts...))
132-
}
133130
*/
131+
if cli.MistralKey != "" {
132+
opts = append(opts, agent.WithMistral(cli.MistralKey, clientopts...))
133+
}
134134
if cli.OpenAIKey != "" {
135135
opts = append(opts, agent.WithOpenAI(cli.OpenAIKey, clientopts...))
136136
}
File renamed without changes.
File renamed without changes.

pkg/agent/opt.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
client "github.com/mutablelogic/go-client"
66
llm "github.com/mutablelogic/go-llm"
77
gemini "github.com/mutablelogic/go-llm/pkg/gemini"
8+
"github.com/mutablelogic/go-llm/pkg/mistral"
89
ollama "github.com/mutablelogic/go-llm/pkg/ollama"
910
openai "github.com/mutablelogic/go-llm/pkg/openai"
1011
)
@@ -34,6 +35,7 @@ func WithAnthropic(key string, opts ...client.ClientOpt) llm.Opt {
3435
}
3536
}
3637
}
38+
*/
3739

3840
func WithMistral(key string, opts ...client.ClientOpt) llm.Opt {
3941
return func(o *llm.Opts) error {
@@ -45,7 +47,6 @@ func WithMistral(key string, opts ...client.ClientOpt) llm.Opt {
4547
}
4648
}
4749
}
48-
*/
4950

5051
func WithOpenAI(key string, opts ...client.ClientOpt) llm.Opt {
5152
return func(o *llm.Opts) error {
File renamed without changes.

pkg/mistral/completion.go

+127-52
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type Response struct {
2020
Created uint64 `json:"created"`
2121
Model string `json:"model"`
2222
Completions `json:"choices"`
23-
Metrics `json:"usage,omitempty"`
23+
*Metrics `json:"usage,omitempty"`
2424
}
2525

2626
// Possible completions
@@ -54,78 +54,98 @@ func (r Response) String() string {
5454
return string(data)
5555
}
5656

57+
func (c Completion) String() string {
58+
data, err := json.MarshalIndent(c, "", " ")
59+
if err != nil {
60+
return err.Error()
61+
}
62+
return string(data)
63+
}
64+
65+
func (m Metrics) String() string {
66+
data, err := json.MarshalIndent(m, "", " ")
67+
if err != nil {
68+
return err.Error()
69+
}
70+
return string(data)
71+
}
72+
5773
///////////////////////////////////////////////////////////////////////////////
5874
// PUBLIC METHODS
5975

6076
type reqChatCompletion struct {
61-
Model string `json:"model"`
62-
Temperature float64 `json:"temperature,omitempty"`
63-
TopP float64 `json:"top_p,omitempty"`
64-
MaxTokens uint64 `json:"max_tokens,omitempty"`
65-
Stream bool `json:"stream,omitempty"`
66-
StopSequences []string `json:"stop,omitempty"`
67-
Seed uint64 `json:"random_seed,omitempty"`
68-
Messages []*Message `json:"messages"`
69-
Format any `json:"response_format,omitempty"`
70-
Tools []llm.Tool `json:"tools,omitempty"`
71-
ToolChoice any `json:"tool_choice,omitempty"`
72-
PresencePenalty float64 `json:"presence_penalty,omitempty"`
73-
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
74-
NumChoices uint64 `json:"n,omitempty"`
75-
Prediction *Content `json:"prediction,omitempty"`
76-
SafePrompt bool `json:"safe_prompt,omitempty"`
77+
Model string `json:"model"`
78+
Temperature float64 `json:"temperature,omitempty"`
79+
TopP float64 `json:"top_p,omitempty"`
80+
MaxTokens uint64 `json:"max_tokens,omitempty"`
81+
Stream bool `json:"stream,omitempty"`
82+
StopSequences []string `json:"stop,omitempty"`
83+
Seed uint64 `json:"random_seed,omitempty"`
84+
Format any `json:"response_format,omitempty"`
85+
Tools []llm.Tool `json:"tools,omitempty"`
86+
ToolChoice any `json:"tool_choice,omitempty"`
87+
PresencePenalty float64 `json:"presence_penalty,omitempty"`
88+
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
89+
NumCompletions uint64 `json:"n,omitempty"`
90+
Prediction *Content `json:"prediction,omitempty"`
91+
SafePrompt bool `json:"safe_prompt,omitempty"`
92+
Messages []llm.Completion `json:"messages"`
7793
}
7894

95+
// Send a completion request with a single prompt, and return the next completion
7996
func (model *model) Completion(ctx context.Context, prompt string, opts ...llm.Opt) (llm.Completion, error) {
80-
// TODO
81-
return nil, llm.ErrNotImplemented
97+
message, err := messagefactory{}.UserPrompt(prompt, opts...)
98+
if err != nil {
99+
return nil, err
100+
}
101+
return model.Chat(ctx, []llm.Completion{message}, opts...)
82102
}
83103

84-
func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) {
104+
// Send a completion request with multiple completions, and return the next completion
105+
func (model *model) Chat(ctx context.Context, completions []llm.Completion, opts ...llm.Opt) (llm.Completion, error) {
85106
// Apply options
86107
opt, err := llm.ApplyOpts(opts...)
87108
if err != nil {
88109
return nil, err
89110
}
90111

91-
// Append the system prompt at the beginning
92-
messages := make([]*Message, 0, len(context.(*session).seq)+1)
112+
// Create the completions including the system prompt
113+
messages := make([]llm.Completion, 0, len(completions)+1)
93114
if system := opt.SystemPrompt(); system != "" {
94-
messages = append(messages, systemPrompt(system))
95-
}
96-
97-
// Always append the first message of each completion
98-
for _, message := range context.(*session).seq {
99-
messages = append(messages, message)
115+
messages = append(messages, messagefactory{}.SystemPrompt(system))
100116
}
117+
messages = append(messages, completions...)
101118

102119
// Request
103120
req, err := client.NewJSONRequest(reqChatCompletion{
104-
Model: context.(*session).model.Name(),
121+
Model: model.Name(),
105122
Temperature: optTemperature(opt),
106123
TopP: optTopP(opt),
107124
MaxTokens: optMaxTokens(opt),
108125
Stream: optStream(opt),
109126
StopSequences: optStopSequences(opt),
110127
Seed: optSeed(opt),
111-
Messages: messages,
112128
Format: optFormat(opt),
113-
Tools: optTools(mistral, opt),
129+
Tools: optTools(model.Client, opt),
114130
ToolChoice: optToolChoice(opt),
115131
PresencePenalty: optPresencePenalty(opt),
116132
FrequencyPenalty: optFrequencyPenalty(opt),
117-
NumChoices: optNumCompletions(opt),
133+
NumCompletions: optNumCompletions(opt),
118134
Prediction: optPrediction(opt),
119135
SafePrompt: optSafePrompt(opt),
136+
Messages: messages,
120137
})
121138
if err != nil {
122139
return nil, err
123140
}
124141

142+
// Response options
125143
var response Response
126144
reqopts := []client.RequestOpt{
127145
client.OptPath("chat", "completions"),
128146
}
147+
148+
// Streaming
129149
if optStream(opt) {
130150
reqopts = append(reqopts, client.OptTextStreamCallback(func(evt client.TextStreamEvent) error {
131151
if err := streamEvent(&response, evt); err != nil {
@@ -139,7 +159,7 @@ func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context,
139159
}
140160

141161
// Response
142-
if err := mistral.DoWithContext(ctx, req, &response, reqopts...); err != nil {
162+
if err := model.DoWithContext(ctx, req, &response, reqopts...); err != nil {
143163
return nil, err
144164
}
145165

@@ -148,7 +168,7 @@ func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context,
148168
}
149169

150170
///////////////////////////////////////////////////////////////////////////////
151-
// PRIVATE METHODS
171+
// PRIVATE METHODS - STREAMING
152172

153173
func streamEvent(response *Response, evt client.TextStreamEvent) error {
154174
var delta Response
@@ -164,28 +184,32 @@ func streamEvent(response *Response, evt client.TextStreamEvent) error {
164184
if delta.Id != "" {
165185
response.Id = delta.Id
166186
}
187+
if delta.Type != "" {
188+
response.Type = delta.Type
189+
}
167190
if delta.Created != 0 {
168191
response.Created = delta.Created
169192
}
170193
if delta.Model != "" {
171194
response.Model = delta.Model
172195
}
196+
197+
// Append the delta to the response
173198
for _, completion := range delta.Completions {
174-
appendCompletion(response, &completion)
175-
}
176-
if delta.Metrics.InputTokens > 0 {
177-
response.Metrics.InputTokens += delta.Metrics.InputTokens
178-
}
179-
if delta.Metrics.OutputTokens > 0 {
180-
response.Metrics.OutputTokens += delta.Metrics.OutputTokens
199+
if err := appendCompletion(response, &completion); err != nil {
200+
return err
201+
}
181202
}
182-
if delta.Metrics.TotalTokens > 0 {
183-
response.Metrics.TotalTokens += delta.Metrics.TotalTokens
203+
204+
// Apend the metrics to the response
205+
if delta.Metrics != nil {
206+
response.Metrics = delta.Metrics
184207
}
185208
return nil
186209
}
187210

188-
func appendCompletion(response *Response, c *Completion) {
211+
func appendCompletion(response *Response, c *Completion) error {
212+
// Append a new completion
189213
for {
190214
if c.Index < uint64(len(response.Completions)) {
191215
break
@@ -200,32 +224,75 @@ func appendCompletion(response *Response, c *Completion) {
200224
},
201225
})
202226
}
203-
// Add the completion delta
227+
228+
// Add the reason
204229
if c.Reason != "" {
205230
response.Completions[c.Index].Reason = c.Reason
206231
}
232+
233+
// Get the completion
234+
message := response.Completions[c.Index].Message
235+
if message == nil {
236+
return llm.ErrBadParameter
237+
}
238+
239+
// Add the role
207240
if role := c.Delta.Role(); role != "" {
208-
response.Completions[c.Index].Message.RoleContent.Role = role
241+
message.RoleContent.Role = role
209242
}
210243

211-
// TODO: We only allow deltas which are strings at the moment...
212-
if str, ok := c.Delta.Content.(string); ok && str != "" {
213-
if text, ok := response.Completions[c.Index].Message.Content.(string); ok {
214-
response.Completions[c.Index].Message.Content = text + str
244+
// We only allow deltas which are strings at the moment
245+
if c.Delta.Content != nil {
246+
if str, ok := c.Delta.Content.(string); ok {
247+
if text, ok := message.Content.(string); ok {
248+
message.Content = text + str
249+
} else {
250+
message.Content = str
251+
}
252+
} else {
253+
return llm.ErrNotImplemented.Withf("appendCompletion not implemented: %T", c.Delta.Content)
215254
}
216255
}
256+
257+
// Append tool calls
258+
for i := range c.Delta.Calls {
259+
if i >= len(message.Calls) {
260+
message.Calls = append(message.Calls, toolcall{})
261+
}
262+
}
263+
264+
for i, call := range c.Delta.Calls {
265+
if call.meta.Id != "" {
266+
message.Calls[i].meta.Id = call.meta.Id
267+
}
268+
if call.meta.Index != 0 {
269+
message.Calls[i].meta.Index = call.meta.Index
270+
}
271+
if call.meta.Type != "" {
272+
message.Calls[i].meta.Type = call.meta.Type
273+
}
274+
if call.meta.Function.Name != "" {
275+
message.Calls[i].meta.Function.Name = call.meta.Function.Name
276+
}
277+
if call.meta.Function.Arguments != "" {
278+
message.Calls[i].meta.Function.Arguments += call.meta.Function.Arguments
279+
}
280+
}
281+
282+
// Return success
283+
return nil
217284
}
218285

219286
///////////////////////////////////////////////////////////////////////////////
220-
// PUBLIC METHODS - COMPLETIONS
287+
// COMPLETIONS
221288

222289
// Return the number of completions
223290
func (c Completions) Num() int {
224291
return len(c)
225292
}
226293

227294
// Return message for a specific completion
228-
func (c Completions) Message(index int) *Message {
295+
func (c Completions) Choice(index int) llm.Completion {
229296
if index < 0 || index >= len(c) {
230297
return nil
231298
}
@@ -249,6 +316,14 @@ func (c Completions) Text(index int) string {
249316
return c[index].Message.Text(0)
250317
}
251318

319+
// Return audio content for a specific completion
320+
func (c Completions) Audio(index int) *llm.Attachment {
321+
if index < 0 || index >= len(c) {
322+
return nil
323+
}
324+
return c[index].Message.Audio(0)
325+
}
326+
252327
// Return the current session tool calls given the completion index.
253328
// Will return nil if no tool calls were returned.
254329
func (c Completions) ToolCalls(index int) []llm.ToolCall {

0 commit comments

Comments
 (0)