From 71f60e812ca3b3c53d43deb109ff5faee93eb76e Mon Sep 17 00:00:00 2001
From: Sebastian Junges <sebastian.junges@gmail.com>
Date: Sun, 7 Mar 2021 23:02:57 -0800
Subject: [PATCH] more precise analysis of whether commands will synchronize

---
 .../generator/PrismNextStateGenerator.cpp     | 16 +++++++++++--
 src/storm/generator/PrismNextStateGenerator.h |  2 ++
 src/storm/storage/prism/Program.cpp           | 24 ++++++++++++++++++-
 src/storm/storage/prism/Program.h             |  5 ++++
 4 files changed, 44 insertions(+), 3 deletions(-)

diff --git a/src/storm/generator/PrismNextStateGenerator.cpp b/src/storm/generator/PrismNextStateGenerator.cpp
index 6678a87b4..eeb7db84f 100644
--- a/src/storm/generator/PrismNextStateGenerator.cpp
+++ b/src/storm/generator/PrismNextStateGenerator.cpp
@@ -482,6 +482,9 @@ namespace storm {
                 bool hasOneEnabledCommand = false;
                 for (auto commandIndexIt = commandIndices.begin(), commandIndexIte = commandIndices.end(); commandIndexIt != commandIndexIte; ++commandIndexIt) {
                     storm::prism::Command const& command = module.getCommand(*commandIndexIt);
+                    if (!isCommandPotentiallySynchronizing(command)) {
+                        continue;
+                    }
                     if (commandFilter != CommandFilter::All) {
                         STORM_LOG_ASSERT(commandFilter == CommandFilter::Markovian || commandFilter == CommandFilter::Probabilistic, "Unexpected command filter.");
                         if ((commandFilter == CommandFilter::Markovian) != command.isMarkovian()) {
@@ -546,8 +549,8 @@ namespace storm {
                 for (uint_fast64_t j = 0; j < module.getNumberOfCommands(); ++j) {
                     storm::prism::Command const& command = module.getCommand(j);
 
-                    // Only consider unlabeled commands.
-                    if (command.isLabeled()) continue;
+                    // Only consider commands that are not possibly synchronizing.
+                    if (isCommandPotentiallySynchronizing(command)) continue;
 
                     if (commandFilter != CommandFilter::All) {
                         STORM_LOG_ASSERT(commandFilter == CommandFilter::Markovian || commandFilter == CommandFilter::Probabilistic, "Unexpected command filter.");
@@ -607,6 +610,10 @@ namespace storm {
                         choice.addReward(stateActionRewardValue);
                     }
 
+                    if (this->options.isBuildChoiceLabelsSet() && command.isLabeled()) {
+                        choice.addLabel(program.getActionName(command.getActionIndex()));
+                    }
+
                     if (program.getModelType() == storm::prism::Program::ModelType::SMG) {
                         storm::storage::PlayerIndex const& playerOfModule = moduleIndexToPlayerIndexMap.at(i);
                         STORM_LOG_THROW(playerOfModule != storm::storage::INVALID_PLAYER_INDEX, storm::exceptions::WrongFormatException, "Module " << module.getName() << " is not owned by any player but has at least one enabled, unlabeled command.");
@@ -838,6 +845,11 @@ namespace storm {
             return std::make_shared<storm::storage::sparse::PrismChoiceOrigins>(std::make_shared<storm::prism::Program>(program), std::move(identifiers), std::move(identifierToCommandSetMapping));
         }
 
+        template<typename ValueType, typename StateType>
+        bool PrismNextStateGenerator<ValueType, StateType>::isCommandPotentiallySynchronizing(const prism::Command &command) const {
+            return program.getPossiblySynchronizingCommands().get(command.getGlobalIndex());
+        }
+
 
         template class PrismNextStateGenerator<double>;
 
diff --git a/src/storm/generator/PrismNextStateGenerator.h b/src/storm/generator/PrismNextStateGenerator.h
index 2aa1317bd..7e63da41a 100644
--- a/src/storm/generator/PrismNextStateGenerator.h
+++ b/src/storm/generator/PrismNextStateGenerator.h
@@ -117,6 +117,8 @@ namespace storm {
              */
             void generateSynchronizedDistribution(storm::storage::BitVector const& state, ValueType const& probability, uint64_t position, std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>::const_iterator> const& iteratorList, storm::builder::jit::Distribution<StateType, ValueType>& distribution, StateToIdCallback stateToIdCallback);
 
+            bool isCommandPotentiallySynchronizing(prism::Command const& command) const;
+
             // The program used for the generation of next states.
             storm::prism::Program program;
 
diff --git a/src/storm/storage/prism/Program.cpp b/src/storm/storage/prism/Program.cpp
index d7bb65148..ca17bf607 100644
--- a/src/storm/storage/prism/Program.cpp
+++ b/src/storm/storage/prism/Program.cpp
@@ -146,7 +146,7 @@ namespace storm {
         formulas(formulas), formulaToIndexMap(), players(players), modules(modules), moduleToIndexMap(),
         rewardModels(rewardModels), rewardModelToIndexMap(), systemCompositionConstruct(compositionConstruct),
         labels(labels), labelToIndexMap(), observationLabels(observationLabels), actionToIndexMap(actionToIndexMap), indexToActionMap(), actions(),
-        synchronizingActionIndices(), actionIndicesToModuleIndexMap(), variableToModuleIndexMap(), prismCompatibility(prismCompatibility)
+        synchronizingActionIndices(), actionIndicesToModuleIndexMap(), variableToModuleIndexMap(), possiblySynchronizingCommands(), prismCompatibility(prismCompatibility)
         {
 
             // Start by creating the necessary mappings from the given ones.
@@ -163,6 +163,24 @@ namespace storm {
                 }
             }
 
+            possiblySynchronizingCommands = storage::BitVector(this->getNumberOfCommands());
+            std::set<uint64_t> possiblySynchronizingActionIndices;
+            for(uint64_t syncAction : synchronizingActionIndices) {
+                if (getModuleIndicesByActionIndex(syncAction).size() > 1) {
+                    std::cout << "syncAction " << syncAction << std::endl;
+                    possiblySynchronizingActionIndices.insert(syncAction);
+                }
+            }
+            for (auto const& module : getModules()) {
+                for (auto const& command : module.getCommands()) {
+                    if (command.isLabeled()) {
+                        if (possiblySynchronizingActionIndices.count(command.getActionIndex())) {
+                            possiblySynchronizingCommands.set(command.getGlobalIndex());
+                        }
+                    }
+                }
+            }
+
             if (finalModel) {
                 // If the model is supposed to be a CTMC, but contains probabilistic commands, we transform them to Markovian
                 // commands and issue a warning.
@@ -816,6 +834,10 @@ namespace storm {
             return this->observationLabels.size();
         }
 
+        storm::storage::BitVector const& Program::getPossiblySynchronizingCommands() const {
+            return possiblySynchronizingCommands;
+        }
+
         Program Program::restrictCommands(storm::storage::FlatSet<uint_fast64_t> const& indexSet) const {
             std::vector<storm::prism::Module> newModules;
             newModules.reserve(this->getNumberOfModules());
diff --git a/src/storm/storage/prism/Program.h b/src/storm/storage/prism/Program.h
index 3092af7f8..697eb9e0f 100644
--- a/src/storm/storage/prism/Program.h
+++ b/src/storm/storage/prism/Program.h
@@ -6,6 +6,7 @@
 #include <vector>
 #include <set>
 #include <boost/optional.hpp>
+#include <storm/storage/BitVector.h>
 
 #include "storm/storage/prism/Constant.h"
 #include "storm/storage/prism/Formula.h"
@@ -729,6 +730,8 @@ namespace storm {
              */
             std::pair<storm::jani::Model, std::vector<storm::jani::Property>> toJani(std::vector<storm::jani::Property> const& properties, bool allVariablesGlobal = true, std::string suffix = "") const;
 
+            storm::storage::BitVector const& getPossiblySynchronizingCommands() const;
+
         private:
             /*!
              * This function builds a command that corresponds to the synchronization of the given list of commands.
@@ -831,6 +834,8 @@ namespace storm {
             // A mapping from variable names to the modules in which they were declared.
             std::map<std::string, uint_fast64_t> variableToModuleIndexMap;
 
+            storage::BitVector possiblySynchronizingCommands;
+
             bool prismCompatibility;
         };