Skip to content

Commit

Permalink
simplying gemini parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
huynle committed Feb 16, 2024
2 parents 36b6003 + c0f93cd commit e6432eb
Show file tree
Hide file tree
Showing 12 changed files with 542 additions and 467 deletions.
154 changes: 74 additions & 80 deletions lua/ogpt/api.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,91 +7,84 @@ local Response = require("ogpt.response")

local Api = Object("Api")

Api.STATE_COMPLETED = "COMPLETED"

function Api:init(provider, action, opts)
self.opts = opts
self.provider = provider
self.action = action
end

function Api:completions(custom_params, cb, opts)
-- TODO: not working atm
local params = vim.tbl_extend("keep", custom_params, Config.options.api_params)
params.stream = false
self:make_call(self.COMPLETIONS_URL, params, cb, opts)
end

function Api:chat_completions(custom_params, partial_result_fn, should_stop, opts)
local stream = custom_params.stream or false
function Api:chat_completions(response, inputs)
local custom_params = inputs.custom_params
local partial_result_fn = inputs.partial_result_fn
local should_stop = inputs.should_stop or function() end

-- local stream = custom_params.stream or false
local params, _completion_url, ctx = self.provider:expand_model(custom_params)

ctx.params = params
ctx.provider = self.provider.name
ctx.model = custom_params.model
utils.log("Request to: " .. _completion_url)
utils.log(params)
-- local raw_chunks = ""
-- local state = "START"
partial_result_fn = vim.schedule_wrap(partial_result_fn)
local response = Response()
response.ctx = ctx
response.rest_params = params
response.partial_result_cb = partial_result_fn
response:run_async()

if stream then
local accumulate = {}

self:exec(
"curl",
{
"--silent",
"--show-error",
"--no-buffer",
_completion_url,
"-d",
vim.json.encode(params),
table.unpack(self.provider:request_headers()), -- has to be the last item in the list
},
function(chunk)
local chunk_og = chunk
table.insert(accumulate, chunk_og)
response:add_raw_chunk(chunk)
-- local content = {
-- raw = chunk,
-- accumulate = accumulate,
-- content = raw_chunks or "",
-- state = state or "START",
-- ctx = ctx,
-- params = params,
-- }
self.provider:process_raw(response)
end,
function(_text, _state, _ctx)
-- partial_result_fn(_text, _state, _ctx)
-- if opts.on_stop then
-- opts.on_stop()
-- end
end,
should_stop,
function()
-- partial_result_fn(raw_chunks, "END", ctx)
-- if opts.on_stop then
-- opts.on_stop()
-- end
end
)
else
params.stream = false
self:make_call(self.provider.envs.CHAT_COMPLETIONS_URL, params, partial_result_fn, opts)
local on_complete = inputs.on_complete or function()
response:set_state(response.STATE_COMPLETED)
end
local on_start = inputs.on_start
or function()
-- utils.log("Start Exec of: Curl " .. vim.inspect(curl_args), vim.log.levels.DEBUG)
response:set_state(response.STATE_INPROGRESS)
end
local on_error = inputs.on_error
or function(msg)
-- utils.log("Error running curl: " .. msg or "", vim.log.levels.ERROR)
response:set_state(response.STATE_ERROR)
end
local on_stop = inputs.on_stop or function()
response:set_state(response.STATE_STOPPED)
end
end

function Api:edits(custom_params, cb)
local params = self.action.params
params.stream = true
params = vim.tbl_extend("force", params, custom_params)
self:chat_completions(params, cb)
-- if params.stream then
-- local accumulate = {}
local curl_args = {
"--silent",
"--show-error",
"--no-buffer",
_completion_url,
"-d",
vim.json.encode(params),
}
for _, header_item in ipairs(self.provider:request_headers()) do
table.insert(curl_args, header_item)
end

self:exec("curl", curl_args, on_start, function(chunk)
response:add_chunk(chunk)
end, on_complete, on_error, on_stop, should_stop)
end

-- function Api:edits(custom_params, cb)
-- local params = self.action.params
-- params.stream = true
-- params = vim.tbl_extend("force", params, custom_params)
-- self:chat_completions(params, cb)
-- end

function Api:make_call(url, params, cb, ctx, raw_chunks, state, opts)
-- TODO: to be deprecated
ctx = ctx or {}
raw_chunks = raw_chunks or ""
state = state or "START"
Expand Down Expand Up @@ -250,33 +243,32 @@ local function ensureUrlProtocol(str)
return "https://" .. str
end

function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop)
local stdout = vim.loop.new_pipe()
function Api:exec(cmd, args, on_start, on_stdout_chunk, on_complete, on_error, on_stop, should_stop)
local stderr = vim.loop.new_pipe()
local stdout = vim.loop.new_pipe()
local stderr_chunks = {}

local handle, err
local function on_stdout_read(_, chunk)
if chunk then
-- vim.schedule(function()
if should_stop and should_stop() then
if handle ~= nil then
handle:kill(2) -- send SIGINT
pcall(function()
stdout:close()
end)
pcall(function()
stderr:close()
end)
pcall(function()
handle:close()
end)
-- on_stop()
on_complete("", "END")
vim.schedule(function()
if should_stop() then
if handle ~= nil then
handle:kill(2) -- send SIGINT
pcall(function()
stdout:close()
end)
pcall(function()
stderr:close()
end)
pcall(function()
handle:close()
end)
on_stop()
end
end
end
-- end)
on_stdout_chunk(chunk)
on_stdout_chunk(chunk)
end)
end
end

Expand All @@ -286,7 +278,7 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop)
end
end

utils.log("executing: " .. vim.inspect(cmd) .. " " .. vim.inspect(args), vim.log.levels.DEBUG)
on_start()
handle, err = vim.loop.spawn(cmd, {
args = args,
stdio = { nil, stdout, stderr },
Expand All @@ -299,13 +291,15 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop)

vim.schedule(function()
if code ~= 0 then
on_complete(vim.trim(table.concat(stderr_chunks, "")), "ERROR")
on_error()
else
on_complete()
end
end)
end)

if not handle then
on_complete(cmd .. " could not be started: " .. err, "ERROR")
on_error(cmd .. " could not be started: " .. err)
else
stdout:read_start(on_stdout_read)
stderr:read_start(on_stderr_read)
Expand Down
6 changes: 5 additions & 1 deletion lua/ogpt/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ M.logs = {}

function M.defaults()
local defaults = {
debug = false,
-- options of 0-5, is trace, debug, info, warn, error, off, respectively
debug = {
log_level = 3,
notify_level = 3,
},
edgy = false,
single_window = false,
yank_register = "+",
Expand Down
57 changes: 57 additions & 0 deletions lua/ogpt/flows/actions/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,61 @@ function BaseAction:display_input_suffix(suffix)
end
end

function BaseAction:on_complete(response)
-- empty
end

function BaseAction:addAnswerPartial(response)
local content = response:pop_content()
local text = content[1]
local state = content[2]

if state == "ERROR" then
self:run_spinner(false)
utils.log("An Error Occurred: " .. text, vim.log.levels.ERROR)
self.output_panel:unmount()
return
end

if state == "END" then
utils.log("Received END Flag", vim.log.levels.DEBUG)
if not utils.is_buf_exists(self.output_panel.bufnr) then
return
end
vim.api.nvim_buf_set_option(self.output_panel.bufnr, "modifiable", true)
vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, {}) -- clear the window, an put the final answer in
vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, vim.split(text, "\n"))
self:on_complete(response)
end

if state == "START" then
self:run_spinner(false)
if not utils.is_buf_exists(self.output_panel.bufnr) then
return
end
vim.api.nvim_buf_set_option(self.output_panel.bufnr, "modifiable", true)
end

if state == "START" or state == "CONTINUE" then
if not utils.is_buf_exists(self.output_panel.bufnr) then
return
end
vim.api.nvim_buf_set_option(self.output_panel.bufnr, "modifiable", true)
local lines = vim.split(text, "\n", {})
local length = #lines

for i, line in ipairs(lines) do
if self.output_panel.bufnr and vim.fn.bufexists(self.output_panel.bufnr) then
local currentLine = vim.api.nvim_buf_get_lines(self.output_panel.bufnr, -2, -1, false)[1]
if currentLine then
vim.api.nvim_buf_set_lines(self.output_panel.bufnr, -2, -1, false, { currentLine .. line })
if i == length and i > 1 then
vim.api.nvim_buf_set_lines(self.output_panel.bufnr, -1, -1, false, { "" })
end
end
end
end
end
end

return BaseAction
62 changes: 25 additions & 37 deletions lua/ogpt/flows/actions/edits/init.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
local BaseAction = require("ogpt.flows.actions.base")
local Response = require("ogpt.response")
local utils = require("ogpt.utils")
local Config = require("ogpt.config")
local Layout = require("ogpt.common.layout")
Expand Down Expand Up @@ -68,6 +69,20 @@ function EditAction:run()
end)
end

function EditAction:on_complete(response)
-- on the completion, execute this function to extract out codeblocks
local output_txt = response:get_processed_text()
local nlcount = utils.count_newlines_at_end(output_txt)
if self.strategy == STRATEGY_EDIT_CODE then
output_txt = response:extract_code()
end
local output_txt_nlfixed = utils.replace_newlines_at_end(output_txt, nlcount)
local _output = utils.split_string_by_line(output_txt_nlfixed)
if self.output_panel.bufnr then
vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, _output)
end
end

function EditAction:edit_with_instructions(output_lines, selection, opts, ...)
opts = opts or {}
opts.params = opts.params or self.params
Expand Down Expand Up @@ -130,6 +145,7 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...)
end
end,
on_submit = vim.schedule_wrap(function(instruction)
local response = Response(self.provider)
-- clear input
vim.api.nvim_buf_set_lines(self.instructions_input.bufnr, 0, -1, false, { "" })
vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, { "" })
Expand All @@ -146,38 +162,12 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...)
local params = vim.tbl_extend("keep", { messages = messages }, self.parameters_panel.params)

params.stream = true
self.provider.api:chat_completions(
params,
utils.partial(utils.add_partial_completion, {
panel = self.output_panel,
on_complete = function(response)
-- on the completion, execute this function to extract out codeblocks
local nlcount = utils.count_newlines_at_end(response)
local output_txt = response
if opts.edit_code then
local code_response = utils.extract_code(response)
-- if the chat is to edit code, it will try to extract out the code from response
output_txt = response
if code_response then
output_txt = utils.match_indentation(response, code_response)
else
vim.notify("no codeblock detected", vim.log.levels.INFO)
end
if response.applied_changes then
vim.notify(response.applied_changes, vim.log.levels.INFO)
end
end
local output_txt_nlfixed = utils.replace_newlines_at_end(output_txt, nlcount)
local _output = utils.split_string_by_line(output_txt_nlfixed)
if self.output_panel.bufnr then
vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, _output)
end
end,
progress = function(flag)
self:run_spinner(flag)
end,
})
)
self.provider.api:chat_completions(response, {
custom_params = params,
partial_result_fn = function(...)
self:addAnswerPartial(...)
end,
})
end),
})

Expand Down Expand Up @@ -426,11 +416,9 @@ end

function EditAction:build_edit_messages(input, instructions, opts)
local _input = input
if opts.edit_code then
_input = "```" .. (opts.filetype or "") .. "\n" .. input .. "````"
else
_input = "```" .. (opts.filetype or "") .. "\n" .. input .. "````"
end

_input = "```" .. (opts.filetype or "") .. "\n" .. input .. "````"

local variables = vim.tbl_extend("force", opts.variables, {
instruction = instructions,
input = _input,
Expand Down
Loading

0 comments on commit e6432eb

Please sign in to comment.