Browse Source

Further work on MaxSAT-based minimal command counterexamples.

Former-commit-id: 4991bdcb3d
tempestpy_adaptions
dehnert 11 years ago
parent
commit
b860f16ada
  1. 149
      src/counterexamples/SMTMinimalCommandSetGenerator.h
  2. 2
      src/models/Mdp.h
  3. 20
      src/storm.cpp

149
src/counterexamples/SMTMinimalCommandSetGenerator.h

@ -243,20 +243,154 @@ namespace storm {
}
/*!
* Asserts that at most one of the blocking variables may be true at any time.
* Asserts that the conjunction of the given formulae holds.
*
* @param context The Z3 context in which to build the expressions.
* @param solver The solver to use for the satisfiability evaluation.
* @param blockingVariables A vector of variables out of which only one may be true.
* @param formulaVector A vector of expressions that shall form the conjunction.
*/
static void assertAtMostOne(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& blockingVariables) {
for (uint_fast64_t i = 0; i < blockingVariables.size(); ++i) {
for (uint_fast64_t j = i + 1; j < blockingVariables.size(); ++j) {
solver.add(!blockingVariables[i] || !blockingVariables[j]);
static void assertConjunction(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& formulaVector) {
z3::expr conjunction(context);
for (uint_fast64_t i = 0; i < formulaVector.size(); ++i) {
if (i == 0) {
conjunction = formulaVector[i];
} else {
conjunction = conjunction && formulaVector[i];
}
}
solver.add(conjunction);
}
/*!
* Creates a full-adder for the two inputs and returns the resulting bit as well as the carry bit.
*
* @param in1 The first input to the adder.
* @param in2 The second input to the adder.
* @param carryIn The carry bit input to the adder.
* @return A pair whose first component represents the carry bit and whose second component represents the
* result bit.
*/
static std::pair<z3::expr, z3::expr> createFullAdder(z3::expr in1, z3::expr in2, z3::expr carryIn) {
z3::expr resultBit = (in1 && !in2 && !carryIn) || (!in1 && in2 && !carryIn) || (!in1 && !in2 && carryIn);
z3::expr carryBit = in1 && in2 || in1 && carryIn || in2 && carryIn;
return std::make_pair(carryBit, resultBit);
}
/*!
* Creates an adder for the two inputs of equal size. The resulting vector represents the different bits of
* the sum (and is thus one bit longer than the two inputs).
*
* @param context The Z3 context in which to build the expressions.
* @param in1 The first input to the adder.
* @param in2 The second input to the adder.
* @return A vector representing the bits of the sum of the two inputs.
*/
static std::vector<z3::expr> createAdder(z3::context& context, std::vector<z3::expr> const& in1, std::vector<z3::expr> const& in2) {
// Sanity check for sizes of input.
if (in1.size() != in2.size() || in1.size() == 0) {
LOG4CPLUS_ERROR(logger, "Illegal input to adder (" << in1.size() << ", " << in2.size() << ").");
throw storm::exceptions::InvalidArgumentException() << "Illegal input to adder.";
}
// Prepare result.
std::vector<z3::expr> result;
result.reserve(in1.size() + 1);
// Add all bits individually and pass on carry bit appropriately.
z3::expr carryBit = context.bool_val(false);
for (uint_fast64_t currentBit = 0; currentBit < in1.size(); ++currentBit) {
std::pair<z3::expr, z3::expr> localResult = createFullAdder(in1[currentBit], in2[currentBit], carryBit);
result.push_back(localResult.second);
carryBit = localResult.first;
}
result.push_back(carryBit);
return result;
}
/*!
* Given a number of input numbers, creates a number of output numbers that corresponds to the sum of two
* consecutive numbers of the input. If the number if input numbers is odd, the last number is simply added
* to the output.
*
* @param context The Z3 context in which to build the expressions.
* @param in A vector or binary encoded numbers.
* @return A vector of numbers that each correspond to the sum of two consecutive elements of the input.
*/
static std::vector<std::vector<z3::expr>> createAdderPairs(z3::context& context, std::vector<std::vector<z3::expr>> const& in) {
std::vector<std::vector<z3::expr>> result;
result.reserve(in.size() / 2 + in.size() % 2);
for (uint_fast64_t index = 0; index < in.size() / 2; ++index) {
result.push_back(createAdder(context, in[2 * index], in[2 * index + 1]));
}
if (in.size() % 2 != 0) {
result.push_back(in[in.size() - 1]);
result.back().push_back(context.bool_val(false));
}
return result;
}
/*!
* Creates a counter circuit that returns the number of literals out of the given vector that are set to true.
*
* @param context The Z3 context in which to build the expressions.
* @param literals The literals for which to create the adder circuit.
* @return A bit vector representing the number of literals that are set to true.
*/
static std::vector<z3::expr> createCounterCircuit(z3::context& context, std::vector<z3::expr> const& literals) {
// Create the auxiliary vector.
std::vector<std::vector<z3::expr>> aux;
for (uint_fast64_t index = 0; index < literals.size(); ++index) {
aux.emplace_back();
aux.back().push_back(literals[index]);
}
while (aux.size() > 1) {
aux = createAdderPairs(context, aux);
}
return aux[0];
}
/*!
* Asserts that the input vector encodes a decimal smaller or equal to one.
*
* @param context The Z3 context in which to build the expressions.
* @param solver The solver to use for the satisfiability evaluation.
* @param input The binary encoded input number.
*/
static void assertLessOrEqualOne(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& input) {
std::vector<z3::expr> tmp;
tmp.reserve(input.size() - 1);
for (uint_fast64_t index = 1; index < input.size(); ++index) {
tmp.push_back(!input[index]);
}
assertConjunction(context, solver, tmp);
}
/*!
* Asserts that at most one of the blocking variables may be true at any time.
*
* @param context The Z3 context in which to build the expressions.
* @param solver The solver to use for the satisfiability evaluation.
* @param blockingVariables A vector of variables out of which only one may be true.
*/
static void assertAtMostOne(z3::context& context, z3::solver& solver, std::vector<z3::expr> const& literals) {
std::vector<z3::expr> counter = createCounterCircuit(context, literals);
assertLessOrEqualOne(context, solver, counter);
// for (uint_fast64_t i = 0; i < blockingVariables.size(); ++i) {
// for (uint_fast64_t j = i + 1; j < blockingVariables.size(); ++j) {
// solver.add(!blockingVariables[i] || !blockingVariables[j]);
// }
// }
}
/*!
* Performs one Fu-Malik-Maxsat step.
@ -366,6 +500,7 @@ namespace storm {
// Now we are ready to construct the label set from the model of the solver.
std::set<uint_fast64_t> result;
z3::model model = solver.get_model();
for (auto const& labelIndexPair : variableInformation.labelToIndexMap) {
z3::expr value = model.eval(variableInformation.labelVariables[labelIndexPair.second]);

2
src/models/Mdp.h

@ -151,7 +151,7 @@ public:
for(uint_fast64_t state = 0; state < this->getNumberOfStates(); ++state) {
bool stateHasValidChoice = false;
for (uint_fast64_t choice = this->getNondeterministicChoiceIndices()[state]; choice < this->getNondeterministicChoiceIndices()[state + 1]; ++choice) {
bool choiceValid = storm::utility::set::isSubsetOf(choiceLabeling[state], enabledChoiceLabels);
bool choiceValid = storm::utility::set::isSubsetOf(choiceLabeling[choice], enabledChoiceLabels);
// If the choice is valid, copy over all its elements.
if (choiceValid) {

20
src/storm.cpp

@ -338,13 +338,19 @@ int main(const int argc, const char* argv[]) {
model->printModelInformationToStream(std::cout);
// Enable the following lines to test the MinimalLabelSetGenerator.
// if (model->getType() == storm::models::MDP) {
// std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
// storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true);
// }
if (model->getType() == storm::models::MDP) {
std::shared_ptr<storm::models::Mdp<double>> labeledMdp = model->as<storm::models::Mdp<double>>();
storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished");
storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1");
storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States;
std::unordered_set<uint_fast64_t> labels = storm::counterexamples::MILPMinimalLabelSetGenerator<double>::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true);
std::cout << "Found solution with " << labels.size() << " commands." << std::endl;
for (uint_fast64_t label : labels) {
std::cout << label << ", ";
}
std::cout << std::endl;
}
// Enable the following lines to test the SMTMinimalCommandSetGenerator.
if (model->getType() == storm::models::MDP) {

Loading…
Cancel
Save