Skip to content

Commit

Permalink
Support for SecurityTokenDescriptor.Claims in JwtSecurity/Saml/Saml2 …
Browse files Browse the repository at this point in the history
…Tokens
  • Loading branch information
RojaEnnam authored and brentschmaltz committed May 22, 2020
1 parent 203eca5 commit e4f8a0c
Show file tree
Hide file tree
Showing 14 changed files with 962 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@ public class JwtTokenUtilities
JwtHeaderParameterNames.Zip
};

internal static Dictionary<string, object> CreateDictionaryFromClaims(IEnumerable<Claim> claims)
/// <summary>
/// Creates a dictionary from a list of Claim's.
/// </summary>
/// <param name="claims"> A list of claims.</param>
/// <returns> A Dictionary representing claims.</returns>
internal static IDictionary<string, object> CreateDictionaryFromClaims(IEnumerable<Claim> claims)
{
var payload = new Dictionary<string, object>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ internal static class LogMessages
// signature creation / validation
internal const string IDX11312 = "IDX11312: Unable to validate token. A SamlSamlAttributeStatement can only have one SamlAttribute of type 'Actor'. This special SamlAttribute is used in delegation scenarios.";
internal const string IDX11313 = "IDX11313: Unable to process Saml attribute. A SamlSubject must contain either or both of Name and ConfirmationMethod.";
internal const string IDX11314 = "IDX11314: The AttributeValueXsiType of a SAML Attribute must be a string of the form 'prefix#suffix', where prefix and suffix are non-empty strings. Found: '{0}'";

// SamlSerializer reading
internal const string IDX11100 = "IDX11100: Saml Only one element of type '{0}' is supported.";
Expand Down
6 changes: 3 additions & 3 deletions src/Microsoft.IdentityModel.Tokens.Saml/Saml/SamlAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,16 @@ public string AttributeValueXsiType

int indexOfHash = value.IndexOf('#');
if (indexOfHash == -1)
throw LogExceptionMessage(new SecurityTokenInvalidAudienceException("value, SR.GetString(SR.ID4254)")); ;
throw LogExceptionMessage(new SecurityTokenInvalidAudienceException(FormatInvariant(LogMessages.IDX11314, value)));

string prefix = value.Substring(0, indexOfHash);
if (prefix.Length == 0)
throw LogExceptionMessage(new ArgumentException("value SR.GetString(SR.ID4254)"));
throw LogExceptionMessage(new ArgumentException(FormatInvariant(LogMessages.IDX11314, value)));

string suffix = value.Substring(indexOfHash + 1);
if (suffix.Length == 0)
{
throw LogExceptionMessage(new ArgumentException("value, SR.GetString(SR.ID4254)"));
throw LogExceptionMessage(new ArgumentException(FormatInvariant(LogMessages.IDX11314, value)));
}

_attributeValueXsiType = value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,12 @@ protected virtual SamlAttributeStatement CreateAttributeStatement(SamlSubject su
if (tokenDescriptor == null)
throw LogArgumentNullException(nameof(tokenDescriptor));

if (tokenDescriptor.Subject != null)
IEnumerable<Claim> claims = SamlTokenUtilities.GetAllClaims(tokenDescriptor.Claims, tokenDescriptor.Subject != null ? tokenDescriptor.Subject.Claims : null);

if (claims != null && claims.Any())
{
var attributes = new List<SamlAttribute>();
foreach (var claim in tokenDescriptor.Subject.Claims)
foreach (var claim in claims)
{
if (claim != null && claim.Type != ClaimTypes.NameIdentifier)
{
Expand All @@ -293,7 +295,7 @@ protected virtual SamlAttributeStatement CreateAttributeStatement(SamlSubject su
}
}

AddActorToAttributes(attributes, tokenDescriptor.Subject.Actor);
AddActorToAttributes(attributes, tokenDescriptor.Subject?.Actor);

var consolidatedAttributes = ConsolidateAttributes(attributes);
if (consolidatedAttributes.Count > 0)
Expand Down Expand Up @@ -450,9 +452,12 @@ protected virtual SamlSubject CreateSubject(SecurityTokenDescriptor tokenDescrip

var samlSubject = new SamlSubject();
Claim identityClaim = null;
if (tokenDescriptor.Subject != null && tokenDescriptor.Subject.Claims != null)

IEnumerable<Claim> claims = SamlTokenUtilities.GetAllClaims(tokenDescriptor.Claims, tokenDescriptor.Subject != null ? tokenDescriptor.Subject.Claims : null);

if (claims != null && claims.Any())
{
foreach (var claim in tokenDescriptor.Subject.Claims)
foreach (var claim in claims)
{
if (claim.Type == ClaimTypes.NameIdentifier)
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Microsoft.IdentityModel.Logging;
using Microsoft.IdentityModel.Xml;

//------------------------------------------------------------------------------
//
// Copyright (c) Microsoft Corporation.
Expand Down Expand Up @@ -27,19 +26,20 @@
//
//------------------------------------------------------------------------------

using System;
using System.Security.Claims;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.IdentityModel.Xml;
using System;
using Microsoft.IdentityModel.Logging;
using TokenLogMessages = Microsoft.IdentityModel.Tokens.LogMessages;

namespace Microsoft.IdentityModel.Tokens.Saml
{
/// <summary>
/// A class which contains useful methods for processing saml tokens.
/// </summary>
public class SamlTokenUtilities
internal class SamlTokenUtilities
{
/// <summary>
/// Returns a <see cref="SecurityKey"/> to use when validating the signature of a token.
Expand Down Expand Up @@ -102,5 +102,89 @@ internal static IEnumerable<SecurityKey> GetKeysForTokenSignatureValidation(stri
}
return null;
}

/// <summary>
/// Creates <see cref="Claim"/>'s from <paramref name="claimsCollection"/>.
/// </summary>
/// <param name="claimsCollection"> A dictionary that represents a set of claims.</param>
/// <returns> A collection of <see cref="Claim"/>'s created from the <paramref name="claimsCollection"/>.</returns>
internal static IEnumerable<Claim> CreateClaimsFromDictionary(IDictionary<string, object> claimsCollection)
{
if (claimsCollection == null)
return null;

var claims = new List<Claim>();
foreach (var claim in claimsCollection)
{
string claimType = claim.Key;
object claimValue = claim.Value;
if (claimValue != null)
{
var valueType = GetXsiTypeForValue(claimValue);
if (valueType == null && claimValue is IEnumerable claimList)
{
foreach (var item in claimList)
{
valueType = GetXsiTypeForValue(item);
if (valueType == null && item is IEnumerable)
throw new NotSupportedException(LogHelper.FormatInvariant(TokenLogMessages.IDX10105, claimType));

claims.Add(new Claim(claimType, item.ToString(), valueType));
}
}
else
{
claims.Add(new Claim(claimType, claimValue.ToString(), valueType));
}
}
}

return claims;
}

/// <summary>
/// Merges <paramref name="claims"/> and <paramref name="subjectClaims"/>
/// </summary>
/// <param name="claims"> A dictionary of claims.</param>
/// <param name="subjectClaims"> A collection of <see cref="Claim"/>'s</param>
/// <returns> A merged list of <see cref="Claim"/>'s.</returns>
internal static IEnumerable<Claim> GetAllClaims(IDictionary<string, object> claims, IEnumerable<Claim> subjectClaims)
{
if (claims == null)
return subjectClaims;
else
return TokenUtilities.MergeClaims(CreateClaimsFromDictionary(claims), subjectClaims);
}

/// <summary>
/// Gets the value type of the <see cref="Claim"/> from its value <paramref name="value"/>
/// </summary>
/// <param name="value"> The <see cref="Claim"/> value.</param>
/// <returns> The value type of the <see cref="Claim"/>.</returns>
internal static string GetXsiTypeForValue(object value)
{
if (value != null)
{
if (value is string)
return ClaimValueTypes.String;

if (value is bool)
return ClaimValueTypes.Boolean;

if (value is int)
return ClaimValueTypes.Integer32;

if (value is long)
return ClaimValueTypes.Integer64;

if (value is double)
return ClaimValueTypes.Double;

if (value is DateTime)
return ClaimValueTypes.DateTime;
}

return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.IO;
using System.Linq;
using System.Security.Claims;
using System.Text;
using System.Xml;
Expand Down Expand Up @@ -681,20 +682,26 @@ protected virtual Saml2AttributeStatement CreateAttributeStatement(SecurityToken
throw LogArgumentNullException(nameof(tokenDescriptor.Subject));

var attributes = new List<Saml2Attribute>();
foreach (Claim claim in tokenDescriptor.Subject.Claims)

IEnumerable<Claim> claims = SamlTokenUtilities.GetAllClaims(tokenDescriptor.Claims, tokenDescriptor.Subject != null ? tokenDescriptor.Subject.Claims : null);

if (claims != null && claims.Any())
{
if (claim != null)
foreach (Claim claim in claims)
{
switch (claim.Type)
if (claim != null)
{
// TODO - should these really be filtered?
case ClaimTypes.AuthenticationInstant:
case ClaimTypes.AuthenticationMethod:
case ClaimTypes.NameIdentifier:
break;
default:
attributes.Add(CreateAttribute(claim));
break;
switch (claim.Type)
{
// TODO - should these really be filtered?
case ClaimTypes.AuthenticationInstant:
case ClaimTypes.AuthenticationMethod:
case ClaimTypes.NameIdentifier:
break;
default:
attributes.Add(CreateAttribute(claim));
break;
}
}
}
}
Expand Down Expand Up @@ -895,9 +902,11 @@ protected virtual Saml2Subject CreateSubject(SecurityTokenDescriptor tokenDescri
string nameIdentifierSpProviderId = null;
string nameIdentifierSpNameQualifier = null;

if (tokenDescriptor.Subject != null && tokenDescriptor.Subject.Claims != null)
IEnumerable<Claim> claims = SamlTokenUtilities.GetAllClaims(tokenDescriptor.Claims, tokenDescriptor.Subject != null ? tokenDescriptor.Subject.Claims : null);

if (claims != null && claims.Any())
{
foreach (var claim in tokenDescriptor.Subject.Claims)
foreach (var claim in claims)
{
if (claim.Type == ClaimTypes.NameIdentifier)
{
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.IdentityModel.Tokens/LogMessages.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ internal static class LogMessages
public const string IDX10102 = "IDX10102: NameClaimType cannot be null or whitespace.";
public const string IDX10103 = "IDX10103: RoleClaimType cannot be null or whitespace.";
public const string IDX10104 = "IDX10104: TokenLifetimeInMinutes must be greater than zero. value: '{0}'";
public const string IDX10105 = "IDX10105: ClaimValue that is a collection of collections is not supported. Such ClaimValue is found for ClaimType : '{0}'";

// token validation
public const string IDX10204 = "IDX10204: Unable to validate issuer. validationParameters.ValidIssuer is null or whitespace AND validationParameters.ValidIssuers is null.";
Expand Down Expand Up @@ -227,6 +228,7 @@ internal static class LogMessages
public const string IDX10812 = "IDX10812: Unable to create a {0} from the properties found in the JsonWebKey: '{1}'.";
public const string IDX10813 = "IDX10813: Unable to create a {0} from the properties found in the JsonWebKey: '{1}', Exception '{2}'.";


#pragma warning restore 1591
}
}
30 changes: 27 additions & 3 deletions src/Microsoft.IdentityModel.Tokens/TokenUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Security.Claims;
using Microsoft.IdentityModel.Logging;
using TokenLogMessages = Microsoft.IdentityModel.Tokens.LogMessages;

Expand All @@ -38,7 +37,7 @@ namespace Microsoft.IdentityModel.Tokens
/// <summary>
/// A class which contains useful methods for processing tokens.
/// </summary>
public class TokenUtilities
internal class TokenUtilities
{
/// <summary>
/// Returns all <see cref="SecurityKey"/> provided in validationParameters.
Expand All @@ -55,5 +54,30 @@ internal static IEnumerable<SecurityKey> GetAllSigningKeys(TokenValidationParame
foreach (SecurityKey key in validationParameters.IssuerSigningKeys)
yield return key;
}

/// <summary>
/// Merges claims. If a claim with same type exists in both <paramref name="claims"/> and <paramref name="subjectClaims"/>, the one in claims will be kept.
/// </summary>
/// <param name="claims"> Collection of <see cref="Claim"/>'s.</param>
/// <param name="subjectClaims"> Collection of <see cref="Claim"/>'s.</param>
/// <returns> A Merged list of <see cref="Claim"/>'s.</returns>
internal static IEnumerable<Claim> MergeClaims(IEnumerable<Claim> claims, IEnumerable<Claim> subjectClaims)
{
if (claims == null)
return subjectClaims;

if (subjectClaims == null)
return claims;

List<Claim> result = claims.ToList();

foreach (Claim claim in subjectClaims)
{
if (!claims.Any(i => i.Type == claim.Type))
result.Add(claim);
}

return result;
}
}
}
Loading

0 comments on commit e4f8a0c

Please sign in to comment.