diff --git a/src/counterexamples/CounterexampleOptions.cpp b/src/counterexamples/CounterexampleOptions.cpp index 8fb95adfb..eb9c1c25c 100644 --- a/src/counterexamples/CounterexampleOptions.cpp +++ b/src/counterexamples/CounterexampleOptions.cpp @@ -6,5 +6,6 @@ bool CounterexampleOptionsRegistered = storm::settings::Settings::registerNewMod techniques.push_back("milp"); instance->addOption(storm::settings::OptionBuilder("Counterexample", "mincmd", "", "Computes a counterexample for the given symbolic model in terms of a minimal command set.").addArgument(storm::settings::ArgumentBuilder::createStringArgument("propertyFile", "The file containing the properties for which counterexamples are to be generated.").addValidationFunctionString(storm::settings::ArgumentValidators::existingReadableFileValidator()).build()).addArgument(storm::settings::ArgumentBuilder::createStringArgument("method", "Sets which technique is used to derive the counterexample. Must be either \"milp\" or \"sat\".").setDefaultValueString("sat").addValidationFunctionString(storm::settings::ArgumentValidators::stringInListValidator(techniques)).build()).build()); instance->addOption(storm::settings::OptionBuilder("Counterexample", "stats", "s", "Sets whether to display statistics for certain functionalities.").build()); + instance->addOption(storm::settings::OptionBuilder("Counterexample", "encreach", "", "Sets whether to encode reachability for SAT-based minimal command counterexample generation.").build()); return true; }); diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h index 35eb65789..c3f2d2db1 100644 --- a/src/counterexamples/SMTMinimalCommandSetGenerator.h +++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h @@ -895,15 +895,22 @@ namespace storm { */ static void assertReachabilityCuts(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) { + if (!variableInformation.hasReachabilityVariables) { + throw storm::exceptions::InvalidStateException() << "Impossible to assert reachability cuts without the necessary variables."; + } + // Get some data from the MDP for convenient access. storm::storage::SparseMatrix const& transitionMatrix = labeledMdp.getTransitionMatrix(); std::vector> const& choiceLabeling = labeledMdp.getChoiceLabeling(); storm::storage::SparseMatrix backwardTransitions = labeledMdp.getBackwardTransitions(); - + + // First, we add the formulas that encode + // (1) if an incoming transition is chosen, an outgoing one is chosen as well (for non-initial states) + // (2) an outgoing transition out of the initial states is taken. + z3::expr initialStateExpression = context.bool_val(false); for (auto relevantState : relevancyInformation.relevantStates) { - // Only consider the state if it's not an initial state. if (!labeledMdp.getInitialStates().get(relevantState)) { - + // Assert the constraints (1). storm::storage::VectorSet relevantPredecessors; for (typename storm::storage::SparseMatrix::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(relevantState), predecessorIte = backwardTransitions.constColumnIteratorEnd(relevantState); predecessorIt != predecessorIte; ++predecessorIt) { if (relevantState != *predecessorIt && relevancyInformation.relevantStates.get(*predecessorIt)) { @@ -912,16 +919,72 @@ namespace storm { } storm::storage::VectorSet relevantSuccessors; - for (auto const& relevantChoices : relevancyInformation.relevantChoicesForRelevantStates.at(relevantState)) { - for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(relevantChoices); successorIt != transitionMatrix.constColumnIteratorEnd(relevantChoices); ++successorIt) { + for (auto const& relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(relevantState)) { + for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(relevantChoice); successorIt != transitionMatrix.constColumnIteratorEnd(relevantChoice); ++successorIt) { + if (relevantState != *successorIt && (relevancyInformation.relevantStates.get(*successorIt) || psiStates.get(*successorIt))) { + relevantSuccessors.insert(*successorIt); + } + } + } + + z3::expr expression = context.bool_val(true); + for (auto predecessor : relevantPredecessors) { + expression = expression && !variableInformation.statePairVariables.at(variableInformation.statePairToIndexMap.at(std::make_pair(predecessor, relevantState))); + } + for (auto successor : relevantSuccessors) { + expression = expression || variableInformation.statePairVariables.at(variableInformation.statePairToIndexMap.at(std::make_pair(relevantState, successor))); + } + + solver.add(expression); + } else { + // Assert the constraints (2). + storm::storage::VectorSet relevantSuccessors; + for (auto const& relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(relevantState)) { + for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(relevantChoice); successorIt != transitionMatrix.constColumnIteratorEnd(relevantChoice); ++successorIt) { if (relevantState != *successorIt && (relevancyInformation.relevantStates.get(*successorIt) || psiStates.get(*successorIt))) { relevantSuccessors.insert(*successorIt); } } } - // TODO: build the constraints + for (auto successor : relevantSuccessors) { + initialStateExpression = initialStateExpression || variableInformation.statePairVariables.at(variableInformation.statePairToIndexMap.at(std::make_pair(relevantState, successor))); + } + } + } + solver.add(initialStateExpression); + + // Finally, add constraints that + // (1) if a transition is selected, a valid labeling is selected as well. + // (2) enforce that if a transition from s to s' is selected, the ordering variables become strictly larger. + for (auto const& statePairIndexPair : variableInformation.statePairToIndexMap) { + uint_fast64_t sourceState = statePairIndexPair.first.first; + uint_fast64_t targetState = statePairIndexPair.first.second; + + // Assert constraint for (1). + storm::storage::VectorSet choicesForStatePair; + for (auto const& relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(sourceState)) { + for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(relevantChoice); successorIt != transitionMatrix.constColumnIteratorEnd(relevantChoice); ++successorIt) { + if (*successorIt == targetState) { + choicesForStatePair.insert(relevantChoice); + } + } + } + z3::expr labelExpression = !variableInformation.statePairVariables.at(statePairIndexPair.second); + for (auto choice : choicesForStatePair) { + z3::expr choiceExpression = context.bool_val(true); + for (auto element : choiceLabeling.at(choice)) { + if (!relevancyInformation.knownLabels.contains(element)) { + choiceExpression = choiceExpression && variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(element)); + } + } + labelExpression = labelExpression || choiceExpression; } + solver.add(labelExpression); + + // Assert constraint for (2). + z3::expr orderExpression = !variableInformation.statePairVariables.at(statePairIndexPair.second) || variableInformation.stateOrderVariables.at(variableInformation.relevantStatesToOrderVariableIndexMap.at(sourceState)) < variableInformation.stateOrderVariables.at(variableInformation.relevantStatesToOrderVariableIndexMap.at(targetState)); + solver.add(orderExpression); } } @@ -1312,6 +1375,8 @@ namespace storm { static void analyzeZeroProbabilitySolution(z3::context& context, z3::solver& solver, storm::models::Mdp const& subMdp, storm::models::Mdp const& originalMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, storm::storage::VectorSet const& commandSet, VariableInformation& variableInformation, RelevancyInformation const& relevancyInformation) { storm::storage::BitVector reachableStates(subMdp.getNumberOfStates()); + LOG4CPLUS_DEBUG(logger, "Analyzing solution with zero probability."); + // Initialize the stack for the DFS. bool targetStateIsReachable = false; std::vector stack; @@ -1385,7 +1450,7 @@ namespace storm { if (isBorderChoice) { storm::storage::VectorSet currentLabelSet; - for (auto label : choiceLabeling[currentChoice]) { + for (auto label : choiceLabeling.at(currentChoice)) { if (!commandSet.contains(label)) { currentLabelSet.insert(label); } @@ -1398,7 +1463,7 @@ namespace storm { } } - // Given the results of the previous analysis, we construct the implications + // Given the results of the previous analysis, we construct the implications. std::vector formulae; storm::storage::VectorSet unknownReachableLabels; std::set_difference(reachableLabels.begin(), reachableLabels.end(), relevancyInformation.knownLabels.begin(), relevancyInformation.knownLabels.end(), std::inserter(unknownReachableLabels, unknownReachableLabels.end())); @@ -1433,8 +1498,9 @@ namespace storm { * @param variableInformation A structure with information about the variables of the solver. */ static void analyzeInsufficientProbabilitySolution(z3::context& context, z3::solver& solver, storm::models::Mdp const& subMdp, storm::models::Mdp const& originalMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, storm::storage::VectorSet const& commandSet, VariableInformation& variableInformation, RelevancyInformation const& relevancyInformation) { - // ruleOutSolution(context, solver, commandSet, variableInformation); - + + LOG4CPLUS_DEBUG(logger, "Analyzing solution with insufficient probability."); + storm::storage::BitVector reachableStates(subMdp.getNumberOfStates()); // Initialize the stack for the DFS. @@ -1741,7 +1807,7 @@ namespace storm { // Delegate the actual computation work to the function of equal name. auto startTime = std::chrono::high_resolution_clock::now(); - auto labelSet = getMinimalCommandSet(program, constantDefinitionString, labeledMdp, phiStates, psiStates, bound, strictBound, true); + auto labelSet = getMinimalCommandSet(program, constantDefinitionString, labeledMdp, phiStates, psiStates, bound, strictBound, true, storm::settings::Settings::getInstance()->isSet("encreach")); auto endTime = std::chrono::high_resolution_clock::now(); std::cout << std::endl << "Computed minimal label set of size " << labelSet.size() << " in " << std::chrono::duration_cast(endTime - startTime).count() << "ms." << std::endl; diff --git a/src/storage/VectorSet.cpp b/src/storage/VectorSet.cpp index f150c346e..363723053 100644 --- a/src/storage/VectorSet.cpp +++ b/src/storage/VectorSet.cpp @@ -204,6 +204,7 @@ namespace storm { template void VectorSet::insert(VectorSet const& other) { data.insert(data.end(), other.data.begin(), other.data.end()); + dirty = true; } template