diff --git a/src/storm-pomdp/generator/BeliefSupportTracker.cpp b/src/storm-pomdp/generator/BeliefSupportTracker.cpp index 3ee6195ef..714a85446 100644 --- a/src/storm-pomdp/generator/BeliefSupportTracker.cpp +++ b/src/storm-pomdp/generator/BeliefSupportTracker.cpp @@ -29,6 +29,12 @@ namespace storm { currentBeliefSupport = newBeliefSupport; } + template + void BeliefSupportTracker::reset() { + currentBeliefSupport = pomdp.getInitialStates(); + } + + template class BeliefSupportTracker; template class BeliefSupportTracker; diff --git a/src/storm-pomdp/generator/BeliefSupportTracker.h b/src/storm-pomdp/generator/BeliefSupportTracker.h index dcd133043..91cfed593 100644 --- a/src/storm-pomdp/generator/BeliefSupportTracker.h +++ b/src/storm-pomdp/generator/BeliefSupportTracker.h @@ -12,10 +12,21 @@ namespace storm { */ public: BeliefSupportTracker(storm::models::sparse::Pomdp const& pomdp); - + /** + * The current belief support according to the tracker + * @return + */ storm::storage::BitVector const& getCurrentBeliefSupport() const; - + /*! + * Update current belief support state + * @param action The action that was taken + * @param observation The new (state) observation + */ void track(uint64_t action, uint64_t observation); + /*! + * Reset to initial state + */ + void reset(); private: storm::models::sparse::Pomdp const& pomdp; diff --git a/src/test/storm-pomdp/CMakeLists.txt b/src/test/storm-pomdp/CMakeLists.txt index 70ceeb3dd..d4e2ed819 100644 --- a/src/test/storm-pomdp/CMakeLists.txt +++ b/src/test/storm-pomdp/CMakeLists.txt @@ -9,7 +9,7 @@ register_source_groups_from_filestructure("${ALL_FILES}" test) # Note that the tests also need the source files, except for the main file include_directories(${GTEST_INCLUDE_DIR}) -foreach (testsuite analysis transformation modelchecker) +foreach (testsuite analysis transformation modelchecker tracking) file(GLOB_RECURSE TEST_${testsuite}_FILES ${STORM_TESTS_BASE_PATH}/${testsuite}/*.h ${STORM_TESTS_BASE_PATH}/${testsuite}/*.cpp) add_executable (test-pomdp-${testsuite} ${TEST_${testsuite}_FILES} ${STORM_TESTS_BASE_PATH}/storm-test.cpp) diff --git a/src/test/storm-pomdp/tracking/BeliefSupportTrackingTest.cpp b/src/test/storm-pomdp/tracking/BeliefSupportTrackingTest.cpp new file mode 100644 index 000000000..ec28bdbf0 --- /dev/null +++ b/src/test/storm-pomdp/tracking/BeliefSupportTrackingTest.cpp @@ -0,0 +1,47 @@ +#include "test/storm_gtest.h" +#include "storm-config.h" +#include "storm/models/sparse/StandardRewardModel.h" +#include "storm-parsers/parser/PrismParser.h" +#include "storm/builder/ExplicitModelBuilder.h" +#include "storm/api/storm.h" +#include "storm-parsers/api/storm-parsers.h" +#include "storm-pomdp/analysis/FormulaInformation.h" +#include "storm-pomdp/transformer/MakePOMDPCanonic.h" +#include "storm-pomdp/generator/BeliefSupportTracker.h" + +// TODO +// These tests depend on the interpretation of action and observation numbers and those may change. +// A more robust test would take the high-level actions and observations and track on those. + +TEST(BeliefSupportTracking, Maze) { + storm::prism::Program program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/pomdp/maze2.prism"); + program = storm::utility::prism::preprocess(program, "sl=0.4"); + std::shared_ptr formula = storm::api::parsePropertiesForPrismProgram("Pmax=? [F \"goal\" ]", program).front().getRawFormula(); + std::shared_ptr> pomdp = storm::api::buildSparseModel(program, {formula})->as>(); + storm::transformer::MakePOMDPCanonic makeCanonic(*pomdp); + pomdp = makeCanonic.transform(); + + storm::generator::BeliefSupportTracker tracker(*pomdp); + EXPECT_EQ(pomdp->getInitialStates(), tracker.getCurrentBeliefSupport()); + tracker.track(0,0); + auto beliefsup = tracker.getCurrentBeliefSupport(); + EXPECT_EQ(6ul, beliefsup.getNumberOfSetBits()); + tracker.track(0,0); + EXPECT_EQ(beliefsup, tracker.getCurrentBeliefSupport()); + tracker.track(0,1); + EXPECT_TRUE(tracker.getCurrentBeliefSupport().empty()); + tracker.reset(); + EXPECT_EQ(pomdp->getInitialStates(), tracker.getCurrentBeliefSupport()); + tracker.track(0, 0); + EXPECT_EQ(beliefsup, tracker.getCurrentBeliefSupport()); + tracker.track(1,0); + EXPECT_EQ(beliefsup, tracker.getCurrentBeliefSupport()); + tracker.track(2,1); + EXPECT_EQ(1ul, tracker.getCurrentBeliefSupport().getNumberOfSetBits()); + tracker.track(3,0); + EXPECT_EQ(1ul, tracker.getCurrentBeliefSupport().getNumberOfSetBits()); + tracker.track(3,0); + EXPECT_EQ(2ul, tracker.getCurrentBeliefSupport().getNumberOfSetBits()); + + +}