Skip to content

Commit

Permalink
Add failed policy requirement names to UnauthorizedAccessException (#460
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jasongin authored Oct 18, 2024
1 parent f46cdab commit 04f0200
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
10 changes: 9 additions & 1 deletion cs/src/Management/TunnelManagementClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class TunnelManagementClient : ITunnelManagementClient
private const string TunnelAuthenticationScheme = "Tunnel";
private const string RequestIdHeaderName = "VsSaaS-Request-Id";
private const string CheckAvailableSubPath = ":checkNameAvailability";
private const string EnterprisePolicyFailureHeaderName = "X-Enterprise-Policy-Failure";
private const int CreateNameRetries = 3;

private static readonly string[] ManageAccessTokenScope =
Expand Down Expand Up @@ -640,7 +641,7 @@ private string UserLimitsPath
}

// Enterprise Policies
if (response.Headers.Contains("X-Enterprise-Policy-Failure"))
if (response.Headers.Contains(EnterprisePolicyFailureHeaderName))
{
errorMessage = problemDetails!.Title + ": " + problemDetails.Detail;
}
Expand Down Expand Up @@ -728,6 +729,13 @@ private string UserLimitsPath
ex.SetAuthenticationSchemes(authHeaderValues);
}

// Propagate failed policy requirement names.
if (response.Headers.TryGetValues(
EnterprisePolicyFailureHeaderName, out var policyFailureValues))
{
ex.SetEnterprisePolicyRequirements(policyFailureValues);
}

throw ex;

case HttpStatusCode.NotFound:
Expand Down
41 changes: 39 additions & 2 deletions cs/src/Management/UnauthorizedAccessExceptionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ namespace Microsoft.DevTunnels.Management;
public static class UnauthorizedAccessExceptionExtensions
{
private const string AuthenticationSchemesKey = "AuthenticationSchemes";
private const string EnterprisePolicyRequirementsKey = "EnterprisePolicyRequirements";

/// <summary>
/// Gets the list of schemes that may be used to authenticate, when an
/// <see cref="UnauthorizedAccessException" /> was thrown for an unauthenticated request.
/// </summary>
public static IEnumerable<AuthenticationHeaderValue>? GetAuthenticationSchemes(
public static IEnumerable<AuthenticationHeaderValue> GetAuthenticationSchemes(
this UnauthorizedAccessException ex)
{
Requires.NotNull(ex, nameof(ex));
Expand All @@ -32,7 +33,7 @@ public static class UnauthorizedAccessExceptionExtensions
return authenticationSchemes?
.Select((s) => AuthenticationHeaderValue.TryParse(s, out var value) ? value : null!)
.Where((s) => s != null)
.ToArray();
.ToArray() ?? Enumerable.Empty<AuthenticationHeaderValue>();
}
}

Expand All @@ -58,4 +59,40 @@ internal static void SetAuthenticationSchemes(
ex.Data[AuthenticationSchemesKey] = authenticationSchemes?.ToArray();
}
}

/// <summary>
/// Gets the list of enterprise policy requirements that caused the
/// <see cref="UnauthorizedAccessException" />.
/// </summary>
/// <remarks>
/// Each item is a non-localized string policy requirement name, such as:
/// "DisableAnonymousAccessRequirement",
/// "DisableDevTunnelsRequirement",
/// "RestrictedTenantAccessRequirement"
/// </remarks>
public static IEnumerable<string> GetEnterprisePolicyRequirements(
this UnauthorizedAccessException ex)
{
Requires.NotNull(ex, nameof(ex));

lock (ex.Data)
{
return ex.Data[EnterprisePolicyRequirementsKey] as string[] ??
Enumerable.Empty<string>();
}
}

/// <summary>
/// Sets the list of enterprise policy requirements that caused the
/// <see cref="UnauthorizedAccessException" />.
/// </summary>
public static void SetEnterprisePolicyRequirements(
this UnauthorizedAccessException ex,
IEnumerable<string>? enterprisePolicyRequirements)
{
lock (ex.Data)
{
ex.Data[EnterprisePolicyRequirementsKey] = enterprisePolicyRequirements?.ToArray();
}
}
}
36 changes: 36 additions & 0 deletions cs/test/TunnelsSDK.Test/TunnelManagementClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,42 @@ public async Task HandleFirewallResponse()
Assert.Contains(tunnelServiceUri.Host, ex.Message);
}

[Fact]
public async Task HandlePolicyFailureResponse()
{
const string policyRequirement1 = "DisableAnonymousAccessRequirement";
const string policyRequirement2 = "DisableAnonymousAccessRequirement";

var handler = new MockHttpMessageHandler(
(message, ct) =>
{
var result = new HttpResponseMessage(HttpStatusCode.Unauthorized)
{
RequestMessage = message,
};
result.Headers.Add("X-Enterprise-Policy-Failure", policyRequirement1);
result.Headers.Add("X-Enterprise-Policy-Failure", policyRequirement2);
return Task.FromResult(result);
});

var client = new TunnelManagementClient(this.userAgent, null, this.tunnelServiceUri, handler);

var requestTunnel = new Tunnel
{
TunnelId = TunnelId,
ClusterId = ClusterId,
};

var ex = await Assert.ThrowsAsync<UnauthorizedAccessException>(
() => client.GetTunnelAsync(requestTunnel, options: null, this.timeout));
Assert.Collection(
ex.GetEnterprisePolicyRequirements(),
(r) => Assert.Equal(policyRequirement1, r),
(r) => Assert.Equal(policyRequirement2, r));
}



private sealed class MockHttpMessageHandler : DelegatingHandler
{
private readonly Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler;
Expand Down

0 comments on commit 04f0200

Please sign in to comment.