Skip to content

Commit

Permalink
Fix unhandled exceptions, cache unresolvable entities (#163)
Browse files Browse the repository at this point in the history
* fix: catch exceptions on cert chain parsing and enterpriseca permissions (BED-4822)

* chore: fix unhandled exception

* fix: cache unresolvable principals so we don't retry a million times

* fix: use proper string comparison on reset

* chore: bump version
  • Loading branch information
rvazarkar committed Sep 17, 2024
1 parent dd20a7b commit 0850cf4
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 59 deletions.
60 changes: 60 additions & 0 deletions src/CommonLib/ConcurrentHashSet.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;

namespace SharpHoundCommonLib;

/// <summary>
/// A concurrent implementation of a hashset using a ConcurrentDictionary as the backing structure.
/// </summary>
public class ConcurrentHashSet : IDisposable{
private ConcurrentDictionary<string, byte> _backingDictionary;

public ConcurrentHashSet() {
_backingDictionary = new ConcurrentDictionary<string, byte>();
}

public ConcurrentHashSet(StringComparer comparison) {
_backingDictionary = new ConcurrentDictionary<string, byte>(comparison);
}

/// <summary>
/// Attempts to add an item to the set. Returns true if adding was successful, false otherwise
/// </summary>
/// <param name="item"></param>
/// <returns></returns>
public bool Add(string item) {
return _backingDictionary.TryAdd(item, byte.MinValue);
}

/// <summary>
/// Attempts to remove an item from the set. Returns true of removing was successful, false otherwise
/// </summary>
/// <param name="item"></param>
/// <returns></returns>
public bool Remove(string item) {
return _backingDictionary.TryRemove(item, out _);
}

/// <summary>
/// Checks if the given item is in the set
/// </summary>
/// <param name="item"></param>
/// <returns></returns>
public bool Contains(string item) {
return _backingDictionary.ContainsKey(item);
}

/// <summary>
/// Returns all values in the set
/// </summary>
/// <returns></returns>
public IEnumerable<string> Values() {
return _backingDictionary.Keys;
}

public void Dispose() {
_backingDictionary = null;
GC.SuppressFinalize(this);
}
}
5 changes: 5 additions & 0 deletions src/CommonLib/ILdapUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,10 @@ IAsyncEnumerable<Result<string>> RangedRetrieval(string distinguishedName,
/// <param name="context">The naming context being retrieved</param>
/// <returns>A tuple containing success state as well as the resolved distinguished name if successful</returns>
Task<(bool Success, string Path)> GetNamingContextPath(string domain, NamingContext context);

/// <summary>
/// Resets temporary caches in LDAPUtils
/// </summary>
void ResetUtils();
}
}
100 changes: 59 additions & 41 deletions src/CommonLib/LdapUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
using System.DirectoryServices;
using System.DirectoryServices.AccountManagement;
using System.DirectoryServices.ActiveDirectory;
using System.DirectoryServices.Protocols;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Security.Principal;
using System.Text;
using System.Text.RegularExpressions;
Expand All @@ -24,13 +22,13 @@
using Domain = System.DirectoryServices.ActiveDirectory.Domain;
using Group = SharpHoundCommonLib.OutputTypes.Group;
using SearchScope = System.DirectoryServices.Protocols.SearchScope;
using SecurityMasks = System.DirectoryServices.Protocols.SecurityMasks;

namespace SharpHoundCommonLib {
public class LdapUtils : ILdapUtils {
//This cache is indexed by domain sid
private static readonly ConcurrentDictionary<string, Domain> DomainCache = new();
private static readonly ConcurrentDictionary<string, byte> DomainControllers = new();
private static ConcurrentDictionary<string, Domain> _domainCache = new();
private static ConcurrentHashSet _domainControllers = new(StringComparer.OrdinalIgnoreCase);
private static ConcurrentHashSet _unresolvablePrincipals = new(StringComparer.OrdinalIgnoreCase);

private static readonly ConcurrentDictionary<string, string> DomainToForestCache =
new(StringComparer.OrdinalIgnoreCase);
Expand All @@ -55,11 +53,6 @@ private readonly ConcurrentDictionary<string, string>

private ConnectionPoolManager _connectionPool;

private static readonly TimeSpan MinBackoffDelay = TimeSpan.FromSeconds(2);
private static readonly TimeSpan MaxBackoffDelay = TimeSpan.FromSeconds(20);
private const int BackoffDelayMultiplier = 2;
private const int MaxRetries = 3;

private static readonly byte[] NameRequest = {
0x80, 0x94, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x20, 0x43, 0x4b, 0x41,
Expand Down Expand Up @@ -99,7 +92,8 @@ public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters
return _connectionPool.Query(queryParameters, cancellationToken);
}

public IAsyncEnumerable<LdapResult<IDirectoryObject>> PagedQuery(LdapQueryParameters queryParameters, CancellationToken cancellationToken = new()) {
public IAsyncEnumerable<LdapResult<IDirectoryObject>> PagedQuery(LdapQueryParameters queryParameters,
CancellationToken cancellationToken = new()) {
return _connectionPool.PagedQuery(queryParameters, cancellationToken);
}

Expand All @@ -119,12 +113,24 @@ public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters
return (true, principal);
}

if (_unresolvablePrincipals.Contains(identifier)) {
return (false, new TypedPrincipal(identifier, Label.Base));
}

if (identifier.StartsWith("S-")) {
var result = await LookupSidType(identifier, objectDomain);
if (!result.Success) {
_unresolvablePrincipals.Add(identifier);
}

return (result.Success, new TypedPrincipal(identifier, result.Type));
}

var (success, type) = await LookupGuidType(identifier, objectDomain);
if (!success) {
_unresolvablePrincipals.Add(identifier);
}

return (success, new TypedPrincipal(identifier, type));
}

Expand Down Expand Up @@ -297,7 +303,8 @@ public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters
LDAPFilter = new LdapFilter().AddAllObjects().GetFilter(),
};

var result = await Query(queryParameters).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail()).FirstOrDefaultAsync();
var result = await Query(queryParameters).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail())
.FirstOrDefaultAsync();
if (result.IsSuccess &&
result.Value.TryGetProperty(LDAPProperties.RootDomainNamingContext, out var rootNamingContext)) {
return (true, Helpers.DistinguishedNameToDomain(rootNamingContext).ToUpper());
Expand All @@ -306,12 +313,6 @@ public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters
return (false, null);
}

private static TimeSpan GetNextBackoff(int retryCount) {
return TimeSpan.FromSeconds(Math.Min(
MinBackoffDelay.TotalSeconds * Math.Pow(BackoffDelayMultiplier, retryCount),
MaxBackoffDelay.TotalSeconds));
}

public async Task<(bool Success, string DomainName)> GetDomainNameFromSid(string sid) {
string domainSid;
try {
Expand Down Expand Up @@ -465,7 +466,7 @@ private static TimeSpan GetNextBackoff(int retryCount) {
/// <returns></returns>
public bool GetDomain(string domainName, out Domain domain) {
var cacheKey = domainName ?? _nullCacheKey;
if (DomainCache.TryGetValue(cacheKey, out domain)) return true;
if (_domainCache.TryGetValue(cacheKey, out domain)) return true;

try {
DirectoryContext context;
Expand All @@ -482,7 +483,7 @@ public bool GetDomain(string domainName, out Domain domain) {

domain = Domain.GetDomain(context);
if (domain == null) return false;
DomainCache.TryAdd(cacheKey, domain);
_domainCache.TryAdd(cacheKey, domain);
return true;
} catch (Exception e) {
_log.LogDebug(e, "GetDomain call failed for domain name {Name}", domainName);
Expand All @@ -491,7 +492,7 @@ public bool GetDomain(string domainName, out Domain domain) {
}

public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domain domain) {
if (DomainCache.TryGetValue(domainName, out domain)) return true;
if (_domainCache.TryGetValue(domainName, out domain)) return true;

try {
DirectoryContext context;
Expand All @@ -508,7 +509,7 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai

domain = Domain.GetDomain(context);
if (domain == null) return false;
DomainCache.TryAdd(domainName, domain);
_domainCache.TryAdd(domainName, domain);
return true;
} catch (Exception e) {
Logging.Logger.LogDebug("Static GetDomain call failed for domain {DomainName}: {Error}", domainName,
Expand All @@ -525,8 +526,7 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai
/// <param name="domainName"></param>
/// <returns></returns>
public bool GetDomain(out Domain domain) {
var cacheKey = _nullCacheKey;
if (DomainCache.TryGetValue(cacheKey, out domain)) return true;
if (_domainCache.TryGetValue(_nullCacheKey, out domain)) return true;

try {
var context = _ldapConfig.Username != null
Expand All @@ -535,7 +535,7 @@ public bool GetDomain(out Domain domain) {
: new DirectoryContext(DirectoryContextType.Domain);

domain = Domain.GetDomain(context);
DomainCache.TryAdd(cacheKey, domain);
_domainCache.TryAdd(_nullCacheKey, domain);
return true;
} catch (Exception e) {
_log.LogDebug(e, "GetDomain call failed for blank domain");
Expand Down Expand Up @@ -579,7 +579,7 @@ public bool GetDomain(out Domain domain) {
return (false, string.Empty);
}

if (_hostResolutionMap.TryGetValue(strippedHost, out var sid)) return (true, sid);
if (_hostResolutionMap.TryGetValue(strippedHost, out var sid)) return (sid != null, sid);

//Immediately start with NetWkstaGetInfo as it's our most reliable indicator if successful
if (await GetWorkstationInfo(strippedHost) is (true, var workstationInfo)) {
Expand Down Expand Up @@ -658,6 +658,7 @@ public bool GetDomain(out Domain domain) {
//pass
}

_hostResolutionMap.TryAdd(strippedHost, null);
return (false, "");
}

Expand Down Expand Up @@ -693,7 +694,7 @@ public bool GetDomain(out Domain domain) {
if (await GetWellKnownPrincipal(sid, domain) is (true, var principal)) {
sids.Add(principal.ObjectIdentifier);
} else {
sids.Add(sid);
sids.Add(sid);
}
} else {
return (false, Array.Empty<string>());
Expand Down Expand Up @@ -825,9 +826,10 @@ public ActiveDirectorySecurityDescriptor MakeSecurityDescriptor() {
}

public async Task<bool> IsDomainController(string computerObjectId, string domainName) {
if (DomainControllers.ContainsKey(computerObjectId)) {
if (_domainControllers.Contains(computerObjectId)) {
return true;
}

var resDomain = await GetDomainNameFromSid(domainName) is (false, var tempDomain) ? tempDomain : domainName;
var filter = new LdapFilter().AddFilter(CommonFilters.SpecificSID(computerObjectId), true)
.AddFilter(CommonFilters.DomainControllers, true);
Expand All @@ -837,8 +839,9 @@ public async Task<bool> IsDomainController(string computerObjectId, string domai
LDAPFilter = filter.GetFilter(),
}).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail()).FirstOrDefaultAsync();
if (result.IsSuccess) {
DomainControllers.TryAdd(computerObjectId, new byte());
_domainControllers.Add(computerObjectId);
}

return result.IsSuccess;
}

Expand All @@ -847,6 +850,10 @@ public async Task<bool> IsDomainController(string computerObjectId, string domai
return (true, principal);
}

if (_unresolvablePrincipals.Contains(distinguishedName)) {
return (false, default);
}

var domain = Helpers.DistinguishedNameToDomain(distinguishedName);
var result = await Query(new LdapQueryParameters {
DomainName = domain,
Expand Down Expand Up @@ -890,13 +897,13 @@ public async Task<bool> IsDomainController(string computerObjectId, string domai

return (false, default);
} catch {
_unresolvablePrincipals.Add(distinguishedName);
return (false, default);
}
}
}

public async Task<(bool Success, string DSHeuristics)> GetDSHueristics(string domain, string dn)
{
public async Task<(bool Success, string DSHeuristics)> GetDSHueristics(string domain, string dn) {
var configPath = CommonPaths.CreateDNPath(CommonPaths.DirectoryServicePath, dn);
var queryParameters = new LdapQueryParameters {
Attributes = new[] { LDAPProperties.DSHeuristics },
Expand All @@ -907,16 +914,18 @@ public async Task<bool> IsDomainController(string computerObjectId, string domai
SearchBase = configPath
};

var result = await Query(queryParameters).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail()).FirstOrDefaultAsync();
var result = await Query(queryParameters).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail())
.FirstOrDefaultAsync();
if (result.IsSuccess &&
result.Value.TryGetProperty(LDAPProperties.DSHeuristics, out var dsh)) {
return (true, dsh);
}

return (false, null);
}

public void AddDomainController(string domainControllerSID) {
DomainControllers.TryAdd(domainControllerSID, new byte());
_domainControllers.Add(domainControllerSID);
}

public async IAsyncEnumerable<OutputBase> GetWellKnownPrincipalOutput() {
Expand Down Expand Up @@ -948,27 +957,28 @@ public async IAsyncEnumerable<OutputBase> GetWellKnownPrincipalOutput() {
yield return entdc;
}
}

private async IAsyncEnumerable<Group> GetEnterpriseDCGroups() {
var grouped = new ConcurrentDictionary<string, List<string>>(StringComparer.OrdinalIgnoreCase);
var forestSidToName = new ConcurrentDictionary<string, string>(StringComparer.OrdinalIgnoreCase);
foreach (var domainSid in DomainControllers.GroupBy(x => new SecurityIdentifier(x.Key).AccountDomainSid.Value)) {
foreach (var domainSid in _domainControllers.Values().GroupBy(x =>
new SecurityIdentifier(x).AccountDomainSid.Value)) {
if (await GetDomainNameFromSid(domainSid.Key) is (true, var domainName) &&
await GetForest(domainName) is (true, var forestName) &&
await GetDomainSidFromDomainName(forestName) is (true, var forestDomainSid)) {
forestSidToName.TryAdd(forestDomainSid, forestName);
if (!grouped.ContainsKey(forestDomainSid)) {
grouped[forestDomainSid] = new List<string>();
}

foreach (var k in domainSid) {
grouped[forestDomainSid].Add(k.Key);
grouped[forestDomainSid].Add(k);
}
}
}

foreach (var f in grouped) {
var group = new Group() { ObjectIdentifier = $"{f.Key}-S-1-5-9" };
var group = new Group { ObjectIdentifier = $"{f.Key}-S-1-5-9" };
group.Properties.Add("name", $"ENTERPRISE DOMAIN CONTROLLERS@{forestSidToName[f.Key]}".ToUpper());
group.Properties.Add("domainsid", f.Key);
group.Properties.Add("domain", forestSidToName[f.Key]);
Expand Down Expand Up @@ -1040,6 +1050,14 @@ public void SetLdapConfig(LdapConfig config) {
return (false, default);
}

public void ResetUtils() {
_unresolvablePrincipals = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase);
_domainCache = new ConcurrentDictionary<string, Domain>();
_domainControllers = new ConcurrentHashSet(StringComparer.OrdinalIgnoreCase);
_connectionPool?.Dispose();
_connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner);
}

private IDirectoryObject CreateDirectoryEntry(string path) {
if (_ldapConfig.Username != null) {
return new DirectoryEntry(path, _ldapConfig.Username, _ldapConfig.Password).ToDirectoryObject();
Expand Down Expand Up @@ -1122,11 +1140,11 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN
if (!directoryObject.GetObjectIdentifier(out var objectIdentifier)) {
return (false, default);
}

var res = new ResolvedSearchResult {
ObjectId = objectIdentifier
};

//If the object is deleted, we can short circuit the rest of this logic as we don't really care about anything else
if (directoryObject.IsDeleted()) {
res.Deleted = true;
Expand All @@ -1140,7 +1158,7 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN
utils.AddDomainController(objectIdentifier);
}
}

string domain;

if (directoryObject.TryGetDistinguishedName(out var distinguishedName)) {
Expand Down
Loading

0 comments on commit 0850cf4

Please sign in to comment.