You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
123 lines
5.5 KiB
123 lines
5.5 KiB
#include "gtest/gtest.h"
|
|
#include "storm-config.h"
|
|
#include "test/storm_gtest.h"
|
|
|
|
#include "storm/storage/SparseMatrix.h"
|
|
#include "storm/solver/Multiplier.h"
|
|
#include "storm/environment/solver/MultiplierEnvironment.h"
|
|
|
|
#include "storm/utility/vector.h"
|
|
namespace {
|
|
|
|
class NativeEnvironment {
|
|
public:
|
|
typedef double ValueType;
|
|
static const bool isExact = false;
|
|
static storm::Environment createEnvironment() {
|
|
storm::Environment env;
|
|
env.solver().multiplier().setType(storm::solver::MultiplierType::Native);
|
|
return env;
|
|
}
|
|
};
|
|
|
|
class GmmxxEnvironment {
|
|
public:
|
|
typedef double ValueType;
|
|
static const bool isExact = false;
|
|
static storm::Environment createEnvironment() {
|
|
storm::Environment env;
|
|
env.solver().multiplier().setType(storm::solver::MultiplierType::Gmmxx);
|
|
return env;
|
|
}
|
|
};
|
|
|
|
template<typename TestType>
|
|
class MultiplierTest : public ::testing::Test {
|
|
public:
|
|
typedef typename TestType::ValueType ValueType;
|
|
MultiplierTest() : _environment(TestType::createEnvironment()) {}
|
|
storm::Environment const& env() const { return _environment; }
|
|
ValueType precision() const { return TestType::isExact ? parseNumber("0") : parseNumber("1e-15");}
|
|
ValueType parseNumber(std::string const& input) const { return storm::utility::convertNumber<ValueType>(input);}
|
|
private:
|
|
storm::Environment _environment;
|
|
};
|
|
|
|
typedef ::testing::Types<
|
|
NativeEnvironment,
|
|
GmmxxEnvironment
|
|
> TestingTypes;
|
|
|
|
TYPED_TEST_SUITE(MultiplierTest, TestingTypes);
|
|
|
|
TYPED_TEST(MultiplierTest, repeatedMultiplyTest) {
|
|
typedef typename TestFixture::ValueType ValueType;
|
|
ASSERT_NO_THROW(storm::storage::SparseMatrixBuilder<ValueType> builder);
|
|
storm::storage::SparseMatrixBuilder<ValueType> builder;
|
|
ASSERT_NO_THROW(builder.addNextValue(0, 1, this->parseNumber("0.5")));
|
|
ASSERT_NO_THROW(builder.addNextValue(0, 4, this->parseNumber("0.5")));
|
|
ASSERT_NO_THROW(builder.addNextValue(1, 2, this->parseNumber("0.5")));
|
|
ASSERT_NO_THROW(builder.addNextValue(1, 4, this->parseNumber("0.5")));
|
|
ASSERT_NO_THROW(builder.addNextValue(2, 3, this->parseNumber("0.5")));
|
|
ASSERT_NO_THROW(builder.addNextValue(2, 4, this->parseNumber("0.5")));
|
|
ASSERT_NO_THROW(builder.addNextValue(3, 4, this->parseNumber("1")));
|
|
ASSERT_NO_THROW(builder.addNextValue(4, 4, this->parseNumber("1")));
|
|
|
|
storm::storage::SparseMatrix<ValueType> A;
|
|
ASSERT_NO_THROW(A = builder.build());
|
|
|
|
std::vector<ValueType> x(5);
|
|
x[4] = this->parseNumber("1");
|
|
|
|
auto factory = storm::solver::MultiplierFactory<ValueType>();
|
|
auto multiplier = factory.create(this->env(), A);
|
|
ASSERT_NO_THROW(multiplier->repeatedMultiply(this->env(), x, nullptr, 4));
|
|
EXPECT_NEAR(x[0], this->parseNumber("1"), this->precision());
|
|
}
|
|
|
|
TYPED_TEST(MultiplierTest, repeatedMultiplyAndReduceTest) {
|
|
typedef typename TestFixture::ValueType ValueType;
|
|
|
|
storm::storage::SparseMatrixBuilder<ValueType> builder(0, 0, 0, false, true);
|
|
ASSERT_NO_THROW(builder.newRowGroup(0));
|
|
ASSERT_NO_THROW(builder.addNextValue(0, 0, this->parseNumber("0.9")));
|
|
ASSERT_NO_THROW(builder.addNextValue(0, 1, this->parseNumber("0.099")));
|
|
ASSERT_NO_THROW(builder.addNextValue(0, 2, this->parseNumber("0.001")));
|
|
ASSERT_NO_THROW(builder.addNextValue(1, 1, this->parseNumber("0.5")));
|
|
ASSERT_NO_THROW(builder.addNextValue(1, 2, this->parseNumber("0.5")));
|
|
ASSERT_NO_THROW(builder.newRowGroup(2));
|
|
ASSERT_NO_THROW(builder.addNextValue(2, 1, this->parseNumber("1")));
|
|
ASSERT_NO_THROW(builder.newRowGroup(3));
|
|
ASSERT_NO_THROW(builder.addNextValue(3, 2, this->parseNumber("1")));
|
|
|
|
storm::storage::SparseMatrix<ValueType> A;
|
|
ASSERT_NO_THROW(A = builder.build());
|
|
|
|
std::vector<ValueType> initialX = {this->parseNumber("0"), this->parseNumber("1"), this->parseNumber("0")};
|
|
std::vector<ValueType> x;
|
|
|
|
auto factory = storm::solver::MultiplierFactory<ValueType>();
|
|
auto multiplier = factory.create(this->env(), A);
|
|
|
|
x = initialX;
|
|
ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Minimize, x, nullptr, 1));
|
|
EXPECT_NEAR(x[0], this->parseNumber("0.099"), this->precision());
|
|
|
|
x = initialX;
|
|
ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Minimize, x, nullptr, 2));
|
|
EXPECT_NEAR(x[0], this->parseNumber("0.1881"), this->precision());
|
|
|
|
x = initialX;
|
|
ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Minimize, x, nullptr, 20));
|
|
EXPECT_NEAR(x[0], this->parseNumber("0.5"), this->precision());
|
|
|
|
x = initialX;
|
|
ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Maximize, x, nullptr, 1));
|
|
EXPECT_NEAR(x[0], this->parseNumber("0.5"), this->precision());
|
|
|
|
x = initialX;
|
|
ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Maximize, x, nullptr, 20));
|
|
EXPECT_NEAR(x[0], this->parseNumber("0.923808265834023387639"), this->precision());
|
|
}
|
|
|
|
}
|