diff --git a/src/storage/model.cpp b/src/storage/model.cpp index 45594d6..1170dcc 100644 --- a/src/storage/model.cpp +++ b/src/storage/model.cpp @@ -16,6 +16,8 @@ #include "storm/models/symbolic/MarkovAutomaton.h" #include "storm/models/symbolic/StandardRewardModel.h" +#include "storm/storage/Scheduler.h" + #include #include #include @@ -105,6 +107,7 @@ void define_model(py::module& m) { py::class_> modelBase(m, "_ModelBase", "Base class for all models"); modelBase.def_property_readonly("nr_states", &ModelBase::getNumberOfStates, "Number of states") .def_property_readonly("nr_transitions", &ModelBase::getNumberOfTransitions, "Number of transitions") + .def_property_readonly("nr_choices", &ModelBase::getNumberOfChoices, "Number of choices") .def_property_readonly("model_type", &ModelBase::getType, "Model type") .def_property_readonly("supports_parameters", &ModelBase::supportsParameters, "Flag whether model supports parameters") .def_property_readonly("has_parameters", &ModelBase::hasParameters, "Flag whether model has parameters") @@ -183,11 +186,14 @@ void define_sparse_model(py::module& m) { .def_property_readonly("backward_transition_matrix", &SparseModel::getBackwardTransitions, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Backward transition matrix") .def("reduce_to_state_based_rewards", &SparseModel::reduceToStateBasedRewards) .def("__str__", getModelInfoPrinter()) + .def("to_dot", [](SparseModel& model) { std::stringstream ss; model.writeDotToStream(ss); return ss.str(); }, "Write dot to a string") ; py::class_, std::shared_ptr>>(m, "SparseDtmc", "DTMC in sparse representation", model) .def("__str__", getModelInfoPrinter("DTMC")) ; py::class_, std::shared_ptr>>(m, "SparseMdp", "MDP in sparse representation", model) + .def_property_readonly("nondeterministic_choice_indices", [](SparseMdp const& mdp) { return mdp.getNondeterministicChoiceIndices(); }) + .def("apply_scheduler", [](SparseMdp const& mdp, storm::storage::Scheduler const& scheduler, bool dropUnreachableStates) { return mdp.applyScheduler(scheduler, dropUnreachableStates); } , "apply scheduler", "scheduler"_a, "drop_unreachable_states"_a = true) .def("__str__", getModelInfoPrinter("MDP")) ; py::class_, std::shared_ptr>>(m, "SparsePomdp", "POMDP in sparse representation", model) @@ -209,6 +215,7 @@ void define_sparse_model(py::module& m) { .def_property_readonly("transition_rewards", [](SparseRewardModel& rewardModel) {return rewardModel.getTransitionRewardMatrix();}) .def_property_readonly("state_rewards", [](SparseRewardModel& rewardModel) {return rewardModel.getStateRewardVector();}) .def("get_state_reward", [](SparseRewardModel& rewardModel, uint64_t state) {return rewardModel.getStateReward(state);}) + .def("get_zero_reward_states", &SparseRewardModel::getStatesWithZeroReward, "get states where all rewards are zero", py::arg("transition_matrix")) .def("get_state_action_reward", [](SparseRewardModel& rewardModel, uint64_t action_index) {return rewardModel.getStateActionReward(action_index);}) .def_property_readonly("state_action_rewards", [](SparseRewardModel& rewardModel) {return rewardModel.getStateActionRewardVector();}) .def("reduce_to_state_based_rewards", [](SparseRewardModel& rewardModel, storm::storage::SparseMatrix const& transitions, bool onlyStateRewards){return rewardModel.reduceToStateBasedRewards(transitions, onlyStateRewards);}, py::arg("transition_matrix"), py::arg("only_state_rewards"), "Reduce to state-based rewards") @@ -237,6 +244,8 @@ void define_sparse_model(py::module& m) { .def("__str__", getModelInfoPrinter("ParametricDTMC")) ; py::class_, std::shared_ptr>>(m, "SparseParametricMdp", "pMDP in sparse representation", modelRatFunc) + .def_property_readonly("nondeterministic_choice_indices", [](SparseMdp const& mdp) { return mdp.getNondeterministicChoiceIndices(); }) + .def("apply_scheduler", [](SparseMdp const& mdp, storm::storage::Scheduler const& scheduler, bool dropUnreachableStates) { return mdp.applyScheduler(scheduler, dropUnreachableStates); } , "apply scheduler", "scheduler"_a, "drop_unreachable_states"_a = true) .def("__str__", getModelInfoPrinter("ParametricMDP")) ; py::class_, std::shared_ptr>>(m, "SparseParametricCtmc", "pCTMC in sparse representation", modelRatFunc)