diff --git a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs index 84275f473..633c49058 100644 --- a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs +++ b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. See License.txt in the project root for license information. using System; -using System.Diagnostics; using System.Linq; using System.Net; using System.Threading; @@ -26,69 +25,59 @@ public static class DurableTaskClientExtensions /// The . /// The HTTP request that this response is for. /// The ID of the orchestration instance to check. - /// The cancellation token. - /// Total allowed timeout for output from the durable function. The default value is 10 seconds. /// The timeout between checks for output from the durable function. The default value is 1 second. /// Optional parameter that configures the http response code returned. Defaults to false. + /// Optional parameter that configures whether to get the inputs and outputs of the orchestration. Defaults to true. + /// A token that signals if the wait should be canceled. If canceled, call CreateCheckStatusResponseAsync to return a reponse contains a HttpManagementPayload. /// - public static async Task WaitForCompletionOrCreateCheckStatusResponseAsync(this DurableTaskClient client, + public static async Task WaitForCompletionOrCreateCheckStatusResponseAsync( + this DurableTaskClient client, HttpRequestData request, string instanceId, - CancellationToken cancellation = default, - TimeSpan? timeout = null, TimeSpan? retryInterval = null, - bool returnInternalServerErrorOnFailure = false + bool returnInternalServerErrorOnFailure = false, + bool getInputsAndOutputs = true, + CancellationToken cancellation = default ) { - TimeSpan timeoutLocal = timeout ?? TimeSpan.FromSeconds(10); TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1); - - if (retryIntervalLocal > timeoutLocal) - { - throw new ArgumentException($"Total timeout {timeoutLocal.TotalSeconds} should be bigger than retry timeout {retryIntervalLocal.TotalSeconds}"); - } - - Stopwatch stopwatch = Stopwatch.StartNew(); - while (true) + try { - var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: true); - if (status != null) + while (true) { - if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed || + var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: getInputsAndOutputs); + if (status != null) + { + if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed || #pragma warning disable CS0618 // Type or member is obsolete - status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled || + status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled || #pragma warning restore CS0618 // Type or member is obsolete - status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated || - status.RuntimeStatus == OrchestrationRuntimeStatus.Failed) - { - var response = request.CreateResponse( - (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure)? HttpStatusCode.InternalServerError: HttpStatusCode.OK); - await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId) + status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated || + status.RuntimeStatus == OrchestrationRuntimeStatus.Failed) { - CreatedAt = status.CreatedAt, - LastUpdatedAt = status.LastUpdatedAt, - RuntimeStatus = status.RuntimeStatus, - SerializedInput = status.SerializedInput, - SerializedOutput = status.SerializedOutput, - SerializedCustomStatus = status.SerializedCustomStatus, - }, statusCode: response.StatusCode); - - return response; + var response = request.CreateResponse( + (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure) ? HttpStatusCode.InternalServerError : HttpStatusCode.OK); + await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId) + { + CreatedAt = status.CreatedAt, + LastUpdatedAt = status.LastUpdatedAt, + RuntimeStatus = status.RuntimeStatus, + SerializedInput = status.SerializedInput, + SerializedOutput = status.SerializedOutput, + SerializedCustomStatus = status.SerializedCustomStatus, + }, statusCode: response.StatusCode); + + return response; + } } - } - - TimeSpan elapsed = stopwatch.Elapsed; - if (elapsed < timeout) - { - TimeSpan remainingTime = timeoutLocal!.Subtract(elapsed); - await Task.Delay(remainingTime > retryIntervalLocal ? retryIntervalLocal : remainingTime); - } - else - { - return await CreateCheckStatusResponseAsync(client, request, instanceId, cancellation: cancellation); + await Task.Delay(retryIntervalLocal, cancellation); } } - } + catch (OperationCanceledException) + { + return await CreateCheckStatusResponseAsync(client, request, instanceId); + } + } /// /// Creates an HTTP response that is useful for checking the status of the specified instance. @@ -290,43 +279,54 @@ private static ObjectSerializer GetObjectSerializer(HttpResponseData response) { // Default to the scheme from the request URL string proto = request.Url.Scheme; - string baseUrl; + string host = request.Url.Authority; // Check for "Forwarded" header - if (request.Headers.TryGetValues("Forwarded", out var forwarded)) + if (request.Headers.TryGetValues("Forwarded", out var forwardedHeaders)) { - var forwardedDict = (forwarded.FirstOrDefault() ?? "").Split(';') + var forwardedDict = forwardedHeaders.FirstOrDefault()?.Split(';') .Select(pair => pair.Split('=')) - .Where(pair => pair.Length == 2) // Ensure valid key-value pairs + .Where(pair => pair.Length == 2) .ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim()); - if (forwardedDict.TryGetValue("proto", out var forwardedProto)) - { - proto = forwardedProto; - } - - if (forwardedDict.TryGetValue("host", out var forwardedHost)) + if (forwardedDict != null) { - baseUrl = $"{proto}://{forwardedHost}"; - return baseUrl; + if (forwardedDict.TryGetValue("proto", out var forwardedProto)) + { + proto = forwardedProto; + } + if (forwardedDict.TryGetValue("host", out var forwardedHost)) + { + host = forwardedHost; + // Return if either proto or host (or both) were found in "Forwarded" header + return $"{proto}://{forwardedHost}"; + } } } - - // Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers + // Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers if "Forwarded" is not present if (request.Headers.TryGetValues("X-Forwarded-Proto", out var protos)) { - proto = protos.First(); + proto = protos.FirstOrDefault() ?? proto; } - if (request.Headers.TryGetValues("X-Forwarded-Host", out var hosts)) { - baseUrl = $"{proto}://{hosts.First()}"; - return baseUrl; + // Return base URL if either "X-Forwarded-Proto" or "X-Forwarded-Host" (or both) are found + host = hosts.FirstOrDefault() ?? host; + return $"{proto}://{host}"; + } + + // Fallback to "X-Original-Proto" and "X-Original-Host" headers if neither of the above produced a returnable URL + if (request.Headers.TryGetValues("X-Original-Proto", out var originalProtos)) + { + proto = originalProtos.First(); + } + if (request.Headers.TryGetValues("X-Original-Host", out var originalHosts)) + { + host = originalHosts.First(); } - // Fallback to using the request's URL if no forwarding headers are found - baseUrl = $"{proto}://{request.Url.Authority}"; - return baseUrl; + // Construct and return the base URL from guaranteed fallback values + return $"{proto}://{host}"; } diff --git a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs index a52e1a713..1daafede6 100644 --- a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs +++ b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs @@ -138,8 +138,8 @@ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenComp } /// - /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestration is still running. - /// The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted. + /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestrator didn't finish within + /// the timeout period. The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted. /// [Fact] public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunning() @@ -155,8 +155,8 @@ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunn var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult); HttpRequestData request = this.MockHttpRequestAndResponseData(); - - HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId); + CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, cancellation : cts.Token); Assert.NotNull(response); Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); @@ -211,6 +211,42 @@ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenFail AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata); } + /// + /// Tests the `GetBaseUrlFromRequest` can return the right base URL from the HttpRequestData with different forwarding or proxies. + /// This test covers the following scenarios: + /// - Using the "Forwarded" header + /// - Using "X-Forwarded-Proto" and "X-Forwarded-Host" headers + /// - Using only "X-Forwarded-Host" with default protocol + /// - Using "X-Original-Proto" and "X-Original-Host" headers + /// - no headers + /// + [Theory] + [InlineData("Forwarded", "proto=https;host=forwarded.example.com","","", "https://forwarded.example.com/runtime/webhooks/durabletask")] + [InlineData("X-Forwarded-Proto", "https", "X-Forwarded-Host", "xforwarded.example.com", "https://xforwarded.example.com/runtime/webhooks/durabletask")] + [InlineData("", "", "X-Forwarded-Host", "test.net", "https://test.net/runtime/webhooks/durabletask")] + [InlineData("X-Original-Proto", "https", "X-Original-Host", "original.example.com", "https://original.example.com/runtime/webhooks/durabletask")] + [InlineData("", "", "", "", "http://localhost:7075/runtime/webhooks/durabletask")] // Default base URL for empty headers + public void TestHttpRequestDataForwardingHandling(string header1, string? value1, string header2, string value2, string expectedBaseUrl) + { + var headers = new HttpHeadersCollection(); + if (!string.IsNullOrEmpty(header1)) + { + headers.Add(header1, value1); + } + if (!string.IsNullOrEmpty(header2)) + { + headers.Add(header2, value2); + } + + var request = this.MockHttpRequestAndResponseData(headers); + var client = this.GetTestFunctionsDurableTaskClient(); + + var payload = client.CreateHttpManagementPayload("testInstanceId", request); + AssertHttpManagementPayload(payload, expectedBaseUrl, "testInstanceId"); + } + + + private static void AssertHttpManagementPayload(HttpManagementPayload payload, string BaseUrl, string instanceId) { Assert.Equal(instanceId, payload.Id); @@ -235,7 +271,8 @@ private static void AssertOrhcestrationMetadata( OrchestrationMetadata expected, // Mocks the required HttpRequestData and HttpResponseData for testing purposes. // This method sets up a mock HttpRequestData with a predefined URL and a mock HttpResponseDatav with a default status code and body. - private HttpRequestData MockHttpRequestAndResponseData() + // The headers of HttpRequestData can be provided as an optional parameter, otherwise an empty HttpHeadersCollection is used. + private HttpRequestData MockHttpRequestAndResponseData(HttpHeadersCollection? headers = null) { var mockObjectSerializer = new Mock(); @@ -271,8 +308,9 @@ private HttpRequestData MockHttpRequestAndResponseData() // Set up the URL property. mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri("http://localhost:7075/orchestrators/E1_HelloSequence")); - - var headers = new HttpHeadersCollection(); + + // If headers are provided, use them, otherwise create a new empty HttpHeadersCollection + headers ??= new HttpHeadersCollection(); // Setup the Headers property to return the empty headers mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers);