Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for TextGenUI API #13

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 49 additions & 40 deletions lua/ogpt/api.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ local Utils = require("ogpt.utils")

local Api = {}

function Api.get_provider()
local provider
if type(Config.options.default_provider) == "string" then
provider = require("ogpt.provider." .. Config.options.default_provider)
else
provider = require("ogpt.provider." .. Config.options.default_provider.name)
provider.envs = vim.tbl_extend("force", provider.envs, Config.options.default_provider)
end
local envs = provider.load_envs()
Api = vim.tbl_extend("force", Api, envs)
return provider
end

function Api.completions(custom_params, cb)
local params = vim.tbl_extend("keep", custom_params, Config.options.api_params)
params.stream = false
Expand All @@ -14,11 +27,26 @@ end
function Api.chat_completions(custom_params, cb, should_stop, opts)
local params = vim.tbl_extend("keep", custom_params, Config.options.api_params)
local stream = params.stream or false
local _model = params.model

local _completion_url = Api.CHAT_COMPLETIONS_URL
if type(_model) == "table" then
if _model.modify_url and type(_model.modify_url) == "function" then
_completion_url = _model.modify_url(_completion_url)
else
_completion_url = _model.modify_url
end
end

if _model and _model.conform_fn then
params = _model.conform_fn(params)
else
params = Api.provider.conform(params)
end

local ctx = {}
-- add params before conform
ctx.params = params
if stream then
params = Utils.conform_to_ollama(params)
local raw_chunks = ""
local state = "START"

Expand All @@ -30,7 +58,7 @@ function Api.chat_completions(custom_params, cb, should_stop, opts)
"--silent",
"--show-error",
"--no-buffer",
Api.CHAT_COMPLETIONS_URL,
_completion_url,
"-H",
"Content-Type: application/json",
"-H",
Expand All @@ -39,33 +67,18 @@ function Api.chat_completions(custom_params, cb, should_stop, opts)
vim.json.encode(params),
},
function(chunk)
local process_line = function(_ok, _json)
if _json and _json.done then
ctx.context = _json.context
cb(raw_chunks, "END", ctx)
else
if _ok and not vim.tbl_isempty(_json) then
if _json and _json.message then
cb(_json.message.content, state, ctx)
raw_chunks = raw_chunks .. _json.message.content
state = "CONTINUE"
end
end
end
end

local ok, json = pcall(vim.json.decode, chunk)
if ok and json ~= nil then
if json.error ~= nil then
cb(json.error, "ERROR", ctx)
return
end
process_line(ok, json)
ctx, raw_chunks, state = Api.provider.process_line(_ok, _json, ctx, raw_chunks, state, cb)
else
for line in chunk:gmatch("[^\n]+") do
local raw_json = string.gsub(line, "^data: ", "")
local raw_json = string.gsub(line, "^data:", "")
local _ok, _json = pcall(vim.json.decode, raw_json)
process_line(_ok, _json)
ctx, raw_chunks, state = Api.provider.process_line(_ok, _json, ctx, raw_chunks, state, cb)
end
end
end,
Expand Down Expand Up @@ -93,8 +106,6 @@ function Api.edits(custom_params, cb)
end

function Api.make_call(url, params, cb)
params = Utils.conform_to_ollama(params)

TMP_MSG_FILENAME = os.tmpname()
local f = io.open(TMP_MSG_FILENAME, "w+")
if f == nil then
Expand Down Expand Up @@ -274,23 +285,21 @@ local function ensureUrlProtocol(str)
end

function Api.setup()
loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", function(value)
Api.OLLAMA_API_HOST = value
Api.MODELS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/tags")
Api.COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/generate")
Api.CHAT_COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/chat")
end, "http://localhost:11434")

loadApiKey("OLLAMA_API_KEY", "OLLAMA_API_KEY", "api_key_cmd", function(value)
Api.OLLAMA_API_KEY = value
loadConfigFromEnv("OPENAI_API_TYPE", "OPENAI_API_TYPE")
if Api["OPENAI_API_TYPE"] == "azure" then
loadAzureConfigs()
Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OLLAMA_API_KEY
else
Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OLLAMA_API_KEY
end
end, " ")
local provider = Api.get_provider()
Api.provider = provider

-- loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", provider.make_url, "http://localhost:11434")

-- loadApiKey("OLLAMA_API_KEY", "OLLAMA_API_KEY", "api_key_cmd", function(value)
-- Api.OLLAMA_API_KEY = value
-- loadConfigFromEnv("OPENAI_API_TYPE", "OPENAI_API_TYPE")
-- if Api["OPENAI_API_TYPE"] == "azure" then
-- loadAzureConfigs()
-- Api.AUTHORIZATION_HEADER = "api-key: " .. Api.OLLAMA_API_KEY
-- else
-- Api.AUTHORIZATION_HEADER = "Authorization: Bearer " .. Api.OLLAMA_API_KEY
-- end
-- end, " ")
end

function Api.exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop)
Expand Down
5 changes: 5 additions & 0 deletions lua/ogpt/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ local M = {}
function M.defaults()
local defaults = {
api_key_cmd = nil,
default_provider = {
name = "ollama",
api_host = os.getenv("OLLAMA_API_HOST"),
api_key = os.getenv("OLLAMA_API_KEY"),
},
yank_register = "+",
edit_with_instructions = {
diff = false,
Expand Down
82 changes: 0 additions & 82 deletions lua/ogpt/flows/chat/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -282,78 +282,6 @@ function Chat:addAnswerPartial(text, state, ctx)
end
end

-- function Chat:addAnswerPartial(text, state, ctx)
-- if state == "ERROR" then
-- return self:addAnswer(text, {})
-- end
--
-- local start_line = 0
-- if self.selectedIndex > 0 then
-- local prev = self.messages[self.selectedIndex]
-- start_line = prev.end_line + (prev.type == ANSWER and 2 or 1)
-- end
--
-- if state == "END" then
-- local usage = {}
-- local idx = self.session:add_item({
-- type = ANSWER,
-- text = text,
-- ctx = ctx or {},
-- usage = usage,
-- })
--
-- local lines = {}
-- local nr_of_lines = 0
-- for line in string.gmatch(text, "[^\n]+") do
-- nr_of_lines = nr_of_lines + 1
-- table.insert(lines, line)
-- end
--
-- local end_line = start_line + nr_of_lines - 1
-- table.insert(self.messages, {
-- idx = idx,
-- usage = usage or {},
-- type = ANSWER,
-- text = text,
-- lines = lines,
-- nr_of_lines = nr_of_lines,
-- start_line = start_line,
-- end_line = end_line,
-- context = ctx.context,
-- })
-- self.selectedIndex = self.selectedIndex + 1
-- vim.api.nvim_buf_set_lines(self.chat_window.bufnr, -1, -1, false, { "", "" })
-- Signs.set_for_lines(self.chat_window.bufnr, start_line, end_line, "chat")
-- end
--
-- if state == "START" then
-- self:stopSpinner()
-- self:set_lines(-2, -1, false, { "" })
-- vim.api.nvim_buf_set_option(self.chat_window.bufnr, "modifiable", true)
-- end
--
-- if state == "START" or state == "CONTINUE" then
-- local lines = vim.split(text, "\n", {})
-- local length = #lines
-- local buffer = self.chat_window.bufnr
-- local win = self.chat_window.winid
--
-- for i, line in ipairs(lines) do
-- local currentLine = vim.api.nvim_buf_get_lines(buffer, -2, -1, false)[1]
-- vim.api.nvim_buf_set_lines(buffer, -2, -1, false, { currentLine .. line })
--
-- local last_line_num = vim.api.nvim_buf_line_count(buffer)
-- Signs.set_for_lines(self.chat_window.bufnr, start_line, last_line_num - 1, "chat")
-- if i == length and i > 1 then
-- vim.api.nvim_buf_set_lines(buffer, -1, -1, false, { "" })
-- end
-- if self:is_buf_visiable() then
-- vim.api.nvim_win_set_cursor(win, { last_line_num, 0 })
-- end
-- end
-- end
-- end

function Chat:get_total_tokens()
local total_tokens = 0
for i = 1, #self.messages, 1 do
Expand Down Expand Up @@ -567,16 +495,6 @@ function Chat:toMessages()
return messages
end

function Chat:toOllama(messages)
local output = ""
for _, entry in ipairs(messages) do
if entry.content then
output = output .. entry.role .. ": " .. entry.content .. "\n\n"
end
end
return output
end

function Chat:count()
local count = 0
for _ in pairs(self.messages) do
Expand Down
110 changes: 110 additions & 0 deletions lua/ogpt/provider/ollama.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
local utils = require("ogpt.utils")

local M = {}

M.envs = {
api_host = os.getenv("OGPT_API_HOST"),
api_key = os.getenv("OGPT_API_KEY"),
}

function M.load_envs()
local _envs = {}
_envs.OLLAMA_API_HOST = M.envs.api_host
_envs.OLLAMA_API_KEY = M.envs.api_key
_envs.MODELS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/tags")
_envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/generate")
_envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/chat")
_envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (_envs.OLLAMA_API_KEY or " ")
M.envs = vim.tbl_extend("force", M.envs, _envs)
return M.envs
end

M.ollama_options = {
"num_keep",
"seed",
"num_predict",
"top_k",
"top_p",
"tfs_z",
"typical_p",
"repeat_last_n",
"temperature",
"repeat_penalty",
"presence_penalty",
"frequency_penalty",
"mirostat",
"mirostat_tau",
"mirostat_eta",
"penalize_newline",
"stop",
"numa",
"num_ctx",
"num_batch",
"num_gqa",
"num_gpu",
"main_gpu",
"low_vram",
"f16_kv",
"logits_all",
"vocab_only",
"use_mmap",
"use_mlock",
"embedding_only",
"rope_frequency_base",
"rope_frequency_scale",
"num_thread",
}

function M.conform(params)
local ollama_parameters = {
"model",
-- "prompt",
"messages",
"format",
"options",
"system",
"template",
-- "context",
"stream",
"raw",
}

-- https://github.com/jmorganca/ollama/blob/main/docs/api.md#show-model-information

local param_options = {}

for key, value in pairs(params) do
if not vim.tbl_contains(ollama_parameters, key) then
if vim.tbl_contains(M.ollama_options, key) then
param_options[key] = value
params[key] = nil
else
params[key] = nil
end
end
end
local _options = vim.tbl_extend("keep", param_options, params.options or {})
if next(_options) ~= nil then
params.options = _options
end
return params
end

function M.process_line(_ok, _json, ctx, raw_chunks, state, cb)
if _json and _json.done then
ctx.context = _json.context
cb(raw_chunks, "END", ctx)
else
if _ok and not vim.tbl_isempty(_json) then
if _json and _json.message then
cb(_json.message.content, state, ctx)
raw_chunks = raw_chunks .. _json.message.content
state = "CONTINUE"
end
end
end

return ctx, raw_chunks, state
end

return M
Loading
Loading