#include "storm/storage/dd/bisimulation/PartialQuotientExtractor.h" #include "storm/storage/dd/DdManager.h" #include "storm/models/symbolic/Mdp.h" #include "storm/models/symbolic/StochasticTwoPlayerGame.h" #include "storm/models/symbolic/StandardRewardModel.h" #include "storm/settings/SettingsManager.h" #include "storm/utility/macros.h" #include "storm/exceptions/NotSupportedException.h" namespace storm { namespace dd { namespace bisimulation { template PartialQuotientExtractor::PartialQuotientExtractor(storm::models::symbolic::Model const& model) : model(model) { auto const& settings = storm::settings::getModule(); this->quotientFormat = settings.getQuotientFormat(); STORM_LOG_THROW(this->quotientFormat == storm::settings::modules::BisimulationSettings::QuotientFormat::Dd, storm::exceptions::NotSupportedException, "Only DD-based partial quotient extraction is currently supported."); } template std::shared_ptr> PartialQuotientExtractor::extract(Partition const& partition, PreservationInformation const& preservationInformation) { auto start = std::chrono::high_resolution_clock::now(); std::shared_ptr> result; STORM_LOG_THROW(this->quotientFormat == storm::settings::modules::BisimulationSettings::QuotientFormat::Dd, storm::exceptions::NotSupportedException, "Only DD-based partial quotient extraction is currently supported."); result = extractDdQuotient(partition, preservationInformation); auto end = std::chrono::high_resolution_clock::now(); STORM_LOG_TRACE("Quotient extraction completed in " << std::chrono::duration_cast(end - start).count() << "ms."); STORM_LOG_THROW(result, storm::exceptions::NotSupportedException, "Quotient could not be extracted."); return result; } template std::shared_ptr> PartialQuotientExtractor::extractDdQuotient(Partition const& partition, PreservationInformation const& preservationInformation) { auto modelType = model.getType(); if (modelType == storm::models::ModelType::Dtmc || modelType == storm::models::ModelType::Mdp) { // Sanity checks. STORM_LOG_ASSERT(partition.getNumberOfStates() == model.getNumberOfStates(), "Mismatching partition size."); STORM_LOG_ASSERT(partition.getStates().renameVariables(model.getColumnVariables(), model.getRowVariables()) == model.getReachableStates(), "Mismatching partition."); std::set blockVariableSet = {partition.getBlockVariable()}; std::set blockPrimeVariableSet = {partition.getPrimedBlockVariable()}; std::vector> blockMetaVariablePairs = {std::make_pair(partition.getBlockVariable(), partition.getPrimedBlockVariable())}; storm::dd::Bdd partitionAsBdd = partition.storedAsBdd() ? partition.asBdd() : partition.asAdd().notZero(); auto start = std::chrono::high_resolution_clock::now(); partitionAsBdd = partitionAsBdd.renameVariables(model.getColumnVariables(), model.getRowVariables()); storm::dd::Bdd reachableStates = partitionAsBdd.existsAbstract(model.getRowVariables()); storm::dd::Bdd initialStates = (model.getInitialStates() && partitionAsBdd).existsAbstract(model.getRowVariables()); std::map> preservedLabelBdds; for (auto const& label : preservationInformation.getLabels()) { preservedLabelBdds.emplace(label, (model.getStates(label) && partitionAsBdd).existsAbstract(model.getRowVariables())); } for (auto const& expression : preservationInformation.getExpressions()) { std::stringstream stream; stream << expression; std::string expressionAsString = stream.str(); auto it = preservedLabelBdds.find(expressionAsString); if (it != preservedLabelBdds.end()) { STORM_LOG_WARN("Duplicate label '" << expressionAsString << "', dropping second label definition."); } else { preservedLabelBdds.emplace(stream.str(), (model.getStates(expression) && partitionAsBdd).existsAbstract(model.getRowVariables())); } } auto end = std::chrono::high_resolution_clock::now(); STORM_LOG_TRACE("Quotient labels extracted in " << std::chrono::duration_cast(end - start).count() << "ms."); start = std::chrono::high_resolution_clock::now(); std::set blockAndRowVariables; std::set_union(blockVariableSet.begin(), blockVariableSet.end(), model.getRowVariables().begin(), model.getRowVariables().end(), std::inserter(blockAndRowVariables, blockAndRowVariables.end())); std::set blockPrimeAndColumnVariables; std::set_union(blockPrimeVariableSet.begin(), blockPrimeVariableSet.end(), model.getColumnVariables().begin(), model.getColumnVariables().end(), std::inserter(blockPrimeAndColumnVariables, blockPrimeAndColumnVariables.end())); storm::dd::Add partitionAsAdd = partitionAsBdd.template toAdd(); storm::dd::Add quotientTransitionMatrix = model.getTransitionMatrix().multiplyMatrix(partitionAsAdd.renameVariables(blockAndRowVariables, blockPrimeAndColumnVariables), model.getColumnVariables()); quotientTransitionMatrix = quotientTransitionMatrix * partitionAsAdd; end = std::chrono::high_resolution_clock::now(); // Check quotient matrix for sanity. if (std::is_same::value) { STORM_LOG_ASSERT(quotientTransitionMatrix.greater(storm::utility::one()).isZero(), "Illegal entries in quotient matrix."); } else { STORM_LOG_ASSERT(quotientTransitionMatrix.greater(storm::utility::one() + storm::utility::convertNumber(1e-6)).isZero(), "Illegal entries in quotient matrix."); } STORM_LOG_TRACE("Quotient transition matrix extracted in " << std::chrono::duration_cast(end - start).count() << "ms."); storm::dd::Bdd quotientTransitionMatrixBdd = quotientTransitionMatrix.notZero(); std::set nonSourceVariables; std::set_union(blockPrimeVariableSet.begin(), blockPrimeVariableSet.end(), model.getRowVariables().begin(), model.getRowVariables().end(), std::inserter(nonSourceVariables, nonSourceVariables.begin())); storm::dd::Bdd deadlockStates = !quotientTransitionMatrixBdd.existsAbstract(nonSourceVariables) && reachableStates; start = std::chrono::high_resolution_clock::now(); std::unordered_map> quotientRewardModels; for (auto const& rewardModelName : preservationInformation.getRewardModelNames()) { auto const& rewardModel = model.getRewardModel(rewardModelName); boost::optional> quotientStateRewards; if (rewardModel.hasStateRewards()) { quotientStateRewards = rewardModel.getStateRewardVector() * partitionAsAdd; } boost::optional> quotientStateActionRewards; if (rewardModel.hasStateActionRewards()) { quotientStateActionRewards = rewardModel.getStateActionRewardVector() * partitionAsAdd; } quotientRewardModels.emplace(rewardModelName, storm::models::symbolic::StandardRewardModel(quotientStateRewards, quotientStateActionRewards, boost::none)); } end = std::chrono::high_resolution_clock::now(); STORM_LOG_TRACE("Reward models extracted in " << std::chrono::duration_cast(end - start).count() << "ms."); if (modelType == storm::models::ModelType::Dtmc) { return std::make_shared>(model.getManager().asSharedPointer(), reachableStates, initialStates, deadlockStates, quotientTransitionMatrix, blockVariableSet, blockPrimeVariableSet, blockMetaVariablePairs, model.getRowVariables(), preservedLabelBdds, quotientRewardModels); } else if (modelType == storm::models::ModelType::Mdp) { std::set allNondeterminismVariables; std::set_union(model.getRowVariables().begin(), model.getRowVariables().end(), model.getNondeterminismVariables().begin(), model.getNondeterminismVariables().end(), std::inserter(allNondeterminismVariables, allNondeterminismVariables.begin())); return std::make_shared>(model.getManager().asSharedPointer(), reachableStates, initialStates, deadlockStates, quotientTransitionMatrix, blockVariableSet, blockPrimeVariableSet, blockMetaVariablePairs, model.getRowVariables(), model.getNondeterminismVariables(), allNondeterminismVariables, preservedLabelBdds, quotientRewardModels); } else { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Unsupported quotient type."); } } else { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Cannot extract partial quotient for this model type."); } } template class PartialQuotientExtractor; template class PartialQuotientExtractor; #ifdef STORM_HAVE_CARL template class PartialQuotientExtractor; template class PartialQuotientExtractor; #endif } } }