diff --git a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp index 2340d9e76..0d88c456e 100644 --- a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp +++ b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.cpp @@ -17,6 +17,30 @@ namespace storm { namespace modelchecker { namespace helper { + template + void SparseNondeterministicStepBoundedHorizonHelper::getMaybeStatesRowGroupSizes(storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector& maybeStatesRowGroupSizes, uint& choiceValuesCounter) { + std::vector rowGroupIndices = transitionMatrix.getRowGroupIndices(); + for(uint counter = 0; counter < maybeStates.size(); counter++) { + if(maybeStates.get(counter)) { + maybeStatesRowGroupSizes.push_back(rowGroupIndices.at(counter)); + choiceValuesCounter += transitionMatrix.getRowGroupSize(counter); + } + } + } + + template + void SparseNondeterministicStepBoundedHorizonHelper::getChoiceValues(storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector const& choiceValues, std::vector& allChoices) { + auto choice_it = choiceValues.begin(); + for(uint state = 0; state < transitionMatrix.getRowGroupIndices().size() - 1; state++) { + uint rowGroupSize = transitionMatrix.getRowGroupSize(state); + if (maybeStates.get(state)) { + for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { + allChoices.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = *choice_it; + } + } + } + } + template SparseNondeterministicStepBoundedHorizonHelper::SparseNondeterministicStepBoundedHorizonHelper(/*storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions*/) //transitionMatrix(transitionMatrix), backwardTransitions(backwardTransitions) @@ -60,34 +84,21 @@ namespace storm { std::vector subresult(maybeStates.getNumberOfSetBits()); auto multiplier = storm::solver::MultiplierFactory().create(env, submatrix); + + std::vector rowGroupIndices = transitionMatrix.getRowGroupIndices(); + std::vector maybeStatesRowGroupSizes; + uint choiceValuesCounter; + getMaybeStatesRowGroupSizes(transitionMatrix, maybeStates, maybeStatesRowGroupSizes, choiceValuesCounter); + choiceValues = std::vector(choiceValuesCounter, storm::utility::zero()); + if (lowerBound == 0) { if(goal.isShieldingTask()) { - - std::vector::index_type> rowGroupIndices = transitionMatrix.getRowGroupIndices(); - std::vector::index_type> reducedRowGroupIndices; - uint sizeChoiceValues = 0; - for(uint counter = 0; counter < maybeStates.size(); counter++) { - if(maybeStates.get(counter)) { - sizeChoiceValues += transitionMatrix.getRowGroupSize(counter); - reducedRowGroupIndices.push_back(rowGroupIndices.at(counter)); - } - } - choiceValues = std::vector(sizeChoiceValues, storm::utility::zero()); - - multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound, nullptr, choiceValues, reducedRowGroupIndices); + multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound, nullptr, choiceValues, maybeStatesRowGroupSizes); // fill up choicesValues for shields - std::vector allChoices = std::vector(transitionMatrix.getRowGroupIndices().at(transitionMatrix.getRowGroupIndices().size() - 1), storm::utility::zero()); - auto choice_it = choiceValues.begin(); - for(uint state = 0; state < transitionMatrix.getRowGroupIndices().size() - 1; state++) { - uint rowGroupSize = transitionMatrix.getRowGroupIndices().at(state + 1) - transitionMatrix.getRowGroupIndices().at(state); - if (maybeStates.get(state)) { - for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { - allChoices.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = *choice_it; - } - } - } + std::vector allChoices = std::vector(transitionMatrix.getRowCount(), storm::utility::zero()); + getChoiceValues(transitionMatrix, maybeStates, choiceValues, allChoices); choiceValues = allChoices; } else { multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound); @@ -95,34 +106,15 @@ namespace storm { } else { if(goal.isShieldingTask()) { - std::vector::index_type> rowGroupIndices = transitionMatrix.getRowGroupIndices(); - std::vector::index_type> reducedRowGroupIndices; - uint sizeChoiceValues = 0; - for(uint counter = 0; counter < maybeStates.size(); counter++) { - if(maybeStates.get(counter)) { - sizeChoiceValues += transitionMatrix.getRowGroupSize(counter); - reducedRowGroupIndices.push_back(rowGroupIndices.at(counter)); - } - } - choiceValues = std::vector(sizeChoiceValues, storm::utility::zero()); - - multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1, nullptr, choiceValues, reducedRowGroupIndices); + multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1, nullptr, choiceValues, maybeStatesRowGroupSizes); storm::storage::SparseMatrix submatrix = transitionMatrix.getSubmatrix(true, maybeStates, maybeStates, false); auto multiplier = storm::solver::MultiplierFactory().create(env, submatrix); b = std::vector(b.size(), storm::utility::zero()); - multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, lowerBound - 1, nullptr, choiceValues, reducedRowGroupIndices); + multiplier->repeatedMultiplyAndReduceWithChoices(env, goal.direction(), subresult, &b, lowerBound - 1, nullptr, choiceValues, maybeStatesRowGroupSizes); // fill up choicesValues for shields - std::vector allChoices = std::vector(transitionMatrix.getRowGroupIndices().at(transitionMatrix.getRowGroupIndices().size() - 1), storm::utility::zero()); - auto choice_it = choiceValues.begin(); - for(uint state = 0; state < transitionMatrix.getRowGroupIndices().size() - 1; state++) { - uint rowGroupSize = transitionMatrix.getRowGroupIndices().at(state + 1) - transitionMatrix.getRowGroupIndices().at(state); - if (maybeStates.get(state)) { - for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { - allChoices.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = *choice_it; - } - } - } + std::vector allChoices = std::vector(transitionMatrix.getRowCount(), storm::utility::zero()); + getChoiceValues(transitionMatrix, maybeStates, choiceValues, allChoices); choiceValues = allChoices; } else { multiplier->repeatedMultiplyAndReduce(env, goal.direction(), subresult, &b, upperBound - lowerBound + 1); @@ -147,4 +139,4 @@ namespace storm { template class SparseNondeterministicStepBoundedHorizonHelper; } } -} \ No newline at end of file +} diff --git a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h index 0a0a0fc51..0ef64fe2c 100644 --- a/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h +++ b/src/storm/modelchecker/helper/finitehorizon/SparseNondeterministicStepBoundedHorizonHelper.h @@ -16,6 +16,9 @@ namespace storm { class SparseNondeterministicStepBoundedHorizonHelper { public: SparseNondeterministicStepBoundedHorizonHelper(); + + void getMaybeStatesRowGroupSizes(storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector& maybeStatesRowGroupSizes, uint& choiceValuesCounter); + void getChoiceValues(storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates, std::vector const& choiceValues, std::vector& allChoices); std::vector compute(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, uint64_t lowerBound, uint64_t upperBound, ModelCheckerHint const& hint = ModelCheckerHint(), storm::storage::BitVector& resultMaybeStates = nullptr, std::vector& choiceValues = nullptr); private: };