From 36f1306b4af256ede89d2e59a749995a35752e09 Mon Sep 17 00:00:00 2001
From: dehnert <dehnert@cs.rwth-aachen.de>
Date: Wed, 19 Jun 2013 20:31:45 +0200
Subject: [PATCH] Now schedulers get computed correctly.

Former-commit-id: 3b986ffbf8917aff44be5f0669f6a7ed7db1239c
---
 .../prctl/SparseMdpPrctlModelChecker.h        | 54 +++++++++++--------
 1 file changed, 31 insertions(+), 23 deletions(-)

diff --git a/src/modelchecker/prctl/SparseMdpPrctlModelChecker.h b/src/modelchecker/prctl/SparseMdpPrctlModelChecker.h
index 739a6d408..530e9df18 100644
--- a/src/modelchecker/prctl/SparseMdpPrctlModelChecker.h
+++ b/src/modelchecker/prctl/SparseMdpPrctlModelChecker.h
@@ -336,7 +336,7 @@ namespace storm {
                     
                     // If we were required to generate a scheduler, do so now.
                     if (scheduler != nullptr) {
-                        this->computeTakenChoices(this->minimumOperatorStack.top(), *result, *scheduler, this->getModel().getNondeterministicChoiceIndices());
+                        this->computeTakenChoices(this->minimumOperatorStack.top(), false, *result, *scheduler, this->getModel().getNondeterministicChoiceIndices());
                     }
                     
                     return result;
@@ -541,11 +541,7 @@ namespace storm {
                     storm::utility::vector::setVectorValues(*result, infinityStates, storm::utility::constGetInfinity<Type>());
                     
 //                    std::vector<uint_fast64_t> myScheduler(this->getModel().getNumberOfStates());
-//                    this->computeTakenChoices(this->minimumOperatorStack.top(), *result, myScheduler, this->getModel().getNondeterministicChoiceIndices());
-//                    std::cout << "min? " << this->minimumOperatorStack.top() << " 487: " << myScheduler[487] << std::endl;
-//                    std::cout << "513: " << (*result)[513] << " 484: " << (*result)[484] << std::endl;
-//                    std::cout << "real scheduler: " << myScheduler << std::endl;
-//                    storm::storage::BitVector subsys(this->getModel().getNumberOfStates(), true);
+//                    this->computeTakenChoices(this->minimumOperatorStack.top(), true, *result, myScheduler, this->getModel().getNondeterministicChoiceIndices());
 //                    
 //                    std::vector<uint_fast64_t> stateColoring(this->getModel().getNumberOfStates());
 //                    for (auto target : *targetStates) {
@@ -556,11 +552,11 @@ namespace storm {
 //                    colors[0] = "white";
 //                    colors[1] = "blue";
 //                    
-//                    this->getModel().writeDotToStream(std::cout, true, &subsys, result, nullptr, &stateColoring, &colors, &myScheduler);
+//                    this->getModel().writeDotToStream(std::cout, true, storm::storage::BitVector(this->getModel().getNumberOfStates(), true), result, nullptr, &stateColoring, &colors, &myScheduler);
                     
                     // If we were required to generate a scheduler, do so now.
                     if (scheduler != nullptr) {
-                        this->computeTakenChoices(this->minimumOperatorStack.top(), *result, *scheduler, this->getModel().getNondeterministicChoiceIndices());
+                        this->computeTakenChoices(this->minimumOperatorStack.top(), true, *result, *scheduler, this->getModel().getNondeterministicChoiceIndices());
                     }
                     
                     // Delete temporary storages and return result.
@@ -577,19 +573,31 @@ namespace storm {
                  * @param takenChoices The output vector that is to store the taken choices.
                  * @param nondeterministicChoiceIndices The assignment of states to their nondeterministic choices in the matrix.
                  */
-                void computeTakenChoices(bool minimize, std::vector<Type> const& result, std::vector<uint_fast64_t>& takenChoices, std::vector<uint_fast64_t> const& nondeterministicChoiceIndices) const {
-//                    std::vector<Type> nondeterministicResult(this->getModel().getTransitionMatrix().getColumnCount());
-//                    std::vector<Type> temporaryResult(nondeterministicChoiceIndices.size() - 1);
-//                    if (linearEquationSolver != nullptr) {
-//                        this->linearEquationSolver->performMatrixVectorMultiplication(this->getModel().getTransitionMatrix(), );
-//                    } else {
-//                        throw storm::exceptions::InvalidStateException() << "No valid linear equation solver available.";
-//                    }
-//                    if (minimize) {
-//                        storm::utility::vector::reduceVectorMin(nondeterministicResult, temporaryResult, nondeterministicChoiceIndices, &takenChoices);
-//                    } else {
-//                        storm::utility::vector::reduceVectorMax(nondeterministicResult, temporaryResult, nondeterministicChoiceIndices, &takenChoices);
-//                    }
+                void computeTakenChoices(bool minimize, bool addRewards, std::vector<Type> const& result, std::vector<uint_fast64_t>& takenChoices, std::vector<uint_fast64_t> const& nondeterministicChoiceIndices) const {
+                    std::vector<Type> temporaryResult(nondeterministicChoiceIndices.size() - 1);
+                    std::vector<Type> nondeterministicResult(result);
+                    storm::solver::GmmxxLinearEquationSolver<Type> solver;
+                    solver.performMatrixVectorMultiplication(this->getModel().getTransitionMatrix(), nondeterministicResult, nullptr, 1);
+                    if (addRewards) {
+                        std::vector<Type> totalRewardVector;
+                        if (this->getModel().hasTransitionRewards()) {
+                            std::vector<Type> totalRewardVector = this->getModel().getTransitionMatrix().getPointwiseProductRowSumVector(this->getModel().getTransitionRewardMatrix());
+                            if (this->getModel().hasStateRewards()) {
+                                std::vector<Type> stateRewards(totalRewardVector.size());
+                                storm::utility::vector::selectVectorValuesRepeatedly(stateRewards, storm::storage::BitVector(this->getModel().getStateRewardVector().size(), true), this->getModel().getNondeterministicChoiceIndices(), this->getModel().getStateRewardVector());
+                                storm::utility::vector::addVectorsInPlace(totalRewardVector, stateRewards);
+                            }
+                        } else {
+                            totalRewardVector.resize(nondeterministicResult.size());
+                            storm::utility::vector::selectVectorValuesRepeatedly(totalRewardVector, storm::storage::BitVector(this->getModel().getStateRewardVector().size(), true), this->getModel().getNondeterministicChoiceIndices(), this->getModel().getStateRewardVector());
+                        }
+                        storm::utility::vector::addVectorsInPlace(nondeterministicResult, totalRewardVector);
+                    }
+                    if (minimize) {
+                        storm::utility::vector::reduceVectorMin(nondeterministicResult, temporaryResult, nondeterministicChoiceIndices, &takenChoices);
+                    } else {
+                        storm::utility::vector::reduceVectorMax(nondeterministicResult, temporaryResult, nondeterministicChoiceIndices, &takenChoices);
+                    }
                 }
                 
                 /*!
@@ -647,8 +655,8 @@ namespace storm {
                         storm::utility::vector::selectVectorValues(b, scheduler, subNondeterministicChoiceIndices, rightHandSide);
                         storm::storage::SparseMatrix<Type> A(submatrix.getSubmatrix(scheduler, subNondeterministicChoiceIndices));
                         A.convertToEquationSystem();
-                        std::unique_ptr<storm::solver::GmmxxLinearEquationSolver<Type>> solver(new storm::solver::GmmxxLinearEquationSolver<Type>());
-                        solver->solveEquationSystem(A, result, b);
+                        storm::solver::GmmxxLinearEquationSolver<Type> solver;
+                        solver.solveEquationSystem(A, result, b);
                         
                         // As there are sometimes some very small values in the vector due to numerical solving, we set
                         // them to zero, because they otherwise require a certain number of value iterations.