diff --git a/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.cpp b/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.cpp index 0d0ec5928..eb1212419 100644 --- a/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.cpp +++ b/src/storm/modelchecker/prctl/helper/SparseDtmcPrctlHelper.cpp @@ -209,12 +209,22 @@ namespace storm { template std::vector SparseDtmcPrctlHelper::computeReachabilityRewards(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, RewardModelType const& rewardModel, storm::storage::BitVector const& targetStates, bool qualitative, storm::solver::LinearEquationSolverFactory const& linearEquationSolverFactory, ModelCheckerHint const& hint) { - return computeReachabilityRewards(env, std::move(goal), transitionMatrix, backwardTransitions, [&] (uint_fast64_t numberOfRows, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates) { return rewardModel.getTotalRewardVector(numberOfRows, transitionMatrix, maybeStates); }, targetStates, qualitative, linearEquationSolverFactory, hint); + // Extend the set of target states such that states for which target is reached without collecting any reward are included + // TODO + storm::storage::BitVector extendedTargetStates = storm::utility::graph::performProb1(backwardTransitions, rewardModel.getStatesWithZeroReward(transitionMatrix), targetStates); + STORM_LOG_INFO("Extended the set of target states from " << targetStates.getNumberOfSetBits() << " states to " << extendedTargetStates.getNumberOfSetBits() << " states."); + std::cout << "TODO: make target state extension a setting." << std::endl; + return computeReachabilityRewards(env, std::move(goal), transitionMatrix, backwardTransitions, [&] (uint_fast64_t numberOfRows, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::BitVector const& maybeStates) { return rewardModel.getTotalRewardVector(numberOfRows, transitionMatrix, maybeStates); }, extendedTargetStates, qualitative, linearEquationSolverFactory, hint); } template std::vector SparseDtmcPrctlHelper::computeReachabilityRewards(Environment const& env, storm::solver::SolveGoal&& goal, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, std::vector const& totalStateRewardVector, storm::storage::BitVector const& targetStates, bool qualitative, storm::solver::LinearEquationSolverFactory const& linearEquationSolverFactory, ModelCheckerHint const& hint) { + // TODO + storm::storage::BitVector extendedTargetStates = storm::utility::graph::performProb1(backwardTransitions, storm::utility::vector::filterZero(totalStateRewardVector), targetStates); + STORM_LOG_INFO("Extended the set of target states from " << targetStates.getNumberOfSetBits() << " states to " << extendedTargetStates.getNumberOfSetBits() << " states."); + std::cout << "TODO: make target state extension a setting" << std::endl; + return computeReachabilityRewards(env, std::move(goal), transitionMatrix, backwardTransitions, [&] (uint_fast64_t numberOfRows, storm::storage::SparseMatrix const&, storm::storage::BitVector const& maybeStates) { std::vector result(numberOfRows);