Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ollama support #35

Merged
merged 4 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
public enum ModelServiceEnum {
OPENAI("OpenAI", "OpenAI Service"),
LLAMA("LLaMA", "Code LLaMA (Locally)"),
AIGATEWAY("AIGateway", "AI Gateway");
AIGATEWAY("AIGateway", "AI Gateway"),
OLLAMA("Ollama", "Ollama Service");

// model name
private final String name;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package com.zhongan.devpilot.integrations.llms;

import com.google.gson.Gson;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not recommend to use gson, please use jackson.

import com.intellij.openapi.project.Project;
import com.zhongan.devpilot.gui.toolwindows.chat.DevPilotChatToolWindowService;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotChatCompletionRequest;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotChatCompletionResponse;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotMessage;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotSuccessStreamingResponse;
import com.zhongan.devpilot.util.JsonUtils;
import com.zhongan.devpilot.util.OkhttpUtils;
import com.zhongan.devpilot.webview.model.MessageModel;

Expand Down Expand Up @@ -47,7 +47,8 @@ public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Null
return;
}

var response = JsonUtils.fromJson(data, DevPilotSuccessStreamingResponse.class);
// var response = JsonUtils.fromJson(data, DevPilotSuccessStreamingResponse.class);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't originally intend to modify this, but the existing code is causing errors. Does it not produce any errors when you run the code ?

Unrecognized field "system_fingerprint" (class com.zhongan.devpilot.integrations.llms.entity.DevPilotSuccessStreamingResponse), not marked as ignorable

there's a property named system_fingerprint in data, but DevPilotSuccessStreamingResponse has no such property. how about add a system_fingerprint property to class DevPilotSuccessStreamingResponse?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can use such "@JsonIgnoreProperties(ignoreUnknown = true)" to ignore unused field

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I have already modified it according to your suggestion and pushed.

var response = new Gson().fromJson(data, DevPilotSuccessStreamingResponse.class);

if (response == null) {
interruptSend();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.zhongan.devpilot.enums.ModelServiceEnum;
import com.zhongan.devpilot.integrations.llms.aigateway.AIGatewayServiceProvider;
import com.zhongan.devpilot.integrations.llms.llama.LlamaServiceProvider;
import com.zhongan.devpilot.integrations.llms.ollama.OllamaServiceProvider;
import com.zhongan.devpilot.integrations.llms.openai.OpenAIServiceProvider;
import com.zhongan.devpilot.settings.state.DevPilotLlmSettingsState;

Expand All @@ -23,6 +24,8 @@ public LlmProvider getLlmProvider(Project project) {
return project.getService(LlamaServiceProvider.class);
case AIGATEWAY:
return project.getService(AIGatewayServiceProvider.class);
case OLLAMA:
return project.getService(OllamaServiceProvider.class);
}

return project.getService(AIGatewayServiceProvider.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package com.zhongan.devpilot.integrations.llms.ollama;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.intellij.openapi.components.Service;
import com.intellij.openapi.project.Project;
import com.zhongan.devpilot.gui.toolwindows.chat.DevPilotChatToolWindowService;
import com.zhongan.devpilot.integrations.llms.LlmProvider;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotChatCompletionRequest;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotChatCompletionResponse;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotFailedResponse;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotMessage;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotSuccessResponse;
import com.zhongan.devpilot.settings.state.OllamaSettingsState;
import com.zhongan.devpilot.util.DevPilotMessageBundle;
import com.zhongan.devpilot.util.OkhttpUtils;
import com.zhongan.devpilot.util.UserAgentUtils;
import com.zhongan.devpilot.webview.model.MessageModel;

import java.io.IOException;
import java.util.Objects;
import java.util.function.Consumer;

import org.apache.commons.lang3.StringUtils;

import okhttp3.Call;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;

@Service(Service.Level.PROJECT)
public final class OllamaServiceProvider implements LlmProvider {

private final ObjectMapper objectMapper = new ObjectMapper();

private EventSource es;

private DevPilotChatToolWindowService toolWindowService;

private MessageModel resultModel = new MessageModel();

@Override
public String chatCompletion(Project project, DevPilotChatCompletionRequest chatCompletionRequest, Consumer<String> callback) {
var host = OllamaSettingsState.getInstance().getModelHost();
var modelName = OllamaSettingsState.getInstance().getModelName();
var service = project.getService(DevPilotChatToolWindowService.class);
this.toolWindowService = service;

if (StringUtils.isEmpty(host)) {
service.callErrorInfo("Chat completion failed: host is empty");
return "";
}

if (StringUtils.isEmpty(modelName)) {
service.callErrorInfo("Chat completion failed: ollama model name is empty");
return "";
}

if (host.endsWith("/")) {
host = host.substring(0, host.length() - 1);
}

chatCompletionRequest.setModel(modelName);

try {
var request = new Request.Builder()
.url(host + "/v1/chat/completions")
.header("User-Agent", UserAgentUtils.getUserAgent())
.post(RequestBody.create(objectMapper.writeValueAsString(chatCompletionRequest), MediaType.parse("application/json")))
.build();

this.es = this.buildEventSource(request, service, callback);
} catch (Exception e) {
service.callErrorInfo("Chat completion failed: " + e.getMessage());
return "";
}

return "";
}

@Override
public void interruptSend() {
if (es != null) {
es.cancel();
// remember the broken message
if (resultModel != null && !StringUtils.isEmpty(resultModel.getContent())) {
resultModel.setStreaming(false);
toolWindowService.addMessage(resultModel);
}

toolWindowService.callWebView();
// after interrupt, reset result model
resultModel = null;
}
}

@Override
public DevPilotChatCompletionResponse chatCompletionSync(DevPilotChatCompletionRequest chatCompletionRequest) {
var host = OllamaSettingsState.getInstance().getModelHost();
var modelName = OllamaSettingsState.getInstance().getModelName();

if (StringUtils.isEmpty(host)) {
return DevPilotChatCompletionResponse.failed("Chat completion failed: host is empty");
}

if (StringUtils.isEmpty(modelName)) {
return DevPilotChatCompletionResponse.failed("Chat completion failed: ollama model name is empty");
}

chatCompletionRequest.setModel(modelName);

okhttp3.Response response;

try {
var request = new Request.Builder()
.url(host + "/v1/chat/completions")
.header("User-Agent", UserAgentUtils.getUserAgent())
.post(RequestBody.create(objectMapper.writeValueAsString(chatCompletionRequest), MediaType.parse("application/json")))
.build();

Call call = OkhttpUtils.getClient().newCall(request);
response = call.execute();
} catch (Exception e) {
return DevPilotChatCompletionResponse.failed("Chat completion failed: " + e.getMessage());
}

try {
return parseResult(chatCompletionRequest, response);
} catch (IOException e) {
return DevPilotChatCompletionResponse.failed("Chat completion failed: " + e.getMessage());
}
}

private DevPilotChatCompletionResponse parseResult(DevPilotChatCompletionRequest chatCompletionRequest, okhttp3.Response response) throws IOException {
if (response == null) {
return DevPilotChatCompletionResponse.failed(DevPilotMessageBundle.get("devpilot.chatWindow.response.null"));
}

var result = Objects.requireNonNull(response.body()).string();

if (response.isSuccessful()) {
var message = objectMapper.readValue(result, DevPilotSuccessResponse.class)
.getChoices()
.get(0)
.getMessage();
var devPilotMessage = new DevPilotMessage();
devPilotMessage.setRole("assistant");
devPilotMessage.setContent(message.getContent());
chatCompletionRequest.getMessages().add(devPilotMessage);
return DevPilotChatCompletionResponse.success(message.getContent());

} else {
return DevPilotChatCompletionResponse.failed(objectMapper.readValue(result, DevPilotFailedResponse.class)
.getError()
.getMessage());
}
}

}
Loading
Loading