Skip to content

Commit 4040c73

Browse files
committed
feat: support Qwen models provided by Alibaba Cloud
This implementation relies on the Alibaba Cloud official SDK. API reference: https://www.alibabacloud.com/help/en/model-studio/use-qwen-by-calling-api How to obtain an API-KEY: https://www.alibabacloud.com/help/en/model-studio/get-api-key Signed-off-by: jiangsier-xyz <[email protected]>
1 parent e723371 commit 4040c73

File tree

21 files changed

+3671
-1
lines changed

21 files changed

+3671
-1
lines changed

models/spring-ai-qwen/pom.xml

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<!--
3+
~ Copyright 2023-2024 the original author or authors.
4+
~
5+
~ Licensed under the Apache License, Version 2.0 (the "License");
6+
~ you may not use this file except in compliance with the License.
7+
~ You may obtain a copy of the License at
8+
~
9+
~ https://www.apache.org/licenses/LICENSE-2.0
10+
~
11+
~ Unless required by applicable law or agreed to in writing, software
12+
~ distributed under the License is distributed on an "AS IS" BASIS,
13+
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
~ See the License for the specific language governing permissions and
15+
~ limitations under the License.
16+
-->
17+
18+
<project xmlns="http://maven.apache.org/POM/4.0.0"
19+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
20+
<modelVersion>4.0.0</modelVersion>
21+
<parent>
22+
<groupId>org.springframework.ai</groupId>
23+
<artifactId>spring-ai-parent</artifactId>
24+
<version>1.0.0-SNAPSHOT</version>
25+
<relativePath>../../pom.xml</relativePath>
26+
</parent>
27+
<artifactId>spring-ai-qwen</artifactId>
28+
<packaging>jar</packaging>
29+
<name>Spring AI Model - Qwen</name>
30+
<description>Qwen models support</description>
31+
<url>https://github.com/spring-projects/spring-ai</url>
32+
33+
<scm>
34+
<url>https://github.com/spring-projects/spring-ai</url>
35+
<connection>git://github.com/spring-projects/spring-ai.git</connection>
36+
<developerConnection>[email protected]:spring-projects/spring-ai.git</developerConnection>
37+
</scm>
38+
39+
<properties>
40+
<dashscope.version>2.18.4</dashscope.version>
41+
</properties>
42+
43+
<dependencies>
44+
45+
<!-- production dependencies -->
46+
<dependency>
47+
<groupId>org.springframework.ai</groupId>
48+
<artifactId>spring-ai-client-chat</artifactId>
49+
<version>${project.parent.version}</version>
50+
</dependency>
51+
52+
<dependency>
53+
<groupId>org.springframework.ai</groupId>
54+
<artifactId>spring-ai-core</artifactId>
55+
<version>${project.parent.version}</version>
56+
</dependency>
57+
58+
<dependency>
59+
<groupId>org.springframework</groupId>
60+
<artifactId>spring-context-support</artifactId>
61+
</dependency>
62+
63+
<dependency>
64+
<groupId>org.slf4j</groupId>
65+
<artifactId>slf4j-api</artifactId>
66+
</dependency>
67+
68+
<dependency>
69+
<groupId>com.alibaba</groupId>
70+
<artifactId>dashscope-sdk-java</artifactId>
71+
<version>${dashscope.version}</version>
72+
<exclusions>
73+
<exclusion>
74+
<groupId>org.slf4j</groupId>
75+
<artifactId>slf4j-simple</artifactId>
76+
</exclusion>
77+
</exclusions>
78+
</dependency>
79+
80+
<!-- test dependencies -->
81+
<dependency>
82+
<groupId>org.springframework.ai</groupId>
83+
<artifactId>spring-ai-test</artifactId>
84+
<version>${project.version}</version>
85+
<scope>test</scope>
86+
</dependency>
87+
</dependencies>
88+
89+
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
package org.springframework.ai.qwen;
2+
3+
import io.micrometer.observation.Observation;
4+
import io.micrometer.observation.ObservationRegistry;
5+
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
6+
import org.springframework.ai.chat.model.ChatModel;
7+
import org.springframework.ai.chat.model.ChatResponse;
8+
import org.springframework.ai.chat.model.MessageAggregator;
9+
import org.springframework.ai.chat.observation.ChatModelObservationContext;
10+
import org.springframework.ai.chat.observation.ChatModelObservationConvention;
11+
import org.springframework.ai.chat.observation.ChatModelObservationDocumentation;
12+
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
13+
import org.springframework.ai.chat.prompt.ChatOptions;
14+
import org.springframework.ai.chat.prompt.Prompt;
15+
import org.springframework.ai.model.ModelOptionsUtils;
16+
import org.springframework.ai.model.function.FunctionCallingOptions;
17+
import org.springframework.ai.model.tool.DefaultToolExecutionEligibilityPredicate;
18+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
19+
import org.springframework.ai.model.tool.ToolCallingManager;
20+
import org.springframework.ai.model.tool.ToolExecutionEligibilityPredicate;
21+
import org.springframework.ai.model.tool.ToolExecutionResult;
22+
import org.springframework.ai.observation.conventions.AiProvider;
23+
import org.springframework.ai.qwen.api.QwenApi;
24+
import org.springframework.ai.qwen.api.QwenModel;
25+
import org.springframework.util.Assert;
26+
import reactor.core.publisher.Flux;
27+
import reactor.core.scheduler.Schedulers;
28+
29+
import static org.springframework.ai.qwen.api.QwenApiHelper.getOrDefault;
30+
31+
public class QwenChatModel implements ChatModel {
32+
33+
private static final ChatModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultChatModelObservationConvention();
34+
35+
private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build();
36+
37+
private final QwenApi qwenApi;
38+
39+
private final QwenChatOptions defaultOptions;
40+
41+
private final ObservationRegistry observationRegistry;
42+
43+
private final ToolCallingManager toolCallingManager;
44+
45+
private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate;
46+
47+
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
48+
49+
public QwenChatModel(QwenApi openAiApi, QwenChatOptions defaultOptions, ToolCallingManager toolCallingManager,
50+
ObservationRegistry observationRegistry) {
51+
this(openAiApi, defaultOptions, toolCallingManager, observationRegistry,
52+
new DefaultToolExecutionEligibilityPredicate());
53+
}
54+
55+
public QwenChatModel(QwenApi qwenApi, QwenChatOptions defaultOptions, ToolCallingManager toolCallingManager,
56+
ObservationRegistry observationRegistry,
57+
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
58+
Assert.notNull(qwenApi, "qwenApi cannot be null");
59+
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
60+
Assert.notNull(observationRegistry, "observationRegistry cannot be null");
61+
Assert.notNull(toolExecutionEligibilityPredicate, "toolExecutionEligibilityPredicate cannot be null");
62+
this.qwenApi = qwenApi;
63+
this.defaultOptions = defaultOptions;
64+
this.toolCallingManager = getOrDefault(toolCallingManager, DEFAULT_TOOL_CALLING_MANAGER);
65+
this.observationRegistry = observationRegistry;
66+
this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
67+
}
68+
69+
public static Builder builder() {
70+
return new Builder();
71+
}
72+
73+
@Override
74+
public ChatResponse call(Prompt prompt) {
75+
Prompt requestPrompt = buildRequestPrompt(prompt);
76+
return internalCall(requestPrompt, null);
77+
}
78+
79+
@Override
80+
public Flux<ChatResponse> stream(Prompt prompt) {
81+
Prompt requestPrompt = buildRequestPrompt(prompt);
82+
return this.internalStream(requestPrompt, null);
83+
}
84+
85+
@Override
86+
public ChatOptions getDefaultOptions() {
87+
return QwenChatOptions.fromOptions(this.defaultOptions);
88+
}
89+
90+
/**
91+
* Use the provided convention for reporting observation data
92+
* @param observationConvention The provided convention
93+
*/
94+
public void setObservationConvention(ChatModelObservationConvention observationConvention) {
95+
Assert.notNull(observationConvention, "observationConvention cannot be null");
96+
this.observationConvention = observationConvention;
97+
}
98+
99+
private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
100+
ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
101+
.prompt(prompt)
102+
.provider(AiProvider.ALIBABA.value())
103+
.requestOptions(prompt.getOptions())
104+
.build();
105+
106+
ChatResponse response = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION
107+
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
108+
this.observationRegistry)
109+
.observe(() -> {
110+
ChatResponse chatResponse = qwenApi.call(prompt, previousChatResponse);
111+
observationContext.setResponse(chatResponse);
112+
return chatResponse;
113+
});
114+
115+
if (toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
116+
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
117+
if (toolExecutionResult.returnDirect()) {
118+
// return tool execution result directly to the client
119+
return ChatResponse.builder()
120+
.from(response)
121+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
122+
.build();
123+
}
124+
else {
125+
// send the tool execution result back to the model
126+
return internalCall(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
127+
response);
128+
}
129+
}
130+
131+
return response;
132+
}
133+
134+
private Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
135+
return Flux.deferContextual(contextView -> {
136+
final ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
137+
.prompt(prompt)
138+
.provider(AiProvider.ALIBABA.value())
139+
.requestOptions(prompt.getOptions())
140+
.build();
141+
142+
Observation observation = ChatModelObservationDocumentation.CHAT_MODEL_OPERATION.observation(
143+
this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
144+
this.observationRegistry);
145+
146+
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
147+
148+
// @formatter:off
149+
Flux<ChatResponse> chatResponse = this.qwenApi.streamCall(prompt, previousChatResponse)
150+
.flatMap(response -> {
151+
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
152+
return Flux.defer(() -> {
153+
var toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response);
154+
if (toolExecutionResult.returnDirect()) {
155+
// return tool execution result directly to the client
156+
return Flux.just(ChatResponse.builder().from(response)
157+
.generations(ToolExecutionResult.buildGenerations(toolExecutionResult))
158+
.build());
159+
} else {
160+
// send the tool execution result back to the model.
161+
return this.internalStream(
162+
new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()),
163+
response);
164+
}
165+
}).subscribeOn(Schedulers.boundedElastic());
166+
}
167+
else {
168+
return Flux.just(response);
169+
}
170+
})
171+
.doOnError(observation::error)
172+
.doFinally(s -> observation.stop())
173+
.contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation));
174+
// @formatter:on
175+
176+
return new MessageAggregator().aggregate(chatResponse, observationContext::setResponse);
177+
});
178+
}
179+
180+
private Prompt buildRequestPrompt(Prompt prompt) {
181+
// process runtime options
182+
QwenChatOptions runtimeOptions = null;
183+
if (prompt.getOptions() != null) {
184+
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
185+
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
186+
QwenChatOptions.class);
187+
}
188+
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
189+
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
190+
QwenChatOptions.class);
191+
}
192+
else {
193+
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
194+
QwenChatOptions.class);
195+
}
196+
}
197+
198+
QwenChatOptions requestOptions = QwenChatOptions.fromOptions(this.defaultOptions).overrideWith(runtimeOptions);
199+
200+
ToolCallingChatOptions.validateToolCallbacks(requestOptions.getToolCallbacks());
201+
202+
return new Prompt(prompt.getInstructions(), requestOptions);
203+
}
204+
205+
public static final class Builder {
206+
207+
private QwenApi qwenApi;
208+
209+
private QwenChatOptions defaultOptions = QwenChatOptions.builder().model(QwenModel.QWEN_MAX.getName()).build();
210+
211+
private ToolCallingManager toolCallingManager;
212+
213+
private ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate = new DefaultToolExecutionEligibilityPredicate();
214+
215+
private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
216+
217+
private Builder() {
218+
}
219+
220+
public Builder qwenApi(QwenApi qwenApi) {
221+
this.qwenApi = qwenApi;
222+
return this;
223+
}
224+
225+
public Builder defaultOptions(QwenChatOptions defaultOptions) {
226+
this.defaultOptions = defaultOptions;
227+
return this;
228+
}
229+
230+
public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
231+
this.toolCallingManager = toolCallingManager;
232+
return this;
233+
}
234+
235+
public Builder toolExecutionEligibilityPredicate(
236+
ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate) {
237+
this.toolExecutionEligibilityPredicate = toolExecutionEligibilityPredicate;
238+
return this;
239+
}
240+
241+
public Builder observationRegistry(ObservationRegistry observationRegistry) {
242+
this.observationRegistry = observationRegistry;
243+
return this;
244+
}
245+
246+
public QwenChatModel build() {
247+
return new QwenChatModel(this.qwenApi, this.defaultOptions, this.toolCallingManager,
248+
this.observationRegistry, this.toolExecutionEligibilityPredicate);
249+
}
250+
251+
}
252+
253+
}

0 commit comments

Comments
 (0)