TimQu
7 years ago
4 changed files with 314 additions and 58 deletions
-
75src/storm/solver/Multiplier.cpp
-
126src/storm/solver/Multiplier.h
-
143src/storm/solver/NativeMultiplier.cpp
-
28src/storm/solver/NativeMultiplier.h
@ -0,0 +1,75 @@ |
|||
#include "storm/solver/Multiplier.h"
|
|||
|
|||
#include "storm-config.h"
|
|||
|
|||
#include "storm/storage/SparseMatrix.h"
|
|||
|
|||
#include "storm/adapters/RationalNumberAdapter.h"
|
|||
#include "storm/adapters/RationalFunctionAdapter.h"
|
|||
|
|||
#include "storm/utility/macros.h"
|
|||
#include "storm/solver/SolverSelectionOptions.h"
|
|||
#include "storm/solver/NativeMultiplier.h"
|
|||
#include "storm/environment/solver/MultiplierEnvironment.h"
|
|||
|
|||
namespace storm { |
|||
namespace solver { |
|||
|
|||
template<typename ValueType> |
|||
Multiplier<ValueType>::Multiplier(storm::storage::SparseMatrix<ValueType> const& matrix) : matrix(matrix), allowGaussSeidelMultiplications(false) { |
|||
// Intentionally left empty.
|
|||
} |
|||
|
|||
template<typename ValueType> |
|||
bool Multiplier<ValueType>::getAllowGaussSeidelMultiplications() const { |
|||
return allowGaussSeidelMultiplications; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
void Multiplier<ValueType>::setAllowGaussSeidelMultiplications(bool value) { |
|||
allowGaussSeidelMultiplications = value; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
void Multiplier<ValueType>::clearCache() const { |
|||
cachedVector.reset(); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
void Multiplier<ValueType>::repeatedMultiply(Environment const& env, std::vector<ValueType>& x, std::vector<ValueType> const* b, uint64_t n) const { |
|||
for (uint64_t i = 0; i < n; ++i) { |
|||
multiply(env, x, b, x); |
|||
} |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
void Multiplier<ValueType>::repeatedMultiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, uint64_t n) const { |
|||
for (uint64_t i = 0; i < n; ++i) { |
|||
multiplyAndReduce(env, dir, rowGroupIndices, x, b, x); |
|||
} |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
std::unique_ptr<Multiplier<ValueType>> MultiplierFactory<ValueType>::create(Environment const& env, storm::storage::SparseMatrix<ValueType> const& matrix) { |
|||
switch (env.solver().multiplier().getType()) { |
|||
case MultiplierType::Gmmxx: |
|||
//return std::make_unique<GmmxxMultiplier<ValueType>>(matrix);
|
|||
STORM_PRINT_AND_LOG("gmm mult not yet supported"); |
|||
case MultiplierType::Native: |
|||
return std::make_unique<NativeMultiplier<ValueType>>(matrix); |
|||
} |
|||
} |
|||
|
|||
|
|||
template class Multiplier<double>; |
|||
template class MultiplierFactory<double>; |
|||
|
|||
#ifdef STORM_HAVE_CARL
|
|||
template class Multiplier<storm::RationalNumber>; |
|||
template class MultiplierFactory<storm::RationalNumber>; |
|||
template class Multiplier<storm::RationalFunction>; |
|||
template class MultiplierFactory<storm::RationalFunction>; |
|||
#endif
|
|||
|
|||
} |
|||
} |
@ -0,0 +1,126 @@ |
|||
#pragma once |
|||
|
|||
#include "storm/solver/OptimizationDirection.h" |
|||
#include "storm/solver/MultiplicationStyle.h" |
|||
|
|||
namespace storm { |
|||
|
|||
class Environment; |
|||
|
|||
namespace storage { |
|||
template<typename ValueType> |
|||
class SparseMatrix; |
|||
} |
|||
|
|||
namespace solver { |
|||
|
|||
template<typename ValueType> |
|||
class Multiplier { |
|||
public: |
|||
|
|||
Multiplier(storm::storage::SparseMatrix<ValueType> const& matrix); |
|||
|
|||
/*! |
|||
* Retrieves whether Gauss Seidel style multiplications are allowed. |
|||
*/ |
|||
bool getAllowGaussSeidelMultiplications() const; |
|||
|
|||
/*! |
|||
* Sets whether Gauss Seidel style multiplications are allowed. |
|||
*/ |
|||
void setAllowGaussSeidelMultiplications(bool value); |
|||
|
|||
/*! |
|||
* Returns the multiplication style performed by this multiplier |
|||
*/ |
|||
virtual MultiplicationStyle getMultiplicationStyle() const = 0; |
|||
|
|||
/* |
|||
* Clears the currently cached data of this multiplier in order to free some memory. |
|||
*/ |
|||
virtual void clearCache() const; |
|||
|
|||
/*! |
|||
* Performs a matrix-vector multiplication x' = A*x + b. |
|||
* |
|||
* @param x The input vector with which to multiply the matrix. Its length must be equal |
|||
* to the number of columns of A. |
|||
* @param b If non-null, this vector is added after the multiplication. If given, its length must be equal |
|||
* to the number of rows of A. |
|||
* @param result The target vector into which to write the multiplication result. Its length must be equal |
|||
* to the number of rows of A. Can be the same as the x vector. |
|||
*/ |
|||
virtual void multiply(Environment const& env, std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<ValueType>& result) const = 0; |
|||
|
|||
/*! |
|||
* Performs a matrix-vector multiplication x' = A*x + b and then minimizes/maximizes over the row groups |
|||
* so that the resulting vector has the size of number of row groups of A. |
|||
* |
|||
* @param dir The direction for the reduction step. |
|||
* @param rowGroupIndices A vector storing the row groups over which to reduce. |
|||
* @param x The input vector with which to multiply the matrix. Its length must be equal |
|||
* to the number of columns of A. |
|||
* @param b If non-null, this vector is added after the multiplication. If given, its length must be equal |
|||
* to the number of rows of A. |
|||
* @param result The target vector into which to write the multiplication result. Its length must be equal |
|||
* to the number of rows of A. Can be the same as the x vector. |
|||
* @param choices If given, the choices made in the reduction process are written to this vector. |
|||
*/ |
|||
virtual void multiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, std::vector<ValueType>& result, std::vector<uint_fast64_t>* choices = nullptr) const = 0; |
|||
|
|||
/*! |
|||
* Performs repeated matrix-vector multiplication, using x[0] = x and x[i + 1] = A*x[i] + b. After |
|||
* performing the necessary multiplications, the result is written to the input vector x. Note that the |
|||
* matrix A has to be given upon construction time of the solver object. |
|||
* |
|||
* @param x The initial vector with which to perform matrix-vector multiplication. Its length must be equal |
|||
* to the number of columns of A. |
|||
* @param b If non-null, this vector is added after each multiplication. If given, its length must be equal |
|||
* to the number of rows of A. |
|||
* @param n The number of times to perform the multiplication. |
|||
*/ |
|||
void repeatedMultiply(Environment const& env, std::vector<ValueType>& x, std::vector<ValueType> const* b, uint64_t n) const; |
|||
|
|||
/*! |
|||
* Performs repeated matrix-vector multiplication x' = A*x + b and then minimizes/maximizes over the row groups |
|||
* so that the resulting vector has the size of number of row groups of A. |
|||
* |
|||
* @param dir The direction for the reduction step. |
|||
* @param rowGroupIndices A vector storing the row groups over which to reduce. |
|||
* @param x The input vector with which to multiply the matrix. Its length must be equal |
|||
* to the number of columns of A. |
|||
* @param b If non-null, this vector is added after the multiplication. If given, its length must be equal |
|||
* to the number of rows of A. |
|||
* @param result The target vector into which to write the multiplication result. Its length must be equal |
|||
* to the number of rows of A. |
|||
* @param n The number of times to perform the multiplication. |
|||
*/ |
|||
void repeatedMultiplyAndReduce(Environment const& env, OptimizationDirection const& dir, std::vector<uint64_t> const& rowGroupIndices, std::vector<ValueType>& x, std::vector<ValueType> const* b, uint64_t n) const; |
|||
|
|||
/*! |
|||
* Multiplies the row with the given index with x and adds the given offset |
|||
* @param rowIndex The index of the considered row |
|||
* @param x The input vector with which the row is multiplied |
|||
*/ |
|||
virtual ValueType multiplyRow(Environment const& env, uint64_t const& rowIndex, std::vector<ValueType> const& x, ValueType const& offset) const = 0; |
|||
|
|||
protected: |
|||
mutable std::unique_ptr<std::vector<ValueType>> cachedVector; |
|||
storm::storage::SparseMatrix<ValueType> const& matrix; |
|||
private: |
|||
bool allowGaussSeidelMultiplications; |
|||
}; |
|||
|
|||
template<typename ValueType> |
|||
class MultiplierFactory { |
|||
public: |
|||
MultiplierFactory() = default; |
|||
~MultiplierFactory() = default; |
|||
|
|||
std::unique_ptr<Multiplier<ValueType>> create(Environment const& env, storm::storage::SparseMatrix<ValueType> const& matrix); |
|||
|
|||
|
|||
}; |
|||
|
|||
} |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue