Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pulling the chat history into LLMChatHistory #244

Open
wants to merge 7 commits into
base: release/v2.3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 63 additions & 133 deletions Runtime/LLMCharacter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ public class LLMCharacter : MonoBehaviour
[Remote] public int numRetries = 10;
/// <summary> allows to use a server with API key </summary>
[Remote] public string APIKey;
/// <summary> file to save the chat history.
/// The file is saved only for Chat calls with addToHistory set to true.
/// The file will be saved within the persistentDataPath directory (see https://docs.unity3d.com/ScriptReference/Application-persistentDataPath.html). </summary>
[LLM] public string save = "";
/// <summary> file to save the cache. </summary>
[LLM] public string cacheFilename = "";
/// <summary> toggle to save the LLM cache. This speeds up the prompt calculation but also requires ~100MB of space per character. </summary>
[LLM] public bool saveCache = false;
/// <summary> select to log the constructed prompt the Unity Editor. </summary>
Expand Down Expand Up @@ -113,21 +111,20 @@ public class LLMCharacter : MonoBehaviour
public Dictionary<int, string> logitBias = null;

/// <summary> the name of the player </summary>
[Chat] public string playerName = "user";
[Chat] public string playerRole = "user";
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
/// <summary> the name of the AI </summary>
[Chat] public string AIName = "assistant";
[Chat] public string aiRole = "assistant";
/// <summary> a description of the AI role. This defines the LLMCharacter system prompt </summary>
[TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
[TextArea(5, 10), Chat] public string systemPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
/// <summary> option to set the number of tokens to retain from the prompt (nKeep) based on the LLMCharacter system prompt </summary>
public bool setNKeepToPrompt = true;

/// \cond HIDE
public List<ChatMessage> chat;
private SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
private LLMChatHistory chatHistory;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chatHistory should be public to allow to modify it or switch to another

private string chatTemplate;
private ChatTemplate template = null;
public string grammarString;
private List<(string, string)> requestHeaders;
private List<(string, string)> requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
private List<UnityWebRequest> WIPRequests = new List<UnityWebRequest>();
/// \endcond

Expand All @@ -140,12 +137,10 @@ public class LLMCharacter : MonoBehaviour
/// - the chat template is constructed
/// - the number of tokens to keep are based on the system prompt (if setNKeepToPrompt=true)
/// </summary>
public void Awake()
public async void Awake()
{
// Start the LLM server in a cross-platform way
if (!enabled) return;

requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
if (!remote)
{
AssignLLM();
Expand All @@ -164,6 +159,7 @@ public void Awake()

InitGrammar();
InitHistory();
await LoadCache();
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
}

void OnValidate()
Expand Down Expand Up @@ -217,70 +213,34 @@ void SortBySceneAndHierarchy(LLM[] array)

protected void InitHistory()
{
InitPrompt();
_ = LoadHistory();
}
// Check if we have a chat history component available
chatHistory = GetComponent<LLMChatHistory>();

protected async Task LoadHistory()
{
if (save == "" || !File.Exists(GetJsonSavePath(save))) return;
await chatLock.WaitAsync(); // Acquire the lock
try
{
await Load(save);
}
finally
{
chatLock.Release(); // Release the lock
// If not, go ahead and create one.
if (chatHistory == null) {
chatHistory = gameObject.AddComponent<LLMChatHistory>();
}
}

public virtual string GetSavePath(string filename)
{
return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/');
}

public virtual string GetJsonSavePath(string filename)
public virtual string GetCacheSavePath()
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
{
return GetSavePath(filename + ".json");
}

public virtual string GetCacheSavePath(string filename)
{
return GetSavePath(filename + ".cache");
}

private void InitPrompt(bool clearChat = true)
{
if (chat != null)
{
if (clearChat) chat.Clear();
}
else
{
chat = new List<ChatMessage>();
}
ChatMessage promptMessage = new ChatMessage { role = "system", content = prompt };
if (chat.Count == 0)
{
chat.Add(promptMessage);
}
else
{
chat[0] = promptMessage;
}
return Path.Combine(Application.persistentDataPath, cacheFilename + ".cache").Replace('\\', '/');
}

/// <summary>
/// Set the system prompt for the LLMCharacter.
/// </summary>
/// <param name="newPrompt"> the system prompt </param>
/// <param name="clearChat"> whether to clear (true) or keep (false) the current chat history on top of the system prompt. </param>
public void SetPrompt(string newPrompt, bool clearChat = true)
public async Task SetPrompt(string newPrompt, bool clearChat = true)
{
prompt = newPrompt;
systemPrompt = newPrompt;
nKeep = -1;
InitPrompt(clearChat);

if (clearChat) {
// Clear any existing messages
await chatHistory?.Clear();
}
}

private bool CheckTemplate()
Expand All @@ -293,12 +253,16 @@ private bool CheckTemplate()
return true;
}

private ChatMessage GetSystemPromptMessage() {
return new ChatMessage() { role = LLMConstants.SYSTEM_ROLE, content = systemPrompt };
}

private async Task<bool> InitNKeep()
{
if (setNKeepToPrompt && nKeep == -1)
{
if (!CheckTemplate()) return false;
string systemPrompt = template.ComputePrompt(new List<ChatMessage>(){chat[0]}, playerName, "", false);
string systemPrompt = template.ComputePrompt(new List<ChatMessage>(){GetSystemPromptMessage()}, playerRole, aiRole, false);
List<int> tokens = await Tokenize(systemPrompt);
if (tokens == null) return false;
SetNKeep(tokens);
Expand Down Expand Up @@ -360,7 +324,7 @@ public async void SetGrammar(string path)
List<string> GetStopwords()
{
if (!CheckTemplate()) return null;
List<string> stopAll = new List<string>(template.GetStop(playerName, AIName));
List<string> stopAll = new List<string>(template.GetStop(playerRole, aiRole));
if (stop != null) stopAll.AddRange(stop);
return stopAll;
}
Expand Down Expand Up @@ -400,20 +364,19 @@ ChatRequest GenerateRequest(string prompt)
return chatRequest;
}

public void AddMessage(string role, string content)
public async Task AddPlayerMessage(string content)
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
{
// add the question / answer to the chat list, update prompt
chat.Add(new ChatMessage { role = role, content = content });
await chatHistory.AddMessage(playerRole, content);
}

public void AddPlayerMessage(string content)
public async Task AddAIMessage(string content)
{
AddMessage(playerName, content);
await chatHistory.AddMessage(aiRole, content);
}

public void AddAIMessage(string content)
public LLMChatHistory GetChatHistory()
{
AddMessage(AIName, content);
return chatHistory;
}

protected string ChatContent(ChatResult result)
Expand Down Expand Up @@ -490,44 +453,33 @@ protected string SlotContent(SlotResult result)
/// <returns>the LLM response</returns>
public async Task<string> Chat(string query, Callback<string> callback = null, EmptyCallback completionCallback = null, bool addToHistory = true)
{
// handle a chat message by the user
// call the callback function while the answer is received
// call the completionCallback function when the answer is fully received
await LoadTemplate();
if (!CheckTemplate()) return null;
if (!await InitNKeep()) return null;

var playerMessage = new ChatMessage() { role = playerRole, content = query };

string json;
await chatLock.WaitAsync();
try
{
AddPlayerMessage(query);
string prompt = template.ComputePrompt(chat, playerName, AIName);
json = JsonUtility.ToJson(GenerateRequest(prompt));
chat.RemoveAt(chat.Count - 1);
}
finally
{
chatLock.Release();
}
// Setup the full list of messages for the current request
List<ChatMessage> promptMessages = chatHistory ? chatHistory.GetChatMessages() : new List<ChatMessage>();
promptMessages.Insert(0, GetSystemPromptMessage());
promptMessages.Add(playerMessage);

string result = await CompletionRequest(json, callback);
// Prepare the request
string formattedPrompt = template.ComputePrompt(promptMessages, playerRole, aiRole);
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
string requestJson = JsonUtility.ToJson(GenerateRequest(formattedPrompt));

// Call the LLM
string result = await CompletionRequest(requestJson, callback);

if (addToHistory && result != null)
// Update our chat history if required
if (addToHistory && chatHistory && result != null)
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
{
await chatLock.WaitAsync();
try
{
AddPlayerMessage(query);
AddAIMessage(result);
}
finally
{
chatLock.Release();
}
if (save != "") _ = Save(save);
await AddPlayerMessage(query);
await AddAIMessage(result);
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
}

await SaveCache();

completionCallback?.Invoke();
return result;
}
Expand Down Expand Up @@ -623,7 +575,7 @@ public async Task<List<float>> Embeddings(string query, Callback<List<float>> ca
return await PostRequest<EmbeddingsResult, List<float>>(json, "embeddings", EmbeddingsContent, callback);
}

protected async Task<string> Slot(string filepath, string action)
private async Task<string> Slot(string filepath, string action)
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
{
SlotRequest slotRequest = new SlotRequest();
slotRequest.id_slot = slot;
Expand All @@ -634,46 +586,26 @@ protected async Task<string> Slot(string filepath, string action)
}

/// <summary>
/// Saves the chat history and cache to the provided filename / relative path.
/// Saves the cache to the provided filename / relative path.
/// </summary>
/// <param name="filename">filename / relative path to save the chat history</param>
/// <param name="filename">filename / relative path to save the cache</param>
/// <returns></returns>
public virtual async Task<string> Save(string filename)
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
public virtual async Task<string> SaveCache()
{
string filepath = GetJsonSavePath(filename);
string dirname = Path.GetDirectoryName(filepath);
if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname);
string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) });
File.WriteAllText(filepath, json);

string cachepath = GetCacheSavePath(filename);
if (remote || !saveCache) return null;
string result = await Slot(cachepath, "save");
string result = await Slot(GetCacheSavePath(), "save");
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
return result;
}

/// <summary>
/// Load the chat history and cache from the provided filename / relative path.
/// Load the cache from the provided filename / relative path.
/// </summary>
/// <param name="filename">filename / relative path to load the chat history from</param>
/// <param name="filename">filename / relative path to load the cache from</param>
/// <returns></returns>
public virtual async Task<string> Load(string filename)
public virtual async Task<string> LoadCache()
{
string filepath = GetJsonSavePath(filename);
if (!File.Exists(filepath))
{
LLMUnitySetup.LogError($"File {filepath} does not exist.");
return null;
}
string json = File.ReadAllText(filepath);
List<ChatMessage> chatHistory = JsonUtility.FromJson<ChatListWrapper>(json).chat;
InitPrompt(true);
chat.AddRange(chatHistory);
LLMUnitySetup.Log($"Loaded {filepath}");

string cachepath = GetCacheSavePath(filename);
if (remote || !saveCache || !File.Exists(GetSavePath(cachepath))) return null;
string result = await Slot(cachepath, "restore");
if (remote || !saveCache || !File.Exists(GetCacheSavePath())) return null;
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
string result = await Slot(GetCacheSavePath(), "restore");
return result;
}

Expand Down Expand Up @@ -826,11 +758,9 @@ protected async Task<Ret> PostRequestRemote<Res, Ret>(string json, string endpoi
{
result = default;
error = request.error;
if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break;
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
}
}
tryNr--;
if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr));
nabrown737 marked this conversation as resolved.
Show resolved Hide resolved
}

if (error != null) LLMUnitySetup.LogError(error);
Expand Down
Loading