Skip to content

Commit

Permalink
base64 encode request (#64)
Browse files Browse the repository at this point in the history
* base64 encode request

* base64 doc

* handle immutable map

* avoid npe

* add encoding tag
  • Loading branch information
xiangtianyu authored Aug 12, 2024
1 parent 3410061 commit c71b93d
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 77 deletions.
3 changes: 2 additions & 1 deletion BUILD_PLUGIN.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
If you want to build your own complete plugin like DevPilot, there are some required condition:
1. AI gateway: support multi LLM model and provide api for plugin [Gateway repo](https://github.com/openpilot-hub/devpilot-gateway)
2. Auth System: support authorization check for login user (You can close it by setting `DefaultConst.AUTH_ON` to false)
3. Telemetry System: upload user behavior data for analysis (You can close it by setting `DefaultConst.TELEMETRY_ON` to false)
3. Telemetry System: upload user behavior data for analysis (You can close it by setting `DefaultConst.TELEMETRY_ON` to false)
4. Request Encoding: request will be encoded by base64 for security (You can close it by setting `DefaultConst.REQUEST_ENCODING_ON` to false))
3 changes: 2 additions & 1 deletion BUILD_PLUGIN_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
如果想构建一个完整的属于你自己的DevPilot应用,需要有如下几个条件:
1. AI网关:用于兼容不同的LLM模型,并提供API给插件使用 [网关仓库](https://github.com/openpilot-hub/devpilot-gateway)
2. 权限系统:用于校验插件用户的登录和使用权限(可以通过设置`DefaultConst.AUTH_ON`为false来关闭)
3. 指标系统:用于处理用户上报的使用数据用于分析(可以通过设置`DefaultConst.TELEMETRY_ON`为false来关闭)
3. 指标系统:用于处理用户上报的使用数据用于分析(可以通过设置`DefaultConst.TELEMETRY_ON`为false来关闭)
4. 请求编码:插件的请求内容会适度的编码为base64格式(可以通过设置`DefaultConst.REQUEST_ENCODING_ON`为false来关闭)
2 changes: 2 additions & 0 deletions src/main/java/com/zhongan/devpilot/constant/DefaultConst.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,6 @@ private DefaultConst() {
public static final boolean AUTH_ON = true;

public static final boolean TELEMETRY_ON = true;

public static final boolean REQUEST_ENCODING_ON = true;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,10 @@
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.gson.Gson;
import com.intellij.lang.Language;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.editor.Document;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.fileEditor.FileDocumentManager;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.project.ProjectUtil;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.PsiDocumentManager;
import com.zhongan.devpilot.actions.notifications.DevPilotNotification;
import com.zhongan.devpilot.gui.toolwindows.chat.DevPilotChatToolWindowService;
import com.zhongan.devpilot.integrations.llms.LlmProvider;
Expand All @@ -27,14 +20,14 @@
import com.zhongan.devpilot.settings.state.LanguageSettingsState;
import com.zhongan.devpilot.util.DevPilotMessageBundle;
import com.zhongan.devpilot.util.EditorUtils;
import com.zhongan.devpilot.util.GatewayRequestUtils;
import com.zhongan.devpilot.util.GitUtil;
import com.zhongan.devpilot.util.LoginUtils;
import com.zhongan.devpilot.util.OkhttpUtils;
import com.zhongan.devpilot.util.UserAgentUtils;
import com.zhongan.devpilot.webview.model.MessageModel;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -50,7 +43,6 @@
import okhttp3.sse.EventSource;

import static com.zhongan.devpilot.constant.DefaultConst.AI_GATEWAY_INSTRUCT_COMPLETION;
import static com.zhongan.devpilot.util.VirtualFileUtil.getRelativeFilePath;

@Service(Service.Level.PROJECT)
public final class AIGatewayServiceProvider implements LlmProvider {
Expand Down Expand Up @@ -97,7 +89,7 @@ public String chatCompletion(Project project, DevPilotChatCompletionRequest chat
}
}
var request = requestBuilder
.post(RequestBody.create(objectMapper.writeValueAsString(chatCompletionRequest), MediaType.parse("application/json")))
.post(RequestBody.create(GatewayRequestUtils.chatRequestJson(chatCompletionRequest), MediaType.parse("application/json")))
.build();

DevPilotNotification.debug(LoginUtils.getLoginType() + "---" + UserAgentUtils.buildUserAgent());
Expand Down Expand Up @@ -152,7 +144,7 @@ public DevPilotChatCompletionResponse chatCompletionSync(DevPilotChatCompletionR
Response response;

try {
String requestBody = objectMapper.writeValueAsString(chatCompletionRequest);
String requestBody = GatewayRequestUtils.chatRequestJson(chatCompletionRequest);
DevPilotNotification.debug("Send Request :[" + requestBody + "].");

var request = new Request.Builder()
Expand Down Expand Up @@ -223,34 +215,9 @@ public DevPilotMessage instructCompletion(DevPilotInstructCompletionRequest inst
return null;
}

int offset = instructCompletionRequest.getOffset();
Editor editor = instructCompletionRequest.getEditor();
final Document[] document = new Document[1];
final Language[] language = new Language[1];
final VirtualFile[] virtualFile = new VirtualFile[1];
final String[] relativePath = new String[1];

ApplicationManager.getApplication().runReadAction(() -> {
document[0] = editor.getDocument();
language[0] = PsiDocumentManager.getInstance(editor.getProject()).getPsiFile(document[0]).getLanguage();
virtualFile[0] = FileDocumentManager.getInstance().getFile(document[0]);
relativePath[0] = getRelativeFilePath(editor.getProject(), virtualFile[0]);
});

String text = document[0].getText();

Map<String, String> map = new HashMap<>();
map.put("document", text);
map.put("position", String.valueOf(offset));
map.put("language", language[0].getID());
map.put("filePath", relativePath[0]);
map.put("completionType", instructCompletionRequest.getCompletionType());
ObjectMapper objectMapper = new ObjectMapper();

Response response;
String json;
String json = GatewayRequestUtils.completionRequestJson(instructCompletionRequest);
try {
json = objectMapper.writeValueAsString(map);
var request = new Request.Builder()
.url(host + AI_GATEWAY_INSTRUCT_COMPLETION)
.header("User-Agent", UserAgentUtils.buildUserAgent())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ public class DevPilotChatCompletionRequest {

String version = "V240801";

String encoding = null;

boolean stream = true;

List<DevPilotMessage> messages = new ArrayList<>();
Expand Down Expand Up @@ -35,4 +37,11 @@ public void setVersion(String version) {
this.version = version;
}

public String getEncoding() {
return encoding;
}

public void setEncoding(String encoding) {
this.encoding = encoding;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ public class DevPilotInstructCompletionRequest {

String completionType = "inline";

String encoding = null;

public Editor getEditor() {
return editor;
}
Expand Down Expand Up @@ -95,4 +97,12 @@ public String getCompletionType() {
public void setCompletionType(String completionType) {
this.completionType = completionType;
}

public String getEncoding() {
return encoding;
}

public void setEncoding(String encoding) {
this.encoding = encoding;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,9 @@
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.gson.Gson;
import com.intellij.lang.Language;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.components.Service;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.editor.Document;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.fileEditor.FileDocumentManager;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.PsiDocumentManager;
import com.zhongan.devpilot.actions.notifications.DevPilotNotification;
import com.zhongan.devpilot.gui.toolwindows.chat.DevPilotChatToolWindowService;
import com.zhongan.devpilot.integrations.llms.LlmProvider;
Expand All @@ -24,14 +17,13 @@
import com.zhongan.devpilot.integrations.llms.entity.DevPilotSuccessResponse;
import com.zhongan.devpilot.settings.state.LanguageSettingsState;
import com.zhongan.devpilot.util.DevPilotMessageBundle;
import com.zhongan.devpilot.util.GatewayRequestUtils;
import com.zhongan.devpilot.util.LoginUtils;
import com.zhongan.devpilot.util.OkhttpUtils;
import com.zhongan.devpilot.util.UserAgentUtils;
import com.zhongan.devpilot.webview.model.MessageModel;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;

Expand All @@ -46,7 +38,6 @@

import static com.zhongan.devpilot.constant.DefaultConst.AI_GATEWAY_INSTRUCT_COMPLETION;
import static com.zhongan.devpilot.constant.DefaultConst.TRIAL_DEFAULT_HOST;
import static com.zhongan.devpilot.util.VirtualFileUtil.getRelativeFilePath;

@Service(Service.Level.PROJECT)
public final class TrialServiceProvider implements LlmProvider {
Expand Down Expand Up @@ -75,7 +66,7 @@ public String chatCompletion(Project project, DevPilotChatCompletionRequest chat
.url(TRIAL_DEFAULT_HOST + "/v1/chat/completions")
.header("User-Agent", UserAgentUtils.buildUserAgent())
.header("Auth-Type", "wx")
.post(RequestBody.create(objectMapper.writeValueAsString(chatCompletionRequest), MediaType.parse("application/json")))
.post(RequestBody.create(GatewayRequestUtils.chatRequestJson(chatCompletionRequest), MediaType.parse("application/json")))
.build();

this.es = this.buildEventSource(request, service, callback);
Expand All @@ -100,7 +91,7 @@ public DevPilotChatCompletionResponse chatCompletionSync(DevPilotChatCompletionR
.url(TRIAL_DEFAULT_HOST + "/v1/chat/completions")
.header("User-Agent", UserAgentUtils.buildUserAgent())
.header("Auth-Type", "wx")
.post(RequestBody.create(objectMapper.writeValueAsString(chatCompletionRequest), MediaType.parse("application/json")))
.post(RequestBody.create(GatewayRequestUtils.chatRequestJson(chatCompletionRequest), MediaType.parse("application/json")))
.build();

var call = OkhttpUtils.getClient().newCall(request);
Expand All @@ -123,34 +114,9 @@ public DevPilotMessage instructCompletion(DevPilotInstructCompletionRequest inst
return null;
}

int offset = instructCompletionRequest.getOffset();
Editor editor = instructCompletionRequest.getEditor();
final Document[] document = new Document[1];
final Language[] language = new Language[1];
final VirtualFile[] virtualFile = new VirtualFile[1];
final String[] relativePath = new String[1];

ApplicationManager.getApplication().runReadAction(() -> {
document[0] = editor.getDocument();
language[0] = PsiDocumentManager.getInstance(editor.getProject()).getPsiFile(document[0]).getLanguage();
virtualFile[0] = FileDocumentManager.getInstance().getFile(document[0]);
relativePath[0] = getRelativeFilePath(editor.getProject(), virtualFile[0]);
});

String text = document[0].getText();

Map<String, String> map = new HashMap<>();
map.put("document", text);
map.put("position", String.valueOf(offset));
map.put("language", language[0].getID());
map.put("filePath", relativePath[0]);
map.put("completionType", instructCompletionRequest.getCompletionType());
ObjectMapper objectMapper = new ObjectMapper();

Response response;
String json;
String json = GatewayRequestUtils.completionRequestJson(instructCompletionRequest);
try {
json = objectMapper.writeValueAsString(map);
var request = new Request.Builder()
.url(TRIAL_DEFAULT_HOST + AI_GATEWAY_INSTRUCT_COMPLETION)
.header("User-Agent", UserAgentUtils.buildUserAgent())
Expand Down
13 changes: 13 additions & 0 deletions src/main/java/com/zhongan/devpilot/util/Base64Utils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.zhongan.devpilot.util;

import java.util.Base64;

public class Base64Utils {
public static String base64Encoding(String code) {
return Base64.getEncoder().encodeToString(code.getBytes());
}

public static String base64Decoding(String code) {
return new String(Base64.getDecoder().decode(code));
}
}
82 changes: 82 additions & 0 deletions src/main/java/com/zhongan/devpilot/util/GatewayRequestUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.zhongan.devpilot.util;

import com.intellij.lang.Language;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.editor.Document;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.fileEditor.FileDocumentManager;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.PsiDocumentManager;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotChatCompletionRequest;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotInstructCompletionRequest;
import com.zhongan.devpilot.integrations.llms.entity.DevPilotMessage;

import java.util.HashMap;
import java.util.Map;

import static com.zhongan.devpilot.constant.DefaultConst.REQUEST_ENCODING_ON;
import static com.zhongan.devpilot.util.VirtualFileUtil.getRelativeFilePath;

public class GatewayRequestUtils {
public static String completionRequestJson(DevPilotInstructCompletionRequest instructCompletionRequest) {
int offset = instructCompletionRequest.getOffset();
Editor editor = instructCompletionRequest.getEditor();
final Document[] document = new Document[1];
final Language[] language = new Language[1];
final VirtualFile[] virtualFile = new VirtualFile[1];
final String[] relativePath = new String[1];

ApplicationManager.getApplication().runReadAction(() -> {
document[0] = editor.getDocument();
language[0] = PsiDocumentManager.getInstance(editor.getProject()).getPsiFile(document[0]).getLanguage();
virtualFile[0] = FileDocumentManager.getInstance().getFile(document[0]);
relativePath[0] = getRelativeFilePath(editor.getProject(), virtualFile[0]);
});

String text = document[0].getText();

if (isRequestEncoding()) {
instructCompletionRequest.setEncoding("base64");
text = Base64Utils.base64Encoding(text);
}

Map<String, String> map = new HashMap<>();
map.put("document", text);
map.put("position", String.valueOf(offset));
map.put("language", language[0].getID());
map.put("filePath", relativePath[0]);
map.put("completionType", instructCompletionRequest.getCompletionType());
map.put("encoding", instructCompletionRequest.getEncoding());

return JsonUtils.toJson(map);
}

public static String chatRequestJson(DevPilotChatCompletionRequest chatCompletionRequest) {
if (isRequestEncoding()) {
chatCompletionRequest.setEncoding("base64");

var messageList = chatCompletionRequest.getMessages();

for (DevPilotMessage devPilotMessage : messageList) {
if (devPilotMessage.getContent() != null) {
devPilotMessage.setContent(Base64Utils.base64Encoding(devPilotMessage.getContent()));
}

// avoid immutable map
if (devPilotMessage.getPromptData() != null) {
var promptData = new HashMap<>(devPilotMessage.getPromptData());
for (Map.Entry<String, String> entry : promptData.entrySet()) {
entry.setValue(Base64Utils.base64Encoding(entry.getValue()));
}
devPilotMessage.setPromptData(promptData);
}
}
}

return JsonUtils.toJson(chatCompletionRequest);
}

private static boolean isRequestEncoding() {
return REQUEST_ENCODING_ON;
}
}

0 comments on commit c71b93d

Please sign in to comment.