From c03726e8516fa631c0ae839d1f29a50c6be9bbb0 Mon Sep 17 00:00:00 2001 From: Julien Viet Date: Fri, 26 Jan 2024 14:32:55 +0100 Subject: [PATCH] The CombinerExecutor class creates an instance of FastThreadLocal for each combiner executor leading to an increase of the InternalThreadLocalMap index. Consequently each thread local map of FastThreadLocalThread will get a new map sized accordingly, leading to an eventual memory leak. Use a static FastThreadLocal in CombinerExecutor instead of a instance field, the data structure stored in this thread local map keeps track of the CombinerExecutor running in order to allow interleaved execution of CombinerExecutor post tasks without interfering each other. The structure is optimized for the most frequent case. --- .../core/net/impl/pool/CombinerExecutor.java | 49 +++++++++--- .../net/impl/pool/ConnectionPoolTest.java | 76 ++++++++++++++++++- .../net/impl/pool/SynchronizationTest.java | 68 ++++++++++++++++- 3 files changed, 180 insertions(+), 13 deletions(-) diff --git a/src/main/java/io/vertx/core/net/impl/pool/CombinerExecutor.java b/src/main/java/io/vertx/core/net/impl/pool/CombinerExecutor.java index 5acdd797c43..052904d2eec 100644 --- a/src/main/java/io/vertx/core/net/impl/pool/CombinerExecutor.java +++ b/src/main/java/io/vertx/core/net/impl/pool/CombinerExecutor.java @@ -13,6 +13,8 @@ import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.internal.PlatformDependent; +import java.util.HashMap; +import java.util.Map; import java.util.Queue; import java.util.concurrent.atomic.AtomicInteger; @@ -29,11 +31,19 @@ public class CombinerExecutor implements Executor { private final AtomicInteger s = new AtomicInteger(); private final S state; - protected static final class InProgressTail { + protected static final class InProgressTail { + + final CombinerExecutor combiner; Task task; + Map, Task> others; + + public InProgressTail(CombinerExecutor combiner, Task task) { + this.combiner = combiner; + this.task = task; + } } - private final FastThreadLocal current = new FastThreadLocal<>(); + private static final FastThreadLocal> current = new FastThreadLocal<>(); public CombinerExecutor(S state) { this.state = state; @@ -72,23 +82,42 @@ public void submit(Action action) { } } while (!q.isEmpty() && s.compareAndSet(0, 1)); if (head != null) { - InProgressTail inProgress = current.get(); + InProgressTail inProgress = (InProgressTail) current.get(); if (inProgress == null) { - inProgress = new InProgressTail(); + inProgress = new InProgressTail<>(this, tail); current.set(inProgress); - inProgress.task = tail; try { // from now one cannot trust tail anymore head.runNextTasks(); + assert inProgress.others == null || inProgress.others.isEmpty(); } finally { current.remove(); } } else { - assert inProgress.task != null; - Task oldNextTail = inProgress.task.replaceNext(head); - assert oldNextTail == null; - inProgress.task = tail; - + if (inProgress.combiner == this) { + Task oldNextTail = inProgress.task.replaceNext(head); + assert oldNextTail == null; + inProgress.task = tail; + } else { + Map, Task> map = inProgress.others; + if (map == null) { + map = inProgress.others = new HashMap<>(1); + } + Task task = map.get(this); + if (task == null) { + map.put(this, tail); + try { + // from now one cannot trust tail anymore + head.runNextTasks(); + } finally { + map.remove(this); + } + } else { + Task oldNextTail = task.replaceNext(head); + assert oldNextTail == null; + map.put(this, tail); + } + } } } } diff --git a/src/test/java/io/vertx/core/net/impl/pool/ConnectionPoolTest.java b/src/test/java/io/vertx/core/net/impl/pool/ConnectionPoolTest.java index 0ed540cf6f8..0b512b77efb 100644 --- a/src/test/java/io/vertx/core/net/impl/pool/ConnectionPoolTest.java +++ b/src/test/java/io/vertx/core/net/impl/pool/ConnectionPoolTest.java @@ -958,14 +958,14 @@ public void testPostTasksTrampoline() throws Exception { List res = Collections.synchronizedList(new LinkedList<>()); AtomicInteger seq = new AtomicInteger(); CountDownLatch latch = new CountDownLatch(1 + numAcquires); + int[] count = new int[1]; ConnectionPool pool = ConnectionPool.pool(new PoolConnector() { - int count = 0; int reentrancy = 0; @Override public Future> connect(ContextInternal context, Listener listener) { assertEquals(0, reentrancy++); try { - int val = count++; + int val = count[0]++; if (val == 0) { // Queue extra requests for (int i = 0;i < numAcquires;i++) { @@ -975,6 +975,7 @@ public Future> connect(ContextInternal context, Listen latch.countDown(); })); } + assertEquals(1, count[0]); } return Future.failedFuture("failure"); } finally { @@ -995,10 +996,81 @@ public boolean isValid(Connection connection) { })); }); awaitLatch(latch); + assertEquals(1 + numAcquires, count[0]); List expected = IntStream.concat(IntStream.range(1, numAcquires + 1), IntStream.of(0)).boxed().collect(Collectors.toList()); assertEquals(expected, res); } + @Test + public void testConcurrentPostTasksTrampoline() throws Exception { + AtomicReference> ref1 = new AtomicReference<>(); + AtomicReference> ref2 = new AtomicReference<>(); + ContextInternal ctx = vertx.createEventLoopContext(); + List res = Collections.synchronizedList(new LinkedList<>()); + CountDownLatch latch = new CountDownLatch(4); + ConnectionPool pool1 = ConnectionPool.pool(new PoolConnector<>() { + int count = 0; + int reentrancy = 0; + @Override + public Future> connect(ContextInternal context, Listener listener) { + assertEquals(0, reentrancy++); + try { + int val = count++; + if (val == 0) { + ref1.get().acquire(ctx, 0, onFailure(err -> { + res.add(1); + latch.countDown(); + })); + ref2.get().acquire(ctx, 0, onFailure(err -> { + res.add(2); + latch.countDown(); + })); + } + return Future.failedFuture("failure"); + } finally { + reentrancy--; + } + } + @Override + public boolean isValid(Connection connection) { + return true; + } + }, new int[]{1}, 2); + ConnectionPool pool2 = ConnectionPool.pool(new PoolConnector<>() { + int count = 0; + int reentrancy = 0; + @Override + public Future> connect(ContextInternal context, Listener listener) { + assertEquals(0, reentrancy++); + try { + int val = count++; + if (val == 0) { + ref2.get().acquire(ctx, 0, onFailure(err -> { + res.add(3); + latch.countDown(); + })); + ref1.get().acquire(ctx, 0, onFailure(err -> { + res.add(4); + latch.countDown(); + })); + } + return Future.failedFuture("failure"); + } finally { + reentrancy--; + } + } + @Override + public boolean isValid(Connection connection) { + return true; + } + }, new int[]{1}, 2); + ref1.set(pool1); + ref2.set(pool2); + pool1.acquire(ctx, 0, onFailure(err -> res.add(0))); + awaitLatch(latch); +// assertEquals(Arrays.asList(0, 2, 1, 3, 4), res); + } + static class Connection { public Connection() { } diff --git a/src/test/java/io/vertx/core/net/impl/pool/SynchronizationTest.java b/src/test/java/io/vertx/core/net/impl/pool/SynchronizationTest.java index ff3fbc97a67..2bd1903eaeb 100644 --- a/src/test/java/io/vertx/core/net/impl/pool/SynchronizationTest.java +++ b/src/test/java/io/vertx/core/net/impl/pool/SynchronizationTest.java @@ -10,12 +10,16 @@ */ package io.vertx.core.net.impl.pool; +import io.netty.util.concurrent.FastThreadLocal; import io.vertx.test.core.AsyncTestBase; import org.junit.Assume; import org.junit.Test; import java.lang.management.ManagementFactory; import java.lang.management.ThreadMXBean; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicBoolean; @@ -71,13 +75,57 @@ public void run() { assertFalse(isReentrant2.get()); } + @Test + public void testActionReentrancy2() throws Exception { + List log = new LinkedList<>(); + Executor combiner1 = new CombinerExecutor<>(new Object()); + Executor combiner2 = new CombinerExecutor<>(new Object()); + int[] reentrancy = new int[2]; + combiner1.submit(state1 -> taskOf(() -> { + assertEquals(0, reentrancy[0]++); + combiner1.submit(state2 -> taskOf(() -> { + assertEquals(0, reentrancy[0]++); + log.add(0); + reentrancy[0]--; + })); + combiner2.submit(state2 -> taskOf(() -> { + assertEquals(0, reentrancy[1]++); + log.add(1); + combiner1.submit(state3 -> taskOf(() -> { + assertEquals(0, reentrancy[0]++); + log.add(2); + reentrancy[0]--; + })); + combiner2.submit(state3 -> taskOf(() -> { + assertEquals(0, reentrancy[1]++); + log.add(3); + reentrancy[1]--; + })); + reentrancy[1]--; + })); + reentrancy[0]--; + })); + assertEquals(0, reentrancy[0]); + assertEquals(0, reentrancy[1]); + assertEquals(Arrays.asList(1, 3, 0, 2), log); + } + + static Task taskOf(Runnable runnable) { + return new Task() { + @Override + public void run() { + runnable.run(); + } + }; + } + @Test public void testFoo() throws Exception { Assume.assumeFalse(io.vertx.core.impl.Utils.isWindows()); int numThreads = 8; int numIter = 1_000 * 100; Executor sync = new CombinerExecutor<>(new Object()); - Executor.Action action = s -> { + Executor.Action action = s -> { burnCPU(10); return null; }; @@ -166,4 +214,22 @@ public void run() { }); assertEquals(3, order.get()); } + + @Test + public void testFastThreadLocalStability() { + CombinerExecutor executor = new CombinerExecutor<>(null); + int expected = io.netty.util.internal.InternalThreadLocalMap.lastVariableIndex(); + AtomicInteger counter = new AtomicInteger(); + for (int i = 0;i < 1000;i++) { + executor = new CombinerExecutor<>(null); + executor.submit(state -> new Task() { + @Override + public void run() { + counter.incrementAndGet(); + } + }); + assertEquals(i + 1, counter.get()); + } + assertEquals(expected, io.netty.util.internal.InternalThreadLocalMap.lastVariableIndex()); + } }