Skip to content

Commit

Permalink
getting textgenui api to work
Browse files Browse the repository at this point in the history
  • Loading branch information
huynle committed Jan 8, 2024
1 parent 1e5c815 commit 868809f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 23 deletions.
8 changes: 5 additions & 3 deletions lua/ogpt/api.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ function Api.chat_completions(custom_params, cb, should_stop, opts)
local ctx = {}
-- add params before conform
ctx.params = params
local _model = params.model
params.model = nil
if stream then
params = Utils.conform_to_ollama(params)
local raw_chunks = ""
Expand All @@ -30,7 +32,7 @@ function Api.chat_completions(custom_params, cb, should_stop, opts)
"--silent",
"--show-error",
"--no-buffer",
Api.CHAT_COMPLETIONS_URL,
Utils.update_url_route(Api.CHAT_COMPLETIONS_URL, _model),
"-H",
"Content-Type: application/json",
"-H",
Expand Down Expand Up @@ -275,8 +277,8 @@ function Api.setup()
loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", function(value)
Api.OLLAMA_API_HOST = value
Api.MODELS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/tags")
Api.COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/generate")
Api.CHAT_COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/generate")
Api.COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/")
Api.CHAT_COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/")
end, "http://localhost:11434")

loadApiKey("OLLAMA_API_KEY", "OLLAMA_API_KEY", "api_key_cmd", function(value)
Expand Down
6 changes: 3 additions & 3 deletions lua/ogpt/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,16 @@ function M.defaults()
},

api_params = {
model = "mistral:7b",
model = "mixtral-8-7b-moe-instruct-tgi-predictor-ai-factory",
temperature = 0.8,
top_p = 1,
top_p = 0.9,
},
api_edit_params = {
model = "mistral:7b",
frequency_penalty = 0,
presence_penalty = 0,
temperature = 0.5,
top_p = 1,
top_p = 0.9,
},
use_openai_functions_for_edits = false,
actions = {
Expand Down
13 changes: 9 additions & 4 deletions lua/ogpt/flows/actions/chat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ local STRATEGY_QUICK_FIX = "quick_fix"
function ChatAction:init(opts)
self.super:init(opts)
self.params = opts.params or {}
self.system = opts.system or ""
self.template = opts.template or "{{input}}"
self.system = type(opts.system) == "function" and opts.system() or opts.system or ""
self.template = type(opts.template) == "function" and opts.template() or opts.template or "{{input}}"
-- self.instruction = type(opts.instruction) == "function" and opts.instruction() or opts.instruction or ""
self.variables = opts.variables or {}
self.strategy = opts.strategy or STRATEGY_APPEND
self.ui = opts.ui or {}
Expand All @@ -52,7 +53,11 @@ function ChatAction:render_template()
data = vim.tbl_extend("force", {}, data, self.variables)
local result = self.template
for key, value in pairs(data) do
result = result:gsub("{{" .. key .. "}}", value)
if type(value) == "function" then
value = value()
end
local escaped_value = Utils.escape_pattern(value)
result = string.gsub(result, "{{" .. key .. "}}", escaped_value)
end
return result
end
Expand All @@ -66,7 +71,7 @@ function ChatAction:get_params()
table.insert(messages, message)
return vim.tbl_extend("force", Config.options.api_params, self.params, {
messages = messages,
system = self.system,
system = self.system and "" == self.system and nil or self.system,
})
end

Expand Down
49 changes: 36 additions & 13 deletions lua/ogpt/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,15 @@ end

function M._conform_to_ollama_api(params)
local ollama_parameters = {
"model",
"prompt",
"format",
"options",
"system",
"template",
"context",
-- "model",
"inputs",
-- "format",
"parameters",
-- "system",
-- "template",
-- "context",
"stream",
"raw",
-- "raw",
}

-- https://github.com/jmorganca/ollama/blob/main/docs/api.md#show-model-information
Expand All @@ -207,9 +207,9 @@ function M._conform_to_ollama_api(params)
params[key] = nil
end
end
local _options = vim.tbl_extend("keep", param_options, params.options or {})
local _options = vim.tbl_extend("keep", param_options, params.parameters or {})
if next(_options) ~= nil then
params.options = _options
params.parameters = _options
end
return params
end
Expand All @@ -218,8 +218,8 @@ function M.conform_to_ollama(params)
if params.messages then
local messages = params.messages
params.messages = nil
params.system = params.system or ""
params.prompt = params.prompt or ""
-- params.system = params.system or ""
params.inputs = params.inputs or ""
for _, message in ipairs(messages) do
if message.role == "system" then
params.system = params.system .. "\n" .. message.content .. "\n"
Expand All @@ -228,7 +228,7 @@ function M.conform_to_ollama(params)

for _, message in ipairs(messages) do
if message.role == "user" then
params.prompt = params.prompt .. "\n" .. message.content .. "\n"
params.inputs = params.inputs .. "\n" .. message.content .. "\n"
end
end
end
Expand Down Expand Up @@ -277,4 +277,27 @@ function M.tableToString(tbl, indent)
return str
end

function M.update_url_route(url, new_route)
-- Extract the base URL parts
local base_url_parts = { url:match("^(https?://([^/]*))") }
local base_protocol = base_url_parts[1]
local base_domain = base_url_parts[2]

-- Construct the new URL
local new_url = base_protocol .. "://" .. new_route .. "." .. base_domain

-- Preserve any path, query string, or fragment from the original URL
local path_and_query = url:match("^https?://[^/]*(/.*)$")
if path_and_query then
new_url = new_url .. path_and_query
end

return new_url
end

function M.escape_pattern(text)
-- https://stackoverflow.com/a/34953646/4780010
return text:gsub("([^%w])", "%%%1")
end

return M

0 comments on commit 868809f

Please sign in to comment.