Skip to content

Commit

Permalink
SQS: Distributed Tracing support (#2591)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Alex Hemsath <[email protected]>
  • Loading branch information
tippmar-nr and nr-ahemsath authored Jul 2, 2024
1 parent 995b377 commit 4a6c869
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
// SPDX-License-Identifier: Apache-2.0

using System;
using System.Collections;
using System.Collections.Generic;
using NewRelic.Agent.Api;
using NewRelic.Agent.Api.Experimental;
using NewRelic.Agent.Extensions.Providers.Wrapper;
using NewRelic.Reflection;

namespace NewRelic.Agent.Extensions.AwsSdk
{
public static class SqsHelper
{
private static Func<object, IDictionary> _getMessageAttributes;
private static Func<object> _messageAttributeValueTypeFactory;

public const string VendorName = "SQS";

private class SqsAttributes
Expand Down Expand Up @@ -48,25 +54,57 @@ public SqsAttributes(string url)
public static ISegment GenerateSegment(ITransaction transaction, MethodCall methodCall, string url, MessageBrokerAction action)
{
var attr = new SqsAttributes(url);
return transaction.StartMessageBrokerSegment(methodCall, MessageBrokerDestinationType.Queue, action, VendorName, attr.QueueName);
var segment = transaction.StartMessageBrokerSegment(methodCall, MessageBrokerDestinationType.Queue, action, VendorName, attr.QueueName);
segment.GetExperimentalApi().MakeLeaf();

return segment;
}

public static void InsertDistributedTraceHeaders(ITransaction transaction, dynamic webRequest)
public static void InsertDistributedTraceHeaders(ITransaction transaction, object sendMessageRequest)
{
var setHeaders = new Action<dynamic, string, string>((wr, key, value) =>
var headersInserted = 0;

var setHeaders = new Action<object, string, string>((smr, key, value) =>
{
var headers = wr.Headers as IDictionary<string, object>;
var getMessageAttributes = _getMessageAttributes ??=
VisibilityBypasser.Instance.GeneratePropertyAccessor<IDictionary>(
smr.GetType(), "MessageAttributes");
if (headers == null)
{
headers = new Dictionary<string, object>();
wr.Headers = headers;
}
var messageAttributes = getMessageAttributes(smr);
// SQS is limited to no more than 10 attributes; if we can't add up to 3 attributes, don't add any
if ((messageAttributes.Count + 3 - headersInserted) > 10)
return;
// create a new MessageAttributeValue instance
var messageAttributeValueTypeFactory = _messageAttributeValueTypeFactory ??= VisibilityBypasser.Instance.GenerateTypeFactory(smr.GetType().Assembly.FullName, "Amazon.SQS.Model.MessageAttributeValue");
object newMessageAttributeValue = messageAttributeValueTypeFactory.Invoke();
var dataTypePropertySetter = VisibilityBypasser.Instance.GeneratePropertySetter<string>(newMessageAttributeValue, "DataType");
dataTypePropertySetter("String");
var stringValuePropertySetter = VisibilityBypasser.Instance.GeneratePropertySetter<string>(newMessageAttributeValue, "StringValue");
stringValuePropertySetter(value);
messageAttributes.Add(key, newMessageAttributeValue);
++headersInserted;
});

transaction.InsertDistributedTraceHeaders(sendMessageRequest, setHeaders);

}
public static void AcceptDistributedTraceHeaders(ITransaction transaction, dynamic messageAttributes)
{
var getHeaders = new Func<IDictionary, string, IEnumerable<string>>((maDict, key) =>
{
if (!maDict.Contains(key))
return [];
headers[key] = value;
return [(string)((dynamic)maDict[key]).StringValue];
});

transaction.InsertDistributedTraceHeaders(webRequest, setHeaders);
transaction.AcceptDistributedTraceHeaders((IDictionary)messageAttributes, getHeaders, TransportType.Queue);

}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
// Copyright 2020 New Relic, Inc. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

using System.Collections.Concurrent;
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using NewRelic.Agent.Api;
using NewRelic.Agent.Extensions.AwsSdk;
using NewRelic.Agent.Extensions.Providers.Wrapper;
using NewRelic.Reflection;

namespace NewRelic.Providers.Wrapper.AwsSdk
{
Expand All @@ -12,6 +17,12 @@ public class AwsSdkPipelineWrapper : IWrapper
public bool IsTransactionRequired => true;

private const string WrapperName = "AwsSdkPipelineWrapper";
private static readonly ConcurrentDictionary<Type, Func<object, object>> _getRequestResponseFromGeneric = new();

private const string NEWRELIC_TRACE_HEADER = "newrelic";
private const string W3C_TRACEPARENT_HEADER = "traceparent";
private const string W3C_TRACESTATE_HEADER = "tracestate";


public CanWrapResponse CanWrap(InstrumentedMethodInfo methodInfo)
{
Expand All @@ -23,6 +34,15 @@ public AfterWrappedMethodDelegate BeforeWrappedMethod(InstrumentedMethodCall ins
// Get the IExecutionContext (the only parameter)
dynamic executionContext = instrumentedMethodCall.MethodCall.MethodArguments[0];

var isAsync = instrumentedMethodCall.IsAsync ||
instrumentedMethodCall.InstrumentedMethodInfo.Method.MethodName == "InvokeAsync";

if (isAsync)
{
transaction.AttachToAsync();
transaction.DetachFromPrimary(); //Remove from thread-local type storage
}

// Get the IRequestContext
if (executionContext.RequestContext == null)
{
Expand All @@ -31,16 +51,6 @@ public AfterWrappedMethodDelegate BeforeWrappedMethod(InstrumentedMethodCall ins
}
dynamic requestContext = executionContext.RequestContext;

if (requestContext.ServiceMetaData == null)
{
agent.Logger.Debug("AwsSdkPipelineWrapper: requestContext.ServiceMetaData is null. Returning NoOp delegate.");
return Delegates.NoOp;
}
dynamic metadata = requestContext.ServiceMetaData;

// check for null first if we decide to use this property
// string requestId = metadata.ServiceId; // SQS?

// Get the AmazonWebServiceRequest being invoked. The name will tell us the type of request
if (requestContext.OriginalRequest == null)
{
Expand All @@ -53,13 +63,11 @@ public AfterWrappedMethodDelegate BeforeWrappedMethod(InstrumentedMethodCall ins
agent.Logger.Finest("AwsSdkPipelineWrapper: Request type is " + requestType);

MessageBrokerAction action;
var insertDistributedTraceHeaders = false;
switch (requestType)
{
case "SendMessageRequest":
case "SendMessageBatchRequest":
action = MessageBrokerAction.Produce;
insertDistributedTraceHeaders = true;
break;
case "ReceiveMessageRequest":
action = MessageBrokerAction.Consume;
Expand All @@ -74,19 +82,67 @@ public AfterWrappedMethodDelegate BeforeWrappedMethod(InstrumentedMethodCall ins

string requestQueueUrl = request.QueueUrl;
ISegment segment = SqsHelper.GenerateSegment(transaction, instrumentedMethodCall.MethodCall, requestQueueUrl, action);
if (insertDistributedTraceHeaders)
if (action == MessageBrokerAction.Produce)
{
// This needs to happen at the end
if (requestContext.Request == null)
agent.Logger.Finest("AwsSdkPipelineWrapper: requestContext.Request is null, unable to insert distributed trace headers.");
if (request.MessageAttributes == null)
{
agent.Logger.Debug("AwsSdkPipelineWrapper: requestContext.OriginalRequest.MessageAttributes is null, unable to insert distributed trace headers.");
}
else
{
dynamic webRequest = requestContext.Request;
SqsHelper.InsertDistributedTraceHeaders(transaction, webRequest);
SqsHelper.InsertDistributedTraceHeaders(transaction, request);
}
}

return Delegates.GetDelegateFor(segment);
// modify the request to ask for DT headers in the response message attributes
if (action == MessageBrokerAction.Consume)
{
if (request.MessageAttributeNames == null)
request.MessageAttributeNames = new List<string>();

request.MessageAttributeNames.Add(NEWRELIC_TRACE_HEADER);
request.MessageAttributeNames.Add(W3C_TRACESTATE_HEADER);
request.MessageAttributeNames.Add(W3C_TRACEPARENT_HEADER);
}


if (isAsync)
{
return Delegates.GetAsyncDelegateFor<Task>(agent, segment, true, ProcessResponse, TaskContinuationOptions.ExecuteSynchronously);

void ProcessResponse(Task responseTask)
{
if (!ValidTaskResponse(responseTask) || (segment == null) || action != MessageBrokerAction.Consume)
return;

// taskResult is a ReceiveMessageResponse
var taskResultGetter = _getRequestResponseFromGeneric.GetOrAdd(responseTask.GetType(), t => VisibilityBypasser.Instance.GeneratePropertyAccessor<object>(t, "Result"));
dynamic receiveMessageResponse = taskResultGetter(responseTask);

// accept distributed trace headers from the first message in the response
SqsHelper.AcceptDistributedTraceHeaders(transaction, receiveMessageResponse.Messages[0].MessageAttributes);
}
}

return Delegates.GetDelegateFor(
onComplete: segment.End,
onSuccess: () =>
{
if (action != MessageBrokerAction.Consume)
return;
var ec = executionContext;
var response = ec.ResponseContext.Response; // response is a ReceiveMessageResponse
// accept distributed trace headers from the first message in the response
SqsHelper.AcceptDistributedTraceHeaders(transaction, response.Messages[0].MessageAttributes);
}
);
}

private static bool ValidTaskResponse(Task response)
{
return response?.Status == TaskStatus.RanToCompletion;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// Copyright 2020 New Relic, Inc. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Amazon.SQS.Model;
using NewRelic.Agent.Api;
using NewRelic.Agent.Api.Experimental;
using NewRelic.Agent.Extensions.AwsSdk;
using NewRelic.Agent.Extensions.Providers.Wrapper;
using NUnit.Framework;
using Telerik.JustMock;

namespace Agent.Extensions.Tests.Helpers
{
[TestFixture]
public class SqsHelperTests
{
private ITransaction _mockTransaction;

[SetUp]
public void SetUp()
{
_mockTransaction = Mock.Create<ITransaction>();

Mock.Arrange(() => _mockTransaction.InsertDistributedTraceHeaders(Arg.IsAny<object>(), Arg.IsAny<Action<object, string, string>>()))
.DoInstead((object carrier, Action<object, string, string> setter) =>
{
setter(carrier, "traceparent", "traceparentvalue");
setter(carrier, "tracestate", "tracestatevalue");
});
}


[Test]
public void InsertDistributedTraceHeaders_ValidRequest_InsertsHeaders()
{
// Arrange
var sendMessageRequest = new MockMessageRequest
{
MessageAttributes = new Dictionary<string, MessageAttributeValue>
{
{ "key1", new MessageAttributeValue { DataType = "String", StringValue = "value1" } },
}
};

// Act
SqsHelper.InsertDistributedTraceHeaders(_mockTransaction, sendMessageRequest);

// Assert
Assert.That(sendMessageRequest.MessageAttributes, Has.Count.EqualTo(3));
Assert.That(sendMessageRequest.MessageAttributes, Contains.Key("traceparent"));
Assert.That(sendMessageRequest.MessageAttributes, Contains.Key("tracestate"));
Assert.That(sendMessageRequest.MessageAttributes["traceparent"].StringValue, Is.EqualTo("traceparentvalue"));
Assert.That(sendMessageRequest.MessageAttributes["tracestate"].StringValue, Is.EqualTo("tracestatevalue"));
}
[Test]
[TestCase(7, true)]
[TestCase(8, false)]
public void InsertDistributedTraceHeaders_AttributeLimit_ExceedsLimitGracefully(int attributeCount, bool dtHeadersShouldBeAdded)
{
// Arrange
var sendMessageRequest = new MockMessageRequest
{
MessageAttributes = new Dictionary<string, MessageAttributeValue>()
};

// Pre-populate the message attributes to reach the limit
for (int i = 0; i < attributeCount; i++)
{
sendMessageRequest.MessageAttributes.Add($"key{i}", new MessageAttributeValue { DataType = "String", StringValue = $"value{i}" });
}

// Act
SqsHelper.InsertDistributedTraceHeaders(_mockTransaction, sendMessageRequest);

// Assert
if (dtHeadersShouldBeAdded)
{
Assert.That(sendMessageRequest.MessageAttributes, Has.Count.EqualTo(attributeCount + 2));
Assert.That(sendMessageRequest.MessageAttributes, Does.ContainKey("traceparent"));
Assert.That(sendMessageRequest.MessageAttributes, Does.ContainKey("tracestate"));
}
else
{
// assert that no additional headers were added
Assert.That(sendMessageRequest.MessageAttributes, Has.Count.EqualTo(attributeCount));
Assert.That(sendMessageRequest.MessageAttributes, Does.Not.ContainKey("traceparent"));
Assert.That(sendMessageRequest.MessageAttributes, Does.Not.ContainKey("tracestate"));
}
}

[Test]
public void AcceptDistributedTraceHeaders_HeadersPresent_AppliesHeaders()
{
// Arrange
var messageRequest = new MockMessageRequest
{
MessageAttributes = new Dictionary<string, MessageAttributeValue>
{
{ "traceparent", new MessageAttributeValue { DataType = "String", StringValue = "00-abcdef1234567890abcdef1234567890-abcdef123456-01" } },
{ "tracestate", new MessageAttributeValue { DataType = "String", StringValue = "congo=t61rcWkgMzE" } }
}
};

var results = new Dictionary<string, string>();

Mock.Arrange(() => _mockTransaction.AcceptDistributedTraceHeaders<IDictionary>(Arg.IsAny<IDictionary>(), Arg.IsAny<Func<IDictionary, string, IEnumerable<string>>>(), Arg.IsAny<TransportType>()))
.DoInstead((IDictionary carrier, Func<IDictionary, string, IEnumerable<string>> getter, TransportType _) =>
{
var value = getter(carrier, "newrelic").SingleOrDefault();
if (!string.IsNullOrEmpty(value))
results["newrelic"] = value;
value = getter(carrier, "traceparent").SingleOrDefault();
if (!string.IsNullOrEmpty(value))
results["traceparent"] = value;
value = getter(carrier, "tracestate").SingleOrDefault();
if (!string.IsNullOrEmpty(value))
results["tracestate"] = value;
});

// Act
SqsHelper.AcceptDistributedTraceHeaders(_mockTransaction, messageRequest.MessageAttributes);

// Assert
Assert.That(results, Has.Count.EqualTo(2));
Assert.That(results, Contains.Key("traceparent").WithValue("00-abcdef1234567890abcdef1234567890-abcdef123456-01"));
Assert.That(results, Contains.Key("tracestate").WithValue("congo=t61rcWkgMzE"));
Assert.That(results, Does.Not.ContainKey("newrelic"));
}
}
}

namespace Amazon.SQS.Model
{
public class MockMessageRequest
{
public Dictionary<string, MessageAttributeValue> MessageAttributes { get; set; }
}

public class MessageAttributeValue // name and namespace are required for reflection in SqsHelper.InsertDistributedTraceHeaders
{
public string DataType { get; set; }
public string StringValue { get; set; }
}
}

0 comments on commit 4a6c869

Please sign in to comment.