@@ -20,7 +20,7 @@ type Response struct {
20
20
Created uint64 `json:"created"`
21
21
Model string `json:"model"`
22
22
Completions `json:"choices"`
23
- Metrics `json:"usage,omitempty"`
23
+ * Metrics `json:"usage,omitempty"`
24
24
}
25
25
26
26
// Possible completions
@@ -54,78 +54,98 @@ func (r Response) String() string {
54
54
return string (data )
55
55
}
56
56
57
+ func (c Completion ) String () string {
58
+ data , err := json .MarshalIndent (c , "" , " " )
59
+ if err != nil {
60
+ return err .Error ()
61
+ }
62
+ return string (data )
63
+ }
64
+
65
+ func (m Metrics ) String () string {
66
+ data , err := json .MarshalIndent (m , "" , " " )
67
+ if err != nil {
68
+ return err .Error ()
69
+ }
70
+ return string (data )
71
+ }
72
+
57
73
///////////////////////////////////////////////////////////////////////////////
58
74
// PUBLIC METHODS
59
75
60
76
type reqChatCompletion struct {
61
- Model string `json:"model"`
62
- Temperature float64 `json:"temperature,omitempty"`
63
- TopP float64 `json:"top_p,omitempty"`
64
- MaxTokens uint64 `json:"max_tokens,omitempty"`
65
- Stream bool `json:"stream,omitempty"`
66
- StopSequences []string `json:"stop,omitempty"`
67
- Seed uint64 `json:"random_seed,omitempty"`
68
- Messages [] * Message `json:"messages "`
69
- Format any `json:"response_format ,omitempty"`
70
- Tools []llm. Tool `json:"tools ,omitempty"`
71
- ToolChoice any `json:"tool_choice ,omitempty"`
72
- PresencePenalty float64 `json:"presence_penalty ,omitempty"`
73
- FrequencyPenalty float64 `json:"frequency_penalty ,omitempty"`
74
- NumChoices uint64 `json:"n ,omitempty"`
75
- Prediction * Content `json:"prediction ,omitempty"`
76
- SafePrompt bool `json:"safe_prompt,omitempty "`
77
+ Model string `json:"model"`
78
+ Temperature float64 `json:"temperature,omitempty"`
79
+ TopP float64 `json:"top_p,omitempty"`
80
+ MaxTokens uint64 `json:"max_tokens,omitempty"`
81
+ Stream bool `json:"stream,omitempty"`
82
+ StopSequences []string `json:"stop,omitempty"`
83
+ Seed uint64 `json:"random_seed,omitempty"`
84
+ Format any `json:"response_format,omitempty "`
85
+ Tools []llm. Tool `json:"tools ,omitempty"`
86
+ ToolChoice any `json:"tool_choice ,omitempty"`
87
+ PresencePenalty float64 `json:"presence_penalty ,omitempty"`
88
+ FrequencyPenalty float64 `json:"frequency_penalty ,omitempty"`
89
+ NumCompletions uint64 `json:"n ,omitempty"`
90
+ Prediction * Content `json:"prediction ,omitempty"`
91
+ SafePrompt bool `json:"safe_prompt ,omitempty"`
92
+ Messages []llm. Completion `json:"messages "`
77
93
}
78
94
95
+ // Send a completion request with a single prompt, and return the next completion
79
96
func (model * model ) Completion (ctx context.Context , prompt string , opts ... llm.Opt ) (llm.Completion , error ) {
80
- // TODO
81
- return nil , llm .ErrNotImplemented
97
+ message , err := messagefactory {}.UserPrompt (prompt , opts ... )
98
+ if err != nil {
99
+ return nil , err
100
+ }
101
+ return model .Chat (ctx , []llm.Completion {message }, opts ... )
82
102
}
83
103
84
- func (mistral * Client ) ChatCompletion (ctx context.Context , context llm.Context , opts ... llm.Opt ) (* Response , error ) {
104
+ // Send a completion request with multiple completions, and return the next completion
105
+ func (model * model ) Chat (ctx context.Context , completions []llm.Completion , opts ... llm.Opt ) (llm.Completion , error ) {
85
106
// Apply options
86
107
opt , err := llm .ApplyOpts (opts ... )
87
108
if err != nil {
88
109
return nil , err
89
110
}
90
111
91
- // Append the system prompt at the beginning
92
- messages := make ([]* Message , 0 , len (context .( * session ). seq )+ 1 )
112
+ // Create the completions including the system prompt
113
+ messages := make ([]llm. Completion , 0 , len (completions )+ 1 )
93
114
if system := opt .SystemPrompt (); system != "" {
94
- messages = append (messages , systemPrompt (system ))
95
- }
96
-
97
- // Always append the first message of each completion
98
- for _ , message := range context .(* session ).seq {
99
- messages = append (messages , message )
115
+ messages = append (messages , messagefactory {}.SystemPrompt (system ))
100
116
}
117
+ messages = append (messages , completions ... )
101
118
102
119
// Request
103
120
req , err := client .NewJSONRequest (reqChatCompletion {
104
- Model : context .( * session ). model .Name (),
121
+ Model : model .Name (),
105
122
Temperature : optTemperature (opt ),
106
123
TopP : optTopP (opt ),
107
124
MaxTokens : optMaxTokens (opt ),
108
125
Stream : optStream (opt ),
109
126
StopSequences : optStopSequences (opt ),
110
127
Seed : optSeed (opt ),
111
- Messages : messages ,
112
128
Format : optFormat (opt ),
113
- Tools : optTools (mistral , opt ),
129
+ Tools : optTools (model . Client , opt ),
114
130
ToolChoice : optToolChoice (opt ),
115
131
PresencePenalty : optPresencePenalty (opt ),
116
132
FrequencyPenalty : optFrequencyPenalty (opt ),
117
- NumChoices : optNumCompletions (opt ),
133
+ NumCompletions : optNumCompletions (opt ),
118
134
Prediction : optPrediction (opt ),
119
135
SafePrompt : optSafePrompt (opt ),
136
+ Messages : messages ,
120
137
})
121
138
if err != nil {
122
139
return nil , err
123
140
}
124
141
142
+ // Response options
125
143
var response Response
126
144
reqopts := []client.RequestOpt {
127
145
client .OptPath ("chat" , "completions" ),
128
146
}
147
+
148
+ // Streaming
129
149
if optStream (opt ) {
130
150
reqopts = append (reqopts , client .OptTextStreamCallback (func (evt client.TextStreamEvent ) error {
131
151
if err := streamEvent (& response , evt ); err != nil {
@@ -139,7 +159,7 @@ func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context,
139
159
}
140
160
141
161
// Response
142
- if err := mistral .DoWithContext (ctx , req , & response , reqopts ... ); err != nil {
162
+ if err := model .DoWithContext (ctx , req , & response , reqopts ... ); err != nil {
143
163
return nil , err
144
164
}
145
165
@@ -148,7 +168,7 @@ func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context,
148
168
}
149
169
150
170
///////////////////////////////////////////////////////////////////////////////
151
- // PRIVATE METHODS
171
+ // PRIVATE METHODS - STREAMING
152
172
153
173
func streamEvent (response * Response , evt client.TextStreamEvent ) error {
154
174
var delta Response
@@ -164,28 +184,32 @@ func streamEvent(response *Response, evt client.TextStreamEvent) error {
164
184
if delta .Id != "" {
165
185
response .Id = delta .Id
166
186
}
187
+ if delta .Type != "" {
188
+ response .Type = delta .Type
189
+ }
167
190
if delta .Created != 0 {
168
191
response .Created = delta .Created
169
192
}
170
193
if delta .Model != "" {
171
194
response .Model = delta .Model
172
195
}
196
+
197
+ // Append the delta to the response
173
198
for _ , completion := range delta .Completions {
174
- appendCompletion (response , & completion )
175
- }
176
- if delta .Metrics .InputTokens > 0 {
177
- response .Metrics .InputTokens += delta .Metrics .InputTokens
178
- }
179
- if delta .Metrics .OutputTokens > 0 {
180
- response .Metrics .OutputTokens += delta .Metrics .OutputTokens
199
+ if err := appendCompletion (response , & completion ); err != nil {
200
+ return err
201
+ }
181
202
}
182
- if delta .Metrics .TotalTokens > 0 {
183
- response .Metrics .TotalTokens += delta .Metrics .TotalTokens
203
+
204
+ // Apend the metrics to the response
205
+ if delta .Metrics != nil {
206
+ response .Metrics = delta .Metrics
184
207
}
185
208
return nil
186
209
}
187
210
188
- func appendCompletion (response * Response , c * Completion ) {
211
+ func appendCompletion (response * Response , c * Completion ) error {
212
+ // Append a new completion
189
213
for {
190
214
if c .Index < uint64 (len (response .Completions )) {
191
215
break
@@ -200,32 +224,75 @@ func appendCompletion(response *Response, c *Completion) {
200
224
},
201
225
})
202
226
}
203
- // Add the completion delta
227
+
228
+ // Add the reason
204
229
if c .Reason != "" {
205
230
response .Completions [c .Index ].Reason = c .Reason
206
231
}
232
+
233
+ // Get the completion
234
+ message := response .Completions [c .Index ].Message
235
+ if message == nil {
236
+ return llm .ErrBadParameter
237
+ }
238
+
239
+ // Add the role
207
240
if role := c .Delta .Role (); role != "" {
208
- response . Completions [ c . Index ]. Message .RoleContent .Role = role
241
+ message .RoleContent .Role = role
209
242
}
210
243
211
- // TODO: We only allow deltas which are strings at the moment...
212
- if str , ok := c .Delta .Content .(string ); ok && str != "" {
213
- if text , ok := response .Completions [c .Index ].Message .Content .(string ); ok {
214
- response .Completions [c .Index ].Message .Content = text + str
244
+ // We only allow deltas which are strings at the moment
245
+ if c .Delta .Content != nil {
246
+ if str , ok := c .Delta .Content .(string ); ok {
247
+ if text , ok := message .Content .(string ); ok {
248
+ message .Content = text + str
249
+ } else {
250
+ message .Content = str
251
+ }
252
+ } else {
253
+ return llm .ErrNotImplemented .Withf ("appendCompletion not implemented: %T" , c .Delta .Content )
215
254
}
216
255
}
256
+
257
+ // Append tool calls
258
+ for i := range c .Delta .Calls {
259
+ if i >= len (message .Calls ) {
260
+ message .Calls = append (message .Calls , toolcall {})
261
+ }
262
+ }
263
+
264
+ for i , call := range c .Delta .Calls {
265
+ if call .meta .Id != "" {
266
+ message .Calls [i ].meta .Id = call .meta .Id
267
+ }
268
+ if call .meta .Index != 0 {
269
+ message .Calls [i ].meta .Index = call .meta .Index
270
+ }
271
+ if call .meta .Type != "" {
272
+ message .Calls [i ].meta .Type = call .meta .Type
273
+ }
274
+ if call .meta .Function .Name != "" {
275
+ message .Calls [i ].meta .Function .Name = call .meta .Function .Name
276
+ }
277
+ if call .meta .Function .Arguments != "" {
278
+ message .Calls [i ].meta .Function .Arguments += call .meta .Function .Arguments
279
+ }
280
+ }
281
+
282
+ // Return success
283
+ return nil
217
284
}
218
285
219
286
///////////////////////////////////////////////////////////////////////////////
220
- // PUBLIC METHODS - COMPLETIONS
287
+ // COMPLETIONS
221
288
222
289
// Return the number of completions
223
290
func (c Completions ) Num () int {
224
291
return len (c )
225
292
}
226
293
227
294
// Return message for a specific completion
228
- func (c Completions ) Message (index int ) * Message {
295
+ func (c Completions ) Choice (index int ) llm. Completion {
229
296
if index < 0 || index >= len (c ) {
230
297
return nil
231
298
}
@@ -249,6 +316,14 @@ func (c Completions) Text(index int) string {
249
316
return c [index ].Message .Text (0 )
250
317
}
251
318
319
+ // Return audio content for a specific completion
320
+ func (c Completions ) Audio (index int ) * llm.Attachment {
321
+ if index < 0 || index >= len (c ) {
322
+ return nil
323
+ }
324
+ return c [index ].Message .Audio (0 )
325
+ }
326
+
252
327
// Return the current session tool calls given the completion index.
253
328
// Will return nil if no tool calls were returned.
254
329
func (c Completions ) ToolCalls (index int ) []llm.ToolCall {
0 commit comments