diff --git a/src/storage/prism/Program.cpp b/src/storage/prism/Program.cpp index 21f999a40..516983fa6 100644 --- a/src/storage/prism/Program.cpp +++ b/src/storage/prism/Program.cpp @@ -7,8 +7,6 @@ #include "src/storage/expressions/ExpressionManager.h" #include "src/settings/SettingsManager.h" #include "src/settings/modules/IOSettings.h" -#include "src/utility/macros.h" -#include "src/utility/solver.h" #include "src/exceptions/InvalidArgumentException.h" #include "src/exceptions/OutOfRangeException.h" #include "src/exceptions/WrongFormatException.h" @@ -18,6 +16,10 @@ #include "src/storage/jani/Model.h" +#include "src/utility/macros.h" +#include "src/utility/solver.h" +#include "src/utility/vector.h" + #include "src/storage/prism/CompositionVisitor.h" #include "src/storage/prism/Compositions.h" #include "src/storage/prism/CompositionToJaniVisitor.h" @@ -1360,7 +1362,10 @@ namespace storm { solver->add(atLeastOneCommandFromModule); } - // Now we are in a position to start the enumeration over all command variables. + // Now we are in a position to start the enumeration over all command variables. While doing so, we + // keep track of previously seen command combinations, because the AllSat procedures are not + // always guaranteed to only provide distinct models. + std::unordered_set, storm::utility::vector::VectorHash> seenCommandCombinations; solver->allSat(allCommandVariables, [&] (storm::solver::SmtSolver::ModelReference& modelReference) -> bool { // Now we need to reconstruct the chosen commands from the valuation of the command variables. std::vector>> chosenCommands(possibleCommands.size()); @@ -1382,12 +1387,19 @@ namespace storm { bool movedAtLeastOneIterator = false; std::vector> commandCombination(chosenCommands.size(), chosenCommands.front().front()); + std::vector commandCombinationIndices(iterators.size()); do { for (uint_fast64_t index = 0; index < iterators.size(); ++index) { commandCombination[index] = *iterators[index]; + commandCombinationIndices[index] = commandCombination[index].get().getGlobalIndex(); } - newCommands.push_back(synchronizeCommands(nextCommandIndex, actionIndex, nextUpdateIndex, indexToActionMap.find(actionIndex)->second, commandCombination)); + // Only add the command combination if it was not previously seen. + auto seenIt = seenCommandCombinations.find(commandCombinationIndices); + if (seenIt == seenCommandCombinations.end()) { + newCommands.push_back(synchronizeCommands(nextCommandIndex, actionIndex, nextUpdateIndex, indexToActionMap.find(actionIndex)->second, commandCombination)); + seenCommandCombinations.insert(commandCombinationIndices); + } // Move the counters appropriately. ++nextCommandIndex; diff --git a/src/utility/vector.h b/src/utility/vector.h index 6faf88c7c..6069e78cb 100644 --- a/src/utility/vector.h +++ b/src/utility/vector.h @@ -25,6 +25,18 @@ namespace storm { namespace utility { namespace vector { + template + struct VectorHash { + size_t operator()(std::vector const& vec) const { + std::hash hasher; + std::size_t seed = 0; + for (ValueType const& element : vec) { + seed ^= hasher(element) + 0x9e3779b9 + (seed<<6) + (seed>>2); + } + return seed; + } + }; + /*! * Sets the provided values at the provided positions in the given vector. * @@ -712,7 +724,7 @@ namespace storm { * @return String containing the representation of the vector. */ template - std::string toString(std::vector vector) { + std::string toString(std::vector const& vector) { std::stringstream stream; stream << "vector (" << vector.size() << ") [ "; if (!vector.empty()) { diff --git a/test/functional/storage/PrismProgramTest.cpp b/test/functional/storage/PrismProgramTest.cpp index e8ea63c98..3cffe21a1 100644 --- a/test/functional/storage/PrismProgramTest.cpp +++ b/test/functional/storage/PrismProgramTest.cpp @@ -7,7 +7,7 @@ #include "src/storage/jani/Model.h" #ifdef STORM_HAVE_MSAT -TEST(PrismProgramTest, FlattenModules) { +TEST(PrismProgramTest, FlattenModules_Leader_Mathsat) { storm::prism::Program program; ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/leader3.nm")); @@ -16,33 +16,126 @@ TEST(PrismProgramTest, FlattenModules) { ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); EXPECT_EQ(1, program.getNumberOfModules()); EXPECT_EQ(74, program.getModule(0).getNumberOfCommands()); +} +TEST(PrismProgramTest, FlattenModules_Wlan_Mathsat) { + storm::prism::Program program; ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/wlan0_collide.nm")); + std::shared_ptr smtSolverFactory = std::make_shared(); + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); EXPECT_EQ(1, program.getNumberOfModules()); - EXPECT_EQ(180, program.getModule(0).getNumberOfCommands()); + EXPECT_EQ(179, program.getModule(0).getNumberOfCommands()); +} +TEST(PrismProgramTest, FlattenModules_Csma_Mathsat) { + storm::prism::Program program; ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/csma2_2.nm")); + std::shared_ptr smtSolverFactory = std::make_shared(); + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); EXPECT_EQ(1, program.getNumberOfModules()); - EXPECT_EQ(71, program.getModule(0).getNumberOfCommands()); + EXPECT_EQ(70, program.getModule(0).getNumberOfCommands()); +} +TEST(PrismProgramTest, FlattenModules_Firewire_Mathsat) { + storm::prism::Program program; ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/firewire.nm")); + std::shared_ptr smtSolverFactory = std::make_shared(); + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); EXPECT_EQ(1, program.getNumberOfModules()); - EXPECT_EQ(5026, program.getModule(0).getNumberOfCommands()); + EXPECT_EQ(5024, program.getModule(0).getNumberOfCommands()); +} +TEST(PrismProgramTest, FlattenModules_Coin_Mathsat) { + storm::prism::Program program; ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/coin2.nm")); + std::shared_ptr smtSolverFactory = std::make_shared(); + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); EXPECT_EQ(1, program.getNumberOfModules()); EXPECT_EQ(13, program.getModule(0).getNumberOfCommands()); +} + +TEST(PrismProgramTest, FlattenModules_Dice_Mathsat) { + storm::prism::Program program; + ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/two_dice.nm")); + + std::shared_ptr smtSolverFactory = std::make_shared(); + + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); + EXPECT_EQ(1, program.getNumberOfModules()); + EXPECT_EQ(16, program.getModule(0).getNumberOfCommands()); +} +#endif + +#ifdef STORM_HAVE_Z3 +TEST(PrismProgramTest, FlattenModules_Leader_Z3) { + storm::prism::Program program; + ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/leader3.nm")); + + std::shared_ptr smtSolverFactory = std::make_shared(); + + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); + EXPECT_EQ(1, program.getNumberOfModules()); + EXPECT_EQ(74, program.getModule(0).getNumberOfCommands()); +} + +TEST(PrismProgramTest, FlattenModules_Wlan_Z3) { + storm::prism::Program program; + ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/wlan0_collide.nm")); + + std::shared_ptr smtSolverFactory = std::make_shared(); + + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); + EXPECT_EQ(1, program.getNumberOfModules()); + EXPECT_EQ(179, program.getModule(0).getNumberOfCommands()); +} + +TEST(PrismProgramTest, FlattenModules_Csma_Z3) { + storm::prism::Program program; + ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/csma2_2.nm")); + + std::shared_ptr smtSolverFactory = std::make_shared(); + + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); + EXPECT_EQ(1, program.getNumberOfModules()); + EXPECT_EQ(70, program.getModule(0).getNumberOfCommands()); +} + +TEST(PrismProgramTest, FlattenModules_Firewire_Z3) { + storm::prism::Program program; + ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/firewire.nm")); + + std::shared_ptr smtSolverFactory = std::make_shared(); + + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); + EXPECT_EQ(1, program.getNumberOfModules()); + EXPECT_EQ(5024, program.getModule(0).getNumberOfCommands()); +} + +TEST(PrismProgramTest, FlattenModules_Coin_Z3) { + storm::prism::Program program; + ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/coin2.nm")); + + std::shared_ptr smtSolverFactory = std::make_shared(); + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); + EXPECT_EQ(1, program.getNumberOfModules()); + EXPECT_EQ(13, program.getModule(0).getNumberOfCommands()); +} + +TEST(PrismProgramTest, FlattenModules_Dice_Z3) { + storm::prism::Program program; ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/two_dice.nm")); + std::shared_ptr smtSolverFactory = std::make_shared(); + ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory)); EXPECT_EQ(1, program.getNumberOfModules()); EXPECT_EQ(16, program.getModule(0).getNumberOfCommands());