From a3e930eb4a1f54161e2244f88170d320b7148d63 Mon Sep 17 00:00:00 2001 From: nabrown Date: Sat, 14 Sep 2024 14:15:43 -0400 Subject: [PATCH 1/7] Pulling the chat history out of the character component and into LLMChatHistory --- Runtime/LLMCharacter.cs | 213 ++++++------------ Runtime/LLMChatHistory.cs | 139 ++++++++++++ Runtime/LLMChatHistory.cs.meta | 11 + Runtime/LLMChatTemplates.cs | 6 +- Runtime/LLMConstants.cs | 6 + Runtime/LLMConstants.cs.meta | 11 + Samples~/AndroidDemo/Scene.unity | 27 ++- Samples~/ChatBot/Scene.unity | 27 ++- .../KnowledgeBaseGame/KnowledgeBaseGame.cs | 2 +- Samples~/KnowledgeBaseGame/Scene.unity | 37 +-- Samples~/MultipleCharacters/Scene.unity | 51 ++++- Samples~/SimpleInteraction/Scene.unity | 27 ++- .../SimpleInteraction/SimpleInteraction.cs | 4 +- Tests/Runtime/TestLLM.cs | 60 ++--- Tests/Runtime/TestLLMChatHistory.cs | 64 ++++++ Tests/Runtime/TestLLMChatHistory.cs.meta | 11 + 16 files changed, 465 insertions(+), 231 deletions(-) create mode 100644 Runtime/LLMChatHistory.cs create mode 100644 Runtime/LLMChatHistory.cs.meta create mode 100644 Runtime/LLMConstants.cs create mode 100644 Runtime/LLMConstants.cs.meta create mode 100644 Tests/Runtime/TestLLMChatHistory.cs create mode 100644 Tests/Runtime/TestLLMChatHistory.cs.meta diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 8b7ee452..21ca102d 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -32,10 +32,8 @@ public class LLMCharacter : MonoBehaviour [Remote] public int numRetries = 10; /// allows to use a server with API key [Remote] public string APIKey; - /// file to save the chat history. - /// The file is saved only for Chat calls with addToHistory set to true. - /// The file will be saved within the persistentDataPath directory (see https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html). - [LLM] public string save = ""; + /// file to save the cache. + [LLM] public string cacheFilename = ""; /// toggle to save the LLM cache. This speeds up the prompt calculation but also requires ~100MB of space per character. [LLM] public bool saveCache = false; /// select to log the constructed prompt the Unity Editor. @@ -47,8 +45,6 @@ public class LLMCharacter : MonoBehaviour [ModelAdvanced] public string grammar = null; /// option to cache the prompt as it is being created by the chat to avoid reprocessing the entire prompt every time (default: true) [ModelAdvanced] public bool cachePrompt = true; - /// specify which slot of the server to use for computation (affects caching) - [ModelAdvanced] public int slot = -1; /// seed for reproducibility. For random results every time set to -1. [ModelAdvanced] public int seed = 0; /// number of tokens to predict (-1 = infinity, -2 = until context filled). @@ -113,21 +109,21 @@ public class LLMCharacter : MonoBehaviour public Dictionary logitBias = null; /// the name of the player - [Chat] public string playerName = "user"; + [Chat] public string playerRole = "user"; /// the name of the AI - [Chat] public string AIName = "assistant"; + [Chat] public string aiRole = "assistant"; /// a description of the AI role. This defines the LLMCharacter system prompt - [TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."; + [TextArea(5, 10), Chat] public string systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."; /// option to set the number of tokens to retain from the prompt (nKeep) based on the LLMCharacter system prompt public bool setNKeepToPrompt = true; /// \cond HIDE - public List chat; - private SemaphoreSlim chatLock = new SemaphoreSlim(1, 1); + private LLMChatHistory chatHistory; private string chatTemplate; private ChatTemplate template = null; public string grammarString; - private List<(string, string)> requestHeaders; + protected int id_slot = -1; + private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") }; private List WIPRequests = new List(); /// \endcond @@ -140,12 +136,10 @@ public class LLMCharacter : MonoBehaviour /// - the chat template is constructed /// - the number of tokens to keep are based on the system prompt (if setNKeepToPrompt=true) /// - public void Awake() + public async void Awake() { // Start the LLM server in a cross-platform way if (!enabled) return; - - requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") }; if (!remote) { AssignLLM(); @@ -154,22 +148,17 @@ public void Awake() LLMUnitySetup.LogError($"No LLM assigned or detected for LLMCharacter {name}!"); return; } - int slotFromServer = llm.Register(this); - if (slot == -1) slot = slotFromServer; - } - else - { - if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add(("Authorization", "Bearer " + APIKey)); + id_slot = llm.Register(this); } InitGrammar(); InitHistory(); + await LoadCache(); } void OnValidate() { AssignLLM(); - if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set"); } void Reset() @@ -217,58 +206,18 @@ void SortBySceneAndHierarchy(LLM[] array) protected void InitHistory() { - InitPrompt(); - _ = LoadHistory(); - } + // Check if we have a chat history component available + chatHistory = GetComponent(); - protected async Task LoadHistory() - { - if (save == "" || !File.Exists(GetJsonSavePath(save))) return; - await chatLock.WaitAsync(); // Acquire the lock - try - { - await Load(save); - } - finally - { - chatLock.Release(); // Release the lock + // If not, go ahead and create one. + if (chatHistory == null) { + chatHistory = gameObject.AddComponent(); } } - public virtual string GetSavePath(string filename) - { - return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/'); - } - - public virtual string GetJsonSavePath(string filename) - { - return GetSavePath(filename + ".json"); - } - - public virtual string GetCacheSavePath(string filename) - { - return GetSavePath(filename + ".cache"); - } - - private void InitPrompt(bool clearChat = true) + public virtual string GetCacheSavePath() { - if (chat != null) - { - if (clearChat) chat.Clear(); - } - else - { - chat = new List(); - } - ChatMessage promptMessage = new ChatMessage { role = "system", content = prompt }; - if (chat.Count == 0) - { - chat.Add(promptMessage); - } - else - { - chat[0] = promptMessage; - } + return Path.Combine(Application.persistentDataPath, cacheFilename + ".cache").Replace('\\', '/'); } /// @@ -276,11 +225,15 @@ private void InitPrompt(bool clearChat = true) /// /// the system prompt /// whether to clear (true) or keep (false) the current chat history on top of the system prompt. - public void SetPrompt(string newPrompt, bool clearChat = true) + public async Task SetPrompt(string newPrompt, bool clearChat = true) { - prompt = newPrompt; + systemPrompt = newPrompt; nKeep = -1; - InitPrompt(clearChat); + + if (clearChat) { + // Clear any existing messages + await chatHistory?.Clear(); + } } private bool CheckTemplate() @@ -293,12 +246,16 @@ private bool CheckTemplate() return true; } + private ChatMessage GetSystemPromptMessage() { + return new ChatMessage() { role = LLMConstants.SYSTEM_ROLE, content = systemPrompt }; + } + private async Task InitNKeep() { if (setNKeepToPrompt && nKeep == -1) { if (!CheckTemplate()) return false; - string systemPrompt = template.ComputePrompt(new List(){chat[0]}, playerName, "", false); + string systemPrompt = template.ComputePrompt(new List(){GetSystemPromptMessage()}, playerRole, aiRole, false); List tokens = await Tokenize(systemPrompt); if (tokens == null) return false; SetNKeep(tokens); @@ -360,7 +317,7 @@ public async void SetGrammar(string path) List GetStopwords() { if (!CheckTemplate()) return null; - List stopAll = new List(template.GetStop(playerName, AIName)); + List stopAll = new List(template.GetStop(playerRole, aiRole)); if (stop != null) stopAll.AddRange(stop); return stopAll; } @@ -371,7 +328,7 @@ ChatRequest GenerateRequest(string prompt) ChatRequest chatRequest = new ChatRequest(); if (debugPrompt) LLMUnitySetup.Log(prompt); chatRequest.prompt = prompt; - chatRequest.id_slot = slot; + chatRequest.id_slot = id_slot; chatRequest.temperature = temperature; chatRequest.top_k = topK; chatRequest.top_p = topP; @@ -400,20 +357,19 @@ ChatRequest GenerateRequest(string prompt) return chatRequest; } - public void AddMessage(string role, string content) + public async Task AddPlayerMessage(string content) { - // add the question / answer to the chat list, update prompt - chat.Add(new ChatMessage { role = role, content = content }); + await chatHistory.AddMessage(playerRole, content); } - public void AddPlayerMessage(string content) + public async Task AddAIMessage(string content) { - AddMessage(playerName, content); + await chatHistory.AddMessage(aiRole, content); } - public void AddAIMessage(string content) + public LLMChatHistory GetChatHistory() { - AddMessage(AIName, content); + return chatHistory; } protected string ChatContent(ChatResult result) @@ -490,44 +446,33 @@ protected string SlotContent(SlotResult result) /// the LLM response public async Task Chat(string query, Callback callback = null, EmptyCallback completionCallback = null, bool addToHistory = true) { - // handle a chat message by the user - // call the callback function while the answer is received - // call the completionCallback function when the answer is fully received await LoadTemplate(); if (!CheckTemplate()) return null; if (!await InitNKeep()) return null; + + var playerMessage = new ChatMessage() { role = playerRole, content = query }; - string json; - await chatLock.WaitAsync(); - try - { - AddPlayerMessage(query); - string prompt = template.ComputePrompt(chat, playerName, AIName); - json = JsonUtility.ToJson(GenerateRequest(prompt)); - chat.RemoveAt(chat.Count - 1); - } - finally - { - chatLock.Release(); - } + // Setup the full list of messages for the current request + List promptMessages = chatHistory ? chatHistory.GetChatMessages() : new List(); + promptMessages.Insert(0, GetSystemPromptMessage()); + promptMessages.Add(playerMessage); - string result = await CompletionRequest(json, callback); + // Prepare the request + string formattedPrompt = template.ComputePrompt(promptMessages, playerRole, aiRole); + string requestJson = JsonUtility.ToJson(GenerateRequest(formattedPrompt)); - if (addToHistory && result != null) + // Call the LLM + string result = await CompletionRequest(requestJson, callback); + + // Update our chat history if required + if (addToHistory && chatHistory && result != null) { - await chatLock.WaitAsync(); - try - { - AddPlayerMessage(query); - AddAIMessage(result); - } - finally - { - chatLock.Release(); - } - if (save != "") _ = Save(save); + await AddPlayerMessage(query); + await AddAIMessage(result); } + await SaveCache(); + completionCallback?.Invoke(); return result; } @@ -623,10 +568,10 @@ public async Task> Embeddings(string query, Callback> ca return await PostRequest>(json, "embeddings", EmbeddingsContent, callback); } - protected async Task Slot(string filepath, string action) + private async Task Slot(string filepath, string action) { SlotRequest slotRequest = new SlotRequest(); - slotRequest.id_slot = slot; + slotRequest.id_slot = id_slot; slotRequest.filepath = filepath; slotRequest.action = action; string json = JsonUtility.ToJson(slotRequest); @@ -634,46 +579,26 @@ protected async Task Slot(string filepath, string action) } /// - /// Saves the chat history and cache to the provided filename / relative path. + /// Saves the cache to the provided filename / relative path. /// - /// filename / relative path to save the chat history + /// filename / relative path to save the cache /// - public virtual async Task Save(string filename) + public virtual async Task SaveCache() { - string filepath = GetJsonSavePath(filename); - string dirname = Path.GetDirectoryName(filepath); - if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname); - string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) }); - File.WriteAllText(filepath, json); - - string cachepath = GetCacheSavePath(filename); if (remote || !saveCache) return null; - string result = await Slot(cachepath, "save"); + string result = await Slot(GetCacheSavePath(), "save"); return result; } /// - /// Load the chat history and cache from the provided filename / relative path. + /// Load the cache from the provided filename / relative path. /// - /// filename / relative path to load the chat history from + /// filename / relative path to load the cache from /// - public virtual async Task Load(string filename) + public virtual async Task LoadCache() { - string filepath = GetJsonSavePath(filename); - if (!File.Exists(filepath)) - { - LLMUnitySetup.LogError($"File {filepath} does not exist."); - return null; - } - string json = File.ReadAllText(filepath); - List chatHistory = JsonUtility.FromJson(json).chat; - InitPrompt(true); - chat.AddRange(chatHistory); - LLMUnitySetup.Log($"Loaded {filepath}"); - - string cachepath = GetCacheSavePath(filename); - if (remote || !saveCache || !File.Exists(GetSavePath(cachepath))) return null; - string result = await Slot(cachepath, "restore"); + if (remote || !saveCache || !File.Exists(GetCacheSavePath())) return null; + string result = await Slot(GetCacheSavePath(), "restore"); return result; } @@ -698,7 +623,7 @@ protected Ret ConvertContent(string response, ContentCallback= 0) llm.CancelRequest(slot); + if (id_slot >= 0) llm.CancelRequest(id_slot); } protected void CancelRequestsRemote() @@ -826,11 +751,9 @@ protected async Task PostRequestRemote(string json, string endpoi { result = default; error = request.error; - if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break; } } tryNr--; - if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr)); } if (error != null) LLMUnitySetup.LogError(error); diff --git a/Runtime/LLMChatHistory.cs b/Runtime/LLMChatHistory.cs new file mode 100644 index 00000000..7e560909 --- /dev/null +++ b/Runtime/LLMChatHistory.cs @@ -0,0 +1,139 @@ +/// @file +/// @brief File implementing the LLMChatHistory. +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using UnityEngine; + +namespace LLMUnity +{ + /// @ingroup llm + /// + /// Manages a single instance of a chat history. + /// + public class LLMChatHistory : MonoBehaviour + { + /// + /// The name of the file where this chat history will be saved. + /// The file will be saved within the persistentDataPath directory (see https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html). + /// + public string ChatHistoryFilename = string.Empty; + + /// + /// The current chat history + /// + protected List _chatHistory; + + /// + /// Ensures we're not trying to update the chat while saving or loading + /// + protected SemaphoreSlim chatLock = new SemaphoreSlim(1, 1); + + /// + /// The Unity Awake function that initializes the state before the application starts. + /// + public async void Awake() + { + // If a filename has been provided for the chat history, attempt to load it + if (ChatHistoryFilename != string.Empty) { + await Load(); + } + else { + _chatHistory = new List(); + } + } + + /// + /// Appends a new message to the end of this chat. + /// + public async Task AddMessage(string role, string content) + { + await WithChatLock(async () => { + await Task.Run(() => _chatHistory.Add(new ChatMessage { role = role, content = content })); + }); + + // Save our newly updated chat history to the file system + _ = Save(); + } + + public List GetChatMessages() { + return new List(_chatHistory); + } + + /// + /// Saves the chat history to the file system + /// + public async Task Save() + { + // If no filename has been provided, create one + if (ChatHistoryFilename == string.Empty) { + ChatHistoryFilename = $"chat_{Guid.NewGuid()}"; + } + + string filePath = GetChatHistoryFilePath(); + string directoryName = Path.GetDirectoryName(filePath); + + // Ensure the directory exists + if (!Directory.Exists(directoryName)) Directory.CreateDirectory(directoryName); + + // Save the chat history as json + await WithChatLock(async () => { + string json = JsonUtility.ToJson(new ChatListWrapper { chat = _chatHistory }); + await File.WriteAllTextAsync(filePath, json); + }); + } + + /// + /// Load the chat history from the file system + /// + public async Task Load() + { + string filePath = GetChatHistoryFilePath(); + + if (!File.Exists(filePath)) + { + LLMUnitySetup.LogError($"File {filePath} does not exist."); + return; + } + + // Load the chat from the json file + await WithChatLock(async () => { + string json = await File.ReadAllTextAsync(filePath); + _chatHistory = JsonUtility.FromJson(json).chat; + LLMUnitySetup.Log($"Loaded {filePath}"); + }); + } + + /// + /// Clears out the current chat history. + /// + public async Task Clear() { + await WithChatLock(async () => { + await Task.Run(() => _chatHistory.Clear()); + }); + + _ = Save(); + } + + public bool IsEmpty() { + return _chatHistory?.Count == 0; + } + + protected string GetChatHistoryFilePath() + { + return Path.Combine(Application.persistentDataPath, ChatHistoryFilename + ".json").Replace('\\', '/'); + } + + protected async Task WithChatLock(Func action) { + await chatLock.WaitAsync(); + try { + await action(); + } + finally { + chatLock.Release(); + } + } + } +} \ No newline at end of file diff --git a/Runtime/LLMChatHistory.cs.meta b/Runtime/LLMChatHistory.cs.meta new file mode 100644 index 00000000..3ca7ffbd --- /dev/null +++ b/Runtime/LLMChatHistory.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: b9aef079d11e8894bae3ae510742c32f +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/LLMChatTemplates.cs b/Runtime/LLMChatTemplates.cs index 0820078d..ec092c8c 100644 --- a/Runtime/LLMChatTemplates.cs +++ b/Runtime/LLMChatTemplates.cs @@ -174,7 +174,7 @@ public virtual string ComputePrompt(List messages, string playerNam { string chatPrompt = PromptPrefix(); int start = 0; - if (messages[0].role == "system") + if (messages[0].role == LLMConstants.SYSTEM_ROLE) { chatPrompt += RequestPrefix() + SystemPrefix() + messages[0].content + SystemSuffix(); start = 1; @@ -356,7 +356,7 @@ public class GemmaTemplate : ChatTemplate public override string ComputePrompt(List messages, string playerName, string AIName, bool endWithPrefix = true) { List messagesSystemPrompt = messages; - if (messages[0].role == "system") + if (messages[0].role == LLMConstants.SYSTEM_ROLE) { string firstUserMessage = messages[0].content; int start = 1; @@ -466,7 +466,7 @@ public class Phi3Template : ChatTemplate public override string ComputePrompt(List messages, string playerName, string AIName, bool endWithPrefix = true) { List messagesSystemPrompt = messages; - if (messages[0].role == "system") + if (messages[0].role == LLMConstants.SYSTEM_ROLE) { string firstUserMessage = messages[0].content; int start = 1; diff --git a/Runtime/LLMConstants.cs b/Runtime/LLMConstants.cs new file mode 100644 index 00000000..082a696f --- /dev/null +++ b/Runtime/LLMConstants.cs @@ -0,0 +1,6 @@ +namespace LLMUnity { + static class LLMConstants { + + public const string SYSTEM_ROLE = "system"; + } +} \ No newline at end of file diff --git a/Runtime/LLMConstants.cs.meta b/Runtime/LLMConstants.cs.meta new file mode 100644 index 00000000..a2312884 --- /dev/null +++ b/Runtime/LLMConstants.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 54899fb2b22da0d4480f36b31baa3536 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Samples~/AndroidDemo/Scene.unity b/Samples~/AndroidDemo/Scene.unity index d7a58ae7..62a53cda 100644 --- a/Samples~/AndroidDemo/Scene.unity +++ b/Samples~/AndroidDemo/Scene.unity @@ -38,7 +38,6 @@ RenderSettings: m_ReflectionIntensity: 1 m_CustomReflection: {fileID: 0} m_Sun: {fileID: 0} - m_IndirectSpecularColor: {r: 0.44657832, g: 0.49641222, b: 0.57481664, a: 1} m_UseRadianceAmbientProbe: 0 --- !u!157 &3 LightmapSettings: @@ -654,6 +653,7 @@ GameObject: m_Component: - component: {fileID: 498662972} - component: {fileID: 498662973} + - component: {fileID: 498662974} m_Layer: 0 m_Name: LLMCharacter m_TagString: Untagged @@ -693,7 +693,8 @@ MonoBehaviour: llm: {fileID: 1047848254} host: localhost port: 13333 - save: + numRetries: -1 + cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 @@ -720,13 +721,25 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerName: user - AIName: assistant - prompt: A chat between a curious human and an artificial intelligence assistant. + playerRole: user + aiRole: assistant + systemPrompt: A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. setNKeepToPrompt: 1 - chat: [] grammarString: +--- !u!114 &498662974 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 498662970} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} + m_Name: + m_EditorClassIdentifier: + ChatHistoryFilename: --- !u!1 &659217390 GameObject: m_ObjectHideFlags: 0 @@ -1335,6 +1348,8 @@ MonoBehaviour: model: chatTemplate: chatml lora: + loraWeights: + flashAttention: 0 --- !u!4 &1047848255 Transform: m_ObjectHideFlags: 0 diff --git a/Samples~/ChatBot/Scene.unity b/Samples~/ChatBot/Scene.unity index bea05e22..dbecb667 100644 --- a/Samples~/ChatBot/Scene.unity +++ b/Samples~/ChatBot/Scene.unity @@ -38,7 +38,6 @@ RenderSettings: m_ReflectionIntensity: 1 m_CustomReflection: {fileID: 0} m_Sun: {fileID: 0} - m_IndirectSpecularColor: {r: 0.44657832, g: 0.49641222, b: 0.57481664, a: 1} m_UseRadianceAmbientProbe: 0 --- !u!157 &3 LightmapSettings: @@ -671,6 +670,8 @@ MonoBehaviour: model: chatTemplate: chatml lora: + loraWeights: + flashAttention: 0 --- !u!1 &1051131186 GameObject: m_ObjectHideFlags: 0 @@ -1092,6 +1093,7 @@ GameObject: m_Component: - component: {fileID: 1844795170} - component: {fileID: 1844795171} + - component: {fileID: 1844795172} m_Layer: 0 m_Name: LLMCharacter m_TagString: Untagged @@ -1131,7 +1133,8 @@ MonoBehaviour: llm: {fileID: 817827756} host: localhost port: 13333 - save: + numRetries: -1 + cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 @@ -1158,13 +1161,25 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerName: user - AIName: assistant - prompt: A chat between a curious human and an artificial intelligence assistant. + playerRole: user + aiRole: assistant + systemPrompt: A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. setNKeepToPrompt: 1 - chat: [] grammarString: +--- !u!114 &1844795172 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 1844795168} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} + m_Name: + m_EditorClassIdentifier: + ChatHistoryFilename: --- !u!1 &2011827136 GameObject: m_ObjectHideFlags: 0 diff --git a/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs b/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs index e6f3bc29..67509a6e 100644 --- a/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs +++ b/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs @@ -145,7 +145,7 @@ protected override void DropdownChange(int selection) Debug.Log($"{currentBotName}: {currentBot.NumPhrases()} phrases available"); // set the LLMCharacter name - llmCharacter.AIName = currentBotName; + llmCharacter.aiRole = currentBotName; } void SetAIText(string text) diff --git a/Samples~/KnowledgeBaseGame/Scene.unity b/Samples~/KnowledgeBaseGame/Scene.unity index ff10abb9..5f43dbe0 100644 --- a/Samples~/KnowledgeBaseGame/Scene.unity +++ b/Samples~/KnowledgeBaseGame/Scene.unity @@ -38,7 +38,6 @@ RenderSettings: m_ReflectionIntensity: 1 m_CustomReflection: {fileID: 0} m_Sun: {fileID: 0} - m_IndirectSpecularColor: {r: 0.44657832, g: 0.49641222, b: 0.57481664, a: 1} m_UseRadianceAmbientProbe: 0 --- !u!157 &3 LightmapSettings: @@ -4260,6 +4259,7 @@ GameObject: m_Component: - component: {fileID: 1275496424} - component: {fileID: 1275496423} + - component: {fileID: 1275496425} m_Layer: 0 m_Name: LLMCharacter m_TagString: Untagged @@ -4284,7 +4284,8 @@ MonoBehaviour: llm: {fileID: 2142407556} host: localhost port: 13333 - save: + numRetries: -1 + cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 @@ -4311,18 +4312,13 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerName: Detective - AIName: - prompt: 'You are a robot working at a house where a diamond was stolen and a detective - asks you questions about the robbery. - - Answer the question provided at the - section "Question" based on the possible answers at the section "Answers". - - Rephrase - the answer, do not copy it directly.' + playerRole: Detective + aiRole: + systemPrompt: "You are a robot working at a house where a diamond was stolen and + a detective\r asks you questions about the robbery.\r\n\r\nAnswer the question + provided at the\r section \"Question\" based on the possible answers at the section + \"Answers\".\r\n\r\nRephrase\r the answer, do not copy it directly.'" setNKeepToPrompt: 1 - chat: [] grammarString: --- !u!4 &1275496424 Transform: @@ -4339,6 +4335,19 @@ Transform: m_Children: [] m_Father: {fileID: 0} m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} +--- !u!114 &1275496425 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 1275496422} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} + m_Name: + m_EditorClassIdentifier: + ChatHistoryFilename: --- !u!1 &1278063793 GameObject: m_ObjectHideFlags: 0 @@ -7548,6 +7557,8 @@ MonoBehaviour: model: chatTemplate: chatml lora: + loraWeights: + flashAttention: 0 --- !u!4 &2142407557 Transform: m_ObjectHideFlags: 0 diff --git a/Samples~/MultipleCharacters/Scene.unity b/Samples~/MultipleCharacters/Scene.unity index 75117583..709ce3e8 100644 --- a/Samples~/MultipleCharacters/Scene.unity +++ b/Samples~/MultipleCharacters/Scene.unity @@ -38,7 +38,6 @@ RenderSettings: m_ReflectionIntensity: 1 m_CustomReflection: {fileID: 0} m_Sun: {fileID: 0} - m_IndirectSpecularColor: {r: 0.44657832, g: 0.49641222, b: 0.57481664, a: 1} m_UseRadianceAmbientProbe: 0 --- !u!157 &3 LightmapSettings: @@ -556,6 +555,7 @@ GameObject: m_Component: - component: {fileID: 714802013} - component: {fileID: 714802014} + - component: {fileID: 714802015} m_Layer: 0 m_Name: LLMCharacter2 m_TagString: Untagged @@ -595,7 +595,8 @@ MonoBehaviour: llm: {fileID: 1047848254} host: localhost port: 13333 - save: + numRetries: -1 + cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 @@ -622,14 +623,26 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerName: Human - AIName: Adam - prompt: A chat between a curious human and an artificial intelligence assistant + playerRole: Human + aiRole: Adam + systemPrompt: A chat between a curious human and an artificial intelligence assistant named Adam. The assistant gives helpful, detailed, and polite answers to the human's questions. setNKeepToPrompt: 1 - chat: [] grammarString: +--- !u!114 &714802015 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 714802011} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} + m_Name: + m_EditorClassIdentifier: + ChatHistoryFilename: --- !u!1 &726528676 GameObject: m_ObjectHideFlags: 0 @@ -1527,6 +1540,8 @@ MonoBehaviour: model: chatTemplate: chatml lora: + loraWeights: + flashAttention: 0 --- !u!4 &1047848255 Transform: m_ObjectHideFlags: 0 @@ -1922,6 +1937,7 @@ GameObject: m_Component: - component: {fileID: 1493015759} - component: {fileID: 1493015760} + - component: {fileID: 1493015761} m_Layer: 0 m_Name: LLMCharacter1 m_TagString: Untagged @@ -1961,7 +1977,8 @@ MonoBehaviour: llm: {fileID: 1047848254} host: localhost port: 13333 - save: + numRetries: -1 + cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 @@ -1988,14 +2005,26 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerName: Human - AIName: Eve - prompt: A chat between a curious human and an artificial intelligence assistant + playerRole: Human + aiRole: Eve + systemPrompt: A chat between a curious human and an artificial intelligence assistant named Eve. The assistant gives helpful, detailed, and polite answers to the human's questions. setNKeepToPrompt: 1 - chat: [] grammarString: +--- !u!114 &1493015761 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 1493015757} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} + m_Name: + m_EditorClassIdentifier: + ChatHistoryFilename: --- !u!1 &1609985808 GameObject: m_ObjectHideFlags: 0 diff --git a/Samples~/SimpleInteraction/Scene.unity b/Samples~/SimpleInteraction/Scene.unity index b4e5e246..4a2dbf79 100644 --- a/Samples~/SimpleInteraction/Scene.unity +++ b/Samples~/SimpleInteraction/Scene.unity @@ -38,7 +38,6 @@ RenderSettings: m_ReflectionIntensity: 1 m_CustomReflection: {fileID: 0} m_Sun: {fileID: 0} - m_IndirectSpecularColor: {r: 0.44657832, g: 0.49641222, b: 0.57481664, a: 1} m_UseRadianceAmbientProbe: 0 --- !u!157 &3 LightmapSettings: @@ -444,6 +443,7 @@ GameObject: m_Component: - component: {fileID: 498662972} - component: {fileID: 498662973} + - component: {fileID: 498662974} m_Layer: 0 m_Name: LLMCharacter m_TagString: Untagged @@ -483,7 +483,8 @@ MonoBehaviour: llm: {fileID: 1047848254} host: localhost port: 13333 - save: + numRetries: -1 + cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 @@ -510,13 +511,25 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerName: user - AIName: assistant - prompt: A chat between a curious human and an artificial intelligence assistant. + playerRole: user + aiRole: assistant + systemPrompt: A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. setNKeepToPrompt: 1 - chat: [] grammarString: +--- !u!114 &498662974 +MonoBehaviour: + m_ObjectHideFlags: 0 + m_CorrespondingSourceObject: {fileID: 0} + m_PrefabInstance: {fileID: 0} + m_PrefabAsset: {fileID: 0} + m_GameObject: {fileID: 498662970} + m_Enabled: 1 + m_EditorHideFlags: 0 + m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} + m_Name: + m_EditorClassIdentifier: + ChatHistoryFilename: --- !u!1 &724531319 GameObject: m_ObjectHideFlags: 0 @@ -1050,6 +1063,8 @@ MonoBehaviour: model: chatTemplate: chatml lora: + loraWeights: + flashAttention: 0 --- !u!4 &1047848255 Transform: m_ObjectHideFlags: 0 diff --git a/Samples~/SimpleInteraction/SimpleInteraction.cs b/Samples~/SimpleInteraction/SimpleInteraction.cs index 9c088c3b..7c9a2cbc 100644 --- a/Samples~/SimpleInteraction/SimpleInteraction.cs +++ b/Samples~/SimpleInteraction/SimpleInteraction.cs @@ -16,11 +16,11 @@ void Start() playerText.Select(); } - void onInputFieldSubmit(string message) + async void onInputFieldSubmit(string message) { playerText.interactable = false; AIText.text = "..."; - _ = llmCharacter.Chat(message, SetAIText, AIReplyComplete); + await llmCharacter.Chat(message, SetAIText, AIReplyComplete); } public void SetAIText(string text) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index 747e8af6..288c3dd0 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -77,6 +77,7 @@ public class TestLLM protected GameObject gameObject; protected LLM llm; protected LLMCharacter llmCharacter; + protected LLMChatHistory llmChatHistory; protected Exception error = null; protected string prompt; protected string query; @@ -192,9 +193,9 @@ public virtual LLMCharacter CreateLLMCharacter() { LLMCharacter llmCharacter = gameObject.AddComponent(); llmCharacter.llm = llm; - llmCharacter.playerName = "Instruction"; - llmCharacter.AIName = "Response"; - llmCharacter.prompt = prompt; + llmCharacter.playerRole = "Instruction"; + llmCharacter.aiRole = "Response"; + llmCharacter.systemPrompt = prompt; llmCharacter.temperature = 0; llmCharacter.seed = 0; llmCharacter.stream = false; @@ -233,29 +234,29 @@ public virtual async Task Tests() { await llmCharacter.Tokenize("I", TestTokens); await llmCharacter.Warmup(); - TestInitParameters(tokens1, 1); + TestInitParameters(tokens1, 0); TestWarmup(); await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply1)); TestPostChat(3); - llmCharacter.SetPrompt(llmCharacter.prompt); - llmCharacter.AIName = "False response"; + await llmCharacter.SetPrompt(llmCharacter.systemPrompt); + llmCharacter.aiRole = "False response"; await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply2)); TestPostChat(3); await llmCharacter.Chat("bye!"); TestPostChat(5); prompt = "How are you?"; - llmCharacter.SetPrompt(prompt); + await llmCharacter.SetPrompt(prompt); await llmCharacter.Chat("hi"); - TestInitParameters(tokens2, 3); + TestInitParameters(tokens2, 2); List embeddings = await llmCharacter.Embeddings("hi how are you?"); TestEmbeddings(embeddings); } - public void TestInitParameters(int nkeep, int chats) + public void TestInitParameters(int nKeep, int expectedMessageCount) { - Assert.AreEqual(llmCharacter.nKeep, nkeep); - Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerName, llmCharacter.AIName).Length > 0); - Assert.AreEqual(llmCharacter.chat.Count, chats); + Assert.AreEqual(llmCharacter.nKeep, nKeep); + Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerRole, llmCharacter.aiRole).Length > 0); + Assert.AreEqual(llmCharacter.GetChatHistory().GetChatMessages().Count, expectedMessageCount); } public void TestTokens(List tokens) @@ -265,18 +266,18 @@ public void TestTokens(List tokens) public void TestWarmup() { - Assert.That(llmCharacter.chat.Count == 1); + //Assert.That(llmCharacter.chat.Count == 1); } - public void TestChat(string reply, string replyGT) + public void TestChat(string generatedReply, string expectedReply) { - Debug.Log(reply.Trim()); - Assert.That(reply.Trim() == replyGT); + Debug.Log(generatedReply.Trim()); + Assert.That(generatedReply.Trim() == expectedReply); } public void TestPostChat(int num) { - Assert.That(llmCharacter.chat.Count == num); + //Assert.That(llmCharacter.chat.Count == num); } public void TestEmbeddings(List embeddings) @@ -449,7 +450,7 @@ public class TestLLMCharacter_Save : TestLLM public override LLMCharacter CreateLLMCharacter() { LLMCharacter llmCharacter = base.CreateLLMCharacter(); - llmCharacter.save = saveName; + llmCharacter.cacheFilename = saveName; llmCharacter.saveCache = true; return llmCharacter; } @@ -457,31 +458,14 @@ public override LLMCharacter CreateLLMCharacter() public override async Task Tests() { await base.Tests(); - TestSave(); + TestSaveCache(); } - public void TestSave() + public void TestSaveCache() { - string jsonPath = llmCharacter.GetJsonSavePath(saveName); - string cachePath = llmCharacter.GetCacheSavePath(saveName); - Assert.That(File.Exists(jsonPath)); + string cachePath = llmCharacter.GetCacheSavePath(); Assert.That(File.Exists(cachePath)); - string json = File.ReadAllText(jsonPath); - File.Delete(jsonPath); File.Delete(cachePath); - - List chatHistory = JsonUtility.FromJson(json).chat; - Assert.AreEqual(chatHistory.Count, 2); - Assert.AreEqual(chatHistory[0].role, llmCharacter.playerName); - Assert.AreEqual(chatHistory[0].content, "hi"); - Assert.AreEqual(chatHistory[1].role, llmCharacter.AIName); - - Assert.AreEqual(llmCharacter.chat.Count, chatHistory.Count + 1); - for (int i = 0; i < chatHistory.Count; i++) - { - Assert.AreEqual(chatHistory[i].role, llmCharacter.chat[i + 1].role); - Assert.AreEqual(chatHistory[i].content, llmCharacter.chat[i + 1].content); - } } } } diff --git a/Tests/Runtime/TestLLMChatHistory.cs b/Tests/Runtime/TestLLMChatHistory.cs new file mode 100644 index 00000000..1205bd14 --- /dev/null +++ b/Tests/Runtime/TestLLMChatHistory.cs @@ -0,0 +1,64 @@ +using System.Collections.Generic; +using System.Threading.Tasks; +using LLMUnity; +using NUnit.Framework; +using UnityEngine; +using UnityEngine.TestTools; + +namespace LLMUnityTests +{ + public class TestLLMChatHistory + { + private GameObject _gameObject; + private LLMChatHistory _chatHistory; + + + [SetUp] + public void Setup() + { + // Create a new GameObject + _gameObject = new GameObject("TestObject"); + + // Add the component X to the GameObject + _chatHistory = _gameObject.AddComponent(); + } + + [Test] + public async void TestSaveAndLoad() + { + // 1. ARRANGE + // Add a few messages to save + await _chatHistory.AddMessage("user", "hello"); + await _chatHistory.AddMessage("ai", "hi"); + + // Save them off and grab the generated filename (since we didn't supply one) + await _chatHistory.Save(); + string filename = _chatHistory.ChatHistoryFilename; + + // 2. ACT + // Destroy the current chat history + Object.Destroy(_chatHistory); + + // Recreate the chat history and load from the same file + _chatHistory = _gameObject.AddComponent(); + _chatHistory.ChatHistoryFilename = filename; + await _chatHistory.Load(); + + // 3. ASSERT + // Validate the messages were loaded + List loadedMessages = _chatHistory.GetChatMessages(); + Assert.AreEqual(loadedMessages.Count, 2); + Assert.AreEqual(loadedMessages[0].role, "user"); + Assert.AreEqual(loadedMessages[0].content, "hello"); + Assert.AreEqual(loadedMessages[1].role, "ai"); + Assert.AreEqual(loadedMessages[1].content, "hi"); + } + + [TearDown] + public void Teardown() + { + // Cleanup the GameObject after the test + Object.Destroy(_gameObject); + } + } +} \ No newline at end of file diff --git a/Tests/Runtime/TestLLMChatHistory.cs.meta b/Tests/Runtime/TestLLMChatHistory.cs.meta new file mode 100644 index 00000000..59db3f27 --- /dev/null +++ b/Tests/Runtime/TestLLMChatHistory.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: ed6dca1c56c54d04caf905b0b1caf269 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: From 525d547694c412817bba746271b3d059f472cf44 Mon Sep 17 00:00:00 2001 From: nabrown Date: Sat, 14 Sep 2024 16:40:49 -0400 Subject: [PATCH 2/7] Fixing merge mistake --- Runtime/LLMCharacter.cs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 21ca102d..c2bbf419 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -45,6 +45,8 @@ public class LLMCharacter : MonoBehaviour [ModelAdvanced] public string grammar = null; /// option to cache the prompt as it is being created by the chat to avoid reprocessing the entire prompt every time (default: true) [ModelAdvanced] public bool cachePrompt = true; + /// specify which slot of the server to use for computation (affects caching) + [ModelAdvanced] public int slot = -1; /// seed for reproducibility. For random results every time set to -1. [ModelAdvanced] public int seed = 0; /// number of tokens to predict (-1 = infinity, -2 = until context filled). @@ -122,7 +124,6 @@ public class LLMCharacter : MonoBehaviour private string chatTemplate; private ChatTemplate template = null; public string grammarString; - protected int id_slot = -1; private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") }; private List WIPRequests = new List(); /// \endcond @@ -148,7 +149,12 @@ public async void Awake() LLMUnitySetup.LogError($"No LLM assigned or detected for LLMCharacter {name}!"); return; } - id_slot = llm.Register(this); + int slotFromServer = llm.Register(this); + if (slot == -1) slot = slotFromServer; + } + else + { + if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add(("Authorization", "Bearer " + APIKey)); } InitGrammar(); @@ -159,6 +165,7 @@ public async void Awake() void OnValidate() { AssignLLM(); + if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set"); } void Reset() @@ -328,7 +335,7 @@ ChatRequest GenerateRequest(string prompt) ChatRequest chatRequest = new ChatRequest(); if (debugPrompt) LLMUnitySetup.Log(prompt); chatRequest.prompt = prompt; - chatRequest.id_slot = id_slot; + chatRequest.id_slot = slot; chatRequest.temperature = temperature; chatRequest.top_k = topK; chatRequest.top_p = topP; @@ -571,7 +578,7 @@ public async Task> Embeddings(string query, Callback> ca private async Task Slot(string filepath, string action) { SlotRequest slotRequest = new SlotRequest(); - slotRequest.id_slot = id_slot; + slotRequest.id_slot = slot; slotRequest.filepath = filepath; slotRequest.action = action; string json = JsonUtility.ToJson(slotRequest); @@ -623,7 +630,7 @@ protected Ret ConvertContent(string response, ContentCallback= 0) llm.CancelRequest(id_slot); + if (slot >= 0) llm.CancelRequest(slot); } protected void CancelRequestsRemote() From b5b8d1e48dd4ce72c1c46f08eff0888a4c13d075 Mon Sep 17 00:00:00 2001 From: nabrown Date: Sat, 21 Sep 2024 12:54:43 -0400 Subject: [PATCH 3/7] PR updates --- Runtime/LLMCharacter.cs | 30 ++++++------ Runtime/LLMChatHistory.cs | 14 ++++-- Samples~/AndroidDemo/Scene.unity | 28 +++++------ Samples~/ChatBot/Scene.unity | 26 ++++------ .../KnowledgeBaseGame/KnowledgeBaseGame.cs | 2 +- Samples~/KnowledgeBaseGame/Scene.unity | 26 ++++------ Samples~/MultipleCharacters/Scene.unity | 47 ++++++------------- Samples~/SimpleInteraction/Scene.unity | 26 ++++------ .../SimpleInteraction/SimpleInteraction.cs | 2 +- Tests/Runtime/TestLLM.cs | 22 ++++----- 10 files changed, 92 insertions(+), 131 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index c2bbf419..1091b4c4 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -110,21 +110,22 @@ public class LLMCharacter : MonoBehaviour /// By providing a token ID and a positive or negative bias value, you can increase or decrease the probability of that token being generated. public Dictionary logitBias = null; + /// the chat history component that this character uses to store it's chat messages + [Chat] public LLMChatHistory chatHistory; /// the name of the player - [Chat] public string playerRole = "user"; + [Chat] public string playerName = "user"; /// the name of the AI - [Chat] public string aiRole = "assistant"; + [Chat] public string aiName = "assistant"; /// a description of the AI role. This defines the LLMCharacter system prompt [TextArea(5, 10), Chat] public string systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."; /// option to set the number of tokens to retain from the prompt (nKeep) based on the LLMCharacter system prompt public bool setNKeepToPrompt = true; /// \cond HIDE - private LLMChatHistory chatHistory; private string chatTemplate; private ChatTemplate template = null; public string grammarString; - private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") }; + private List<(string, string)> requestHeaders; private List WIPRequests = new List(); /// \endcond @@ -141,6 +142,8 @@ public async void Awake() { // Start the LLM server in a cross-platform way if (!enabled) return; + + requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") }; if (!remote) { AssignLLM(); @@ -213,10 +216,7 @@ void SortBySceneAndHierarchy(LLM[] array) protected void InitHistory() { - // Check if we have a chat history component available - chatHistory = GetComponent(); - - // If not, go ahead and create one. + // If no specific chat history object has been assigned to this character, create one. if (chatHistory == null) { chatHistory = gameObject.AddComponent(); } @@ -262,7 +262,7 @@ private async Task InitNKeep() if (setNKeepToPrompt && nKeep == -1) { if (!CheckTemplate()) return false; - string systemPrompt = template.ComputePrompt(new List(){GetSystemPromptMessage()}, playerRole, aiRole, false); + string systemPrompt = template.ComputePrompt(new List(){GetSystemPromptMessage()}, playerName, aiName, false); List tokens = await Tokenize(systemPrompt); if (tokens == null) return false; SetNKeep(tokens); @@ -324,7 +324,7 @@ public async void SetGrammar(string path) List GetStopwords() { if (!CheckTemplate()) return null; - List stopAll = new List(template.GetStop(playerRole, aiRole)); + List stopAll = new List(template.GetStop(playerName, aiName)); if (stop != null) stopAll.AddRange(stop); return stopAll; } @@ -366,12 +366,12 @@ ChatRequest GenerateRequest(string prompt) public async Task AddPlayerMessage(string content) { - await chatHistory.AddMessage(playerRole, content); + await chatHistory.AddMessage(playerName, content); } public async Task AddAIMessage(string content) { - await chatHistory.AddMessage(aiRole, content); + await chatHistory.AddMessage(aiName, content); } public LLMChatHistory GetChatHistory() @@ -457,7 +457,7 @@ public async Task Chat(string query, Callback callback = null, E if (!CheckTemplate()) return null; if (!await InitNKeep()) return null; - var playerMessage = new ChatMessage() { role = playerRole, content = query }; + var playerMessage = new ChatMessage() { role = playerName, content = query }; // Setup the full list of messages for the current request List promptMessages = chatHistory ? chatHistory.GetChatMessages() : new List(); @@ -465,7 +465,7 @@ public async Task Chat(string query, Callback callback = null, E promptMessages.Add(playerMessage); // Prepare the request - string formattedPrompt = template.ComputePrompt(promptMessages, playerRole, aiRole); + string formattedPrompt = template.ComputePrompt(promptMessages, playerName, aiName); string requestJson = JsonUtility.ToJson(GenerateRequest(formattedPrompt)); // Call the LLM @@ -575,7 +575,7 @@ public async Task> Embeddings(string query, Callback> ca return await PostRequest>(json, "embeddings", EmbeddingsContent, callback); } - private async Task Slot(string filepath, string action) + protected async Task Slot(string filepath, string action) { SlotRequest slotRequest = new SlotRequest(); slotRequest.id_slot = slot; diff --git a/Runtime/LLMChatHistory.cs b/Runtime/LLMChatHistory.cs index 7e560909..fda00202 100644 --- a/Runtime/LLMChatHistory.cs +++ b/Runtime/LLMChatHistory.cs @@ -21,6 +21,11 @@ public class LLMChatHistory : MonoBehaviour /// public string ChatHistoryFilename = string.Empty; + /// + /// If true, this component will automatically save a copy of its data to the filesystem with each update. + /// + public bool EnableAutoSave = true; + /// /// The current chat history /// @@ -54,8 +59,9 @@ await WithChatLock(async () => { await Task.Run(() => _chatHistory.Add(new ChatMessage { role = role, content = content })); }); - // Save our newly updated chat history to the file system - _ = Save(); + if (EnableAutoSave) { + _ = Save(); + } } public List GetChatMessages() { @@ -114,7 +120,9 @@ await WithChatLock(async () => { await Task.Run(() => _chatHistory.Clear()); }); - _ = Save(); + if (EnableAutoSave) { + _ = Save(); + } } public bool IsEmpty() { diff --git a/Samples~/AndroidDemo/Scene.unity b/Samples~/AndroidDemo/Scene.unity index 62a53cda..0d4e29d0 100644 --- a/Samples~/AndroidDemo/Scene.unity +++ b/Samples~/AndroidDemo/Scene.unity @@ -653,7 +653,6 @@ GameObject: m_Component: - component: {fileID: 498662972} - component: {fileID: 498662973} - - component: {fileID: 498662974} m_Layer: 0 m_Name: LLMCharacter m_TagString: Untagged @@ -694,12 +693,14 @@ MonoBehaviour: host: localhost port: 13333 numRetries: -1 + APIKey: cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 grammar: cachePrompt: 1 + slot: -1 seed: 0 numPredict: 256 temperature: 0.2 @@ -721,25 +722,13 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerRole: user - aiRole: assistant + chatHistory: {fileID: 0} + playerName: user + aiName: assistant systemPrompt: A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. setNKeepToPrompt: 1 grammarString: ---- !u!114 &498662974 -MonoBehaviour: - m_ObjectHideFlags: 0 - m_CorrespondingSourceObject: {fileID: 0} - m_PrefabInstance: {fileID: 0} - m_PrefabAsset: {fileID: 0} - m_GameObject: {fileID: 498662970} - m_Enabled: 1 - m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} - m_Name: - m_EditorClassIdentifier: - ChatHistoryFilename: --- !u!1 &659217390 GameObject: m_ObjectHideFlags: 0 @@ -909,7 +898,7 @@ MonoBehaviour: m_FloatArgument: 0 m_StringArgument: m_BoolArgument: 0 - m_CallState: 0 + m_CallState: 2 --- !u!114 &724531322 MonoBehaviour: m_ObjectHideFlags: 0 @@ -1350,6 +1339,11 @@ MonoBehaviour: lora: loraWeights: flashAttention: 0 + APIKey: + SSLCert: + SSLCertPath: + SSLKey: + SSLKeyPath: --- !u!4 &1047848255 Transform: m_ObjectHideFlags: 0 diff --git a/Samples~/ChatBot/Scene.unity b/Samples~/ChatBot/Scene.unity index dbecb667..10faae40 100644 --- a/Samples~/ChatBot/Scene.unity +++ b/Samples~/ChatBot/Scene.unity @@ -672,6 +672,11 @@ MonoBehaviour: lora: loraWeights: flashAttention: 0 + APIKey: + SSLCert: + SSLCertPath: + SSLKey: + SSLKeyPath: --- !u!1 &1051131186 GameObject: m_ObjectHideFlags: 0 @@ -1093,7 +1098,6 @@ GameObject: m_Component: - component: {fileID: 1844795170} - component: {fileID: 1844795171} - - component: {fileID: 1844795172} m_Layer: 0 m_Name: LLMCharacter m_TagString: Untagged @@ -1134,12 +1138,14 @@ MonoBehaviour: host: localhost port: 13333 numRetries: -1 + APIKey: cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 grammar: cachePrompt: 1 + slot: -1 seed: 0 numPredict: 256 temperature: 0.2 @@ -1161,25 +1167,13 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerRole: user - aiRole: assistant + chatHistory: {fileID: 0} + playerName: user + aiName: assistant systemPrompt: A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. setNKeepToPrompt: 1 grammarString: ---- !u!114 &1844795172 -MonoBehaviour: - m_ObjectHideFlags: 0 - m_CorrespondingSourceObject: {fileID: 0} - m_PrefabInstance: {fileID: 0} - m_PrefabAsset: {fileID: 0} - m_GameObject: {fileID: 1844795168} - m_Enabled: 1 - m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} - m_Name: - m_EditorClassIdentifier: - ChatHistoryFilename: --- !u!1 &2011827136 GameObject: m_ObjectHideFlags: 0 diff --git a/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs b/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs index 67509a6e..ce6ea74b 100644 --- a/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs +++ b/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs @@ -145,7 +145,7 @@ protected override void DropdownChange(int selection) Debug.Log($"{currentBotName}: {currentBot.NumPhrases()} phrases available"); // set the LLMCharacter name - llmCharacter.aiRole = currentBotName; + llmCharacter.aiName = currentBotName; } void SetAIText(string text) diff --git a/Samples~/KnowledgeBaseGame/Scene.unity b/Samples~/KnowledgeBaseGame/Scene.unity index 5f43dbe0..67338df7 100644 --- a/Samples~/KnowledgeBaseGame/Scene.unity +++ b/Samples~/KnowledgeBaseGame/Scene.unity @@ -4259,7 +4259,6 @@ GameObject: m_Component: - component: {fileID: 1275496424} - component: {fileID: 1275496423} - - component: {fileID: 1275496425} m_Layer: 0 m_Name: LLMCharacter m_TagString: Untagged @@ -4285,12 +4284,14 @@ MonoBehaviour: host: localhost port: 13333 numRetries: -1 + APIKey: cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 grammar: cachePrompt: 1 + slot: -1 seed: 0 numPredict: 256 temperature: 0.2 @@ -4312,8 +4313,9 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerRole: Detective - aiRole: + chatHistory: {fileID: 0} + playerName: Detective + aiName: systemPrompt: "You are a robot working at a house where a diamond was stolen and a detective\r asks you questions about the robbery.\r\n\r\nAnswer the question provided at the\r section \"Question\" based on the possible answers at the section @@ -4335,19 +4337,6 @@ Transform: m_Children: [] m_Father: {fileID: 0} m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0} ---- !u!114 &1275496425 -MonoBehaviour: - m_ObjectHideFlags: 0 - m_CorrespondingSourceObject: {fileID: 0} - m_PrefabInstance: {fileID: 0} - m_PrefabAsset: {fileID: 0} - m_GameObject: {fileID: 1275496422} - m_Enabled: 1 - m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} - m_Name: - m_EditorClassIdentifier: - ChatHistoryFilename: --- !u!1 &1278063793 GameObject: m_ObjectHideFlags: 0 @@ -7559,6 +7548,11 @@ MonoBehaviour: lora: loraWeights: flashAttention: 0 + APIKey: + SSLCert: + SSLCertPath: + SSLKey: + SSLKeyPath: --- !u!4 &2142407557 Transform: m_ObjectHideFlags: 0 diff --git a/Samples~/MultipleCharacters/Scene.unity b/Samples~/MultipleCharacters/Scene.unity index 709ce3e8..b9017152 100644 --- a/Samples~/MultipleCharacters/Scene.unity +++ b/Samples~/MultipleCharacters/Scene.unity @@ -555,7 +555,6 @@ GameObject: m_Component: - component: {fileID: 714802013} - component: {fileID: 714802014} - - component: {fileID: 714802015} m_Layer: 0 m_Name: LLMCharacter2 m_TagString: Untagged @@ -596,12 +595,14 @@ MonoBehaviour: host: localhost port: 13333 numRetries: -1 + APIKey: cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 grammar: cachePrompt: 1 + slot: -1 seed: 0 numPredict: 256 temperature: 0.2 @@ -623,26 +624,14 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerRole: Human - aiRole: Adam + chatHistory: {fileID: 0} + playerName: Human + aiName: Adam systemPrompt: A chat between a curious human and an artificial intelligence assistant named Adam. The assistant gives helpful, detailed, and polite answers to the human's questions. setNKeepToPrompt: 1 grammarString: ---- !u!114 &714802015 -MonoBehaviour: - m_ObjectHideFlags: 0 - m_CorrespondingSourceObject: {fileID: 0} - m_PrefabInstance: {fileID: 0} - m_PrefabAsset: {fileID: 0} - m_GameObject: {fileID: 714802011} - m_Enabled: 1 - m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} - m_Name: - m_EditorClassIdentifier: - ChatHistoryFilename: --- !u!1 &726528676 GameObject: m_ObjectHideFlags: 0 @@ -1542,6 +1531,11 @@ MonoBehaviour: lora: loraWeights: flashAttention: 0 + APIKey: + SSLCert: + SSLCertPath: + SSLKey: + SSLKeyPath: --- !u!4 &1047848255 Transform: m_ObjectHideFlags: 0 @@ -1937,7 +1931,6 @@ GameObject: m_Component: - component: {fileID: 1493015759} - component: {fileID: 1493015760} - - component: {fileID: 1493015761} m_Layer: 0 m_Name: LLMCharacter1 m_TagString: Untagged @@ -1978,12 +1971,14 @@ MonoBehaviour: host: localhost port: 13333 numRetries: -1 + APIKey: cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 grammar: cachePrompt: 1 + slot: -1 seed: 0 numPredict: 256 temperature: 0.2 @@ -2005,26 +2000,14 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerRole: Human - aiRole: Eve + chatHistory: {fileID: 0} + playerName: Human + aiName: Eve systemPrompt: A chat between a curious human and an artificial intelligence assistant named Eve. The assistant gives helpful, detailed, and polite answers to the human's questions. setNKeepToPrompt: 1 grammarString: ---- !u!114 &1493015761 -MonoBehaviour: - m_ObjectHideFlags: 0 - m_CorrespondingSourceObject: {fileID: 0} - m_PrefabInstance: {fileID: 0} - m_PrefabAsset: {fileID: 0} - m_GameObject: {fileID: 1493015757} - m_Enabled: 1 - m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} - m_Name: - m_EditorClassIdentifier: - ChatHistoryFilename: --- !u!1 &1609985808 GameObject: m_ObjectHideFlags: 0 diff --git a/Samples~/SimpleInteraction/Scene.unity b/Samples~/SimpleInteraction/Scene.unity index 4a2dbf79..1440d9d9 100644 --- a/Samples~/SimpleInteraction/Scene.unity +++ b/Samples~/SimpleInteraction/Scene.unity @@ -443,7 +443,6 @@ GameObject: m_Component: - component: {fileID: 498662972} - component: {fileID: 498662973} - - component: {fileID: 498662974} m_Layer: 0 m_Name: LLMCharacter m_TagString: Untagged @@ -484,12 +483,14 @@ MonoBehaviour: host: localhost port: 13333 numRetries: -1 + APIKey: cacheFilename: saveCache: 0 debugPrompt: 0 stream: 1 grammar: cachePrompt: 1 + slot: -1 seed: 0 numPredict: 256 temperature: 0.2 @@ -511,25 +512,13 @@ MonoBehaviour: ignoreEos: 0 nKeep: -1 stop: [] - playerRole: user - aiRole: assistant + chatHistory: {fileID: 0} + playerName: user + aiName: assistant systemPrompt: A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. setNKeepToPrompt: 1 grammarString: ---- !u!114 &498662974 -MonoBehaviour: - m_ObjectHideFlags: 0 - m_CorrespondingSourceObject: {fileID: 0} - m_PrefabInstance: {fileID: 0} - m_PrefabAsset: {fileID: 0} - m_GameObject: {fileID: 498662970} - m_Enabled: 1 - m_EditorHideFlags: 0 - m_Script: {fileID: 11500000, guid: b9aef079d11e8894bae3ae510742c32f, type: 3} - m_Name: - m_EditorClassIdentifier: - ChatHistoryFilename: --- !u!1 &724531319 GameObject: m_ObjectHideFlags: 0 @@ -1065,6 +1054,11 @@ MonoBehaviour: lora: loraWeights: flashAttention: 0 + APIKey: + SSLCert: + SSLCertPath: + SSLKey: + SSLKeyPath: --- !u!4 &1047848255 Transform: m_ObjectHideFlags: 0 diff --git a/Samples~/SimpleInteraction/SimpleInteraction.cs b/Samples~/SimpleInteraction/SimpleInteraction.cs index 7c9a2cbc..910e5f90 100644 --- a/Samples~/SimpleInteraction/SimpleInteraction.cs +++ b/Samples~/SimpleInteraction/SimpleInteraction.cs @@ -20,7 +20,7 @@ async void onInputFieldSubmit(string message) { playerText.interactable = false; AIText.text = "..."; - await llmCharacter.Chat(message, SetAIText, AIReplyComplete); + _ = llmCharacter.Chat(message, SetAIText, AIReplyComplete); } public void SetAIText(string text) diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index 288c3dd0..76d57c13 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -193,8 +193,8 @@ public virtual LLMCharacter CreateLLMCharacter() { LLMCharacter llmCharacter = gameObject.AddComponent(); llmCharacter.llm = llm; - llmCharacter.playerRole = "Instruction"; - llmCharacter.aiRole = "Response"; + llmCharacter.playerName = "Instruction"; + llmCharacter.aiName = "Response"; llmCharacter.systemPrompt = prompt; llmCharacter.temperature = 0; llmCharacter.seed = 0; @@ -235,15 +235,14 @@ public virtual async Task Tests() await llmCharacter.Tokenize("I", TestTokens); await llmCharacter.Warmup(); TestInitParameters(tokens1, 0); - TestWarmup(); await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply1)); - TestPostChat(3); + TestPostChat(2); await llmCharacter.SetPrompt(llmCharacter.systemPrompt); - llmCharacter.aiRole = "False response"; + llmCharacter.aiName = "False response"; await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply2)); - TestPostChat(3); + TestPostChat(2); await llmCharacter.Chat("bye!"); - TestPostChat(5); + TestPostChat(4); prompt = "How are you?"; await llmCharacter.SetPrompt(prompt); await llmCharacter.Chat("hi"); @@ -255,7 +254,7 @@ public virtual async Task Tests() public void TestInitParameters(int nKeep, int expectedMessageCount) { Assert.AreEqual(llmCharacter.nKeep, nKeep); - Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerRole, llmCharacter.aiRole).Length > 0); + Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerName, llmCharacter.aiName).Length > 0); Assert.AreEqual(llmCharacter.GetChatHistory().GetChatMessages().Count, expectedMessageCount); } @@ -264,11 +263,6 @@ public void TestTokens(List tokens) Assert.AreEqual(tokens, new List {40}); } - public void TestWarmup() - { - //Assert.That(llmCharacter.chat.Count == 1); - } - public void TestChat(string generatedReply, string expectedReply) { Debug.Log(generatedReply.Trim()); @@ -277,7 +271,7 @@ public void TestChat(string generatedReply, string expectedReply) public void TestPostChat(int num) { - //Assert.That(llmCharacter.chat.Count == num); + Assert.AreEqual(num, llmCharacter.chatHistory.GetChatMessages().Count); } public void TestEmbeddings(List embeddings) From 994b62462f45968baacf2c2872c337a7594c756f Mon Sep 17 00:00:00 2001 From: nabrown Date: Sat, 21 Sep 2024 13:09:20 -0400 Subject: [PATCH 4/7] Updating README --- README.md | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index add4b72f..d0c26f2a 100644 --- a/README.md +++ b/README.md @@ -151,19 +151,18 @@ It is also a good idea to enable the `Download on Build` option in the LLM GameO
Save / Load your chat history -To automatically save / load your chat history, you can specify the `Save` parameter of the LLMCharacter to the filename (or relative path) of your choice. -The file is saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html). -This also saves the state of the LLM which means that the previously cached prompt does not need to be recomputed. +Your `LLMCharacter` components will automatically create corresponding `LLMChatHistory` components to store their chat histories. +- If you don't want to save the chat history, set the `EnableAutoSave` of the `LLMChatHistory` to false. +- You can specify the filename to use by setting the `ChatHistoryFilename` of the `LLMChatHistory`. The file is saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html). To manually save your chat history, you can use: ``` c# - llmCharacter.Save("filename"); + llmChatHistory.Save(); ``` and to load the history: ``` c# - llmCharacter.Load("filename"); + llmChatHistory.Load(); ``` -where filename the filename or relative path of your choice.
@@ -452,8 +451,8 @@ If the user's GPU is not supported, the LLM will fall back to the CPU - `Port` port of the LLM server (if `Remote` is set) - `Num Retries` number of HTTP request retries from the LLM server (if `Remote` is set) - `API key` API key of the LLM server (if `Remote` is set) --
Save save filename or relative path If set, the chat history and LLM state (if save cache is enabled) is automatically saved to file specified.
The chat history is saved with a json suffix and the LLM state with a cache suffix.
Both files are saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).
-- `Save Cache` select to save the LLM state along with the chat history. The LLM state is typically around 100MB+. +-
Cache Filename save filename or relative path If set, the LLM state (if save cache is enabled) is automatically saved to file specified.
The LLM state is saved with a cache suffix.
The file is saved in the [persistentDataPath folder of Unity](https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html).
+- `Save Cache` select to save the LLM state. The LLM state is typically around 100MB+. - `Debug Prompt` select to log the constructed prompts in the Unity Editor #### 🗨️ Chat Settings From d598c7266e4a90c30478cbcea45e18a166784900 Mon Sep 17 00:00:00 2001 From: nabrown Date: Mon, 23 Sep 2024 08:55:55 -0400 Subject: [PATCH 5/7] updating so that changing out the chat history invalidates the cache --- Runtime/LLMCharacter.cs | 23 +++++++++++++++++------ Tests/Runtime/TestLLM.cs | 2 +- Tests/Runtime/TestLLMChatHistory.cs | 4 ++-- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 1091b4c4..83159225 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -111,7 +111,13 @@ public class LLMCharacter : MonoBehaviour public Dictionary logitBias = null; /// the chat history component that this character uses to store it's chat messages - [Chat] public LLMChatHistory chatHistory; + public LLMChatHistory chatHistory { + get { return _chatHistory; } + set { + _chatHistory = value; + isCacheInvalid = true; + } + } /// the name of the player [Chat] public string playerName = "user"; /// the name of the AI @@ -122,11 +128,14 @@ public class LLMCharacter : MonoBehaviour public bool setNKeepToPrompt = true; /// \cond HIDE + [SerializeField, Chat] + private LLMChatHistory _chatHistory; private string chatTemplate; private ChatTemplate template = null; public string grammarString; private List<(string, string)> requestHeaders; private List WIPRequests = new List(); + private bool isCacheInvalid = false; /// \endcond /// @@ -374,11 +383,6 @@ public async Task AddAIMessage(string content) await chatHistory.AddMessage(aiName, content); } - public LLMChatHistory GetChatHistory() - { - return chatHistory; - } - protected string ChatContent(ChatResult result) { // get content from a chat result received from the endpoint @@ -605,6 +609,13 @@ public virtual async Task SaveCache() public virtual async Task LoadCache() { if (remote || !saveCache || !File.Exists(GetCacheSavePath())) return null; + + // If the cache has become invalid, don't bother loading this time. + if (isCacheInvalid) { + isCacheInvalid = false; + return null; + } + string result = await Slot(GetCacheSavePath(), "restore"); return result; } diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index 76d57c13..faa1dc4a 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -255,7 +255,7 @@ public void TestInitParameters(int nKeep, int expectedMessageCount) { Assert.AreEqual(llmCharacter.nKeep, nKeep); Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerName, llmCharacter.aiName).Length > 0); - Assert.AreEqual(llmCharacter.GetChatHistory().GetChatMessages().Count, expectedMessageCount); + Assert.AreEqual(llmCharacter.chatHistory?.GetChatMessages().Count, expectedMessageCount); } public void TestTokens(List tokens) diff --git a/Tests/Runtime/TestLLMChatHistory.cs b/Tests/Runtime/TestLLMChatHistory.cs index 1205bd14..b3b937c7 100644 --- a/Tests/Runtime/TestLLMChatHistory.cs +++ b/Tests/Runtime/TestLLMChatHistory.cs @@ -37,10 +37,10 @@ public async void TestSaveAndLoad() // 2. ACT // Destroy the current chat history - Object.Destroy(_chatHistory); + Object.Destroy(_gameObject); // Recreate the chat history and load from the same file - _chatHistory = _gameObject.AddComponent(); + Setup(); _chatHistory.ChatHistoryFilename = filename; await _chatHistory.Load(); From f968397cc11f9657d87a32dc16eacb2e5d69b900 Mon Sep 17 00:00:00 2001 From: nabrown Date: Fri, 11 Oct 2024 09:01:13 -0400 Subject: [PATCH 6/7] PR Fixes --- Runtime/LLMCharacter.cs | 73 +++++++++++++------ .../KnowledgeBaseGame/KnowledgeBaseGame.cs | 2 +- .../SimpleInteraction/SimpleInteraction.cs | 2 +- Tests/Runtime/TestLLM.cs | 12 +-- 4 files changed, 57 insertions(+), 32 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index 83159225..a8b56728 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -32,8 +32,8 @@ public class LLMCharacter : MonoBehaviour [Remote] public int numRetries = 10; /// allows to use a server with API key [Remote] public string APIKey; - /// file to save the cache. - [LLM] public string cacheFilename = ""; + /// filename to use when saving the cache or chat history. + [LLM] public string save = ""; /// toggle to save the LLM cache. This speeds up the prompt calculation but also requires ~100MB of space per character. [LLM] public bool saveCache = false; /// select to log the constructed prompt the Unity Editor. @@ -121,9 +121,9 @@ public LLMChatHistory chatHistory { /// the name of the player [Chat] public string playerName = "user"; /// the name of the AI - [Chat] public string aiName = "assistant"; + [Chat] public string AIName = "assistant"; /// a description of the AI role. This defines the LLMCharacter system prompt - [TextArea(5, 10), Chat] public string systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."; + [TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."; /// option to set the number of tokens to retain from the prompt (nKeep) based on the LLMCharacter system prompt public bool setNKeepToPrompt = true; @@ -170,7 +170,7 @@ public async void Awake() } InitGrammar(); - InitHistory(); + await InitHistory(); await LoadCache(); } @@ -223,17 +223,19 @@ void SortBySceneAndHierarchy(LLM[] array) } } - protected void InitHistory() + protected async Task InitHistory() { // If no specific chat history object has been assigned to this character, create one. if (chatHistory == null) { chatHistory = gameObject.AddComponent(); + chatHistory.ChatHistoryFilename = save; + await chatHistory.Load(); } } public virtual string GetCacheSavePath() { - return Path.Combine(Application.persistentDataPath, cacheFilename + ".cache").Replace('\\', '/'); + return Path.Combine(Application.persistentDataPath, save + ".cache").Replace('\\', '/'); } /// @@ -243,7 +245,7 @@ public virtual string GetCacheSavePath() /// whether to clear (true) or keep (false) the current chat history on top of the system prompt. public async Task SetPrompt(string newPrompt, bool clearChat = true) { - systemPrompt = newPrompt; + prompt = newPrompt; nKeep = -1; if (clearChat) { @@ -263,7 +265,7 @@ private bool CheckTemplate() } private ChatMessage GetSystemPromptMessage() { - return new ChatMessage() { role = LLMConstants.SYSTEM_ROLE, content = systemPrompt }; + return new ChatMessage() { role = LLMConstants.SYSTEM_ROLE, content = prompt }; } private async Task InitNKeep() @@ -271,7 +273,7 @@ private async Task InitNKeep() if (setNKeepToPrompt && nKeep == -1) { if (!CheckTemplate()) return false; - string systemPrompt = template.ComputePrompt(new List(){GetSystemPromptMessage()}, playerName, aiName, false); + string systemPrompt = template.ComputePrompt(new List(){GetSystemPromptMessage()}, playerName, AIName, false); List tokens = await Tokenize(systemPrompt); if (tokens == null) return false; SetNKeep(tokens); @@ -333,7 +335,7 @@ public async void SetGrammar(string path) List GetStopwords() { if (!CheckTemplate()) return null; - List stopAll = new List(template.GetStop(playerName, aiName)); + List stopAll = new List(template.GetStop(playerName, AIName)); if (stop != null) stopAll.AddRange(stop); return stopAll; } @@ -380,7 +382,7 @@ public async Task AddPlayerMessage(string content) public async Task AddAIMessage(string content) { - await chatHistory.AddMessage(aiName, content); + await chatHistory.AddMessage(AIName, content); } protected string ChatContent(ChatResult result) @@ -469,14 +471,14 @@ public async Task Chat(string query, Callback callback = null, E promptMessages.Add(playerMessage); // Prepare the request - string formattedPrompt = template.ComputePrompt(promptMessages, playerName, aiName); + string formattedPrompt = template.ComputePrompt(promptMessages, playerName, AIName); string requestJson = JsonUtility.ToJson(GenerateRequest(formattedPrompt)); // Call the LLM string result = await CompletionRequest(requestJson, callback); // Update our chat history if required - if (addToHistory && chatHistory && result != null) + if (addToHistory && result != null) { await AddPlayerMessage(query); await AddAIMessage(result); @@ -598,23 +600,19 @@ public virtual async Task SaveCache() { if (remote || !saveCache) return null; string result = await Slot(GetCacheSavePath(), "save"); + + // We now have a valid cache + isCacheInvalid = false; + return result; } /// - /// Load the cache from the provided filename / relative path. + /// Load the prompt cache. /// - /// filename / relative path to load the cache from - /// public virtual async Task LoadCache() { - if (remote || !saveCache || !File.Exists(GetCacheSavePath())) return null; - - // If the cache has become invalid, don't bother loading this time. - if (isCacheInvalid) { - isCacheInvalid = false; - return null; - } + if (remote || !saveCache || isCacheInvalid || !File.Exists(GetCacheSavePath())) return null; string result = await Slot(GetCacheSavePath(), "restore"); return result; @@ -784,6 +782,33 @@ protected async Task PostRequest(string json, string endpoint, Co if (remote) return await PostRequestRemote(json, endpoint, getContent, callback); return await PostRequestLocal(json, endpoint, getContent, callback); } + + #region Obsolete Functions + + [Obsolete] + public virtual async Task Save(string filename) { + + if (chatHistory) { + await chatHistory.Save(); + } + + return await SaveCache(); + } + + [Obsolete] + public virtual async Task Load(string filename) { + + if (chatHistory) { + chatHistory.ChatHistoryFilename = filename; + await chatHistory.Load(); + } + + save = filename; + return await LoadCache(); + } + + #endregion + } /// \cond HIDE diff --git a/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs b/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs index ce6ea74b..e6f3bc29 100644 --- a/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs +++ b/Samples~/KnowledgeBaseGame/KnowledgeBaseGame.cs @@ -145,7 +145,7 @@ protected override void DropdownChange(int selection) Debug.Log($"{currentBotName}: {currentBot.NumPhrases()} phrases available"); // set the LLMCharacter name - llmCharacter.aiName = currentBotName; + llmCharacter.AIName = currentBotName; } void SetAIText(string text) diff --git a/Samples~/SimpleInteraction/SimpleInteraction.cs b/Samples~/SimpleInteraction/SimpleInteraction.cs index 910e5f90..9c088c3b 100644 --- a/Samples~/SimpleInteraction/SimpleInteraction.cs +++ b/Samples~/SimpleInteraction/SimpleInteraction.cs @@ -16,7 +16,7 @@ void Start() playerText.Select(); } - async void onInputFieldSubmit(string message) + void onInputFieldSubmit(string message) { playerText.interactable = false; AIText.text = "..."; diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index faa1dc4a..57040110 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -194,8 +194,8 @@ public virtual LLMCharacter CreateLLMCharacter() LLMCharacter llmCharacter = gameObject.AddComponent(); llmCharacter.llm = llm; llmCharacter.playerName = "Instruction"; - llmCharacter.aiName = "Response"; - llmCharacter.systemPrompt = prompt; + llmCharacter.AIName = "Response"; + llmCharacter.prompt = prompt; llmCharacter.temperature = 0; llmCharacter.seed = 0; llmCharacter.stream = false; @@ -237,8 +237,8 @@ public virtual async Task Tests() TestInitParameters(tokens1, 0); await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply1)); TestPostChat(2); - await llmCharacter.SetPrompt(llmCharacter.systemPrompt); - llmCharacter.aiName = "False response"; + await llmCharacter.SetPrompt(llmCharacter.prompt); + llmCharacter.AIName = "False response"; await llmCharacter.Chat(query, (string reply) => TestChat(reply, reply2)); TestPostChat(2); await llmCharacter.Chat("bye!"); @@ -254,7 +254,7 @@ public virtual async Task Tests() public void TestInitParameters(int nKeep, int expectedMessageCount) { Assert.AreEqual(llmCharacter.nKeep, nKeep); - Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerName, llmCharacter.aiName).Length > 0); + Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerName, llmCharacter.AIName).Length > 0); Assert.AreEqual(llmCharacter.chatHistory?.GetChatMessages().Count, expectedMessageCount); } @@ -444,7 +444,7 @@ public class TestLLMCharacter_Save : TestLLM public override LLMCharacter CreateLLMCharacter() { LLMCharacter llmCharacter = base.CreateLLMCharacter(); - llmCharacter.cacheFilename = saveName; + llmCharacter.save = saveName; llmCharacter.saveCache = true; return llmCharacter; } From f70d6fedb144b95071b5e7a9dc80fa8042928feb Mon Sep 17 00:00:00 2001 From: nabrown Date: Tue, 15 Oct 2024 16:56:00 -0400 Subject: [PATCH 7/7] PR Fixes --- Runtime/LLMCharacter.cs | 31 +++++++++++++++++++++++------ Runtime/LLMChatHistory.cs | 31 +++++++++++++++++++---------- Tests/Runtime/TestLLM.cs | 4 ++-- Tests/Runtime/TestLLMChatHistory.cs | 2 +- 4 files changed, 48 insertions(+), 20 deletions(-) diff --git a/Runtime/LLMCharacter.cs b/Runtime/LLMCharacter.cs index a8b56728..960f6318 100644 --- a/Runtime/LLMCharacter.cs +++ b/Runtime/LLMCharacter.cs @@ -273,7 +273,7 @@ private async Task InitNKeep() if (setNKeepToPrompt && nKeep == -1) { if (!CheckTemplate()) return false; - string systemPrompt = template.ComputePrompt(new List(){GetSystemPromptMessage()}, playerName, AIName, false); + string systemPrompt = template.ComputePrompt(new List(){GetSystemPromptMessage()}, playerName, "", false); List tokens = await Tokenize(systemPrompt); if (tokens == null) return false; SetNKeep(tokens); @@ -375,14 +375,19 @@ ChatRequest GenerateRequest(string prompt) return chatRequest; } + public async Task AddMessage(string role, string content) + { + await chatHistory.AddMessage(role, content); + } + public async Task AddPlayerMessage(string content) { - await chatHistory.AddMessage(playerName, content); + await AddMessage(playerName, content); } public async Task AddAIMessage(string content) { - await chatHistory.AddMessage(AIName, content); + await AddMessage(AIName, content); } protected string ChatContent(ChatResult result) @@ -466,7 +471,9 @@ public async Task Chat(string query, Callback callback = null, E var playerMessage = new ChatMessage() { role = playerName, content = query }; // Setup the full list of messages for the current request - List promptMessages = chatHistory ? chatHistory.GetChatMessages() : new List(); + List promptMessages = chatHistory ? + await chatHistory.GetChatMessages() : + new List(); promptMessages.Insert(0, GetSystemPromptMessage()); promptMessages.Add(playerMessage); @@ -480,8 +487,12 @@ public async Task Chat(string query, Callback callback = null, E // Update our chat history if required if (addToHistory && result != null) { - await AddPlayerMessage(query); - await AddAIMessage(result); + await _chatHistory.AddMessages( + new List { + new ChatMessage { role = playerName, content = query }, + new ChatMessage { role = AIName, content = result } + } + ); } await SaveCache(); @@ -767,9 +778,11 @@ protected async Task PostRequestRemote(string json, string endpoi { result = default; error = request.error; + if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break; } } tryNr--; + if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr)); } if (error != null) LLMUnitySetup.LogError(error); @@ -807,6 +820,12 @@ public virtual async Task Load(string filename) { return await LoadCache(); } + [Obsolete] + public virtual string GetSavePath(string filename) + { + return _chatHistory.GetChatHistoryFilePath(); + } + #endregion } diff --git a/Runtime/LLMChatHistory.cs b/Runtime/LLMChatHistory.cs index fda00202..df49d5d8 100644 --- a/Runtime/LLMChatHistory.cs +++ b/Runtime/LLMChatHistory.cs @@ -34,7 +34,7 @@ public class LLMChatHistory : MonoBehaviour /// /// Ensures we're not trying to update the chat while saving or loading /// - protected SemaphoreSlim chatLock = new SemaphoreSlim(1, 1); + protected SemaphoreSlim _chatLock = new SemaphoreSlim(1, 1); /// /// The Unity Awake function that initializes the state before the application starts. @@ -50,13 +50,10 @@ public async void Awake() } } - /// - /// Appends a new message to the end of this chat. - /// - public async Task AddMessage(string role, string content) + public async Task AddMessages(List messages) { await WithChatLock(async () => { - await Task.Run(() => _chatHistory.Add(new ChatMessage { role = role, content = content })); + await Task.Run(() => _chatHistory.AddRange(messages)); }); if (EnableAutoSave) { @@ -64,8 +61,20 @@ await WithChatLock(async () => { } } - public List GetChatMessages() { - return new List(_chatHistory); + public async Task AddMessage(string role, string content) + { + await AddMessages(new List { new ChatMessage { role = role, content = content } }); + } + + public async Task> GetChatMessages() { + + List chatMessages = null; + + await WithChatLock(async () => { + await Task.Run(() => chatMessages = new List(_chatHistory)); + }); + + return chatMessages; } /// @@ -129,18 +138,18 @@ public bool IsEmpty() { return _chatHistory?.Count == 0; } - protected string GetChatHistoryFilePath() + public string GetChatHistoryFilePath() { return Path.Combine(Application.persistentDataPath, ChatHistoryFilename + ".json").Replace('\\', '/'); } protected async Task WithChatLock(Func action) { - await chatLock.WaitAsync(); + await _chatLock.WaitAsync(); try { await action(); } finally { - chatLock.Release(); + _chatLock.Release(); } } } diff --git a/Tests/Runtime/TestLLM.cs b/Tests/Runtime/TestLLM.cs index 57040110..76544e92 100644 --- a/Tests/Runtime/TestLLM.cs +++ b/Tests/Runtime/TestLLM.cs @@ -255,7 +255,7 @@ public void TestInitParameters(int nKeep, int expectedMessageCount) { Assert.AreEqual(llmCharacter.nKeep, nKeep); Assert.That(ChatTemplate.GetTemplate(llm.chatTemplate).GetStop(llmCharacter.playerName, llmCharacter.AIName).Length > 0); - Assert.AreEqual(llmCharacter.chatHistory?.GetChatMessages().Count, expectedMessageCount); + Assert.AreEqual(llmCharacter.chatHistory?.GetChatMessages().Result.Count, expectedMessageCount); } public void TestTokens(List tokens) @@ -271,7 +271,7 @@ public void TestChat(string generatedReply, string expectedReply) public void TestPostChat(int num) { - Assert.AreEqual(num, llmCharacter.chatHistory.GetChatMessages().Count); + Assert.AreEqual(num, llmCharacter.chatHistory.GetChatMessages().Result.Count); } public void TestEmbeddings(List embeddings) diff --git a/Tests/Runtime/TestLLMChatHistory.cs b/Tests/Runtime/TestLLMChatHistory.cs index b3b937c7..42033762 100644 --- a/Tests/Runtime/TestLLMChatHistory.cs +++ b/Tests/Runtime/TestLLMChatHistory.cs @@ -46,7 +46,7 @@ public async void TestSaveAndLoad() // 3. ASSERT // Validate the messages were loaded - List loadedMessages = _chatHistory.GetChatMessages(); + List loadedMessages = await _chatHistory.GetChatMessages(); Assert.AreEqual(loadedMessages.Count, 2); Assert.AreEqual(loadedMessages[0].role, "user"); Assert.AreEqual(loadedMessages[0].content, "hello");