Browse Source

Further work on minimal label set generators.

Former-commit-id: 84e86f5842
tempestpy_adaptions
dehnert 11 years ago
parent
commit
b18199d3ec
  1. 31
      src/counterexamples/MILPMinimalLabelSetGenerator.h
  2. 147
      src/counterexamples/SMTMinimalCommandSetGenerator.h
  3. 5
      src/models/Mdp.h
  4. 29
      src/storm.cpp
  5. 21
      src/utility/counterexamples.h

31
src/counterexamples/MILPMinimalLabelSetGenerator.h

@ -66,6 +66,7 @@ namespace storm {
std::unordered_map<uint_fast64_t, std::list<uint_fast64_t>> relevantChoicesForRelevantStates;
std::unordered_map<uint_fast64_t, std::list<uint_fast64_t>> problematicChoicesForProblematicStates;
std::set<uint_fast64_t> allRelevantLabels;
std::set<uint_fast64_t> knownLabels;
};
/*!
@ -868,8 +869,8 @@ namespace storm {
uint_fast64_t numberOfConstraintsCreated = 0;
int error = 0;
std::set<uint_fast64_t> knownLabels = storm::utility::counterexamples::getGuaranteedLabelSet(labeledMdp, psiStates, choiceInformation.allRelevantLabels);
for (auto label : knownLabels) {
choiceInformation.knownLabels = storm::utility::counterexamples::getGuaranteedLabelSet(labeledMdp, psiStates, choiceInformation.allRelevantLabels);
for (auto label : choiceInformation.knownLabels) {
double coefficient = 1;
int variableIndex = variableInformation.labelToVariableIndexMap.at(label);
@ -963,6 +964,10 @@ namespace storm {
}
for (auto predecessor : predecessors) {
if (!stateInformation.relevantStates.get(predecessor)) {
continue;
}
std::list<uint_fast64_t>::const_iterator choiceVariableIndicesIterator = variableInformation.stateToChoiceVariablesIndexMap.at(predecessor).begin();
for (auto relevantChoice : choiceInformation.relevantChoicesForRelevantStates.at(predecessor)) {
bool choiceTargetsCurrentState = false;
@ -998,7 +1003,7 @@ namespace storm {
}
++numberOfConstraintsCreated;
}
// Assert that at least one initial state selects at least one action.
variables.clear();
coefficients.clear();
@ -1241,13 +1246,24 @@ namespace storm {
static std::set<uint_fast64_t> getMinimalLabelSet(storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, double probabilityThreshold, bool checkThresholdFeasible = false, bool includeSchedulerCuts = false) {
#ifdef STORM_HAVE_GUROBI
auto startTime = std::chrono::high_resolution_clock::now();
// (0) Check whether the MDP is indeed labeled.
if (!labeledMdp.hasChoiceLabels()) {
throw storm::exceptions::InvalidArgumentException() << "Minimal label set generation is impossible for unlabeled model.";
}
// (1) FIXME: check whether its possible to exceed the threshold if checkThresholdFeasible is set.
// (1) Check whether its possible to exceed the threshold if checkThresholdFeasible is set.
double maximalReachabilityProbability = 0;
storm::modelchecker::prctl::SparseMdpPrctlModelChecker<T> modelchecker(labeledMdp, new storm::solver::GmmxxNondeterministicLinearEquationSolver<T>());
std::vector<T> result = modelchecker.checkUntil(false, phiStates, psiStates, false, nullptr);
for (auto state : labeledMdp.getInitialStates()) {
maximalReachabilityProbability = std::max(maximalReachabilityProbability, result[state]);
}
if (maximalReachabilityProbability <= probabilityThreshold) {
throw storm::exceptions::InvalidArgumentException() << "Given probability threshold " << probabilityThreshold << " can not be achieved in model with maximal reachability probability of " << maximalReachabilityProbability << ".";
}
// (2) Identify relevant and problematic states.
StateInformation stateInformation = determineRelevantAndProblematicStates(labeledMdp, phiStates, psiStates);
@ -1277,10 +1293,13 @@ namespace storm {
// (4.5) Read off result from variables.
std::set<uint_fast64_t> usedLabelSet = getUsedLabelsInSolution(environmentModelPair.first, environmentModelPair.second, variableInformation);
usedLabelSet.insert(choiceInformation.knownLabels.begin(), choiceInformation.knownLabels.end());
// Display achieved probability.
std::pair<uint_fast64_t, double> initialStateProbabilityPair = getReachabilityProbability(environmentModelPair.first, environmentModelPair.second, labeledMdp, variableInformation);
LOG4CPLUS_DEBUG(logger, "Achieved probability " << initialStateProbabilityPair.second << " in initial state " << initialStateProbabilityPair.first << ".");
auto endTime = std::chrono::high_resolution_clock::now();
std::cout << "Computed minimal label set of size " << usedLabelSet.size() << " in " << std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count() << "ms." << std::endl;
// (4.6) Shutdown Gurobi.
destroyGurobiModelAndEnvironment(environmentModelPair.first, environmentModelPair.second);

147
src/counterexamples/SMTMinimalCommandSetGenerator.h

@ -89,6 +89,7 @@ namespace storm {
relevancyInformation.relevantStates &= ~psiStates;
LOG4CPLUS_DEBUG(logger, "Found " << relevancyInformation.relevantStates.getNumberOfSetBits() << " relevant states.");
LOG4CPLUS_DEBUG(logger, relevancyInformation.relevantStates.toString());
// Retrieve some references for convenient access.
storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix();
@ -121,13 +122,21 @@ namespace storm {
// Compute the set of labels that are known to be taken in any case.
relevancyInformation.knownLabels = storm::utility::counterexamples::getGuaranteedLabelSet(labeledMdp, psiStates, relevancyInformation.relevantLabels);
if (!relevancyInformation.knownLabels.empty()) {
std::set<uint_fast64_t> remainingLabels;
std::set_difference(relevancyInformation.relevantLabels.begin(), relevancyInformation.relevantLabels.end(), relevancyInformation.knownLabels.begin(), relevancyInformation.knownLabels.end(), std::inserter(remainingLabels, remainingLabels.begin()));
relevancyInformation.relevantLabels = remainingLabels;
}
// std::vector<std::set<uint_fast64_t>> guaranteedLabels = storm::utility::counterexamples::getGuaranteedLabelSets(labeledMdp, psiStates, relevancyInformation.relevantLabels);
// for (auto state : relevancyInformation.relevantStates) {
// std::cout << "state " << state << " ##########################################################" << std::endl;
// for (auto label : guaranteedLabels[state]) {
// std::cout << label << ", ";
// }
// std::cout << std::endl;
// }
LOG4CPLUS_DEBUG(logger, "Found " << relevancyInformation.relevantLabels.size() << " relevant and " << relevancyInformation.knownLabels.size() << " known labels.");
return relevancyInformation;
}
@ -852,20 +861,10 @@ namespace storm {
* @param variableInformation A structure with information about the variables for the labels.
*/
static void ruleOutSolution(z3::context& context, z3::solver& solver, std::set<uint_fast64_t> const& commandSet, VariableInformation const& variableInformation) {
std::map<uint_fast64_t, uint_fast64_t>::const_iterator labelIndexIterator = variableInformation.labelToIndexMap.begin();
z3::expr blockSolutionExpression(context);
if (commandSet.find(labelIndexIterator->first) != commandSet.end()) {
blockSolutionExpression = !variableInformation.labelVariables[labelIndexIterator->second];
} else {
blockSolutionExpression = variableInformation.labelVariables[labelIndexIterator->second];
}
++labelIndexIterator;
for (; labelIndexIterator != variableInformation.labelToIndexMap.end(); ++labelIndexIterator) {
if (commandSet.find(labelIndexIterator->first) != commandSet.end()) {
blockSolutionExpression = blockSolutionExpression || !variableInformation.labelVariables[labelIndexIterator->second];
} else {
blockSolutionExpression = blockSolutionExpression || variableInformation.labelVariables[labelIndexIterator->second];
z3::expr blockSolutionExpression = context.bool_val(false);
for (auto labelIndexPair : variableInformation.labelToIndexMap) {
if (commandSet.find(labelIndexPair.first) == commandSet.end()) {
blockSolutionExpression = blockSolutionExpression || variableInformation.labelVariables[labelIndexPair.second];
}
}
@ -949,6 +948,113 @@ namespace storm {
// set and return it.
return getUsedLabelSet(context, solver.get_model(), variableInformation);
}
static void analyzeBadSolution(z3::context& context, z3::solver& solver, storm::models::Mdp<T> const& subMdp, storm::models::Mdp<T> const& originalMdp, storm::storage::BitVector const& psiStates, std::set<uint_fast64_t> const& commandSet, VariableInformation& variableInformation, RelevancyInformation const& relevancyInformation) {
storm::storage::BitVector reachableStates(subMdp.getNumberOfStates());
// Initialize the stack for the DFS.
bool targetStateIsReachable = false;
std::vector<uint_fast64_t> stack;
stack.reserve(subMdp.getNumberOfStates());
for (auto initialState : subMdp.getInitialStates()) {
stack.push_back(initialState);
reachableStates.set(initialState, true);
}
storm::storage::SparseMatrix<T> const& transitionMatrix = subMdp.getTransitionMatrix();
std::vector<uint_fast64_t> const& nondeterministicChoiceIndices = subMdp.getNondeterministicChoiceIndices();
std::vector<std::set<uint_fast64_t>> const& subChoiceLabeling = subMdp.getChoiceLabeling();
std::set<uint_fast64_t> reachableLabels;
while (!stack.empty()) {
uint_fast64_t currentState = stack.back();
stack.pop_back();
for (uint_fast64_t currentChoice = nondeterministicChoiceIndices[currentState]; currentChoice < nondeterministicChoiceIndices[currentState + 1]; ++currentChoice) {
bool choiceTargetsRelevantState = false;
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(currentChoice), successorIte = transitionMatrix.constColumnIteratorEnd(currentChoice); successorIt != successorIte; ++successorIt) {
if (relevancyInformation.relevantStates.get(*successorIt) && currentState != *successorIt) {
choiceTargetsRelevantState = true;
if (!reachableStates.get(*successorIt)) {
reachableStates.set(*successorIt, true);
stack.push_back(*successorIt);
}
} else if (psiStates.get(*successorIt)) {
targetStateIsReachable = true;
}
}
if (choiceTargetsRelevantState) {
for (auto label : subChoiceLabeling[currentChoice]) {
reachableLabels.insert(label);
}
}
}
}
LOG4CPLUS_DEBUG(logger, "Successfully performed reachability analysis.");
if (targetStateIsReachable) {
LOG4CPLUS_ERROR(logger, "Target must be unreachable for this analysis.");
throw storm::exceptions::InvalidStateException() << "Target must be unreachable for this analysis.";
}
std::vector<std::set<uint_fast64_t>> const& choiceLabeling = originalMdp.getChoiceLabeling();
std::set<uint_fast64_t> cutLabels;
for (auto state : reachableStates) {
for (auto currentChoice : relevancyInformation.relevantChoicesForRelevantStates.at(state)) {
if (!storm::utility::set::isSubsetOf(choiceLabeling[currentChoice], commandSet)) {
for (auto label : choiceLabeling[currentChoice]) {
if (commandSet.find(label) == commandSet.end()) {
cutLabels.insert(label);
}
}
}
}
}
std::vector<z3::expr> formulae;
std::set<uint_fast64_t> unknownReachableLabels;
std::set_difference(reachableLabels.begin(), reachableLabels.end(), relevancyInformation.knownLabels.begin(), relevancyInformation.knownLabels.end(), std::inserter(unknownReachableLabels, unknownReachableLabels.begin()));
for (auto label : unknownReachableLabels) {
formulae.push_back(!variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(label)));
}
for (auto cutLabel : cutLabels) {
formulae.push_back(variableInformation.labelVariables.at(variableInformation.labelToIndexMap.at(cutLabel)));
}
LOG4CPLUS_DEBUG(logger, "Asserting reachability implications.");
// for (auto e : formulae) {
// std::cout << e << ", ";
// }
// std::cout << std::endl;
assertDisjunction(context, solver, formulae);
//
// std::cout << "formulae: " << std::endl;
// for (auto e : formulae) {
// std::cout << e << ", ";
// }
// std::cout << std::endl;
//
// storm::storage::BitVector unreachableRelevantStates = ~reachableStates & relevancyInformation.relevantStates;
// std::cout << unreachableRelevantStates.toString() << std::endl;
// std::cout << reachableStates.toString() << std::endl;
// std::cout << "reachable commands" << std::endl;
// for (auto label : reachableLabels) {
// std::cout << label << ", ";
// }
// std::cout << std::endl;
// std::cout << "cut commands" << std::endl;
// for (auto label : cutLabels) {
// std::cout << label << ", ";
// }
// std::cout << std::endl;
}
#endif
public:
@ -1013,6 +1119,7 @@ namespace storm {
uint_fast64_t currentBound = 0;
maximalReachabilityProbability = 0;
auto iterationTimer = std::chrono::high_resolution_clock::now();
uint_fast64_t zeroProbabilityCount = 0;
do {
LOG4CPLUS_DEBUG(logger, "Computing minimal command set.");
commandSet = findSmallestCommandSet(context, solver, variableInformation, currentBound);
@ -1027,11 +1134,19 @@ namespace storm {
LOG4CPLUS_DEBUG(logger, "Computed model checking results.");
// Now determine the maximal reachability probability by checking all initial states.
maximalReachabilityProbability = 0;
for (auto state : labeledMdp.getInitialStates()) {
maximalReachabilityProbability = std::max(maximalReachabilityProbability, result[state]);
}
if (maximalReachabilityProbability <= probabilityThreshold) {
if (maximalReachabilityProbability == 0) {
++zeroProbabilityCount;
// If there was no target state reachable, analyze the solution and guide the solver into the
// right direction.
analyzeBadSolution(context, solver, subMdp, labeledMdp, psiStates, commandSet, variableInformation, relevancyInformation);
}
// In case we have not yet exceeded the given threshold, we have to rule out the current solution.
ruleOutSolution(context, solver, commandSet, variableInformation);
} else {
@ -1041,7 +1156,7 @@ namespace storm {
endTime = std::chrono::high_resolution_clock::now();
if (std::chrono::duration_cast<std::chrono::seconds>(endTime - iterationTimer).count() > 5) {
std::cout << "Performed " << iterations << " iterations in " << std::chrono::duration_cast<std::chrono::seconds>(endTime - startTime).count() << "s. Current command set size is " << commandSet.size() << "." << std::endl;
std::cout << "Performed " << iterations << " iterations in " << std::chrono::duration_cast<std::chrono::seconds>(endTime - startTime).count() << "s. Current command set size is " << commandSet.size() << ". Encountered maximal probability of zero " << zeroProbabilityCount << " times." << std::endl;
iterationTimer = std::chrono::high_resolution_clock::now();
}
} while (!done);

5
src/models/Mdp.h

@ -145,6 +145,7 @@ public:
storm::storage::SparseMatrix<T> transitionMatrix;
transitionMatrix.initialize();
std::vector<uint_fast64_t> nondeterministicChoiceIndices;
std::vector<std::set<uint_fast64_t>> newChoiceLabeling;
// Check for each choice of each state, whether the choice labels are fully contained in the given label set.
uint_fast64_t currentRow = 0;
@ -163,6 +164,7 @@ public:
for (typename storm::storage::SparseMatrix<T>::ConstIterator rowIt = row.begin(), rowIte = row.end(); rowIt != rowIte; ++rowIt) {
transitionMatrix.insertNextValue(currentRow, rowIt.column(), rowIt.value(), true);
}
newChoiceLabeling.emplace_back(choiceLabeling[choice]);
++currentRow;
}
}
@ -171,13 +173,14 @@ public:
if (!stateHasValidChoice) {
nondeterministicChoiceIndices.push_back(currentRow);
transitionMatrix.insertNextValue(currentRow, state, storm::utility::constGetOne<T>(), true);
newChoiceLabeling.emplace_back();
++currentRow;
}
}
transitionMatrix.finalize(true);
nondeterministicChoiceIndices.push_back(currentRow);
Mdp<T> restrictedMdp(std::move(transitionMatrix), storm::models::AtomicPropositionsLabeling(this->getStateLabeling()), std::move(nondeterministicChoiceIndices), this->hasStateRewards() ? boost::optional<std::vector<T>>(this->getStateRewardVector()) : boost::optional<std::vector<T>>(), this->hasTransitionRewards() ? boost::optional<storm::storage::SparseMatrix<T>>(this->getTransitionRewardMatrix()) : boost::optional<storm::storage::SparseMatrix<T>>(), boost::optional<std::vector<std::set<uint_fast64_t>>>(this->getChoiceLabeling()));
Mdp<T> restrictedMdp(std::move(transitionMatrix), storm::models::AtomicPropositionsLabeling(this->getStateLabeling()), std::move(nondeterministicChoiceIndices), this->hasStateRewards() ? boost::optional<std::vector<T>>(this->getStateRewardVector()) : boost::optional<std::vector<T>>(), this->hasTransitionRewards() ? boost::optional<storm::storage::SparseMatrix<T>>(this->getTransitionRewardMatrix()) : boost::optional<storm::storage::SparseMatrix<T>>(), boost::optional<std::vector<std::set<uint_fast64_t>>>(newChoiceLabeling));
return restrictedMdp;
}

29
src/storm.cpp

@ -338,19 +338,22 @@ int main(const int argc, const char* argv[]) {
model->printModelInformationToStream(std::cout);
// Enable the following lines to test the MinimalLabelSetGenerator.
// if (model->getType() == storm::models::MDP) {
// std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
if (model->getType() == storm::models::MDP) {
std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
// Build stuff for coin example.
// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
// std::set<uint_fast64_t> labels = storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true, true);
//
// std::cout << "Found solution with " << labels.size() << " commands." << std::endl;
// for (uint_fast64_t label : labels) {
// std::cout << label << ", ";
// }
// std::cout << std::endl;
// }
// storm::storage::BitVector const& collisionStates = labeledMdp->getLabeledStates("collision_max_backoff");
// storm::storage::BitVector const& deliveredStates = labeledMdp->getLabeledStates("all_delivered");
// std::set<uint_fast64_t> labels = storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, ~collisionStates, deliveredStates, 0.5, true, false);
// storm::storage::BitVector const& electedStates = labeledMdp->getLabeledStates("elected");
// std::set<uint_fast64_t> labels = storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), electedStates, 0.5, true, true);
}
// Enable the following lines to test the SMTMinimalCommandSetGenerator.
if (model->getType() == storm::models::MDP) {
@ -363,13 +366,13 @@ int main(const int argc, const char* argv[]) {
// std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true);
// Build stuff for csma example.
storm::storage::BitVector const& collisionStates = labeledMdp->getLabeledStates("collision_max_backoff");
storm::storage::BitVector const& deliveredStates = labeledMdp->getLabeledStates("all_delivered");
std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, ~collisionStates, deliveredStates, 0.5, true);
// storm::storage::BitVector const& collisionStates = labeledMdp->getLabeledStates("collision_max_backoff");
// storm::storage::BitVector const& deliveredStates = labeledMdp->getLabeledStates("all_delivered");
// std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, ~collisionStates, deliveredStates, 0.5, true);
// Build stuff for firewire example.
// storm::storage::BitVector const& electedStates = labeledMdp->getLabeledStates("elected");
// std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), electedStates, 0.01, true);
// std::set<uint_fast64_t> labels = storm::counterexamples::SMTMinimalCommandSetGenerator<double>::getMinimalCommandSet(program, constants, *labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), electedStates, 0.5, true);
// Build stuff for wlan example.
// storm::storage::BitVector const& oneCollisionStates = labeledMdp->getLabeledStates("oneCollision");

21
src/utility/counterexamples.h

@ -15,12 +15,12 @@ namespace storm {
namespace counterexamples {
/*!
* Computes a set of action labels that is visited along all paths from an initial to a target state.
* Computes a set of action labels that is visited along all paths from any state to a target state.
*
* @return The set of action labels that is visited on all paths from an initial to a target state.
* @return The set of action labels that is visited on all paths from any state to a target state.
*/
template <typename T>
std::set<uint_fast64_t> getGuaranteedLabelSet(storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& psiStates, std::set<uint_fast64_t> const& relevantLabels) {
std::vector<std::set<uint_fast64_t>> getGuaranteedLabelSets(storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& psiStates, std::set<uint_fast64_t> const& relevantLabels) {
// Get some data from the MDP for convenient access.
storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix();
std::vector<uint_fast64_t> const& nondeterministicChoiceIndices = labeledMdp.getNondeterministicChoiceIndices();
@ -78,11 +78,22 @@ namespace storm {
worklist.pop();
}
// Now build the intersection over the analysis information of all initial states.
return analysisInformation;
}
/*!
* Computes a set of action labels that is visited along all paths from an initial state to a target state.
*
* @return The set of action labels that is visited on all paths from an initial state to a target state.
*/
template <typename T>
std::set<uint_fast64_t> getGuaranteedLabelSet(storm::models::Mdp<T> const& labeledMdp, storm::storage::BitVector const& psiStates, std::set<uint_fast64_t> const& relevantLabels) {
std::vector<std::set<uint_fast64_t>> guaranteedLabels = getGuaranteedLabelSets(labeledMdp, psiStates, relevantLabels);
std::set<uint_fast64_t> knownLabels(relevantLabels);
std::set<uint_fast64_t> tempIntersection;
for (auto initialState : labeledMdp.getInitialStates()) {
std::set_intersection(knownLabels.begin(), knownLabels.end(), analysisInformation[initialState].begin(), analysisInformation[initialState].end(), std::inserter(tempIntersection, tempIntersection.begin()));
std::set_intersection(knownLabels.begin(), knownLabels.end(), guaranteedLabels[initialState].begin(), guaranteedLabels[initialState].end(), std::inserter(tempIntersection, tempIntersection.begin()));
std::swap(knownLabels, tempIntersection);
tempIntersection.clear();
}
Loading…
Cancel
Save