diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp index b007e1a43..d9991f535 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.cpp @@ -144,6 +144,37 @@ namespace storm { return result; } + + template + bool IterativeMinMaxLinearEquationSolver::solveInducedEquationSystem(std::unique_ptr>& linearEquationSolver, std::vector const& scheduler, std::vector& x, std::vector& subB, std::vector const& originalB) const { + assert(subB.size() == x.size()); + + // Resolve the nondeterminism according to the given scheduler. + bool convertToEquationSystem = this->linearEquationSolverFactory->getEquationProblemFormat() == LinearEquationSolverProblemFormat::EquationSystem; + storm::storage::SparseMatrix submatrix = this->A->selectRowsFromRowGroups(scheduler, convertToEquationSystem); + if (convertToEquationSystem) { + submatrix.convertToEquationSystem(); + } + storm::utility::vector::selectVectorValues(subB, scheduler, this->A->getRowGroupIndices(), originalB); + + // Check whether the linear equation solver is already initialized + if (!linearEquationSolver) { + // Initialize the equation solver + linearEquationSolver = this->linearEquationSolverFactory->create(std::move(submatrix)); + if (this->lowerBound) { // TODO + linearEquationSolver->setLowerBound(this->lowerBound.get()); + } + if (this->upperBound) { + linearEquationSolver->setUpperBound(this->upperBound.get()); + } + linearEquationSolver->setCachingEnabled(true); + } else { + // If the equation solver is already initialized, it suffices to update the matrix + linearEquationSolver->setMatrix(std::move(submatrix)); + } + // Solve the equation system for the 'DTMC' and return true upon success + return linearEquationSolver->solveEquations(x, subB); + } template bool IterativeMinMaxLinearEquationSolver::solveEquationsPolicyIteration(OptimizationDirection dir, std::vector& x, std::vector const& b) const { @@ -156,30 +187,15 @@ namespace storm { } std::vector& subB = *auxiliaryRowGroupVector; - // Resolve the nondeterminism according to the current scheduler. - bool convertToEquationSystem = this->linearEquationSolverFactory->getEquationProblemFormat() == LinearEquationSolverProblemFormat::EquationSystem; - storm::storage::SparseMatrix submatrix = this->A->selectRowsFromRowGroups(scheduler, convertToEquationSystem); - if (convertToEquationSystem) { - submatrix.convertToEquationSystem(); - } - storm::utility::vector::selectVectorValues(subB, scheduler, this->A->getRowGroupIndices(), b); - - // Create a solver that we will use throughout the procedure. We will modify the matrix in each iteration. - auto solver = this->linearEquationSolverFactory->create(std::move(submatrix)); - if (this->lowerBound) { - solver->setLowerBound(this->lowerBound.get()); - } - if (this->upperBound) { - solver->setUpperBound(this->upperBound.get()); - } - solver->setCachingEnabled(true); + // The solver that we will use throughout the procedure. + std::unique_ptr> solver; SolverStatus status = SolverStatus::InProgress; uint64_t iterations = 0; this->startMeasureProgress(); do { // Solve the equation system for the 'DTMC'. - solver->solveEquations(x, subB); + solveInducedEquationSystem(solver, scheduler, x, subB, b); // Go through the multiplication result and see whether we can improve any of the choices. bool schedulerImproved = false; @@ -212,14 +228,6 @@ namespace storm { // If the scheduler did not improve, we are done. if (!schedulerImproved) { status = SolverStatus::Converged; - } else { - // Update the scheduler and the solver. - submatrix = this->A->selectRowsFromRowGroups(scheduler, true); - if (convertToEquationSystem) { - submatrix.convertToEquationSystem(); - } - storm::utility::vector::selectVectorValues(subB, scheduler, this->A->getRowGroupIndices(), b); - solver->setMatrix(std::move(submatrix)); } // Update environment variables. @@ -376,37 +384,25 @@ namespace storm { SolverGuarantee guarantee = SolverGuarantee::None; if (this->hasInitialScheduler()) { - // Resolve the nondeterminism according to the initial scheduler. - bool convertToEquationSystem = this->linearEquationSolverFactory->getEquationProblemFormat() == LinearEquationSolverProblemFormat::EquationSystem; - storm::storage::SparseMatrix submatrix = this->A->selectRowsFromRowGroups(this->getInitialScheduler(), convertToEquationSystem); - if (convertToEquationSystem) { - submatrix.convertToEquationSystem(); - } - storm::utility::vector::selectVectorValues(*auxiliaryRowGroupVector, this->getInitialScheduler(), this->A->getRowGroupIndices(), b); - - // Solve the resulting equation system. - auto submatrixSolver = this->linearEquationSolverFactory->create(std::move(submatrix)); - submatrixSolver->setCachingEnabled(true); - if (this->lowerBound) { - submatrixSolver->setLowerBound(this->lowerBound.get()); - } - if (this->upperBound) { - submatrixSolver->setUpperBound(this->upperBound.get()); - } - submatrixSolver->solveEquations(x, *auxiliaryRowGroupVector); - + // Solve the equation system induced by the initial scheduler. + std::unique_ptr> linEqSolver; + solveInducedEquationSystem(linEqSolver, this->getInitialScheduler(), x, *auxiliaryRowGroupVector, b); // If we were given an initial scheduler and are maximizing (minimizing), our current solution becomes // always less-or-equal (greater-or-equal) than the actual solution. - if (dir == storm::OptimizationDirection::Maximize) { + guarantee = maximize(dir) ? SolverGuarantee::LessOrEqual : SolverGuarantee::GreaterOrEqual; + } else if (!this->hasUniqueSolution()) { + if (maximize(dir)) { + this->createLowerBoundsVector(x); guarantee = SolverGuarantee::LessOrEqual; } else { + this->createUpperBoundsVector(x); guarantee = SolverGuarantee::GreaterOrEqual; } - } else if (!this->hasUniqueSolution()) { - if (dir == storm::OptimizationDirection::Maximize) { + } else if (this->hasCustomTerminationCondition()) { + if (this->getTerminationCondition().requiresGuarantee(SolverGuarantee::LessOrEqual)) { this->createLowerBoundsVector(x); guarantee = SolverGuarantee::LessOrEqual; - } else { + } else if (this->getTerminationCondition().requiresGuarantee(SolverGuarantee::GreaterOrEqual)) { this->createUpperBoundsVector(x); guarantee = SolverGuarantee::GreaterOrEqual; } diff --git a/src/storm/solver/IterativeMinMaxLinearEquationSolver.h b/src/storm/solver/IterativeMinMaxLinearEquationSolver.h index 582e3ebd3..0ba6a2e67 100644 --- a/src/storm/solver/IterativeMinMaxLinearEquationSolver.h +++ b/src/storm/solver/IterativeMinMaxLinearEquationSolver.h @@ -67,12 +67,12 @@ namespace storm { virtual MinMaxLinearEquationSolverRequirements getRequirements(boost::optional const& direction = boost::none) const override; private: + bool solveInducedEquationSystem(std::unique_ptr>& linearEquationSolver, std::vector const& scheduler, std::vector& x, std::vector& subB, std::vector const& originalB) const; bool solveEquationsPolicyIteration(OptimizationDirection dir, std::vector& x, std::vector const& b) const; bool valueImproved(OptimizationDirection dir, ValueType const& value1, ValueType const& value2) const; bool solveEquationsValueIteration(OptimizationDirection dir, std::vector& x, std::vector const& b) const; bool solveEquationsSoundValueIteration(OptimizationDirection dir, std::vector& x, std::vector const& b) const; - bool solveEquationsAcyclic(OptimizationDirection dir, std::vector& x, std::vector const& b) const; bool solveEquationsRationalSearch(OptimizationDirection dir, std::vector& x, std::vector const& b) const; template diff --git a/src/storm/solver/TerminationCondition.cpp b/src/storm/solver/TerminationCondition.cpp index e074eaaa2..745cc8a51 100644 --- a/src/storm/solver/TerminationCondition.cpp +++ b/src/storm/solver/TerminationCondition.cpp @@ -13,6 +13,11 @@ namespace storm { return false; } + template + bool NoTerminationCondition::requiresGuarantee(SolverGuarantee const&) const { + return false; + } + template TerminateIfFilteredSumExceedsThreshold::TerminateIfFilteredSumExceedsThreshold(storm::storage::BitVector const& filter, ValueType const& threshold, bool strict) : threshold(threshold), filter(filter), strict(strict) { // Intentionally left empty. @@ -29,6 +34,11 @@ namespace storm { return strict ? currentThreshold > this->threshold : currentThreshold >= this->threshold; } + template + bool TerminateIfFilteredSumExceedsThreshold::requiresGuarantee(SolverGuarantee const& guarantee) const { + return guarantee == SolverGuarantee::LessOrEqual; + } + template TerminateIfFilteredExtremumExceedsThreshold::TerminateIfFilteredExtremumExceedsThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum) : TerminateIfFilteredSumExceedsThreshold(filter, threshold, strict), useMinimum(useMinimum) { // Intentionally left empty. @@ -45,6 +55,11 @@ namespace storm { return this->strict ? currentValue > this->threshold : currentValue >= this->threshold; } + template + bool TerminateIfFilteredExtremumExceedsThreshold::requiresGuarantee(SolverGuarantee const& guarantee) const { + return guarantee == SolverGuarantee::LessOrEqual; + } + template TerminateIfFilteredExtremumBelowThreshold::TerminateIfFilteredExtremumBelowThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum) : TerminateIfFilteredSumExceedsThreshold(filter, threshold, strict), useMinimum(useMinimum) { // Intentionally left empty. @@ -61,6 +76,11 @@ namespace storm { return this->strict ? currentValue < this->threshold : currentValue <= this->threshold; } + template + bool TerminateIfFilteredExtremumBelowThreshold::requiresGuarantee(SolverGuarantee const& guarantee) const { + return guarantee == SolverGuarantee::GreaterOrEqual; + } + template class TerminateIfFilteredSumExceedsThreshold; template class TerminateIfFilteredExtremumExceedsThreshold; template class TerminateIfFilteredExtremumBelowThreshold; diff --git a/src/storm/solver/TerminationCondition.h b/src/storm/solver/TerminationCondition.h index 7c364989c..86733ade9 100644 --- a/src/storm/solver/TerminationCondition.h +++ b/src/storm/solver/TerminationCondition.h @@ -14,12 +14,19 @@ namespace storm { * Retrieves whether the guarantee provided by the solver for the current result is sufficient to terminate. */ virtual bool terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const = 0; + + /*! + * Retrieves whether the termination criterion requires the given guarantee in order to decide termination. + * @return + */ + virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const = 0; }; template class NoTerminationCondition : public TerminationCondition { public: virtual bool terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; + virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override; }; template @@ -27,7 +34,8 @@ namespace storm { public: TerminateIfFilteredSumExceedsThreshold(storm::storage::BitVector const& filter, ValueType const& threshold, bool strict); - bool terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const; + bool terminateNow(std::vector const& currentValues, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; + virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override; protected: ValueType threshold; @@ -40,7 +48,8 @@ namespace storm { public: TerminateIfFilteredExtremumExceedsThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum); - bool terminateNow(std::vector const& currentValue, SolverGuarantee const& guarantee = SolverGuarantee::None) const; + bool terminateNow(std::vector const& currentValue, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; + virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override; protected: bool useMinimum; @@ -51,7 +60,8 @@ namespace storm { public: TerminateIfFilteredExtremumBelowThreshold(storm::storage::BitVector const& filter, bool strict, ValueType const& threshold, bool useMinimum); - bool terminateNow(std::vector const& currentValue, SolverGuarantee const& guarantee = SolverGuarantee::None) const; + bool terminateNow(std::vector const& currentValue, SolverGuarantee const& guarantee = SolverGuarantee::None) const override; + virtual bool requiresGuarantee(SolverGuarantee const& guarantee) const override; protected: bool useMinimum;