diff --git a/lua/ogpt/api.lua b/lua/ogpt/api.lua
index ff4a843..7fd3c77 100644
--- a/lua/ogpt/api.lua
+++ b/lua/ogpt/api.lua
@@ -5,6 +5,19 @@ local Utils = require("ogpt.utils")
local Api = {}
+function Api.get_provider()
+ local provider
+ if type(Config.options.default_provider) == "string" then
+ provider = require("ogpt.provider." .. Config.options.default_provider)
+ else
+ provider = require("ogpt.provider." .. Config.options.default_provider.name)
+ provider.envs = vim.tbl_extend("force", provider.envs, Config.options.default_provider)
+ end
+ local envs = provider.load_envs()
+ Api = vim.tbl_extend("force", Api, envs)
+ return provider
+end
+
function Api.completions(custom_params, cb)
local params = vim.tbl_extend("keep", custom_params, Config.options.api_params)
params.stream = false
@@ -14,11 +27,26 @@ end
function Api.chat_completions(custom_params, cb, should_stop, opts)
local params = vim.tbl_extend("keep", custom_params, Config.options.api_params)
local stream = params.stream or false
+ local _model = params.model
+
+ local _completion_url = Api.CHAT_COMPLETIONS_URL
+ if type(_model) == "table" then
+ if _model.modify_url and type(_model.modify_url) == "function" then
+ _completion_url = _model.modify_url(_completion_url)
+ else
+ _completion_url = _model.modify_url
+ end
+ end
+
+ if _model and _model.conform_fn then
+ params = _model.conform_fn(params)
+ else
+ params = Api.provider.conform(params)
+ end
+
local ctx = {}
- -- add params before conform
ctx.params = params
if stream then
- params = Utils.conform_to_ollama(params)
local raw_chunks = ""
local state = "START"
@@ -30,7 +58,7 @@ function Api.chat_completions(custom_params, cb, should_stop, opts)
"--silent",
"--show-error",
"--no-buffer",
- Api.CHAT_COMPLETIONS_URL,
+ _completion_url,
"-H",
"Content-Type: application/json",
"-H",
@@ -39,33 +67,18 @@ function Api.chat_completions(custom_params, cb, should_stop, opts)
vim.json.encode(params),
},
function(chunk)
- local process_line = function(_ok, _json)
- if _json and _json.done then
- ctx.context = _json.context
- cb(raw_chunks, "END", ctx)
- else
- if _ok and not vim.tbl_isempty(_json) then
- if _json and _json.message then
- cb(_json.message.content, state, ctx)
- raw_chunks = raw_chunks .. _json.message.content
- state = "CONTINUE"
- end
- end
- end
- end
-
local ok, json = pcall(vim.json.decode, chunk)
if ok and json ~= nil then
if json.error ~= nil then
cb(json.error, "ERROR", ctx)
return
end
- process_line(ok, json)
+ ctx, raw_chunks, state = Api.provider.process_line(_ok, _json, ctx, raw_chunks, state, cb)
else
for line in chunk:gmatch("[^\n]+") do
- local raw_json = string.gsub(line, "^data: ", "")
+ local raw_json = string.gsub(line, "^data:", "")
local _ok, _json = pcall(vim.json.decode, raw_json)
- process_line(_ok, _json)
+ ctx, raw_chunks, state = Api.provider.process_line(_ok, _json, ctx, raw_chunks, state, cb)
end
end
end,
@@ -93,8 +106,6 @@ function Api.edits(custom_params, cb)
end
function Api.make_call(url, params, cb)
- params = Utils.conform_to_ollama(params)
-
TMP_MSG_FILENAME = os.tmpname()
local f = io.open(TMP_MSG_FILENAME, "w+")
if f == nil then
@@ -274,23 +285,21 @@ local function ensureUrlProtocol(str)
end
function Api.setup()
- loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", function(value)
- Api.OLLAMA_API_HOST = value
- Api.MODELS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/tags")
- Api.COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/generate")
- Api.CHAT_COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/chat")
- end, "http://localhost:11434")
-
- loadApiKey("OLLAMA_API_KEY", "OLLAMA_API_KEY", "api_key_cmd", function(value)
- Api.OLLAMA_API_KEY = value
- loadConfigFromEnv("OPENAI_API_TYPE", "OPENAI_API_TYPE")
- if Api["OPENAI_API_TYPE"] == "azure" then
- loadAzureConfigs()
- Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OLLAMA_API_KEY
- else
- Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OLLAMA_API_KEY
- end
- end, " ")
+ local provider = Api.get_provider()
+ Api.provider = provider
+
+ -- loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", provider.make_url, "http://localhost:11434")
+
+ -- loadApiKey("OLLAMA_API_KEY", "OLLAMA_API_KEY", "api_key_cmd", function(value)
+ -- Api.OLLAMA_API_KEY = value
+ -- loadConfigFromEnv("OPENAI_API_TYPE", "OPENAI_API_TYPE")
+ -- if Api["OPENAI_API_TYPE"] == "azure" then
+ -- loadAzureConfigs()
+ -- Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OLLAMA_API_KEY
+ -- else
+ -- Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OLLAMA_API_KEY
+ -- end
+ -- end, " ")
end
function Api.exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop)
diff --git a/lua/ogpt/config.lua b/lua/ogpt/config.lua
index a5b0424..055c7d6 100644
--- a/lua/ogpt/config.lua
+++ b/lua/ogpt/config.lua
@@ -9,6 +9,11 @@ local M = {}
function M.defaults()
local defaults = {
api_key_cmd = nil,
+ default_provider = {
+ name = "ollama",
+ api_host = os.getenv("OLLAMA_API_HOST"),
+ api_key = os.getenv("OLLAMA_API_KEY"),
+ },
yank_register = "+",
edit_with_instructions = {
diff = false,
diff --git a/lua/ogpt/flows/chat/base.lua b/lua/ogpt/flows/chat/base.lua
index f5991ef..a2b00b9 100644
--- a/lua/ogpt/flows/chat/base.lua
+++ b/lua/ogpt/flows/chat/base.lua
@@ -282,78 +282,6 @@ function Chat:addAnswerPartial(text, state, ctx)
end
end
--- function Chat:addAnswerPartial(text, state, ctx)
--- if state == "ERROR" then
--- return self:addAnswer(text, {})
--- end
---
--- local start_line = 0
--- if self.selectedIndex > 0 then
--- local prev = self.messages[self.selectedIndex]
--- start_line = prev.end_line + (prev.type == ANSWER and 2 or 1)
--- end
---
--- if state == "END" then
--- local usage = {}
--- local idx = self.session:add_item({
--- type = ANSWER,
--- text = text,
--- ctx = ctx or {},
--- usage = usage,
--- })
---
--- local lines = {}
--- local nr_of_lines = 0
--- for line in string.gmatch(text, "[^\n]+") do
--- nr_of_lines = nr_of_lines + 1
--- table.insert(lines, line)
--- end
---
--- local end_line = start_line + nr_of_lines - 1
--- table.insert(self.messages, {
--- idx = idx,
--- usage = usage or {},
--- type = ANSWER,
--- text = text,
--- lines = lines,
--- nr_of_lines = nr_of_lines,
--- start_line = start_line,
--- end_line = end_line,
--- context = ctx.context,
--- })
--- self.selectedIndex = self.selectedIndex + 1
--- vim.api.nvim_buf_set_lines(self.chat_window.bufnr, -1, -1, false, { "", "" })
--- Signs.set_for_lines(self.chat_window.bufnr, start_line, end_line, "chat")
--- end
---
--- if state == "START" then
--- self:stopSpinner()
--- self:set_lines(-2, -1, false, { "" })
--- vim.api.nvim_buf_set_option(self.chat_window.bufnr, "modifiable", true)
--- end
---
--- if state == "START" or state == "CONTINUE" then
--- local lines = vim.split(text, "\n", {})
--- local length = #lines
--- local buffer = self.chat_window.bufnr
--- local win = self.chat_window.winid
---
--- for i, line in ipairs(lines) do
--- local currentLine = vim.api.nvim_buf_get_lines(buffer, -2, -1, false)[1]
--- vim.api.nvim_buf_set_lines(buffer, -2, -1, false, { currentLine .. line })
---
--- local last_line_num = vim.api.nvim_buf_line_count(buffer)
--- Signs.set_for_lines(self.chat_window.bufnr, start_line, last_line_num - 1, "chat")
--- if i == length and i > 1 then
--- vim.api.nvim_buf_set_lines(buffer, -1, -1, false, { "" })
--- end
--- if self:is_buf_visiable() then
--- vim.api.nvim_win_set_cursor(win, { last_line_num, 0 })
--- end
--- end
--- end
--- end
-
function Chat:get_total_tokens()
local total_tokens = 0
for i = 1, #self.messages, 1 do
@@ -567,16 +495,6 @@ function Chat:toMessages()
return messages
end
-function Chat:toOllama(messages)
- local output = ""
- for _, entry in ipairs(messages) do
- if entry.content then
- output = output .. entry.role .. ": " .. entry.content .. "\n\n"
- end
- end
- return output
-end
-
function Chat:count()
local count = 0
for _ in pairs(self.messages) do
diff --git a/lua/ogpt/provider/ollama.lua b/lua/ogpt/provider/ollama.lua
new file mode 100644
index 0000000..e79d70c
--- /dev/null
+++ b/lua/ogpt/provider/ollama.lua
@@ -0,0 +1,110 @@
+local utils = require("ogpt.utils")
+
+local M = {}
+
+M.envs = {
+ api_host = os.getenv("OGPT_API_HOST"),
+ api_key = os.getenv("OGPT_API_KEY"),
+}
+
+function M.load_envs()
+ local _envs = {}
+ _envs.OLLAMA_API_HOST = M.envs.api_host
+ _envs.OLLAMA_API_KEY = M.envs.api_key
+ _envs.MODELS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/tags")
+ _envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/generate")
+ _envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/chat")
+ _envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (_envs.OLLAMA_API_KEY or " ")
+ M.envs = vim.tbl_extend("force", M.envs, _envs)
+ return M.envs
+end
+
+M.ollama_options = {
+ "num_keep",
+ "seed",
+ "num_predict",
+ "top_k",
+ "top_p",
+ "tfs_z",
+ "typical_p",
+ "repeat_last_n",
+ "temperature",
+ "repeat_penalty",
+ "presence_penalty",
+ "frequency_penalty",
+ "mirostat",
+ "mirostat_tau",
+ "mirostat_eta",
+ "penalize_newline",
+ "stop",
+ "numa",
+ "num_ctx",
+ "num_batch",
+ "num_gqa",
+ "num_gpu",
+ "main_gpu",
+ "low_vram",
+ "f16_kv",
+ "logits_all",
+ "vocab_only",
+ "use_mmap",
+ "use_mlock",
+ "embedding_only",
+ "rope_frequency_base",
+ "rope_frequency_scale",
+ "num_thread",
+}
+
+function M.conform(params)
+ local ollama_parameters = {
+ "model",
+ -- "prompt",
+ "messages",
+ "format",
+ "options",
+ "system",
+ "template",
+ -- "context",
+ "stream",
+ "raw",
+ }
+
+ -- https://github.com/jmorganca/ollama/blob/main/docs/api.md#show-model-information
+
+ local param_options = {}
+
+ for key, value in pairs(params) do
+ if not vim.tbl_contains(ollama_parameters, key) then
+ if vim.tbl_contains(M.ollama_options, key) then
+ param_options[key] = value
+ params[key] = nil
+ else
+ params[key] = nil
+ end
+ end
+ end
+ local _options = vim.tbl_extend("keep", param_options, params.options or {})
+ if next(_options) ~= nil then
+ params.options = _options
+ end
+ return params
+end
+
+function M.process_line(_ok, _json, ctx, raw_chunks, state, cb)
+ if _json and _json.done then
+ ctx.context = _json.context
+ cb(raw_chunks, "END", ctx)
+ else
+ if _ok and not vim.tbl_isempty(_json) then
+ if _json and _json.message then
+ cb(_json.message.content, state, ctx)
+ raw_chunks = raw_chunks .. _json.message.content
+ state = "CONTINUE"
+ end
+ end
+ end
+
+ return ctx, raw_chunks, state
+end
+
+return M
diff --git a/lua/ogpt/provider/textgenui.lua b/lua/ogpt/provider/textgenui.lua
new file mode 100644
index 0000000..a21e43d
--- /dev/null
+++ b/lua/ogpt/provider/textgenui.lua
@@ -0,0 +1,109 @@
+local utils = require("ogpt.utils")
+local M = {}
+
+M.envs = {
+ api_host = os.getenv("OGPT_API_HOST"),
+ api_key = os.getenv("OGPT_API_KEY"),
+}
+
+function M.load_envs()
+ local _envs = {}
+ _envs.MODELS_URL = utils.ensureUrlProtocol(M.envs.api_host .. "/api/tags")
+ _envs.COMPLETIONS_URL = utils.ensureUrlProtocol(M.envs.api_host)
+ _envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(M.envs.api_host)
+ _envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (M.envs.api_key or " ")
+ M.envs = vim.tbl_extend("force", M.envs, _envs)
+ return M.envs
+end
+
+M.textgenui_options = { "seed", "top_k", "top_p", "stop" }
+
+function M.conform_to_textgenui_api(params)
+ local model_params = {
+ "seed",
+ "top_k",
+ "top_p",
+ "stop",
+ }
+
+ local request_params = {
+ "inputs",
+ "parameters",
+ "stream",
+ }
+
+ local param_options = {}
+
+ for key, value in pairs(params) do
+ if not vim.tbl_contains(request_params, key) then
+ if vim.tbl_contains(model_params, key) then
+ param_options[key] = value
+ params[key] = nil
+ else
+ params[key] = nil
+ end
+ end
+ end
+ local _options = vim.tbl_extend("keep", param_options, params.options or {})
+ if next(_options) ~= nil then
+ params.parameters = _options
+ end
+ return params
+end
+function M.update_messages(messages)
+ -- https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
+ local tokens = {
+ BOS = "",
+ EOS = "",
+ INST_START = "[INST]",
+ INST_END = "[/INST]",
+ }
+ local _input = { tokens.BOS }
+ for i, message in ipairs(messages) do
+ if i < #messages then -- Stop before the last item
+ if message.role == "user" then
+ table.insert(_input, tokens.INST_START)
+ table.insert(_input, message.content)
+ table.insert(_input, tokens.INST_END)
+ elseif message.role == "system" then
+ table.insert(_input, message.content)
+ if i == #message - 1 then
+ table.insert(_input, tokens.EOS)
+ end
+ end
+ else
+ table.insert(_input, tokens.INST_START)
+ table.insert(_input, message.content)
+ table.insert(_input, tokens.INST_END)
+ end
+ end
+ local final_string = table.concat(_input, " ")
+ return final_string
+end
+
+function M.conform(params)
+ params = params or {}
+ params.inputs = M.update_messages(params.messages or {})
+ return M.conform_to_textgenui_api(params)
+end
+
+function M.process_line(_ok, _json, ctx, raw_chunks, state, cb)
+ if not _ok then
+ return
+ end
+ if _json and (_json.details ~= vim.NIL) and (_json.details.finished_reason == "eos_token") then
+ ctx.context = _json.context
+ cb(raw_chunks, "END", ctx)
+ else
+ if _ok and not vim.tbl_isempty(_json) then
+ if _json and _json.token then
+ cb(_json.token.text, state, ctx)
+ raw_chunks = raw_chunks .. _json.token.text
+ state = "CONTINUE"
+ end
+ end
+ end
+ return ctx, raw_chunks, state
+end
+
+return M
diff --git a/lua/ogpt/utils.lua b/lua/ogpt/utils.lua
index 25e5565..c7e8fa2 100644
--- a/lua/ogpt/utils.lua
+++ b/lua/ogpt/utils.lua
@@ -38,6 +38,13 @@ M.ollama_options = {
"num_thread",
}
+M.textgenui_options = {
+ "seed",
+ "top_k",
+ "top_p",
+ "stop",
+}
+
function M.split(text)
local t = {}
for str in string.gmatch(text, "%S+") do
@@ -241,6 +248,63 @@ function M.conform_to_ollama(params)
return M._conform_to_ollama_api(params)
end
+function M._conform_to_textgenui_api(params)
+ local model_params = {
+ "seed",
+ "top_k",
+ "top_p",
+ "stop",
+ }
+
+ local request_params = {
+ "inputs",
+ "parameters",
+ "stream",
+ }
+
+ local param_options = {}
+
+ for key, value in pairs(params) do
+ if not vim.tbl_contains(request_params, key) then
+ if vim.tbl_contains(model_params, key) then
+ param_options[key] = value
+ params[key] = nil
+ else
+ params[key] = nil
+ end
+ end
+ end
+ local _options = vim.tbl_extend("keep", param_options, params.options or {})
+ if next(_options) ~= nil then
+ params.parameters = _options
+ end
+ return params
+end
+
+function M.conform_to_textgenui(params)
+ -- conform to mixtral
+ -- [INST] Instruction [/INST] Model answer [INST] Follow-up instruction [/INST]
+ if params.messages then
+ local messages = params.messages
+ params.messages = nil
+ -- params.system = params.system or ""
+ params.inputs = params.inputs or ""
+ -- for _, message in ipairs(messages) do
+ -- if message.role == "system" then
+ -- params.system = params.system .. "\n" .. message.content .. "\n"
+ -- end
+ -- end
+
+ for _, message in ipairs(messages) do
+ if message.role == "user" then
+ params.inputs = params.inputs .. "\n" .. message.content .. "\n"
+ end
+ end
+ end
+
+ return M._conform_to_textgenui_api(params)
+end
+
function M.extract_code(text)
-- Iterate through all code blocks in the message using a regular expression pattern
local lastCodeBlock
@@ -394,4 +458,33 @@ function M.escape_pattern(text)
return text:gsub("([^%w])", "%%%1")
end
+function M.update_url_route(url, new_model)
+ local host = url:match("https?://([^/]+)")
+ local subdomain, domain, tld = host:match("([^.]+)%.([^.]+)%.([^.]+)")
+ local _new_url = url:gsub(host, new_model .. "." .. domain .. "." .. tld)
+ return _new_url
+end
+
+function M.to_model_string(messages)
+ local output = ""
+ for _, entry in ipairs(messages) do
+ if entry.content then
+ output = output .. entry.role .. ": " .. entry.content .. "\n\n"
+ end
+ end
+ return output
+end
+
+function M.startsWith(str, start)
+ return string.sub(str, 1, string.len(start)) == start
+end
+
+function M.ensureUrlProtocol(str)
+ if M.startsWith(str, "https://") or M.startsWith(str, "http://") then
+ return str
+ end
+
+ return "https://" .. str
+end
+
return M