diff --git a/src/storage/model.cpp b/src/storage/model.cpp index b3f824d..9be8f96 100644 --- a/src/storage/model.cpp +++ b/src/storage/model.cpp @@ -212,7 +212,7 @@ void define_sparse_model(py::module& m) { ; py::class_, std::shared_ptr>>(m, "SparsePomdp", "POMDP in sparse representation", mdp) .def(py::init>(), py::arg("other_model")) - .def(py::init const&, bool>(), py::arg("components"), py::arg("canonicFlag")=false) //todo tests + .def(py::init const&, bool>(), py::arg("components"), py::arg("canonic_flag")=false) //todo tests .def("__str__", &getModelInfoPrinter) .def("get_observation", &SparsePomdp::getObservation, py::arg("state")) .def_property_readonly("observations", &SparsePomdp::getObservations) @@ -234,7 +234,6 @@ void define_sparse_model(py::module& m) { ; py::class_>(m, "SparseRewardModel", "Reward structure for sparse models") - //todo init? .def_property_readonly("has_state_rewards", &SparseRewardModel::hasStateRewards) .def_property_readonly("has_state_action_rewards", &SparseRewardModel::hasStateActionRewards) .def_property_readonly("has_transition_rewards", &SparseRewardModel::hasTransitionRewards) diff --git a/src/storage/modelcomponents.cpp b/src/storage/modelcomponents.cpp index 95a738c..2adbcb6 100644 --- a/src/storage/modelcomponents.cpp +++ b/src/storage/modelcomponents.cpp @@ -25,15 +25,16 @@ template using SparseModelComponents = storm::storage::spars void define_sparse_modelcomponents(py::module& m) { - py::class_>(m, "SparseModelComponents", "ModelComponents description..") //todo + // shared_ptr? todo + py::class_, std::shared_ptr>>(m, "SparseModelComponents", "ModelComponents description..") //todo .def(py::init const&, StateLabeling const&, std::unordered_map> const&, bool, boost::optional const&, boost::optional> const&>(), - py::arg("transition_matrix"), py::arg("state_labeling") = storm::models::sparse::StateLabeling(), + py::arg("transition_matrix") = SparseMatrix(), py::arg("state_labeling") = storm::models::sparse::StateLabeling(), py::arg("reward_models") = std::unordered_map>(), py::arg("rate_transitions") = false, py::arg("markovian_states") = boost::none, py::arg("player1_matrix") = boost::none) - //.def(py::init<>()) // for rvalue ? todo + // General components (for all model types) .def_readwrite("transition_matrix", &SparseModelComponents::transitionMatrix) diff --git a/tests/storage/test_modelcomponents.py b/tests/storage/test_modelcomponents.py new file mode 100644 index 0000000..2334016 --- /dev/null +++ b/tests/storage/test_modelcomponents.py @@ -0,0 +1,65 @@ +import stormpy +import stormpy.logic +from helpers.helper import get_example_path +import pytest + + +class TestSparseModel: + def test_init_default(self): + components = stormpy.SparseModelComponents() + + assert components.state_labeling.get_labels() == set() + assert components.reward_models == {} + assert components.transition_matrix.nr_rows == 0 + assert components.transition_matrix.nr_columns == 0 + assert components.markovian_states is None + assert components.player1_matrix is None + assert not components.rate_transitions + + # def test_init(self): + # todo Build simple transition matrix etc + # transition_matrix = + + + def test_dtmc_modelcomponents(self): + program = stormpy.parse_prism_program(get_example_path("dtmc", "die.pm")) + model = stormpy.build_model(program) + + components = stormpy.SparseModelComponents(transition_matrix=model.transition_matrix, + state_labeling=model.labeling, + reward_models=model.reward_models) + + dtmc = stormpy.storage.SparseDtmc(components) + + assert dtmc.model_type == stormpy.ModelType.DTMC + assert dtmc.initial_states == [0] + assert dtmc.nr_states == 13 + for state in dtmc.states: + assert len(state.actions) <= 1 + assert dtmc.labeling.get_labels() == {'init', 'deadlock', 'done', 'one', 'two', 'three', 'four', 'five', 'six'} + assert dtmc.nr_transitions == 20 + assert len(dtmc.reward_models) == 1 + assert not dtmc.reward_models["coin_flips"].has_state_rewards + assert dtmc.reward_models["coin_flips"].has_state_action_rewards + for reward in dtmc.reward_models["coin_flips"].state_action_rewards: + assert reward == 1.0 or reward == 0.0 + assert not dtmc.reward_models["coin_flips"].has_transition_rewards + assert not dtmc.supports_parameters + + + def test_pmdp_modelcomponents(self): + program = stormpy.parse_prism_program(get_example_path("pmdp", "two_dice.nm")) + model = stormpy.build_parametric_model(program) + + + def test_ma_modelcomponents(self): + program = stormpy.parse_prism_program(get_example_path("ma", "simple.ma"), False, True) + formulas = stormpy.parse_properties_for_prism_program("Pmax=? [ F<=2 s=2 ]", program) + model = stormpy.build_model(program, formulas) + #todo create mc + + assert model.nr_states == 4 + assert model.nr_transitions == 7 + assert model.model_type == stormpy.ModelType.MA + assert not model.supports_parameters + assert type(model) is stormpy.SparseMA