You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

100 lines
4.4 KiB

  1. #include "gtest/gtest.h"
  2. #include "src/transformer/StateDuplicator.h"
  3. TEST(StateDuplicator, SimpleModelTest) {
  4. storm::storage::SparseMatrix<double> matrix;
  5. storm::storage::SparseMatrixBuilder<double> builder(6, 4, 7, true, true, 4);
  6. ASSERT_NO_THROW(builder.newRowGroup(0));
  7. ASSERT_NO_THROW(builder.addNextValue(0, 0, 0.3));
  8. ASSERT_NO_THROW(builder.addNextValue(0, 1, 0.7));
  9. ASSERT_NO_THROW(builder.addNextValue(1, 3, 1.0));
  10. ASSERT_NO_THROW(builder.newRowGroup(2));
  11. ASSERT_NO_THROW(builder.addNextValue(2, 1, 1.0));
  12. ASSERT_NO_THROW(builder.newRowGroup(3));
  13. ASSERT_NO_THROW(builder.addNextValue(3, 0, 1.0));
  14. ASSERT_NO_THROW(builder.newRowGroup(4));
  15. ASSERT_NO_THROW(builder.addNextValue(4, 0, 1.0));
  16. ASSERT_NO_THROW(builder.addNextValue(5, 3, 1.0));
  17. ASSERT_NO_THROW(matrix = builder.build());
  18. storm::models::sparse::StateLabeling labeling(4);
  19. storm::storage::BitVector initStates(4);
  20. initStates.set(0);
  21. labeling.addLabel("init", initStates);
  22. storm::storage::BitVector gateStates(4);
  23. gateStates.set(3);
  24. labeling.addLabel("gate", gateStates);
  25. storm::storage::BitVector aStates(4);
  26. aStates.set(0);
  27. aStates.set(2);
  28. labeling.addLabel("a", aStates);
  29. storm::storage::BitVector bStates(4);
  30. bStates.set(1);
  31. bStates.set(3);
  32. labeling.addLabel("b", bStates);
  33. std::unordered_map<std::string, storm::models::sparse::StandardRewardModel<double>> rewardModels;
  34. std::vector<double> stateReward = {1.0, 2.0, 3.0, 4.0};
  35. std::vector<double> stateActionReward = {1.1, 1.2, 2.1, 3.1, 4.1, 4.2};
  36. rewardModels.insert(std::make_pair("rewards", storm::models::sparse::StandardRewardModel<double>(stateReward, stateActionReward)));
  37. storm::models::sparse::Mdp<double> model(matrix, labeling, rewardModels);
  38. auto res = storm::transformer::StateDuplicator<storm::models::sparse::Mdp<double>>::transform(model, gateStates);
  39. storm::storage::SparseMatrixBuilder<double> expectedBuilder(8, 5, 10, true, true, 5);
  40. ASSERT_NO_THROW(expectedBuilder.newRowGroup(0));
  41. ASSERT_NO_THROW(expectedBuilder.addNextValue(0, 0, 0.3));
  42. ASSERT_NO_THROW(expectedBuilder.addNextValue(0, 1, 0.7));
  43. ASSERT_NO_THROW(expectedBuilder.addNextValue(1, 2, 1.0));
  44. ASSERT_NO_THROW(expectedBuilder.newRowGroup(2));
  45. ASSERT_NO_THROW(expectedBuilder.addNextValue(2, 1, 1.0));
  46. ASSERT_NO_THROW(expectedBuilder.newRowGroup(3));
  47. ASSERT_NO_THROW(expectedBuilder.addNextValue(3, 3, 1.0));
  48. ASSERT_NO_THROW(expectedBuilder.addNextValue(4, 2, 1.0));
  49. ASSERT_NO_THROW(expectedBuilder.newRowGroup(5));
  50. ASSERT_NO_THROW(expectedBuilder.addNextValue(5, 3, 0.3));
  51. ASSERT_NO_THROW(expectedBuilder.addNextValue(5, 4, 0.7));
  52. ASSERT_NO_THROW(expectedBuilder.addNextValue(6, 2, 1.0));
  53. ASSERT_NO_THROW(expectedBuilder.newRowGroup(7));
  54. ASSERT_NO_THROW(expectedBuilder.addNextValue(7, 4, 1.0));
  55. ASSERT_NO_THROW(matrix = expectedBuilder.build());
  56. EXPECT_EQ(matrix, res.model->getTransitionMatrix());
  57. initStates.resize(5);
  58. EXPECT_EQ(initStates, res.model->getInitialStates());
  59. gateStates=storm::storage::BitVector(5);
  60. gateStates.set(2);
  61. EXPECT_EQ(gateStates, res.model->getStates("gate"));
  62. aStates = initStates;
  63. aStates.set(3);
  64. EXPECT_EQ(aStates, res.model->getStates("a"));
  65. bStates = ~aStates;
  66. EXPECT_EQ(bStates, res.model->getStates("b"));
  67. EXPECT_TRUE(res.model->hasRewardModel("rewards"));
  68. EXPECT_TRUE(res.model->getRewardModel("rewards").hasStateRewards());
  69. stateReward = {1.0, 2.0, 4.0, 1.0, 2.0};
  70. EXPECT_EQ(stateReward, res.model->getRewardModel("rewards").getStateRewardVector());
  71. EXPECT_TRUE(res.model->getRewardModel("rewards").hasStateActionRewards());
  72. stateActionReward = {1.1, 1.2, 2.1, 4.1, 4.2, 1.1, 1.2, 2.1};
  73. EXPECT_EQ(stateActionReward, res.model->getRewardModel("rewards").getStateActionRewardVector());
  74. storm::storage::BitVector firstCopy(5);
  75. firstCopy.set(0);
  76. firstCopy.set(1);
  77. EXPECT_EQ(firstCopy, res.firstCopy);
  78. EXPECT_EQ(~firstCopy, res.secondCopy);
  79. std::vector<uint_fast64_t> mapping = {0,1,3,0,1};
  80. EXPECT_EQ(mapping, res.newToOldStateIndexMapping);
  81. uint_fast64_t max = std::numeric_limits<uint_fast64_t>::max();
  82. mapping = {0, 1, max, max};
  83. EXPECT_EQ(mapping, res.firstCopyOldToNewStateIndexMapping);
  84. mapping = {3, 4, max, 2};
  85. EXPECT_EQ(mapping, res.secondCopyOldToNewStateIndexMapping);
  86. }