Skip to content

Commit

Permalink
Merge pull request #204 from jasonk000/jkoch/less-locking-partitioned…
Browse files Browse the repository at this point in the history
…-limiter

Reduce contention a little in AbstractPartitionedLimiter
  • Loading branch information
umairk79 authored Sep 3, 2024
2 parents 075bc76 + 09a1e6f commit 58dab7f
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ public BuilderT metricRegistry(MetricRegistry registry) {
* Due to the builders not having access to the ContextT, it is the duty of subclasses to ensure that
* implementations are type safe.
*
* Predicates should not rely strictly on state of the Limiter (such as inflight count) when evaluating
* whether to bypass. There is no guarantee that the state will be synchronized or consistent with respect to
* the bypass predicate, and the bypass predicate may be called by multiple threads concurrently.
*
* @param shouldBypass Predicate condition to bypass limit
* @return Chainable builder
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Function;

public abstract class AbstractPartitionedLimiter<ContextT> extends AbstractLimiter<ContextT> {
Expand Down Expand Up @@ -105,10 +104,10 @@ public Limiter<ContextT> build() {

static class Partition {
private final String name;
private final AtomicInteger busy = new AtomicInteger(0);

private double percent = 0.0;
private int limit = 0;
private int busy = 0;
private volatile int limit = 0;
private long backoffMillis = 0;
private MetricRegistry.SampleListener inflightDistribution;

Expand All @@ -134,25 +133,41 @@ void updateLimit(int totalLimit) {
}

boolean isLimitExceeded() {
return busy >= limit;
return busy.get() >= limit;
}

void acquire() {
busy++;
inflightDistribution.addSample(busy);
int nowBusy = busy.incrementAndGet();
inflightDistribution.addSample(nowBusy);
}

/**
* Try to acquire a slot, returning false if the limit is exceeded.
* @return
*/
boolean tryAcquire() {
int current = busy.get();
while (current < limit) {
if (busy.compareAndSet(current, current + 1)) {
inflightDistribution.addSample(current + 1);
return true;
}
current = busy.get();
}

return false;
}

void release() {
busy--;
busy.decrementAndGet();
}

int getLimit() {
return limit;
}

public int getInflight() {
return busy;
return busy.get();
}

double getPercent() {
Expand All @@ -166,14 +181,13 @@ void createMetrics(MetricRegistry registry) {

@Override
public String toString() {
return "Partition [pct=" + percent + ", limit=" + limit + ", busy=" + busy + "]";
return "Partition [pct=" + percent + ", limit=" + limit + ", busy=" + busy.get() + "]";
}
}

private final Map<String, Partition> partitions;
private final Partition unknownPartition;
private final List<Function<ContextT, String>> partitionResolvers;
private final ReentrantLock lock = new ReentrantLock();
private final AtomicInteger delayedThreads = new AtomicInteger();
private final int maxDelayedThreads;

Expand Down Expand Up @@ -211,63 +225,67 @@ private Partition resolvePartition(ContextT context) {

@Override
public Optional<Listener> acquire(ContextT context) {
final Partition partition = resolvePartition(context);

try {
lock.lock();
if (shouldBypass(context)){
return createBypassListener();
}
if (getInflight() >= getLimit() && partition.isLimitExceeded()) {
lock.unlock();
if (partition.backoffMillis > 0 && delayedThreads.get() < maxDelayedThreads) {
try {
delayedThreads.incrementAndGet();
TimeUnit.MILLISECONDS.sleep(partition.backoffMillis);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
delayedThreads.decrementAndGet();
}
}
if (shouldBypass(context)){
return createBypassListener();
}

return createRejectedListener();
}
final Partition partition = resolvePartition(context);

// This is a little unusual in that the partition is not a hard limit. It is
// only a limit that it is applied if the global limit is exceeded. This allows
// for excess capacity in each partition to allow for bursting over the limit,
// but only if there is spare global capacity.

final boolean overLimit;
if (getInflight() >= getLimit()) {
// over global limit, so respect partition limit
boolean couldAcquire = partition.tryAcquire();
overLimit = !couldAcquire;
} else {
// we are below global limit, so no need to respect partition limit
partition.acquire();
final Listener listener = createListener();
return Optional.of(new Listener() {
@Override
public void onSuccess() {
listener.onSuccess();
releasePartition(partition);
}
overLimit = false;
}

@Override
public void onIgnore() {
listener.onIgnore();
releasePartition(partition);
if (overLimit) {
if (partition.backoffMillis > 0 && delayedThreads.get() < maxDelayedThreads) {
try {
delayedThreads.incrementAndGet();
TimeUnit.MILLISECONDS.sleep(partition.backoffMillis);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
delayedThreads.decrementAndGet();
}
}

@Override
public void onDropped() {
listener.onDropped();
releasePartition(partition);
}
});
} finally {
if (lock.isHeldByCurrentThread())
lock.unlock();
return createRejectedListener();
}

final Listener listener = createListener();
return Optional.of(new Listener() {
@Override
public void onSuccess() {
listener.onSuccess();
releasePartition(partition);
}

@Override
public void onIgnore() {
listener.onIgnore();
releasePartition(partition);
}

@Override
public void onDropped() {
listener.onDropped();
releasePartition(partition);
}
});
}

private void releasePartition(Partition partition) {
try {
lock.lock();
partition.release();
} finally {
lock.unlock();
}
partition.release();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
package com.netflix.concurrency.limits.limiter;

import com.netflix.concurrency.limits.Limiter;
import com.netflix.concurrency.limits.Limiter.Listener;
import com.netflix.concurrency.limits.limit.FixedLimit;
import com.netflix.concurrency.limits.limit.SettableLimit;
import org.junit.Assert;
import org.junit.Test;

import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Predicate;

Expand Down Expand Up @@ -227,4 +236,86 @@ public void testBypassSimpleLimiter() {
Assert.assertTrue(limiter.acquire("admin").isPresent());
}
}

@Test
public void testConcurrentPartitions() throws InterruptedException {
final int THREAD_COUNT = 5;
final int ITERATIONS = 500;
final int LIMIT = 20;

AbstractPartitionedLimiter<String> limiter = (AbstractPartitionedLimiter<String>) TestPartitionedLimiter.newBuilder()
.limit(FixedLimit.of(LIMIT))
.partitionResolver(Function.identity())
.partition("A", 0.5)
.partition("B", 0.3)
.partition("C", 0.2)
.build();

ExecutorService executor = Executors.newFixedThreadPool(THREAD_COUNT * 3);
CountDownLatch startLatch = new CountDownLatch(1);
CountDownLatch endLatch = new CountDownLatch(THREAD_COUNT * 3);
Map<String, AtomicInteger> successCounts = new ConcurrentHashMap<>();
Map<String, AtomicInteger> rejectionCounts = new ConcurrentHashMap<>();
Map<String, AtomicInteger> maxConcurrents = new ConcurrentHashMap<>();
AtomicInteger globalMaxInflight = new AtomicInteger(0);

for (String partition : Arrays.asList("A", "B", "C")) {
successCounts.put(partition, new AtomicInteger(0));
rejectionCounts.put(partition, new AtomicInteger(0));
maxConcurrents.put(partition, new AtomicInteger(0));

for (int i = 0; i < THREAD_COUNT; i++) {
executor.submit(() -> {
try {
startLatch.await();
for (int j = 0; j < ITERATIONS; j++) {
Optional<Listener> listener = limiter.acquire(partition);
if (listener.isPresent()) {
try {
int current = limiter.getPartition(partition).getInflight();
maxConcurrents.get(partition).updateAndGet(max -> Math.max(max, current));
successCounts.get(partition).incrementAndGet();
globalMaxInflight.updateAndGet(max -> Math.max(max, limiter.getInflight()));
Thread.sleep(1); // Simulate some work
} finally {
listener.get().onSuccess();
}
} else {
rejectionCounts.get(partition).incrementAndGet();
}
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
endLatch.countDown();
}
});
}
}

startLatch.countDown();
endLatch.await();
executor.shutdown();
executor.awaitTermination(10, TimeUnit.SECONDS);

StringBuilder resultSummary = new StringBuilder();
for (String partition : Arrays.asList("A", "B", "C")) {
int successCount = successCounts.get(partition).get();
int rejectionCount = rejectionCounts.get(partition).get();
int maxConcurrent = maxConcurrents.get(partition).get();

resultSummary.append(String.format("%s(success=%d,reject=%d,maxConcurrent=%d) ",
partition, successCount, rejectionCount, maxConcurrent));

Assert.assertTrue("Max concurrent for " + partition + " should not exceed global limit. " + resultSummary,
maxConcurrent <= LIMIT);
Assert.assertEquals("Total attempts for " + partition + " should equal success + rejections. " + resultSummary,
THREAD_COUNT * ITERATIONS,
successCount + rejectionCount);
}

Assert.assertTrue("Global max inflight should not exceed total limit. " + resultSummary,
globalMaxInflight.get() <= LIMIT + THREAD_COUNT);
}

}
Loading

0 comments on commit 58dab7f

Please sign in to comment.