Skip to content

Commit

Permalink
fixing response object for reliable parsing of incoming json messages
Browse files Browse the repository at this point in the history
  • Loading branch information
huynle committed Oct 9, 2024
2 parents 9579eb7 + fe87e33 commit 5468f7c
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 24 deletions.
1 change: 1 addition & 0 deletions lua/ogpt/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ function M.defaults()
border_right_sign = "|",
max_line_length = 120,
edgy = nil, -- use global default
args = {},
sessions_window = {
active_sign = " 󰄵 ",
inactive_sign = " 󰄱 ",
Expand Down
70 changes: 68 additions & 2 deletions lua/ogpt/flows/chat/base.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ local Signs = require("ogpt.signs")
local Spinner = require("ogpt.spinner")
local Session = require("ogpt.flows.chat.session")
local UtilWindow = require("ogpt.util_window")
local template_helpers = require("ogpt.flows.actions.template_helpers")

QUESTION, ANSWER, SYSTEM = 1, 2, 3
ROLE_ASSISTANT = "assistant"
Expand All @@ -21,7 +22,7 @@ ROLE_USER = "user"
local Chat = Object("Chat")

function Chat:init(opts)
opts = opts or {}
self.opts = opts or {}
self.input_extmark_id = nil

self.active_panel = nil
Expand Down Expand Up @@ -53,8 +54,9 @@ function Chat:init(opts)
self.prompt_lines = 1

self.display_mode = Config.options.popup_layout.default
self.params = Config.get_chat_params(opts.provider)
self.params = Config.get_chat_params(self.opts.provider)

self.variables = {}
self.session = Session.latest()
self.provider = nil
self.selectedIndex = 0
Expand Down Expand Up @@ -171,13 +173,77 @@ function Chat:isBusy()
return self.is_running
end

function Chat:update_variables()
local bufnr = vim.api.nvim_get_current_buf()
self.variables = vim.tbl_extend("force", self.variables, {
filetype = function()
return vim.api.nvim_buf_get_option(bufnr, "filetype")
end,
input = function()
return utils.get_selected_text(bufnr)
end,
selection = function()
return utils.get_selected_range(bufnr)
end,
})

-- pull in action defined args
self.variables = vim.tbl_extend("force", self.variables, self.opts.args or {})

-- add in plugin predefined template helpers
for helper, helper_fn in pairs(template_helpers) do
local _v = { [helper] = helper_fn }
self.variables = vim.tbl_extend("force", self.variables, _v)
end
end

function Chat:render_template(variables, text)
variables = vim.tbl_extend("force", self.variables, variables or {})
-- lazily render the final string.
-- it recursively loop on the template string until it does not find anymore
-- {{{}}} patterns
local stop = false
local depth = 2
local result = text
-- deprecating warning of {{}} (double curly)
if vim.fn.match(result, [[\v\{\{([^}]+)\}\}(})@!]]) ~= -1 then
utils.log(
"You may be using the {{<template_helper>}}, please updated to {{{<template_helpers>}}} (triple curly braces)in your custom actions.",
vim.log.levels.ERROR
)
end

local pattern = "%{%{%{(([%w_]+))%}%}%}"
repeat
for match in string.gmatch(result, pattern) do
local value = variables[match]
if value then
if type(value) == "function" then
value = value(self.variables)
end
local escaped_value = utils.escape_pattern(value)
result = string.gsub(result, "{{{" .. match .. "}}}", escaped_value)
else
utils.log("Cannot find {{{" .. match .. "}}}", vim.log.levels.ERROR)
stop = true
end
end
depth = depth - 1
until not string.match(result, pattern) or stop or depth == 0
return result
end

function Chat:add(type, text, usage)
local idx = self.session:add_item({
type = type,
text = text,
usage = usage,
ctx = { context = self.session:previous_context() },
})

self:update_variables()
text = self:render_template(Config.options.chat.args, text)

self:_add(type, text, usage, idx)
self:render_role()
end
Expand Down
33 changes: 11 additions & 22 deletions lua/ogpt/response.lua
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ function Response:init(provider, events)
self.not_processed = Deque.new()
self.not_processed_raw = Deque.new()
self.raw_chunk_tx, self.raw_chunk_rx = channel.mpsc()
self.processed_raw_tx, self.processed_raw_rx = channel.mpsc()
self.processed_content_tx, self.processed_content_rx = channel.mpsc()
self.processed_raw_tx, self.processsed_raw_rx = channel.mpsc()
self.processed_content_tx, self.processsed_content_rx = channel.mpsc()
self.response_state = nil
self.chunk_regex = ""
self:set_state(self.STATE_INIT)
Expand Down Expand Up @@ -90,7 +90,7 @@ function Response:_process_added_chunk()
chunk = chunk .. queued_chunk
end

-- Run different strategies for processing responses here
-- Run different strategies for processsing responses here
if self.provider.response_params.strategy == self.STRATEGY_CHUNK then
self.processed_raw_tx.send(chunk)
elseif
Expand Down Expand Up @@ -123,7 +123,7 @@ function Response:_process_added_chunk()
end

function Response:pop_content()
local content = self.processed_content_rx.recv()
local content = self.processsed_content_rx.recv()
if content[2] == "END" and (content[1] and content[1] == "") then
content[1] = self:get_processed_text()
end
Expand All @@ -141,29 +141,18 @@ function Response:render()
end

function Response:pop_chunk()
-- -- pop the next chunk and add anything that is not process
-- local _value = self.not_processed
-- self.not_processed = ""
-- local _chunk = self.processed_raw_rx.recv()
-- utils.log("Got chunk... now appending to 'not_processed'", vim.log.levels.TRACE)
-- return _value .. _chunk

local _chunk = self.processed_raw_rx.recv()
local _chunk = self.processsed_raw_rx.recv()
utils.log("pushing processed raw to queue: " .. _chunk, vim.log.levels.TRACE)

-- push on to queue
self.not_processed:pushright(_chunk)
utils.log("popping processed raw from queue: " .. _chunk, vim.log.levels.TRACE)
return self.not_processed:popleft()

-- local chunk = _chunk
-- -- clear the queue each time to try to get a full chunk
-- for i, queued_chunk in self.not_processed:ipairs_left() do
-- chunk = chunk .. queued_chunk
-- utils.log("Adding to final processed output: " .. chunk, vim.log.levels.TRACE)
-- self.not_processed[i] = nil
-- end
-- return chunk

local _data = {}
while not self.not_processed:is_empty() do
table.insert(_data, self.not_processed:popleft())
end
return table.concat(_data, "")
end

function Response:get_accumulated_chunks()
Expand Down
13 changes: 13 additions & 0 deletions lua/ogpt/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -456,4 +456,17 @@ function M.curl(args, on_exit)
return job
end

function M.get_visual_selection(bufnr)
local lines, start_row, start_col, end_row, end_col = M.get_visual_lines(bufnr)

return lines, start_row, start_col, end_row, end_col
end

function M.get_selected_text(bufnr)
-- selection using vim GV, after selection is made, it remains in
-- vim registry, and this function recall that selection
local lines, _, _, _, _ = M.get_visual_selection(bufnr)
return table.concat(lines, "\n")
end

return M

0 comments on commit 5468f7c

Please sign in to comment.