diff --git a/src/storage/model.cpp b/src/storage/model.cpp index 27baa1d..510c936 100644 --- a/src/storage/model.cpp +++ b/src/storage/model.cpp @@ -119,6 +119,7 @@ void define_model(py::module& m) { .def_property_readonly("reward_models", [](Model& model) {return model.getRewardModels(); }, "Reward models") .def_property_readonly("transition_matrix", &getTransitionMatrix, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Transition matrix") .def_property_readonly("backward_transition_matrix", &getBackwardTransitionMatrix, py::return_value_policy::reference, py::keep_alive<1, 0>(), "Backward transition matrix") + .def("reduce_to_state_based_rewards", &Model::reduceToStateBasedRewards) .def("__str__", getModelInfoPrinter()) ; py::class_, std::shared_ptr>>(m, "SparseDtmc", "DTMC in sparse representation", model) @@ -138,6 +139,7 @@ void define_model(py::module& m) { .def_property_readonly("transition_rewards", [](RewardModel& rewardModel) {return rewardModel.getTransitionRewardMatrix();}) .def_property_readonly("state_rewards", [](RewardModel& rewardModel) {return rewardModel.getStateRewardVector();}) .def_property_readonly("state_action_rewards", [](RewardModel& rewardModel) {return rewardModel.getStateActionRewardVector();}) + .def("reduce_to_state_based_rewards", [](RewardModel& rewardModel, SparseMatrix const& transitions, bool onlyStateRewards){return rewardModel.reduceToStateBasedRewards(transitions, onlyStateRewards);}, py::arg("transition_matrix"), py::arg("only_state_rewards"), "Reduce to state-based rewards") ; diff --git a/tests/storage/test_model.py b/tests/storage/test_model.py index dec72ba..9f25182 100644 --- a/tests/storage/test_model.py +++ b/tests/storage/test_model.py @@ -43,6 +43,20 @@ class TestModel: assert not model.supports_parameters assert type(model) is stormpy.SparseDtmc + def test_reduce_to_state_based_rewards(self): + program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm")) + prop = "R=? [F \"done\"]" + properties = stormpy.parse_properties_for_prism_program(prop, program, None) + model = stormpy.build_model(program, properties) + model.reduce_to_state_based_rewards() + assert len(model.reward_models) == 1 + assert model.reward_models["coin_flips"].has_state_rewards + assert not model.reward_models["coin_flips"].has_state_action_rewards + for reward in model.reward_models["coin_flips"].state_rewards: + assert reward == 1.0 or reward == 0.0 + assert not model.reward_models["coin_flips"].has_transition_rewards + + def test_build_parametric_dtmc_from_prism_program(self): program = stormpy.parse_prism_program(get_example_path("pdtmc", "brp16_2.pm")) prop = "P=? [F s=5]"