diff --git a/lua/ogpt/common/layout.lua b/lua/ogpt/common/layout.lua index 8409bdd..949b350 100644 --- a/lua/ogpt/common/layout.lua +++ b/lua/ogpt/common/layout.lua @@ -29,6 +29,16 @@ function Layout:init(boxes, options, box, edgy) end end +function Layout:mount_all(...) + if not self.edgy then + Layout.super.mount(self, ...) + else + for _, box in ipairs(self.boxes) do + box:mount(...) + end + end +end + function Layout:mount(...) if not self.edgy then Layout.super.mount(self, ...) diff --git a/lua/ogpt/config.lua b/lua/ogpt/config.lua index 61789f8..4670f7a 100644 --- a/lua/ogpt/config.lua +++ b/lua/ogpt/config.lua @@ -38,6 +38,22 @@ function M.defaults() top_p = 0.99, }, }, + openrouter = { + enabled = true, + model = "gpt-4", + api_host = os.getenv("OPENROUTER_API_HOST") or "https://openrouter.ai/api", + api_key = os.getenv("OPENROUTER_API_KEY") or "", + api_params = { + temperature = 0.5, + top_p = 0.99, + }, + api_chat_params = { + frequency_penalty = 0.8, + presence_penalty = 0.5, + temperature = 0.8, + top_p = 0.99, + }, + }, gemini = { enabled = true, api_host = os.getenv("GEMINI_API_HOST"), diff --git a/lua/ogpt/flows/chat/base.lua b/lua/ogpt/flows/chat/base.lua index 0b9aa3f..00233e5 100644 --- a/lua/ogpt/flows/chat/base.lua +++ b/lua/ogpt/flows/chat/base.lua @@ -1040,6 +1040,7 @@ function Chat:set_keymaps() -- toggle system self:map(Config.options.chat.keymaps.toggle_system_role_open, function() if self.system_role_open and self.active_panel == self.system_role_panel then + self.system_role_panel:hide() self:set_active_panel(self.chat_input) end @@ -1048,6 +1049,7 @@ function Chat:set_keymaps() self:redraw() if self.system_role_open then + self.system_role_panel:show() self:set_active_panel(self.system_role_panel) end end) diff --git a/lua/ogpt/models.lua b/lua/ogpt/models.lua index 715dc4c..2a16236 100644 --- a/lua/ogpt/models.lua +++ b/lua/ogpt/models.lua @@ -58,7 +58,11 @@ local finder = function(provider, opts) process_single_model({ name = model }) end - if not job_started and provider:models_url() then + if job_started and job_completed then + process_complete() + end + + if provider:models_url() then local args = {} vim.list_extend(args, { provider:models_url() }) vim.list_extend(args, provider:request_headers()) @@ -79,8 +83,6 @@ local finder = function(provider, opts) job_started = true job:new(job_opts):start() - else - process_complete() end end, }) diff --git a/lua/ogpt/provider/base.lua b/lua/ogpt/provider/base.lua index aa807e8..3ef6af6 100644 --- a/lua/ogpt/provider/base.lua +++ b/lua/ogpt/provider/base.lua @@ -95,8 +95,10 @@ function Provider:models_url() end function Provider:request_headers(params) - local _model = params.model - local _conform_headers_fn = _model and _model.conform_headers_fn + local _conform_headers_fn + if vim.tbl_get(params or {}, "model", "conform_headers_fn") then + _conform_headers_fn = params.model.conform_headers_fn + end if _conform_headers_fn then params = _conform_headers_fn(self, params) diff --git a/lua/ogpt/provider/openrouter.lua b/lua/ogpt/provider/openrouter.lua new file mode 100644 index 0000000..0b486ca --- /dev/null +++ b/lua/ogpt/provider/openrouter.lua @@ -0,0 +1,132 @@ +local Config = require("ogpt.config") +local utils = require("ogpt.utils") +local ProviderBase = require("ogpt.provider.base") +local Response = require("ogpt.response") +local OpenRouter = ProviderBase:extend("openrouter") + +function OpenRouter:init(opts) + OpenRouter.super.init(self, opts) + self.name = "openrouter" + self.api_parameters = { + "messages", + "model", + + "stop", + "stream", + + "max_tokens", + "temperature", + "top_p", + "top_k", + "frequency_penalty", + "presence_penalty", + "repetition_penalty", + "seed", + } + self.api_chat_request_options = {} +end + +function OpenRouter:load_envs(override) + local _envs = {} + _envs.OPENROUTER_API_HOST = Config.options.providers.openrouter.api_host + or os.getenv("OPENROUTER_API_HOST") + or "https://api.openrouter.com" + _envs.OPENROUTER_API_KEY = Config.options.providers.openrouter.api_key or os.getenv("OPENROUTER_API_KEY") or "" + _envs.MODELS_URL = utils.ensureUrlProtocol(_envs.OPENROUTER_API_HOST .. "/v1/models") + _envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OPENROUTER_API_HOST .. "/v1/completions") + _envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OPENROUTER_API_HOST .. "/v1/chat/completions") + _envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (_envs.OPENROUTER_API_KEY or " ") + self.envs = vim.tbl_extend("force", _envs, override or {}) + return self.envs +end + +function OpenRouter:parse_api_model_response(res, cb) + local response = table.concat(res, "\n") + local ok, json = pcall(vim.fn.json_decode, response) + if not ok then + vim.print("OGPT ERROR: something happened when trying request for models from " .. self:models_url()) + else + local data = json.data or {} + for _, model in ipairs(data) do + cb({ + name = model.id, + }) + end + end +end + +function OpenRouter:conform_request(params) + local _to_remove_system_idx = {} + for idx, message in ipairs(params.messages) do + if message.role == "system" then + table.insert(_to_remove_system_idx, idx) + end + end + + -- Remove elements from the list based on indices + for i = #_to_remove_system_idx, 1, -1 do + table.remove(params.messages, _to_remove_system_idx[i]) + end + + -- conform to support text only model + local messages = params.messages + local conformed_messages = {} + for _, message in ipairs(messages) do + table.insert(conformed_messages, { + role = message.role, + content = utils.gather_text_from_parts(message.content), + }) + end + + -- Insert the updated params.system string at the beginning of conformed_messages + if params.system then + table.insert(conformed_messages, 1, { + role = "system", + content = params.system, + }) + end + + params.messages = conformed_messages + + -- general clean up, remove things that shouldnt be here + for key, value in pairs(params) do + if not vim.tbl_contains(self.api_parameters, key) then + utils.log("Did not process " .. key .. " for " .. self.name) + params[key] = nil + end + end + + return params +end + +function OpenRouter:process_response(response) + local chunk = response:pop_chunk() + local raw_json = string.gsub(chunk, "^data:", "") + local _ok, _json = pcall(vim.json.decode, raw_json) + if _ok then + self:_process_line({ json = _json, raw = chunk }, response) + else + self:_process_line({ json = nil, raw = chunk }, response) + end +end + +function OpenRouter:_process_line(content, response) + local _json = content.json + local _raw = content.raw + if _json then + local text_delta = vim.tbl_get(_json, "choices", 1, "delta", "content") + local text = vim.tbl_get(_json, "choices", 1, "message", "content") + if text_delta then + response:add_processed_text(text_delta, "CONTINUE") + elseif text then + -- done + end + elseif not _json and string.find(_raw, "[DONE]") then + -- done + else + response:could_not_process(_raw) + utils.log("Could not process chunk for openrouter: " .. _raw, vim.log.levels.DEBUG) + end +end + +return OpenRouter