You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

123 lines
5.5 KiB

#include "gtest/gtest.h"
#include "storm-config.h"
#include "test/storm_gtest.h"
#include "storm/storage/SparseMatrix.h"
#include "storm/solver/Multiplier.h"
#include "storm/environment/solver/MultiplierEnvironment.h"
#include "storm/utility/vector.h"
namespace {
class NativeEnvironment {
public:
typedef double ValueType;
static const bool isExact = false;
static storm::Environment createEnvironment() {
storm::Environment env;
env.solver().multiplier().setType(storm::solver::MultiplierType::Native);
return env;
}
};
class GmmxxEnvironment {
public:
typedef double ValueType;
static const bool isExact = false;
static storm::Environment createEnvironment() {
storm::Environment env;
env.solver().multiplier().setType(storm::solver::MultiplierType::Gmmxx);
return env;
}
};
template<typename TestType>
class MultiplierTest : public ::testing::Test {
public:
typedef typename TestType::ValueType ValueType;
MultiplierTest() : _environment(TestType::createEnvironment()) {}
storm::Environment const& env() const { return _environment; }
ValueType precision() const { return TestType::isExact ? parseNumber("0") : parseNumber("1e-15");}
ValueType parseNumber(std::string const& input) const { return storm::utility::convertNumber<ValueType>(input);}
private:
storm::Environment _environment;
};
typedef ::testing::Types<
NativeEnvironment,
GmmxxEnvironment
> TestingTypes;
TYPED_TEST_SUITE(MultiplierTest, TestingTypes);
TYPED_TEST(MultiplierTest, repeatedMultiplyTest) {
typedef typename TestFixture::ValueType ValueType;
ASSERT_NO_THROW(storm::storage::SparseMatrixBuilder<ValueType> builder);
storm::storage::SparseMatrixBuilder<ValueType> builder;
ASSERT_NO_THROW(builder.addNextValue(0, 1, this->parseNumber("0.5")));
ASSERT_NO_THROW(builder.addNextValue(0, 4, this->parseNumber("0.5")));
ASSERT_NO_THROW(builder.addNextValue(1, 2, this->parseNumber("0.5")));
ASSERT_NO_THROW(builder.addNextValue(1, 4, this->parseNumber("0.5")));
ASSERT_NO_THROW(builder.addNextValue(2, 3, this->parseNumber("0.5")));
ASSERT_NO_THROW(builder.addNextValue(2, 4, this->parseNumber("0.5")));
ASSERT_NO_THROW(builder.addNextValue(3, 4, this->parseNumber("1")));
ASSERT_NO_THROW(builder.addNextValue(4, 4, this->parseNumber("1")));
storm::storage::SparseMatrix<ValueType> A;
ASSERT_NO_THROW(A = builder.build());
std::vector<ValueType> x(5);
x[4] = this->parseNumber("1");
auto factory = storm::solver::MultiplierFactory<ValueType>();
auto multiplier = factory.create(this->env(), A);
ASSERT_NO_THROW(multiplier->repeatedMultiply(this->env(), x, nullptr, 4));
EXPECT_NEAR(x[0], this->parseNumber("1"), this->precision());
}
TYPED_TEST(MultiplierTest, repeatedMultiplyAndReduceTest) {
typedef typename TestFixture::ValueType ValueType;
storm::storage::SparseMatrixBuilder<ValueType> builder(0, 0, 0, false, true);
ASSERT_NO_THROW(builder.newRowGroup(0));
ASSERT_NO_THROW(builder.addNextValue(0, 0, this->parseNumber("0.9")));
ASSERT_NO_THROW(builder.addNextValue(0, 1, this->parseNumber("0.099")));
ASSERT_NO_THROW(builder.addNextValue(0, 2, this->parseNumber("0.001")));
ASSERT_NO_THROW(builder.addNextValue(1, 1, this->parseNumber("0.5")));
ASSERT_NO_THROW(builder.addNextValue(1, 2, this->parseNumber("0.5")));
ASSERT_NO_THROW(builder.newRowGroup(2));
ASSERT_NO_THROW(builder.addNextValue(2, 1, this->parseNumber("1")));
ASSERT_NO_THROW(builder.newRowGroup(3));
ASSERT_NO_THROW(builder.addNextValue(3, 2, this->parseNumber("1")));
storm::storage::SparseMatrix<ValueType> A;
ASSERT_NO_THROW(A = builder.build());
std::vector<ValueType> initialX = {this->parseNumber("0"), this->parseNumber("1"), this->parseNumber("0")};
std::vector<ValueType> x;
auto factory = storm::solver::MultiplierFactory<ValueType>();
auto multiplier = factory.create(this->env(), A);
x = initialX;
ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Minimize, x, nullptr, 1));
EXPECT_NEAR(x[0], this->parseNumber("0.099"), this->precision());
x = initialX;
ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Minimize, x, nullptr, 2));
EXPECT_NEAR(x[0], this->parseNumber("0.1881"), this->precision());
x = initialX;
ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Minimize, x, nullptr, 20));
EXPECT_NEAR(x[0], this->parseNumber("0.5"), this->precision());
x = initialX;
ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Maximize, x, nullptr, 1));
EXPECT_NEAR(x[0], this->parseNumber("0.5"), this->precision());
x = initialX;
ASSERT_NO_THROW(multiplier->repeatedMultiplyAndReduce(this->env(), storm::OptimizationDirection::Maximize, x, nullptr, 20));
EXPECT_NEAR(x[0], this->parseNumber("0.923808265834023387639"), this->precision());
}
}