diff --git a/src/counterexamples/SMTMinimalCommandSetGenerator.h b/src/counterexamples/SMTMinimalCommandSetGenerator.h index 5f0e6ae85..59a4279df 100644 --- a/src/counterexamples/SMTMinimalCommandSetGenerator.h +++ b/src/counterexamples/SMTMinimalCommandSetGenerator.h @@ -243,20 +243,154 @@ namespace storm { } /*! - * Asserts that at most one of the blocking variables may be true at any time. + * Asserts that the conjunction of the given formulae holds. * * @param context The Z3 context in which to build the expressions. * @param solver The solver to use for the satisfiability evaluation. - * @param blockingVariables A vector of variables out of which only one may be true. + * @param formulaVector A vector of expressions that shall form the conjunction. */ - static void assertAtMostOne(z3::context& context, z3::solver& solver, std::vector const& blockingVariables) { - for (uint_fast64_t i = 0; i < blockingVariables.size(); ++i) { - for (uint_fast64_t j = i + 1; j < blockingVariables.size(); ++j) { - solver.add(!blockingVariables[i] || !blockingVariables[j]); + static void assertConjunction(z3::context& context, z3::solver& solver, std::vector const& formulaVector) { + z3::expr conjunction(context); + for (uint_fast64_t i = 0; i < formulaVector.size(); ++i) { + if (i == 0) { + conjunction = formulaVector[i]; + } else { + conjunction = conjunction && formulaVector[i]; } } + solver.add(conjunction); + } + + /*! + * Creates a full-adder for the two inputs and returns the resulting bit as well as the carry bit. + * + * @param in1 The first input to the adder. + * @param in2 The second input to the adder. + * @param carryIn The carry bit input to the adder. + * @return A pair whose first component represents the carry bit and whose second component represents the + * result bit. + */ + static std::pair createFullAdder(z3::expr in1, z3::expr in2, z3::expr carryIn) { + z3::expr resultBit = (in1 && !in2 && !carryIn) || (!in1 && in2 && !carryIn) || (!in1 && !in2 && carryIn); + z3::expr carryBit = in1 && in2 || in1 && carryIn || in2 && carryIn; + + return std::make_pair(carryBit, resultBit); + } + + /*! + * Creates an adder for the two inputs of equal size. The resulting vector represents the different bits of + * the sum (and is thus one bit longer than the two inputs). + * + * @param context The Z3 context in which to build the expressions. + * @param in1 The first input to the adder. + * @param in2 The second input to the adder. + * @return A vector representing the bits of the sum of the two inputs. + */ + static std::vector createAdder(z3::context& context, std::vector const& in1, std::vector const& in2) { + // Sanity check for sizes of input. + if (in1.size() != in2.size() || in1.size() == 0) { + LOG4CPLUS_ERROR(logger, "Illegal input to adder (" << in1.size() << ", " << in2.size() << ")."); + throw storm::exceptions::InvalidArgumentException() << "Illegal input to adder."; + } + + // Prepare result. + std::vector result; + result.reserve(in1.size() + 1); + + // Add all bits individually and pass on carry bit appropriately. + z3::expr carryBit = context.bool_val(false); + for (uint_fast64_t currentBit = 0; currentBit < in1.size(); ++currentBit) { + std::pair localResult = createFullAdder(in1[currentBit], in2[currentBit], carryBit); + + result.push_back(localResult.second); + carryBit = localResult.first; + } + result.push_back(carryBit); + + return result; + } + + /*! + * Given a number of input numbers, creates a number of output numbers that corresponds to the sum of two + * consecutive numbers of the input. If the number if input numbers is odd, the last number is simply added + * to the output. + * + * @param context The Z3 context in which to build the expressions. + * @param in A vector or binary encoded numbers. + * @return A vector of numbers that each correspond to the sum of two consecutive elements of the input. + */ + static std::vector> createAdderPairs(z3::context& context, std::vector> const& in) { + std::vector> result; + result.reserve(in.size() / 2 + in.size() % 2); + + for (uint_fast64_t index = 0; index < in.size() / 2; ++index) { + result.push_back(createAdder(context, in[2 * index], in[2 * index + 1])); + } + + if (in.size() % 2 != 0) { + result.push_back(in[in.size() - 1]); + result.back().push_back(context.bool_val(false)); + } + + return result; + } + + /*! + * Creates a counter circuit that returns the number of literals out of the given vector that are set to true. + * + * @param context The Z3 context in which to build the expressions. + * @param literals The literals for which to create the adder circuit. + * @return A bit vector representing the number of literals that are set to true. + */ + static std::vector createCounterCircuit(z3::context& context, std::vector const& literals) { + // Create the auxiliary vector. + std::vector> aux; + for (uint_fast64_t index = 0; index < literals.size(); ++index) { + aux.emplace_back(); + aux.back().push_back(literals[index]); + } + + while (aux.size() > 1) { + aux = createAdderPairs(context, aux); + } + + return aux[0]; + } + + /*! + * Asserts that the input vector encodes a decimal smaller or equal to one. + * + * @param context The Z3 context in which to build the expressions. + * @param solver The solver to use for the satisfiability evaluation. + * @param input The binary encoded input number. + */ + static void assertLessOrEqualOne(z3::context& context, z3::solver& solver, std::vector const& input) { + std::vector tmp; + tmp.reserve(input.size() - 1); + for (uint_fast64_t index = 1; index < input.size(); ++index) { + tmp.push_back(!input[index]); + } + assertConjunction(context, solver, tmp); + } + + /*! + * Asserts that at most one of the blocking variables may be true at any time. + * + * @param context The Z3 context in which to build the expressions. + * @param solver The solver to use for the satisfiability evaluation. + * @param blockingVariables A vector of variables out of which only one may be true. + */ + static void assertAtMostOne(z3::context& context, z3::solver& solver, std::vector const& literals) { + std::vector counter = createCounterCircuit(context, literals); + assertLessOrEqualOne(context, solver, counter); + + +// for (uint_fast64_t i = 0; i < blockingVariables.size(); ++i) { +// for (uint_fast64_t j = i + 1; j < blockingVariables.size(); ++j) { +// solver.add(!blockingVariables[i] || !blockingVariables[j]); +// } +// } } - /*! * Performs one Fu-Malik-Maxsat step. @@ -366,6 +500,7 @@ namespace storm { // Now we are ready to construct the label set from the model of the solver. std::set result; z3::model model = solver.get_model(); + for (auto const& labelIndexPair : variableInformation.labelToIndexMap) { z3::expr value = model.eval(variableInformation.labelVariables[labelIndexPair.second]); diff --git a/src/models/Mdp.h b/src/models/Mdp.h index 48a2584d5..4bdbbf85c 100644 --- a/src/models/Mdp.h +++ b/src/models/Mdp.h @@ -151,7 +151,7 @@ public: for(uint_fast64_t state = 0; state < this->getNumberOfStates(); ++state) { bool stateHasValidChoice = false; for (uint_fast64_t choice = this->getNondeterministicChoiceIndices()[state]; choice < this->getNondeterministicChoiceIndices()[state + 1]; ++choice) { - bool choiceValid = storm::utility::set::isSubsetOf(choiceLabeling[state], enabledChoiceLabels); + bool choiceValid = storm::utility::set::isSubsetOf(choiceLabeling[choice], enabledChoiceLabels); // If the choice is valid, copy over all its elements. if (choiceValid) { diff --git a/src/storm.cpp b/src/storm.cpp index ad9985d60..448169d94 100644 --- a/src/storm.cpp +++ b/src/storm.cpp @@ -338,13 +338,19 @@ int main(const int argc, const char* argv[]) { model->printModelInformationToStream(std::cout); // Enable the following lines to test the MinimalLabelSetGenerator. -// if (model->getType() == storm::models::MDP) { -// std::shared_ptr> labeledMdp = model->as>(); -// storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); -// storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); -// storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; -// storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true); -// } + if (model->getType() == storm::models::MDP) { + std::shared_ptr> labeledMdp = model->as>(); + storm::storage::BitVector const& finishedStates = labeledMdp->getLabeledStates("finished"); + storm::storage::BitVector const& allCoinsEqual1States = labeledMdp->getLabeledStates("all_coins_equal_1"); + storm::storage::BitVector targetStates = finishedStates & allCoinsEqual1States; + std::unordered_set labels = storm::counterexamples::MILPMinimalLabelSetGenerator::getMinimalLabelSet(*labeledMdp, storm::storage::BitVector(labeledMdp->getNumberOfStates(), true), targetStates, 0.3, true, true); + + std::cout << "Found solution with " << labels.size() << " commands." << std::endl; + for (uint_fast64_t label : labels) { + std::cout << label << ", "; + } + std::cout << std::endl; + } // Enable the following lines to test the SMTMinimalCommandSetGenerator. if (model->getType() == storm::models::MDP) {