diff --git a/src/CommonLib/ConcurrentHashSet.cs b/src/CommonLib/ConcurrentHashSet.cs new file mode 100644 index 00000000..670175ce --- /dev/null +++ b/src/CommonLib/ConcurrentHashSet.cs @@ -0,0 +1,60 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; + +namespace SharpHoundCommonLib; + +/// +/// A concurrent implementation of a hashset using a ConcurrentDictionary as the backing structure. +/// +public class ConcurrentHashSet : IDisposable{ + private ConcurrentDictionary _backingDictionary; + + public ConcurrentHashSet() { + _backingDictionary = new ConcurrentDictionary(); + } + + public ConcurrentHashSet(StringComparer comparison) { + _backingDictionary = new ConcurrentDictionary(comparison); + } + + /// + /// Attempts to add an item to the set. Returns true if adding was successful, false otherwise + /// + /// + /// + public bool Add(string item) { + return _backingDictionary.TryAdd(item, byte.MinValue); + } + + /// + /// Attempts to remove an item from the set. Returns true of removing was successful, false otherwise + /// + /// + /// + public bool Remove(string item) { + return _backingDictionary.TryRemove(item, out _); + } + + /// + /// Checks if the given item is in the set + /// + /// + /// + public bool Contains(string item) { + return _backingDictionary.ContainsKey(item); + } + + /// + /// Returns all values in the set + /// + /// + public IEnumerable Values() { + return _backingDictionary.Keys; + } + + public void Dispose() { + _backingDictionary = null; + GC.SuppressFinalize(this); + } +} \ No newline at end of file diff --git a/src/CommonLib/ILdapUtils.cs b/src/CommonLib/ILdapUtils.cs index 56ebd782..d97e8a43 100644 --- a/src/CommonLib/ILdapUtils.cs +++ b/src/CommonLib/ILdapUtils.cs @@ -169,5 +169,10 @@ IAsyncEnumerable> RangedRetrieval(string distinguishedName, /// The naming context being retrieved /// A tuple containing success state as well as the resolved distinguished name if successful Task<(bool Success, string Path)> GetNamingContextPath(string domain, NamingContext context); + + /// + /// Resets temporary caches in LDAPUtils + /// + void ResetUtils(); } } \ No newline at end of file diff --git a/src/CommonLib/LdapUtils.cs b/src/CommonLib/LdapUtils.cs index 2eb3f7da..17293d3b 100644 --- a/src/CommonLib/LdapUtils.cs +++ b/src/CommonLib/LdapUtils.cs @@ -4,11 +4,9 @@ using System.DirectoryServices; using System.DirectoryServices.AccountManagement; using System.DirectoryServices.ActiveDirectory; -using System.DirectoryServices.Protocols; using System.Linq; using System.Net; using System.Net.Sockets; -using System.Runtime.CompilerServices; using System.Security.Principal; using System.Text; using System.Text.RegularExpressions; @@ -24,13 +22,13 @@ using Domain = System.DirectoryServices.ActiveDirectory.Domain; using Group = SharpHoundCommonLib.OutputTypes.Group; using SearchScope = System.DirectoryServices.Protocols.SearchScope; -using SecurityMasks = System.DirectoryServices.Protocols.SecurityMasks; namespace SharpHoundCommonLib { public class LdapUtils : ILdapUtils { //This cache is indexed by domain sid - private static readonly ConcurrentDictionary DomainCache = new(); - private static readonly ConcurrentDictionary DomainControllers = new(); + private static ConcurrentDictionary _domainCache = new(); + private static ConcurrentHashSet _domainControllers = new(StringComparer.OrdinalIgnoreCase); + private static ConcurrentHashSet _unresolvablePrincipals = new(StringComparer.OrdinalIgnoreCase); private static readonly ConcurrentDictionary DomainToForestCache = new(StringComparer.OrdinalIgnoreCase); @@ -55,11 +53,6 @@ private readonly ConcurrentDictionary private ConnectionPoolManager _connectionPool; - private static readonly TimeSpan MinBackoffDelay = TimeSpan.FromSeconds(2); - private static readonly TimeSpan MaxBackoffDelay = TimeSpan.FromSeconds(20); - private const int BackoffDelayMultiplier = 2; - private const int MaxRetries = 3; - private static readonly byte[] NameRequest = { 0x80, 0x94, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x43, 0x4b, 0x41, @@ -99,7 +92,8 @@ public IAsyncEnumerable> Query(LdapQueryParameters return _connectionPool.Query(queryParameters, cancellationToken); } - public IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, CancellationToken cancellationToken = new()) { + public IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, + CancellationToken cancellationToken = new()) { return _connectionPool.PagedQuery(queryParameters, cancellationToken); } @@ -119,12 +113,24 @@ public IAsyncEnumerable> Query(LdapQueryParameters return (true, principal); } + if (_unresolvablePrincipals.Contains(identifier)) { + return (false, new TypedPrincipal(identifier, Label.Base)); + } + if (identifier.StartsWith("S-")) { var result = await LookupSidType(identifier, objectDomain); + if (!result.Success) { + _unresolvablePrincipals.Add(identifier); + } + return (result.Success, new TypedPrincipal(identifier, result.Type)); } var (success, type) = await LookupGuidType(identifier, objectDomain); + if (!success) { + _unresolvablePrincipals.Add(identifier); + } + return (success, new TypedPrincipal(identifier, type)); } @@ -297,7 +303,8 @@ public IAsyncEnumerable> Query(LdapQueryParameters LDAPFilter = new LdapFilter().AddAllObjects().GetFilter(), }; - var result = await Query(queryParameters).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); + var result = await Query(queryParameters).DefaultIfEmpty(LdapResult.Fail()) + .FirstOrDefaultAsync(); if (result.IsSuccess && result.Value.TryGetProperty(LDAPProperties.RootDomainNamingContext, out var rootNamingContext)) { return (true, Helpers.DistinguishedNameToDomain(rootNamingContext).ToUpper()); @@ -306,12 +313,6 @@ public IAsyncEnumerable> Query(LdapQueryParameters return (false, null); } - private static TimeSpan GetNextBackoff(int retryCount) { - return TimeSpan.FromSeconds(Math.Min( - MinBackoffDelay.TotalSeconds * Math.Pow(BackoffDelayMultiplier, retryCount), - MaxBackoffDelay.TotalSeconds)); - } - public async Task<(bool Success, string DomainName)> GetDomainNameFromSid(string sid) { string domainSid; try { @@ -465,7 +466,7 @@ private static TimeSpan GetNextBackoff(int retryCount) { /// public bool GetDomain(string domainName, out Domain domain) { var cacheKey = domainName ?? _nullCacheKey; - if (DomainCache.TryGetValue(cacheKey, out domain)) return true; + if (_domainCache.TryGetValue(cacheKey, out domain)) return true; try { DirectoryContext context; @@ -482,7 +483,7 @@ public bool GetDomain(string domainName, out Domain domain) { domain = Domain.GetDomain(context); if (domain == null) return false; - DomainCache.TryAdd(cacheKey, domain); + _domainCache.TryAdd(cacheKey, domain); return true; } catch (Exception e) { _log.LogDebug(e, "GetDomain call failed for domain name {Name}", domainName); @@ -491,7 +492,7 @@ public bool GetDomain(string domainName, out Domain domain) { } public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domain domain) { - if (DomainCache.TryGetValue(domainName, out domain)) return true; + if (_domainCache.TryGetValue(domainName, out domain)) return true; try { DirectoryContext context; @@ -508,7 +509,7 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai domain = Domain.GetDomain(context); if (domain == null) return false; - DomainCache.TryAdd(domainName, domain); + _domainCache.TryAdd(domainName, domain); return true; } catch (Exception e) { Logging.Logger.LogDebug("Static GetDomain call failed for domain {DomainName}: {Error}", domainName, @@ -525,8 +526,7 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai /// /// public bool GetDomain(out Domain domain) { - var cacheKey = _nullCacheKey; - if (DomainCache.TryGetValue(cacheKey, out domain)) return true; + if (_domainCache.TryGetValue(_nullCacheKey, out domain)) return true; try { var context = _ldapConfig.Username != null @@ -535,7 +535,7 @@ public bool GetDomain(out Domain domain) { : new DirectoryContext(DirectoryContextType.Domain); domain = Domain.GetDomain(context); - DomainCache.TryAdd(cacheKey, domain); + _domainCache.TryAdd(_nullCacheKey, domain); return true; } catch (Exception e) { _log.LogDebug(e, "GetDomain call failed for blank domain"); @@ -579,7 +579,7 @@ public bool GetDomain(out Domain domain) { return (false, string.Empty); } - if (_hostResolutionMap.TryGetValue(strippedHost, out var sid)) return (true, sid); + if (_hostResolutionMap.TryGetValue(strippedHost, out var sid)) return (sid != null, sid); //Immediately start with NetWkstaGetInfo as it's our most reliable indicator if successful if (await GetWorkstationInfo(strippedHost) is (true, var workstationInfo)) { @@ -658,6 +658,7 @@ public bool GetDomain(out Domain domain) { //pass } + _hostResolutionMap.TryAdd(strippedHost, null); return (false, ""); } @@ -693,7 +694,7 @@ public bool GetDomain(out Domain domain) { if (await GetWellKnownPrincipal(sid, domain) is (true, var principal)) { sids.Add(principal.ObjectIdentifier); } else { - sids.Add(sid); + sids.Add(sid); } } else { return (false, Array.Empty()); @@ -825,9 +826,10 @@ public ActiveDirectorySecurityDescriptor MakeSecurityDescriptor() { } public async Task IsDomainController(string computerObjectId, string domainName) { - if (DomainControllers.ContainsKey(computerObjectId)) { + if (_domainControllers.Contains(computerObjectId)) { return true; } + var resDomain = await GetDomainNameFromSid(domainName) is (false, var tempDomain) ? tempDomain : domainName; var filter = new LdapFilter().AddFilter(CommonFilters.SpecificSID(computerObjectId), true) .AddFilter(CommonFilters.DomainControllers, true); @@ -837,8 +839,9 @@ public async Task IsDomainController(string computerObjectId, string domai LDAPFilter = filter.GetFilter(), }).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); if (result.IsSuccess) { - DomainControllers.TryAdd(computerObjectId, new byte()); + _domainControllers.Add(computerObjectId); } + return result.IsSuccess; } @@ -847,6 +850,10 @@ public async Task IsDomainController(string computerObjectId, string domai return (true, principal); } + if (_unresolvablePrincipals.Contains(distinguishedName)) { + return (false, default); + } + var domain = Helpers.DistinguishedNameToDomain(distinguishedName); var result = await Query(new LdapQueryParameters { DomainName = domain, @@ -890,13 +897,13 @@ public async Task IsDomainController(string computerObjectId, string domai return (false, default); } catch { + _unresolvablePrincipals.Add(distinguishedName); return (false, default); } } } - public async Task<(bool Success, string DSHeuristics)> GetDSHueristics(string domain, string dn) - { + public async Task<(bool Success, string DSHeuristics)> GetDSHueristics(string domain, string dn) { var configPath = CommonPaths.CreateDNPath(CommonPaths.DirectoryServicePath, dn); var queryParameters = new LdapQueryParameters { Attributes = new[] { LDAPProperties.DSHeuristics }, @@ -907,16 +914,18 @@ public async Task IsDomainController(string computerObjectId, string domai SearchBase = configPath }; - var result = await Query(queryParameters).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync(); + var result = await Query(queryParameters).DefaultIfEmpty(LdapResult.Fail()) + .FirstOrDefaultAsync(); if (result.IsSuccess && result.Value.TryGetProperty(LDAPProperties.DSHeuristics, out var dsh)) { return (true, dsh); } + return (false, null); } public void AddDomainController(string domainControllerSID) { - DomainControllers.TryAdd(domainControllerSID, new byte()); + _domainControllers.Add(domainControllerSID); } public async IAsyncEnumerable GetWellKnownPrincipalOutput() { @@ -948,11 +957,12 @@ public async IAsyncEnumerable GetWellKnownPrincipalOutput() { yield return entdc; } } - + private async IAsyncEnumerable GetEnterpriseDCGroups() { var grouped = new ConcurrentDictionary>(StringComparer.OrdinalIgnoreCase); var forestSidToName = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); - foreach (var domainSid in DomainControllers.GroupBy(x => new SecurityIdentifier(x.Key).AccountDomainSid.Value)) { + foreach (var domainSid in _domainControllers.Values().GroupBy(x => + new SecurityIdentifier(x).AccountDomainSid.Value)) { if (await GetDomainNameFromSid(domainSid.Key) is (true, var domainName) && await GetForest(domainName) is (true, var forestName) && await GetDomainSidFromDomainName(forestName) is (true, var forestDomainSid)) { @@ -960,15 +970,15 @@ await GetDomainSidFromDomainName(forestName) is (true, var forestDomainSid)) { if (!grouped.ContainsKey(forestDomainSid)) { grouped[forestDomainSid] = new List(); } - + foreach (var k in domainSid) { - grouped[forestDomainSid].Add(k.Key); + grouped[forestDomainSid].Add(k); } } } foreach (var f in grouped) { - var group = new Group() { ObjectIdentifier = $"{f.Key}-S-1-5-9" }; + var group = new Group { ObjectIdentifier = $"{f.Key}-S-1-5-9" }; group.Properties.Add("name", $"ENTERPRISE DOMAIN CONTROLLERS@{forestSidToName[f.Key]}".ToUpper()); group.Properties.Add("domainsid", f.Key); group.Properties.Add("domain", forestSidToName[f.Key]); @@ -1040,6 +1050,14 @@ public void SetLdapConfig(LdapConfig config) { return (false, default); } + public void ResetUtils() { + _unresolvablePrincipals = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase); + _domainCache = new ConcurrentDictionary(); + _domainControllers = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase); + _connectionPool?.Dispose(); + _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner); + } + private IDirectoryObject CreateDirectoryEntry(string path) { if (_ldapConfig.Username != null) { return new DirectoryEntry(path, _ldapConfig.Username, _ldapConfig.Password).ToDirectoryObject(); @@ -1122,11 +1140,11 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN if (!directoryObject.GetObjectIdentifier(out var objectIdentifier)) { return (false, default); } - + var res = new ResolvedSearchResult { ObjectId = objectIdentifier }; - + //If the object is deleted, we can short circuit the rest of this logic as we don't really care about anything else if (directoryObject.IsDeleted()) { res.Deleted = true; @@ -1140,7 +1158,7 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN utils.AddDomainController(objectIdentifier); } } - + string domain; if (directoryObject.TryGetDistinguishedName(out var distinguishedName)) { diff --git a/src/CommonLib/Processors/CertAbuseProcessor.cs b/src/CommonLib/Processors/CertAbuseProcessor.cs index aff16f8f..8c1cc787 100644 --- a/src/CommonLib/Processors/CertAbuseProcessor.cs +++ b/src/CommonLib/Processors/CertAbuseProcessor.cs @@ -107,8 +107,10 @@ public async Task ProcessRegistryEnrollmentPermissions(str } var (resSuccess, resolvedPrincipal) = await GetRegistryPrincipal(new SecurityIdentifier(principalSid), principalDomain, computerName, isDomainController, computerObjectId, machineSid); if (!resSuccess) { - resolvedPrincipal.ObjectType = Label.Base; - resolvedPrincipal.ObjectIdentifier = principalSid; + resolvedPrincipal = new TypedPrincipal { + ObjectType = Label.Base, + ObjectIdentifier = principalSid + }; } var isInherited = rule.IsInherited(); diff --git a/src/CommonLib/Processors/LdapPropertyProcessor.cs b/src/CommonLib/Processors/LdapPropertyProcessor.cs index 3c5c63eb..8577f0fe 100644 --- a/src/CommonLib/Processors/LdapPropertyProcessor.cs +++ b/src/CommonLib/Processors/LdapPropertyProcessor.cs @@ -2,13 +2,12 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; -using System.Reflection; using System.Runtime.InteropServices; using System.Security.AccessControl; using System.Security.Cryptography.X509Certificates; using System.Security.Principal; using System.Threading.Tasks; -using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Logging; using SharpHoundCommonLib.Enums; using SharpHoundCommonLib.LDAPQueries; using SharpHoundCommonLib.OutputTypes; @@ -866,8 +865,8 @@ private enum IsTextUnicodeFlags { public class ParsedCertificate { public string Thumbprint { get; set; } public string Name { get; set; } - public string[] Chain { get; set; } = Array.Empty(); - public bool HasBasicConstraints { get; set; } = false; + public string[] Chain { get; set; } + public bool HasBasicConstraints { get; set; } public int BasicConstraintPathLength { get; set; } public ParsedCertificate(byte[] rawCertificate) { @@ -877,21 +876,26 @@ public ParsedCertificate(byte[] rawCertificate) { Name = string.IsNullOrEmpty(name) ? Thumbprint : name; // Chain - X509Chain chain = new X509Chain(); - chain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck; - chain.Build(parsedCertificate); - var temp = new List(); - foreach (X509ChainElement cert in chain.ChainElements) temp.Add(cert.Certificate.Thumbprint); - Chain = temp.ToArray(); + try { + var chain = new X509Chain(); + chain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck; + chain.Build(parsedCertificate); + var temp = new List(); + foreach (var cert in chain.ChainElements) temp.Add(cert.Certificate.Thumbprint); + Chain = temp.ToArray(); + } catch (Exception e) { + Logging.LogProvider.CreateLogger("ParsedCertificate").LogWarning(e, "Failed to read certificate chain for certificate {Name} with Algo {Algorithm}", name, parsedCertificate.SignatureAlgorithm.FriendlyName); + Chain = Array.Empty(); + } + // Extensions - X509ExtensionCollection extensions = parsedCertificate.Extensions; - List certificateExtensions = new List(); - foreach (X509Extension extension in extensions) { - CertificateExtension certificateExtension = new CertificateExtension(extension); + var extensions = parsedCertificate.Extensions; + foreach (var extension in extensions) { + var certificateExtension = new CertificateExtension(extension); switch (certificateExtension.Oid.Value) { case CAExtensionTypes.BasicConstraints: - X509BasicConstraintsExtension ext = (X509BasicConstraintsExtension)extension; + var ext = (X509BasicConstraintsExtension)extension; HasBasicConstraints = ext.HasPathLengthConstraint; BasicConstraintPathLength = ext.PathLengthConstraint; break; diff --git a/src/CommonLib/SharpHoundCommonLib.csproj b/src/CommonLib/SharpHoundCommonLib.csproj index d8e0b937..85cffc92 100644 --- a/src/CommonLib/SharpHoundCommonLib.csproj +++ b/src/CommonLib/SharpHoundCommonLib.csproj @@ -9,7 +9,7 @@ Common library for C# BloodHound enumeration tasks GPL-3.0-only https://github.com/BloodHoundAD/SharpHoundCommon - 4.0.6 + 4.0.7 SharpHoundCommonLib SharpHoundCommonLib diff --git a/test/unit/Facades/MockLdapUtils.cs b/test/unit/Facades/MockLdapUtils.cs index 90e8cce3..1b62adde 100644 --- a/test/unit/Facades/MockLdapUtils.cs +++ b/test/unit/Facades/MockLdapUtils.cs @@ -1016,6 +1016,10 @@ public void SetLdapConfig(LdapConfig config) { throw new NotImplementedException(); } + public void ResetUtils() { + throw new NotImplementedException(); + } + public Domain GetDomain(string domainName = null) { throw new NotImplementedException();