diff --git a/resources/examples/testfiles/mdp/die_selection.nm b/resources/examples/testfiles/mdp/die_selection.nm index 819472268..17c211622 100644 --- a/resources/examples/testfiles/mdp/die_selection.nm +++ b/resources/examples/testfiles/mdp/die_selection.nm @@ -9,23 +9,26 @@ module die d : [0..6] init 0; [fair] s=0 -> 0.5 : (s'=1) + 0.5 : (s'=2); - [ufair1] s=0 -> 0.6 : (s'=1) + 0.4 : (s'=2); - [ufair2] s=0 -> 0.7 : (s'=1) + 0.3 : (s'=2); + [ufair1] s=0 -> 0.6 : (s'=1) + 0.4 : (s'=2); + [ufair2] s=0 -> 0.7 : (s'=1) + 0.3 : (s'=2); [fair] s=1 -> 0.5 : (s'=3) + 0.5 : (s'=4); [ufair1] s=1 -> 0.6 : (s'=3) + 0.4 : (s'=4); - [ufair2] s=1 -> 0.7 : (s'=3) + 0.3 : (s'=4); - [fair] s=2 -> 0.5 : (s'=5) + 0.5 : (s'=6); + [ufair2] s=1 -> 0.7 : (s'=3) + 0.3 : (s'=4); + [fair] s=2 -> 0.5 : (s'=5) + 0.5 : (s'=6); [ufair1] s=2 -> 0.6 : (s'=5) + 0.4 : (s'=6); [ufair2] s=2 -> 0.7 : (s'=5) + 0.3 : (s'=6); - [fair] s=3 -> 0.5 : (s'=1) + 0.5 : (s'=7) & (d'=1); + [fair] s=3 -> 0.5 : (s'=1) + 0.5 : (s'=7) & (d'=1); [ufair1] s=3 -> 0.6 : (s'=1) + 0.4 : (s'=7) & (d'=1); [ufair2] s=3 -> 0.7 : (s'=1) + 0.3 : (s'=7) & (d'=1); [fair] s=4 -> 0.5 : (s'=7) & (d'=2) + 0.5 : (s'=7) & (d'=3); [ufair1] s=4 -> 0.6 : (s'=7) & (d'=2) + 0.4 : (s'=7) & (d'=3); [ufair2] s=4 -> 0.7 : (s'=7) & (d'=2) + 0.3 : (s'=7) & (d'=3); [fair] s=5 -> 0.5 : (s'=7) & (d'=4) + 0.5 : (s'=7) & (d'=5); - [ufair1] s=5 -> 0.6 : (s'=2) + 0.4 : (s'=7) & (d'=6); - [ufair2] s=5 -> 0.7 : (s'=2) + 0.3 : (s'=7) & (d'=6); + [ufair1] s=5 -> 0.6 : (s'=7) & (d'=4) + 0.4 : (s'=7) & (d'=5); + [ufair2] s=5 -> 0.7 : (s'=7) & (d'=4) + 0.3 : (s'=7) & (d'=5); + [fair] s=6 -> 0.5 : (s'=2) + 0.5 : (s'=7) & (d'=6); + [ufair1] s=6 -> 0.6 : (s'=2) + 0.4 : (s'=7) & (d'=6); + [ufair2] s=6 -> 0.7 : (s'=2) + 0.3 : (s'=7) & (d'=6); [] s=7 -> 1: (s'=7); endmodule diff --git a/src/test/storm/builder/ExplicitPrismModelBuilderTest.cpp b/src/test/storm/builder/ExplicitPrismModelBuilderTest.cpp index d0f4201e1..7af0028e7 100644 --- a/src/test/storm/builder/ExplicitPrismModelBuilderTest.cpp +++ b/src/test/storm/builder/ExplicitPrismModelBuilderTest.cpp @@ -1,3 +1,4 @@ +#include #include "test/storm_gtest.h" #include "storm-config.h" #include "storm/models/sparse/StandardRewardModel.h" @@ -66,7 +67,6 @@ TEST(ExplicitPrismModelBuilderTest, Ctmc) { TEST(ExplicitPrismModelBuilderTest, Mdp) { storm::prism::Program program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/mdp/two_dice.nm"); - std::shared_ptr> model = storm::builder::ExplicitModelBuilder(program).build(); EXPECT_EQ(169ul, model->getNumberOfStates()); EXPECT_EQ(436ul, model->getNumberOfTransitions()); @@ -169,4 +169,47 @@ TEST(ExplicitPrismModelBuilderTest, ExportExplicitLookup) { EXPECT_EQ(model->getNumberOfStates(), lookup.lookup({{svar, manager.integer(1)}, {dvar, manager.integer(2)}})); EXPECT_TRUE(model->getNumberOfStates() > lookup.lookup({{svar, manager.integer(7)}, {dvar, manager.integer(2)}})); EXPECT_EQ(1ul, model->getLabelsOfState(lookup.lookup({{svar, manager.integer(7)}, {dvar, manager.integer(2)}})).count("two")); -} \ No newline at end of file +} + + +bool trivial_true_mask(storm::expressions::SimpleValuation const&, uint64_t) { + return true; +} + +bool trivial_false_mask(storm::expressions::SimpleValuation const&, uint64_t) { + return false; +} + +bool only_first_action_mask(storm::expressions::SimpleValuation const&, uint64_t actionIndex) { + return actionIndex <= 1; +} + +TEST(ExplicitPrismModelBuilderTest, CallbackActionMask) { + storm::prism::Program program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/mdp/die_selection.nm"); + storm::generator::NextStateGeneratorOptions generatorOptions; + generatorOptions.setBuildAllLabels(); + generatorOptions.setBuildChoiceLabels(); + std::shared_ptr> mask_object = std::make_shared>(trivial_true_mask); + std::shared_ptr> generator = std::make_shared>(program, generatorOptions, mask_object); + auto builder = storm::builder::ExplicitModelBuilder(generator); + + std::shared_ptr> model = builder.build(); + EXPECT_EQ(13ul, model->getNumberOfStates()); + EXPECT_EQ(48ul, model->getNumberOfTransitions()); + + mask_object = std::make_shared>(trivial_false_mask); + generator = std::make_shared>(program, generatorOptions, mask_object); + builder = storm::builder::ExplicitModelBuilder(generator); + + model = builder.build(); + EXPECT_EQ(1ul, model->getNumberOfStates()); + EXPECT_EQ(1ul, model->getNumberOfTransitions()); + + mask_object = std::make_shared>(only_first_action_mask); + generator = std::make_shared>(program, generatorOptions, mask_object); + builder = storm::builder::ExplicitModelBuilder(generator); + + model = builder.build(); + EXPECT_EQ(13ul, model->getNumberOfStates()); + EXPECT_EQ(20ul, model->getNumberOfTransitions()); +}