Skip to content

Commit

Permalink
Adding support for TexgenUI (Huggingface API) (#14)
Browse files Browse the repository at this point in the history
* adding textgenui

* things are working

* update to take on textgenui API, and allow for custom functions to run for different model implementation

* updating doc

* more updates

* more updates

* getting things updated

* updating

* updating to have textgenui

* updating readme
  • Loading branch information
huynle authored Jan 11, 2024
1 parent de42b1d commit 7490b6a
Show file tree
Hide file tree
Showing 10 changed files with 855 additions and 251 deletions.
478 changes: 388 additions & 90 deletions README.md

Large diffs are not rendered by default.

107 changes: 60 additions & 47 deletions lua/ogpt/api.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ local Utils = require("ogpt.utils")

local Api = {}

function Api.get_provider()
local provider
if type(Config.options.default_provider) == "string" then
provider = require("ogpt.provider." .. Config.options.default_provider)
else
provider = require("ogpt.provider." .. Config.options.default_provider.name)
provider.envs = vim.tbl_extend("force", provider.envs, Config.options.default_provider)
end
local envs = provider.load_envs()
Api = vim.tbl_extend("force", Api, envs)
return provider
end

function Api.completions(custom_params, cb)
local params = vim.tbl_extend("keep", custom_params, Config.options.api_params)
params.stream = false
Expand All @@ -14,11 +27,30 @@ end
function Api.chat_completions(custom_params, cb, should_stop, opts)
local params = vim.tbl_extend("keep", custom_params, Config.options.api_params)
local stream = params.stream or false
local _model = params.model

local _completion_url = Api.CHAT_COMPLETIONS_URL
if type(_model) == "table" then
if _model.modify_url and type(_model.modify_url) == "function" then
_completion_url = _model.modify_url(_completion_url)
else
_completion_url = _model.modify_url
end
end

if _model and _model.conform_fn then
params = _model.conform_fn(params)
else
params = Api.provider.conform(params)
end

local ctx = {}
-- add params before conform
ctx.params = params
if Config.options.debug then
vim.notify("Request to: " .. _completion_url, vim.log.levels.DEBUG, { title = "OGPT Debug" })
end

if stream then
params = Utils.conform_to_ollama(params)
local raw_chunks = ""
local state = "START"

Expand All @@ -30,7 +62,7 @@ function Api.chat_completions(custom_params, cb, should_stop, opts)
"--silent",
"--show-error",
"--no-buffer",
Api.CHAT_COMPLETIONS_URL,
_completion_url,
"-H",
"Content-Type: application/json",
"-H",
Expand All @@ -39,37 +71,24 @@ function Api.chat_completions(custom_params, cb, should_stop, opts)
vim.json.encode(params),
},
function(chunk)
local process_line = function(_ok, _json)
if _json and _json.done then
ctx.context = _json.context
cb(raw_chunks, "END", ctx)
else
if _ok and not vim.tbl_isempty(_json) then
if _json and _json.message then
cb(_json.message.content, state, ctx)
raw_chunks = raw_chunks .. _json.message.content
state = "CONTINUE"
end
end
end
end

local ok, json = pcall(vim.json.decode, chunk)
if ok and json ~= nil then
if ok then
if json.error ~= nil then
cb(json.error, "ERROR", ctx)
return
end
process_line(ok, json)
else
for line in chunk:gmatch("[^\n]+") do
local raw_json = string.gsub(line, "^data: ", "")
local _ok, _json = pcall(vim.json.decode, raw_json)
process_line(_ok, _json)
ctx, raw_chunks, state = Api.provider.process_line(json, ctx, raw_chunks, state, cb)
return
end

for line in chunk:gmatch("[^\n]+") do
local raw_json = string.gsub(line, "^data:", "")
local _ok, _json = pcall(vim.json.decode, raw_json)
if _ok then
ctx, raw_chunks, state = Api.provider.process_line(_json, ctx, raw_chunks, state, cb)
end
end
end,

function(err, _)
cb(err, "ERROR", ctx)
end,
Expand All @@ -86,15 +105,11 @@ end

function Api.edits(custom_params, cb)
local params = vim.tbl_extend("keep", custom_params, Config.options.api_edit_params)
-- params.stream = params.stream or false
-- Api.make_call(Api.CHAT_COMPLETIONS_URL, params, cb)
params.stream = true
Api.chat_completions(params, cb)
end

function Api.make_call(url, params, cb)
params = Utils.conform_to_ollama(params)

TMP_MSG_FILENAME = os.tmpname()
local f = io.open(TMP_MSG_FILENAME, "w+")
if f == nil then
Expand Down Expand Up @@ -274,23 +289,21 @@ local function ensureUrlProtocol(str)
end

function Api.setup()
loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", function(value)
Api.OLLAMA_API_HOST = value
Api.MODELS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/tags")
Api.COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/generate")
Api.CHAT_COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/chat")
end, "http://localhost:11434")

loadApiKey("OLLAMA_API_KEY", "OLLAMA_API_KEY", "api_key_cmd", function(value)
Api.OLLAMA_API_KEY = value
loadConfigFromEnv("OPENAI_API_TYPE", "OPENAI_API_TYPE")
if Api["OPENAI_API_TYPE"] == "azure" then
loadAzureConfigs()
Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OLLAMA_API_KEY
else
Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OLLAMA_API_KEY
end
end, " ")
local provider = Api.get_provider()
Api.provider = provider

-- loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", provider.make_url, "http://localhost:11434")

-- loadApiKey("OLLAMA_API_KEY", "OLLAMA_API_KEY", "api_key_cmd", function(value)
-- Api.OLLAMA_API_KEY = value
-- loadConfigFromEnv("OPENAI_API_TYPE", "OPENAI_API_TYPE")
-- if Api["OPENAI_API_TYPE"] == "azure" then
-- loadAzureConfigs()
-- Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OLLAMA_API_KEY
-- else
-- Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OLLAMA_API_KEY
-- end
-- end, " ")
end

function Api.exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop)
Expand Down
48 changes: 43 additions & 5 deletions lua/ogpt/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,22 @@ WELCOME_MESSAGE = [[
local M = {}
function M.defaults()
local defaults = {
debug = false,
api_key_cmd = nil,
default_provider = {
-- can also support `textgenui`
name = "ollama",
api_host = os.getenv("OLLAMA_API_HOST"),
api_key = os.getenv("OLLAMA_API_KEY"),
},
yank_register = "+",
edit_with_instructions = {
diff = false,
keymaps = {
close = "<C-c>",
accept = "<C-y>",
toggle_diff = "<C-d>",
toggle_parameters = "<C-o>",
accept = "<C-y>", -- accept the output and write to original buffer
toggle_diff = "<C-d>", -- view the diff between left and right panes and use diff-mode
toggle_parameters = "<C-o>", -- Toggle parameters window
cycle_windows = "<Tab>",
use_output_as_input = "<C-u>",
},
Expand Down Expand Up @@ -177,16 +184,38 @@ function M.defaults()
},

api_params = {
-- takes a string or a table
model = "mistral:7b",
-- model = {
-- -- create a modify url specifically for mixtral to run
-- name = "mixtral-8-7b",
-- modify_url = function(url)
-- -- given a URL, this function modifies the URL specifically to the model
-- -- This is useful when you have different models hosted on different subdomains like
-- -- https://model1.yourdomain.com/
-- -- https://model2.yourdomain.com/
-- local new_model = "mixtral-8-7b"
-- local host = url:match("https?://([^/]+)")
-- local subdomain, domain, tld = host:match("([^.]+)%.([^.]+)%.([^.]+)")
-- local _new_url = url:gsub(host, new_model .. "." .. domain .. "." .. tld)
-- return _new_url
-- end,
-- conform_fn = function(params)
-- -- Different models might have different instruction format
-- -- for example, Mixtral operates on `<s> [INST] Instruction [/INST] Model answer</s> [INST] Follow-up instruction [/INST] `
-- end,
-- },

temperature = 0.8,
top_p = 1,
top_p = 0.99,
},
api_edit_params = {
-- used for `edit` and `edit_code` strategy in the actions
model = "mistral:7b",
frequency_penalty = 0,
presence_penalty = 0,
temperature = 0.5,
top_p = 1,
top_p = 0.99,
},
actions = {

Expand Down Expand Up @@ -247,6 +276,15 @@ M.namespace_id = vim.api.nvim_create_namespace("OGPTNS")
function M.setup(options)
options = options or {}
M.options = vim.tbl_deep_extend("force", {}, M.defaults(), options)
local _complete_replace = {
"actions",
}

for _, to_replace in pairs(_complete_replace) do
for key, item in pairs(options[to_replace]) do
M.options[to_replace][key] = item
end
end
end

return M
23 changes: 14 additions & 9 deletions lua/ogpt/flows/actions/chat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,21 @@ function ChatAction:init(opts)
self:display_input_suffix(state)
end)
end)

self:update_variables()
end

function ChatAction:render_template()
local input = self.strategy == STRATEGY_QUICK_FIX and self:get_selected_text_with_line_numbers()
or self:get_selected_text()
local data = {
function ChatAction:update_variables()
self.variables = vim.tbl_extend("force", self.variables, {
filetype = self:get_filetype(),
input = input,
}
data = vim.tbl_extend("force", {}, data, self.variables)
input = self.strategy == STRATEGY_QUICK_FIX and self:get_selected_text_with_line_numbers()
or self:get_selected_text(),
})
end

function ChatAction:render_template()
local result = self.template
for key, value in pairs(data) do
for key, value in pairs(self.variables) do
local escaped_value = Utils.escape_pattern(value)
result = string.gsub(result, "{{" .. key .. "}}", escaped_value)
end
Expand Down Expand Up @@ -140,10 +143,12 @@ function ChatAction:call_api(panel, params)
params,
Utils.partial(Utils.add_partial_completion, {
panel = panel,
-- finalize_opts = opts,
progress = function(flag)
self:run_spinner(flag)
end,
on_complete = function(total_text)
-- print("completed: " .. total_text)
end,
}),
function()
-- should stop function
Expand Down
10 changes: 7 additions & 3 deletions lua/ogpt/flows/actions/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ local read_actions_from_file = function(filename)
local json_string = file:read("*a")
file:close()

return vim.json.decode(json_string)
local ok, json = pcall(vim.json.decode, json_string)
if ok then
return json
end
end

function M.read_actions()
local actions = {}
local actions = Config.options.actions
-- local actions = {}
local paths = {}

-- add default actions
Expand All @@ -45,7 +49,7 @@ function M.read_actions()
end
end
end
return vim.tbl_extend("keep", Config.options.actions, actions)
return actions
end

function M.run_action(opts)
Expand Down
Loading

0 comments on commit 7490b6a

Please sign in to comment.