diff --git a/lua/ogpt/api.lua b/lua/ogpt/api.lua index b0a5ca6..474d659 100644 --- a/lua/ogpt/api.lua +++ b/lua/ogpt/api.lua @@ -18,7 +18,7 @@ function Api:completions(custom_params, cb) self:make_call(self.COMPLETIONS_URL, params, cb) end -function Api:chat_completions(custom_params, cb, should_stop, opts) +function Api:chat_completions(custom_params, partial_result_fn, should_stop, opts) local stream = custom_params.stream or false local params, _completion_url = Config.expand_model(self, custom_params) @@ -31,7 +31,7 @@ function Api:chat_completions(custom_params, cb, should_stop, opts) local raw_chunks = "" local state = "START" - cb = vim.schedule_wrap(cb) + partial_result_fn = vim.schedule_wrap(partial_result_fn) self:exec( "curl", @@ -59,10 +59,10 @@ function Api:chat_completions(custom_params, cb, should_stop, opts) } table.insert(error_msg, vim.inspect(params)) -- local error_msg = "OGPT ERROR: " .. (json.error.message or "Something went wrong") - cb(table.concat(error_msg, " "), "ERROR", ctx) + partial_result_fn(table.concat(error_msg, " "), "ERROR", ctx) return end - ctx, raw_chunks, state = self.provider.process_line(json, ctx, raw_chunks, state, cb) + ctx, raw_chunks, state = self.provider.process_line(json, ctx, raw_chunks, state, partial_result_fn) return end @@ -70,21 +70,21 @@ function Api:chat_completions(custom_params, cb, should_stop, opts) local raw_json = string.gsub(line, "^data:", "") local _ok, _json = pcall(vim.json.decode, raw_json) if _ok then - ctx, raw_chunks, state = self.provider.process_line(_json, ctx, raw_chunks, state, cb) + ctx, raw_chunks, state = self.provider.process_line(_json, ctx, raw_chunks, state, partial_result_fn) end end end, function(err, _) - cb(err, "ERROR", ctx) + partial_result_fn(err, "ERROR", ctx) end, should_stop, function() - cb(raw_chunks, "END", ctx) + partial_result_fn(raw_chunks, "END", ctx) end ) else params.stream = false - self:make_call(self.provider.envs.CHAT_COMPLETIONS_URL, params, cb) + self:make_call(self.provider.envs.CHAT_COMPLETIONS_URL, params, partial_result_fn) end end diff --git a/lua/ogpt/common/layouts.lua b/lua/ogpt/common/layouts.lua new file mode 100644 index 0000000..19a07f1 --- /dev/null +++ b/lua/ogpt/common/layouts.lua @@ -0,0 +1,92 @@ +local Config = require("ogpt.config") +local Layout = require("nui.layout") + +local M = {} + +M.edit_with_nui_layout = function(layout, input, instruction, output, parameters, opts) + opts = opts or {} + local _boxes + if opts.show_parameters then + _boxes = Layout.Box({ + Layout.Box({ + Layout.Box(input, { grow = 1 }), + Layout.Box(instruction, { size = 3 }), + }, { dir = "col", grow = 1 }), + Layout.Box(output, { grow = 1 }), + Layout.Box(parameters, { size = 40 }), + }, { dir = "row" }) + else + _boxes = Layout.Box({ + Layout.Box({ + Layout.Box(input, { grow = 1 }), + Layout.Box(instruction, { size = 3 }), + }, { dir = "col", size = "50%" }), + Layout.Box(output, { size = "50%" }), + }, { dir = "row" }) + end + + if not layout then + layout = Layout({ + relative = "editor", + position = "50%", + size = { + width = Config.options.popup_layout.center.width, + height = Config.options.popup_layout.center.height, + }, + }, _boxes) + else + layout:update(_boxes) + end + + layout:mount() + + if opts.show_parameters then + parameters:show() + parameters:mount() + + vim.api.nvim_set_current_win(parameters.winid) + vim.api.nvim_buf_set_option(parameters.bufnr, "modifiable", false) + vim.api.nvim_win_set_option(parameters.winid, "cursorline", true) + else + parameters:hide() + vim.api.nvim_set_current_win(instruction.winid) + end + + return layout +end + +M.edit_with_no_layout = function(layout, input, instruction, output, parameters, opts) + opts = opts or {} + opts = vim.tbl_extend("force", opts, { + buf = { + vars = { + ogpt = true, + }, + }, + }) + + input:mount() + vim.api.nvim_buf_set_var(input.bufnr, "ogpt_input", true) + output:mount() + vim.api.nvim_buf_set_var(output.bufnr, "ogpt_output", true) + instruction:mount() + vim.api.nvim_buf_set_var(instruction.bufnr, "ogpt_instruction", true) + + vim.api.nvim_buf_set_var(parameters.bufnr, "ogpt_parameters", true) + + if opts.show_parameters then + parameters:show() + parameters:mount() + + vim.api.nvim_set_current_win(parameters.winid) + vim.api.nvim_buf_set_option(parameters.bufnr, "modifiable", false) + vim.api.nvim_win_set_option(parameters.winid, "cursorline", true) + else + parameters:hide() + -- vim.api.nvim_set_current_win(instruction.winid) + end + + return nil +end + +return M diff --git a/lua/ogpt/common/simple_window.lua b/lua/ogpt/common/simple_window.lua new file mode 100644 index 0000000..2c45288 --- /dev/null +++ b/lua/ogpt/common/simple_window.lua @@ -0,0 +1,282 @@ +local classes = require("ogpt.common.classes") +local utils = require("ogpt.utils") +local Config = require("ogpt.config") + +local SimpleView = classes.class() + +function SimpleView:init(visitor, opts) + self.visitor = visitor + -- Set default values for the options table if it's not provided. + opts = vim.tbl_extend("force", { + buf = { + swapfile = false, + bufhidden = "wipe", + filetype = "markdown", + vars = {}, + }, + win = { + wrap = true, + cursorline = true, + }, + enter = false, + keymaps = {}, + new_win = false, + }, opts or {}) + self.opts = opts + + self.name = visitor.name or opts.buf.filetype + self.bufnr = nil +end + +-- Define a function that creates a new window with the given options. +-- The function returns the buffer and window handles. +-- function SimpleView:show(opts) +-- self.visible = true +-- opts = vim.tbl_extend("force", self.opts, opts or {}) +-- +-- -- Save the handle of the window from which we open the navigation. +-- local start_win = vim.api.nvim_get_current_win() +-- +-- -- Get the buffer handle. +-- -- local buf = vim.fn.bufnr(self.name) +-- local buf = self.bufnr +-- +-- -- Get the window handle. +-- local win +-- +-- -- If the buffer already exists, find the window that displays it and return its handle. +-- if buf ~= -1 then +-- for _, win_id in ipairs(vim.api.nvim_list_wins()) do +-- local bufnr = vim.api.nvim_win_get_buf(win_id) +-- if bufnr == buf then +-- vim.api.nvim_set_current_win(win_id) +-- vim.api.nvim_set_current_win(start_win) +-- return buf, win_id +-- end +-- end +-- end +-- +-- -- Reset the current window to the one from which we opened the navigation. +-- vim.api.nvim_set_current_win(start_win) +-- +-- -- Return the buffer and window handles. +-- return buf, win +-- end + +function SimpleView:unmount() + -- local buf, win = self:mount() + -- Close the window + local force = true + vim.api.nvim_win_close(self.winid, force) +end + +-- mount can take name or filename +function SimpleView:show() + self:mount() +end + +-- mount can take name or filename +function SimpleView:hide() + local force = true + local winid = vim.fn.bufwinid(self.bufnr) + if winid ~= -1 then + vim.api.nvim_win_close(self.winid, force) + end +end + +-- mount can take name or filename +function SimpleView:mount(name, opts) + name = name or self.name + -- Save the handle of the window from which we open the navigation. + local start_win = vim.api.nvim_get_current_win() + + -- Try to get the buffer handle. + local buf = vim.fn.bufnr(name) + if buf ~= -1 and vim.fn.bufexists(buf) then + -- buffer is still sitting out there and is valid + else + -- pull it from vim global + buf = vim.g[name] + end + + local previous_winid + if not self.opts.new_win then + previous_winid = self.winid + end + + -- If the buffer already exists, find the window that displays it and return its handle. + if buf and vim.fn.bufexists(buf) then + for _, win_id in ipairs(vim.api.nvim_list_wins()) do + local bufnr = vim.api.nvim_win_get_buf(win_id) + if bufnr == buf then + -- if not self.opts.enter then + -- vim.api.nvim_set_current_win(start_win) + -- else + -- vim.api.nvim_set_current_win(win_id) + -- end + previous_winid = win_id + break + -- self.bufnr = buf + -- self.winid = win_id + -- return buf, win_id + end + end + end + + -- -- If the buffer already exists, find the window that displays it and return its handle. + -- if buf and vim.fn.bufexists(buf) then + -- for _, win_id in ipairs(vim.api.nvim_list_wins()) do + -- local bufnr = vim.api.nvim_win_get_buf(win_id) + -- if bufnr == buf then + -- if not self.opts.enter then + -- vim.api.nvim_set_current_win(start_win) + -- else + -- vim.api.nvim_set_current_win(win_id) + -- end + -- self.bufnr = buf + -- self.winid = win_id + -- return buf, win_id + -- end + -- end + -- end + + -- Open a new vertical window at the far right. + -- vim.api.nvim_command("botright " .. "vnew") + if vim.fn.filereadable(name) == 1 then + if previous_winid then + vim.api.nvim_set_current_win(previous_winid) + vim.api.nvim_command("edit " .. name) + else + vim.api.nvim_command("vnew " .. name) + end + self.bufnr = vim.api.nvim_get_current_buf() + else + if previous_winid and not self.opts.new_win and (self.bufnr ~= nil and vim.fn.bufwinid(self.bufnr) ~= -1) then + vim.api.nvim_set_current_win(previous_winid) + self.bufnr = vim.api.nvim_get_current_buf() + vim.api.nvim_buf_set_lines(self.bufnr, -1, -1, false, {}) + else + vim.api.nvim_command("vnew") + self.bufnr = vim.api.nvim_get_current_buf() + -- Set the buffer's filetype to the filetype specified in the options table. + vim.api.nvim_buf_set_option(self.bufnr, "filetype", self.opts.buf.filetype) + -- Set the name of the buffer to the buffer name specified in the options table. + vim.api.nvim_buf_set_name(self.bufnr, name) + end + end + + -- Get the buffer and window handles of the new window. + self.winid = vim.api.nvim_get_current_win() + + -- -- Set the buffer type to "nofile" to prevent it from being saved. + vim.api.nvim_buf_set_option(self.bufnr, "buftype", "nofile") + + -- Disable swapfile for the buffer. + vim.api.nvim_buf_set_option(self.bufnr, "swapfile", false) + + -- Set the buffer's hidden option to "wipe" to destroy it when it's hidden. + vim.api.nvim_buf_set_option(self.bufnr, "bufhidden", "delete") + + -- -- Set so that the cusor does not jump + -- vim.api.nvim_buf_set_option(self.bufnr, "switchbuf", "useopen") + + -- -- Set the name of the buffer to the buffer name specified in the options table. + -- vim.api.nvim_buf_set_name(self.bufnr, name or self.name) + + -- Set buffer variables as specified in the options table. + for key, value in pairs(self.opts.buf.vars or {}) do + vim.api.nvim_buf_set_var(self.bufnr, key, value) + end + + -- Set the window options as specified in the options table. + -- vim.api.nvim_win_set_option(win, "wrap", opts.win.wrap) + -- vim.api.nvim_win_set_option(win, "cursorline", opts.win.cursorline) + + -- Set the keymaps for the window as specified in the options table. + for keymap, command in pairs(self.opts.keymaps) do + vim.keymap.set("n", keymap, command, { noremap = true, buffer = self.bufnr }) + end + + if not self.opts.enter then + vim.api.nvim_set_current_win(start_win) + end + vim.g[name] = self.bufnr +end + +function SimpleView:map(mode, key, command) + mode = vim.tbl_islist(mode) and mode or { mode } + vim.keymap.set(mode, key, command, { buffer = self.bufnr }) +end + +function SimpleView:apply_map(opts) + -- accept output and replace + self:map("n", Config.options.popup.keymaps.accept, function() + -- local _lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) + vim.api.nvim_buf_set_text( + opts.main_bufnr, + opts.selection_idx.start_row - 1, + opts.selection_idx.start_col - 1, + opts.selection_idx.end_row - 1, + opts.selection_idx.end_col, + opts.lines + ) + vim.cmd("q") + end) + + -- accept output and prepend + self:map("n", Config.options.popup.keymaps.prepend, function() + local _lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) + table.insert(_lines, "") + table.insert(_lines, "") + vim.api.nvim_buf_set_text( + opts.main_bufnr, + opts.selection_idx.end_row - 1, + opts.selection_idx.start_col - 1, + opts.selection_idx.end_row - 1, + opts.selection_idx.start_col - 1, + _lines + ) + vim.cmd("q") + end) + + -- accept output and append + self:map("n", Config.options.popup.keymaps.append, function() + local _lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) + table.insert(_lines, 1, "") + table.insert(_lines, "") + vim.api.nvim_buf_set_text( + opts.main_bufnr, + opts.selection_idx.end_row, + opts.selection_idx.start_col - 1, + opts.selection_idx.end_row, + opts.selection_idx.start_col - 1, + _lines + ) + vim.cmd("q") + end) + + -- yank code in output and close + self:map("n", Config.options.popup.keymaps.yank_code, function() + local _lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) + local _code = utils.getSelectedCode(_lines) + vim.fn.setreg(Config.options.yank_register, _code) + + if vim.fn.mode() == "i" then + vim.api.nvim_command("stopinsert") + end + vim.cmd("q") + end) + + -- yank output and close + self:map("n", Config.options.popup.keymaps.yank_to_register, function() + local _lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) + vim.fn.setreg(Config.options.yank_register, _lines) + + if vim.fn.mode() == "i" then + vim.api.nvim_command("stopinsert") + end + vim.cmd("q") + end) +end + +return SimpleView diff --git a/lua/ogpt/flows/actions/base.lua b/lua/ogpt/flows/actions/base.lua index dc04c99..2c761b1 100644 --- a/lua/ogpt/flows/actions/base.lua +++ b/lua/ogpt/flows/actions/base.lua @@ -25,9 +25,9 @@ end function BaseAction:post_init() self.popup = PopupWindow() self.spinner = Spinner:new(function(state) - vim.schedule(function() - self:display_input_suffix(state) - end) + -- vim.schedule(function() + -- self:display_input_suffix(state) + -- end) end) self:update_variables() @@ -111,12 +111,14 @@ function BaseAction:set_loading(state) end self:mark_selection_with_signs() self.spinner:start() + self.stop = false else self.spinner:stop() Signs.del(bufnr) if self.extmark_id then vim.api.nvim_buf_del_extmark(bufnr, namespace_id, self.extmark_id) end + self.stop = true end end diff --git a/lua/ogpt/flows/actions/edits/init.lua b/lua/ogpt/flows/actions/edits/init.lua index 060d997..45c4d9d 100644 --- a/lua/ogpt/flows/actions/edits/init.lua +++ b/lua/ogpt/flows/actions/edits/init.lua @@ -24,7 +24,6 @@ function EditAction:init(name, opts) self.variables = opts.variables or {} self.strategy = opts.strategy or STRATEGY_EDIT self.ui = opts.ui or {} - self:post_init() end diff --git a/lua/ogpt/flows/actions/init.lua b/lua/ogpt/flows/actions/init.lua index 3a69000..539a6d9 100644 --- a/lua/ogpt/flows/actions/init.lua +++ b/lua/ogpt/flows/actions/init.lua @@ -2,6 +2,7 @@ local M = {} -- local CompletionAction = require("ogpt.flows.actions.completions") local EditAction = require("ogpt.flows.actions.edits") +local EditSimpleAction = require("ogpt.flows.actions.simple_edit") local PopupAction = require("ogpt.flows.actions.popup") local Config = require("ogpt.config") @@ -9,6 +10,7 @@ local classes_by_type = { chat = PopupAction, -- completion = CompletionAction, edit = EditAction, + simple_edit = EditSimpleAction, popup = PopupAction, } diff --git a/lua/ogpt/flows/actions/popup/init.lua b/lua/ogpt/flows/actions/popup/init.lua index 2d675dd..a31294e 100644 --- a/lua/ogpt/flows/actions/popup/init.lua +++ b/lua/ogpt/flows/actions/popup/init.lua @@ -2,6 +2,7 @@ local classes = require("ogpt.common.classes") local BaseAction = require("ogpt.flows.actions.base") local utils = require("ogpt.utils") local Config = require("ogpt.config") +local SimpleWindow = require("ogpt.common.simple_window") local PopupAction = classes.class(BaseAction) @@ -9,6 +10,7 @@ local STRATEGY_REPLACE = "replace" local STRATEGY_APPEND = "append" local STRATEGY_PREPEND = "prepend" local STRATEGY_DISPLAY = "display" +local STRATEGY_DISPLAY_WINDOW = "display_window" local STRATEGY_QUICK_FIX = "quick_fix" function PopupAction:init(name, opts) @@ -27,13 +29,81 @@ function PopupAction:init(name, opts) end function PopupAction:run() - self.stop = false + -- self.stop = false local params = self:get_params() local _, start_row, start_col, end_row, end_col = self:get_visual_selection() + local opts = { + name = self.name, + cur_win = self.cur_win, + main_bufnr = self:get_bufnr(), + selection_idx = { + start_row = start_row, + start_col = start_col, + end_row = end_row, + end_col = end_col, + }, + default_ui = self.ui, + title = self.opts.title, + args = self.opts.args, + stop = function() + self.stop = true + end, + } if self.strategy == STRATEGY_DISPLAY then - self:run_spinner(true) - self.popup:mount({ + self:set_loading(true) + self.popup:mount(opts) + params.stream = true + self.provider.api:chat_completions( + params, + utils.partial(utils.add_partial_completion, { + panel = self.popup, + progress = function(flag) + self:run_spinner(flag) + end, + on_complete = function(total_text) + -- print("completed: " .. total_text) + end, + }), + function() + -- should stop function + if self.stop then + -- self.stop = false + -- self:run_spinner(false) + self:set_loading(false) + return true + else + return false + end + end + ) + elseif self.strategy == STRATEGY_DISPLAY_WINDOW then + self.popup = SimpleWindow.new(self, { + new_win = false, + buf = { + filetype = "markdown", + vars = { + ogpt = true, + }, + }, + }) + self.popup:apply_map(opts) + + local keys = Config.options.popup.keymaps.close + if type(keys) ~= "table" then + keys = { keys } + end + for _, key in ipairs(keys) do + self.popup:map("n", key, function() + if opts.stop and type(opts.stop) == "function" then + opts.stop() + end + self.popup:unmount() + end) + end + + self:set_loading(true) + self.popup:mount(self.name, { name = self.name, cur_win = self.cur_win, main_bufnr = self:get_bufnr(), @@ -51,7 +121,29 @@ function PopupAction:run() end, }) params.stream = true - self:call_api(self.popup, params) + self.provider.api:chat_completions( + params, + utils.partial(utils.add_partial_completion, { + panel = self.popup, + progress = function(flag) + self:run_spinner(flag) + end, + on_complete = function(total_text) + -- print("completed: " .. total_text) + end, + }), + function() + -- should stop function + if self.stop then + -- self.stop = false + -- self:run_spinner(false) + self:set_loading(false) + return true + else + return false + end + end + ) else self:set_loading(true) self.provider.api:chat_completions(params, function(answer, usage) @@ -60,31 +152,6 @@ function PopupAction:run() end end -function PopupAction:call_api(panel, params) - self.provider.api:chat_completions( - params, - utils.partial(utils.add_partial_completion, { - panel = panel, - progress = function(flag) - self:run_spinner(flag) - end, - on_complete = function(total_text) - -- print("completed: " .. total_text) - end, - }), - function() - -- should stop function - if self.stop then - self.stop = false - self:run_spinner(false) - return true - else - return false - end - end - ) -end - function PopupAction:on_result(answer, usage) vim.schedule(function() self:set_loading(false) diff --git a/lua/ogpt/flows/actions/simple_edit/init.lua b/lua/ogpt/flows/actions/simple_edit/init.lua new file mode 100644 index 0000000..eb333a2 --- /dev/null +++ b/lua/ogpt/flows/actions/simple_edit/init.lua @@ -0,0 +1,327 @@ +local classes = require("ogpt.common.classes") +local SimpleWindow = require("ogpt.common.simple_window") +local BaseAction = require("ogpt.flows.actions.base") +local layouts = require("ogpt.common.layouts") +local utils = require("ogpt.utils") +local Config = require("ogpt.config") +local Layout = require("nui.layout") +local Split = require("nui.split") +local Popup = require("nui.popup") +local ChatInput = require("ogpt.input") +local SimpleParameters = require("ogpt.simple_parameters") + +local EditAction = classes.class(BaseAction) + +local STRATEGY_EDIT = "edit" +local STRATEGY_EDIT_CODE = "edit_code" + +function EditAction:init(name, opts) + self.name = name or "" + opts = opts or {} + self.super:init(opts) + self.provider = Config.get_provider(opts.provider, self) + self.params = Config.get_action_params(self.provider.name, opts.params or {}) + self.system = type(opts.system) == "function" and opts.system() or opts.system or "" + self.template = type(opts.template) == "function" and opts.template() or opts.template or "{{input}}" + self.variables = opts.variables or {} + self.strategy = opts.strategy or STRATEGY_EDIT + self.ui = opts.ui or {} + self:post_init() +end + +function EditAction:run() + vim.schedule(function() + if self.strategy == STRATEGY_EDIT_CODE and self.opts.delay then + self:edit_with_instructions({}, { self:get_visual_selection() }, { + template = self.template, + variables = self.variables, + edit_code = true, + filetype = self:get_filetype(), + }) + elseif self.strategy == STRATEGY_EDIT and self.opts.delay then + self:edit_with_instructions({}, { self:get_visual_selection() }, { + template = self.template, + variables = self.variables, + edit_code = false, + }) + elseif self.strategy == STRATEGY_EDIT then + self:edit_with_instructions({}, { self:get_visual_selection() }, { + template = self.template, + variables = self.variables, + -- params = self:get_params(), + }) + elseif self.strategy == STRATEGY_EDIT_CODE then + self:edit_with_instructions({}, { self:get_visual_selection() }, { + template = self.template, + variables = self.variables, + -- params = self:get_params(), + edit_code = true, + filetype = self:get_filetype(), + }) + end + end) +end + +local instructions_input, layout, input_window, output_window, output, timer, filetype, bufnr, extmark_id + +local setup_and_mount = function(lines, output_lines, ...) + -- set input + if lines then + vim.api.nvim_buf_set_lines(input_window.bufnr, 0, -1, false, lines) + end + + -- set output + if output_lines then + vim.api.nvim_buf_set_lines(output_window.bufnr, 0, -1, false, output_lines) + end + + -- set input and output settings + for _, window in ipairs({ input_window, output_window }) do + vim.api.nvim_buf_set_option(window.bufnr, "filetype", "markdown") + vim.api.nvim_win_set_option(window.winid, "number", true) + end +end + +function EditAction:edit_with_instructions(output_lines, selection, opts, ...) + opts = opts or {} + opts.params = opts.params or self.params + local api_params = opts.params + + bufnr = self:get_bufnr() + + filetype = vim.api.nvim_buf_get_option(bufnr, "filetype") + + local visual_lines, start_row, start_col, end_row, end_col + if selection == nil then + visual_lines, start_row, start_col, end_row, end_col = utils.get_visual_lines(bufnr) + else + visual_lines, start_row, start_col, end_row, end_col = unpack(selection) + end + local parameters_panel = SimpleParameters.get_parameters_panel("edits", api_params) + input_window = Popup(Config.options.popup_window) + output_window = Popup(Config.options.popup_window) + -- instructions_input = ChatInput(Config.options.popup_input, { + instructions_input = SimpleWindow.new(self, { + prompt = Config.options.popup_input.prompt, + default_value = opts.instruction or "", + on_close = function() + -- if self.spinner:is_running() then + -- self.spinner:stop() + -- end + self:run_spinner(false) + if timer ~= nil then + timer:stop() + end + end, + on_submit = vim.schedule_wrap(function(instruction) + -- clear input + vim.api.nvim_buf_set_lines(instructions_input.bufnr, 0, -1, false, { "" }) + vim.api.nvim_buf_set_lines(output_window.bufnr, 0, -1, false, { "" }) + -- show_progress() + self:run_spinner(true) + + local input = table.concat(vim.api.nvim_buf_get_lines(input_window.bufnr, 0, -1, false), "\n") + + -- if instruction is empty, try to get the original instruction from opts + if instruction == "" then + instruction = opts.instruction or "" + end + local messages = self:build_edit_messages(input, instruction, opts) + local params = vim.tbl_extend("keep", { messages = messages }, SimpleParameters.params) + self.provider.api:edits( + params, + utils.partial(utils.add_partial_completion, { + panel = output_window, + on_complete = function(response) + -- on the completion, execute this function to extract out codeblocks + local nlcount = utils.count_newlines_at_end(response) + local output_txt = response + if opts.edit_code then + local code_response = utils.extract_code(response) + -- if the chat is to edit code, it will try to extract out the code from response + output_txt = response + if code_response then + output_txt = utils.match_indentation(response, code_response) + else + vim.notify("no codeblock detected", vim.log.levels.INFO) + end + if response.applied_changes then + vim.notify(response.applied_changes, vim.log.levels.INFO) + end + end + local output_txt_nlfixed = utils.replace_newlines_at_end(output_txt, nlcount) + local _output = utils.split_string_by_line(output_txt_nlfixed) + if output_window.bufnr then + vim.api.nvim_buf_set_lines(output_window.bufnr, 0, -1, false, _output) + end + end, + progress = function(flag) + self:run_spinner(flag) + end, + }) + ) + end), + }) + instructions_input:map("n", "", function() + local instructions = vim.api.nvim_buf_get_lines(instructions_input.bufnr, 0, -1, false) + instructions_input.opts.on_submit(table.concat(instructions, "\n")) + end) + + layout = layouts.edit_with_no_layout(layout, input_window, instructions_input, output_window, parameters_panel, { + show_parameters = false, + }) + + -- accept output window + for _, window in ipairs({ input_window, output_window, instructions_input }) do + for _, mode in ipairs({ "n", "i" }) do + window:map(mode, Config.options.edit.keymaps.accept, function() + instructions_input.input_props.on_close() + local lines = vim.api.nvim_buf_get_lines(output_window.bufnr, 0, -1, false) + vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) + vim.notify("Successfully applied the change!", vim.log.levels.INFO) + end, { noremap = true }) + end + end + + -- use output as input + for _, window in ipairs({ input_window, output_window, instructions_input }) do + for _, mode in ipairs({ "n", "i" }) do + window:map(mode, Config.options.edit.keymaps.use_output_as_input, function() + local lines = vim.api.nvim_buf_get_lines(output_window.bufnr, 0, -1, false) + vim.api.nvim_buf_set_lines(input_window.bufnr, 0, -1, false, lines) + vim.api.nvim_buf_set_lines(output_window.bufnr, 0, -1, false, {}) + end, { noremap = true }) + end + end + + -- close + for _, window in ipairs({ input_window, output_window, instructions_input }) do + for _, mode in ipairs({ "n", "i" }) do + window:map(mode, Config.options.edit.keymaps.close, function() + self.spinner:stop() + if vim.fn.mode() == "i" then + vim.api.nvim_command("stopinsert") + end + vim.cmd("q") + end, { noremap = true }) + end + end + + -- toggle parameters + local parameters_open = false + for _, popup in ipairs({ parameters_panel, instructions_input, input_window, output_window }) do + for _, mode in ipairs({ "n", "i" }) do + popup:map(mode, Config.options.edit.keymaps.toggle_parameters, function() + if parameters_open then + layouts.edit_with_no_layout(layout, input_window, instructions_input, output_window, parameters_panel, { + show_parameters = false, + }) + else + layouts.edit_with_no_layout(layout, input_window, instructions_input, output_window, parameters_panel, { + show_parameters = true, + }) + SimpleParameters.refresh_panel() + end + parameters_open = not parameters_open + -- set input and output settings + -- TODO + for _, window in ipairs({ input_window, output_window }) do + vim.api.nvim_buf_set_option(window.bufnr, "filetype", filetype) + vim.api.nvim_win_set_option(window.winid, "number", true) + end + end, {}) + end + end + + -- cycle windows + local active_panel = instructions_input + for _, popup in ipairs({ input_window, output_window, parameters_panel, instructions_input }) do + for _, mode in ipairs({ "n", "i" }) do + if mode == "i" and (popup == input_window or popup == output_window) then + goto continue + end + popup:map(mode, Config.options.edit.keymaps.cycle_windows, function() + if active_panel == instructions_input then + vim.api.nvim_set_current_win(input_window.winid) + active_panel = input_window + vim.api.nvim_command("stopinsert") + elseif active_panel == input_window and mode ~= "i" then + vim.api.nvim_set_current_win(output_window.winid) + active_panel = output_window + vim.api.nvim_command("stopinsert") + elseif active_panel == output_window and mode ~= "i" then + if parameters_open then + vim.api.nvim_set_current_win(parameters_panel.winid) + active_panel = parameters_panel + else + vim.api.nvim_set_current_win(instructions_input.winid) + active_panel = instructions_input + end + elseif active_panel == parameters_panel then + vim.api.nvim_set_current_win(instructions_input.winid) + active_panel = instructions_input + end + end, {}) + ::continue:: + end + end + + -- toggle diff mode + local diff_mode = Config.options.edit.diff + for _, popup in ipairs({ parameters_panel, instructions_input, output_window, input_window }) do + for _, mode in ipairs({ "n", "i" }) do + popup:map(mode, Config.options.edit.keymaps.toggle_diff, function() + diff_mode = not diff_mode + for _, winid in ipairs({ input_window.winid, output_window.winid }) do + vim.api.nvim_set_current_win(winid) + if diff_mode then + vim.api.nvim_command("diffthis") + else + vim.api.nvim_command("diffoff") + end + vim.api.nvim_set_current_win(instructions_input.winid) + end + end, {}) + end + end + + setup_and_mount(visual_lines, output_lines) +end + +function EditAction:build_edit_messages(input, instructions, opts) + local _input = input + if opts.edit_code then + _input = "```" .. (opts.filetype or "") .. "\n" .. input .. "````" + else + _input = "```" .. (opts.filetype or "") .. "\n" .. input .. "````" + end + local variables = vim.tbl_extend("force", {}, { + instruction = instructions, + input = _input, + filetype = opts.filetype, + }, opts.variables) + local system_msg = opts.params.system or "" + local messages = { + { + role = "system", + content = system_msg, + }, + { + role = "user", + content = self:render_template(variables, opts.template), + }, + } + + return messages +end + +function EditAction:render_template(variables, template) + local result = template + for key, value in pairs(variables) do + local escaped_value = utils.escape_pattern(value) + result = string.gsub(result, "{{" .. key .. "}}", escaped_value) + end + return result +end + +return EditAction diff --git a/lua/ogpt/parameters.lua b/lua/ogpt/parameters.lua index 7712618..7237c7c 100644 --- a/lua/ogpt/parameters.lua +++ b/lua/ogpt/parameters.lua @@ -1,4 +1,5 @@ local pickers = require("telescope.pickers") +local SimpleWindow = require("ogpt.common.simple_window") local Utils = require("ogpt.utils") local conf = require("telescope.config").values local actions = require("telescope.actions") @@ -256,6 +257,8 @@ M.get_parameters_panel = function(type, default_params, session, parent) end M.panel = Popup(Config.options.parameters_window) + -- M.panel = SimpleWindow.new(Config.options.parameters_window) + -- M.panel:mount() M.refresh_panel() M.panel:map("n", "d", function() diff --git a/lua/ogpt/simple_parameters.lua b/lua/ogpt/simple_parameters.lua new file mode 100644 index 0000000..a339831 --- /dev/null +++ b/lua/ogpt/simple_parameters.lua @@ -0,0 +1,382 @@ +local pickers = require("telescope.pickers") +local SimpleWindow = require("ogpt.common.simple_window") +local Utils = require("ogpt.utils") +local conf = require("telescope.config").values +local actions = require("telescope.actions") +local action_state = require("telescope.actions.state") + +local M = {} +M.vts = {} + +local Popup = require("nui.popup") +local Config = require("ogpt.config") + +local namespace_id = vim.api.nvim_create_namespace("OGPTNS") + +local float_validator = function(min, max) + return function(value) + return tonumber(value) + end +end + +local bool_validator = function(min, max) + return function(value) + local stringtoboolean = { ["true"] = true, ["false"] = false } + return stringtoboolean(value) + end +end + +local integer_validator = function(min, max) + return function(value) + return tonumber(value) + end +end + +local model_validator = function() + return function(value) + return value + end +end + +local function preview_command(entry, bufnr, width) + vim.api.nvim_buf_call(bufnr, function() + local preview = Utils.wrapTextToTable(entry.value, width - 5) + table.insert(preview, 1, "---") + table.insert(preview, 1, entry.display) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, true, preview) + end) +end + +local finder = function(opts) + return setmetatable({ + close = function() + -- TODO: check if we need to make some cleanup + end, + }, { + __call = function(_, prompt, process_result, process_complete) + local _params = { + values = opts.parameters, + } + for _, param in ipairs(_params.values) do + process_result({ + value = param, + display = param, + ordinal = param, + preview_command = preview_command, + }) + end + process_complete() + end, + }) +end + +local params_order = { + "provider", + "model", + "embedding_only", + "f16_kv", + "frequency_penalty", + "logits_all", + "low_vram", + "main_gpu", + "max_tokens", + "mirostat", + "mirostat_eta", + "mirostat_tau", + "num_batch", + "num_ctx", + "num_gpu", + "num_gqa", + "num_keep", + "num_predict", + "num_thread", + "presence_penalty", + "repeat_last_n", + "repeat_penalty", + "rope_frequency_base", + "rope_frequency_scale", + "seed", + "stop", + "temperature", + "tfs_z", + "top_k", + "top_p", + "typical_p", + "use_mlock", + "use_mmap", + "vocab_only", +} +local params_validators = { + provider = model_validator(), + model = model_validator(), + embedding_only = model_validator(), + f16_kv = model_validator(), + frequency_penalty = float_validator(), + mirostat = integer_validator(), + mirostat_eta = float_validator(), + mirostat_tau = float_validator(), + num_batch = integer_validator(), + num_ctx = integer_validator(), + num_gpu = integer_validator(), + num_gqa = integer_validator(), + num_keep = integer_validator(), + num_predict = integer_validator(), + num_thread = integer_validator(), + presence_penalty = float_validator(), + repeat_last_n = integer_validator(), + repeat_penalty = float_validator(), + seed = integer_validator(), + stop = model_validator(), + temperature = float_validator(), + tfs_z = float_validator(), + top_k = float_validator(), + top_p = float_validator(), + logits_all = bool_validator(), + vocab_only = bool_validator(), + use_mmap = bool_validator(), + use_mlock = bool_validator(), + low_vram = bool_validator(), +} + +local function write_virtual_text(bufnr, ns, line, chunks, mode) + mode = mode or "extmark" + if mode == "extmark" then + return vim.api.nvim_buf_set_extmark(bufnr, ns, line, 0, { virt_text = chunks, virt_text_pos = "overlay" }) + elseif mode == "vt" then + pcall(vim.api.nvim_buf_set_virtual_text, bufnr, ns, line, chunks, {}) + end +end + +function M.select_parameter(opts) + opts = opts or {} + pickers + .new(opts, { + sorting_strategy = "ascending", + layout_config = { + height = 0.5, + }, + results_title = "Select Additional Parameter", + prompt_prefix = Config.options.popup_input.prompt, + selection_caret = Config.options.chat.answer_sign .. " ", + prompt_title = "Parameter", + finder = finder({ + parameters = params_order, + }), + sorter = conf.generic_sorter(opts), + attach_mappings = function(prompt_bufnr) + actions.select_default:replace(function() + actions.close(prompt_bufnr) + local selection = action_state.get_selected_entry() + opts.cb(selection.display, vim.fn.input("value: ")) + end) + return true + end, + }) + :find() +end + +M.read_config = function(session) + if not session then + local home = os.getenv("HOME") or os.getenv("USERPROFILE") + local file = io.open(home .. "/" .. ".ogpt-" .. M.type .. "-params.json", "rb") + if not file then + return nil + end + + local jsonString = file:read("*a") + file:close() + return vim.json.decode(jsonString) + else + return session.parameters + end +end + +M.write_config = function(config, session) + if not session then + local home = os.getenv("HOME") or os.getenv("USERPROFILE") + local file, err = io.open(home .. "/" .. ".ogpt-" .. M.type .. "-params.json", "w") + if file ~= nil then + local json_string = vim.json.encode(config) + file:write(json_string) + file:close() + else + vim.notify("Cannot save parameters: " .. err, vim.log.levels.ERROR) + end + else + session.parameters = config + session:save() + end +end + +M.refresh_panel = function() + -- write details as virtual text + local details = {} + for _, key in pairs(params_order) do + if M.params[key] ~= nil then + local display_text = M.params[key] + if type(display_text) == "table" then + if display_text.name then + display_text = display_text.name + else + display_text = table.concat(M.params[key], ", ") + end + end + + local vt = { + { Config.options.parameters_window.setting_sign .. key .. ": ", "ErrorMsg" }, + { display_text .. "", "Identifier" }, + } + table.insert(details, vt) + end + end + + vim.api.nvim_buf_clear_namespace(M.panel.bufnr, namespace_id, 0, -1) + + local line = 1 + local empty_lines = {} + for _ = 1, #details do + table.insert(empty_lines, "") + end + + vim.api.nvim_buf_set_option(M.panel.bufnr, "modifiable", true) + vim.api.nvim_buf_set_lines(M.panel.bufnr, line - 1, line - 1 + #empty_lines, false, empty_lines) + vim.api.nvim_buf_set_option(M.panel.bufnr, "modifiable", false) + for _, d in ipairs(details) do + M.vts[line - 1] = write_virtual_text(M.panel.bufnr, namespace_id, line - 1, d) + line = line + 1 + end +end + +M.get_parameters_panel = function(type, default_params, session, parent) + M.type = type + M.name = "ogpt_parameters" + local custom_params = M.read_config(session or {}) + + M.params = vim.tbl_deep_extend("force", {}, default_params, custom_params or {}) + if session then + M.params = session.parameters + end + + -- M.panel = Popup(Config.options.parameters_window) + M.panel = SimpleWindow.new(M, Config.options.parameters_window) + M.panel:mount() + + vim.api.nvim_buf_set_option(M.panel.bufnr, "modifiable", true) + + M.refresh_panel() + + M.panel:map("n", "d", function() + local row, _ = unpack(vim.api.nvim_win_get_cursor(M.panel.winid)) + + local existing_order = {} + for _, key in ipairs(params_order) do + if M.params[key] ~= nil then + table.insert(existing_order, key) + end + end + + local key = existing_order[row] + M.update_property(key, row, nil, session) + M.refresh_panel() + end) + + M.panel:map("n", "a", function() + local row, _ = unpack(vim.api.nvim_win_get_cursor(M.panel.winid)) + M.select_parameter({ + cb = function(key, value) + M.update_property(key, row + 1, value, session) + end, + }) + end) + + M.panel:map("n", "", function() + local row, _ = unpack(vim.api.nvim_win_get_cursor(M.panel.winid)) + + local existing_order = {} + for _, key in ipairs(params_order) do + if M.params[key] ~= nil then + table.insert(existing_order, key) + end + end + + local key = existing_order[row] + if key == "model" then + local models = require("ogpt.models") + models.select_model(parent.provider, { + cb = function(display, value) + M.update_property(key, row, value, session) + end, + }) + elseif key == "provider" then + local provider = require("ogpt.provider") + provider.select_provider({ + cb = function(display, value) + M.update_property(key, row, value, session) + parent.provider = Config.get_provider(value) + end, + }) + else + local value = M.params[key] + M.open_edit_property_input(key, value, row, function(new_value) + M.update_property(key, row, Utils.process_string(new_value), session) + end) + end + end, {}) + + return M.panel +end + +M.update_property = function(key, row, new_value, session) + if not key or not new_value then + M.params[key] = nil + vim.api.nvim_buf_set_option(M.panel.bufnr, "modifiable", true) + vim.api.nvim_del_current_line() + vim.api.nvim_buf_set_option(M.panel.bufnr, "modifiable", false) + else + M.params[key] = params_validators[key](new_value) + end + M.write_config(M.params, session) + M.refresh_panel() +end + +M.get_panel = function(session, parent) + return M.get_parameters_panel(" ", session.parameters or {}, session, parent) +end + +M.open_edit_property_input = function(key, value, row, cb) + -- convert table to string first + if type(value) == "table" then + value = table.concat(value, ", ") + end + + local Input = require("nui.input") + + local input = Input({ + relative = { + type = "win", + winid = M.panel.winid, + }, + position = { + row = row - 1, + col = 0, + }, + size = { + width = 38, + }, + border = { + style = "none", + }, + win_options = { + winhighlight = "Normal:Normal,FloatBorder:Normal", + }, + }, { + prompt = Config.options.popup_input.prompt .. key .. ": ", + default_value = "" .. value, + on_submit = cb, + }) + + -- mount/open the component + input:mount() +end + +return M