Skip to content

Commit

Permalink
gemini is working
Browse files Browse the repository at this point in the history
  • Loading branch information
huynle committed Feb 11, 2024
1 parent 0638f8b commit 8a53132
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 101 deletions.
136 changes: 91 additions & 45 deletions lua/ogpt/api.lua
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ function Api:chat_completions(custom_params, partial_result_fn, should_stop, opt
ctx.model = custom_params.model
utils.log("Request to: " .. _completion_url)
utils.log(params)
local raw_chunks = ""
local state = "START"

if stream then
local raw_chunks = ""
local state = "START"

partial_result_fn = vim.schedule_wrap(partial_result_fn)
-- partial_result_fn = vim.schedule_wrap(partial_result_fn)
local accumulate = {}

self:exec(
"curl",
Expand All @@ -41,44 +41,91 @@ function Api:chat_completions(custom_params, partial_result_fn, should_stop, opt
"--show-error",
"--no-buffer",
_completion_url,
"-H",
"Content-Type: application/json",
"-H",
self.provider.envs.AUTHORIZATION_HEADER,
"-d",
vim.json.encode(params),
table.unpack(self.provider:request_headers()), -- has to be the last item in the list
},
function(chunk)
local ok, json = pcall(vim.json.decode, chunk)
if ok then
if json.error ~= nil then
local error_msg = {
"OGPT ERROR:",
self.provider.name,
vim.inspect(json.error) or "",
"Something went wrong.",
}
table.insert(error_msg, vim.inspect(params))
-- local error_msg = "OGPT ERROR: " .. (json.error.message or "Something went wrong")
partial_result_fn(table.concat(error_msg, " "), "ERROR", ctx)
return
end
ctx, raw_chunks, state =
self.provider:process_line({ json = json, raw = chunk }, ctx, raw_chunks, state, partial_result_fn, opts)
return
end

for line in chunk:gmatch("[^\n]+") do
local raw_json = string.gsub(line, "^data:", "")
local _ok, _json = pcall(vim.json.decode, raw_json)
if _ok then
ctx, raw_chunks, state =
self.provider:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, partial_result_fn, opts)
else
ctx, raw_chunks, state =
self.provider:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, partial_result_fn, opts)
end
end
local chunk_og = chunk
table.insert(accumulate, chunk_og)
local content = {
raw = chunk,
accumulate = accumulate,
content = raw_chunks or "",
state = state or "START",
ctx = ctx,
}
ctx, raw_chunks, state = self.provider:process_raw(content, partial_result_fn, opts)
--
-- local ok, json = pcall(vim.json.decode, chunk)
-- if not ok then
-- -- gemini is missing bracket on returns
-- chunk = string.gsub(chunk, "^%[", "")
-- chunk = string.gsub(chunk, "^%,", "")
-- chunk = string.gsub(chunk, "%]$", "")
-- chunk = vim.trim(chunk, "\n")
-- chunk = vim.trim(chunk, "\r")
-- ok, json = pcall(vim.json.decode, chunk)
-- end
--
-- -- if ok then
-- -- vim.print("okay")
-- -- else
-- -- vim.print("not okay")
-- -- end
--
-- if ok then
-- if json.error ~= nil then
-- local error_msg = {
-- "OGPT ERROR:",
-- self.provider.name,
-- vim.inspect(json.error) or "",
-- "Something went wrong.",
-- }
-- table.insert(error_msg, vim.inspect(params))
-- -- local error_msg = "OGPT ERROR: " .. (json.error.message or "Something went wrong")
-- partial_result_fn(table.concat(error_msg, " "), "ERROR", ctx)
-- return
-- end
-- ctx, raw_chunks, state =
-- self.provider:process_line({ json = json, raw = chunk }, ctx, raw_chunks, state, partial_result_fn, opts)
-- return
-- end
--
-- -- openai
-- for line in chunk:gmatch("[^\n]+") do
-- local raw_json = string.gsub(line, "^data:", "")
-- local _ok, _json = pcall(vim.json.decode, raw_json)
-- if _ok then
-- ctx, raw_chunks, state =
-- self.provider:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, partial_result_fn, opts)
-- else
-- ctx, raw_chunks, state = self.provider:process_line(
-- { json = nil, raw = chunk_og },
-- ctx,
-- raw_chunks,
-- state,
-- partial_result_fn,
-- opts
-- )
-- end
-- end
--
-- if not ok then
-- -- if not ok, try to keep process the accumulated stdout
-- table.insert(accumulate, chunk_og)
-- ok, json = pcall(vim.json.decode, table.concat(accumulate, ""))
-- if ok then
-- ctx, raw_chunks, state = self.provider:process_line(
-- { json = json, raw = accumulate },
-- ctx,
-- raw_chunks,
-- state,
-- partial_result_fn,
-- opts
-- )
-- end
-- end
end,
function(_text, _state, _ctx)
partial_result_fn(_text, _state, _ctx)
Expand Down Expand Up @@ -123,11 +170,9 @@ function Api:make_call(url, params, cb, ctx, raw_chunks, state, opts)

local curl_args = {
url,
table.unpack(self.provider:request_headers()),
"-H",
"Content-Type: application/json",
"-d",
"@" .. TMP_MSG_FILENAME,
table.unpack(self.provider:request_headers()),
}

self.job = job
Expand Down Expand Up @@ -275,8 +320,8 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop)

local handle, err
local function on_stdout_read(_, chunk)
if chunk then
vim.schedule(function()
vim.schedule(function()
if chunk then
if should_stop and should_stop() then
if handle ~= nil then
handle:kill(2) -- send SIGINT
Expand All @@ -295,8 +340,8 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop)
return
end
on_stdout_chunk(chunk)
end)
end
end
end)
end

local function on_stderr_read(_, chunk)
Expand All @@ -305,6 +350,7 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop)
end
end

utils.log("executing: " .. vim.inspect(cmd) .. " " .. vim.inspect(args), vim.log.levels.DEBUG)
handle, err = vim.loop.spawn(cmd, {
args = args,
stdio = { nil, stdout, stderr },
Expand Down
8 changes: 8 additions & 0 deletions lua/ogpt/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ function M.defaults()
api_host = os.getenv("GEMINI_API_HOST"),
api_key = os.getenv("GEMINI_API_KEY"),
model = "gemini-pro",
api_params = {
temperature = 0.5,
topP = 0.99,
},
api_chat_params = {
temperature = 0.5,
topP = 0.99,
},
},
textgenui = {
enabled = true,
Expand Down
6 changes: 5 additions & 1 deletion lua/ogpt/flows/actions/popup/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ function PopupAction:run()
self:run_spinner(flag)
end,
on_complete = function(total_text)
-- print("completed: " .. total_text)
-- utils.log("request completed - processed text is:\n" .. total_text, vim.log.levels.DEBUG)
if vim.fn.bufexists(self.popup.bufnr) then
vim.api.nvim_buf_set_option(self.popup.bufnr, "modifiable", true)
vim.api.nvim_buf_set_lines(self.popup.bufnr, -2, -1, false, vim.split(total_text, "\n", {}))
end
end,
}),
function()
Expand Down
38 changes: 37 additions & 1 deletion lua/ogpt/provider/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ local utils = require("ogpt.utils")
local Provider = Object("Provider")

function Provider:init(opts)
self.name = self.class.name
self.name = string.lower(self.class.name)
opts = vim.tbl_extend("force", Config.options.providers[self.name], opts)
self.enabled = opts.enabled
self.model = opts.model
Expand Down Expand Up @@ -175,6 +175,42 @@ function Provider:conform_messages(params)
return params
end

function Provider:process_raw(content, cb, opts)
local chunk = content.raw
local state = content.state
local raw_chunks = content.content
local accumulate = content.accumulate
local ctx = content.ctx
local ok, json = pcall(vim.json.decode, chunk)

-- if not ok then
-- -- gemini is missing bracket on returns
-- chunk = string.gsub(chunk, "^%[", "")
-- chunk = string.gsub(chunk, "^%,", "")
-- chunk = string.gsub(chunk, "%]$", "")
-- chunk = vim.trim(chunk, "\n")
-- chunk = vim.trim(chunk, "\r")
-- ok, json = pcall(vim.json.decode, chunk)
-- end

if ok then
if json.error ~= nil then
local error_msg = {
"OGPT ERROR:",
self.provider.name,
vim.inspect(json.error) or "",
"Something went wrong.",
}
table.insert(error_msg, vim.inspect(params))
-- local error_msg = "OGPT ERROR: " .. (json.error.message or "Something went wrong")
cb(table.concat(error_msg, " "), "ERROR", ctx)
-- return
return { ctx, raw_chunks, state }
end
return self:process_line({ json = json, raw = chunk }, ctx, raw_chunks, state, cb)
end
end

function Provider:process_line(content, ctx, raw_chunks, state, cb)
local _json = content.json
local raw = content.raw
Expand Down
Loading

0 comments on commit 8a53132

Please sign in to comment.