Browse Source

test for masking during building

tempestpy_adaptions
Sebastian Junges 4 years ago
parent
commit
76bf1049ee
  1. 17
      resources/examples/testfiles/mdp/die_selection.nm
  2. 45
      src/test/storm/builder/ExplicitPrismModelBuilderTest.cpp

17
resources/examples/testfiles/mdp/die_selection.nm

@ -9,23 +9,26 @@ module die
d : [0..6] init 0; d : [0..6] init 0;
[fair] s=0 -> 0.5 : (s'=1) + 0.5 : (s'=2); [fair] s=0 -> 0.5 : (s'=1) + 0.5 : (s'=2);
[ufair1] s=0 -> 0.6 : (s'=1) + 0.4 : (s'=2);
[ufair2] s=0 -> 0.7 : (s'=1) + 0.3 : (s'=2);
[ufair1] s=0 -> 0.6 : (s'=1) + 0.4 : (s'=2);
[ufair2] s=0 -> 0.7 : (s'=1) + 0.3 : (s'=2);
[fair] s=1 -> 0.5 : (s'=3) + 0.5 : (s'=4); [fair] s=1 -> 0.5 : (s'=3) + 0.5 : (s'=4);
[ufair1] s=1 -> 0.6 : (s'=3) + 0.4 : (s'=4); [ufair1] s=1 -> 0.6 : (s'=3) + 0.4 : (s'=4);
[ufair2] s=1 -> 0.7 : (s'=3) + 0.3 : (s'=4);
[fair] s=2 -> 0.5 : (s'=5) + 0.5 : (s'=6);
[ufair2] s=1 -> 0.7 : (s'=3) + 0.3 : (s'=4);
[fair] s=2 -> 0.5 : (s'=5) + 0.5 : (s'=6);
[ufair1] s=2 -> 0.6 : (s'=5) + 0.4 : (s'=6); [ufair1] s=2 -> 0.6 : (s'=5) + 0.4 : (s'=6);
[ufair2] s=2 -> 0.7 : (s'=5) + 0.3 : (s'=6); [ufair2] s=2 -> 0.7 : (s'=5) + 0.3 : (s'=6);
[fair] s=3 -> 0.5 : (s'=1) + 0.5 : (s'=7) & (d'=1);
[fair] s=3 -> 0.5 : (s'=1) + 0.5 : (s'=7) & (d'=1);
[ufair1] s=3 -> 0.6 : (s'=1) + 0.4 : (s'=7) & (d'=1); [ufair1] s=3 -> 0.6 : (s'=1) + 0.4 : (s'=7) & (d'=1);
[ufair2] s=3 -> 0.7 : (s'=1) + 0.3 : (s'=7) & (d'=1); [ufair2] s=3 -> 0.7 : (s'=1) + 0.3 : (s'=7) & (d'=1);
[fair] s=4 -> 0.5 : (s'=7) & (d'=2) + 0.5 : (s'=7) & (d'=3); [fair] s=4 -> 0.5 : (s'=7) & (d'=2) + 0.5 : (s'=7) & (d'=3);
[ufair1] s=4 -> 0.6 : (s'=7) & (d'=2) + 0.4 : (s'=7) & (d'=3); [ufair1] s=4 -> 0.6 : (s'=7) & (d'=2) + 0.4 : (s'=7) & (d'=3);
[ufair2] s=4 -> 0.7 : (s'=7) & (d'=2) + 0.3 : (s'=7) & (d'=3); [ufair2] s=4 -> 0.7 : (s'=7) & (d'=2) + 0.3 : (s'=7) & (d'=3);
[fair] s=5 -> 0.5 : (s'=7) & (d'=4) + 0.5 : (s'=7) & (d'=5); [fair] s=5 -> 0.5 : (s'=7) & (d'=4) + 0.5 : (s'=7) & (d'=5);
[ufair1] s=5 -> 0.6 : (s'=2) + 0.4 : (s'=7) & (d'=6);
[ufair2] s=5 -> 0.7 : (s'=2) + 0.3 : (s'=7) & (d'=6);
[ufair1] s=5 -> 0.6 : (s'=7) & (d'=4) + 0.4 : (s'=7) & (d'=5);
[ufair2] s=5 -> 0.7 : (s'=7) & (d'=4) + 0.3 : (s'=7) & (d'=5);
[fair] s=6 -> 0.5 : (s'=2) + 0.5 : (s'=7) & (d'=6);
[ufair1] s=6 -> 0.6 : (s'=2) + 0.4 : (s'=7) & (d'=6);
[ufair2] s=6 -> 0.7 : (s'=2) + 0.3 : (s'=7) & (d'=6);
[] s=7 -> 1: (s'=7); [] s=7 -> 1: (s'=7);
endmodule endmodule

45
src/test/storm/builder/ExplicitPrismModelBuilderTest.cpp

@ -1,3 +1,4 @@
#include <storm/generator/PrismNextStateGenerator.h>
#include "test/storm_gtest.h" #include "test/storm_gtest.h"
#include "storm-config.h" #include "storm-config.h"
#include "storm/models/sparse/StandardRewardModel.h" #include "storm/models/sparse/StandardRewardModel.h"
@ -66,7 +67,6 @@ TEST(ExplicitPrismModelBuilderTest, Ctmc) {
TEST(ExplicitPrismModelBuilderTest, Mdp) { TEST(ExplicitPrismModelBuilderTest, Mdp) {
storm::prism::Program program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/mdp/two_dice.nm"); storm::prism::Program program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/mdp/two_dice.nm");
std::shared_ptr<storm::models::sparse::Model<double>> model = storm::builder::ExplicitModelBuilder<double>(program).build(); std::shared_ptr<storm::models::sparse::Model<double>> model = storm::builder::ExplicitModelBuilder<double>(program).build();
EXPECT_EQ(169ul, model->getNumberOfStates()); EXPECT_EQ(169ul, model->getNumberOfStates());
EXPECT_EQ(436ul, model->getNumberOfTransitions()); EXPECT_EQ(436ul, model->getNumberOfTransitions());
@ -170,3 +170,46 @@ TEST(ExplicitPrismModelBuilderTest, ExportExplicitLookup) {
EXPECT_TRUE(model->getNumberOfStates() > lookup.lookup({{svar, manager.integer(7)}, {dvar, manager.integer(2)}})); EXPECT_TRUE(model->getNumberOfStates() > lookup.lookup({{svar, manager.integer(7)}, {dvar, manager.integer(2)}}));
EXPECT_EQ(1ul, model->getLabelsOfState(lookup.lookup({{svar, manager.integer(7)}, {dvar, manager.integer(2)}})).count("two")); EXPECT_EQ(1ul, model->getLabelsOfState(lookup.lookup({{svar, manager.integer(7)}, {dvar, manager.integer(2)}})).count("two"));
} }
bool trivial_true_mask(storm::expressions::SimpleValuation const&, uint64_t) {
return true;
}
bool trivial_false_mask(storm::expressions::SimpleValuation const&, uint64_t) {
return false;
}
bool only_first_action_mask(storm::expressions::SimpleValuation const&, uint64_t actionIndex) {
return actionIndex <= 1;
}
TEST(ExplicitPrismModelBuilderTest, CallbackActionMask) {
storm::prism::Program program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/mdp/die_selection.nm");
storm::generator::NextStateGeneratorOptions generatorOptions;
generatorOptions.setBuildAllLabels();
generatorOptions.setBuildChoiceLabels();
std::shared_ptr<storm::generator::StateValuationFunctionMask<double>> mask_object = std::make_shared<storm::generator::StateValuationFunctionMask<double>>(trivial_true_mask);
std::shared_ptr<storm::generator::PrismNextStateGenerator<double>> generator = std::make_shared<storm::generator::PrismNextStateGenerator<double>>(program, generatorOptions, mask_object);
auto builder = storm::builder::ExplicitModelBuilder<double>(generator);
std::shared_ptr<storm::models::sparse::Model<double>> model = builder.build();
EXPECT_EQ(13ul, model->getNumberOfStates());
EXPECT_EQ(48ul, model->getNumberOfTransitions());
mask_object = std::make_shared<storm::generator::StateValuationFunctionMask<double>>(trivial_false_mask);
generator = std::make_shared<storm::generator::PrismNextStateGenerator<double>>(program, generatorOptions, mask_object);
builder = storm::builder::ExplicitModelBuilder<double>(generator);
model = builder.build();
EXPECT_EQ(1ul, model->getNumberOfStates());
EXPECT_EQ(1ul, model->getNumberOfTransitions());
mask_object = std::make_shared<storm::generator::StateValuationFunctionMask<double>>(only_first_action_mask);
generator = std::make_shared<storm::generator::PrismNextStateGenerator<double>>(program, generatorOptions, mask_object);
builder = storm::builder::ExplicitModelBuilder<double>(generator);
model = builder.build();
EXPECT_EQ(13ul, model->getNumberOfStates());
EXPECT_EQ(20ul, model->getNumberOfTransitions());
}
Loading…
Cancel
Save