Browse Source

Further work on MaxSAT-based minimal counterexample generator.

Former-commit-id: 847a6e202c
tempestpy_adaptions
dehnert 11 years ago
parent
commit
d6c59e2ca3
  1. 270
      src/counterexamples/SMTMinimalCommandSetGenerator.h
  2. 26
      src/storm.cpp

270
src/counterexamples/SMTMinimalCommandSetGenerator.h

@ -8,6 +8,8 @@
#ifndef STORM_COUNTEREXAMPLES_SMTMINIMALCOMMANDSETGENERATOR_MDP_H_ #ifndef STORM_COUNTEREXAMPLES_SMTMINIMALCOMMANDSETGENERATOR_MDP_H_
#define STORM_COUNTEREXAMPLES_SMTMINIMALCOMMANDSETGENERATOR_MDP_H_ #define STORM_COUNTEREXAMPLES_SMTMINIMALCOMMANDSETGENERATOR_MDP_H_
#include <queue>
// To detect whether the usage of Z3 is possible, this include is neccessary. // To detect whether the usage of Z3 is possible, this include is neccessary.
#include "storm-config.h" #include "storm-config.h"
@ -153,110 +155,231 @@ namespace storm {
for (uint_fast64_t index = 0; index < variableInformation.labelVariables.size(); ++index) { for (uint_fast64_t index = 0; index < variableInformation.labelVariables.size(); ++index) {
solver.add(!variableInformation.labelVariables[index] || variableInformation.auxiliaryVariables[index]); solver.add(!variableInformation.labelVariables[index] || variableInformation.auxiliaryVariables[index]);
} }
std::vector<std::set<uint_fast64_t>> const& choiceLabeling = labeledMdp.getChoiceLabeling();
}
/*!
* Asserts cuts that rule out a lot of suboptimal solutions.
*
* @param labeledMdp The labeled MDP for which to compute the cuts.
* @param context The Z3 context in which to build the expressions.
* @param solver The solver to use for the satisfiability evaluation.
*/
static void assertCuts(storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& psiStates, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) {
// Walk through the MDP and:
// identify labels enabled in initial states
// identify labels that can directly precede a given action
// identify labels that directly reach a target state
// identify labels that can directly follow a given action
// TODO: identify which labels need to synchronize
std::set<uint_fast64_t> initialLabels;
std::map<uint_fast64_t, std::set<uint_fast64_t>> precedingLabels;
std::set<uint_fast64_t> targetLabels;
std::map<uint_fast64_t, std::set<uint_fast64_t>> followingLabels;
std::map<uint_fast64_t, std::set<uint_fast64_t>> synchronizingLabels;
// Get some data from the MDP for convenient access.
storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix(); storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix();
// Assert that at least one of the labels of one of the relevant initial states is taken.
std::vector<z3::expr> expressionVector;
bool firstAssignment = true;
for (auto state : labeledMdp.getInitialStates()) {
if (relevancyInformation.relevantStates.get(state)) {
for (auto const& choice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) {
for (auto const& label : choiceLabeling[choice]) {
z3::expr labelExpression = variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label));
if (firstAssignment) {
expressionVector.push_back(labelExpression);
firstAssignment = false;
} else {
expressionVector.back() = expressionVector.back() && labelExpression;
std::vector<uint_fast64_t> const& nondeterministicChoiceIndices = labeledMdp.getNondeterministicChoiceIndices();
storm::storage::BitVector const& initialStates = labeledMdp.getInitialStates();
std::vector<std::set<uint_fast64_t>> const& choiceLabeling = labeledMdp.getChoiceLabeling();
storm::storage::SparseMatrix<bool> backwardTransitions = labeledMdp.getBackwardTransitions();
for (auto currentState : relevancyInformation.relevantStates) {
for (auto currentChoice : relevancyInformation.relevantChoicesForRelevantStates.at(currentState)) {
// If the state is initial, we need to add all the choice labels to the initial label set.
if (initialStates.get(currentState)) {
for (auto label : choiceLabeling[currentChoice]) {
initialLabels.insert(label);
}
}
// Iterate over successors and add relevant choices of relevant successors to the following label set.
bool canReachTargetState = false;
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(currentChoice), successorIte = transitionMatrix.constColumnIteratorEnd(currentChoice); successorIt != successorIte; ++successorIt) {
if (relevancyInformation.relevantStates.get(*successorIt)) {
for (auto relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(*successorIt)) {
for (auto labelToAdd : choiceLabeling[relevantChoice]) {
for (auto labelForWhichToAdd : choiceLabeling[currentChoice]) {
followingLabels[labelForWhichToAdd].insert(labelToAdd);
}
}
}
} else if (psiStates.get(*successorIt)) {
canReachTargetState = true;
}
}
// If the choice can reach a target state directly, we add all the labels to the target label set.
if (canReachTargetState) {
for (auto label : choiceLabeling[currentChoice]) {
targetLabels.insert(label);
}
}
// Iterate over predecessors and add all choices that target the current state to the preceding
// label set of all labels of all relevant choices of the current state.
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(currentState), predecessorIte = backwardTransitions.constColumnIteratorEnd(currentState); predecessorIt != predecessorIte; ++predecessorIt) {
for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(*predecessorIt)) {
bool choiceTargetsCurrentState = false;
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(predecessorChoice), successorIte = transitionMatrix.constColumnIteratorEnd(predecessorChoice); successorIt != successorIte; ++successorIt) {
if (*successorIt == currentState) {
choiceTargetsCurrentState = true;
}
}
if (choiceTargetsCurrentState) {
for (auto labelToAdd : choiceLabeling[predecessorChoice]) {
for (auto labelForWhichToAdd : choiceLabeling[currentChoice]) {
precedingLabels[labelForWhichToAdd].insert(labelToAdd);
}
}
} }
} }
} }
} }
} }
assertDisjunction(context, solver, expressionVector);
// Assert that at least one of the labels that are selected can reach a target state in one step.
storm::storage::SparseMatrix<bool> backwardTransitions = labeledMdp.getBackwardTransitions();
// Compute the set of predecessors of target states.
std::unordered_set<uint_fast64_t> predecessors;
std::vector<z3::expr> formulae;
// Start by asserting that we take at least one initial label.
for (auto label : initialLabels) {
formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)));
}
assertDisjunction(context, solver, formulae);
formulae.clear();
// Also assert that we take at least one target label.
for (auto label : targetLabels) {
formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)));
}
assertDisjunction(context, solver, formulae);
// Now assert that for each non-target label, we take a following label.
for (auto const& labelSetPair : followingLabels) {
formulae.clear();
if (targetLabels.find(labelSetPair.first) == targetLabels.end()) {
formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(labelSetPair.first)));
for (auto followingLabel : labelSetPair.second) {
formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(followingLabel)));
}
}
if (formulae.size() > 0) {
assertDisjunction(context, solver, formulae);
}
}
// Consequently, assert that for each non-initial label, we take preceding command.
for (auto const& labelSetPair : precedingLabels) {
formulae.clear();
if (initialLabels.find(labelSetPair.first) == initialLabels.end()) {
formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(labelSetPair.first)));
for (auto followingLabel : labelSetPair.second) {
formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(followingLabel)));
}
}
if (formulae.size() > 0) {
assertDisjunction(context, solver, formulae);
}
}
// Now we compute the set of labels that is present on all paths from the initial to the target states.
std::vector<std::set<uint_fast64_t>> analysisInformation(labeledMdp.getNumberOfStates(), relevancyInformation.relevantLabels);
std::queue<std::pair<uint_fast64_t, uint_fast64_t>> worklist;
// Initially, put all predecessors of target states in the worklist and empty the analysis information
// them.
for (auto state : psiStates) { for (auto state : psiStates) {
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(state); predecessorIt != backwardTransitions.constColumnIteratorEnd(state); ++predecessorIt) {
if (state != *predecessorIt) {
predecessors.insert(*predecessorIt);
analysisInformation[state] = std::set<uint_fast64_t>();
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(state), predecessorIte = backwardTransitions.constColumnIteratorEnd(state); predecessorIt != predecessorIte; ++predecessorIt) {
if (relevancyInformation.relevantStates.get(*predecessorIt)) {
worklist.push(std::make_pair(*predecessorIt, state));
} }
} }
} }
expressionVector.clear();
firstAssignment = true;
for (auto state : predecessors) {
for (auto choice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) {
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(choice); successorIt != transitionMatrix.constColumnIteratorEnd(choice); ++successorIt) {
if (psiStates.get(*successorIt)) {
for (auto const& label : choiceLabeling[choice]) {
z3::expr labelExpression = variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label));
if (firstAssignment) {
expressionVector.push_back(labelExpression);
firstAssignment = false;
} else {
expressionVector.back() = expressionVector.back() && labelExpression;
}
}
// Iterate as long as the worklist is non-empty.
while (!worklist.empty()) {
std::pair<uint_fast64_t, uint_fast64_t> const& currentStateTargetStatePair = worklist.front();
uint_fast64_t currentState = currentStateTargetStatePair.first;
uint_fast64_t targetState = currentStateTargetStatePair.second;
// Iterate over the successor states for all choices and compute new analysis information.
std::set<uint_fast64_t> intersection;
for (auto currentChoice : relevancyInformation.relevantChoicesForRelevantStates.at(currentState)) {
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(currentChoice), successorIte = transitionMatrix.constColumnIteratorEnd(currentChoice); successorIt != successorIte; ++successorIt) {
// If we can reach the target state with this choice, we need to intersect the current
// analysis information with the union of the new analysis information of the target state
// and the choice labels.
if (*successorIt == targetState) {
std::set_intersection(analysisInformation[currentState].begin(), analysisInformation[currentState].end(), analysisInformation[targetState].begin(), analysisInformation[targetState].end(), std::inserter(intersection, intersection.begin()));
std::set<uint_fast64_t> choiceLabelIntersection;
std::set_intersection(analysisInformation[currentState].begin(), analysisInformation[currentState].end(), choiceLabeling[currentChoice].begin(), choiceLabeling[currentChoice].end(), std::inserter(intersection, intersection.begin()));
} }
} }
} }
// If the analysis information changed, we need to update it and put all the predecessors of this
// state in the worklist.
if (analysisInformation[currentState] != intersection) {
analysisInformation[currentState] = std::move(intersection);
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(currentState), predecessorIte = backwardTransitions.constColumnIteratorEnd(currentState); predecessorIt != predecessorIte; ++predecessorIt) {
worklist.push(std::make_pair(*predecessorIt, currentState));
}
}
worklist.pop();
} }
assertDisjunction(context, solver, expressionVector);
}
/*!
* Asserts cuts that rule out a lot of suboptimal solutions.
*
* @param program The program for which to derive the cuts.
* @param context The Z3 context in which to build the expressions.
* @param solver The solver to use for the satisfiability evaluation.
*/
static void assertCuts(storm::ir::Program const& program, z3::context& context, z3::solver& solver) {
// TODO.
// Now build the intersection over the analysis information of all initial states.
std::set<uint_fast64_t> knownLabels(relevancyInformation.relevantLabels);
std::set<uint_fast64_t> tempIntersection;
for (auto initialState : labeledMdp.getInitialStates()) {
std::set_intersection(knownLabels.begin(), knownLabels.end(), analysisInformation[initialState].begin(), analysisInformation[initialState].end(), std::inserter(tempIntersection, tempIntersection.begin()));
std::swap(knownLabels, tempIntersection);
tempIntersection.clear();
}
formulae.clear();
for (auto label : knownLabels) {
formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)));
}
assertConjunction(context, solver, formulae);
} }
/*! /*!
* Asserts that the disjunction of the given formulae holds.
* Asserts that the disjunction of the given formulae holds. If the content of the disjunction is empty,
* this corresponds to asserting false.
* *
* @param context The Z3 context in which to build the expressions. * @param context The Z3 context in which to build the expressions.
* @param solver The solver to use for the satisfiability evaluation. * @param solver The solver to use for the satisfiability evaluation.
* @param formulaVector A vector of expressions that shall form the disjunction. * @param formulaVector A vector of expressions that shall form the disjunction.
*/ */
static void assertDisjunction(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& formulaVector) { static void assertDisjunction(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& formulaVector) {
z3::expr disjunction(context);
for (uint_fast64_t i = 0; i < formulaVector.size(); ++i) {
if (i == 0) {
disjunction = formulaVector[i];
} else {
disjunction = disjunction || formulaVector[i];
}
z3::expr disjunction = context.bool_val(false);
for (auto expr : formulaVector) {
disjunction = disjunction || expr;
} }
solver.add(disjunction); solver.add(disjunction);
} }
/*! /*!
* Asserts that the conjunction of the given formulae holds.
* Asserts that the conjunction of the given formulae holds. If the content of the conjunction is empty,
* this corresponds to asserting true.
* *
* @param context The Z3 context in which to build the expressions. * @param context The Z3 context in which to build the expressions.
* @param solver The solver to use for the satisfiability evaluation. * @param solver The solver to use for the satisfiability evaluation.
* @param formulaVector A vector of expressions that shall form the conjunction. * @param formulaVector A vector of expressions that shall form the conjunction.
*/ */
static void assertConjunction(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& formulaVector) { static void assertConjunction(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& formulaVector) {
z3::expr conjunction(context);
for (uint_fast64_t i = 0; i < formulaVector.size(); ++i) {
if (i == 0) {
conjunction = formulaVector[i];
} else {
conjunction = conjunction && formulaVector[i];
}
z3::expr conjunction = context.bool_val(true);
for (auto expr : formulaVector) {
conjunction = conjunction && expr;
} }
solver.add(conjunction); solver.add(conjunction);
} }
@ -383,13 +506,6 @@ namespace storm {
static void assertAtMostOne(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& literals) { static void assertAtMostOne(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& literals) {
std::vector<z3::expr> counter = createCounterCircuit(context, literals); std::vector<z3::expr> counter = createCounterCircuit(context, literals);
assertLessOrEqualOne(context, solver, counter); assertLessOrEqualOne(context, solver, counter);
// for (uint_fast64_t i = 0; i < blockingVariables.size(); ++i) {
// for (uint_fast64_t j = i + 1; j < blockingVariables.size(); ++j) {
// solver.add(!blockingVariables[i] || !blockingVariables[j]);
// }
// }
} }
/*! /*!
@ -545,7 +661,7 @@ namespace storm {
assertInitialConstraints(program, labeledMdp, psiStates, context, solver, variableInformation, relevancyInformation); assertInitialConstraints(program, labeledMdp, psiStates, context, solver, variableInformation, relevancyInformation);
// (6) Add constraints that cut off a lot of suboptimal solutions. // (6) Add constraints that cut off a lot of suboptimal solutions.
assertCuts(program, context, solver);
assertCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, context, solver);
// (7) Find the smallest set of commands that satisfies all constraints. If the probability of // (7) Find the smallest set of commands that satisfies all constraints. If the probability of
// satisfying phi until psi exceeds the given threshold, the set of labels is minimal and can be returned. // satisfying phi until psi exceeds the given threshold, the set of labels is minimal and can be returned.

26
src/storm.cpp

@ -338,19 +338,19 @@ int main(const int argc, const char* argv[]) {
model->printModelInformationToStream(std::cout); model->printModelInformationToStream(std::cout);
// Enable the following lines to test the MinimalLabelSetGenerator. // Enable the following lines to test the MinimalLabelSetGenerator.
if (model->getType() == storm::models::MDP) {
std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
std::unordered_set<uint_fast64_t> labels = storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true);
std::cout << "Found solution with " << labels.size() << " commands." << std::endl;
for (uint_fast64_t label : labels) {
std::cout << label << ", ";
}
std::cout << std::endl;
}
// if (model->getType() == storm::models::MDP) {
// std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
// std::unordered_set<uint_fast64_t> labels = storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true);
//
// std::cout << "Found solution with " << labels.size() << " commands." << std::endl;
// for (uint_fast64_t label : labels) {
// std::cout << label << ", ";
// }
// std::cout << std::endl;
// }
// Enable the following lines to test the SMTMinimalCommandSetGenerator. // Enable the following lines to test the SMTMinimalCommandSetGenerator.
if (model->getType() == storm::models::MDP) { if (model->getType() == storm::models::MDP) {

Loading…
Cancel
Save