Skip to content

Commit

Permalink
Avoid repeated calculations of state count under each filter
Browse files Browse the repository at this point in the history
  • Loading branch information
Sothatsit committed Feb 22, 2024
1 parent 6c547ff commit 44c2e12
Showing 1 changed file with 20 additions and 31 deletions.
51 changes: 20 additions & 31 deletions src/main/java/net/royalur/lut/LutTrainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -407,34 +407,21 @@ private double performTrainingIterationSection(

private double performTrainingIteration(
Lut<R> lut,
int stateCount,
Function<FastSimpleGame, Boolean> stateFilter
) {
// Calculate the number of states.
AtomicInteger stateCount = new AtomicInteger(0);
loopLightGameStates((game) -> {
if (game.isFinished || !stateFilter.apply(game))
return;

stateCount.incrementAndGet();
});
int threadCount = Runtime.getRuntime().availableProcessors();

if (stateCount.get() < threadCount) {
return performTrainingIterationSection(
lut, stateFilter, 0, stateCount.get()
);
}

// Split up the keys between threads for processing.
int statesPerThread = (stateCount.get() + threadCount - 1) / threadCount;
int threadCount = Runtime.getRuntime().availableProcessors();
int statesPerThread = (stateCount + threadCount - 1) / threadCount;
AtomicReference<Double> maxChange = new AtomicReference<>(0.0d);

List<Thread> threads = new ArrayList<>();
AtomicReference<Exception> error = new AtomicReference<>();

for (int threadNo = 0; threadNo < threadCount; ++threadNo) {

int fromIndex = statesPerThread * threadNo;
int toIndex = Math.min(stateCount.get(), statesPerThread * (threadNo + 1));
int toIndex = statesPerThread * (threadNo + 1);

Thread thread = new Thread(() -> {
try {
Expand All @@ -454,6 +441,8 @@ private double performTrainingIteration(
threads.add(thread);
thread.start();
}

// Wait for all processing to complete.
try {
for (Thread thread : threads) {
thread.join();
Expand Down Expand Up @@ -487,21 +476,20 @@ public Lut<R> train(
int pieceCount = settings.getStartingPieceCount();
for (int minScore = pieceCount - 1; minScore >= 0; --minScore) {
for (int maxScore = pieceCount - 1; maxScore >= minScore; --maxScore) {

int minScoreFinal = minScore;
int maxScoreFinal = maxScore;
Function<FastSimpleGame, Boolean> stateFilter = game -> {
int min = Math.min(game.light.score, game.dark.score);
int max = Math.max(game.light.score, game.dark.score);
return min == minScoreFinal && max == maxScoreFinal;
};
int stateCount = countStates(stateFilter);

double maxChange;
do {
long start = System.nanoTime();

int minScoreFinal = minScore;
int maxScoreFinal = maxScore;
maxChange = performTrainingIteration(
lut,
game -> {
int min = Math.min(game.light.score, game.dark.score);
int max = Math.max(game.light.score, game.dark.score);
return min == minScoreFinal && max == maxScoreFinal;
}
);

maxChange = performTrainingIteration(lut, stateCount, stateFilter);
double durationMs = (System.nanoTime() - start) / 1e6;
System.out.printf(
"%d. scores = [%d, %d], max diff = %.3f (%s ms)\n",
Expand All @@ -524,11 +512,12 @@ public Lut<R> train(
System.out.println();
System.out.println("Finished progressive value iteration!");
System.out.println("Starting full value iteration for 10 steps...");
int stateCount = countStates();
for (int index = 0; index < 10; ++index) {
long start = System.nanoTime();

double maxChange = performTrainingIteration(
lut, game -> true
lut, stateCount, game -> true
);
double durationMs = (System.nanoTime() - start) / 1e6;
System.out.printf(
Expand Down

0 comments on commit 44c2e12

Please sign in to comment.