diff --git a/src/storm/abstraction/AbstractionInformation.cpp b/src/storm/abstraction/AbstractionInformation.cpp index 52a409b28..9c2250dc5 100644 --- a/src/storm/abstraction/AbstractionInformation.cpp +++ b/src/storm/abstraction/AbstractionInformation.cpp @@ -67,6 +67,7 @@ namespace storm { allPredicateIdentities &= predicateIdentities.back(); sourceVariables.insert(newMetaVariable.first); successorVariables.insert(newMetaVariable.second); + orderedSourceVariables.push_back(newMetaVariable.first); orderedSuccessorVariables.push_back(newMetaVariable.second); ddVariableIndexToPredicateIndexMap[predicateIdentities.back().getIndex()] = predicateIndex; return predicateIndex; @@ -121,6 +122,22 @@ namespace storm { return predicates; } + template + std::vector AbstractionInformation::getPredicates(storm::storage::BitVector const& predicateValuation) const { + STORM_LOG_ASSERT(predicateValuation.size() == this->getNumberOfPredicates(), "Size of predicate valuation does not match number of predicates."); + + std::vector result; + for (uint64_t index = 0; index < this->getNumberOfPredicates(); ++index) { + if (predicateValuation[index]) { + result.push_back(this->getPredicateByIndex(index)); + } else { + result.push_back(!this->getPredicateByIndex(index)); + } + } + + return result; + } + template storm::expressions::Expression const& AbstractionInformation::getPredicateByIndex(uint_fast64_t index) const { return predicates[index]; @@ -268,6 +285,11 @@ namespace storm { return orderedSuccessorVariables; } + template + std::vector const& AbstractionInformation::getOrderedSourceVariables() const { + return orderedSourceVariables; + } + template storm::dd::Bdd const& AbstractionInformation::getAllPredicateIdentities() const { return allPredicateIdentities; @@ -404,6 +426,25 @@ namespace storm { return result; } + template + storm::storage::BitVector AbstractionInformation::decodeState(storm::dd::Bdd const& state) const { + STORM_LOG_ASSERT(state.getNonZeroCount() == 1, "Wrong number of non-zero entries."); + + storm::storage::BitVector statePredicates(this->getNumberOfPredicates()); + + storm::dd::Add add = state.template toAdd(); + auto it = add.begin(); + auto stateValuePair = *it; + for (uint_fast64_t index = 0; index < this->getOrderedSourceVariables().size(); ++index) { + auto const& successorVariable = this->getOrderedSourceVariables()[index]; + if (stateValuePair.first.getBooleanValue(successorVariable)) { + statePredicates.set(index); + } + } + + return statePredicates; + } + template std::map AbstractionInformation::decodeChoiceToUpdateSuccessorMapping(storm::dd::Bdd const& choice) const { std::map result; @@ -412,28 +453,12 @@ namespace storm { for (auto const& successorValuePair : lowerChoiceAsAdd) { uint_fast64_t updateIndex = this->decodeAux(successorValuePair.first, 0, this->getAuxVariableCount()); -#ifdef LOCAL_DEBUG - std::cout << "update idx: " << updateIndex << std::endl; -#endif storm::storage::BitVector successor(this->getNumberOfPredicates()); for (uint_fast64_t index = 0; index < this->getOrderedSuccessorVariables().size(); ++index) { auto const& successorVariable = this->getOrderedSuccessorVariables()[index]; -#ifdef LOCAL_DEBUG - std::cout << successorVariable.getName() << " has value"; -#endif if (successorValuePair.first.getBooleanValue(successorVariable)) { successor.set(index); -#ifdef LOCAL_DEBUG - std::cout << " true"; -#endif - } else { -#ifdef LOCAL_DEBUG - std::cout << " false"; -#endif } -#ifdef LOCAL_DEBUG - std::cout << std::endl; -#endif } result[updateIndex] = successor; @@ -442,40 +467,26 @@ namespace storm { } template - std::pair AbstractionInformation::decodeStateAndUpdate(storm::dd::Bdd const& state) const { - storm::storage::BitVector successor(this->getNumberOfPredicates()); + std::tuple AbstractionInformation::decodeStatePlayer1ChoiceAndUpdate(storm::dd::Bdd const& stateChoiceAndUpdate) const { + stateChoiceAndUpdate.template toAdd().exportToDot("out.dot"); + STORM_LOG_ASSERT(stateChoiceAndUpdate.getNonZeroCount() == 1, "Wrong number of non-zero entries."); - storm::dd::Add stateAsAdd = state.template toAdd(); - uint_fast64_t updateIndex = 0; - for (auto const& stateValuePair : stateAsAdd) { - uint_fast64_t updateIndex = this->decodeAux(stateValuePair.first, 0, this->getAuxVariableCount()); - -#ifdef LOCAL_DEBUG - std::cout << "update idx: " << updateIndex << std::endl; -#endif - storm::storage::BitVector successor(this->getNumberOfPredicates()); - for (uint_fast64_t index = 0; index < this->getOrderedSuccessorVariables().size(); ++index) { - auto const& successorVariable = this->getOrderedSuccessorVariables()[index]; -#ifdef LOCAL_DEBUG - std::cout << successorVariable.getName() << " has value"; -#endif - if (stateValuePair.first.getBooleanValue(successorVariable)) { - successor.set(index); -#ifdef LOCAL_DEBUG - std::cout << " true"; -#endif - } else { -#ifdef LOCAL_DEBUG - std::cout << " false"; -#endif - } -#ifdef LOCAL_DEBUG - std::cout << std::endl; -#endif + storm::storage::BitVector statePredicates(this->getNumberOfPredicates()); + + storm::dd::Add add = stateChoiceAndUpdate.template toAdd(); + auto it = add.begin(); + auto stateValuePair = *it; + uint64_t choiceIndex = this->decodePlayer1Choice(stateValuePair.first, this->getPlayer1VariableCount()); + uint64_t updateIndex = this->decodeAux(stateValuePair.first, 0, this->getAuxVariableCount()); + for (uint_fast64_t index = 0; index < this->getOrderedSourceVariables().size(); ++index) { + auto const& successorVariable = this->getOrderedSourceVariables()[index]; + + if (stateValuePair.first.getBooleanValue(successorVariable)) { + statePredicates.set(index); } } - return std::make_pair(successors, updateIndex); + return std::make_tuple(statePredicates, choiceIndex, updateIndex); } template class AbstractionInformation; diff --git a/src/storm/abstraction/AbstractionInformation.h b/src/storm/abstraction/AbstractionInformation.h index 6bfb728e1..4ba2e4a70 100644 --- a/src/storm/abstraction/AbstractionInformation.h +++ b/src/storm/abstraction/AbstractionInformation.h @@ -141,6 +141,11 @@ namespace storm { */ std::vector const& getPredicates() const; + /*! + * Retrieves a list of expression that corresponds to the given predicate valuation. + */ + std::vector getPredicates(storm::storage::BitVector const& predicateValuation) const; + /*! * Retrieves the predicate with the given index. * @@ -341,6 +346,13 @@ namespace storm { */ std::set const& getSuccessorVariables() const; + /*! + * Retrieves the ordered collection of source meta variables. + * + * @return All source meta variables. + */ + std::vector const& getOrderedSourceVariables() const; + /*! * Retrieves the ordered collection of successor meta variables. * @@ -434,16 +446,22 @@ namespace storm { */ std::vector> declareNewVariables(std::vector> const& oldPredicates, std::set const& newPredicates) const; + /*! + * Decodes the given state (given as a BDD over the source variables) into a a bit vector indicating the + * truth values of the predicates in the state. + */ + storm::storage::BitVector decodeState(storm::dd::Bdd const& state) const; + /*! * Decodes the choice in the form of a BDD over the destination variables. */ std::map decodeChoiceToUpdateSuccessorMapping(storm::dd::Bdd const& choice) const; /*! - * Decodes the given state-and-update BDD (state as source variables) into a bit vector indicating the truth values of - * the predicates in the state and the update index. + * Decodes the given BDD (over source, player 1 and aux variables) into a bit vector indicating the truth + * values of the predicates in the state and the choice/update indices. */ - std::pair decodeStateAndUpdate(storm::dd::Bdd const& stateAndUpdate) const; + std::tuple decodeStatePlayer1ChoiceAndUpdate(storm::dd::Bdd const& stateChoiceAndUpdate) const; private: /*! @@ -504,7 +522,10 @@ namespace storm { /// The set of all successor variables. std::set successorVariables; - + + /// An ordered collection of the source variables. + std::vector orderedSourceVariables; + /// An ordered collection of the successor variables. std::vector orderedSuccessorVariables; diff --git a/src/storm/abstraction/MenuGameAbstractor.h b/src/storm/abstraction/MenuGameAbstractor.h index c4657e4fb..ee5e383a4 100644 --- a/src/storm/abstraction/MenuGameAbstractor.h +++ b/src/storm/abstraction/MenuGameAbstractor.h @@ -27,6 +27,7 @@ namespace storm { virtual storm::expressions::Expression const& getGuard(uint64_t player1Choice) const = 0; virtual std::pair getPlayer1ChoiceRange() const = 0; virtual std::map getVariableUpdates(uint64_t player1Choice, uint64_t auxiliaryChoice) const = 0; + virtual storm::expressions::Expression getInitialExpression() const = 0; /// Methods to refine the abstraction. virtual void refine(RefinementCommand const& command) = 0; diff --git a/src/storm/abstraction/MenuGameRefiner.cpp b/src/storm/abstraction/MenuGameRefiner.cpp index 90ddd3f77..ddb5bd1d4 100644 --- a/src/storm/abstraction/MenuGameRefiner.cpp +++ b/src/storm/abstraction/MenuGameRefiner.cpp @@ -5,6 +5,9 @@ #include "storm/storage/dd/DdManager.h" #include "storm/utility/dd.h" +#include "storm/utility/solver.h" + +#include "storm/solver/MathsatSmtSolver.h" #include "storm/settings/SettingsManager.h" #include "storm/settings/modules/AbstractionSettings.h" @@ -24,6 +27,18 @@ namespace storm { return predicates; } + template + struct PivotStateCandidatesResult { + storm::dd::Bdd reachableTransitionsMin; + storm::dd::Bdd reachableTransitionsMax; + storm::dd::Bdd pivotStates; + }; + + template + PivotStateResult::PivotStateResult(storm::dd::Bdd const& pivotState, storm::OptimizationDirection fromDirection) : pivotState(pivotState), fromDirection(fromDirection) { + // Intentionally left empty. + } + template MenuGameRefiner::MenuGameRefiner(MenuGameAbstractor& abstractor, std::unique_ptr&& smtSolver) : abstractor(abstractor), splitPredicates(storm::settings::getModule().isSplitPredicatesSet()), splitGuards(storm::settings::getModule().isSplitGuardsSet()), splitter(), equivalenceChecker(std::move(smtSolver)) { @@ -47,17 +62,17 @@ namespace storm { storm::dd::Bdd getMostProbablePathSpanningTree(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& targetState, storm::dd::Bdd const& transitionFilter) { storm::dd::Add maxProbabilities = game.getInitialStates().template toAdd(); - storm::dd::Add border = game.getInitialStates().template toAdd(); + storm::dd::Bdd border = game.getInitialStates(); storm::dd::Bdd spanningTree = game.getManager().getBddZero(); - storm::dd::Add transitionMatrix = ((transitionFilter && game.getExtendedTransitionMatrix().maxAbstractRepresentative(game.getProbabilisticBranchingVariables())).template toAdd() * game.getExtendedTransitionMatrix()); - transitionMatrix = transitionMatrix.sumAbstract(game.getNondeterminismVariables()); + storm::dd::Add transitionMatrix = ((transitionFilter && game.getExtendedTransitionMatrix().maxAbstractRepresentative(game.getProbabilisticBranchingVariables())).template toAdd() * game.getExtendedTransitionMatrix()).sumAbstract(game.getPlayer2Variables()); std::set variablesToAbstract(game.getRowVariables()); + variablesToAbstract.insert(game.getPlayer1Variables().begin(), game.getPlayer1Variables().end()); variablesToAbstract.insert(game.getProbabilisticBranchingVariables().begin(), game.getProbabilisticBranchingVariables().end()); while (!border.isZero() && (border && targetState).isZero()) { // Determine the new maximal probabilities to all states. - storm::dd::Add tmp = border * transitionMatrix * maxProbabilities; + storm::dd::Add tmp = border.template toAdd() * transitionMatrix * maxProbabilities; storm::dd::Bdd newMaxProbabilityChoices = tmp.maxAbstractRepresentative(variablesToAbstract); storm::dd::Add newMaxProbabilities = tmp.maxAbstract(variablesToAbstract).swapVariables(game.getRowColumnMetaVariablePairs()); @@ -72,14 +87,14 @@ namespace storm { spanningTree |= updateStates.swapVariables(game.getRowColumnMetaVariablePairs()) && newMaxProbabilityChoices; // Continue exploration from states that have been updated. - border = updateStates.template toAdd(); + border = updateStates; } return spanningTree; } template - std::pair, storm::OptimizationDirection> pickPivotState(storm::dd::Bdd const& initialStates, storm::dd::Bdd const& transitionsMin, storm::dd::Bdd const& transitionsMax, std::set const& rowVariables, std::set const& columnVariables, storm::dd::Bdd const& pivotStates, boost::optional> const& quantitativeResult = boost::none) { + PivotStateResult pickPivotState(storm::dd::Bdd const& initialStates, storm::dd::Bdd const& transitionsMin, storm::dd::Bdd const& transitionsMax, std::set const& rowVariables, std::set const& columnVariables, storm::dd::Bdd const& pivotStates, boost::optional> const& quantitativeResult = boost::none) { // Set up used variables. storm::dd::Bdd frontierMin = initialStates; @@ -91,14 +106,14 @@ namespace storm { bool foundPivotState = !frontierPivotStates.isZero(); if (foundPivotState) { STORM_LOG_TRACE("Picked pivot state from " << frontierPivotStates.getNonZeroCount() << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total."); - return std::make_pair(frontierPivotStates.existsAbstractRepresentative(rowVariables), storm::OptimizationDirection::Minimize); + return PivotStateResult(frontierPivotStates.existsAbstractRepresentative(rowVariables), storm::OptimizationDirection::Minimize); } else { - // Otherwise, we perform a simulatenous BFS in the sense that we make one step in both the min and max // transitions and check for pivot states we encounter. while (!foundPivotState) { frontierMin = frontierMin.relationalProduct(transitionsMin, rowVariables, columnVariables); frontierMax = frontierMax.relationalProduct(transitionsMax, rowVariables, columnVariables); + ++level; storm::dd::Bdd frontierMinPivotStates = frontierMin && pivotStates; storm::dd::Bdd frontierMaxPivotStates = frontierMax && pivotStates; @@ -122,7 +137,7 @@ namespace storm { } STORM_LOG_TRACE("Picked pivot state with difference " << diffValue << " from " << numberOfPivotStateCandidatesOnLevel << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total."); - return std::make_pair(direction == storm::OptimizationDirection::Minimize ? diffMin.maxAbstractRepresentative(rowVariables) : diffMax.maxAbstractRepresentative(rowVariables), direction); + return PivotStateResult(direction == storm::OptimizationDirection::Minimize ? diffMin.maxAbstractRepresentative(rowVariables) : diffMax.maxAbstractRepresentative(rowVariables), direction); } else { STORM_LOG_TRACE("Picked pivot state from " << numberOfPivotStateCandidatesOnLevel << " candidates on level " << level << ", " << pivotStates.getNonZeroCount() << " candidates in total."); @@ -133,15 +148,14 @@ namespace storm { direction = storm::OptimizationDirection::Maximize; } - return std::make_pair(direction == storm::OptimizationDirection::Minimize ? frontierMinPivotStates.existsAbstractRepresentative(rowVariables) : frontierMaxPivotStates.existsAbstractRepresentative(rowVariables), direction); + return PivotStateResult(direction == storm::OptimizationDirection::Minimize ? frontierMinPivotStates.existsAbstractRepresentative(rowVariables) : frontierMaxPivotStates.existsAbstractRepresentative(rowVariables), direction); } } - ++level; } } STORM_LOG_ASSERT(false, "This point must not be reached, because then no pivot state could be found."); - return std::make_pair(storm::dd::Bdd(), storm::OptimizationDirection::Minimize); + return PivotStateResult(storm::dd::Bdd(), storm::OptimizationDirection::Minimize); } template @@ -206,17 +220,10 @@ namespace storm { return RefinementPredicates(fromGuard ? RefinementPredicates::Source::Guard : RefinementPredicates::Source::WeakestPrecondition, {newPredicate}); } - template - struct PivotStateResult { - storm::dd::Bdd reachableTransitionsMin; - storm::dd::Bdd reachableTransitionsMax; - storm::dd::Bdd pivotStates; - }; - template - PivotStateResult computePivotStates(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& transitionMatrixBdd, storm::dd::Bdd const& minPlayer1Strategy, storm::dd::Bdd const& minPlayer2Strategy, storm::dd::Bdd const& maxPlayer1Strategy, storm::dd::Bdd const& maxPlayer2Strategy) { + PivotStateCandidatesResult computePivotStates(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& transitionMatrixBdd, storm::dd::Bdd const& minPlayer1Strategy, storm::dd::Bdd const& minPlayer2Strategy, storm::dd::Bdd const& maxPlayer1Strategy, storm::dd::Bdd const& maxPlayer2Strategy) { - PivotStateResult result; + PivotStateCandidatesResult result; // Build the fragment of transitions that is reachable by either the min or the max strategies. result.reachableTransitionsMin = (transitionMatrixBdd && minPlayer1Strategy && minPlayer2Strategy).existsAbstract(game.getNondeterminismVariables()); @@ -279,18 +286,127 @@ namespace storm { } template - storm::expressions::Expression MenuGameRefiner::buildTraceFormula(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& spanningTree, storm::dd::Bdd const& pivotState) const { + std::vector> MenuGameRefiner::buildTrace(storm::expressions::ExpressionManager& expressionManager, storm::abstraction::MenuGame const& game, storm::dd::Bdd const& spanningTree, storm::dd::Bdd const& pivotState) const { + std::vector> result; + + // Prepare some variables. AbstractionInformation const& abstractionInformation = abstractor.get().getAbstractionInformation(); + std::set variablesToAbstract(game.getColumnVariables()); + variablesToAbstract.insert(game.getPlayer1Variables().begin(), game.getPlayer1Variables().end()); + variablesToAbstract.insert(game.getProbabilisticBranchingVariables().begin(), game.getProbabilisticBranchingVariables().end()); - storm::dd::Bdd currentState = pivotState; + std::map oldToNewVariables; + for (auto const& variable : abstractionInformation.getExpressionManager().getVariables()) { + oldToNewVariables[variable] = expressionManager.getVariable(variable.getName()); + } + std::map lastSubstitution; + for (auto const& variable : oldToNewVariables) { + lastSubstitution[variable.second] = variable.second; + } + std::map stepVariableToCopiedVariableMap; + // Start with the target state part of the trace. + storm::storage::BitVector decodedTargetState = abstractionInformation.decodeState(pivotState); + result.emplace_back(abstractionInformation.getPredicates(decodedTargetState)); + for (auto& predicate : result.back()) { + predicate = predicate.changeManager(expressionManager); + } + + pivotState.template toAdd().exportToDot("pivot.dot"); + + // Perform a backward search for an initial state. + storm::dd::Bdd currentState = pivotState; + uint64_t cnt = 0; while ((currentState && game.getInitialStates()).isZero()) { storm::dd::Bdd predecessorTransition = currentState.swapVariables(game.getRowColumnMetaVariablePairs()) && spanningTree; + std::tuple decodedPredecessor = abstractionInformation.decodeStatePlayer1ChoiceAndUpdate(predecessorTransition); + std::cout << "got predecessor " << std::get<0>(decodedPredecessor) << ", choice " << std::get<1>(decodedPredecessor) << " and update " << std::get<2>(decodedPredecessor) << std::endl; + // predecessorTransition.template toAdd().exportToDot("pred_" + std::to_string(cnt) + ".dot"); + // Create a new copy of each variable to use for this step. + std::map substitution; + for (auto const& variablePair : oldToNewVariables) { + storm::expressions::Variable variableCopy = expressionManager.declareVariableCopy(variablePair.second); + substitution[variablePair.second] = variableCopy; + stepVariableToCopiedVariableMap[variableCopy] = variablePair.second; + } + + // Retrieve the variable updates that the predecessor needs to perform to get to the current state. + auto variableUpdates = abstractor.get().getVariableUpdates(std::get<1>(decodedPredecessor), std::get<2>(decodedPredecessor)); + for (auto const& update : variableUpdates) { + storm::expressions::Variable newVariable = oldToNewVariables.at(update.first); + if (update.second.hasBooleanType()) { + result.back().push_back(storm::expressions::iff(lastSubstitution.at(oldToNewVariables.at(update.first)), update.second.changeManager(expressionManager).substitute(substitution))); + } else { + result.back().push_back(lastSubstitution.at(oldToNewVariables.at(update.first)) == update.second.changeManager(expressionManager).substitute(substitution)); + } + } + + // Add the guard of the choice. + result.back().push_back(abstractor.get().getGuard(std::get<1>(decodedPredecessor)).changeManager(expressionManager).substitute(substitution)); + + // Retrieve the predicate valuation in the predecessor. + result.emplace_back(abstractionInformation.getPredicates(std::get<0>(decodedPredecessor))); + for (auto& predicate : result.back()) { + predicate = predicate.changeManager(expressionManager).substitute(substitution); + } + + // Move backwards one step. + lastSubstitution = std::move(substitution); + currentState = predecessorTransition.existsAbstract(variablesToAbstract); + ++cnt; } - return storm::expressions::Expression(); + result.back().push_back(abstractor.get().getInitialExpression().changeManager(expressionManager).substitute(lastSubstitution)); + return result; + } + + template + boost::optional> MenuGameRefiner::derivePredicatesFromInterpolation(storm::abstraction::MenuGame const& game, PivotStateResult const& pivotStateResult, storm::dd::Bdd const& minPlayer1Strategy, storm::dd::Bdd const& minPlayer2Strategy, storm::dd::Bdd const& maxPlayer1Strategy, storm::dd::Bdd const& maxPlayer2Strategy) const { + + // Compute the most probable path from any initial state to the pivot state. + storm::dd::Bdd spanningTree = getMostProbablePathSpanningTree(game, pivotStateResult.pivotState, pivotStateResult.fromDirection == storm::OptimizationDirection::Minimize ? minPlayer1Strategy && minPlayer2Strategy : maxPlayer1Strategy && maxPlayer2Strategy); + + // Create a new expression manager that we can use for the interpolation. + std::shared_ptr interpolationManager = abstractor.get().getAbstractionInformation().getExpressionManager().clone(); + + // Build the trace of the most probable path in terms of which predicates hold in each step. + std::vector> trace = buildTrace(*interpolationManager, game, spanningTree, pivotStateResult.pivotState); + + // Now encode the trace as an SMT problem. + storm::solver::MathsatSmtSolver interpolatingSolver(*interpolationManager, storm::solver::MathsatSmtSolver::Options(true, false, true)); + + uint64_t stepCounter = 0; + for (auto const& step : trace) { + std::cout << "group " << stepCounter << std::endl; + interpolatingSolver.setInterpolationGroup(stepCounter); + for (auto const& predicate : step) { + std::cout << predicate << std::endl; + interpolatingSolver.add(predicate); + } + ++stepCounter; + } + + storm::solver::SmtSolver::CheckResult result = interpolatingSolver.check(); + if (result == storm::solver::SmtSolver::CheckResult::Unsat) { + STORM_LOG_TRACE("Trace formula is unsatisfiable. Starting interpolation."); + + std::vector interpolants; + std::vector prefix; + for (uint64_t step = stepCounter; step > 1; --step) { + prefix.push_back(step - 1); + storm::expressions::Expression interpolant = interpolatingSolver.getInterpolant(prefix); + STORM_LOG_ASSERT(!interpolant.isTrue() && !interpolant.isFalse(), "Expected other interpolant."); + interpolants.push_back(interpolant); + } + return boost::make_optional(interpolants); + } else { + STORM_LOG_TRACE("Trace formula is satisfiable."); + std::cout << interpolatingSolver.getModelAsValuation().toString(true) << std::endl; + } + + return boost::none; } template @@ -307,28 +423,30 @@ namespace storm { minPlayer1Strategy = (maxPlayer1Strategy && qualitativeResult.prob0Min.getPlayer2States()).existsAbstract(game.getPlayer1Variables()).ite(maxPlayer1Strategy, minPlayer1Strategy); // Compute all reached pivot states. - PivotStateResult pivotStateResult = computePivotStates(game, transitionMatrixBdd, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); + PivotStateCandidatesResult pivotStateCandidatesResult = computePivotStates(game, transitionMatrixBdd, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); // We can only refine in case we have a reachable player 1 state with a player 2 successor (under either // player 1's min or max strategy) such that from this player 2 state, both prob0 min and prob1 max define // strategies and they differ. Hence, it is possible that we arrive at a point where no suitable pivot state // is found. In this case, we abort the qualitative refinement here. - if (pivotStateResult.pivotStates.isZero()) { + if (pivotStateCandidatesResult.pivotStates.isZero()) { return false; } - STORM_LOG_ASSERT(!pivotStateResult.pivotStates.isZero(), "Unable to proceed without pivot state candidates."); + STORM_LOG_ASSERT(!pivotStateCandidatesResult.pivotStates.isZero(), "Unable to proceed without pivot state candidates."); // Now that we have the pivot state candidates, we need to pick one. - std::pair, storm::OptimizationDirection> pivotState = pickPivotState(game.getInitialStates(), pivotStateResult.reachableTransitionsMin, pivotStateResult.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStateResult.pivotStates); + PivotStateResult pivotStateResult = pickPivotState(game.getInitialStates(), pivotStateCandidatesResult.reachableTransitionsMin, pivotStateCandidatesResult.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStateCandidatesResult.pivotStates); - // FIXME. - storm::dd::Bdd spanningTree = getMostProbablePathSpanningTree(game, pivotState.first, pivotState.second == storm::OptimizationDirection::Minimize ? minPlayer1Strategy && minPlayer2Strategy : maxPlayer1Strategy && maxPlayer2Strategy); - storm::expressions::Expression traceFormula = buildTraceFormula(game, spanningTree, pivotState.first); - - exit(-1); + boost::optional> interpolationPredicates = derivePredicatesFromInterpolation(game, pivotStateResult, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); + if (interpolationPredicates) { + std::cout << "Got interpolation predicates!" << std::endl; + for (auto const& pred : interpolationPredicates.get()) { + std::cout << "pred: " << pred << std::endl; + } + } // Derive predicate based on the selected pivot state. - RefinementPredicates predicates = derivePredicatesFromPivotState(game, pivotState.first, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); + RefinementPredicates predicates = derivePredicatesFromPivotState(game, pivotStateResult.pivotState, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); std::vector preparedPredicates = preprocessPredicates(predicates.getPredicates(), (predicates.getSource() == RefinementPredicates::Source::Guard && splitGuards) || (predicates.getSource() == RefinementPredicates::Source::WeakestPrecondition && splitPredicates)); performRefinement(createGlobalRefinement(preparedPredicates)); return true; @@ -344,19 +462,23 @@ namespace storm { storm::dd::Bdd maxPlayer2Strategy = quantitativeResult.max.player2Strategy; // Compute all reached pivot states. - PivotStateResult pivotStateResult = computePivotStates(game, transitionMatrixBdd, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); + PivotStateCandidatesResult pivotStateCandidatesResult = computePivotStates(game, transitionMatrixBdd, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); - STORM_LOG_ASSERT(!pivotStateResult.pivotStates.isZero(), "Unable to refine without pivot state candidates."); + STORM_LOG_ASSERT(!pivotStateCandidatesResult.pivotStates.isZero(), "Unable to refine without pivot state candidates."); // Now that we have the pivot state candidates, we need to pick one. - std::pair, storm::OptimizationDirection> pivotState = pickPivotState(game.getInitialStates(), pivotStateResult.reachableTransitionsMin, pivotStateResult.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStateResult.pivotStates); + PivotStateResult pivotStateResult = pickPivotState(game.getInitialStates(), pivotStateCandidatesResult.reachableTransitionsMin, pivotStateCandidatesResult.reachableTransitionsMax, game.getRowVariables(), game.getColumnVariables(), pivotStateCandidatesResult.pivotStates); - // FIXME. - getMostProbablePathSpanningTree(game, pivotState.first, pivotState.second == storm::OptimizationDirection::Minimize ? minPlayer1Strategy && minPlayer2Strategy : maxPlayer1Strategy && maxPlayer2Strategy); - exit(-1); + boost::optional> interpolationPredicates = derivePredicatesFromInterpolation(game, pivotStateResult, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); + if (interpolationPredicates) { + std::cout << "Got interpolation predicates!" << std::endl; + for (auto const& pred : interpolationPredicates.get()) { + std::cout << "pred: " << pred << std::endl; + } + } // Derive predicate based on the selected pivot state. - RefinementPredicates predicates = derivePredicatesFromPivotState(game, pivotState.first, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); + RefinementPredicates predicates = derivePredicatesFromPivotState(game, pivotStateResult.pivotState, minPlayer1Strategy, minPlayer2Strategy, maxPlayer1Strategy, maxPlayer2Strategy); std::vector preparedPredicates = preprocessPredicates(predicates.getPredicates(), (predicates.getSource() == RefinementPredicates::Source::Guard && splitGuards) || (predicates.getSource() == RefinementPredicates::Source::WeakestPrecondition && splitPredicates)); performRefinement(createGlobalRefinement(preparedPredicates)); return true; diff --git a/src/storm/abstraction/MenuGameRefiner.h b/src/storm/abstraction/MenuGameRefiner.h index c631ed3d8..3260e254e 100644 --- a/src/storm/abstraction/MenuGameRefiner.h +++ b/src/storm/abstraction/MenuGameRefiner.h @@ -4,6 +4,8 @@ #include #include +#include + #include "storm/abstraction/RefinementCommand.h" #include "storm/abstraction/QualitativeResultMinMax.h" #include "storm/abstraction/QuantitativeResultMinMax.h" @@ -41,6 +43,14 @@ namespace storm { std::vector predicates; }; + template + struct PivotStateResult { + PivotStateResult(storm::dd::Bdd const& pivotState, storm::OptimizationDirection fromDirection); + + storm::dd::Bdd pivotState; + storm::OptimizationDirection fromDirection; + }; + template class MenuGameRefiner { public: @@ -82,7 +92,8 @@ namespace storm { */ std::vector createGlobalRefinement(std::vector const& predicates) const; - storm::expressions::Expression buildTraceFormula(storm::abstraction::MenuGame const& game, storm::dd::Bdd const& spanningTree, storm::dd::Bdd const& pivotState) const; + boost::optional> derivePredicatesFromInterpolation(storm::abstraction::MenuGame const& game, PivotStateResult const& pivotStateResult, storm::dd::Bdd const& minPlayer1Strategy, storm::dd::Bdd const& minPlayer2Strategy, storm::dd::Bdd const& maxPlayer1Strategy, storm::dd::Bdd const& maxPlayer2Strategy) const; + std::vector> buildTrace(storm::expressions::ExpressionManager& expressionManager, storm::abstraction::MenuGame const& game, storm::dd::Bdd const& spanningTree, storm::dd::Bdd const& pivotState) const; void performRefinement(std::vector const& refinementCommands) const; diff --git a/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp b/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp index 9727cb7c0..552e00486 100644 --- a/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp +++ b/src/storm/abstraction/prism/PrismMenuGameAbstractor.cpp @@ -121,6 +121,11 @@ namespace storm { return std::make_pair(0, modules.front().getCommands().size()); } + template + storm::expressions::Expression PrismMenuGameAbstractor::getInitialExpression() const { + return program.get().getInitialStatesExpression(); + } + template storm::dd::Bdd PrismMenuGameAbstractor::getStates(storm::expressions::Expression const& predicate) { STORM_LOG_ASSERT(currentGame != nullptr, "Game was not properly created."); diff --git a/src/storm/abstraction/prism/PrismMenuGameAbstractor.h b/src/storm/abstraction/prism/PrismMenuGameAbstractor.h index 6c8835b9c..ae7bd5550 100644 --- a/src/storm/abstraction/prism/PrismMenuGameAbstractor.h +++ b/src/storm/abstraction/prism/PrismMenuGameAbstractor.h @@ -84,6 +84,11 @@ namespace storm { */ std::pair getPlayer1ChoiceRange() const override; + /*! + * Retrieves the expression that characterizes the initial states. + */ + storm::expressions::Expression getInitialExpression() const override; + /*! * Retrieves the set of states (represented by a BDD) satisfying the given predicate, assuming that it * was either given as an initial predicate or used as a refining predicate later. diff --git a/src/storm/storage/expressions/ChangeManagerVisitor.cpp b/src/storm/storage/expressions/ChangeManagerVisitor.cpp new file mode 100644 index 000000000..d650131e3 --- /dev/null +++ b/src/storm/storage/expressions/ChangeManagerVisitor.cpp @@ -0,0 +1,68 @@ +#include "storm/storage/expressions/ChangeManagerVisitor.h" + +#include "storm/storage/expressions/Expressions.h" + +namespace storm { + namespace expressions { + + ChangeManagerVisitor::ChangeManagerVisitor(ExpressionManager const& manager) : manager(manager) { + // Intentionally left empty. + } + + Expression ChangeManagerVisitor::changeManager(storm::expressions::Expression const& expression) { + return Expression(boost::any_cast>(expression.accept(*this, boost::none))); + } + + boost::any ChangeManagerVisitor::visit(IfThenElseExpression const& expression, boost::any const& data) { + auto newCondition = boost::any_cast>(expression.getCondition()->accept(*this, data)); + auto newThen = boost::any_cast>(expression.getThenExpression()->accept(*this, data)); + auto newElse = boost::any_cast>(expression.getElseExpression()->accept(*this, data)); + return std::shared_ptr(new IfThenElseExpression(manager, expression.getType(), newCondition, newThen, newElse)); + } + + boost::any ChangeManagerVisitor::visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) { + auto newFirstOperand = boost::any_cast>(expression.getFirstOperand()->accept(*this, data)); + auto newSecondOperand = boost::any_cast>(expression.getSecondOperand()->accept(*this, data)); + return std::shared_ptr(new BinaryBooleanFunctionExpression(manager, expression.getType(), newFirstOperand, newSecondOperand, expression.getOperatorType())); + } + + boost::any ChangeManagerVisitor::visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) { + auto newFirstOperand = boost::any_cast>(expression.getFirstOperand()->accept(*this, data)); + auto newSecondOperand = boost::any_cast>(expression.getSecondOperand()->accept(*this, data)); + return std::shared_ptr(new BinaryNumericalFunctionExpression(manager, expression.getType(), newFirstOperand, newSecondOperand, expression.getOperatorType())); + } + + boost::any ChangeManagerVisitor::visit(BinaryRelationExpression const& expression, boost::any const& data) { + auto newFirstOperand = boost::any_cast>(expression.getFirstOperand()->accept(*this, data)); + auto newSecondOperand = boost::any_cast>(expression.getSecondOperand()->accept(*this, data)); + return std::shared_ptr(new BinaryRelationExpression(manager, expression.getType(), newFirstOperand, newSecondOperand, expression.getRelationType())); + } + + boost::any ChangeManagerVisitor::visit(VariableExpression const& expression, boost::any const& data) { + return std::shared_ptr(new VariableExpression(manager.getVariable(expression.getVariableName()))); + } + + boost::any ChangeManagerVisitor::visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) { + auto newOperand = boost::any_cast>(expression.getOperand()->accept(*this, data)); + return std::shared_ptr(new UnaryBooleanFunctionExpression(manager, expression.getType(), newOperand, expression.getOperatorType())); + } + + boost::any ChangeManagerVisitor::visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) { + auto newOperand = boost::any_cast>(expression.getOperand()->accept(*this, data)); + return std::shared_ptr(new UnaryNumericalFunctionExpression(manager, expression.getType(), newOperand, expression.getOperatorType())); + } + + boost::any ChangeManagerVisitor::visit(BooleanLiteralExpression const& expression, boost::any const& data) { + return std::shared_ptr(new BooleanLiteralExpression(manager, expression.getValue())); + } + + boost::any ChangeManagerVisitor::visit(IntegerLiteralExpression const& expression, boost::any const& data) { + return std::shared_ptr(new IntegerLiteralExpression(manager, expression.getValue())); + } + + boost::any ChangeManagerVisitor::visit(RationalLiteralExpression const& expression, boost::any const& data) { + return std::shared_ptr(new RationalLiteralExpression(manager, expression.getValue())); + } + + } +} diff --git a/src/storm/storage/expressions/ChangeManagerVisitor.h b/src/storm/storage/expressions/ChangeManagerVisitor.h new file mode 100644 index 000000000..12f732f2b --- /dev/null +++ b/src/storm/storage/expressions/ChangeManagerVisitor.h @@ -0,0 +1,33 @@ +#pragma once + +#include "storm/storage/expressions/ExpressionVisitor.h" +#include "storm/storage/expressions/ExpressionManager.h" + +namespace storm { + namespace expressions { + + class Expression; + + class ChangeManagerVisitor : public ExpressionVisitor { + public: + ChangeManagerVisitor(ExpressionManager const& manager); + + Expression changeManager(storm::expressions::Expression const& expression); + + virtual boost::any visit(IfThenElseExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryRelationExpression const& expression, boost::any const& data) override; + virtual boost::any visit(VariableExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) override; + + private: + ExpressionManager const& manager; + }; + + } +} diff --git a/src/storm/storage/expressions/Expression.cpp b/src/storm/storage/expressions/Expression.cpp index ba6a26a80..0a2f200e3 100644 --- a/src/storm/storage/expressions/Expression.cpp +++ b/src/storm/storage/expressions/Expression.cpp @@ -6,6 +6,7 @@ #include "storm/storage/expressions/SubstitutionVisitor.h" #include "storm/storage/expressions/LinearityCheckVisitor.h" #include "storm/storage/expressions/SyntacticalEqualityCheckVisitor.h" +#include "storm/storage/expressions/ChangeManagerVisitor.h" #include "storm/storage/expressions/Expressions.h" #include "storm/exceptions/InvalidTypeException.h" #include "storm/exceptions/InvalidArgumentException.h" @@ -32,6 +33,11 @@ namespace storm { // Intentionally left empty. } + Expression Expression::changeManager(ExpressionManager const& newExpressionManager) const { + ChangeManagerVisitor visitor(newExpressionManager); + return visitor.changeManager(*this); + } + Expression Expression::substitute(std::map const& identifierToExpressionMap) const { return SubstitutionVisitor>(identifierToExpressionMap).substitute(*this); } diff --git a/src/storm/storage/expressions/Expression.h b/src/storm/storage/expressions/Expression.h index 9b6c27e8e..3d6da87c9 100644 --- a/src/storm/storage/expressions/Expression.h +++ b/src/storm/storage/expressions/Expression.h @@ -81,6 +81,11 @@ namespace storm { Expression& operator=(Expression&&) = default; #endif + /*! + * Converts the expression to an expression over the variables of the provided expression manager. + */ + Expression changeManager(ExpressionManager const& newExpressionManager) const; + /*! * Substitutes all occurrences of the variables according to the given map. Note that this substitution is * done simultaneously, i.e., variables appearing in the expressions that were "plugged in" are not diff --git a/src/storm/storage/expressions/ExpressionManager.cpp b/src/storm/storage/expressions/ExpressionManager.cpp index 82d40ec08..13fa5be6d 100644 --- a/src/storm/storage/expressions/ExpressionManager.cpp +++ b/src/storm/storage/expressions/ExpressionManager.cpp @@ -56,6 +56,10 @@ namespace storm { // Intentionally left empty. } + std::shared_ptr ExpressionManager::clone() const { + return std::shared_ptr(new ExpressionManager(*this)); + } + Expression ExpressionManager::boolean(bool value) const { return Expression(std::shared_ptr(new BooleanLiteralExpression(*this, value))); } @@ -125,6 +129,10 @@ namespace storm { return nameIndexPair != nameToIndexMapping.end(); } + Variable ExpressionManager::declareVariableCopy(Variable const& variable) { + return declareFreshVariable(variable.getType(), true, "_" + variable.getName() + "_"); + } + Variable ExpressionManager::declareVariable(std::string const& name, storm::expressions::Type const& variableType, bool auxiliary) { STORM_LOG_THROW(!variableExists(name), storm::exceptions::InvalidArgumentException, "Variable with name '" << name << "' already exists."); return declareOrGetVariable(name, variableType, auxiliary); @@ -187,7 +195,9 @@ namespace storm { nameToIndexMapping[name] = newIndex; indexToNameMapping[newIndex] = name; indexToTypeMapping[newIndex] = variableType; - return Variable(this->getSharedPointer(), newIndex); + Variable result(this->getSharedPointer(), newIndex); + variableSet.insert(result); + return result; } } @@ -197,6 +207,10 @@ namespace storm { return Variable(this->getSharedPointer(), nameIndexPair->second); } + std::set const& ExpressionManager::getVariables() const { + return variableSet; + } + Expression ExpressionManager::getVariableExpression(std::string const& name) const { return Expression(getVariable(name)); } diff --git a/src/storm/storage/expressions/ExpressionManager.h b/src/storm/storage/expressions/ExpressionManager.h index 4ae5bb295..05cc11a2c 100644 --- a/src/storm/storage/expressions/ExpressionManager.h +++ b/src/storm/storage/expressions/ExpressionManager.h @@ -71,16 +71,14 @@ namespace storm { */ ExpressionManager(); - // Explicitly delete copy construction/assignment, since the manager is supposed to be stored as a pointer - // of some sort. This is because the expression classes store a reference to the manager and it must - // therefore be guaranteed that they do not become invalid, because the manager has been copied. - ExpressionManager(ExpressionManager const& other) = delete; - ExpressionManager& operator=(ExpressionManager const& other) = delete; -#ifndef WINDOWS + /*! + * Creates a new expression manager with the same set of variables. + */ + std::shared_ptr clone() const; + // Create default instantiations for the move construction/assignment. ExpressionManager(ExpressionManager&& other) = default; ExpressionManager& operator=(ExpressionManager&& other) = default; -#endif /*! * Creates an expression that characterizes the given boolean literal. @@ -147,7 +145,15 @@ namespace storm { * @return The rational type. */ Type const& getRationalType() const; - + + /*! + * Declares a variable that is a copy of the provided variable (i.e. has the same type). + * + * @param variable The variable of which to create a copy. + * @return The newly declared variable. + */ + Variable declareVariableCopy(Variable const& variable); + /*! * Declares a variable with a name that must not yet exist and its corresponding type. Note that the name * must not start with two underscores since these variables are reserved for internal use only. @@ -218,6 +224,11 @@ namespace storm { */ Variable getVariable(std::string const& name) const; + /*! + * Retrieves the set of all variables known to this manager. + */ + std::set const& getVariables() const; + /*! * Retrieves whether a variable with the given name is known to the manager. * @@ -361,6 +372,12 @@ namespace storm { friend std::ostream& operator<<(std::ostream& out, ExpressionManager const& manager); private: + // Explicitly make copy construction/assignment private, since the manager is supposed to be stored as a pointer + // of some sort. This is because the expression classes store a reference to the manager and it must + // therefore be guaranteed that they do not become invalid, because the manager has been copied. + ExpressionManager(ExpressionManager const& other) = default; + ExpressionManager& operator=(ExpressionManager const& other) = default; + /*! * Checks whether the given variable name is valid. * @@ -405,6 +422,9 @@ namespace storm { */ uint_fast64_t getNumberOfAuxiliaryVariables(storm::expressions::Type const& variableType) const; + // The set of all known variables. + std::set variableSet; + // A mapping from all variable names (auxiliary + normal) to their indices. std::unordered_map nameToIndexMapping;