diff --git a/lib/stormpy/pomdp/__init__.py b/lib/stormpy/pomdp/__init__.py index d490e77b4..389565c0a 100644 --- a/lib/stormpy/pomdp/__init__.py +++ b/lib/stormpy/pomdp/__init__.py @@ -64,15 +64,18 @@ def create_nondeterminstic_belief_tracker(model, reduction_timeout, track_timeou return pomdp.NondeterministicBeliefTrackerDoubleSparse(model, opts) -def create_observation_trace_unfolder(model, risk_assessment, expr_manager): +def create_observation_trace_unfolder(model, risk_assessment, expr_manager, rejection_sampling = True): """ :param model: :param risk_assessment: :param expr_manager: + :param rejection_sampling: :return: """ + options = pomdp.ObservationTraceUnfolderOptions() + options.rejection_sampling = rejection_sampling if model.is_exact: - return pomdp.ObservationTraceUnfolderExact(model, risk_assessment, expr_manager) + return pomdp.ObservationTraceUnfolderExact(model, risk_assessment, expr_manager, options) else: - return pomdp.ObservationTraceUnfolderDouble(model, risk_assessment, expr_manager) \ No newline at end of file + return pomdp.ObservationTraceUnfolderDouble(model, risk_assessment, expr_manager, options) \ No newline at end of file diff --git a/src/pomdp/transformations.cpp b/src/pomdp/transformations.cpp index 07fa32abc..8ea950f65 100644 --- a/src/pomdp/transformations.cpp +++ b/src/pomdp/transformations.cpp @@ -32,9 +32,11 @@ std::shared_ptr> apply_unk } template -std::shared_ptr> unfold_trace(storm::models::sparse::Pomdp const& pomdp, std::shared_ptr& exprManager, std::vector const& observationTrace, std::vector const& riskDef ) { - storm::pomdp::ObservationTraceUnfolder transformer(pomdp, exprManager); - return transformer.transform(observationTrace, riskDef); +std::shared_ptr> unfold_trace(storm::models::sparse::Pomdp const& pomdp, std::shared_ptr& exprManager, std::vector const& observationTrace, std::vector const& riskDef, bool rejectionSampling=true) { + storm::pomdp::ObservationTraceUnfolderOptions options = storm::pomdp::ObservationTraceUnfolderOptions(); + options.rejectionSampling = rejectionSampling; + storm::pomdp::ObservationTraceUnfolder transformer(pomdp, riskDef, exprManager, options); + return transformer.transform(observationTrace); } // STANDARD, SIMPLE_LINEAR, SIMPLE_LINEAR_INVERSE, SIMPLE_LOG, FULL @@ -47,6 +49,11 @@ void define_transformations_nt(py::module &m) { .value("full", storm::transformer::PomdpFscApplicationMode::FULL) ; + py::class_ unfolderOptions(m, "ObservationTraceUnfolderOptions", "Options for the ObservationTraceUnfolder"); + unfolderOptions.def(py::init<>()); + unfolderOptions.def_readwrite("rejection_sampling", &storm::pomdp::ObservationTraceUnfolderOptions::rejectionSampling); + + } template @@ -55,12 +62,12 @@ void define_transformations(py::module& m, std::string const& vtSuffix) { m.def(("_unfold_memory_" + vtSuffix).c_str(), &unfold_memory, "Unfold memory into a POMDP", py::arg("pomdp"), py::arg("memorystructure"), py::arg("memorylabels") = false, py::arg("keep_state_valuations")=false); m.def(("_make_simple_"+ vtSuffix).c_str(), &make_simple, "Make POMDP simple", py::arg("pomdp"), py::arg("keep_state_valuations")=false); m.def(("_apply_unknown_fsc_" + vtSuffix).c_str(), &apply_unknown_fsc, "Apply unknown FSC",py::arg("pomdp"), py::arg("application_mode")=storm::transformer::PomdpFscApplicationMode::SIMPLE_LINEAR); - //m.def(("_unfold_trace_" + vtSuffix).c_str(), &unfold_trace, "Unfold observed trace", py::arg("pomdp"), py::arg("expression_manager"),py::arg("observation_trace"), py::arg("risk_definition")); py::class_> unfolder(m, ("ObservationTraceUnfolder" + vtSuffix).c_str(), "Unfolds observation traces in models"); - unfolder.def(py::init const&, std::vector const&, std::shared_ptr&>(), py::arg("model"), py::arg("risk"), py::arg("expression_manager")); + unfolder.def(py::init const&, std::vector const&, std::shared_ptr&, storm::pomdp::ObservationTraceUnfolderOptions const&>(), py::arg("model"), py::arg("risk"), py::arg("expression_manager"), py::arg("options")); + unfolder.def("is_rejection_sampling_set", &storm::pomdp::ObservationTraceUnfolder::isRejectionSamplingSet); unfolder.def("transform", &storm::pomdp::ObservationTraceUnfolder::transform, py::arg("trace")); - unfolder.def("extend", &storm::pomdp::ObservationTraceUnfolder::extend, py::arg("new_observation")); + unfolder.def("extend", &storm::pomdp::ObservationTraceUnfolder::extend, py::arg("new_observations")); unfolder.def("reset", &storm::pomdp::ObservationTraceUnfolder::reset, py::arg("new_observation")); }