From 12b10af6728029fc5bf7fbdff0ea6fe35a767ec7 Mon Sep 17 00:00:00 2001 From: dehnert Date: Mon, 4 Sep 2017 21:32:06 +0200 Subject: [PATCH] started on hybrid MDP helper respecting solver requirements --- .../prctl/helper/HybridMdpPrctlHelper.cpp | 52 +++++++++++++++++-- .../prctl/helper/SparseMdpPrctlHelper.cpp | 8 +-- .../IterativeMinMaxLinearEquationSolver.cpp | 6 +-- src/storm/storage/SparseMatrix.h | 3 +- 4 files changed, 54 insertions(+), 15 deletions(-) diff --git a/src/storm/modelchecker/prctl/helper/HybridMdpPrctlHelper.cpp b/src/storm/modelchecker/prctl/helper/HybridMdpPrctlHelper.cpp index 24fe1326f..02c7d783c 100644 --- a/src/storm/modelchecker/prctl/helper/HybridMdpPrctlHelper.cpp +++ b/src/storm/modelchecker/prctl/helper/HybridMdpPrctlHelper.cpp @@ -25,6 +25,36 @@ namespace storm { namespace modelchecker { namespace helper { + template + std::vector computeValidInitialScheduler(uint64_t numberOfMaybeStates, storm::storage::SparseMatrix const& transitionMatrix, std::vector const& b) { + std::vector result(numberOfMaybeStates); + storm::storage::BitVector targetStates(numberOfMaybeStates); + + for (uint64_t state = 0; state < numberOfMaybeStates; ++state) { + // Record all states with non-zero probability of moving directly to the target states. + for (uint64_t row = transitionMatrix.getRowGroupIndices()[state]; row < transitionMatrix.getRowGroupIndices()[state + 1]; ++row) { + if (!storm::utility::isZero(b[row])) { + targetStates.set(state); + result[state] = row - transitionMatrix.getRowGroupIndices()[state]; + } + } + } + + if (!targetStates.full()) { + storm::storage::Scheduler validScheduler(numberOfMaybeStates); + storm::storage::SparseMatrix backwardTransitions = transitionMatrix.transpose(true); + storm::utility::graph::computeSchedulerProbGreater0E(transitionMatrix, backwardTransitions, storm::storage::BitVector(numberOfMaybeStates, true), targetStates, validScheduler, boost::none); + + for (uint64_t state = 0; state < numberOfMaybeStates; ++state) { + if (!targetStates.get(state)) { + result[state] = validScheduler.getChoice(state).getDeterministicChoice(); + } + } + } + + return result; + } + template std::unique_ptr HybridMdpPrctlHelper::computeUntilProbabilities(OptimizationDirection dir, storm::models::symbolic::NondeterministicModel const& model, storm::dd::Add const& transitionMatrix, storm::dd::Bdd const& phiStates, storm::dd::Bdd const& psiStates, bool qualitative, storm::solver::MinMaxLinearEquationSolverFactory const& linearEquationSolverFactory) { // We need to identify the states which have to be taken out of the matrix, i.e. all states that have @@ -75,14 +105,26 @@ namespace storm { // Translate the symbolic matrix/vector to their explicit representations and solve the equation system. std::pair, std::vector> explicitRepresentation = submatrix.toMatrixVector(subvector, std::move(rowGroupSizes), model.getNondeterminismVariables(), odd, odd); - // Check for requirements of the solver. - storm::solver::MinMaxLinearEquationSolverRequirements requirements = linearEquationSolverFactory.getRequirements(storm::solver::MinMaxLinearEquationSolverSystemType::UntilProbabilities); - STORM_LOG_THROW(requirements.empty(), storm::exceptions::UncheckedRequirementException, "Cannot establish requirements for solver."); - // Create the solution vector. std::vector x(maybeStates.getNonZeroCount(), storm::utility::zero()); + + // Check for requirements of the solver. + storm::solver::MinMaxLinearEquationSolverRequirements requirements = linearEquationSolverFactory.getRequirements(storm::solver::MinMaxLinearEquationSolverSystemType::UntilProbabilities, dir); + boost::optional> initialScheduler; + if (!requirements.empty()) { + if (requirements.requires(storm::solver::MinMaxLinearEquationSolverRequirements::Element::ValidInitialScheduler)) { + STORM_LOG_DEBUG("Computing valid scheduler hint, because the solver requires it."); + initialScheduler = computeValidInitialScheduler(x.size(), explicitRepresentation.first, explicitRepresentation.second); + + requirements.set(storm::solver::MinMaxLinearEquationSolverRequirements::Element::ValidInitialScheduler, false); + } + STORM_LOG_THROW(requirements.empty(), storm::exceptions::UncheckedRequirementException, "Cannot establish requirements for solver."); + } std::unique_ptr> solver = linearEquationSolverFactory.create(std::move(explicitRepresentation.first)); + if (initialScheduler) { + solver->setInitialScheduler(std::move(initialScheduler.get())); + } solver->setRequirementsChecked(); solver->solveEquations(dir, x, explicitRepresentation.second); @@ -247,7 +289,7 @@ namespace storm { // non-maybe states in the matrix. storm::dd::Add submatrix = transitionMatrix * maybeStatesAdd; - // Then compute the state reward vector to use in the computation. + // Then compute the reward vector to use in the computation. storm::dd::Add subvector = rewardModel.getTotalRewardVector(maybeStatesAdd, submatrix, model.getColumnVariables()); if (!rewardModel.hasStateActionRewards() && !rewardModel.hasTransitionRewards()) { // If the reward model neither has state-action nor transition rewards, we need to multiply diff --git a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp index a90c6bbfc..d842705ac 100644 --- a/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp +++ b/src/storm/modelchecker/prctl/helper/SparseMdpPrctlHelper.cpp @@ -87,12 +87,12 @@ namespace storm { template std::vector computeValidSchedulerHint(storm::solver::MinMaxLinearEquationSolverSystemType const& type, storm::storage::SparseMatrix const& transitionMatrix, storm::storage::SparseMatrix const& backwardTransitions, storm::storage::BitVector const& maybeStates, storm::storage::BitVector const& filterStates, storm::storage::BitVector const& targetStates) { - std::unique_ptr> validScheduler = std::make_unique>(maybeStates.size()); + storm::storage::Scheduler validScheduler(maybeStates.size()); if (type == storm::solver::MinMaxLinearEquationSolverSystemType::UntilProbabilities) { - storm::utility::graph::computeSchedulerProbGreater0E(transitionMatrix, backwardTransitions, filterStates, targetStates, *validScheduler, boost::none); + storm::utility::graph::computeSchedulerProbGreater0E(transitionMatrix, backwardTransitions, filterStates, targetStates, validScheduler, boost::none); } else if (type == storm::solver::MinMaxLinearEquationSolverSystemType::ReachabilityRewards) { - storm::utility::graph::computeSchedulerProb1E(maybeStates | targetStates, transitionMatrix, backwardTransitions, filterStates, targetStates, *validScheduler); + storm::utility::graph::computeSchedulerProb1E(maybeStates | targetStates, transitionMatrix, backwardTransitions, filterStates, targetStates, validScheduler); } else { STORM_LOG_ASSERT(false, "Unexpected equation system type."); } @@ -101,7 +101,7 @@ namespace storm { std::vector schedulerHint(maybeStates.getNumberOfSetBits()); auto maybeIt = maybeStates.begin(); for (auto& choice : schedulerHint) { - choice = validScheduler->getChoice(*maybeIt).getDeterministicChoice(); + choice = validScheduler.getChoice(*maybeIt).getDeterministicChoice(); ++maybeIt; } return schedulerHint; diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp index 0718db5b8..314c97c2b 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp @@ -225,10 +225,8 @@ namespace storm { } } } else if (equationSystemType == MinMaxLinearEquationSolverSystemType::ReachabilityRewards) { - if (this->getSettings().getSolutionMethod() == IterativeMinMaxLinearEquationSolverSettings::SolutionMethod::PolicyIteration) { - if (!direction || direction.get() == OptimizationDirection::Minimize) { - requirements.set(MinMaxLinearEquationSolverRequirements::Element::ValidInitialScheduler); - } + if (!direction || direction.get() == OptimizationDirection::Minimize) { + requirements.set(MinMaxLinearEquationSolverRequirements::Element::ValidInitialScheduler); } } diff --git a/src/storm/storage/SparseMatrix.h b/src/storm/storage/SparseMatrix.h index bd8c1bad1..6f4dba96f 100644 --- a/src/storm/storage/SparseMatrix.h +++ b/src/storm/storage/SparseMatrix.h @@ -688,7 +688,7 @@ namespace storm { /*! * Selects exactly one row from each row group of this matrix and returns the resulting matrix. * -s * @param insertDiagonalEntries If set to true, the resulting matrix will have zero entries in column i for + * @param insertDiagonalEntries If set to true, the resulting matrix will have zero entries in column i for * each row in row group i. This can then be used for inserting other values later. * @return A submatrix of the current matrix by selecting one row out of each row group. */ @@ -715,7 +715,6 @@ s * @param insertDiagonalEntries If set to true, the resulting matri */ storm::storage::SparseMatrix transpose(bool joinGroups = false, bool keepZeros = false) const; - /*! * Transposes the matrix w.r.t. the selected rows. * This is equivalent to selectRowsFromRowGroups(rowGroupChoices, false).transpose(false, keepZeros) but avoids creating one intermediate matrix.