Browse Source

multiplier reduce also returns choices

tempestpy_adaptions
Stefan Pranger 3 years ago
parent
commit
5bf552a43a
  1. 27
      src/storm/solver/Multiplier.cpp
  2. 2
      src/storm/solver/Multiplier.h

27
src/storm/solver/Multiplier.cpp

@ -75,7 +75,7 @@ namespace storm {
progress.startNewMeasurement(0);
for (uint64_t i = 0; i < n; ++i) {
multiply(env, x, b, choiceValues);
reduce(env, dir, choiceValues, rowGroupIndices, x);
reduce(env, dir, rowGroupIndices, choiceValues, x);
if (storm::utility::resources::isTerminate()) {
STORM_LOG_WARN("Aborting after " << i << " of " << n << " multiplications");
break;
@ -90,8 +90,9 @@ namespace storm {
}
template<typename ValueType>
void Multiplier<ValueType>::reduce(Environment const& env, OptimizationDirection const& dir, std::vector<ValueType> const& choiceValues, std::vector<storm::storage::SparseMatrix<double>::index_type> rowGroupIndices, std::vector<ValueType>& result, storm::storage::BitVector const* dirOverride) const {
auto choice_it = choiceValues.begin();
void Multiplier<ValueType>::reduce(Environment const& env, OptimizationDirection const& dir, std::vector<storm::storage::SparseMatrix<double>::index_type> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::vector<ValueType>& result, std::vector<uint_fast64_t>* choices, storm::storage::BitVector const* dirOverride) const {
auto choiceValue_it = choiceValues.begin();
auto optChoice_it = choiceValues.begin();
for(uint state = 0; state < rowGroupIndices.size(); state++) {
uint rowGroupSize;
if(state == 0) {
@ -101,22 +102,22 @@ namespace storm {
}
if(dirOverride != nullptr) {
if((dir == storm::OptimizationDirection::Minimize && !dirOverride->get(state)) || (dir == storm::OptimizationDirection::Maximize && dirOverride->get(state))) {
result.at(state) = *std::min_element(choice_it, choice_it + rowGroupSize);
choice_it += rowGroupSize;
}
else {
result.at(state) = *std::max_element(choice_it, choice_it + rowGroupSize);
choice_it += rowGroupSize;
optChoice_it = std::min_element(choiceValue_it, choiceValue_it + rowGroupSize);
} else {
optChoice_it = std::max_element(choiceValue_it, choiceValue_it + rowGroupSize);
}
} else {
if(dir == storm::OptimizationDirection::Minimize) {
result.at(state) = *std::min_element(choice_it, choice_it + rowGroupSize);
choice_it += rowGroupSize;
optChoice_it = std::min_element(choiceValue_it, choiceValue_it + rowGroupSize);
} else {
result.at(state) = *std::max_element(choice_it, choice_it + rowGroupSize);
choice_it += rowGroupSize;
optChoice_it = std::max_element(choiceValue_it, choiceValue_it + rowGroupSize);
}
}
result.at(state) = *optChoice_it;
if(choices) {
choices->at(state) = std::distance(choiceValue_it, optChoice_it);
}
choiceValue_it += rowGroupSize;
}
}

2
src/storm/solver/Multiplier.h

@ -141,7 +141,7 @@ namespace storm {
*/
virtual void multiplyRow2(uint64_t const& rowIndex, std::vector<ValueType> const& x1, ValueType& val1, std::vector<ValueType> const& x2, ValueType& val2) const;
void reduce(Environment const& env, OptimizationDirection const& dir, std::vector<ValueType> const& choiceValues, std::vector<storm::storage::SparseMatrix<double>::index_type> rowGroupIndices, std::vector<ValueType>& result, storm::storage::BitVector const* dirOverride = nullptr) const;
void reduce(Environment const& env, OptimizationDirection const& dir, std::vector<storm::storage::SparseMatrix<double>::index_type> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::vector<ValueType>& result, std::vector<uint_fast64_t>* choices = nullptr, storm::storage::BitVector const* dirOverride = nullptr) const;
protected:
mutable std::unique_ptr<std::vector<ValueType>> cachedVector;

Loading…
Cancel
Save