diff --git a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp index 68d7631d0..a37d11782 100644 --- a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp +++ b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp @@ -25,7 +25,7 @@ namespace storm { } template - std::vector SparseNondeterministicStepBoundedHorizonHelper::compute(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, uint64_t lowerBound, uint64_t upperBound, ModelCheckerHint const& hint) + std::vector SparseNondeterministicStepBoundedHorizonHelper::compute(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, uint64_t lowerBound, uint64_t upperBound, ModelCheckerHint const& hint, storm::storage::BitVector& resultMaybeStates, std::vector& choiceValues) { std::vector result(transitionMatrix.getRowGroupCount(), storm::utility::zero()); storm::storage::BitVector makeZeroColumns; @@ -61,13 +61,27 @@ namespace storm { auto multiplier = storm::solver::MultiplierFactory().create(env, submatrix); if (lowerBound == 0) { - multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound); + if(goal.isShieldingTask()) + { + multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound, nullptr, choiceValues, transitionMatrix.getRowGroupIndices()); + } else { + multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound); + } } else { - multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1); - storm::storage::SparseMatrix submatrix = transitionMatrix.getSubmatrix(true, maybeStates, maybeStates, false); - auto multiplier = storm::solver::MultiplierFactory().create(env, submatrix); - b = std::vector(b.size(), storm::utility::zero()); - multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, lowerBound - 1); + if(goal.isShieldingTask()) + { + multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1, nullptr, choiceValues, transitionMatrix.getRowGroupIndices()); + storm::storage::SparseMatrix submatrix = transitionMatrix.getSubmatrix(true, maybeStates, maybeStates, false); + auto multiplier = storm::solver::MultiplierFactory().create(env, submatrix); + b = std::vector(b.size(), storm::utility::zero()); + multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, lowerBound - 1, nullptr, choiceValues, transitionMatrix.getRowGroupIndices()); + } else { + multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1); + storm::storage::SparseMatrix submatrix = transitionMatrix.getSubmatrix(true, maybeStates, maybeStates, false); + auto multiplier = storm::solver::MultiplierFactory().create(env, submatrix); + b = std::vector(b.size(), storm::utility::zero()); + multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, lowerBound - 1); + } } // Set the values of the resulting vector accordingly. storm::utility::vector::setVectorValues(result, maybeStates, subresult); @@ -75,10 +89,15 @@ namespace storm { if (lowerBound == 0) { storm::utility::vector::setVectorValues(result, psiStates, storm::utility::one()); } + + //TODO: check if this works with nullptr as default for resultMaybeStates + if(!resultMaybeStates.empty()) + { + resultMaybeStates = maybeStates; + } return result; } - template class SparseNondeterministicStepBoundedHorizonHelper; template class SparseNondeterministicStepBoundedHorizonHelper; } diff --git a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h index 98fc49c2c..0a0a0fc51 100644 --- a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h +++ b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h @@ -4,6 +4,7 @@ #include "storm/modelchecker/hints/ModelCheckerHint.h" #include "storm/modelchecker/prctl/helper/SolutionType.h" #include "storm/storage/SparseMatrix.h" +#include "storm/storage/BitVector.h" #include "storm/utility/solver.h" #include "storm/solver/SolveGoal.h" @@ -15,7 +16,7 @@ namespace storm { class SparseNondeterministicStepBoundedHorizonHelper { public: SparseNondeterministicStepBoundedHorizonHelper(); - std::vector compute(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, uint64_t lowerBound, uint64_t upperBound, ModelCheckerHint const& hint = ModelCheckerHint()); + std::vector compute(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, uint64_t lowerBound, uint64_t upperBound, ModelCheckerHint const& hint = ModelCheckerHint(), storm::storage::BitVector& resultMaybeStates = nullptr, std::vector& choiceValues = nullptr); private: };