Skip to content

Commit

Permalink
Performance improvement for concurrent operations and matrix multipli…
Browse files Browse the repository at this point in the history
…cation.

- Replaced ThreadManager.concurrentLoop methods with ThreadManager.concurrentOperation and ThreadManager.concurrentBlockedOperation. These methods offer a slight performance enhancement over the old methods as they have remove some overhead related to IntStream's.

- The way some values are indexed in matrix multiplication methods has been updated to reduce the total number of array axes and the number of index computations. This will should result in a slight performance improvement for large matrices.
  • Loading branch information
jacobdwatters committed Aug 9, 2024
1 parent d7313b0 commit 7a7c1ca
Show file tree
Hide file tree
Showing 27 changed files with 1,404 additions and 1,007 deletions.
34 changes: 5 additions & 29 deletions src/main/java/org/flag4j/concurrency/Configurations.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
/**
* Configurations for standard and concurrent operations.
*/
public abstract class Configurations {
public final class Configurations {
private Configurations() {
throw new IllegalStateException(ErrorMessages.getUtilityClassErrMsg());
}
Expand All @@ -41,29 +41,24 @@ private Configurations() {
/**
* The default block size for blocked algorithms.
*/
private static final int DEFAULT_BLOCK_SIZE = 64;
public static final int DEFAULT_BLOCK_SIZE = 64;
/**
* The default minimum recursive size for recursive algorithms.
*/
private static final int DEFAULT_MIN_RECURSIVE_SIZE = 128;

public static final int DEFAULT_MIN_RECURSIVE_SIZE = 128;
/**
* The block size to use in blocked algorithms.
*/
private static int blockSize = DEFAULT_BLOCK_SIZE;

/**
* The minimum size of tensor/matrix/vector to make recursive calls on in recursive algorithms.
*/
private static int minRecursiveSize = DEFAULT_MIN_RECURSIVE_SIZE;


/**
* Sets the number of threads for use in concurrent operations as the number of processors available to the Java
* virtual machine. Note that this value may change during runtime. This method will include logical cores so the value
* returned may be higher than the number of physical cores on the machine if hyper-threading is enabled.
* <br><br>
* This is implemented as: <code>numThreads = {@link Runtime#availableProcessors() Runtime.getRuntime().availableProcessors()};</code>
* @implNote This is implemented as:
* <code>numThreads = {@link Runtime#availableProcessors() Runtime.getRuntime().availableProcessors()};</code>
* @return The new value of numThreads, i.e. the number of available processors.
*/
public static int setNumThreadsAsAvailableProcessors() {
Expand Down Expand Up @@ -108,30 +103,11 @@ public static void setBlockSize(int blockSize) {
}


/**
* Gets the minimum size of tensor/matrix/vector to make recursive calls on in recursive algorithms.
* @return minimum size of tensor/matrix/vector to make recursive calls on in recursive algorithms.
*/
public static int getMinRecursiveSize() {
return minRecursiveSize;
}


/**
* Sets the minimum size of tensor/matrix/vector to make recursive calls on in recursive algorithms.
* @param minRecursiveSize New minimum size.
*/
public static void setMinRecursiveSize(int minRecursiveSize) {
Configurations.minRecursiveSize = Math.max(1, minRecursiveSize);
}


/**
* Resets all configurations to their default values.
*/
public static void resetAll() {
ThreadManager.setParallelismLevel(DEFAULT_NUM_THREADS);
blockSize = DEFAULT_BLOCK_SIZE;
minRecursiveSize = DEFAULT_MIN_RECURSIVE_SIZE;
}
}
40 changes: 40 additions & 0 deletions src/main/java/org/flag4j/concurrency/TensorOperation.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* MIT License
*
* Copyright (c) 2024. Jacob Watters
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

package org.flag4j.concurrency;


/**
* Functional interface for general tensor operation.
*/
@FunctionalInterface
public interface TensorOperation {

/**
* Applies a tensor operation over the specified index range.
* @param startIdx Staring index for operation.
* @param endIdx Ending index for operation.
*/
void apply(int startIdx, int endIdx);
}
140 changes: 101 additions & 39 deletions src/main/java/org/flag4j/concurrency/ThreadManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,16 @@

import org.flag4j.util.ErrorMessages;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.function.IntConsumer;
import java.util.logging.Level;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.*;
import java.util.logging.Logger;
import java.util.stream.IntStream;

/**
* This class contains the base thread pool for all concurrent operations and several methods for managing the
* pool.
*/
public class ThreadManager {
public final class ThreadManager {
private ThreadManager() {
// Hide default constructor for utility class.
throw new IllegalStateException(ErrorMessages.getUtilityClassErrMsg());
Expand Down Expand Up @@ -68,7 +64,7 @@ private ThreadManager() {
/**
* Thread pool for managing threads executing concurrent operations.
*/
protected static ExecutorService threadPool = Executors.newFixedThreadPool(parallelismLevel, daemonFactory);
private static ThreadPoolExecutor threadPool = (ThreadPoolExecutor) Executors.newFixedThreadPool(parallelismLevel, daemonFactory);


/**
Expand All @@ -78,7 +74,7 @@ private ThreadManager() {
*/
protected static void setParallelismLevel(int parallelismLevel) {
ThreadManager.parallelismLevel = Math.max(parallelismLevel, 1);
threadPool = Executors.newFixedThreadPool(parallelismLevel, daemonFactory);
threadPool.setCorePoolSize(parallelismLevel);
}


Expand All @@ -92,44 +88,110 @@ public static int getParallelismLevel() {


/**
* Applies a concurrent loop to a function.
* @param startIndex Starting index for concurrent loop (inclusive).
* @param endIndex Ending index for concurrent loop (exclusive).
* @param function Function to apply each iteration. Function may be dependent on iteration index but should
* individual iterations should be independent of each other.
* Computes a specified tensor operation concurrently by evenly dividing work amoung available threads (specified by
* {@link Configurations#getNumThreads()}).
* @param totalSize Total size of the outer loop for the operation.
* @param operation Operation to be computed.
*/
public static void concurrentLoop(int startIndex, int endIndex, IntConsumer function) {
try {
threadPool.submit(() -> IntStream.range(startIndex, endIndex).parallel().forEach(function)).get();
} catch (InterruptedException | ExecutionException e) {
threadLogger.setLevel(Level.WARNING);
threadLogger.warning(e.getMessage());
Thread.currentThread().interrupt();
public static void concurrentOperation(final int totalSize, final TensorOperation operation) {
// Calculate chunk size.
int chunkSize = (totalSize + parallelismLevel - 1) / parallelismLevel;
List<Future<?>> futures = new ArrayList<>(parallelismLevel);

for(int threadIndex = 0; threadIndex < parallelismLevel; threadIndex++) {
final int startIdx = threadIndex * chunkSize;
final int endIdx = Math.min(startIdx + chunkSize, totalSize);

if(startIdx >= endIdx) break; // No more indices to process.

futures.add(ThreadManager.threadPool.submit(() -> {
operation.apply(startIdx, endIdx);
}));
}

// Wait for all tasks to complete.
for(Future<?> future : futures) {
try {
future.get(); // Ensure all tasks are complete.
} catch (InterruptedException | ExecutionException e) {
// An exception occured.
threadLogger.warning(e.getMessage());
Thread.currentThread().interrupt();
}
}
}


/**
* Applies a concurrent strided-loop to a function.
* @param startIndex Starting index for concurrent loop (inclusive).
* @param endIndex Ending index for concurrent loop (exclusive).
* @param step Step size for the index variable of the loop (i.e. the stride size).
* @param function Function to apply each iteration. Function may be dependent on iteration index but should
* individual iterations should be independent of each other.
* Computes a specified blocked tensor operation concurrently by evenly dividing work amoung available threads (specified by
* {@link Configurations#getNumThreads()}).
* @param totalSize Total size of the outer loop for the operation.
* @param blockSize Size of the block used in the blocekdOperation.
* @param blockedOperation Operation to be computed.
*/
public static void concurrentBlockedOperation(final int totalSize, final int blockSize, final TensorOperation blockedOperation) {
// Calculate chunk size for blocks.
int numBlocks = (totalSize + blockSize - 1) / blockSize;
List<Future<?>> futures = new ArrayList<>(parallelismLevel);

for(int blockIndex = 0; blockIndex < numBlocks; blockIndex++) {
final int startBlock = blockIndex * blockSize;
final int endBlock = Math.min(startBlock + blockSize, totalSize);

futures.add(threadPool.submit(() -> {
blockedOperation.apply(startBlock, endBlock);
}));
}

// Wait for all tasks to complete.
for(Future<?> future : futures) {
try {
future.get(); // Ensure all tasks are complete.
} catch (InterruptedException | ExecutionException e) {
// An exception occured.
threadLogger.warning(e.getMessage());
Thread.currentThread().interrupt();
}
}
}

// TODO: TEMP FOR TESTING.
/**
* Executes a concurrent operation on a given range of indices.
* The operation is split across multiple threads, each handling a subset of the range.
*
* @param totalTasks The total number of tasks (e.g., rows in a matrix) to be processed.
* @param task A lambda expression or function that takes three arguments: start index, end index, and thread ID.
* This function represents the work to be done by each thread for its assigned range.
*/
public static void concurrentLoop(int startIndex, int endIndex, int step, IntConsumer function) {
if(step <= 0)
throw new IllegalArgumentException(ErrorMessages.getNegValueErr(startIndex));
public static void concurrentOperation(int totalTasks, TriConsumer<Integer, Integer, Integer> task) {
int numThreads = Runtime.getRuntime().availableProcessors();
ExecutorService executor = Executors.newFixedThreadPool(numThreads);

int tasksPerThread = (totalTasks + numThreads - 1) / numThreads; // Ceiling division

for (int threadId = 0; threadId < numThreads; threadId++) {
int startIdx = threadId * tasksPerThread;
int endIdx = Math.min(startIdx + tasksPerThread, totalTasks);

if (startIdx < endIdx) {
final int finalThreadId = threadId;
executor.submit(() -> task.accept(startIdx, endIdx, finalThreadId));
}
}

executor.shutdown();

try {
int range = endIndex - startIndex;
int iterations = range/step + ((range%step == 0) ? 0 : 1);
threadPool.submit(() -> IntStream.range(0, iterations).parallel().forEach(
i -> function.accept(startIndex + i*step))
).get();
} catch (InterruptedException | ExecutionException e) {
threadLogger.setLevel(Level.WARNING);
threadLogger.warning(e.getMessage());
executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Thread execution interrupted", e);
}
}

@FunctionalInterface
public interface TriConsumer<T, U, V> {
void accept(T t, U u, V v);
}
}
9 changes: 0 additions & 9 deletions src/main/java/org/flag4j/core/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

package org.flag4j.core;

import org.flag4j.arrays.dense.Tensor;
import org.flag4j.util.ArrayUtils;
import org.flag4j.util.ParameterChecks;

Expand Down Expand Up @@ -311,12 +310,4 @@ public String toString() {

return joiner.toString();
}


public static void main(String[] args) {
Shape s = new Shape();
Tensor t = new Tensor(s);

System.out.println(t.entries.length);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ public static CNumber[] elemDivConcurrent(CNumber[] src1, Shape shape1, CNumber[
ParameterChecks.assertEqualShape(shape1, shape2);
CNumber[] product = new CNumber[src1.length];

ThreadManager.concurrentLoop(0, product.length,
(i)->product[i] = src1[i].div(src2[i]));
ThreadManager.concurrentOperation(product.length, (start, end)->{
for(int i=start; i<end; i++) {
product[i] = src1[i].div(src2[i]);
}
});

return product;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public class ComplexDenseElemMult {
/**
* Minimum number of entries in each tensor to apply concurrent algorithm.
*/
private static final int CONCURRENT_THRESHOLD = 50625;
private static final int CONCURRENT_THRESHOLD = 50_000;


private ComplexDenseElemMult() {
Expand Down Expand Up @@ -81,9 +81,11 @@ public static CNumber[] elemMultConcurrent(CNumber[] src1, Shape shape1, CNumber
ParameterChecks.assertEqualShape(shape1, shape2);
CNumber[] product = new CNumber[src1.length];

ThreadManager.concurrentLoop(0, product.length,
(i)->product[i] = src1[i].mult(src2[i])
);
ThreadManager.concurrentOperation(product.length, ((startIdx, endIdx) -> {
for(int i=startIdx; i<endIdx; i++) {
product[i] = src1[i].mult(src2[i]);
}
}));

return product;
}
Expand Down
Loading

0 comments on commit 7a7c1ca

Please sign in to comment.