From ddc43877ac8dc65ea8025ca0bb131c049b816d1e Mon Sep 17 00:00:00 2001 From: Huy Le Date: Tue, 21 Nov 2023 06:52:23 -0700 Subject: [PATCH 1/2] fixing --- lua/ogpt/code_edits.lua | 40 +++++++++++++------------- lua/ogpt/flows/chat/session.lua | 9 +++--- lua/ogpt/{models.lua => templates.lua} | 0 3 files changed, 24 insertions(+), 25 deletions(-) rename lua/ogpt/{models.lua => templates.lua} (100%) diff --git a/lua/ogpt/code_edits.lua b/lua/ogpt/code_edits.lua index 6dc4546..e21e2e8 100644 --- a/lua/ogpt/code_edits.lua +++ b/lua/ogpt/code_edits.lua @@ -139,7 +139,7 @@ M.edit_with_instructions = function(output_lines, bufnr, selection, ...) end local api_params = Config.options.api_edit_params local use_functions_for_edits = Config.options.use_openai_functions_for_edits - local settings_panel = Parameters.get_parameters_panel("edits", api_params) + local parameters_panel = Parameters.get_parameters_panel("edits", api_params) input_window = Popup(Config.options.popup_window) output_window = Popup(Config.options.popup_window) instructions_input = ChatInput(Config.options.popup_input, { @@ -227,12 +227,12 @@ M.edit_with_instructions = function(output_lines, bufnr, selection, ...) end, { noremap = true }) end - -- toggle settings - local settings_open = false - for _, popup in ipairs({ settings_panel, instructions_input }) do + -- toggle parameters + local parameters_open = false + for _, popup in ipairs({ parameters_panel, instructions_input }) do for _, mode in ipairs({ "n", "i" }) do - popup:map(mode, Config.options.edit_with_instructions.keymaps.toggle_settings, function() - if settings_open then + popup:map(mode, Config.options.edit_with_instructions.keymaps.toggle_parameters, function() + if parameters_open then layout:update(Layout.Box({ Layout.Box({ Layout.Box(input_window, { grow = 1 }), @@ -240,7 +240,7 @@ M.edit_with_instructions = function(output_lines, bufnr, selection, ...) }, { dir = "col", size = "50%" }), Layout.Box(output_window, { size = "50%" }), }, { dir = "row" })) - settings_panel:hide() + parameters_panel:hide() vim.api.nvim_set_current_win(instructions_input.winid) else layout:update(Layout.Box({ @@ -249,16 +249,16 @@ M.edit_with_instructions = function(output_lines, bufnr, selection, ...) Layout.Box(instructions_input, { size = 3 }), }, { dir = "col", grow = 1 }), Layout.Box(output_window, { grow = 1 }), - Layout.Box(settings_panel, { size = 40 }), + Layout.Box(parameters_panel, { size = 40 }), }, { dir = "row" })) - settings_panel:show() - settings_panel:mount() + parameters_panel:show() + parameters_panel:mount() - vim.api.nvim_set_current_win(settings_panel.winid) - vim.api.nvim_buf_set_option(settings_panel.bufnr, "modifiable", false) - vim.api.nvim_win_set_option(settings_panel.winid, "cursorline", true) + vim.api.nvim_set_current_win(parameters_panel.winid) + vim.api.nvim_buf_set_option(parameters_panel.bufnr, "modifiable", false) + vim.api.nvim_win_set_option(parameters_panel.winid, "cursorline", true) end - settings_open = not settings_open + parameters_open = not parameters_open -- set input and output settings -- TODO for _, window in ipairs({ input_window, output_window }) do @@ -271,7 +271,7 @@ M.edit_with_instructions = function(output_lines, bufnr, selection, ...) -- cycle windows local active_panel = instructions_input - for _, popup in ipairs({ input_window, output_window, settings_panel, instructions_input }) do + for _, popup in ipairs({ input_window, output_window, parameters_panel, instructions_input }) do for _, mode in ipairs({ "n", "i" }) do if mode == "i" and (popup == input_window or popup == output_window) then goto continue @@ -286,14 +286,14 @@ M.edit_with_instructions = function(output_lines, bufnr, selection, ...) active_panel = output_window vim.api.nvim_command("stopinsert") elseif active_panel == output_window and mode ~= "i" then - if settings_open then - vim.api.nvim_set_current_win(settings_panel.winid) - active_panel = settings_panel + if parameters_open then + vim.api.nvim_set_current_win(parameters_panel.winid) + active_panel = parameters_panel else vim.api.nvim_set_current_win(instructions_input.winid) active_panel = instructions_input end - elseif active_panel == settings_panel then + elseif active_panel == parameters_panel then vim.api.nvim_set_current_win(instructions_input.winid) active_panel = instructions_input end @@ -304,7 +304,7 @@ M.edit_with_instructions = function(output_lines, bufnr, selection, ...) -- toggle diff mode local diff_mode = Config.options.edit_with_instructions.diff - for _, popup in ipairs({ settings_panel, instructions_input }) do + for _, popup in ipairs({ parameters_panel, instructions_input }) do for _, mode in ipairs({ "n", "i" }) do popup:map(mode, Config.options.edit_with_instructions.keymaps.toggle_diff, function() diff_mode = not diff_mode diff --git a/lua/ogpt/flows/chat/session.lua b/lua/ogpt/flows/chat/session.lua index 9dde5c3..d6c95ab 100644 --- a/lua/ogpt/flows/chat/session.lua +++ b/lua/ogpt/flows/chat/session.lua @@ -59,14 +59,14 @@ function Session:to_export() return { name = self.name, updated_at = self.updated_at, - settings = self.parameters, + parameters = self.parameters, conversation = self.conversation, } end function Session:previous_context() - if #self.conversation > 1 then - return self.conversation[#self.conversation].context + if #self.conversation > 2 then + return self.conversation[#self.conversation - 1].context end return {} end @@ -79,7 +79,6 @@ function Session:add_item(item) if ctx and ctx.params and ctx.params.options then self.parameters = ctx.params.options self.parameters.model = ctx.params.model - self.parameters.model = ctx.params.model item.context = ctx.context end if self.updated_at == self.name and item.type == 1 then @@ -140,7 +139,7 @@ function Session:load() local data = vim.json.decode(jsonString) self.name = data.name self.updated_at = data.updated_at or get_current_date() - self.parameters = data.settings + self.parameters = data.parameters self.conversation = data.conversation end diff --git a/lua/ogpt/models.lua b/lua/ogpt/templates.lua similarity index 100% rename from lua/ogpt/models.lua rename to lua/ogpt/templates.lua From f5b1374ec2c9e15138113acad40d6c959b52e151 Mon Sep 17 00:00:00 2001 From: Huy Le Date: Tue, 21 Nov 2023 14:15:01 -0700 Subject: [PATCH 2/2] pull models from API, and can now add/remove settings --- lua/ogpt/api.lua | 3 +- lua/ogpt/config.lua | 10 +- lua/ogpt/flows/chat/base.lua | 2 +- lua/ogpt/flows/chat/session.lua | 7 +- lua/ogpt/models.lua | 113 +++++++++++++++++ lua/ogpt/parameters.lua | 207 ++++++++++++++++++++++++++------ lua/ogpt/utils.lua | 73 +++++------ 7 files changed, 334 insertions(+), 81 deletions(-) create mode 100644 lua/ogpt/models.lua diff --git a/lua/ogpt/api.lua b/lua/ogpt/api.lua index 942b1ab..fcdc58d 100644 --- a/lua/ogpt/api.lua +++ b/lua/ogpt/api.lua @@ -57,7 +57,7 @@ function Api.chat_completions(custom_params, cb, should_stop) local ok, json = pcall(vim.json.decode, chunk) if ok and json ~= nil then if json.error ~= nil then - cb(json.error.message, "ERROR", ctx) + cb(json.error, "ERROR", ctx) return end process_line(ok, json) @@ -274,6 +274,7 @@ end function Api.setup() loadApiHost("OLLAMA_API_HOST", "OLLAMA_API_HOST", "api_host_cmd", function(value) Api.OLLAMA_API_HOST = value + Api.MODELS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/tags") Api.COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/generate") Api.CHAT_COMPLETIONS_URL = ensureUrlProtocol(Api.OLLAMA_API_HOST .. "/api/generate") end, "http://localhost:11434/api/generate") diff --git a/lua/ogpt/config.lua b/lua/ogpt/config.lua index 8f907ac..afbf60a 100644 --- a/lua/ogpt/config.lua +++ b/lua/ogpt/config.lua @@ -26,8 +26,8 @@ function M.defaults() loading_text = "Loading, please wait ...", question_sign = "", -- 🙂 answer_sign = "ﮧ", -- 🤖 - border_left_sign = "", - border_right_sign = "", + border_left_sign = "|", + border_right_sign = "|", max_line_length = 120, sessions_window = { active_sign = "  ", @@ -54,7 +54,7 @@ function M.defaults() cycle_modes = "", next_message = "", prev_message = "", - select_session = "", + select_session = "", rename_session = "r", delete_session = "d", draft_message = "", @@ -141,8 +141,6 @@ function M.defaults() }, api_params = { model = "mistral:7b", - frequency_penalty = 0, - presence_penalty = 0, -- max_tokens = 300, temperature = 0.8, top_p = 1, @@ -152,7 +150,7 @@ function M.defaults() model = "codellama:13b", frequency_penalty = 0, presence_penalty = 0, - temperature = 0, + temperature = 0.5, top_p = 1, -- n = 1, }, diff --git a/lua/ogpt/flows/chat/base.lua b/lua/ogpt/flows/chat/base.lua index 20363ff..f1bd907 100644 --- a/lua/ogpt/flows/chat/base.lua +++ b/lua/ogpt/flows/chat/base.lua @@ -828,7 +828,7 @@ function Chat:set_keymaps() self:map(Config.options.chat.keymaps.new_session, function() self:new_session() Sessions:refresh() - end, { self.parameters_panel, self.chat_input }) + end, { self.parameters_panel, self.chat_input, self.chat_window }) -- cycle panes self:map(Config.options.chat.keymaps.cycle_windows, function() diff --git a/lua/ogpt/flows/chat/session.lua b/lua/ogpt/flows/chat/session.lua index d6c95ab..84d8b9f 100644 --- a/lua/ogpt/flows/chat/session.lua +++ b/lua/ogpt/flows/chat/session.lua @@ -65,8 +65,11 @@ function Session:to_export() end function Session:previous_context() - if #self.conversation > 2 then - return self.conversation[#self.conversation - 1].context + for ith = #self.conversation, 1, -1 do + local context = self.conversation[ith].context + if context then + return context + end end return {} end diff --git a/lua/ogpt/models.lua b/lua/ogpt/models.lua new file mode 100644 index 0000000..4800362 --- /dev/null +++ b/lua/ogpt/models.lua @@ -0,0 +1,113 @@ +local pickers = require("telescope.pickers") +local conf = require("telescope.config").values +local actions = require("telescope.actions") +local action_state = require("telescope.actions.state") +local job = require("plenary.job") +local Api = require("ogpt.api") + +local Utils = require("ogpt.utils") +local Config = require("ogpt.config") + +local function preview_command(entry, bufnr, width) + vim.api.nvim_buf_call(bufnr, function() + local preview = Utils.wrapTextToTable(entry.value, width - 5) + table.insert(preview, 1, "---") + table.insert(preview, 1, entry.display) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, true, preview) + end) +end + +local function entry_maker(model) + return { + value = model.name, + display = model.name, + ordinal = model.digest, + preview_command = preview_command, + } +end + +local finder = function(opts) + local job_started = false + local job_completed = false + local results = {} + local num_results = 0 + + return setmetatable({ + close = function() + -- TODO: check if we need to make some cleanup + end, + }, { + __call = function(_, prompt, process_result, process_complete) + if job_completed then + local current_count = num_results + for index = 1, current_count do + if process_result(results[index]) then + break + end + end + process_complete() + end + + if not job_started then + job_started = true + job + :new({ + command = "curl", + args = { + opts.url, + }, + on_exit = vim.schedule_wrap(function(j, exit_code) + if exit_code ~= 0 then + vim.notify("An Error Occurred, cannot fetch list of prompts ...", vim.log.levels.ERROR) + process_complete() + end + + local response = table.concat(j:result(), "\n") + local json = vim.fn.json_decode(response) + + for _, model in ipairs(json.models) do + local v = entry_maker(model) + num_results = num_results + 1 + results[num_results] = v + process_result(v) + end + + process_complete() + job_completed = true + end), + }) + :start() + end + end, + }) +end +-- + +local M = {} +function M.select_model(opts) + opts = opts or {} + pickers + .new(opts, { + sorting_strategy = "ascending", + layout_config = { + height = 0.5, + }, + results_title = "Select Ollama Model", + prompt_prefix = Config.options.popup_input.prompt, + selection_caret = Config.options.chat.answer_sign .. " ", + prompt_title = "Models", + finder = finder({ url = Api.MODELS_URL }), + sorter = conf.generic_sorter(opts), + attach_mappings = function(prompt_bufnr) + actions.select_default:replace(function() + actions.close(prompt_bufnr) + local selection = action_state.get_selected_entry() + opts.cb(selection.display, selection.value) + end) + return true + end, + }) + :find() +end + +return M diff --git a/lua/ogpt/parameters.lua b/lua/ogpt/parameters.lua index a25d71e..37d1ad5 100644 --- a/lua/ogpt/parameters.lua +++ b/lua/ogpt/parameters.lua @@ -1,3 +1,9 @@ +local pickers = require("telescope.pickers") +local Utils = require("ogpt.utils") +local conf = require("telescope.config").values +local actions = require("telescope.actions") +local action_state = require("telescope.actions.state") + local M = {} M.vts = {} @@ -12,14 +18,55 @@ local float_validator = function(min, max) end end +local bool_validator = function(min, max) + return function(value) + local stringtoboolean = { ["true"] = true, ["false"] = false } + return stringtoboolean(value) + end +end + local integer_validator = function(min, max) return function(value) return tonumber(value) end end -local model_validator = function(value) - return value +local model_validator = function() + return function(value) + return value + end +end + +local function preview_command(entry, bufnr, width) + vim.api.nvim_buf_call(bufnr, function() + local preview = Utils.wrapTextToTable(entry.value, width - 5) + table.insert(preview, 1, "---") + table.insert(preview, 1, entry.display) + vim.api.nvim_buf_set_lines(bufnr, 0, -1, true, preview) + end) +end + +local finder = function(opts) + return setmetatable({ + close = function() + -- TODO: check if we need to make some cleanup + end, + }, { + __call = function(_, prompt, process_result, process_complete) + local _params = { + values = opts.parameters, + } + for _, param in ipairs(_params.values) do + process_result({ + value = param, + display = param, + ordinal = param, + preview_command = preview_command, + }) + end + process_complete() + end, + }) end local params_order = { @@ -41,8 +88,6 @@ local params_order = { "num_keep", "num_predict", "num_thread", - "numa", - "penalize_newline", "presence_penalty", "repeat_last_n", "repeat_penalty", @@ -60,12 +105,34 @@ local params_order = { "vocab_only", } local params_validators = { - model = model_validator, - frequency_penalty = float_validator(-2, 2), - presence_penalty = float_validator(-2, 2), - max_tokens = integer_validator(0, 4096), - temperature = float_validator(0, 1), - top_p = float_validator(0, 1), + model = model_validator(), + embedding_only = model_validator(), + f16_kv = model_validator(), + frequency_penalty = float_validator(), + mirostat = integer_validator(), + mirostat_eta = float_validator(), + mirostat_tau = float_validator(), + num_batch = integer_validator(), + num_ctx = integer_validator(), + num_gpu = integer_validator(), + num_gqa = integer_validator(), + num_keep = integer_validator(), + num_predict = integer_validator(), + num_thread = integer_validator(), + presence_penalty = float_validator(), + repeat_last_n = integer_validator(), + repeat_penalty = float_validator(), + seed = integer_validator(), + stop = model_validator(), + temperature = float_validator(), + tfs_z = float_validator(), + top_k = float_validator(), + top_p = float_validator(), + logits_all = bool_validator(), + vocab_only = bool_validator(), + use_mmap = bool_validator(), + use_mlock = bool_validator(), + low_vram = bool_validator(), } local function write_virtual_text(bufnr, ns, line, chunks, mode) @@ -77,6 +144,34 @@ local function write_virtual_text(bufnr, ns, line, chunks, mode) end end +function M.select_parameter(opts) + opts = opts or {} + pickers + .new(opts, { + sorting_strategy = "ascending", + layout_config = { + height = 0.5, + }, + results_title = "OGPT Acts As ...", + prompt_prefix = Config.options.popup_input.prompt, + selection_caret = Config.options.chat.answer_sign .. " ", + prompt_title = "Parameter", + finder = finder({ + parameters = params_order, + }), + sorter = conf.generic_sorter(opts), + attach_mappings = function(prompt_bufnr) + actions.select_default:replace(function() + actions.close(prompt_bufnr) + local selection = action_state.get_selected_entry() + opts.cb(selection.display, vim.fn.input("value: ")) + end) + return true + end, + }) + :find() +end + M.read_config = function(session) if not session then local home = os.getenv("HOME") or os.getenv("USERPROFILE") @@ -110,13 +205,7 @@ M.write_config = function(config, session) end end -M.get_parameters_panel = function(type, default_params, session) - M.type = type - local custom_params = M.read_config(session or {}) - M.params = vim.tbl_deep_extend("force", {}, default_params, custom_params or {}) - - M.panel = Popup(Config.options.parameters_window) - +M.refresh_panel = function() -- write details as virtual text local details = {} for _, key in pairs(params_order) do @@ -129,19 +218,36 @@ M.get_parameters_panel = function(type, default_params, session) end end + vim.api.nvim_buf_clear_namespace(M.panel.bufnr, namespace_id, 0, -1) + local line = 1 local empty_lines = {} for _ = 1, #details do table.insert(empty_lines, "") end + vim.api.nvim_buf_set_option(M.panel.bufnr, "modifiable", true) vim.api.nvim_buf_set_lines(M.panel.bufnr, line - 1, line - 1 + #empty_lines, false, empty_lines) + vim.api.nvim_buf_set_option(M.panel.bufnr, "modifiable", false) for _, d in ipairs(details) do M.vts[line - 1] = write_virtual_text(M.panel.bufnr, namespace_id, line - 1, d) line = line + 1 end +end - M.panel:map("n", "", function() +M.get_parameters_panel = function(type, default_params, session) + M.type = type + local custom_params = M.read_config(session or {}) + + M.params = vim.tbl_deep_extend("force", {}, default_params, custom_params or {}) + if session then + M.params = session.parameters + end + + M.panel = Popup(Config.options.parameters_window) + M.refresh_panel() + + M.panel:map("n", "d", function() local row, _ = unpack(vim.api.nvim_win_get_cursor(M.panel.winid)) local existing_order = {} @@ -152,30 +258,61 @@ M.get_parameters_panel = function(type, default_params, session) end local key = existing_order[row] - local value = M.params[key] + M.update_property(key, row, nil, session) + M.refresh_panel() + end) - M.open_edit_property_input(key, value, row, function(new_value) - M.params[key] = params_validators[key](new_value) - local vt = { + M.panel:map("n", "a", function() + local row, _ = unpack(vim.api.nvim_win_get_cursor(M.panel.winid)) + M.select_parameter({ + cb = function(key, value) + M.update_property(key, row + 1, value, session) + end, + }) + end) - { Config.options.parameters_window.setting_sign .. key .. ": ", "ErrorMsg" }, - { M.params[key] .. "", "Identifier" }, - } - vim.api.nvim_buf_del_extmark(M.panel.bufnr, namespace_id, M.vts[row - 1]) - M.vts[row - 1] = vim.api.nvim_buf_set_extmark( - M.panel.bufnr, - namespace_id, - row - 1, - 0, - { virt_text = vt, virt_text_pos = "overlay" } - ) - M.write_config(M.params, session) - end) + M.panel:map("n", "", function() + local row, _ = unpack(vim.api.nvim_win_get_cursor(M.panel.winid)) + + local existing_order = {} + for _, key in ipairs(params_order) do + if M.params[key] ~= nil then + table.insert(existing_order, key) + end + end + + local key = existing_order[row] + if key == "model" then + local models = require("ogpt.models") + models.select_model({ + cb = function(display, value) + M.update_property(key, row, value, session) + end, + }) + else + local value = M.params[key] + M.open_edit_property_input(key, value, row, function(new_value) + M.update_property(key, row, new_value, session) + end) + end end, {}) return M.panel end +M.update_property = function(key, row, new_value, session) + if not key or not new_value then + M.params[key] = nil + vim.api.nvim_buf_set_option(M.panel.bufnr, "modifiable", true) + vim.api.nvim_del_current_line() + vim.api.nvim_buf_set_option(M.panel.bufnr, "modifiable", false) + else + M.params[key] = params_validators[key](new_value) + end + M.write_config(M.params, session) + M.refresh_panel() +end + M.get_panel = function(session) return M.get_parameters_panel(" ", session.parameters or {}, session) end diff --git a/lua/ogpt/utils.lua b/lua/ogpt/utils.lua index f0559f4..261dba8 100644 --- a/lua/ogpt/utils.lua +++ b/lua/ogpt/utils.lua @@ -2,6 +2,42 @@ local M = {} local ESC_FEEDKEY = vim.api.nvim_replace_termcodes("", true, false, true) +M.ollama_options = { + "num_keep", + "seed", + "num_predict", + "top_k", + "top_p", + "tfs_z", + "typical_p", + "repeat_last_n", + "temperature", + "repeat_penalty", + "presence_penalty", + "frequency_penalty", + "mirostat", + "mirostat_tau", + "mirostat_eta", + "penalize_newline", + "stop", + "numa", + "num_ctx", + "num_batch", + "num_gqa", + "num_gpu", + "main_gpu", + "low_vram", + "f16_kv", + "logits_all", + "vocab_only", + "use_mmap", + "use_mlock", + "embedding_only", + "rope_frequency_base", + "rope_frequency_scale", + "num_thread", +} + function M.split(text) local t = {} for str in string.gmatch(text, "%S+") do @@ -160,46 +196,11 @@ function M._conform_to_ollama_api(params) } -- https://github.com/jmorganca/ollama/blob/main/docs/api.md#show-model-information - local ollama_options = { - "num_keep", - "seed", - "num_predict", - "top_k", - "top_p", - "tfs_z", - "typical_p", - "repeat_last_n", - "temperature", - "repeat_penalty", - "presence_penalty", - "frequency_penalty", - "mirostat", - "mirostat_tau", - "mirostat_eta", - "penalize_newline", - "stop", - "numa", - "num_ctx", - "num_batch", - "num_gqa", - "num_gpu", - "main_gpu", - "low_vram", - "f16_kv", - "logits_all", - "vocab_only", - "use_mmap", - "use_mlock", - "embedding_only", - "rope_frequency_base", - "rope_frequency_scale", - "num_thread", - } local param_options = {} for key, value in pairs(params) do - if not vim.tbl_contains(ollama_parameters, key) and vim.tbl_contains(ollama_options, key) then + if not vim.tbl_contains(ollama_parameters, key) and vim.tbl_contains(M.ollama_options, key) then param_options[key] = value params[key] = nil end