1
- import { streamJSON } from "@continuedev/fetch" ;
1
+ import { streamSse } from "@continuedev/fetch" ;
2
2
import {
3
3
ChatMessage ,
4
4
Chunk ,
5
5
CompletionOptions ,
6
6
LLMOptions ,
7
+ MessageContent ,
7
8
} from "../../index.js" ;
8
9
import { renderChatMessage , stripImages } from "../../util/messageContent.js" ;
9
10
import { BaseLLM } from "../index.js" ;
10
11
11
12
class Cohere extends BaseLLM {
12
13
static providerName = "cohere" ;
13
14
static defaultOptions : Partial < LLMOptions > = {
14
- apiBase : "https://api.cohere.ai/v1 " ,
15
+ apiBase : "https://api.cohere.ai/v2 " ,
15
16
maxEmbeddingBatchSize : 96 ,
16
17
} ;
17
18
static maxStopSequences = 5 ;
18
19
19
20
private _convertMessages ( msgs : ChatMessage [ ] ) : any [ ] {
20
21
const messages = [ ] ;
22
+ let lastToolPlan : MessageContent | undefined ;
21
23
for ( const m of msgs ) {
22
- if ( m . role === "system" || ! m . content ) {
24
+ if ( ! m . content ) {
23
25
continue ;
24
26
}
25
- messages . push ( {
26
- role : m . role === "assistant" ? "chatbot" : m . role ,
27
- message : m . content ,
28
- } ) ;
27
+ switch ( m . role ) {
28
+ case "user" :
29
+ if ( typeof m . content === "string" ) {
30
+ messages . push ( {
31
+ role : m . role ,
32
+ content : m . content ,
33
+ } ) ;
34
+ break ;
35
+ }
36
+
37
+ messages . push ( {
38
+ role : m . role ,
39
+ content : m . content . map ( ( part ) => {
40
+ if ( part . type === "imageUrl" ) {
41
+ return {
42
+ type : "image_url" ,
43
+ image_url : { url : part . imageUrl . url } ,
44
+ } ;
45
+ }
46
+ return part ;
47
+ } ) ,
48
+ } ) ;
49
+ break ;
50
+ case "thinking" :
51
+ lastToolPlan = m . content ;
52
+ break ;
53
+ case "assistant" :
54
+ if ( m . toolCalls ) {
55
+ if ( ! lastToolPlan ) {
56
+ throw new Error ( "No tool plan found" ) ;
57
+ }
58
+ messages . push ( {
59
+ role : m . role ,
60
+ tool_calls : m . toolCalls . map ( ( toolCall ) => ( {
61
+ id : toolCall . id ,
62
+ type : "function" ,
63
+ function : {
64
+ name : toolCall . function ?. name ,
65
+ arguments : toolCall . function ?. arguments ,
66
+ } ,
67
+ } ) ) ,
68
+ // Ideally the tool plan would be in this message, but it is
69
+ // split in another, usually the previous, this one's content is
70
+ // a space.
71
+ // tool_plan: m.content,
72
+ tool_plan : lastToolPlan ,
73
+ } ) ;
74
+ lastToolPlan = undefined ;
75
+ break ;
76
+ }
77
+ messages . push ( {
78
+ role : m . role ,
79
+ content : m . content ,
80
+ } ) ;
81
+ break ;
82
+ case "system" :
83
+ messages . push ( {
84
+ role : m . role ,
85
+ content : stripImages ( m . content ) ,
86
+ } ) ;
87
+ break ;
88
+ case "tool" :
89
+ messages . push ( {
90
+ role : m . role ,
91
+ content : m . content ,
92
+ tool_call_id : m . toolCallId ,
93
+ } ) ;
94
+ break ;
95
+ default :
96
+ break ;
97
+ }
29
98
}
30
99
return messages ;
31
100
}
@@ -41,7 +110,14 @@ class Cohere extends BaseLLM {
41
110
stop_sequences : options . stop ?. slice ( 0 , Cohere . maxStopSequences ) ,
42
111
frequency_penalty : options . frequencyPenalty ,
43
112
presence_penalty : options . presencePenalty ,
44
- raw_prompting : options . raw ,
113
+ tools : options . tools ?. map ( ( tool ) => ( {
114
+ type : "function" ,
115
+ function : {
116
+ name : tool . function . name ,
117
+ parameters : tool . function . parameters ,
118
+ description : tool . function . description ,
119
+ } ,
120
+ } ) ) ,
45
121
} ;
46
122
}
47
123
@@ -67,19 +143,12 @@ class Cohere extends BaseLLM {
67
143
...this . requestOptions ?. headers ,
68
144
} ;
69
145
70
- let preamble : string | undefined = undefined ;
71
- const systemMessage = messages . find ( ( m ) => m . role === "system" ) ?. content ;
72
- if ( systemMessage ) {
73
- preamble = stripImages ( systemMessage ) ;
74
- }
75
146
const resp = await this . fetch ( new URL ( "chat" , this . apiBase ) , {
76
147
method : "POST" ,
77
148
headers,
78
149
body : JSON . stringify ( {
79
150
...this . _convertArgs ( options ) ,
80
- message : messages . pop ( ) ?. content ,
81
- chat_history : this . _convertMessages ( messages ) ,
82
- preamble,
151
+ messages : this . _convertMessages ( messages ) ,
83
152
} ) ,
84
153
signal,
85
154
} ) ;
@@ -90,13 +159,97 @@ class Cohere extends BaseLLM {
90
159
91
160
if ( options . stream === false ) {
92
161
const data = await resp . json ( ) ;
93
- yield { role : "assistant" , content : data . text } ;
162
+ if ( data . message . tool_calls ) {
163
+ yield {
164
+ // Use the "thinking" role for `tool_plan`, since there is no such
165
+ // role in the Cohere API at the moment and it is a "a
166
+ // chain-of-thought style reflection".
167
+ role : "thinking" ,
168
+ content : data . message . tool_plan ,
169
+ } ;
170
+ yield {
171
+ role : "assistant" ,
172
+ content : "" ,
173
+ toolCalls : data . message . tool_calls . map ( ( toolCall : any ) => ( {
174
+ id : toolCall . id ,
175
+ type : "function" ,
176
+ function : {
177
+ name : toolCall . function ?. name ,
178
+ arguments : toolCall . function ?. arguments ,
179
+ } ,
180
+ } ) ) ,
181
+ } ;
182
+ return ;
183
+ }
184
+ yield { role : "assistant" , content : data . message . content [ 0 ] . text } ;
94
185
return ;
95
186
}
96
187
97
- for await ( const value of streamJSON ( resp ) ) {
98
- if ( value . event_type === "text-generation" ) {
99
- yield { role : "assistant" , content : value . text } ;
188
+ let lastToolUseId : string | undefined ;
189
+ let lastToolUseName : string | undefined ;
190
+ for await ( const value of streamSse ( resp ) ) {
191
+ // https://docs.cohere.com/v2/docs/streaming#stream-events
192
+ switch ( value . type ) {
193
+ // https://docs.cohere.com/v2/docs/streaming#content-delta
194
+ case "content-delta" :
195
+ yield {
196
+ role : "assistant" ,
197
+ content : value . delta . message . content . text ,
198
+ } ;
199
+ break ;
200
+ // https://docs.cohere.com/reference/chat-stream#request.body.messages.assistant.tool_plan
201
+ case "tool-plan-delta" :
202
+ // Use the "thinking" role for `tool_plan`, since there is no such
203
+ // role in the Cohere API at the moment and it is a "a
204
+ // chain-of-thought style reflection".
205
+ yield {
206
+ role : "thinking" ,
207
+ content : value . delta . message . tool_plan ,
208
+ } ;
209
+ break ;
210
+ case "tool-call-start" :
211
+ lastToolUseId = value . delta . message . tool_calls . id ;
212
+ lastToolUseName = value . delta . message . tool_calls . function . name ;
213
+ yield {
214
+ role : "assistant" ,
215
+ content : "" ,
216
+ toolCalls : [
217
+ {
218
+ id : lastToolUseId ,
219
+ type : "function" ,
220
+ function : {
221
+ name : lastToolUseName ,
222
+ arguments : value . delta . message . tool_calls . function . arguments ,
223
+ } ,
224
+ } ,
225
+ ] ,
226
+ } ;
227
+ break ;
228
+ case "tool-call-delta" :
229
+ if ( ! lastToolUseId || ! lastToolUseName ) {
230
+ throw new Error ( "No tool use found" ) ;
231
+ }
232
+ yield {
233
+ role : "assistant" ,
234
+ content : "" ,
235
+ toolCalls : [
236
+ {
237
+ id : lastToolUseId ,
238
+ type : "function" ,
239
+ function : {
240
+ name : lastToolUseName ,
241
+ arguments : value . delta . message . tool_calls . function . arguments ,
242
+ } ,
243
+ } ,
244
+ ] ,
245
+ } ;
246
+ break ;
247
+ case "tool-call-end" :
248
+ lastToolUseId = undefined ;
249
+ lastToolUseName = undefined ;
250
+ break ;
251
+ default :
252
+ break ;
100
253
}
101
254
}
102
255
}
0 commit comments