diff --git a/lua/ogpt/api.lua b/lua/ogpt/api.lua index 4e6bce1..b716cca 100644 --- a/lua/ogpt/api.lua +++ b/lua/ogpt/api.lua @@ -3,6 +3,7 @@ local Config = require("ogpt.config") local logger = require("ogpt.common.logger") local Object = require("ogpt.common.object") local utils = require("ogpt.utils") +local Response = require("ogpt.response") local Api = Object("Api") @@ -27,11 +28,15 @@ function Api:chat_completions(custom_params, partial_result_fn, should_stop, opt ctx.model = custom_params.model utils.log("Request to: " .. _completion_url) utils.log(params) - local raw_chunks = "" - local state = "START" + -- local raw_chunks = "" + -- local state = "START" + partial_result_fn = vim.schedule_wrap(partial_result_fn) + local response = Response() + response.ctx = ctx + response.rest_params = params + response.partial_result_cb = partial_result_fn if stream then - -- partial_result_fn = vim.schedule_wrap(partial_result_fn) local accumulate = {} self:exec( @@ -48,94 +53,26 @@ function Api:chat_completions(custom_params, partial_result_fn, should_stop, opt function(chunk) local chunk_og = chunk table.insert(accumulate, chunk_og) - local content = { - raw = chunk, - accumulate = accumulate, - content = raw_chunks or "", - state = state or "START", - ctx = ctx, - } - ctx, raw_chunks, state = self.provider:process_raw(content, partial_result_fn, opts) - -- - -- local ok, json = pcall(vim.json.decode, chunk) - -- if not ok then - -- -- gemini is missing bracket on returns - -- chunk = string.gsub(chunk, "^%[", "") - -- chunk = string.gsub(chunk, "^%,", "") - -- chunk = string.gsub(chunk, "%]$", "") - -- chunk = vim.trim(chunk, "\n") - -- chunk = vim.trim(chunk, "\r") - -- ok, json = pcall(vim.json.decode, chunk) - -- end - -- - -- -- if ok then - -- -- vim.print("okay") - -- -- else - -- -- vim.print("not okay") - -- -- end - -- - -- if ok then - -- if json.error ~= nil then - -- local error_msg = { - -- "OGPT ERROR:", - -- self.provider.name, - -- vim.inspect(json.error) or "", - -- "Something went wrong.", - -- } - -- table.insert(error_msg, vim.inspect(params)) - -- -- local error_msg = "OGPT ERROR: " .. (json.error.message or "Something went wrong") - -- partial_result_fn(table.concat(error_msg, " "), "ERROR", ctx) - -- return - -- end - -- ctx, raw_chunks, state = - -- self.provider:process_line({ json = json, raw = chunk }, ctx, raw_chunks, state, partial_result_fn, opts) - -- return - -- end - -- - -- -- openai - -- for line in chunk:gmatch("[^\n]+") do - -- local raw_json = string.gsub(line, "^data:", "") - -- local _ok, _json = pcall(vim.json.decode, raw_json) - -- if _ok then - -- ctx, raw_chunks, state = - -- self.provider:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, partial_result_fn, opts) - -- else - -- ctx, raw_chunks, state = self.provider:process_line( - -- { json = nil, raw = chunk_og }, - -- ctx, - -- raw_chunks, - -- state, - -- partial_result_fn, - -- opts - -- ) - -- end - -- end - -- - -- if not ok then - -- -- if not ok, try to keep process the accumulated stdout - -- table.insert(accumulate, chunk_og) - -- ok, json = pcall(vim.json.decode, table.concat(accumulate, "")) - -- if ok then - -- ctx, raw_chunks, state = self.provider:process_line( - -- { json = json, raw = accumulate }, - -- ctx, - -- raw_chunks, - -- state, - -- partial_result_fn, - -- opts - -- ) - -- end - -- end + response:add_raw_chunk(chunk) + -- local content = { + -- raw = chunk, + -- accumulate = accumulate, + -- content = raw_chunks or "", + -- state = state or "START", + -- ctx = ctx, + -- params = params, + -- } + self.provider:process_raw(response) end, function(_text, _state, _ctx) - partial_result_fn(_text, _state, _ctx) + -- partial_result_fn(_text, _state, _ctx) -- if opts.on_stop then -- opts.on_stop() -- end end, should_stop, function() - partial_result_fn(raw_chunks, "END", ctx) + -- partial_result_fn(raw_chunks, "END", ctx) -- if opts.on_stop then -- opts.on_stop() -- end @@ -186,7 +123,7 @@ function Api:make_call(url, params, cb, ctx, raw_chunks, state, opts) "An Error Occurred, when calling `curl " .. table.concat(curl_args, " ") .. "`", vim.log.levels.ERROR ) - cb("ERROR: API Error") + cb("ERROR: API Error", "ERROR") end local result = table.concat(response:result(), "\n") @@ -320,28 +257,27 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop) local handle, err local function on_stdout_read(_, chunk) - vim.schedule(function() - if chunk then - if should_stop and should_stop() then - if handle ~= nil then - handle:kill(2) -- send SIGINT - pcall(function() - stdout:close() - end) - pcall(function() - stderr:close() - end) - pcall(function() - handle:close() - end) - -- on_stop() - on_complete("", "END") - end - return + if chunk then + -- vim.schedule(function() + if should_stop and should_stop() then + if handle ~= nil then + handle:kill(2) -- send SIGINT + pcall(function() + stdout:close() + end) + pcall(function() + stderr:close() + end) + pcall(function() + handle:close() + end) + -- on_stop() + on_complete("", "END") end - on_stdout_chunk(chunk) end - end) + -- end) + on_stdout_chunk(chunk) + end end local function on_stderr_read(_, chunk) diff --git a/lua/ogpt/flows/actions/popup/init.lua b/lua/ogpt/flows/actions/popup/init.lua index 9a98250..07396db 100644 --- a/lua/ogpt/flows/actions/popup/init.lua +++ b/lua/ogpt/flows/actions/popup/init.lua @@ -103,46 +103,44 @@ function PopupAction:run() end function PopupAction:on_result(answer, usage) - vim.schedule(function() - self:set_loading(false) - local lines = utils.split_string_by_line(answer) - local _, start_row, start_col, end_row, end_col = self:get_visual_selection() - local bufnr = self:get_bufnr() - if self.strategy == STRATEGY_PREPEND then - answer = answer .. "\n" .. self:get_selected_text() - vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) - elseif self.strategy == STRATEGY_APPEND then - answer = self:get_selected_text() .. "\n\n" .. answer .. "\n" - vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) - elseif self.strategy == STRATEGY_REPLACE then - answer = answer - vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) - elseif self.strategy == STRATEGY_QUICK_FIX then - if #lines == 1 and lines[1] == "" then - vim.notify("Your Code looks fine, no issues.", vim.log.levels.INFO) - return - end + self:set_loading(false) + local lines = utils.split_string_by_line(answer) + local _, start_row, start_col, end_row, end_col = self:get_visual_selection() + local bufnr = self:get_bufnr() + if self.strategy == STRATEGY_PREPEND then + answer = answer .. "\n" .. self:get_selected_text() + vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) + elseif self.strategy == STRATEGY_APPEND then + answer = self:get_selected_text() .. "\n\n" .. answer .. "\n" + vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) + elseif self.strategy == STRATEGY_REPLACE then + answer = answer + vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) + elseif self.strategy == STRATEGY_QUICK_FIX then + if #lines == 1 and lines[1] == "" then + vim.notify("Your Code looks fine, no issues.", vim.log.levels.INFO) + return + end - local entries = {} - for _, line in ipairs(lines) do - local lnum, text = line:match("(%d+):(.*)") - if lnum then - local entry = { filename = vim.fn.expand("%:p"), lnum = tonumber(lnum), text = text } - table.insert(entries, entry) - end - end - if entries then - vim.fn.setqflist(entries) - vim.cmd(Config.options.show_quickfixes_cmd) + local entries = {} + for _, line in ipairs(lines) do + local lnum, text = line:match("(%d+):(.*)") + if lnum then + local entry = { filename = vim.fn.expand("%:p"), lnum = tonumber(lnum), text = text } + table.insert(entries, entry) end end - - -- set the cursor onto the answer - if self.strategy == STRATEGY_APPEND then - local target_line = end_row + 3 - vim.api.nvim_win_set_cursor(0, { target_line, 0 }) + if entries then + vim.fn.setqflist(entries) + vim.cmd(Config.options.show_quickfixes_cmd) end - end) + end + + -- set the cursor onto the answer + if self.strategy == STRATEGY_APPEND then + local target_line = end_row + 3 + vim.api.nvim_win_set_cursor(0, { target_line, 0 }) + end end return PopupAction diff --git a/lua/ogpt/flows/chat/base.lua b/lua/ogpt/flows/chat/base.lua index 2f06089..58adf55 100644 --- a/lua/ogpt/flows/chat/base.lua +++ b/lua/ogpt/flows/chat/base.lua @@ -221,10 +221,12 @@ function Chat:addAnswer(text, usage) self:add(ANSWER, text, usage) end -function Chat:addAnswerPartial(text, state, ctx) +function Chat:addAnswerPartial(response, state) + local text = response.current_text + local ctx = response.ctx if state == "ERROR" then self:stopSpinner() - return self:addAnswer(text, {}) + -- return self:addAnswer(text, {}) end local start_line = 0 @@ -237,6 +239,7 @@ function Chat:addAnswerPartial(text, state, ctx) -- -- most likely, ended by the using raising the stop flag -- self:stopSpinner() elseif state == "END" and text ~= "" then + self:stopSpinner() local usage = {} local idx = self.session:add_item({ type = ANSWER, @@ -258,12 +261,12 @@ function Chat:addAnswerPartial(text, state, ctx) idx = idx, usage = usage or {}, type = ANSWER, - text = text, + text = response:get_total_processed_text(), lines = lines, nr_of_lines = nr_of_lines, start_line = start_line, end_line = end_line, - context = ctx.context, + context = response:get_context(), }) self.selectedIndex = self.selectedIndex + 1 vim.api.nvim_buf_set_lines(self.chat_window.bufnr, -1, -1, false, { "", "" }) diff --git a/lua/ogpt/provider/openai.lua b/lua/ogpt/provider/openai.lua index c570d0d..9b5fb60 100644 --- a/lua/ogpt/provider/openai.lua +++ b/lua/ogpt/provider/openai.lua @@ -55,48 +55,69 @@ function Openai:conform_request(params) return params end -function Openai:process_raw(content, cb, opts) - local chunk = content.raw - local state = content.state - local ctx = content.ctx - local raw_chunks = content.content +function Openai:process_raw(response) + local chunk = response.current_raw_chunk + -- local state = response.state + -- local ctx = response.ctx + -- local raw_chunks = response.processed_text + -- local params = response.params + -- local accumulate = response.accumulate_chunks + + local ok, json = pcall(vim.json.decode, chunk) + -- if ok then + -- if json.error ~= nil then + -- local error_msg = { + -- "OGPT ERROR:", + -- self.provider.name, + -- vim.inspect(json.error) or "", + -- "Something went wrong.", + -- } + -- table.insert(error_msg, vim.inspect(params)) + -- -- local error_msg = "OGPT ERROR: " .. (json.error.message or "Something went wrong") + -- cb(table.concat(error_msg, " "), "ERROR", ctx) + -- return + -- end + -- ctx, raw_chunks, state = self:process_line({ json = json, raw = chunk }, ctx, raw_chunks, state, cb, opts) + -- return + -- end - -- openai for line in chunk:gmatch("[^\n]+") do local raw_json = string.gsub(line, "^data:", "") local _ok, _json = pcall(vim.json.decode, raw_json) if _ok then - return self:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, cb) - -- else - -- ctx, raw_chunks, state = - -- self:process_line({ json = nil, raw = chunk_og }, ctx, raw_chunks, state, partial_result_fn, opts) + self:process_line({ json = _json, raw = line }, response) + else + self:process_line({ json = nil, raw = line }, response) end end end -function Openai:process_line(content, ctx, raw_chunks, state, cb) +function Openai:process_line(content, response) + local ctx = response.ctx + -- local total_text = response.processed_text + local state = response.state + local cb = response.partial_result_cb local _json = content.json - local raw = content.raw - -- given a JSON response from the STREAMING api, processs it - if _json and _json.done then - ctx.context = _json.context - cb(raw_chunks, "END", ctx) - elseif type(_json) == "string" and string.find(_json, "[DONE]") then - cb(raw_chunks, "END", ctx) - else + local _raw = content.raw + if _json then local text_delta = vim.tbl_get(_json, "choices", 1, "delta", "content") local text = vim.tbl_get(_json, "choices", 1, "message", "content") if text_delta then - cb(_json.choices[1].delta.content, state) - raw_chunks = raw_chunks .. _json.choices[1].delta.content + response:add_processed_text(text_delta) state = "CONTINUE" + cb(response, state) elseif text then - cb(_json.choices[1].message.content, state) - raw_chunks = raw_chunks .. _json.choices[1].message.content + response:add_processed_text(text_delta) + cb(response, "END") end + elseif not _json and string.find(_raw, "[DONE]") then + cb(response, "END") + else + utils.log("Something NOT hanndled openai: _json\n" .. vim.inspect(_json)) + utils.log("Something NOT hanndled openai: _raw\n" .. vim.inspect(_raw)) end - return ctx, raw_chunks, state + -- return ctx, total_text, state end return Openai diff --git a/lua/ogpt/response.lua b/lua/ogpt/response.lua new file mode 100644 index 0000000..8121008 --- /dev/null +++ b/lua/ogpt/response.lua @@ -0,0 +1,35 @@ +local Object = require("ogpt.common.object") + +local Response = Object("Response") + +function Response:init(opts) + self.accumulated_chunks = {} + self.current_raw_chunk = "" + self.processed_text = {} + self.current_text = "" + self.state = "START" + self.rest_params = {} + self.ctx = {} + self.partial_result_cb = nil + self.error = nil +end + +function Response:add_raw_chunk(chunk) + table.insert(self.accumulated_chunks, chunk) + self.current_raw_chunk = chunk +end + +function Response:add_processed_text(text) + table.insert(self.processed_text, text) + self.current_text = text +end + +function Response:get_total_processed_text() + return table.concat(self.processed_text, "") +end + +function Response:get_context() + return self.ctx +end + +return Response diff --git a/lua/ogpt/utils.lua b/lua/ogpt/utils.lua index c442f8c..683ae79 100644 --- a/lua/ogpt/utils.lua +++ b/lua/ogpt/utils.lua @@ -269,7 +269,11 @@ function M.add_partial_completion(opts, text, state) end local start_line = 0 - if state == "END" and text ~= "" then + + if state == "END" and text == "" then + -- -- most likely, ended by the using raising the stop flag + -- self:stopSpinner() + elseif state == "END" and text ~= "" then if not opts.on_complete then return end