From b3e77730a9575f60039846da07b7889289c0b014 Mon Sep 17 00:00:00 2001
From: dehnert <dehnert@cs.rwth-aachen.de>
Date: Wed, 10 Aug 2016 12:55:52 +0200
Subject: [PATCH] added uniqueness mechanism in flattenModules to compensate
 for missing uniqueness in allsat of solvers

Former-commit-id: b4ebd17f68acbe6503f6d708c314a69292eb748f
---
 src/storage/prism/Program.cpp                |  20 +++-
 src/utility/vector.h                         |  14 ++-
 test/functional/storage/PrismProgramTest.cpp | 101 ++++++++++++++++++-
 3 files changed, 126 insertions(+), 9 deletions(-)

diff --git a/src/storage/prism/Program.cpp b/src/storage/prism/Program.cpp
index 21f999a40..516983fa6 100644
--- a/src/storage/prism/Program.cpp
+++ b/src/storage/prism/Program.cpp
@@ -7,8 +7,6 @@
 #include "src/storage/expressions/ExpressionManager.h"
 #include "src/settings/SettingsManager.h"
 #include "src/settings/modules/IOSettings.h"
-#include "src/utility/macros.h"
-#include "src/utility/solver.h"
 #include "src/exceptions/InvalidArgumentException.h"
 #include "src/exceptions/OutOfRangeException.h"
 #include "src/exceptions/WrongFormatException.h"
@@ -18,6 +16,10 @@
 
 #include "src/storage/jani/Model.h"
 
+#include "src/utility/macros.h"
+#include "src/utility/solver.h"
+#include "src/utility/vector.h"
+
 #include "src/storage/prism/CompositionVisitor.h"
 #include "src/storage/prism/Compositions.h"
 #include "src/storage/prism/CompositionToJaniVisitor.h"
@@ -1360,7 +1362,10 @@ namespace storm {
                         solver->add(atLeastOneCommandFromModule);
                     }
                     
-                    // Now we are in a position to start the enumeration over all command variables.
+                    // Now we are in a position to start the enumeration over all command variables. While doing so, we
+                    // keep track of previously seen command combinations, because the AllSat procedures are not
+                    // always guaranteed to only provide distinct models.
+                    std::unordered_set<std::vector<uint_fast64_t>, storm::utility::vector::VectorHash<uint_fast64_t>> seenCommandCombinations;
                     solver->allSat(allCommandVariables, [&] (storm::solver::SmtSolver::ModelReference& modelReference) -> bool {
                         // Now we need to reconstruct the chosen commands from the valuation of the command variables.
                         std::vector<std::vector<std::reference_wrapper<Command const>>> chosenCommands(possibleCommands.size());
@@ -1382,12 +1387,19 @@ namespace storm {
                         
                         bool movedAtLeastOneIterator = false;
                         std::vector<std::reference_wrapper<Command const>> commandCombination(chosenCommands.size(), chosenCommands.front().front());
+                        std::vector<uint_fast64_t> commandCombinationIndices(iterators.size());
                         do {
                             for (uint_fast64_t index = 0; index < iterators.size(); ++index) {
                                 commandCombination[index] = *iterators[index];
+                                commandCombinationIndices[index] = commandCombination[index].get().getGlobalIndex();
                             }
                             
-                            newCommands.push_back(synchronizeCommands(nextCommandIndex, actionIndex, nextUpdateIndex, indexToActionMap.find(actionIndex)->second, commandCombination));
+                            // Only add the command combination if it was not previously seen.
+                            auto seenIt = seenCommandCombinations.find(commandCombinationIndices);
+                            if (seenIt == seenCommandCombinations.end()) {
+                                newCommands.push_back(synchronizeCommands(nextCommandIndex, actionIndex, nextUpdateIndex, indexToActionMap.find(actionIndex)->second, commandCombination));
+                                seenCommandCombinations.insert(commandCombinationIndices);
+                            }
                             
                             // Move the counters appropriately.
                             ++nextCommandIndex;
diff --git a/src/utility/vector.h b/src/utility/vector.h
index 6faf88c7c..6069e78cb 100644
--- a/src/utility/vector.h
+++ b/src/utility/vector.h
@@ -25,6 +25,18 @@ namespace storm {
     namespace utility {
         namespace vector {
 
+            template<typename ValueType>
+            struct VectorHash {
+                size_t operator()(std::vector<ValueType> const& vec) const {
+                    std::hash<ValueType> hasher;
+                    std::size_t seed = 0;
+                    for (ValueType const& element : vec) {
+                        seed ^= hasher(element) + 0x9e3779b9 + (seed<<6) + (seed>>2);
+                    }
+                    return seed;
+                }
+            };
+            
             /*!
              * Sets the provided values at the provided positions in the given vector.
              *
@@ -712,7 +724,7 @@ namespace storm {
              * @return String containing the representation of the vector.
              */
             template<typename ValueType>
-            std::string toString(std::vector<ValueType> vector) {
+            std::string toString(std::vector<ValueType> const& vector) {
                 std::stringstream stream;
                 stream << "vector (" << vector.size() << ") [ ";
                 if (!vector.empty()) {
diff --git a/test/functional/storage/PrismProgramTest.cpp b/test/functional/storage/PrismProgramTest.cpp
index e8ea63c98..3cffe21a1 100644
--- a/test/functional/storage/PrismProgramTest.cpp
+++ b/test/functional/storage/PrismProgramTest.cpp
@@ -7,7 +7,7 @@
 #include "src/storage/jani/Model.h"
 
 #ifdef STORM_HAVE_MSAT
-TEST(PrismProgramTest, FlattenModules) {
+TEST(PrismProgramTest, FlattenModules_Leader_Mathsat) {
     storm::prism::Program program;
     ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/leader3.nm"));
 
@@ -16,33 +16,126 @@ TEST(PrismProgramTest, FlattenModules) {
     ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
     EXPECT_EQ(1, program.getNumberOfModules());
     EXPECT_EQ(74, program.getModule(0).getNumberOfCommands());
+}
 
+TEST(PrismProgramTest, FlattenModules_Wlan_Mathsat) {
+    storm::prism::Program program;
     ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/wlan0_collide.nm"));
     
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>();
+
     ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
     EXPECT_EQ(1, program.getNumberOfModules());
-    EXPECT_EQ(180, program.getModule(0).getNumberOfCommands());
+    EXPECT_EQ(179, program.getModule(0).getNumberOfCommands());
+}
 
+TEST(PrismProgramTest, FlattenModules_Csma_Mathsat) {
+    storm::prism::Program program;
     ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/csma2_2.nm"));
     
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>();
+
     ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
     EXPECT_EQ(1, program.getNumberOfModules());
-    EXPECT_EQ(71, program.getModule(0).getNumberOfCommands());
+    EXPECT_EQ(70, program.getModule(0).getNumberOfCommands());
+}
 
+TEST(PrismProgramTest, FlattenModules_Firewire_Mathsat) {
+    storm::prism::Program program;
     ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/firewire.nm"));
     
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>();
+
     ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
     EXPECT_EQ(1, program.getNumberOfModules());
-    EXPECT_EQ(5026, program.getModule(0).getNumberOfCommands());
+    EXPECT_EQ(5024, program.getModule(0).getNumberOfCommands());
+}
 
+TEST(PrismProgramTest, FlattenModules_Coin_Mathsat) {
+    storm::prism::Program program;
     ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/coin2.nm"));
     
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>();
+
     ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
     EXPECT_EQ(1, program.getNumberOfModules());
     EXPECT_EQ(13, program.getModule(0).getNumberOfCommands());
+}
+
+TEST(PrismProgramTest, FlattenModules_Dice_Mathsat) {
+    storm::prism::Program program;
+    ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/two_dice.nm"));
+
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::MathsatSmtSolverFactory>();
+
+    ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
+    EXPECT_EQ(1, program.getNumberOfModules());
+    EXPECT_EQ(16, program.getModule(0).getNumberOfCommands());
+}
+#endif
+
+#ifdef STORM_HAVE_Z3
+TEST(PrismProgramTest, FlattenModules_Leader_Z3) {
+    storm::prism::Program program;
+    ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/leader3.nm"));
+    
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::Z3SmtSolverFactory>();
+    
+    ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
+    EXPECT_EQ(1, program.getNumberOfModules());
+    EXPECT_EQ(74, program.getModule(0).getNumberOfCommands());
+}
+
+TEST(PrismProgramTest, FlattenModules_Wlan_Z3) {
+    storm::prism::Program program;
+    ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/wlan0_collide.nm"));
+    
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::Z3SmtSolverFactory>();
+    
+    ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
+    EXPECT_EQ(1, program.getNumberOfModules());
+    EXPECT_EQ(179, program.getModule(0).getNumberOfCommands());
+}
+
+TEST(PrismProgramTest, FlattenModules_Csma_Z3) {
+    storm::prism::Program program;
+    ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/csma2_2.nm"));
+    
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::Z3SmtSolverFactory>();
+    
+    ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
+    EXPECT_EQ(1, program.getNumberOfModules());
+    EXPECT_EQ(70, program.getModule(0).getNumberOfCommands());
+}
+
+TEST(PrismProgramTest, FlattenModules_Firewire_Z3) {
+    storm::prism::Program program;
+    ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/firewire.nm"));
+    
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::Z3SmtSolverFactory>();
+    
+    ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
+    EXPECT_EQ(1, program.getNumberOfModules());
+    EXPECT_EQ(5024, program.getModule(0).getNumberOfCommands());
+}
+
+TEST(PrismProgramTest, FlattenModules_Coin_Z3) {
+    storm::prism::Program program;
+    ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/coin2.nm"));
+    
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::Z3SmtSolverFactory>();
     
+    ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
+    EXPECT_EQ(1, program.getNumberOfModules());
+    EXPECT_EQ(13, program.getModule(0).getNumberOfCommands());
+}
+
+TEST(PrismProgramTest, FlattenModules_Dice_Z3) {
+    storm::prism::Program program;
     ASSERT_NO_THROW(program = storm::parser::PrismParser::parse(STORM_CPP_TESTS_BASE_PATH "/functional/parser/prism/two_dice.nm"));
     
+    std::shared_ptr<storm::utility::solver::SmtSolverFactory> smtSolverFactory = std::make_shared<storm::utility::solver::Z3SmtSolverFactory>();
+    
     ASSERT_NO_THROW(program = program.flattenModules(smtSolverFactory));
     EXPECT_EQ(1, program.getNumberOfModules());
     EXPECT_EQ(16, program.getModule(0).getNumberOfCommands());