Browse Source

Made all examples from the MILP-paper work. Most of them are really slow though.

Former-commit-id: 1f3f5afb9a
tempestpy_adaptions
dehnert 11 years ago
parent
commit
c31dbc85a7
  1. 14
      src/adapters/Z3ExpressionAdapter.h
  2. 102
      src/counterexamples/SMTMinimalCommandSetGenerator.h
  3. 2
      src/ir/Command.cpp
  4. 8
      src/parser/prismparser/ConstIntegerExpressionGrammar.cpp
  5. 2
      src/parser/prismparser/ConstIntegerExpressionGrammar.h
  6. 2
      src/parser/prismparser/PrismGrammar.cpp
  7. 26
      src/storm.cpp

14
src/adapters/Z3ExpressionAdapter.h

@ -58,7 +58,7 @@ namespace storm {
stack.push(leftResult || rightResult);
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean binary operator: '" << expression->getFunctionType() << "'.";
<< "Unknown boolean binary operator: '" << expression->getFunctionType() << "' in expression " << expression->toString() << ".";
}
}
@ -85,8 +85,14 @@ namespace storm {
case storm::ir::expressions::BinaryNumericalFunctionExpression::DIVIDE:
stack.push(leftResult / rightResult);
break;
case storm::ir::expressions::BinaryNumericalFunctionExpression::MIN:
stack.push(ite(leftResult <= rightResult, leftResult, rightResult));
break;
case storm::ir::expressions::BinaryNumericalFunctionExpression::MAX:
stack.push(ite(leftResult >= rightResult, leftResult, rightResult));
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean binary operator: '" << expression->getFunctionType() << "'.";
<< "Unknown numerical binary operator: '" << expression->getFunctionType() << "' in expression " << expression->toString() << ".";
}
}
@ -119,7 +125,7 @@ namespace storm {
stack.push(leftResult >= rightResult);
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean binary operator: '" << expression->getRelationType() << "'.";
<< "Unknown boolean binary operator: '" << expression->getRelationType() << "' in expression " << expression->toString() << ".";
}
}
@ -177,7 +183,7 @@ namespace storm {
stack.push(!childResult);
break;
default: throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: "
<< "Unknown boolean unary operator: '" << expression->getFunctionType() << "'.";
<< "Unknown boolean binary operator: '" << expression->getFunctionType() << "' in expression " << expression->toString() << ".";
}
}

102
src/counterexamples/SMTMinimalCommandSetGenerator.h

@ -49,7 +49,7 @@ namespace storm {
std::set<uint_fast64_t> knownLabels;
// A list of relevant choices for each relevant state.
std::unordered_map<uint_fast64_t, std::list<uint_fast64_t>> relevantChoicesForRelevantStates;
std::map<uint_fast64_t, std::list<uint_fast64_t>> relevantChoicesForRelevantStates;
};
struct VariableInformation {
@ -253,10 +253,12 @@ namespace storm {
targetLabels.insert(label);
}
}
// Iterate over predecessors and add all choices that target the current state to the preceding
// label set of all labels of all relevant choices of the current state.
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(currentState), predecessorIte = backwardTransitions.constColumnIteratorEnd(currentState); predecessorIt != predecessorIte; ++predecessorIt) {
}
// Iterate over predecessors and add all choices that target the current state to the preceding
// label set of all labels of all relevant choices of the current state.
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(currentState), predecessorIte = backwardTransitions.constColumnIteratorEnd(currentState); predecessorIt != predecessorIte; ++predecessorIt) {
if (relevancyInformation.relevantStates.get(*predecessorIt)) {
for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(*predecessorIt)) {
bool choiceTargetsCurrentState = false;
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(predecessorChoice), successorIte = transitionMatrix.constColumnIteratorEnd(predecessorChoice); successorIt != successorIte; ++successorIt) {
@ -266,9 +268,11 @@ namespace storm {
}
if (choiceTargetsCurrentState) {
for (auto labelToAdd : choiceLabeling[predecessorChoice]) {
for (auto labelForWhichToAdd : choiceLabeling[currentChoice]) {
precedingLabels[labelForWhichToAdd].insert(labelToAdd);
for (auto currentChoice : relevancyInformation.relevantChoicesForRelevantStates.at(currentState)) {
for (auto labelToAdd : choiceLabeling[predecessorChoice]) {
for (auto labelForWhichToAdd : choiceLabeling[currentChoice]) {
precedingLabels[labelForWhichToAdd].insert(labelToAdd);
}
}
}
}
@ -277,8 +281,11 @@ namespace storm {
}
}
LOG4CPLUS_DEBUG(logger, "Successfully gathered data for explicit cuts.");
std::vector<z3::expr> formulae;
LOG4CPLUS_DEBUG(logger, "Asserting initial label is taken.");
// Start by asserting that we take at least one initial label. We may do so only if there is no initial
// label that is already known. Otherwise this condition would be too strong.
std::set<uint_fast64_t> intersection;
@ -294,6 +301,7 @@ namespace storm {
intersection.clear();
}
LOG4CPLUS_DEBUG(logger, "Asserting target label is taken.");
// Likewise, if no target label is known, we may assert that there is at least one.
std::set_intersection(targetLabels.begin(), targetLabels.end(), relevancyInformation.knownLabels.begin(), relevancyInformation.knownLabels.end(), std::inserter(intersection, intersection.begin()));
if (intersection.empty()) {
@ -306,6 +314,7 @@ namespace storm {
intersection.clear();
}
LOG4CPLUS_DEBUG(logger, "Asserting taken labels are followed by another label if they are not a target label.");
// Now assert that for each non-target label, we take a following label.
for (auto const& labelSetPair : followingLabels) {
formulae.clear();
@ -330,21 +339,38 @@ namespace storm {
}
}
LOG4CPLUS_DEBUG(logger, "Asserting synchronization cuts.");
// Finally, assert that if we take one of the synchronizing labels, we also take one of the combinations
// the label appears in.
for (auto const& labelSynchronizingSetsPair : synchronizingLabels) {
formulae.clear();
formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(labelSynchronizingSetsPair.first)));
if (relevancyInformation.knownLabels.find(labelSynchronizingSetsPair.first) == relevancyInformation.knownLabels.end()) {
formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(labelSynchronizingSetsPair.first)));
}
// We need to be careful, because there may be one synchronisation set out of which all labels are
// known, which means we must not assert anything.
bool allImplicantsKnownForOneSet = false;
for (auto const& synchronizingSet : labelSynchronizingSetsPair.second) {
z3::expr currentCombination = context.bool_val(true);
bool allImplicantsKnownForCurrentSet = true;
for (auto label : synchronizingSet) {
currentCombination = currentCombination && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label));
if (relevancyInformation.knownLabels.find(label) == relevancyInformation.knownLabels.end()) {
currentCombination = currentCombination && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label));
}
}
formulae.push_back(currentCombination);
// If all implicants of the current set are known, we do not need to further build the constraint.
if (allImplicantsKnownForCurrentSet) {
allImplicantsKnownForOneSet = true;
break;
}
}
assertDisjunction(context, solver, formulae);
if (!allImplicantsKnownForOneSet) {
assertDisjunction(context, solver, formulae);
}
}
}
@ -371,23 +397,25 @@ namespace storm {
// Iterate over predecessors and add all choices that target the current state to the preceding
// label set of all labels of all relevant choices of the current state.
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(currentState), predecessorIte = backwardTransitions.constColumnIteratorEnd(currentState); predecessorIt != predecessorIte; ++predecessorIt) {
for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(*predecessorIt)) {
bool choiceTargetsCurrentState = false;
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(predecessorChoice), successorIte = transitionMatrix.constColumnIteratorEnd(predecessorChoice); successorIt != successorIte; ++successorIt) {
if (*successorIt == currentState) {
choiceTargetsCurrentState = true;
}
}
if (choiceTargetsCurrentState) {
if (choiceLabeling.at(predecessorChoice).size() > 1) {
for (auto label : choiceLabeling.at(currentChoice)) {
hasSynchronizingPredecessor.insert(label);
if (relevancyInformation.relevantStates.get(*predecessorIt)) {
for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(*predecessorIt)) {
bool choiceTargetsCurrentState = false;
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(predecessorChoice), successorIte = transitionMatrix.constColumnIteratorEnd(predecessorChoice); successorIt != successorIte; ++successorIt) {
if (*successorIt == currentState) {
choiceTargetsCurrentState = true;
}
}
for (auto labelToAdd : choiceLabeling[predecessorChoice]) {
for (auto labelForWhichToAdd : choiceLabeling[currentChoice]) {
precedingLabels[labelForWhichToAdd].insert(labelToAdd);
if (choiceTargetsCurrentState) {
if (choiceLabeling.at(predecessorChoice).size() > 1) {
for (auto label : choiceLabeling.at(currentChoice)) {
hasSynchronizingPredecessor.insert(label);
}
}
for (auto labelToAdd : choiceLabeling[predecessorChoice]) {
for (auto labelForWhichToAdd : choiceLabeling[currentChoice]) {
precedingLabels[labelForWhichToAdd].insert(labelToAdd);
}
}
}
}
@ -396,9 +424,6 @@ namespace storm {
}
}
// FIXME: The following procedure to assert backward cuts is not correct in the presence of synchronizing
// actions, because it may be the case that several synchronizing commands are necessary to enable another
// command and not just one.
storm::utility::ir::VariableInformation programVariableInformation = storm::utility::ir::createVariableInformation(program);
// Create a context and register all variables of the program with their correct type.
@ -451,6 +476,10 @@ namespace storm {
// If the label of the command is not relevant, skip it entirely.
if (relevancyInformation.relevantLabels.find(command.getGlobalIndex()) == relevancyInformation.relevantLabels.end()) continue;
// If the label has a synchronizing predecessor, we also need to skip it, because the following
// procedure can only consider predecessors in isolation.
if(hasSynchronizingPredecessor.find(command.getGlobalIndex()) != hasSynchronizingPredecessor.end()) continue;
// Save the state of the solver so we can easily backtrack.
localSolver.push();
@ -926,6 +955,7 @@ namespace storm {
static std::set<uint_fast64_t> getMinimalCommandSet(storm::ir::Program program, std::string const& constantDefinitionString, storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, double probabilityThreshold, bool checkThresholdFeasible = false) {
#ifdef STORM_HAVE_Z3
auto startTime = std::chrono::high_resolution_clock::now();
auto endTime = std::chrono::high_resolution_clock::now();
storm::utility::ir::defineUndefinedConstants(program, constantDefinitionString);
@ -965,6 +995,7 @@ namespace storm {
variableInformation.auxiliaryVariables.push_back(assertLessOrEqualKRelaxed(context, solver, variableInformation.adderVariables, 0));
// (7) Add constraints that cut off a lot of suboptimal solutions.
LOG4CPLUS_DEBUG(logger, "Asserting cuts.");
assertExplicitCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, context, solver);
LOG4CPLUS_DEBUG(logger, "Asserted explicit cuts.");
assertSymbolicCuts(program, labeledMdp, variableInformation, relevancyInformation, context, solver);
@ -981,10 +1012,11 @@ namespace storm {
uint_fast64_t iterations = 0;
uint_fast64_t currentBound = 0;
maximalReachabilityProbability = 0;
auto iterationTimer = std::chrono::high_resolution_clock::now();
do {
LOG4CPLUS_DEBUG(logger, "Computing minimal command set.");
commandSet = findSmallestCommandSet(context, solver, variableInformation, currentBound);
LOG4CPLUS_DEBUG(logger, "Computed minimal command set of size " << commandSet.size() << ".");
LOG4CPLUS_DEBUG(logger, "Computed minimal command set of size " << (commandSet.size() + relevancyInformation.knownLabels.size()) << ".");
// Restrict the given MDP to the current set of labels and compute the reachability probability.
commandSet.insert(relevancyInformation.knownLabels.begin(), relevancyInformation.knownLabels.end());
@ -1006,14 +1038,20 @@ namespace storm {
done = true;
}
++iterations;
endTime = std::chrono::high_resolution_clock::now();
if (std::chrono::duration_cast<std::chrono::seconds>(endTime - iterationTimer).count() > 5) {
std::cout << "Performed " << iterations << " iterations in " << std::chrono::duration_cast<std::chrono::seconds>(endTime - startTime).count() << "s. Current command set size is " << commandSet.size() << "." << std::endl;
iterationTimer = std::chrono::high_resolution_clock::now();
}
} while (!done);
LOG4CPLUS_INFO(logger, "Found minimal label set after " << iterations << " iterations.");
// (9) Return the resulting command set after undefining the constants.
storm::utility::ir::undefineUndefinedConstants(program);
auto endTime = std::chrono::high_resolution_clock::now();
LOG4CPLUS_WARN(logger, "Computed minimal command set of size " << commandSet.size() << " in " << std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count() << "ms.");
endTime = std::chrono::high_resolution_clock::now();
std::cout << "Computed minimal command set of size " << commandSet.size() << " in " << std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count() << "ms (" << iterations << " iterations)." << std::endl;
return commandSet;

2
src/ir/Command.cpp

@ -27,7 +27,7 @@ namespace storm {
: actionName(oldCommand.getActionName()), guardExpression(oldCommand.guardExpression->clone(renaming, variableState)), globalIndex(newGlobalIndex) {
auto renamingPair = renaming.find(this->actionName);
if (renamingPair != renaming.end()) {
this->actionName = renamingPair->first;
this->actionName = renamingPair->second;
}
this->updates.reserve(oldCommand.getNumberOfUpdates());
for (Update const& update : oldCommand.updates) {

8
src/parser/prismparser/ConstIntegerExpressionGrammar.cpp

@ -19,9 +19,15 @@ namespace storm {
[qi::_val = phoenix::bind(&BaseGrammar::createIntMult, this, qi::_val, qi::_1)];
constantIntegerMultExpression.name("constant integer expression");
constantAtomicIntegerExpression %= (qi::lit("(") >> constantIntegerExpression >> qi::lit(")") | integerConstantExpression);
constantAtomicIntegerExpression %= (constantIntegerMinMaxExpression | constantIntegerFloorCeilExpression | qi::lit("(") >> constantIntegerExpression >> qi::lit(")") | integerConstantExpression);
constantAtomicIntegerExpression.name("constant integer expression");
constantIntegerMinMaxExpression = ((qi::lit("min")[qi::_a = true] | qi::lit("max")[qi::_a = false]) >> qi::lit("(") >> constantIntegerExpression >> qi::lit(",") >> constantIntegerExpression >> qi::lit(")"))[qi::_val = phoenix::bind(&BaseGrammar::createIntMinMax, this, qi::_a, qi::_1, qi::_2)];
constantIntegerMinMaxExpression.name("integer min/max expression");
constantIntegerFloorCeilExpression = ((qi::lit("floor")[qi::_a = true] | qi::lit("ceil")[qi::_a = false]) >> qi::lit("(") >> constantIntegerExpression >> qi::lit(")"))[qi::_val = phoenix::bind(&BaseGrammar::createIntFloorCeil, this, qi::_a, qi::_1)];
constantIntegerFloorCeilExpression.name("integer floor/ceil expression");
integerConstantExpression %= (this->state->integerConstants_ | integerLiteralExpression);
integerConstantExpression.name("integer constant or literal");

2
src/parser/prismparser/ConstIntegerExpressionGrammar.h

@ -30,6 +30,8 @@ namespace storm {
qi::rule<Iterator, std::shared_ptr<BaseExpression>(), Skipper> constantAtomicIntegerExpression;
qi::rule<Iterator, std::shared_ptr<BaseExpression>(), Skipper> integerConstantExpression;
qi::rule<Iterator, std::shared_ptr<BaseExpression>(), Skipper> integerLiteralExpression;
qi::rule<Iterator, std::shared_ptr<BaseExpression>(), qi::locals<bool>, Skipper> constantIntegerMinMaxExpression;
qi::rule<Iterator, std::shared_ptr<BaseExpression>(), qi::locals<bool>, Skipper> constantIntegerFloorCeilExpression;
};

2
src/parser/prismparser/PrismGrammar.cpp

@ -245,7 +245,7 @@ namespace storm {
// This block defines all entities that are needed for parsing a program.
modelTypeDefinition = modelType_;
modelTypeDefinition.name("model type");
start = (
start = (qi::eps >
modelTypeDefinition >
constantDefinitionList(qi::_a, qi::_b, qi::_c) >
globalVariableDefinitionList(qi::_d) >

26
src/storm.cpp

@ -355,14 +355,26 @@ int main(const int argc, const char* argv[]) {
// Enable the following lines to test the SMTMinimalCommandSetGenerator.
if (model->getType() == storm::models::MDP) {
std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true);
// storm::storage::BitVector const& collisionStates = labeledMdp->getLabeledStates("collision_max_backoff");
// storm::storage::BitVector const& deliveredStates = labeledMdp->getLabeledStates("all_delivered");
// std::set<uint_fast64_t> labels = storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, ~collisionStates, deliveredStates, 0.5, true, true);
// Build stuff for coin example.
// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
// std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true);
// Build stuff for csma example.
storm::storage::BitVector const& collisionStates = labeledMdp->getLabeledStates("collision_max_backoff");
storm::storage::BitVector const& deliveredStates = labeledMdp->getLabeledStates("all_delivered");
std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, ~collisionStates, deliveredStates, 0.5, true);
// Build stuff for firewire example.
// storm::storage::BitVector const& electedStates = labeledMdp->getLabeledStates("elected");
// std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), electedStates, 0.01, true);
// Build stuff for wlan example.
// storm::storage::BitVector const& oneCollisionStates = labeledMdp->getLabeledStates("oneCollision");
// storm::storage::BitVector const& twoCollisionStates = labeledMdp->getLabeledStates("twoCollisions");
// std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), twoCollisionStates, 0.1, true);
}
}

Loading…
Cancel
Save