Skip to content

Commit

Permalink
Implement reduceTransitions for MDPs to allow efficient accumulation …
Browse files Browse the repository at this point in the history
…over transitions
  • Loading branch information
Steffen Märcker committed May 8, 2019
1 parent 874be3c commit 3b005bd
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 27 deletions.
38 changes: 37 additions & 1 deletion prism/src/explicit/DTMCFromMDPAndMDStrategy.java
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,45 @@ public void forEachTransition(int s, TransitionConsumer c)
if (!strat.isChoiceDefined(s)) {
return;
}
mdp.forEachTransition(s, strat.getChoiceIndex(s), c::accept);
mdp.forEachTransition(s, strat.getChoiceIndex(s), c);
}

@Override
public <T> T reduceTransitions(int state, T init, ObjTransitionFunction<T> fn)
{
if (!strat.isChoiceDefined(state)) {
return init;
}
return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn);
}

@Override
public double reduceTransitions(int state, double init, DoubleTransitionFunction fn)
{
if (!strat.isChoiceDefined(state)) {
return init;
}
return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn);
}

@Override
public int reduceTransitions(int state, int init, IntTransitionFunction fn)
{
if (!strat.isChoiceDefined(state)) {
return init;
}
return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn);
}

@Override
public long reduceTransitions(int state, long init, LongTransitionFunction fn)
{
if (!strat.isChoiceDefined(state)) {
return init;
}
return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn);
}

@Override
public double mvMultSingle(int s, double vect[])
{
Expand Down
112 changes: 86 additions & 26 deletions prism/src/explicit/MDP.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
import java.util.PrimitiveIterator.OfInt;

import common.IterableStateSet;
import explicit.DTMC.DoubleTransitionFunction;
import explicit.DTMC.IntTransitionFunction;
import explicit.DTMC.LongTransitionFunction;
import explicit.DTMC.ObjTransitionFunction;
import explicit.DTMC.TransitionConsumer;
import explicit.rewards.MCRewards;
import explicit.rewards.MDPRewards;
import prism.PrismUtils;
Expand All @@ -53,16 +58,6 @@ public interface MDP extends MDPGeneric<Double>
*/
public Iterator<Entry<Integer, Double>> getTransitionsIterator(int s, int i);

/**
* Functional interface for a consumer,
* accepting transitions (s,t,d), i.e.,
* from state s to state t with value d.
*/
@FunctionalInterface
public interface TransitionConsumer {
void accept(int s, int t, double d);
}

/**
* Iterate over the outgoing transitions of state {@code s} and choice {@code i}
* and call the accept method of the consumer for each of them:
Expand All @@ -88,6 +83,83 @@ public default void forEachTransition(int s, int i, TransitionConsumer c)
}
}

/**
* Iterate over the outgoing transitions of state {@code state} and choice {@code c}
* and apply the reducing function {@code fn}
* to the intermediate result and the transition:
* <br/>
* Call {@code apply(r,s,t,d)} where
* {@code r} is the intermediate result,
* {@code t} is the successor state and,
* {@code d} = P(s,c,t) is the probability from {@code s} to {@code t} with choice {@code c},
* The return value of apply is the intermediate result for the next transition.
* <p>
* <i>Default implementation</i>: The default implementation relies on iterating over the
* iterator returned by {@code getTransitionsIterator()}.
* <p><i>Note</i>: This method is the base for the default implementation of the numerical
* computation methods (mvMult, etc). In derived classes, it may thus be worthwhile to
* provide a specialised implementation for this method that avoids using the Iterator mechanism.
*
* @param state the state
* @param choice the choice
* @param init initial result value
* @param fn the reducing function
*/
public default <T> T reduceTransitions(int state, int choice, T init, ObjTransitionFunction<T> fn)
{
T result = init;
for (Iterator<Entry<Integer, Double>> it = getTransitionsIterator(state, choice); it.hasNext(); ) {
Entry<Integer, Double> e = it.next();
result = fn.apply(result, state, e.getKey(), e.getValue());
}
return result;
}

/**
* Primitive specialisation of {@code reduce} for {@code double} values.
*
* @see #reduceTransitions(int, Object, ObjTransitionFunction)
*/
public default double reduceTransitions(int state, int choice, double init, DoubleTransitionFunction fn)
{
double result = init;
for (Iterator<Entry<Integer, Double>> it = getTransitionsIterator(state, choice); it.hasNext(); ) {
Entry<Integer, Double> e = it.next();
result = fn.apply(result, state, e.getKey(), e.getValue());
}
return result;
}

/**
* Primitive specialisation of {@code reduce} for {@code int} values.
*
* @see #reduceTransitions(int, Object, ObjTransitionFunction)
*/
public default int reduceTransitions(int state, int choice, int init, IntTransitionFunction fn)
{
int result = init;
for (Iterator<Entry<Integer, Double>> it = getTransitionsIterator(state, choice); it.hasNext(); ) {
Entry<Integer, Double> e = it.next();
result = fn.apply(result, state, e.getKey(), e.getValue());
}
return result;
}

/**
* Primitive specialisation of {@code reduce} for {@code long} values.
*
* @see #reduceTransitions(int, Object, ObjTransitionFunction)
*/
public default long reduceTransitions(int state, int choice, long init, LongTransitionFunction fn)
{
long result = init;
for (Iterator<Entry<Integer, Double>> it = getTransitionsIterator(state, choice); it.hasNext(); ) {
Entry<Integer, Double> e = it.next();
result = fn.apply(result, state, e.getKey(), e.getValue());
}
return result;
}

/**
* Functional interface for a function
* mapping transitions (s,t,d), i.e.,
Expand All @@ -105,24 +177,12 @@ public interface TransitionToDoubleFunction {
* <br>
* Return sum_t f(s, t, P(s,i,t)), where t ranges over the i-successors of s.
*
* @param s the state s
* @param c the consumer
* @param state the state s
* @param choice the consumer
*/
public default double sumOverTransitions(final int s, final int i, final TransitionToDoubleFunction f)
public default double sumOverTransitions(int state, int choice, TransitionToDoubleFunction f)
{
class Sum {
double sum = 0.0;

void accept(int s, int t, double d)
{
sum += f.apply(s, t, d);
}
}

Sum sum = new Sum();
forEachTransition(s, i, sum::accept);

return sum.sum;
return reduceTransitions(state, choice, 0.0, (r, s, t, d) -> r + f.apply(s, t, d));
}

/**
Expand Down
52 changes: 52 additions & 0 deletions prism/src/explicit/MDPSparse.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
import java.util.TreeMap;

import common.IterableStateSet;
import explicit.DTMC.DoubleTransitionFunction;
import explicit.DTMC.IntTransitionFunction;
import explicit.DTMC.LongTransitionFunction;
import explicit.DTMC.ObjTransitionFunction;
import explicit.rewards.MCRewards;
import explicit.rewards.MDPRewards;
import parser.State;
Expand Down Expand Up @@ -574,6 +578,54 @@ public SuccessorsIterator getSuccessors(final int s, final int i)

// Accessors (for MDP)

@Override
public <T> T reduceTransitions(int state, int choice, T init, ObjTransitionFunction<T> fn)
{
T result = init;
int start = choiceStarts[rowStarts[state] + choice];
int stop = choiceStarts[rowStarts[state] + choice + 1];
for (int col = start; col < stop; col++) {
result = fn.apply(result, state, cols[col], nonZeros[col]);
}
return result;
}

@Override
public double reduceTransitions(int state, int choice, double init, DoubleTransitionFunction fn)
{
double result = init;
int start = choiceStarts[rowStarts[state] + choice];
int stop = choiceStarts[rowStarts[state] + choice + 1];
for (int col = start; col < stop; col++) {
result = fn.apply(result, state, cols[col], nonZeros[col]);
}
return result;
}

@Override
public int reduceTransitions(int state, int choice, int init, IntTransitionFunction fn)
{
int result = init;
int start = choiceStarts[rowStarts[state] + choice];
int stop = choiceStarts[rowStarts[state] + choice + 1];
for (int col = start; col < stop; col++) {
result = fn.apply(result, state, cols[col], nonZeros[col]);
}
return result;
}

@Override
public long reduceTransitions(int state, int choice, long init, LongTransitionFunction fn)
{
long result = init;
int start = choiceStarts[rowStarts[state] + choice];
int stop = choiceStarts[rowStarts[state] + choice + 1];
for (int col = start; col < stop; col++) {
result = fn.apply(result, state, cols[col], nonZeros[col]);
}
return result;
}

@Override
public int getNumTransitions(int s, int i)
{
Expand Down

0 comments on commit 3b005bd

Please sign in to comment.