diff --git a/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp b/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp index a57a0ccef..78d554346 100644 --- a/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp +++ b/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp @@ -41,8 +41,7 @@ namespace storm { } viHelper.performValueIteration(env, x, b, goal.direction()); - //if(goal.isShieldingTask()) { - if(true) { + if(goal.isShieldingTask()) { viHelper.getChoiceValues(env, x, constrainedChoiceValues); } viHelper.fillResultVector(x, relevantStates, psiStates); @@ -107,13 +106,12 @@ namespace storm { // create multiplier and execute the calculation for 1 step auto multiplier = storm::solver::MultiplierFactory().create(env, transitionMatrix); std::vector choiceValues = std::vector(transitionMatrix.getRowCount(), storm::utility::zero()); - - //if(goal.isShieldingTask()) { - if (true) { + if (goal.isShieldingTask()) { multiplier->multiply(env, x, &b, choiceValues); + multiplier->reduce(env, goal.direction(), choiceValues, transitionMatrix.getRowGroupIndices(), x, &statesOfCoalition); + } else { + multiplier->multiplyAndReduce(env, goal.direction(), x, &b, x, nullptr, &statesOfCoalition); } - multiplier->multiplyAndReduce(env, goal.direction(), x, &b, x, nullptr, &statesOfCoalition); - return SMGSparseModelCheckingHelperReturnType(std::move(x), std::move(allStates), nullptr, std::move(choiceValues)); } diff --git a/src/storm/solver/Multiplier.cpp b/src/storm/solver/Multiplier.cpp index 97eb22bb8..f36a6c87c 100644 --- a/src/storm/solver/Multiplier.cpp +++ b/src/storm/solver/Multiplier.cpp @@ -68,12 +68,46 @@ namespace storm { } } + template + void Multiplier::repeatedMultiplyAndReduceWithChoices(Environment const& env, OptimizationDirection const& dir, std::vector& x, std::vector const* b, uint64_t n, storm::storage::BitVector const* dirOverride, std::vector& choiceValues, std::vector::index_type> rowGroupIndices) const { + storm::utility::ProgressMeasurement progress("multiplications"); + progress.setMaxCount(n); + progress.startNewMeasurement(0); + for (uint64_t i = 0; i < n; ++i) { + + multiply(env, x, b, choiceValues); + reduce(env, dir, choiceValues, rowGroupIndices, x); + + multiplyAndReduce(env, dir, x, b, x); + if (storm::utility::resources::isTerminate()) { + STORM_LOG_WARN("Aborting after " << i << " of " << n << " multiplications"); + break; + } + } + } + template void Multiplier::multiplyRow2(uint64_t const& rowIndex, std::vector const& x1, ValueType& val1, std::vector const& x2, ValueType& val2) const { multiplyRow(rowIndex, x1, val1); multiplyRow(rowIndex, x2, val2); } + template + void Multiplier::reduce(Environment const& env, OptimizationDirection const& dir, std::vector const& choiceValues, std::vector::index_type> rowGroupIndices, std::vector& result, storm::storage::BitVector const* dirOverride) const { + auto choice_it = choiceValues.begin(); + for(uint state = 0; state < rowGroupIndices.size() - 1; state++) { + uint rowGroupSize = rowGroupIndices[state + 1] - rowGroupIndices[state]; + if((dir == storm::OptimizationDirection::Minimize && !dirOverride->get(state)) || (dir == storm::OptimizationDirection::Maximize && dirOverride->get(state))) { + result.at(state) = *std::min_element(choice_it, choice_it + rowGroupSize); + choice_it += rowGroupSize; + } + else { + result.at(state) = *std::max_element(choice_it, choice_it + rowGroupSize); + choice_it += rowGroupSize; + } + } + } + template std::unique_ptr> MultiplierFactory::create(Environment const& env, storm::storage::SparseMatrix const& matrix) { auto type = env.solver().multiplier().getType(); diff --git a/src/storm/solver/Multiplier.h b/src/storm/solver/Multiplier.h index 4127a6625..5e132a628 100644 --- a/src/storm/solver/Multiplier.h +++ b/src/storm/solver/Multiplier.h @@ -8,6 +8,8 @@ #include "storm/solver/OptimizationDirection.h" #include "storm/solver/MultiplicationStyle.h" +#include "storm/storage/SparseMatrix.h" + namespace storm { @@ -119,6 +121,8 @@ namespace storm { */ void repeatedMultiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector& x, std::vector const* b, uint64_t n, storm::storage::BitVector const* dirOverride = nullptr) const; + void repeatedMultiplyAndReduceWithChoices(const Environment &env, const OptimizationDirection &dir, std::vector &x, const std::vector *b, uint64_t n, const storage::BitVector *dirOverride, std::vector &choiceValues, std::vector rowGroupIndices) const; + /*! * Multiplies the row with the given index with x and adds the result to the provided value * @param rowIndex The index of the considered row @@ -137,9 +141,12 @@ namespace storm { */ virtual void multiplyRow2(uint64_t const& rowIndex, std::vector const& x1, ValueType& val1, std::vector const& x2, ValueType& val2) const; + void reduce(Environment const& env, OptimizationDirection const& dir, std::vector const& choiceValues, std::vector::index_type> rowGroupIndices, std::vector& result, storm::storage::BitVector const* dirOverride = nullptr) const; + protected: mutable std::unique_ptr> cachedVector; storm::storage::SparseMatrix const& matrix; + }; template