Browse Source

LpMinMaxLinearEquationSolver: Fixed an issue when using glpk occurring when the lower bound of a variable matches the upper bound. Also revamped retrieving of lower/upper bounds.

tempestpy_adaptions
Tim Quatmann 4 years ago
parent
commit
45a7db8222
No known key found for this signature in database GPG Key ID: 6EDE19592731EEC3
  1. 55
      src/storm/solver/LpMinMaxLinearEquationSolver.cpp

55
src/storm/solver/LpMinMaxLinearEquationSolver.cpp

@ -3,6 +3,8 @@
#include "storm/environment/solver/MinMaxSolverEnvironment.h"
#include "storm/utility/vector.h"
#include "storm/utility/macros.h"
#include "storm/storage/expressions/VariableExpression.h"
#include "storm/storage/expressions/RationalLiteralExpression.h"
#include "storm/exceptions/InvalidEnvironmentException.h"
#include "storm/exceptions/UnexpectedException.h"
@ -31,23 +33,30 @@ namespace storm {
// Set up the LP solver
std::unique_ptr<storm::solver::LpSolver<ValueType>> solver = lpSolverFactory->create("");
solver->setOptimizationDirection(invert(dir));
// Create a variable for each row group
std::vector<storm::expressions::Variable> variables;
variables.reserve(this->A->getRowGroupCount());
std::vector<storm::expressions::Expression> variableExpressions;
variableExpressions.reserve(this->A->getRowGroupCount());
for (uint64_t rowGroup = 0; rowGroup < this->A->getRowGroupCount(); ++rowGroup) {
if (this->lowerBound) {
if (this->upperBound) {
variables.push_back(solver->addBoundedContinuousVariable("x" + std::to_string(rowGroup), this->lowerBound.get(), this->upperBound.get(), storm::utility::one<ValueType>()));
if (this->hasLowerBound()) {
ValueType lowerBound = this->getLowerBound(rowGroup);
if (this->hasUpperBound()) {
ValueType upperBound = this->getUpperBound(rowGroup);
if (lowerBound == upperBound) {
// Some solvers (like glpk) don't support variables with bounds [x,x]. We therefore just use a constant instead. This should be more efficient anyways.
variableExpressions.push_back(solver->getConstant(lowerBound));
} else {
STORM_LOG_ASSERT(lowerBound <= upperBound, "Lower Bound at row group " << rowGroup << " is " << lowerBound << " which exceeds the upper bound " << upperBound << ".");
variableExpressions.emplace_back(solver->addBoundedContinuousVariable("x" + std::to_string(rowGroup), lowerBound, upperBound, storm::utility::one<ValueType>()));
}
} else {
variables.push_back(solver->addLowerBoundedContinuousVariable("x" + std::to_string(rowGroup), this->lowerBound.get(), storm::utility::one<ValueType>()));
variableExpressions.emplace_back(solver->addLowerBoundedContinuousVariable("x" + std::to_string(rowGroup), lowerBound, storm::utility::one<ValueType>()));
}
} else {
if (this->upperBound) {
variables.push_back(solver->addUpperBoundedContinuousVariable("x" + std::to_string(rowGroup), this->upperBound.get(), storm::utility::one<ValueType>()));
variableExpressions.emplace_back(solver->addUpperBoundedContinuousVariable("x" + std::to_string(rowGroup), this->getUpperBound(rowGroup), storm::utility::one<ValueType>()));
} else {
variables.push_back(solver->addUnboundedContinuousVariable("x" + std::to_string(rowGroup), storm::utility::one<ValueType>()));
variableExpressions.emplace_back(solver->addUnboundedContinuousVariable("x" + std::to_string(rowGroup), storm::utility::one<ValueType>()));
}
}
}
@ -55,15 +64,19 @@ namespace storm {
// Add a constraint for each row
for (uint64_t rowGroup = 0; rowGroup < this->A->getRowGroupCount(); ++rowGroup) {
for (uint64_t row = this->A->getRowGroupIndices()[rowGroup]; row < this->A->getRowGroupIndices()[rowGroup + 1]; ++row) {
storm::expressions::Expression rowConstraint = solver->getConstant(b[row]);
for (auto const& entry : this->A->getRow(row)) {
rowConstraint = rowConstraint + (solver->getConstant(entry.getValue()) * variables[entry.getColumn()].getExpression());
for (uint64_t rowIndex = this->A->getRowGroupIndices()[rowGroup]; rowIndex < this->A->getRowGroupIndices()[rowGroup + 1]; ++rowIndex) {
auto row = this->A->getRow(rowIndex);
std::vector<storm::expressions::Expression> summands;
summands.reserve(1+row.getNumberOfEntries());
summands.push_back(solver->getConstant(b[rowIndex]));
for (auto const& entry : row) {
summands.push_back(solver->getConstant(entry.getValue()) * variableExpressions[entry.getColumn()]);
}
storm::expressions::Expression rowConstraint = storm::expressions::sum(summands);
if (minimize(dir)) {
rowConstraint = variables[rowGroup].getExpression() <= rowConstraint;
rowConstraint = variableExpressions[rowGroup] <= rowConstraint;
} else {
rowConstraint = variables[rowGroup].getExpression() >= rowConstraint;
rowConstraint = variableExpressions[rowGroup] >= rowConstraint;
}
solver->addConstraint("", rowConstraint);
}
@ -76,11 +89,17 @@ namespace storm {
STORM_LOG_THROW(solver->isOptimal(), storm::exceptions::UnexpectedException, "Unable to find optimal solution for MinMax equation system.");
// write the solution into the solution vector
STORM_LOG_ASSERT(x.size() == variables.size(), "Dimension of x-vector does not match number of varibales.");
STORM_LOG_ASSERT(x.size() == variableExpressions.size(), "Dimension of x-vector does not match number of varibales.");
auto xIt = x.begin();
auto vIt = variables.begin();
auto vIt = variableExpressions.begin();
for (; xIt != x.end(); ++xIt, ++vIt) {
*xIt = solver->getContinuousValue(*vIt);
auto const& vBaseExpr = vIt->getBaseExpression();
if (vBaseExpr.isVariable()) {
*xIt = solver->getContinuousValue(vBaseExpr.asVariableExpression().getVariable());
} else {
STORM_LOG_ASSERT(vBaseExpr.isRationalLiteralExpression(), "Variable expression has unexpected type.");
*xIt = storm::utility::convertNumber<ValueType>(vBaseExpr.asRationalLiteralExpression().getValue());
}
}
// If requested, we store the scheduler for retrieval.

Loading…
Cancel
Save