Skip to content

Commit

Permalink
getting things updated
Browse files Browse the repository at this point in the history
  • Loading branch information
huynle committed Jan 9, 2024
1 parent a040ea8 commit 995bc04
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 25 deletions.
16 changes: 11 additions & 5 deletions lua/ogpt/api.lua
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,27 @@ function Api.chat_completions(custom_params, cb, should_stop, opts)
},
function(chunk)
local ok, json = pcall(vim.json.decode, chunk)
if ok and json ~= nil then
if ok then
if json.error ~= nil then
cb(json.error, "ERROR", ctx)
return
end
ctx, raw_chunks, state = Api.provider.process_line(_ok, _json, ctx, raw_chunks, state, cb)
else
ctx, raw_chunks, state = Api.provider.process_line(ok, json, ctx, raw_chunks, state, cb)
return
end

if chunk:match("[^\n]+") then
for line in chunk:gmatch("[^\n]+") do
local raw_json = string.gsub(line, "^data:", "")
local _ok, _json = pcall(vim.json.decode, raw_json)
ctx, raw_chunks, state = Api.provider.process_line(_ok, _json, ctx, raw_chunks, state, cb)
if _ok then
ctx, raw_chunks, state = Api.provider.process_line(_json, ctx, raw_chunks, state, cb)
end
end
else
print("didnt get it")
end
end,

function(err, _)
cb(err, "ERROR", ctx)
end,
Expand Down
4 changes: 3 additions & 1 deletion lua/ogpt/flows/actions/chat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ function ChatAction:call_api(panel, params)
params,
Utils.partial(Utils.add_partial_completion, {
panel = panel,
-- finalize_opts = opts,
progress = function(flag)
self:run_spinner(flag)
end,
on_complete = function(total_text)
-- print("completed: " .. total_text)
end,
}),
function()
-- should stop function
Expand Down
5 changes: 4 additions & 1 deletion lua/ogpt/flows/actions/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ local read_actions_from_file = function(filename)
local json_string = file:read("*a")
file:close()

return vim.json.decode(json_string)
local ok, json = pcall(vim.json.decode, json_string)
if ok then
return json
end
end

function M.read_actions()
Expand Down
30 changes: 18 additions & 12 deletions lua/ogpt/flows/chat/session.lua
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ function Session:delete()
end

function Session:to_export()
return {
local _val = {
name = self.name,
updated_at = self.updated_at,
parameters = self.parameters,
conversation = self.conversation,
}
_val.parameters.model = self.parameters.model and self.parameters.model.name or self.parameters.model
return _val
end

function Session:previous_context()
Expand Down Expand Up @@ -144,11 +146,13 @@ function Session:load()
local jsonString = file:read("*a")
file:close()

local data = vim.json.decode(jsonString)
self.name = data.name
self.updated_at = data.updated_at or get_current_date()
self.parameters = data.parameters
self.conversation = data.conversation
local ok, data = pcall(vim.json.decode, jsonString)
if ok then
self.name = data.name
self.updated_at = data.updated_at or get_current_date()
self.parameters = data.parameters
self.conversation = data.conversation
end
end

--
Expand All @@ -173,12 +177,14 @@ function Session.list_sessions()
if updated_at == nil then
updated_at = filename
end

table.insert(sessions, {
filename = filename,
name = name,
ts = parse_date_time(updated_at),
})
local ok, ts = pcall(parse_date_time, updated_at)
if ok then
table.insert(sessions, {
filename = filename,
name = name,
ts = ts,
})
end
end

table.sort(sessions, function(a, b)
Expand Down
4 changes: 2 additions & 2 deletions lua/ogpt/provider/ollama.lua
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ function M.conform(params)
return params
end

function M.process_line(_ok, _json, ctx, raw_chunks, state, cb)
function M.process_line(_json, ctx, raw_chunks, state, cb)
if _json and _json.done then
ctx.context = _json.context
cb(raw_chunks, "END", ctx)
else
if _ok and not vim.tbl_isempty(_json) then
if not vim.tbl_isempty(_json) then
if _json and _json.message then
cb(_json.message.content, state, ctx)
raw_chunks = raw_chunks .. _json.message.content
Expand Down
13 changes: 9 additions & 4 deletions lua/ogpt/provider/textgenui.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ function M.conform_to_textgenui_api(params)
"seed",
"top_k",
"top_p",
"top_n_tokens",
"typical_p",
"stop",
"details",
"max_new_tokens",
"repetition_penalty",
}

local request_params = {
Expand Down Expand Up @@ -84,13 +88,14 @@ end

function M.conform(params)
params = params or {}
params.details = true
params.max_new_tokens = 1024

params.inputs = M.update_messages(params.messages or {})
return M.conform_to_textgenui_api(params)
end

function M.process_line(_ok, _json, ctx, raw_chunks, state, cb)
if _ok and not vim.tbl_isempty(_json) and _json and _json.token then
function M.process_line(_json, ctx, raw_chunks, state, cb)
if _json.token then
if _json.token.text == "</s>" then
ctx.context = _json.context
cb(raw_chunks, "END", ctx)
Expand All @@ -100,7 +105,7 @@ function M.process_line(_ok, _json, ctx, raw_chunks, state, cb)
state = "CONTINUE"
end
else
return
print(_json)
end
return ctx, raw_chunks, state
end
Expand Down
3 changes: 3 additions & 0 deletions lua/ogpt/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ function M.add_partial_completion(opts, text, state)
end

if state == "START" or state == "CONTINUE" then
vim.api.nvim_buf_set_option(panel.bufnr, "modifiable", true)
local lines = vim.split(text, "\n", {})
local length = #lines
local buffer = panel.bufnr
Expand All @@ -409,6 +410,8 @@ function M.add_partial_completion(opts, text, state)
end
end
end
else
print("stuc")
end
end

Expand Down

0 comments on commit 995bc04

Please sign in to comment.