Skip to content

Commit

Permalink
update by comment
Browse files Browse the repository at this point in the history
  • Loading branch information
nytian committed Nov 4, 2024
1 parent ffdd114 commit 6826026
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 75 deletions.
136 changes: 68 additions & 68 deletions src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License. See License.txt in the project root for license information.

using System;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Threading;
Expand All @@ -26,69 +25,59 @@ public static class DurableTaskClientExtensions
/// <param name="client">The <see cref="DurableTaskClient"/>.</param>
/// <param name="request">The HTTP request that this response is for.</param>
/// <param name="instanceId">The ID of the orchestration instance to check.</param>
/// <param name="cancellation">The cancellation token.</param>
/// <param name="timeout">Total allowed timeout for output from the durable function. The default value is 10 seconds.</param>
/// <param name="retryInterval">The timeout between checks for output from the durable function. The default value is 1 second.</param>
/// <param name="returnInternalServerErrorOnFailure">Optional parameter that configures the http response code returned. Defaults to <c>false</c>.</param>
/// <param name="getInputsAndOutputs">Optional parameter that configures whether to get the inputs and outputs of the orchestration. Defaults to <c>true</c>.</param>
/// <param name="cancellation">A token that signals if the wait should be canceled. If canceled, call CreateCheckStatusResponseAsync to return a reponse contains a HttpManagementPayload.</param>
/// <returns></returns>
public static async Task<HttpResponseData> WaitForCompletionOrCreateCheckStatusResponseAsync(this DurableTaskClient client,
public static async Task<HttpResponseData> WaitForCompletionOrCreateCheckStatusResponseAsync(
this DurableTaskClient client,
HttpRequestData request,
string instanceId,
CancellationToken cancellation = default,
TimeSpan? timeout = null,
TimeSpan? retryInterval = null,
bool returnInternalServerErrorOnFailure = false
bool returnInternalServerErrorOnFailure = false,
bool getInputsAndOutputs = true,
CancellationToken cancellation = default
)
{
TimeSpan timeoutLocal = timeout ?? TimeSpan.FromSeconds(10);
TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1);

if (retryIntervalLocal > timeoutLocal)
{
throw new ArgumentException($"Total timeout {timeoutLocal.TotalSeconds} should be bigger than retry timeout {retryIntervalLocal.TotalSeconds}");
}

Stopwatch stopwatch = Stopwatch.StartNew();
while (true)
try
{
var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: true);
if (status != null)
while (true)
{
if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed ||
var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: getInputsAndOutputs);
if (status != null)
{
if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed ||
#pragma warning disable CS0618 // Type or member is obsolete
status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled ||
status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled ||
#pragma warning restore CS0618 // Type or member is obsolete
status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated ||
status.RuntimeStatus == OrchestrationRuntimeStatus.Failed)
{
var response = request.CreateResponse(
(status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure)? HttpStatusCode.InternalServerError: HttpStatusCode.OK);
await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId)
status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated ||
status.RuntimeStatus == OrchestrationRuntimeStatus.Failed)
{
CreatedAt = status.CreatedAt,
LastUpdatedAt = status.LastUpdatedAt,
RuntimeStatus = status.RuntimeStatus,
SerializedInput = status.SerializedInput,
SerializedOutput = status.SerializedOutput,
SerializedCustomStatus = status.SerializedCustomStatus,
}, statusCode: response.StatusCode);

return response;
var response = request.CreateResponse(
(status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure) ? HttpStatusCode.InternalServerError : HttpStatusCode.OK);
await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId)
{
CreatedAt = status.CreatedAt,
LastUpdatedAt = status.LastUpdatedAt,
RuntimeStatus = status.RuntimeStatus,
SerializedInput = status.SerializedInput,
SerializedOutput = status.SerializedOutput,
SerializedCustomStatus = status.SerializedCustomStatus,
}, statusCode: response.StatusCode);

return response;
}
}
}

TimeSpan elapsed = stopwatch.Elapsed;
if (elapsed < timeout)
{
TimeSpan remainingTime = timeoutLocal!.Subtract(elapsed);
await Task.Delay(remainingTime > retryIntervalLocal ? retryIntervalLocal : remainingTime);
}
else
{
return await CreateCheckStatusResponseAsync(client, request, instanceId, cancellation: cancellation);
await Task.Delay(retryIntervalLocal, cancellation);
}
}
}
catch (OperationCanceledException)
{
return await CreateCheckStatusResponseAsync(client, request, instanceId);
}
}

/// <summary>
/// Creates an HTTP response that is useful for checking the status of the specified instance.
Expand Down Expand Up @@ -290,43 +279,54 @@ private static ObjectSerializer GetObjectSerializer(HttpResponseData response)
{
// Default to the scheme from the request URL
string proto = request.Url.Scheme;
string baseUrl;
string host = request.Url.Authority;

// Check for "Forwarded" header
if (request.Headers.TryGetValues("Forwarded", out var forwarded))
if (request.Headers.TryGetValues("Forwarded", out var forwardedHeaders))
{
var forwardedDict = (forwarded.FirstOrDefault() ?? "").Split(';')
var forwardedDict = forwardedHeaders.FirstOrDefault()?.Split(';')
.Select(pair => pair.Split('='))
.Where(pair => pair.Length == 2) // Ensure valid key-value pairs
.Where(pair => pair.Length == 2)
.ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim());

if (forwardedDict.TryGetValue("proto", out var forwardedProto))
{
proto = forwardedProto;
}

if (forwardedDict.TryGetValue("host", out var forwardedHost))
if (forwardedDict != null)
{
baseUrl = $"{proto}://{forwardedHost}";
return baseUrl;
if (forwardedDict.TryGetValue("proto", out var forwardedProto))
{
proto = forwardedProto;
}
if (forwardedDict.TryGetValue("host", out var forwardedHost))
{
host = forwardedHost;
// Return if either proto or host (or both) were found in "Forwarded" header
return $"{proto}://{forwardedHost}";
}
}
}

// Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers
// Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers if "Forwarded" is not present
if (request.Headers.TryGetValues("X-Forwarded-Proto", out var protos))
{
proto = protos.First();
proto = protos.FirstOrDefault() ?? proto;
}

if (request.Headers.TryGetValues("X-Forwarded-Host", out var hosts))
{
baseUrl = $"{proto}://{hosts.First()}";
return baseUrl;
// Return base URL if either "X-Forwarded-Proto" or "X-Forwarded-Host" (or both) are found
host = hosts.FirstOrDefault() ?? host;
return $"{proto}://{host}";
}

// Fallback to "X-Original-Proto" and "X-Original-Host" headers if neither of the above produced a returnable URL
if (request.Headers.TryGetValues("X-Original-Proto", out var originalProtos))
{
proto = originalProtos.First();
}
if (request.Headers.TryGetValues("X-Original-Host", out var originalHosts))
{
host = originalHosts.First();
}

// Fallback to using the request's URL if no forwarding headers are found
baseUrl = $"{proto}://{request.Url.Authority}";
return baseUrl;
// Construct and return the base URL from guaranteed fallback values
return $"{proto}://{host}";
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenComp
}

/// <summary>
/// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestration is still running.
/// The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted.
/// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestrator didn't finish within
/// the timeout period. The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted.
/// </summary>
[Fact]
public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunning()
Expand All @@ -155,8 +155,8 @@ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunn
var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult);

HttpRequestData request = this.MockHttpRequestAndResponseData();

HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId);
CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, cancellation : cts.Token);

Assert.NotNull(response);
Assert.Equal(HttpStatusCode.Accepted, response.StatusCode);
Expand Down Expand Up @@ -211,6 +211,42 @@ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenFail
AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata);
}

/// <summary>
/// Tests the `GetBaseUrlFromRequest` can return the right base URL from the HttpRequestData with different forwarding or proxies.
/// This test covers the following scenarios:
/// - Using the "Forwarded" header
/// - Using "X-Forwarded-Proto" and "X-Forwarded-Host" headers
/// - Using only "X-Forwarded-Host" with default protocol
/// - Using "X-Original-Proto" and "X-Original-Host" headers
/// - no headers
/// </summary>
[Theory]
[InlineData("Forwarded", "proto=https;host=forwarded.example.com","","", "https://forwarded.example.com/runtime/webhooks/durabletask")]
[InlineData("X-Forwarded-Proto", "https", "X-Forwarded-Host", "xforwarded.example.com", "https://xforwarded.example.com/runtime/webhooks/durabletask")]
[InlineData("", "", "X-Forwarded-Host", "test.net", "https://test.net/runtime/webhooks/durabletask")]
[InlineData("X-Original-Proto", "https", "X-Original-Host", "original.example.com", "https://original.example.com/runtime/webhooks/durabletask")]
[InlineData("", "", "", "", "http://localhost:7075/runtime/webhooks/durabletask")] // Default base URL for empty headers
public void TestHttpRequestDataForwardingHandling(string header1, string? value1, string header2, string value2, string expectedBaseUrl)
{
var headers = new HttpHeadersCollection();
if (!string.IsNullOrEmpty(header1))
{
headers.Add(header1, value1);
}
if (!string.IsNullOrEmpty(header2))
{
headers.Add(header2, value2);
}

var request = this.MockHttpRequestAndResponseData(headers);
var client = this.GetTestFunctionsDurableTaskClient();

var payload = client.CreateHttpManagementPayload("testInstanceId", request);
AssertHttpManagementPayload(payload, expectedBaseUrl, "testInstanceId");
}



private static void AssertHttpManagementPayload(HttpManagementPayload payload, string BaseUrl, string instanceId)
{
Assert.Equal(instanceId, payload.Id);
Expand All @@ -235,7 +271,8 @@ private static void AssertOrhcestrationMetadata( OrchestrationMetadata expected,

// Mocks the required HttpRequestData and HttpResponseData for testing purposes.
// This method sets up a mock HttpRequestData with a predefined URL and a mock HttpResponseDatav with a default status code and body.
private HttpRequestData MockHttpRequestAndResponseData()
// The headers of HttpRequestData can be provided as an optional parameter, otherwise an empty HttpHeadersCollection is used.
private HttpRequestData MockHttpRequestAndResponseData(HttpHeadersCollection? headers = null)
{
var mockObjectSerializer = new Mock<ObjectSerializer>();

Expand Down Expand Up @@ -271,8 +308,9 @@ private HttpRequestData MockHttpRequestAndResponseData()

// Set up the URL property.
mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri("http://localhost:7075/orchestrators/E1_HelloSequence"));

var headers = new HttpHeadersCollection();

// If headers are provided, use them, otherwise create a new empty HttpHeadersCollection
headers ??= new HttpHeadersCollection();

// Setup the Headers property to return the empty headers
mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers);
Expand Down

0 comments on commit 6826026

Please sign in to comment.