diff --git a/application/AppGateway/Program.cs b/application/AppGateway/Program.cs index 4f86b5086..58778749b 100644 --- a/application/AppGateway/Program.cs +++ b/application/AppGateway/Program.cs @@ -32,6 +32,11 @@ ); } +builder.Services.AddSingleton<BlockInternalApiTransform>(); +reverseProxyBuilder.AddTransforms(context => + context.RequestTransforms.Add(context.Services.GetRequiredService<BlockInternalApiTransform>()) +); + builder.Services.AddNamedBlobStorages(builder, ("avatars-storage", "AVATARS_STORAGE_URL")); builder.WebHost.UseKestrel(option => option.AddServerHeader = false); diff --git a/application/AppGateway/Transformations/BlockInternalApiTransform.cs b/application/AppGateway/Transformations/BlockInternalApiTransform.cs new file mode 100644 index 000000000..da8286c09 --- /dev/null +++ b/application/AppGateway/Transformations/BlockInternalApiTransform.cs @@ -0,0 +1,16 @@ +using Yarp.ReverseProxy.Transforms; + +namespace PlatformPlatform.AppGateway.Transformations; + +public class BlockInternalApiTransform : RequestTransform +{ + public override async ValueTask ApplyAsync(RequestTransformContext context) + { + if (context.HttpContext.Request.Path.Value?.Contains("/internal-api/", StringComparison.OrdinalIgnoreCase) == true) + { + context.HttpContext.Response.StatusCode = StatusCodes.Status403Forbidden; + context.HttpContext.Response.ContentType = "text/plain"; + await context.HttpContext.Response.WriteAsync("Access to internal API is forbidden."); + } + } +} diff --git a/application/AppGateway/Transformations/ManagedIdentityTransform.cs b/application/AppGateway/Transformations/ManagedIdentityTransform.cs index 13e770510..e625f5fb6 100644 --- a/application/AppGateway/Transformations/ManagedIdentityTransform.cs +++ b/application/AppGateway/Transformations/ManagedIdentityTransform.cs @@ -8,7 +8,10 @@ public class ManagedIdentityTransform(TokenCredential credential) { protected override string? GetValue(RequestTransformContext context) { - if (!context.HttpContext.Request.Path.StartsWithSegments("/avatars")) return null; + if (!context.HttpContext.Request.Path.StartsWithSegments("/avatars", StringComparison.OrdinalIgnoreCase)) + { + return null; + } var tokenRequestContext = new TokenRequestContext(["https://storage.azure.com/.default"]); var token = credential.GetToken(tokenRequestContext, context.HttpContext.RequestAborted);