Skip to content

Commit 4aa838a

Browse files
authored
Merge pull request #6054 from maxbrunet/feat/cohere-refresh-2025
✨ refresh Cohere support
2 parents 4d63b4c + 6cf992e commit 4aa838a

File tree

11 files changed

+531
-63
lines changed

11 files changed

+531
-63
lines changed

core/llm/autodetect.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ const PROVIDER_HANDLES_TEMPLATING: string[] = [
5252
"msty",
5353
"anthropic",
5454
"bedrock",
55+
"cohere",
5556
"sagemaker",
5657
"continue-proxy",
5758
"mistral",
@@ -65,6 +66,7 @@ const PROVIDER_HANDLES_TEMPLATING: string[] = [
6566
const PROVIDER_SUPPORTS_IMAGES: string[] = [
6667
"openai",
6768
"ollama",
69+
"cohere",
6870
"gemini",
6971
"msty",
7072
"anthropic",
@@ -89,6 +91,8 @@ const MODEL_SUPPORTS_IMAGES: string[] = [
8991
"gpt-4o-mini",
9092
"gpt-4-vision",
9193
"claude-3",
94+
"c4ai-aya-vision-8b",
95+
"c4ai-aya-vision-32b",
9296
"gemini-ultra",
9397
"gemini-1.5-pro",
9498
"gemini-1.5-flash",
@@ -140,6 +144,7 @@ function modelSupportsImages(
140144
const PARALLEL_PROVIDERS: string[] = [
141145
"anthropic",
142146
"bedrock",
147+
"cohere",
143148
"sagemaker",
144149
"deepinfra",
145150
"gemini",
@@ -176,6 +181,7 @@ function autodetectTemplateType(model: string): TemplateType | undefined {
176181
if (
177182
lower.includes("gpt") ||
178183
lower.includes("command") ||
184+
lower.includes("aya") ||
179185
lower.includes("chat-bison") ||
180186
lower.includes("pplx") ||
181187
lower.includes("gemini") ||

core/llm/llms/Cohere.ts

Lines changed: 173 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,100 @@
1-
import { streamJSON } from "@continuedev/fetch";
1+
import { streamSse } from "@continuedev/fetch";
22
import {
33
ChatMessage,
44
Chunk,
55
CompletionOptions,
66
LLMOptions,
7+
MessageContent,
78
} from "../../index.js";
89
import { renderChatMessage, stripImages } from "../../util/messageContent.js";
910
import { BaseLLM } from "../index.js";
1011

1112
class Cohere extends BaseLLM {
1213
static providerName = "cohere";
1314
static defaultOptions: Partial<LLMOptions> = {
14-
apiBase: "https://api.cohere.ai/v1",
15+
apiBase: "https://api.cohere.ai/v2",
1516
maxEmbeddingBatchSize: 96,
1617
};
1718
static maxStopSequences = 5;
1819

1920
private _convertMessages(msgs: ChatMessage[]): any[] {
2021
const messages = [];
22+
let lastToolPlan: MessageContent | undefined;
2123
for (const m of msgs) {
22-
if (m.role === "system" || !m.content) {
24+
if (!m.content) {
2325
continue;
2426
}
25-
messages.push({
26-
role: m.role === "assistant" ? "chatbot" : m.role,
27-
message: m.content,
28-
});
27+
switch (m.role) {
28+
case "user":
29+
if (typeof m.content === "string") {
30+
messages.push({
31+
role: m.role,
32+
content: m.content,
33+
});
34+
break;
35+
}
36+
37+
messages.push({
38+
role: m.role,
39+
content: m.content.map((part) => {
40+
if (part.type === "imageUrl") {
41+
return {
42+
type: "image_url",
43+
image_url: { url: part.imageUrl.url },
44+
};
45+
}
46+
return part;
47+
}),
48+
});
49+
break;
50+
case "thinking":
51+
lastToolPlan = m.content;
52+
break;
53+
case "assistant":
54+
if (m.toolCalls) {
55+
if (!lastToolPlan) {
56+
throw new Error("No tool plan found");
57+
}
58+
messages.push({
59+
role: m.role,
60+
tool_calls: m.toolCalls.map((toolCall) => ({
61+
id: toolCall.id,
62+
type: "function",
63+
function: {
64+
name: toolCall.function?.name,
65+
arguments: toolCall.function?.arguments,
66+
},
67+
})),
68+
// Ideally the tool plan would be in this message, but it is
69+
// split in another, usually the previous, this one's content is
70+
// a space.
71+
// tool_plan: m.content,
72+
tool_plan: lastToolPlan,
73+
});
74+
lastToolPlan = undefined;
75+
break;
76+
}
77+
messages.push({
78+
role: m.role,
79+
content: m.content,
80+
});
81+
break;
82+
case "system":
83+
messages.push({
84+
role: m.role,
85+
content: stripImages(m.content),
86+
});
87+
break;
88+
case "tool":
89+
messages.push({
90+
role: m.role,
91+
content: m.content,
92+
tool_call_id: m.toolCallId,
93+
});
94+
break;
95+
default:
96+
break;
97+
}
2998
}
3099
return messages;
31100
}
@@ -41,7 +110,14 @@ class Cohere extends BaseLLM {
41110
stop_sequences: options.stop?.slice(0, Cohere.maxStopSequences),
42111
frequency_penalty: options.frequencyPenalty,
43112
presence_penalty: options.presencePenalty,
44-
raw_prompting: options.raw,
113+
tools: options.tools?.map((tool) => ({
114+
type: "function",
115+
function: {
116+
name: tool.function.name,
117+
parameters: tool.function.parameters,
118+
description: tool.function.description,
119+
},
120+
})),
45121
};
46122
}
47123

@@ -67,19 +143,12 @@ class Cohere extends BaseLLM {
67143
...this.requestOptions?.headers,
68144
};
69145

70-
let preamble: string | undefined = undefined;
71-
const systemMessage = messages.find((m) => m.role === "system")?.content;
72-
if (systemMessage) {
73-
preamble = stripImages(systemMessage);
74-
}
75146
const resp = await this.fetch(new URL("chat", this.apiBase), {
76147
method: "POST",
77148
headers,
78149
body: JSON.stringify({
79150
...this._convertArgs(options),
80-
message: messages.pop()?.content,
81-
chat_history: this._convertMessages(messages),
82-
preamble,
151+
messages: this._convertMessages(messages),
83152
}),
84153
signal,
85154
});
@@ -90,13 +159,97 @@ class Cohere extends BaseLLM {
90159

91160
if (options.stream === false) {
92161
const data = await resp.json();
93-
yield { role: "assistant", content: data.text };
162+
if (data.message.tool_calls) {
163+
yield {
164+
// Use the "thinking" role for `tool_plan`, since there is no such
165+
// role in the Cohere API at the moment and it is a "a
166+
// chain-of-thought style reflection".
167+
role: "thinking",
168+
content: data.message.tool_plan,
169+
};
170+
yield {
171+
role: "assistant",
172+
content: "",
173+
toolCalls: data.message.tool_calls.map((toolCall: any) => ({
174+
id: toolCall.id,
175+
type: "function",
176+
function: {
177+
name: toolCall.function?.name,
178+
arguments: toolCall.function?.arguments,
179+
},
180+
})),
181+
};
182+
return;
183+
}
184+
yield { role: "assistant", content: data.message.content[0].text };
94185
return;
95186
}
96187

97-
for await (const value of streamJSON(resp)) {
98-
if (value.event_type === "text-generation") {
99-
yield { role: "assistant", content: value.text };
188+
let lastToolUseId: string | undefined;
189+
let lastToolUseName: string | undefined;
190+
for await (const value of streamSse(resp)) {
191+
// https://docs.cohere.com/v2/docs/streaming#stream-events
192+
switch (value.type) {
193+
// https://docs.cohere.com/v2/docs/streaming#content-delta
194+
case "content-delta":
195+
yield {
196+
role: "assistant",
197+
content: value.delta.message.content.text,
198+
};
199+
break;
200+
// https://docs.cohere.com/reference/chat-stream#request.body.messages.assistant.tool_plan
201+
case "tool-plan-delta":
202+
// Use the "thinking" role for `tool_plan`, since there is no such
203+
// role in the Cohere API at the moment and it is a "a
204+
// chain-of-thought style reflection".
205+
yield {
206+
role: "thinking",
207+
content: value.delta.message.tool_plan,
208+
};
209+
break;
210+
case "tool-call-start":
211+
lastToolUseId = value.delta.message.tool_calls.id;
212+
lastToolUseName = value.delta.message.tool_calls.function.name;
213+
yield {
214+
role: "assistant",
215+
content: "",
216+
toolCalls: [
217+
{
218+
id: lastToolUseId,
219+
type: "function",
220+
function: {
221+
name: lastToolUseName,
222+
arguments: value.delta.message.tool_calls.function.arguments,
223+
},
224+
},
225+
],
226+
};
227+
break;
228+
case "tool-call-delta":
229+
if (!lastToolUseId || !lastToolUseName) {
230+
throw new Error("No tool use found");
231+
}
232+
yield {
233+
role: "assistant",
234+
content: "",
235+
toolCalls: [
236+
{
237+
id: lastToolUseId,
238+
type: "function",
239+
function: {
240+
name: lastToolUseName,
241+
arguments: value.delta.message.tool_calls.function.arguments,
242+
},
243+
},
244+
],
245+
};
246+
break;
247+
case "tool-call-end":
248+
lastToolUseId = undefined;
249+
lastToolUseName = undefined;
250+
break;
251+
default:
252+
break;
100253
}
101254
}
102255
}

core/llm/toolSupport.test.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,19 @@ describe("PROVIDER_TOOL_SUPPORT", () => {
109109
});
110110
});
111111

112+
describe("cohere", () => {
113+
const supportsFn = PROVIDER_TOOL_SUPPORT["cohere"];
114+
115+
it("should return true for Command models", () => {
116+
expect(supportsFn("command-r")).toBe(true);
117+
expect(supportsFn("command-a")).toBe(true);
118+
});
119+
120+
it("should return false for other models", () => {
121+
expect(supportsFn("c4ai-aya-expanse-32b")).toBe(false);
122+
});
123+
});
124+
112125
describe("gemini", () => {
113126
const supportsFn = PROVIDER_TOOL_SUPPORT["gemini"];
114127

@@ -221,6 +234,7 @@ describe("PROVIDER_TOOL_SUPPORT", () => {
221234
expect(supportsFn("qwen2")).toBe(true);
222235
expect(supportsFn("mixtral-8x7b")).toBe(true);
223236
expect(supportsFn("command-r")).toBe(true);
237+
expect(supportsFn("command-a")).toBe(true);
224238
expect(supportsFn("smollm2")).toBe(true);
225239
expect(supportsFn("hermes3")).toBe(true);
226240
expect(supportsFn("athene-v2")).toBe(true);

core/llm/toolSupport.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ export const PROVIDER_TOOL_SUPPORT: Record<string, (model: string) => boolean> =
7171

7272
return false;
7373
},
74+
cohere: (model) => {
75+
return model.toLowerCase().startsWith("command");
76+
},
7477
gemini: (model) => {
7578
// All gemini models support function calling
7679
return model.toLowerCase().includes("gemini");
@@ -144,6 +147,7 @@ export const PROVIDER_TOOL_SUPPORT: Record<string, (model: string) => boolean> =
144147
"qwen3",
145148
"mixtral",
146149
"command-r",
150+
"command-a",
147151
"smollm2",
148152
"hermes3",
149153
"athene-v2",
@@ -226,6 +230,7 @@ export const PROVIDER_TOOL_SUPPORT: Record<string, (model: string) => boolean> =
226230
"qwen/qwen3",
227231
"qwen/qwen-",
228232
"cohere/command-r",
233+
"cohere/command-a",
229234
"ai21/jamba-1.6",
230235
"mistralai/mistral",
231236
"mistralai/ministral",

0 commit comments

Comments
 (0)