From fa4a1aa68f24517cfd38d06daec991ebcb2a6903 Mon Sep 17 00:00:00 2001
From: Mavo <matthias.volk@rwth-aachen.de>
Date: Mon, 22 Feb 2016 15:44:37 +0100
Subject: [PATCH] Fixed bug with filtering reward vector

Former-commit-id: ad709ad0dd3da3c062480f1cd23e910e46064dab
---
 .../SparseDtmcEliminationModelChecker.cpp     | 23 +++++++++++++++----
 .../SparseDtmcEliminationModelChecker.h       |  2 ++
 2 files changed, 20 insertions(+), 5 deletions(-)

diff --git a/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.cpp b/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.cpp
index 2ea8fa459..5425044b2 100644
--- a/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.cpp
+++ b/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.cpp
@@ -584,11 +584,13 @@ namespace storm {
             
             // Do some sanity checks to establish some required properties.
             RewardModelType const& rewardModel = this->getModel().getRewardModel(checkTask.isRewardModelSet() ? checkTask.getRewardModel() : "");
-            // TODO use maybeStates instead of all states
-            std::vector<ValueType> stateRewardValues = rewardModel.getTotalRewardVector(trueStates.getNumberOfSetBits(), this->getModel().getTransitionMatrix(), trueStates);
 
             STORM_LOG_THROW(!rewardModel.empty(), storm::exceptions::IllegalArgumentException, "Input model does not have a reward model.");
-            std::vector<ValueType> result = computeReachabilityRewards(this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), this->getModel().getInitialStates(), targetStates, stateRewardValues, checkTask.isOnlyInitialStatesRelevantSet());
+            std::vector<ValueType> result = computeReachabilityRewards(this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), this->getModel().getInitialStates(), targetStates,
+                                                                       [&] (uint_fast64_t numberOfRows, storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::BitVector const& maybeStates) {
+                                                                           return rewardModel.getTotalRewardVector(numberOfRows, transitionMatrix, maybeStates);
+                                                                       },
+                                                                       checkTask.isOnlyInitialStatesRelevantSet());
 
             // Construct check result.
             std::unique_ptr<CheckResult> checkResult(new ExplicitQuantitativeCheckResult<ValueType>(result));
@@ -597,6 +599,17 @@ namespace storm {
 
         template<typename SparseDtmcModelType>
         std::vector<typename SparseDtmcModelType::ValueType> SparseDtmcEliminationModelChecker<SparseDtmcModelType>::computeReachabilityRewards(storm::storage::SparseMatrix<ValueType> const& probabilityMatrix, storm::storage::SparseMatrix<ValueType> const& backwardTransitions, storm::storage::BitVector const& initialStates, storm::storage::BitVector const& targetStates, std::vector<ValueType>& stateRewardValues, bool computeForInitialStatesOnly) {
+            return computeReachabilityRewards(probabilityMatrix, backwardTransitions, initialStates, targetStates,
+                                              [&] (uint_fast64_t numberOfRows, storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::BitVector const& maybeStates) {
+                                                  std::vector<ValueType> result(numberOfRows);
+                                                  storm::utility::vector::selectVectorValues(result, maybeStates, stateRewardValues);
+                                                  return result;
+                                              },
+                                              computeForInitialStatesOnly);
+        }
+
+        template<typename SparseDtmcModelType>
+        std::vector<typename SparseDtmcModelType::ValueType> SparseDtmcEliminationModelChecker<SparseDtmcModelType>::computeReachabilityRewards(storm::storage::SparseMatrix<ValueType> const& probabilityMatrix, storm::storage::SparseMatrix<ValueType> const& backwardTransitions, storm::storage::BitVector const& initialStates, storm::storage::BitVector const& targetStates, std::function<std::vector<ValueType>(uint_fast64_t, storm::storage::SparseMatrix<ValueType> const&, storm::storage::BitVector const&)> const& totalStateRewardVectorGetter, bool computeForInitialStatesOnly) {
 
             uint_fast64_t numberOfStates = probabilityMatrix.getRowCount();
             
@@ -639,9 +652,9 @@ namespace storm {
                 storm::storage::SparseMatrix<ValueType> submatrixTransposed = submatrix.transpose();
                 
                 // Project the state reward vector to all maybe-states.
-                storm::storage::BitVector phiStates(numberOfStates, true);
+                std::vector<ValueType> stateRewardValues = totalStateRewardVectorGetter(submatrix.getRowCount(), probabilityMatrix, maybeStates);
 
-                std::vector<ValueType> subresult = computeReachabilityValues(submatrix, stateRewardValues, submatrixTransposed, newInitialStates, computeForInitialStatesOnly, phiStates, targetStates, probabilityMatrix.getConstrainedRowSumVector(maybeStates, targetStates));
+                std::vector<ValueType> subresult = computeReachabilityValues(submatrix, stateRewardValues, submatrixTransposed, newInitialStates, computeForInitialStatesOnly, trueStates, targetStates, probabilityMatrix.getConstrainedRowSumVector(maybeStates, targetStates));
                 storm::utility::vector::setVectorValues<ValueType>(result, maybeStates, subresult);
             }
             
diff --git a/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.h b/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.h
index 76e939d66..33b61a2c4 100644
--- a/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.h
+++ b/src/modelchecker/reachability/SparseDtmcEliminationModelChecker.h
@@ -91,6 +91,8 @@ namespace storm {
             
             static std::vector<ValueType> computeLongRunValues(storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::SparseMatrix<ValueType> const& backwardTransitions, storm::storage::BitVector const& initialStates, storm::storage::BitVector const& maybeStates, bool computeResultsForInitialStatesOnly, std::vector<ValueType>& stateValues);
             
+            static std::vector<ValueType> computeReachabilityRewards(storm::storage::SparseMatrix<ValueType> const& probabilityMatrix, storm::storage::SparseMatrix<ValueType> const& backwardTransitions, storm::storage::BitVector const& initialStates, storm::storage::BitVector const& targetStates, std::function<std::vector<ValueType>(uint_fast64_t, storm::storage::SparseMatrix<ValueType> const&, storm::storage::BitVector const&)> const& totalStateRewardVectorGetter, bool computeForInitialStatesOnly);
+
             static std::vector<ValueType> computeReachabilityValues(storm::storage::SparseMatrix<ValueType> const& transitionMatrix, std::vector<ValueType>& values, storm::storage::SparseMatrix<ValueType> const& backwardTransitions, storm::storage::BitVector const& initialStates, bool computeResultsForInitialStatesOnly, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, std::vector<ValueType> const& oneStepProbabilitiesToTarget);
             
             static std::shared_ptr<StatePriorityQueue<ValueType>> createStatePriorityQueue(boost::optional<std::vector<uint_fast64_t>> const& stateDistances, storm::storage::FlexibleSparseMatrix<ValueType> const& transitionMatrix, storm::storage::FlexibleSparseMatrix<ValueType> const& backwardTransitions, std::vector<ValueType>& oneStepProbabilities, storm::storage::BitVector const& states);