diff --git a/concurrency-limits-core/src/main/java/com/netflix/concurrency/limits/limiter/AbstractLimiter.java b/concurrency-limits-core/src/main/java/com/netflix/concurrency/limits/limiter/AbstractLimiter.java index 0796b65..e4cb9e0 100644 --- a/concurrency-limits-core/src/main/java/com/netflix/concurrency/limits/limiter/AbstractLimiter.java +++ b/concurrency-limits-core/src/main/java/com/netflix/concurrency/limits/limiter/AbstractLimiter.java @@ -73,6 +73,10 @@ public BuilderT metricRegistry(MetricRegistry registry) { * Due to the builders not having access to the ContextT, it is the duty of subclasses to ensure that * implementations are type safe. * + * Predicates should not rely strictly on state of the Limiter (such as inflight count) when evaluating + * whether to bypass. There is no guarantee that the state will be synchronized or consistent with respect to + * the bypass predicate, and the bypass predicate may be called by multiple threads concurrently. + * * @param shouldBypass Predicate condition to bypass limit * @return Chainable builder */ diff --git a/concurrency-limits-core/src/main/java/com/netflix/concurrency/limits/limiter/AbstractPartitionedLimiter.java b/concurrency-limits-core/src/main/java/com/netflix/concurrency/limits/limiter/AbstractPartitionedLimiter.java index 1f0d085..36bb6ad 100644 --- a/concurrency-limits-core/src/main/java/com/netflix/concurrency/limits/limiter/AbstractPartitionedLimiter.java +++ b/concurrency-limits-core/src/main/java/com/netflix/concurrency/limits/limiter/AbstractPartitionedLimiter.java @@ -30,7 +30,6 @@ import java.util.Optional; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.locks.ReentrantLock; import java.util.function.Function; public abstract class AbstractPartitionedLimiter extends AbstractLimiter { @@ -105,10 +104,10 @@ public Limiter build() { static class Partition { private final String name; + private final AtomicInteger busy = new AtomicInteger(0); private double percent = 0.0; - private int limit = 0; - private int busy = 0; + private volatile int limit = 0; private long backoffMillis = 0; private MetricRegistry.SampleListener inflightDistribution; @@ -134,17 +133,33 @@ void updateLimit(int totalLimit) { } boolean isLimitExceeded() { - return busy >= limit; + return busy.get() >= limit; } void acquire() { - busy++; - inflightDistribution.addSample(busy); + int nowBusy = busy.incrementAndGet(); + inflightDistribution.addSample(nowBusy); + } + + /** + * Try to acquire a slot, returning false if the limit is exceeded. + * @return + */ + boolean tryAcquire() { + int current = busy.get(); + while (current < limit) { + if (busy.compareAndSet(current, current + 1)) { + inflightDistribution.addSample(current + 1); + return true; + } + current = busy.get(); + } + return false; } void release() { - busy--; + busy.decrementAndGet(); } int getLimit() { @@ -152,7 +167,7 @@ int getLimit() { } public int getInflight() { - return busy; + return busy.get(); } double getPercent() { @@ -166,14 +181,13 @@ void createMetrics(MetricRegistry registry) { @Override public String toString() { - return "Partition [pct=" + percent + ", limit=" + limit + ", busy=" + busy + "]"; + return "Partition [pct=" + percent + ", limit=" + limit + ", busy=" + busy.get() + "]"; } } private final Map partitions; private final Partition unknownPartition; private final List> partitionResolvers; - private final ReentrantLock lock = new ReentrantLock(); private final AtomicInteger delayedThreads = new AtomicInteger(); private final int maxDelayedThreads; @@ -211,63 +225,67 @@ private Partition resolvePartition(ContextT context) { @Override public Optional acquire(ContextT context) { - final Partition partition = resolvePartition(context); - - try { - lock.lock(); - if (shouldBypass(context)){ - return createBypassListener(); - } - if (getInflight() >= getLimit() && partition.isLimitExceeded()) { - lock.unlock(); - if (partition.backoffMillis > 0 && delayedThreads.get() < maxDelayedThreads) { - try { - delayedThreads.incrementAndGet(); - TimeUnit.MILLISECONDS.sleep(partition.backoffMillis); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } finally { - delayedThreads.decrementAndGet(); - } - } + if (shouldBypass(context)){ + return createBypassListener(); + } - return createRejectedListener(); - } + final Partition partition = resolvePartition(context); + // This is a little unusual in that the partition is not a hard limit. It is + // only a limit that it is applied if the global limit is exceeded. This allows + // for excess capacity in each partition to allow for bursting over the limit, + // but only if there is spare global capacity. + + final boolean overLimit; + if (getInflight() >= getLimit()) { + // over global limit, so respect partition limit + boolean couldAcquire = partition.tryAcquire(); + overLimit = !couldAcquire; + } else { + // we are below global limit, so no need to respect partition limit partition.acquire(); - final Listener listener = createListener(); - return Optional.of(new Listener() { - @Override - public void onSuccess() { - listener.onSuccess(); - releasePartition(partition); - } + overLimit = false; + } - @Override - public void onIgnore() { - listener.onIgnore(); - releasePartition(partition); + if (overLimit) { + if (partition.backoffMillis > 0 && delayedThreads.get() < maxDelayedThreads) { + try { + delayedThreads.incrementAndGet(); + TimeUnit.MILLISECONDS.sleep(partition.backoffMillis); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + delayedThreads.decrementAndGet(); } + } - @Override - public void onDropped() { - listener.onDropped(); - releasePartition(partition); - } - }); - } finally { - if (lock.isHeldByCurrentThread()) - lock.unlock(); + return createRejectedListener(); } + + final Listener listener = createListener(); + return Optional.of(new Listener() { + @Override + public void onSuccess() { + listener.onSuccess(); + releasePartition(partition); + } + + @Override + public void onIgnore() { + listener.onIgnore(); + releasePartition(partition); + } + + @Override + public void onDropped() { + listener.onDropped(); + releasePartition(partition); + } + }); } private void releasePartition(Partition partition) { - try { - lock.lock(); - partition.release(); - } finally { - lock.unlock(); - } + partition.release(); } @Override diff --git a/concurrency-limits-core/src/test/java/com/netflix/concurrency/limits/limiter/AbstractPartitionedLimiterTest.java b/concurrency-limits-core/src/test/java/com/netflix/concurrency/limits/limiter/AbstractPartitionedLimiterTest.java index a827c6a..505a7f8 100644 --- a/concurrency-limits-core/src/test/java/com/netflix/concurrency/limits/limiter/AbstractPartitionedLimiterTest.java +++ b/concurrency-limits-core/src/test/java/com/netflix/concurrency/limits/limiter/AbstractPartitionedLimiterTest.java @@ -1,12 +1,21 @@ package com.netflix.concurrency.limits.limiter; import com.netflix.concurrency.limits.Limiter; +import com.netflix.concurrency.limits.Limiter.Listener; import com.netflix.concurrency.limits.limit.FixedLimit; import com.netflix.concurrency.limits.limit.SettableLimit; import org.junit.Assert; import org.junit.Test; +import java.util.Arrays; +import java.util.Map; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.function.Predicate; @@ -227,4 +236,86 @@ public void testBypassSimpleLimiter() { Assert.assertTrue(limiter.acquire("admin").isPresent()); } } + + @Test + public void testConcurrentPartitions() throws InterruptedException { + final int THREAD_COUNT = 5; + final int ITERATIONS = 500; + final int LIMIT = 20; + + AbstractPartitionedLimiter limiter = (AbstractPartitionedLimiter) TestPartitionedLimiter.newBuilder() + .limit(FixedLimit.of(LIMIT)) + .partitionResolver(Function.identity()) + .partition("A", 0.5) + .partition("B", 0.3) + .partition("C", 0.2) + .build(); + + ExecutorService executor = Executors.newFixedThreadPool(THREAD_COUNT * 3); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(THREAD_COUNT * 3); + Map successCounts = new ConcurrentHashMap<>(); + Map rejectionCounts = new ConcurrentHashMap<>(); + Map maxConcurrents = new ConcurrentHashMap<>(); + AtomicInteger globalMaxInflight = new AtomicInteger(0); + + for (String partition : Arrays.asList("A", "B", "C")) { + successCounts.put(partition, new AtomicInteger(0)); + rejectionCounts.put(partition, new AtomicInteger(0)); + maxConcurrents.put(partition, new AtomicInteger(0)); + + for (int i = 0; i < THREAD_COUNT; i++) { + executor.submit(() -> { + try { + startLatch.await(); + for (int j = 0; j < ITERATIONS; j++) { + Optional listener = limiter.acquire(partition); + if (listener.isPresent()) { + try { + int current = limiter.getPartition(partition).getInflight(); + maxConcurrents.get(partition).updateAndGet(max -> Math.max(max, current)); + successCounts.get(partition).incrementAndGet(); + globalMaxInflight.updateAndGet(max -> Math.max(max, limiter.getInflight())); + Thread.sleep(1); // Simulate some work + } finally { + listener.get().onSuccess(); + } + } else { + rejectionCounts.get(partition).incrementAndGet(); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + endLatch.countDown(); + } + }); + } + } + + startLatch.countDown(); + endLatch.await(); + executor.shutdown(); + executor.awaitTermination(10, TimeUnit.SECONDS); + + StringBuilder resultSummary = new StringBuilder(); + for (String partition : Arrays.asList("A", "B", "C")) { + int successCount = successCounts.get(partition).get(); + int rejectionCount = rejectionCounts.get(partition).get(); + int maxConcurrent = maxConcurrents.get(partition).get(); + + resultSummary.append(String.format("%s(success=%d,reject=%d,maxConcurrent=%d) ", + partition, successCount, rejectionCount, maxConcurrent)); + + Assert.assertTrue("Max concurrent for " + partition + " should not exceed global limit. " + resultSummary, + maxConcurrent <= LIMIT); + Assert.assertEquals("Total attempts for " + partition + " should equal success + rejections. " + resultSummary, + THREAD_COUNT * ITERATIONS, + successCount + rejectionCount); + } + + Assert.assertTrue("Global max inflight should not exceed total limit. " + resultSummary, + globalMaxInflight.get() <= LIMIT + THREAD_COUNT); + } + } diff --git a/concurrency-limits-core/src/test/java/com/netflix/concurrency/limits/limiter/SimpleLimiterTest.java b/concurrency-limits-core/src/test/java/com/netflix/concurrency/limits/limiter/SimpleLimiterTest.java index cc8c264..f6065c6 100644 --- a/concurrency-limits-core/src/test/java/com/netflix/concurrency/limits/limiter/SimpleLimiterTest.java +++ b/concurrency-limits-core/src/test/java/com/netflix/concurrency/limits/limiter/SimpleLimiterTest.java @@ -1,11 +1,19 @@ package com.netflix.concurrency.limits.limiter; import com.netflix.concurrency.limits.Limiter; +import com.netflix.concurrency.limits.Limiter.Listener; import com.netflix.concurrency.limits.limit.FixedLimit; +import com.netflix.concurrency.limits.limiter.AbstractPartitionedLimiterTest.TestPartitionedLimiter; + import org.junit.Assert; import org.junit.Test; import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; public class SimpleLimiterTest { @@ -82,4 +90,66 @@ public void testSimpleBypassLimiterDefault() { Assert.assertFalse(limiter.acquire("admin").isPresent()); } + @Test + public void testConcurrentSimple() throws InterruptedException { + final int THREAD_COUNT = 100; + final int ITERATIONS = 1000; + final int LIMIT = 10; + + SimpleLimiter limiter = (SimpleLimiter) TestPartitionedLimiter.newBuilder() + .limit(FixedLimit.of(LIMIT)) + .partition("default", 1.0) + .build(); + + ExecutorService executor = Executors.newFixedThreadPool(THREAD_COUNT); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(THREAD_COUNT); + AtomicInteger successCount = new AtomicInteger(0); + AtomicInteger rejectionCount = new AtomicInteger(0); + AtomicInteger maxConcurrent = new AtomicInteger(0); + + for (int i = 0; i < THREAD_COUNT; i++) { + executor.submit(() -> { + try { + startLatch.await(); + for (int j = 0; j < ITERATIONS; j++) { + Optional listener = limiter.acquire("default"); + if (listener.isPresent()) { + try { + int current = limiter.getInflight(); + maxConcurrent.updateAndGet(max -> Math.max(max, current)); + successCount.incrementAndGet(); + Thread.sleep(1); // Simulate some work + } finally { + listener.get().onSuccess(); + } + } else { + rejectionCount.incrementAndGet(); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + endLatch.countDown(); + } + }); + } + + startLatch.countDown(); + endLatch.await(); + executor.shutdown(); + executor.awaitTermination(10, TimeUnit.SECONDS); + + StringBuilder resultBuilder = new StringBuilder(); + resultBuilder.append("Success count: ").append(successCount.get()) + .append(" | Rejection count: ").append(rejectionCount.get()) + .append(" | Max concurrent: ").append(maxConcurrent.get()); + String results = resultBuilder.toString(); + + Assert.assertTrue("Max concurrent should not exceed limit. " + results, + maxConcurrent.get() <= LIMIT); + Assert.assertEquals("Total attempts should equal success + rejections. " + results, + THREAD_COUNT * ITERATIONS, successCount.get() + rejectionCount.get()); + } + }