From a3e92d2f72952e759effe1cf765f7d7cad87a2a9 Mon Sep 17 00:00:00 2001
From: Tim Quatmann <tim.quatmann@cs.rwth-aachen.de>
Date: Mon, 30 Mar 2020 12:18:06 +0200
Subject: [PATCH] Using the new reward functionalities of BliefGrid. This also
 fixes setting rewards in a wrong way (previously, the same reward was
 assigned to states with the same observation).

---
 .../ApproximatePOMDPModelchecker.cpp          | 71 ++++++++++---------
 1 file changed, 39 insertions(+), 32 deletions(-)

diff --git a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp
index fb5264ca8..3aeb81d18 100644
--- a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp
+++ b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp
@@ -137,7 +137,28 @@ namespace storm {
 
                 stream << "##########################################" << std::endl;
             }
-
+            
+            std::shared_ptr<storm::logic::Formula const> createStandardProperty(bool min, bool computeRewards) {
+                std::string propertyString = computeRewards ? "R" : "P";
+                propertyString += min ? "min" : "max";
+                propertyString += "=? [F \"target\"]";
+                std::vector<storm::jani::Property> propertyVector = storm::api::parseProperties(propertyString);
+                return storm::api::extractFormulasFromProperties(propertyVector).front();
+            }
+            
+            template<typename ValueType>
+            storm::modelchecker::CheckTask<storm::logic::Formula, ValueType> createStandardCheckTask(std::shared_ptr<storm::logic::Formula const>& property, std::vector<ValueType>&& hintVector) {
+                //Note: The property should not run out of scope after calling this because the task only stores the property by reference.
+                // Therefore, this method needs the property by reference (and not const reference)
+                auto task = storm::api::createTask<ValueType>(property, false);
+                if (!hintVector.empty()) {
+                    auto hint = storm::modelchecker::ExplicitModelCheckerHint<ValueType>();
+                    hint.setResultHint(std::move(hintVector));
+                    auto hintPtr = std::make_shared<storm::modelchecker::ExplicitModelCheckerHint<ValueType>>(hint);
+                    task.setHint(hintPtr);
+                }
+                return task;
+            }
             
             template<typename ValueType, typename RewardModelType>
             std::unique_ptr<POMDPCheckResult<ValueType>>
@@ -360,6 +381,10 @@ namespace storm {
                 }
 
                 storm::storage::BeliefGrid<storm::models::sparse::Pomdp<ValueType>> beliefGrid(pomdp, options.numericPrecision);
+                if (computeRewards) {
+                    beliefGrid.setRewardModel();
+                }
+                
                 bsmap_type beliefStateMap;
 
                 std::deque<uint64_t> beliefsToBeExpanded;
@@ -520,37 +545,27 @@ namespace storm {
                 storm::storage::sparse::ModelComponents<ValueType, RewardModelType> modelComponents(mdpTransitionsBuilder.build(mdpMatrixRow, nextMdpStateId, nextMdpStateId), std::move(mdpLabeling));
                 auto overApproxMdp = std::make_shared<storm::models::sparse::Mdp<ValueType, RewardModelType>>(std::move(modelComponents));
                 if (computeRewards) {
-                    storm::models::sparse::StandardRewardModel<ValueType> mdpRewardModel(boost::none, std::vector<ValueType>(mdpMatrixRow));
+                    storm::models::sparse::StandardRewardModel<ValueType> mdpRewardModel(boost::none, std::vector<ValueType>(mdpMatrixRow, storm::utility::zero<ValueType>()));
                     for (auto const &iter : beliefStateMap.left) {
                         if (fullyExpandedStates.get(iter.second)) {
-                            auto currentBelief = beliefGrid.getGridPoint(iter.first);
+                            auto const& currentBelief = beliefGrid.getGridPoint(iter.first);
                             auto representativeState = currentBelief.begin()->first;
                             for (uint64_t action = 0; action < pomdp.getNumberOfChoices(representativeState); ++action) {
-                                // Add the reward
                                 uint64_t mdpChoice = overApproxMdp->getChoiceIndex(storm::storage::StateActionPair(iter.second, action));
-                                uint64_t pomdpChoice = pomdp.getChoiceIndex(storm::storage::StateActionPair(representativeState, action));
-                                mdpRewardModel.setStateActionReward(mdpChoice, getRewardAfterAction(pomdpChoice, currentBelief));
+                                mdpRewardModel.setStateActionReward(mdpChoice, beliefGrid.getBeliefActionReward(currentBelief, action));
                             }
                         }
                     }
                     overApproxMdp->addRewardModel("default", mdpRewardModel);
-                    overApproxMdp->restrictRewardModels(std::set<std::string>({"default"}));
                 }
                 statistics.overApproximationBuildTime.stop();
                 STORM_PRINT("Over Approximation MDP build took " << statistics.overApproximationBuildTime << " seconds." << std::endl);
                 overApproxMdp->printModelInformationToStream(std::cout);
 
                 auto modelPtr = std::static_pointer_cast<storm::models::sparse::Model<ValueType, RewardModelType>>(overApproxMdp);
-                std::string propertyString = computeRewards ? "R" : "P";
-                propertyString += min ? "min" : "max";
-                propertyString += "=? [F \"target\"]";
-                std::vector<storm::jani::Property> propertyVector = storm::api::parseProperties(propertyString);
-                std::shared_ptr<storm::logic::Formula const> property = storm::api::extractFormulasFromProperties(propertyVector).front();
-                auto task = storm::api::createTask<ValueType>(property, false);
-                auto hint = storm::modelchecker::ExplicitModelCheckerHint<ValueType>();
-                hint.setResultHint(hintVector);
-                auto hintPtr = std::make_shared<storm::modelchecker::ExplicitModelCheckerHint<ValueType>>(hint);
-                task.setHint(hintPtr);
+                auto property = createStandardProperty(min, computeRewards);
+                auto task = createStandardCheckTask(property, std::move(hintVector));
+
                 statistics.overApproximationCheckTime.start();
                 std::unique_ptr<storm::modelchecker::CheckResult> res(storm::api::verifyWithSparseEngine<ValueType>(overApproxMdp, task));
                 statistics.overApproximationCheckTime.stop();
@@ -1172,16 +1187,14 @@ namespace storm {
                 storm::storage::sparse::ModelComponents<ValueType, RewardModelType> modelComponents(mdpTransitionsBuilder.build(mdpMatrixRow, nextMdpStateId, nextMdpStateId), std::move(mdpLabeling));
                 auto model = std::make_shared<storm::models::sparse::Mdp<ValueType, RewardModelType>>(std::move(modelComponents));
                 if (computeRewards) {
-                    storm::models::sparse::StandardRewardModel<ValueType> mdpRewardModel(boost::none, std::vector<ValueType>(mdpMatrixRow));
+                    storm::models::sparse::StandardRewardModel<ValueType> mdpRewardModel(boost::none, std::vector<ValueType>(mdpMatrixRow, storm::utility::zero<ValueType>()));
                     for (auto const &iter : beliefStateMap.left) {
                         if (fullyExpandedStates.get(iter.second)) {
-                            auto currentBelief = beliefGrid.getGridPoint(iter.first);
+                            auto const& currentBelief = beliefGrid.getGridPoint(iter.first);
                             auto representativeState = currentBelief.begin()->first;
                             for (uint64_t action = 0; action < pomdp.getNumberOfChoices(representativeState); ++action) {
-                                // Add the reward
                                 uint64_t mdpChoice = model->getChoiceIndex(storm::storage::StateActionPair(iter.second, action));
-                                uint64_t pomdpChoice = pomdp.getChoiceIndex(storm::storage::StateActionPair(representativeState, action));
-                                mdpRewardModel.setStateActionReward(mdpChoice, getRewardAfterAction(pomdpChoice, currentBelief));
+                                mdpRewardModel.setStateActionReward(mdpChoice, beliefGrid.getBeliefActionReward(currentBelief, action));
                             }
                         }
                     }
@@ -1192,17 +1205,11 @@ namespace storm {
                 model->printModelInformationToStream(std::cout);
                 statistics.underApproximationBuildTime.stop();
 
-                std::string propertyString;
-                if (computeRewards) {
-                    propertyString = min ? "Rmin=? [F \"target\"]" : "Rmax=? [F \"target\"]";
-                } else {
-                    propertyString = min ? "Pmin=? [F \"target\"]" : "Pmax=? [F \"target\"]";
-                }
-                std::vector<storm::jani::Property> propertyVector = storm::api::parseProperties(propertyString);
-                std::shared_ptr<storm::logic::Formula const> property = storm::api::extractFormulasFromProperties(propertyVector).front();
-
+                auto property = createStandardProperty(min, computeRewards);
+                auto task = createStandardCheckTask(property, std::vector<ValueType>());
+                
                 statistics.underApproximationCheckTime.start();
-                std::unique_ptr<storm::modelchecker::CheckResult> res(storm::api::verifyWithSparseEngine<ValueType>(model, storm::api::createTask<ValueType>(property, false)));
+                std::unique_ptr<storm::modelchecker::CheckResult> res(storm::api::verifyWithSparseEngine<ValueType>(model, task));
                 statistics.underApproximationCheckTime.stop();
                 if (storm::utility::resources::isTerminate() && !res) {
                     return nullptr;