From 6ddddd8cfa9591eb1090bf129ad278a67affba0c Mon Sep 17 00:00:00 2001 From: TimQu Date: Wed, 28 Oct 2015 19:04:50 +0100 Subject: [PATCH] Implemented policy extraction for value iteration Former-commit-id: 604b4667b80ff411bec899feacf98ec7cf3f031c --- src/modelchecker/region/ApproximationModel.cpp | 3 +++ src/solver/GmmxxMinMaxLinearEquationSolver.cpp | 5 +++++ src/solver/MinMaxLinearEquationSolver.cpp | 2 +- src/solver/MinMaxLinearEquationSolver.h | 12 ++++++++++++ src/solver/NativeMinMaxLinearEquationSolver.cpp | 5 +++++ src/utility/solver.cpp | 1 + 6 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/modelchecker/region/ApproximationModel.cpp b/src/modelchecker/region/ApproximationModel.cpp index c92b4d0c1..c93c0fc25 100644 --- a/src/modelchecker/region/ApproximationModel.cpp +++ b/src/modelchecker/region/ApproximationModel.cpp @@ -350,7 +350,10 @@ namespace storm { void ApproximationModel, double>::invokeSolver(bool computeLowerBounds){ storm::solver::SolveGoal goal(computeLowerBounds); std::unique_ptr> solver = storm::solver::configureMinMaxLinearEquationSolver(goal, storm::utility::solver::MinMaxLinearEquationSolverFactory(), this->matrixData.matrix); + solver->setPolicyTracking(); solver->solveEquationSystem(this->eqSysResult, this->vectorData.vector); + std::vector policy(solver->getPolicy()); + std::cout << "Policy: " << policy.size() << " entries. [0]=" << policy[0] << " [20]=" << policy[20] << std::endl; } template<> diff --git a/src/solver/GmmxxMinMaxLinearEquationSolver.cpp b/src/solver/GmmxxMinMaxLinearEquationSolver.cpp index 86b8dfa23..6c82ab1cf 100644 --- a/src/solver/GmmxxMinMaxLinearEquationSolver.cpp +++ b/src/solver/GmmxxMinMaxLinearEquationSolver.cpp @@ -89,6 +89,11 @@ namespace storm { if (!multiplyResultMemoryProvided) { delete multiplyResult; } + + if(this->trackPolicy){ + this->policy = this->computePolicy(x,b); + } + } else { // We will use Policy Iteration to solve the given system. // We first guess an initial choice resolution which will be refined after each iteration. diff --git a/src/solver/MinMaxLinearEquationSolver.cpp b/src/solver/MinMaxLinearEquationSolver.cpp index f0c2e4869..7df50cfff 100644 --- a/src/solver/MinMaxLinearEquationSolver.cpp +++ b/src/solver/MinMaxLinearEquationSolver.cpp @@ -24,7 +24,7 @@ namespace storm { } std::vector AbstractMinMaxLinearEquationSolver::getPolicy() const { - STORM_LOG_THROW(!useValueIteration, storm::exceptions::NotImplementedException, "Getting policies after value iteration is not yet supported!"); + assert(!policy.empty()); return policy; } } diff --git a/src/solver/MinMaxLinearEquationSolver.h b/src/solver/MinMaxLinearEquationSolver.h index 79d14907b..f30b6605b 100644 --- a/src/solver/MinMaxLinearEquationSolver.h +++ b/src/solver/MinMaxLinearEquationSolver.h @@ -8,6 +8,7 @@ #include "src/storage/sparse/StateType.h" #include "AllowEarlyTerminationCondition.h" #include "OptimizationDirection.h" +#include "src/utility/vector.h" namespace storm { namespace storage { @@ -140,6 +141,17 @@ namespace storm { protected: + + std::vector computePolicy(std::vector& x, std::vector const& b) const{ + std::vector xPrime(this->A.getRowCount()); + this->A.multiplyVectorWithMatrix(x, xPrime); + storm::utility::vector::addVectors(xPrime, b, xPrime); + std::vector policy(x.size()); + std::vector reduced(x.size()); + storm::utility::vector::reduceVectorMinOrMax(convert(this->direction), xPrime, reduced, this->A.getRowGroupIndices(), &(policy)); + return policy; + } + storm::storage::SparseMatrix const& A; std::unique_ptr> earlyTermination; diff --git a/src/solver/NativeMinMaxLinearEquationSolver.cpp b/src/solver/NativeMinMaxLinearEquationSolver.cpp index 1a0881402..4835de497 100644 --- a/src/solver/NativeMinMaxLinearEquationSolver.cpp +++ b/src/solver/NativeMinMaxLinearEquationSolver.cpp @@ -90,6 +90,11 @@ namespace storm { if (!multiplyResultMemoryProvided) { delete multiplyResult; } + + if(this->trackPolicy){ + this->policy = this->computePolicy(x,b); + } + } else { // We will use Policy Iteration to solve the given system. // We first guess an initial choice resolution which will be refined after each iteration. diff --git a/src/utility/solver.cpp b/src/utility/solver.cpp index 15d8a71ac..4df7215d4 100644 --- a/src/utility/solver.cpp +++ b/src/utility/solver.cpp @@ -123,6 +123,7 @@ namespace storm { case storm::solver::EquationSolverType::Topological: { STORM_LOG_THROW(prefTech != storm::solver::MinMaxTechniqueSelection::PolicyIteration, storm::exceptions::NotImplementedException, "Policy iteration for topological solver is not supported."); + STORM_LOG_THROW(!trackPolicy , storm::exceptions::NotImplementedException, "Policy extraction for Topological solver not supported."); p1.reset(new storm::solver::TopologicalMinMaxLinearEquationSolver(matrix)); break; }