diff --git a/src/storm/models/sparse/Model.cpp b/src/storm/models/sparse/Model.cpp index d115e49f7..b00ef9b6f 100644 --- a/src/storm/models/sparse/Model.cpp +++ b/src/storm/models/sparse/Model.cpp @@ -360,6 +360,17 @@ namespace storm { std::set getProbabilityParameters(Model const& model) { return storm::storage::getVariables(model.getTransitionMatrix()); } + + + + std::set getRewardParameters(Model const& model) { + std::set result; + for(auto rewModel : model.getRewardModels()) { + std::set tmp = getRewardModelParameters(rewModel.second); + result.insert(tmp.begin(), tmp.end()); + } + return result; + } #endif template class Model; diff --git a/src/storm/models/sparse/Model.h b/src/storm/models/sparse/Model.h index 51f43eb57..583ffbb93 100644 --- a/src/storm/models/sparse/Model.h +++ b/src/storm/models/sparse/Model.h @@ -374,6 +374,7 @@ namespace storm { #ifdef STORM_HAVE_CARL std::set getProbabilityParameters(Model const& model); + std::set getRewardParameters(Model const& model); #endif } // namespace sparse } // namespace models diff --git a/src/storm/models/sparse/StandardRewardModel.cpp b/src/storm/models/sparse/StandardRewardModel.cpp index c82041437..8b7e8c349 100644 --- a/src/storm/models/sparse/StandardRewardModel.cpp +++ b/src/storm/models/sparse/StandardRewardModel.cpp @@ -298,7 +298,24 @@ namespace storm { << std::noboolalpha; return out; } - + + std::set getRewardModelParameters(StandardRewardModel const& rewModel) { + std::set vars; + if (rewModel.hasTransitionRewards()) { + vars = storm::storage::getVariables(rewModel.getTransitionRewardMatrix()); + } + if (rewModel.hasStateActionRewards()) { + std::set tmp = storm::utility::vector::getVariables(rewModel.getStateActionRewardVector()); + vars.insert(tmp.begin(), tmp.end()); + } + if (rewModel.hasStateRewards()) { + std::set tmp = storm::utility::vector::getVariables(rewModel.getStateRewardVector()); + vars.insert(tmp.begin(), tmp.end()); + } + return vars; + + } + // Explicitly instantiate the class. template std::vector StandardRewardModel::getTotalRewardVector(storm::storage::SparseMatrix const& transitionMatrix) const; template std::vector StandardRewardModel::getTotalRewardVector(uint_fast64_t numberOfRows, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& filter) const; diff --git a/src/storm/models/sparse/StandardRewardModel.h b/src/storm/models/sparse/StandardRewardModel.h index 6651b789b..8fcfa48ef 100644 --- a/src/storm/models/sparse/StandardRewardModel.h +++ b/src/storm/models/sparse/StandardRewardModel.h @@ -1,11 +1,11 @@ -#ifndef STORM_MODELS_SPARSE_STANDARDREWARDMODEL_H_ -#define STORM_MODELS_SPARSE_STANDARDREWARDMODEL_H_ +#pragma once #include #include #include "storm/storage/SparseMatrix.h" #include "storm/utility/OsDetection.h" +#include "storm/adapters/CarlAdapter.h" namespace storm { namespace models { @@ -296,8 +296,8 @@ namespace storm { template std::ostream& operator<<(std::ostream& out, StandardRewardModel const& rewardModel); + + std::set getRewardModelParameters(StandardRewardModel const& rewModel); } } } - -#endif /* STORM_MODELS_SPARSE_STANDARDREWARDMODEL_H_ */ diff --git a/src/storm/utility/vector.h b/src/storm/utility/vector.h index f0cf1330e..dc22fa310 100644 --- a/src/storm/utility/vector.h +++ b/src/storm/utility/vector.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "storm/storage/BitVector.h" #include "storm/utility/constants.h" @@ -822,6 +823,14 @@ namespace storm { return std::any_of(v.begin(), v.end(), [](T value){return !storm::utility::isZero(value);}); } + inline std::set getVariables(std::vector const& vector) { + std::set result; + for(auto const& entry : vector) { + entry.gatherVariables(result); + } + return result; + } + /*! * Output vector as string. *