diff --git a/lua/ogpt/api.lua b/lua/ogpt/api.lua index 0b25c93..27594af 100644 --- a/lua/ogpt/api.lua +++ b/lua/ogpt/api.lua @@ -17,6 +17,8 @@ function Api.chat_completions(custom_params, cb, should_stop, opts) local ctx = {} -- add params before conform ctx.params = params + local _model = params.model + params.model = nil if stream then params = Utils.conform_to_ollama(params) local raw_chunks = "" @@ -30,7 +32,7 @@ function Api.chat_completions(custom_params, cb, should_stop, opts) "--silent", "--show-error", "--no-buffer", - Api.CHAT_COMPLETIONS_URL, + Utils.update_url_route(Api.CHAT_COMPLETIONS_URL, _model), "-H", "Content-Type: application/json", "-H", @@ -275,8 +277,8 @@ function Api.setup() loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", function(value) Api.OLLAMA_API_HOST = value Api.MODELS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/tags") - Api.COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/generate") - Api.CHAT_COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/generate") + Api.COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/") + Api.CHAT_COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/") end, "http://localhost:11434") loadApiKey("OLLAMA_API_KEY", "OLLAMA_API_KEY", "api_key_cmd", function(value) diff --git a/lua/ogpt/config.lua b/lua/ogpt/config.lua index 0eda46a..838b29a 100644 --- a/lua/ogpt/config.lua +++ b/lua/ogpt/config.lua @@ -165,16 +165,16 @@ function M.defaults() }, api_params = { - model = "mistral:7b", + model = "mixtral-8-7b-moe-instruct-tgi-predictor-ai-factory", temperature = 0.8, - top_p = 1, + top_p = 0.9, }, api_edit_params = { model = "mistral:7b", frequency_penalty = 0, presence_penalty = 0, temperature = 0.5, - top_p = 1, + top_p = 0.9, }, use_openai_functions_for_edits = false, actions = { diff --git a/lua/ogpt/flows/actions/chat/init.lua b/lua/ogpt/flows/actions/chat/init.lua index a74d423..dc0ac55 100644 --- a/lua/ogpt/flows/actions/chat/init.lua +++ b/lua/ogpt/flows/actions/chat/init.lua @@ -34,8 +34,9 @@ local STRATEGY_QUICK_FIX = "quick_fix" function ChatAction:init(opts) self.super:init(opts) self.params = opts.params or {} - self.system = opts.system or "" - self.template = opts.template or "{{input}}" + self.system = type(opts.system) == "function" and opts.system() or opts.system or "" + self.template = type(opts.template) == "function" and opts.template() or opts.template or "{{input}}" + -- self.instruction = type(opts.instruction) == "function" and opts.instruction() or opts.instruction or "" self.variables = opts.variables or {} self.strategy = opts.strategy or STRATEGY_APPEND self.ui = opts.ui or {} @@ -52,7 +53,11 @@ function ChatAction:render_template() data = vim.tbl_extend("force", {}, data, self.variables) local result = self.template for key, value in pairs(data) do - result = result:gsub("{{" .. key .. "}}", value) + if type(value) == "function" then + value = value() + end + local escaped_value = Utils.escape_pattern(value) + result = string.gsub(result, "{{" .. key .. "}}", escaped_value) end return result end @@ -66,7 +71,7 @@ function ChatAction:get_params() table.insert(messages, message) return vim.tbl_extend("force", Config.options.api_params, self.params, { messages = messages, - system = self.system, + system = self.system and "" == self.system and nil or self.system, }) end diff --git a/lua/ogpt/utils.lua b/lua/ogpt/utils.lua index 378d192..d41a538 100644 --- a/lua/ogpt/utils.lua +++ b/lua/ogpt/utils.lua @@ -186,15 +186,15 @@ end function M._conform_to_ollama_api(params) local ollama_parameters = { - "model", - "prompt", - "format", - "options", - "system", - "template", - "context", + -- "model", + "inputs", + -- "format", + "parameters", + -- "system", + -- "template", + -- "context", "stream", - "raw", + -- "raw", } -- https://github.com/jmorganca/ollama/blob/main/docs/api.md#show-model-information @@ -207,9 +207,9 @@ function M._conform_to_ollama_api(params) params[key] = nil end end - local _options = vim.tbl_extend("keep", param_options, params.options or {}) + local _options = vim.tbl_extend("keep", param_options, params.parameters or {}) if next(_options) ~= nil then - params.options = _options + params.parameters = _options end return params end @@ -218,8 +218,8 @@ function M.conform_to_ollama(params) if params.messages then local messages = params.messages params.messages = nil - params.system = params.system or "" - params.prompt = params.prompt or "" + -- params.system = params.system or "" + params.inputs = params.inputs or "" for _, message in ipairs(messages) do if message.role == "system" then params.system = params.system .. "\n" .. message.content .. "\n" @@ -228,7 +228,7 @@ function M.conform_to_ollama(params) for _, message in ipairs(messages) do if message.role == "user" then - params.prompt = params.prompt .. "\n" .. message.content .. "\n" + params.inputs = params.inputs .. "\n" .. message.content .. "\n" end end end @@ -277,4 +277,27 @@ function M.tableToString(tbl, indent) return str end +function M.update_url_route(url, new_route) + -- Extract the base URL parts + local base_url_parts = { url:match("^(https?://([^/]*))") } + local base_protocol = base_url_parts[1] + local base_domain = base_url_parts[2] + + -- Construct the new URL + local new_url = base_protocol .. "://" .. new_route .. "." .. base_domain + + -- Preserve any path, query string, or fragment from the original URL + local path_and_query = url:match("^https?://[^/]*(/.*)$") + if path_and_query then + new_url = new_url .. path_and_query + end + + return new_url +end + +function M.escape_pattern(text) + -- https://stackoverflow.com/a/34953646/4780010 + return text:gsub("([^%w])", "%%%1") +end + return M