From f35b9788a78e104c9bcba7ec9e9b56cbce7e9e2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Steffen=20M=C3=A4rcker?= Date: Wed, 8 May 2019 11:36:08 +0200 Subject: [PATCH] Implement reduceTransitions for MDPs to allow efficient accumulation over transitions --- .../explicit/DTMCFromMDPAndMDStrategy.java | 38 +++++- prism/src/explicit/MDP.java | 112 ++++++++++++++---- prism/src/explicit/MDPSparse.java | 52 ++++++++ 3 files changed, 175 insertions(+), 27 deletions(-) diff --git a/prism/src/explicit/DTMCFromMDPAndMDStrategy.java b/prism/src/explicit/DTMCFromMDPAndMDStrategy.java index fe9ca18ec8..9d1e002d21 100644 --- a/prism/src/explicit/DTMCFromMDPAndMDStrategy.java +++ b/prism/src/explicit/DTMCFromMDPAndMDStrategy.java @@ -187,9 +187,45 @@ public void forEachTransition(int s, TransitionConsumer c) if (!strat.isChoiceDefined(s)) { return; } - mdp.forEachTransition(s, strat.getChoiceIndex(s), c::accept); + mdp.forEachTransition(s, strat.getChoiceIndex(s), c); } + @Override + public T reduceTransitions(int state, T init, ObjTransitionFunction fn) + { + if (!strat.isChoiceDefined(state)) { + return init; + } + return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn); + } + + @Override + public double reduceTransitions(int state, double init, DoubleTransitionFunction fn) + { + if (!strat.isChoiceDefined(state)) { + return init; + } + return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn); + } + + @Override + public int reduceTransitions(int state, int init, IntTransitionFunction fn) + { + if (!strat.isChoiceDefined(state)) { + return init; + } + return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn); + } + + @Override + public long reduceTransitions(int state, long init, LongTransitionFunction fn) + { + if (!strat.isChoiceDefined(state)) { + return init; + } + return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn); +} + @Override public double mvMultSingle(int s, double vect[]) { diff --git a/prism/src/explicit/MDP.java b/prism/src/explicit/MDP.java index 613d0a2040..1b3791578a 100644 --- a/prism/src/explicit/MDP.java +++ b/prism/src/explicit/MDP.java @@ -35,6 +35,11 @@ import java.util.PrimitiveIterator.OfInt; import common.IterableStateSet; +import explicit.DTMC.DoubleTransitionFunction; +import explicit.DTMC.IntTransitionFunction; +import explicit.DTMC.LongTransitionFunction; +import explicit.DTMC.ObjTransitionFunction; +import explicit.DTMC.TransitionConsumer; import explicit.rewards.MCRewards; import explicit.rewards.MDPRewards; import prism.PrismUtils; @@ -53,16 +58,6 @@ public interface MDP extends MDPGeneric */ public Iterator> getTransitionsIterator(int s, int i); - /** - * Functional interface for a consumer, - * accepting transitions (s,t,d), i.e., - * from state s to state t with value d. - */ - @FunctionalInterface - public interface TransitionConsumer { - void accept(int s, int t, double d); - } - /** * Iterate over the outgoing transitions of state {@code s} and choice {@code i} * and call the accept method of the consumer for each of them: @@ -88,6 +83,83 @@ public default void forEachTransition(int s, int i, TransitionConsumer c) } } + /** + * Iterate over the outgoing transitions of state {@code state} and choice {@code c} + * and apply the reducing function {@code fn} + * to the intermediate result and the transition: + *
+ * Call {@code apply(r,s,t,d)} where + * {@code r} is the intermediate result, + * {@code t} is the successor state and, + * {@code d} = P(s,c,t) is the probability from {@code s} to {@code t} with choice {@code c}, + * The return value of apply is the intermediate result for the next transition. + *

+ * Default implementation: The default implementation relies on iterating over the + * iterator returned by {@code getTransitionsIterator()}. + *

Note: This method is the base for the default implementation of the numerical + * computation methods (mvMult, etc). In derived classes, it may thus be worthwhile to + * provide a specialised implementation for this method that avoids using the Iterator mechanism. + * + * @param state the state + * @param choice the choice + * @param init initial result value + * @param fn the reducing function + */ + public default T reduceTransitions(int state, int choice, T init, ObjTransitionFunction fn) + { + T result = init; + for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) { + Entry e = it.next(); + result = fn.apply(result, state, e.getKey(), e.getValue()); + } + return result; + } + + /** + * Primitive specialisation of {@code reduce} for {@code double} values. + * + * @see #reduceTransitions(int, Object, ObjTransitionFunction) + */ + public default double reduceTransitions(int state, int choice, double init, DoubleTransitionFunction fn) + { + double result = init; + for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) { + Entry e = it.next(); + result = fn.apply(result, state, e.getKey(), e.getValue()); + } + return result; + } + + /** + * Primitive specialisation of {@code reduce} for {@code int} values. + * + * @see #reduceTransitions(int, Object, ObjTransitionFunction) + */ + public default int reduceTransitions(int state, int choice, int init, IntTransitionFunction fn) + { + int result = init; + for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) { + Entry e = it.next(); + result = fn.apply(result, state, e.getKey(), e.getValue()); + } + return result; + } + + /** + * Primitive specialisation of {@code reduce} for {@code long} values. + * + * @see #reduceTransitions(int, Object, ObjTransitionFunction) + */ + public default long reduceTransitions(int state, int choice, long init, LongTransitionFunction fn) + { + long result = init; + for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) { + Entry e = it.next(); + result = fn.apply(result, state, e.getKey(), e.getValue()); + } + return result; + } + /** * Functional interface for a function * mapping transitions (s,t,d), i.e., @@ -105,24 +177,12 @@ public interface TransitionToDoubleFunction { *
* Return sum_t f(s, t, P(s,i,t)), where t ranges over the i-successors of s. * - * @param s the state s - * @param c the consumer + * @param state the state s + * @param choice the consumer */ - public default double sumOverTransitions(final int s, final int i, final TransitionToDoubleFunction f) + public default double sumOverTransitions(int state, int choice, TransitionToDoubleFunction f) { - class Sum { - double sum = 0.0; - - void accept(int s, int t, double d) - { - sum += f.apply(s, t, d); - } - } - - Sum sum = new Sum(); - forEachTransition(s, i, sum::accept); - - return sum.sum; + return reduceTransitions(state, choice, 0.0, (r, s, t, d) -> r + f.apply(s, t, d)); } /** diff --git a/prism/src/explicit/MDPSparse.java b/prism/src/explicit/MDPSparse.java index 4180d0dfc8..e6b2aba5a7 100644 --- a/prism/src/explicit/MDPSparse.java +++ b/prism/src/explicit/MDPSparse.java @@ -41,6 +41,10 @@ import java.util.TreeMap; import common.IterableStateSet; +import explicit.DTMC.DoubleTransitionFunction; +import explicit.DTMC.IntTransitionFunction; +import explicit.DTMC.LongTransitionFunction; +import explicit.DTMC.ObjTransitionFunction; import explicit.rewards.MCRewards; import explicit.rewards.MDPRewards; import parser.State; @@ -574,6 +578,54 @@ public SuccessorsIterator getSuccessors(final int s, final int i) // Accessors (for MDP) + @Override + public T reduceTransitions(int state, int choice, T init, ObjTransitionFunction fn) + { + T result = init; + int start = choiceStarts[rowStarts[state] + choice]; + int stop = choiceStarts[rowStarts[state] + choice + 1]; + for (int col = start; col < stop; col++) { + result = fn.apply(result, state, cols[col], nonZeros[col]); + } + return result; + } + + @Override + public double reduceTransitions(int state, int choice, double init, DoubleTransitionFunction fn) + { + double result = init; + int start = choiceStarts[rowStarts[state] + choice]; + int stop = choiceStarts[rowStarts[state] + choice + 1]; + for (int col = start; col < stop; col++) { + result = fn.apply(result, state, cols[col], nonZeros[col]); + } + return result; + } + + @Override + public int reduceTransitions(int state, int choice, int init, IntTransitionFunction fn) + { + int result = init; + int start = choiceStarts[rowStarts[state] + choice]; + int stop = choiceStarts[rowStarts[state] + choice + 1]; + for (int col = start; col < stop; col++) { + result = fn.apply(result, state, cols[col], nonZeros[col]); + } + return result; + } + + @Override + public long reduceTransitions(int state, int choice, long init, LongTransitionFunction fn) + { + long result = init; + int start = choiceStarts[rowStarts[state] + choice]; + int stop = choiceStarts[rowStarts[state] + choice + 1]; + for (int col = start; col < stop; col++) { + result = fn.apply(result, state, cols[col], nonZeros[col]); + } + return result; + } + @Override public int getNumTransitions(int s, int i) {