Browse Source

Fix for SMT-based minimal command set generator. Minor fixes to string output of expression classes.

Former-commit-id: 316a762d74
main
dehnert 12 years ago
parent
commit
fda9c43e86
  1. 12
      examples/mdp/consensus/coin4.nm
  2. 117
      src/counterexamples/SMTMinimalCommandSetGenerator.h
  3. 10
      src/ir/Module.cpp
  4. 6
      src/ir/Module.h
  5. 8
      src/ir/Program.cpp
  6. 6
      src/ir/Program.h
  7. 4
      src/ir/expressions/BinaryBooleanFunctionExpression.cpp
  8. 4
      src/ir/expressions/BinaryNumericalFunctionExpression.cpp
  9. 4
      src/ir/expressions/BinaryRelationExpression.cpp
  10. 5
      src/ir/expressions/ConstantExpression.h
  11. 3
      src/ir/expressions/UnaryBooleanFunctionExpression.cpp
  12. 9
      src/ir/expressions/UnaryNumericalFunctionExpression.cpp
  13. 26
      src/storm.cpp

12
examples/mdp/consensus/coin4.nm

@ -29,18 +29,18 @@ module process1
// flip coin
[] (pc1=0) -> 0.5 : (coin1'=0) & (pc1'=1) + 0.5 : (coin1'=1) & (pc1'=1);
// write tails -1 (reset coin to add regularity)
[] (pc1=1) & (coin1=0) & (counter>0) -> (counter'=counter-1) & (pc1'=2) & (coin1'=0);
[] (pc1=1) & (coin1=0) & (counter>0) -> 1 : (counter'=counter-1) & (pc1'=2) & (coin1'=0);
// write heads +1 (reset coin to add regularity)
[] (pc1=1) & (coin1=1) & (counter<range) -> (counter'=counter+1) & (pc1'=2) & (coin1'=0);
[] (pc1=1) & (coin1=1) & (counter<range) -> 1 : (counter'=counter+1) & (pc1'=2) & (coin1'=0);
// check
// decide tails
[] (pc1=2) & (counter<=left) -> (pc1'=3) & (coin1'=0);
[] (pc1=2) & (counter<=left) -> 1 : (pc1'=3) & (coin1'=0);
// decide heads
[] (pc1=2) & (counter>=right) -> (pc1'=3) & (coin1'=1);
[] (pc1=2) & (counter>=right) -> 1 : (pc1'=3) & (coin1'=1);
// flip again
[] (pc1=2) & (counter>left) & (counter<right) -> (pc1'=0);
[] (pc1=2) & (counter>left) & (counter<right) -> 1 : (pc1'=0);
// loop (all loop together when done)
[done] (pc1=3) -> (pc1'=3);
[done] (pc1=3) -> 1 : (pc1'=3);
endmodule

117
src/counterexamples/SMTMinimalCommandSetGenerator.h

@ -45,6 +45,7 @@ namespace storm {
struct VariableInformation {
std::vector<z3::expr> labelVariables;
std::vector<z3::expr> originalAuxiliaryVariables;
std::vector<z3::expr> auxiliaryVariables;
std::map<uint_fast64_t, uint_fast64_t> labelToIndexMap;
};
@ -134,6 +135,8 @@ namespace storm {
variableInformation.auxiliaryVariables.push_back(context.bool_const(variableName.str().c_str()));
}
variableInformation.originalAuxiliaryVariables = variableInformation.auxiliaryVariables;
return variableInformation;
}
@ -305,11 +308,55 @@ namespace storm {
* @param context The Z3 context in which to build the expressions.
* @param solver The solver to use for the satisfiability evaluation.
*/
static void assertSymbolicCuts(storm::ir::Program const& program, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) {
// FIXME:
// find synchronization cuts
// find forward/backward cuts
static void assertSymbolicCuts(storm::ir::Program const& program, storm::models::Mdp<T> const& labeledMdp, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) {
// Initially, we look for synchronisation implications.
// for (uint_fast64_t moduleIndex = 0; moduleIndex < program.getNumberOfModules(); ++moduleIndex) {
// storm::ir::Module const& module = program.getModule(moduleIndex);
//
// for (uint_fast64_t commandIndex = 0; commandIndex < module.getNumberOfCommands(); ++commandIndex) {
// storm::ir::Command const& command = module.getCommand(commandIndex);
//
// // If the command is unlabeled, there are no synchronisation cuts to apply.
// if (command.getActionName() == "") continue;
// }
// }
std::map<uint_fast64_t, std::set<uint_fast64_t>> precedingLabels;
// Get some data from the MDP for convenient access.
storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix();
std::vector<std::set<uint_fast64_t>> const& choiceLabeling = labeledMdp.getChoiceLabeling();
storm::storage::SparseMatrix<bool> backwardTransitions = labeledMdp.getBackwardTransitions();
for (auto currentState : relevancyInformation.relevantStates) {
for (auto currentChoice : relevancyInformation.relevantChoicesForRelevantStates.at(currentState)) {
// Iterate over predecessors and add all choices that target the current state to the preceding
// label set of all labels of all relevant choices of the current state.
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(currentState), predecessorIte = backwardTransitions.constColumnIteratorEnd(currentState); predecessorIt != predecessorIte; ++predecessorIt) {
for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(*predecessorIt)) {
bool choiceTargetsCurrentState = false;
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(predecessorChoice), successorIte = transitionMatrix.constColumnIteratorEnd(predecessorChoice); successorIt != successorIte; ++successorIt) {
if (*successorIt == currentState) {
choiceTargetsCurrentState = true;
}
}
if (choiceTargetsCurrentState) {
for (auto labelToAdd : choiceLabeling[predecessorChoice]) {
for (auto labelForWhichToAdd : choiceLabeling[currentChoice]) {
precedingLabels[labelForWhichToAdd].insert(labelToAdd);
}
}
}
}
}
}
}
// FIXME: The following procedure to assert backward cuts is not correct in the presence of synchronizing
// actions, because it may be the case that several synchronizing commands are necessary to enable another
// command and not just one.
storm::utility::ir::VariableInformation programVariableInformation = storm::utility::ir::createVariableInformation(program);
// Create a context and register all variables of the program with their correct type.
@ -352,13 +399,13 @@ namespace storm {
std::map<uint_fast64_t, std::set<uint_fast64_t>> backwardImplications;
// First check for possible backward cuts.
// Now check for possible backward cuts.
for (uint_fast64_t moduleIndex = 0; moduleIndex < program.getNumberOfModules(); ++moduleIndex) {
storm::ir::Module const& module = program.getModule(moduleIndex);
for (uint_fast64_t commandIndex = 0; commandIndex < module.getNumberOfCommands(); ++commandIndex) {
storm::ir::Command const& command = module.getCommand(commandIndex);
// If the label of the command is not relevant, skip it entirely.
if (relevancyInformation.relevantLabels.find(command.getGlobalIndex()) == relevancyInformation.relevantLabels.end()) continue;
@ -373,8 +420,7 @@ namespace storm {
localSolver.pop();
localSolver.push();
// If it is not and the action is not synchronizing, we can impose backward cuts.
if (checkResult == z3::unsat && command.getActionName() == "") {
if (checkResult == z3::unsat) {
localSolver.add(!expressionAdapter.translateExpression(command.getGuard()));
localSolver.push();
@ -426,7 +472,10 @@ namespace storm {
for (auto const& labelImplicationsPair : backwardImplications) {
formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(labelImplicationsPair.first)));
for (auto label : labelImplicationsPair.second) {
std::set<uint_fast64_t> actualImplications;
std::set_intersection(labelImplicationsPair.second.begin(), labelImplicationsPair.second.end(), precedingLabels.at(labelImplicationsPair.first).begin(), precedingLabels.at(labelImplicationsPair.first).end(), std::inserter(actualImplications, actualImplications.begin()));
for (auto label : actualImplications) {
formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)));
}
@ -599,10 +648,10 @@ namespace storm {
* @param variableInformation A structure with information about the variables for the labels.
* @return True iff the constraint system was satisfiable.
*/
static bool fuMalikMaxsatStep(z3::context& context, z3::solver& solver, VariableInformation& variableInformation, std::vector<z3::expr>& softConstraints, uint_fast64_t& nextFreeVariableIndex) {
static bool fuMalikMaxsatStep(z3::context& context, z3::solver& solver, std::vector<z3::expr>& auxiliaryVariables, std::vector<z3::expr>& softConstraints, uint_fast64_t& nextFreeVariableIndex) {
z3::expr_vector assumptions(context);
for (auto const& auxVariable : variableInformation.auxiliaryVariables) {
assumptions.push_back(!auxVariable);
for (auto const& auxiliaryVariable : auxiliaryVariables) {
assumptions.push_back(!auxiliaryVariable);
}
// Check whether the assumptions are satisfiable.
@ -640,11 +689,11 @@ namespace storm {
variableName.str("");
variableName << "a" << nextFreeVariableIndex;
++nextFreeVariableIndex;
variableInformation.auxiliaryVariables[softConstraintIndex] = context.bool_const(variableName.str().c_str());
auxiliaryVariables[softConstraintIndex] = context.bool_const(variableName.str().c_str());
softConstraints[softConstraintIndex] = softConstraints[softConstraintIndex] || blockingVariables.back();
solver.add(softConstraints[softConstraintIndex] || variableInformation.auxiliaryVariables[softConstraintIndex]);
solver.add(softConstraints[softConstraintIndex] || auxiliaryVariables[softConstraintIndex]);
}
}
}
@ -695,9 +744,13 @@ namespace storm {
*/
static std::set<uint_fast64_t> findSmallestCommandSet(z3::context& context, z3::solver& solver, VariableInformation& variableInformation, std::vector<z3::expr>& softConstraints, uint_fast64_t& nextFreeVariableIndex) {
solver.push();
// Copy the original auxiliary variables and soft constraints so the procedure can modify the copies.
// std::vector<z3::expr> auxiliaryVariables(variableInformation.auxiliaryVariables);
// std::vector<z3::expr> tmpSoftConstraints(softConstraints);
// solver.push();
for (uint_fast64_t i = 0; ; ++i) {
if (fuMalikMaxsatStep(context, solver, variableInformation, softConstraints, nextFreeVariableIndex)) {
if (fuMalikMaxsatStep(context, solver, variableInformation.auxiliaryVariables, softConstraints, nextFreeVariableIndex)) {
break;
}
}
@ -707,21 +760,21 @@ namespace storm {
z3::model model = solver.get_model();
for (auto const& labelIndexPair : variableInformation.labelToIndexMap) {
z3::expr value = model.eval(variableInformation.labelVariables[labelIndexPair.second]);
z3::expr auxValue = model.eval(variableInformation.originalAuxiliaryVariables[labelIndexPair.second]);
// Check whether the label variable was set or not.
if (eq(value, context.bool_val(true))) {
// Check whether the auxiliary variable was set or not.
if (eq(auxValue, context.bool_val(true))) {
result.insert(labelIndexPair.first);
} else if (eq(value, context.bool_val(false))) {
} else if (eq(auxValue, context.bool_val(false))) {
// Nothing to do in this case.
} else if (eq(value, variableInformation.labelVariables[labelIndexPair.second])) {
// If the variable is a "don't care", then we rather not take it, so nothing to do in this case
// as well.
} else if (eq(auxValue, variableInformation.originalAuxiliaryVariables[labelIndexPair.second])) {
// If the auxiliary variable is a don't care, then we don't take the corresponding command.
} else {
throw storm::exceptions::InvalidStateException() << "Could not retrieve value of boolean variable from illegal value.";
}
}
solver.pop();
// solver.pop();
return result;
}
@ -756,7 +809,7 @@ namespace storm {
// (6) Add constraints that cut off a lot of suboptimal solutions.
assertExplicitCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, context, solver);
assertSymbolicCuts(program, variableInformation, relevancyInformation, context, solver);
assertSymbolicCuts(program, labeledMdp, variableInformation, relevancyInformation, context, solver);
// (7) Find the smallest set of commands that satisfies all constraints. If the probability of
// satisfying phi until psi exceeds the given threshold, the set of labels is minimal and can be returned.
@ -780,7 +833,14 @@ namespace storm {
bool done = false;
uint_fast64_t iterations = 0;
do {
LOG4CPLUS_DEBUG(logger, "Computing minimal command set.");
commandSet = findSmallestCommandSet(context, solver, variableInformation, softConstraints, nextFreeVariableIndex);
LOG4CPLUS_DEBUG(logger, "Computed minimal command set of size " << commandSet.size() << ".");
std::cout << "solution: " << std::endl;
for (auto label : commandSet) {
std::cout << label << ", ";
}
std::cout << std::endl;
// Restrict the given MDP to the current set of labels and compute the reachability probability.
storm::models::Mdp<T> subMdp = labeledMdp.restrictChoiceLabels(commandSet);
@ -802,7 +862,12 @@ namespace storm {
}
++iterations;
} while (!done);
LOG4CPLUS_ERROR(logger, "Found minimal label set after " << iterations << " iterations.");
LOG4CPLUS_INFO(logger, "Found minimal label set after " << iterations << " iterations.");
// Verify the results.
storm::ir::Program programCopy(program);
programCopy.restrictCommands(commandSet);
std::cout << programCopy.toString() << std::endl;
// (8) Return the resulting command set after undefining the constants.
storm::utility::ir::undefineUndefinedConstants(program);

10
src/ir/Module.cpp

@ -186,5 +186,15 @@ namespace storm {
}
}
void Module::restrictCommands(std::set<uint_fast64_t> const& indexSet) {
std::vector<storm::ir::Command> newCommands;
for (auto const& command : commands) {
if (indexSet.find(command.getGlobalIndex()) != indexSet.end()) {
newCommands.push_back(std::move(command));
}
}
commands = std::move(newCommands);
}
} // namespace ir
} // namespace storm

6
src/ir/Module.h

@ -181,6 +181,12 @@ namespace storm {
*/
std::set<uint_fast64_t> const& getCommandsByAction(std::string const& action) const;
/*!
* Deletes all commands with indices not in the given set from the module.
*
* @param indexSet The set of indices for which to keep the commands.
*/
void restrictCommands(std::set<uint_fast64_t> const& indexSet);
private:
/*!

8
src/ir/Program.cpp

@ -167,7 +167,7 @@ namespace storm {
}
for (auto const& label : labels) {
result << "label " << label.first << " = " << label.second->toString() <<";" << std::endl;
result << "label \"" << label.first << "\" = " << label.second->toString() <<";" << std::endl;
}
return result.str();
@ -290,5 +290,11 @@ namespace storm {
return this->globalIntegerVariableToIndexMap.at(variableName);
}
void Program::restrictCommands(std::set<uint_fast64_t> const& indexSet) {
for (auto& module : modules) {
module.restrictCommands(indexSet);
}
}
} // namespace ir
} // namepsace storm

6
src/ir/Program.h

@ -261,6 +261,12 @@ namespace storm {
*/
uint_fast64_t getGlobalIndexOfIntegerVariable(std::string const& variableName) const;
/*!
* Deletes all commands with indices not in the given set from the program.
*
* @param indexSet The set of indices for which to keep the commands.
*/
void restrictCommands(std::set<uint_fast64_t> const& indexSet);
private:
// The type of the model.
ModelType modelType;

4
src/ir/expressions/BinaryBooleanFunctionExpression.cpp

@ -52,12 +52,12 @@ namespace storm {
std::string BinaryBooleanFunctionExpression::toString() const {
std::stringstream result;
result << this->getLeft()->toString();
result << "(" << this->getLeft()->toString();
switch (functionType) {
case AND: result << " & "; break;
case OR: result << " | "; break;
}
result << this->getRight()->toString();
result << this->getRight()->toString() << ")";
return result.str();
}

4
src/ir/expressions/BinaryNumericalFunctionExpression.cpp

@ -75,14 +75,14 @@ namespace storm {
std::string BinaryNumericalFunctionExpression::toString() const {
std::stringstream result;
result << this->getLeft()->toString();
result << "(" << this->getLeft()->toString();
switch (functionType) {
case PLUS: result << " + "; break;
case MINUS: result << " - "; break;
case TIMES: result << " * "; break;
case DIVIDE: result << " / "; break;
}
result << this->getRight()->toString();
result << this->getRight()->toString() << ")";
return result.str();
}

4
src/ir/expressions/BinaryRelationExpression.cpp

@ -56,7 +56,7 @@ namespace storm {
std::string BinaryRelationExpression::toString() const {
std::stringstream result;
result << this->getLeft()->toString();
result << "(" << this->getLeft()->toString();
switch (relationType) {
case EQUAL: result << " = "; break;
case NOT_EQUAL: result << " != "; break;
@ -65,7 +65,7 @@ namespace storm {
case GREATER: result << " > "; break;
case GREATER_OR_EQUAL: result << " >= "; break;
}
result << this->getRight()->toString();
result << this->getRight()->toString() << ")";
return result.str();
}

5
src/ir/expressions/ConstantExpression.h

@ -75,9 +75,10 @@ namespace storm {
virtual std::string toString() const override {
std::stringstream result;
result << this->getConstantName();
if (this->valueStructPointer->defined) {
result << "[" << this->valueStructPointer->value << "]";
result << this->valueStructPointer->value;
} else {
result << this->getConstantName();
}
return result.str();
}

3
src/ir/expressions/UnaryBooleanFunctionExpression.cpp

@ -48,10 +48,11 @@ namespace storm {
std::string UnaryBooleanFunctionExpression::toString() const {
std::stringstream result;
result << "(";
switch (functionType) {
case NOT: result << "!"; break;
}
result << "(" << this->getChild()->toString() << ")";
result << this->getChild()->toString() << ")";
return result.str();
}

9
src/ir/expressions/UnaryNumericalFunctionExpression.cpp

@ -62,13 +62,14 @@ namespace storm {
}
std::string UnaryNumericalFunctionExpression::toString() const {
std::string result = "";
std::stringstream result;
result << "(";
switch (functionType) {
case MINUS: result += "-"; break;
case MINUS: result << "-"; break;
}
result += this->getChild()->toString();
result << this->getChild()->toString() << ")";
return result;
return result.str();
}
} // namespace expressions

26
src/storm.cpp

@ -353,19 +353,19 @@ int main(const int argc, const char* argv[]) {
}
// Enable the following lines to test the SMTMinimalCommandSetGenerator.
if (model->getType() == storm::models::MDP) {
std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true);
std::cout << "Found solution with " << labels.size() << " commands." << std::endl;
for (uint_fast64_t label : labels) {
std::cout << label << ", ";
}
std::cout << std::endl;
}
// if (model->getType() == storm::models::MDP) {
// std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
// std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true);
//
// std::cout << "Found solution with " << labels.size() << " commands." << std::endl;
// for (uint_fast64_t label : labels) {
// std::cout << label << ", ";
// }
// std::cout << std::endl;
// }
}
// Perform clean-up and terminate.

Loading…
Cancel
Save