diff --git a/src/storm/solver/GmmxxMultiplier.cpp b/src/storm/solver/GmmxxMultiplier.cpp index 939a4b61e..e992e4457 100644 --- a/src/storm/solver/GmmxxMultiplier.cpp +++ b/src/storm/solver/GmmxxMultiplier.cpp @@ -162,7 +162,7 @@ namespace storm { --itr; --currentRow; - for (uint64_t row = *row_group_it + 1, rowEnd = *(row_group_it + 1); row < rowEnd; ++row, ++currentRow, --itr, --add_it) { + for (uint64_t row = *row_group_it + 1, rowEnd = *(row_group_it + 1); row < rowEnd; ++row, --currentRow, --itr, --add_it) { ValueType newValue = b ? *add_it : storm::utility::zero(); newValue += vect_sp(gmm::linalg_traits::row(itr), x); @@ -170,7 +170,7 @@ namespace storm { oldSelectedChoiceValue = newValue; } - if (min ? currentValue < oldSelectedChoiceValue : currentValue > oldSelectedChoiceValue) { + if (min ? newValue < currentValue : newValue > currentValue) { currentValue = newValue; if (choices) { selectedChoice = currentRow - *row_group_it; diff --git a/src/storm/utility/vector.h b/src/storm/utility/vector.h index c283d7495..56494dc03 100644 --- a/src/storm/utility/vector.h +++ b/src/storm/utility/vector.h @@ -615,31 +615,46 @@ namespace storm { typename std::vector::const_iterator sourceIt = source.begin() + *rowGroupingIt; typename std::vector::const_iterator sourceIte; typename std::vector::iterator choiceIt; - uint_fast64_t localChoice; - if (choices != nullptr) { + if (choices) { choiceIt = choices->begin() + startRow; } - + + // Variables for correctly tracking choices (only update if new choice is strictly better). + T oldSelectedChoiceValue; + uint64_t selectedChoice; + + uint64_t currentRow = 0; for (; targetIt != targetIte; ++targetIt, ++rowGroupingIt, ++choiceIt) { // Only traverse elements if the row group is non-empty. if (*rowGroupingIt != *(rowGroupingIt + 1)) { *targetIt = *sourceIt; - ++sourceIt; - localChoice = 1; - if (choices != nullptr) { - *choiceIt = 0; + + if (choices) { + selectedChoice = 0; + if (*choiceIt == 0) { + oldSelectedChoiceValue = *targetIt; + } } - for (sourceIte = source.begin() + *(rowGroupingIt + 1); sourceIt != sourceIte; ++sourceIt, ++localChoice) { + ++sourceIt; + ++currentRow; + + for (sourceIte = source.begin() + *(rowGroupingIt + 1); sourceIt != sourceIte; ++sourceIt, ++currentRow) { + if (choices && *choiceIt + *rowGroupingIt == currentRow) { + oldSelectedChoiceValue = *sourceIt; + } + if (f(*sourceIt, *targetIt)) { *targetIt = *sourceIt; - if (choices != nullptr) { - *choiceIt = localChoice; + if (choices) { + selectedChoice = std::distance(source.begin(), sourceIt) - *rowGroupingIt; } } } - } else { - *targetIt = storm::utility::zero(); + + if (choices && f(*targetIt, oldSelectedChoiceValue)) { + *choiceIt = selectedChoice; + } } } } @@ -680,6 +695,7 @@ namespace storm { T oldSelectedChoiceValue; uint64_t selectedChoice; + uint64_t currentRow = 0; for (; targetIt != targetIte; ++targetIt, ++rowGroupingIt, ++choiceIt) { // Only traverse elements if the row group is non-empty. if (*rowGroupingIt != *(rowGroupingIt + 1)) { @@ -693,8 +709,10 @@ namespace storm { } ++sourceIt; - for (sourceIte = source.begin() + *(rowGroupingIt + 1); sourceIt != sourceIte; ++sourceIt) { - if (choices && selectedChoice == std::distance(source.begin(), sourceIt) - *rowGroupingIt) { + ++currentRow; + + for (sourceIte = source.begin() + *(rowGroupingIt + 1); sourceIt != sourceIte; ++sourceIt, ++currentRow) { + if (choices && *rowGroupingIt + *choiceIt == currentRow) { oldSelectedChoiceValue = *sourceIt; } @@ -710,6 +728,7 @@ namespace storm { *choiceIt = selectedChoice; } } else { + *choiceIt = 0; *targetIt = storm::utility::zero(); } }