From fda9c43e869f6d8e75d71dc2f8a354308c2f61dd Mon Sep 17 00:00:00 2001 From: dehnert Date: Thu, 10 Oct 2013 17:22:35 +0200 Subject: [PATCH] Fix for SMT-based minimal command set generator. Minor fixes to string output of expression classes. Former-commit-id: 316a762d74788eab5379418c9b83b1f04b4ac44a --- examples/mdp/consensus/coin4.nm | 12 +- .../SMTMinimalCommandSetGenerator.h | 117 ++++++++++++++---- src/ir/Module.cpp | 10 ++ src/ir/Module.h | 6 + src/ir/Program.cpp | 8 +- src/ir/Program.h | 6 + .../BinaryBooleanFunctionExpression.cpp | 4 +- .../BinaryNumericalFunctionExpression.cpp | 4 +- .../expressions/BinaryRelationExpression.cpp | 4 +- src/ir/expressions/ConstantExpression.h | 5 +- .../UnaryBooleanFunctionExpression.cpp | 3 +- .../UnaryNumericalFunctionExpression.cpp | 9 +- src/storm.cpp | 26 ++-- 13 files changed, 155 insertions(+), 59 deletions(-) diff --git a/examples/mdp/consensus/coin4.nm b/examples/mdp/consensus/coin4.nm index 8639b867b..fa1c5ba00 100644 --- a/examples/mdp/consensus/coin4.nm +++ b/examples/mdp/consensus/coin4.nm @@ -29,18 +29,18 @@ module process1 // flip coin [] (pc1=0) -> 0.5 : (coin1'=0) & (pc1'=1) + 0.5 : (coin1'=1) & (pc1'=1); // write tails -1 (reset coin to add regularity) - [] (pc1=1) & (coin1=0) & (counter>0) -> (counter'=counter-1) & (pc1'=2) & (coin1'=0); + [] (pc1=1) & (coin1=0) & (counter>0) -> 1 : (counter'=counter-1) & (pc1'=2) & (coin1'=0); // write heads +1 (reset coin to add regularity) - [] (pc1=1) & (coin1=1) & (counter (counter'=counter+1) & (pc1'=2) & (coin1'=0); + [] (pc1=1) & (coin1=1) & (counter 1 : (counter'=counter+1) & (pc1'=2) & (coin1'=0); // check // decide tails - [] (pc1=2) & (counter<=left) -> (pc1'=3) & (coin1'=0); + [] (pc1=2) & (counter<=left) -> 1 : (pc1'=3) & (coin1'=0); // decide heads - [] (pc1=2) & (counter>=right) -> (pc1'=3) & (coin1'=1); + [] (pc1=2) & (counter>=right) -> 1 : (pc1'=3) & (coin1'=1); // flip again - [] (pc1=2) & (counter>left) & (counter (pc1'=0); + [] (pc1=2) & (counter>left) & (counter 1 : (pc1'=0); // loop (all loop together when done) - [done] (pc1=3) -> (pc1'=3); + [done] (pc1=3) -> 1 : (pc1'=3); endmodule diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h index ad33fbef8..95ef99789 100644 --- a/src/counterexamples/SMTMinimalCommandSetGenerator.h +++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h @@ -45,6 +45,7 @@ namespace storm { struct VariableInformation { std::vector labelVariables; + std::vector originalAuxiliaryVariables; std::vector auxiliaryVariables; std::map labelToIndexMap; }; @@ -134,6 +135,8 @@ namespace storm { variableInformation.auxiliaryVariables.push_back(context.bool_const(variableName.str().c_str())); } + variableInformation.originalAuxiliaryVariables = variableInformation.auxiliaryVariables; + return variableInformation; } @@ -305,11 +308,55 @@ namespace storm { * @param context The Z3 context in which to build the expressions. * @param solver The solver to use for the satisfiability evaluation. */ - static void assertSymbolicCuts(storm::ir::Program const& program, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) { - // FIXME: - // find synchronization cuts - // find forward/backward cuts + static void assertSymbolicCuts(storm::ir::Program const& program, storm::models::Mdp const& labeledMdp, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) { + + // Initially, we look for synchronisation implications. +// for (uint_fast64_t moduleIndex = 0; moduleIndex < program.getNumberOfModules(); ++moduleIndex) { +// storm::ir::Module const& module = program.getModule(moduleIndex); +// +// for (uint_fast64_t commandIndex = 0; commandIndex < module.getNumberOfCommands(); ++commandIndex) { +// storm::ir::Command const& command = module.getCommand(commandIndex); +// +// // If the command is unlabeled, there are no synchronisation cuts to apply. +// if (command.getActionName() == "") continue; +// } +// } + std::map> precedingLabels; + + // Get some data from the MDP for convenient access. + storm::storage::SparseMatrix const& transitionMatrix = labeledMdp.getTransitionMatrix(); + std::vector> const& choiceLabeling = labeledMdp.getChoiceLabeling(); + storm::storage::SparseMatrix backwardTransitions = labeledMdp.getBackwardTransitions(); + + for (auto currentState : relevancyInformation.relevantStates) { + for (auto currentChoice : relevancyInformation.relevantChoicesForRelevantStates.at(currentState)) { + // Iterate over predecessors and add all choices that target the current state to the preceding + // label set of all labels of all relevant choices of the current state. + for (typename storm::storage::SparseMatrix::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(currentState), predecessorIte = backwardTransitions.constColumnIteratorEnd(currentState); predecessorIt != predecessorIte; ++predecessorIt) { + for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(*predecessorIt)) { + bool choiceTargetsCurrentState = false; + for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(predecessorChoice), successorIte = transitionMatrix.constColumnIteratorEnd(predecessorChoice); successorIt != successorIte; ++successorIt) { + if (*successorIt == currentState) { + choiceTargetsCurrentState = true; + } + } + + if (choiceTargetsCurrentState) { + for (auto labelToAdd : choiceLabeling[predecessorChoice]) { + for (auto labelForWhichToAdd : choiceLabeling[currentChoice]) { + precedingLabels[labelForWhichToAdd].insert(labelToAdd); + } + } + } + } + } + } + } + + // FIXME: The following procedure to assert backward cuts is not correct in the presence of synchronizing + // actions, because it may be the case that several synchronizing commands are necessary to enable another + // command and not just one. storm::utility::ir::VariableInformation programVariableInformation = storm::utility::ir::createVariableInformation(program); // Create a context and register all variables of the program with their correct type. @@ -352,13 +399,13 @@ namespace storm { std::map> backwardImplications; - // First check for possible backward cuts. + // Now check for possible backward cuts. for (uint_fast64_t moduleIndex = 0; moduleIndex < program.getNumberOfModules(); ++moduleIndex) { storm::ir::Module const& module = program.getModule(moduleIndex); for (uint_fast64_t commandIndex = 0; commandIndex < module.getNumberOfCommands(); ++commandIndex) { storm::ir::Command const& command = module.getCommand(commandIndex); - + // If the label of the command is not relevant, skip it entirely. if (relevancyInformation.relevantLabels.find(command.getGlobalIndex()) == relevancyInformation.relevantLabels.end()) continue; @@ -373,8 +420,7 @@ namespace storm { localSolver.pop(); localSolver.push(); - // If it is not and the action is not synchronizing, we can impose backward cuts. - if (checkResult == z3::unsat && command.getActionName() == "") { + if (checkResult == z3::unsat) { localSolver.add(!expressionAdapter.translateExpression(command.getGuard())); localSolver.push(); @@ -426,7 +472,10 @@ namespace storm { for (auto const& labelImplicationsPair : backwardImplications) { formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(labelImplicationsPair.first))); - for (auto label : labelImplicationsPair.second) { + std::set actualImplications; + std::set_intersection(labelImplicationsPair.second.begin(), labelImplicationsPair.second.end(), precedingLabels.at(labelImplicationsPair.first).begin(), precedingLabels.at(labelImplicationsPair.first).end(), std::inserter(actualImplications, actualImplications.begin())); + + for (auto label : actualImplications) { formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label))); } @@ -599,10 +648,10 @@ namespace storm { * @param variableInformation A structure with information about the variables for the labels. * @return True iff the constraint system was satisfiable. */ - static bool fuMalikMaxsatStep(z3::context& context, z3::solver& solver, VariableInformation& variableInformation, std::vector& softConstraints, uint_fast64_t& nextFreeVariableIndex) { + static bool fuMalikMaxsatStep(z3::context& context, z3::solver& solver, std::vector& auxiliaryVariables, std::vector& softConstraints, uint_fast64_t& nextFreeVariableIndex) { z3::expr_vector assumptions(context); - for (auto const& auxVariable : variableInformation.auxiliaryVariables) { - assumptions.push_back(!auxVariable); + for (auto const& auxiliaryVariable : auxiliaryVariables) { + assumptions.push_back(!auxiliaryVariable); } // Check whether the assumptions are satisfiable. @@ -640,11 +689,11 @@ namespace storm { variableName.str(""); variableName << "a" << nextFreeVariableIndex; ++nextFreeVariableIndex; - variableInformation.auxiliaryVariables[softConstraintIndex] = context.bool_const(variableName.str().c_str()); + auxiliaryVariables[softConstraintIndex] = context.bool_const(variableName.str().c_str()); softConstraints[softConstraintIndex] = softConstraints[softConstraintIndex] || blockingVariables.back(); - solver.add(softConstraints[softConstraintIndex] || variableInformation.auxiliaryVariables[softConstraintIndex]); + solver.add(softConstraints[softConstraintIndex] || auxiliaryVariables[softConstraintIndex]); } } } @@ -695,9 +744,13 @@ namespace storm { */ static std::set findSmallestCommandSet(z3::context& context, z3::solver& solver, VariableInformation& variableInformation, std::vector& softConstraints, uint_fast64_t& nextFreeVariableIndex) { - solver.push(); + // Copy the original auxiliary variables and soft constraints so the procedure can modify the copies. + // std::vector auxiliaryVariables(variableInformation.auxiliaryVariables); + // std::vector tmpSoftConstraints(softConstraints); + // solver.push(); + for (uint_fast64_t i = 0; ; ++i) { - if (fuMalikMaxsatStep(context, solver, variableInformation, softConstraints, nextFreeVariableIndex)) { + if (fuMalikMaxsatStep(context, solver, variableInformation.auxiliaryVariables, softConstraints, nextFreeVariableIndex)) { break; } } @@ -707,21 +760,21 @@ namespace storm { z3::model model = solver.get_model(); for (auto const& labelIndexPair : variableInformation.labelToIndexMap) { - z3::expr value = model.eval(variableInformation.labelVariables[labelIndexPair.second]); + z3::expr auxValue = model.eval(variableInformation.originalAuxiliaryVariables[labelIndexPair.second]); - // Check whether the label variable was set or not. - if (eq(value, context.bool_val(true))) { + // Check whether the auxiliary variable was set or not. + if (eq(auxValue, context.bool_val(true))) { result.insert(labelIndexPair.first); - } else if (eq(value, context.bool_val(false))) { + } else if (eq(auxValue, context.bool_val(false))) { // Nothing to do in this case. - } else if (eq(value, variableInformation.labelVariables[labelIndexPair.second])) { - // If the variable is a "don't care", then we rather not take it, so nothing to do in this case - // as well. + } else if (eq(auxValue, variableInformation.originalAuxiliaryVariables[labelIndexPair.second])) { + // If the auxiliary variable is a don't care, then we don't take the corresponding command. } else { throw storm::exceptions::InvalidStateException() << "Could not retrieve value of boolean variable from illegal value."; } } - solver.pop(); + + // solver.pop(); return result; } @@ -756,7 +809,7 @@ namespace storm { // (6) Add constraints that cut off a lot of suboptimal solutions. assertExplicitCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, context, solver); - assertSymbolicCuts(program, variableInformation, relevancyInformation, context, solver); + assertSymbolicCuts(program, labeledMdp, variableInformation, relevancyInformation, context, solver); // (7) Find the smallest set of commands that satisfies all constraints. If the probability of // satisfying phi until psi exceeds the given threshold, the set of labels is minimal and can be returned. @@ -780,7 +833,14 @@ namespace storm { bool done = false; uint_fast64_t iterations = 0; do { + LOG4CPLUS_DEBUG(logger, "Computing minimal command set."); commandSet = findSmallestCommandSet(context, solver, variableInformation, softConstraints, nextFreeVariableIndex); + LOG4CPLUS_DEBUG(logger, "Computed minimal command set of size " << commandSet.size() << "."); + std::cout << "solution: " << std::endl; + for (auto label : commandSet) { + std::cout << label << ", "; + } + std::cout << std::endl; // Restrict the given MDP to the current set of labels and compute the reachability probability. storm::models::Mdp subMdp = labeledMdp.restrictChoiceLabels(commandSet); @@ -802,7 +862,12 @@ namespace storm { } ++iterations; } while (!done); - LOG4CPLUS_ERROR(logger, "Found minimal label set after " << iterations << " iterations."); + LOG4CPLUS_INFO(logger, "Found minimal label set after " << iterations << " iterations."); + + // Verify the results. + storm::ir::Program programCopy(program); + programCopy.restrictCommands(commandSet); + std::cout << programCopy.toString() << std::endl; // (8) Return the resulting command set after undefining the constants. storm::utility::ir::undefineUndefinedConstants(program); diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index 3d92acd10..c8790d2f4 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -186,5 +186,15 @@ namespace storm { } } + void Module::restrictCommands(std::set const& indexSet) { + std::vector newCommands; + for (auto const& command : commands) { + if (indexSet.find(command.getGlobalIndex()) != indexSet.end()) { + newCommands.push_back(std::move(command)); + } + } + commands = std::move(newCommands); + } + } // namespace ir } // namespace storm diff --git a/src/ir/Module.h b/src/ir/Module.h index ff96f266a..d0bcec060 100644 --- a/src/ir/Module.h +++ b/src/ir/Module.h @@ -181,6 +181,12 @@ namespace storm { */ std::set const& getCommandsByAction(std::string const& action) const; + /*! + * Deletes all commands with indices not in the given set from the module. + * + * @param indexSet The set of indices for which to keep the commands. + */ + void restrictCommands(std::set const& indexSet); private: /*! diff --git a/src/ir/Program.cpp b/src/ir/Program.cpp index b4cd0fad3..659c9b3d4 100644 --- a/src/ir/Program.cpp +++ b/src/ir/Program.cpp @@ -167,7 +167,7 @@ namespace storm { } for (auto const& label : labels) { - result << "label " << label.first << " = " << label.second->toString() <<";" << std::endl; + result << "label \"" << label.first << "\" = " << label.second->toString() <<";" << std::endl; } return result.str(); @@ -290,5 +290,11 @@ namespace storm { return this->globalIntegerVariableToIndexMap.at(variableName); } + void Program::restrictCommands(std::set const& indexSet) { + for (auto& module : modules) { + module.restrictCommands(indexSet); + } + } + } // namespace ir } // namepsace storm diff --git a/src/ir/Program.h b/src/ir/Program.h index 2ddf31da3..01bd25d04 100644 --- a/src/ir/Program.h +++ b/src/ir/Program.h @@ -261,6 +261,12 @@ namespace storm { */ uint_fast64_t getGlobalIndexOfIntegerVariable(std::string const& variableName) const; + /*! + * Deletes all commands with indices not in the given set from the program. + * + * @param indexSet The set of indices for which to keep the commands. + */ + void restrictCommands(std::set const& indexSet); private: // The type of the model. ModelType modelType; diff --git a/src/ir/expressions/BinaryBooleanFunctionExpression.cpp b/src/ir/expressions/BinaryBooleanFunctionExpression.cpp index c55af9b38..83cb1d82f 100644 --- a/src/ir/expressions/BinaryBooleanFunctionExpression.cpp +++ b/src/ir/expressions/BinaryBooleanFunctionExpression.cpp @@ -52,12 +52,12 @@ namespace storm { std::string BinaryBooleanFunctionExpression::toString() const { std::stringstream result; - result << this->getLeft()->toString(); + result << "(" << this->getLeft()->toString(); switch (functionType) { case AND: result << " & "; break; case OR: result << " | "; break; } - result << this->getRight()->toString(); + result << this->getRight()->toString() << ")"; return result.str(); } diff --git a/src/ir/expressions/BinaryNumericalFunctionExpression.cpp b/src/ir/expressions/BinaryNumericalFunctionExpression.cpp index e11d82538..51878942f 100644 --- a/src/ir/expressions/BinaryNumericalFunctionExpression.cpp +++ b/src/ir/expressions/BinaryNumericalFunctionExpression.cpp @@ -75,14 +75,14 @@ namespace storm { std::string BinaryNumericalFunctionExpression::toString() const { std::stringstream result; - result << this->getLeft()->toString(); + result << "(" << this->getLeft()->toString(); switch (functionType) { case PLUS: result << " + "; break; case MINUS: result << " - "; break; case TIMES: result << " * "; break; case DIVIDE: result << " / "; break; } - result << this->getRight()->toString(); + result << this->getRight()->toString() << ")"; return result.str(); } diff --git a/src/ir/expressions/BinaryRelationExpression.cpp b/src/ir/expressions/BinaryRelationExpression.cpp index 87ef42141..5c08e7f33 100644 --- a/src/ir/expressions/BinaryRelationExpression.cpp +++ b/src/ir/expressions/BinaryRelationExpression.cpp @@ -56,7 +56,7 @@ namespace storm { std::string BinaryRelationExpression::toString() const { std::stringstream result; - result << this->getLeft()->toString(); + result << "(" << this->getLeft()->toString(); switch (relationType) { case EQUAL: result << " = "; break; case NOT_EQUAL: result << " != "; break; @@ -65,7 +65,7 @@ namespace storm { case GREATER: result << " > "; break; case GREATER_OR_EQUAL: result << " >= "; break; } - result << this->getRight()->toString(); + result << this->getRight()->toString() << ")"; return result.str(); } diff --git a/src/ir/expressions/ConstantExpression.h b/src/ir/expressions/ConstantExpression.h index dd0db446c..dc1763c33 100644 --- a/src/ir/expressions/ConstantExpression.h +++ b/src/ir/expressions/ConstantExpression.h @@ -75,9 +75,10 @@ namespace storm { virtual std::string toString() const override { std::stringstream result; - result << this->getConstantName(); if (this->valueStructPointer->defined) { - result << "[" << this->valueStructPointer->value << "]"; + result << this->valueStructPointer->value; + } else { + result << this->getConstantName(); } return result.str(); } diff --git a/src/ir/expressions/UnaryBooleanFunctionExpression.cpp b/src/ir/expressions/UnaryBooleanFunctionExpression.cpp index 79eaa754c..3b0e6ca03 100644 --- a/src/ir/expressions/UnaryBooleanFunctionExpression.cpp +++ b/src/ir/expressions/UnaryBooleanFunctionExpression.cpp @@ -48,10 +48,11 @@ namespace storm { std::string UnaryBooleanFunctionExpression::toString() const { std::stringstream result; + result << "("; switch (functionType) { case NOT: result << "!"; break; } - result << "(" << this->getChild()->toString() << ")"; + result << this->getChild()->toString() << ")"; return result.str(); } diff --git a/src/ir/expressions/UnaryNumericalFunctionExpression.cpp b/src/ir/expressions/UnaryNumericalFunctionExpression.cpp index 228731749..e4ac21e13 100644 --- a/src/ir/expressions/UnaryNumericalFunctionExpression.cpp +++ b/src/ir/expressions/UnaryNumericalFunctionExpression.cpp @@ -62,13 +62,14 @@ namespace storm { } std::string UnaryNumericalFunctionExpression::toString() const { - std::string result = ""; + std::stringstream result; + result << "("; switch (functionType) { - case MINUS: result += "-"; break; + case MINUS: result << "-"; break; } - result += this->getChild()->toString(); + result << this->getChild()->toString() << ")"; - return result; + return result.str(); } } // namespace expressions diff --git a/src/storm.cpp b/src/storm.cpp index d0ddc0654..599db5eda 100644 --- a/src/storm.cpp +++ b/src/storm.cpp @@ -353,19 +353,19 @@ int main(const int argc, const char* argv[]) { } // Enable the following lines to test the SMTMinimalCommandSetGenerator. - if (model->getType() == storm::models::MDP) { - std::shared_ptr> labeledMdp = model->as>(); - storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); - storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); - storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; - std::set labels = storm::counterexamples::SMTMinimalCommandSetGenerator::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true); - - std::cout << "Found solution with " << labels.size() << " commands." << std::endl; - for (uint_fast64_t label : labels) { - std::cout << label << ", "; - } - std::cout << std::endl; - } +// if (model->getType() == storm::models::MDP) { +// std::shared_ptr> labeledMdp = model->as>(); +// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); +// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); +// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; +// std::set labels = storm::counterexamples::SMTMinimalCommandSetGenerator::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true); +// +// std::cout << "Found solution with " << labels.size() << " commands." << std::endl; +// for (uint_fast64_t label : labels) { +// std::cout << label << ", "; +// } +// std::cout << std::endl; +// } } // Perform clean-up and terminate.