From 6ce1f96efcdafad207c93a168f04083e1733d4d1 Mon Sep 17 00:00:00 2001
From: Sebastian Junges <sebastian.junges@gmail.com>
Date: Sun, 23 May 2021 21:40:24 -0700
Subject: [PATCH] belief support tracking test and cleaning

---
 .../generator/BeliefSupportTracker.cpp        |  6 +++
 .../generator/BeliefSupportTracker.h          | 15 +++++-
 src/test/storm-pomdp/CMakeLists.txt           |  2 +-
 .../tracking/BeliefSupportTrackingTest.cpp    | 47 +++++++++++++++++++
 4 files changed, 67 insertions(+), 3 deletions(-)
 create mode 100644 src/test/storm-pomdp/tracking/BeliefSupportTrackingTest.cpp

diff --git a/src/storm-pomdp/generator/BeliefSupportTracker.cpp b/src/storm-pomdp/generator/BeliefSupportTracker.cpp
index 3ee6195ef..714a85446 100644
--- a/src/storm-pomdp/generator/BeliefSupportTracker.cpp
+++ b/src/storm-pomdp/generator/BeliefSupportTracker.cpp
@@ -29,6 +29,12 @@ namespace storm {
             currentBeliefSupport = newBeliefSupport;
         }
 
+        template<typename ValueType>
+        void BeliefSupportTracker<ValueType>::reset() {
+            currentBeliefSupport = pomdp.getInitialStates();
+        }
+
+
         template class BeliefSupportTracker<double>;
         template class BeliefSupportTracker<storm::RationalNumber>;
 
diff --git a/src/storm-pomdp/generator/BeliefSupportTracker.h b/src/storm-pomdp/generator/BeliefSupportTracker.h
index dcd133043..91cfed593 100644
--- a/src/storm-pomdp/generator/BeliefSupportTracker.h
+++ b/src/storm-pomdp/generator/BeliefSupportTracker.h
@@ -12,10 +12,21 @@ namespace storm {
              */
         public:
             BeliefSupportTracker(storm::models::sparse::Pomdp<ValueType> const& pomdp);
-
+            /**
+             * The current belief support according to the tracker
+             * @return
+             */
             storm::storage::BitVector const& getCurrentBeliefSupport() const;
-
+            /*!
+             * Update current belief support state
+             * @param action The action that was taken
+             * @param observation The new (state) observation
+             */
             void track(uint64_t action, uint64_t observation);
+            /*!
+             * Reset to initial state
+             */
+            void reset();
 
         private:
             storm::models::sparse::Pomdp<ValueType> const& pomdp;
diff --git a/src/test/storm-pomdp/CMakeLists.txt b/src/test/storm-pomdp/CMakeLists.txt
index 70ceeb3dd..d4e2ed819 100644
--- a/src/test/storm-pomdp/CMakeLists.txt
+++ b/src/test/storm-pomdp/CMakeLists.txt
@@ -9,7 +9,7 @@ register_source_groups_from_filestructure("${ALL_FILES}" test)
 # Note that the tests also need the source files, except for the main file
 include_directories(${GTEST_INCLUDE_DIR})
 
-foreach (testsuite analysis transformation modelchecker)
+foreach (testsuite analysis transformation modelchecker tracking)
 
 	  file(GLOB_RECURSE TEST_${testsuite}_FILES ${STORM_TESTS_BASE_PATH}/${testsuite}/*.h ${STORM_TESTS_BASE_PATH}/${testsuite}/*.cpp)
       add_executable (test-pomdp-${testsuite} ${TEST_${testsuite}_FILES} ${STORM_TESTS_BASE_PATH}/storm-test.cpp)
diff --git a/src/test/storm-pomdp/tracking/BeliefSupportTrackingTest.cpp b/src/test/storm-pomdp/tracking/BeliefSupportTrackingTest.cpp
new file mode 100644
index 000000000..ec28bdbf0
--- /dev/null
+++ b/src/test/storm-pomdp/tracking/BeliefSupportTrackingTest.cpp
@@ -0,0 +1,47 @@
+#include "test/storm_gtest.h"
+#include "storm-config.h"
+#include "storm/models/sparse/StandardRewardModel.h"
+#include "storm-parsers/parser/PrismParser.h"
+#include "storm/builder/ExplicitModelBuilder.h"
+#include "storm/api/storm.h"
+#include "storm-parsers/api/storm-parsers.h"
+#include "storm-pomdp/analysis/FormulaInformation.h"
+#include "storm-pomdp/transformer/MakePOMDPCanonic.h"
+#include "storm-pomdp/generator/BeliefSupportTracker.h"
+
+// TODO
+// These tests depend on the interpretation of action and observation numbers and those may change.
+// A more robust test would take the high-level actions and observations and track on those.
+
+TEST(BeliefSupportTracking, Maze) {
+    storm::prism::Program program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/pomdp/maze2.prism");
+    program = storm::utility::prism::preprocess(program, "sl=0.4");
+    std::shared_ptr<storm::logic::Formula const> formula = storm::api::parsePropertiesForPrismProgram("Pmax=? [F \"goal\" ]", program).front().getRawFormula();
+    std::shared_ptr<storm::models::sparse::Pomdp<double>> pomdp = storm::api::buildSparseModel<double>(program, {formula})->as<storm::models::sparse::Pomdp<double>>();
+    storm::transformer::MakePOMDPCanonic<double> makeCanonic(*pomdp);
+    pomdp = makeCanonic.transform();
+
+    storm::generator::BeliefSupportTracker<double> tracker(*pomdp);
+    EXPECT_EQ(pomdp->getInitialStates(), tracker.getCurrentBeliefSupport());
+    tracker.track(0,0);
+    auto beliefsup = tracker.getCurrentBeliefSupport();
+    EXPECT_EQ(6ul, beliefsup.getNumberOfSetBits());
+    tracker.track(0,0);
+    EXPECT_EQ(beliefsup, tracker.getCurrentBeliefSupport());
+    tracker.track(0,1);
+    EXPECT_TRUE(tracker.getCurrentBeliefSupport().empty());
+    tracker.reset();
+    EXPECT_EQ(pomdp->getInitialStates(), tracker.getCurrentBeliefSupport());
+    tracker.track(0, 0);
+    EXPECT_EQ(beliefsup, tracker.getCurrentBeliefSupport());
+    tracker.track(1,0);
+    EXPECT_EQ(beliefsup, tracker.getCurrentBeliefSupport());
+    tracker.track(2,1);
+    EXPECT_EQ(1ul, tracker.getCurrentBeliefSupport().getNumberOfSetBits());
+    tracker.track(3,0);
+    EXPECT_EQ(1ul, tracker.getCurrentBeliefSupport().getNumberOfSetBits());
+    tracker.track(3,0);
+    EXPECT_EQ(2ul, tracker.getCurrentBeliefSupport().getNumberOfSetBits());
+
+
+}