From a9f72198a075459db7dd5d95567be3ec314a223b Mon Sep 17 00:00:00 2001 From: TimQu Date: Thu, 18 Jan 2018 20:10:27 +0100 Subject: [PATCH] made filtering states with reward zero a setting --- .../prctl/helper/SparseDtmcPrctlHelper.cpp | 49 ++++++++++++------- .../prctl/helper/SparseDtmcPrctlHelper.h | 2 +- .../settings/modules/ModelCheckerSettings.cpp | 4 +- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.cpp b/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.cpp index fd4ca4ff1..f10e1155e 100644 --- a/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.cpp +++ b/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.cpp @@ -382,29 +382,32 @@ namespace storm { template std::vector SparseDtmcPrctlHelper::computeReachabilityRewards(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, RewardModelType const& rewardModel, storm::storage::BitVector const& targetStates, bool qualitative, storm::solver::LinearEquationSolverFactory const& linearEquationSolverFactory, ModelCheckerHint const& hint) { - // Extend the set of target states such that states for which target is reached without collecting any reward are included - // TODO - storm::storage::BitVector extendedTargetStates = storm::utility::graph::performProb1(backwardTransitions, rewardModel.getStatesWithZeroReward(transitionMatrix), targetStates); - STORM_LOG_INFO("Extended the set of target states from " << targetStates.getNumberOfSetBits() << " states to " << extendedTargetStates.getNumberOfSetBits() << " states."); - std::cout << "TODO: make target state extension a setting." << std::endl; - return computeReachabilityRewards(env, std::move(goal), transitionMatrix, backwardTransitions, [&] (uint_fast64_t numberOfRows, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates) { return rewardModel.getTotalRewardVector(numberOfRows, transitionMatrix, maybeStates); }, extendedTargetStates, qualitative, linearEquationSolverFactory, hint); + + return computeReachabilityRewards(env, std::move(goal), transitionMatrix, backwardTransitions, + [&] (uint_fast64_t numberOfRows, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates) { + return rewardModel.getTotalRewardVector(numberOfRows, transitionMatrix, maybeStates); + }, + targetStates, qualitative, linearEquationSolverFactory, + [&] () { + return rewardModel.getStatesWithZeroReward(transitionMatrix); + }, + hint); } template std::vector SparseDtmcPrctlHelper::computeReachabilityRewards(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, std::vector const& totalStateRewardVector, storm::storage::BitVector const& targetStates, bool qualitative, storm::solver::LinearEquationSolverFactory const& linearEquationSolverFactory, ModelCheckerHint const& hint) { - // TODO - storm::storage::BitVector extendedTargetStates = storm::utility::graph::performProb1(backwardTransitions, storm::utility::vector::filterZero(totalStateRewardVector), targetStates); - STORM_LOG_INFO("Extended the set of target states from " << targetStates.getNumberOfSetBits() << " states to " << extendedTargetStates.getNumberOfSetBits() << " states."); - std::cout << "TODO: make target state extension a setting" << std::endl; - return computeReachabilityRewards(env, std::move(goal), transitionMatrix, backwardTransitions, [&] (uint_fast64_t numberOfRows, storm::storage::SparseMatrix const&, storm::storage::BitVector const& maybeStates) { std::vector result(numberOfRows); storm::utility::vector::selectVectorValues(result, maybeStates, totalStateRewardVector); return result; }, - targetStates, qualitative, linearEquationSolverFactory, hint); + targetStates, qualitative, linearEquationSolverFactory, + [&] () { + return storm::utility::vector::filterZero(totalStateRewardVector); + }, + hint); } // This function computes an upper bound on the reachability rewards (see Baier et al, CAV'17). @@ -421,24 +424,32 @@ namespace storm { } template - std::vector SparseDtmcPrctlHelper::computeReachabilityRewards(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, std::function(uint_fast64_t, storm::storage::SparseMatrix const&, storm::storage::BitVector const&)> const& totalStateRewardVectorGetter, storm::storage::BitVector const& targetStates, bool qualitative, storm::solver::LinearEquationSolverFactory const& linearEquationSolverFactory, ModelCheckerHint const& hint) { + std::vector SparseDtmcPrctlHelper::computeReachabilityRewards(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, std::function(uint_fast64_t, storm::storage::SparseMatrix const&, storm::storage::BitVector const&)> const& totalStateRewardVectorGetter, storm::storage::BitVector const& targetStates, bool qualitative, storm::solver::LinearEquationSolverFactory const& linearEquationSolverFactory, std::function const& zeroRewardStatesGetter, ModelCheckerHint const& hint) { std::vector result(transitionMatrix.getRowCount(), storm::utility::zero()); + // Determine which states have reward zero + storm::storage::BitVector rew0States; + if (storm::settings::getModule().isFilterRewZeroSet()) { + rew0States = storm::utility::graph::performProb1(backwardTransitions, zeroRewardStatesGetter(), targetStates); + } else { + rew0States = targetStates; + } + // Determine which states have a reward that is less than infinity. storm::storage::BitVector maybeStates; if (hint.isExplicitModelCheckerHint() && hint.template asExplicitModelCheckerHint().getComputeOnlyMaybeStates()) { maybeStates = hint.template asExplicitModelCheckerHint().getMaybeStates(); - storm::utility::vector::setVectorValues(result, ~(maybeStates | targetStates), storm::utility::infinity()); + storm::utility::vector::setVectorValues(result, ~(maybeStates | rew0States), storm::utility::infinity()); - STORM_LOG_INFO("Preprocessing: " << targetStates.getNumberOfSetBits() << " target states (" << maybeStates.getNumberOfSetBits() << " states remaining)."); + STORM_LOG_INFO("Preprocessing: " << rew0States.getNumberOfSetBits() << " States with reward zero (" << maybeStates.getNumberOfSetBits() << " states remaining)."); } else { storm::storage::BitVector trueStates(transitionMatrix.getRowCount(), true); - storm::storage::BitVector infinityStates = storm::utility::graph::performProb1(backwardTransitions, trueStates, targetStates); + storm::storage::BitVector infinityStates = storm::utility::graph::performProb1(backwardTransitions, trueStates, rew0States); infinityStates.complement(); - maybeStates = ~(targetStates | infinityStates); + maybeStates = ~(rew0States | infinityStates); - STORM_LOG_INFO("Preprocessing: " << infinityStates.getNumberOfSetBits() << " states with reward infinity, " << targetStates.getNumberOfSetBits() << " target states (" << maybeStates.getNumberOfSetBits() << " states remaining)."); + STORM_LOG_INFO("Preprocessing: " << infinityStates.getNumberOfSetBits() << " states with reward infinity, " << rew0States.getNumberOfSetBits() << " states with reward zero (" << maybeStates.getNumberOfSetBits() << " states remaining)."); storm::utility::vector::setVectorValues(result, infinityStates, storm::utility::infinity()); } @@ -473,7 +484,7 @@ namespace storm { boost::optional> upperRewardBounds; requirements.clearLowerBounds(); if (requirements.requiresUpperBounds()) { - upperRewardBounds = computeUpperRewardBounds(submatrix, b, transitionMatrix.getConstrainedRowSumVector(maybeStates, targetStates)); + upperRewardBounds = computeUpperRewardBounds(submatrix, b, transitionMatrix.getConstrainedRowSumVector(maybeStates, rew0States)); requirements.clearUpperBounds(); } STORM_LOG_THROW(requirements.empty(), storm::exceptions::UncheckedRequirementException, "There are unchecked requirements of the solver."); diff --git a/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.h b/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.h index ead15de74..e9b4f939f 100644 --- a/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.h +++ b/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.h @@ -57,7 +57,7 @@ namespace storm { static std::vector computeConditionalRewards(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, RewardModelType const& rewardModel, storm::storage::BitVector const& targetStates, storm::storage::BitVector const& conditionStates, bool qualitative, storm::solver::LinearEquationSolverFactory const& linearEquationSolverFactory); private: - static std::vector computeReachabilityRewards(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, std::function(uint_fast64_t, storm::storage::SparseMatrix const&, storm::storage::BitVector const&)> const& totalStateRewardVectorGetter, storm::storage::BitVector const& targetStates, bool qualitative, storm::solver::LinearEquationSolverFactory const& linearEquationSolverFactory, ModelCheckerHint const& hint = ModelCheckerHint()); + static std::vector computeReachabilityRewards(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, std::function(uint_fast64_t, storm::storage::SparseMatrix const&, storm::storage::BitVector const&)> const& totalStateRewardVectorGetter, storm::storage::BitVector const& targetStates, bool qualitative, storm::solver::LinearEquationSolverFactory const& linearEquationSolverFactory, std::function const& zeroRewardStatesGetter, ModelCheckerHint const& hint = ModelCheckerHint()); struct BaierTransformedModel { BaierTransformedModel() : noTargetStates(false) { diff --git a/src/storm/settings/modules/ModelCheckerSettings.cpp b/src/storm/settings/modules/ModelCheckerSettings.cpp index ec5ec7921..9d7646327 100644 --- a/src/storm/settings/modules/ModelCheckerSettings.cpp +++ b/src/storm/settings/modules/ModelCheckerSettings.cpp @@ -12,8 +12,8 @@ namespace storm { namespace settings { namespace modules { - const std::string GeneralSettings::moduleName = "modelchecker"; - const std::string GeneralSettings::filterRewZeroOptionName = "filterrewzero"; + const std::string ModelCheckerSettings::moduleName = "modelchecker"; + const std::string ModelCheckerSettings::filterRewZeroOptionName = "filterrewzero"; ModelCheckerSettings::ModelCheckerSettings() : ModuleSettings(moduleName) { this->addOption(storm::settings::OptionBuilder(moduleName, filterRewZeroOptionName, false, "If set, states with reward zero are filtered out, potentially reducing the size of the equation system").build());