From 621e9f679405ddc16b1d7f769bc47e0ac8a7949e Mon Sep 17 00:00:00 2001
From: Fabian Russold <fabian.russold@student.tugraz.at>
Date: Tue, 1 Oct 2024 12:09:17 +0200
Subject: [PATCH] optimization: WIP relevant states for game VI

---
 .../rpatl/helper/SparseSmgRpatlHelper.cpp     | 56 +++++++++++++------
 1 file changed, 38 insertions(+), 18 deletions(-)

diff --git a/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp b/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp
index e3dc2d8ee..88f2bd1a2 100644
--- a/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp
+++ b/src/storm/modelchecker/rpatl/helper/SparseSmgRpatlHelper.cpp
@@ -22,8 +22,7 @@ namespace storm {
 
                 // Relevant states are those states which are phiStates and not PsiStates.
                 storm::storage::BitVector relevantStates = phiStates & ~psiStates;
-
-                // Initialize the x vector and solution vector result.
+                    // Initialize the x vector and solution vector result.
                 std::vector<ValueType> x = std::vector<ValueType>(relevantStates.getNumberOfSetBits(), storm::utility::zero<ValueType>());
                 std::vector<ValueType> result = std::vector<ValueType>(transitionMatrix.getRowGroupCount(), storm::utility::zero<ValueType>());
                 std::vector<ValueType> b = transitionMatrix.getConstrainedRowGroupSumVector(relevantStates, psiStates);
@@ -62,44 +61,65 @@ namespace storm {
 
             template<typename ValueType>
             SMGSparseModelCheckingHelperReturnType<ValueType> SparseSmgRpatlHelper<ValueType>::computeUntilProbabilitiesSound(Environment const& env, storm::solver::SolveGoal<ValueType>&& goal, storm::storage::SparseMatrix<ValueType> const& transitionMatrix, storm::storage::SparseMatrix<ValueType> const& backwardTransitions, storm::storage::BitVector const& phiStates, storm::storage::BitVector const& psiStates, bool qualitative, storm::storage::BitVector statesOfCoalition, bool produceScheduler, ModelCheckerHint const& hint) {
+                STORM_LOG_DEBUG("statesOfCoalition: " << statesOfCoalition << std::endl);
+
+                storm::storage::BitVector prob1 = storm::utility::graph::performProb1(backwardTransitions, phiStates, psiStates);
+                storm::storage::BitVector probGreater0 = storm::utility::graph::performProbGreater0(backwardTransitions, phiStates, psiStates);
+                STORM_LOG_DEBUG("probGreater0: " << probGreater0 << std::endl);
+
+
 
-                storm::modelchecker::helper::internal::SoundGameViHelper<ValueType> viHelper(transitionMatrix, backwardTransitions, statesOfCoalition, psiStates, goal.direction());
                 std::unique_ptr<storm::storage::Scheduler<ValueType>> scheduler;
-                storm::storage::BitVector relevantStates(psiStates.size(), true);
+                storm::storage::BitVector relevantStates = storm::storage::BitVector(transitionMatrix.getRowGroupCount(), true); // TODO Fabian
+
+                storm::storage::SparseMatrix<ValueType> submatrix = transitionMatrix.getSubmatrix(true, relevantStates, relevantStates, false);
 
                 // Initialize the x vector and solution vector result.
-                // TODO Fabian: maybe relevant states (later)
-                std::vector<ValueType> xL = std::vector<ValueType>(transitionMatrix.getRowGroupCount(), storm::utility::zero<ValueType>());
+                std::vector<ValueType> xL = std::vector<ValueType>(relevantStates.getNumberOfSetBits(), storm::utility::zero<ValueType>());
+                auto xL_begin = xL.begin();
+                std::for_each(xL.begin(), xL.end(), [&prob1, &xL_begin](ValueType &it)
+                              {
+                                  if (prob1[&it - &(*xL_begin)])
+                                      it = 1;
+                              });
                 // std::transform(xL.begin(), xL.end(), psiStates.begin(), xL, [](double& a) { a *= 3; }) // TODO Fabian
                 // assigning 1s to the xL vector for all Goal states
-                assert(xL.size() == psiStates.size());
-                for (size_t i = 0; i < xL.size(); i++)
-                {
-                    if (psiStates[i])
-                        xL[i] = 1;
-                }
-                STORM_LOG_DEBUG("xL " << xL);
                 std::vector<ValueType> xU = std::vector<ValueType>(transitionMatrix.getRowGroupCount(), storm::utility::zero<ValueType>());
-                storm::storage::BitVector probGreater0 = storm::utility::graph::performProbGreater0(backwardTransitions, phiStates, psiStates);
                 // assigning 1s to the xU vector for all states except the states s where Prob(sEf) = 0 for all goal states f
-                assert(xU.size() == probGreater0.size());
                 auto xU_begin = xU.begin();
                 std::for_each(xU.begin(), xU.end(), [&probGreater0, &xU_begin](ValueType &it)
                               {
                                   if (probGreater0[&it - &(*xU_begin)])
                                       it = 1;
                               });
-
-                STORM_LOG_DEBUG("xU " << xU);
+                /*size_t i = 0;
+                auto new_end = std::remove_if(xU.begin(), xU.end(), [&relevantStates, &i](const auto& item) {
+                    bool ret = !(relevantStates[i]);
+                    i++;
+                    return ret;
+                });
+                xU.erase(new_end, xU.end());
+                xU.resize(relevantStates.getNumberOfSetBits()); */
 
                 std::vector<ValueType> result = std::vector<ValueType>(transitionMatrix.getRowGroupCount(), storm::utility::zero<ValueType>());
+                std::vector<ValueType> b = transitionMatrix.getConstrainedRowGroupSumVector(relevantStates, psiStates);
+
+                // STORM_LOG_DEBUG(transitionMatrix);
+                STORM_LOG_DEBUG("b: " << b);
+                storm::storage::BitVector clippedStatesOfCoalition(relevantStates.getNumberOfSetBits());
+                clippedStatesOfCoalition.setClippedStatesOfCoalition(relevantStates, statesOfCoalition);
                 // std::vector<ValueType> constrainedChoiceValues = std::vector<ValueType>(b.size(), storm::utility::zero<ValueType>()); // TODO Fabian: do I need this?
                 std::vector<ValueType> constrainedChoiceValues;
 
+                storm::modelchecker::helper::internal::SoundGameViHelper<ValueType> viHelper(transitionMatrix, backwardTransitions, b, statesOfCoalition, psiStates, goal.direction());
+
                 viHelper.performValueIteration(env, xL, xU, goal.direction(), constrainedChoiceValues);
 
+                storm::utility::vector::setVectorValues(result, relevantStates, xU);
+                storm::utility::vector::setVectorValues(result, psiStates, storm::utility::one<ValueType>());
+
                 STORM_LOG_DEBUG(xU);
-                return SMGSparseModelCheckingHelperReturnType<ValueType>(std::move(xU), std::move(relevantStates), std::move(scheduler), std::move(constrainedChoiceValues));
+                return SMGSparseModelCheckingHelperReturnType<ValueType>(std::move(result), std::move(relevantStates), std::move(scheduler), std::move(constrainedChoiceValues));
             }
 
             template<typename ValueType>