From df8b893417ed5f548014dc264510f4e9f847f30e Mon Sep 17 00:00:00 2001 From: Stefan Pranger Date: Tue, 22 Dec 2020 17:38:56 +0100 Subject: [PATCH] change optimization direction if overridden --- src/storm/solver/GmmxxMultiplier.cpp | 30 ++++++++++++++++++---------- src/storm/solver/Multiplier.cpp | 6 ++++++ src/storm/solver/Multiplier.h | 5 +++++ 3 files changed, 31 insertions(+), 10 deletions(-) diff --git a/src/storm/solver/GmmxxMultiplier.cpp b/src/storm/solver/GmmxxMultiplier.cpp index 6e0f727b0..15b1aefbd 100644 --- a/src/storm/solver/GmmxxMultiplier.cpp +++ b/src/storm/solver/GmmxxMultiplier.cpp @@ -1,5 +1,7 @@ #include "storm/solver/GmmxxMultiplier.h" +#include + #include "storm/adapters/RationalNumberAdapter.h" #include "storm/adapters/RationalFunctionAdapter.h" #include "storm/adapters/IntelTbbAdapter.h" @@ -166,23 +168,29 @@ namespace storm { choice_it = backwards ? choices->end() - 1 : choices->begin(); } + boost::optional optimizationDirectionOverride; + if(this->getOptimizationDirectionOverride().is_initialized()) { + optimizationDirectionOverride = this->getOptimizationDirectionOverride(); + } + // Variables for correctly tracking choices (only update if new choice is strictly better). ValueType oldSelectedChoiceValue; uint64_t selectedChoice; uint64_t currentRow = backwards ? gmmMatrix.nrows() - 1 : 0; + uint64_t currentRowGroup = backwards ? rowGroupIndices.size() - 1 : 0; auto row_group_it = backwards ? rowGroupIndices.end() - 2 : rowGroupIndices.begin(); auto row_group_ite = backwards ? rowGroupIndices.begin() - 1 : rowGroupIndices.end() - 1; while (row_group_it != row_group_ite) { ValueType currentValue = storm::utility::zero(); - + // Only multiply and reduce if the row group is not empty. if (*row_group_it != *(row_group_it + 1)) { // Process the (backwards ? last : first) row of the current row group if (b) { currentValue = *add_it; } - + currentValue += vect_sp(gmm::linalg_traits::row(itr), x); if (choices) { @@ -202,18 +210,18 @@ namespace storm { ++currentRow; ++add_it; } - + // Process the (rowGroupSize-1) remaining rows within the current row Group uint64_t rowGroupSize = *(row_group_it + 1) - *row_group_it; for (uint64_t i = 1; i < rowGroupSize; ++i) { ValueType newValue = b ? *add_it : storm::utility::zero(); newValue += vect_sp(gmm::linalg_traits::row(itr), x); - + if (choices && currentRow == *choice_it + *row_group_it) { oldSelectedChoiceValue = newValue; } - if (compare(newValue, currentValue)) { + if(isOverridden(currentRowGroup) ? !compare(newValue, currentValue) : compare(newValue, currentValue)) { currentValue = newValue; if (choices) { selectedChoice = currentRow - *row_group_it; @@ -230,33 +238,35 @@ namespace storm { ++add_it; } } - + // Finally write value to target vector. *target_it = currentValue; - if (choices && compare(currentValue, oldSelectedChoiceValue)) { + if(choices && isOverridden(currentRowGroup) ? !compare(currentValue, oldSelectedChoiceValue) : compare(currentValue, oldSelectedChoiceValue) ) { *choice_it = selectedChoice; } } - + // move rowGroup-based iterators to the next row group if (backwards) { --row_group_it; --choice_it; --target_it; + --currentRowGroup; } else { ++row_group_it; ++choice_it; ++target_it; + ++currentRowGroup; } } } - + template<> template void GmmxxMultiplier::multAddReduceHelper(std::vector const& rowGroupIndices, std::vector const& x, std::vector const* b, std::vector& result, std::vector* choices) const { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Operation not supported for this data type."); } - + template void GmmxxMultiplier::multAddParallel(std::vector const& x, std::vector const* b, std::vector& result) const { #ifdef STORM_HAVE_INTELTBB diff --git a/src/storm/solver/Multiplier.cpp b/src/storm/solver/Multiplier.cpp index 87e74b16a..556a4dfa9 100644 --- a/src/storm/solver/Multiplier.cpp +++ b/src/storm/solver/Multiplier.cpp @@ -84,6 +84,12 @@ namespace storm { return optimizationDirectionOverride; } + template + bool Multiplier::isOverridden(uint_fast64_t const index) const { + if(!optimizationDirectionOverride.is_initialized()) return false; + return optimizationDirectionOverride.get().get(index); + } + template std::unique_ptr> MultiplierFactory::create(Environment const& env, storm::storage::SparseMatrix const& matrix) { auto type = env.solver().multiplier().getType(); diff --git a/src/storm/solver/Multiplier.h b/src/storm/solver/Multiplier.h index 552a427dd..acf754e94 100644 --- a/src/storm/solver/Multiplier.h +++ b/src/storm/solver/Multiplier.h @@ -147,6 +147,11 @@ namespace storm { */ boost::optional getOptimizationDirectionOverride() const; + /* + * TODO + */ + bool isOverridden(uint_fast64_t const index) const; + protected: mutable std::unique_ptr> cachedVector; storm::storage::SparseMatrix const& matrix;