@ -1,7 +1,12 @@
# include "storm/modelchecker/prctl/helper/rewardbounded/EpochModel.h"
# include "storm/modelchecker/prctl/helper/rewardbounded/MultiDimensionalRewardUnfolding.h"
# include "storm/utility/graph.h"
# include "storm/environment/solver/MinMaxSolverEnvironment.h"
# include "storm/environment/solver/SolverEnvironment.h"
# include "storm/exceptions/UncheckedRequirementException.h"
# include "storm/exceptions/UnexpectedException.h"
namespace storm {
namespace modelchecker {
@ -46,9 +51,20 @@ namespace storm {
if ( epochModel . epochMatrixChanged ) {
x . assign ( epochModel . epochMatrix . getRowGroupCount ( ) , storm : : utility : : zero < ValueType > ( ) ) ;
storm : : solver : : GeneralLinearEquationSolverFactory < ValueType > linearEquationSolverFactory ;
linEqSolver = linearEquationSolverFactory . create ( env , epochModel . epochMatrix ) ;
// We only check for acyclic models if the equation problem has the fixedPointSystem format.
// We could also do this for other formats, however, this requires either matrix conversions or a different 'hasCycle' implementation.
// Also, we would have to match the equationProblemFormat of the acyclic solver.
bool epochMatrixAcyclic = epochModel . equationSolverProblemFormat . get ( ) = = storm : : solver : : LinearEquationSolverProblemFormat : : FixedPointSystem & & ! storm : : utility : : graph : : hasCycle ( epochModel . epochMatrix ) ;
Environment acyclicEnv ;
if ( epochMatrixAcyclic ) {
acyclicEnv = env ;
acyclicEnv . solver ( ) . setLinearEquationSolverType ( storm : : solver : : EquationSolverType : : Acyclic ) ;
linEqSolver = linearEquationSolverFactory . create ( acyclicEnv , epochModel . epochMatrix ) ;
} else {
linEqSolver = linearEquationSolverFactory . create ( env , epochModel . epochMatrix ) ;
}
linEqSolver - > setCachingEnabled ( true ) ;
auto req = linEqSolver - > getRequirements ( env ) ;
auto req = linEqSolver - > getRequirements ( epochMatrixAcyclic ? acyclicEnv : e nv ) ;
if ( lowerBound ) {
linEqSolver - > setLowerBound ( lowerBound . get ( ) ) ;
req . clearLowerBounds ( ) ;
@ -57,7 +73,11 @@ namespace storm {
linEqSolver - > setUpperBound ( upperBound . get ( ) ) ;
req . clearUpperBounds ( ) ;
}
if ( epochMatrixAcyclic ) {
req . clearAcyclic ( ) ;
}
STORM_LOG_THROW ( ! req . hasEnabledCriticalRequirement ( ) , storm : : exceptions : : UncheckedRequirementException , " Solver requirements " + req . getEnabledRequirementsAsString ( ) + " not checked. " ) ;
STORM_LOG_THROW ( linEqSolver - > getEquationProblemFormat ( epochMatrixAcyclic ? acyclicEnv : env ) = = epochModel . equationSolverProblemFormat . get ( ) , storm : : exceptions : : UnexpectedException , " The constructed solver uses a different equation problem format then the one that has been specified initially. " ) ;
}
// Prepare the right hand side of the equation system
@ -79,8 +99,6 @@ namespace storm {
return storm : : utility : : vector : : filterVector ( x , epochModel . epochInStates ) ;
}
template < typename ValueType >
std : : vector < ValueType > analyzeTrivialMdpEpochModel ( OptimizationDirection dir , EpochModel < ValueType , true > & epochModel ) {
// Assert that the epoch model is indeed trivial
@ -138,13 +156,21 @@ namespace storm {
if ( epochModel . epochMatrixChanged ) {
x . assign ( epochModel . epochMatrix . getRowGroupCount ( ) , storm : : utility : : zero < ValueType > ( ) ) ;
storm : : solver : : GeneralMinMaxLinearEquationSolverFactory < ValueType > minMaxLinearEquationSolverFactory ;
minMaxSolver = minMaxLinearEquationSolverFactory . create ( env , epochModel . epochMatrix ) ;
bool epochMatrixAcyclic = ! storm : : utility : : graph : : hasCycle ( epochModel . epochMatrix ) ;
Environment acyclicEnv ;
if ( epochMatrixAcyclic ) {
acyclicEnv = env ;
acyclicEnv . solver ( ) . minMax ( ) . setMethod ( storm : : solver : : MinMaxMethod : : Acyclic ) ;
minMaxSolver = minMaxLinearEquationSolverFactory . create ( acyclicEnv , epochModel . epochMatrix ) ;
} else {
minMaxSolver = minMaxLinearEquationSolverFactory . create ( env , epochModel . epochMatrix ) ;
}
minMaxSolver - > setHasUniqueSolution ( ) ;
minMaxSolver - > setHasNoEndComponents ( ) ;
minMaxSolver - > setOptimizationDirection ( dir ) ;
minMaxSolver - > setCachingEnabled ( true ) ;
minMaxSolver - > setTrackScheduler ( true ) ;
auto req = minMaxSolver - > getRequirements ( env , dir , false ) ;
minMaxSolver - > setTrackScheduler ( ! epochMatrixAcyclic ) ; // only track the scheduler if there are cycles
auto req = minMaxSolver - > getRequirements ( epochMatrixAcyclic ? acyclicEnv : e nv , dir , false ) ;
if ( lowerBound ) {
minMaxSolver - > setLowerBound ( lowerBound . get ( ) ) ;
req . clearLowerBounds ( ) ;
@ -153,11 +179,16 @@ namespace storm {
minMaxSolver - > setUpperBound ( upperBound . get ( ) ) ;
req . clearUpperBounds ( ) ;
}
if ( epochMatrixAcyclic ) {
req . clearAcyclic ( ) ;
}
STORM_LOG_THROW ( ! req . hasEnabledCriticalRequirement ( ) , storm : : exceptions : : UncheckedRequirementException , " Solver requirements " + req . getEnabledRequirementsAsString ( ) + " not checked. " ) ;
minMaxSolver - > setRequirementsChecked ( ) ;
} else {
auto choicesTmp = minMaxSolver - > getSchedulerChoices ( ) ;
minMaxSolver - > setInitialScheduler ( std : : move ( choicesTmp ) ) ;
if ( minMaxSolver & & minMaxSolver - > isTrackSchedulerSet ( ) ) {
auto choicesTmp = minMaxSolver - > getSchedulerChoices ( ) ;
minMaxSolver - > setInitialScheduler ( std : : move ( choicesTmp ) ) ;
}
}
// Prepare the right hand side of the equation system