From 79641ef1310d99a187fba3083fba9c3c6dddfb26 Mon Sep 17 00:00:00 2001
From: Tim Quatmann <tim.quatmann@cs.rwth-aachen.de>
Date: Wed, 1 Apr 2020 15:59:31 +0200
Subject: [PATCH] Started to make the BeliefMdpExplorer more flexible, allowing
 to restart the exploration

---
 src/storm-pomdp/builder/BeliefMdpExplorer.h | 224 ++++++++++++++------
 src/storm-pomdp/storage/BeliefManager.h     |   4 +-
 2 files changed, 162 insertions(+), 66 deletions(-)

diff --git a/src/storm-pomdp/builder/BeliefMdpExplorer.h b/src/storm-pomdp/builder/BeliefMdpExplorer.h
index 86f49fe02..e13e20cf3 100644
--- a/src/storm-pomdp/builder/BeliefMdpExplorer.h
+++ b/src/storm-pomdp/builder/BeliefMdpExplorer.h
@@ -11,6 +11,7 @@
 #include "storm/api/verification.h"
 
 #include "storm/storage/BitVector.h"
+#include "storm/storage/SparseMatrix.h"
 #include "storm/utility/macros.h"
 #include "storm-pomdp/storage/BeliefManager.h"
 #include "storm/utility/SignalHandler.h"
@@ -19,6 +20,7 @@
 #include "storm/modelchecker/results/ExplicitQuantitativeCheckResult.h"
 #include "storm/modelchecker/hints/ExplicitModelCheckerHint.cpp"
 
+
 namespace storm {
     namespace builder {
         template<typename PomdpType, typename BeliefValueType = typename PomdpType::ValueType>
@@ -46,16 +48,17 @@ namespace storm {
                 // Reset data from potential previous explorations
                 mdpStateToBeliefIdMap.clear();
                 beliefIdToMdpStateMap.clear();
-                beliefIdsWithMdpState.clear();
-                beliefIdsWithMdpState.grow(beliefManager->getNumberOfBeliefIds(), false);
+                exploredBeliefIds.clear();
+                exploredBeliefIds.grow(beliefManager->getNumberOfBeliefIds(), false);
+                mdpStatesToExplore.clear();
                 lowerValueBounds.clear();
                 upperValueBounds.clear();
                 values.clear();
-                mdpTransitionsBuilder = storm::storage::SparseMatrixBuilder<ValueType>(0, 0, 0, true, true);
-                currentRowCount = 0;
-                startOfCurrentRowGroup = 0;
+                exploredMdpTransitions.clear();
+                exploredChoiceIndices.clear();
                 mdpActionRewards.clear();
                 exploredMdp = nullptr;
+                currentMdpState = noState();
                 
                 // Add some states with special treatment (if requested)
                 if (extraBottomStateValue) {
@@ -63,10 +66,8 @@ namespace storm {
                     mdpStateToBeliefIdMap.push_back(beliefManager->noId());
                     insertValueHints(extraBottomStateValue.get(), extraBottomStateValue.get());
 
-                    startOfCurrentRowGroup = currentRowCount;
-                    mdpTransitionsBuilder.newRowGroup(startOfCurrentRowGroup);
-                    mdpTransitionsBuilder.addNextValue(currentRowCount, extraBottomState.get(), storm::utility::one<ValueType>());
-                    ++currentRowCount;
+                    internalAddRowGroupIndex();
+                    internalAddTransition(getStartOfCurrentRowGroup(), extraBottomState.get(), storm::utility::one<ValueType>());
                 } else {
                     extraBottomState = boost::none;
                 }
@@ -75,10 +76,8 @@ namespace storm {
                     mdpStateToBeliefIdMap.push_back(beliefManager->noId());
                     insertValueHints(extraTargetStateValue.get(), extraTargetStateValue.get());
                     
-                    startOfCurrentRowGroup = currentRowCount;
-                    mdpTransitionsBuilder.newRowGroup(startOfCurrentRowGroup);
-                    mdpTransitionsBuilder.addNextValue(currentRowCount, extraTargetState.get(), storm::utility::one<ValueType>());
-                    ++currentRowCount;
+                    internalAddRowGroupIndex();
+                    internalAddTransition(getStartOfCurrentRowGroup(), extraTargetState.get(), storm::utility::one<ValueType>());
                     
                     targetStates.grow(getCurrentNumberOfMdpStates(), false);
                     targetStates.set(extraTargetState.get(), true);
@@ -89,24 +88,62 @@ namespace storm {
                 // Set up the initial state.
                 initialMdpState = getOrAddMdpState(beliefManager->getInitialBelief());
             }
+            
+            /*!
+             * Restarts the exploration to allow re-exploring each state.
+             * After calling this, the "currently explored" MDP has the same number of states and choices as the "old" one, but the choices are still empty
+             * This method inserts the initial state of the MDP in the exploration queue.
+             * While re-exploring, the reference to the old MDP remains valid.
+             */
+            void restartExploration() {
+                STORM_LOG_ASSERT(status == Status::ModelChecked || status == Status::ModelFinished, "Method call is invalid in current status.");
+                 // We will not erase old states during the exploration phase, so most state-based data (like mappings between MDP and Belief states) remain valid.
+                exploredBeliefIds.clear();
+                exploredBeliefIds.grow(beliefManager->getNumberOfBeliefIds(), false);
+                exploredMdpTransitions.clear();
+                exploredMdpTransitions.resize(exploredMdp->getNumberOfChoices);
+                exploredChoiceIndices = exploredMdp->getNondeterministicChoiceIndices();
+                mdpActionRewards.clear();
+                if (exploredMdp->hasRewardModel()) {
+                    // Can be overwritten during exploration
+                    mdpActionRewards = exploredMdp->getUniqueRewardModel().getStateActionRewardVector();
+                }
+                targetStates = storm::storage::BitVector(getCurrentNumberOfMdpStates(), false);
+                truncatedStates = storm::storage::BitVector(getCurrentNumberOfMdpStates(), false);
+                mdpStatesToExplore.clear();
+
+                // The extra states are not changed
+                if (extraBottomState) {
+                    currentMdpState = extraBottomState.get();
+                    restoreOldBehaviorAtCurrentState(0);
+                }
+                if (extraTargetState) {
+                    currentMdpState = extraTargetState.get();
+                    restoreOldBehaviorAtCurrentState(0);
+                }
+                currentMdpState = noState();
+                
+                // Set up the initial state.
+                initialMdpState = getOrAddMdpState(beliefManager->getInitialBelief());
+            }
     
             bool hasUnexploredState() const {
                 STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
-                return !beliefIdsToExplore.empty();
+                return !mdpStatesToExplore.empty();
             }
     
             BeliefId exploreNextState() {
                 STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
-                // Set up the matrix builder
-                finishCurrentRow();
-                startOfCurrentRowGroup = currentRowCount;
-                mdpTransitionsBuilder.newRowGroup(startOfCurrentRowGroup);
-                ++currentRowCount;
                 
                 // Pop from the queue.
-                auto result = beliefIdsToExplore.front();
-                beliefIdsToExplore.pop_front();
-                return result;
+                currentMdpState = mdpStatesToExplore.front();
+                mdpStatesToExplore.pop_front();
+                
+                if (!currentStateHasOldBehavior()) {
+                    internalAddRowGroupIndex();
+                }
+                
+                return mdpStateToBeliefIdMap[currentMdpState];
             }
             
             void addTransitionsToExtraStates(uint64_t const& localActionIndex, ValueType const& targetStateValue = storm::utility::zero<ValueType>(), ValueType const& bottomStateValue = storm::utility::zero<ValueType>()) {
@@ -114,7 +151,7 @@ namespace storm {
                 // We first insert the entries of the current row in a separate map.
                 // This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
                 
-                uint64_t row = startOfCurrentRowGroup + localActionIndex;
+                uint64_t row = getStartOfCurrentRowGroup() + localActionIndex;
                 if (!storm::utility::isZero(bottomStateValue)) {
                     STORM_LOG_ASSERT(extraBottomState.is_initialized(), "Requested a transition to the extra bottom state but there is none.");
                     internalAddTransition(row, extraBottomState.get(), bottomStateValue);
@@ -127,7 +164,7 @@ namespace storm {
             
             void addSelfloopTransition(uint64_t const& localActionIndex = 0, ValueType const& value = storm::utility::one<ValueType>()) {
                 STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
-                uint64_t row = startOfCurrentRowGroup + localActionIndex;
+                uint64_t row = getStartOfCurrentRowGroup() + localActionIndex;
                 internalAddTransition(row, getCurrentMdpState(), value);
             }
             
@@ -145,24 +182,24 @@ namespace storm {
                 // This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
                 MdpStateType column;
                 if (ignoreNewBeliefs) {
-                    column = getMdpState(transitionTarget);
+                    column = getExploredMdpState(transitionTarget);
                     if (column == noState()) {
                         return false;
                     }
                 } else {
                     column = getOrAddMdpState(transitionTarget);
                 }
-                uint64_t row = startOfCurrentRowGroup + localActionIndex;
+                uint64_t row = getStartOfCurrentRowGroup() + localActionIndex;
                 internalAddTransition(row, column, value);
                 return true;
             }
             
             void computeRewardAtCurrentState(uint64 const& localActionIndex, ValueType extraReward = storm::utility::zero<ValueType>()) {
                 STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
-                if (currentRowCount >= mdpActionRewards.size()) {
-                    mdpActionRewards.resize(currentRowCount, storm::utility::zero<ValueType>());
+                if (getCurrentNumberOfMdpChoices() > mdpActionRewards.size()) {
+                    mdpActionRewards.resize(getCurrentNumberOfMdpChoices(), storm::utility::zero<ValueType>());
                 }
-                uint64_t row = startOfCurrentRowGroup + localActionIndex;
+                uint64_t row = getStartOfCurrentRowGroup() + localActionIndex;
                 mdpActionRewards[row] = beliefManager->getBeliefActionReward(getCurrentBeliefId(), localActionIndex) + extraReward;
             }
             
@@ -178,11 +215,64 @@ namespace storm {
                 truncatedStates.set(getCurrentMdpState(), true);
             }
             
+            bool currentStateHasOldBehavior() {
+                STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
+                return exploredMdp && getCurrentMdpState() < exploredMdp->getNumberOfStates();
+            }
+            
+            /*!
+             * Inserts transitions and rewards at the given action as in the MDP of the previous exploration.
+             * Does NOT set whether the state is truncated and/or target.
+             * Will add "old" states that have not been considered before into the exploration queue
+             * @param localActionIndex
+             */
+            void restoreOldBehaviorAtCurrentState(uint64_t const& localActionIndex) {
+                STORM_LOG_ASSERT(currentStateHasOldBehavior(), "Cannot restore old behavior as the current state does not have any.");
+                uint64_t choiceIndex = exploredChoiceIndices[getCurrentMdpState()] + localActionIndex;
+                STORM_LOG_ASSERT(choiceIndex < exploredChoiceIndices[getCurrentMdpState() + 1], "Invalid local action index.");
+                
+                // Insert the transitions
+                for (auto const& transition : exploredMdp->getTransitionMatrix().getRow(choiceIndex)) {
+                    internalAddTransition(choiceIndex, transition.getColumn(), transition.getValue());
+                    // Check whether exploration is needed
+                    auto beliefId = mdpStateToBeliefIdMap[transition.getColumn()];
+                    if (beliefId != beliefManager->noId()) { // Not the extra target or bottom state
+                        if (!exploredBeliefIds.get(beliefId)) {
+                            // This belief needs exploration
+                            exploredBeliefIds.set(beliefId, true);
+                            mdpStatesToExplore.push_back(transition.getColumn());
+                        }
+                    }
+                }
+                
+                // Actually, nothing needs to be done for rewards since we already initialize the vector with the "old" values
+            }
+            
             void finishExploration() {
                 STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
+                STORM_LOG_ASSERT(!hasUnexploredState(), "Finishing exploration not possible if there are still unexplored states.");
+                // Finish the last row grouping in case the last explored state was new
+                if (!currentStateHasOldBehavior()) {
+                    internalAddRowGroupIndex();
+                }
+                
                 // Create the tranistion matrix
-                finishCurrentRow();
-                auto mdpTransitionMatrix = mdpTransitionsBuilder.build(getCurrentNumberOfMdpChoices(), getCurrentNumberOfMdpStates(), getCurrentNumberOfMdpStates());
+                uint64_t entryCount = 0;
+                for (auto const& row : exploredMdpTransitions) {
+                    entryCount += row.size();
+                }
+                storm::storage::SparseMatrixBuilder<ValueType> builder(getCurrentNumberOfMdpChoices(), getCurrentNumberOfMdpStates(), entryCount, true, true, getCurrentNumberOfMdpStates());
+                for (uint64_t groupIndex = 0; groupIndex < exploredChoiceIndices.size() - 1; ++groupIndex) {
+                    uint64_t rowIndex = exploredChoiceIndices[groupIndex];
+                    uint64_t groupEnd = exploredChoiceIndices[groupIndex + 1];
+                    builder.newRowGroup(rowIndex);
+                    for (; rowIndex < groupEnd; ++rowIndex) {
+                        for (auto const& entry : exploredMdpTransitions[rowIndex]) {
+                            builder.addNextValue(rowIndex, entry.first, entry.second);
+                        }
+                    }
+                }
+                auto mdpTransitionMatrix = builder.build();
                 
                 // Create a standard labeling
                 storm::models::sparse::StateLabeling mdpLabeling(getCurrentNumberOfMdpStates());
@@ -212,13 +302,18 @@ namespace storm {
             }
             
             MdpStateType getCurrentNumberOfMdpStates() const {
-                STORM_LOG_ASSERT(status != Status::Uninitialized, "Method call is invalid in current status.");
+                STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
                 return mdpStateToBeliefIdMap.size();
             }
     
             MdpStateType getCurrentNumberOfMdpChoices() const {
-                STORM_LOG_ASSERT(status != Status::Uninitialized, "Method call is invalid in current status.");
-                return currentRowCount;
+                STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
+                return exploredMdpTransitions.size();
+            }
+    
+            MdpStateType getStartOfCurrentRowGroup() const {
+                STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
+                return exploredChoiceIndices.back();
             }
 
             ValueType getLowerValueBoundAtCurrentState() const {
@@ -291,7 +386,8 @@ namespace storm {
             }
             
             MdpStateType getCurrentMdpState() const {
-                return mdpTransitionsBuilder.getCurrentRowGroupCount() - 1;
+                STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
+                return currentMdpState;
             }
             
             MdpStateType getCurrentBeliefId() const {
@@ -299,27 +395,20 @@ namespace storm {
             }
             
             void internalAddTransition(uint64_t const& row, MdpStateType const& column, ValueType const& value) {
-                // We first insert the entries of the current row in a separate map.
-                // This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
-                STORM_LOG_ASSERT(row >= currentRowCount - 1, "Trying to insert in an already completed row.");
-                if (row >= currentRowCount) {
-                    // We are going to start a new row, so insert the entries of the old one
-                    finishCurrentRow();
-                    currentRowCount = row + 1;
+                STORM_LOG_ASSERT(row <= exploredMdpTransitions.size(), "Skipped at least one row.");
+                if (row == exploredMdpTransitions.size()) {
+                    exploredMdpTransitions.emplace_back();
                 }
-                STORM_LOG_ASSERT(mdpTransitionsBuilderCurrentRowEntries.count(column) == 0, "Trying to insert multiple transitions to the same state.");
-                mdpTransitionsBuilderCurrentRowEntries[column] = value;
+                STORM_LOG_ASSERT(exploredMdpTransitions[row].count(column) == 0, "Trying to insert multiple transitions to the same state.");
+                exploredMdpTransitions[row][column] = value;
             }
             
-            void finishCurrentRow() {
-                for (auto const& entry : mdpTransitionsBuilderCurrentRowEntries) {
-                    mdpTransitionsBuilder.addNextValue(currentRowCount - 1, entry.first, entry.second);
-                }
-                mdpTransitionsBuilderCurrentRowEntries.clear();
+            void internalAddRowGroupIndex() {
+                exploredChoiceIndices.push_back(getCurrentNumberOfMdpChoices());
             }
             
-            MdpStateType getMdpState(BeliefId const& beliefId) const {
-                if (beliefId < beliefIdsWithMdpState.size() && beliefIdsWithMdpState.get(beliefId)) {
+            MdpStateType getExploredMdpState(BeliefId const& beliefId) const {
+                if (beliefId < exploredBeliefIds.size() && exploredBeliefIds.get(beliefId)) {
                     return beliefIdToMdpStateMap.at(beliefId);
                 } else {
                     return noState();
@@ -336,20 +425,28 @@ namespace storm {
             }
             
             MdpStateType getOrAddMdpState(BeliefId const& beliefId) {
-                beliefIdsWithMdpState.grow(beliefId + 1, false);
-                if (beliefIdsWithMdpState.get(beliefId)) {
+                exploredBeliefIds.grow(beliefId + 1, false);
+                if (exploredBeliefIds.get(beliefId)) {
                     return beliefIdToMdpStateMap[beliefId];
                 } else {
-                    // Add a new MDP state
-                    beliefIdsWithMdpState.set(beliefId, true);
+                    // This state needs exploration
+                    exploredBeliefIds.set(beliefId, true);
+                    
+                    // If this is a restart of the exploration, we still might have an MDP state for the belief
+                    if (exploredMdp) {
+                        auto findRes = beliefIdToMdpStateMap.find(beliefId);
+                        if (findRes != beliefIdToMdpStateMap.end()) {
+                            mdpStatesToExplore.push_back(findRes->second);
+                            return findRes->second;
+                        }
+                    }
+                    // At this poind we need to add a new MDP state
                     MdpStateType result = getCurrentNumberOfMdpStates();
                     assert(getCurrentNumberOfMdpStates() == mdpStateToBeliefIdMap.size());
                     mdpStateToBeliefIdMap.push_back(beliefId);
                     beliefIdToMdpStateMap[beliefId] = result;
-                    // This new belief needs exploration
-                    beliefIdsToExplore.push_back(beliefId);
-                    
                     insertValueHints(computeLowerValueBoundAtBelief(beliefId), computeUpperValueBoundAtBelief(beliefId));
+                    mdpStatesToExplore.push_back(result);
                     return result;
                 }
             }
@@ -358,15 +455,14 @@ namespace storm {
             std::shared_ptr<BeliefManagerType> beliefManager;
             std::vector<BeliefId> mdpStateToBeliefIdMap;
             std::map<BeliefId, MdpStateType> beliefIdToMdpStateMap;
-            storm::storage::BitVector beliefIdsWithMdpState;
+            storm::storage::BitVector exploredBeliefIds;
             
             // Exploration information
-            std::deque<uint64_t> beliefIdsToExplore;
-            storm::storage::SparseMatrixBuilder<ValueType> mdpTransitionsBuilder;
-            std::map<MdpStateType, ValueType> mdpTransitionsBuilderCurrentRowEntries;
+            std::deque<uint64_t> mdpStatesToExplore;
+            std::vector<std::map<MdpStateType, ValueType>> exploredMdpTransitions;
+            std::vector<MdpStateType> exploredChoiceIndices;
             std::vector<ValueType> mdpActionRewards;
-            uint64_t startOfCurrentRowGroup;
-            uint64_t currentRowCount;
+            uint64_t currentMdpState;
             
             // Special states during exploration
             boost::optional<MdpStateType> extraTargetState;
diff --git a/src/storm-pomdp/storage/BeliefManager.h b/src/storm-pomdp/storage/BeliefManager.h
index 9cb7c039c..8f0dcd225 100644
--- a/src/storm-pomdp/storage/BeliefManager.h
+++ b/src/storm-pomdp/storage/BeliefManager.h
@@ -324,8 +324,8 @@ namespace storm {
             }
             
             std::map<BeliefId, ValueType> expandInternal(BeliefId const& beliefId, uint64_t actionIndex, boost::optional<std::vector<uint64_t>> const& observationTriangulationResolutions = boost::none) {
-                std::map<BeliefId, ValueType> destinations; // The belief ids should be ordered
-                // TODO: Does this make sense? It could be better to order them afterwards because now we rely on the fact that MDP states have the same order than their associated BeliefIds
+                std::map<BeliefId, ValueType> destinations;
+                // TODO: Output as vector?
                 
                 BeliefType belief = getBelief(beliefId);