Skip to content

Commit

Permalink
Fix race conditions in D2 cluster subsetting. Refactor subsetting cac…
Browse files Browse the repository at this point in the history
…he into SubsettingState
  • Loading branch information
rickzx authored Jun 3, 2021
1 parent 6d5516d commit 0105db4
Show file tree
Hide file tree
Showing 23 changed files with 724 additions and 259 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and what APIs have changed, if applicable.

## [Unreleased]

## [29.18.15] - 2021-06-02
- Fix race conditions in D2 cluster subsetting. Refactor subsetting cache to SubsettingState.

## [29.18.14] - 2021-05-27
- Use class.getClassLoader() instead of thread.getContextClassLoader() to get the class loader.

Expand Down Expand Up @@ -4966,7 +4969,8 @@ patch operations can re-use these classes for generating patch messages.

## [0.14.1]

[Unreleased]: https://github.com/linkedin/rest.li/compare/v29.18.14...master
[Unreleased]: https://github.com/linkedin/rest.li/compare/v29.18.15...master
[29.18.15]: https://github.com/linkedin/rest.li/compare/v29.18.14...v29.18.15
[29.18.14]: https://github.com/linkedin/rest.li/compare/v29.18.13...v29.18.14
[29.18.13]: https://github.com/linkedin/rest.li/compare/v29.18.12...v29.18.13
[29.18.12]: https://github.com/linkedin/rest.li/compare/v29.18.11...v29.18.12
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.linkedin.d2.balancer.properties.ServiceProperties;
import com.linkedin.d2.balancer.properties.UriProperties;
import com.linkedin.d2.balancer.strategies.LoadBalancerStrategy;
import com.linkedin.d2.balancer.subsetting.SubsettingState;
import com.linkedin.d2.balancer.util.partitions.PartitionAccessor;
import com.linkedin.d2.discovery.event.PropertyEventThread.PropertyEventShutdownCallback;
import com.linkedin.r2.transport.common.bridge.client.TransportClient;
Expand Down Expand Up @@ -91,12 +92,13 @@ public interface LoadBalancerState
List<SchemeStrategyPair> getStrategiesForService(String serviceName,
List<String> prioritizedSchemes);

default Map<URI, TrackerClient> getClientsSubset(String serviceName,
default SubsettingState.SubsetItem getClientsSubset(String serviceName,
int minClusterSubsetSize,
int partitionId,
Map<URI, TrackerClient> potentialClients)
Map<URI, TrackerClient> potentialClients,
long version)
{
return potentialClients;
return new SubsettingState.SubsetItem(false, potentialClients);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ public interface TrackerClient extends LoadBalancerClient
*/
TransportClient getTransportClient();

/**
* @param doNotSlowStart Should the host skip performing slow start
*/
default void setDoNotSlowStart(boolean doNotSlowStart)
{
}

/**
* @return Should the host skip performing slow start
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,11 @@ public class TrackerClientImpl implements TrackerClient
private final Map<Integer, PartitionData> _partitionData;
private final URI _uri;
private final Predicate<Integer> _isErrorStatus;
private final boolean _doNotSlowStart;
private final ConcurrentMap<Integer, Double> _subsetWeightMap;
final CallTracker _callTracker;

private boolean _doNotSlowStart;

private volatile CallTracker.CallStats _latestCallStats;

public TrackerClientImpl(URI uri, Map<Integer, PartitionData> partitionDataMap, TransportClient transportClient,
Expand Down Expand Up @@ -203,6 +204,12 @@ public void onResponse(TransportResponse<RestResponse> response)
}
}

@Override
public void setDoNotSlowStart(boolean doNotSlowStart)
{
_doNotSlowStart = doNotSlowStart;
}

@Override
public boolean doNotSlowStart()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import com.linkedin.d2.balancer.properties.ServiceProperties;
import com.linkedin.d2.balancer.properties.UriProperties;
import com.linkedin.d2.balancer.strategies.LoadBalancerStrategy;
import com.linkedin.d2.balancer.subsetting.SubsettingState;
import com.linkedin.d2.balancer.util.ClientFactoryProvider;
import com.linkedin.d2.balancer.util.ClusterInfoProvider;
import com.linkedin.d2.balancer.util.HostOverrideList;
Expand Down Expand Up @@ -285,8 +286,10 @@ public <K> MapKeyResult<Ring<URI>, K> getRings(URI serviceUri, Iterable<K> keys)
Ring<URI> ring = null;
for (LoadBalancerState.SchemeStrategyPair pair : orderedStrategies)
{
Map<URI, TrackerClient> clients = getPotentialClients(serviceName, service, cluster, uris, pair.getScheme(), partitionId);
ring = pair.getStrategy().getRing(uriItem.getVersion(), partitionId, clients);
SubsettingState.SubsetItem subsetItem = getPotentialClients(serviceName, service,
cluster, uris, pair.getScheme(), partitionId, uriItem.getVersion());
ring = pair.getStrategy().getRing(uriItem.getVersion(), partitionId, subsetItem.getWeightedSubset(),
subsetItem.shouldForceUpdate());

if (!ring.isEmpty())
{
Expand Down Expand Up @@ -336,9 +339,10 @@ public Map<Integer, Ring<URI>> getRings(URI serviceUri) throws ServiceUnavailabl
{
for (LoadBalancerState.SchemeStrategyPair pair : orderedStrategies)
{
Map<URI, TrackerClient> trackerClients = getPotentialClients(serviceName, service, cluster, uris,
pair.getScheme(), partitionId);
Ring<URI> ring = pair.getStrategy().getRing(uriItem.getVersion(), partitionId, trackerClients);
SubsettingState.SubsetItem subsetItem = getPotentialClients(serviceName, service, cluster, uris,
pair.getScheme(), partitionId, uriItem.getVersion());
Ring<URI> ring = pair.getStrategy().getRing(uriItem.getVersion(), partitionId, subsetItem.getWeightedSubset(),
subsetItem.shouldForceUpdate());
// ring will never be null; it can be empty
ringMap.put(partitionId, ring);

Expand Down Expand Up @@ -550,12 +554,14 @@ public <K> HostToKeyMapper<K> getPartitionInformation(URI serviceUri, Collection
{
for (LoadBalancerState.SchemeStrategyPair pair : orderedStrategies)
{
Map<URI, TrackerClient> trackerClients = getPotentialClients(serviceName, service, cluster, uris,
pair.getScheme(), partitionId);
SubsettingState.SubsetItem subsetItem = getPotentialClients(serviceName, service, cluster, uris,
pair.getScheme(), partitionId, uriItem.getVersion());
Map<URI, TrackerClient> trackerClients = subsetItem.getWeightedSubset();
int size = Math.min(trackerClients.size(), limitHostPerPartition);
List<URI> rankedUri = new ArrayList<>(size);

Ring<URI> ring = pair.getStrategy().getRing(uriItem.getVersion(), partitionId, trackerClients);
Ring<URI> ring = pair.getStrategy().getRing(uriItem.getVersion(), partitionId, trackerClients,
subsetItem.shouldForceUpdate());
Iterator<URI> iterator = ring.getIterator(hash);

while (iterator.hasNext() && rankedUri.size() < size)
Expand Down Expand Up @@ -709,25 +715,26 @@ public void getLoadBalancedServiceProperties(String serviceName, boolean waitFor
}

// supports partitioning
private Map<URI, TrackerClient> getPotentialClients(String serviceName,
private SubsettingState.SubsetItem getPotentialClients(String serviceName,
ServiceProperties serviceProperties,
ClusterProperties clusterProperties,
UriProperties uris,
String scheme,
int partitionId)
int partitionId,
long version)
{
Set<URI> possibleUris = uris.getUriBySchemeAndPartition(scheme, partitionId);

Map<URI, TrackerClient> clientsToBalance = getPotentialClients(serviceName, serviceProperties, clusterProperties, possibleUris);
Map<URI, TrackerClient> clientsSubset = serviceProperties.isEnableClusterSubsetting() ?
SubsettingState.SubsetItem subsetItem = serviceProperties.isEnableClusterSubsetting() ?
_state.getClientsSubset(serviceName, serviceProperties.getMinClusterSubsetSize(),
partitionId, clientsToBalance) : clientsToBalance;
partitionId, clientsToBalance, version) : new SubsettingState.SubsetItem(false, clientsToBalance);

if (clientsSubset.isEmpty())
if (subsetItem.getWeightedSubset().isEmpty())
{
info(_log, "Can not find a host for service: ", serviceName, ", scheme: ", scheme, ", partition: ", partitionId);
}
return clientsSubset;
return subsetItem;
}

private Map<URI, TrackerClient> getPotentialClients(String serviceName,
Expand Down Expand Up @@ -839,11 +846,13 @@ private TrackerClient chooseTrackerClient(Request request, RequestContext reques
LoadBalancerStrategy strategy = pair.getStrategy();
String scheme = pair.getScheme();


clientsToLoadBalance = getPotentialClients(serviceName, serviceProperties, cluster, uris, scheme, partitionId);
SubsettingState.SubsetItem subsetItem = getPotentialClients(serviceName, serviceProperties, cluster,
uris, scheme, partitionId, uriItem.getVersion());
clientsToLoadBalance = subsetItem.getWeightedSubset();

trackerClient =
strategy.getTrackerClient(request, requestContext, uriItem.getVersion(), partitionId, clientsToLoadBalance);
strategy.getTrackerClient(request, requestContext, uriItem.getVersion(), partitionId, clientsToLoadBalance,
subsetItem.shouldForceUpdate());

debug(_log,
"load balancer strategy for ",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
import com.linkedin.d2.balancer.strategies.degrader.DegraderLoadBalancerStrategyV3;
import com.linkedin.d2.balancer.strategies.relative.RelativeLoadBalancerStrategy;
import com.linkedin.d2.balancer.subsetting.DeterministicSubsettingMetadataProvider;
import com.linkedin.d2.balancer.subsetting.SubsettingStrategy;
import com.linkedin.d2.balancer.subsetting.SubsettingStrategyFactory;
import com.linkedin.d2.balancer.subsetting.SubsettingState;
import com.linkedin.d2.balancer.subsetting.SubsettingStrategyFactoryImpl;
import com.linkedin.d2.balancer.util.ClientFactoryProvider;
import com.linkedin.d2.balancer.util.LoadBalancerUtil;
Expand All @@ -59,7 +58,6 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -144,9 +142,7 @@ public class SimpleLoadBalancerState implements LoadBalancerState, ClientFactory
private final SSLParameters _sslParameters;
private final boolean _isSSLEnabled;
private final SslSessionValidatorFactory _sslSessionValidatorFactory;

private final SubsettingStrategyFactory _subsettingStrategyFactory;
private final ConcurrentMap<String, ConcurrentMap<Integer, Map<URI, TrackerClient>>> _weightedSubsetsCache;
private final SubsettingState _subsettingState;

/*
* Concurrency considerations:
Expand Down Expand Up @@ -314,13 +310,12 @@ public SimpleLoadBalancerState(ScheduledExecutorService executorService,
_clusterListeners = Collections.synchronizedList(new ArrayList<>());
if (deterministicSubsettingMetadataProvider != null)
{
_subsettingStrategyFactory = new SubsettingStrategyFactoryImpl(deterministicSubsettingMetadataProvider, this);
_subsettingState = new SubsettingState(new SubsettingStrategyFactoryImpl(), deterministicSubsettingMetadataProvider);
}
else
{
_subsettingStrategyFactory = SubsettingStrategyFactory.NO_OP_SUBSETTING_STRATEGY_FACTORY;
_subsettingState = null;
}
_weightedSubsetsCache = new ConcurrentHashMap<>();
}

public void register(final SimpleLoadBalancerStateListener listener)
Expand Down Expand Up @@ -674,57 +669,30 @@ public void setDelayedExecution(long delayedExecution)
}

@Override
public Map<URI, TrackerClient> getClientsSubset(String serviceName,
public SubsettingState.SubsetItem getClientsSubset(String serviceName,
int minClusterSubsetSize,
int partitionId,
Map<URI, TrackerClient> potentialClients)
Map<URI, TrackerClient> potentialClients,
long version)
{
SubsettingStrategy<URI> subsettingStrategy = _subsettingStrategyFactory.get(serviceName, minClusterSubsetSize, partitionId);

if (subsettingStrategy == null)
if (_subsettingState == null)
{
return potentialClients;
}

// If cluster version is not changed, return the cached subset if possible
if (!subsettingStrategy.isSubsetChanged(_version.get()) &&
_weightedSubsetsCache.containsKey(serviceName) &&
_weightedSubsetsCache.get(serviceName).containsKey(partitionId))
{
return _weightedSubsetsCache.get(serviceName).get(partitionId);
}

Map<URI, Double> weightMap = potentialClients.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().getPartitionWeight(partitionId)));
Map<URI, Double> subsetMap = subsettingStrategy.getWeightedSubset(weightMap, _version.get());

if (subsetMap == null)
{
return potentialClients;
return new SubsettingState.SubsetItem(false, potentialClients);
}
else
{
Map<URI, TrackerClient> subsetClients = new HashMap<>();
for (Map.Entry<URI, Double> entry: subsetMap.entrySet())
{
URI uri = entry.getKey();
TrackerClient client = potentialClients.get(uri);
client.setSubsetWeight(partitionId, subsetMap.get(uri));
subsetClients.put(uri, client);
}

_weightedSubsetsCache.computeIfAbsent(serviceName, k -> new ConcurrentHashMap<>());
_weightedSubsetsCache.get(serviceName).put(partitionId, subsetClients);

debug(_log, "cluster subset updated for service ", serviceName, ": [",
subsetClients.values().stream()
.limit(LOG_SUBSET_MAX_SIZE)
.map(client -> client.getUri() + ":" + client.getSubsetWeight(partitionId))
.collect(Collectors.joining(",")),
" (total ", subsetClients.size(), ")]"
SubsettingState.SubsetItem subsetItem = _subsettingState
.getClientsSubset(serviceName, minClusterSubsetSize, partitionId, potentialClients, version, this);

debug(_log, "get cluster subset for service ", serviceName, ": [",
subsetItem.getWeightedSubset().values().stream()
.limit(LOG_SUBSET_MAX_SIZE)
.map(client -> client.getUri() + ":" + client.getSubsetWeight(partitionId))
.collect(Collectors.joining(",")),
" (total ", subsetItem.getWeightedSubset().size(), ")], shouldForceUpdate = ", subsetItem.shouldForceUpdate()
);

return subsetClients;
return subsetItem;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ TrackerClient getTrackerClient(Request request,
int partitionId,
Map<URI, TrackerClient> trackerClients);

@Nullable
default TrackerClient getTrackerClient(Request request,
RequestContext requestContext,
long clusterGenerationId,
int partitionId,
Map<URI, TrackerClient> trackerClients,
boolean shouldForceUpdate)
{
return getTrackerClient(request, requestContext, clusterGenerationId, partitionId, trackerClients);
}

/**
* Returns a ring that can be used to choose a host. The ring will contain all the
* tracker clients passed as the argument.
Expand All @@ -79,6 +90,15 @@ Ring<URI> getRing(long clusterGenerationId,
int partitionId,
Map<URI, TrackerClient> trackerClients);

@Nonnull
default Ring<URI> getRing(long clusterGenerationId,
int partitionId,
Map<URI, TrackerClient> trackerClients,
boolean shouldForceUpdate)
{
return getRing(clusterGenerationId, partitionId, trackerClients);
}

/**
* Return the hashFunction which will be applied on {@code Request} to find the host for routing purpose
* @return the hashFunction
Expand Down
Loading

0 comments on commit 0105db4

Please sign in to comment.