Skip to content

Commit

Permalink
Add IInternalClock abstractions; Add ServiceCollectionExtensions
Browse files Browse the repository at this point in the history
  • Loading branch information
rodion-m committed Apr 15, 2023
1 parent 553b9e9 commit 9df3c1d
Show file tree
Hide file tree
Showing 20 changed files with 338 additions and 37 deletions.
1 change: 0 additions & 1 deletion OpenAI.ChatGpt.EntityFrameworkCore/EfChatHistoryStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ namespace OpenAI.ChatGpt.EntityFrameworkCore;
public class EfChatHistoryStorage : IChatHistoryStorage
{
private readonly ChatGptDbContext _dbContext;


public EfChatHistoryStorage(ChatGptDbContext dbContext)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using OpenAI.ChatGpt.Interfaces;
using static OpenAI.ChatGpt.Extensions.ServiceCollectionExtensions;

namespace OpenAI.ChatGpt.EntityFrameworkCore.Extensions;

public static class ServiceCollectionExtensions
{
/// <summary>
/// Adds the <see cref="IChatHistoryStorage"/> implementation using Entity Framework Core.
/// </summary>
public static IServiceCollection AddChatGptEntityFrameworkIntegration(
this IServiceCollection services,
Action<DbContextOptionsBuilder> optionsAction,
string credentialsConfigSectionPath = CredentialsConfigSectionPathDefault,
string completionsConfigSectionPath = CompletionsConfigSectionPathDefault)
{
ArgumentNullException.ThrowIfNull(services);
ArgumentNullException.ThrowIfNull(optionsAction);
if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(credentialsConfigSectionPath));
}
if (string.IsNullOrWhiteSpace(completionsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(completionsConfigSectionPath));
}

services.AddChatGptIntegrationCore(credentialsConfigSectionPath, completionsConfigSectionPath);
services.AddDbContext<ChatGptDbContext>(optionsAction);
services.AddScoped<IChatHistoryStorage, EfChatHistoryStorage>();
return services;
}
}
12 changes: 8 additions & 4 deletions OpenAI.ChatGpt/Chat.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Runtime.CompilerServices;
using System.Text;
using OpenAI.ChatGpt.Interfaces;
using OpenAI.ChatGpt.Internal;
using OpenAI.ChatGpt.Models;
using OpenAI.Models.ChatCompletion;

Expand All @@ -19,21 +20,24 @@ public class Chat : IDisposable
public bool IsCancelled => _cts?.IsCancellationRequested ?? false;

private readonly IChatHistoryStorage _chatHistoryStorage;
private readonly IInternalClock _clock;
private readonly OpenAiClient _client;
private bool _isNew;
private CancellationTokenSource? _cts;

internal Chat(
IChatHistoryStorage chatHistoryStorage,
IInternalClock clock,
OpenAiClient client,
string userId,
Topic topic,
bool isNew)
{
UserId = userId ?? throw new ArgumentNullException(nameof(userId));
Topic = topic ?? throw new ArgumentNullException(nameof(topic));
_chatHistoryStorage = chatHistoryStorage ?? throw new ArgumentNullException(nameof(chatHistoryStorage));
_clock = clock ?? throw new ArgumentNullException(nameof(clock));
_client = client ?? throw new ArgumentNullException(nameof(client));
UserId = userId ?? throw new ArgumentNullException(nameof(userId));
Topic = topic ?? throw new ArgumentNullException(nameof(topic));
_isNew = isNew;
}

Expand Down Expand Up @@ -65,7 +69,7 @@ private async Task<string> GetNextMessageResponse(
);

await _chatHistoryStorage.SaveMessages(
UserId, ChatId, message, response, _cts.Token);
UserId, ChatId, message, response, _clock.GetCurrentTime(), _cts.Token);
IsWriting = false;
_isNew = false;

Expand Down Expand Up @@ -105,7 +109,7 @@ private async IAsyncEnumerable<string> StreamNextMessageResponse(
}

await _chatHistoryStorage.SaveMessages(
UserId, ChatId, message, sb.ToString(), _cts.Token);
UserId, ChatId, message, sb.ToString(), _clock.GetCurrentTime(), _cts.Token);
IsWriting = false;
_isNew = false;
}
Expand Down
24 changes: 16 additions & 8 deletions OpenAI.ChatGpt/ChatGPT.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using OpenAI.ChatGpt.Interfaces;
using OpenAI.ChatGpt.Internal;
using OpenAI.ChatGpt.Models;
using OpenAI.Models.ChatCompletion;

Expand All @@ -10,6 +11,7 @@ public class ChatGPT : IDisposable
{
private readonly string _userId;
private readonly IChatHistoryStorage _chatHistoryStorage;
private readonly IInternalClock _clock;
private readonly ChatCompletionsConfig? _config;
private readonly OpenAiClient _client;
private Chat? _currentChat;
Expand All @@ -19,13 +21,15 @@ public class ChatGPT : IDisposable
/// </summary>
public ChatGPT(
OpenAiClient client,
string userId,
IChatHistoryStorage chatHistoryStorage,
IChatHistoryStorage chatHistoryStorage,
IInternalClock clock,
string userId,
ChatCompletionsConfig? config)
{
_client = client ?? throw new ArgumentNullException(nameof(client));
_userId = userId ?? throw new ArgumentNullException(nameof(userId));
_chatHistoryStorage = chatHistoryStorage ?? throw new ArgumentNullException(nameof(chatHistoryStorage));
_clock = clock ?? throw new ArgumentNullException(nameof(clock));
_config = config;
}

Expand All @@ -34,11 +38,13 @@ public ChatGPT(
/// </summary>
public ChatGPT(
OpenAiClient client,
IChatHistoryStorage chatHistoryStorage,
IChatHistoryStorage chatHistoryStorage,
IInternalClock clock,
ChatCompletionsConfig? config)
{
_client = client ?? throw new ArgumentNullException(nameof(client));
_chatHistoryStorage = chatHistoryStorage ?? throw new ArgumentNullException(nameof(chatHistoryStorage));
_clock = clock ?? throw new ArgumentNullException(nameof(clock));
_userId = Guid.Empty.ToString();
_config = config;
}
Expand All @@ -49,11 +55,12 @@ public ChatGPT(
public static Task<Chat> CreateInMemoryChat(
string apiKey,
ChatCompletionsConfig? config = null,
UserOrSystemMessage? initialDialog = null)
UserOrSystemMessage? initialDialog = null,
IInternalClock? clock = null)
{
if (apiKey == null) throw new ArgumentNullException(nameof(apiKey));
var client = new OpenAiClient(apiKey);
var chatGpt = new ChatGPT(client, new InMemoryChatHistoryStorage(), config);
var chatGpt = new ChatGPT(client, new InMemoryChatHistoryStorage(), clock ?? new InternalClockUtc(), config);
return chatGpt.StartNewTopic(initialDialog: initialDialog);
}

Expand Down Expand Up @@ -82,8 +89,9 @@ public async Task<Chat> StartNewTopic(
CancellationToken cancellationToken = default)
{
config = ChatCompletionsConfig.CombineOrDefault(_config, config);
var topic = new Topic(_chatHistoryStorage.NewTopicId(), _userId, name, _chatHistoryStorage.Now(), config);
var topic = new Topic(_chatHistoryStorage.NewTopicId(), _userId, name, _clock.GetCurrentTime(), config);
await _chatHistoryStorage.AddTopic(topic, cancellationToken);
initialDialog ??= config.GetInitialDialogOrNull();
if (initialDialog is not null)
{
var messages = ConvertToPersistentMessages(initialDialog, topic);
Expand All @@ -98,7 +106,7 @@ private IEnumerable<PersistentChatMessage> ConvertToPersistentMessages(ChatCompl
{
return dialog.GetMessages()
.Select(m => new PersistentChatMessage(
_chatHistoryStorage.NewMessageId(), _userId, topic.Id, _chatHistoryStorage.Now(), m)
_chatHistoryStorage.NewMessageId(), _userId, topic.Id, _clock.GetCurrentTime(), m)
);
}

Expand All @@ -122,7 +130,7 @@ private Task<Chat> SetTopic(Topic topic, CancellationToken cancellationToken = d
private Chat CreateChat(Topic topic, bool isNew)
{
if (topic == null) throw new ArgumentNullException(nameof(topic));
return new Chat(_chatHistoryStorage, _client, _userId, topic, isNew);
return new Chat(_chatHistoryStorage, _clock, _client, _userId, topic, isNew);
}

public async Task<IReadOnlyList<Topic>> GetTopics(CancellationToken cancellationToken = default)
Expand Down
42 changes: 33 additions & 9 deletions OpenAI.ChatGpt/ChatGPTFactory.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.Extensions.Options;
using OpenAI.ChatGpt.Interfaces;
using OpenAI.ChatGpt.Internal;
using OpenAI.ChatGpt.Models;

namespace OpenAI.ChatGpt;
Expand All @@ -21,44 +22,64 @@ public class ChatGPTFactory : IDisposable
private readonly OpenAiClient _client;
private readonly ChatCompletionsConfig _config;
private readonly IChatHistoryStorage _chatHistoryStorage;
private readonly IInternalClock _clock;
private bool _ensureStorageCreatedCalled;

public ChatGPTFactory(
HttpClient httpClient,
IHttpClientFactory httpClientFactory,
IOptions<ChatGptCredentials> credentials,
IOptions<ChatCompletionsConfig> config,
IChatHistoryStorage chatHistoryStorage)
IChatHistoryStorage chatHistoryStorage,
IInternalClock clock)
{
if (httpClient == null) throw new ArgumentNullException(nameof(httpClient));
if (httpClientFactory == null) throw new ArgumentNullException(nameof(httpClientFactory));
if (credentials?.Value == null) throw new ArgumentNullException(nameof(credentials));
_config = config?.Value ?? throw new ArgumentNullException(nameof(config));
_chatHistoryStorage = chatHistoryStorage ?? throw new ArgumentNullException(nameof(chatHistoryStorage));
_client = new OpenAiClient(httpClient);
_clock = clock ?? throw new ArgumentNullException(nameof(clock));
_client = CreateOpenAiClient(httpClientFactory, credentials);
}

public ChatGPTFactory(
IOptions<ChatGptCredentials> credentials,
IOptions<ChatCompletionsConfig> config,
IChatHistoryStorage chatHistoryStorage)
IChatHistoryStorage chatHistoryStorage,
IInternalClock clock)
{
if (credentials?.Value == null) throw new ArgumentNullException(nameof(credentials));
_config = config?.Value ?? throw new ArgumentNullException(nameof(config));
_chatHistoryStorage = chatHistoryStorage ?? throw new ArgumentNullException(nameof(chatHistoryStorage));
_clock = clock ?? throw new ArgumentNullException(nameof(clock));
_client = new OpenAiClient(credentials.Value.ApiKey);
}

public ChatGPTFactory(
string apiKey,
IChatHistoryStorage chatHistoryStorage,
IChatHistoryStorage chatHistoryStorage,
IInternalClock? clock = null,
ChatCompletionsConfig? config = null)
{
if (apiKey == null) throw new ArgumentNullException(nameof(apiKey));
_client = new OpenAiClient(apiKey);
_config = config ?? ChatCompletionsConfig.Default;
_chatHistoryStorage = chatHistoryStorage ?? throw new ArgumentNullException(nameof(chatHistoryStorage));
_clock = clock ?? new InternalClockUtc();
}

private OpenAiClient CreateOpenAiClient(
IHttpClientFactory httpClientFactory,
IOptions<ChatGptCredentials> credentials)
{
var httpClient = httpClientFactory.CreateClient(nameof(ChatGPTFactory));
httpClient.DefaultRequestHeaders.Authorization = credentials.Value.GetAuthHeader();
httpClient.BaseAddress = new Uri(credentials.Value.ApiHost);
return new OpenAiClient(httpClient);
}

public static ChatGPTFactory CreateInMemory(string apiKey, ChatCompletionsConfig? config = null)
{
if (apiKey == null) throw new ArgumentNullException(nameof(apiKey));
return new ChatGPTFactory(apiKey, new InMemoryChatHistoryStorage(), config);
return new ChatGPTFactory(apiKey, new InMemoryChatHistoryStorage(), new InternalClockUtc(), config);
}

public async Task<ChatGPT> Create(
Expand All @@ -68,14 +89,16 @@ public async Task<ChatGPT> Create(
CancellationToken cancellationToken = default)
{
if (userId == null) throw new ArgumentNullException(nameof(userId));
if (ensureStorageCreated)
if (ensureStorageCreated && !_ensureStorageCreatedCalled)
{
await _chatHistoryStorage.EnsureStorageCreated(cancellationToken);
_ensureStorageCreatedCalled = true;
}
return new ChatGPT(
_client,
userId,
_chatHistoryStorage,
_clock,
userId,
ChatCompletionsConfig.Combine(_config, config)
);
}
Expand All @@ -92,6 +115,7 @@ public async Task<ChatGPT> Create(
return new ChatGPT(
_client,
_chatHistoryStorage,
_clock,
ChatCompletionsConfig.Combine(_config, config)
);
}
Expand Down
65 changes: 65 additions & 0 deletions OpenAI.ChatGpt/Extensions/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using Microsoft.Extensions.DependencyInjection;
using OpenAI.ChatGpt.Interfaces;
using OpenAI.ChatGpt.Internal;
using OpenAI.ChatGpt.Models;

namespace OpenAI.ChatGpt.Extensions;

public static class ServiceCollectionExtensions
{
public const string CredentialsConfigSectionPathDefault = "ChatGptCredentials";
public const string CompletionsConfigSectionPathDefault = "ChatCompletionsConfig";

public static IServiceCollection AddChatGptInMemoryIntegration(
this IServiceCollection services,
string credentialsConfigSectionPath = CredentialsConfigSectionPathDefault,
string completionsConfigSectionPath = CompletionsConfigSectionPathDefault)
{
ArgumentNullException.ThrowIfNull(services);
if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(credentialsConfigSectionPath));
}
if (string.IsNullOrWhiteSpace(completionsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(completionsConfigSectionPath));
}
services.AddChatGptIntegrationCore(credentialsConfigSectionPath, completionsConfigSectionPath);
services.AddSingleton<IChatHistoryStorage, InMemoryChatHistoryStorage>();
return services;
}
public static IServiceCollection AddChatGptIntegrationCore(
this IServiceCollection services,
string credentialsConfigSectionPath = CredentialsConfigSectionPathDefault,
string completionsConfigSectionPath = CompletionsConfigSectionPathDefault)
{
ArgumentNullException.ThrowIfNull(services);
if (string.IsNullOrWhiteSpace(credentialsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(credentialsConfigSectionPath));
}
if (string.IsNullOrWhiteSpace(completionsConfigSectionPath))
{
throw new ArgumentException("Value cannot be null or whitespace.",
nameof(completionsConfigSectionPath));
}

services.AddOptions<ChatGptCredentials>()
.BindConfiguration(credentialsConfigSectionPath)
.ValidateDataAnnotations();
services.AddOptions<ChatCompletionsConfig>()
.BindConfiguration(completionsConfigSectionPath)
.Configure(_ => { })
.ValidateDataAnnotations();

services.AddHttpClient();

services.AddSingleton<IInternalClock, InternalClockUtc>();
services.AddSingleton<ChatGPTFactory>();

return services;
}
}
6 changes: 6 additions & 0 deletions OpenAI.ChatGpt/InMemoryChatHistoryStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@

namespace OpenAI.ChatGpt;

/// <summary>
/// Represents an in-memory storage for managing messages and topics.
/// </summary>
/// <remarks>
/// Thread safe for different users. Not thread safe for the same user.
/// </remarks>
public class InMemoryChatHistoryStorage : IChatHistoryStorage
{
private readonly ConcurrentDictionary<string, Dictionary<Guid, Topic>> _users = new();
Expand Down
Loading

0 comments on commit 9df3c1d

Please sign in to comment.