Skip to content

Commit

Permalink
Dev: general update (#42)
Browse files Browse the repository at this point in the history
* updating code yank and replacement in popup

* adding openrouter

* updating so that openai have system prompt

* Fixing openai
  • Loading branch information
huynle authored Dec 4, 2024
1 parent f87735d commit 0fae02d
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 26 deletions.
11 changes: 6 additions & 5 deletions lua/ogpt/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ function M.defaults()
},
gemini = {
enabled = true,
api_host = os.getenv("GEMINI_API_HOST"),
api_key = os.getenv("GEMINI_API_KEY"),
api_host = os.getenv("GEMINI_API_HOST") or "https://generativelanguage.googleapis.com/v1beta",
api_key = os.getenv("GEMINI_API_KEY") or "",
model = "gemini-pro",
api_params = {
temperature = 0.5,
Expand All @@ -70,7 +70,8 @@ function M.defaults()
},
anthropic = {
enabled = true,
api_key = os.getenv("ANTHROPIC_API_KEY"),
api_host = os.getenv("ANTHROPIC_API_HOST") or "https://api.anthropic.com/v1/messages",
api_key = os.getenv("ANTHROPIC_API_KEY") or "",
model = "claude-3-opus-20240229",
api_params = {
temperature = 0.5,
Expand All @@ -85,8 +86,8 @@ function M.defaults()
},
textgenui = {
enabled = true,
api_host = os.getenv("OGPT_API_HOST"),
api_key = os.getenv("OGPT_API_KEY"),
api_host = os.getenv("TEXTGEN_API_HOST"),
api_key = os.getenv("TEXTGEN_API_KEY"),
model = {
-- create a modify url specifically for mixtral to run
name = "mixtral-8-7b",
Expand Down
1 change: 0 additions & 1 deletion lua/ogpt/module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ function M.open_chat_with_awesome_prompt(opts)
chat.chat_window.border:set_text("top", " OGPT - Acts as " .. act .. " ", "center")

chat:set_system_message(prompt)
chat:open_system_panel()
end),
})
end
Expand Down
4 changes: 1 addition & 3 deletions lua/ogpt/provider/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,10 @@ end

function Gemini:load_envs(override)
local _envs = {}
_envs.GEMINI_API_HOST = Config.options.providers.gemini.api_host
_envs.GEMINI_API_KEY = Config.options.providers.gemini.api_key or os.getenv("GEMINI_API_KEY") or ""
_envs.AUTH = "key=" .. (_envs.GEMINI_API_KEY or " ")
_envs.MODEL = "gemini-pro"
_envs.GEMINI_API_HOST = Config.options.providers.gemini.api_host
or os.getenv("GEMINI_API_HOST")
or "https://generativelanguage.googleapis.com/v1beta"
_envs.MODELS_URL = utils.ensureUrlProtocol(_envs.GEMINI_API_HOST .. "/models")
self.envs = vim.tbl_extend("force", _envs, override or {})
return self.envs
Expand Down
42 changes: 35 additions & 7 deletions lua/ogpt/provider/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,14 @@ function Openai:init(opts)
"messages",
"stream",
"temperature",
"presence_penalty",
"frequency_penalty",
"top_p",
"max_tokens",
}
self.api_chat_request_options = {}
end

function Openai:load_envs(override)
local _envs = {}
_envs.OPENAI_API_HOST = Config.options.providers.openai.api_host
or os.getenv("OPENAI_API_HOST")
or "https://api.openai.com"
_envs.OPENAI_API_KEY = Config.options.providers.openai.api_key or os.getenv("OPENAI_API_KEY") or ""
_envs.OPENAI_API_KEY = Config.options.providers.openai.api_key
_envs.MODELS_URL = utils.ensureUrlProtocol(_envs.OPENAI_API_HOST .. "/v1/models")
_envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OPENAI_API_HOST .. "/v1/completions")
_envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OPENAI_API_HOST .. "/v1/chat/completions")
Expand All @@ -50,12 +44,46 @@ function Openai:parse_api_model_response(res, cb)
end

function Openai:conform_request(params)
local _to_remove_system_idx = {}
for idx, message in ipairs(params.messages) do
if message.role == "system" then
table.insert(_to_remove_system_idx, idx)
end
end

-- Remove elements from the list based on indices
for i = #_to_remove_system_idx, 1, -1 do
table.remove(params.messages, _to_remove_system_idx[i])
end

-- conform to support text only model
local messages = params.messages
local conformed_messages = {}
for _, message in ipairs(messages) do
table.insert(conformed_messages, {
role = message.role,
content = utils.gather_text_from_parts(message.content),
})
end

-- Insert the updated params.system string at the beginning of conformed_messages
if params.system then
table.insert(conformed_messages, 1, {
role = "system",
content = params.system,
})
end

params.messages = conformed_messages

-- general clean up, remove things that shouldnt be here
for key, value in pairs(params) do
if not vim.tbl_contains(self.api_parameters, key) then
utils.log("Did not process " .. key .. " for " .. self.name)
params[key] = nil
end
end

return params
end

Expand Down
4 changes: 1 addition & 3 deletions lua/ogpt/provider/openrouter.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ end
function OpenRouter:load_envs(override)
local _envs = {}
_envs.OPENROUTER_API_HOST = Config.options.providers.openrouter.api_host
or os.getenv("OPENROUTER_API_HOST")
or "https://api.openrouter.com"
_envs.OPENROUTER_API_KEY = Config.options.providers.openrouter.api_key or os.getenv("OPENROUTER_API_KEY") or ""
_envs.OPENROUTER_API_KEY = Config.options.providers.openrouter.api_key
_envs.MODELS_URL = utils.ensureUrlProtocol(_envs.OPENROUTER_API_HOST .. "/v1/models")
_envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OPENROUTER_API_HOST .. "/v1/completions")
_envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OPENROUTER_API_HOST .. "/v1/chat/completions")
Expand Down
27 changes: 20 additions & 7 deletions lua/ogpt/provider/textgenui.lua
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,16 @@ function Textgenui:load_envs(override)
or os.getenv("TEXTGEN_API_HOST")
or "https://api.textgen.com"
_envs.TEXTGEN_API_KEY = Config.options.providers.textgenui.api_key or os.getenv("TEXTGEN_API_KEY") or ""
_envs.MODELS_URL = utils.ensureUrlProtocol(_envs.TEXTGEN_API_HOST .. "/api/tags")
_envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.TEXTGEN_API_HOST)
_envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.TEXTGEN_API_HOST)
_envs.MODELS_URL = utils.ensureUrlProtocol(_envs.TEXTGEN_API_HOST .. "/v1/models")
_envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.TEXTGEN_API_HOST .. "/v1/completions")
_envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.TEXTGEN_API_HOST .. "/v1/chat/completions")
_envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (_envs.TEXTGEN_API_KEY or " ")
self.envs = vim.tbl_extend("force", _envs, override or {})
return self.envs
end

function Textgenui:completion_url()
if self.stream_response then
return self.envs.TEXTGEN_API_HOST .. "/generate_stream"
end
return self.envs.TEXTGEN_API_HOST .. "/generate"
return self.envs.COMPLETIONS_URL
end

function Textgenui:conform_request(params)
Expand Down Expand Up @@ -164,4 +161,20 @@ function Textgenui:process_response(response)
end
end

function Textgenui:parse_api_model_response(res, cb)
local response = table.concat(res, "\n")
local ok, json = pcall(vim.fn.json_decode, response)

if not ok then
vim.print("OGPT ERROR: something happened when trying request for models from " .. self:models_url())
else
-- Given a json object from the api, parse this and get the names of the model to be displayed
for _, model in ipairs(json.data) do
cb({
name = model.id,
})
end
end
end

return Textgenui

0 comments on commit 0fae02d

Please sign in to comment.