diff --git a/src/modelchecker/region/ApproximationModel.cpp b/src/modelchecker/region/ApproximationModel.cpp index 479b9d7f9..4acd5377d 100644 --- a/src/modelchecker/region/ApproximationModel.cpp +++ b/src/modelchecker/region/ApproximationModel.cpp @@ -70,6 +70,9 @@ namespace storm { this->solverData.result = std::vector(maybeStates.getNumberOfSetBits(), this->computeRewards ? storm::utility::one() : ConstantType(0.5)); this->solverData.initialStateIndex = newIndices[initialState]; + this->solverData.lastMinimizingPolicy = Policy(this->matrixData.matrix.getRowGroupCount(), 0); + this->solverData.lastMaximizingPolicy = Policy(this->matrixData.matrix.getRowGroupCount(), 0); + this->solverData.lastPlayer1Policy = Policy(this->matrixData.matrix.getRowGroupCount(), 0); } template @@ -272,10 +275,6 @@ namespace storm { instantiate(region, computeLowerBounds); Policy& policy = computeLowerBounds ? this->solverData.lastMinimizingPolicy : this->solverData.lastMaximizingPolicy; //TODO: at this point, set policy to the one stored in the region. - if(policy.empty()){ - //No guess available (yet) - policy = Policy(this->matrixData.matrix.getRowGroupCount(), 0); - } invokeSolver(computeLowerBounds, policy); //TODO: policy for games. diff --git a/src/modelchecker/region/SamplingModel.cpp b/src/modelchecker/region/SamplingModel.cpp index 2df9c63fd..f1fa8ef52 100644 --- a/src/modelchecker/region/SamplingModel.cpp +++ b/src/modelchecker/region/SamplingModel.cpp @@ -26,7 +26,7 @@ namespace storm { namespace region { template - SamplingModel::SamplingModel(ParametricSparseModelType const& parametricModel, std::shared_ptr formula) : solveGoal(storm::logic::isLowerBound(formula->getComparisonType())){ + SamplingModel::SamplingModel(ParametricSparseModelType const& parametricModel, std::shared_ptr formula){ //First some simple checks and initializations.. if(formula->isProbabilityOperatorFormula()){ this->computeRewards=false; @@ -38,6 +38,7 @@ namespace storm { } else { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Invalid formula: " << formula << ". Sampling model only supports eventually or reachability reward formulae."); } + this->solverData.solveGoal = storm::solver::SolveGoal(storm::logic::isLowerBound(formula->getComparisonType())); STORM_LOG_THROW(parametricModel.hasLabel("target"), storm::exceptions::InvalidArgumentException, "The given Model has no \"target\"-statelabel."); this->targetStates = parametricModel.getStateLabeling().getStates("target"); STORM_LOG_THROW(parametricModel.hasLabel("sink"), storm::exceptions::InvalidArgumentException, "The given Model has no \"sink\"-statelabel."); @@ -63,8 +64,9 @@ namespace storm { this->matrixData.assignment.shrink_to_fit(); this->vectorData.assignment.shrink_to_fit(); - this->eqSysResult = std::vector(maybeStates.getNumberOfSetBits(), this->computeRewards ? storm::utility::one() : ConstantType(0.5)); - this->eqSysInitIndex = newIndices[initialState]; + this->solverData.result = std::vector(maybeStates.getNumberOfSetBits(), this->computeRewards ? storm::utility::one() : ConstantType(0.5)); + this->solverData.initialStateIndex = newIndices[initialState]; + this->solverData.lastPolicy = Policy(this->matrixData.matrix.getRowGroupCount(), 0); } template @@ -209,7 +211,7 @@ namespace storm { instantiate(point); invokeSolver(); std::vector result(this->maybeStates.size()); - storm::utility::vector::setVectorValues(result, this->maybeStates, this->eqSysResult); + storm::utility::vector::setVectorValues(result, this->maybeStates, this->solverData.result); storm::utility::vector::setVectorValues(result, this->targetStates, this->computeRewards ? storm::utility::zero() : storm::utility::one()); storm::utility::vector::setVectorValues(result, ~(this->maybeStates | this->targetStates), this->computeRewards ? storm::utility::infinity() : storm::utility::zero()); @@ -220,7 +222,7 @@ namespace storm { ConstantType SamplingModel::computeInitialStateValue(std::mapconst& point) { instantiate(point); invokeSolver(); - return this->eqSysResult[this->eqSysInitIndex]; + return this->solverData.result[this->solverData.initialStateIndex]; } template @@ -247,16 +249,14 @@ namespace storm { template<> void SamplingModel, double>::invokeSolver(){ std::unique_ptr> solver = storm::utility::solver::LinearEquationSolverFactory().create(this->matrixData.matrix); - solver->solveEquationSystem(this->eqSysResult, this->vectorData.vector); + solver->solveEquationSystem(this->solverData.result, this->vectorData.vector); } template<> void SamplingModel, double>::invokeSolver(){ - std::unique_ptr> solver = storm::solver::configureMinMaxLinearEquationSolver(this->solveGoal, storm::utility::solver::MinMaxLinearEquationSolverFactory(), this->matrixData.matrix); - if(!this->solveGoal.minimize()){ - //The value iteration method is not correct if the value is maximized and the initial x-vector is not <= the actual probability/reward. - this->eqSysResult.assign(this->eqSysResult.size(), storm::utility::zero()); - } - solver->solveEquationSystem(this->eqSysResult, this->vectorData.vector); + std::unique_ptr> solver = storm::solver::configureMinMaxLinearEquationSolver(this->solverData.solveGoal, storm::utility::solver::MinMaxLinearEquationSolverFactory(), this->matrixData.matrix); + solver->setPolicyTracking(); + solver->solveEquationSystem(this->solverData.solveGoal.direction(), this->solverData.result, this->vectorData.vector, nullptr, nullptr, &this->solverData.lastPolicy); + this->solverData.lastPolicy = solver->getPolicy(); } diff --git a/src/modelchecker/region/SamplingModel.h b/src/modelchecker/region/SamplingModel.h index ecf2db01c..e13757ce7 100644 --- a/src/modelchecker/region/SamplingModel.h +++ b/src/modelchecker/region/SamplingModel.h @@ -51,24 +51,28 @@ namespace storm { ConstantType computeInitialStateValue(std::mapconst& point); private: - typedef typename std::unordered_map::value_type FunctionEntry; + typedef std::vector Policy; + void initializeProbabilities(ParametricSparseModelType const& parametricModel, std::vector const& newIndices); void initializeRewards(ParametricSparseModelType const& parametricModel, std::vector const& newIndices); void instantiate(std::mapconst& point); void invokeSolver(); - //Some designated states in the original model - storm::storage::BitVector targetStates, maybeStates; - //The last result of the solving the equation system. Also serves as first guess for the next call. - //Note: eqSysResult.size==maybeStates.numberOfSetBits - std::vector eqSysResult; - //The index which represents the result for the initial state in the eqSysResult vector - std::size_t eqSysInitIndex; //A flag that denotes whether we compute probabilities or rewards bool computeRewards; - //The goal we want to accomplish when solving the eq sys. - storm::solver::SolveGoal solveGoal; + + //Some designated states in the original model + storm::storage::BitVector targetStates, maybeStates; + + struct SolverData{ + //The result from the previous instantiation. Serve as first guess for the next call. + std::vector result; //Note: result.size==maybeStates.numberOfSetBits + std::size_t initialStateIndex; //The index which represents the result for the initial state in the result vector + //The following is only relevant if we consider mdps: + storm::solver::SolveGoal solveGoal = storm::solver::SolveGoal(true); //No default cunstructor for solve goal... + Policy lastPolicy; //best policy from the previous instantiation. Serves as first guess for the next call. + } solverData; /* The data required for the equation system, i.e., a matrix and a vector. diff --git a/test/functional/modelchecker/SparseDtmcRegionModelCheckerTest.cpp b/test/functional/modelchecker/SparseDtmcRegionModelCheckerTest.cpp index d07d5ab00..bbf247c6e 100644 --- a/test/functional/modelchecker/SparseDtmcRegionModelCheckerTest.cpp +++ b/test/functional/modelchecker/SparseDtmcRegionModelCheckerTest.cpp @@ -87,10 +87,10 @@ TEST(SparseDtmcRegionModelCheckerTest, Brp_Prob) { //smt EXPECT_EQ((storm::modelchecker::region::RegionCheckResult::ALLVIOLATED), allVioRegionSmt.getCheckResult()); storm::settings::mutableRegionSettings().resetModes(); + carl::VariablePool::getInstance().clear(); } TEST(SparseDtmcRegionModelCheckerTest, Brp_Rew) { - std::string const& programFile = STORM_CPP_BASE_PATH "/examples/pdtmc/brp_rewards/brp_16_2.pm"; std::string const& formulaAsString = "R>2.5 [F \"target\" ]"; std::string const& constantsAsString = "pL=0.9,TOAck=0.5"; @@ -178,8 +178,8 @@ TEST(SparseDtmcRegionModelCheckerTest, Brp_Rew) { modelchecker.checkRegion(exBothHardRegionSmt); //smt EXPECT_EQ((storm::modelchecker::region::RegionCheckResult::EXISTSBOTH), exBothHardRegionSmtApp.getCheckResult()); - storm::settings::mutableRegionSettings().resetModes(); + carl::VariablePool::getInstance().clear(); } TEST(SparseDtmcRegionModelCheckerTest, Brp_Rew_Infty) { @@ -226,6 +226,7 @@ TEST(SparseDtmcRegionModelCheckerTest, Brp_Rew_Infty) { //smt EXPECT_EQ((storm::modelchecker::region::RegionCheckResult::ALLSAT), allSatRegionSmt.getCheckResult()); storm::settings::mutableRegionSettings().resetModes(); + carl::VariablePool::getInstance().clear(); } TEST(SparseDtmcRegionModelCheckerTest, Brp_Rew_4Par) { @@ -289,6 +290,7 @@ TEST(SparseDtmcRegionModelCheckerTest, Brp_Rew_4Par) { //smt EXPECT_EQ((storm::modelchecker::region::RegionCheckResult::ALLVIOLATED), allVioRegionSmt.getCheckResult()); storm::settings::mutableRegionSettings().resetModes(); + carl::VariablePool::getInstance().clear(); } TEST(SparseDtmcRegionModelCheckerTest, Crowds_Prob) { @@ -377,6 +379,7 @@ TEST(SparseDtmcRegionModelCheckerTest, Crowds_Prob) { //smt EXPECT_EQ((storm::modelchecker::region::RegionCheckResult::ALLVIOLATED), allVioHardRegionSmtApp.getCheckResult()); storm::settings::mutableRegionSettings().resetModes(); + carl::VariablePool::getInstance().clear(); } TEST(SparseDtmcRegionModelCheckerTest, Crowds_Prob_1Par) { @@ -442,6 +445,7 @@ TEST(SparseDtmcRegionModelCheckerTest, Crowds_Prob_1Par) { //smt EXPECT_EQ((storm::modelchecker::region::RegionCheckResult::ALLVIOLATED), allVioRegionSmt.getCheckResult()); storm::settings::mutableRegionSettings().resetModes(); + carl::VariablePool::getInstance().clear(); } TEST(SparseDtmcRegionModelCheckerTest, Crowds_Prob_Const) { @@ -490,6 +494,7 @@ TEST(SparseDtmcRegionModelCheckerTest, Crowds_Prob_Const) { //smt EXPECT_EQ((storm::modelchecker::region::RegionCheckResult::ALLSAT), allSatRegionSmt.getCheckResult()); storm::settings::mutableRegionSettings().resetModes(); + carl::VariablePool::getInstance().clear(); } #endif \ No newline at end of file diff --git a/test/functional/modelchecker/SparseMdpRegionModelCheckerTest.cpp b/test/functional/modelchecker/SparseMdpRegionModelCheckerTest.cpp index ab0a4e03d..d4314d4de 100644 --- a/test/functional/modelchecker/SparseMdpRegionModelCheckerTest.cpp +++ b/test/functional/modelchecker/SparseMdpRegionModelCheckerTest.cpp @@ -66,6 +66,7 @@ TEST(SparseMdpRegionModelCheckerTest, two_dice_Prob) { EXPECT_EQ((storm::modelchecker::region::RegionCheckResult::ALLVIOLATED), allVioRegion.getCheckResult()); storm::settings::mutableRegionSettings().resetModes(); + carl::VariablePool::getInstance().clear(); } TEST(SparseMdpRegionModelCheckerTest, coin_Prob) { @@ -94,11 +95,11 @@ TEST(SparseMdpRegionModelCheckerTest, coin_Prob) { auto exBothRegion=storm::modelchecker::region::ParameterRegion::parseRegion("0.4<=p<=0.65,0.5<=q<=0.7"); auto allVioRegion=storm::modelchecker::region::ParameterRegion::parseRegion("0.4<=p<=0.7,0.55<=q<=0.6"); - EXPECT_NEAR(0.9512773402, modelchecker.getReachabilityValue(allSatRegion.getLowerBounds()), storm::settings::generalSettings().getPrecision()); + EXPECT_NEAR(0.95127874851, modelchecker.getReachabilityValue(allSatRegion.getLowerBounds()), storm::settings::generalSettings().getPrecision()); EXPECT_NEAR(0.26787251126, modelchecker.getReachabilityValue(allSatRegion.getUpperBounds()), storm::settings::generalSettings().getPrecision()); - EXPECT_NEAR(0.41879628383, modelchecker.getReachabilityValue(exBothRegion.getLowerBounds()), storm::settings::generalSettings().getPrecision()); + EXPECT_NEAR(0.41880006098, modelchecker.getReachabilityValue(exBothRegion.getLowerBounds()), storm::settings::generalSettings().getPrecision()); EXPECT_NEAR(0.01535089684, modelchecker.getReachabilityValue(exBothRegion.getUpperBounds()), storm::settings::generalSettings().getPrecision()); - EXPECT_NEAR(0.24952471590, modelchecker.getReachabilityValue(allVioRegion.getLowerBounds()), storm::settings::generalSettings().getPrecision()); + EXPECT_NEAR(0.24952791523, modelchecker.getReachabilityValue(allVioRegion.getLowerBounds()), storm::settings::generalSettings().getPrecision()); EXPECT_NEAR(0.01711494956, modelchecker.getReachabilityValue(allVioRegion.getUpperBounds()), storm::settings::generalSettings().getPrecision()); //test approximative method @@ -114,6 +115,7 @@ TEST(SparseMdpRegionModelCheckerTest, coin_Prob) { EXPECT_EQ((storm::modelchecker::region::RegionCheckResult::ALLVIOLATED), allVioRegion.getCheckResult()); storm::settings::mutableRegionSettings().resetModes(); + carl::VariablePool::getInstance().clear(); } #endif \ No newline at end of file