From 79e229542e3846b13456f0d9c8db6bdac860aebb Mon Sep 17 00:00:00 2001 From: Dixon T E Date: Wed, 6 Nov 2024 04:57:08 +1100 Subject: [PATCH] Add WaitForCompletionOrCreateCheckStatusResponseAsync to Microsoft.Azure.Functions.Worker.DurableTaskClientExtensions (#2875) * Initial implementation of WaitForCompletionOrCreateCheckStatusResponseAsync * Support X-Forwarded-Host et al * Removed output of request headers used in my debugging * Set location header to include returnInternalServerErrorOnFailure=true if requested * update api and add unit test * update sortings * Remove unnecessary spaces * add back forword request handling and update test accordingly * update by comment * add summary * update test * remove x-original-forwarded as we shouldn't use this * default getinputsandoutputs to false * update test by comment --------- Co-authored-by: naiyuantian@microsoft.com --- .../DurableTaskClientExtensions.cs | 114 ++++++++- .../FunctionsDurableTaskClientTests.cs | 235 +++++++++++++++++- 2 files changed, 346 insertions(+), 3 deletions(-) diff --git a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs index bbd6222a8..251ebb2d7 100644 --- a/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs +++ b/src/Worker.Extensions.DurableTask/DurableTaskClientExtensions.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. See License.txt in the project root for license information. using System; +using System.Linq; using System.Net; using System.Threading; using System.Threading.Tasks; @@ -18,6 +19,70 @@ namespace Microsoft.Azure.Functions.Worker; /// public static class DurableTaskClientExtensions { + /// + /// Waits for the completion of the specified orchestration instance with a retry interval, controlled by the cancellation token. + /// If the orchestration does not complete within the required time, returns an HTTP response containing the class to manage instances. + /// + /// The . + /// The HTTP request that this response is for. + /// The ID of the orchestration instance to check. + /// The timeout between checks for output from the durable function. The default value is 1 second. + /// Optional parameter that configures the http response code returned. Defaults to false. + /// Optional parameter that configures whether to get the inputs and outputs of the orchestration. Defaults to false. + /// A token that signals if the wait should be canceled. If canceled, call CreateCheckStatusResponseAsync to return a reponse contains a HttpManagementPayload. + /// + public static async Task WaitForCompletionOrCreateCheckStatusResponseAsync( + this DurableTaskClient client, + HttpRequestData request, + string instanceId, + TimeSpan? retryInterval = null, + bool returnInternalServerErrorOnFailure = false, + bool getInputsAndOutputs = false, + CancellationToken cancellation = default + ) + { + TimeSpan retryIntervalLocal = retryInterval ?? TimeSpan.FromSeconds(1); + try + { + while (true) + { + var status = await client.GetInstanceAsync(instanceId, getInputsAndOutputs: getInputsAndOutputs); + if (status != null) + { + if (status.RuntimeStatus == OrchestrationRuntimeStatus.Completed || +#pragma warning disable CS0618 // Type or member is obsolete + status.RuntimeStatus == OrchestrationRuntimeStatus.Canceled || +#pragma warning restore CS0618 // Type or member is obsolete + status.RuntimeStatus == OrchestrationRuntimeStatus.Terminated || + status.RuntimeStatus == OrchestrationRuntimeStatus.Failed) + { + var response = request.CreateResponse( + (status.RuntimeStatus == OrchestrationRuntimeStatus.Failed && returnInternalServerErrorOnFailure) ? HttpStatusCode.InternalServerError : HttpStatusCode.OK); + await response.WriteAsJsonAsync(new + { + Name = status.Name, + InstanceId = status.InstanceId, + CreatedAt = status.CreatedAt, + LastUpdatedAt = status.LastUpdatedAt, + RuntimeStatus = status.RuntimeStatus.ToString(), // Convert enum to string + SerializedInput = status.SerializedInput, + SerializedOutput = status.SerializedOutput, + SerializedCustomStatus = status.SerializedCustomStatus + }, statusCode: response.StatusCode); + + return response; + } + } + await Task.Delay(retryIntervalLocal, cancellation); + } + } + // If the task is canceled, call CreateCheckStatusResponseAsync to return a response containing instance management URLs. + catch (OperationCanceledException) + { + return await CreateCheckStatusResponseAsync(client, request, instanceId); + } + } + /// /// Creates an HTTP response that is useful for checking the status of the specified instance. /// @@ -170,13 +235,13 @@ static string BuildUrl(string url, params string?[] queryValues) // The base URL could be null if: // 1. The DurableTaskClient isn't a FunctionsDurableTaskClient (which would have the baseUrl from bindings) // 2. There's no valid HttpRequestData provided - string? baseUrl = ((request != null) ? request.Url.GetLeftPart(UriPartial.Authority) : GetBaseUrl(client)); + string? baseUrl = ((request != null) ? GetBaseUrlFromRequest(request) : GetBaseUrl(client)); if (baseUrl == null) { throw new InvalidOperationException("Failed to create HTTP management payload as base URL is null. Either use Functions bindings or provide an HTTP request to create the HttpPayload."); } - + bool isFromRequest = request != null; string formattedInstanceId = Uri.EscapeDataString(instanceId); @@ -214,6 +279,51 @@ private static ObjectSerializer GetObjectSerializer(HttpResponseData response) ?? throw new InvalidOperationException("A serializer is not configured for the worker."); } + private static string? GetBaseUrlFromRequest(HttpRequestData request) + { + // Default to the scheme from the request URL + string proto = request.Url.Scheme; + string host = request.Url.Authority; + + // Check for "Forwarded" header + if (request.Headers.TryGetValues("Forwarded", out var forwardedHeaders)) + { + var forwardedDict = forwardedHeaders.FirstOrDefault()?.Split(';') + .Select(pair => pair.Split('=')) + .Where(pair => pair.Length == 2) + .ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim()); + + if (forwardedDict != null) + { + if (forwardedDict.TryGetValue("proto", out var forwardedProto)) + { + proto = forwardedProto; + } + if (forwardedDict.TryGetValue("host", out var forwardedHost)) + { + host = forwardedHost; + // Return if either proto or host (or both) were found in "Forwarded" header + return $"{proto}://{forwardedHost}"; + } + } + } + // Check for "X-Forwarded-Proto" and "X-Forwarded-Host" headers if "Forwarded" is not present + if (request.Headers.TryGetValues("X-Forwarded-Proto", out var protos)) + { + proto = protos.FirstOrDefault() ?? proto; + } + if (request.Headers.TryGetValues("X-Forwarded-Host", out var hosts)) + { + // Return base URL if either "X-Forwarded-Proto" or "X-Forwarded-Host" (or both) are found + host = hosts.FirstOrDefault() ?? host; + return $"{proto}://{host}"; + } + + // Construct and return the base URL from default fallback values + return $"{proto}://{host}"; + } + + private static string? GetQueryParams(DurableTaskClient client) { return client is FunctionsDurableTaskClient functions ? functions.QueryString : null; diff --git a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs index 6f975d2c5..1623f4559 100644 --- a/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs +++ b/test/Worker.Extensions.DurableTask.Tests/FunctionsDurableTaskClientTests.cs @@ -1,6 +1,10 @@ +using System.Net; +using Azure.Core.Serialization; using Microsoft.Azure.Functions.Worker.Http; using Microsoft.DurableTask.Client; +using Microsoft.Extensions.Options; using Moq; +using Newtonsoft.Json; namespace Microsoft.Azure.Functions.Worker.Tests { @@ -9,7 +13,7 @@ namespace Microsoft.Azure.Functions.Worker.Tests /// public class FunctionsDurableTaskClientTests { - private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null) + private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? baseUrl = null, OrchestrationMetadata? orchestrationMetadata = null) { // construct mock client @@ -21,6 +25,12 @@ private FunctionsDurableTaskClient GetTestFunctionsDurableTaskClient(string? bas durableClientMock.Setup(x => x.TerminateInstanceAsync( It.IsAny(), It.IsAny(), It.IsAny())).Returns(completedTask); + if (orchestrationMetadata != null) + { + durableClientMock.Setup(x => x.GetInstancesAsync(orchestrationMetadata.InstanceId, It.IsAny(), It.IsAny())) + .ReturnsAsync(orchestrationMetadata); + } + DurableTaskClient durableClient = durableClientMock.Object; FunctionsDurableTaskClient client = new FunctionsDurableTaskClient(durableClient, queryString: null, httpBaseUrl: baseUrl); return client; @@ -82,6 +92,8 @@ public void CreateHttpManagementPayload_WithHttpRequestData() // Create mock HttpRequestData object. var mockFunctionContext = new Mock(); var mockHttpRequestData = new Mock(mockFunctionContext.Object); + var headers = new HttpHeadersCollection(); + mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers); mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri(requestUrl)); HttpManagementPayload payload = client.CreateHttpManagementPayload(instanceId, mockHttpRequestData.Object); @@ -89,6 +101,153 @@ public void CreateHttpManagementPayload_WithHttpRequestData() AssertHttpManagementPayload(payload, "http://localhost:7075/runtime/webhooks/durabletask", instanceId); } + /// + /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns the expected response when the orchestration is completed. + /// The expected response should include OrchestrationMetadata in the body with an HttpStatusCode.OK. + /// + [Fact] + public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenCompleted() + { + string instanceId = "test-instance-id-completed"; + var expectedResult = new OrchestrationMetadata("TestCompleted", instanceId) + { + CreatedAt = DateTime.UtcNow, + LastUpdatedAt = DateTime.UtcNow, + RuntimeStatus = OrchestrationRuntimeStatus.Completed, + SerializedCustomStatus = "TestCustomStatus", + SerializedInput = "TestInput", + SerializedOutput = "TestOutput" + }; + + var client = this.GetTestFunctionsDurableTaskClient( orchestrationMetadata: expectedResult); + + HttpRequestData request = this.MockHttpRequestAndResponseData(); + + HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId); + + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + + // Reset stream position for reading + response.Body.Position = 0; + var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync(response.Body); + + // Assert the response content is not null and check the content is correct. + Assert.NotNull(orchestratorMetadata); + AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata); + } + + /// + /// Test that the `WaitForCompletionOrCreateCheckStatusResponseAsync` method returns expected response when the orchestrator didn't finish within + /// the timeout period. The response body should contain a HttpManagementPayload with HttpStatusCode.Accepted. + /// + [Fact] + public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenRunning() + { + string instanceId = "test-instance-id-running"; + var expectedResult = new OrchestrationMetadata("TestRunning", instanceId) + { + CreatedAt = DateTime.UtcNow, + LastUpdatedAt = DateTime.UtcNow, + RuntimeStatus = OrchestrationRuntimeStatus.Running, + }; + + var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult); + + HttpRequestData request = this.MockHttpRequestAndResponseData(); + HttpResponseData response; + using (CancellationTokenSource cts = new CancellationTokenSource(TimeSpan.FromSeconds(10))) + { + response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, cancellation: cts.Token); + }; + + Assert.NotNull(response); + Assert.Equal(HttpStatusCode.Accepted, response.StatusCode); + + // Reset stream position for reading + response.Body.Position = 0; + HttpManagementPayload? payload; + using (var reader = new StreamReader(response.Body)) + { + payload = JsonConvert.DeserializeObject(await reader.ReadToEndAsync()); + } + + // Assert the response content is not null and check the content is correct. + Assert.NotNull(payload); + AssertHttpManagementPayload(payload, "https://localhost:7075/runtime/webhooks/durabletask", instanceId); + } + + /// + /// Tests the `WaitForCompletionOrCreateCheckStatusResponseAsync` method to ensure it returns the correct HTTP status code + /// based on the `returnInternalServerErrorOnFailure` parameter when the orchestration has failed. + /// + [Theory] + [InlineData(true, HttpStatusCode.InternalServerError)] + [InlineData(false, HttpStatusCode.OK)] + public async Task TestWaitForCompletionOrCreateCheckStatusResponseAsync_WhenFailed(bool returnInternalServerErrorOnFailure, HttpStatusCode expected) + { + string instanceId = "test-instance-id-failed"; + var expectedResult = new OrchestrationMetadata("TestFailed", instanceId) + { + CreatedAt = DateTime.UtcNow, + LastUpdatedAt = DateTime.UtcNow, + RuntimeStatus = OrchestrationRuntimeStatus.Failed, + SerializedOutput = "Microsoft.DurableTask.TaskFailedException: Task 'SayHello' (#0) failed with an unhandled exception: Exception while executing function: Functions.SayHello", + SerializedInput = null + }; + + var client = this.GetTestFunctionsDurableTaskClient(orchestrationMetadata: expectedResult); + + HttpRequestData request = this.MockHttpRequestAndResponseData(); + + HttpResponseData response = await client.WaitForCompletionOrCreateCheckStatusResponseAsync(request, instanceId, returnInternalServerErrorOnFailure: returnInternalServerErrorOnFailure); + + Assert.NotNull(response); + Assert.Equal(expected, response.StatusCode); + + // Reset stream position for reading + response.Body.Position = 0; + var orchestratorMetadata = await System.Text.Json.JsonSerializer.DeserializeAsync(response.Body); + + // Assert the response content is not null and check the content is correct. + Assert.NotNull(orchestratorMetadata); + AssertOrhcestrationMetadata(expectedResult, orchestratorMetadata); + } + + /// + /// Tests the `GetBaseUrlFromRequest` can return the right base URL from the HttpRequestData with different forwarding or proxies. + /// This test covers the following scenarios: + /// - Using the "Forwarded" header + /// - Using "X-Forwarded-Proto" and "X-Forwarded-Host" headers + /// - Using only "X-Forwarded-Host" with default protocol + /// - no headers + /// + [Theory] + [InlineData("Forwarded", "proto=https;host=forwarded.example.com","","", "https://forwarded.example.com/runtime/webhooks/durabletask")] + [InlineData("X-Forwarded-Proto", "https", "X-Forwarded-Host", "xforwarded.example.com", "https://xforwarded.example.com/runtime/webhooks/durabletask")] + [InlineData("", "", "X-Forwarded-Host", "test.net", "https://test.net/runtime/webhooks/durabletask")] + [InlineData("", "", "", "", "https://localhost:7075/runtime/webhooks/durabletask")] // Default base URL for empty headers + public void TestHttpRequestDataForwardingHandling(string header1, string? value1, string header2, string value2, string expectedBaseUrl) + { + var headers = new HttpHeadersCollection(); + if (!string.IsNullOrEmpty(header1)) + { + headers.Add(header1, value1); + } + if (!string.IsNullOrEmpty(header2)) + { + headers.Add(header2, value2); + } + + var request = this.MockHttpRequestAndResponseData(headers); + var client = this.GetTestFunctionsDurableTaskClient(); + + var payload = client.CreateHttpManagementPayload("testInstanceId", request); + AssertHttpManagementPayload(payload, expectedBaseUrl, "testInstanceId"); + } + + + private static void AssertHttpManagementPayload(HttpManagementPayload payload, string BaseUrl, string instanceId) { Assert.Equal(instanceId, payload.Id); @@ -99,5 +258,79 @@ private static void AssertHttpManagementPayload(HttpManagementPayload payload, s Assert.Equal($"{BaseUrl}/instances/{instanceId}/suspend?reason={{{{text}}}}", payload.SuspendPostUri); Assert.Equal($"{BaseUrl}/instances/{instanceId}/resume?reason={{{{text}}}}", payload.ResumePostUri); } + + private static void AssertOrhcestrationMetadata(OrchestrationMetadata expectedResult, dynamic actualResult) + { + Assert.Equal(expectedResult.Name, actualResult.GetProperty("Name").GetString()); + Assert.Equal(expectedResult.InstanceId, actualResult.GetProperty("InstanceId").GetString()); + Assert.Equal(expectedResult.CreatedAt, actualResult.GetProperty("CreatedAt").GetDateTime()); + Assert.Equal(expectedResult.LastUpdatedAt, actualResult.GetProperty("LastUpdatedAt").GetDateTime()); + Assert.Equal(expectedResult.RuntimeStatus.ToString(), actualResult.GetProperty("RuntimeStatus").GetString()); + Assert.Equal(expectedResult.SerializedInput, actualResult.GetProperty("SerializedInput").GetString()); + Assert.Equal(expectedResult.SerializedOutput, actualResult.GetProperty("SerializedOutput").GetString()); + Assert.Equal(expectedResult.SerializedCustomStatus, actualResult.GetProperty("SerializedCustomStatus").GetString()); + } + + // Mocks the required HttpRequestData and HttpResponseData for testing purposes. + // This method sets up a mock HttpRequestData with a predefined URL and a mock HttpResponseDatav with a default status code and body. + // The headers of HttpRequestData can be provided as an optional parameter, otherwise an empty HttpHeadersCollection is used. + private HttpRequestData MockHttpRequestAndResponseData(HttpHeadersCollection? headers = null) + { + var mockObjectSerializer = new Mock(); + + // Setup the SerializeAsync method + mockObjectSerializer.Setup(s => s.SerializeAsync(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(async (stream, value, type, token) => + { + await System.Text.Json.JsonSerializer.SerializeAsync(stream, value, type, cancellationToken: token); + }); + + var workerOptions = new WorkerOptions + { + Serializer = mockObjectSerializer.Object + }; + var mockOptions = new Mock>(); + mockOptions.Setup(o => o.Value).Returns(workerOptions); + + // Mock the service provider + var mockServiceProvider = new Mock(); + + // Set up the service provider to return the mock IOptions + mockServiceProvider.Setup(sp => sp.GetService(typeof(IOptions))) + .Returns(mockOptions.Object); + + // Set up the service provider to return the mock ObjectSerializer + mockServiceProvider.Setup(sp => sp.GetService(typeof(ObjectSerializer))) + .Returns(mockObjectSerializer.Object); + + // Create a mock FunctionContext and assign the service provider + var mockFunctionContext = new Mock(); + mockFunctionContext.SetupGet(c => c.InstanceServices).Returns(mockServiceProvider.Object); + var mockHttpRequestData = new Mock(mockFunctionContext.Object); + + // Set up the URL property. + mockHttpRequestData.SetupGet(r => r.Url).Returns(new Uri("https://localhost:7075/orchestrators/E1_HelloSequence")); + + // If headers are provided, use them, otherwise create a new empty HttpHeadersCollection + headers ??= new HttpHeadersCollection(); + + // Setup the Headers property to return the empty headers + mockHttpRequestData.SetupGet(r => r.Headers).Returns(headers); + + var mockHttpResponseData = new Mock(mockFunctionContext.Object) + { + DefaultValue = DefaultValue.Mock + }; + + // Enable setting StatusCode and Body as mutable properties + mockHttpResponseData.SetupProperty(r => r.StatusCode, HttpStatusCode.OK); + mockHttpResponseData.SetupProperty(r => r.Body, new MemoryStream()); + + // Setup CreateResponse to return the configured HttpResponseData mock + mockHttpRequestData.Setup(r => r.CreateResponse()) + .Returns(mockHttpResponseData.Object); + + return mockHttpRequestData.Object; + } } }