From 06e3d4a331a04c20c1625483345b9b2a769aaea1 Mon Sep 17 00:00:00 2001
From: dehnert <dehnert@cs.rwth-aachen.de>
Date: Thu, 5 Apr 2018 22:56:36 +0200
Subject: [PATCH] fixing issues in gmmxx TBB multiply-and-reduce

---
 src/storm/solver/GmmxxMultiplier.cpp | 44 +++++++++++++++++++---------
 1 file changed, 30 insertions(+), 14 deletions(-)

diff --git a/src/storm/solver/GmmxxMultiplier.cpp b/src/storm/solver/GmmxxMultiplier.cpp
index 9618e2495..a11675b75 100644
--- a/src/storm/solver/GmmxxMultiplier.cpp
+++ b/src/storm/solver/GmmxxMultiplier.cpp
@@ -210,9 +210,13 @@ namespace storm {
             }
             
             void operator()(tbb::blocked_range<unsigned long> const& range) const {
+                typedef std::vector<ValueType> VectorType;
+                typedef gmm::csr_matrix<ValueType> MatrixType;
+
+                bool min = dir == OptimizationDirection::Minimize;
                 auto groupIt = rowGroupIndices.begin() + range.begin();
                 auto groupIte = rowGroupIndices.begin() + range.end();
-
+                
                 auto itr = mat_row_const_begin(matrix) + *groupIt;
                 typename std::vector<ValueType>::const_iterator bIt;
                 if (b) {
@@ -224,12 +228,13 @@ namespace storm {
                 }
                 
                 auto resultIt = result.begin() + range.begin();
-
+                
+                // Variables for correctly tracking choices (only update if new choice is strictly better).
+                ValueType oldSelectedChoiceValue;
+                uint64_t selectedChoice;
+                
+                uint64_t currentRow = *groupIt;
                 for (; groupIt != groupIte; ++groupIt, ++resultIt, ++choiceIt) {
-                    if (choices) {
-                        *choiceIt = 0;
-                    }
-                    
                     ValueType currentValue = storm::utility::zero<ValueType>();
                     
                     // Only multiply and reduce if the row group is not empty.
@@ -239,25 +244,36 @@ namespace storm {
                             ++bIt;
                         }
                         
-                        ++itr;
+                        currentValue += vect_sp(gmm::linalg_traits<MatrixType>::row(itr), x);
                         
-                        for (auto itre = mat_row_const_begin(matrix) + *(groupIt + 1); itr != itre; ++itr) {
-                            ValueType newValue = vect_sp(gmm::linalg_traits<gmm::csr_matrix<ValueType>>::row(itr), x, typename gmm::linalg_traits<gmm::csr_matrix<ValueType>>::storage_type(), typename gmm::linalg_traits<std::vector<ValueType>>::storage_type());
-                            if (b) {
-                                newValue += *bIt;
-                                ++bIt;
+                        if (choices) {
+                            selectedChoice = currentRow - *groupIt;
+                            if (*choiceIt == selectedChoice) {
+                                oldSelectedChoiceValue = currentValue;
                             }
+                        }
+                        
+                        ++itr;
+                        ++currentRow;
+                        
+                        for (auto itre = mat_row_const_begin(matrix) + *(groupIt + 1); itr != itre; ++itr, ++bIt, ++currentRow) {
+                            ValueType newValue = b ? *bIt : storm::utility::zero<ValueType>();
+                            newValue += vect_sp(gmm::linalg_traits<MatrixType>::row(itr), x);
                             
-                            if ((dir == OptimizationDirection::Minimize && newValue < currentValue) || (dir == OptimizationDirection::Maximize && newValue > currentValue)) {
+                            if (min ? newValue < currentValue : newValue > currentValue) {
                                 currentValue = newValue;
                                 if (choices) {
-                                    *choiceIt = std::distance(mat_row_const_begin(matrix), itr) - *groupIt;
+                                    selectedChoice = currentRow - *groupIt;
                                 }
                             }
                         }
                     }
                     
+                    // Finally write value to target vector.
                     *resultIt = currentValue;
+                    if (choices && (min ? currentValue < oldSelectedChoiceValue : currentValue > oldSelectedChoiceValue)) {
+                        *choiceIt = selectedChoice;
+                    }
                 }
             }