Browse Source

Using multiplier in game solver

tempestpy_adaptions
TimQu 7 years ago
parent
commit
66c5255d8c
  1. 20
      src/storm/solver/StandardGameSolver.cpp
  2. 6
      src/storm/solver/StandardGameSolver.h

20
src/storm/solver/StandardGameSolver.cpp

@ -159,9 +159,8 @@ namespace storm {
template<typename ValueType>
bool StandardGameSolver<ValueType>::solveGameValueIteration(Environment const& env, OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector<ValueType>& x, std::vector<ValueType> const& b) const {
if(!linEqSolverPlayer2Matrix) {
linEqSolverPlayer2Matrix = linearEquationSolverFactory->create(env, player2Matrix, storm::solver::LinearEquationSolverTask::Multiply);
linEqSolverPlayer2Matrix->setCachingEnabled(true);
if (!multiplierPlayer2Matrix) {
multiplierPlayer2Matrix = storm::solver::MultiplierFactory<ValueType>().create(env, player2Matrix);
}
if (!auxiliaryP2RowVector) {
@ -204,7 +203,7 @@ namespace storm {
Status status = Status::InProgress;
while (status == Status::InProgress) {
multiplyAndReduce(player1Dir, player2Dir, *currentX, &b, *linEqSolverPlayer2Matrix, multiplyResult, reducedMultiplyResult, *newX);
multiplyAndReduce(env, player1Dir, player2Dir, *currentX, &b, *multiplierPlayer2Matrix, multiplyResult, reducedMultiplyResult, *newX);
// Determine whether the method converged.
if (storm::utility::vector::equalModuloPrecision<ValueType>(*currentX, *newX, precision, relative)) {
@ -242,9 +241,8 @@ namespace storm {
template<typename ValueType>
void StandardGameSolver<ValueType>::repeatedMultiply(Environment const& env, OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector<ValueType>& x, std::vector<ValueType> const* b, uint_fast64_t n) const {
if(!linEqSolverPlayer2Matrix) {
linEqSolverPlayer2Matrix = linearEquationSolverFactory->create(env, player2Matrix, storm::solver::LinearEquationSolverTask::Multiply);
linEqSolverPlayer2Matrix->setCachingEnabled(true);
if (!multiplierPlayer2Matrix) {
multiplierPlayer2Matrix = storm::solver::MultiplierFactory<ValueType>().create(env, player2Matrix);
}
if (!auxiliaryP2RowVector) {
@ -258,7 +256,7 @@ namespace storm {
std::vector<ValueType>& reducedMultiplyResult = *auxiliaryP2RowGroupVector;
for (uint_fast64_t iteration = 0; iteration < n; ++iteration) {
multiplyAndReduce(player1Dir, player2Dir, x, b, *linEqSolverPlayer2Matrix, multiplyResult, reducedMultiplyResult, x);
multiplyAndReduce(env, player1Dir, player2Dir, x, b, *multiplierPlayer2Matrix, multiplyResult, reducedMultiplyResult, x);
}
if(!this->isCachingEnabled()) {
@ -267,9 +265,9 @@ namespace storm {
}
template<typename ValueType>
void StandardGameSolver<ValueType>::multiplyAndReduce(OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector<ValueType>& x, std::vector<ValueType> const* b, storm::solver::LinearEquationSolver<ValueType> const& linEqSolver, std::vector<ValueType>& multiplyResult, std::vector<ValueType>& p2ReducedMultiplyResult, std::vector<ValueType>& p1ReducedMultiplyResult) const {
void StandardGameSolver<ValueType>::multiplyAndReduce(Environment const& env, OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector<ValueType>& x, std::vector<ValueType> const* b, storm::solver::Multiplier<ValueType> const& multiplier, std::vector<ValueType>& multiplyResult, std::vector<ValueType>& p2ReducedMultiplyResult, std::vector<ValueType>& p1ReducedMultiplyResult) const {
linEqSolver.multiply(x, b, multiplyResult);
multiplier.multiply(env, x, b, multiplyResult);
storm::utility::vector::reduceVectorMinOrMax(player2Dir, multiplyResult, p2ReducedMultiplyResult, player2Matrix.getRowGroupIndices());
@ -404,7 +402,7 @@ namespace storm {
template<typename ValueType>
void StandardGameSolver<ValueType>::clearCache() const {
linEqSolverPlayer2Matrix.reset();
multiplierPlayer2Matrix.reset();
auxiliaryP2RowVector.reset();
auxiliaryP2RowGroupVector.reset();
auxiliaryP1RowGroupVector.reset();

6
src/storm/solver/StandardGameSolver.h

@ -1,6 +1,7 @@
#pragma once
#include "storm/solver/LinearEquationSolver.h"
#include "storm/solver/Multiplier.h"
#include "storm/solver/GameSolver.h"
#include "SolverSelectionOptions.h"
@ -26,8 +27,7 @@ namespace storm {
bool solveGameValueIteration(Environment const& env, OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector<ValueType>& x, std::vector<ValueType> const& b) const;
// Computes p2Matrix * x + b, reduces the result w.r.t. player 2 choices, and then reduces the result w.r.t. player 1 choices.
void multiplyAndReduce(OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector<ValueType>& x, std::vector<ValueType> const* b,
storm::solver::LinearEquationSolver<ValueType> const& linEqSolver, std::vector<ValueType>& multiplyResult, std::vector<ValueType>& p2ReducedMultiplyResult, std::vector<ValueType>& p1ReducedMultiplyResult) const;
void multiplyAndReduce(Environment const& env, OptimizationDirection player1Dir, OptimizationDirection player2Dir, std::vector<ValueType>& x, std::vector<ValueType> const* b, storm::solver::Multiplier<ValueType> const& multiplier, std::vector<ValueType>& multiplyResult, std::vector<ValueType>& p2ReducedMultiplyResult, std::vector<ValueType>& p1ReducedMultiplyResult) const;
// Solves the equation system given by the two choice selections
void getInducedMatrixVector(std::vector<ValueType>& x, std::vector<ValueType> const& b, std::vector<uint_fast64_t> const& player1Choices, std::vector<uint_fast64_t> const& player2Choices, storm::storage::SparseMatrix<ValueType>& inducedMatrix, std::vector<ValueType>& inducedVector) const;
@ -43,7 +43,7 @@ namespace storm {
};
// possibly cached data
mutable std::unique_ptr<storm::solver::LinearEquationSolver<ValueType>> linEqSolverPlayer2Matrix;
mutable std::unique_ptr<storm::solver::Multiplier<ValueType>> multiplierPlayer2Matrix;
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryP2RowVector; // player2Matrix.rowCount() entries
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryP2RowGroupVector; // player2Matrix.rowGroupCount() entries
mutable std::unique_ptr<std::vector<ValueType>> auxiliaryP1RowGroupVector; // player1Matrix.rowGroupCount() entries

Loading…
Cancel
Save