diff --git a/src/CommonLib/ConcurrentHashSet.cs b/src/CommonLib/ConcurrentHashSet.cs
new file mode 100644
index 00000000..670175ce
--- /dev/null
+++ b/src/CommonLib/ConcurrentHashSet.cs
@@ -0,0 +1,60 @@
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+
+namespace SharpHoundCommonLib;
+
+///
+/// A concurrent implementation of a hashset using a ConcurrentDictionary as the backing structure.
+///
+public class ConcurrentHashSet : IDisposable{
+ private ConcurrentDictionary _backingDictionary;
+
+ public ConcurrentHashSet() {
+ _backingDictionary = new ConcurrentDictionary();
+ }
+
+ public ConcurrentHashSet(StringComparer comparison) {
+ _backingDictionary = new ConcurrentDictionary(comparison);
+ }
+
+ ///
+ /// Attempts to add an item to the set. Returns true if adding was successful, false otherwise
+ ///
+ ///
+ ///
+ public bool Add(string item) {
+ return _backingDictionary.TryAdd(item, byte.MinValue);
+ }
+
+ ///
+ /// Attempts to remove an item from the set. Returns true of removing was successful, false otherwise
+ ///
+ ///
+ ///
+ public bool Remove(string item) {
+ return _backingDictionary.TryRemove(item, out _);
+ }
+
+ ///
+ /// Checks if the given item is in the set
+ ///
+ ///
+ ///
+ public bool Contains(string item) {
+ return _backingDictionary.ContainsKey(item);
+ }
+
+ ///
+ /// Returns all values in the set
+ ///
+ ///
+ public IEnumerable Values() {
+ return _backingDictionary.Keys;
+ }
+
+ public void Dispose() {
+ _backingDictionary = null;
+ GC.SuppressFinalize(this);
+ }
+}
\ No newline at end of file
diff --git a/src/CommonLib/ILdapUtils.cs b/src/CommonLib/ILdapUtils.cs
index 56ebd782..d97e8a43 100644
--- a/src/CommonLib/ILdapUtils.cs
+++ b/src/CommonLib/ILdapUtils.cs
@@ -169,5 +169,10 @@ IAsyncEnumerable> RangedRetrieval(string distinguishedName,
/// The naming context being retrieved
/// A tuple containing success state as well as the resolved distinguished name if successful
Task<(bool Success, string Path)> GetNamingContextPath(string domain, NamingContext context);
+
+ ///
+ /// Resets temporary caches in LDAPUtils
+ ///
+ void ResetUtils();
}
}
\ No newline at end of file
diff --git a/src/CommonLib/LdapUtils.cs b/src/CommonLib/LdapUtils.cs
index 2eb3f7da..17293d3b 100644
--- a/src/CommonLib/LdapUtils.cs
+++ b/src/CommonLib/LdapUtils.cs
@@ -4,11 +4,9 @@
using System.DirectoryServices;
using System.DirectoryServices.AccountManagement;
using System.DirectoryServices.ActiveDirectory;
-using System.DirectoryServices.Protocols;
using System.Linq;
using System.Net;
using System.Net.Sockets;
-using System.Runtime.CompilerServices;
using System.Security.Principal;
using System.Text;
using System.Text.RegularExpressions;
@@ -24,13 +22,13 @@
using Domain = System.DirectoryServices.ActiveDirectory.Domain;
using Group = SharpHoundCommonLib.OutputTypes.Group;
using SearchScope = System.DirectoryServices.Protocols.SearchScope;
-using SecurityMasks = System.DirectoryServices.Protocols.SecurityMasks;
namespace SharpHoundCommonLib {
public class LdapUtils : ILdapUtils {
//This cache is indexed by domain sid
- private static readonly ConcurrentDictionary DomainCache = new();
- private static readonly ConcurrentDictionary DomainControllers = new();
+ private static ConcurrentDictionary _domainCache = new();
+ private static ConcurrentHashSet _domainControllers = new(StringComparer.OrdinalIgnoreCase);
+ private static ConcurrentHashSet _unresolvablePrincipals = new(StringComparer.OrdinalIgnoreCase);
private static readonly ConcurrentDictionary DomainToForestCache =
new(StringComparer.OrdinalIgnoreCase);
@@ -55,11 +53,6 @@ private readonly ConcurrentDictionary
private ConnectionPoolManager _connectionPool;
- private static readonly TimeSpan MinBackoffDelay = TimeSpan.FromSeconds(2);
- private static readonly TimeSpan MaxBackoffDelay = TimeSpan.FromSeconds(20);
- private const int BackoffDelayMultiplier = 2;
- private const int MaxRetries = 3;
-
private static readonly byte[] NameRequest = {
0x80, 0x94, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x20, 0x43, 0x4b, 0x41,
@@ -99,7 +92,8 @@ public IAsyncEnumerable> Query(LdapQueryParameters
return _connectionPool.Query(queryParameters, cancellationToken);
}
- public IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters, CancellationToken cancellationToken = new()) {
+ public IAsyncEnumerable> PagedQuery(LdapQueryParameters queryParameters,
+ CancellationToken cancellationToken = new()) {
return _connectionPool.PagedQuery(queryParameters, cancellationToken);
}
@@ -119,12 +113,24 @@ public IAsyncEnumerable> Query(LdapQueryParameters
return (true, principal);
}
+ if (_unresolvablePrincipals.Contains(identifier)) {
+ return (false, new TypedPrincipal(identifier, Label.Base));
+ }
+
if (identifier.StartsWith("S-")) {
var result = await LookupSidType(identifier, objectDomain);
+ if (!result.Success) {
+ _unresolvablePrincipals.Add(identifier);
+ }
+
return (result.Success, new TypedPrincipal(identifier, result.Type));
}
var (success, type) = await LookupGuidType(identifier, objectDomain);
+ if (!success) {
+ _unresolvablePrincipals.Add(identifier);
+ }
+
return (success, new TypedPrincipal(identifier, type));
}
@@ -297,7 +303,8 @@ public IAsyncEnumerable> Query(LdapQueryParameters
LDAPFilter = new LdapFilter().AddAllObjects().GetFilter(),
};
- var result = await Query(queryParameters).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync();
+ var result = await Query(queryParameters).DefaultIfEmpty(LdapResult.Fail())
+ .FirstOrDefaultAsync();
if (result.IsSuccess &&
result.Value.TryGetProperty(LDAPProperties.RootDomainNamingContext, out var rootNamingContext)) {
return (true, Helpers.DistinguishedNameToDomain(rootNamingContext).ToUpper());
@@ -306,12 +313,6 @@ public IAsyncEnumerable> Query(LdapQueryParameters
return (false, null);
}
- private static TimeSpan GetNextBackoff(int retryCount) {
- return TimeSpan.FromSeconds(Math.Min(
- MinBackoffDelay.TotalSeconds * Math.Pow(BackoffDelayMultiplier, retryCount),
- MaxBackoffDelay.TotalSeconds));
- }
-
public async Task<(bool Success, string DomainName)> GetDomainNameFromSid(string sid) {
string domainSid;
try {
@@ -465,7 +466,7 @@ private static TimeSpan GetNextBackoff(int retryCount) {
///
public bool GetDomain(string domainName, out Domain domain) {
var cacheKey = domainName ?? _nullCacheKey;
- if (DomainCache.TryGetValue(cacheKey, out domain)) return true;
+ if (_domainCache.TryGetValue(cacheKey, out domain)) return true;
try {
DirectoryContext context;
@@ -482,7 +483,7 @@ public bool GetDomain(string domainName, out Domain domain) {
domain = Domain.GetDomain(context);
if (domain == null) return false;
- DomainCache.TryAdd(cacheKey, domain);
+ _domainCache.TryAdd(cacheKey, domain);
return true;
} catch (Exception e) {
_log.LogDebug(e, "GetDomain call failed for domain name {Name}", domainName);
@@ -491,7 +492,7 @@ public bool GetDomain(string domainName, out Domain domain) {
}
public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domain domain) {
- if (DomainCache.TryGetValue(domainName, out domain)) return true;
+ if (_domainCache.TryGetValue(domainName, out domain)) return true;
try {
DirectoryContext context;
@@ -508,7 +509,7 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai
domain = Domain.GetDomain(context);
if (domain == null) return false;
- DomainCache.TryAdd(domainName, domain);
+ _domainCache.TryAdd(domainName, domain);
return true;
} catch (Exception e) {
Logging.Logger.LogDebug("Static GetDomain call failed for domain {DomainName}: {Error}", domainName,
@@ -525,8 +526,7 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai
///
///
public bool GetDomain(out Domain domain) {
- var cacheKey = _nullCacheKey;
- if (DomainCache.TryGetValue(cacheKey, out domain)) return true;
+ if (_domainCache.TryGetValue(_nullCacheKey, out domain)) return true;
try {
var context = _ldapConfig.Username != null
@@ -535,7 +535,7 @@ public bool GetDomain(out Domain domain) {
: new DirectoryContext(DirectoryContextType.Domain);
domain = Domain.GetDomain(context);
- DomainCache.TryAdd(cacheKey, domain);
+ _domainCache.TryAdd(_nullCacheKey, domain);
return true;
} catch (Exception e) {
_log.LogDebug(e, "GetDomain call failed for blank domain");
@@ -579,7 +579,7 @@ public bool GetDomain(out Domain domain) {
return (false, string.Empty);
}
- if (_hostResolutionMap.TryGetValue(strippedHost, out var sid)) return (true, sid);
+ if (_hostResolutionMap.TryGetValue(strippedHost, out var sid)) return (sid != null, sid);
//Immediately start with NetWkstaGetInfo as it's our most reliable indicator if successful
if (await GetWorkstationInfo(strippedHost) is (true, var workstationInfo)) {
@@ -658,6 +658,7 @@ public bool GetDomain(out Domain domain) {
//pass
}
+ _hostResolutionMap.TryAdd(strippedHost, null);
return (false, "");
}
@@ -693,7 +694,7 @@ public bool GetDomain(out Domain domain) {
if (await GetWellKnownPrincipal(sid, domain) is (true, var principal)) {
sids.Add(principal.ObjectIdentifier);
} else {
- sids.Add(sid);
+ sids.Add(sid);
}
} else {
return (false, Array.Empty());
@@ -825,9 +826,10 @@ public ActiveDirectorySecurityDescriptor MakeSecurityDescriptor() {
}
public async Task IsDomainController(string computerObjectId, string domainName) {
- if (DomainControllers.ContainsKey(computerObjectId)) {
+ if (_domainControllers.Contains(computerObjectId)) {
return true;
}
+
var resDomain = await GetDomainNameFromSid(domainName) is (false, var tempDomain) ? tempDomain : domainName;
var filter = new LdapFilter().AddFilter(CommonFilters.SpecificSID(computerObjectId), true)
.AddFilter(CommonFilters.DomainControllers, true);
@@ -837,8 +839,9 @@ public async Task IsDomainController(string computerObjectId, string domai
LDAPFilter = filter.GetFilter(),
}).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync();
if (result.IsSuccess) {
- DomainControllers.TryAdd(computerObjectId, new byte());
+ _domainControllers.Add(computerObjectId);
}
+
return result.IsSuccess;
}
@@ -847,6 +850,10 @@ public async Task IsDomainController(string computerObjectId, string domai
return (true, principal);
}
+ if (_unresolvablePrincipals.Contains(distinguishedName)) {
+ return (false, default);
+ }
+
var domain = Helpers.DistinguishedNameToDomain(distinguishedName);
var result = await Query(new LdapQueryParameters {
DomainName = domain,
@@ -890,13 +897,13 @@ public async Task IsDomainController(string computerObjectId, string domai
return (false, default);
} catch {
+ _unresolvablePrincipals.Add(distinguishedName);
return (false, default);
}
}
}
- public async Task<(bool Success, string DSHeuristics)> GetDSHueristics(string domain, string dn)
- {
+ public async Task<(bool Success, string DSHeuristics)> GetDSHueristics(string domain, string dn) {
var configPath = CommonPaths.CreateDNPath(CommonPaths.DirectoryServicePath, dn);
var queryParameters = new LdapQueryParameters {
Attributes = new[] { LDAPProperties.DSHeuristics },
@@ -907,16 +914,18 @@ public async Task IsDomainController(string computerObjectId, string domai
SearchBase = configPath
};
- var result = await Query(queryParameters).DefaultIfEmpty(LdapResult.Fail()).FirstOrDefaultAsync();
+ var result = await Query(queryParameters).DefaultIfEmpty(LdapResult.Fail())
+ .FirstOrDefaultAsync();
if (result.IsSuccess &&
result.Value.TryGetProperty(LDAPProperties.DSHeuristics, out var dsh)) {
return (true, dsh);
}
+
return (false, null);
}
public void AddDomainController(string domainControllerSID) {
- DomainControllers.TryAdd(domainControllerSID, new byte());
+ _domainControllers.Add(domainControllerSID);
}
public async IAsyncEnumerable GetWellKnownPrincipalOutput() {
@@ -948,11 +957,12 @@ public async IAsyncEnumerable GetWellKnownPrincipalOutput() {
yield return entdc;
}
}
-
+
private async IAsyncEnumerable GetEnterpriseDCGroups() {
var grouped = new ConcurrentDictionary>(StringComparer.OrdinalIgnoreCase);
var forestSidToName = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase);
- foreach (var domainSid in DomainControllers.GroupBy(x => new SecurityIdentifier(x.Key).AccountDomainSid.Value)) {
+ foreach (var domainSid in _domainControllers.Values().GroupBy(x =>
+ new SecurityIdentifier(x).AccountDomainSid.Value)) {
if (await GetDomainNameFromSid(domainSid.Key) is (true, var domainName) &&
await GetForest(domainName) is (true, var forestName) &&
await GetDomainSidFromDomainName(forestName) is (true, var forestDomainSid)) {
@@ -960,15 +970,15 @@ await GetDomainSidFromDomainName(forestName) is (true, var forestDomainSid)) {
if (!grouped.ContainsKey(forestDomainSid)) {
grouped[forestDomainSid] = new List();
}
-
+
foreach (var k in domainSid) {
- grouped[forestDomainSid].Add(k.Key);
+ grouped[forestDomainSid].Add(k);
}
}
}
foreach (var f in grouped) {
- var group = new Group() { ObjectIdentifier = $"{f.Key}-S-1-5-9" };
+ var group = new Group { ObjectIdentifier = $"{f.Key}-S-1-5-9" };
group.Properties.Add("name", $"ENTERPRISE DOMAIN CONTROLLERS@{forestSidToName[f.Key]}".ToUpper());
group.Properties.Add("domainsid", f.Key);
group.Properties.Add("domain", forestSidToName[f.Key]);
@@ -1040,6 +1050,14 @@ public void SetLdapConfig(LdapConfig config) {
return (false, default);
}
+ public void ResetUtils() {
+ _unresolvablePrincipals = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase);
+ _domainCache = new ConcurrentDictionary();
+ _domainControllers = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase);
+ _connectionPool?.Dispose();
+ _connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner);
+ }
+
private IDirectoryObject CreateDirectoryEntry(string path) {
if (_ldapConfig.Username != null) {
return new DirectoryEntry(path, _ldapConfig.Username, _ldapConfig.Password).ToDirectoryObject();
@@ -1122,11 +1140,11 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN
if (!directoryObject.GetObjectIdentifier(out var objectIdentifier)) {
return (false, default);
}
-
+
var res = new ResolvedSearchResult {
ObjectId = objectIdentifier
};
-
+
//If the object is deleted, we can short circuit the rest of this logic as we don't really care about anything else
if (directoryObject.IsDeleted()) {
res.Deleted = true;
@@ -1140,7 +1158,7 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN
utils.AddDomainController(objectIdentifier);
}
}
-
+
string domain;
if (directoryObject.TryGetDistinguishedName(out var distinguishedName)) {
diff --git a/src/CommonLib/Processors/CertAbuseProcessor.cs b/src/CommonLib/Processors/CertAbuseProcessor.cs
index aff16f8f..8c1cc787 100644
--- a/src/CommonLib/Processors/CertAbuseProcessor.cs
+++ b/src/CommonLib/Processors/CertAbuseProcessor.cs
@@ -107,8 +107,10 @@ public async Task ProcessRegistryEnrollmentPermissions(str
}
var (resSuccess, resolvedPrincipal) = await GetRegistryPrincipal(new SecurityIdentifier(principalSid), principalDomain, computerName, isDomainController, computerObjectId, machineSid);
if (!resSuccess) {
- resolvedPrincipal.ObjectType = Label.Base;
- resolvedPrincipal.ObjectIdentifier = principalSid;
+ resolvedPrincipal = new TypedPrincipal {
+ ObjectType = Label.Base,
+ ObjectIdentifier = principalSid
+ };
}
var isInherited = rule.IsInherited();
diff --git a/src/CommonLib/Processors/LdapPropertyProcessor.cs b/src/CommonLib/Processors/LdapPropertyProcessor.cs
index 3c5c63eb..8577f0fe 100644
--- a/src/CommonLib/Processors/LdapPropertyProcessor.cs
+++ b/src/CommonLib/Processors/LdapPropertyProcessor.cs
@@ -2,13 +2,12 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
-using System.Reflection;
using System.Runtime.InteropServices;
using System.Security.AccessControl;
using System.Security.Cryptography.X509Certificates;
using System.Security.Principal;
using System.Threading.Tasks;
-using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Extensions.Logging;
using SharpHoundCommonLib.Enums;
using SharpHoundCommonLib.LDAPQueries;
using SharpHoundCommonLib.OutputTypes;
@@ -866,8 +865,8 @@ private enum IsTextUnicodeFlags {
public class ParsedCertificate {
public string Thumbprint { get; set; }
public string Name { get; set; }
- public string[] Chain { get; set; } = Array.Empty();
- public bool HasBasicConstraints { get; set; } = false;
+ public string[] Chain { get; set; }
+ public bool HasBasicConstraints { get; set; }
public int BasicConstraintPathLength { get; set; }
public ParsedCertificate(byte[] rawCertificate) {
@@ -877,21 +876,26 @@ public ParsedCertificate(byte[] rawCertificate) {
Name = string.IsNullOrEmpty(name) ? Thumbprint : name;
// Chain
- X509Chain chain = new X509Chain();
- chain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
- chain.Build(parsedCertificate);
- var temp = new List();
- foreach (X509ChainElement cert in chain.ChainElements) temp.Add(cert.Certificate.Thumbprint);
- Chain = temp.ToArray();
+ try {
+ var chain = new X509Chain();
+ chain.ChainPolicy.RevocationMode = X509RevocationMode.NoCheck;
+ chain.Build(parsedCertificate);
+ var temp = new List();
+ foreach (var cert in chain.ChainElements) temp.Add(cert.Certificate.Thumbprint);
+ Chain = temp.ToArray();
+ } catch (Exception e) {
+ Logging.LogProvider.CreateLogger("ParsedCertificate").LogWarning(e, "Failed to read certificate chain for certificate {Name} with Algo {Algorithm}", name, parsedCertificate.SignatureAlgorithm.FriendlyName);
+ Chain = Array.Empty();
+ }
+
// Extensions
- X509ExtensionCollection extensions = parsedCertificate.Extensions;
- List certificateExtensions = new List();
- foreach (X509Extension extension in extensions) {
- CertificateExtension certificateExtension = new CertificateExtension(extension);
+ var extensions = parsedCertificate.Extensions;
+ foreach (var extension in extensions) {
+ var certificateExtension = new CertificateExtension(extension);
switch (certificateExtension.Oid.Value) {
case CAExtensionTypes.BasicConstraints:
- X509BasicConstraintsExtension ext = (X509BasicConstraintsExtension)extension;
+ var ext = (X509BasicConstraintsExtension)extension;
HasBasicConstraints = ext.HasPathLengthConstraint;
BasicConstraintPathLength = ext.PathLengthConstraint;
break;
diff --git a/src/CommonLib/SharpHoundCommonLib.csproj b/src/CommonLib/SharpHoundCommonLib.csproj
index d8e0b937..85cffc92 100644
--- a/src/CommonLib/SharpHoundCommonLib.csproj
+++ b/src/CommonLib/SharpHoundCommonLib.csproj
@@ -9,7 +9,7 @@
Common library for C# BloodHound enumeration tasks
GPL-3.0-only
https://github.com/BloodHoundAD/SharpHoundCommon
- 4.0.6
+ 4.0.7
SharpHoundCommonLib
SharpHoundCommonLib
diff --git a/test/unit/Facades/MockLdapUtils.cs b/test/unit/Facades/MockLdapUtils.cs
index 90e8cce3..1b62adde 100644
--- a/test/unit/Facades/MockLdapUtils.cs
+++ b/test/unit/Facades/MockLdapUtils.cs
@@ -1016,6 +1016,10 @@ public void SetLdapConfig(LdapConfig config) {
throw new NotImplementedException();
}
+ public void ResetUtils() {
+ throw new NotImplementedException();
+ }
+
public Domain GetDomain(string domainName = null)
{
throw new NotImplementedException();