Skip to content

Commit

Permalink
Add thread-safe message scheduling and related tests (#1638)
Browse files Browse the repository at this point in the history
Introduce `ScheduledMediumMessageQueue` for thread-safe scheduling of messages. Updated `Dispatcher` to use the new queue and modified the scheduling logic for improved reliability. Added extensive unit tests to ensure correctness of message scheduling and publishing behavior under various scenarios.
  • Loading branch information
amimelia authored Jan 19, 2025
1 parent 8575109 commit e3d851f
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 22 deletions.
88 changes: 88 additions & 0 deletions src/DotNetCore.CAP/Internal/ScheduledMediumMessageQueue.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using DotNetCore.CAP.Persistence;

namespace DotNetCore.CAP.Internal;

public class ScheduledMediumMessageQueue
{
private readonly SortedSet<(long, MediumMessage)> _queue = new(Comparer<(long, MediumMessage)>.Create((a, b) =>
{
int result = a.Item1.CompareTo(b.Item1);
return result == 0 ? String.Compare(a.Item2.DbId, b.Item2.DbId, StringComparison.Ordinal) : result;
}));

private readonly SemaphoreSlim _semaphore = new(0);
private readonly object _lock = new();

public void Enqueue(MediumMessage message, long sendTime)
{
lock (_lock)
{
_queue.Add((sendTime, message));
}

_semaphore.Release();
}

public int Count
{
get
{
lock (_lock)
{
return _queue.Count;
}
}
}

public IEnumerable<MediumMessage> UnorderedItems
{
get
{
lock (_lock)
{
return _queue.Select(x => x.Item2).ToList();
}
}
}

public async IAsyncEnumerable<MediumMessage> GetConsumingEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default)
{
while (!cancellationToken.IsCancellationRequested)
{
await _semaphore.WaitAsync(cancellationToken);

(long, MediumMessage)? nextItem = null;

lock (_lock)
{
if (_queue.Count > 0)
{
var topMessage = _queue.First();
var timeLeft = topMessage.Item1 - DateTime.Now.Ticks;
if (timeLeft < 500000) // 50ms
{
nextItem = topMessage;
_queue.Remove(topMessage);
}
}
}

if (nextItem is not null)
{
yield return nextItem.Value.Item2;
}
else
{
// Re-release the semaphore if no item is ready yet
_semaphore.Release();
await Task.Delay(50, cancellationToken);
}
}
}
}
36 changes: 15 additions & 21 deletions src/DotNetCore.CAP/Processor/IDispatcher.Default.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,19 @@ namespace DotNetCore.CAP.Processor;

public class Dispatcher : IDispatcher
{
private readonly CancellationTokenSource _delayCts = new();
private readonly ISubscribeExecutor _executor;
private readonly ILogger<Dispatcher> _logger;
private readonly CapOptions _options;
private readonly IMessageSender _sender;
private readonly IDataStorage _storage;
private readonly PriorityQueue<MediumMessage, long> _schedulerQueue;
private readonly ScheduledMediumMessageQueue _schedulerQueue = new();
private readonly bool _enableParallelExecute;
private readonly bool _enableParallelSend;
private readonly int _pChannelSize;

private CancellationTokenSource? _tasksCts;
private Channel<MediumMessage> _publishedChannel = default!;
private Channel<(MediumMessage, ConsumerExecutorDescriptor?)> _receivedChannel = default!;
private long _nextSendTime = DateTime.MaxValue.Ticks;

public Dispatcher(ILogger<Dispatcher> logger, IMessageSender sender, IOptions<CapOptions> options,
ISubscribeExecutor executor, IDataStorage storage)
Expand All @@ -41,7 +39,6 @@ public Dispatcher(ILogger<Dispatcher> logger, IMessageSender sender, IOptions<Ca
_sender = sender;
_options = options.Value;
_executor = executor;
_schedulerQueue = new PriorityQueue<MediumMessage, long>();
_storage = storage;
_enableParallelExecute = options.Value.EnableSubscriberParallelExecute;
_enableParallelSend = options.Value.EnablePublishParallelSend;
Expand All @@ -52,7 +49,6 @@ public async Task Start(CancellationToken stoppingToken)
{
stoppingToken.ThrowIfCancellationRequested();
_tasksCts = CancellationTokenSource.CreateLinkedTokenSource(stoppingToken, CancellationToken.None);
_tasksCts.Token.Register(() => _delayCts.Cancel());

_publishedChannel = Channel.CreateBounded<MediumMessage>(new BoundedChannelOptions(_pChannelSize)
{
Expand Down Expand Up @@ -88,7 +84,7 @@ await Task.WhenAll(Enumerable.Range(0, _options.SubscriberParallelExecuteThreadC
{
if (_schedulerQueue.Count == 0) return;

var messageIds = _schedulerQueue.UnorderedItems.Select(x => x.Element.DbId).ToArray();
var messageIds = _schedulerQueue.UnorderedItems.Select(x => x.DbId).ToArray();
_storage.ChangePublishStateToDelayedAsync(messageIds).GetAwaiter().GetResult();
_logger.LogDebug("Update storage to delayed success of delayed message in memory queue!");
}
Expand All @@ -102,29 +98,32 @@ await Task.WhenAll(Enumerable.Range(0, _options.SubscriberParallelExecuteThreadC
{
try
{
while (_schedulerQueue.TryPeek(out _, out _nextSendTime))
await foreach (var nextMessage in _schedulerQueue.GetConsumingEnumerable(_tasksCts.Token))
{
var delayTime = _nextSendTime - DateTime.Now.Ticks;

if (delayTime > 500000) //50ms
{
await Task.Delay(new TimeSpan(delayTime), _delayCts.Token);
}
_tasksCts.Token.ThrowIfCancellationRequested();

await _sender.SendAsync(_schedulerQueue.Dequeue()).ConfigureAwait(false);
await _sender.SendAsync(nextMessage).ConfigureAwait(false);
}

_tasksCts.Token.WaitHandle.WaitOne(100);
}
catch (OperationCanceledException)
{
//Ignore
}
catch (Exception ex)
{
_logger.LogWarning(ex,
"Scheduled message publishing failed unexpectedly, which will stop future scheduled " +
"messages from publishing. See more details here: https://github.com/dotnetcore/CAP/issues/1637. " +
"Exception: {Message}",
ex.Message);
throw;
}
}
}, _tasksCts.Token).ConfigureAwait(false);
}

public async ValueTask EnqueueToScheduler(MediumMessage message, DateTime publishTime, object? transaction = null)
public async Task EnqueueToScheduler(MediumMessage message, DateTime publishTime, object? transaction = null)
{
message.ExpiresAt = publishTime;

Expand All @@ -135,11 +134,6 @@ public async ValueTask EnqueueToScheduler(MediumMessage message, DateTime publis
await _storage.ChangePublishStateAsync(message, StatusName.Queued, transaction);

_schedulerQueue.Enqueue(message, publishTime.Ticks);

if (publishTime.Ticks < _nextSendTime)
{
_delayCts.Cancel();
}
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion src/DotNetCore.CAP/Transport/IDispatcher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ public interface IDispatcher : IProcessingServer

ValueTask EnqueueToExecute(MediumMessage message, ConsumerExecutorDescriptor? descriptor = null);

ValueTask EnqueueToScheduler(MediumMessage message, DateTime publishTime, object? transaction = null);
Task EnqueueToScheduler(MediumMessage message, DateTime publishTime, object? transaction = null);
}
179 changes: 179 additions & 0 deletions test/DotNetCore.CAP.Test/DispatcherTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using DotNetCore.CAP.Internal;
using DotNetCore.CAP.Messages;
using DotNetCore.CAP.Persistence;
using DotNetCore.CAP.Processor;
using DotNetCore.CAP.Test.Helpers;
using FluentAssertions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using NSubstitute;
using Xunit;

namespace DotNetCore.CAP.Test;

public class DispatcherTests
{
private readonly ILogger<Dispatcher> _logger;
private readonly ISubscribeExecutor _executor;
private readonly IDataStorage _storage;

public DispatcherTests()
{
_logger = Substitute.For<ILogger<Dispatcher>>();
_executor = Substitute.For<ISubscribeExecutor>();
_storage = Substitute.For<IDataStorage>();
}

[Fact]
public async Task EnqueueToPublish_ShouldInvokeSend_WhenParallelSendDisabled()
{
// Arrange
var sender = new TestThreadSafeMessageSender();
var options = Options.Create(new CapOptions
{
EnableSubscriberParallelExecute = true,
EnablePublishParallelSend = false,
SubscriberParallelExecuteThreadCount = 2,
SubscriberParallelExecuteBufferFactor = 2
});

var dispatcher = new Dispatcher(_logger, sender, options, _executor, _storage);

using var cts = new CancellationTokenSource();
var messageId = "testId";

// Act
await dispatcher.Start(cts.Token);
await dispatcher.EnqueueToPublish(CreateTestMessage(messageId));
await cts.CancelAsync();

// Assert
sender.Count.Should().Be(1);
sender.ReceivedMessages.First().DbId.Should().Be(messageId);
}

[Fact]
public async Task EnqueueToPublish_ShouldBeThreadSafe_WhenParallelSendDisabled()
{
// Arrange
var sender = new TestThreadSafeMessageSender();
var options = Options.Create(new CapOptions
{
EnableSubscriberParallelExecute = true,
EnablePublishParallelSend = false,
SubscriberParallelExecuteThreadCount = 2,
SubscriberParallelExecuteBufferFactor = 2
});
var dispatcher = new Dispatcher(_logger, sender, options, _executor, _storage);

using var cts = new CancellationTokenSource();
var messages = Enumerable.Range(1, 100)
.Select(i => CreateTestMessage(i.ToString()))
.ToArray();

// Act
await dispatcher.Start(cts.Token);

var tasks = messages
.Select(msg => Task.Run(() => dispatcher.EnqueueToPublish(msg), CancellationToken.None));
await Task.WhenAll(tasks);
await cts.CancelAsync();

// Assert
sender.Count.Should().Be(100);
var receivedMessages = sender.ReceivedMessages.Select(m => m.DbId).Order().ToList();
var expected = messages.Select(m => m.DbId).Order().ToList();
expected.Should().Equal(receivedMessages);
}

[Fact]
public async Task EnqueueToScheduler_ShouldBeThreadSafe_WhenDelayLessThenMinute()
{
// Arrange
var sender = new TestThreadSafeMessageSender();
var options = Options.Create(new CapOptions
{
EnableSubscriberParallelExecute = true,
EnablePublishParallelSend = false,
SubscriberParallelExecuteThreadCount = 2,
SubscriberParallelExecuteBufferFactor = 2
});
var dispatcher = new Dispatcher(_logger, sender, options, _executor, _storage);

using var cts = new CancellationTokenSource();
var messages = Enumerable.Range(1, 10000)
.Select(i => CreateTestMessage(i.ToString()))
.ToArray();

// Act
await dispatcher.Start(cts.Token);
var dateTime = DateTime.Now.AddSeconds(1);
await Parallel.ForEachAsync(messages, CancellationToken.None,
async (m, ct) => { await dispatcher.EnqueueToScheduler(m, dateTime); });

await Task.Delay(1500, CancellationToken.None);

await cts.CancelAsync();

// Assert
sender.Count.Should().Be(10000);

var receivedMessages = sender.ReceivedMessages.Select(m => m.DbId).Order().ToList();
var expected = messages.Select(m => m.DbId).Order().ToList();
expected.Should().Equal(receivedMessages);
}

[Fact]
public async Task EnqueueToScheduler_ShouldSendMessagesInCorrectOrder_WhenEarlierMessageIsSentLater()
{
// Arrange
var sender = new TestThreadSafeMessageSender();
var options = Options.Create(new CapOptions
{
EnableSubscriberParallelExecute = true,
EnablePublishParallelSend = false,
SubscriberParallelExecuteThreadCount = 2,
SubscriberParallelExecuteBufferFactor = 2
});
var dispatcher = new Dispatcher(_logger, sender, options, _executor, _storage);

using var cts = new CancellationTokenSource();
var messages = Enumerable.Range(1, 3)
.Select(i => CreateTestMessage(i.ToString()))
.ToArray();

// Act
await dispatcher.Start(cts.Token);
var dateTime = DateTime.Now;

await dispatcher.EnqueueToScheduler(messages[0], dateTime.AddSeconds(1));
await dispatcher.EnqueueToScheduler(messages[1], dateTime.AddMilliseconds(200));
await dispatcher.EnqueueToScheduler(messages[2], dateTime.AddMilliseconds(100));

await Task.Delay(1200, CancellationToken.None);
await cts.CancelAsync();

// Assert
sender.ReceivedMessages.Select(m => m.DbId).Should().Equal(["3", "2", "1"]);
}


private MediumMessage CreateTestMessage(string id = "1")
{
return new MediumMessage()
{
DbId = id,
Origin = new Message(
headers: new Dictionary<string, string>()
{
{ "cap-msg-id", id }
},
value: new MessageValue("[email protected]", "User"))
};
}
}
2 changes: 2 additions & 0 deletions test/DotNetCore.CAP.Test/DotNetCore.CAP.Test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="FluentAssertions" Version="7.0.0" />
<PackageReference Include="NSubstitute" Version="5.3.0" />
</ItemGroup>

<ItemGroup>
Expand Down
Loading

0 comments on commit e3d851f

Please sign in to comment.