diff --git a/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java b/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java index 7646fcdfc9..219bc67c24 100644 --- a/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java +++ b/spark/src/main/java/org/apache/spark/CometTaskMemoryManager.java @@ -20,6 +20,10 @@ package org.apache.spark; import java.io.IOException; +import java.util.concurrent.atomic.AtomicLong; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.MemoryMode; @@ -30,40 +34,46 @@ * memory manager. This assumes Spark's off-heap memory mode is enabled. */ public class CometTaskMemoryManager { + + private static final Logger logger = LoggerFactory.getLogger(CometTaskMemoryManager.class); + /** The id uniquely identifies the native plan this memory manager is associated to */ private final long id; private final TaskMemoryManager internal; private final NativeMemoryConsumer nativeMemoryConsumer; - private long used; + private final AtomicLong used = new AtomicLong(); public CometTaskMemoryManager(long id) { this.id = id; this.internal = TaskContext$.MODULE$.get().taskMemoryManager(); this.nativeMemoryConsumer = new NativeMemoryConsumer(); - this.used = 0; } // Called by Comet native through JNI. // Returns the actual amount of memory (in bytes) granted. public long acquireMemory(long size) { long acquired = internal.acquireExecutionMemory(size, nativeMemoryConsumer); + used.addAndGet(acquired); if (acquired < size) { // If memory manager is not able to acquire the requested size, log memory usage internal.showMemoryUsage(); } - used += acquired; return acquired; } // Called by Comet native through JNI public void releaseMemory(long size) { - used -= size; + long newUsed = used.addAndGet(-size); + if (newUsed < 0) { + logger.error( + "Used memory is negative: " + newUsed + " after releasing memory chunk of: " + size); + } internal.releaseExecutionMemory(size, nativeMemoryConsumer); } public long getUsed() { - return used; + return used.get(); } /**