diff --git a/src/storm/abstraction/prism/CommandAbstractor.cpp b/src/storm/abstraction/prism/CommandAbstractor.cpp index 1af0e8d12..7198b2f2b 100644 --- a/src/storm/abstraction/prism/CommandAbstractor.cpp +++ b/src/storm/abstraction/prism/CommandAbstractor.cpp @@ -106,30 +106,130 @@ namespace storm { ++index; } - // Proceed by relating the blocks via assignments until nothing changes anymore. + // Merge all blocks that are related via the right-hand side of assignments. + for (auto const& update : command.get().getUpdates()) { + for (auto const& assignment : update.getAssignments()) { + std::set rhsVariables = assignment.getExpression().getVariables(); + + if (!rhsVariables.empty()) { + uint64_t blockToKeep = variableToLocalBlockIndex.at(*rhsVariables.begin()); + for (auto const& variable : rhsVariables) { + uint64_t block = variableToLocalBlockIndex.at(variable); + if (block != blockToKeep) { + for (auto const& blockIndex : relevantBlockPartition[block]) { + for (auto const& variable : localExpressionInformation.getVariableBlockWithIndex(blockIndex)) { + variableToLocalBlockIndex[variable] = blockToKeep; + } + } + relevantBlockPartition[blockToKeep].insert(relevantBlockPartition[block].begin(), relevantBlockPartition[block].end()); + relevantBlockPartition[block].clear(); + } + } + } + } + } + + // Proceed by relating the blocks via assignment-variables and the expressions of their assigned expressions. bool changed = false; do { + changed = false; for (auto const& update : command.get().getUpdates()) { for (auto const& assignment : update.getAssignments()) { std::set rhsVariables = assignment.getExpression().getVariables(); + if (!rhsVariables.empty()) { + storm::expressions::Variable const& representativeVariable = *rhsVariables.begin(); + uint64_t representativeBlock = variableToLocalBlockIndex.at(representativeVariable); + uint64_t assignmentVariableBlock = variableToLocalBlockIndex.at(assignment.getVariable()); + + // If the blocks are different, we merge them now + if (assignmentVariableBlock != representativeBlock) { + changed = true; + + for (auto const& blockIndex : relevantBlockPartition[assignmentVariableBlock]) { + for (auto const& variable : localExpressionInformation.getVariableBlockWithIndex(blockIndex)) { + variableToLocalBlockIndex[variable] = representativeBlock; + } + } + relevantBlockPartition[representativeBlock].insert(relevantBlockPartition[assignmentVariableBlock].begin(), relevantBlockPartition[assignmentVariableBlock].end()); + relevantBlockPartition[assignmentVariableBlock].clear(); + + } + } } } } while (changed); - // if the decomposition has size 1, use the plain technique from before - - // otherwise, enumerate the abstract guard so we do this only once - - // then enumerate the solutions for each of the blocks of the decomposition - - // multiply the results - - // multiply with the abstract guard - - // multiply with missing identities + // Now remove all blocks that are empty and obtain the partition. + std::vector> cleanedRelevantBlockPartition; + for (auto& element : relevantBlockPartition) { + if (!element.empty()) { + cleanedRelevantBlockPartition.emplace_back(std::move(element)); + } + } + relevantBlockPartition = std::move(cleanedRelevantBlockPartition); - // cache and return result + // if the decomposition has size 1, use the plain technique from before + if (relevantBlockPartition.size() == 1) { + STORM_LOG_TRACE("Relevant block partition size is one, falling back to regular computation."); + recomputeCachedBdd(); + } else { + // otherwise, enumerate the abstract guard so we do this only once + std::set relatedGuardPredicates = localExpressionInformation.getRelatedExpressions(command.get().getGuardExpression().getVariables()); + std::vector guardDecisionVariables; + std::vector> guardVariablesAndPredicates; + for (auto const& element : relevantPredicatesAndVariables.first) { + if (relatedGuardPredicates.find(element.second) != relatedGuardPredicates.end()) { + guardDecisionVariables.push_back(element.first); + guardVariablesAndPredicates.push_back(element); + } + } + uint64_t numberOfSolutions = 0; + abstractGuard = this->getAbstractionInformation().getDdManager().getBddZero(); + smtSolver->allSat(decisionVariables, [this,&guardVariablesAndPredicates,&numberOfSolutions] (storm::solver::SmtSolver::ModelReference const& model) { + abstractGuard |= getSourceStateBdd(model, guardVariablesAndPredicates); + ++numberOfSolutions; + return true; + }); + STORM_LOG_TRACE("Enumerated " << numberOfSolutions << " for abstract guard."); + + // then enumerate the solutions for each of the blocks of the decomposition + for (auto const& block : relevantBlockPartition) { + std::set relevantPredicates; + for (auto const& innerBlock : block) { + relevantPredicates.insert(localExpressionInformation.getExpressionBlock(innerBlock).begin(), localExpressionInformation.getExpressionBlock(innerBlock).end()); + } + + std::vector decisionVariables; + std::vector>> variablesAndPredicates; + for (uint64_t updateIndex = 0; updateIndex < command.get().getNumberOfUpdates(); ++updateIndex) { + variablesAndPredicates.emplace_back(); + for (auto const& element : relevantPredicatesAndVariables.second[updateIndex]) { + if (relevantPredicates.find(element.second) != relevantPredicates.end()) { + decisionVariables.push_back(element.first); + variablesAndPredicates.back().push_back(element); + } + } + } + + std::unordered_map, std::vector>> sourceToDistributionsMap; + numberOfSolutions = 0; + smtSolver->allSat(decisionVariables, [&sourceToDistributionsMap,this,&numberOfSolutions] (storm::solver::SmtSolver::ModelReference const& model) { + sourceToDistributionsMap[getSourceStateBdd(model, relevantPredicatesAndVariables.first)].push_back(getDistributionBdd(model, relevantPredicatesAndVariables.second)); + ++numberOfSolutions; + return true; + }); + } + + // multiply the results + + // multiply with the abstract guard + + // multiply with missing identities + + // cache and return result + + } } template @@ -141,7 +241,7 @@ namespace storm { std::unordered_map, std::vector>> sourceToDistributionsMap; uint64_t numberOfSolutions = 0; smtSolver->allSat(decisionVariables, [&sourceToDistributionsMap,this,&numberOfSolutions] (storm::solver::SmtSolver::ModelReference const& model) { - sourceToDistributionsMap[getSourceStateBdd(model)].push_back(getDistributionBdd(model)); + sourceToDistributionsMap[getSourceStateBdd(model, relevantPredicatesAndVariables.first)].push_back(getDistributionBdd(model, relevantPredicatesAndVariables.second)); ++numberOfSolutions; return true; }); @@ -283,9 +383,9 @@ namespace storm { } template - storm::dd::Bdd CommandAbstractor::getSourceStateBdd(storm::solver::SmtSolver::ModelReference const& model) const { + storm::dd::Bdd CommandAbstractor::getSourceStateBdd(storm::solver::SmtSolver::ModelReference const& model, std::vector> const& variablePredicates) const { storm::dd::Bdd result = this->getAbstractionInformation().getDdManager().getBddOne(); - for (auto const& variableIndexPair : relevantPredicatesAndVariables.first) { + for (auto const& variableIndexPair : variablePredicates) { if (model.getBooleanValue(variableIndexPair.first)) { result &= this->getAbstractionInformation().encodePredicateAsSource(variableIndexPair.second); } else { @@ -298,14 +398,14 @@ namespace storm { } template - storm::dd::Bdd CommandAbstractor::getDistributionBdd(storm::solver::SmtSolver::ModelReference const& model) const { + storm::dd::Bdd CommandAbstractor::getDistributionBdd(storm::solver::SmtSolver::ModelReference const& model, std::vector>> const& variablePredicates) const { storm::dd::Bdd result = this->getAbstractionInformation().getDdManager().getBddZero(); for (uint_fast64_t updateIndex = 0; updateIndex < command.get().getNumberOfUpdates(); ++updateIndex) { storm::dd::Bdd updateBdd = this->getAbstractionInformation().getDdManager().getBddOne(); // Translate block variables for this update into a successor block. - for (auto const& variableIndexPair : relevantPredicatesAndVariables.second[updateIndex]) { + for (auto const& variableIndexPair : variablePredicates[updateIndex]) { if (model.getBooleanValue(variableIndexPair.first)) { updateBdd &= this->getAbstractionInformation().encodePredicateAsSuccessor(variableIndexPair.second); } else { diff --git a/src/storm/abstraction/prism/CommandAbstractor.h b/src/storm/abstraction/prism/CommandAbstractor.h index c9eafce8f..9046fd053 100644 --- a/src/storm/abstraction/prism/CommandAbstractor.h +++ b/src/storm/abstraction/prism/CommandAbstractor.h @@ -147,7 +147,7 @@ namespace storm { * @param model The model to translate. * @return The source state encoded as a DD. */ - storm::dd::Bdd getSourceStateBdd(storm::solver::SmtSolver::ModelReference const& model) const; + storm::dd::Bdd getSourceStateBdd(storm::solver::SmtSolver::ModelReference const& model, std::vector> const& variablePredicates) const; /*! * Translates the given model to a distribution over successor states. @@ -155,7 +155,7 @@ namespace storm { * @param model The model to translate. * @return The source state encoded as a DD. */ - storm::dd::Bdd getDistributionBdd(storm::solver::SmtSolver::ModelReference const& model) const; + storm::dd::Bdd getDistributionBdd(storm::solver::SmtSolver::ModelReference const& model, std::vector>> const& variablePredicates) const; /*! * Recomputes the cached BDD. This needs to be triggered if any relevant predicates change.