Browse Source

Minimal label set generator now works for coin example, yay

Former-commit-id: 9ab8552d82
main
dehnert 12 years ago
parent
commit
5ff550194c
  1. 113
      src/counterexamples/MinimalLabelSetGenerator.h
  2. 6
      src/storm.cpp

113
src/counterexamples/MinimalLabelSetGenerator.h

@ -62,9 +62,14 @@ namespace storm {
storm::storage::BitVector problematicStates = storm::utility::graph::performProbGreater0E(labeledMdp, backwardTransitions, phiStates, psiStates); storm::storage::BitVector problematicStates = storm::utility::graph::performProbGreater0E(labeledMdp, backwardTransitions, phiStates, psiStates);
problematicStates.complement(); problematicStates.complement();
problematicStates &= relevantStates; problematicStates &= relevantStates;
LOG4CPLUS_INFO(logger, "Found " << phiStates.getNumberOfSetBits() << " filter states (" << phiStates.toString() << ").");
LOG4CPLUS_INFO(logger, "Found " << psiStates.getNumberOfSetBits() << " target states (" << psiStates.toString() << ").");
LOG4CPLUS_INFO(logger, "Found " << relevantStates.getNumberOfSetBits() << " relevant states (" << relevantStates.toString() << ").");
LOG4CPLUS_INFO(logger, "Found " << problematicStates.getNumberOfSetBits() << " problematic states (" << problematicStates.toString() << ").");
// (3) Determine set of relevant labels.
// (3) Determine sets of relevant labels and problematic choices.
std::unordered_map<uint_fast64_t, std::list<uint_fast64_t>> relevantChoicesForRelevantStates; std::unordered_map<uint_fast64_t, std::list<uint_fast64_t>> relevantChoicesForRelevantStates;
std::unordered_map<uint_fast64_t, std::list<uint_fast64_t>> problematicChoicesForProblematicStates;
std::unordered_set<uint_fast64_t> relevantLabels; std::unordered_set<uint_fast64_t> relevantLabels;
storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix(); storm::storage::SparseMatrix<T> const& transitionMatrix = labeledMdp.getTransitionMatrix();
std::vector<uint_fast64_t> const& nondeterministicChoiceIndices = labeledMdp.getNondeterministicChoiceIndices(); std::vector<uint_fast64_t> const& nondeterministicChoiceIndices = labeledMdp.getNondeterministicChoiceIndices();
@ -73,8 +78,12 @@ namespace storm {
// If so, the associated labels become relevant. // If so, the associated labels become relevant.
for (auto state : relevantStates) { for (auto state : relevantStates) {
relevantChoicesForRelevantStates.emplace(state, std::list<uint_fast64_t>()); relevantChoicesForRelevantStates.emplace(state, std::list<uint_fast64_t>());
if (problematicStates.get(state)) {
problematicChoicesForProblematicStates.emplace(state, std::list<uint_fast64_t>());
}
for (uint_fast64_t row = nondeterministicChoiceIndices[state]; row < nondeterministicChoiceIndices[state + 1]; ++row) { for (uint_fast64_t row = nondeterministicChoiceIndices[state]; row < nondeterministicChoiceIndices[state + 1]; ++row) {
bool currentChoiceRelevant = false; bool currentChoiceRelevant = false;
bool allSuccessorsProblematic = true;
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(row); successorIt != transitionMatrix.constColumnIteratorEnd(row); ++successorIt) { for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(row); successorIt != transitionMatrix.constColumnIteratorEnd(row); ++successorIt) {
// If there is a relevant successor, we need to add the labels of the current choice. // If there is a relevant successor, we need to add the labels of the current choice.
if (relevantStates.get(*successorIt) || psiStates.get(*successorIt)) { if (relevantStates.get(*successorIt) || psiStates.get(*successorIt)) {
@ -86,16 +95,18 @@ namespace storm {
relevantChoicesForRelevantStates[state].emplace_back(row); relevantChoicesForRelevantStates[state].emplace_back(row);
} }
} }
if (!problematicStates.get(*successorIt)) {
allSuccessorsProblematic = false;
}
}
if (problematicStates.get(state) && allSuccessorsProblematic) {
problematicChoicesForProblematicStates[state].emplace_back(row);
} }
} }
} }
LOG4CPLUS_INFO(logger, "Found " << relevantLabels.size() << " relevant labels."); LOG4CPLUS_INFO(logger, "Found " << relevantLabels.size() << " relevant labels.");
// Determine set of problematic transitions.
std::unordered_map<uint_fast64_t, std::list<uint_fast64_t>> problematicChoicesForProblematicStates;
for (auto state : problematicStates) {
std::unordered_map<uint_fast64_t, std::list<uint_fast64_t>> relevantChoicesForRelevantStates;
for (auto label : relevantLabels) {
LOG4CPLUS_INFO(logger, "Relevant label " << label << ".");
} }
// (3) Encode resulting system as MILP problem. // (3) Encode resulting system as MILP problem.
@ -171,11 +182,11 @@ namespace storm {
} }
// Create variables for problematic states, successors of problematic states and transitions of problematic states. // Create variables for problematic states, successors of problematic states and transitions of problematic states.
std::unordered_map<uint_fast64_t, uint_fast64_t> problematicStateVariables;
std::unordered_map<uint_fast64_t, uint_fast64_t> problematicStateVariablesToIndexMap;
std::unordered_map<std::pair<uint_fast64_t, uint_fast64_t>, uint_fast64_t, PairHash> problematicTransitionVariables; std::unordered_map<std::pair<uint_fast64_t, uint_fast64_t>, uint_fast64_t, PairHash> problematicTransitionVariables;
for (auto state : problematicStates) { for (auto state : problematicStates) {
// First check whether there is not already a variable for this state and proceed with next state. // First check whether there is not already a variable for this state and proceed with next state.
if (problematicStateVariables.find(state) == problematicStateVariables.end()) {
if (problematicStateVariablesToIndexMap.find(state) == problematicStateVariablesToIndexMap.end()) {
// Reset stringstream properly to construct new variable name. // Reset stringstream properly to construct new variable name.
variableNameBuffer.str(""); variableNameBuffer.str("");
variableNameBuffer.clear(); variableNameBuffer.clear();
@ -186,7 +197,7 @@ namespace storm {
LOG4CPLUS_ERROR(logger, "Could not create Gurobi variable (" << GRBgeterrormsg(env) << ")."); LOG4CPLUS_ERROR(logger, "Could not create Gurobi variable (" << GRBgeterrormsg(env) << ").");
throw storm::exceptions::InvalidStateException() << "Could not create Gurobi variable (" << GRBgeterrormsg(env) << ")."; throw storm::exceptions::InvalidStateException() << "Could not create Gurobi variable (" << GRBgeterrormsg(env) << ").";
} }
problematicStateVariables[state] = nextLabelIndex;
problematicStateVariablesToIndexMap[state] = nextLabelIndex;
++nextLabelIndex; ++nextLabelIndex;
} }
@ -194,7 +205,7 @@ namespace storm {
for (uint_fast64_t row : relevantChoicesForState) { for (uint_fast64_t row : relevantChoicesForState) {
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(row); successorIt != transitionMatrix.constColumnIteratorEnd(row); ++successorIt) { for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(row); successorIt != transitionMatrix.constColumnIteratorEnd(row); ++successorIt) {
if (relevantStates.get(*successorIt)) { if (relevantStates.get(*successorIt)) {
if (problematicStateVariables.find(*successorIt) == problematicStateVariables.end()) {
if (problematicStateVariablesToIndexMap.find(*successorIt) == problematicStateVariablesToIndexMap.end()) {
// Reset stringstream properly to construct new variable name. // Reset stringstream properly to construct new variable name.
variableNameBuffer.str(""); variableNameBuffer.str("");
variableNameBuffer.clear(); variableNameBuffer.clear();
@ -204,7 +215,7 @@ namespace storm {
LOG4CPLUS_ERROR(logger, "Could not create Gurobi variable (" << GRBgeterrormsg(env) << ")."); LOG4CPLUS_ERROR(logger, "Could not create Gurobi variable (" << GRBgeterrormsg(env) << ").");
throw storm::exceptions::InvalidStateException() << "Could not create Gurobi variable (" << GRBgeterrormsg(env) << ")."; throw storm::exceptions::InvalidStateException() << "Could not create Gurobi variable (" << GRBgeterrormsg(env) << ").";
} }
problematicStateVariables[state] = nextLabelIndex;
problematicStateVariablesToIndexMap[state] = nextLabelIndex;
++nextLabelIndex; ++nextLabelIndex;
} }
variableNameBuffer.str(""); variableNameBuffer.str("");
@ -222,7 +233,7 @@ namespace storm {
} }
} }
LOG4CPLUS_ERROR(logger, "Successfully created " << nextLabelIndex << " Gurobi variables.");
LOG4CPLUS_INFO(logger, "Successfully created " << nextLabelIndex << " Gurobi variables.");
// Update model to incorporate prior changes. // Update model to incorporate prior changes.
error = GRBupdatemodel(model); error = GRBupdatemodel(model);
@ -240,7 +251,7 @@ namespace storm {
for (auto initialState : initialStates) { for (auto initialState : initialStates) {
int variableIndex = static_cast<int>(stateToProbabilityVariableIndex[initialState]); int variableIndex = static_cast<int>(stateToProbabilityVariableIndex[initialState]);
double coefficient = 1.0; double coefficient = 1.0;
error = GRBaddconstr(model, 1, &variableIndex, &coefficient, GRB_GREATER_EQUAL, lowerProbabilityBound + 10e-6, nullptr);
error = GRBaddconstr(model, 1, &variableIndex, &coefficient, GRB_GREATER_EQUAL, lowerProbabilityBound + 1e-6, nullptr);
if (error) { if (error) {
LOG4CPLUS_ERROR(logger, "Unable to assert constraint (" << GRBgeterrormsg(env) << ")."); LOG4CPLUS_ERROR(logger, "Unable to assert constraint (" << GRBgeterrormsg(env) << ").");
throw storm::exceptions::InvalidStateException() << "Unable to assert constraint (" << GRBgeterrormsg(env) << ")."; throw storm::exceptions::InvalidStateException() << "Unable to assert constraint (" << GRBgeterrormsg(env) << ").";
@ -321,13 +332,13 @@ namespace storm {
for (typename storm::storage::SparseMatrix<T>::ConstIterator successorIt = rows.begin(), successorIte = rows.end(); successorIt != successorIte; ++successorIt) { for (typename storm::storage::SparseMatrix<T>::ConstIterator successorIt = rows.begin(), successorIte = rows.end(); successorIt != successorIte; ++successorIt) {
if (relevantStates.get(successorIt.column())) { if (relevantStates.get(successorIt.column())) {
variables.push_back(stateToProbabilityVariableIndex[successorIt.column()]); variables.push_back(stateToProbabilityVariableIndex[successorIt.column()]);
coefficients.push_back(-1);
coefficients.push_back(-successorIt.value());
} else if (psiStates.get(successorIt.column())) { } else if (psiStates.get(successorIt.column())) {
rightHandSide += successorIt.value(); rightHandSide += successorIt.value();
} }
} }
coefficients.push_back(-1);
coefficients.push_back(1);
variables.push_back(currentChoiceVariableIndex); variables.push_back(currentChoiceVariableIndex);
error = GRBaddconstr(model, variables.size(), &variables[0], &coefficients[0], GRB_LESS_EQUAL, rightHandSide, nullptr); error = GRBaddconstr(model, variables.size(), &variables[0], &coefficients[0], GRB_LESS_EQUAL, rightHandSide, nullptr);
@ -341,8 +352,54 @@ namespace storm {
} }
// Add constraints that ensure reachability of at least one unproblematic state. // Add constraints that ensure reachability of at least one unproblematic state.
for (auto stateListPair : problematicChoicesForProblematicStates) {
for (auto problematicChoice : stateListPair.second) {
uint_fast64_t currentChoiceVariableIndex = stateToStartingIndexMap[stateListPair.first];
for (auto relevantChoice : relevantChoicesForRelevantStates[stateListPair.first]) {
if (relevantChoice == problematicChoice) {
break;
}
++currentChoiceVariableIndex;
}
std::vector<int> variables;
std::vector<double> coefficients;
variables.push_back(currentChoiceVariableIndex);
coefficients.push_back(1);
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(problematicChoice); successorIt != transitionMatrix.constColumnIteratorEnd(problematicChoice); ++successorIt) {
variables.push_back(problematicTransitionVariables[std::make_pair(stateListPair.first, *successorIt)]);
coefficients.push_back(-1);
}
error = GRBaddconstr(model, variables.size(), &variables[0], &coefficients[0], GRB_LESS_EQUAL, 0, nullptr);
if (error) {
LOG4CPLUS_ERROR(logger, "Unable to assert constraint (" << GRBgeterrormsg(env) << ").");
throw storm::exceptions::InvalidStateException() << "Unable to assert constraint (" << GRBgeterrormsg(env) << ").";
}
}
}
for (auto state : problematicStates) { for (auto state : problematicStates) {
for (auto problematicChoice : problematicChoicesForProblematicStates[state]) {
for (typename storm::storage::SparseMatrix<T>::ConstIndexIterator successorIt = transitionMatrix.constColumnIteratorBegin(problematicChoice); successorIt != transitionMatrix.constColumnIteratorEnd(problematicChoice); ++successorIt) {
std::vector<int> variables;
std::vector<double> coefficients;
variables.push_back(problematicStateVariablesToIndexMap[state]);
coefficients.push_back(1);
variables.push_back(problematicStateVariablesToIndexMap[*successorIt]);
coefficients.push_back(-1);
variables.push_back(problematicTransitionVariables[std::make_pair(state, *successorIt)]);
coefficients.push_back(1);
error = GRBaddconstr(model, variables.size(), &variables[0], &coefficients[0], GRB_LESS_EQUAL, 1 - 1e-6, nullptr);
if (error) {
LOG4CPLUS_ERROR(logger, "Unable to assert constraint (" << GRBgeterrormsg(env) << ").");
throw storm::exceptions::InvalidStateException() << "Unable to assert constraint (" << GRBgeterrormsg(env) << ").";
}
}
}
} }
// Update model to incorporate prior changes. // Update model to incorporate prior changes.
@ -358,7 +415,27 @@ namespace storm {
throw storm::exceptions::InvalidStateException() << "Unable to write Gurobi model (" << GRBgeterrormsg(env) << ")."; throw storm::exceptions::InvalidStateException() << "Unable to write Gurobi model (" << GRBgeterrormsg(env) << ").";
} }
// (3.3) Construct objective function.
error = GRBoptimize(model);
if (error) {
LOG4CPLUS_ERROR(logger, "Unable to optimize Gurobi model (" << GRBgeterrormsg(env) << ").");
throw storm::exceptions::InvalidStateException() << "Unable to optimize Gurobi model (" << GRBgeterrormsg(env) << ").";
}
std::vector<double> solution(labelToIndexMap.size());
error = GRBgetdblattrarray(model, GRB_DBL_ATTR_X, 0, labelToIndexMap.size(), &solution[0]);
if (error) {
LOG4CPLUS_ERROR(logger, "Unable to get Gurobi solution (" << GRBgeterrormsg(env) << ").");
throw storm::exceptions::InvalidStateException() << "Unable to get Gurobi solution (" << GRBgeterrormsg(env) << ").";
}
for (auto labelIndexPair : labelToIndexMap) {
std::cout << "label: " << labelIndexPair.first << " with value " << solution[labelIndexPair.second] << std::endl;
}
double reachabilityProbability = 0;
error = GRBgetdblattrarray(model, GRB_DBL_ATTR_X, stateToProbabilityVariableIndex[0], 1, &reachabilityProbability);
std::cout << "prob: " << reachabilityProbability << std::endl;
// (3.4) Construct constraint system. // (3.4) Construct constraint system.
// (4) Read off result from MILP variables. // (4) Read off result from MILP variables.

6
src/storm.cpp

@ -338,9 +338,9 @@ int main(const int argc, const char* argv[]) {
if (model->getType() == storm::models::MDP) { if (model->getType() == storm::models::MDP) {
std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>(); std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
storm::storage::BitVector const& allCoinsEqualStates = labeledMdp->getLabeledStates("agree");
storm::storage::BitVector targetStates = finishedStates & allCoinsEqualStates;
storm::counterexamples::MinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.4, true);
storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
storm::counterexamples::MinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.2, true);
} }
} }

Loading…
Cancel
Save