#include "src/models/sparse/Mdp.h" #include "src/exceptions/InvalidArgumentException.h" #include "src/utility/constants.h" #include "src/utility/vector.h" #include "src/adapters/CarlAdapter.h" #include "src/models/sparse/StandardRewardModel.h" namespace storm { namespace models { namespace sparse { template Mdp::Mdp(storm::storage::SparseMatrix const& transitionMatrix, storm::models::sparse::StateLabeling const& stateLabeling, std::unordered_map const& rewardModels, boost::optional> const& optionalChoiceLabeling) : NondeterministicModel(storm::models::ModelType::Mdp, transitionMatrix, stateLabeling, rewardModels, optionalChoiceLabeling) { STORM_LOG_THROW(transitionMatrix.isProbabilistic(), storm::exceptions::InvalidArgumentException, "The probability matrix is invalid."); } template Mdp::Mdp(storm::storage::SparseMatrix&& transitionMatrix, storm::models::sparse::StateLabeling&& stateLabeling, std::unordered_map&& rewardModels, boost::optional>&& optionalChoiceLabeling) : NondeterministicModel(storm::models::ModelType::Mdp, std::move(transitionMatrix), std::move(stateLabeling), std::move(rewardModels), std::move(optionalChoiceLabeling)) { STORM_LOG_THROW(transitionMatrix.isProbabilistic(), storm::exceptions::InvalidArgumentException, "The probability matrix is invalid."); } template Mdp Mdp::restrictChoiceLabels(LabelSet const& enabledChoiceLabels) const { STORM_LOG_THROW(this->hasChoiceLabeling(), storm::exceptions::InvalidArgumentException, "Restriction to label set is impossible for unlabeled model."); std::vector const& choiceLabeling = this->getChoiceLabeling(); storm::storage::SparseMatrixBuilder transitionMatrixBuilder(0, this->getTransitionMatrix().getColumnCount(), 0, true, true); std::vector newChoiceLabeling; // Check for each choice of each state, whether the choice labels are fully contained in the given label set. uint_fast64_t currentRow = 0; for(uint_fast64_t state = 0; state < this->getNumberOfStates(); ++state) { bool stateHasValidChoice = false; for (uint_fast64_t choice = this->getTransitionMatrix().getRowGroupIndices()[state]; choice < this->getTransitionMatrix().getRowGroupIndices()[state + 1]; ++choice) { bool choiceValid = std::includes(enabledChoiceLabels.begin(), enabledChoiceLabels.end(), choiceLabeling[choice].begin(), choiceLabeling[choice].end()); // If the choice is valid, copy over all its elements. if (choiceValid) { if (!stateHasValidChoice) { transitionMatrixBuilder.newRowGroup(currentRow); } stateHasValidChoice = true; for (auto const& entry : this->getTransitionMatrix().getRow(choice)) { transitionMatrixBuilder.addNextValue(currentRow, entry.getColumn(), entry.getValue()); } newChoiceLabeling.emplace_back(choiceLabeling[choice]); ++currentRow; } } // If no choice of the current state may be taken, we insert a self-loop to the state instead. if (!stateHasValidChoice) { transitionMatrixBuilder.newRowGroup(currentRow); transitionMatrixBuilder.addNextValue(currentRow, state, storm::utility::one()); newChoiceLabeling.emplace_back(); ++currentRow; } } Mdp restrictedMdp(transitionMatrixBuilder.build(), storm::models::sparse::StateLabeling(this->getStateLabeling()), std::unordered_map(this->getRewardModels()), boost::optional>(newChoiceLabeling)); return restrictedMdp; } template Mdp Mdp::restrictChoices(storm::storage::BitVector const& enabledChoices) const { storm::storage::SparseMatrix restrictedTransitions = this->getTransitionMatrix().restrictRows(enabledChoices); std::unordered_map newRewardModels; for (auto const& rewardModel : this->getRewardModels()) { newRewardModels.emplace(rewardModel.first, rewardModel.second.restrictActions(enabledChoices)); } if(this->hasChoiceLabeling()) { return Mdp(restrictedTransitions, this->getStateLabeling(), newRewardModels, boost::optional>(storm::utility::vector::filterVector(this->getChoiceLabeling(), enabledChoices))); } else { return Mdp(restrictedTransitions, this->getStateLabeling(), newRewardModels, boost::optional>()); } } template uint_least64_t Mdp::getChoiceIndex(storm::storage::StateActionPair const& stateactPair) const { return this->getNondeterministicChoiceIndices()[stateactPair.getState()]+stateactPair.getAction(); } template class Mdp; template class Mdp; template class Mdp; template class Mdp>; template class Mdp; } // namespace sparse } // namespace models } // namespace storm