From eee29c50543daeb6e1fb3c7a593e3f54238f5b3c Mon Sep 17 00:00:00 2001 From: Kevin Jones Date: Fri, 24 Jan 2025 12:22:54 -0500 Subject: [PATCH 1/2] Special case List in SelectPrimaryIdentity. --- .../System/Security/Claims/ClaimsPrincipal.cs | 26 ++++++- .../tests/ClaimsPrincipalTests.cs | 67 +++++++++++++++++++ 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs b/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs index de8f7d89725c68..2fd38469dd25c9 100644 --- a/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs +++ b/src/libraries/System.Security.Claims/src/System/Security/Claims/ClaimsPrincipal.cs @@ -61,11 +61,31 @@ protected ClaimsPrincipal(SerializationInfo info, StreamingContext context) { ArgumentNullException.ThrowIfNull(identities); - foreach (ClaimsIdentity identity in identities) + // If the identities value is exactly a List, special case it so that + // the enumerator allocation can be skipped. Doing this for List is the 99% + // case because it is normally used on the _identities value, which is a List. + if (identities.GetType() == typeof(List)) { - if (identity != null) + List identitiesList = (identities as List)!; + + for (int i = 0; i < identitiesList.Count; i++) + { + ClaimsIdentity identity = identitiesList[i]; + + if (identity != null) + { + return identity; + } + } + } + else + { + foreach (ClaimsIdentity identity in identities) { - return identity; + if (identity != null) + { + return identity; + } } } diff --git a/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs b/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs index c51d2dc37202da..841e134b4b2a76 100644 --- a/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs +++ b/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Collections; using System.Collections.Generic; using System.IO; using System.Linq; @@ -242,6 +243,72 @@ public void Current_FallsBackToThread_UnauthenticatedPrincipalPolicy() }).Dispose(); } + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public void PrimaryIdentitySelector_Default() + { + RemoteExecutor.Invoke(static () => + { + ClaimsIdentity identity0 = null; + ClaimsIdentity identity1 = new([new Claim("type", "value")]); + ClaimsIdentity identity2 = new([new Claim("type", "value")]); + IEnumerable identities = [identity0, identity1, identity2]; + Func, ClaimsIdentity> selector = ClaimsPrincipal.PrimaryIdentitySelector; + + Assert.Same(identity1, selector(identities)); + Assert.Null(selector([])); + AssertExtensions.Throws("identities", () => selector(null)); + }).Dispose(); + } + + [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + public void PrimaryIdentitySelector_DefaultDoesNotSpecialCaseInterfaceList() + { + RemoteExecutor.Invoke(static () => + { + ClaimsIdentity identity0 = null; + ClaimsIdentity identity1 = new([new Claim("type", "value")]); + ClaimsIdentity identity2 = new([new Claim("type", "value")]); + ClaimsIdentityList identities = [identity0, identity1, identity2]; + Func, ClaimsIdentity> selector = ClaimsPrincipal.PrimaryIdentitySelector; + + Assert.Same(identity1, selector(identities)); + Assert.True(identities.EnumeratedAtLeastOnce, nameof(identities.EnumeratedAtLeastOnce)); + Assert.Null(selector(new ClaimsIdentityList())); + }).Dispose(); + } + + private sealed class ClaimsIdentityList : IList + { + private readonly List _claimsIdentities = []; + + public bool EnumeratedAtLeastOnce { get; set; } + + public ClaimsIdentity this[int index] + { + get => _claimsIdentities[index]; + set => _claimsIdentities[index] = value; + } + + public int Count => _claimsIdentities.Count; + public bool IsReadOnly => ((ICollection)_claimsIdentities).IsReadOnly; + public void Add(ClaimsIdentity item) => _claimsIdentities.Add(item); + public void Clear() => _claimsIdentities.Clear(); + public bool Contains(ClaimsIdentity item) => _claimsIdentities.Contains(item); + public void CopyTo(ClaimsIdentity[] array, int arrayIndex) => _claimsIdentities.CopyTo(array, arrayIndex); + public int IndexOf(ClaimsIdentity item) => _claimsIdentities.IndexOf(item); + public void Insert(int index, ClaimsIdentity item) => _claimsIdentities.Insert(index, item); + public bool Remove(ClaimsIdentity item) => _claimsIdentities.Remove(item); + public void RemoveAt(int index) => _claimsIdentities.RemoveAt(index); + + public IEnumerator GetEnumerator() + { + EnumeratedAtLeastOnce = true; + return _claimsIdentities.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); + } + private class NonClaimsPrincipal : IPrincipal { public IIdentity Identity { get; set; } From cef9f9d165133b1605bebba4444d767c79461f61 Mon Sep 17 00:00:00 2001 From: Kevin Jones Date: Fri, 24 Jan 2025 14:09:26 -0500 Subject: [PATCH 2/2] Code review feedback --- .../tests/ClaimsPrincipalTests.cs | 35 +++++-------------- 1 file changed, 8 insertions(+), 27 deletions(-) diff --git a/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs b/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs index 841e134b4b2a76..a3cc102e212c1e 100644 --- a/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs +++ b/src/libraries/System.Security.Claims/tests/ClaimsPrincipalTests.cs @@ -261,7 +261,7 @@ public void PrimaryIdentitySelector_Default() } [ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] - public void PrimaryIdentitySelector_DefaultDoesNotSpecialCaseInterfaceList() + public void PrimaryIdentitySelector_DefaultOnlySpecialCasesList() { RemoteExecutor.Invoke(static () => { @@ -269,44 +269,25 @@ public void PrimaryIdentitySelector_DefaultDoesNotSpecialCaseInterfaceList() ClaimsIdentity identity1 = new([new Claim("type", "value")]); ClaimsIdentity identity2 = new([new Claim("type", "value")]); ClaimsIdentityList identities = [identity0, identity1, identity2]; - Func, ClaimsIdentity> selector = ClaimsPrincipal.PrimaryIdentitySelector; + Func selector = ClaimsPrincipal.PrimaryIdentitySelector; Assert.Same(identity1, selector(identities)); - Assert.True(identities.EnumeratedAtLeastOnce, nameof(identities.EnumeratedAtLeastOnce)); + Assert.Equal(1, identities.GetEnumeratorCount); Assert.Null(selector(new ClaimsIdentityList())); }).Dispose(); } - private sealed class ClaimsIdentityList : IList + private sealed class ClaimsIdentityList : List, IEnumerable { private readonly List _claimsIdentities = []; - public bool EnumeratedAtLeastOnce { get; set; } + public int GetEnumeratorCount { get; private set; } - public ClaimsIdentity this[int index] + public new IEnumerator GetEnumerator() { - get => _claimsIdentities[index]; - set => _claimsIdentities[index] = value; + GetEnumeratorCount++; + return base.GetEnumerator(); } - - public int Count => _claimsIdentities.Count; - public bool IsReadOnly => ((ICollection)_claimsIdentities).IsReadOnly; - public void Add(ClaimsIdentity item) => _claimsIdentities.Add(item); - public void Clear() => _claimsIdentities.Clear(); - public bool Contains(ClaimsIdentity item) => _claimsIdentities.Contains(item); - public void CopyTo(ClaimsIdentity[] array, int arrayIndex) => _claimsIdentities.CopyTo(array, arrayIndex); - public int IndexOf(ClaimsIdentity item) => _claimsIdentities.IndexOf(item); - public void Insert(int index, ClaimsIdentity item) => _claimsIdentities.Insert(index, item); - public bool Remove(ClaimsIdentity item) => _claimsIdentities.Remove(item); - public void RemoveAt(int index) => _claimsIdentities.RemoveAt(index); - - public IEnumerator GetEnumerator() - { - EnumeratedAtLeastOnce = true; - return _claimsIdentities.GetEnumerator(); - } - - IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)this).GetEnumerator(); } private class NonClaimsPrincipal : IPrincipal