diff --git a/src/storm/solver/Multiplier.cpp b/src/storm/solver/Multiplier.cpp index e5df00b05..4e4fa9abe 100644 --- a/src/storm/solver/Multiplier.cpp +++ b/src/storm/solver/Multiplier.cpp @@ -75,7 +75,7 @@ namespace storm { progress.startNewMeasurement(0); for (uint64_t i = 0; i < n; ++i) { multiply(env, x, b, choiceValues); - reduce(env, dir, choiceValues, rowGroupIndices, x); + reduce(env, dir, rowGroupIndices, choiceValues, x); if (storm::utility::resources::isTerminate()) { STORM_LOG_WARN("Aborting after " << i << " of " << n << " multiplications"); break; @@ -90,33 +90,34 @@ namespace storm { } template - void Multiplier::reduce(Environment const& env, OptimizationDirection const& dir, std::vector const& choiceValues, std::vector::index_type> rowGroupIndices, std::vector& result, storm::storage::BitVector const* dirOverride) const { - auto choice_it = choiceValues.begin(); + void Multiplier::reduce(Environment const& env, OptimizationDirection const& dir, std::vector::index_type> const& rowGroupIndices, std::vector const& choiceValues, std::vector& result, std::vector* choices, storm::storage::BitVector const* dirOverride) const { + auto choiceValue_it = choiceValues.begin(); + auto optChoice_it = choiceValues.begin(); for(uint state = 0; state < rowGroupIndices.size(); state++) { uint rowGroupSize; if(state == 0) { rowGroupSize = rowGroupIndices[state]; } else { - rowGroupSize = rowGroupIndices[state] - rowGroupIndices[state-1]; + rowGroupSize = rowGroupIndices[state] - rowGroupIndices[state - 1]; } if(dirOverride != nullptr) { if((dir == storm::OptimizationDirection::Minimize && !dirOverride->get(state)) || (dir == storm::OptimizationDirection::Maximize && dirOverride->get(state))) { - result.at(state) = *std::min_element(choice_it, choice_it + rowGroupSize); - choice_it += rowGroupSize; - } - else { - result.at(state) = *std::max_element(choice_it, choice_it + rowGroupSize); - choice_it += rowGroupSize; + optChoice_it = std::min_element(choiceValue_it, choiceValue_it + rowGroupSize); + } else { + optChoice_it = std::max_element(choiceValue_it, choiceValue_it + rowGroupSize); } } else { if(dir == storm::OptimizationDirection::Minimize) { - result.at(state) = *std::min_element(choice_it, choice_it + rowGroupSize); - choice_it += rowGroupSize; + optChoice_it = std::min_element(choiceValue_it, choiceValue_it + rowGroupSize); } else { - result.at(state) = *std::max_element(choice_it, choice_it + rowGroupSize); - choice_it += rowGroupSize; + optChoice_it = std::max_element(choiceValue_it, choiceValue_it + rowGroupSize); } } + result.at(state) = *optChoice_it; + if(choices) { + choices->at(state) = std::distance(choiceValue_it, optChoice_it); + } + choiceValue_it += rowGroupSize; } } diff --git a/src/storm/solver/Multiplier.h b/src/storm/solver/Multiplier.h index 5e132a628..fc300ff4b 100644 --- a/src/storm/solver/Multiplier.h +++ b/src/storm/solver/Multiplier.h @@ -141,7 +141,7 @@ namespace storm { */ virtual void multiplyRow2(uint64_t const& rowIndex, std::vector const& x1, ValueType& val1, std::vector const& x2, ValueType& val2) const; - void reduce(Environment const& env, OptimizationDirection const& dir, std::vector const& choiceValues, std::vector::index_type> rowGroupIndices, std::vector& result, storm::storage::BitVector const* dirOverride = nullptr) const; + void reduce(Environment const& env, OptimizationDirection const& dir, std::vector::index_type> const& rowGroupIndices, std::vector const& choiceValues, std::vector& result, std::vector* choices = nullptr, storm::storage::BitVector const* dirOverride = nullptr) const; protected: mutable std::unique_ptr> cachedVector;