Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unhandled exceptions, cache unresolvable entities #163

Merged
merged 6 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/CommonLib/ConcurrentHashSet.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;

namespace SharpHoundCommonLib;

/// <summary>
/// A concurrent implementation of a hashset using a ConcurrentDictionary as the backing structure.
/// </summary>
public class ConcurrentHashSet : IDisposable{
private ConcurrentDictionary<string, byte> _backingDictionary;

public ConcurrentHashSet() {
_backingDictionary = new ConcurrentDictionary<string, byte>();
}

public ConcurrentHashSet(StringComparer comparison) {
_backingDictionary = new ConcurrentDictionary<string, byte>(comparison);
}

/// <summary>
/// Attempts to add an item to the set. Returns true if adding was successful, false otherwise
/// </summary>
/// <param name="item"></param>
/// <returns></returns>
public bool Add(string item) {
return _backingDictionary.TryAdd(item, byte.MinValue);
}

/// <summary>
/// Attempts to remove an item from the set. Returns true of removing was successful, false otherwise
/// </summary>
/// <param name="item"></param>
/// <returns></returns>
public bool Remove(string item) {
return _backingDictionary.TryRemove(item, out _);
}

/// <summary>
/// Checks if the given item is in the set
/// </summary>
/// <param name="item"></param>
/// <returns></returns>
public bool Contains(string item) {
return _backingDictionary.ContainsKey(item);
}

/// <summary>
/// Returns all values in the set
/// </summary>
/// <returns></returns>
public IEnumerable<string> Values() {
return _backingDictionary.Keys;
}

public void Dispose() {
_backingDictionary = null;
GC.SuppressFinalize(this);
}
}
5 changes: 5 additions & 0 deletions src/CommonLib/ILdapUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,5 +169,10 @@ IAsyncEnumerable<Result<string>> RangedRetrieval(string distinguishedName,
/// <param name="context">The naming context being retrieved</param>
/// <returns>A tuple containing success state as well as the resolved distinguished name if successful</returns>
Task<(bool Success, string Path)> GetNamingContextPath(string domain, NamingContext context);

/// <summary>
/// Resets temporary caches in LDAPUtils
/// </summary>
void ResetUtils();
}
}
100 changes: 59 additions & 41 deletions src/CommonLib/LdapUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
using System.DirectoryServices;
using System.DirectoryServices.AccountManagement;
using System.DirectoryServices.ActiveDirectory;
using System.DirectoryServices.Protocols;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Security.Principal;
using System.Text;
using System.Text.RegularExpressions;
Expand All @@ -24,13 +22,13 @@
using Domain = System.DirectoryServices.ActiveDirectory.Domain;
using Group = SharpHoundCommonLib.OutputTypes.Group;
using SearchScope = System.DirectoryServices.Protocols.SearchScope;
using SecurityMasks = System.DirectoryServices.Protocols.SecurityMasks;

namespace SharpHoundCommonLib {
public class LdapUtils : ILdapUtils {
//This cache is indexed by domain sid
private static readonly ConcurrentDictionary<string, Domain> DomainCache = new();
private static readonly ConcurrentDictionary<string, byte> DomainControllers = new();
private static ConcurrentDictionary<string, Domain> _domainCache = new();
private static ConcurrentHashSet _domainControllers = new(StringComparer.OrdinalIgnoreCase);
private static ConcurrentHashSet _unresolvablePrincipals = new(StringComparer.OrdinalIgnoreCase);

private static readonly ConcurrentDictionary<string, string> DomainToForestCache =
new(StringComparer.OrdinalIgnoreCase);
Expand All @@ -55,11 +53,6 @@ private readonly ConcurrentDictionary<string, string>

private ConnectionPoolManager _connectionPool;

private static readonly TimeSpan MinBackoffDelay = TimeSpan.FromSeconds(2);
private static readonly TimeSpan MaxBackoffDelay = TimeSpan.FromSeconds(20);
private const int BackoffDelayMultiplier = 2;
private const int MaxRetries = 3;

private static readonly byte[] NameRequest = {
0x80, 0x94, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x20, 0x43, 0x4b, 0x41,
Expand Down Expand Up @@ -99,7 +92,8 @@ public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters
return _connectionPool.Query(queryParameters, cancellationToken);
}

public IAsyncEnumerable<LdapResult<IDirectoryObject>> PagedQuery(LdapQueryParameters queryParameters, CancellationToken cancellationToken = new()) {
public IAsyncEnumerable<LdapResult<IDirectoryObject>> PagedQuery(LdapQueryParameters queryParameters,
CancellationToken cancellationToken = new()) {
rvazarkar marked this conversation as resolved.
Show resolved Hide resolved
return _connectionPool.PagedQuery(queryParameters, cancellationToken);
}

Expand All @@ -119,12 +113,24 @@ public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters
return (true, principal);
}

if (_unresolvablePrincipals.Contains(identifier)) {
return (false, new TypedPrincipal(identifier, Label.Base));
}

if (identifier.StartsWith("S-")) {
var result = await LookupSidType(identifier, objectDomain);
if (!result.Success) {
_unresolvablePrincipals.Add(identifier);
}

return (result.Success, new TypedPrincipal(identifier, result.Type));
}

var (success, type) = await LookupGuidType(identifier, objectDomain);
if (!success) {
_unresolvablePrincipals.Add(identifier);
}

return (success, new TypedPrincipal(identifier, type));
}

Expand Down Expand Up @@ -297,7 +303,8 @@ public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters
LDAPFilter = new LdapFilter().AddAllObjects().GetFilter(),
};

var result = await Query(queryParameters).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail()).FirstOrDefaultAsync();
var result = await Query(queryParameters).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail())
.FirstOrDefaultAsync();
if (result.IsSuccess &&
result.Value.TryGetProperty(LDAPProperties.RootDomainNamingContext, out var rootNamingContext)) {
return (true, Helpers.DistinguishedNameToDomain(rootNamingContext).ToUpper());
Expand All @@ -306,12 +313,6 @@ public IAsyncEnumerable<LdapResult<IDirectoryObject>> Query(LdapQueryParameters
return (false, null);
}

private static TimeSpan GetNextBackoff(int retryCount) {
return TimeSpan.FromSeconds(Math.Min(
MinBackoffDelay.TotalSeconds * Math.Pow(BackoffDelayMultiplier, retryCount),
MaxBackoffDelay.TotalSeconds));
}

public async Task<(bool Success, string DomainName)> GetDomainNameFromSid(string sid) {
string domainSid;
try {
Expand Down Expand Up @@ -465,7 +466,7 @@ private static TimeSpan GetNextBackoff(int retryCount) {
/// <returns></returns>
public bool GetDomain(string domainName, out Domain domain) {
var cacheKey = domainName ?? _nullCacheKey;
if (DomainCache.TryGetValue(cacheKey, out domain)) return true;
if (_domainCache.TryGetValue(cacheKey, out domain)) return true;

try {
DirectoryContext context;
Expand All @@ -482,7 +483,7 @@ public bool GetDomain(string domainName, out Domain domain) {

domain = Domain.GetDomain(context);
if (domain == null) return false;
DomainCache.TryAdd(cacheKey, domain);
_domainCache.TryAdd(cacheKey, domain);
return true;
} catch (Exception e) {
_log.LogDebug(e, "GetDomain call failed for domain name {Name}", domainName);
Expand All @@ -491,7 +492,7 @@ public bool GetDomain(string domainName, out Domain domain) {
}

public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domain domain) {
if (DomainCache.TryGetValue(domainName, out domain)) return true;
if (_domainCache.TryGetValue(domainName, out domain)) return true;

try {
DirectoryContext context;
Expand All @@ -508,7 +509,7 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai

domain = Domain.GetDomain(context);
if (domain == null) return false;
DomainCache.TryAdd(domainName, domain);
_domainCache.TryAdd(domainName, domain);
return true;
} catch (Exception e) {
Logging.Logger.LogDebug("Static GetDomain call failed for domain {DomainName}: {Error}", domainName,
Expand All @@ -525,8 +526,7 @@ public static bool GetDomain(string domainName, LdapConfig ldapConfig, out Domai
/// <param name="domainName"></param>
/// <returns></returns>
public bool GetDomain(out Domain domain) {
var cacheKey = _nullCacheKey;
if (DomainCache.TryGetValue(cacheKey, out domain)) return true;
if (_domainCache.TryGetValue(_nullCacheKey, out domain)) return true;

try {
var context = _ldapConfig.Username != null
Expand All @@ -535,7 +535,7 @@ public bool GetDomain(out Domain domain) {
: new DirectoryContext(DirectoryContextType.Domain);

domain = Domain.GetDomain(context);
DomainCache.TryAdd(cacheKey, domain);
_domainCache.TryAdd(_nullCacheKey, domain);
return true;
} catch (Exception e) {
_log.LogDebug(e, "GetDomain call failed for blank domain");
Expand Down Expand Up @@ -579,7 +579,7 @@ public bool GetDomain(out Domain domain) {
return (false, string.Empty);
}

if (_hostResolutionMap.TryGetValue(strippedHost, out var sid)) return (true, sid);
if (_hostResolutionMap.TryGetValue(strippedHost, out var sid)) return (sid != null, sid);

//Immediately start with NetWkstaGetInfo as it's our most reliable indicator if successful
if (await GetWorkstationInfo(strippedHost) is (true, var workstationInfo)) {
Expand Down Expand Up @@ -658,6 +658,7 @@ public bool GetDomain(out Domain domain) {
//pass
}

_hostResolutionMap.TryAdd(strippedHost, null);
return (false, "");
}

Expand Down Expand Up @@ -693,7 +694,7 @@ public bool GetDomain(out Domain domain) {
if (await GetWellKnownPrincipal(sid, domain) is (true, var principal)) {
sids.Add(principal.ObjectIdentifier);
} else {
sids.Add(sid);
sids.Add(sid);
}
} else {
return (false, Array.Empty<string>());
Expand Down Expand Up @@ -825,9 +826,10 @@ public ActiveDirectorySecurityDescriptor MakeSecurityDescriptor() {
}

public async Task<bool> IsDomainController(string computerObjectId, string domainName) {
if (DomainControllers.ContainsKey(computerObjectId)) {
if (_domainControllers.Contains(computerObjectId)) {
return true;
}

var resDomain = await GetDomainNameFromSid(domainName) is (false, var tempDomain) ? tempDomain : domainName;
var filter = new LdapFilter().AddFilter(CommonFilters.SpecificSID(computerObjectId), true)
.AddFilter(CommonFilters.DomainControllers, true);
Expand All @@ -837,8 +839,9 @@ public async Task<bool> IsDomainController(string computerObjectId, string domai
LDAPFilter = filter.GetFilter(),
}).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail()).FirstOrDefaultAsync();
if (result.IsSuccess) {
DomainControllers.TryAdd(computerObjectId, new byte());
_domainControllers.Add(computerObjectId);
}

return result.IsSuccess;
}

Expand All @@ -847,6 +850,10 @@ public async Task<bool> IsDomainController(string computerObjectId, string domai
return (true, principal);
}

if (_unresolvablePrincipals.Contains(distinguishedName)) {
return (false, default);
}

var domain = Helpers.DistinguishedNameToDomain(distinguishedName);
var result = await Query(new LdapQueryParameters {
DomainName = domain,
Expand Down Expand Up @@ -890,13 +897,13 @@ public async Task<bool> IsDomainController(string computerObjectId, string domai

return (false, default);
} catch {
_unresolvablePrincipals.Add(distinguishedName);
return (false, default);
}
}
}

public async Task<(bool Success, string DSHeuristics)> GetDSHueristics(string domain, string dn)
{
public async Task<(bool Success, string DSHeuristics)> GetDSHueristics(string domain, string dn) {
var configPath = CommonPaths.CreateDNPath(CommonPaths.DirectoryServicePath, dn);
var queryParameters = new LdapQueryParameters {
Attributes = new[] { LDAPProperties.DSHeuristics },
Expand All @@ -907,16 +914,18 @@ public async Task<bool> IsDomainController(string computerObjectId, string domai
SearchBase = configPath
};

var result = await Query(queryParameters).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail()).FirstOrDefaultAsync();
var result = await Query(queryParameters).DefaultIfEmpty(LdapResult<IDirectoryObject>.Fail())
.FirstOrDefaultAsync();
if (result.IsSuccess &&
result.Value.TryGetProperty(LDAPProperties.DSHeuristics, out var dsh)) {
return (true, dsh);
}

return (false, null);
}

public void AddDomainController(string domainControllerSID) {
DomainControllers.TryAdd(domainControllerSID, new byte());
_domainControllers.Add(domainControllerSID);
}

public async IAsyncEnumerable<OutputBase> GetWellKnownPrincipalOutput() {
Expand Down Expand Up @@ -948,27 +957,28 @@ public async IAsyncEnumerable<OutputBase> GetWellKnownPrincipalOutput() {
yield return entdc;
}
}

private async IAsyncEnumerable<Group> GetEnterpriseDCGroups() {
var grouped = new ConcurrentDictionary<string, List<string>>(StringComparer.OrdinalIgnoreCase);
var forestSidToName = new ConcurrentDictionary<string, string>(StringComparer.OrdinalIgnoreCase);
foreach (var domainSid in DomainControllers.GroupBy(x => new SecurityIdentifier(x.Key).AccountDomainSid.Value)) {
foreach (var domainSid in _domainControllers.Values().GroupBy(x =>
new SecurityIdentifier(x).AccountDomainSid.Value)) {
if (await GetDomainNameFromSid(domainSid.Key) is (true, var domainName) &&
await GetForest(domainName) is (true, var forestName) &&
await GetDomainSidFromDomainName(forestName) is (true, var forestDomainSid)) {
forestSidToName.TryAdd(forestDomainSid, forestName);
if (!grouped.ContainsKey(forestDomainSid)) {
grouped[forestDomainSid] = new List<string>();
}

foreach (var k in domainSid) {
grouped[forestDomainSid].Add(k.Key);
grouped[forestDomainSid].Add(k);
}
}
}

foreach (var f in grouped) {
var group = new Group() { ObjectIdentifier = $"{f.Key}-S-1-5-9" };
var group = new Group { ObjectIdentifier = $"{f.Key}-S-1-5-9" };
group.Properties.Add("name", $"ENTERPRISE DOMAIN CONTROLLERS@{forestSidToName[f.Key]}".ToUpper());
group.Properties.Add("domainsid", f.Key);
group.Properties.Add("domain", forestSidToName[f.Key]);
Expand Down Expand Up @@ -1040,6 +1050,14 @@ public void SetLdapConfig(LdapConfig config) {
return (false, default);
}

public void ResetUtils() {
_unresolvablePrincipals = new ConcurrentHashSet();
_domainCache = new ConcurrentDictionary<string, Domain>();
_domainControllers = new ConcurrentHashSet();
_connectionPool?.Dispose();
_connectionPool = new ConnectionPoolManager(_ldapConfig, scanner: _portScanner);
}

private IDirectoryObject CreateDirectoryEntry(string path) {
if (_ldapConfig.Username != null) {
return new DirectoryEntry(path, _ldapConfig.Username, _ldapConfig.Password).ToDirectoryObject();
Expand Down Expand Up @@ -1122,11 +1140,11 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN
if (!directoryObject.GetObjectIdentifier(out var objectIdentifier)) {
return (false, default);
}

var res = new ResolvedSearchResult {
ObjectId = objectIdentifier
};

//If the object is deleted, we can short circuit the rest of this logic as we don't really care about anything else
if (directoryObject.IsDeleted()) {
res.Deleted = true;
Expand All @@ -1140,7 +1158,7 @@ internal static bool ResolveLabel(string objectIdentifier, string distinguishedN
utils.AddDomainController(objectIdentifier);
}
}

string domain;

if (directoryObject.TryGetDistinguishedName(out var distinguishedName)) {
Expand Down
Loading
Loading