From b18199d3ec4d925eaa2e2b9dba512505832f3b62 Mon Sep 17 00:00:00 2001 From: dehnert Date: Wed, 16 Oct 2013 21:04:05 +0200 Subject: [PATCH] Further work on minimal label set generators. Former-commit-id: 84e86f5842d821fcf09fbb8cec7bb11228a1887b --- .../MILPMinimalLabelSetGenerator.h | 31 +++- .../SMTMinimalCommandSetGenerator.h | 147 ++++++++++++++++-- src/models/Mdp.h | 5 +- src/storm.cpp | 29 ++-- src/utility/counterexamples.h | 21 ++- 5 files changed, 192 insertions(+), 41 deletions(-) diff --git a/src/counterexamples/MILPMinimalLabelSetGenerator.h b/src/counterexamples/MILPMinimalLabelSetGenerator.h index f82694461..7f15b3975 100644 --- a/src/counterexamples/MILPMinimalLabelSetGenerator.h +++ b/src/counterexamples/MILPMinimalLabelSetGenerator.h @@ -66,6 +66,7 @@ namespace storm { std::unordered_map> relevantChoicesForRelevantStates; std::unordered_map> problematicChoicesForProblematicStates; std::set allRelevantLabels; + std::set knownLabels; }; /*! @@ -868,8 +869,8 @@ namespace storm { uint_fast64_t numberOfConstraintsCreated = 0; int error = 0; - std::set knownLabels = storm::utility::counterexamples::getGuaranteedLabelSet(labeledMdp, psiStates, choiceInformation.allRelevantLabels); - for (auto label : knownLabels) { + choiceInformation.knownLabels = storm::utility::counterexamples::getGuaranteedLabelSet(labeledMdp, psiStates, choiceInformation.allRelevantLabels); + for (auto label : choiceInformation.knownLabels) { double coefficient = 1; int variableIndex = variableInformation.labelToVariableIndexMap.at(label); @@ -963,6 +964,10 @@ namespace storm { } for (auto predecessor : predecessors) { + if (!stateInformation.relevantStates.get(predecessor)) { + continue; + } + std::list::const_iterator choiceVariableIndicesIterator = variableInformation.stateToChoiceVariablesIndexMap.at(predecessor).begin(); for (auto relevantChoice : choiceInformation.relevantChoicesForRelevantStates.at(predecessor)) { bool choiceTargetsCurrentState = false; @@ -998,7 +1003,7 @@ namespace storm { } ++numberOfConstraintsCreated; } - + // Assert that at least one initial state selects at least one action. variables.clear(); coefficients.clear(); @@ -1241,13 +1246,24 @@ namespace storm { static std::set getMinimalLabelSet(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, double probabilityThreshold, bool checkThresholdFeasible = false, bool includeSchedulerCuts = false) { #ifdef STORM_HAVE_GUROBI + auto startTime = std::chrono::high_resolution_clock::now(); + // (0) Check whether the MDP is indeed labeled. if (!labeledMdp.hasChoiceLabels()) { throw storm::exceptions::InvalidArgumentException() << "Minimal label set generation is impossible for unlabeled model."; } - // (1) FIXME: check whether its possible to exceed the threshold if checkThresholdFeasible is set. - + // (1) Check whether its possible to exceed the threshold if checkThresholdFeasible is set. + double maximalReachabilityProbability = 0; + storm::modelchecker::prctl::SparseMdpPrctlModelChecker modelchecker(labeledMdp, new storm::solver::GmmxxNondeterministicLinearEquationSolver()); + std::vector result = modelchecker.checkUntil(false, phiStates, psiStates, false, nullptr); + for (auto state : labeledMdp.getInitialStates()) { + maximalReachabilityProbability = std::max(maximalReachabilityProbability, result[state]); + } + if (maximalReachabilityProbability <= probabilityThreshold) { + throw storm::exceptions::InvalidArgumentException() << "Given probability threshold " << probabilityThreshold << " can not be achieved in model with maximal reachability probability of " << maximalReachabilityProbability << "."; + } + // (2) Identify relevant and problematic states. StateInformation stateInformation = determineRelevantAndProblematicStates(labeledMdp, phiStates, psiStates); @@ -1277,10 +1293,13 @@ namespace storm { // (4.5) Read off result from variables. std::set usedLabelSet = getUsedLabelsInSolution(environmentModelPair.first, environmentModelPair.second, variableInformation); + usedLabelSet.insert(choiceInformation.knownLabels.begin(), choiceInformation.knownLabels.end()); // Display achieved probability. std::pair initialStateProbabilityPair = getReachabilityProbability(environmentModelPair.first, environmentModelPair.second, labeledMdp, variableInformation); - LOG4CPLUS_DEBUG(logger, "Achieved probability " << initialStateProbabilityPair.second << " in initial state " << initialStateProbabilityPair.first << "."); + + auto endTime = std::chrono::high_resolution_clock::now(); + std::cout << "Computed minimal label set of size " << usedLabelSet.size() << " in " << std::chrono::duration_cast(endTime - startTime).count() << "ms." << std::endl; // (4.6) Shutdown Gurobi. destroyGurobiModelAndEnvironment(environmentModelPair.first, environmentModelPair.second); diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h index 7b83b53a4..39081fca2 100644 --- a/src/counterexamples/SMTMinimalCommandSetGenerator.h +++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h @@ -89,6 +89,7 @@ namespace storm { relevancyInformation.relevantStates &= ~psiStates; LOG4CPLUS_DEBUG(logger, "Found " << relevancyInformation.relevantStates.getNumberOfSetBits() << " relevant states."); + LOG4CPLUS_DEBUG(logger, relevancyInformation.relevantStates.toString()); // Retrieve some references for convenient access. storm::storage::SparseMatrix const& transitionMatrix = labeledMdp.getTransitionMatrix(); @@ -121,13 +122,21 @@ namespace storm { // Compute the set of labels that are known to be taken in any case. relevancyInformation.knownLabels = storm::utility::counterexamples::getGuaranteedLabelSet(labeledMdp, psiStates, relevancyInformation.relevantLabels); - if (!relevancyInformation.knownLabels.empty()) { std::set remainingLabels; std::set_difference(relevancyInformation.relevantLabels.begin(), relevancyInformation.relevantLabels.end(), relevancyInformation.knownLabels.begin(), relevancyInformation.knownLabels.end(), std::inserter(remainingLabels, remainingLabels.begin())); relevancyInformation.relevantLabels = remainingLabels; } +// std::vector> guaranteedLabels = storm::utility::counterexamples::getGuaranteedLabelSets(labeledMdp, psiStates, relevancyInformation.relevantLabels); +// for (auto state : relevancyInformation.relevantStates) { +// std::cout << "state " << state << " ##########################################################" << std::endl; +// for (auto label : guaranteedLabels[state]) { +// std::cout << label << ", "; +// } +// std::cout << std::endl; +// } + LOG4CPLUS_DEBUG(logger, "Found " << relevancyInformation.relevantLabels.size() << " relevant and " << relevancyInformation.knownLabels.size() << " known labels."); return relevancyInformation; } @@ -852,20 +861,10 @@ namespace storm { * @param variableInformation A structure with information about the variables for the labels. */ static void ruleOutSolution(z3::context& context, z3::solver& solver, std::set const& commandSet, VariableInformation const& variableInformation) { - std::map::const_iterator labelIndexIterator = variableInformation.labelToIndexMap.begin(); - z3::expr blockSolutionExpression(context); - if (commandSet.find(labelIndexIterator->first) != commandSet.end()) { - blockSolutionExpression = !variableInformation.labelVariables[labelIndexIterator->second]; - } else { - blockSolutionExpression = variableInformation.labelVariables[labelIndexIterator->second]; - } - ++labelIndexIterator; - - for (; labelIndexIterator != variableInformation.labelToIndexMap.end(); ++labelIndexIterator) { - if (commandSet.find(labelIndexIterator->first) != commandSet.end()) { - blockSolutionExpression = blockSolutionExpression || !variableInformation.labelVariables[labelIndexIterator->second]; - } else { - blockSolutionExpression = blockSolutionExpression || variableInformation.labelVariables[labelIndexIterator->second]; + z3::expr blockSolutionExpression = context.bool_val(false); + for (auto labelIndexPair : variableInformation.labelToIndexMap) { + if (commandSet.find(labelIndexPair.first) == commandSet.end()) { + blockSolutionExpression = blockSolutionExpression || variableInformation.labelVariables[labelIndexPair.second]; } } @@ -949,6 +948,113 @@ namespace storm { // set and return it. return getUsedLabelSet(context, solver.get_model(), variableInformation); } + + static void analyzeBadSolution(z3::context& context, z3::solver& solver, storm::models::Mdp const& subMdp, storm::models::Mdp const& originalMdp, storm::storage::BitVector const& psiStates, std::set const& commandSet, VariableInformation& variableInformation, RelevancyInformation const& relevancyInformation) { + storm::storage::BitVector reachableStates(subMdp.getNumberOfStates()); + + // Initialize the stack for the DFS. + bool targetStateIsReachable = false; + std::vector stack; + stack.reserve(subMdp.getNumberOfStates()); + for (auto initialState : subMdp.getInitialStates()) { + stack.push_back(initialState); + reachableStates.set(initialState, true); + } + + storm::storage::SparseMatrix const& transitionMatrix = subMdp.getTransitionMatrix(); + std::vector const& nondeterministicChoiceIndices = subMdp.getNondeterministicChoiceIndices(); + std::vector> const& subChoiceLabeling = subMdp.getChoiceLabeling(); + + std::set reachableLabels; + + while (!stack.empty()) { + uint_fast64_t currentState = stack.back(); + stack.pop_back(); + + for (uint_fast64_t currentChoice = nondeterministicChoiceIndices[currentState]; currentChoice < nondeterministicChoiceIndices[currentState + 1]; ++currentChoice) { + bool choiceTargetsRelevantState = false; + + for (typename storm::storage::SparseMatrix::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(currentChoice), successorIte = transitionMatrix.constColumnIteratorEnd(currentChoice); successorIt != successorIte; ++successorIt) { + if (relevancyInformation.relevantStates.get(*successorIt) && currentState != *successorIt) { + choiceTargetsRelevantState = true; + if (!reachableStates.get(*successorIt)) { + reachableStates.set(*successorIt, true); + stack.push_back(*successorIt); + } + } else if (psiStates.get(*successorIt)) { + targetStateIsReachable = true; + } + } + + if (choiceTargetsRelevantState) { + for (auto label : subChoiceLabeling[currentChoice]) { + reachableLabels.insert(label); + } + } + } + } + + LOG4CPLUS_DEBUG(logger, "Successfully performed reachability analysis."); + + if (targetStateIsReachable) { + LOG4CPLUS_ERROR(logger, "Target must be unreachable for this analysis."); + throw storm::exceptions::InvalidStateException() << "Target must be unreachable for this analysis."; + } + + std::vector> const& choiceLabeling = originalMdp.getChoiceLabeling(); + std::set cutLabels; + for (auto state : reachableStates) { + for (auto currentChoice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) { + if (!storm::utility::set::isSubsetOf(choiceLabeling[currentChoice], commandSet)) { + for (auto label : choiceLabeling[currentChoice]) { + if (commandSet.find(label) == commandSet.end()) { + cutLabels.insert(label); + } + } + } + } + } + + std::vector formulae; + std::set unknownReachableLabels; + std::set_difference(reachableLabels.begin(), reachableLabels.end(), relevancyInformation.knownLabels.begin(), relevancyInformation.knownLabels.end(), std::inserter(unknownReachableLabels, unknownReachableLabels.begin())); + for (auto label : unknownReachableLabels) { + formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label))); + } + for (auto cutLabel : cutLabels) { + formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(cutLabel))); + } + + LOG4CPLUS_DEBUG(logger, "Asserting reachability implications."); + +// for (auto e : formulae) { +// std::cout << e << ", "; +// } +// std::cout << std::endl; + + assertDisjunction(context, solver, formulae); +// +// std::cout << "formulae: " << std::endl; +// for (auto e : formulae) { +// std::cout << e << ", "; +// } +// std::cout << std::endl; +// +// storm::storage::BitVector unreachableRelevantStates = ~reachableStates & relevancyInformation.relevantStates; +// std::cout << unreachableRelevantStates.toString() << std::endl; +// std::cout << reachableStates.toString() << std::endl; +// std::cout << "reachable commands" << std::endl; +// for (auto label : reachableLabels) { +// std::cout << label << ", "; +// } +// std::cout << std::endl; +// std::cout << "cut commands" << std::endl; +// for (auto label : cutLabels) { +// std::cout << label << ", "; +// } +// std::cout << std::endl; + + } #endif public: @@ -1013,6 +1119,7 @@ namespace storm { uint_fast64_t currentBound = 0; maximalReachabilityProbability = 0; auto iterationTimer = std::chrono::high_resolution_clock::now(); + uint_fast64_t zeroProbabilityCount = 0; do { LOG4CPLUS_DEBUG(logger, "Computing minimal command set."); commandSet = findSmallestCommandSet(context, solver, variableInformation, currentBound); @@ -1027,11 +1134,19 @@ namespace storm { LOG4CPLUS_DEBUG(logger, "Computed model checking results."); // Now determine the maximal reachability probability by checking all initial states. + maximalReachabilityProbability = 0; for (auto state : labeledMdp.getInitialStates()) { maximalReachabilityProbability = std::max(maximalReachabilityProbability, result[state]); } if (maximalReachabilityProbability <= probabilityThreshold) { + if (maximalReachabilityProbability == 0) { + ++zeroProbabilityCount; + + // If there was no target state reachable, analyze the solution and guide the solver into the + // right direction. + analyzeBadSolution(context, solver, subMdp, labeledMdp, psiStates, commandSet, variableInformation, relevancyInformation); + } // In case we have not yet exceeded the given threshold, we have to rule out the current solution. ruleOutSolution(context, solver, commandSet, variableInformation); } else { @@ -1041,7 +1156,7 @@ namespace storm { endTime = std::chrono::high_resolution_clock::now(); if (std::chrono::duration_cast(endTime - iterationTimer).count() > 5) { - std::cout << "Performed " << iterations << " iterations in " << std::chrono::duration_cast(endTime - startTime).count() << "s. Current command set size is " << commandSet.size() << "." << std::endl; + std::cout << "Performed " << iterations << " iterations in " << std::chrono::duration_cast(endTime - startTime).count() << "s. Current command set size is " << commandSet.size() << ". Encountered maximal probability of zero " << zeroProbabilityCount << " times." << std::endl; iterationTimer = std::chrono::high_resolution_clock::now(); } } while (!done); diff --git a/src/models/Mdp.h b/src/models/Mdp.h index 4bdbbf85c..17620e194 100644 --- a/src/models/Mdp.h +++ b/src/models/Mdp.h @@ -145,6 +145,7 @@ public: storm::storage::SparseMatrix transitionMatrix; transitionMatrix.initialize(); std::vector nondeterministicChoiceIndices; + std::vector> newChoiceLabeling; // Check for each choice of each state, whether the choice labels are fully contained in the given label set. uint_fast64_t currentRow = 0; @@ -163,6 +164,7 @@ public: for (typename storm::storage::SparseMatrix::ConstIterator rowIt = row.begin(), rowIte = row.end(); rowIt != rowIte; ++rowIt) { transitionMatrix.insertNextValue(currentRow, rowIt.column(), rowIt.value(), true); } + newChoiceLabeling.emplace_back(choiceLabeling[choice]); ++currentRow; } } @@ -171,13 +173,14 @@ public: if (!stateHasValidChoice) { nondeterministicChoiceIndices.push_back(currentRow); transitionMatrix.insertNextValue(currentRow, state, storm::utility::constGetOne(), true); + newChoiceLabeling.emplace_back(); ++currentRow; } } transitionMatrix.finalize(true); nondeterministicChoiceIndices.push_back(currentRow); - Mdp restrictedMdp(std::move(transitionMatrix), storm::models::AtomicPropositionsLabeling(this->getStateLabeling()), std::move(nondeterministicChoiceIndices), this->hasStateRewards() ? boost::optional>(this->getStateRewardVector()) : boost::optional>(), this->hasTransitionRewards() ? boost::optional>(this->getTransitionRewardMatrix()) : boost::optional>(), boost::optional>>(this->getChoiceLabeling())); + Mdp restrictedMdp(std::move(transitionMatrix), storm::models::AtomicPropositionsLabeling(this->getStateLabeling()), std::move(nondeterministicChoiceIndices), this->hasStateRewards() ? boost::optional>(this->getStateRewardVector()) : boost::optional>(), this->hasTransitionRewards() ? boost::optional>(this->getTransitionRewardMatrix()) : boost::optional>(), boost::optional>>(newChoiceLabeling)); return restrictedMdp; } diff --git a/src/storm.cpp b/src/storm.cpp index 79e0d4d11..5fef49964 100644 --- a/src/storm.cpp +++ b/src/storm.cpp @@ -338,19 +338,22 @@ int main(const int argc, const char* argv[]) { model->printModelInformationToStream(std::cout); // Enable the following lines to test the MinimalLabelSetGenerator. -// if (model->getType() == storm::models::MDP) { -// std::shared_ptr> labeledMdp = model->as>(); + if (model->getType() == storm::models::MDP) { + std::shared_ptr> labeledMdp = model->as>(); + + // Build stuff for coin example. // storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); // storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); // storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; // std::set labels = storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true, true); -// -// std::cout << "Found solution with " << labels.size() << " commands." << std::endl; -// for (uint_fast64_t label : labels) { -// std::cout << label << ", "; -// } -// std::cout << std::endl; -// } + +// storm::storage::BitVector const& collisionStates = labeledMdp->getLabeledStates("collision_max_backoff"); +// storm::storage::BitVector const& deliveredStates = labeledMdp->getLabeledStates("all_delivered"); +// std::set labels = storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, ~collisionStates, deliveredStates, 0.5, true, false); + +// storm::storage::BitVector const& electedStates = labeledMdp->getLabeledStates("elected"); +// std::set labels = storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), electedStates, 0.5, true, true); + } // Enable the following lines to test the SMTMinimalCommandSetGenerator. if (model->getType() == storm::models::MDP) { @@ -363,13 +366,13 @@ int main(const int argc, const char* argv[]) { // std::set labels = storm::counterexamples::SMTMinimalCommandSetGenerator::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true); // Build stuff for csma example. - storm::storage::BitVector const& collisionStates = labeledMdp->getLabeledStates("collision_max_backoff"); - storm::storage::BitVector const& deliveredStates = labeledMdp->getLabeledStates("all_delivered"); - std::set labels = storm::counterexamples::SMTMinimalCommandSetGenerator::getMinimalCommandSet(program, constants, *labeledMdp, ~collisionStates, deliveredStates, 0.5, true); +// storm::storage::BitVector const& collisionStates = labeledMdp->getLabeledStates("collision_max_backoff"); +// storm::storage::BitVector const& deliveredStates = labeledMdp->getLabeledStates("all_delivered"); +// std::set labels = storm::counterexamples::SMTMinimalCommandSetGenerator::getMinimalCommandSet(program, constants, *labeledMdp, ~collisionStates, deliveredStates, 0.5, true); // Build stuff for firewire example. // storm::storage::BitVector const& electedStates = labeledMdp->getLabeledStates("elected"); -// std::set labels = storm::counterexamples::SMTMinimalCommandSetGenerator::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), electedStates, 0.01, true); +// std::set labels = storm::counterexamples::SMTMinimalCommandSetGenerator::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), electedStates, 0.5, true); // Build stuff for wlan example. // storm::storage::BitVector const& oneCollisionStates = labeledMdp->getLabeledStates("oneCollision"); diff --git a/src/utility/counterexamples.h b/src/utility/counterexamples.h index 1b5005745..b0aa9682f 100644 --- a/src/utility/counterexamples.h +++ b/src/utility/counterexamples.h @@ -15,12 +15,12 @@ namespace storm { namespace counterexamples { /*! - * Computes a set of action labels that is visited along all paths from an initial to a target state. + * Computes a set of action labels that is visited along all paths from any state to a target state. * - * @return The set of action labels that is visited on all paths from an initial to a target state. + * @return The set of action labels that is visited on all paths from any state to a target state. */ template - std::set getGuaranteedLabelSet(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, std::set const& relevantLabels) { + std::vector> getGuaranteedLabelSets(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, std::set const& relevantLabels) { // Get some data from the MDP for convenient access. storm::storage::SparseMatrix const& transitionMatrix = labeledMdp.getTransitionMatrix(); std::vector const& nondeterministicChoiceIndices = labeledMdp.getNondeterministicChoiceIndices(); @@ -78,11 +78,22 @@ namespace storm { worklist.pop(); } - // Now build the intersection over the analysis information of all initial states. + return analysisInformation; + } + + /*! + * Computes a set of action labels that is visited along all paths from an initial state to a target state. + * + * @return The set of action labels that is visited on all paths from an initial state to a target state. + */ + template + std::set getGuaranteedLabelSet(storm::models::Mdp const& labeledMdp, storm::storage::BitVector const& psiStates, std::set const& relevantLabels) { + std::vector> guaranteedLabels = getGuaranteedLabelSets(labeledMdp, psiStates, relevantLabels); + std::set knownLabels(relevantLabels); std::set tempIntersection; for (auto initialState : labeledMdp.getInitialStates()) { - std::set_intersection(knownLabels.begin(), knownLabels.end(), analysisInformation[initialState].begin(), analysisInformation[initialState].end(), std::inserter(tempIntersection, tempIntersection.begin())); + std::set_intersection(knownLabels.begin(), knownLabels.end(), guaranteedLabels[initialState].begin(), guaranteedLabels[initialState].end(), std::inserter(tempIntersection, tempIntersection.begin())); std::swap(knownLabels, tempIntersection); tempIntersection.clear(); }