#include "gtest/gtest.h" #include "storm-config.h" #include "test/storm_gtest.h" #include "storm/storage/SparseMatrix.h" #include "storm/solver/Multiplier.h" #include "storm/environment/solver/MultiplierEnvironment.h" #include "storm/utility/vector.h" namespace { class NativeEnvironment { public: typedef double ValueType; static const bool isExact = false; static storm::Environment createEnvironment() { storm::Environment env; env.solver().multiplier().setType(storm::solver::MultiplierType::Native); return env; } }; class GmmxxEnvironment { public: typedef double ValueType; static const bool isExact = false; static storm::Environment createEnvironment() { storm::Environment env; env.solver().multiplier().setType(storm::solver::MultiplierType::Gmmxx); return env; } }; template<typename TestType> class MultiplierTest : public ::testing::Test { public: typedef typename TestType::ValueType ValueType; MultiplierTest() : _environment(TestType::createEnvironment()) {} storm::Environment const& env() const { return _environment; } ValueType precision() const { return TestType::isExact ? parseNumber("0") : parseNumber("1e-15");} ValueType parseNumber(std::string const& input) const { return storm::utility::convertNumber<ValueType>(input);} private: storm::Environment _environment; }; typedef ::testing::Types< NativeEnvironment, GmmxxEnvironment > TestingTypes; TYPED_TEST_SUITE(MultiplierTest, TestingTypes); TYPED_TEST(MultiplierTest, repeatedMultiplyTest) { typedef typename TestFixture::ValueType ValueType; ASSERT_NO_THROW(storm::storage::SparseMatrixBuilder<ValueType> builder); storm::storage::SparseMatrixBuilder<ValueType> builder; ASSERT_NO_THROW(builder.addNextValue(0, 1, this->parseNumber("0.5"))); ASSERT_NO_THROW(builder.addNextValue(0, 4, this->parseNumber("0.5"))); ASSERT_NO_THROW(builder.addNextValue(1, 2, this->parseNumber("0.5"))); ASSERT_NO_THROW(builder.addNextValue(1, 4, this->parseNumber("0.5"))); ASSERT_NO_THROW(builder.addNextValue(2, 3, this->parseNumber("0.5"))); ASSERT_NO_THROW(builder.addNextValue(2, 4, this->parseNumber("0.5"))); ASSERT_NO_THROW(builder.addNextValue(3, 4, this->parseNumber("1"))); ASSERT_NO_THROW(builder.addNextValue(4, 4, this->parseNumber("1"))); storm::storage::SparseMatrix<ValueType> A; ASSERT_NO_THROW(A = builder.build()); std::vector<ValueType> x(5); x[4] = this->parseNumber("1"); auto factory = storm::solver::MultiplierFactory<ValueType>(); auto multiplier = factory.create(this->env(), A); ASSERT_NO_THROW(multiplier->repeatedMultiply(this->env(), x, nullptr, 4)); EXPECT_NEAR(x[0], this->parseNumber("1"), this->precision()); } TYPED_TEST(MultiplierTest, repeatedMultiplyAndReduceTest) { typedef typename TestFixture::ValueType ValueType; storm::storage::SparseMatrixBuilder<ValueType> builder(0, 0, 0, false, true); ASSERT_NO_THROW(builder.newRowGroup(0)); ASSERT_NO_THROW(builder.addNextValue(0, 0, this->parseNumber("0.9"))); ASSERT_NO_THROW(builder.addNextValue(0, 1, this->parseNumber("0.099"))); ASSERT_NO_THROW(builder.addNextValue(0, 2, this->parseNumber("0.001"))); ASSERT_NO_THROW(builder.addNextValue(1, 1, this->parseNumber("0.5"))); ASSERT_NO_THROW(builder.addNextValue(1, 2, this->parseNumber("0.5"))); ASSERT_NO_THROW(builder.newRowGroup(2)); ASSERT_NO_THROW(builder.addNextValue(2, 1, this->parseNumber("1"))); ASSERT_NO_THROW(builder.newRowGroup(3)); ASSERT_NO_THROW(builder.addNextValue(3, 2, this->parseNumber("1"))); storm::storage::SparseMatrix<ValueType> A; ASSERT_NO_THROW(A = builder.build()); std::vector<ValueType> initialX = {this->parseNumber("0"), this->parseNumber("1"), this->parseNumber("0")}; std::vector<ValueType> x; auto factory = storm::solver::MultiplierFactory<ValueType>(); auto multiplier = factory.create(this->env(), A); x = initialX; ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Minimize, x, nullptr, 1)); EXPECT_NEAR(x[0], this->parseNumber("0.099"), this->precision()); x = initialX; ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Minimize, x, nullptr, 2)); EXPECT_NEAR(x[0], this->parseNumber("0.1881"), this->precision()); x = initialX; ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Minimize, x, nullptr, 20)); EXPECT_NEAR(x[0], this->parseNumber("0.5"), this->precision()); x = initialX; ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Maximize, x, nullptr, 1)); EXPECT_NEAR(x[0], this->parseNumber("0.5"), this->precision()); x = initialX; ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Maximize, x, nullptr, 20)); EXPECT_NEAR(x[0], this->parseNumber("0.923808265834023387639"), this->precision()); } }