Browse Source

change optimization direction if overridden

tempestpy_adaptions
Stefan Pranger 4 years ago
parent
commit
df8b893417
  1. 30
      src/storm/solver/GmmxxMultiplier.cpp
  2. 6
      src/storm/solver/Multiplier.cpp
  3. 5
      src/storm/solver/Multiplier.h

30
src/storm/solver/GmmxxMultiplier.cpp

@ -1,5 +1,7 @@
#include "storm/solver/GmmxxMultiplier.h"
#include <boost/optional.hpp>
#include "storm/adapters/RationalNumberAdapter.h"
#include "storm/adapters/RationalFunctionAdapter.h"
#include "storm/adapters/IntelTbbAdapter.h"
@ -166,23 +168,29 @@ namespace storm {
choice_it = backwards ? choices->end() - 1 : choices->begin();
}
boost::optional<storm::storage::BitVector> optimizationDirectionOverride;
if(this->getOptimizationDirectionOverride().is_initialized()) {
optimizationDirectionOverride = this->getOptimizationDirectionOverride();
}
// Variables for correctly tracking choices (only update if new choice is strictly better).
ValueType oldSelectedChoiceValue;
uint64_t selectedChoice;
uint64_t currentRow = backwards ? gmmMatrix.nrows() - 1 : 0;
uint64_t currentRowGroup = backwards ? rowGroupIndices.size() - 1 : 0;
auto row_group_it = backwards ? rowGroupIndices.end() - 2 : rowGroupIndices.begin();
auto row_group_ite = backwards ? rowGroupIndices.begin() - 1 : rowGroupIndices.end() - 1;
while (row_group_it != row_group_ite) {
ValueType currentValue = storm::utility::zero<ValueType>();
// Only multiply and reduce if the row group is not empty.
if (*row_group_it != *(row_group_it + 1)) {
// Process the (backwards ? last : first) row of the current row group
if (b) {
currentValue = *add_it;
}
currentValue += vect_sp(gmm::linalg_traits<MatrixType>::row(itr), x);
if (choices) {
@ -202,18 +210,18 @@ namespace storm {
++currentRow;
++add_it;
}
// Process the (rowGroupSize-1) remaining rows within the current row Group
uint64_t rowGroupSize = *(row_group_it + 1) - *row_group_it;
for (uint64_t i = 1; i < rowGroupSize; ++i) {
ValueType newValue = b ? *add_it : storm::utility::zero<ValueType>();
newValue += vect_sp(gmm::linalg_traits<MatrixType>::row(itr), x);
if (choices && currentRow == *choice_it + *row_group_it) {
oldSelectedChoiceValue = newValue;
}
if (compare(newValue, currentValue)) {
if(isOverridden(currentRowGroup) ? !compare(newValue, currentValue) : compare(newValue, currentValue)) {
currentValue = newValue;
if (choices) {
selectedChoice = currentRow - *row_group_it;
@ -230,33 +238,35 @@ namespace storm {
++add_it;
}
}
// Finally write value to target vector.
*target_it = currentValue;
if (choices && compare(currentValue, oldSelectedChoiceValue)) {
if(choices && isOverridden(currentRowGroup) ? !compare(currentValue, oldSelectedChoiceValue) : compare(currentValue, oldSelectedChoiceValue) ) {
*choice_it = selectedChoice;
}
}
// move rowGroup-based iterators to the next row group
if (backwards) {
--row_group_it;
--choice_it;
--target_it;
--currentRowGroup;
} else {
++row_group_it;
++choice_it;
++target_it;
++currentRowGroup;
}
}
}
template<>
template<typename Compare, bool backwards>
void GmmxxMultiplier<storm::RationalFunction>::multAddReduceHelper(std::vector<uint64_t> const& rowGroupIndices, std::vector<storm::RationalFunction> const& x, std::vector<storm::RationalFunction> const* b, std::vector<storm::RationalFunction>& result, std::vector<uint64_t>* choices) const {
STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Operation not supported for this data type.");
}
template<typename ValueType>
void GmmxxMultiplier<ValueType>::multAddParallel(std::vector<ValueType> const& x, std::vector<ValueType> const* b, std::vector<ValueType>& result) const {
#ifdef STORM_HAVE_INTELTBB

6
src/storm/solver/Multiplier.cpp

@ -84,6 +84,12 @@ namespace storm {
return optimizationDirectionOverride;
}
template<typename ValueType>
bool Multiplier<ValueType>::isOverridden(uint_fast64_t const index) const {
if(!optimizationDirectionOverride.is_initialized()) return false;
return optimizationDirectionOverride.get().get(index);
}
template<typename ValueType>
std::unique_ptr<Multiplier<ValueType>> MultiplierFactory<ValueType>::create(Environment const& env, storm::storage::SparseMatrix<ValueType> const& matrix) {
auto type = env.solver().multiplier().getType();

5
src/storm/solver/Multiplier.h

@ -147,6 +147,11 @@ namespace storm {
*/
boost::optional<storm::storage::BitVector> getOptimizationDirectionOverride() const;
/*
* TODO
*/
bool isOverridden(uint_fast64_t const index) const;
protected:
mutable std::unique_ptr<std::vector<ValueType>> cachedVector;
storm::storage::SparseMatrix<ValueType> const& matrix;

Loading…
Cancel
Save