From 5de77eed7739b101b6c2639399852e0a582e24e5 Mon Sep 17 00:00:00 2001 From: Huy Le Date: Wed, 13 Nov 2024 13:44:00 -0700 Subject: [PATCH] Fixing openai --- lua/ogpt/provider/openai.lua | 37 ++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/lua/ogpt/provider/openai.lua b/lua/ogpt/provider/openai.lua index 5bbb617..9cf9f62 100644 --- a/lua/ogpt/provider/openai.lua +++ b/lua/ogpt/provider/openai.lua @@ -10,13 +10,8 @@ function Openai:init(opts) self.api_parameters = { "model", "messages", - "prompt", "stream", "temperature", - "presence_penalty", - "frequency_penalty", - "top_p", - "max_tokens", } self.api_chat_request_options = {} end @@ -51,16 +46,46 @@ function Openai:parse_api_model_response(res, cb) end function Openai:conform_request(params) + local _to_remove_system_idx = {} + for idx, message in ipairs(params.messages) do + if message.role == "system" then + table.insert(_to_remove_system_idx, idx) + end + end + + -- Remove elements from the list based on indices + for i = #_to_remove_system_idx, 1, -1 do + table.remove(params.messages, _to_remove_system_idx[i]) + end + + -- conform to support text only model + local messages = params.messages + local conformed_messages = {} + for _, message in ipairs(messages) do + table.insert(conformed_messages, { + role = message.role, + content = utils.gather_text_from_parts(message.content), + }) + end + + -- Insert the updated params.system string at the beginning of conformed_messages if params.system then - params["prompt"] = params.system + table.insert(conformed_messages, 1, { + role = "system", + content = params.system, + }) end + params.messages = conformed_messages + + -- general clean up, remove things that shouldnt be here for key, value in pairs(params) do if not vim.tbl_contains(self.api_parameters, key) then utils.log("Did not process " .. key .. " for " .. self.name) params[key] = nil end end + return params end