diff --git a/lua/ogpt/api.lua b/lua/ogpt/api.lua index ff4a843..7fd3c77 100644 --- a/lua/ogpt/api.lua +++ b/lua/ogpt/api.lua @@ -5,6 +5,19 @@ local Utils = require("ogpt.utils") local Api = {} +function Api.get_provider() + local provider + if type(Config.options.default_provider) == "string" then + provider = require("ogpt.provider." .. Config.options.default_provider) + else + provider = require("ogpt.provider." .. Config.options.default_provider.name) + provider.envs = vim.tbl_extend("force", provider.envs, Config.options.default_provider) + end + local envs = provider.load_envs() + Api = vim.tbl_extend("force", Api, envs) + return provider +end + function Api.completions(custom_params, cb) local params = vim.tbl_extend("keep", custom_params, Config.options.api_params) params.stream = false @@ -14,11 +27,26 @@ end function Api.chat_completions(custom_params, cb, should_stop, opts) local params = vim.tbl_extend("keep", custom_params, Config.options.api_params) local stream = params.stream or false + local _model = params.model + + local _completion_url = Api.CHAT_COMPLETIONS_URL + if type(_model) == "table" then + if _model.modify_url and type(_model.modify_url) == "function" then + _completion_url = _model.modify_url(_completion_url) + else + _completion_url = _model.modify_url + end + end + + if _model and _model.conform_fn then + params = _model.conform_fn(params) + else + params = Api.provider.conform(params) + end + local ctx = {} - -- add params before conform ctx.params = params if stream then - params = Utils.conform_to_ollama(params) local raw_chunks = "" local state = "START" @@ -30,7 +58,7 @@ function Api.chat_completions(custom_params, cb, should_stop, opts) "--silent", "--show-error", "--no-buffer", - Api.CHAT_COMPLETIONS_URL, + _completion_url, "-H", "Content-Type: application/json", "-H", @@ -39,33 +67,18 @@ function Api.chat_completions(custom_params, cb, should_stop, opts) vim.json.encode(params), }, function(chunk) - local process_line = function(_ok, _json) - if _json and _json.done then - ctx.context = _json.context - cb(raw_chunks, "END", ctx) - else - if _ok and not vim.tbl_isempty(_json) then - if _json and _json.message then - cb(_json.message.content, state, ctx) - raw_chunks = raw_chunks .. _json.message.content - state = "CONTINUE" - end - end - end - end - local ok, json = pcall(vim.json.decode, chunk) if ok and json ~= nil then if json.error ~= nil then cb(json.error, "ERROR", ctx) return end - process_line(ok, json) + ctx, raw_chunks, state = Api.provider.process_line(_ok, _json, ctx, raw_chunks, state, cb) else for line in chunk:gmatch("[^\n]+") do - local raw_json = string.gsub(line, "^data: ", "") + local raw_json = string.gsub(line, "^data:", "") local _ok, _json = pcall(vim.json.decode, raw_json) - process_line(_ok, _json) + ctx, raw_chunks, state = Api.provider.process_line(_ok, _json, ctx, raw_chunks, state, cb) end end end, @@ -93,8 +106,6 @@ function Api.edits(custom_params, cb) end function Api.make_call(url, params, cb) - params = Utils.conform_to_ollama(params) - TMP_MSG_FILENAME = os.tmpname() local f = io.open(TMP_MSG_FILENAME, "w+") if f == nil then @@ -274,23 +285,21 @@ local function ensureUrlProtocol(str) end function Api.setup() - loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", function(value) - Api.OLLAMA_API_HOST = value - Api.MODELS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/tags") - Api.COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/generate") - Api.CHAT_COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/chat") - end, "http://localhost:11434") - - loadApiKey("OLLAMA_API_KEY", "OLLAMA_API_KEY", "api_key_cmd", function(value) - Api.OLLAMA_API_KEY = value - loadConfigFromEnv("OPENAI_API_TYPE", "OPENAI_API_TYPE") - if Api["OPENAI_API_TYPE"] == "azure" then - loadAzureConfigs() - Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OLLAMA_API_KEY - else - Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OLLAMA_API_KEY - end - end, " ") + local provider = Api.get_provider() + Api.provider = provider + + -- loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", provider.make_url, "http://localhost:11434") + + -- loadApiKey("OLLAMA_API_KEY", "OLLAMA_API_KEY", "api_key_cmd", function(value) + -- Api.OLLAMA_API_KEY = value + -- loadConfigFromEnv("OPENAI_API_TYPE", "OPENAI_API_TYPE") + -- if Api["OPENAI_API_TYPE"] == "azure" then + -- loadAzureConfigs() + -- Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OLLAMA_API_KEY + -- else + -- Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OLLAMA_API_KEY + -- end + -- end, " ") end function Api.exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop) diff --git a/lua/ogpt/config.lua b/lua/ogpt/config.lua index a5b0424..055c7d6 100644 --- a/lua/ogpt/config.lua +++ b/lua/ogpt/config.lua @@ -9,6 +9,11 @@ local M = {} function M.defaults() local defaults = { api_key_cmd = nil, + default_provider = { + name = "ollama", + api_host = os.getenv("OLLAMA_API_HOST"), + api_key = os.getenv("OLLAMA_API_KEY"), + }, yank_register = "+", edit_with_instructions = { diff = false, diff --git a/lua/ogpt/flows/chat/base.lua b/lua/ogpt/flows/chat/base.lua index f5991ef..a2b00b9 100644 --- a/lua/ogpt/flows/chat/base.lua +++ b/lua/ogpt/flows/chat/base.lua @@ -282,78 +282,6 @@ function Chat:addAnswerPartial(text, state, ctx) end end --- function Chat:addAnswerPartial(text, state, ctx) --- if state == "ERROR" then --- return self:addAnswer(text, {}) --- end --- --- local start_line = 0 --- if self.selectedIndex > 0 then --- local prev = self.messages[self.selectedIndex] --- start_line = prev.end_line + (prev.type == ANSWER and 2 or 1) --- end --- --- if state == "END" then --- local usage = {} --- local idx = self.session:add_item({ --- type = ANSWER, --- text = text, --- ctx = ctx or {}, --- usage = usage, --- }) --- --- local lines = {} --- local nr_of_lines = 0 --- for line in string.gmatch(text, "[^\n]+") do --- nr_of_lines = nr_of_lines + 1 --- table.insert(lines, line) --- end --- --- local end_line = start_line + nr_of_lines - 1 --- table.insert(self.messages, { --- idx = idx, --- usage = usage or {}, --- type = ANSWER, --- text = text, --- lines = lines, --- nr_of_lines = nr_of_lines, --- start_line = start_line, --- end_line = end_line, --- context = ctx.context, --- }) --- self.selectedIndex = self.selectedIndex + 1 --- vim.api.nvim_buf_set_lines(self.chat_window.bufnr, -1, -1, false, { "", "" }) --- Signs.set_for_lines(self.chat_window.bufnr, start_line, end_line, "chat") --- end --- --- if state == "START" then --- self:stopSpinner() --- self:set_lines(-2, -1, false, { "" }) --- vim.api.nvim_buf_set_option(self.chat_window.bufnr, "modifiable", true) --- end --- --- if state == "START" or state == "CONTINUE" then --- local lines = vim.split(text, "\n", {}) --- local length = #lines --- local buffer = self.chat_window.bufnr --- local win = self.chat_window.winid --- --- for i, line in ipairs(lines) do --- local currentLine = vim.api.nvim_buf_get_lines(buffer, -2, -1, false)[1] --- vim.api.nvim_buf_set_lines(buffer, -2, -1, false, { currentLine .. line }) --- --- local last_line_num = vim.api.nvim_buf_line_count(buffer) --- Signs.set_for_lines(self.chat_window.bufnr, start_line, last_line_num - 1, "chat") --- if i == length and i > 1 then --- vim.api.nvim_buf_set_lines(buffer, -1, -1, false, { "" }) --- end --- if self:is_buf_visiable() then --- vim.api.nvim_win_set_cursor(win, { last_line_num, 0 }) --- end --- end --- end --- end - function Chat:get_total_tokens() local total_tokens = 0 for i = 1, #self.messages, 1 do @@ -567,16 +495,6 @@ function Chat:toMessages() return messages end -function Chat:toOllama(messages) - local output = "" - for _, entry in ipairs(messages) do - if entry.content then - output = output .. entry.role .. ": " .. entry.content .. "\n\n" - end - end - return output -end - function Chat:count() local count = 0 for _ in pairs(self.messages) do diff --git a/lua/ogpt/provider/ollama.lua b/lua/ogpt/provider/ollama.lua new file mode 100644 index 0000000..e79d70c --- /dev/null +++ b/lua/ogpt/provider/ollama.lua @@ -0,0 +1,110 @@ +local utils = require("ogpt.utils") + +local M = {} + +M.envs = { + api_host = os.getenv("OGPT_API_HOST"), + api_key = os.getenv("OGPT_API_KEY"), +} + +function M.load_envs() + local _envs = {} + _envs.OLLAMA_API_HOST = M.envs.api_host + _envs.OLLAMA_API_KEY = M.envs.api_key + _envs.MODELS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/tags") + _envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/generate") + _envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/chat") + _envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (_envs.OLLAMA_API_KEY or " ") + M.envs = vim.tbl_extend("force", M.envs, _envs) + return M.envs +end + +M.ollama_options = { + "num_keep", + "seed", + "num_predict", + "top_k", + "top_p", + "tfs_z", + "typical_p", + "repeat_last_n", + "temperature", + "repeat_penalty", + "presence_penalty", + "frequency_penalty", + "mirostat", + "mirostat_tau", + "mirostat_eta", + "penalize_newline", + "stop", + "numa", + "num_ctx", + "num_batch", + "num_gqa", + "num_gpu", + "main_gpu", + "low_vram", + "f16_kv", + "logits_all", + "vocab_only", + "use_mmap", + "use_mlock", + "embedding_only", + "rope_frequency_base", + "rope_frequency_scale", + "num_thread", +} + +function M.conform(params) + local ollama_parameters = { + "model", + -- "prompt", + "messages", + "format", + "options", + "system", + "template", + -- "context", + "stream", + "raw", + } + + -- https://github.com/jmorganca/ollama/blob/main/docs/api.md#show-model-information + + local param_options = {} + + for key, value in pairs(params) do + if not vim.tbl_contains(ollama_parameters, key) then + if vim.tbl_contains(M.ollama_options, key) then + param_options[key] = value + params[key] = nil + else + params[key] = nil + end + end + end + local _options = vim.tbl_extend("keep", param_options, params.options or {}) + if next(_options) ~= nil then + params.options = _options + end + return params +end + +function M.process_line(_ok, _json, ctx, raw_chunks, state, cb) + if _json and _json.done then + ctx.context = _json.context + cb(raw_chunks, "END", ctx) + else + if _ok and not vim.tbl_isempty(_json) then + if _json and _json.message then + cb(_json.message.content, state, ctx) + raw_chunks = raw_chunks .. _json.message.content + state = "CONTINUE" + end + end + end + + return ctx, raw_chunks, state +end + +return M diff --git a/lua/ogpt/provider/textgenui.lua b/lua/ogpt/provider/textgenui.lua new file mode 100644 index 0000000..a21e43d --- /dev/null +++ b/lua/ogpt/provider/textgenui.lua @@ -0,0 +1,109 @@ +local utils = require("ogpt.utils") +local M = {} + +M.envs = { + api_host = os.getenv("OGPT_API_HOST"), + api_key = os.getenv("OGPT_API_KEY"), +} + +function M.load_envs() + local _envs = {} + _envs.MODELS_URL = utils.ensureUrlProtocol(M.envs.api_host .. "/api/tags") + _envs.COMPLETIONS_URL = utils.ensureUrlProtocol(M.envs.api_host) + _envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(M.envs.api_host) + _envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (M.envs.api_key or " ") + M.envs = vim.tbl_extend("force", M.envs, _envs) + return M.envs +end + +M.textgenui_options = { "seed", "top_k", "top_p", "stop" } + +function M.conform_to_textgenui_api(params) + local model_params = { + "seed", + "top_k", + "top_p", + "stop", + } + + local request_params = { + "inputs", + "parameters", + "stream", + } + + local param_options = {} + + for key, value in pairs(params) do + if not vim.tbl_contains(request_params, key) then + if vim.tbl_contains(model_params, key) then + param_options[key] = value + params[key] = nil + else + params[key] = nil + end + end + end + local _options = vim.tbl_extend("keep", param_options, params.options or {}) + if next(_options) ~= nil then + params.parameters = _options + end + return params +end +function M.update_messages(messages) + -- https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1 + local tokens = { + BOS = "", + EOS = "", + INST_START = "[INST]", + INST_END = "[/INST]", + } + local _input = { tokens.BOS } + for i, message in ipairs(messages) do + if i < #messages then -- Stop before the last item + if message.role == "user" then + table.insert(_input, tokens.INST_START) + table.insert(_input, message.content) + table.insert(_input, tokens.INST_END) + elseif message.role == "system" then + table.insert(_input, message.content) + if i == #message - 1 then + table.insert(_input, tokens.EOS) + end + end + else + table.insert(_input, tokens.INST_START) + table.insert(_input, message.content) + table.insert(_input, tokens.INST_END) + end + end + local final_string = table.concat(_input, " ") + return final_string +end + +function M.conform(params) + params = params or {} + params.inputs = M.update_messages(params.messages or {}) + return M.conform_to_textgenui_api(params) +end + +function M.process_line(_ok, _json, ctx, raw_chunks, state, cb) + if not _ok then + return + end + if _json and (_json.details ~= vim.NIL) and (_json.details.finished_reason == "eos_token") then + ctx.context = _json.context + cb(raw_chunks, "END", ctx) + else + if _ok and not vim.tbl_isempty(_json) then + if _json and _json.token then + cb(_json.token.text, state, ctx) + raw_chunks = raw_chunks .. _json.token.text + state = "CONTINUE" + end + end + end + return ctx, raw_chunks, state +end + +return M diff --git a/lua/ogpt/utils.lua b/lua/ogpt/utils.lua index 25e5565..c7e8fa2 100644 --- a/lua/ogpt/utils.lua +++ b/lua/ogpt/utils.lua @@ -38,6 +38,13 @@ M.ollama_options = { "num_thread", } +M.textgenui_options = { + "seed", + "top_k", + "top_p", + "stop", +} + function M.split(text) local t = {} for str in string.gmatch(text, "%S+") do @@ -241,6 +248,63 @@ function M.conform_to_ollama(params) return M._conform_to_ollama_api(params) end +function M._conform_to_textgenui_api(params) + local model_params = { + "seed", + "top_k", + "top_p", + "stop", + } + + local request_params = { + "inputs", + "parameters", + "stream", + } + + local param_options = {} + + for key, value in pairs(params) do + if not vim.tbl_contains(request_params, key) then + if vim.tbl_contains(model_params, key) then + param_options[key] = value + params[key] = nil + else + params[key] = nil + end + end + end + local _options = vim.tbl_extend("keep", param_options, params.options or {}) + if next(_options) ~= nil then + params.parameters = _options + end + return params +end + +function M.conform_to_textgenui(params) + -- conform to mixtral + -- [INST] Instruction [/INST] Model answer [INST] Follow-up instruction [/INST] + if params.messages then + local messages = params.messages + params.messages = nil + -- params.system = params.system or "" + params.inputs = params.inputs or "" + -- for _, message in ipairs(messages) do + -- if message.role == "system" then + -- params.system = params.system .. "\n" .. message.content .. "\n" + -- end + -- end + + for _, message in ipairs(messages) do + if message.role == "user" then + params.inputs = params.inputs .. "\n" .. message.content .. "\n" + end + end + end + + return M._conform_to_textgenui_api(params) +end + function M.extract_code(text) -- Iterate through all code blocks in the message using a regular expression pattern local lastCodeBlock @@ -394,4 +458,33 @@ function M.escape_pattern(text) return text:gsub("([^%w])", "%%%1") end +function M.update_url_route(url, new_model) + local host = url:match("https?://([^/]+)") + local subdomain, domain, tld = host:match("([^.]+)%.([^.]+)%.([^.]+)") + local _new_url = url:gsub(host, new_model .. "." .. domain .. "." .. tld) + return _new_url +end + +function M.to_model_string(messages) + local output = "" + for _, entry in ipairs(messages) do + if entry.content then + output = output .. entry.role .. ": " .. entry.content .. "\n\n" + end + end + return output +end + +function M.startsWith(str, start) + return string.sub(str, 1, string.len(start)) == start +end + +function M.ensureUrlProtocol(str) + if M.startsWith(str, "https://") or M.startsWith(str, "http://") then + return str + end + + return "https://" .. str +end + return M