From 8a531328b24f60e2b3b681bf1d4fd39b891e95bd Mon Sep 17 00:00:00 2001 From: Huy Le Date: Sun, 11 Feb 2024 15:02:42 -0700 Subject: [PATCH] gemini is working --- lua/ogpt/api.lua | 136 +++++++++++++++++--------- lua/ogpt/config.lua | 8 ++ lua/ogpt/flows/actions/popup/init.lua | 6 +- lua/ogpt/provider/base.lua | 38 ++++++- lua/ogpt/provider/gemini.lua | 111 ++++++++++++++------- lua/ogpt/provider/openai.lua | 39 +++++--- lua/ogpt/utils.lua | 1 + 7 files changed, 238 insertions(+), 101 deletions(-) diff --git a/lua/ogpt/api.lua b/lua/ogpt/api.lua index 97b0693..4e6bce1 100644 --- a/lua/ogpt/api.lua +++ b/lua/ogpt/api.lua @@ -27,12 +27,12 @@ function Api:chat_completions(custom_params, partial_result_fn, should_stop, opt ctx.model = custom_params.model utils.log("Request to: " .. _completion_url) utils.log(params) + local raw_chunks = "" + local state = "START" if stream then - local raw_chunks = "" - local state = "START" - - partial_result_fn = vim.schedule_wrap(partial_result_fn) + -- partial_result_fn = vim.schedule_wrap(partial_result_fn) + local accumulate = {} self:exec( "curl", @@ -41,44 +41,91 @@ function Api:chat_completions(custom_params, partial_result_fn, should_stop, opt "--show-error", "--no-buffer", _completion_url, - "-H", - "Content-Type: application/json", - "-H", - self.provider.envs.AUTHORIZATION_HEADER, "-d", vim.json.encode(params), + table.unpack(self.provider:request_headers()), -- has to be the last item in the list }, function(chunk) - local ok, json = pcall(vim.json.decode, chunk) - if ok then - if json.error ~= nil then - local error_msg = { - "OGPT ERROR:", - self.provider.name, - vim.inspect(json.error) or "", - "Something went wrong.", - } - table.insert(error_msg, vim.inspect(params)) - -- local error_msg = "OGPT ERROR: " .. (json.error.message or "Something went wrong") - partial_result_fn(table.concat(error_msg, " "), "ERROR", ctx) - return - end - ctx, raw_chunks, state = - self.provider:process_line({ json = json, raw = chunk }, ctx, raw_chunks, state, partial_result_fn, opts) - return - end - - for line in chunk:gmatch("[^\n]+") do - local raw_json = string.gsub(line, "^data:", "") - local _ok, _json = pcall(vim.json.decode, raw_json) - if _ok then - ctx, raw_chunks, state = - self.provider:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, partial_result_fn, opts) - else - ctx, raw_chunks, state = - self.provider:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, partial_result_fn, opts) - end - end + local chunk_og = chunk + table.insert(accumulate, chunk_og) + local content = { + raw = chunk, + accumulate = accumulate, + content = raw_chunks or "", + state = state or "START", + ctx = ctx, + } + ctx, raw_chunks, state = self.provider:process_raw(content, partial_result_fn, opts) + -- + -- local ok, json = pcall(vim.json.decode, chunk) + -- if not ok then + -- -- gemini is missing bracket on returns + -- chunk = string.gsub(chunk, "^%[", "") + -- chunk = string.gsub(chunk, "^%,", "") + -- chunk = string.gsub(chunk, "%]$", "") + -- chunk = vim.trim(chunk, "\n") + -- chunk = vim.trim(chunk, "\r") + -- ok, json = pcall(vim.json.decode, chunk) + -- end + -- + -- -- if ok then + -- -- vim.print("okay") + -- -- else + -- -- vim.print("not okay") + -- -- end + -- + -- if ok then + -- if json.error ~= nil then + -- local error_msg = { + -- "OGPT ERROR:", + -- self.provider.name, + -- vim.inspect(json.error) or "", + -- "Something went wrong.", + -- } + -- table.insert(error_msg, vim.inspect(params)) + -- -- local error_msg = "OGPT ERROR: " .. (json.error.message or "Something went wrong") + -- partial_result_fn(table.concat(error_msg, " "), "ERROR", ctx) + -- return + -- end + -- ctx, raw_chunks, state = + -- self.provider:process_line({ json = json, raw = chunk }, ctx, raw_chunks, state, partial_result_fn, opts) + -- return + -- end + -- + -- -- openai + -- for line in chunk:gmatch("[^\n]+") do + -- local raw_json = string.gsub(line, "^data:", "") + -- local _ok, _json = pcall(vim.json.decode, raw_json) + -- if _ok then + -- ctx, raw_chunks, state = + -- self.provider:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, partial_result_fn, opts) + -- else + -- ctx, raw_chunks, state = self.provider:process_line( + -- { json = nil, raw = chunk_og }, + -- ctx, + -- raw_chunks, + -- state, + -- partial_result_fn, + -- opts + -- ) + -- end + -- end + -- + -- if not ok then + -- -- if not ok, try to keep process the accumulated stdout + -- table.insert(accumulate, chunk_og) + -- ok, json = pcall(vim.json.decode, table.concat(accumulate, "")) + -- if ok then + -- ctx, raw_chunks, state = self.provider:process_line( + -- { json = json, raw = accumulate }, + -- ctx, + -- raw_chunks, + -- state, + -- partial_result_fn, + -- opts + -- ) + -- end + -- end end, function(_text, _state, _ctx) partial_result_fn(_text, _state, _ctx) @@ -123,11 +170,9 @@ function Api:make_call(url, params, cb, ctx, raw_chunks, state, opts) local curl_args = { url, - table.unpack(self.provider:request_headers()), - "-H", - "Content-Type: application/json", "-d", "@" .. TMP_MSG_FILENAME, + table.unpack(self.provider:request_headers()), } self.job = job @@ -275,8 +320,8 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop) local handle, err local function on_stdout_read(_, chunk) - if chunk then - vim.schedule(function() + vim.schedule(function() + if chunk then if should_stop and should_stop() then if handle ~= nil then handle:kill(2) -- send SIGINT @@ -295,8 +340,8 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop) return end on_stdout_chunk(chunk) - end) - end + end + end) end local function on_stderr_read(_, chunk) @@ -305,6 +350,7 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop) end end + utils.log("executing: " .. vim.inspect(cmd) .. " " .. vim.inspect(args), vim.log.levels.DEBUG) handle, err = vim.loop.spawn(cmd, { args = args, stdio = { nil, stdout, stderr }, diff --git a/lua/ogpt/config.lua b/lua/ogpt/config.lua index 8756da3..9bdd580 100644 --- a/lua/ogpt/config.lua +++ b/lua/ogpt/config.lua @@ -39,6 +39,14 @@ function M.defaults() api_host = os.getenv("GEMINI_API_HOST"), api_key = os.getenv("GEMINI_API_KEY"), model = "gemini-pro", + api_params = { + temperature = 0.5, + topP = 0.99, + }, + api_chat_params = { + temperature = 0.5, + topP = 0.99, + }, }, textgenui = { enabled = true, diff --git a/lua/ogpt/flows/actions/popup/init.lua b/lua/ogpt/flows/actions/popup/init.lua index 20df5ad..9a98250 100644 --- a/lua/ogpt/flows/actions/popup/init.lua +++ b/lua/ogpt/flows/actions/popup/init.lua @@ -75,7 +75,11 @@ function PopupAction:run() self:run_spinner(flag) end, on_complete = function(total_text) - -- print("completed: " .. total_text) + -- utils.log("request completed - processed text is:\n" .. total_text, vim.log.levels.DEBUG) + if vim.fn.bufexists(self.popup.bufnr) then + vim.api.nvim_buf_set_option(self.popup.bufnr, "modifiable", true) + vim.api.nvim_buf_set_lines(self.popup.bufnr, -2, -1, false, vim.split(total_text, "\n", {})) + end end, }), function() diff --git a/lua/ogpt/provider/base.lua b/lua/ogpt/provider/base.lua index e3b26a3..caf3a35 100644 --- a/lua/ogpt/provider/base.lua +++ b/lua/ogpt/provider/base.lua @@ -5,7 +5,7 @@ local utils = require("ogpt.utils") local Provider = Object("Provider") function Provider:init(opts) - self.name = self.class.name + self.name = string.lower(self.class.name) opts = vim.tbl_extend("force", Config.options.providers[self.name], opts) self.enabled = opts.enabled self.model = opts.model @@ -175,6 +175,42 @@ function Provider:conform_messages(params) return params end +function Provider:process_raw(content, cb, opts) + local chunk = content.raw + local state = content.state + local raw_chunks = content.content + local accumulate = content.accumulate + local ctx = content.ctx + local ok, json = pcall(vim.json.decode, chunk) + + -- if not ok then + -- -- gemini is missing bracket on returns + -- chunk = string.gsub(chunk, "^%[", "") + -- chunk = string.gsub(chunk, "^%,", "") + -- chunk = string.gsub(chunk, "%]$", "") + -- chunk = vim.trim(chunk, "\n") + -- chunk = vim.trim(chunk, "\r") + -- ok, json = pcall(vim.json.decode, chunk) + -- end + + if ok then + if json.error ~= nil then + local error_msg = { + "OGPT ERROR:", + self.provider.name, + vim.inspect(json.error) or "", + "Something went wrong.", + } + table.insert(error_msg, vim.inspect(params)) + -- local error_msg = "OGPT ERROR: " .. (json.error.message or "Something went wrong") + cb(table.concat(error_msg, " "), "ERROR", ctx) + -- return + return { ctx, raw_chunks, state } + end + return self:process_line({ json = json, raw = chunk }, ctx, raw_chunks, state, cb) + end +end + function Provider:process_line(content, ctx, raw_chunks, state, cb) local _json = content.json local raw = content.raw diff --git a/lua/ogpt/provider/gemini.lua b/lua/ogpt/provider/gemini.lua index e2bbf80..b2c7beb 100644 --- a/lua/ogpt/provider/gemini.lua +++ b/lua/ogpt/provider/gemini.lua @@ -6,7 +6,6 @@ local Gemini = ProviderBase:extend("Gemini") function Gemini:init(opts) Gemini.super.init(self, opts) - self.name = "openai" self.api_parameters = { "contents", } @@ -42,7 +41,7 @@ end function Gemini:completion_url() return utils.ensureUrlProtocol( - self.envs.GEMINI_API_HOST .. "/" .. self.model .. ":streamGenerateContent?" .. self.envs.AUTH + self.envs.GEMINI_API_HOST .. "/models/" .. self.model .. ":streamGenerateContent?" .. self.envs.AUTH ) end @@ -52,8 +51,8 @@ end function Gemini:request_headers() return { - -- "-H", - -- "Content-Type: application/json", + "-H", + "Content-Type: application/json", } end @@ -93,49 +92,85 @@ function Gemini:conform_messages(params) table.remove(params.messages, _to_remove_system_idx[i]) end - -- https://ai.google.dev/tutorials/rest_quickstart#text-only_input - if params.system then - table.insert(params.messages, 1, { - role = "system", - content = params.system, + local messages = params.messages + local _contents = {} + for _, content in ipairs(messages) do + table.insert(_contents, { + role = content.role == "assistant" and "model" or "user", + parts = { + text = utils.gather_text_from_parts(content.content), + }, }) end + + -- -- https://ai.google.dev/tutorials/rest_quickstart#text-only_input + -- if params.system then + -- table.insert(params.messages, 1, { + -- role = "system", + -- content = params.system, + -- }) + -- end + + params.messages = nil + params.contents = _contents return params end +function Gemini:process_raw(content, cb) + local chunk = content.raw + local state = content.state + local raw_chunks = content.content + local accumulate = content.accumulate + + local ok, json = pcall(vim.json.decode, chunk) + if not ok then + -- gemini is missing bracket on returns + chunk = string.gsub(chunk, "^%[", "") + chunk = string.gsub(chunk, "^%,", "") + chunk = string.gsub(chunk, "%]$", "") + chunk = vim.trim(chunk, "\n") + chunk = vim.trim(chunk, "\r") + ok, json = pcall(vim.json.decode, chunk) + end + + if ok then + return self:process_line({ json = json, raw = accumulate }, ctx, raw_chunks, state, cb) + else + -- if not ok, try to keep process the accumulated stdout + ok, json = pcall(vim.json.decode, table.concat(accumulate, "")) + if ok then + return self:process_line({ json = json, raw = accumulate }, ctx, raw_chunks, state, cb) + end + end +end + function Gemini:process_line(content, ctx, raw_chunks, state, cb) - local _json = content.json + local _json = content.json or {} local raw = content.raw - local text = _json.candidates[0].content.parts[0].text - -- given a JSON response from the STREAMING api, processs it - if _json and _json.done then - ctx.context = _json.context - cb(raw_chunks, "END", ctx) - elseif type(_json) == "string" and string.find(_json, "[DONE]") then - cb(raw_chunks, "END", ctx) - else - if - not vim.tbl_isempty(_json) - and _json - and _json.choices - and _json.choices[1] - and _json.choices[1].delta - and _json.choices[1].delta.content - then - cb(_json.choices[1].delta.content, state) - raw_chunks = raw_chunks .. _json.choices[1].delta.content + if type(_json) == "string" then + utils.log("Something is going on, _json is a string, expecing a table..", vim.log.levels.ERROR) + elseif vim.tbl_isempty(_json) then + if raw == "]" then + cb(raw_chunks, "END", ctx) + else + cb("Could not process the following raw chunk:\n" .. raw, "ERROR", ctx) + end + elseif _json then + local text = vim.tbl_get(_json, "candidates", 1, "content", "parts", 1, "text") + if text then + cb(text, state) + raw_chunks = raw_chunks .. text state = "CONTINUE" - elseif - not vim.tbl_isempty(_json) - and _json - and _json.choices - and _json.choices[1] - and _json.choices[1].message - and _json.choices[1].message.content - then - cb(_json.choices[1].message.content, state) - raw_chunks = raw_chunks .. _json.choices[1].message.content + else + local total_text = {} + for _, part in ipairs(_json) do + text = vim.tbl_get(part, "candidates", 1, "content", "parts", 1, "text") + if text then + table.insert(total_text, text) + end + end + cb(table.concat(total_text, ""), "END") end end diff --git a/lua/ogpt/provider/openai.lua b/lua/ogpt/provider/openai.lua index c796f52..c570d0d 100644 --- a/lua/ogpt/provider/openai.lua +++ b/lua/ogpt/provider/openai.lua @@ -55,6 +55,25 @@ function Openai:conform_request(params) return params end +function Openai:process_raw(content, cb, opts) + local chunk = content.raw + local state = content.state + local ctx = content.ctx + local raw_chunks = content.content + + -- openai + for line in chunk:gmatch("[^\n]+") do + local raw_json = string.gsub(line, "^data:", "") + local _ok, _json = pcall(vim.json.decode, raw_json) + if _ok then + return self:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, cb) + -- else + -- ctx, raw_chunks, state = + -- self:process_line({ json = nil, raw = chunk_og }, ctx, raw_chunks, state, partial_result_fn, opts) + end + end +end + function Openai:process_line(content, ctx, raw_chunks, state, cb) local _json = content.json local raw = content.raw @@ -65,25 +84,13 @@ function Openai:process_line(content, ctx, raw_chunks, state, cb) elseif type(_json) == "string" and string.find(_json, "[DONE]") then cb(raw_chunks, "END", ctx) else - if - not vim.tbl_isempty(_json) - and _json - and _json.choices - and _json.choices[1] - and _json.choices[1].delta - and _json.choices[1].delta.content - then + local text_delta = vim.tbl_get(_json, "choices", 1, "delta", "content") + local text = vim.tbl_get(_json, "choices", 1, "message", "content") + if text_delta then cb(_json.choices[1].delta.content, state) raw_chunks = raw_chunks .. _json.choices[1].delta.content state = "CONTINUE" - elseif - not vim.tbl_isempty(_json) - and _json - and _json.choices - and _json.choices[1] - and _json.choices[1].message - and _json.choices[1].message.content - then + elseif text then cb(_json.choices[1].message.content, state) raw_chunks = raw_chunks .. _json.choices[1].message.content end diff --git a/lua/ogpt/utils.lua b/lua/ogpt/utils.lua index 0f1fb74..c442f8c 100644 --- a/lua/ogpt/utils.lua +++ b/lua/ogpt/utils.lua @@ -257,6 +257,7 @@ end function M.add_partial_completion(opts, text, state) local panel = opts.panel local progress = opts.progress + local on_complete = opts.on_complete if state == "ERROR" then if progress then