Skip to content

Commit

Permalink
Add message size checks (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
slang25 authored Nov 2, 2024
1 parent 8cc3666 commit 0cd099b
Show file tree
Hide file tree
Showing 6 changed files with 784 additions and 273 deletions.
48 changes: 48 additions & 0 deletions src/LocalSqsSnsMessaging/SnsClient/InMemorySnsClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public sealed partial class InMemorySnsClient : IAmazonSimpleNotificationService
{
private readonly InMemoryAwsBus _bus;
private readonly Lazy<ISimpleNotificationServicePaginatorFactory> _paginators;

private const int MaxMessageSize = 262144;

internal InMemorySnsClient(InMemoryAwsBus bus)
{
Expand Down Expand Up @@ -390,17 +392,63 @@ public Task<PublishResponse> PublishAsync(string topicArn, string message, strin
public Task<PublishResponse> PublishAsync(PublishRequest request, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(request);

var messageSize = CalculateMessageSize(request.Message, request.MessageAttributes);
if (messageSize > MaxMessageSize)
{
throw new InvalidParameterException($"Message size has exceeded the limit of {MaxMessageSize} bytes.");
}

var topic = GetTopicByArn(request.TopicArn);
var result = topic.PublishAction.Execute(request);

return Task.FromResult(result);
}

private static int CalculateMessageSize(string message, Dictionary<string, MessageAttributeValue>? messageAttributes)
{
var totalSize = 0;

// Add message body size
totalSize += Encoding.UTF8.GetByteCount(message);

// Add message attributes size
if (messageAttributes != null)
{
foreach (var (key, attributeValue) in messageAttributes)
{
// Add attribute name size
totalSize += Encoding.UTF8.GetByteCount(key);

// Add data type size (including any custom type prefix)
totalSize += Encoding.UTF8.GetByteCount(attributeValue.DataType);

// Add value size based on the type
if (attributeValue.BinaryValue != null)
{
totalSize += (int)attributeValue.BinaryValue.Length;
}
else if (attributeValue.StringValue != null)
{
totalSize += Encoding.UTF8.GetByteCount(attributeValue.StringValue);
}
}
}

return totalSize;
}


public Task<PublishBatchResponse> PublishBatchAsync(PublishBatchRequest request,
CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(request);
var totalSize = request.PublishBatchRequestEntries
.Sum(requestEntry => CalculateMessageSize(requestEntry.Message, requestEntry.MessageAttributes));
if (totalSize > MaxMessageSize)
{
throw new InvalidParameterException($"Message size has exceeded the limit of {MaxMessageSize} bytes.");
}

var topic = GetTopicByArn(request.TopicArn);
var result = topic.PublishAction.ExecuteBatch(request);
Expand Down
53 changes: 51 additions & 2 deletions src/LocalSqsSnsMessaging/SqsClient/InMemorySqsClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public sealed partial class InMemorySqsClient : IAmazonSQS
private readonly InMemoryAwsBus _bus;
private readonly Lazy<ISQSPaginatorFactory> _paginators;

private const int MaxMessageSize = 262144;
private static readonly string[] InternalAttributes = [
QueueAttributeName.ApproximateNumberOfMessages,
QueueAttributeName.ApproximateNumberOfMessagesDelayed,
Expand Down Expand Up @@ -470,6 +471,13 @@ public Task<SendMessageResponse> SendMessageAsync(SendMessageRequest request,
}

var message = CreateMessage(request.MessageBody, request.MessageAttributes, request.MessageSystemAttributes);
var totalSize = CalculateMessageSize(message.Body, message.MessageAttributes);

if (totalSize > MaxMessageSize)
{
throw new AmazonSQSException(
$"Message size ({totalSize} bytes) exceeds the maximum allowed size ({MaxMessageSize} bytes)");
}

if (queue.IsFifo)
{
Expand Down Expand Up @@ -523,6 +531,39 @@ public Task<SendMessageResponse> SendMessageAsync(SendMessageRequest request,
}.SetCommonProperties());
}

private static int CalculateMessageSize(string messageBody, Dictionary<string, MessageAttributeValue>? messageAttributes)
{
var totalSize = 0;

// Add message body size
totalSize += Encoding.UTF8.GetByteCount(messageBody);

// Add message attributes size
if (messageAttributes != null)
{
foreach (var (key, attributeValue) in messageAttributes)
{
// Add attribute name size
totalSize += Encoding.UTF8.GetByteCount(key);

// Add data type size (including any custom type prefix)
totalSize += Encoding.UTF8.GetByteCount(attributeValue.DataType);

// Add value size based on the type
if (attributeValue.BinaryValue != null)
{
totalSize += (int)attributeValue.BinaryValue.Length;
}
else if (attributeValue.StringValue != null)
{
totalSize += Encoding.UTF8.GetByteCount(attributeValue.StringValue);
}
}
}

return totalSize;
}

private static void EnqueueFifoMessage(SqsQueueResource queue, string messageGroupId, Message message)
{
queue.MessageGroups.AddOrUpdate(messageGroupId,
Expand Down Expand Up @@ -1075,10 +1116,18 @@ public Task<SendMessageBatchResponse> SendMessageBatchAsync(SendMessageBatchRequ
Failed = []
};

var totalSize = request.Entries.Sum(e => CalculateMessageSize(e.MessageBody, e.MessageAttributes));

if (totalSize > MaxMessageSize)
{
throw new BatchRequestTooLongException(
$"Batch size ({totalSize} bytes) exceeds the maximum allowed size ({MaxMessageSize} bytes)");
}

foreach (var entry in request.Entries)
{
var message = CreateMessage(entry.MessageBody, entry.MessageAttributes, entry.MessageSystemAttributes);

if (entry.DelaySeconds > 0)
{
message.Attributes["DelaySeconds"] = entry.DelaySeconds.ToString(NumberFormatInfo.InvariantInfo);
Expand All @@ -1096,7 +1145,7 @@ public Task<SendMessageBatchResponse> SendMessageBatchAsync(SendMessageBatchRequ
MD5OfMessageBody = message.MD5OfBody
});
}

return Task.FromResult(response.SetCommonProperties());
}

Expand Down
Loading

0 comments on commit 0cd099b

Please sign in to comment.