Skip to content

Commit

Permalink
添加自动菜单,解决解释sql时没有数据库名的bug,解决解释sql时openai超时的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
hejianjun committed Feb 1, 2024
1 parent 462efdb commit e503072
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,17 @@ const ChatInput = (props: IProps) => {
};

const renderSelectTable = () => {
const { tables, onSelectTableSyncModel, selectedTables, onSelectTables } = props;
const { tables, onSelectTableSyncModel, selectedTables, onSelectTables,syncTableModel } = props;
const options = (tables || []).map((t) => ({ value: t, label: t }));
return (
<div className={styles.aiSelectedTable}>
<Radio.Group
onChange={(v) => onSelectTableSyncModel(v.target.value)}
// value={syncTableModel}
value={SyncModelType.MANUAL}
value={syncTableModel}
style={{ marginBottom: '8px' }}
>
<Space direction="horizontal">
{/* <Radio value={SyncModelType.AUTO}>自动</Radio> */}
<Radio value={SyncModelType.AUTO}>自动</Radio>
<Radio value={SyncModelType.MANUAL}>手动</Radio>
</Space>
</Radio.Group>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ const SelectBoundInfo = memo((props: IProps) => {
boundInfo.databaseName,
boundInfo.schemaName,
);
setSelectedTables(tableNameListTemp.slice(0, 1));
//setSelectedTables(tableNameListTemp.slice(0, 1));
}
}, [allTableList, isActive]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ private SseEmitter chatWithOpenAi(ChatQueryRequest queryRequest, SseEmitter sseE
messages.add(currentMessage);
buildSseEmitter(sseEmitter, uid);
ConnectInfo connectInfo = Chat2DBContext.getConnectInfo();
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter, messages, connectInfo);
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter, messages, connectInfo, queryRequest);
ToolsFunction function = ToolsFunction.builder()
.name("get_table_columns")
.description("获取指定表的字段名,类型")
Expand Down Expand Up @@ -799,8 +799,16 @@ public String queryDatabaseSchema(ChatQueryRequest queryRequest) {
*/
public String queryDatabaseSchema2(ChatQueryRequest queryRequest) {
MetaData metaSchema = Chat2DBContext.getMetaData();
List<Table> tables = metaSchema.tables(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), null);
return tables.stream().map(Table::getName).collect(Collectors.joining(","));
try {
List<Table> tables = metaSchema.tables(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), null);
return tables.stream()
.map(table -> StringUtils.isBlank(table.getComment()) ? table.getName()
: table.getName() + "(" + table.getComment() + ")")
.collect(Collectors.joining(","));
} catch (Exception e) {
log.error("query table error:{}, do nothing", e.getMessage());
return "";
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.Objects;
import java.util.concurrent.TimeUnit;

import ai.chat2db.server.domain.api.model.Config;
import ai.chat2db.server.domain.api.service.ConfigService;
Expand Down Expand Up @@ -93,7 +94,17 @@ public static void refresh() {
log.info("refresh openai apikey:{}", maskApiKey(apikey));
if (Objects.nonNull(host) && Objects.nonNull(port)) {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(host, port));
OkHttpClient okHttpClient = new OkHttpClient.Builder().proxy(proxy).build();
OkHttpClient okHttpClient = new OkHttpClient.Builder()
// 设置连接超时为10秒
.connectTimeout(10, TimeUnit.SECONDS)
// 设置读取超时为30秒
.readTimeout(30, TimeUnit.SECONDS)
// 设置写入超时为15秒
.writeTimeout(15, TimeUnit.SECONDS)
// 设置整个调用的超时为1分钟
.callTimeout(1, TimeUnit.MINUTES)
.proxy(proxy)
.build();
OPEN_AI_STREAM_CLIENT = OpenAiStreamClient.builder().apiHost(apiHost).apiKey(
Lists.newArrayList(apikey)).okHttpClient(okHttpClient).build();
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.chat2db.server.web.api.controller.ai.openai.listener;

import ai.chat2db.server.web.api.controller.ai.openai.client.OpenAIClient;
import ai.chat2db.server.web.api.controller.ai.request.ChatQueryRequest;
import ai.chat2db.server.web.api.controller.ai.response.ChatCompletionResponse;
import ai.chat2db.spi.MetaData;
import ai.chat2db.spi.sql.Chat2DBContext;
Expand Down Expand Up @@ -39,13 +40,16 @@ public class OpenAIEventSourceListener extends EventSourceListener {

private final ConnectInfo connectInfo;

private final ChatQueryRequest queryRequest;

private List<ToolCalls> toolCalls = new ArrayList<>();


public OpenAIEventSourceListener(SseEmitter sseEmitter, List<Message> messages, ConnectInfo connectInfo) {
public OpenAIEventSourceListener(SseEmitter sseEmitter, List<Message> messages, ConnectInfo connectInfo, ChatQueryRequest queryRequest) {
this.sseEmitter = sseEmitter;
this.messages = messages;
this.connectInfo = connectInfo;
this.queryRequest = queryRequest;
}

public static List<ToolCalls> mergeToolCallsLists(List<ToolCalls> list1, List<ToolCalls> list2) {
Expand Down Expand Up @@ -142,7 +146,7 @@ public void onEvent(EventSource eventSource, String id, String type, String data
JSONObject arguments = JSONObject.parse(function.getArguments());
if ("get_table_columns".equals(functionName)) {
MetaData metaSchema = Chat2DBContext.getMetaData();
String ddl = metaSchema.tableDDL(Chat2DBContext.getConnection(), connectInfo.getDatabaseName(), connectInfo.getSchemaName(), arguments.getString("table_name"));
String ddl = metaSchema.tableDDL(Chat2DBContext.getConnection(), queryRequest.getDatabaseName(), queryRequest.getSchemaName(), arguments.getString("table_name"));
messages.add(Message.builder().role(BaseMessage.Role.TOOL)
.toolCallId(callId)
.name(functionName)
Expand Down

0 comments on commit e503072

Please sign in to comment.