Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance and reduce allocations in hot paths #159

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 13 additions & 25 deletions src/Infrastructure/CacheExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using System.Security.Claims;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using System.Threading.Tasks;

namespace IdentityModel.AspNetCore.OAuth2Introspection
Expand All @@ -20,53 +19,44 @@ internal static class CacheExtensions

static CacheExtensions()
{

#if NET6_0_OR_GREATER
Options = new JsonSerializerOptions
{
IgnoreReadOnlyFields = true,
IgnoreReadOnlyProperties = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
};
Converters = { new ClaimConverter() },
#if NET6_0_OR_GREATER
DefaultIgnoreCondition = System.Text.Json.Serialization.JsonIgnoreCondition.WhenWritingNull,
#else
Options = new JsonSerializerOptions
{
IgnoreReadOnlyFields = true,
IgnoreReadOnlyProperties = true,
IgnoreNullValues = true
};
IgnoreNullValues = true,
#endif

Options.Converters.Add(new ClaimConverter());
};
}

public static async Task<IEnumerable<Claim>> GetClaimsAsync(this IDistributedCache cache, OAuth2IntrospectionOptions options, string token)
public static async Task<IEnumerable<Claim>> GetClaimsAsync(this IDistributedCache cache, string cacheKey)
{
var cacheKey = options.CacheKeyGenerator(options,token);
var bytes = await cache.GetAsync(cacheKey).ConfigureAwait(false);

if (bytes == null)
{
return null;
}

var json = Encoding.UTF8.GetString(bytes);
return JsonSerializer.Deserialize<IEnumerable<Claim>>(json, Options);
return JsonSerializer.Deserialize<IEnumerable<Claim>>(bytes, Options);
}

public static async Task SetClaimsAsync(this IDistributedCache cache, OAuth2IntrospectionOptions options, string token, IEnumerable<Claim> claims, TimeSpan duration, ILogger logger)
public static async Task SetClaimsAsync(this IDistributedCache cache, string cacheKey, IEnumerable<Claim> claims, TimeSpan duration, ILogger logger)
{
var expClaim = claims.FirstOrDefault(c => c.Type == JwtClaimTypes.Expiration);

if (expClaim == null)
{
logger.LogWarning("No exp claim found on introspection response, can't cache.");
Log.NoExpClaimFound(logger, null);
return;
}

var now = DateTimeOffset.UtcNow;
var expiration = DateTimeOffset.FromUnixTimeSeconds(long.Parse(expClaim.Value));
logger.LogDebug("Token will expire in {expiration}", expiration);

Log.TokenExpiresOn(logger, expiration, null);

if (expiration <= now)
{
Expand All @@ -84,11 +74,9 @@ public static async Task SetClaimsAsync(this IDistributedCache cache, OAuth2Intr
absoluteLifetime = now.Add(duration);
}

var json = JsonSerializer.Serialize(claims, Options);
var bytes = Encoding.UTF8.GetBytes(json);
var bytes = JsonSerializer.SerializeToUtf8Bytes(claims, Options);

logger.LogDebug("Setting cache item expiration to {expiration}", absoluteLifetime);
var cacheKey = options.CacheKeyGenerator(options, token);
Log.SettingToCache(logger, absoluteLifetime, null);
await cache.SetAsync(cacheKey, bytes, new DistributedCacheEntryOptions { AbsoluteExpiration = absoluteLifetime }).ConfigureAwait(false);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/Infrastructure/CacheUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public static class CacheUtils
/// <returns></returns>
public static Func<OAuth2IntrospectionOptions,string, string> CacheKeyFromToken()
{
return (options, token) => $"{options.CacheKeyPrefix}{token.Sha256()}";
return (options, token) => token.Sha256(options.CacheKeyPrefix);
}
}
}
65 changes: 59 additions & 6 deletions src/Infrastructure/StringExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
using System.Diagnostics;
using System.Security.Cryptography;
using System.Text;
#if NET6_0_OR_GREATER
using System.Buffers;
#endif

namespace IdentityModel.AspNetCore.OAuth2Introspection
{
Expand Down Expand Up @@ -33,17 +36,67 @@ public static bool IsPresent(this string value)
return !string.IsNullOrWhiteSpace(value);
}

internal static string Sha256(this string input)
/// <summary>
/// Returns Base64 UTF8 bytes of <paramref name="input"/> appended to <paramref name="prefix"/>.
/// If <paramref name="input"/> is missing, returns only prefix.
/// </summary>
internal static string Sha256(this string input, string prefix)
{
if (input.IsMissing()) return string.Empty;
if (input.IsMissing()) return prefix;

using (var sha = SHA256.Create())
#if !NET6_0_OR_GREATER
using var sha = SHA256.Create();
var bytes = Encoding.UTF8.GetBytes(input);
var hash = sha.ComputeHash(bytes);
return prefix + Convert.ToBase64String(hash);
#else
const int Base64Sha256Len = 44; // base64 sha256 is always 44 chars
return string.Create(prefix.Length + Base64Sha256Len, (input, prefix), _sha256WithPrefix);
#endif
}

#if NET6_0_OR_GREATER
private static readonly SpanAction<char, (string input, string prefix)> _sha256WithPrefix = Sha256WithPrefix;

/// <summary>
/// Writes prefix with input's sha256 hash as base64 appended to the span.
/// </summary>
private static void Sha256WithPrefix(Span<char> destination, (string input, string prefix) state)
{
const int Sha256Len = 32; // sha256 is always 32 bytes
const int MaxStackAlloc = 256;

var (input, prefix) = state;

// use a rented buffer if input as bytes would be dangerously long to stackalloc
byte[] rented = null;

try
{
var bytes = Encoding.UTF8.GetBytes(input);
var hash = sha.ComputeHash(bytes);
int maxUtf8Len = Encoding.UTF8.GetMaxByteCount(input.Length);

return Convert.ToBase64String(hash);
Span<byte> utf8buffer = maxUtf8Len > MaxStackAlloc
? (rented = ArrayPool<byte>.Shared.Rent(maxUtf8Len))
: stackalloc byte[maxUtf8Len];

int utf8Written = Encoding.UTF8.GetBytes(input, utf8buffer);

Span<byte> hashBuffer = stackalloc byte[Sha256Len];
int hashedCount = SHA256.HashData(utf8buffer[..utf8Written], hashBuffer);
Debug.Assert(hashedCount == Sha256Len);

if (prefix.Length != 0)
prefix.CopyTo(destination);

bool b64success = Convert.TryToBase64Chars(hashBuffer, destination[prefix.Length..], out var b64written);
Debug.Assert(b64success);
}
finally
{
if (rented != null)
ArrayPool<byte>.Shared.Return(rented);
}
}
#endif
}
}
35 changes: 21 additions & 14 deletions src/Infrastructure/TokenRetrieval.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.

using Microsoft.AspNetCore.Http;
using Microsoft.Net.Http.Headers;
using System;
using System.Linq;

namespace IdentityModel.AspNetCore.OAuth2Introspection
{
Expand All @@ -13,24 +13,26 @@ namespace IdentityModel.AspNetCore.OAuth2Introspection
public static class TokenRetrieval
{
/// <summary>
/// Reads the token from the authrorization header.
/// Reads the token from the authorization header.
/// </summary>
/// <param name="scheme">The scheme (defaults to Bearer).</param>
/// <returns></returns>
public static Func<HttpRequest, string> FromAuthorizationHeader(string scheme = "Bearer")
public static Func<HttpRequest, string> FromAuthorizationHeader(
string scheme = OAuth2IntrospectionDefaults.AuthenticationScheme)
{
string schemePrefix = scheme + " ";

return request =>
{
string authorization = request.Headers["Authorization"].FirstOrDefault();

if (string.IsNullOrEmpty(authorization))
if (request.Headers.TryGetValue(HeaderNames.Authorization, out var value) &&
value.Count != 0)
{
return null;
}
string authorization = value[0];

if (authorization.StartsWith(scheme + " ", StringComparison.OrdinalIgnoreCase))
{
return authorization.Substring(scheme.Length + 1).Trim();
if (!string.IsNullOrEmpty(authorization) &&
authorization.StartsWith(schemePrefix, StringComparison.OrdinalIgnoreCase))
{
return new string(authorization.AsSpan(schemePrefix.Length).Trim());
}
}

return null;
Expand All @@ -41,10 +43,15 @@ public static Func<HttpRequest, string> FromAuthorizationHeader(string scheme =
/// Reads the token from a query string parameter.
/// </summary>
/// <param name="name">The name (defaults to access_token).</param>
/// <returns></returns>
public static Func<HttpRequest, string> FromQueryString(string name = "access_token")
{
return request => request.Query[name].FirstOrDefault();
return request =>
{
if (request.Query.TryGetValue(name, out var value) && value.Count > 0)
return value[0];

return null;
};
}
}
}
44 changes: 44 additions & 0 deletions src/Log.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using System;
using Microsoft.Extensions.Logging;

namespace IdentityModel.AspNetCore.OAuth2Introspection
{
internal static class Log
{
public static readonly Action<ILogger, Exception> NoExpClaimFound
= LoggerMessage.Define(
LogLevel.Warning,
1,
"No exp claim found on introspection response, can't cache");

public static readonly Action<ILogger, DateTimeOffset, Exception> TokenExpiresOn
= LoggerMessage.Define<DateTimeOffset>(
LogLevel.Debug,
2,
"Token will expire on {Expiration}");

public static readonly Action<ILogger, DateTimeOffset, Exception> SettingToCache
= LoggerMessage.Define<DateTimeOffset>(
LogLevel.Debug,
3,
"Setting cache item expiration to {Expiration}");

public static readonly Action<ILogger, Exception> SkippingDotToken
= LoggerMessage.Define(
LogLevel.Trace,
4,
"Token contains a dot - skipped because SkipTokensWithDots is set");

public static readonly Action<ILogger, Exception> TokenNotCached
= LoggerMessage.Define(
LogLevel.Trace,
5,
"Token is not cached");

public static readonly Action<ILogger, string, Exception> IntrospectionError
= LoggerMessage.Define<string>(
LogLevel.Error,
6,
"Error returned from introspection endpoint: {Error}");
}
}
Loading