Skip to content

Commit

Permalink
feat(security): add config to disallow parts of schema (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyberhck authored Aug 12, 2022
1 parent c60f21b commit cc47be5
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 6 deletions.
8 changes: 4 additions & 4 deletions API/API.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="HotChocolate.ApolloFederation" Version="12.12.1" />
<PackageReference Include="HotChocolate.AspNetCore" Version="12.12.1" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="6.0.0" />
<PackageReference Include="HotChocolate.ApolloFederation" Version="12.12.1"/>
<PackageReference Include="HotChocolate.AspNetCore" Version="12.12.1"/>
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="6.0.0"/>
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\Business\Business.csproj" />
<ProjectReference Include="..\Business\Business.csproj"/>
</ItemGroup>

</Project>
64 changes: 64 additions & 0 deletions API/Attributes/Protected.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
using System.Reflection;
using API.Configs;
using Business.Exceptions;
using HotChocolate.Resolvers;
using HotChocolate.Types.Descriptors;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Options;

namespace API.Attributes;

public class ProtectedAttribute : ObjectFieldDescriptorAttribute
{
private readonly string _operation;

public ProtectedAttribute(string operation)
{
_operation = operation;
}

public override void OnConfigure(IDescriptorContext context, IObjectFieldDescriptor descriptor, MemberInfo member)
{
descriptor.Use(next => ctx =>
{
var operationRule = GetSecurityRequirements(ctx);
if (operationRule == null)
{
return next.Invoke(ctx);
}
var headerValue = GetHeaderValue(ctx, operationRule);
if (headerValue == null || headerValue != operationRule.Value)
{
throw new OperationNotPermitted("you do not have access to perform this operation");
}
return next.Invoke(ctx);
});
}

private string? GetHeaderValue(IMiddlewareContext ctx, SecurityRequirements rule)
{
ctx.ContextData.TryGetValue("HttpContext", out var ctxData);
if (ctxData == null || ctxData.GetType() != typeof(DefaultHttpContext))
{
throw new ApplicationException("http context not configured correctly");
}

var httpContext = (DefaultHttpContext) ctxData;
if (!httpContext.Request.Headers.TryGetValue(rule.Header, out var headerValue))
{
return null;
}

return headerValue;
}
private SecurityRequirements? GetSecurityRequirements(IMiddlewareContext ctx)
{
var securityOptions = ctx.Service<IOptions<Security>>().Value;
if (securityOptions == null || securityOptions.Rules == null)
{
return null;
}
securityOptions.Rules.TryGetValue(_operation, out var securityRequirements);
return securityRequirements;
}
}
12 changes: 12 additions & 0 deletions API/Configs/Security.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace API.Configs;

public class Security
{
public Dictionary<string, SecurityRequirements>? Rules { set; get; }
}

public class SecurityRequirements
{
public string Header { set; get; }
public string Value { set; get; }
}
3 changes: 2 additions & 1 deletion API/ExceptionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ public IError OnError(IError error)
{
return error.Exception == null
? error
: error.WithMessage(error.Exception.Message).WithCode(ToSnakeCase(error.Exception.GetType().Name).ToUpper());
: error.WithMessage(error.Exception.Message)
.WithCode(ToSnakeCase(error.Exception.GetType().Name).ToUpper());
}

private static string ToSnakeCase(string text)
Expand Down
2 changes: 2 additions & 0 deletions API/Mutation.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using API.Attributes;
using API.Types;
using Business;
using Key = API.Types.Key;
Expand All @@ -6,6 +7,7 @@ namespace API;

public class Mutation
{
[Protected("registerPublicKey")]
public async Task<Key> RegisterPublicKey([Service] IKeyService keyService, string publicKey)
{
return (await keyService.Create(publicKey)).ToGraphQl();
Expand Down
2 changes: 2 additions & 0 deletions API/Query.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using API.Attributes;
using API.Types;
using Business;
using Key = API.Types.Key;
Expand All @@ -6,6 +7,7 @@ namespace API;

public class Query
{
[Protected("keys")]
public async Task<IEnumerable<Key>> Keys([Service] IKeyService keyService)
{
return (await keyService.FetchAllKeys()).Select(x => x.ToGraphQl());
Expand Down
8 changes: 8 additions & 0 deletions Business/Exceptions/OperationNotPermitted.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace Business.Exceptions;

public class OperationNotPermitted : Exception
{
public OperationNotPermitted(string? message) : base(message)
{
}
}
3 changes: 2 additions & 1 deletion Business/KeyService.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Storage;
using KeyNotFoundException = Business.Exceptions.KeyNotFoundException;

namespace Business;

Expand All @@ -22,7 +23,7 @@ public KeyService(IKeyRepository keyRepository)
public async Task<Key> FindById(string id)
{
var key = await _keyRepository.FindById(id);
if (key == null) throw new Exceptions.KeyNotFoundException("key not found");
if (key == null) throw new KeyNotFoundException("key not found");

return key.ToViewModel();
}
Expand Down
2 changes: 2 additions & 0 deletions Startup/StartupExtensions/Configuration.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using API.Configs;
using Storage;

namespace Startup.StartupExtensions;
Expand All @@ -7,5 +8,6 @@ public static class Configuration
public static void AddConfiguration(this IServiceCollection services, IConfiguration configuration)
{
services.Configure<DatabaseConfig>(configuration.GetSection("DatabaseConfig"));
services.Configure<Security>(configuration.GetSection("Security"));
}
}
1 change: 1 addition & 0 deletions Startup/StartupExtensions/DependencyInjection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ public static void AddDependencies(this IServiceCollection services)
services.AddSingleton<IKeyRepository, KeyRepository>();
services.AddSingleton<IKeyService, KeyService>();
services.AddDbContext<ApplicationContext>(ServiceLifetime.Singleton);
services.AddHttpContextAccessor();
}
}

0 comments on commit cc47be5

Please sign in to comment.