diff --git a/README.md b/README.md index e6ddc5f..2b620bd 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,8 @@ ![Lua](https://img.shields.io/badge/Made%20with%20Lua-blueviolet.svg?style=for-the-badge&logo=lua) ## Features -- **Multiple Providers**: OGPT.nvim can take multiple providers. Ollama, OpenAI, textgenui, more if there are pull requests -- **Mix-match Provider**: default provider is used, but you can mix and match different provider AND specific model to different actions. +- **Multiple Providers**: OGPT.nvim can take multiple providers. Ollama, OpenAI, textgenui, Gemini, more if there are pull requests +- **Mix-match Provider**: default provider is used, but you can mix and match different provider AND specific model to different actions, at any point in your run or configurations. - **Interactive Q&A**: Engage in interactive question-and-answer sessions with the powerful gpt model (OGPT) using an intuitive interface. - **Persona-based Conversations**: Explore various perspectives and have conversations with different personas by selecting prompts from Awesome ChatGPT Prompts. - **Customizable Actions**: Execute a range of actions utilizing the gpt model, such as grammar @@ -19,12 +19,11 @@ plugin configurations. ## Installation -if you dont specify a provider, "ollama" will be the default provider. "http://localhost:11434" is -your endpoint. +If you do not specify a provider, `ollama` will be the default provider. `http://localhost:11434` +is your endpoint. `mistral:7b` will be your default if the configuration is not updated. ```lua - --- Lazy +-- Simple, minimal Lazy.nvim configuration { "huynle/ogpt.nvim", event = "VeryLazy", @@ -47,7 +46,7 @@ your endpoint. ## Configuration -`OGPT.nvim` comes with the following defaults, you can override any of the field by passing config as setup param. +`OGPT.nvim` comes with the following defaults. You can override any of the fields by passing a config as setup parameters. https://github.com/huynle/ogpt.nvim/blob/main/lua/ogpt/config.lua @@ -63,6 +62,158 @@ empowering you to generate natural language responses from Ollama's OGPT directl Custom Ollama API host with the configuration option `api_host_cmd` or environment variable called `$OLLAMA_API_HOST`. It's useful if you run Ollama remotely +### Gemini, TextGenUI, OpenAI Setup +* not much here, you just have to get your API keys and provide that in your configuration. If your + configuration files are public, you probably want to create environment variable for your API +keys + +### Edgy.nvim setup +![edgy-example](assets/images/edgy-example.png.png) + +I prefer `edgy.nvim` over the floating windows. It allow for much better interaction. Here is an +example of an 'edgy' configuration. + +```lua +{ +{ + "huynle/ogpt.nvim", + event = "VeryLazy", + opts = { + default_provider = "ollama", + edgy = true, -- enable this! + single_window = false -- set this to true if you want only one OGPT window to appear at a time + providers = { + ollama = { + api_host = os.getenv("OLLAMA_API_HOST") or "http://localhost:11434", + api_key = os.getenv("OLLAMA_API_KEY") or "", + } + } + }, + dependencies = { + "MunifTanjim/nui.nvim", + "nvim-lua/plenary.nvim", + "nvim-telescope/telescope.nvim" + } +}, +{ + "folke/edgy.nvim", + event = "VeryLazy", + init = function() + vim.opt.laststatus = 3 + vim.opt.splitkeep = "screen" -- or "topline" or "screen" + end, + }, + opts = { + exit_when_last = false, + animate = { + enabled = false, + }, + wo = { + winbar = true, + winfixwidth = true, + winfixheight = false, + winhighlight = "WinBar:EdgyWinBar,Normal:EdgyNormal", + spell = false, + signcolumn = "no", + }, + keys = { + -- -- close window + ["q"] = function(win) + win:close() + end, + -- close sidebar + ["Q"] = function(win) + win.view.edgebar:close() + end, + -- increase width + [""] = function(win) + win:resize("width", 3) + end, + -- decrease width + [""] = function(win) + win:resize("width", -3) + end, + -- increase height + [""] = function(win) + win:resize("height", 3) + end, + -- decrease height + [""] = function(win) + win:resize("height", -3) + end, + }, + right = { + { + title = "OGPT Popup", + ft = "ogpt-popup", + size = { width = 0.2 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPT Parameters", + ft = "ogpt-parameters-window", + size = { height = 6 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPT Template", + ft = "ogpt-template", + size = { height = 6 }, + }, + { + title = "OGPT Sesssions", + ft = "ogpt-sessions", + size = { height = 6 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPT System Input", + ft = "ogpt-system-window", + size = { height = 6 }, + }, + { + title = "OGPT", + ft = "ogpt-window", + size = { height = 0.5 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPT {{selection}}", + ft = "ogpt-selection", + size = { width = 80, height = 4 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPt {{instruction}}", + ft = "ogpt-instruction", + size = { width = 80, height = 4 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPT Chat", + ft = "ogpt-input", + size = { width = 80, height = 4 }, + wo = { + wrap = true, + }, + }, + }, + }, + } +} +``` ## Usage @@ -76,15 +227,17 @@ Plugin exposes following commands: #### `OGPTActAs` `OGPTActAs` command which opens a prompt selection from [Awesome OGPT Prompts](https://github.com/f/awesome-chatgpt-prompts) to be used with the `mistral:7b` model. -#### `OGPTRun edit_with_instructions` `OGPTRun edit_with_instructions` command which opens +#### `OGPTRun edit_with_instructions` +`OGPTRun edit_with_instructions` command which opens interactive window to edit selected text or whole window using the `deepseek-coder:6.7b` model, you -can change in this in your config options. This model defined in `config.api_edit_params`. +can change in this in your config options. This model defined in `config..api_params`. + +Use `` (default keymap, can be customized) to open and close the parameter panels. Note +this screenshot is using `edgy.nvim` -#### `OGPTRun edit_code_with_instructions` -This command opens an interactive window to edit selected text or the entire window using the -`deepseek-coder:6.7b` model. You can modify this in your config options. The Ollama response will -be extracted for its code content, and if it doesn't contain any codeblock, it will default back to -the full response. +![edit_with_instruction_no_params](assets/images/edit_with_instruction_no_param_panel.png) + +![edit_with_instructions_with_params](assets/images/edit_with_instruction_with_params_panel.png) #### `OGPTRun` @@ -128,10 +281,17 @@ opts = { ... actions = { grammar_correction = { - type = "popup", + -- type = "popup", -- could be a string or table to override + type = { + popup = { -- overrides the default popup options - https://github.com/huynle/ogpt.nvim/blob/main/lua/ogpt/config.lua#L147-L180 + edgy = true + } + }, + strategy = "replace", + provider = "ollama", -- default to "default_provider" if not provided + model = "mixtral:7b", -- default to "provider..model" if not provided template = "Correct the given text to standard {{lang}}:\n\n```{{input}}```", system = "You are a helpful note writing assistant, given a text input, correct the text only for grammar and spelling error. You are to keep all formatting the same, e.g. markdown bullets, should stay as a markdown bullet in the result, and indents should stay the same. Return ONLY the corrected text.", - strategy = "replace", params = { temperature = 0.3, }, @@ -155,6 +315,86 @@ available for further editing requests The `display` strategy shows the output in a float window. `append` and `replace` modify the text directly in the buffer with "a" or "r" + +#### Run With Options with Vim Commands +On the fly, you can execute a command line to call OGPT. An example to replace +the grammar_correction call, is provided below. +`:OGPTRun grammar_correction {provider="openai", model="gpt-4"}` + +To make it even more dynamic, you can change it to have the provider/model or any parameters be +inputted by the user on the spot when the command is executed. +`:OGPTRun grammar_correction {provider=vim.fn.input("Provider: "), type={popup={edgy=false}}}` + +Additionally, in the above example, `edgy.nvim` can be turned off. So that the response popup +inline where the cursor would be. For additional options for the popup, please read through +https://github.com/huynle/ogpt.nvim/blob/main/lua/ogpt/config.lua#L147-L180 + +For example, you and have it popup and change `enter = false`, which leaves the cursor in the same +location, instead of moving it to the popup. + +Additionally, for advanced users, this allows you to use Vim autocommands. For example, autocompletion +can happen when the cursor is paused. Look at the various Template Helpers for this advanced +options, because now + +### Template Helpers +Currently, the given inputs to the API gets scanned for `{{}}` for expansion. +This is helpful when you want to give a little more context to your API requests, or simply to hook +in additional function calls. + +#### Available Template Helpers +Look at this file for the most up to date Template Helpers. If you have more template helpers, +please make an MR, your contribution is appreciated! + +https://github.com/huynle/ogpt.nvim/blob/main/lua/ogpt/flows/actions/template_helpers.lua + +#### How to Use +This is a custom action that I use all the time to use the visible windows +as context to have AI answer any inline questions. + +```lua +.... + -- Other OGPT configurations here + .... + actions = { + infill_visible_code = { + type = "popup", + template = [[ + Given the following code snippets, please complete the code by infilling the rest of the code in between the two + code snippets for BEFORE and AFTER, these snippets are given below. + + + Code BEFORE infilling position: + ```{{filetype}} + {{visible_window_content}} + {{before_cursor}} + ``` + + Code AFTER infilling position: + ```{{filetype}} + {{after_cursor}} + ``` + + + Within the given snippets, complete the instructions that are given in between the + triple percent sign '%%%-' and '-%%%'. Note that the instructions as + could be multilines AND/OR it could be in a comment block of the code!!! + + Lastly, apply the following conditions to your response. + * The response should replace the '%%%-' and '-%%%' if the code snippet was to be reused. + * PLEASE respond ONLY with the answers to the given instructions. + ]], + strategy = "display", + -- provider = "textgenui", + -- model = "mixtral-8-7b", + -- params = { + -- max_new_tokens = 1000, + -- }, + }, + -- more actions here + } +.... +``` + ### Interactive Chat Parameters * Change Model by Opening the Parameter panels default to (ctrl-o) or your way to it @@ -181,8 +421,7 @@ https://github.com/huynle/ogpt.nvim/blob/main/lua/ogpt/config.lua#L174-L181 When the setting window is opened (with ``), settings can be modified by pressing `Enter` on the related config. Settings are saved across sessions. -### Example Comprehensive Lazy Configuration - +### Example `lazy.nvim` Configuration ```lua return { @@ -533,30 +772,87 @@ return { ### Advanced setup -#### Modify model REST URL +### Reloading Actions for Faster Interaction + +When you are updating your actions frequently, I would recommend adding the following keys to your +`lazy.nvim` `ogpt` configuration. This simply reload `ogpt.nvim` on the spot for you to see your +updated actions. + +```lua +... + -- other config options here + keys = { + { + "ro", + "Lazy reload ogpt.nvim", + desc = "RELOAD ogpt", + }, + ... + } + -- other config options here +... +``` + + + +#### Defining Custom Model + +This is an example of how to set up an Ollama Mixtral model server that might be sitting on a +different server. Note in the example below you can: +* Swap out the REST URL by directly replacing it with a URL string, or define a function that gets + called, to dynamically update. +* `secret_model` is an alias to `mixtral-8-7b`, so in your `actions` you can use `secret_model`. + This is useful when you have multiple providers that have the same power as Mixtral, and you want +to swap different providers to use, based on development environment, or for other reasons. +* When defining a new model, like that of `mixtral-8-7b` in this example, this model will show up + in your options of models in your `chat` and `edit` actions. +* Since custom models might have obscured parameters or settings, the "param" field under your new + model definition is used to force the final override for your REST parameters. +* `conform_message_fn` is used to override the default provider `conform_message` function. This + function allows the massaging of the API request parameters to fit the specific model. This is +really useful when you need to modify the messages to fit the model trained template. +* `conform_request_fn` is used to override the default provider `conform_request` function. This + function (or the provider default function) is called at the very end, right before making the +API call. Final massaging can be done here. ```lua -- advanced model, can take the following structure -local advanced_model = { - -- create a modify url specifically for mixtral to run - name = "mixtral-8-7b", - -- name = "mistral-7b-tgi-predictor-ai-factory", - modify_url = function(url) - -- given a URL, this function modifies the URL specifically to the model - -- This is useful when you have different models hosted on different subdomains like - -- https://model1.yourdomain.com/ - -- https://model2.yourdomain.com/ - local new_model = "mixtral-8-7b" - -- local new_model = "mistral-7b-tgi-predictor-ai-factory" - local host = url:match("https?://([^/]+)") - local subdomain, domain, tld = host:match("([^.]+)%.([^.]+)%.([^.]+)") - local _new_url = url:gsub(host, new_model .. "." .. domain .. "." .. tld) - return _new_url - end, - -- conform_fn = function(params) - -- -- Different models might have different instruction format - -- -- for example, Mixtral operates on ` [INST] Instruction [/INST] Model answer [INST] Follow-up instruction [/INST] ` - -- end, +providers = { + ollama = { + model = "secret_model", -- default model for ollama + models = { + ... + secret_model = "mixtral-8-7b", + ["mixtral-8-7b"]= { + params = { + -- the parameters here are FORCED into the final API REQUEST, OVERRIDDING + -- anything that was set before + max_new_token = 200, + }, + modify_url = function(url) + -- given a URL, this function modifies the URL specifically to the model + -- This is useful when you have different models hosted on different subdomains like + -- https://model1.yourdomain.com/ + -- https://model2.yourdomain.com/ + local new_model = "mixtral-8-7b" + -- local new_model = "mistral-7b-tgi-predictor-ai-factory" + local host = url:match("https?://([^/]+)") + local subdomain, domain, tld = host:match("([^.]+)%.([^.]+)%.([^.]+)") + local _new_url = url:gsub(host, new_model .. "." .. domain .. "." .. tld) + return _new_url + end, + -- conform_messages_fn = function(params) + -- Different models might have different instruction format + -- for example, Mixtral operates on ` [INST] Instruction [/INST] Model answer [INST] Follow-up instruction [/INST] ` + -- look in the `providers` folder of the plugin for examples + -- end, + -- conform_request_fn = function(params) + -- API request might need custom format, this function allows that to happen + -- look in the `providers` folder of the plugin for examples + -- end, + } + } + } } ``` @@ -565,7 +861,90 @@ local advanced_model = { TBD +#### Edgy.nvim Setup + +If you like you `edgy.nvim` setup, then use something like this for your plugin setup options for +`edgy.nvim`. After this is set, make sure you enable the `edgy = true` options in your +configuration options for `ogpt.nvim`. + +```lua +opts = { + right = { + { + title = "OGPT Popup", + ft = "ogpt-popup", + size = { width = 0.2 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPT Parameters", + ft = "ogpt-parameters-window", + size = { height = 6 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPT Template", + ft = "ogpt-template", + size = { height = 6 }, + }, + { + title = "OGPT Sesssions", + ft = "ogpt-sessions", + size = { height = 6 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPT System Input", + ft = "ogpt-system-window", + size = { height = 6 }, + }, + { + title = "OGPT", + ft = "ogpt-window", + size = { height = 0.5 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPT {{selection}}", + ft = "ogpt-selection", + size = { width = 80, height = 4 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPt {{instruction}}", + ft = "ogpt-instruction", + size = { width = 80, height = 4 }, + wo = { + wrap = true, + }, + }, + { + title = "OGPT Chat", + ft = "ogpt-input", + size = { width = 80, height = 4 }, + wo = { + wrap = true, + }, + }, + }, +} +``` + + + ## OGPT planned work ++ [o] Response and Request objects. General interface for modularity, and additional provider + adoption. + [x] Use default provider, but can be overriden at anytime for specific action + [x] original functionality of ChatGPT.nvim to work with Ollama, TextGenUI(huggingface), OpenAI via `providers` + Look at the "default_provider" in the `config.lua`, default is `ollama` @@ -575,7 +954,7 @@ TBD + [x] Add/remove parameters in Chat and Edit + [x] Choose provider, as well as model for Chat and Edit + [x] Customizable actions, with specific provider and model -+ [ ] Another Windows for [Template](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#template), [System](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#system) ++ [x] Another Windows for [Template](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#template), [System](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#system) + [x] Framework to add more providers + [x] clean up documentation + [x] additional actions can be added to config options, or additional json. Look in "config.actions", and "config.actions_paths" diff --git a/assets/images/edgy-example.png.png b/assets/images/edgy-example.png.png new file mode 100644 index 0000000..144bebe Binary files /dev/null and b/assets/images/edgy-example.png.png differ diff --git a/assets/images/edit_with_instruction_no_param_panel.png b/assets/images/edit_with_instruction_no_param_panel.png new file mode 100644 index 0000000..296ff46 Binary files /dev/null and b/assets/images/edit_with_instruction_no_param_panel.png differ diff --git a/assets/images/edit_with_instruction_with_params_panel.png b/assets/images/edit_with_instruction_with_params_panel.png new file mode 100644 index 0000000..1037b72 Binary files /dev/null and b/assets/images/edit_with_instruction_with_params_panel.png differ diff --git a/lua/ogpt/api.lua b/lua/ogpt/api.lua index 2fa30a3..3137556 100644 --- a/lua/ogpt/api.lua +++ b/lua/ogpt/api.lua @@ -3,104 +3,88 @@ local Config = require("ogpt.config") local logger = require("ogpt.common.logger") local Object = require("ogpt.common.object") local utils = require("ogpt.utils") +local Response = require("ogpt.response") local Api = Object("Api") +Api.STATE_COMPLETED = "COMPLETED" + function Api:init(provider, action, opts) self.opts = opts self.provider = provider self.action = action end -function Api:completions(custom_params, cb) +function Api:completions(custom_params, cb, opts) + -- TODO: not working atm local params = vim.tbl_extend("keep", custom_params, Config.options.api_params) params.stream = false - self:make_call(self.COMPLETIONS_URL, params, cb) + self:make_call(self.COMPLETIONS_URL, params, cb, opts) end -function Api:chat_completions(custom_params, partial_result_fn, should_stop, opts) - local stream = custom_params.stream or false - local params, _completion_url = Config.expand_model(self, custom_params) +function Api:chat_completions(response, inputs) + local custom_params = inputs.custom_params + local partial_result_fn = inputs.partial_result_fn + local should_stop = inputs.should_stop or function() end + + -- local stream = custom_params.stream or false + local params, _completion_url, ctx = self.provider:expand_model(custom_params) - local ctx = {} ctx.params = params + ctx.provider = self.provider.name + ctx.model = custom_params.model utils.log("Request to: " .. _completion_url) utils.log(params) + response.ctx = ctx + response.rest_params = params + response.partial_result_cb = partial_result_fn + response:run_async() - if stream then - local raw_chunks = "" - local state = "START" - - partial_result_fn = vim.schedule_wrap(partial_result_fn) - - self:exec( - "curl", - { - "--silent", - "--show-error", - "--no-buffer", - _completion_url, - "-H", - "Content-Type: application/json", - "-H", - self.provider.envs.AUTHORIZATION_HEADER, - "-d", - vim.json.encode(params), - }, - function(chunk) - local ok, json = pcall(vim.json.decode, chunk) - if ok then - if json.error ~= nil then - local error_msg = { - "OGPT ERROR:", - self.provider.name, - vim.inspect(json.error) or "", - "Something went wrong.", - } - table.insert(error_msg, vim.inspect(params)) - -- local error_msg = "OGPT ERROR: " .. (json.error.message or "Something went wrong") - partial_result_fn(table.concat(error_msg, " "), "ERROR", ctx) - return - end - ctx, raw_chunks, state = - self.provider.process_line({ json = json, raw = chunk }, ctx, raw_chunks, state, partial_result_fn) - return - end + local on_complete = inputs.on_complete or function() + response:set_state(response.STATE_COMPLETED) + end + local on_start = inputs.on_start + or function() + -- utils.log("Start Exec of: Curl " .. vim.inspect(curl_args), vim.log.levels.DEBUG) + response:set_state(response.STATE_INPROGRESS) + end + local on_error = inputs.on_error + or function(msg) + -- utils.log("Error running curl: " .. msg or "", vim.log.levels.ERROR) + response:set_state(response.STATE_ERROR) + end + local on_stop = inputs.on_stop or function() + response:set_state(response.STATE_STOPPED) + end - for line in chunk:gmatch("[^\n]+") do - local raw_json = string.gsub(line, "^data:", "") - local _ok, _json = pcall(vim.json.decode, raw_json) - if _ok then - ctx, raw_chunks, state = - self.provider.process_line({ json = _json, raw = line }, ctx, raw_chunks, state, partial_result_fn) - else - ctx, raw_chunks, state = - self.provider.process_line({ json = _json, raw = line }, ctx, raw_chunks, state, partial_result_fn) - end - end - end, - function(err, _) - partial_result_fn(err, "ERROR", ctx) - end, - should_stop, - function() - partial_result_fn(raw_chunks, "END", ctx) - end - ) - else - params.stream = false - self:make_call(self.provider.envs.CHAT_COMPLETIONS_URL, params, partial_result_fn) + -- if params.stream then + -- local accumulate = {} + local curl_args = { + "--silent", + "--show-error", + "--no-buffer", + _completion_url, + "-d", + vim.json.encode(params), + } + for _, header_item in ipairs(self.provider:request_headers()) do + table.insert(curl_args, header_item) end -end -function Api:edits(custom_params, cb) - local params = self.action.params - params.stream = true - params = vim.tbl_extend("force", params, custom_params) - self:chat_completions(params, cb) + self:exec("curl", curl_args, on_start, function(chunk) + response:add_chunk(chunk) + end, on_complete, on_error, on_stop, should_stop) end -function Api:make_call(url, params, cb, ctx, raw_chunks, state) +-- function Api:edits(custom_params, cb) +-- local params = self.action.params +-- params.stream = true +-- params = vim.tbl_extend("force", params, custom_params) +-- self:chat_completions(params, cb) +-- end + +function Api:make_call(url, params, cb, ctx, raw_chunks, state, opts) + -- TODO: to be deprecated ctx = ctx or {} raw_chunks = raw_chunks or "" state = state or "START" @@ -116,12 +100,9 @@ function Api:make_call(url, params, cb, ctx, raw_chunks, state) local curl_args = { url, - "-H", - "Content-Type: application/json", - "-H", - self.provider.envs.AUTHORIZATION_HEADER, "-d", "@" .. TMP_MSG_FILENAME, + table.unpack(self.provider:request_headers()), } self.job = job @@ -135,7 +116,7 @@ function Api:make_call(url, params, cb, ctx, raw_chunks, state) "An Error Occurred, when calling `curl " .. table.concat(curl_args, " ") .. "`", vim.log.levels.ERROR ) - cb("ERROR: API Error") + cb("ERROR: API Error", "ERROR") end local result = table.concat(response:result(), "\n") @@ -154,7 +135,8 @@ function Api:make_call(url, params, cb, ctx, raw_chunks, state) cb(table.concat(error_msg, " "), "ERROR", ctx) return end - ctx, raw_chunks, state = self.provider.process_line({ json = json, raw = result }, ctx, raw_chunks, state, cb) + ctx, raw_chunks, state = + self.provider:process_line({ json = json, raw = result }, ctx, raw_chunks, state, cb, opts) return end @@ -163,10 +145,10 @@ function Api:make_call(url, params, cb, ctx, raw_chunks, state) local _ok, _json = pcall(vim.json.decode, raw_json) if _ok then ctx, raw_chunks, state = - self.provider.process_line({ json = _json, raw = line }, ctx, raw_chunks, state, cb) + self.provider:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, cb, opts) else ctx, raw_chunks, state = - self.provider.process_line({ json = _json, raw = line }, ctx, raw_chunks, state, cb) + self.provider:process_line({ json = _json, raw = line }, ctx, raw_chunks, state, cb, opts) end end end), @@ -261,24 +243,29 @@ local function ensureUrlProtocol(str) return "https://" .. str end -function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop) - local stdout = vim.loop.new_pipe() +function Api:exec(cmd, args, on_start, on_stdout_chunk, on_complete, on_error, on_stop, should_stop) local stderr = vim.loop.new_pipe() + local stdout = vim.loop.new_pipe() local stderr_chunks = {} local handle, err local function on_stdout_read(_, chunk) if chunk then vim.schedule(function() - if should_stop and should_stop() then + if should_stop() then if handle ~= nil then handle:kill(2) -- send SIGINT - stdout:close() - stderr:close() - handle:close() + pcall(function() + stdout:close() + end) + pcall(function() + stderr:close() + end) + pcall(function() + handle:close() + end) on_stop() end - return end on_stdout_chunk(chunk) end) @@ -291,6 +278,7 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop) end end + on_start() handle, err = vim.loop.spawn(cmd, { args = args, stdio = { nil, stdout, stderr }, @@ -303,13 +291,15 @@ function Api:exec(cmd, args, on_stdout_chunk, on_complete, should_stop, on_stop) vim.schedule(function() if code ~= 0 then - on_complete(vim.trim(table.concat(stderr_chunks, ""))) + on_error() + else + on_complete() end end) end) if not handle then - on_complete(cmd .. " could not be started: " .. err) + on_error(cmd .. " could not be started: " .. err) else stdout:read_start(on_stdout_read) stderr:read_start(on_stderr_read) diff --git a/lua/ogpt/common/popup.lua b/lua/ogpt/common/popup.lua index 4968b07..5ba6d4a 100644 --- a/lua/ogpt/common/popup.lua +++ b/lua/ogpt/common/popup.lua @@ -9,6 +9,8 @@ function Popup:init(options, edgy) if options.edgy and options.border or edgy then self.edgy = true options.border = nil + else + options.buf_options.filetype = nil end Popup.super.init(self, options) end diff --git a/lua/ogpt/config.lua b/lua/ogpt/config.lua index c99d384..892a540 100644 --- a/lua/ogpt/config.lua +++ b/lua/ogpt/config.lua @@ -12,8 +12,13 @@ M.logs = {} function M.defaults() local defaults = { - debug = false, + -- options of 0-5, is trace, debug, info, warn, error, off, respectively + debug = { + log_level = 3, + notify_level = 3, + }, edgy = false, + single_window = false, yank_register = "+", default_provider = "ollama", providers = { @@ -23,8 +28,6 @@ function M.defaults() api_host = os.getenv("OPENAI_API_HOST") or "https://api.openai.com", api_key = os.getenv("OPENAI_API_KEY") or "", api_params = { - frequency_penalty = 0, - presence_penalty = 0, temperature = 0.5, top_p = 0.99, }, @@ -35,6 +38,20 @@ function M.defaults() top_p = 0.99, }, }, + gemini = { + enabled = true, + api_host = os.getenv("GEMINI_API_HOST"), + api_key = os.getenv("GEMINI_API_KEY"), + model = "gemini-pro", + api_params = { + temperature = 0.5, + topP = 0.99, + }, + api_chat_params = { + temperature = 0.5, + topP = 0.99, + }, + }, textgenui = { enabled = true, api_host = os.getenv("OGPT_API_HOST"), @@ -63,8 +80,8 @@ function M.defaults() -- }, }, api_params = { - frequency_penalty = 0, - presence_penalty = 0, + -- frequency_penalty = 0, + -- presence_penalty = 0, temperature = 0.5, top_p = 0.99, }, @@ -100,8 +117,8 @@ function M.defaults() -- used for `edit` and `edit_code` strategy in the actions model = nil, -- model = "mistral:7b", - frequency_penalty = 0, - presence_penalty = 0, + -- frequency_penalty = 0, + -- presence_penalty = 0, temperature = 0.5, top_p = 0.99, }, @@ -154,7 +171,7 @@ function M.defaults() }, keymaps = { close = { "", "q" }, - accept = "", + accept = "", append = "a", prepend = "p", yank_code = "c", @@ -241,13 +258,10 @@ function M.defaults() syntax = "markdown", }, }, - system_window = { + util_window = { border = { highlight = "FloatBorder", style = "rounded", - text = { - top = " SYSTEM ", - }, }, win_options = { wrap = true, @@ -255,10 +269,8 @@ function M.defaults() foldcolumn = "2", winhighlight = "Normal:Normal,FloatBorder:FloatBorder", }, - buf_options = { - filetype = "ogpt-system-window", - }, }, + input_window = { prompt = "  ", border = { @@ -286,7 +298,7 @@ function M.defaults() style = "rounded", text = { top_align = "center", - top = " Instruction ", + top = " {{instruction}} ", }, }, win_options = { @@ -323,8 +335,8 @@ function M.defaults() delay = true, extract_codeblock = true, params = { - frequency_penalty = 0, - presence_penalty = 0, + -- frequency_penalty = 0, + -- presence_penalty = 0, temperature = 0.5, top_p = 0.99, }, @@ -369,7 +381,7 @@ function M.setup(options) local function update_edgy_flag(chat_type) for _, t in ipairs(chat_type) do - if not M.options[t].edgy then + if M.options[t].edgy == nil then M.options[t].edgy = M.options.edgy end end @@ -393,20 +405,21 @@ function M.get_provider(provider_name, action, override) local Api = require("ogpt.api") override = override or {} provider_name = provider_name or M.options.default_provider - local provider = require("ogpt.provider." .. provider_name) - local envs = provider.load_envs(override.envs) - provider = vim.tbl_extend("force", provider, override) - provider.envs = envs + local provider = require("ogpt.provider." .. provider_name)(override) + provider:load_envs(override.envs) + -- provider = vim.tbl_extend("force", provider, override) + -- provider.envs = envs provider.api = Api(provider, action, {}) return provider end function M.get_action_params(provider, override) - provider = provider or M.options.default_provider - local default_params = M.options.providers[provider].api_params - default_params.model = default_params.model or M.options.providers[provider].model - default_params.provider = provider - return vim.tbl_extend("force", default_params, override or {}) + provider = provider or M.get_provider(M.options.default_provider) + local default_params = provider:get_action_params(override) + -- default_params.model = default_params.model or provider.model + -- default_params.provider = provider + -- return vim.tbl_extend("force", default_params, override or {}) + return default_params end function M.get_chat_params(provider, override) @@ -417,12 +430,15 @@ function M.get_chat_params(provider, override) return vim.tbl_extend("force", default_params, override or {}) end -function M.expand_model(api, params) - local provider_models = M.options.providers[api.provider.name].models or {} - params = M.get_action_params(api.provider.name, params) - local _model = params.model +function M.expand_model(api, params, ctx) + ctx = ctx or {} + -- local provider_models = M.options.providers[api.provider.name].models or {} + local provider_models = api.provider.models + -- params = M.get_action_params(api.provider, params) + params = api.provider:get_action_params(params) + local _model = params.model or api.provider.model - local _completion_url = api.provider.envs.CHAT_COMPLETIONS_URL + local _completion_url = api.provider:completion_url() local function _expand(name, _m) if type(_m) == "table" then @@ -439,7 +455,9 @@ function M.expand_model(api, params) end end end - params.model = _m.name or name + params.model = _m or name + elseif not vim.tbl_contains(provider_models, _m) then + params.model = _m else for _name, model in pairs(provider_models) do if _name == _m then @@ -452,26 +470,44 @@ function M.expand_model(api, params) end end end + return params end - _expand(nil, _model) + local _full_unfiltered_params = _expand(nil, _model) + -- ctx.tokens = _full_unfiltered_params.model.tokens or {} + -- final force override from the params that are set in the mode itself. + -- This will enforce specific model params, e.g. max_token, etc + local final_overrided_applied_params = + vim.tbl_extend("force", params, vim.tbl_get(_full_unfiltered_params, "model", "params") or {}) - params = M.expand_url(api, params) + params = M.conform_to_provider_request(api, final_overrided_applied_params) - return params, _completion_url + return params, _completion_url, ctx end -function M.expand_url(api, params) - params = M.get_action_params(api.provider.name, params) +function M.conform_to_provider_request(api, params) + params = M.get_action_params(api.provider, params) local _model = params.model - local _conform_fn = _model and _model.conform_fn + local _conform_messages_fn = _model and _model.conform_messages_fn + local _conform_request_fn = _model and _model.conform_request_fn - if _conform_fn then - params = _conform_fn(params) + if _conform_messages_fn then + params = _conform_messages_fn(api.provider, params) else - params = api.provider.conform(params) + params = api.provider:conform_messages(params) end + + if _conform_request_fn then + params = _conform_request_fn(api.provider, params) + else + params = api.provider:conform_request(params) + end + return params end +function M.get_local_model_definition(provider) + return M.options.providers[provider.name].models or {} +end + return M diff --git a/lua/ogpt/flows/actions/base.lua b/lua/ogpt/flows/actions/base.lua index 07ca17e..2d0d0b2 100644 --- a/lua/ogpt/flows/actions/base.lua +++ b/lua/ogpt/flows/actions/base.lua @@ -3,6 +3,7 @@ local Signs = require("ogpt.signs") local Spinner = require("ogpt.spinner") local utils = require("ogpt.utils") local Config = require("ogpt.config") +local template_helpers = require("ogpt.flows.actions.template_helpers") local BaseAction = Object("BaseAction") @@ -122,17 +123,45 @@ end function BaseAction:update_variables() self.variables = vim.tbl_extend("force", self.variables, { - filetype = self:get_filetype(), - input = self:get_selected_text(), + filetype = function() + return self:get_filetype() + end, + input = function() + return self:get_selected_text() + end, + selection = function() + return self:get_selected_text() + end, }) + for helper, helper_fn in pairs(template_helpers) do + local _v = { [helper] = helper_fn } + self.variables = vim.tbl_extend("force", self.variables, _v) + end end -function BaseAction:render_template() +function BaseAction:render_template(variables, templates) + variables = vim.tbl_extend("force", self.variables, variables or {}) + -- lazily render the final string. + -- it recursively loop on the template string until it does not find anymore + -- {{}} patterns + local stop = false + local depth = 2 local result = self.template - for key, value in pairs(self.variables) do - local escaped_value = utils.escape_pattern(value) - result = string.gsub(result, "{{" .. key .. "}}", escaped_value) - end + local pattern = "%{%{(([%w_]+))%}%}" + repeat + for match in string.gmatch(result, pattern) do + local value = variables[match] + if value then + value = type(value) == "function" and value() or value + local escaped_value = utils.escape_pattern(value) + result = string.gsub(result, "{{" .. match .. "}}", escaped_value) + else + utils.log("Cannot find {{" .. match .. "}}", vim.log.levels.ERROR) + stop = true + end + end + depth = depth - 1 + until not string.match(result, pattern) or stop or depth == 0 return result end @@ -140,7 +169,11 @@ function BaseAction:get_params() local messages = self.params.messages or {} local message = { role = "user", - content = self:render_template(), + content = { + { + text = self:render_template(), + }, + }, } table.insert(messages, message) return vim.tbl_extend("force", self.params, { @@ -190,4 +223,61 @@ function BaseAction:display_input_suffix(suffix) end end +function BaseAction:on_complete(response) + -- empty +end + +function BaseAction:addAnswerPartial(response) + local content = response:pop_content() + local text = content[1] + local state = content[2] + + if state == "ERROR" then + self:run_spinner(false) + utils.log("An Error Occurred: " .. text, vim.log.levels.ERROR) + self.output_panel:unmount() + return + end + + if state == "END" then + utils.log("Received END Flag", vim.log.levels.DEBUG) + if not utils.is_buf_exists(self.output_panel.bufnr) then + return + end + vim.api.nvim_buf_set_option(self.output_panel.bufnr, "modifiable", true) + vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, {}) -- clear the window, an put the final answer in + vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, vim.split(text, "\n")) + self:on_complete(response) + end + + if state == "START" then + self:run_spinner(false) + if not utils.is_buf_exists(self.output_panel.bufnr) then + return + end + vim.api.nvim_buf_set_option(self.output_panel.bufnr, "modifiable", true) + end + + if state == "START" or state == "CONTINUE" then + if not utils.is_buf_exists(self.output_panel.bufnr) then + return + end + vim.api.nvim_buf_set_option(self.output_panel.bufnr, "modifiable", true) + local lines = vim.split(text, "\n", {}) + local length = #lines + + for i, line in ipairs(lines) do + if self.output_panel.bufnr and vim.fn.bufexists(self.output_panel.bufnr) then + local currentLine = vim.api.nvim_buf_get_lines(self.output_panel.bufnr, -2, -1, false)[1] + if currentLine then + vim.api.nvim_buf_set_lines(self.output_panel.bufnr, -2, -1, false, { currentLine .. line }) + if i == length and i > 1 then + vim.api.nvim_buf_set_lines(self.output_panel.bufnr, -1, -1, false, { "" }) + end + end + end + end + end +end + return BaseAction diff --git a/lua/ogpt/flows/actions/edits/init.lua b/lua/ogpt/flows/actions/edits/init.lua index aa3e0bc..6eea2fe 100644 --- a/lua/ogpt/flows/actions/edits/init.lua +++ b/lua/ogpt/flows/actions/edits/init.lua @@ -1,10 +1,12 @@ local BaseAction = require("ogpt.flows.actions.base") +local Response = require("ogpt.response") local utils = require("ogpt.utils") local Config = require("ogpt.config") local Layout = require("ogpt.common.layout") local Popup = require("ogpt.common.popup") local ChatInput = require("ogpt.input") local Parameters = require("ogpt.parameters") +local UtilWindow = require("ogpt.util_window") local EditAction = BaseAction:extend("EditAction") @@ -16,7 +18,7 @@ function EditAction:init(name, opts) opts = opts or {} EditAction.super.init(self, opts) self.provider = Config.get_provider(opts.provider, self) - self.params = Config.get_action_params(self.provider.name, opts.params or {}) + self.params = Config.get_action_params(self.provider, opts.params or {}) self.system = type(opts.system) == "function" and opts.system() or opts.system or "" self.template = type(opts.template) == "function" and opts.template() or opts.template or "{{input}}" self.variables = opts.variables or {} @@ -25,7 +27,7 @@ function EditAction:init(name, opts) self.instructions_input = nil self.layout = nil - self.input_panel = nil + self.selection_panel = nil self.output_panel = nil self.parameters_panel = nil self.timer = nil @@ -67,6 +69,20 @@ function EditAction:run() end) end +function EditAction:on_complete(response) + -- on the completion, execute this function to extract out codeblocks + local output_txt = response:get_processed_text() + local nlcount = utils.count_newlines_at_end(output_txt) + if self.strategy == STRATEGY_EDIT_CODE then + output_txt = response:extract_code() + end + local output_txt_nlfixed = utils.replace_newlines_at_end(output_txt, nlcount) + local _output = utils.split_string_by_line(output_txt_nlfixed) + if self.output_panel.bufnr then + vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, _output) + end +end + function EditAction:edit_with_instructions(output_lines, selection, opts, ...) opts = opts or {} opts.params = opts.params or self.params @@ -81,16 +97,43 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) self.parameters_panel = Parameters({ type = "edits", - default_params = api_params, + default_params = vim.tbl_extend("force", api_params, { + provider = self.provider.name, + model = self.provider.model, + }), session = nil, parent = self, edgy = Config.options.edit.edgy, }) - self.input_panel = Popup(Config.options.input_window, Config.options.edit.edgy) + self.selection_panel = UtilWindow({ + filetype = "ogpt-selection", + display = "{{selection}}", + virtual_text = "No selection was made..", + }, Config.options.chat.edgy) + self.template_panel = UtilWindow({ + filetype = "ogpt-template", + display = "Template", + virtual_text = "Template is not defined.. will use the {{selection}}", + default_text = self.template, + on_change = function(text) + self.template = text + end, + }, Config.options.chat.edgy) + + self.system_role_panel = UtilWindow({ + filetype = "ogpt-system-window", + display = "System", + virtual_text = "Define your LLM system message here...", + default_text = opts.params.system, + on_change = function(text) + self.template = text + end, + }, Config.options.chat.edgy) + self.output_panel = Popup(Config.options.output_window, Config.options.edit.edgy) self.instructions_input = ChatInput(Config.options.instruction_window, { edgy = Config.options.edit.edgy, - prompt = Config.options.input_window.prompt, + prompt = Config.options.instruction_window.prompt, default_value = opts.instruction or "", on_close = function() -- if self.spinner:is_running() then @@ -102,13 +145,14 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) end end, on_submit = vim.schedule_wrap(function(instruction) + local response = Response(self.provider) -- clear input vim.api.nvim_buf_set_lines(self.instructions_input.bufnr, 0, -1, false, { "" }) vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, { "" }) -- show_progress() self:run_spinner(self.instructions_input, true) - local input = table.concat(vim.api.nvim_buf_get_lines(self.input_panel.bufnr, 0, -1, false), "\n") + local input = table.concat(vim.api.nvim_buf_get_lines(self.selection_panel.bufnr, 0, -1, false), "\n") -- if instruction is empty, try to get the original instruction from opts if instruction == "" then @@ -116,38 +160,14 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) end local messages = self:build_edit_messages(input, instruction, opts) local params = vim.tbl_extend("keep", { messages = messages }, self.parameters_panel.params) - self.provider.api:edits( - params, - utils.partial(utils.add_partial_completion, { - panel = self.output_panel, - on_complete = function(response) - -- on the completion, execute this function to extract out codeblocks - local nlcount = utils.count_newlines_at_end(response) - local output_txt = response - if opts.edit_code then - local code_response = utils.extract_code(response) - -- if the chat is to edit code, it will try to extract out the code from response - output_txt = response - if code_response then - output_txt = utils.match_indentation(response, code_response) - else - vim.notify("no codeblock detected", vim.log.levels.INFO) - end - if response.applied_changes then - vim.notify(response.applied_changes, vim.log.levels.INFO) - end - end - local output_txt_nlfixed = utils.replace_newlines_at_end(output_txt, nlcount) - local _output = utils.split_string_by_line(output_txt_nlfixed) - if self.output_panel.bufnr then - vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, _output) - end - end, - progress = function(flag) - self:run_spinner(flag) - end, - }) - ) + + params.stream = true + self.provider.api:chat_completions(response, { + custom_params = params, + partial_result_fn = function(...) + self:addAnswerPartial(...) + end, + }) end), }) @@ -161,28 +181,30 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) }, }, - -- Layout.Box({ - -- Layout.Box({ - -- Layout.Box(self.input_panel, { grow = 1 }), - -- Layout.Box(self.instructions_input, { size = 3 }), - -- }, { dir = "col", size = "50%" }), - -- Layout.Box(self.output_panel, { size = "50%" }), - -- }, { dir = "row" }), - Layout.Box({ Layout.Box({ - Layout.Box(self.input_panel, { grow = 1 }), + Layout.Box(self.selection_panel, { grow = 1 }), Layout.Box(self.instructions_input, { size = 3 }), }, { dir = "col", grow = 1 }), Layout.Box(self.output_panel, { grow = 1 }), - Layout.Box(self.parameters_panel, { size = 40 }), + Layout.Box({ + Layout.Box(self.parameters_panel, { grow = 1 }), + Layout.Box(self.system_role_panel, { size = 10 }), + Layout.Box(self.template_panel, { size = 8 }), + }, { dir = "col", grow = 1 }), }, { dir = "row" }), Config.options.edit.edgy ) -- accept output window - for _, window in ipairs({ self.input_panel, self.output_panel, self.instructions_input }) do + for _, window in ipairs({ + self.selection_panel, + self.output_panel, + self.instructions_input, + self.system_role_panel, + self.template_panel, + }) do for _, mode in ipairs({ "n", "i" }) do window:map(mode, Config.options.edit.keymaps.accept, function() self.instructions_input.input_props.on_close() @@ -194,18 +216,31 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) end -- use output as input - for _, window in ipairs({ self.input_panel, self.output_panel, self.instructions_input }) do + for _, window in ipairs({ + self.selection_panel, + self.output_panel, + self.instructions_input, + self.system_role_panel, + self.template_panel, + }) do for _, mode in ipairs({ "n", "i" }) do window:map(mode, Config.options.edit.keymaps.use_output_as_input, function() local lines = vim.api.nvim_buf_get_lines(self.output_panel.bufnr, 0, -1, false) - vim.api.nvim_buf_set_lines(self.input_panel.bufnr, 0, -1, false, lines) + vim.api.nvim_buf_set_lines(self.selection_panel.bufnr, 0, -1, false, lines) vim.api.nvim_buf_set_lines(self.output_panel.bufnr, 0, -1, false, {}) end, { noremap = true }) end end -- close - for _, window in ipairs({ self.input_panel, self.output_panel, self.instructions_input }) do + for _, window in ipairs({ + self.selection_panel, + self.output_panel, + self.instructions_input, + self.system_role_panel, + self.template_panel, + self.parameters_panel, + }) do for _, mode in ipairs({ "n", "i" }) do window:map(mode, Config.options.edit.keymaps.close, function() self.spinner:stop() @@ -219,13 +254,20 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) -- toggle parameters local parameters_open = true - for _, popup in ipairs({ self.parameters_panel, self.instructions_input, self.input_panel, self.output_panel }) do + for _, popup in ipairs({ + self.parameters_panel, + self.instructions_input, + self.system_role_panel, + self.template_panel, + self.selection_panel, + self.output_panel, + }) do for _, mode in ipairs({ "n", "i" }) do popup:map(mode, Config.options.edit.keymaps.toggle_parameters, function() if parameters_open then self.layout:update(Layout.Box({ Layout.Box({ - Layout.Box(self.input_panel, { grow = 1 }), + Layout.Box(self.selection_panel, { grow = 1 }), Layout.Box(self.instructions_input, { size = 3 }), }, { dir = "col", size = "50%" }), Layout.Box(self.output_panel, { size = "50%" }), @@ -235,14 +277,22 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) else self.layout:update(Layout.Box({ Layout.Box({ - Layout.Box(self.input_panel, { grow = 1 }), + Layout.Box(self.selection_panel, { grow = 1 }), Layout.Box(self.instructions_input, { size = 3 }), }, { dir = "col", grow = 1 }), Layout.Box(self.output_panel, { grow = 1 }), - Layout.Box(self.parameters_panel, { size = 40 }), + Layout.Box({ + Layout.Box(self.parameters_panel, { grow = 1 }), + Layout.Box(self.system_role_panel, { size = 10 }), + Layout.Box(self.template_panel, { size = 8 }), + }, { dir = "col", grow = 1 }), }, { dir = "row" })) self.parameters_panel:show() self.parameters_panel:mount() + self.template_panel:show() + self.template_panel:mount() + self.system_role_panel:show() + self.system_role_panel:mount() vim.api.nvim_set_current_win(self.parameters_panel.winid) vim.api.nvim_buf_set_option(self.parameters_panel.bufnr, "modifiable", false) @@ -251,7 +301,7 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) parameters_open = not parameters_open -- set input and output settings -- TODO - for _, window in ipairs({ self.input_panel, self.output_panel }) do + for _, window in ipairs({ self.selection_panel, self.output_panel }) do vim.api.nvim_buf_set_option(window.bufnr, "syntax", self.filetype) vim.api.nvim_win_set_option(window.winid, "number", true) end @@ -261,17 +311,24 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) -- cycle windows local active_panel = self.instructions_input - for _, popup in ipairs({ self.input_panel, self.output_panel, self.parameters_panel, self.instructions_input }) do + for _, popup in ipairs({ + self.selection_panel, + self.output_panel, + self.parameters_panel, + self.instructions_input, + self.system_role_panel, + self.template_panel, + }) do for _, mode in ipairs({ "n", "i" }) do - if mode == "i" and (popup == self.input_panel or popup == self.output_panel) then + if mode == "i" and (popup == self.selection_panel or popup == self.output_panel) then goto continue end popup:map(mode, Config.options.edit.keymaps.cycle_windows, function() if active_panel == self.instructions_input then - vim.api.nvim_set_current_win(self.input_panel.winid) - active_panel = self.input_panel + vim.api.nvim_set_current_win(self.selection_panel.winid) + active_panel = self.selection_panel vim.api.nvim_command("stopinsert") - elseif active_panel == self.input_panel and mode ~= "i" then + elseif active_panel == self.selection_panel and mode ~= "i" then vim.api.nvim_set_current_win(self.output_panel.winid) active_panel = self.output_panel vim.api.nvim_command("stopinsert") @@ -284,6 +341,12 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) active_panel = self.instructions_input end elseif active_panel == self.parameters_panel then + vim.api.nvim_set_current_win(self.system_role_panel.winid) + active_panel = self.system_role_panel + elseif active_panel == self.system_role_panel then + vim.api.nvim_set_current_win(self.template_panel.winid) + active_panel = self.template_panel + elseif active_panel == self.template_panel then vim.api.nvim_set_current_win(self.instructions_input.winid) active_panel = self.instructions_input end @@ -294,11 +357,18 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) -- toggle diff mode local diff_mode = Config.options.edit.diff - for _, popup in ipairs({ self.parameters_panel, self.instructions_input, self.output_panel, self.input_panel }) do + for _, popup in ipairs({ + self.parameters_panel, + self.instructions_input, + self.system_role_panel, + self.template_panel, + self.output_panel, + self.selection_panel, + }) do for _, mode in ipairs({ "n", "i" }) do popup:map(mode, Config.options.edit.keymaps.toggle_diff, function() diff_mode = not diff_mode - for _, winid in ipairs({ self.input_panel.winid, self.output_panel.winid }) do + for _, winid in ipairs({ self.selection_panel.winid, self.output_panel.winid }) do vim.api.nvim_set_current_win(winid) if diff_mode then vim.api.nvim_command("diffthis") @@ -312,7 +382,14 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) end -- set events - for _, popup in ipairs({ self.parameters_panel, self.instructions_input, self.output_panel, self.input_panel }) do + for _, popup in ipairs({ + self.parameters_panel, + self.instructions_input, + self.system_role_panel, + self.template_panel, + self.output_panel, + self.selection_panel, + }) do popup:on({ "BufUnload" }, function() self:set_loading(false) end) @@ -321,7 +398,7 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) self.layout:mount() -- set input if visual_lines then - vim.api.nvim_buf_set_lines(self.input_panel.bufnr, 0, -1, false, visual_lines) + vim.api.nvim_buf_set_lines(self.selection_panel.bufnr, 0, -1, false, visual_lines) end -- set output @@ -330,7 +407,7 @@ function EditAction:edit_with_instructions(output_lines, selection, opts, ...) end -- set input and output settings - for _, window in ipairs({ self.input_panel, self.output_panel }) do + for _, window in ipairs({ self.selection_panel, self.output_panel }) do vim.api.nvim_buf_set_option(window.bufnr, "syntax", "markdown") vim.api.nvim_win_set_option(window.winid, "number", true) end @@ -339,17 +416,15 @@ end function EditAction:build_edit_messages(input, instructions, opts) local _input = input - if opts.edit_code then - _input = "```" .. (opts.filetype or "") .. "\n" .. input .. "````" - else - _input = "```" .. (opts.filetype or "") .. "\n" .. input .. "````" - end - local variables = vim.tbl_extend("force", {}, { + + _input = "```" .. (opts.filetype or "") .. "\n" .. input .. "````" + + local variables = vim.tbl_extend("force", opts.variables, { instruction = instructions, input = _input, filetype = opts.filetype, - }, opts.variables) - local system_msg = opts.params.system or "" + }) + local system_msg = self.system local messages = { { role = "system", diff --git a/lua/ogpt/flows/actions/init.lua b/lua/ogpt/flows/actions/init.lua index d8c28f2..d2863bd 100644 --- a/lua/ogpt/flows/actions/init.lua +++ b/lua/ogpt/flows/actions/init.lua @@ -52,28 +52,33 @@ end function M.run_action(opts) local ACTIONS = M.read_actions() + local action_name = table.remove(opts.fargs, 1) - local action_opts = loadstring("return " .. opts.args)() or {} + local action_opts = loadstring("return " .. table.concat(opts.fargs, " "))() or {} - local action_name = opts.fargs[1] local item = ACTIONS[action_name] - - -- parse args - if item.args then - item.variables = {} - local i = 2 - for key, value in pairs(item.args) do - local arg = type(value.default) == "function" and value.default() or value.default or "" - -- TODO: validataion - item.variables[key] = arg - i = i + 1 - end + if not item then + vim.notify("Action '" .. action_name .. "' does not exist in OGPT Actions. Try checking your config table/JSON.") + return end - opts = vim.tbl_extend("force", {}, action_opts, item) + -- -- parse args + -- if item.args then + -- item.variables = {} + -- local i = 2 + -- for key, value in pairs(item.args) do + -- local arg = type(value.default) == "function" and value.default() or value.default or "" + -- -- TODO: validataion + -- item.variables[key] = arg + -- i = i + 1 + -- end + -- end + + opts = vim.tbl_extend("force", {}, item, action_opts) local class = classes_by_type[item.type] local action = class(action_name, opts) - action:run() + vim.schedule_wrap(action:run()) + return action end return M diff --git a/lua/ogpt/flows/actions/popup/init.lua b/lua/ogpt/flows/actions/popup/init.lua index 7b04701..5d4615c 100644 --- a/lua/ogpt/flows/actions/popup/init.lua +++ b/lua/ogpt/flows/actions/popup/init.lua @@ -1,4 +1,5 @@ local BaseAction = require("ogpt.flows.actions.base") +local Response = require("ogpt.response") local Spinner = require("ogpt.spinner") local PopupWindow = require("ogpt.flows.actions.popup.window") local utils = require("ogpt.utils") @@ -15,27 +16,40 @@ local STRATEGY_QUICK_FIX = "quick_fix" function PopupAction:init(name, opts) self.name = name or "" PopupAction.super.init(self, opts) + + local popup_options = Config.options.popup + if type(opts.type) == "table" then + popup_options = vim.tbl_extend("force", popup_options, opts.type.popup or {}) + end + self.provider = Config.get_provider(opts.provider, self) - self.params = Config.get_action_params(self.provider.name, opts.params or {}) + -- self.params = Config.get_action_params(self.provider, opts.params or {}) + self.params = self.provider:get_action_params(opts.params) self.system = type(opts.system) == "function" and opts.system() or opts.system or "" self.template = type(opts.template) == "function" and opts.template() or opts.template or "{{input}}" self.variables = opts.variables or {} + self.on_events = opts.on_events or {} self.strategy = opts.strategy or STRATEGY_DISPLAY self.ui = opts.ui or {} self.cur_win = vim.api.nvim_get_current_win() - self.edgy = Config.options.popup.edgy - self.popup = PopupWindow(Config.options.popup, Config.options.popup.edgy) + self.output_panel = PopupWindow(popup_options) self.spinner = Spinner:new(function(state) end) self:update_variables() - self.popup:on({ "BufUnload" }, function() + self.output_panel:on({ "BufUnload" }, function() self:set_loading(false) end) end +function PopupAction:close() + self.stop = true + self.output_panel:unmount() +end + function PopupAction:run() -- self.stop = false + local response = Response(self.provider) local params = self:get_params() local _, start_row, start_col, end_row, end_col = self:get_visual_selection() local opts = { @@ -52,26 +66,21 @@ function PopupAction:run() title = self.opts.title, args = self.opts.args, stop = function() - self.stop = true + self:close() + -- self.stop = true end, } if self.strategy == STRATEGY_DISPLAY then self:set_loading(true) - self.popup:mount(opts) + self.output_panel:mount(opts) params.stream = true - self.provider.api:chat_completions( - params, - utils.partial(utils.add_partial_completion, { - panel = self.popup, - progress = function(flag) - self:run_spinner(flag) - end, - on_complete = function(total_text) - -- print("completed: " .. total_text) - end, - }), - function() + self.provider.api:chat_completions(response, { + custom_params = params, + partial_result_fn = function(...) + self:addAnswerPartial(...) + end, + should_stop = function() -- should stop function if self.stop then -- self.stop = false @@ -81,57 +90,59 @@ function PopupAction:run() else return false end - end - ) + end, + }) else self:set_loading(true) - self.provider.api:chat_completions(params, function(answer, usage) - self:on_result(answer, usage) - end) + self.provider.api:chat_completions(response, { + custom_params = params, + partial_result_fn = function(...) + self:on_result(...) + end, + should_stop = nil, + }) end end function PopupAction:on_result(answer, usage) - vim.schedule(function() - self:set_loading(false) - local lines = utils.split_string_by_line(answer) - local _, start_row, start_col, end_row, end_col = self:get_visual_selection() - local bufnr = self:get_bufnr() - if self.strategy == STRATEGY_PREPEND then - answer = answer .. "\n" .. self:get_selected_text() - vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) - elseif self.strategy == STRATEGY_APPEND then - answer = self:get_selected_text() .. "\n\n" .. answer .. "\n" - vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) - elseif self.strategy == STRATEGY_REPLACE then - answer = answer - vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) - elseif self.strategy == STRATEGY_QUICK_FIX then - if #lines == 1 and lines[1] == "" then - vim.notify("Your Code looks fine, no issues.", vim.log.levels.INFO) - return - end + self:set_loading(false) + local lines = utils.split_string_by_line(answer) + local _, start_row, start_col, end_row, end_col = self:get_visual_selection() + local bufnr = self:get_bufnr() + if self.strategy == STRATEGY_PREPEND then + answer = answer .. "\n" .. self:get_selected_text() + vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) + elseif self.strategy == STRATEGY_APPEND then + answer = self:get_selected_text() .. "\n\n" .. answer .. "\n" + vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) + elseif self.strategy == STRATEGY_REPLACE then + answer = answer + vim.api.nvim_buf_set_text(bufnr, start_row - 1, start_col - 1, end_row - 1, end_col, lines) + elseif self.strategy == STRATEGY_QUICK_FIX then + if #lines == 1 and lines[1] == "" then + vim.notify("Your Code looks fine, no issues.", vim.log.levels.INFO) + return + end - local entries = {} - for _, line in ipairs(lines) do - local lnum, text = line:match("(%d+):(.*)") - if lnum then - local entry = { filename = vim.fn.expand("%:p"), lnum = tonumber(lnum), text = text } - table.insert(entries, entry) - end - end - if entries then - vim.fn.setqflist(entries) - vim.cmd(Config.options.show_quickfixes_cmd) + local entries = {} + for _, line in ipairs(lines) do + local lnum, text = line:match("(%d+):(.*)") + if lnum then + local entry = { filename = vim.fn.expand("%:p"), lnum = tonumber(lnum), text = text } + table.insert(entries, entry) end end - - -- set the cursor onto the answer - if self.strategy == STRATEGY_APPEND then - local target_line = end_row + 3 - vim.api.nvim_win_set_cursor(0, { target_line, 0 }) + if entries then + vim.fn.setqflist(entries) + vim.cmd(Config.options.show_quickfixes_cmd) end - end) + end + + -- set the cursor onto the answer + if self.strategy == STRATEGY_APPEND then + local target_line = end_row + 3 + vim.api.nvim_win_set_cursor(0, { target_line, 0 }) + end end return PopupAction diff --git a/lua/ogpt/flows/actions/popup/keymaps.lua b/lua/ogpt/flows/actions/popup/keymaps.lua index 5522ff6..37330d7 100644 --- a/lua/ogpt/flows/actions/popup/keymaps.lua +++ b/lua/ogpt/flows/actions/popup/keymaps.lua @@ -7,13 +7,16 @@ function M.apply_map(popup, opts) -- accept output and replace popup:map("n", Config.options.popup.keymaps.accept, function() -- local _lines = vim.api.nvim_buf_get_lines(popup.bufnr, 0, -1, false) + local _lines = vim.api.nvim_buf_get_lines(popup.bufnr, 0, -1, false) + table.insert(_lines, "") + table.insert(_lines, "") vim.api.nvim_buf_set_text( opts.main_bufnr, opts.selection_idx.start_row - 1, opts.selection_idx.start_col - 1, opts.selection_idx.end_row - 1, opts.selection_idx.end_col, - opts.lines + _lines ) vim.cmd("q") end) @@ -25,9 +28,9 @@ function M.apply_map(popup, opts) table.insert(_lines, "") vim.api.nvim_buf_set_text( opts.main_bufnr, - opts.selection_idx.end_row - 1, + opts.selection_idx.start_row - 1, opts.selection_idx.start_col - 1, - opts.selection_idx.end_row - 1, + opts.selection_idx.start_row - 1, opts.selection_idx.start_col - 1, _lines ) diff --git a/lua/ogpt/flows/actions/popup/window.lua b/lua/ogpt/flows/actions/popup/window.lua index 7ae0c84..126e7f6 100644 --- a/lua/ogpt/flows/actions/popup/window.lua +++ b/lua/ogpt/flows/actions/popup/window.lua @@ -1,4 +1,5 @@ local Popup = require("ogpt.common.popup") +local popup_keymap = require("ogpt.flows.actions.popup.keymaps") local Config = require("ogpt.config") local event = require("nui.utils.autocmd").event local Utils = require("ogpt.utils") @@ -9,7 +10,7 @@ function PopupWindow:init(options, edgy) options = vim.tbl_deep_extend("keep", options or {}, Config.options.popup) self.options = options - PopupWindow.super.init(self, options, self.edgy) + PopupWindow.super.init(self, options, edgy) end function PopupWindow:update_popup_size(opts) @@ -77,99 +78,9 @@ function PopupWindow:mount(opts) PopupWindow.super.mount(self) self:update_popup_size(opts) + popup_keymap.apply_map(self, opts) - -- close - local keys = Config.options.popup.keymaps.close - if type(keys) ~= "table" then - keys = { keys } - end - for _, key in ipairs(keys) do - self:map("n", key, function() - if opts.stop and type(opts.stop) == "function" then - opts.stop() - end - self:unmount() - end) - end - - -- accept output and replace - self:map("n", Config.options.popup.keymaps.accept, function() - local _lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) - table.insert(_lines, "") - table.insert(_lines, "") - vim.api.nvim_buf_set_text( - opts.main_bufnr, - opts.selection_idx.start_row - 1, - opts.selection_idx.start_col - 1, - opts.selection_idx.end_row - 1, - opts.selection_idx.end_col, - _lines - ) - vim.cmd("q") - end) - - -- accept output and prepend - self:map("n", Config.options.popup.keymaps.prepend, function() - local _lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) - table.insert(_lines, "") - table.insert(_lines, "") - vim.api.nvim_buf_set_text( - opts.main_bufnr, - opts.selection_idx.end_row - 1, - opts.selection_idx.start_col - 1, - opts.selection_idx.end_row - 1, - opts.selection_idx.start_col - 1, - _lines - ) - vim.cmd("q") - end) - - -- accept output and append - self:map("n", Config.options.popup.keymaps.append, function() - local _lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) - table.insert(_lines, 1, "") - table.insert(_lines, "") - vim.api.nvim_buf_set_text( - opts.main_bufnr, - opts.selection_idx.end_row, - opts.selection_idx.start_col - 1, - opts.selection_idx.end_row, - opts.selection_idx.start_col - 1, - _lines - ) - vim.cmd("q") - end) - - -- yank code in output and close - self:map("n", Config.options.popup.keymaps.yank_code, function() - local _lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) - local _code = Utils.getSelectedCode(_lines) - vim.fn.setreg(Config.options.yank_register, _code) - - if vim.fn.mode() == "i" then - vim.api.nvim_command("stopinsert") - end - vim.cmd("q") - end) - - -- yank output and close - self:map("n", Config.options.popup.keymaps.yank_to_register, function() - local _lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) - vim.fn.setreg(Config.options.yank_register, _lines) - - if vim.fn.mode() == "i" then - vim.api.nvim_command("stopinsert") - end - vim.cmd("q") - end) - - -- -- unmount component when cursor leaves buffer - -- self:on(event.BufLeave, function() - -- action.stop = true - -- self:unmount() - -- end) - - -- unmount component when cursor leaves buffer + -- unmount component when closing window self:on(event.WinClosed, function() if opts.stop and type(opts.stop) == "function" then opts.stop() diff --git a/lua/ogpt/flows/actions/template_helpers.lua b/lua/ogpt/flows/actions/template_helpers.lua new file mode 100644 index 0000000..2107eda --- /dev/null +++ b/lua/ogpt/flows/actions/template_helpers.lua @@ -0,0 +1,124 @@ +M = {} + +-- write a function to get the content of all the visible window +M.visible_window_content = function() + -- get the list of all visible windows + local wins = vim.api.nvim_list_wins() + wins = vim.tbl_filter(function(win) + return vim.api.nvim_win_get_tabpage(win) == vim.api.nvim_get_current_tabpage() + end, wins) + + -- for each of the win in wins, i need to get the buffer and make sure that the list of buffers are unique + local _buffers = {} + for _, win in ipairs(wins) do + -- get the buffer number of the current window + local bufnr = vim.api.nvim_win_get_buf(win) + + -- check if the buffer is loaded and resolves to a file in the project + local file = vim.fn.bufname(bufnr) + if file ~= "" and vim.fn.filereadable(file) == 1 then + -- add the buffer to the list of unique buffers + if not vim.tbl_contains(_buffers, bufnr) then + table.insert(_buffers, bufnr) + end + end + end + + -- initialize an empty string for the final output + local final_string = "" + + -- iterate over all unique buffers + for _, bufnr in ipairs(_buffers) do + -- get the filepath and line number for the current buffer + local file = vim.fn.bufname(bufnr) + + -- check if the buffer is loaded and resolves to a file in the project + if file ~= "" and vim.fn.filereadable(file) == 1 then + -- get the lines in the buffer + local lines = vim.api.nvim_buf_get_lines(bufnr, 0, -1, false) + + -- concatenate the lines with newline characters + local content = table.concat(lines, "\n") + + -- add the filepath and content to the final output + local _file_string = table.concat({ content }, "\n") + final_string = final_string .. _file_string + end + end + + -- return the final output + return final_string +end + +M.quickfix_content = function() + -- get the content of the quickfix list + local final_string = "" + -- Get the number of items in the quickfix list + local qf_items = vim.fn.getqflist() or {} + for _, item in ipairs(qf_items) do + -- Get the filepath and line number for the current item + local _file_string = table.concat({ + -- "# filepath: " .. filepath, + -- "#content", + -- "```", + item.text, + -- "````", + -- "", + }, "\n") + final_string = final_string .. _file_string + end + return final_string +end + +M.quickfix_file_content = function() + --- Gets the content of the vim quickfix list as a single string. + -- + -- Returns: + -- string -- The content of all files in the quickfix list, concatenated together. + local final_string = "" + -- Get the number of items in the quickfix list + local qf_items = vim.fn.getqflist() or {} + local _buffers = {} + for _, item in pairs(qf_items) do + table.insert(_buffers, item.bufnr) + end + -- ensure that we only get the content of unique files + _buffers = vim.fn.uniq(_buffers) + + for _, bufnr in ipairs(_buffers) do + -- Get the filepath and line number for the current item + local filepath = vim.fn.bufname(bufnr) + local fp = io.open(filepath, "r") + local content = fp:read("*all") + fp:close() + local _file_string = table.concat({ + -- string.format("# filepath: %s", filepath), + -- "#content", + -- "```", + content, + -- "```", + -- "", + }, "\n") + final_string = final_string .. _file_string + end + return final_string +end + +M.before_cursor = function() + -- get the current line number + local current_line = vim.fn.line(".") + + -- get the lines before the current line + local lines_before = vim.api.nvim_buf_get_lines(0, 0, current_line - 1, false) + return table.concat(lines_before, "\n") +end + +M.after_cursor = function() + -- get the current line number + local current_line = vim.fn.line(".") + -- get the lines after the current line + local lines_after = vim.api.nvim_buf_get_lines(0, current_line, -1, false) + return table.concat(lines_after, "\n") +end + +return M diff --git a/lua/ogpt/flows/chat/base.lua b/lua/ogpt/flows/chat/base.lua index b436b41..f023313 100644 --- a/lua/ogpt/flows/chat/base.lua +++ b/lua/ogpt/flows/chat/base.lua @@ -3,6 +3,7 @@ local Layout = require("ogpt.common.layout") local Popup = require("ogpt.common.popup") local ChatInput = require("ogpt.input") +local Response = require("ogpt.response") local Config = require("ogpt.config") local Parameters = require("ogpt.parameters") local Sessions = require("ogpt.flows.chat.sessions") @@ -10,7 +11,7 @@ local utils = require("ogpt.utils") local Signs = require("ogpt.signs") local Spinner = require("ogpt.spinner") local Session = require("ogpt.flows.chat.session") -local SystemWindow = require("ogpt.flows.chat.system_window") +local UtilWindow = require("ogpt.util_window") QUESTION, ANSWER, SYSTEM = 1, 2, 3 ROLE_ASSISTANT = "assistant" @@ -20,11 +21,14 @@ ROLE_USER = "user" local Chat = Object("Chat") function Chat:init(opts) + opts = opts or {} self.input_extmark_id = nil self.active_panel = nil self.selected_message_nsid = vim.api.nvim_create_namespace("OGPTNSSM") + self.is_running = false + -- quit indicator self.active = true self.focused = true @@ -50,17 +54,19 @@ function Chat:init(opts) self.params = Config.get_chat_params(opts.provider) self.session = Session.latest() - - self.provider = Config.get_provider(self.session.parameters.provider, self) + self.provider = nil self.selectedIndex = 0 self.role = ROLE_USER self.messages = {} self.spinner = Spinner:new(function(state) vim.schedule(function() - self:set_lines(-2, -1, false, { state .. " " .. Config.options.chat.loading_text }) + -- self:set_lines(-2, -1, false, { state .. " " .. Config.options.chat.loading_text }) + -- self:set_lines(-2, -1, false, { state }) self:display_input_suffix(state) end) - end) + end, { + -- text = Config.options.chat.loading_text, + }) end function Chat:welcome() @@ -69,6 +75,7 @@ function Chat:welcome() self:set_lines(0, -1, false, {}) self:set_cursor({ 1, 0 }) self:set_system_message(nil, true) + self.provider = Config.get_provider(self.session.parameters.provider, self) local conversation = self.session.conversation or {} if #conversation > 0 then @@ -137,24 +144,29 @@ function Chat:new_session() self.system_message = nil self.system_role_panel:set_text({}) self:welcome() + self.parameters_panel:reload_session_params(self.session) end function Chat:set_session(session) - vim.api.nvim_buf_clear_namespace(self.chat_window.bufnr, Config.namespace_id, 0, -1) + self.session = session or self.session - self.session = session + if vim.fn.bufexists(self.chat_window.bufnr) then + vim.api.nvim_buf_clear_namespace(self.chat_window.bufnr, Config.namespace_id, 0, -1) + end self.messages = {} self.selectedIndex = 0 self:set_lines(0, -1, false, {}) self:set_cursor({ 1, 0 }) self:welcome() - self:configure_parameters_panel(session) + self:configure_parameters_panel(self.session) + self.parameters_panel:reload_session_params(self.session) self:set_keymaps() end function Chat:isBusy() - return self.spinner:is_running() + -- return self.spinner:is_running() + return self.is_running end function Chat:add(type, text, usage) @@ -213,10 +225,25 @@ function Chat:addAnswer(text, usage) self:add(ANSWER, text, usage) end -function Chat:addAnswerPartial(text, state, ctx) +function Chat:addAnswerPartial(response) + local content = response:pop_content() + local text = content[1] + local state = content[2] + -- local state = response.state + -- local text = response.current_text + local ctx = response.ctx + self:stopSpinner() + + if state == "START" then + self.is_running = true + self:set_lines(-2, -1, false, { "" }) + vim.api.nvim_buf_set_option(self.chat_window.bufnr, "modifiable", true) + end + if state == "ERROR" then - self:stopSpinner() - return self:addAnswer(text, {}) + -- self:stopSpinner() + utils.log(text, vim.log.levels.ERROR) + -- return self:addAnswer(text, {}) end local start_line = 0 @@ -225,7 +252,12 @@ function Chat:addAnswerPartial(text, state, ctx) start_line = prev.end_line + (prev.type == ANSWER and 2 or 1) end - if state == "END" and text ~= "" then + -- if state == "END" and text == "" then + -- -- most likely, ended by the using raising the stop flag + -- self:stopSpinner() + if state == "END" then + -- self:stopSpinner() + vim.api.nvim_buf_set_option(self.chat_window.bufnr, "modifiable", true) local usage = {} local idx = self.session:add_item({ type = ANSWER, @@ -233,6 +265,7 @@ function Chat:addAnswerPartial(text, state, ctx) usage = usage, ctx = ctx or {}, }) + self.parameters_panel:reload_session_params() local lines = {} local nr_of_lines = 0 @@ -251,20 +284,15 @@ function Chat:addAnswerPartial(text, state, ctx) nr_of_lines = nr_of_lines, start_line = start_line, end_line = end_line, - context = ctx.context, + context = response:get_context(), }) self.selectedIndex = self.selectedIndex + 1 vim.api.nvim_buf_set_lines(self.chat_window.bufnr, -1, -1, false, { "", "" }) Signs.set_for_lines(self.chat_window.bufnr, start_line, end_line, "chat") end - if state == "START" then - self:stopSpinner() - self:set_lines(-2, -1, false, { "" }) - vim.api.nvim_buf_set_option(self.chat_window.bufnr, "modifiable", true) - end - if state == "START" or state == "CONTINUE" then + self.is_running = true vim.api.nvim_buf_set_option(self.chat_window.bufnr, "modifiable", true) local lines = vim.split(text, "\n", {}) local length = #lines @@ -285,6 +313,9 @@ function Chat:addAnswerPartial(text, state, ctx) end end end + -- assume after each partial answer, the API stopped streaming + -- gemini has no stop flag in its response + self.is_running = false end function Chat:get_total_tokens() @@ -465,10 +496,12 @@ function Chat:renderLastMessage() end function Chat:showProgess() + self.is_running = true self.spinner:start() end function Chat:stopSpinner() + -- just for the spinner, stop it, so we can add the response self.spinner:stop() self:display_input_suffix() end @@ -494,7 +527,15 @@ function Chat:toMessages() elseif msg.type == ANSWER then role = "assistant" end - table.insert(messages, { role = role, content = msg.text }) + table.insert(messages, { + role = role, + content = { + { + text = msg.text, + token = nil, + }, + }, + }) end -- return messages[#messages].content return messages @@ -683,6 +724,15 @@ function Chat:get_layout_params() return config, box end +-- self:stopSpinner() +-- function Chat:stop_output() +-- self.stop_flag = true +-- end +-- +function Chat:on_complete_response(response) + -- +end + function Chat:open() self.session.parameters = vim.tbl_extend("keep", self.session.parameters, self.params) self.parameters_panel = Parameters({ @@ -692,6 +742,7 @@ function Chat:open() parent = self, edgy = Config.options.chat.edgy, }) + self.sessions_panel = Sessions({ edgy = Config.options.chat.edgy, set_session_cb = function(session) @@ -699,20 +750,15 @@ function Chat:open() end, }) self.chat_window = Popup(Config.options.output_window, Config.options.chat.edgy) - self.system_role_panel = SystemWindow({ + self.system_role_panel = UtilWindow({ + filetype = "ogpt-system-window", + display = "System", + virtual_text = "Define your LLM system message here...", on_change = function(text) self:set_system_message(text) end, }, Config.options.chat.edgy) - self.stop = false - self.should_stop = function() - if self.stop then - self.stop = false - return true - else - return false - end - end + self.chat_input = ChatInput(Config.options.input_window, { edgy = Config.options.chat.edgy, prompt = Config.options.input_window.prompt, @@ -726,14 +772,22 @@ function Chat:open() end end), on_submit = function(value) - -- clear input - vim.api.nvim_buf_set_lines(self.chat_input.bufnr, 0, -1, false, { "" }) - if self:isBusy() then vim.notify("I'm busy, please wait a moment...", vim.log.levels.WARN) return end + -- create response object per api call + local response = Response(self.provider, { + on_start = function() + -- restart stop flag + self.stop_flag = false + end, + }) + + -- clear input + vim.api.nvim_buf_set_lines(self.chat_input.bufnr, 0, -1, false, { "" }) + self:addQuestion(value) if self.role == ROLE_USER then @@ -745,9 +799,17 @@ function Chat:open() messages = self:toMessages(), system = self.system_message, }, self.parameters_panel.params) - self.provider.api:chat_completions(params, function(answer, state, ctx) - self:addAnswerPartial(answer, state, ctx) - end, self.should_stop) + self.provider.api:chat_completions(response, { + custom_params = params, + partial_result_fn = function(...) + self:addAnswerPartial(...) + end, + should_stop = function() + -- check the stop flag if it should stop + -- return not self.is_running + return self.stop_flag + end, + }) end end, }) @@ -759,6 +821,8 @@ function Chat:open() -- initialize self.layout:mount() self:welcome() + + self.parameters_panel:reload_session_params() self:set_events() end @@ -820,7 +884,8 @@ function Chat:set_keymaps() -- stop generating self:map(Config.options.chat.keymaps.stop_generating, function() - self.stop = true + self.stop_flag = true + -- self.is_running = false end, { self.chat_input }) -- close @@ -853,6 +918,8 @@ function Chat:set_keymaps() -- new session self:map(Config.options.chat.keymaps.new_session, function() + -- self.stop_flag = true + self.is_running = false self:new_session() self.sessions_panel:refresh() end, { self.parameters_panel, self.chat_input, self.chat_window }) diff --git a/lua/ogpt/flows/chat/session.lua b/lua/ogpt/flows/chat/session.lua index 536bb9c..80dacb7 100644 --- a/lua/ogpt/flows/chat/session.lua +++ b/lua/ogpt/flows/chat/session.lua @@ -88,10 +88,21 @@ function Session:add_item(item) item.ctx = nil end if ctx and ctx.params and ctx.params.options then + -- ollama uses "options" to pass in parameters in its api calls self.parameters = ctx.params.options self.parameters.model = ctx.params.model - item.context = ctx.context + else + -- self.parameters = vim.tbl_extend("force", self.parameters, vim.tbl_get(ctx, "params", "parameters") or {}) + self.parameters = vim.tbl_get(ctx, "params", "parameters") or {} or self.parameters end + + -- add provider anbd model into a session to track + self.parameters.provider = ctx.provider + self.parameters.model = ctx.model + + -- handling context token + item.context = ctx.context + if self.updated_at == self.name and item.type == 1 then self.name = item.text end diff --git a/lua/ogpt/input.lua b/lua/ogpt/input.lua index ff51f64..f4b8fb7 100644 --- a/lua/ogpt/input.lua +++ b/lua/ogpt/input.lua @@ -53,7 +53,8 @@ function Input:init(popup_options, options, edgy) self.input_props = props props.on_submit = function(value) - local target_cursor = vim.api.nvim_win_get_cursor(self._.position.win) + -- local target_cursor = vim.api.nvim_win_get_cursor(self._.position.win) + local target_cursor = vim.api.nvim_win_get_cursor(self.winid) local prompt_normal_mode = vim.fn.mode() == "n" diff --git a/lua/ogpt/models.lua b/lua/ogpt/models.lua index 5d576f6..51661b6 100644 --- a/lua/ogpt/models.lua +++ b/lua/ogpt/models.lua @@ -53,9 +53,8 @@ local finder = function(provider, opts) :new({ command = "curl", args = { - provider.envs.MODELS_URL, - "-H", - provider.envs.AUTHORIZATION_HEADER, + provider:models_url(), + table.unpack(provider:request_headers()), }, on_exit = vim.schedule_wrap(function(j, exit_code) if exit_code ~= 0 then @@ -81,16 +80,20 @@ local finder = function(provider, opts) process_result(v) end - if provider.parse_api_model_response then - provider.parse_api_model_response(json, process_single_model) - else - -- default processor for a REST response from a curl for models - for _, model in ipairs(json.models) do - local v = entry_maker(model) - num_results = num_results + 1 - results[num_results] = v - process_result(v) - end + local _models = {} + _models = vim.tbl_extend("force", _models, json.models or {}) + _models = vim.tbl_extend("force", _models, Config.get_local_model_definition(provider) or {}) + + provider:parse_api_model_response(json, process_single_model) + + -- default processor for a REST response from a curl for models + for name, properties in pairs(_models) do + local v = entry_maker({ + name = name, + }) + num_results = num_results + 1 + results[num_results] = v + process_result(v) end process_complete() diff --git a/lua/ogpt/module.lua b/lua/ogpt/module.lua index bf73bc2..42669da 100644 --- a/lua/ogpt/module.lua +++ b/lua/ogpt/module.lua @@ -1,12 +1,93 @@ -- module represents a lua module for the plugin -local M = {} +local M = { + chats = {}, + actions = {}, +} -local Chat = require("ogpt.flows.chat") +local Config = require("ogpt.config") +local Chat = require("ogpt.flows.chat.base") local Actions = require("ogpt.flows.actions") +local Session = require("ogpt.flows.chat.session") +local Prompts = require("ogpt.prompts") -M.open_chat = Chat.open -M.focus_chat = Chat.focus -M.open_chat_with_awesome_prompt = Chat.open_with_awesome_prompt +-- M.open_chat = Chat.open +-- M.focus_chat = Chat.focus +-- M.open_chat_with_awesome_prompt = Chat.open_with_awesome_prompt M.run_action = Actions.run_action +function M.clear_windows() + if not Config.options.single_window then + return + end + + for _, chat in ipairs(M.chats) do + chat:hide() + end + for _, action in ipairs(M.actions) do + action:close() + end +end + +function M.open_chat(opts) + if Config.options.single_window then + for _, chat in ipairs(M.chats) do + chat:hide() + end + for _, action in ipairs(M.actions) do + action:close() + end + end + + if M.chats ~= nil and M.chats.active then + M.chats:toggle() + else + M.chats = Chat(opts) + M.chats:open(opts) + end +end + +function M.focus_chat(opts) + if M.chats ~= nil then + if not M.chats.focused then + M.chats:hide(opts) + M.chats:show(opts) + end + else + M.chats = Chat(opts) + M.chats:open(opts) + end +end + +function M.open_chat_with_awesome_prompt(opts) + Prompts.selectAwesomePrompt({ + cb = vim.schedule_wrap(function(act, prompt) + -- create new named session + local session = Session({ name = act }) + session:save() + + local chat = Chat:new() + chat:open() + chat.chat_window.border:set_text("top", " OGPT - Acts as " .. act .. " ", "center") + + chat:set_system_message(prompt) + chat:open_system_panel() + end), + }) +end + +function M.run_action(opts) + if Config.options.single_window then + for _, chat in ipairs(M.chats) do + chat:hide() + end + for _, action in ipairs(M.actions) do + action:close() + end + end + local action_hdl = Actions.run_action(opts) + if action_hdl then + table.insert(M.actions, action_hdl) + end +end + return M diff --git a/lua/ogpt/parameters.lua b/lua/ogpt/parameters.lua index 76f3672..a192bb5 100644 --- a/lua/ogpt/parameters.lua +++ b/lua/ogpt/parameters.lua @@ -79,6 +79,7 @@ local params_order = { "low_vram", "main_gpu", "max_tokens", + "max_new_tokens", "mirostat", "mirostat_eta", "mirostat_tau", @@ -111,6 +112,8 @@ local params_validators = { embedding_only = model_validator(), f16_kv = model_validator(), frequency_penalty = float_validator(), + max_tokens = integer_validator(), + max_new_tokens = integer_validator(), mirostat = integer_validator(), mirostat_eta = float_validator(), mirostat_tau = float_validator(), @@ -177,7 +180,7 @@ end function Parameters:read_config(session) if not session then local home = os.getenv("HOME") or os.getenv("USERPROFILE") - local file = io.open(home .. "/" .. ".ogpt-" .. self.parent_type .. "-params.json", "rb") + local file = io.open(home .. "/" .. ".ogpt-" .. self.type .. "-params.json", "rb") if not file then return nil end @@ -191,9 +194,10 @@ function Parameters:read_config(session) end function Parameters:write_config(config, session) + session = session or self.session if not session then local home = os.getenv("HOME") or os.getenv("USERPROFILE") - local file, err = io.open(home .. "/" .. ".ogpt-" .. self.parent_type .. "-params.json", "w") + local file, err = io.open(home .. "/" .. ".ogpt-" .. self.type .. "-params.json", "w") if file ~= nil then local json_string = vim.json.encode(config) file:write(json_string) @@ -250,17 +254,17 @@ function Parameters:init(opts) Parameters.super.init(self, vim.tbl_extend("force", Config.options.parameters_window, opts), opts.edgy) self.vts = {} - local type = opts.type - local default_params = opts.default_params - local session = opts.session - local parent = opts.parent + self.type = opts.type + self.default_params = opts.default_params + self.session = opts.session + self.parent = opts.parent self.parent_type = type - local custom_params = self:read_config(session or {}) + local custom_params = self:read_config(self.session or {}) - self.params = vim.tbl_deep_extend("force", {}, default_params, custom_params or {}) - if session then - self.params = session.parameters + self.params = vim.tbl_deep_extend("force", {}, self.default_params, custom_params or {}) + if self.session then + self.params = self.session.parameters end self:refresh_panel() @@ -302,28 +306,37 @@ function Parameters:init(opts) local key = existing_order[row] if key == "model" then local models = require("ogpt.models") - models.select_model(parent.provider, { + models.select_model(self.parent.provider, { cb = function(display, value) - self:update_property(key, row, value, session) + self:update_property(key, row, value, self.session) end, }) elseif key == "provider" then local provider = require("ogpt.provider") provider.select_provider({ cb = function(display, value) - self:update_property(key, row, value, session) - parent.provider = Config.get_provider(value) + self:update_property(key, row, value, self.session) + self.parent.provider = Config.get_provider(value) end, }) else local value = self.params[key] self:open_edit_property_input(key, value, row, function(new_value) - self:update_property(key, row, Utils.process_string(new_value), session) + self:update_property(key, row, Utils.process_string(new_value), self.session) end) end end, {}) end +function Parameters:reload_session_params(session) + if session then + self.session = session + end + + self.params = self.session.parameters + self:refresh_panel() +end + function Parameters:update_property(key, row, new_value, session) if not key or not new_value then self.params[key] = nil diff --git a/lua/ogpt/provider/base.lua b/lua/ogpt/provider/base.lua new file mode 100644 index 0000000..eaad968 --- /dev/null +++ b/lua/ogpt/provider/base.lua @@ -0,0 +1,275 @@ +local Object = require("ogpt.common.object") +local Config = require("ogpt.config") +local utils = require("ogpt.utils") +local Response = require("ogpt.response") + +local Provider = Object("Provider") + +function Provider:init(opts) + self.name = string.lower(self.class.name) + opts = vim.tbl_extend("force", Config.options.providers[self.name], opts) + self.enabled = opts.enabled + self.model = opts.model + self.stream_response = true + self.models = opts.models or {} + self.api_host = opts.api_host + self.api_key = opts.api_key + self.api_params = opts.api_params + self.api_chat_params = opts.api_chat_params + self.response_params = { + strategy = Response.STRATEGY_LINE_BY_LINE, + split_chunk_match_regex = nil, + -- split_chunk_match_regex = "[^%,]+", -- use to splitting by commas + -- split_chunk_match_regex = "[^\n]+", -- use to splitting by newline character + } + self.envs = {} + self.api_parameters = { + "model", + "messages", + "format", + "options", + "system", + "template", + "stream", + "raw", + } + + self.api_chat_request_options = { + "num_keep", + "seed", + "num_predict", + "top_k", + "top_p", + "tfs_z", + "typical_p", + "repeat_last_n", + "temperature", + "repeat_penalty", + "presence_penalty", + "frequency_penalty", + "mirostat", + "mirostat_tau", + "mirostat_eta", + "penalize_newline", + "stop", + "numa", + "num_ctx", + "num_batch", + "num_gqa", + "num_gpu", + "main_gpu", + "low_vram", + "f16_kv", + "logits_all", + "vocab_only", + "use_mmap", + "use_mlock", + "embedding_only", + "rope_frequency_base", + "rope_frequency_scale", + "num_thread", + } +end + +function Provider:load_envs(override) + local _envs = {} + _envs.OLLAMA_API_HOST = Config.options.providers.ollama.api_host + or os.getenv("OLLAMA_API_HOST") + or "http://localhost:11434" + _envs.OLLAMA_API_KEY = Config.options.providers.ollama.api_key or os.getenv("OLLAMA_API_KEY") or "" + _envs.MODELS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/tags") + _envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/generate") + _envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/chat") + _envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (_envs.OLLAMA_API_KEY or " ") + + self.envs = vim.tbl_extend("force", _envs, override or {}) + return self.envs +end + +function Provider:completion_url() + return self.envs.CHAT_COMPLETIONS_URL +end + +function Provider:models_url() + return self.envs.MODELS_URL +end + +function Provider:request_headers() + return { + "-H", + "Content-Type: application/json", + "-H", + self.envs.AUTHORIZATION_HEADER, + } +end + +function Provider:parse_api_model_response(json, cb) + -- Given a json object from the api, parse this and get the names of the model to be displayed + for _, model in ipairs(json.models) do + cb({ + name = model.name, + }) + end +end + +function Provider:conform_request(params) + -- https://github.com/jmorganca/ollama/blob/main/docs/api.md#show-model-information + + local param_options = {} + + for key, value in pairs(params) do + if not vim.tbl_contains(self.api_parameters, key) then + if vim.tbl_contains(self.api_chat_request_options, key) then + -- move it to the options + param_options[key] = value + params[key] = nil + else + params[key] = nil + end + end + end + local _options = vim.tbl_extend("keep", param_options, params.options or {}) + if next(_options) ~= nil then + params.options = _options + end + return params +end + +function Provider:conform_messages(params) + -- ensure we only have one system message + local _to_remove_system_idx = {} + for idx, message in ipairs(params.messages) do + if message.role == "system" then + table.insert(_to_remove_system_idx, idx) + end + end + + -- Remove elements from the list based on indices + for i = #_to_remove_system_idx, 1, -1 do + table.remove(params.messages, _to_remove_system_idx[i]) + end + + -- conform to support text only model + local messages = params.messages + local conformed_messages = {} + for _, message in ipairs(messages) do + table.insert(conformed_messages, { + role = message.role, + content = utils.gather_text_from_parts(message.content), + }) + end + params.messages = conformed_messages + + return params +end + +function Provider:process_response(response) + local chunk = response:pop_chunk() + local ok, json = pcall(vim.json.decode, chunk) + + if not ok then + utils.log("Cannot process ollama response: \n" .. vim.inspect(chunk)) + json = {} + response:could_not_process(chunk) + return + end + + -- given a JSON response from the STREAMING api, processs it + if type(json) == "string" then + utils.log("got something weird. " .. json, vim.log.levels.ERROR) + elseif vim.tbl_isempty(json) then + utils.log("got nothing in json.") + elseif json.done then + if json.message then + response:add_processed_text(json.message.content, "CONTINUE") + end + elseif json.message then + response:add_processed_text(json.message.content, "CONTINUE") + else + utils.log("unexpected case in ollama", vim.log.levels.ERROR) + end +end + +function Provider:get_action_params(opts) + return vim.tbl_extend( + "force", + { model = self.model }, -- add in model from provider default + self.api_params, -- override with provider api_params + opts or {} + ) -- override with final options +end + +function Provider:expand_model(params, ctx) + params.stream = self.stream_response + params = self:get_action_params(params) + ctx = ctx or {} + local provider_models = self.models + local _model = params.model + + local _completion_url = self:completion_url() + + local function _expand(name, _m) + if type(_m) == "table" then + if _m.modify_url and type(_m.modify_url) == "function" then + _completion_url = _m.modify_url(_completion_url) + elseif _m.modify_url then + _completion_url = _m.modify_url + else + params.model = _m.name + for _, model in ipairs(provider_models) do + if model.name == _m.name then + _expand(model) + break + end + end + end + params.model = _m or name + else + for _name, model in pairs(provider_models) do + if _name == _m then + if type(model) == "table" then + _expand(_name, model) + break + elseif type(model) == "string" then + params.model = model + end + end + end + end + return params + end + + local _full_unfiltered_params = _expand(nil, _model) + ctx.tokens = _full_unfiltered_params.model.tokens or {} + -- final force override from the params that are set in the mode itself. + -- This will enforce specific model params, e.g. max_token, etc + local final_overrided_applied_params = + vim.tbl_extend("force", params, vim.tbl_get(_full_unfiltered_params, "model", "params") or {}) + + params = self:conform_to_provider_request(final_overrided_applied_params) + + return params, _completion_url, ctx +end + +function Provider:conform_to_provider_request(params) + params = self:get_action_params(params) + local _model = params.model + local _conform_messages_fn = _model and _model.conform_messages_fn + local _conform_request_fn = _model and _model.conform_request_fn + + if _conform_messages_fn then + params = _conform_messages_fn(self, params) + else + params = self:conform_messages(params) + end + + if _conform_request_fn then + params = _conform_request_fn(self, params) + else + params = self:conform_request(params) + end + + return params +end + +return Provider diff --git a/lua/ogpt/provider/gemini.lua b/lua/ogpt/provider/gemini.lua new file mode 100644 index 0000000..f683574 --- /dev/null +++ b/lua/ogpt/provider/gemini.lua @@ -0,0 +1,147 @@ +local Config = require("ogpt.config") +local utils = require("ogpt.utils") +local Response = require("ogpt.response") +local ProviderBase = require("ogpt.provider.base") +local Gemini = ProviderBase:extend("Gemini") + +function Gemini:init(opts) + Gemini.super.init(self, opts) + self.response_params = { + strategy = Response.STRATEGY_REGEX, + -- split_chunk_match_regex = "^%,\n+", + -- split_chunk_match_regex = '"text": "(.-)"', + split_chunk_match_regex = '"parts": %[[\n%s]+(.-)%]', -- capture everything in the "content.parts" of the response + } + self.api_parameters = { + "contents", + } + self.api_chat_request_options = { + "stopSequences", + "temperature", + "topP", + "topK", + } +end + +function Gemini:load_envs(override) + local _envs = {} + _envs.GEMINI_API_KEY = Config.options.providers.gemini.api_key or os.getenv("GEMINI_API_KEY") or "" + _envs.AUTH = "key=" .. (_envs.GEMINI_API_KEY or " ") + _envs.MODEL = "gemini-pro" + _envs.GEMINI_API_HOST = Config.options.providers.gemini.api_host + or os.getenv("GEMINI_API_HOST") + or "https://generativelanguage.googleapis.com/v1beta" + _envs.MODELS_URL = utils.ensureUrlProtocol(_envs.GEMINI_API_HOST .. "/models") + self.envs = vim.tbl_extend("force", _envs, override or {}) + return self.envs +end + +function Gemini:parse_api_model_response(json, cb) + -- Given a json object from the api, parse this and get the names of the model to be displayed + for _, model in ipairs(json.models) do + cb({ + name = model.name, + }) + end +end + +function Gemini:completion_url() + if self.stream_response then + return utils.ensureUrlProtocol( + self.envs.GEMINI_API_HOST .. "/models/" .. self.model .. ":streamGenerateContent?" .. self.envs.AUTH + ) + end + return utils.ensureUrlProtocol( + self.envs.GEMINI_API_HOST .. "/models/" .. self.model .. ":generateContent?" .. self.envs.AUTH + ) +end + +function Gemini:models_url() + return utils.ensureUrlProtocol(self.envs.GEMINI_API_HOST .. "/models?" .. self.envs.AUTH) +end + +function Gemini:request_headers() + return { + "-H", + "Content-Type: application/json", + } +end + +function Gemini:conform_request(params) + --https://ai.google.dev/tutorials/rest_quickstart#multi-turn_conversations_chat + + local param_options = {} + + for key, value in pairs(params) do + if not vim.tbl_contains(self.api_parameters, key) then + if vim.tbl_contains(self.api_chat_request_options, key) then + -- move it to the options + param_options[key] = value + params[key] = nil + else + params[key] = nil + end + end + end + local _options = vim.tbl_extend("keep", param_options, params.options or {}) + if next(_options) ~= nil then + params.generationConfig = _options + end + return params +end + +function Gemini:conform_messages(params) + -- ensure we only have one system message + local _to_remove_system_idx = {} + for idx, message in ipairs(params.messages) do + if message.role == "system" then + table.insert(_to_remove_system_idx, idx) + end + end + -- Remove elements from the list based on indices + for i = #_to_remove_system_idx, 1, -1 do + table.remove(params.messages, _to_remove_system_idx[i]) + end + + local messages = params.messages + local _contents = {} + for _, content in ipairs(messages) do + table.insert(_contents, { + role = content.role == "assistant" and "model" or "user", + parts = { + text = utils.gather_text_from_parts(content.content), + }, + }) + end + + params.messages = nil + params.contents = _contents + return params +end + +function Gemini:process_response(response) + local chunk = response:pop_chunk() + utils.log("Popped chunk: " .. chunk) + + local ok, json = pcall(vim.json.decode, chunk) + + if not ok then + -- but it back if its not processed + utils.log("Could not process chunk: " .. chunk) + response:could_not_process(chunk) + return + end + + if type(json) == "string" then + utils.log("Something is going on, json is a string :" .. vim.inspect(json), vim.log.levels.ERROR) + else + local text = vim.tbl_get(json, "text") + if text then + response:add_processed_text(text, "CONTINUE") + else + utils.log("Mising 'text' from response: " .. vim.inspect(json)) + end + end +end + +return Gemini diff --git a/lua/ogpt/provider/init.lua b/lua/ogpt/provider/init.lua index f0dc2d8..4dc4e8e 100644 --- a/lua/ogpt/provider/init.lua +++ b/lua/ogpt/provider/init.lua @@ -2,7 +2,6 @@ local pickers = require("telescope.pickers") local conf = require("telescope.config").values local actions = require("telescope.actions") local action_state = require("telescope.actions.state") -local job = require("plenary.job") local Utils = require("ogpt.utils") local Config = require("ogpt.config") diff --git a/lua/ogpt/provider/ollama.lua b/lua/ogpt/provider/ollama.lua index 3def8ba..a67e8b0 100644 --- a/lua/ogpt/provider/ollama.lua +++ b/lua/ogpt/provider/ollama.lua @@ -1,131 +1,6 @@ -local Config = require("ogpt.config") -local utils = require("ogpt.utils") +local ProviderBase = require("ogpt.provider.base") -local M = {} +-- ollama is first class citizen on OGPT +local Ollama = ProviderBase:extend("Ollama") -M.name = "ollama" -M.envs = {} -M.models = {} - -function M.load_envs(envs) - local _envs = {} - _envs.OLLAMA_API_HOST = Config.options.providers.ollama.api_host - or os.getenv("OLLAMA_API_HOST") - or "http://localhost:11434" - _envs.OLLAMA_API_KEY = Config.options.providers.ollama.api_key or os.getenv("OLLAMA_API_KEY") or "" - _envs.MODELS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/tags") - _envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/generate") - _envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OLLAMA_API_HOST .. "/api/chat") - _envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (_envs.OLLAMA_API_KEY or " ") - M.envs = vim.tbl_extend("force", M.envs, _envs) - M.envs = vim.tbl_extend("force", M.envs, envs or {}) - return M.envs -end - -M.api_parameters = { - "model", - "messages", - "format", - "options", - "system", - "template", - "stream", - "raw", -} - -M.api_chat_request_options = { - "num_keep", - "seed", - "num_predict", - "top_k", - "top_p", - "tfs_z", - "typical_p", - "repeat_last_n", - "temperature", - "repeat_penalty", - "presence_penalty", - "frequency_penalty", - "mirostat", - "mirostat_tau", - "mirostat_eta", - "penalize_newline", - "stop", - "numa", - "num_ctx", - "num_batch", - "num_gqa", - "num_gpu", - "main_gpu", - "low_vram", - "f16_kv", - "logits_all", - "vocab_only", - "use_mmap", - "use_mlock", - "embedding_only", - "rope_frequency_base", - "rope_frequency_scale", - "num_thread", -} - -function M.parse_api_model_response(json, cb) - -- Given a json object from the api, parse this and get the names of the model to be displayed - for _, model in ipairs(json.models) do - cb({ - name = model.name, - }) - end -end - -function M.conform(params) - -- https://github.com/jmorganca/ollama/blob/main/docs/api.md#show-model-information - - local param_options = {} - - for key, value in pairs(params) do - if not vim.tbl_contains(M.api_parameters, key) then - if vim.tbl_contains(M.api_chat_request_options, key) then - -- move it to the options - param_options[key] = value - params[key] = nil - else - params[key] = nil - end - end - end - local _options = vim.tbl_extend("keep", param_options, params.options or {}) - if next(_options) ~= nil then - params.options = _options - end - return params -end - -function M.process_line(content, ctx, raw_chunks, state, cb) - local _json = content.json - local raw = content.raw - -- given a JSON response from the STREAMING api, processs it - if _json and _json.done then - if _json.message then - -- for stream=false case - cb(_json.message.content, state, ctx) - raw_chunks = raw_chunks .. _json.message.content - state = "CONTINUE" - else - ctx.context = _json.context - cb(raw_chunks, "END", ctx) - end - else - if not vim.tbl_isempty(_json) then - if _json and _json.message then - cb(_json.message.content, state, ctx) - raw_chunks = raw_chunks .. _json.message.content - state = "CONTINUE" - end - end - end - - return ctx, raw_chunks, state -end - -return M +return Ollama diff --git a/lua/ogpt/provider/openai.lua b/lua/ogpt/provider/openai.lua index 6d49e45..b4b5ce8 100644 --- a/lua/ogpt/provider/openai.lua +++ b/lua/ogpt/provider/openai.lua @@ -1,13 +1,26 @@ local Config = require("ogpt.config") local utils = require("ogpt.utils") +local ProviderBase = require("ogpt.provider.base") +local Response = require("ogpt.response") +local Openai = ProviderBase:extend("openai") -local M = {} - -M.name = "openai" - -M.envs = {} +function Openai:init(opts) + Openai.super.init(self, opts) + self.name = "openai" + self.api_parameters = { + "model", + "messages", + "stream", + "temperature", + "presence_penalty", + "frequency_penalty", + "top_p", + "max_tokens", + } + self.api_chat_request_options = {} +end -function M.load_envs() +function Openai:load_envs(override) local _envs = {} _envs.OPENAI_API_HOST = Config.options.providers.openai.api_host or os.getenv("OPENAI_API_HOST") @@ -17,22 +30,11 @@ function M.load_envs() _envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OPENAI_API_HOST .. "/v1/completions") _envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.OPENAI_API_HOST .. "/v1/chat/completions") _envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (_envs.OPENAI_API_KEY or " ") - M.envs = vim.tbl_extend("force", M.envs, _envs) - return M.envs + self.envs = vim.tbl_extend("force", _envs, override or {}) + return self.envs end -M._api_chat_parameters = { - "model", - "messages", - "stream", - "temperature", - "presence_penalty", - "frequency_penalty", - "top_p", - "max_tokens", -} - -function M.parse_api_model_response(json, cb) +function Openai:parse_api_model_response(json, cb) local data = json.data or {} for _, model in ipairs(data) do cb({ @@ -41,76 +43,44 @@ function M.parse_api_model_response(json, cb) end end -function M.conform(params) - params = M._conform_messages(params) - +function Openai:conform_request(params) for key, value in pairs(params) do - if not vim.tbl_contains(M._api_chat_parameters, key) then - utils.log("Did not process " .. key .. " for " .. M.name) + if not vim.tbl_contains(self.api_parameters, key) then + utils.log("Did not process " .. key .. " for " .. self.name) params[key] = nil end end return params end -function M._conform_messages(params) - -- ensure we only have one system message - local _to_remove_system_idx = {} - for idx, message in ipairs(params.messages) do - if message.role == "system" then - table.insert(_to_remove_system_idx, idx) - end - end - -- Remove elements from the list based on indices - for i = #_to_remove_system_idx, 1, -1 do - table.remove(params.messages, _to_remove_system_idx[i]) - end - - -- https://platform.openai.com/docs/api-reference/chat - if params.system then - table.insert(params.messages, 1, { - role = "system", - content = params.system, - }) +function Openai:process_response(response) + local chunk = response:pop_chunk() + local raw_json = string.gsub(chunk, "^data:", "") + local _ok, _json = pcall(vim.json.decode, raw_json) + if _ok then + self:_process_line({ json = _json, raw = chunk }, response) + else + self:_process_line({ json = nil, raw = chunk }, response) end - return params end -function M.process_line(content, ctx, raw_chunks, state, cb) +function Openai:_process_line(content, response) local _json = content.json - local raw = content.raw - -- given a JSON response from the STREAMING api, processs it - if _json and _json.done then - ctx.context = _json.context - cb(raw_chunks, "END", ctx) - elseif type(_json) == "string" and string.find(_json, "[DONE]") then - cb(raw_chunks, "END", ctx) - else - if - not vim.tbl_isempty(_json) - and _json - and _json.choices - and _json.choices[1] - and _json.choices[1].delta - and _json.choices[1].delta.content - then - cb(_json.choices[1].delta.content, state) - raw_chunks = raw_chunks .. _json.choices[1].delta.content - state = "CONTINUE" - elseif - not vim.tbl_isempty(_json) - and _json - and _json.choices - and _json.choices[1] - and _json.choices[1].message - and _json.choices[1].message.content - then - cb(_json.choices[1].message.content, state) - raw_chunks = raw_chunks .. _json.choices[1].message.content + local _raw = content.raw + if _json then + local text_delta = vim.tbl_get(_json, "choices", 1, "delta", "content") + local text = vim.tbl_get(_json, "choices", 1, "message", "content") + if text_delta then + response:add_processed_text(text_delta, "CONTINUE") + elseif text then + -- done end + elseif not _json and string.find(_raw, "[DONE]") then + -- done + else + response:could_not_process(_raw) + utils.log("Could not process chunk for openai: " .. _raw, vim.log.levels.DEBUG) end - - return ctx, raw_chunks, state end -return M +return Openai diff --git a/lua/ogpt/provider/textgenui.lua b/lua/ogpt/provider/textgenui.lua index ea2a92e..b1015de 100644 --- a/lua/ogpt/provider/textgenui.lua +++ b/lua/ogpt/provider/textgenui.lua @@ -1,30 +1,39 @@ local Config = require("ogpt.config") local utils = require("ogpt.utils") -local M = {} - -M.name = "textgenui" - -M.request_params = { - "inputs", - "parameters", - "stream", -} - -M.model_params = { - "seed", - "top_k", - "top_p", - "top_n_tokens", - "typical_p", - "stop", - "details", - "max_new_tokens", - "repetition_penalty", -} - -M.envs = {} - -function M.load_envs() +local Response = require("ogpt.response") + +local ProviderBase = require("ogpt.provider.base") +local Textgenui = ProviderBase:extend("Textgenui") + +function Textgenui:init(opts) + Textgenui.super.init(self, opts) + self.name = "textgenui" + self.api_parameters = { + "inputs", + "parameters", + -- "stream", + } + self.api_chat_request_options = { + "best_of", + "decoder_input_details", + "details", + "do_sample", + "max_new_tokens", + "repetition_penalty", + "return_full_text", + "seed", + "stop", + "temperature", + "top_k", + "top_n_tokens", + "top_p", + "truncate", + "typical_p", + "watermark", + } +end + +function Textgenui:load_envs(override) local _envs = {} _envs.TEXTGEN_API_HOST = Config.options.providers.textgenui.api_host or os.getenv("TEXTGEN_API_HOST") @@ -34,23 +43,25 @@ function M.load_envs() _envs.COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.TEXTGEN_API_HOST) _envs.CHAT_COMPLETIONS_URL = utils.ensureUrlProtocol(_envs.TEXTGEN_API_HOST) _envs.AUTHORIZATION_HEADER = "Authorization: Bearer " .. (_envs.TEXTGEN_API_KEY or " ") - M.envs = vim.tbl_extend("force", M.envs, _envs) - return M.envs + self.envs = vim.tbl_extend("force", _envs, override or {}) + return self.envs end -M.textgenui_options = { "seed", "top_k", "top_p", "stop" } +function Textgenui:completion_url() + if self.stream_response then + return self.envs.TEXTGEN_API_HOST .. "/generate_stream" + end + return self.envs.TEXTGEN_API_HOST .. "/generate" +end -function M.conform(params) +function Textgenui:conform_request(params) params = params or {} - -- textgenui uses "inputs" - params.inputs = M._conform_messages(params.messages or {}) - local param_options = {} for key, value in pairs(params) do - if not vim.tbl_contains(M.request_params, key) then - if vim.tbl_contains(M.model_params, key) then + if not vim.tbl_contains(self.api_parameters, key) then + if vim.tbl_contains(self.api_chat_request_options, key) then param_options[key] = value params[key] = nil else @@ -65,53 +76,91 @@ function M.conform(params) return params end -function M._conform_messages(messages) +function Textgenui:conform_messages(params) + local messages = params.messages or {} -- https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1 local tokens = { BOS = "", EOS = "", + BOSYS = "<>", + EOSYS = "<>", INST_START = "[INST]", INST_END = "[/INST]", } local _input = { tokens.BOS } for i, message in ipairs(messages) do + local text = utils.gather_text_from_parts(message.content) if i < #messages then -- Stop before the last item - if message.role == "user" then + if message.role == "system" then + table.insert(_input, tokens.BOSYS) + table.insert(_input, text) + table.insert(_input, tokens.EOSYS) + elseif message.role == "user" then table.insert(_input, tokens.INST_START) - table.insert(_input, message.content) + table.insert(_input, text) table.insert(_input, tokens.INST_END) - elseif message.role == "system" then - table.insert(_input, message.content) + elseif message.role == "assistant" then + table.insert(_input, text) end else table.insert(_input, tokens.EOS) + table.insert(_input, tokens.BOS) table.insert(_input, tokens.INST_START) - table.insert(_input, message.content) + table.insert(_input, text) table.insert(_input, tokens.INST_END) end end local final_string = table.concat(_input, " ") - return final_string + params.inputs = final_string + return params end -function M.process_line(content, ctx, raw_chunks, state, cb) - local _json = content.json - local raw = content.raw - if _json.token then - if _json.token.text == "" then - ctx.context = _json.context - cb(raw_chunks, "END", ctx) - else - cb(_json.token.text, state, ctx) - raw_chunks = raw_chunks .. _json.token.text - state = "CONTINUE" - end - elseif _json.error then - cb(_json.error, "ERROR", ctx) +function Textgenui:process_response(response) + local ctx = response.ctx + local chunk = response:pop_chunk() + + local raw_json = string.gsub(chunk, "^data:", "") + + local ok, _json = pcall(vim.json.decode, raw_json) + + if not ok then + utils.log("Something went wrong with parsing Textgetui json: " .. vim.inspect(response.current_raw_chunk)) + response:could_not_process(chunk) + _json = {} + end + _json = _json or {} + + if _json.error ~= nil then + local error_msg = { + "OGPT ERROR:", + self.provider.name, + vim.inspect(_json.error) or "", + "Something went wrong.", + } + table.insert(error_msg, vim.inspect(response.rest_params)) + response.error = error_msg + response:add_processed_text(error_msg, "ERROR") + return + end + + if not _json.token then + return + end + + if _json.token.text and string.find(_json.token.text, "") then + -- Done + elseif + _json.token.text + and vim.tbl_get(ctx, "tokens", "end_of_result") + and string.find(_json.token.text, vim.tbl_get(ctx, "tokens", "end_of_result")) + then + ctx.context = _json.context + -- done + elseif _json.token.generated_text then + -- done else - print(_json) + response:add_processed_text(_json.token.text, "CONTINUE") end - return ctx, raw_chunks, state end -return M +return Textgenui diff --git a/lua/ogpt/response.lua b/lua/ogpt/response.lua new file mode 100644 index 0000000..df071f4 --- /dev/null +++ b/lua/ogpt/response.lua @@ -0,0 +1,190 @@ +local Object = require("ogpt.common.object") +local utils = require("ogpt.utils") +local async = require("plenary.async.async") +local channel = require("plenary.async.control").channel + +local Response = Object("Response") + +Response.STRATEGY_LINE_BY_LINE = "line" +Response.STRATEGY_CHUNK = "chunk" +Response.STRATEGY_REGEX = "regex" +Response.STATE_INIT = "initialized" +Response.STATE_INPROGRESS = "inprogress" +Response.STATE_COMPLETED = "completed" +Response.STATE_ERRORED = "errored" +Response.STATE_STOPPED = "stopped" + +function Response:init(provider, events) + self.events = events or {} + self.accumulated_chunks = {} + self.processed_text = {} + self.rest_params = {} + self.ctx = {} + self.partial_result_cb = nil + self.in_progress = false + self.strategy = provider.rest_strategy + self.provider = provider + self.not_processed = "" + self.raw_chunk_tx, self.raw_chunk_rx = channel.mpsc() + self.processed_raw_tx, self.processsed_raw_rx = channel.mpsc() + self.processed_content_tx, self.processsed_content_rx = channel.mpsc() + self.response_state = nil + self.chunk_regex = "" + self:set_state(self.STATE_INIT) +end + +function Response:set_processed_text(text) + self.processed_text = text +end + +function Response:add_chunk(chunk) + utils.log("Pushed chunk: " .. chunk, vim.log.levels.TRACE) + self.raw_chunk_tx.send(chunk) +end + +function Response:could_not_process(chunk) + self.not_processed = chunk +end + +function Response:run_async() + async.run(function() + while true do + self:_process_added_chunk() + end + end) + + async.run(function() + while true do + self.provider:process_response(self) + end + end) + + async.run(function() + while true do + self:render() + end + end) +end + +function Response:_process_added_chunk() + local chunk = self.raw_chunk_rx.recv() + utils.log("recv'd chunk: " .. chunk, vim.log.levels.TRACE) + + if self.provider.response_params.strategy == self.STRATEGY_CHUNK then + self.processed_raw_tx.send(chunk) + return + end + -- everything regex related below this line + local _split_regex = "[^\n]+" -- default regex, for line-by-line strategy + if self.provider.response_params.strategy == self.STRATEGY_REGEX then + _split_regex = self.provider.response_params.split_chunk_match_regex or _split_regex + end + + local success = false + for line in chunk:gmatch(_split_regex) do + success = true + utils.log("Chunk processed using regex: " .. _split_regex, vim.log.levels.DEBUG) + self.processed_raw_tx.send(line) + end + + if not success then + utils.log("Chunk COULD NOT BE PROCESSED by regex: '" .. _split_regex .. "' : " .. chunk, vim.log.levels.DEBUG) + end +end + +function Response:pop_content() + local content = self.processsed_content_rx.recv() + if content[2] == "END" and (content[1] and content[1] == "") then + content[1] = self:get_processed_text() + end + return content +end + +function Response:render() + self.partial_result_cb(self) +end + +function Response:pop_chunk() + utils.log("Try to pop chunk...", vim.log.levels.TRACE) + -- pop the next chunk and add anything that is not processs + local _value = self.not_processed + self.not_processed = "" + local _chunk = self.processsed_raw_rx.recv() + utils.log("Got chunk... now appending to 'not_processed'", vim.log.levels.TRACE) + return _value .. _chunk +end + +function Response:get_accumulated_chunks() + return table.concat(self.accumulated_chunks, "") +end + +function Response:add_processed_text(text, flag) + text = text or "" + if vim.tbl_isempty(self.processed_text) then + -- remove the first space found in most llm responses + text = string.gsub(text, "^ ", "") + end + table.insert(self.processed_text, text) + self.processed_content_tx.send({ text, flag }) +end + +function Response:get_processed_text() + return vim.trim(table.concat(self.processed_text, ""), " ") +end + +function Response:get_processed_text_by_lines() + return vim.split(self:get_processed_text(), "\n") +end + +function Response:get_context() + return self.ctx +end + +function Response:set_state(state) + self.response_state = state + self:monitor_state() +end + +function Response:monitor_state() + if self.response_state == self.STATE_INIT then + utils.log("Response Initialized.", vim.log.levels.DEBUG) + if self.events.on_start then + self.events.on_start() + end + -- + elseif self.response_state == self.STATE_INPROGRESS then + utils.log("Response In-progress.", vim.log.levels.DEBUG) + -- + elseif self.response_state == self.STATE_ERRORED then + utils.log("Response Errored.", vim.log.levels.DEBUG) + -- + elseif self.response_state == self.STATE_COMPLETED then + utils.log("Response Completed.", vim.log.levels.DEBUG) + self:add_processed_text("", "END") + self:set_state(self.STATE_INIT) + elseif self.response_state == self.STATE_STOPPED then + utils.log("Response Stoped.", vim.log.levels.DEBUG) + self:add_processed_text("", "END") + self:set_state(self.STATE_INIT) + elseif self.response_state then + utils.log("unknown state: " .. self.response_state, vim.log.levels.ERROR) + end +end + +function Response:extract_code() + local response_text = self:get_processed_text() + local code_response = utils.extract_code(response_text) + -- if the chat is to edit code, it will try to extract out the code from response + local output_txt = response_text + if code_response then + output_txt = utils.match_indentation(response_text, code_response) + else + vim.notify("no codeblock detected", vim.log.levels.INFO) + end + if response_text.applied_changes then + vim.notify(response_text.applied_changes, vim.log.levels.INFO) + end + return output_txt +end + +return Response diff --git a/lua/ogpt/flows/chat/system_window.lua b/lua/ogpt/util_window.lua similarity index 53% rename from lua/ogpt/flows/chat/system_window.lua rename to lua/ogpt/util_window.lua index 3d9607b..2da0236 100644 --- a/lua/ogpt/flows/chat/system_window.lua +++ b/lua/ogpt/util_window.lua @@ -1,18 +1,31 @@ local Popup = require("ogpt.common.popup") local Config = require("ogpt.config") -local SystemWindow = Popup:extend("SystemWindow") +local UtilWindow = Popup:extend("UtilWindow") -function SystemWindow:init(options, edgy) +function UtilWindow:init(options, edgy) + local popup_defaults = { + border = { + text = { + top = options.display or "UTIL", + }, + }, + buf_options = { + filetype = options.filetype or "ogpt-util-window", + }, + } + local popup_options = vim.tbl_deep_extend("force", Config.options.util_window, popup_defaults) + + self.virtual_text = options.virtual_text + self.default_text = options.default_text or "" self.working = false self.on_change = options.on_change - options = vim.tbl_deep_extend("force", options or {}, Config.options.system_window) - - SystemWindow.super.init(self, options, edgy) + UtilWindow.super.init(self, popup_options, edgy) + self:set_text(vim.fn.split(self.default_text, "\n")) end -function SystemWindow:toggle_placeholder() +function UtilWindow:toggle_placeholder() local lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) local text = table.concat(lines, "\n") @@ -24,7 +37,7 @@ function SystemWindow:toggle_placeholder() self.extmark = vim.api.nvim_buf_set_extmark(self.bufnr, Config.namespace_id, 0, 0, { virt_text = { { - "You are a helpful assistant.", + self.virtual_text or "", "@comment", }, }, @@ -33,14 +46,14 @@ function SystemWindow:toggle_placeholder() end end -function SystemWindow:set_text(text) +function UtilWindow:set_text(text) self.working = true vim.api.nvim_buf_set_lines(self.bufnr, 0, -1, false, text) self.working = false end -function SystemWindow:mount() - SystemWindow.super.mount(self) +function UtilWindow:mount() + UtilWindow.super.mount(self) self:toggle_placeholder() vim.api.nvim_buf_attach(self.bufnr, false, { @@ -50,10 +63,12 @@ function SystemWindow:mount() if not self.working then local lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) local text = table.concat(lines, "\n") - self.on_change(text) + if self.on_change and type(self.on_change) == "function" then + self.on_change(text) + end end end, }) end -return SystemWindow +return UtilWindow diff --git a/lua/ogpt/utils.lua b/lua/ogpt/utils.lua index 017a708..0c68a6b 100644 --- a/lua/ogpt/utils.lua +++ b/lua/ogpt/utils.lua @@ -1,4 +1,5 @@ local Config = require("ogpt.config") +local Path = require("plenary.path") local M = {} local ESC_FEEDKEY = vim.api.nvim_replace_termcodes("", true, false, true) @@ -253,56 +254,6 @@ function M.trim(s) return (s:gsub("^%s*(.-)%s*$", "%1")) end -function M.add_partial_completion(opts, text, state) - local panel = opts.panel - local progress = opts.progress - - if state == "ERROR" then - if progress then - progress(false) - end - M.log("An Error Occurred: " .. text, vim.log.levels.ERROR) - panel:unmount() - return - end - - local start_line = 0 - if state == "END" and text ~= "" then - if not opts.on_complete then - return - end - return opts.on_complete(text) - end - - if state == "START" then - if progress then - progress(false) - end - if M.is_buf_exists(panel.bufnr) then - vim.api.nvim_buf_set_option(panel.bufnr, "modifiable", true) - end - text = M.trim(text) - end - - if state == "START" or state == "CONTINUE" then - local lines = vim.split(text, "\n", {}) - local length = #lines - local buffer = panel.bufnr - - for i, line in ipairs(lines) do - if buffer and vim.fn.bufexists(buffer) then - local currentLine = vim.api.nvim_buf_get_lines(buffer, -2, -1, false)[1] - if currentLine then - vim.api.nvim_buf_set_lines(buffer, -2, -1, false, { currentLine .. line }) - if i == length and i > 1 then - vim.api.nvim_buf_set_lines(buffer, -1, -1, false, { "" }) - end - end - end - end - end -end - function M.process_string(inputString) -- Check if the inputString contains a comma if inputString:find(",") then @@ -386,12 +337,30 @@ function M.format_table(tbl, indent) return result end +local log_filename = + Path:new(vim.fn.stdpath("state")):joinpath("ogpt", "ogpt-" .. os.date("%Y-%m-%d") .. ".log"):absolute() -- convert Path object to string + +function M.write_to_log(msg) + local file = io.open(log_filename, "ab") + if file then + file:write(os.date("[%Y-%m-%d %H:%M:%S] ")) + file:write(msg .. "\n") + file:close() + else + vim.notify("Failed to open log file for writing", vim.log.levels.ERROR) + end +end + function M.log(msg, level) + level = level or vim.log.levels.INFO + msg = vim.inspect(msg) - level = level or vim.log.levels.DEBUG - Config.logs[#Config.logs + 1] = { msg = msg, level = level } - if Config.options.debug then - vim.notify(msg, level, { title = "OGPT Debug" }) + if level >= Config.options.debug.log_level then + M.write_to_log(msg) + end + + if level >= Config.options.debug.notify_level then + vim.notify(msg, level, { title = "OGPT Debug" }, level) end end @@ -403,4 +372,16 @@ function M.shallow_copy(t) return t2 end +function M.gather_text_from_parts(parts) + if type(parts) == "string" then + return parts + else + local _text = {} + for _, part in ipairs(parts) do + table.insert(_text, part.text) + end + return table.concat(_text, " ") + end +end + return M diff --git a/plugin/ogpt.lua b/plugin/ogpt.lua index b84fe33..0c479f6 100644 --- a/plugin/ogpt.lua +++ b/plugin/ogpt.lua @@ -31,6 +31,15 @@ end, { end, }) +vim.api.nvim_create_user_command("OGPTRunWithOpts", function(params) + local args = loadstring("return " .. params.args)() + require("ogpt").run_action(args) +end, { + nargs = "?", + force = true, + complete = "lua", +}) + vim.api.nvim_create_user_command("OGPTCompleteCode", function(opts) opts = loadstring("return " .. opts.args)() or {} require("ogpt").complete_code(opts)