Skip to content

Commit

Permalink
[PEx] Correct timelines tracking and RL-based choice selection
Browse files Browse the repository at this point in the history
  • Loading branch information
aman-goel committed Aug 8, 2024
1 parent 0174966 commit 83a90b2
Show file tree
Hide file tree
Showing 16 changed files with 143 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ private static void setup() {
RandomNumberGenerator.setup(PExplicitGlobal.getConfig().getRandomSeed());
MemoryMonitor.setup(PExplicitGlobal.getConfig().getMemLimit());
TimeMonitor.setup(PExplicitGlobal.getConfig().getTimeLimit());
// initialize stats writer
StatWriter.Initialize();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ private static void process(boolean resume) throws Exception {

String schFile = PExplicitGlobal.getConfig().getOutputFolder() + "/" + PExplicitGlobal.getConfig().getProjectName() + "_0_0.schedule";
PExplicitLogger.logInfo(String.format("Writing buggy trace in %s", schFile));
scheduler.schedule.writeToFile(schFile);
scheduler.getSchedule().writeToFile(schFile);

ReplayScheduler replayer = new ReplayScheduler(scheduler.schedule);
ReplayScheduler replayer = new ReplayScheduler(scheduler.getSchedule());
PExplicitGlobal.setScheduler(replayer);
try {
replayer.run();
Expand Down Expand Up @@ -188,9 +188,6 @@ public static void replay() throws Exception {
}

public static void run() throws Exception {
// initialize stats writer
StatWriter.Initialize();

if (PExplicitGlobal.getConfig().getSearchStrategyMode() == SearchStrategyMode.Replay) {
replay();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ public class PExplicitOptions {
.hasArg()
.argName("Mode (string)")
.build();
addOption(choiceSelect);
addHiddenOption(choiceSelect);

/*
* Help menu options
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
package pexplicit.runtime.machine;

import lombok.Getter;
import lombok.Setter;
import org.apache.commons.lang3.tuple.ImmutablePair;
import pexplicit.values.PEvent;

import java.io.Serializable;
import java.util.List;
import java.util.Set;

/**
* Represents the local state of a machine
*/
@Getter
public class MachineLocalState implements Serializable {
/**
* List of values of all local variables (including internal variables like currentState, FIFO queue, etc.)
*/
@Getter
@Setter
private List<Object> locals;
private final List<Object> locals;
private final Set<PEvent> observedEvents;
private final Set<ImmutablePair<PEvent, PEvent>> happensBeforePairs;
private final int timelineHash;

public MachineLocalState(List<Object> locals) {
public MachineLocalState(List<Object> locals, Set<PEvent> observedEvents, Set<ImmutablePair<PEvent, PEvent>> happensBeforePairs) {
this.locals = locals;
this.observedEvents = observedEvents;
this.happensBeforePairs = happensBeforePairs;
this.timelineHash = happensBeforePairs.hashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,6 @@ public List<String> getLocalVarNames() {
result.add("_started");
result.add("_halted");

result.add("_observedEvents");
result.add("_happensBefore");

result.add("_blockedBy");
result.add("_blockedStateExit");
result.add("_blockedNewStateEntry");
Expand Down Expand Up @@ -218,9 +215,6 @@ public List<Object> getLocalVarValues() {
result.add(started);
result.add(halted);

result.add(observedEvents);
result.add(happensBeforePairs);

result.add(blockedBy);
result.add(blockedStateExit);
result.add(blockedNewStateEntry);
Expand Down Expand Up @@ -262,9 +256,6 @@ public List<Object> copyLocalVarValues() {
result.add(started);
result.add(halted);

result.add(new HashSet<>(observedEvents));
result.add(new HashSet<>(happensBeforePairs));

result.add(blockedBy);
result.add(blockedStateExit);
result.add(blockedNewStateEntry);
Expand Down Expand Up @@ -307,9 +298,6 @@ protected int setLocalVarValues(List<Object> values) {
started = (boolean) values.get(idx++);
halted = (boolean) values.get(idx++);

observedEvents = (Set<PEvent>) values.get(idx++);
happensBeforePairs = (Set<ImmutablePair<PEvent, PEvent>>) values.get(idx++);

blockedBy = (PContinuation) values.get(idx++);
blockedStateExit = (State) values.get(idx++);
blockedNewStateEntry = (State) values.get(idx++);
Expand All @@ -336,11 +324,13 @@ protected int setLocalVarValues(List<Object> values) {
}

public MachineLocalState copyMachineState() {
return new MachineLocalState(copyLocalVarValues());
return new MachineLocalState(copyLocalVarValues(), observedEvents, happensBeforePairs);
}

public void setMachineState(MachineLocalState input) {
setLocalVarValues(input.getLocals());
observedEvents = input.getObservedEvents();
happensBeforePairs = input.getHappensBeforePairs();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class Schedule implements Serializable {
* Step state at the start of a scheduler step.
* Used in stateful backtracking
*/
@Getter
@Setter
private transient StepState stepBeginState = null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ public abstract class Scheduler implements SchedulerInterface {
/**
* Current schedule
*/
public final Schedule schedule;
@Getter
protected final Schedule schedule;
/**
* Step number
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import pexplicit.runtime.scheduler.Scheduler;
import pexplicit.runtime.scheduler.choice.ScheduleChoice;
import pexplicit.runtime.scheduler.choice.SearchUnit;
import pexplicit.runtime.scheduler.explicit.choiceselector.ChoiceSelectorMode;
import pexplicit.runtime.scheduler.explicit.choiceselector.ChoiceSelectorQL;
import pexplicit.runtime.scheduler.explicit.strategy.*;
import pexplicit.utils.exceptions.PExplicitRuntimeException;
Expand Down Expand Up @@ -160,6 +159,19 @@ protected void runIteration() throws TimeoutException {
if (scheduleTerminated) {
// schedule terminated, check for deadlock
checkDeadlock();
// update timeline
Integer timelineHash = stepState.getTimelineHash();
if (!timelines.contains(timelineHash)) {
// add new timeline
timelines.add(timelineHash);
// print new timeline
// stepState.printTimeline(timelineHash, choiceNumber, String.format("%d. New timeline %d @%d::%d",
// SearchStatistics.iteration, timelines.size(), stepNumber, choiceNumber));
if (PExplicitGlobal.getChoiceSelector() instanceof ChoiceSelectorQL choiceSelectorQL) {
// reward new timeline
choiceSelectorQL.rewardNewTimeline(this);
}
}
}
if (!skipLiveness) {
// check for liveness
Expand Down Expand Up @@ -198,16 +210,6 @@ protected void runStep() throws TimeoutException {
return;
}

// update timeline
Integer timelineHash = stepState.getTimelineHash();
if (!timelines.contains(timelineHash)) {
// stepState.printTimeline(timelineHash, timelines.size());
timelines.add(timelineHash);
}
if (PExplicitGlobal.getConfig().getChoiceSelectorMode() == ChoiceSelectorMode.QL) {
PExplicitGlobal.getChoiceSelector().startStep(this);
}

if (PExplicitGlobal.getConfig().getStatefulBacktrackingMode() != StatefulBacktrackingMode.None
&& stepNumber != 0) {
schedule.setStepBeginState(stepState.copyState());
Expand Down Expand Up @@ -348,7 +350,7 @@ public PMachine getNextScheduleChoice() {
}

// pick a choice
int selected = PExplicitGlobal.getChoiceSelector().selectChoice(choices);
int selected = PExplicitGlobal.getChoiceSelector().selectChoice(this, choices);
result = PExplicitGlobal.getGlobalMachine(choices.get(selected));
PExplicitLogger.logCurrentScheduleChoice(result, stepNumber, choiceNumber);

Expand Down Expand Up @@ -412,7 +414,7 @@ public PValue<?> getNextDataChoice(List<PValue<?>> input_choices) {
}

// pick a choice
int selected = PExplicitGlobal.getChoiceSelector().selectChoice(choices);
int selected = PExplicitGlobal.getChoiceSelector().selectChoice(this, choices);
result = choices.get(selected);
PExplicitLogger.logCurrentDataChoice(result, stepNumber, choiceNumber);

Expand Down Expand Up @@ -495,6 +497,7 @@ private void setChildTask(SearchUnit unit, int choiceNum, SearchTask parentTask,

newTask.writeToFile();
parentTask.addChild(newTask);
searchStrategy.getPendingTasks().add(newTask.getId());
searchStrategy.addNewTask(newTask);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ public int getMachineCount(Class<? extends PMachine> type) {
return result;
}

public void printTimeline(int hash, int idx) {
PExplicitLogger.logVerbose(String.format("---- Timeline %d @%d ------", idx, hash));
public void printTimeline(int hash, int idx, String comment) {
PExplicitLogger.logVerbose(String.format("----\n%s\tTimeline %d @%d\n-----", comment, idx, hash));
for (PMachine m : machineSet) {
PExplicitLogger.logVerbose(String.format(" %s -> %s", m, m.getHappensBeforePairs()));
}
Expand All @@ -102,7 +102,10 @@ public Integer getTimelineHash() {
List<Integer> features = new ArrayList<>();
for (PMachine m : machineSet) {
features.add(m.hashCode());
features.add(m.getHappensBeforePairs().hashCode());
MachineLocalState ms = machineLocalStates.get(m);
if (ms != null) {
features.add(ms.getTimelineHash());
}
}
return features.hashCode();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,88 +1,84 @@
package pexplicit.runtime.scheduler.explicit.choiceselector;

import lombok.Getter;
import pexplicit.runtime.PExplicitGlobal;
import pexplicit.runtime.logger.PExplicitLogger;
import pexplicit.runtime.machine.PMachine;
import pexplicit.runtime.scheduler.Schedule;
import pexplicit.runtime.scheduler.choice.Choice;
import pexplicit.runtime.scheduler.choice.ScheduleChoice;
import pexplicit.runtime.scheduler.explicit.ExplicitSearchScheduler;
import pexplicit.runtime.scheduler.explicit.StatefulBacktrackingMode;
import pexplicit.utils.random.RandomNumberGenerator;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

public class ChoiceQL implements Serializable {
@Getter
private static final int defaultQValue = 0;
private static final double ALPHA = 0.5;
private static final double GAMMA = 0.05;
private static final double defaultQValue = 1.0;
private static final double ALPHA = 0.3;
private static final double GAMMA = 0.2;
private static final double STEP_PENALTY_REWARD = -1.0;
private static final double NEW_TIMELINE_REWARD = 1.0;
private final ChoiceQTable<Integer, Object> qValues;
private final List<Object> currActions = new ArrayList<>();
/**
* Details about the current step
*/
private Integer currState = 0;
private int currNumTimelines = 0;

public ChoiceQL() {
qValues = new ChoiceQTable();
}

private void rewardAction(Object action, int reward) {
ChoiceQTable.ChoiceQStateEntry stateEntry = qValues.get(currState);
private void rewardAction(int state, Object action, double reward) {
ChoiceQTable.ChoiceQStateEntry stateEntry = qValues.get(state);
ChoiceQTable.ChoiceQClassEntry classEntry = stateEntry.get(action.getClass());
int maxQ = classEntry.getMaxQ();
int oldVal = classEntry.get(action);
int newVal = (int) ((1 - ALPHA) * oldVal + ALPHA * (reward + GAMMA * maxQ));
double maxQ = classEntry.getMaxQ();
double oldVal = classEntry.get(action);
double newVal = ((1 - ALPHA) * oldVal + ALPHA * (reward + GAMMA * maxQ));
classEntry.update(action, newVal);
}

private void setStateTimelineAbstraction(ExplicitSearchScheduler sch) {
List<Integer> features = new ArrayList<>();
for (PMachine m : sch.getStepState().getMachineSet()) {
features.add(m.hashCode());
features.add(m.getHappensBeforePairs().hashCode());
public int select(int state, List<?> choices) {
// Compute the total and minimum weight
double totalWeight = 0.0;
double minWeight = Double.MAX_VALUE;
for (int i = 0; i < choices.size(); i++) {
Object choice = choices.get(i);
double weight = qValues.get(state, choice.getClass(), choice);
totalWeight += weight;
if (weight < minWeight) {
minWeight = weight;
}
}
currState = features.hashCode();
}

public void startStep(ExplicitSearchScheduler sch) {
// printQTable();

// set reward amount
int reward = -100;
if (sch.getTimelines().size() > currNumTimelines) {
reward = 100;
}
// reward last actions
for (Object action : currActions) {
rewardAction(action, reward);
// Now choose a weighted random item
int idx = 0;
for (double r = RandomNumberGenerator.getInstance().getRandomDouble() * totalWeight; idx < choices.size() - 1; idx++) {
Object choice = choices.get(idx);
double weight = qValues.get(state, choice.getClass(), choice);
r -= weight;
if (r <= 0.0) {
break;
}
}
return idx;
}

// set number of timelines at start of step
currNumTimelines = sch.getTimelines().size();

// set state at start of step
setStateTimelineAbstraction(sch);

// reset current actions at start of step
currActions.clear();
public void penalizeSelected(int state, Object action) {
// give a negative reward to the selected choice
rewardAction(state, action, STEP_PENALTY_REWARD);
}

public int selectChoice(List<?> choices) {
int maxVal = Integer.MIN_VALUE;
int selected = 0;
for (int i = 0; i < choices.size(); i++) {
Object choice = choices.get(i);
int val = qValues.get(currState, choice.getClass(), choice);
if (val > maxVal) {
maxVal = val;
selected = i;
public void rewardScheduleChoices(ExplicitSearchScheduler sch) {
Schedule schedule = sch.getSchedule();
for (int cIdx : sch.getSearchStrategy().getCurrTask().getSearchUnits().keySet()) {
int state = 0;
Choice choice = schedule.getChoice(cIdx);
if (PExplicitGlobal.getConfig().getStatefulBacktrackingMode() != StatefulBacktrackingMode.None) {
ScheduleChoice scheduleChoice = schedule.getScheduleChoiceAt(cIdx);
if (scheduleChoice != null && scheduleChoice.getChoiceState() != null) {
state = scheduleChoice.getChoiceState().getTimelineHash();
}
}
rewardAction(state, choice.getCurrent(), NEW_TIMELINE_REWARD);
}
return selected;
}

public void addChoice(Object choice) {
currActions.add(choice);
}

public int getNumStates() {
Expand Down Expand Up @@ -121,10 +117,10 @@ public void printQTable() {
}
Object bestAction = classEntry.getBestAction();
if (bestAction != null) {
int maxQ = classEntry.get(bestAction);
double maxQ = classEntry.get(bestAction);
PExplicitLogger.logVerbose(
String.format(
" %s [%s] -> %s -> %d\t%s",
" %s [%s] -> %s -> %.2f\t%s",
stateStr, cls.getSimpleName(), bestAction, maxQ, classEntry));
}
}
Expand Down
Loading

0 comments on commit 83a90b2

Please sign in to comment.