Skip to content

Commit

Permalink
Add project OpenAI.ChatGpt.EntityFrameworkCore; Add PersistentChatMes…
Browse files Browse the repository at this point in the history
…sage class
  • Loading branch information
rodion-m committed Mar 21, 2023
1 parent bf52335 commit d1efb43
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 50 deletions.
14 changes: 14 additions & 0 deletions OpenAI.ChatGpt.EntityFrameworkCore/ChatGptDbContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Microsoft.EntityFrameworkCore;
using OpenAI.ChatGpt.Models;

namespace OpenAI.ChatGpt.EntityFrameworkCore;

public class ChatGptDbContext : DbContext
{
public DbSet<Topic> Topics => Set<Topic>();
public DbSet<PersistentChatMessage> Messages => Set<PersistentChatMessage>();

public ChatGptDbContext(DbContextOptions<ChatGptDbContext> options) : base(options)
{
}
}
70 changes: 70 additions & 0 deletions OpenAI.ChatGpt.EntityFrameworkCore/EfMessageStore.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using Microsoft.EntityFrameworkCore;
using OpenAI.ChatGpt.Models;

namespace OpenAI.ChatGpt.EntityFrameworkCore;

public class EfMessageStore : IMessageStore
{
private readonly ChatGptDbContext _dbContext;

public EfMessageStore(ChatGptDbContext dbContext)
{
_dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext));
}

public async Task<IEnumerable<Topic>> GetTopics(string userId, CancellationToken cancellationToken)
{
if (userId == null) throw new ArgumentNullException(nameof(userId));
return await _dbContext.Topics.ToListAsync(cancellationToken: cancellationToken);
}

public Task<Topic> GetTopic(string userId, Guid topicId, CancellationToken cancellationToken)
{
if (userId == null) throw new ArgumentNullException(nameof(userId));
return _dbContext.Topics.FirstAsync(
it => it.Id == topicId && it.UserId == userId,
cancellationToken: cancellationToken);
}

public async Task AddTopic(Topic topic, CancellationToken cancellationToken)
{
if (topic == null) throw new ArgumentNullException(nameof(topic));
await _dbContext.Topics.AddAsync(topic, cancellationToken);
await _dbContext.SaveChangesAsync(cancellationToken);
}

public async Task SaveMessages(
string userId,
Guid topicId,
IEnumerable<PersistentChatMessage> messages,
CancellationToken cancellationToken)
{
if (userId == null) throw new ArgumentNullException(nameof(userId));
if (messages == null) throw new ArgumentNullException(nameof(messages));
await _dbContext.AddRangeAsync(messages, cancellationToken);
await _dbContext.SaveChangesAsync(cancellationToken);
}

public async Task<IEnumerable<PersistentChatMessage>> GetMessages(
string userId, Guid topicId, CancellationToken cancellationToken)
{
if (userId == null) throw new ArgumentNullException(nameof(userId));
return await _dbContext.Messages
.Where(it => it.TopicId == topicId && it.UserId == userId)
.ToListAsync(cancellationToken: cancellationToken);
}

public Task<Topic?> GetLastTopicOrNull(string userId, CancellationToken cancellationToken)
{
if (userId == null) throw new ArgumentNullException(nameof(userId));
return _dbContext.Topics
.Where(it => it.UserId == userId)
.OrderByDescending(it => it.CreatedAt)
.FirstOrDefaultAsync(cancellationToken: cancellationToken);
}

public Task EnsureStorageCreated()
{
return _dbContext.Database.EnsureCreatedAsync();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<TargetFrameworks>net6.0;net7.0</TargetFrameworks>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.EntityFrameworkCore" Version="7.0.4" />
<PackageReference Include="Microsoft.EntityFrameworkCore.Design" Version="7.0.4">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\OpenAI.ChatGpt\OpenAI.ChatGpt.csproj" />
</ItemGroup>

</Project>
8 changes: 4 additions & 4 deletions OpenAI.ChatGpt/Chat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public Task<string> GetNextMessageResponse(

private async Task<string> GetNextMessageResponse(
UserOrSystemMessage message,
CancellationToken cancellationToken)
CancellationToken cancellationToken)
{
_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_cts.Token.Register(() => IsWriting = false);
Expand Down Expand Up @@ -109,10 +109,10 @@ await _messageStore.SaveMessages(
_isNew = false;
}

private Task<IEnumerable<ChatCompletionMessage>> LoadHistory(CancellationToken cancellationToken)
private async Task<IEnumerable<ChatCompletionMessage>> LoadHistory(CancellationToken cancellationToken)
{
if (_isNew) return Task.FromResult(Enumerable.Empty<ChatCompletionMessage>());
return _messageStore.GetMessages(UserId, ChatId, cancellationToken);
if (_isNew) return Enumerable.Empty<ChatCompletionMessage>();
return await _messageStore.GetMessages(UserId, ChatId, cancellationToken);
}

public void Stop()
Expand Down
17 changes: 12 additions & 5 deletions OpenAI.ChatGpt/ChatGPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,29 @@ public async Task<Chat> StartNewTopic(
string? name = null,
ChatCompletionsConfig? config = null,
UserOrSystemMessage? initialDialog = null,
DateTimeOffset? createdAt = null,
CancellationToken cancellationToken = default)
{
createdAt ??= DateTimeOffset.Now;
config = ChatCompletionsConfig.CombineOrDefault(_config, config);
var topic = new Topic(_messageStore.NewTopicId(), _userId, name, createdAt.Value, config);
await _messageStore.AddTopic(_userId, topic, cancellationToken);
var topic = new Topic(_messageStore.NewTopicId(), _userId, name, _messageStore.Now(), config);
await _messageStore.AddTopic(topic, cancellationToken);
if (initialDialog is not null)
{
await _messageStore.SaveMessages(_userId, topic.Id, initialDialog.GetMessages(), cancellationToken);
var messages = ConvertToPersistentMessages(initialDialog, topic);
await _messageStore.SaveMessages(_userId, topic.Id, messages, cancellationToken);
}

_currentChat = CreateChat(topic, true);
return _currentChat;
}

private IEnumerable<PersistentChatMessage> ConvertToPersistentMessages(ChatCompletionMessage dialog, Topic topic)
{
return dialog.GetMessages()
.Select(m => new PersistentChatMessage(
_messageStore.NewMessageId(), _userId, topic.Id, _messageStore.Now(), m)
);
}

public async Task<Chat> SetTopic(Guid topicId, CancellationToken cancellationToken = default)
{
var topic = await _messageStore.GetTopic(_userId, topicId, cancellationToken);
Expand Down
30 changes: 15 additions & 15 deletions OpenAI.ChatGpt/ChatGPTFactory.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Net.Http.Headers;
using Microsoft.Extensions.Options;
using Microsoft.Extensions.Options;
using OpenAI.ChatGpt.Models;

namespace OpenAI.ChatGpt;
Expand Down Expand Up @@ -34,7 +33,7 @@ public ChatGPTFactory(
}

public ChatGPTFactory(
IOptions<ChatGPTCredentials> credentials,
IOptions<ChatGptCredentials> credentials,
IOptions<ChatCompletionsConfig> config,
IMessageStore messageStore)
{
Expand All @@ -61,10 +60,16 @@ public static ChatGPTFactory CreateInMemory(string apiKey, ChatCompletionsConfig
return new ChatGPTFactory(apiKey, new InMemoryMessageStore(), config);
}

public ChatGPT Create(string userId, ChatCompletionsConfig? config = null)
public async Task<ChatGPT> Create(
string userId,
ChatCompletionsConfig? config = null,
bool ensureStorageCreated = true)
{
if (userId == null) throw new ArgumentNullException(nameof(userId));
// one of config or _config must be not null:
if (ensureStorageCreated)
{
await _messageStore.EnsureStorageCreated();
}
return new ChatGPT(
_client,
userId,
Expand All @@ -73,8 +78,12 @@ public ChatGPT Create(string userId, ChatCompletionsConfig? config = null)
);
}

public ChatGPT Create(ChatCompletionsConfig? config = null)
public async Task<ChatGPT> Create(ChatCompletionsConfig? config = null, bool ensureStorageCreated = true)
{
if (ensureStorageCreated)
{
await _messageStore.EnsureStorageCreated();
}
return new ChatGPT(
_client,
_messageStore,
Expand All @@ -86,13 +95,4 @@ public void Dispose()
{
_client.Dispose();
}
}

internal class ChatGPTCredentials
{
public string ApiKey { get; set; }
public string? ApiHost { get; set; }

public AuthenticationHeaderValue AuthHeader()
=> new("Bearer", ApiKey);
}
19 changes: 12 additions & 7 deletions OpenAI.ChatGpt/IMessageStore.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,25 @@ public interface IMessageStore
{
Task<IEnumerable<Topic>> GetTopics(string userId, CancellationToken cancellationToken);
Task<Topic> GetTopic(string userId, Guid topicId, CancellationToken cancellationToken);
Task AddTopic(string userId, Topic topic, CancellationToken cancellationToken);
Task AddTopic(Topic topic, CancellationToken cancellationToken);

Task SaveMessages(
string userId,
Guid topicId,
IEnumerable<ChatCompletionMessage> messages,
IEnumerable<PersistentChatMessage> messages,
CancellationToken cancellationToken
);

Task<IEnumerable<ChatCompletionMessage>> GetMessages(
Task<IEnumerable<PersistentChatMessage>> GetMessages(
string userId,
Guid topicId,
CancellationToken cancellationToken
);

Task<Topic?> GetLastTopicOrNull(string userId, CancellationToken cancellationToken);

Task SaveMessages(string userId,
Task SaveMessages(
string userId,
Guid topicId,
UserOrSystemMessage message,
string assistantMessage,
Expand All @@ -33,13 +34,17 @@ Task SaveMessages(string userId,
if (userId == null) throw new ArgumentNullException(nameof(userId));
if (message == null) throw new ArgumentNullException(nameof(message));
if (assistantMessage == null) throw new ArgumentNullException(nameof(assistantMessage));
var enumerable = new ChatCompletionMessage[]
var now = Now();
var enumerable = new PersistentChatMessage[]
{
message,
new AssistantMessage(assistantMessage)
new(NewMessageId(), userId, topicId, now, message),
new(NewMessageId(), userId, topicId, now, ChatCompletionRoles.Assistant, assistantMessage),
};
return SaveMessages(userId, topicId, enumerable, cancellationToken);
}

Guid NewTopicId() => Guid.NewGuid();
Guid NewMessageId() => Guid.NewGuid();
DateTimeOffset Now() => DateTimeOffset.Now;
Task EnsureStorageCreated();
}
29 changes: 15 additions & 14 deletions OpenAI.ChatGpt/InMemoryMessageStore.cs
Original file line number Diff line number Diff line change
@@ -1,47 +1,46 @@
using System.Collections.Concurrent;
using OpenAI.ChatGpt.Models;
using OpenAI.Models.ChatCompletion;

namespace OpenAI.ChatGpt;

internal class InMemoryMessageStore : IMessageStore
{
private readonly ConcurrentDictionary<string, Dictionary<Guid, Topic>> _users = new();
private readonly ConcurrentDictionary<string, Dictionary<Guid, List<ChatCompletionMessage>>>
private readonly ConcurrentDictionary<string, Dictionary<Guid, List<PersistentChatMessage>>>
_messages = new();

public Task SaveMessages(
string userId,
Guid topicId,
IEnumerable<ChatCompletionMessage> messages,
IEnumerable<PersistentChatMessage> messages,
CancellationToken cancellationToken)
{
if (!_messages.TryGetValue(userId, out var userMessages))
{
userMessages = new Dictionary<Guid, List<ChatCompletionMessage>>();
userMessages = new Dictionary<Guid, List<PersistentChatMessage>>();
_messages.TryAdd(userId, userMessages);
}

if (!userMessages.TryGetValue(topicId, out var chatMessages))
{
chatMessages = new List<ChatCompletionMessage>();
chatMessages = new List<PersistentChatMessage>();
userMessages.TryAdd(topicId, chatMessages);
}

chatMessages.AddRange(messages);
return Task.CompletedTask;
}

public Task<IEnumerable<ChatCompletionMessage>> GetMessages(string userId, Guid topicId, CancellationToken cancellationToken)
public Task<IEnumerable<PersistentChatMessage>> GetMessages(string userId, Guid topicId, CancellationToken cancellationToken)
{
if (!_messages.TryGetValue(userId, out var userMessages))
{
return Task.FromResult(Enumerable.Empty<ChatCompletionMessage>());
return Task.FromResult(Enumerable.Empty<PersistentChatMessage>());
}

if (!userMessages.TryGetValue(topicId, out var chatMessages))
{
return Task.FromResult(Enumerable.Empty<ChatCompletionMessage>());
return Task.FromResult(Enumerable.Empty<PersistentChatMessage>());
}

return Task.FromResult(chatMessages.AsEnumerable());
Expand All @@ -57,7 +56,7 @@ public Task<IEnumerable<ChatCompletionMessage>> GetMessages(string userId, Guid
var lastTopic = userChats.Values.MaxBy(x => x.CreatedAt);
return Task.FromResult(lastTopic);
}

public Task<IEnumerable<Topic>> GetTopics(string userId, CancellationToken cancellationToken)
{
if (!_users.TryGetValue(userId, out var topics))
Expand All @@ -83,15 +82,17 @@ public Task<Topic> GetTopic(string userId, Guid topicId, CancellationToken cance
return Task.FromResult(topic);
}

public Task AddTopic(string userId, Topic topic, CancellationToken cancellationToken)
public Task AddTopic(Topic topic, CancellationToken cancellationToken)
{
if (!_users.TryGetValue(userId, out var userChats))
if (!_users.TryGetValue(topic.UserId, out var userTopics))
{
userChats = new Dictionary<Guid, Topic>();
_users.TryAdd(userId, userChats);
userTopics = new Dictionary<Guid, Topic>();
_users.TryAdd(topic.UserId, userTopics);
}

userChats.Add(topic.Id, topic);
userTopics.Add(topic.Id, topic);
return Task.CompletedTask;
}

public Task EnsureStorageCreated() => Task.CompletedTask;
}
7 changes: 7 additions & 0 deletions OpenAI.ChatGpt/Models/ChatGptCredentials.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
namespace OpenAI.ChatGpt.Models;

public class ChatGptCredentials
{
public string ApiKey { get; set; }
public string? ApiHost { get; set; }
}
Loading

0 comments on commit d1efb43

Please sign in to comment.