diff --git a/lua/ogpt/config.lua b/lua/ogpt/config.lua index 8deb4ec..e89dfaa 100644 --- a/lua/ogpt/config.lua +++ b/lua/ogpt/config.lua @@ -207,6 +207,7 @@ function M.defaults() border_right_sign = "|", max_line_length = 120, edgy = nil, -- use global default + args = {}, sessions_window = { active_sign = " 󰄵 ", inactive_sign = " 󰄱 ", diff --git a/lua/ogpt/flows/chat/base.lua b/lua/ogpt/flows/chat/base.lua index 9a358f4..0b9aa3f 100644 --- a/lua/ogpt/flows/chat/base.lua +++ b/lua/ogpt/flows/chat/base.lua @@ -12,6 +12,7 @@ local Signs = require("ogpt.signs") local Spinner = require("ogpt.spinner") local Session = require("ogpt.flows.chat.session") local UtilWindow = require("ogpt.util_window") +local template_helpers = require("ogpt.flows.actions.template_helpers") QUESTION, ANSWER, SYSTEM = 1, 2, 3 ROLE_ASSISTANT = "assistant" @@ -21,7 +22,7 @@ ROLE_USER = "user" local Chat = Object("Chat") function Chat:init(opts) - opts = opts or {} + self.opts = opts or {} self.input_extmark_id = nil self.active_panel = nil @@ -53,8 +54,9 @@ function Chat:init(opts) self.prompt_lines = 1 self.display_mode = Config.options.popup_layout.default - self.params = Config.get_chat_params(opts.provider) + self.params = Config.get_chat_params(self.opts.provider) + self.variables = {} self.session = Session.latest() self.provider = nil self.selectedIndex = 0 @@ -171,6 +173,66 @@ function Chat:isBusy() return self.is_running end +function Chat:update_variables() + local bufnr = vim.api.nvim_get_current_buf() + self.variables = vim.tbl_extend("force", self.variables, { + filetype = function() + return vim.api.nvim_buf_get_option(bufnr, "filetype") + end, + input = function() + return utils.get_selected_text(bufnr) + end, + selection = function() + return utils.get_selected_range(bufnr) + end, + }) + + -- pull in action defined args + self.variables = vim.tbl_extend("force", self.variables, self.opts.args or {}) + + -- add in plugin predefined template helpers + for helper, helper_fn in pairs(template_helpers) do + local _v = { [helper] = helper_fn } + self.variables = vim.tbl_extend("force", self.variables, _v) + end +end + +function Chat:render_template(variables, text) + variables = vim.tbl_extend("force", self.variables, variables or {}) + -- lazily render the final string. + -- it recursively loop on the template string until it does not find anymore + -- {{{}}} patterns + local stop = false + local depth = 2 + local result = text + -- deprecating warning of {{}} (double curly) + if vim.fn.match(result, [[\v\{\{([^}]+)\}\}(})@!]]) ~= -1 then + utils.log( + "You may be using the {{}}, please updated to {{{}}} (triple curly braces)in your custom actions.", + vim.log.levels.ERROR + ) + end + + local pattern = "%{%{%{(([%w_]+))%}%}%}" + repeat + for match in string.gmatch(result, pattern) do + local value = variables[match] + if value then + if type(value) == "function" then + value = value(self.variables) + end + local escaped_value = utils.escape_pattern(value) + result = string.gsub(result, "{{{" .. match .. "}}}", escaped_value) + else + utils.log("Cannot find {{{" .. match .. "}}}", vim.log.levels.ERROR) + stop = true + end + end + depth = depth - 1 + until not string.match(result, pattern) or stop or depth == 0 + return result +end + function Chat:add(type, text, usage) local idx = self.session:add_item({ type = type, @@ -178,6 +240,10 @@ function Chat:add(type, text, usage) usage = usage, ctx = { context = self.session:previous_context() }, }) + + self:update_variables() + text = self:render_template(Config.options.chat.args, text) + self:_add(type, text, usage, idx) self:render_role() end diff --git a/lua/ogpt/response.lua b/lua/ogpt/response.lua index 9aebec3..76b8b3a 100644 --- a/lua/ogpt/response.lua +++ b/lua/ogpt/response.lua @@ -30,8 +30,8 @@ function Response:init(provider, events) self.not_processed = Deque.new() self.not_processed_raw = Deque.new() self.raw_chunk_tx, self.raw_chunk_rx = channel.mpsc() - self.processed_raw_tx, self.processed_raw_rx = channel.mpsc() - self.processed_content_tx, self.processed_content_rx = channel.mpsc() + self.processed_raw_tx, self.processsed_raw_rx = channel.mpsc() + self.processed_content_tx, self.processsed_content_rx = channel.mpsc() self.response_state = nil self.chunk_regex = "" self:set_state(self.STATE_INIT) @@ -90,7 +90,7 @@ function Response:_process_added_chunk() chunk = chunk .. queued_chunk end - -- Run different strategies for processing responses here + -- Run different strategies for processsing responses here if self.provider.response_params.strategy == self.STRATEGY_CHUNK then self.processed_raw_tx.send(chunk) elseif @@ -123,7 +123,7 @@ function Response:_process_added_chunk() end function Response:pop_content() - local content = self.processed_content_rx.recv() + local content = self.processsed_content_rx.recv() if content[2] == "END" and (content[1] and content[1] == "") then content[1] = self:get_processed_text() end @@ -141,29 +141,18 @@ function Response:render() end function Response:pop_chunk() - -- -- pop the next chunk and add anything that is not process - -- local _value = self.not_processed - -- self.not_processed = "" - -- local _chunk = self.processed_raw_rx.recv() - -- utils.log("Got chunk... now appending to 'not_processed'", vim.log.levels.TRACE) - -- return _value .. _chunk - - local _chunk = self.processed_raw_rx.recv() + local _chunk = self.processsed_raw_rx.recv() utils.log("pushing processed raw to queue: " .. _chunk, vim.log.levels.TRACE) -- push on to queue self.not_processed:pushright(_chunk) utils.log("popping processed raw from queue: " .. _chunk, vim.log.levels.TRACE) - return self.not_processed:popleft() - - -- local chunk = _chunk - -- -- clear the queue each time to try to get a full chunk - -- for i, queued_chunk in self.not_processed:ipairs_left() do - -- chunk = chunk .. queued_chunk - -- utils.log("Adding to final processed output: " .. chunk, vim.log.levels.TRACE) - -- self.not_processed[i] = nil - -- end - -- return chunk + + local _data = {} + while not self.not_processed:is_empty() do + table.insert(_data, self.not_processed:popleft()) + end + return table.concat(_data, "") end function Response:get_accumulated_chunks() diff --git a/lua/ogpt/utils.lua b/lua/ogpt/utils.lua index c8b68c1..a2835ff 100644 --- a/lua/ogpt/utils.lua +++ b/lua/ogpt/utils.lua @@ -456,4 +456,17 @@ function M.curl(args, on_exit) return job end +function M.get_visual_selection(bufnr) + local lines, start_row, start_col, end_row, end_col = M.get_visual_lines(bufnr) + + return lines, start_row, start_col, end_row, end_col +end + +function M.get_selected_text(bufnr) + -- selection using vim GV, after selection is made, it remains in + -- vim registry, and this function recall that selection + local lines, _, _, _, _ = M.get_visual_selection(bufnr) + return table.concat(lines, "\n") +end + return M