From 629448c31250b7d08e4293456dbcc34e43a2a3e8 Mon Sep 17 00:00:00 2001 From: dehnert Date: Wed, 9 Oct 2013 16:46:28 +0200 Subject: [PATCH] First working version of MaxSAT-based minimal command counterexample generation. Former-commit-id: 6dc49157f9fd9344c84c5ccbf69b26f77a9a36a7 --- counterexamples.h | 21 +++ src/adapters/ExplicitModelAdapter.h | 22 --- .../MILPMinimalLabelSetGenerator.h | 8 +- .../SMTMinimalCommandSetGenerator.h | 133 +++++++++++++++--- src/storm.cpp | 34 +++-- src/utility/IRUtility.h | 22 +++ 6 files changed, 183 insertions(+), 57 deletions(-) create mode 100644 counterexamples.h diff --git a/counterexamples.h b/counterexamples.h new file mode 100644 index 000000000..b548d6460 --- /dev/null +++ b/counterexamples.h @@ -0,0 +1,21 @@ +/* + * vector.h + * + * Created on: 06.12.2012 + * Author: Christian Dehnert + */ + +#ifndef STORM_UTILITY_COUNTEREXAMPLE_H_ +#define STORM_UTILITY_COUNTEREXAMPLE_H_ + +namespace storm { + namespace utility { + namespace counterexample{ + + + + } // namespace counterexample + } // namespace utility +} // namespace storm + +#endif /* STORM_UTILITY_COUNTEREXAMPLE_H_ */ diff --git a/src/adapters/ExplicitModelAdapter.h b/src/adapters/ExplicitModelAdapter.h index 9696d9a33..8eac31edb 100644 --- a/src/adapters/ExplicitModelAdapter.h +++ b/src/adapters/ExplicitModelAdapter.h @@ -125,28 +125,6 @@ namespace storm { } private: - /*! - * Sets some boolean variable in the given state object. - * - * @param state The state to modify. - * @param index The index of the boolean variable to modify. - * @param value The new value of the variable. - */ - static void setValue(StateType* state, uint_fast64_t index, bool value) { - std::get<0>(*state)[index] = value; - } - - /*! - * Set some integer variable in the given state object. - * - * @param state The state to modify. - * @param index index of the integer variable to modify. - * @param value The new value of the variable. - */ - static void setValue(StateType* state, uint_fast64_t index, int_fast64_t value) { - std::get<1>(*state)[index] = value; - } - /*! * Transforms a state into a somewhat readable string. * diff --git a/src/counterexamples/MILPMinimalLabelSetGenerator.h b/src/counterexamples/MILPMinimalLabelSetGenerator.h index 3392e16d8..fa7ca4180 100644 --- a/src/counterexamples/MILPMinimalLabelSetGenerator.h +++ b/src/counterexamples/MILPMinimalLabelSetGenerator.h @@ -1084,12 +1084,12 @@ namespace storm { * @param model The Gurobi model. * @param variableInformation A struct with information about the variables of the model. */ - static std::unordered_set getUsedLabelsInSolution(GRBenv* env, GRBmodel* model, VariableInformation const& variableInformation) { + static std::set getUsedLabelsInSolution(GRBenv* env, GRBmodel* model, VariableInformation const& variableInformation) { int error = 0; // Check whether the model was optimized, so we can read off the solution. if (checkGurobiModelIsOptimized(env, model)) { - std::unordered_set result; + std::set result; double value = 0; for (auto labelVariablePair : variableInformation.labelToVariableIndexMap) { @@ -1202,7 +1202,7 @@ namespace storm { public: - static std::unordered_set getMinimalLabelSet(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, double probabilityThreshold, bool checkThresholdFeasible = false, bool includeSchedulerCuts = false) { + static std::set getMinimalLabelSet(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, double probabilityThreshold, bool checkThresholdFeasible = false, bool includeSchedulerCuts = false) { #ifdef STORM_HAVE_GUROBI // (0) Check whether the MDP is indeed labeled. if (!labeledMdp.hasChoiceLabels()) { @@ -1239,7 +1239,7 @@ namespace storm { optimizeModel(environmentModelPair.first, environmentModelPair.second); // (4.5) Read off result from variables. - std::unordered_set usedLabelSet = getUsedLabelsInSolution(environmentModelPair.first, environmentModelPair.second, variableInformation); + std::set usedLabelSet = getUsedLabelsInSolution(environmentModelPair.first, environmentModelPair.second, variableInformation); // Display achieved probability. std::pair initialStateProbabilityPair = getReachabilityProbability(environmentModelPair.first, environmentModelPair.second, labeledMdp, variableInformation); diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h index f2b0ead6e..41d3f4db1 100644 --- a/src/counterexamples/SMTMinimalCommandSetGenerator.h +++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h @@ -172,6 +172,7 @@ namespace storm { // * identify labels that can directly precede a given action // * identify labels that directly reach a target state // * identify labels that can directly follow a given action + // * identify labels that can be found on each path to the target states. std::set initialLabels; std::map> precedingLabels; @@ -361,7 +362,7 @@ namespace storm { * @param solver The solver to use for the satisfiability evaluation. */ static void assertSymbolicCuts(storm::ir::Program const& program, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) { - // TODO: + // FIXME: // find synchronization cuts // find forward/backward cuts @@ -388,9 +389,106 @@ namespace storm { z3::expr upperBound = expressionAdapter.translateExpression(integerVariable.getUpperBound()); upperBound = solverVariables.at(integerVariable.getName()) <= upperBound; localSolver.add(upperBound); - } + } + + // Construct an expression that exactly characterizes the initial state. + std::unique_ptr initialState(storm::utility::ir::getInitialState(program, programVariableInformation)); + z3::expr initialStateExpression = localContext.bool_val(true); + for (uint_fast64_t index = 0; index < programVariableInformation.booleanVariables.size(); ++index) { + if (std::get<0>(*initialState).at(programVariableInformation.booleanVariableToIndexMap.at(programVariableInformation.booleanVariables[index].getName()))) { + initialStateExpression = initialStateExpression && solverVariables.at(programVariableInformation.booleanVariables[index].getName()); + } else { + initialStateExpression = initialStateExpression && !solverVariables.at(programVariableInformation.booleanVariables[index].getName()); + } + } + for (uint_fast64_t index = 0; index < programVariableInformation.integerVariables.size(); ++index) { + storm::ir::IntegerVariable const& variable = programVariableInformation.integerVariables[index]; + initialStateExpression = initialStateExpression && (solverVariables.at(variable.getName()) == localContext.int_val(std::get<1>(*initialState).at(programVariableInformation.integerVariableToIndexMap.at(variable.getName())))); + } + + std::map> backwardImplications; + + // First check for possible backward cuts. + for (uint_fast64_t moduleIndex = 0; moduleIndex < program.getNumberOfModules(); ++moduleIndex) { + storm::ir::Module const& module = program.getModule(moduleIndex); + + for (uint_fast64_t commandIndex = 0; commandIndex < module.getNumberOfCommands(); ++commandIndex) { + storm::ir::Command const& command = module.getCommand(commandIndex); + + // If the label of the command is not relevant, skip it entirely. + if (relevancyInformation.relevantLabels.find(command.getGlobalIndex()) == relevancyInformation.relevantLabels.end()) continue; + + // Save the state of the solver so we can easily backtrack. + localSolver.push(); + + // Check if the command is enabled in the initial state. + localSolver.add(expressionAdapter.translateExpression(command.getGuard())); + localSolver.add(initialStateExpression); + + z3::check_result checkResult = localSolver.check(); + localSolver.pop(); + localSolver.push(); + + // If it is not and the action is not synchronizing, we can impose backward cuts. + if (checkResult == z3::unsat && command.getActionName() == "") { + localSolver.add(!expressionAdapter.translateExpression(command.getGuard())); + localSolver.push(); + + // We need to check all commands of the all modules, because they could enable the current + // command via a global variable. + for (uint_fast64_t otherModuleIndex = 0; otherModuleIndex < program.getNumberOfModules(); ++otherModuleIndex) { + storm::ir::Module const& otherModule = program.getModule(otherModuleIndex); + + for (uint_fast64_t otherCommandIndex = 0; otherCommandIndex < otherModule.getNumberOfCommands(); ++otherCommandIndex) { + storm::ir::Command const& otherCommand = otherModule.getCommand(otherCommandIndex); + + // We don't need to consider irrelevant commands and the command itself. + if (relevancyInformation.relevantLabels.find(otherCommand.getGlobalIndex()) == relevancyInformation.relevantLabels.end()) continue; + if (moduleIndex == otherModuleIndex && commandIndex == otherCommandIndex) continue; + + std::vector formulae; + formulae.reserve(otherCommand.getNumberOfUpdates()); + + localSolver.push(); + + for (uint_fast64_t updateIndex = 0; updateIndex < otherCommand.getNumberOfUpdates(); ++updateIndex) { + std::unique_ptr weakestPrecondition = storm::utility::ir::getWeakestPrecondition(command.getGuard(), {otherCommand.getUpdate(updateIndex)}); + + formulae.push_back(expressionAdapter.translateExpression(weakestPrecondition)); + } + + assertDisjunction(localContext, localSolver, formulae); + + // If the assertions were satisfiable, this means the other command could successfully + // enable the current command. + if (localSolver.check() == z3::sat) { + backwardImplications[command.getGlobalIndex()].insert(otherCommand.getGlobalIndex()); + } + + localSolver.pop(); + } + } + + // Remove the negated guard from the solver assertions. + localSolver.pop(); + } + + // Restore state of solver where only the variable bounds are asserted. + localSolver.pop(); + } + } - std::cout << localSolver << std::endl; + std::vector formulae; + for (auto const& labelImplicationsPair : backwardImplications) { + formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(labelImplicationsPair.first))); + + for (auto label : labelImplicationsPair.second) { + formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label))); + } + + assertDisjunction(context, solver, formulae); + formulae.clear(); + } } /*! @@ -564,12 +662,16 @@ namespace storm { } // Check whether the assumptions are satisfiable. + LOG4CPLUS_DEBUG(logger, "Invoking satisfiability checking."); z3::check_result result = solver.check(assumptions); + LOG4CPLUS_DEBUG(logger, "Done invoking satisfiability checking."); - if (result == z3::check_result::sat) { + if (result == z3::sat) { return true; } else { + LOG4CPLUS_DEBUG(logger, "Computing unsat core."); z3::expr_vector unsatCore = solver.unsat_core(); + LOG4CPLUS_DEBUG(logger, "Computed unsat core."); std::vector blockingVariables; blockingVariables.reserve(unsatCore.size()); @@ -648,6 +750,8 @@ namespace storm { * @return The smallest set of labels such that the constraint system of the solver is still satisfiable. */ static std::set findSmallestCommandSet(z3::context& context, z3::solver& solver, VariableInformation& variableInformation, std::vector& softConstraints, uint_fast64_t& nextFreeVariableIndex) { + + solver.push(); for (uint_fast64_t i = 0; ; ++i) { if (fuMalikMaxsatStep(context, solver, variableInformation, softConstraints, nextFreeVariableIndex)) { break; @@ -661,16 +765,19 @@ namespace storm { for (auto const& labelIndexPair : variableInformation.labelToIndexMap) { z3::expr value = model.eval(variableInformation.labelVariables[labelIndexPair.second]); - std::stringstream resultStream; - resultStream << value; - if (resultStream.str() == "true") { + // Check whether the label variable was set or not. + if (eq(value, context.bool_val(true))) { result.insert(labelIndexPair.first); - } else if (resultStream.str() == "false") { + } else if (eq(value, context.bool_val(false))) { // Nothing to do in this case. + } else if (eq(value, variableInformation.labelVariables[labelIndexPair.second])) { + // If the variable is a "don't care", then we rather not take it, so nothing to do in this case + // as well. } else { - throw storm::exceptions::InvalidStateException() << "Could not retrieve value of boolean variable."; + throw storm::exceptions::InvalidStateException() << "Could not retrieve value of boolean variable from illegal value."; } } + solver.pop(); return result; } @@ -749,17 +856,9 @@ namespace storm { } else { done = true; } - std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands in iteration " << iterations << "." << std::endl; ++iterations; } while (!done); - std::cout << "Achieved probability: " << maximalReachabilityProbability << " with " << commandSet.size() << " commands." << std::endl; - std::cout << "Taken commands are:" << std::endl; - for (auto label : commandSet) { - std::cout << label << ", "; - } - std::cout << std::endl; - // (8) Return the resulting command set after undefining the constants. storm::utility::ir::undefineUndefinedConstants(program); return commandSet; diff --git a/src/storm.cpp b/src/storm.cpp index 80573a589..d0ddc0654 100644 --- a/src/storm.cpp +++ b/src/storm.cpp @@ -338,19 +338,19 @@ int main(const int argc, const char* argv[]) { model->printModelInformationToStream(std::cout); // Enable the following lines to test the MinimalLabelSetGenerator. -// if (model->getType() == storm::models::MDP) { -// std::shared_ptr> labeledMdp = model->as>(); -// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); -// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); -// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; -// std::unordered_set labels = storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true); -// -// std::cout << "Found solution with " << labels.size() << " commands." << std::endl; -// for (uint_fast64_t label : labels) { -// std::cout << label << ", "; -// } -// std::cout << std::endl; -// } + if (model->getType() == storm::models::MDP) { + std::shared_ptr> labeledMdp = model->as>(); + storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); + storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); + storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; + std::set labels = storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true, true); + + std::cout << "Found solution with " << labels.size() << " commands." << std::endl; + for (uint_fast64_t label : labels) { + std::cout << label << ", "; + } + std::cout << std::endl; + } // Enable the following lines to test the SMTMinimalCommandSetGenerator. if (model->getType() == storm::models::MDP) { @@ -358,7 +358,13 @@ int main(const int argc, const char* argv[]) { storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; - storm::counterexamples::SMTMinimalCommandSetGenerator::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true); + std::set labels = storm::counterexamples::SMTMinimalCommandSetGenerator::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true); + + std::cout << "Found solution with " << labels.size() << " commands." << std::endl; + for (uint_fast64_t label : labels) { + std::cout << label << ", "; + } + std::cout << std::endl; } } diff --git a/src/utility/IRUtility.h b/src/utility/IRUtility.h index 63cf9519e..09f58dca8 100644 --- a/src/utility/IRUtility.h +++ b/src/utility/IRUtility.h @@ -369,6 +369,28 @@ namespace storm { LOG4CPLUS_DEBUG(logger, "Generated initial state."); return initialState; } + + /*! + * Sets some boolean variable in the given state object. + * + * @param state The state to modify. + * @param index The index of the boolean variable to modify. + * @param value The new value of the variable. + */ + static void setValue(StateType* state, uint_fast64_t index, bool value) { + std::get<0>(*state)[index] = value; + } + + /*! + * Set some integer variable in the given state object. + * + * @param state The state to modify. + * @param index index of the integer variable to modify. + * @param value The new value of the variable. + */ + static void setValue(StateType* state, uint_fast64_t index, int_fast64_t value) { + std::get<1>(*state)[index] = value; + } /*! * Defines the undefined constants of the given program using the given string.