diff --git a/src/models/Mdp.h b/src/models/Mdp.h index 9523132b0..0f97e304a 100644 --- a/src/models/Mdp.h +++ b/src/models/Mdp.h @@ -17,6 +17,7 @@ #include "src/storage/SparseMatrix.h" #include "src/settings/Settings.h" #include "src/models/AbstractNondeterministicModel.h" +#include "src/utility/set.h" namespace storm { @@ -125,25 +126,46 @@ public: return MDP; } -// /*! -// * Constructs an MDP by copying the given MDP and restricting the choices of each state to the ones whose label set -// * is contained in the given label set. -// * -// * @param originalModel The model to restrict. -// * @param enabledChoiceLabels A set of labels that determines which choices of the original model can be taken -// * and which ones need to be ignored. -// */ -// Mdp restrictChoiceLabels(Mdp const& originalModel, std::set const& enabledChoiceLabels) { -// // Only perform this operation if the given model has choice labels. -// if (!originalModel.hasChoiceLabels()) { -// throw storm::exceptions::InvalidArgumentException() << "Restriction to label set is impossible for unlabeled model."; -// } -// -// storm::storage::SparseMatrix transitionMatrix(); -// -// Mdp result; -// return result; -// } + /*! + * Constructs an MDP by copying the given MDP and restricting the choices of each state to the ones whose label set + * is contained in the given label set. + * + * @param originalModel The model to restrict. + * @param enabledChoiceLabels A set of labels that determines which choices of the original model can be taken + * and which ones need to be ignored. + */ + Mdp restrictChoiceLabels(Mdp const& originalModel, std::set const& enabledChoiceLabels) { + // Only perform this operation if the given model has choice labels. + if (!originalModel.hasChoiceLabels()) { + throw storm::exceptions::InvalidArgumentException() << "Restriction to label set is impossible for unlabeled model."; + } + + std::vector> const& choiceLabeling = this->getChoiceLabeling(); + + storm::storage::SparseMatrix transitionMatrix; + transitionMatrix.initialize(); + + // Check for each choice of each state, whether the choice labels are fully contained in the given label set. + for(uint_fast64_t state = 0; state < this->getNumberOfStates(); ++state) { + for (uint_fast64_t choice = this->getNondeterministicChoiceIndices()[state]; choice < this->getNondeterministicChoiceIndices()[state + 1]; ++choice) { + bool choiceValid = storm::utility::set::isSubsetOf(choiceLabeling[state], enabledChoiceLabels); + + // If the choice is valid, copy over all its elements. + if (choiceValid) { + typename storm::storage::SparseMatrix::Rows row = this->getTransitionMatrix().getRows(choice, choice); + for (typename storm::storage::SparseMatrix::ConstIterator rowIt = row.begin(), rowIte = row.end(); rowIt != rowIte; ++rowIt) { + transitionMatrix.insertNextValue(choice, rowIt.column(), rowIt.value(), true); + } + } else { + // If the choice may not be taken, we insert a self-loop to the state instead. + transitionMatrix.insertNextValue(choice, state, storm::utility::constGetOne(), true); + } + } + } + + Mdp restrictedMdp(std::move(transitionMatrix), storm::models::AtomicPropositionsLabeling(this->getStateLabeling()), std::vector(this->getNondeterministicChoiceIndices()), this->hasStateRewards() ? boost::optional>(this->getStateRewardVector()) : boost::optional>(), this->hasTransitionRewards() ? boost::optional>(this->getTransitionRewardMatrix()) : boost::optional>(), boost::optional>>(this->getChoiceLabeling())); + return restrictedMdp; + } /*! * Calculates a hash over all values contained in this Model. diff --git a/src/utility/set.h b/src/utility/set.h new file mode 100644 index 000000000..d264034eb --- /dev/null +++ b/src/utility/set.h @@ -0,0 +1,52 @@ +/* + * set.h + * + * Created on: 06.12.2012 + * Author: Christian Dehnert + */ + +#ifndef STORM_UTILITY_SET_H_ +#define STORM_UTILITY_SET_H_ + +#include + +#include "log4cplus/logger.h" +#include "log4cplus/loggingmacros.h" + +extern log4cplus::Logger logger; + +namespace storm { + namespace utility { + namespace set { + + template + bool isSubsetOf(std::set const& set1, std::set const& set2) { + // First, get a comparator object. + typename std::set::key_compare comparator = set1.key_comp(); + + for (typename std::set::const_iterator it1 = set1.begin(), it2 = set2.begin(); it1 != set1.end() && it2 != set2.end(); ++it1) { + // If the value in set1 is smaller than the value in set2, set1 is not a subset of set2. + if (comparator(*it1, *it2)) { + return false; + } + + // If the value in the second set is smaller, we need to move the iterator until the comparison is false. + while(comparator(*it2, *it1) && it2 != set2.end()) { + ++it2; + } + + // If we have reached the end of set2 or the element we found is actually larger than the one in set1 + // we know that the subset property is violated. + if (it2 == set2.end() || comparator(*it1, *it2)) { + return false; + } + + // Otherwise, we have found an equivalent element and can continue with the next one. + } + } + + } // namespace set + } // namespace utility +} // namespace storm + +#endif /* STORM_UTILITY_SET_H_ */