diff --git a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
index 84275f473..633c49058 100644
--- a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
+++ b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT License. See License.txt in the project root for license information.
using System;
-using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Threading;
@@ -26,69 +25,59 @@ public static class DurableTaskClientExtensions
/// The .
/// The HTTP request that this response is for.
/// The ID of the orchestration instance to check.
- /// The cancellation token.
- /// Total allowed timeout for output from the durable function. The default value is 10 seconds.
/// The timeout between checks for output from the durable function. The default value is 1 second.
/// Optional parameter that configures the http response code returned. Defaults to false.
+ /// Optional parameter that configures whether to get the inputs and outputs of the orchestration. Defaults to true.
+ /// A token that signals if the wait should be canceled. If canceled, call CreateCheckStatusResponseAsync to return a reponse contains a HttpManagementPayload.
///
- public static async Task WaitForCompletionOrCreateCheckStatusResponseAsync(this DurableTaskClient client,
+ public static async Task WaitForCompletionOrCreateCheckStatusResponseAsync(
+ this DurableTaskClient client,
HttpRequestData request,
string instanceId,
- CancellationToken cancellation = default,
- TimeSpan? timeout = null,
TimeSpan? retryInterval = null,
- bool returnInternalServerErrorOnFailure = false
+ bool returnInternalServerErrorOnFailure = false,
+ bool getInputsAndOutputs = true,
+ CancellationToken cancellation = default
)
{
- TimeSpan timeoutLocal = timeout ?? TimeSpan.FromSeconds(10);
TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1);
-
- if (retryIntervalLocal > timeoutLocal)
- {
- throw new ArgumentException($"Total timeout {timeoutLocal.TotalSeconds} should be bigger than retry timeout {retryIntervalLocal.TotalSeconds}");
- }
-
- Stopwatch stopwatch = Stopwatch.StartNew();
- while (true)
+ try
{
- var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: true);
- if (status != null)
+ while (true)
{
- if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed ||
+ var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: getInputsAndOutputs);
+ if (status != null)
+ {
+ if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed ||
#pragma warning disable CS0618 // Type or member is obsolete
- status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled ||
+ status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled ||
#pragma warning restore CS0618 // Type or member is obsolete
- status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated ||
- status.RuntimeStatus == OrchestrationRuntimeStatus.Failed)
- {
- var response = request.CreateResponse(
- (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure)? HttpStatusCode.InternalServerError: HttpStatusCode.OK);
- await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId)
+ status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated ||
+ status.RuntimeStatus == OrchestrationRuntimeStatus.Failed)
{
- CreatedAt = status.CreatedAt,
- LastUpdatedAt = status.LastUpdatedAt,
- RuntimeStatus = status.RuntimeStatus,
- SerializedInput = status.SerializedInput,
- SerializedOutput = status.SerializedOutput,
- SerializedCustomStatus = status.SerializedCustomStatus,
- }, statusCode: response.StatusCode);
-
- return response;
+ var response = request.CreateResponse(
+ (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure) ? HttpStatusCode.InternalServerError : HttpStatusCode.OK);
+ await response.WriteAsJsonAsync(new OrchestrationMetadata(status.Name, status.InstanceId)
+ {
+ CreatedAt = status.CreatedAt,
+ LastUpdatedAt = status.LastUpdatedAt,
+ RuntimeStatus = status.RuntimeStatus,
+ SerializedInput = status.SerializedInput,
+ SerializedOutput = status.SerializedOutput,
+ SerializedCustomStatus = status.SerializedCustomStatus,
+ }, statusCode: response.StatusCode);
+
+ return response;
+ }
}
- }
-
- TimeSpan elapsed = stopwatch.Elapsed;
- if (elapsed < timeout)
- {
- TimeSpan remainingTime = timeoutLocal!.Subtract(elapsed);
- await Task.Delay(remainingTime > retryIntervalLocal ? retryIntervalLocal : remainingTime);
- }
- else
- {
- return await CreateCheckStatusResponseAsync(client, request, instanceId, cancellation: cancellation);
+ await Task.Delay(retryIntervalLocal, cancellation);
}
}
- }
+ catch (OperationCanceledException)
+ {
+ return await CreateCheckStatusResponseAsync(client, request, instanceId);
+ }
+ }
///
/// Creates an HTTP response that is useful for checking the status of the specified instance.
@@ -290,43 +279,54 @@ private static ObjectSerializer GetObjectSerializer(HttpResponseData response)
{
// Default to the scheme from the request URL
string proto = request.Url.Scheme;
- string baseUrl;
+ string host = request.Url.Authority;
// Check for "Forwarded" header
- if (request.Headers.TryGetValues("Forwarded", out var forwarded))
+ if (request.Headers.TryGetValues("Forwarded", out var forwardedHeaders))
{
- var forwardedDict = (forwarded.FirstOrDefault() ?? "").Split(';')
+ var forwardedDict = forwardedHeaders.FirstOrDefault()?.Split(';')
.Select(pair => pair.Split('='))
- .Where(pair => pair.Length == 2) // Ensure valid key-value pairs
+ .Where(pair => pair.Length == 2)
.ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim());
- if (forwardedDict.TryGetValue("proto", out var forwardedProto))
- {
- proto = forwardedProto;
- }
-
- if (forwardedDict.TryGetValue("host", out var forwardedHost))
+ if (forwardedDict != null)
{
- baseUrl = $"{proto}://{forwardedHost}";
- return baseUrl;
+ if (forwardedDict.TryGetValue("proto", out var forwardedProto))
+ {
+ proto = forwardedProto;
+ }
+ if (forwardedDict.TryGetValue("host", out var forwardedHost))
+ {
+ host = forwardedHost;
+ // Return if either proto or host (or both) were found in "Forwarded" header
+ return $"{proto}://{forwardedHost}";
+ }
}
}
-
- // Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers
+ // Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers if "Forwarded" is not present
if (request.Headers.TryGetValues("X-Forwarded-Proto", out var protos))
{
- proto = protos.First();
+ proto = protos.FirstOrDefault() ?? proto;
}
-
if (request.Headers.TryGetValues("X-Forwarded-Host", out var hosts))
{
- baseUrl = $"{proto}://{hosts.First()}";
- return baseUrl;
+ // Return base URL if either "X-Forwarded-Proto" or "X-Forwarded-Host" (or both) are found
+ host = hosts.FirstOrDefault() ?? host;
+ return $"{proto}://{host}";
+ }
+
+ // Fallback to "X-Original-Proto" and "X-Original-Host" headers if neither of the above produced a returnable URL
+ if (request.Headers.TryGetValues("X-Original-Proto", out var originalProtos))
+ {
+ proto = originalProtos.First();
+ }
+ if (request.Headers.TryGetValues("X-Original-Host", out var originalHosts))
+ {
+ host = originalHosts.First();
}
- // Fallback to using the request's URL if no forwarding headers are found
- baseUrl = $"{proto}://{request.Url.Authority}";
- return baseUrl;
+ // Construct and return the base URL from guaranteed fallback values
+ return $"{proto}://{host}";
}
diff --git a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs
index a52e1a713..1daafede6 100644
--- a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs
+++ b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs
@@ -138,8 +138,8 @@ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenComp
}
///
- /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestration is still running.
- /// The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted.
+ /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestrator didn't finish within
+ /// the timeout period. The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted.
///
[Fact]
public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunning()
@@ -155,8 +155,8 @@ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunn
var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult);
HttpRequestData request = this.MockHttpRequestAndResponseData();
-
- HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId);
+ CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));
+ HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, cancellation : cts.Token);
Assert.NotNull(response);
Assert.Equal(HttpStatusCode.Accepted, response.StatusCode);
@@ -211,6 +211,42 @@ public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenFail
AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata);
}
+ ///
+ /// Tests the `GetBaseUrlFromRequest` can return the right base URL from the HttpRequestData with different forwarding or proxies.
+ /// This test covers the following scenarios:
+ /// - Using the "Forwarded" header
+ /// - Using "X-Forwarded-Proto" and "X-Forwarded-Host" headers
+ /// - Using only "X-Forwarded-Host" with default protocol
+ /// - Using "X-Original-Proto" and "X-Original-Host" headers
+ /// - no headers
+ ///
+ [Theory]
+ [InlineData("Forwarded", "proto=https;host=forwarded.example.com","","", "https://forwarded.example.com/runtime/webhooks/durabletask")]
+ [InlineData("X-Forwarded-Proto", "https", "X-Forwarded-Host", "xforwarded.example.com", "https://xforwarded.example.com/runtime/webhooks/durabletask")]
+ [InlineData("", "", "X-Forwarded-Host", "test.net", "https://test.net/runtime/webhooks/durabletask")]
+ [InlineData("X-Original-Proto", "https", "X-Original-Host", "original.example.com", "https://original.example.com/runtime/webhooks/durabletask")]
+ [InlineData("", "", "", "", "http://localhost:7075/runtime/webhooks/durabletask")] // Default base URL for empty headers
+ public void TestHttpRequestDataForwardingHandling(string header1, string? value1, string header2, string value2, string expectedBaseUrl)
+ {
+ var headers = new HttpHeadersCollection();
+ if (!string.IsNullOrEmpty(header1))
+ {
+ headers.Add(header1, value1);
+ }
+ if (!string.IsNullOrEmpty(header2))
+ {
+ headers.Add(header2, value2);
+ }
+
+ var request = this.MockHttpRequestAndResponseData(headers);
+ var client = this.GetTestFunctionsDurableTaskClient();
+
+ var payload = client.CreateHttpManagementPayload("testInstanceId", request);
+ AssertHttpManagementPayload(payload, expectedBaseUrl, "testInstanceId");
+ }
+
+
+
private static void AssertHttpManagementPayload(HttpManagementPayload payload, string BaseUrl, string instanceId)
{
Assert.Equal(instanceId, payload.Id);
@@ -235,7 +271,8 @@ private static void AssertOrhcestrationMetadata( OrchestrationMetadata expected,
// Mocks the required HttpRequestData and HttpResponseData for testing purposes.
// This method sets up a mock HttpRequestData with a predefined URL and a mock HttpResponseDatav with a default status code and body.
- private HttpRequestData MockHttpRequestAndResponseData()
+ // The headers of HttpRequestData can be provided as an optional parameter, otherwise an empty HttpHeadersCollection is used.
+ private HttpRequestData MockHttpRequestAndResponseData(HttpHeadersCollection? headers = null)
{
var mockObjectSerializer = new Mock();
@@ -271,8 +308,9 @@ private HttpRequestData MockHttpRequestAndResponseData()
// Set up the URL property.
mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri("http://localhost:7075/orchestrators/E1_HelloSequence"));
-
- var headers = new HttpHeadersCollection();
+
+ // If headers are provided, use them, otherwise create a new empty HttpHeadersCollection
+ headers ??= new HttpHeadersCollection();
// Setup the Headers property to return the empty headers
mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers);