diff --git a/src/storm/solver/StandardGameSolver.cpp b/src/storm/solver/StandardGameSolver.cpp index bd1024a86..6140a5c15 100644 --- a/src/storm/solver/StandardGameSolver.cpp +++ b/src/storm/solver/StandardGameSolver.cpp @@ -159,9 +159,8 @@ namespace storm { template bool StandardGameSolver::solveGameValueIteration(Environment const& env, OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector& x, std::vector const& b) const { - if(!linEqSolverPlayer2Matrix) { - linEqSolverPlayer2Matrix = linearEquationSolverFactory->create(env, player2Matrix, storm::solver::LinearEquationSolverTask::Multiply); - linEqSolverPlayer2Matrix->setCachingEnabled(true); + if (!multiplierPlayer2Matrix) { + multiplierPlayer2Matrix = storm::solver::MultiplierFactory().create(env, player2Matrix); } if (!auxiliaryP2RowVector) { @@ -204,7 +203,7 @@ namespace storm { Status status = Status::InProgress; while (status == Status::InProgress) { - multiplyAndReduce(player1Dir, player2Dir, *currentX, &b, *linEqSolverPlayer2Matrix, multiplyResult, reducedMultiplyResult, *newX); + multiplyAndReduce(env, player1Dir, player2Dir, *currentX, &b, *multiplierPlayer2Matrix, multiplyResult, reducedMultiplyResult, *newX); // Determine whether the method converged. if (storm::utility::vector::equalModuloPrecision(*currentX, *newX, precision, relative)) { @@ -242,9 +241,8 @@ namespace storm { template void StandardGameSolver::repeatedMultiply(Environment const& env, OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector& x, std::vector const* b, uint_fast64_t n) const { - if(!linEqSolverPlayer2Matrix) { - linEqSolverPlayer2Matrix = linearEquationSolverFactory->create(env, player2Matrix, storm::solver::LinearEquationSolverTask::Multiply); - linEqSolverPlayer2Matrix->setCachingEnabled(true); + if (!multiplierPlayer2Matrix) { + multiplierPlayer2Matrix = storm::solver::MultiplierFactory().create(env, player2Matrix); } if (!auxiliaryP2RowVector) { @@ -258,7 +256,7 @@ namespace storm { std::vector& reducedMultiplyResult = *auxiliaryP2RowGroupVector; for (uint_fast64_t iteration = 0; iteration < n; ++iteration) { - multiplyAndReduce(player1Dir, player2Dir, x, b, *linEqSolverPlayer2Matrix, multiplyResult, reducedMultiplyResult, x); + multiplyAndReduce(env, player1Dir, player2Dir, x, b, *multiplierPlayer2Matrix, multiplyResult, reducedMultiplyResult, x); } if(!this->isCachingEnabled()) { @@ -267,9 +265,9 @@ namespace storm { } template - void StandardGameSolver::multiplyAndReduce(OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector& x, std::vector const* b, storm::solver::LinearEquationSolver const& linEqSolver, std::vector& multiplyResult, std::vector& p2ReducedMultiplyResult, std::vector& p1ReducedMultiplyResult) const { + void StandardGameSolver::multiplyAndReduce(Environment const& env, OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector& x, std::vector const* b, storm::solver::Multiplier const& multiplier, std::vector& multiplyResult, std::vector& p2ReducedMultiplyResult, std::vector& p1ReducedMultiplyResult) const { - linEqSolver.multiply(x, b, multiplyResult); + multiplier.multiply(env, x, b, multiplyResult); storm::utility::vector::reduceVectorMinOrMax(player2Dir, multiplyResult, p2ReducedMultiplyResult, player2Matrix.getRowGroupIndices()); @@ -404,7 +402,7 @@ namespace storm { template void StandardGameSolver::clearCache() const { - linEqSolverPlayer2Matrix.reset(); + multiplierPlayer2Matrix.reset(); auxiliaryP2RowVector.reset(); auxiliaryP2RowGroupVector.reset(); auxiliaryP1RowGroupVector.reset(); diff --git a/src/storm/solver/StandardGameSolver.h b/src/storm/solver/StandardGameSolver.h index 336be3911..e9e8060a7 100644 --- a/src/storm/solver/StandardGameSolver.h +++ b/src/storm/solver/StandardGameSolver.h @@ -1,6 +1,7 @@ #pragma once #include "storm/solver/LinearEquationSolver.h" +#include "storm/solver/Multiplier.h" #include "storm/solver/GameSolver.h" #include "SolverSelectionOptions.h" @@ -26,8 +27,7 @@ namespace storm { bool solveGameValueIteration(Environment const& env, OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector& x, std::vector const& b) const; // Computes p2Matrix * x + b, reduces the result w.r.t. player 2 choices, and then reduces the result w.r.t. player 1 choices. - void multiplyAndReduce(OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector& x, std::vector const* b, - storm::solver::LinearEquationSolver const& linEqSolver, std::vector& multiplyResult, std::vector& p2ReducedMultiplyResult, std::vector& p1ReducedMultiplyResult) const; + void multiplyAndReduce(Environment const& env, OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector& x, std::vector const* b, storm::solver::Multiplier const& multiplier, std::vector& multiplyResult, std::vector& p2ReducedMultiplyResult, std::vector& p1ReducedMultiplyResult) const; // Solves the equation system given by the two choice selections void getInducedMatrixVector(std::vector& x, std::vector const& b, std::vector const& player1Choices, std::vector const& player2Choices, storm::storage::SparseMatrix& inducedMatrix, std::vector& inducedVector) const; @@ -43,7 +43,7 @@ namespace storm { }; // possibly cached data - mutable std::unique_ptr> linEqSolverPlayer2Matrix; + mutable std::unique_ptr> multiplierPlayer2Matrix; mutable std::unique_ptr> auxiliaryP2RowVector; // player2Matrix.rowCount() entries mutable std::unique_ptr> auxiliaryP2RowGroupVector; // player2Matrix.rowGroupCount() entries mutable std::unique_ptr> auxiliaryP1RowGroupVector; // player1Matrix.rowGroupCount() entries