From 591e63e11eed3edae0522c7049f35ce4cff5b0c0 Mon Sep 17 00:00:00 2001
From: lukpo <lukas.posch@student.tugraz.at>
Date: Mon, 19 Jul 2021 16:57:18 +0200
Subject: [PATCH] fixed choiceValues for MDP shields - Until and Globally

---
 .../prctl/SparseMdpPrctlModelChecker.cpp      |  5 +-
 .../prctl/helper/SparseMdpPrctlHelper.cpp     | 61 +++++++++++++++++--
 2 files changed, 61 insertions(+), 5 deletions(-)

diff --git a/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp b/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp
index 30cfcf19f..39b2ea6e9 100644
--- a/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp
+++ b/src/storm/modelchecker/prctl/SparseMdpPrctlModelChecker.cpp
@@ -102,6 +102,8 @@ namespace storm {
                 storm::storage::BitVector resultMaybeStates;
                 std::vector<ValueType> choiceValues;
 
+                //TODO: check the result for shields
+
                 if(checkTask.isShieldingTask()) {
                     numericResult = helper.compute(env, storm::solver::SolveGoal<ValueType>(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector(), pathFormula.getNonStrictLowerBound<uint64_t>(), pathFormula.getNonStrictUpperBound<uint64_t>(), checkTask.getHint(), resultMaybeStates, choiceValues);
                     tempest::shields::createShield<ValueType>(std::make_shared<storm::models::sparse::Mdp<ValueType>>(this->getModel()), std::move(choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(resultMaybeStates), storm::storage::BitVector(resultMaybeStates.size(), false));
@@ -121,7 +123,8 @@ namespace storm {
             auto ret = storm::modelchecker::helper::SparseMdpPrctlHelper<ValueType>::computeNextProbabilities(env, storm::solver::SolveGoal<ValueType>(this->getModel(), checkTask), checkTask.getOptimizationDirection(), this->getModel().getTransitionMatrix(), subResult.getTruthValuesVector());
             std::unique_ptr<CheckResult> result(new ExplicitQuantitativeCheckResult<ValueType>(std::move(ret.values)));
             if(checkTask.isShieldingTask()) {
-                tempest::shields::createShield<ValueType>(std::make_shared<storm::models::sparse::Mdp<ValueType>>(this->getModel()), std::move(ret.values), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(ret.maybeStates), storm::storage::BitVector(ret.maybeStates.size(), false));
+                //TODO: creating a shield for NEXT does not work
+                tempest::shields::createShield<ValueType>(std::make_shared<storm::models::sparse::Mdp<ValueType>>(this->getModel()), std::move(ret.choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(ret.maybeStates), storm::storage::BitVector(ret.maybeStates.size(), false));
             } else if (checkTask.isProduceSchedulersSet() && ret.scheduler) {
                 result->asExplicitQuantitativeCheckResult<ValueType>().setScheduler(std::move(ret.scheduler));
             }
diff --git a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp
index 04eab8cfe..9ad693722 100644
--- a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp
+++ b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp
@@ -620,10 +620,18 @@ namespace storm {
 
                 // create multiplier and execute the calculation for 1 additional step
                 auto multiplier = storm::solver::MultiplierFactory<ValueType>().create(env, transitionMatrix);
-                std::vector<ValueType> choiceValues = std::vector<ValueType>(transitionMatrix.getRowCount(), storm::utility::zero<ValueType>());
+
+                uint sizeChoiceValues = 0;
+                for(uint counter = 0; counter < qualitativeStateSets.maybeStates.size(); counter++) {
+                    if(qualitativeStateSets.maybeStates.get(counter)) {
+                        sizeChoiceValues += transitionMatrix.getRowGroupSize(counter);
+                    }
+                }
+
+                std::vector<ValueType> choiceValues = std::vector<ValueType>(sizeChoiceValues, storm::utility::zero<ValueType>());
                 
                 // Check whether we need to compute exact probabilities for some states.
-                if (qualitative || maybeStatesNotRelevant) {
+                if (qualitative || maybeStatesNotRelevant || !goal.isShieldingTask()) {
                     // Set the values for all maybe-states to 0.5 to indicate that their probability values are neither 0 nor 1.
                     storm::utility::vector::setVectorValues<ValueType>(result, qualitativeStateSets.maybeStates, storm::utility::convertNumber<ValueType>(0.5));
                 } else {
@@ -663,8 +671,50 @@ namespace storm {
                             }
                         }
                         if (goal.isShieldingTask()) {
-                            multiplier->multiply(env, result, &b, choiceValues);
-                            multiplier->reduce(env, goal.direction(), choiceValues, transitionMatrix.getRowGroupIndices(), result, nullptr);
+                            STORM_LOG_DEBUG("SparseMdpPrctlHelper<ValueType>::computeUntilProbabilities: Before multiply()");
+                            STORM_LOG_DEBUG(result);
+
+                            std::vector<ValueType> subResult;
+                            uint sizeChoiceValues = 0;
+                            for(uint counter = 0; counter < qualitativeStateSets.maybeStates.size(); counter++) {
+                                if(qualitativeStateSets.maybeStates.get(counter)) {
+                                    subResult.push_back(result.at(counter));
+                                }
+                            }
+
+                            STORM_LOG_DEBUG(subResult);
+
+
+                            STORM_LOG_DEBUG(choiceValues);
+                            //std::vector<ValueType> b_complete = transitionMatrix.getConstrainedRowGroupSumVector(storm::storage::BitVector(qualitativeStateSets.maybeStates.size(), true), qualitativeStateSets.statesWithProbability1);
+
+                            submatrix = transitionMatrix.getSubmatrix(true, qualitativeStateSets.maybeStates, qualitativeStateSets.maybeStates, false);
+
+                            //STORM_LOG_DEBUG(b_complete);
+
+                            auto sub_multiplier = storm::solver::MultiplierFactory<ValueType>().create(env, submatrix);
+
+
+                            sub_multiplier->multiply(env, subResult, &b, choiceValues);
+                            STORM_LOG_DEBUG(choiceValues);
+
+                            std::vector<ValueType> allChoices = std::vector<ValueType>(transitionMatrix.getRowGroupIndices().at(transitionMatrix.getRowGroupIndices().size() - 1), storm::utility::zero<ValueType>());
+                            auto choice_it = choiceValues.begin();
+                            for(uint state = 0; state < transitionMatrix.getRowGroupIndices().size() - 1; state++) {
+                                uint rowGroupSize = transitionMatrix.getRowGroupIndices().at(state + 1) - transitionMatrix.getRowGroupIndices().at(state);
+                                if (qualitativeStateSets.maybeStates.get(state)) {
+                                    for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
+                                        allChoices.at(transitionMatrix.getRowGroupIndices().at(state) + choice) = *choice_it;
+                                    }
+                                }
+                            }
+                            choiceValues = allChoices;
+
+                            STORM_LOG_DEBUG(choiceValues);
+
+
+                            //TODO: fill up choiceValues with Zeros (dimension!)
+
                         }
                     }
                 }
@@ -703,6 +753,9 @@ namespace storm {
                     for (auto& element : result.values) {
                         element = storm::utility::one<ValueType>() - element;
                     }
+                    for (auto& choice : result.choiceValues) {
+                        choice = storm::utility::one<ValueType>() - choice;
+                    }
                     return result;
                 }
             }