diff --git a/src/mod_storage.cpp b/src/mod_storage.cpp index e91907c..003646b 100644 --- a/src/mod_storage.cpp +++ b/src/mod_storage.cpp @@ -15,5 +15,6 @@ PYBIND11_PLUGIN(storage) { define_bitvector(m); define_model(m); define_sparse_matrix(m); + define_model_instantiator(m); return m.ptr(); } diff --git a/src/storage/model.cpp b/src/storage/model.cpp index c8a8945..f411bd2 100644 --- a/src/storage/model.cpp +++ b/src/storage/model.cpp @@ -22,6 +22,14 @@ storm::storage::SparseMatrix& getTransitionMatrix(storm::models::spar return model.getTransitionMatrix(); } +std::set probabilityVariables(storm::models::sparse::Model const& model) { + return storm::models::sparse::getProbabilityParameters(model); +} + +std::set rewardVariables(storm::models::sparse::Model const& model) { + return storm::models::sparse::getRewardParameters(model); +} + // Define python bindings void define_model(py::module& m) { @@ -64,7 +72,8 @@ void define_model(py::module& m) { ; py::class_, std::shared_ptr>> modelRatFunc(m, "SparseParametricModel", "A probabilistic model where transitions are represented by rational functions and saved in a sparse matrix", modelBase); - modelRatFunc.def("collect_probability_parameters", &storm::models::sparse::getProbabilityParameters, "Collect parameters") + modelRatFunc.def("collect_probability_parameters", &probabilityVariables, "Collect parameters") + .def("collect_reward_parameters", &rewardVariables, "Collect reward parameters") .def_property_readonly("labels", [](storm::models::sparse::Model& model) { return model.getStateLabeling().getLabels(); }, "Labels")