From 45a7db822283b4b2a57280c39618aa1150f5ff22 Mon Sep 17 00:00:00 2001 From: Tim Quatmann Date: Wed, 10 Mar 2021 06:31:18 +0100 Subject: [PATCH] LpMinMaxLinearEquationSolver: Fixed an issue when using glpk occurring when the lower bound of a variable matches the upper bound. Also revamped retrieving of lower/upper bounds. --- .../solver/LpMinMaxLinearEquationSolver.cpp | 61 ++++++++++++------- 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/src/storm/solver/LpMinMaxLinearEquationSolver.cpp b/src/storm/solver/LpMinMaxLinearEquationSolver.cpp index 2b371aae3..cbbba9922 100644 --- a/src/storm/solver/LpMinMaxLinearEquationSolver.cpp +++ b/src/storm/solver/LpMinMaxLinearEquationSolver.cpp @@ -3,6 +3,8 @@ #include "storm/environment/solver/MinMaxSolverEnvironment.h" #include "storm/utility/vector.h" #include "storm/utility/macros.h" +#include "storm/storage/expressions/VariableExpression.h" +#include "storm/storage/expressions/RationalLiteralExpression.h" #include "storm/exceptions/InvalidEnvironmentException.h" #include "storm/exceptions/UnexpectedException.h" @@ -26,28 +28,35 @@ namespace storm { template bool LpMinMaxLinearEquationSolver::internalSolveEquations(Environment const& env, OptimizationDirection dir, std::vector& x, std::vector const& b) const { - + STORM_LOG_THROW(env.solver().minMax().getMethod() == MinMaxMethod::LinearProgramming, storm::exceptions::InvalidEnvironmentException, "This min max solver does not support the selected technique."); - + // Set up the LP solver std::unique_ptr> solver = lpSolverFactory->create(""); - solver->setOptimizationDirection(invert(dir)); - + // Create a variable for each row group - std::vector variables; - variables.reserve(this->A->getRowGroupCount()); + std::vector variableExpressions; + variableExpressions.reserve(this->A->getRowGroupCount()); for (uint64_t rowGroup = 0; rowGroup < this->A->getRowGroupCount(); ++rowGroup) { - if (this->lowerBound) { - if (this->upperBound) { - variables.push_back(solver->addBoundedContinuousVariable("x" + std::to_string(rowGroup), this->lowerBound.get(), this->upperBound.get(), storm::utility::one())); + if (this->hasLowerBound()) { + ValueType lowerBound = this->getLowerBound(rowGroup); + if (this->hasUpperBound()) { + ValueType upperBound = this->getUpperBound(rowGroup); + if (lowerBound == upperBound) { + // Some solvers (like glpk) don't support variables with bounds [x,x]. We therefore just use a constant instead. This should be more efficient anyways. + variableExpressions.push_back(solver->getConstant(lowerBound)); + } else { + STORM_LOG_ASSERT(lowerBound <= upperBound, "Lower Bound at row group " << rowGroup << " is " << lowerBound << " which exceeds the upper bound " << upperBound << "."); + variableExpressions.emplace_back(solver->addBoundedContinuousVariable("x" + std::to_string(rowGroup), lowerBound, upperBound, storm::utility::one())); + } } else { - variables.push_back(solver->addLowerBoundedContinuousVariable("x" + std::to_string(rowGroup), this->lowerBound.get(), storm::utility::one())); + variableExpressions.emplace_back(solver->addLowerBoundedContinuousVariable("x" + std::to_string(rowGroup), lowerBound, storm::utility::one())); } } else { if (this->upperBound) { - variables.push_back(solver->addUpperBoundedContinuousVariable("x" + std::to_string(rowGroup), this->upperBound.get(), storm::utility::one())); + variableExpressions.emplace_back(solver->addUpperBoundedContinuousVariable("x" + std::to_string(rowGroup), this->getUpperBound(rowGroup), storm::utility::one())); } else { - variables.push_back(solver->addUnboundedContinuousVariable("x" + std::to_string(rowGroup), storm::utility::one())); + variableExpressions.emplace_back(solver->addUnboundedContinuousVariable("x" + std::to_string(rowGroup), storm::utility::one())); } } } @@ -55,15 +64,19 @@ namespace storm { // Add a constraint for each row for (uint64_t rowGroup = 0; rowGroup < this->A->getRowGroupCount(); ++rowGroup) { - for (uint64_t row = this->A->getRowGroupIndices()[rowGroup]; row < this->A->getRowGroupIndices()[rowGroup + 1]; ++row) { - storm::expressions::Expression rowConstraint = solver->getConstant(b[row]); - for (auto const& entry : this->A->getRow(row)) { - rowConstraint = rowConstraint + (solver->getConstant(entry.getValue()) * variables[entry.getColumn()].getExpression()); + for (uint64_t rowIndex = this->A->getRowGroupIndices()[rowGroup]; rowIndex < this->A->getRowGroupIndices()[rowGroup + 1]; ++rowIndex) { + auto row = this->A->getRow(rowIndex); + std::vector summands; + summands.reserve(1+row.getNumberOfEntries()); + summands.push_back(solver->getConstant(b[rowIndex])); + for (auto const& entry : row) { + summands.push_back(solver->getConstant(entry.getValue()) * variableExpressions[entry.getColumn()]); } + storm::expressions::Expression rowConstraint = storm::expressions::sum(summands); if (minimize(dir)) { - rowConstraint = variables[rowGroup].getExpression() <= rowConstraint; + rowConstraint = variableExpressions[rowGroup] <= rowConstraint; } else { - rowConstraint = variables[rowGroup].getExpression() >= rowConstraint; + rowConstraint = variableExpressions[rowGroup] >= rowConstraint; } solver->addConstraint("", rowConstraint); } @@ -76,11 +89,17 @@ namespace storm { STORM_LOG_THROW(solver->isOptimal(), storm::exceptions::UnexpectedException, "Unable to find optimal solution for MinMax equation system."); // write the solution into the solution vector - STORM_LOG_ASSERT(x.size() == variables.size(), "Dimension of x-vector does not match number of varibales."); + STORM_LOG_ASSERT(x.size() == variableExpressions.size(), "Dimension of x-vector does not match number of varibales."); auto xIt = x.begin(); - auto vIt = variables.begin(); + auto vIt = variableExpressions.begin(); for (; xIt != x.end(); ++xIt, ++vIt) { - *xIt = solver->getContinuousValue(*vIt); + auto const& vBaseExpr = vIt->getBaseExpression(); + if (vBaseExpr.isVariable()) { + *xIt = solver->getContinuousValue(vBaseExpr.asVariableExpression().getVariable()); + } else { + STORM_LOG_ASSERT(vBaseExpr.isRationalLiteralExpression(), "Variable expression has unexpected type."); + *xIt = storm::utility::convertNumber(vBaseExpr.asRationalLiteralExpression().getValue()); + } } // If requested, we store the scheduler for retrieval.