diff --git a/src/main/java/io/vertx/core/net/impl/pool/CombinerExecutor.java b/src/main/java/io/vertx/core/net/impl/pool/CombinerExecutor.java index 5acdd797c43..052904d2eec 100644 --- a/src/main/java/io/vertx/core/net/impl/pool/CombinerExecutor.java +++ b/src/main/java/io/vertx/core/net/impl/pool/CombinerExecutor.java @@ -13,6 +13,8 @@ import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.internal.PlatformDependent; +import java.util.HashMap; +import java.util.Map; import java.util.Queue; import java.util.concurrent.atomic.AtomicInteger; @@ -29,11 +31,19 @@ public class CombinerExecutor implements Executor { private final AtomicInteger s = new AtomicInteger(); private final S state; - protected static final class InProgressTail { + protected static final class InProgressTail { + + final CombinerExecutor combiner; Task task; + Map, Task> others; + + public InProgressTail(CombinerExecutor combiner, Task task) { + this.combiner = combiner; + this.task = task; + } } - private final FastThreadLocal current = new FastThreadLocal<>(); + private static final FastThreadLocal> current = new FastThreadLocal<>(); public CombinerExecutor(S state) { this.state = state; @@ -72,23 +82,42 @@ public void submit(Action action) { } } while (!q.isEmpty() && s.compareAndSet(0, 1)); if (head != null) { - InProgressTail inProgress = current.get(); + InProgressTail inProgress = (InProgressTail) current.get(); if (inProgress == null) { - inProgress = new InProgressTail(); + inProgress = new InProgressTail<>(this, tail); current.set(inProgress); - inProgress.task = tail; try { // from now one cannot trust tail anymore head.runNextTasks(); + assert inProgress.others == null || inProgress.others.isEmpty(); } finally { current.remove(); } } else { - assert inProgress.task != null; - Task oldNextTail = inProgress.task.replaceNext(head); - assert oldNextTail == null; - inProgress.task = tail; - + if (inProgress.combiner == this) { + Task oldNextTail = inProgress.task.replaceNext(head); + assert oldNextTail == null; + inProgress.task = tail; + } else { + Map, Task> map = inProgress.others; + if (map == null) { + map = inProgress.others = new HashMap<>(1); + } + Task task = map.get(this); + if (task == null) { + map.put(this, tail); + try { + // from now one cannot trust tail anymore + head.runNextTasks(); + } finally { + map.remove(this); + } + } else { + Task oldNextTail = task.replaceNext(head); + assert oldNextTail == null; + map.put(this, tail); + } + } } } } diff --git a/src/test/java/io/vertx/core/net/impl/pool/ConnectionPoolTest.java b/src/test/java/io/vertx/core/net/impl/pool/ConnectionPoolTest.java index 0ed540cf6f8..0b512b77efb 100644 --- a/src/test/java/io/vertx/core/net/impl/pool/ConnectionPoolTest.java +++ b/src/test/java/io/vertx/core/net/impl/pool/ConnectionPoolTest.java @@ -958,14 +958,14 @@ public void testPostTasksTrampoline() throws Exception { List res = Collections.synchronizedList(new LinkedList<>()); AtomicInteger seq = new AtomicInteger(); CountDownLatch latch = new CountDownLatch(1 + numAcquires); + int[] count = new int[1]; ConnectionPool pool = ConnectionPool.pool(new PoolConnector() { - int count = 0; int reentrancy = 0; @Override public Future> connect(ContextInternal context, Listener listener) { assertEquals(0, reentrancy++); try { - int val = count++; + int val = count[0]++; if (val == 0) { // Queue extra requests for (int i = 0;i < numAcquires;i++) { @@ -975,6 +975,7 @@ public Future> connect(ContextInternal context, Listen latch.countDown(); })); } + assertEquals(1, count[0]); } return Future.failedFuture("failure"); } finally { @@ -995,10 +996,81 @@ public boolean isValid(Connection connection) { })); }); awaitLatch(latch); + assertEquals(1 + numAcquires, count[0]); List expected = IntStream.concat(IntStream.range(1, numAcquires + 1), IntStream.of(0)).boxed().collect(Collectors.toList()); assertEquals(expected, res); } + @Test + public void testConcurrentPostTasksTrampoline() throws Exception { + AtomicReference> ref1 = new AtomicReference<>(); + AtomicReference> ref2 = new AtomicReference<>(); + ContextInternal ctx = vertx.createEventLoopContext(); + List res = Collections.synchronizedList(new LinkedList<>()); + CountDownLatch latch = new CountDownLatch(4); + ConnectionPool pool1 = ConnectionPool.pool(new PoolConnector<>() { + int count = 0; + int reentrancy = 0; + @Override + public Future> connect(ContextInternal context, Listener listener) { + assertEquals(0, reentrancy++); + try { + int val = count++; + if (val == 0) { + ref1.get().acquire(ctx, 0, onFailure(err -> { + res.add(1); + latch.countDown(); + })); + ref2.get().acquire(ctx, 0, onFailure(err -> { + res.add(2); + latch.countDown(); + })); + } + return Future.failedFuture("failure"); + } finally { + reentrancy--; + } + } + @Override + public boolean isValid(Connection connection) { + return true; + } + }, new int[]{1}, 2); + ConnectionPool pool2 = ConnectionPool.pool(new PoolConnector<>() { + int count = 0; + int reentrancy = 0; + @Override + public Future> connect(ContextInternal context, Listener listener) { + assertEquals(0, reentrancy++); + try { + int val = count++; + if (val == 0) { + ref2.get().acquire(ctx, 0, onFailure(err -> { + res.add(3); + latch.countDown(); + })); + ref1.get().acquire(ctx, 0, onFailure(err -> { + res.add(4); + latch.countDown(); + })); + } + return Future.failedFuture("failure"); + } finally { + reentrancy--; + } + } + @Override + public boolean isValid(Connection connection) { + return true; + } + }, new int[]{1}, 2); + ref1.set(pool1); + ref2.set(pool2); + pool1.acquire(ctx, 0, onFailure(err -> res.add(0))); + awaitLatch(latch); +// assertEquals(Arrays.asList(0, 2, 1, 3, 4), res); + } + static class Connection { public Connection() { } diff --git a/src/test/java/io/vertx/core/net/impl/pool/SynchronizationTest.java b/src/test/java/io/vertx/core/net/impl/pool/SynchronizationTest.java index ff3fbc97a67..2bd1903eaeb 100644 --- a/src/test/java/io/vertx/core/net/impl/pool/SynchronizationTest.java +++ b/src/test/java/io/vertx/core/net/impl/pool/SynchronizationTest.java @@ -10,12 +10,16 @@ */ package io.vertx.core.net.impl.pool; +import io.netty.util.concurrent.FastThreadLocal; import io.vertx.test.core.AsyncTestBase; import org.junit.Assume; import org.junit.Test; import java.lang.management.ManagementFactory; import java.lang.management.ThreadMXBean; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; @@ -71,13 +75,57 @@ public void run() { assertFalse(isReentrant2.get()); } + @Test + public void testActionReentrancy2() throws Exception { + List log = new LinkedList<>(); + Executor combiner1 = new CombinerExecutor<>(new Object()); + Executor combiner2 = new CombinerExecutor<>(new Object()); + int[] reentrancy = new int[2]; + combiner1.submit(state1 -> taskOf(() -> { + assertEquals(0, reentrancy[0]++); + combiner1.submit(state2 -> taskOf(() -> { + assertEquals(0, reentrancy[0]++); + log.add(0); + reentrancy[0]--; + })); + combiner2.submit(state2 -> taskOf(() -> { + assertEquals(0, reentrancy[1]++); + log.add(1); + combiner1.submit(state3 -> taskOf(() -> { + assertEquals(0, reentrancy[0]++); + log.add(2); + reentrancy[0]--; + })); + combiner2.submit(state3 -> taskOf(() -> { + assertEquals(0, reentrancy[1]++); + log.add(3); + reentrancy[1]--; + })); + reentrancy[1]--; + })); + reentrancy[0]--; + })); + assertEquals(0, reentrancy[0]); + assertEquals(0, reentrancy[1]); + assertEquals(Arrays.asList(1, 3, 0, 2), log); + } + + static Task taskOf(Runnable runnable) { + return new Task() { + @Override + public void run() { + runnable.run(); + } + }; + } + @Test public void testFoo() throws Exception { Assume.assumeFalse(io.vertx.core.impl.Utils.isWindows()); int numThreads = 8; int numIter = 1_000 * 100; Executor sync = new CombinerExecutor<>(new Object()); - Executor.Action action = s -> { + Executor.Action action = s -> { burnCPU(10); return null; }; @@ -166,4 +214,22 @@ public void run() { }); assertEquals(3, order.get()); } + + @Test + public void testFastThreadLocalStability() { + CombinerExecutor executor = new CombinerExecutor<>(null); + int expected = io.netty.util.internal.InternalThreadLocalMap.lastVariableIndex(); + AtomicInteger counter = new AtomicInteger(); + for (int i = 0;i < 1000;i++) { + executor = new CombinerExecutor<>(null); + executor.submit(state -> new Task() { + @Override + public void run() { + counter.incrementAndGet(); + } + }); + assertEquals(i + 1, counter.get()); + } + assertEquals(expected, io.netty.util.internal.InternalThreadLocalMap.lastVariableIndex()); + } }