From c83721066cd8553d61d4fd3060f27d16884d537d Mon Sep 17 00:00:00 2001 From: Tim Quatmann Date: Mon, 9 Mar 2020 14:00:09 +0100 Subject: [PATCH] Use acyclic solver in reward bounded properties. --- .../prctl/helper/rewardbounded/EpochModel.cpp | 49 +++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/storm/modelchecker/prctl/helper/rewardbounded/EpochModel.cpp b/src/storm/modelchecker/prctl/helper/rewardbounded/EpochModel.cpp index f48253fca..63c1e2c76 100644 --- a/src/storm/modelchecker/prctl/helper/rewardbounded/EpochModel.cpp +++ b/src/storm/modelchecker/prctl/helper/rewardbounded/EpochModel.cpp @@ -1,7 +1,12 @@ #include "storm/modelchecker/prctl/helper/rewardbounded/EpochModel.h" #include "storm/modelchecker/prctl/helper/rewardbounded/MultiDimensionalRewardUnfolding.h" +#include "storm/utility/graph.h" +#include "storm/environment/solver/MinMaxSolverEnvironment.h" +#include "storm/environment/solver/SolverEnvironment.h" + #include "storm/exceptions/UncheckedRequirementException.h" +#include "storm/exceptions/UnexpectedException.h" namespace storm { namespace modelchecker { @@ -46,9 +51,20 @@ namespace storm { if (epochModel.epochMatrixChanged) { x.assign(epochModel.epochMatrix.getRowGroupCount(), storm::utility::zero()); storm::solver::GeneralLinearEquationSolverFactory linearEquationSolverFactory; - linEqSolver = linearEquationSolverFactory.create(env, epochModel.epochMatrix); + // We only check for acyclic models if the equation problem has the fixedPointSystem format. + // We could also do this for other formats, however, this requires either matrix conversions or a different 'hasCycle' implementation. + // Also, we would have to match the equationProblemFormat of the acyclic solver. + bool epochMatrixAcyclic = epochModel.equationSolverProblemFormat.get() == storm::solver::LinearEquationSolverProblemFormat::FixedPointSystem && !storm::utility::graph::hasCycle(epochModel.epochMatrix); + Environment acyclicEnv; + if (epochMatrixAcyclic) { + acyclicEnv = env; + acyclicEnv.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Acyclic); + linEqSolver = linearEquationSolverFactory.create(acyclicEnv, epochModel.epochMatrix); + } else { + linEqSolver = linearEquationSolverFactory.create(env, epochModel.epochMatrix); + } linEqSolver->setCachingEnabled(true); - auto req = linEqSolver->getRequirements(env); + auto req = linEqSolver->getRequirements(epochMatrixAcyclic ? acyclicEnv : env); if (lowerBound) { linEqSolver->setLowerBound(lowerBound.get()); req.clearLowerBounds(); @@ -57,7 +73,11 @@ namespace storm { linEqSolver->setUpperBound(upperBound.get()); req.clearUpperBounds(); } + if (epochMatrixAcyclic) { + req.clearAcyclic(); + } STORM_LOG_THROW(!req.hasEnabledCriticalRequirement(), storm::exceptions::UncheckedRequirementException, "Solver requirements " + req.getEnabledRequirementsAsString() + " not checked."); + STORM_LOG_THROW(linEqSolver->getEquationProblemFormat(epochMatrixAcyclic ? acyclicEnv : env) == epochModel.equationSolverProblemFormat.get(), storm::exceptions::UnexpectedException, "The constructed solver uses a different equation problem format then the one that has been specified initially."); } // Prepare the right hand side of the equation system @@ -79,8 +99,6 @@ namespace storm { return storm::utility::vector::filterVector(x, epochModel.epochInStates); } - - template std::vector analyzeTrivialMdpEpochModel(OptimizationDirection dir, EpochModel& epochModel) { // Assert that the epoch model is indeed trivial @@ -138,13 +156,21 @@ namespace storm { if (epochModel.epochMatrixChanged) { x.assign(epochModel.epochMatrix.getRowGroupCount(), storm::utility::zero()); storm::solver::GeneralMinMaxLinearEquationSolverFactory minMaxLinearEquationSolverFactory; - minMaxSolver = minMaxLinearEquationSolverFactory.create(env, epochModel.epochMatrix); + bool epochMatrixAcyclic = !storm::utility::graph::hasCycle(epochModel.epochMatrix); + Environment acyclicEnv; + if (epochMatrixAcyclic) { + acyclicEnv = env; + acyclicEnv.solver().minMax().setMethod(storm::solver::MinMaxMethod::Acyclic); + minMaxSolver = minMaxLinearEquationSolverFactory.create(acyclicEnv, epochModel.epochMatrix); + } else { + minMaxSolver = minMaxLinearEquationSolverFactory.create(env, epochModel.epochMatrix); + } minMaxSolver->setHasUniqueSolution(); minMaxSolver->setHasNoEndComponents(); minMaxSolver->setOptimizationDirection(dir); minMaxSolver->setCachingEnabled(true); - minMaxSolver->setTrackScheduler(true); - auto req = minMaxSolver->getRequirements(env, dir, false); + minMaxSolver->setTrackScheduler(!epochMatrixAcyclic); // only track the scheduler if there are cycles + auto req = minMaxSolver->getRequirements(epochMatrixAcyclic ? acyclicEnv : env, dir, false); if (lowerBound) { minMaxSolver->setLowerBound(lowerBound.get()); req.clearLowerBounds(); @@ -153,11 +179,16 @@ namespace storm { minMaxSolver->setUpperBound(upperBound.get()); req.clearUpperBounds(); } + if (epochMatrixAcyclic) { + req.clearAcyclic(); + } STORM_LOG_THROW(!req.hasEnabledCriticalRequirement(), storm::exceptions::UncheckedRequirementException, "Solver requirements " + req.getEnabledRequirementsAsString() + " not checked."); minMaxSolver->setRequirementsChecked(); } else { - auto choicesTmp = minMaxSolver->getSchedulerChoices(); - minMaxSolver->setInitialScheduler(std::move(choicesTmp)); + if (minMaxSolver && minMaxSolver->isTrackSchedulerSet()) { + auto choicesTmp = minMaxSolver->getSchedulerChoices(); + minMaxSolver->setInitialScheduler(std::move(choicesTmp)); + } } // Prepare the right hand side of the equation system