From 9314d99354a1e9ff445e5eb59345774ab6043bab Mon Sep 17 00:00:00 2001
From: sjunges <sebastian.junges@gmail.com>
Date: Thu, 24 Aug 2017 17:31:05 +0200
Subject: [PATCH] allow inference from actions, but only with a set option

---
 src/storm/builder/BuilderOptions.cpp       |  6 +++-
 src/storm/builder/BuilderOptions.h         | 11 +++++-
 src/storm/builder/ExplicitModelBuilder.cpp | 42 +++++++++++++++++++++-
 src/storm/generator/NextStateGenerator.cpp |  2 --
 4 files changed, 56 insertions(+), 5 deletions(-)

diff --git a/src/storm/builder/BuilderOptions.cpp b/src/storm/builder/BuilderOptions.cpp
index b0d2641bd..4774227fa 100644
--- a/src/storm/builder/BuilderOptions.cpp
+++ b/src/storm/builder/BuilderOptions.cpp
@@ -35,7 +35,7 @@ namespace storm {
             return boost::get<storm::expressions::Expression>(labelOrExpression);
         }
         
-        BuilderOptions::BuilderOptions(bool buildAllRewardModels, bool buildAllLabels) : buildAllRewardModels(buildAllRewardModels), buildAllLabels(buildAllLabels), buildChoiceLabels(false), buildStateValuations(false), buildChoiceOrigins(false), explorationChecks(false), explorationShowProgress(false), explorationShowProgressDelay(0) {
+        BuilderOptions::BuilderOptions(bool buildAllRewardModels, bool buildAllLabels) : buildAllRewardModels(buildAllRewardModels), buildAllLabels(buildAllLabels), buildChoiceLabels(false), buildStateValuations(false), buildChoiceOrigins(false), explorationChecks(false), explorationShowProgress(false), inferObservationsFromActions(false), explorationShowProgressDelay(0) {
             // Intentionally left empty.
         }
         
@@ -156,6 +156,10 @@ namespace storm {
         bool BuilderOptions::isBuildAllLabelsSet() const {
             return buildAllLabels;
         }
+
+        bool BuilderOptions::isInferObservationsFromActionsSet() const {
+            return inferObservationsFromActions;
+        }
         
         BuilderOptions& BuilderOptions::setBuildAllRewardModels(bool newValue) {
             buildAllRewardModels = newValue;
diff --git a/src/storm/builder/BuilderOptions.h b/src/storm/builder/BuilderOptions.h
index f3a6edfbf..0aeb0f674 100644
--- a/src/storm/builder/BuilderOptions.h
+++ b/src/storm/builder/BuilderOptions.h
@@ -82,6 +82,7 @@ namespace storm {
              */
             void setTerminalStatesFromFormula(storm::logic::Formula const& formula);
 
+
             /*!
              * Which reward models are built
              * @return
@@ -107,6 +108,7 @@ namespace storm {
             bool isBuildAllLabelsSet() const;
             bool isExplorationChecksSet() const;
             bool isExplorationShowProgressSet() const;
+            bool isInferObservationsFromActionsSet() const;
             uint64_t getExplorationShowProgressDelay() const;
 
             /**
@@ -155,7 +157,11 @@ namespace storm {
              * @return this
              */
             BuilderOptions& setExplorationChecks(bool newValue = true);
-            
+
+
+            BuilderOptions& setInferObservationsFromActions(bool newValue = true);
+
+
         private:
             /// A flag that indicates whether all reward models are to be built. In this case, the reward model names are
             /// to be ignored.
@@ -191,6 +197,9 @@ namespace storm {
             
             /// A flag that stores whether the progress of exploration is to be printed.
             bool explorationShowProgress;
+
+            /// For POMDPs, should we allow inference of observation classes from different enabled actions.
+            bool inferObservationsFromActions;
             
             /// The delay for printing progress information.
             uint64_t explorationShowProgressDelay;
diff --git a/src/storm/builder/ExplicitModelBuilder.cpp b/src/storm/builder/ExplicitModelBuilder.cpp
index 79d6fc15d..1f2a53d54 100644
--- a/src/storm/builder/ExplicitModelBuilder.cpp
+++ b/src/storm/builder/ExplicitModelBuilder.cpp
@@ -330,9 +330,49 @@ namespace storm {
             }
             if (generator->isPartiallyObservable()) {
                 std::vector<uint32_t> classes;
+                uint32_t newObservation = 0;
                 classes.resize(stateStorage.getNumberOfStates());
+                std::unordered_map<uint32_t, std::vector<std::pair<std::vector<std::string>, uint32_t>>> observationActions;
                 for (auto const& bitVectorIndexPair : stateStorage.stateToId) {
-                    classes[bitVectorIndexPair.second] = generator->observabilityClass(bitVectorIndexPair.first);
+                    uint32_t varObservation = generator->observabilityClass(bitVectorIndexPair.first);
+                    uint32_t observation = -1; // Is replaced later on.
+                    bool foundActionSet = false;
+                    std::vector<std::string> actionNames;
+                    bool addedAnonymousAction = false;
+                    for (uint64 choice = modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second]; choice < modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second+1]; ++choice) {
+                        if (modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).empty()) {
+                            STORM_LOG_THROW(!addedAnonymousAction, storm::exceptions::WrongFormatException, "Cannot have multiple anonymous actions, as these cannot be mapped correctly.");
+                            actionNames.push_back("");
+                            addedAnonymousAction = true;
+                        } else {
+                            STORM_LOG_ASSERT(modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size() == 1, "Expect choice labelling to contain exactly one label at this point, but found " << modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size());
+                            actionNames.push_back(*modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).begin());
+                        }
+                    }
+                    STORM_LOG_TRACE("VarObservation: " << varObservation << " Action Names: " << storm::utility::vector::toString(actionNames));
+                    auto it = observationActions.find(varObservation);
+                    if (it == observationActions.end()) {
+                        observationActions.emplace(varObservation, std::vector<std::pair<std::vector<std::string>, uint32_t>>());
+                    } else {
+                        for(auto const& entries : it->second) {
+                            STORM_LOG_TRACE(storm::utility::vector::toString(entries.first));
+                            if (entries.first == actionNames) {
+                                observation = entries.second;
+                                foundActionSet = true;
+                                break;
+                            }
+                        }
+
+                        STORM_LOG_THROW(generator->getOptions().isInferObservationsFromActionsSet() ||  foundActionSet, storm::exceptions::WrongFormatException, "Two states with the same observation have a different set of enabled actions, this is only allowed with a special option.");
+
+                    }
+                    if (!foundActionSet) {
+                        observation = newObservation;
+                        observationActions.find(varObservation)->second.emplace_back(actionNames, newObservation);
+                        ++newObservation;
+                    }
+
+                    classes[bitVectorIndexPair.second] = observation;
                 }
                 modelComponents.observabilityClasses = classes;
             }
diff --git a/src/storm/generator/NextStateGenerator.cpp b/src/storm/generator/NextStateGenerator.cpp
index faaadf1ed..8d35f9cf2 100644
--- a/src/storm/generator/NextStateGenerator.cpp
+++ b/src/storm/generator/NextStateGenerator.cpp
@@ -158,9 +158,7 @@ namespace storm {
         uint32_t NextStateGenerator<ValueType, StateType>::observabilityClass(CompressedState const &state) const {
             if (this->mask.size() == 0) {
                 this->mask = computeObservabilityMask(variableInformation);
-                std::cout << mask.size() << std::endl;
             }
-            std::cout << state.size() << std::endl;
 
             return unpackStateToObservabilityClass(state, observabilityMap, mask);
         }