diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h index 59a4279df..bbd4e623b 100644 --- a/src/counterexamples/SMTMinimalCommandSetGenerator.h +++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h @@ -8,6 +8,8 @@ #ifndef STORM_COUNTEREXAMPLES_SMTMINIMALCOMMANDSETGENERATOR_MDP_H_ #define STORM_COUNTEREXAMPLES_SMTMINIMALCOMMANDSETGENERATOR_MDP_H_ +#include + // To detect whether the usage of Z3 is possible, this include is neccessary. #include "storm-config.h" @@ -153,110 +155,231 @@ namespace storm { for (uint_fast64_t index = 0; index < variableInformation.labelVariables.size(); ++index) { solver.add(!variableInformation.labelVariables[index] || variableInformation.auxiliaryVariables[index]); } - - std::vector> const& choiceLabeling = labeledMdp.getChoiceLabeling(); + } + + /*! + * Asserts cuts that rule out a lot of suboptimal solutions. + * + * @param labeledMdp The labeled MDP for which to compute the cuts. + * @param context The Z3 context in which to build the expressions. + * @param solver The solver to use for the satisfiability evaluation. + */ + static void assertCuts(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, VariableInformation const& variableInformation, RelevancyInformation const& relevancyInformation, z3::context& context, z3::solver& solver) { + // Walk through the MDP and: + // identify labels enabled in initial states + // identify labels that can directly precede a given action + // identify labels that directly reach a target state + // identify labels that can directly follow a given action + // TODO: identify which labels need to synchronize + + std::set initialLabels; + std::map> precedingLabels; + std::set targetLabels; + std::map> followingLabels; + std::map> synchronizingLabels; + + // Get some data from the MDP for convenient access. storm::storage::SparseMatrix const& transitionMatrix = labeledMdp.getTransitionMatrix(); - - // Assert that at least one of the labels of one of the relevant initial states is taken. - std::vector expressionVector; - bool firstAssignment = true; - for (auto state : labeledMdp.getInitialStates()) { - if (relevancyInformation.relevantStates.get(state)) { - for (auto const& choice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) { - for (auto const& label : choiceLabeling[choice]) { - z3::expr labelExpression = variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); - if (firstAssignment) { - expressionVector.push_back(labelExpression); - firstAssignment = false; - } else { - expressionVector.back() = expressionVector.back() && labelExpression; + std::vector const& nondeterministicChoiceIndices = labeledMdp.getNondeterministicChoiceIndices(); + storm::storage::BitVector const& initialStates = labeledMdp.getInitialStates(); + std::vector> const& choiceLabeling = labeledMdp.getChoiceLabeling(); + storm::storage::SparseMatrix backwardTransitions = labeledMdp.getBackwardTransitions(); + + for (auto currentState : relevancyInformation.relevantStates) { + for (auto currentChoice : relevancyInformation.relevantChoicesForRelevantStates.at(currentState)) { + + // If the state is initial, we need to add all the choice labels to the initial label set. + if (initialStates.get(currentState)) { + for (auto label : choiceLabeling[currentChoice]) { + initialLabels.insert(label); + } + } + + // Iterate over successors and add relevant choices of relevant successors to the following label set. + bool canReachTargetState = false; + for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(currentChoice), successorIte = transitionMatrix.constColumnIteratorEnd(currentChoice); successorIt != successorIte; ++successorIt) { + if (relevancyInformation.relevantStates.get(*successorIt)) { + for (auto relevantChoice : relevancyInformation.relevantChoicesForRelevantStates.at(*successorIt)) { + for (auto labelToAdd : choiceLabeling[relevantChoice]) { + for (auto labelForWhichToAdd : choiceLabeling[currentChoice]) { + followingLabels[labelForWhichToAdd].insert(labelToAdd); + } + } + } + } else if (psiStates.get(*successorIt)) { + canReachTargetState = true; + } + } + + // If the choice can reach a target state directly, we add all the labels to the target label set. + if (canReachTargetState) { + for (auto label : choiceLabeling[currentChoice]) { + targetLabels.insert(label); + } + } + + // Iterate over predecessors and add all choices that target the current state to the preceding + // label set of all labels of all relevant choices of the current state. + for (typename storm::storage::SparseMatrix::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(currentState), predecessorIte = backwardTransitions.constColumnIteratorEnd(currentState); predecessorIt != predecessorIte; ++predecessorIt) { + for (auto predecessorChoice : relevancyInformation.relevantChoicesForRelevantStates.at(*predecessorIt)) { + bool choiceTargetsCurrentState = false; + for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(predecessorChoice), successorIte = transitionMatrix.constColumnIteratorEnd(predecessorChoice); successorIt != successorIte; ++successorIt) { + if (*successorIt == currentState) { + choiceTargetsCurrentState = true; + } + } + + if (choiceTargetsCurrentState) { + for (auto labelToAdd : choiceLabeling[predecessorChoice]) { + for (auto labelForWhichToAdd : choiceLabeling[currentChoice]) { + precedingLabels[labelForWhichToAdd].insert(labelToAdd); + } + } } } } } } - assertDisjunction(context, solver, expressionVector); - // Assert that at least one of the labels that are selected can reach a target state in one step. - storm::storage::SparseMatrix backwardTransitions = labeledMdp.getBackwardTransitions(); - - // Compute the set of predecessors of target states. - std::unordered_set predecessors; + std::vector formulae; + + // Start by asserting that we take at least one initial label. + for (auto label : initialLabels) { + formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label))); + } + assertDisjunction(context, solver, formulae); + + formulae.clear(); + + // Also assert that we take at least one target label. + for (auto label : targetLabels) { + formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label))); + } + assertDisjunction(context, solver, formulae); + + // Now assert that for each non-target label, we take a following label. + for (auto const& labelSetPair : followingLabels) { + formulae.clear(); + if (targetLabels.find(labelSetPair.first) == targetLabels.end()) { + formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(labelSetPair.first))); + for (auto followingLabel : labelSetPair.second) { + formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(followingLabel))); + } + } + if (formulae.size() > 0) { + assertDisjunction(context, solver, formulae); + } + } + + // Consequently, assert that for each non-initial label, we take preceding command. + for (auto const& labelSetPair : precedingLabels) { + formulae.clear(); + if (initialLabels.find(labelSetPair.first) == initialLabels.end()) { + formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(labelSetPair.first))); + for (auto followingLabel : labelSetPair.second) { + formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(followingLabel))); + } + } + if (formulae.size() > 0) { + assertDisjunction(context, solver, formulae); + } + } + + // Now we compute the set of labels that is present on all paths from the initial to the target states. + std::vector> analysisInformation(labeledMdp.getNumberOfStates(), relevancyInformation.relevantLabels); + std::queue> worklist; + + // Initially, put all predecessors of target states in the worklist and empty the analysis information + // them. for (auto state : psiStates) { - for (typename storm::storage::SparseMatrix::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(state); predecessorIt != backwardTransitions.constColumnIteratorEnd(state); ++predecessorIt) { - if (state != *predecessorIt) { - predecessors.insert(*predecessorIt); + analysisInformation[state] = std::set(); + for (typename storm::storage::SparseMatrix::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(state), predecessorIte = backwardTransitions.constColumnIteratorEnd(state); predecessorIt != predecessorIte; ++predecessorIt) { + if (relevancyInformation.relevantStates.get(*predecessorIt)) { + worklist.push(std::make_pair(*predecessorIt, state)); } } } - - expressionVector.clear(); - firstAssignment = true; - for (auto state : predecessors) { - for (auto choice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) { - for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(choice); successorIt != transitionMatrix.constColumnIteratorEnd(choice); ++successorIt) { - if (psiStates.get(*successorIt)) { - for (auto const& label : choiceLabeling[choice]) { - z3::expr labelExpression = variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)); - if (firstAssignment) { - expressionVector.push_back(labelExpression); - firstAssignment = false; - } else { - expressionVector.back() = expressionVector.back() && labelExpression; - } - } + + // Iterate as long as the worklist is non-empty. + while (!worklist.empty()) { + std::pair const& currentStateTargetStatePair = worklist.front(); + uint_fast64_t currentState = currentStateTargetStatePair.first; + uint_fast64_t targetState = currentStateTargetStatePair.second; + + // Iterate over the successor states for all choices and compute new analysis information. + std::set intersection; + for (auto currentChoice : relevancyInformation.relevantChoicesForRelevantStates.at(currentState)) { + for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(currentChoice), successorIte = transitionMatrix.constColumnIteratorEnd(currentChoice); successorIt != successorIte; ++successorIt) { + // If we can reach the target state with this choice, we need to intersect the current + // analysis information with the union of the new analysis information of the target state + // and the choice labels. + if (*successorIt == targetState) { + std::set_intersection(analysisInformation[currentState].begin(), analysisInformation[currentState].end(), analysisInformation[targetState].begin(), analysisInformation[targetState].end(), std::inserter(intersection, intersection.begin())); + + std::set choiceLabelIntersection; + std::set_intersection(analysisInformation[currentState].begin(), analysisInformation[currentState].end(), choiceLabeling[currentChoice].begin(), choiceLabeling[currentChoice].end(), std::inserter(intersection, intersection.begin())); } } } + + // If the analysis information changed, we need to update it and put all the predecessors of this + // state in the worklist. + if (analysisInformation[currentState] != intersection) { + analysisInformation[currentState] = std::move(intersection); + + for (typename storm::storage::SparseMatrix::ConstIndexIterator predecessorIt = backwardTransitions.constColumnIteratorBegin(currentState), predecessorIte = backwardTransitions.constColumnIteratorEnd(currentState); predecessorIt != predecessorIte; ++predecessorIt) { + worklist.push(std::make_pair(*predecessorIt, currentState)); + } + } + + + worklist.pop(); } - assertDisjunction(context, solver, expressionVector); - } - - /*! - * Asserts cuts that rule out a lot of suboptimal solutions. - * - * @param program The program for which to derive the cuts. - * @param context The Z3 context in which to build the expressions. - * @param solver The solver to use for the satisfiability evaluation. - */ - static void assertCuts(storm::ir::Program const& program, z3::context& context, z3::solver& solver) { - // TODO. + + // Now build the intersection over the analysis information of all initial states. + std::set knownLabels(relevancyInformation.relevantLabels); + std::set tempIntersection; + for (auto initialState : labeledMdp.getInitialStates()) { + std::set_intersection(knownLabels.begin(), knownLabels.end(), analysisInformation[initialState].begin(), analysisInformation[initialState].end(), std::inserter(tempIntersection, tempIntersection.begin())); + std::swap(knownLabels, tempIntersection); + tempIntersection.clear(); + } + + formulae.clear(); + for (auto label : knownLabels) { + formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label))); + } + assertConjunction(context, solver, formulae); } /*! - * Asserts that the disjunction of the given formulae holds. + * Asserts that the disjunction of the given formulae holds. If the content of the disjunction is empty, + * this corresponds to asserting false. * * @param context The Z3 context in which to build the expressions. * @param solver The solver to use for the satisfiability evaluation. * @param formulaVector A vector of expressions that shall form the disjunction. */ static void assertDisjunction(z3::context& context, z3::solver& solver, std::vector const& formulaVector) { - z3::expr disjunction(context); - for (uint_fast64_t i = 0; i < formulaVector.size(); ++i) { - if (i == 0) { - disjunction = formulaVector[i]; - } else { - disjunction = disjunction || formulaVector[i]; - } + z3::expr disjunction = context.bool_val(false); + for (auto expr : formulaVector) { + disjunction = disjunction || expr; } solver.add(disjunction); } /*! - * Asserts that the conjunction of the given formulae holds. + * Asserts that the conjunction of the given formulae holds. If the content of the conjunction is empty, + * this corresponds to asserting true. * * @param context The Z3 context in which to build the expressions. * @param solver The solver to use for the satisfiability evaluation. * @param formulaVector A vector of expressions that shall form the conjunction. */ static void assertConjunction(z3::context& context, z3::solver& solver, std::vector const& formulaVector) { - z3::expr conjunction(context); - for (uint_fast64_t i = 0; i < formulaVector.size(); ++i) { - if (i == 0) { - conjunction = formulaVector[i]; - } else { - conjunction = conjunction && formulaVector[i]; - } + z3::expr conjunction = context.bool_val(true); + for (auto expr : formulaVector) { + conjunction = conjunction && expr; } solver.add(conjunction); } @@ -383,13 +506,6 @@ namespace storm { static void assertAtMostOne(z3::context& context, z3::solver& solver, std::vector const& literals) { std::vector counter = createCounterCircuit(context, literals); assertLessOrEqualOne(context, solver, counter); - - -// for (uint_fast64_t i = 0; i < blockingVariables.size(); ++i) { -// for (uint_fast64_t j = i + 1; j < blockingVariables.size(); ++j) { -// solver.add(!blockingVariables[i] || !blockingVariables[j]); -// } -// } } /*! @@ -545,7 +661,7 @@ namespace storm { assertInitialConstraints(program, labeledMdp, psiStates, context, solver, variableInformation, relevancyInformation); // (6) Add constraints that cut off a lot of suboptimal solutions. - assertCuts(program, context, solver); + assertCuts(labeledMdp, psiStates, variableInformation, relevancyInformation, context, solver); // (7) Find the smallest set of commands that satisfies all constraints. If the probability of // satisfying phi until psi exceeds the given threshold, the set of labels is minimal and can be returned. diff --git a/src/storm.cpp b/src/storm.cpp index 448169d94..a1ba8ec34 100644 --- a/src/storm.cpp +++ b/src/storm.cpp @@ -338,19 +338,19 @@ int main(const int argc, const char* argv[]) { model->printModelInformationToStream(std::cout); // Enable the following lines to test the MinimalLabelSetGenerator. - if (model->getType() == storm::models::MDP) { - std::shared_ptr> labeledMdp = model->as>(); - storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); - storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); - storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; - std::unordered_set labels = storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true); - - std::cout << "Found solution with " << labels.size() << " commands." << std::endl; - for (uint_fast64_t label : labels) { - std::cout << label << ", "; - } - std::cout << std::endl; - } +// if (model->getType() == storm::models::MDP) { +// std::shared_ptr> labeledMdp = model->as>(); +// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); +// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); +// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; +// std::unordered_set labels = storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true); +// +// std::cout << "Found solution with " << labels.size() << " commands." << std::endl; +// for (uint_fast64_t label : labels) { +// std::cout << label << ", "; +// } +// std::cout << std::endl; +// } // Enable the following lines to test the SMTMinimalCommandSetGenerator. if (model->getType() == storm::models::MDP) {