Skip to content

Commit

Permalink
Allow to clear default entries in collections in options (#1353)
Browse files Browse the repository at this point in the history
* Fixed: configuration changes are ignored by middleware/handlers that access the collection of all the various endpoint option types

* Fixes and test coverage for overriding HTTP verbs on refresh endpoint

* Fixes and test coverage for clearing exposure on management endpoints

* Fixes and test coverage for clearing keys to sanitize on management endpoints

* Fixes and test coverage for clearing health endpoint groups on management endpoints

* Fixed: don't expose endpoint on all verbs when using conventional routing while none configured
Fixed: when using conventional routing with multiple actuators, only the first middleware ever executes (so for all non-first endpoints, the first (wrong) middleware executes)
  • Loading branch information
bart-vmware authored Sep 6, 2024
1 parent 7676649 commit 50b22fa
Show file tree
Hide file tree
Showing 27 changed files with 727 additions and 117 deletions.
14 changes: 9 additions & 5 deletions src/Management/src/Abstractions/Configuration/EndpointOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,18 @@ public virtual string? Path
/// <summary>
/// Gets the list of HTTP verbs that are allowed for this endpoint.
/// </summary>
public IList<string> AllowedVerbs { get; }
public IList<string> AllowedVerbs { get; private set; } = new List<string>();

protected EndpointOptions()
internal HashSet<string> GetSafeAllowedVerbs()
{
// Caution: Mapping with an empty string in the list results in exposing the endpoint at ALL verbs.
// And duplicate verbs that only differ in case result in an ambiguous match error when mapping routes.
return AllowedVerbs.Where(verb => verb.Length > 0).ToHashSet(StringComparer.OrdinalIgnoreCase);
}

internal void ApplyDefaultAllowedVerbs()
{
// ReSharper disable once VirtualMemberCallInConstructor
#pragma warning disable S1699 // Constructors should only call non-overridable methods
AllowedVerbs = GetDefaultAllowedVerbs();
#pragma warning restore S1699 // Constructors should only call non-overridable methods
}

/// <summary>
Expand Down
13 changes: 9 additions & 4 deletions src/Management/src/Endpoint/ActuatorEndpointMapper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,23 @@ public void Map(IEndpointRouteBuilder endpointRouteBuilder, ActuatorConventionBu
InnerMap(middleware => endpointRouteBuilder.CreateApplicationBuilder().UseMiddleware(middleware.GetType()).Build(),
(middleware, requestPath, pipeline) =>
{
IEndpointConventionBuilder builder = endpointRouteBuilder.MapMethods(requestPath, middleware.EndpointOptions.AllowedVerbs, pipeline);
conventionBuilder.Add(builder);
HashSet<string> allowedVerbs = middleware.EndpointOptions.GetSafeAllowedVerbs();
if (allowedVerbs.Count > 0)
{
IEndpointConventionBuilder endpointConventionBuilder = endpointRouteBuilder.MapMethods(requestPath, allowedVerbs, pipeline);
conventionBuilder.Add(endpointConventionBuilder);
}
});
}

public void Map(IRouteBuilder routeBuilder)
{
ArgumentNullException.ThrowIfNull(routeBuilder);

InnerMap(middleware => routeBuilder.ApplicationBuilder.UseMiddleware(middleware.GetType()).Build(), (middleware, requestPath, pipeline) =>
InnerMap(middleware => routeBuilder.ApplicationBuilder.New().UseMiddleware(middleware.GetType()).Build(), (middleware, requestPath, pipeline) =>
{
foreach (string verb in middleware.EndpointOptions.AllowedVerbs)
foreach (string verb in middleware.EndpointOptions.GetSafeAllowedVerbs())
{
routeBuilder.MapVerb(verb, requestPath, pipeline);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,7 @@ internal static void ConfigureEndpointOptions<TOptions, TConfigureOptions>(this
ArgumentNullException.ThrowIfNull(services);

services.ConfigureOptionsWithChangeTokenSource<TOptions, TConfigureOptions>();

services.TryAddEnumerable(
ServiceDescriptor.Singleton<EndpointOptions, TOptions>(provider => provider.GetRequiredService<IOptionsMonitor<TOptions>>().CurrentValue));
services.TryAddEnumerable(ServiceDescriptor.Singleton<IEndpointOptionsMonitorProvider, EndpointOptionsMonitorProvider<TOptions>>());
}

internal static void ConfigureOptionsWithChangeTokenSource<TOptions, TConfigureOptions>(this IServiceCollection services)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,36 @@ internal sealed class CloudFoundryEndpointHandler : ICloudFoundryEndpointHandler
{
private readonly IOptionsMonitor<ManagementOptions> _managementOptionsMonitor;
private readonly IOptionsMonitor<CloudFoundryEndpointOptions> _endpointOptionsMonitor;
private readonly EndpointOptions[] _endpointOptionsArray;
private readonly IEndpointOptionsMonitorProvider[] _optionsMonitorProviderArray;
private readonly ILogger<HypermediaService> _hypermediaServiceLogger;

public EndpointOptions Options => _endpointOptionsMonitor.CurrentValue;

public CloudFoundryEndpointHandler(IOptionsMonitor<ManagementOptions> managementOptionsMonitor,
IOptionsMonitor<CloudFoundryEndpointOptions> endpointOptionsMonitor, IEnumerable<EndpointOptions> endpointOptionsCollection,
IOptionsMonitor<CloudFoundryEndpointOptions> endpointOptionsMonitor, IEnumerable<IEndpointOptionsMonitorProvider> endpointOptionsMonitorProviders,
ILoggerFactory loggerFactory)
{
ArgumentNullException.ThrowIfNull(managementOptionsMonitor);
ArgumentNullException.ThrowIfNull(endpointOptionsMonitor);
ArgumentNullException.ThrowIfNull(endpointOptionsCollection);
ArgumentNullException.ThrowIfNull(endpointOptionsMonitorProviders);
ArgumentNullException.ThrowIfNull(loggerFactory);

EndpointOptions[] endpointOptionsArray = endpointOptionsCollection.ToArray();
ArgumentGuard.ElementsNotNull(endpointOptionsArray);
IEndpointOptionsMonitorProvider[] optionsMonitorProviderArray = endpointOptionsMonitorProviders.ToArray();
ArgumentGuard.ElementsNotNull(optionsMonitorProviderArray);

_managementOptionsMonitor = managementOptionsMonitor;
_endpointOptionsMonitor = endpointOptionsMonitor;
_endpointOptionsArray = endpointOptionsArray;
_optionsMonitorProviderArray = optionsMonitorProviderArray;
_hypermediaServiceLogger = loggerFactory.CreateLogger<HypermediaService>();
}

public async Task<Links> InvokeAsync(string baseUrl, CancellationToken cancellationToken)
{
ArgumentException.ThrowIfNullOrWhiteSpace(baseUrl);

var hypermediaService = new HypermediaService(_managementOptionsMonitor, _endpointOptionsMonitor, _endpointOptionsArray, _hypermediaServiceLogger);
var hypermediaService =
new HypermediaService(_managementOptionsMonitor, _endpointOptionsMonitor, _optionsMonitorProviderArray, _hypermediaServiceLogger);

Links result = hypermediaService.Invoke(baseUrl);
return await Task.FromResult(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@ internal sealed class CloudFoundrySecurityMiddleware
{
private readonly IOptionsMonitor<ManagementOptions> _managementOptionsMonitor;
private readonly IOptionsMonitor<CloudFoundryEndpointOptions> _endpointOptionsMonitor;
private readonly EndpointOptions[] _endpointOptionsArray;
private readonly IEndpointOptionsMonitorProvider[] _endpointOptionsMonitorProviderArray;
private readonly RequestDelegate? _next;
private readonly ILogger<CloudFoundrySecurityMiddleware> _logger;
private readonly PermissionsProvider _permissionsProvider;

public CloudFoundrySecurityMiddleware(IOptionsMonitor<ManagementOptions> managementOptionsMonitor,
IOptionsMonitor<CloudFoundryEndpointOptions> endpointOptionsMonitor, IEnumerable<EndpointOptions> endpointOptionsCollection,
IOptionsMonitor<CloudFoundryEndpointOptions> endpointOptionsMonitor, IEnumerable<IEndpointOptionsMonitorProvider> endpointOptionsMonitorProviders,
PermissionsProvider permissionsProvider, ILogger<CloudFoundrySecurityMiddleware> logger, RequestDelegate? next)
{
ArgumentNullException.ThrowIfNull(managementOptionsMonitor);
ArgumentNullException.ThrowIfNull(endpointOptionsMonitor);
ArgumentNullException.ThrowIfNull(endpointOptionsCollection);
ArgumentNullException.ThrowIfNull(endpointOptionsMonitorProviders);
ArgumentNullException.ThrowIfNull(permissionsProvider);
ArgumentNullException.ThrowIfNull(logger);

EndpointOptions[] endpointOptionsArray = endpointOptionsCollection.ToArray();
ArgumentGuard.ElementsNotNull(endpointOptionsArray);
IEndpointOptionsMonitorProvider[] endpointOptionsMonitorProviderArray = endpointOptionsMonitorProviders.ToArray();
ArgumentGuard.ElementsNotNull(endpointOptionsMonitorProviderArray);

_managementOptionsMonitor = managementOptionsMonitor;
_endpointOptionsMonitor = endpointOptionsMonitor;
_endpointOptionsArray = endpointOptionsArray.Where(options => options is not HypermediaEndpointOptions).ToArray();
_endpointOptionsMonitorProviderArray = endpointOptionsMonitorProviderArray;
_permissionsProvider = permissionsProvider;
_logger = logger;
_next = next;
Expand Down Expand Up @@ -125,7 +125,8 @@ internal Task<SecurityResult> GetPermissionsAsync(HttpContext context)

private EndpointOptions? FindTargetEndpoint(PathString requestPath)
{
foreach (EndpointOptions endpointOptions in _endpointOptionsArray)
foreach (EndpointOptions endpointOptions in _endpointOptionsMonitorProviderArray.Select(provider => provider.Get())
.Where(options => options is not HypermediaEndpointOptions))
{
string basePath = ConfigureManagementOptions.DefaultCloudFoundryPath;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,17 @@ public override void Configure(EnvironmentEndpointOptions options)

// It's not possible to distinguish between null and an empty list in configuration.
// See https://github.com/dotnet/extensions/issues/1341.
// As a workaround, we interpret a single empty string element to clear the defaults.
if (options.KeysToSanitize.Count == 0)
{
foreach (string defaultKey in DefaultKeysToSanitize)
{
options.KeysToSanitize.Add(defaultKey);
}
}
else if (options.KeysToSanitize is [""])
{
options.KeysToSanitize.Clear();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Steeltoe.Management.Endpoint.Actuators.Environment;
public sealed class EnvironmentEndpointOptions : EndpointOptions
{
/// <summary>
/// Gets the list of keys to sanitize.
/// Gets the list of keys to sanitize. Allows regular expressions.
/// </summary>
public IList<string> KeysToSanitize { get; } = new List<string>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,21 @@ public override void Configure(HealthEndpointOptions options)
};
}

options.Groups.TryAdd("liveness", new HealthGroupOptions
if (options.Groups.Count == 0)
{
Include = "liveness"
});
options.Groups["liveness"] = new HealthGroupOptions
{
Include = "liveness"
};

options.Groups.TryAdd("readiness", new HealthGroupOptions
options.Groups["readiness"] = new HealthGroupOptions
{
Include = "readiness"
};
}
else if (options.Groups.Count == 1 && options.Groups.TryGetValue(string.Empty, out HealthGroupOptions? group) && string.IsNullOrEmpty(group.Include))
{
Include = "readiness"
});
options.Groups.Clear();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,34 @@ internal sealed class ActuatorEndpointHandler : IActuatorEndpointHandler
{
private readonly IOptionsMonitor<ManagementOptions> _managementOptionsMonitor;
private readonly IOptionsMonitor<HypermediaEndpointOptions> _endpointOptionsMonitor;
private readonly EndpointOptions[] _endpointOptionsArray;
private readonly IEndpointOptionsMonitorProvider[] _endpointOptionsMonitorProviderArray;
private readonly ILogger<HypermediaService> _hypermediaServiceLogger;

public EndpointOptions Options => _endpointOptionsMonitor.CurrentValue;

public ActuatorEndpointHandler(IOptionsMonitor<ManagementOptions> managementOptionsMonitor,
IOptionsMonitor<HypermediaEndpointOptions> endpointOptionsMonitor, IEnumerable<EndpointOptions> endpointOptionsCollection, ILoggerFactory loggerFactory)
IOptionsMonitor<HypermediaEndpointOptions> endpointOptionsMonitor, IEnumerable<IEndpointOptionsMonitorProvider> endpointOptionsMonitorProviders,
ILoggerFactory loggerFactory)
{
ArgumentNullException.ThrowIfNull(managementOptionsMonitor);
ArgumentNullException.ThrowIfNull(endpointOptionsMonitor);
ArgumentNullException.ThrowIfNull(endpointOptionsCollection);
ArgumentNullException.ThrowIfNull(endpointOptionsMonitorProviders);
ArgumentNullException.ThrowIfNull(loggerFactory);

EndpointOptions[] endpointOptionsArray = endpointOptionsCollection.ToArray();
ArgumentGuard.ElementsNotNull(endpointOptionsArray);
IEndpointOptionsMonitorProvider[] endpointOptionsMonitorProviderArray = endpointOptionsMonitorProviders.ToArray();
ArgumentGuard.ElementsNotNull(endpointOptionsMonitorProviderArray);

_managementOptionsMonitor = managementOptionsMonitor;
_endpointOptionsMonitor = endpointOptionsMonitor;
_endpointOptionsArray = endpointOptionsArray;
_endpointOptionsMonitorProviderArray = endpointOptionsMonitorProviderArray;
_hypermediaServiceLogger = loggerFactory.CreateLogger<HypermediaService>();
}

public Task<Links> InvokeAsync(string baseUrl, CancellationToken cancellationToken)
{
ArgumentException.ThrowIfNullOrWhiteSpace(baseUrl);

var service = new HypermediaService(_managementOptionsMonitor, _endpointOptionsMonitor, _endpointOptionsArray, _hypermediaServiceLogger);
var service = new HypermediaService(_managementOptionsMonitor, _endpointOptionsMonitor, _endpointOptionsMonitorProviderArray, _hypermediaServiceLogger);
Links result = service.Invoke(baseUrl);
return Task.FromResult(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,38 @@ internal sealed class HypermediaService
{
private readonly IOptionsMonitor<ManagementOptions> _managementOptionsMonitor;
private readonly EndpointOptions _endpointOptions;
private readonly ICollection<EndpointOptions> _endpointOptionsCollection;
private readonly ICollection<IEndpointOptionsMonitorProvider> _endpointOptionsMonitorProviders;
private readonly ILogger<HypermediaService> _logger;

public HypermediaService(IOptionsMonitor<ManagementOptions> managementOptionsMonitor,
IOptionsMonitor<HypermediaEndpointOptions> hypermediaEndpointOptionsMonitor, ICollection<EndpointOptions> endpointOptionsCollection,
ILogger<HypermediaService> logger)
IOptionsMonitor<HypermediaEndpointOptions> hypermediaEndpointOptionsMonitor,
ICollection<IEndpointOptionsMonitorProvider> endpointOptionsMonitorProviders, ILogger<HypermediaService> logger)
{
ArgumentNullException.ThrowIfNull(managementOptionsMonitor);
ArgumentNullException.ThrowIfNull(hypermediaEndpointOptionsMonitor);
ArgumentNullException.ThrowIfNull(endpointOptionsCollection);
ArgumentGuard.ElementsNotNull(endpointOptionsCollection);
ArgumentNullException.ThrowIfNull(endpointOptionsMonitorProviders);
ArgumentGuard.ElementsNotNull(endpointOptionsMonitorProviders);
ArgumentNullException.ThrowIfNull(logger);

_managementOptionsMonitor = managementOptionsMonitor;
_endpointOptions = hypermediaEndpointOptionsMonitor.CurrentValue;
_endpointOptionsCollection = endpointOptionsCollection;
_endpointOptionsMonitorProviders = endpointOptionsMonitorProviders;
_logger = logger;
}

public HypermediaService(IOptionsMonitor<ManagementOptions> managementOptionsMonitor,
IOptionsMonitor<CloudFoundryEndpointOptions> cloudFoundryEndpointOptionsMonitor, ICollection<EndpointOptions> endpointOptionsCollection,
ILogger<HypermediaService> logger)
IOptionsMonitor<CloudFoundryEndpointOptions> cloudFoundryEndpointOptionsMonitor,
ICollection<IEndpointOptionsMonitorProvider> endpointOptionsMonitorProviders, ILogger<HypermediaService> logger)
{
ArgumentNullException.ThrowIfNull(managementOptionsMonitor);
ArgumentNullException.ThrowIfNull(cloudFoundryEndpointOptionsMonitor);
ArgumentNullException.ThrowIfNull(endpointOptionsCollection);
ArgumentGuard.ElementsNotNull(endpointOptionsCollection);
ArgumentNullException.ThrowIfNull(endpointOptionsMonitorProviders);
ArgumentGuard.ElementsNotNull(endpointOptionsMonitorProviders);
ArgumentNullException.ThrowIfNull(logger);

_managementOptionsMonitor = managementOptionsMonitor;
_endpointOptions = cloudFoundryEndpointOptionsMonitor.CurrentValue;
_endpointOptionsCollection = endpointOptionsCollection;
_endpointOptionsMonitorProviders = endpointOptionsMonitorProviders;
_logger = logger;
}

Expand All @@ -65,7 +65,7 @@ public Links Invoke(string baseUrl)

Link? selfLink = null;

foreach (EndpointOptions endpointOptions in _endpointOptionsCollection)
foreach (EndpointOptions endpointOptions in _endpointOptionsMonitorProviders.Select(provider => provider.Get()))
{
if (!endpointOptions.IsEnabled(_managementOptionsMonitor.CurrentValue) || !endpointOptions.IsExposed(_managementOptionsMonitor.CurrentValue))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,26 @@

namespace Steeltoe.Management.Endpoint.Actuators.Trace;

internal sealed class ConfigureTraceEndpointOptions : ConfigureEndpointOptions<TraceEndpointOptions>, IConfigureNamedOptions<TraceEndpointOptions>
internal sealed class ConfigureTraceEndpointOptions(IConfiguration configuration)
: ConfigureEndpointOptions<TraceEndpointOptions>(configuration, ManagementInfoPrefixV2, EndpointIdV2), IConfigureNamedOptions<TraceEndpointOptions>
{
private const string EndpointIdV1 = "trace";
private const string EndpointIdV2 = "httptrace";
private const string ManagementInfoPrefixV1 = "management:endpoints:trace";
private const string ManagementInfoPrefix = "management:endpoints:httptrace";
private const string ManagementInfoPrefixV2 = "management:endpoints:httptrace";
private const int DefaultCapacity = 100;
private readonly IConfiguration _configuration;

public ConfigureTraceEndpointOptions(IConfiguration configuration)
: base(configuration, ManagementInfoPrefix, "httptrace")
{
ArgumentNullException.ThrowIfNull(configuration);

_configuration = configuration;
}

public void Configure(string? name, TraceEndpointOptions options)
{
ArgumentNullException.ThrowIfNull(options);

if (name == TraceEndpointOptionNames.V2.ToString() || string.IsNullOrEmpty(name))
{
Configure(options);
ConfigureAtKey(Configuration, ManagementInfoPrefixV2, EndpointIdV2, options);
}
else
{
_configuration.GetSection(ManagementInfoPrefixV1).Bind(options);

if (string.IsNullOrEmpty(options.Id))
{
options.Id = "trace";
}
ConfigureAtKey(Configuration, ManagementInfoPrefixV1, EndpointIdV1, options);
}

if (options.Capacity == -1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,22 @@ protected ConfigureEndpointOptions(IConfiguration configuration, string prefix,
}

public virtual void Configure(T options)
{
ConfigureAtKey(Configuration, _prefix, _id, options);
}

protected static void ConfigureAtKey(IConfiguration configuration, string configurationKey, string endpointId, T options)
{
ArgumentNullException.ThrowIfNull(options);

Configuration.GetSection(_prefix).Bind(options);
configuration.GetSection(configurationKey).Bind(options);

options.Id ??= _id;
options.Id ??= endpointId;

if (options.AllowedVerbs.Count == 0)
{
options.ApplyDefaultAllowedVerbs();
}

if (!Enum.IsDefined(options.RequiredPermissions))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ public void Configure(Exposure options)
{
ReplaceCollection(options.Include, DefaultIncludes);
}
else
{
if (options.Include is [""])
{
ReplaceCollection(options.Include, Array.Empty<string>());
}

if (options.Exclude is [""])
{
ReplaceCollection(options.Exclude, Array.Empty<string>());
}
}
}

private static List<string>? GetListFromConfigurationCsvString(IConfigurationSection section, string key)
Expand Down
Loading

0 comments on commit 50b22fa

Please sign in to comment.