diff --git a/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.cpp b/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.cpp index 2ea8fa459..5425044b2 100644 --- a/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.cpp +++ b/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.cpp @@ -584,11 +584,13 @@ namespace storm { // Do some sanity checks to establish some required properties. RewardModelType const& rewardModel = this->getModel().getRewardModel(checkTask.isRewardModelSet() ? checkTask.getRewardModel() : ""); - // TODO use maybeStates instead of all states - std::vector stateRewardValues = rewardModel.getTotalRewardVector(trueStates.getNumberOfSetBits(), this->getModel().getTransitionMatrix(), trueStates); STORM_LOG_THROW(!rewardModel.empty(), storm::exceptions::IllegalArgumentException, "Input model does not have a reward model."); - std::vector result = computeReachabilityRewards(this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), this->getModel().getInitialStates(), targetStates, stateRewardValues, checkTask.isOnlyInitialStatesRelevantSet()); + std::vector result = computeReachabilityRewards(this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), this->getModel().getInitialStates(), targetStates, + [&] (uint_fast64_t numberOfRows, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates) { + return rewardModel.getTotalRewardVector(numberOfRows, transitionMatrix, maybeStates); + }, + checkTask.isOnlyInitialStatesRelevantSet()); // Construct check result. std::unique_ptr checkResult(new ExplicitQuantitativeCheckResult(result)); @@ -597,6 +599,17 @@ namespace storm { template std::vector SparseDtmcEliminationModelChecker::computeReachabilityRewards(storm::storage::SparseMatrix const& probabilityMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& initialStates, storm::storage::BitVector const& targetStates, std::vector& stateRewardValues, bool computeForInitialStatesOnly) { + return computeReachabilityRewards(probabilityMatrix, backwardTransitions, initialStates, targetStates, + [&] (uint_fast64_t numberOfRows, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates) { + std::vector result(numberOfRows); + storm::utility::vector::selectVectorValues(result, maybeStates, stateRewardValues); + return result; + }, + computeForInitialStatesOnly); + } + + template + std::vector SparseDtmcEliminationModelChecker::computeReachabilityRewards(storm::storage::SparseMatrix const& probabilityMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& initialStates, storm::storage::BitVector const& targetStates, std::function(uint_fast64_t, storm::storage::SparseMatrix const&, storm::storage::BitVector const&)> const& totalStateRewardVectorGetter, bool computeForInitialStatesOnly) { uint_fast64_t numberOfStates = probabilityMatrix.getRowCount(); @@ -639,9 +652,9 @@ namespace storm { storm::storage::SparseMatrix submatrixTransposed = submatrix.transpose(); // Project the state reward vector to all maybe-states. - storm::storage::BitVector phiStates(numberOfStates, true); + std::vector stateRewardValues = totalStateRewardVectorGetter(submatrix.getRowCount(), probabilityMatrix, maybeStates); - std::vector subresult = computeReachabilityValues(submatrix, stateRewardValues, submatrixTransposed, newInitialStates, computeForInitialStatesOnly, phiStates, targetStates, probabilityMatrix.getConstrainedRowSumVector(maybeStates, targetStates)); + std::vector subresult = computeReachabilityValues(submatrix, stateRewardValues, submatrixTransposed, newInitialStates, computeForInitialStatesOnly, trueStates, targetStates, probabilityMatrix.getConstrainedRowSumVector(maybeStates, targetStates)); storm::utility::vector::setVectorValues(result, maybeStates, subresult); } diff --git a/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.h b/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.h index 76e939d66..33b61a2c4 100644 --- a/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.h +++ b/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.h @@ -91,6 +91,8 @@ namespace storm { static std::vector computeLongRunValues(storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& initialStates, storm::storage::BitVector const& maybeStates, bool computeResultsForInitialStatesOnly, std::vector& stateValues); + static std::vector computeReachabilityRewards(storm::storage::SparseMatrix const& probabilityMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& initialStates, storm::storage::BitVector const& targetStates, std::function(uint_fast64_t, storm::storage::SparseMatrix const&, storm::storage::BitVector const&)> const& totalStateRewardVectorGetter, bool computeForInitialStatesOnly); + static std::vector computeReachabilityValues(storm::storage::SparseMatrix const& transitionMatrix, std::vector& values, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& initialStates, bool computeResultsForInitialStatesOnly, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, std::vector const& oneStepProbabilitiesToTarget); static std::shared_ptr> createStatePriorityQueue(boost::optional> const& stateDistances, storm::storage::FlexibleSparseMatrix const& transitionMatrix, storm::storage::FlexibleSparseMatrix const& backwardTransitions, std::vector& oneStepProbabilities, storm::storage::BitVector const& states);