Browse Source

Use acyclic solver in reward bounded properties.

tempestpy_adaptions
Tim Quatmann 5 years ago
parent
commit
c83721066c
  1. 49
      src/storm/modelchecker/prctl/helper/rewardbounded/EpochModel.cpp

49
src/storm/modelchecker/prctl/helper/rewardbounded/EpochModel.cpp

@ -1,7 +1,12 @@
#include "storm/modelchecker/prctl/helper/rewardbounded/EpochModel.h" #include "storm/modelchecker/prctl/helper/rewardbounded/EpochModel.h"
#include "storm/modelchecker/prctl/helper/rewardbounded/MultiDimensionalRewardUnfolding.h" #include "storm/modelchecker/prctl/helper/rewardbounded/MultiDimensionalRewardUnfolding.h"
#include "storm/utility/graph.h"
#include "storm/environment/solver/MinMaxSolverEnvironment.h"
#include "storm/environment/solver/SolverEnvironment.h"
#include "storm/exceptions/UncheckedRequirementException.h" #include "storm/exceptions/UncheckedRequirementException.h"
#include "storm/exceptions/UnexpectedException.h"
namespace storm { namespace storm {
namespace modelchecker { namespace modelchecker {
@ -46,9 +51,20 @@ namespace storm {
if (epochModel.epochMatrixChanged) { if (epochModel.epochMatrixChanged) {
x.assign(epochModel.epochMatrix.getRowGroupCount(), storm::utility::zero<ValueType>()); x.assign(epochModel.epochMatrix.getRowGroupCount(), storm::utility::zero<ValueType>());
storm::solver::GeneralLinearEquationSolverFactory<ValueType> linearEquationSolverFactory; storm::solver::GeneralLinearEquationSolverFactory<ValueType> linearEquationSolverFactory;
linEqSolver = linearEquationSolverFactory.create(env, epochModel.epochMatrix);
// We only check for acyclic models if the equation problem has the fixedPointSystem format.
// We could also do this for other formats, however, this requires either matrix conversions or a different 'hasCycle' implementation.
// Also, we would have to match the equationProblemFormat of the acyclic solver.
bool epochMatrixAcyclic = epochModel.equationSolverProblemFormat.get() == storm::solver::LinearEquationSolverProblemFormat::FixedPointSystem && !storm::utility::graph::hasCycle(epochModel.epochMatrix);
Environment acyclicEnv;
if (epochMatrixAcyclic) {
acyclicEnv = env;
acyclicEnv.solver().setLinearEquationSolverType(storm::solver::EquationSolverType::Acyclic);
linEqSolver = linearEquationSolverFactory.create(acyclicEnv, epochModel.epochMatrix);
} else {
linEqSolver = linearEquationSolverFactory.create(env, epochModel.epochMatrix);
}
linEqSolver->setCachingEnabled(true); linEqSolver->setCachingEnabled(true);
auto req = linEqSolver->getRequirements(env);
auto req = linEqSolver->getRequirements(epochMatrixAcyclic ? acyclicEnv : env);
if (lowerBound) { if (lowerBound) {
linEqSolver->setLowerBound(lowerBound.get()); linEqSolver->setLowerBound(lowerBound.get());
req.clearLowerBounds(); req.clearLowerBounds();
@ -57,7 +73,11 @@ namespace storm {
linEqSolver->setUpperBound(upperBound.get()); linEqSolver->setUpperBound(upperBound.get());
req.clearUpperBounds(); req.clearUpperBounds();
} }
if (epochMatrixAcyclic) {
req.clearAcyclic();
}
STORM_LOG_THROW(!req.hasEnabledCriticalRequirement(), storm::exceptions::UncheckedRequirementException, "Solver requirements " + req.getEnabledRequirementsAsString() + " not checked."); STORM_LOG_THROW(!req.hasEnabledCriticalRequirement(), storm::exceptions::UncheckedRequirementException, "Solver requirements " + req.getEnabledRequirementsAsString() + " not checked.");
STORM_LOG_THROW(linEqSolver->getEquationProblemFormat(epochMatrixAcyclic ? acyclicEnv : env) == epochModel.equationSolverProblemFormat.get(), storm::exceptions::UnexpectedException, "The constructed solver uses a different equation problem format then the one that has been specified initially.");
} }
// Prepare the right hand side of the equation system // Prepare the right hand side of the equation system
@ -79,8 +99,6 @@ namespace storm {
return storm::utility::vector::filterVector(x, epochModel.epochInStates); return storm::utility::vector::filterVector(x, epochModel.epochInStates);
} }
template<typename ValueType> template<typename ValueType>
std::vector<ValueType> analyzeTrivialMdpEpochModel(OptimizationDirection dir, EpochModel<ValueType, true>& epochModel) { std::vector<ValueType> analyzeTrivialMdpEpochModel(OptimizationDirection dir, EpochModel<ValueType, true>& epochModel) {
// Assert that the epoch model is indeed trivial // Assert that the epoch model is indeed trivial
@ -138,13 +156,21 @@ namespace storm {
if (epochModel.epochMatrixChanged) { if (epochModel.epochMatrixChanged) {
x.assign(epochModel.epochMatrix.getRowGroupCount(), storm::utility::zero<ValueType>()); x.assign(epochModel.epochMatrix.getRowGroupCount(), storm::utility::zero<ValueType>());
storm::solver::GeneralMinMaxLinearEquationSolverFactory<ValueType> minMaxLinearEquationSolverFactory; storm::solver::GeneralMinMaxLinearEquationSolverFactory<ValueType> minMaxLinearEquationSolverFactory;
minMaxSolver = minMaxLinearEquationSolverFactory.create(env, epochModel.epochMatrix);
bool epochMatrixAcyclic = !storm::utility::graph::hasCycle(epochModel.epochMatrix);
Environment acyclicEnv;
if (epochMatrixAcyclic) {
acyclicEnv = env;
acyclicEnv.solver().minMax().setMethod(storm::solver::MinMaxMethod::Acyclic);
minMaxSolver = minMaxLinearEquationSolverFactory.create(acyclicEnv, epochModel.epochMatrix);
} else {
minMaxSolver = minMaxLinearEquationSolverFactory.create(env, epochModel.epochMatrix);
}
minMaxSolver->setHasUniqueSolution(); minMaxSolver->setHasUniqueSolution();
minMaxSolver->setHasNoEndComponents(); minMaxSolver->setHasNoEndComponents();
minMaxSolver->setOptimizationDirection(dir); minMaxSolver->setOptimizationDirection(dir);
minMaxSolver->setCachingEnabled(true); minMaxSolver->setCachingEnabled(true);
minMaxSolver->setTrackScheduler(true);
auto req = minMaxSolver->getRequirements(env, dir, false);
minMaxSolver->setTrackScheduler(!epochMatrixAcyclic); // only track the scheduler if there are cycles
auto req = minMaxSolver->getRequirements(epochMatrixAcyclic ? acyclicEnv : env, dir, false);
if (lowerBound) { if (lowerBound) {
minMaxSolver->setLowerBound(lowerBound.get()); minMaxSolver->setLowerBound(lowerBound.get());
req.clearLowerBounds(); req.clearLowerBounds();
@ -153,11 +179,16 @@ namespace storm {
minMaxSolver->setUpperBound(upperBound.get()); minMaxSolver->setUpperBound(upperBound.get());
req.clearUpperBounds(); req.clearUpperBounds();
} }
if (epochMatrixAcyclic) {
req.clearAcyclic();
}
STORM_LOG_THROW(!req.hasEnabledCriticalRequirement(), storm::exceptions::UncheckedRequirementException, "Solver requirements " + req.getEnabledRequirementsAsString() + " not checked."); STORM_LOG_THROW(!req.hasEnabledCriticalRequirement(), storm::exceptions::UncheckedRequirementException, "Solver requirements " + req.getEnabledRequirementsAsString() + " not checked.");
minMaxSolver->setRequirementsChecked(); minMaxSolver->setRequirementsChecked();
} else { } else {
auto choicesTmp = minMaxSolver->getSchedulerChoices();
minMaxSolver->setInitialScheduler(std::move(choicesTmp));
if (minMaxSolver && minMaxSolver->isTrackSchedulerSet()) {
auto choicesTmp = minMaxSolver->getSchedulerChoices();
minMaxSolver->setInitialScheduler(std::move(choicesTmp));
}
} }
// Prepare the right hand side of the equation system // Prepare the right hand side of the equation system

Loading…
Cancel
Save