Browse Source

Finished linear coefficient visitor and adapted glpk solver to new expression-based LP solver interface.

Former-commit-id: ba1d3a912f
tempestpy_adaptions
dehnert 11 years ago
parent
commit
d5c2f9248f
  1. 42
      src/solver/GlpkLpSolver.cpp
  2. 85
      src/storage/expressions/LinearCoefficientVisitor.cpp
  3. 19
      src/storage/expressions/LinearityCheckVisitor.cpp

42
src/solver/GlpkLpSolver.cpp

@ -4,6 +4,8 @@
#include <iostream> #include <iostream>
#include "src/storage/expressions/LinearCoefficientVisitor.h"
#include "src/settings/Settings.h" #include "src/settings/Settings.h"
#include "src/exceptions/ExceptionMacros.h" #include "src/exceptions/ExceptionMacros.h"
#include "src/exceptions/InvalidAccessException.h" #include "src/exceptions/InvalidAccessException.h"
@ -160,26 +162,46 @@ namespace storm {
LOG_THROW(constraint.isRelationalExpression(), storm::exceptions::InvalidArgumentException, "Illegal constraint is not a relational expression."); LOG_THROW(constraint.isRelationalExpression(), storm::exceptions::InvalidArgumentException, "Illegal constraint is not a relational expression.");
LOG_THROW(constraint.getOperator() != storm::expressions::OperatorType::NotEqual, storm::exceptions::InvalidArgumentException, "Illegal constraint uses inequality operator."); LOG_THROW(constraint.getOperator() != storm::expressions::OperatorType::NotEqual, storm::exceptions::InvalidArgumentException, "Illegal constraint uses inequality operator.");
// TODO: get variable/coefficients vector from constraint.
std::pair<storm::expressions::SimpleValuation, double> leftCoefficients = storm::expressions::LinearCoefficientVisitor().getLinearCoefficients(constraint.getOperand(0));
std::pair<storm::expressions::SimpleValuation, double> rightCoefficients = storm::expressions::LinearCoefficientVisitor().getLinearCoefficients(constraint.getOperand(1));
for (auto const& identifier : rightCoefficients.first.getDoubleIdentifiers()) {
if (leftCoefficients.first.containsDoubleIdentifier(identifier)) {
leftCoefficients.first.setDoubleValue(identifier, leftCoefficients.first.getDoubleValue(identifier) - rightCoefficients.first.getDoubleValue(identifier));
} else {
leftCoefficients.first.addDoubleIdentifier(identifier, -rightCoefficients.first.getDoubleValue(identifier));
}
}
rightCoefficients.second -= leftCoefficients.second;
// Now we need to transform the coefficients to the vector representation.
std::vector<int> variables;
std::vector<double> coefficients;
for (auto const& identifier : leftCoefficients.first.getDoubleIdentifiers()) {
auto identifierIndexPair = this->variableNameToIndexMap.find(identifier);
LOG_THROW(identifierIndexPair != this->variableNameToIndexMap.end(), storm::exceptions::InvalidArgumentException, "Constraint contains illegal identifier '" << identifier << "'.");
variables.push_back(identifierIndexPair->second);
coefficients.push_back(leftCoefficients.first.getDoubleValue(identifier));
}
// Determine the type of the constraint and add it properly. // Determine the type of the constraint and add it properly.
switch (constraint.getOperator()) { switch (constraint.getOperator()) {
case storm::expressions::OperatorType::Less: case storm::expressions::OperatorType::Less:
glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_UP, 0, rightHandSideValue - storm::settings::Settings::getInstance()->getOptionByLongName("glpkinttol").getArgument(0).getValueAsDouble());
glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_UP, 0, rightCoefficients.second - storm::settings::Settings::getInstance()->getOptionByLongName("glpkinttol").getArgument(0).getValueAsDouble());
break; break;
case LESS_EQUAL:
glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_UP, 0, rightHandSideValue);
case storm::expressions::OperatorType::LessOrEqual:
glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_UP, 0, rightCoefficients.second);
break; break;
case GREATER:
glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_LO, rightHandSideValue + storm::settings::Settings::getInstance()->getOptionByLongName("glpkinttol").getArgument(0).getValueAsDouble(), 0);
case storm::expressions::OperatorType::Greater:
glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_LO, rightCoefficients.second + storm::settings::Settings::getInstance()->getOptionByLongName("glpkinttol").getArgument(0).getValueAsDouble(), 0);
break; break;
case GREATER_EQUAL:
glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_LO, rightHandSideValue, 0);
case storm::expressions::OperatorType::GreaterOrEqual:
glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_LO, rightCoefficients.second, 0);
break; break;
case EQUAL:
glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_FX, rightHandSideValue, rightHandSideValue);
case storm::expressions::OperatorType::Equal:
glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_FX, rightCoefficients.second, rightCoefficients.second);
break; break;
default:
LOG_ASSERT(false, "Illegal operator in LP solver constraint.");
} }
// Record the variables and coefficients in the coefficient matrix. // Record the variables and coefficients in the coefficient matrix.

85
src/storage/expressions/LinearCoefficientVisitor.cpp

@ -16,15 +16,96 @@ namespace storm {
} }
void LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const* expression) { void LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const* expression) {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
} }
void LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const* expression) { void LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const* expression) {
if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
// Now add the left result to the right result.
for (auto const& identifier : leftResult.first.Valuation::getDoubleIdentifiers()) {
if (rightResult.first.containsDoubleIdentifier(identifier)) {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) + rightResult.first.getDoubleValue(identifier));
} else {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier));
}
}
rightResult.second += leftResult.second;
return;
} else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
// Now subtract the right result from the left result.
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
if (rightResult.first.containsDoubleIdentifier(identifier)) {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) - rightResult.first.getDoubleValue(identifier));
} else {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier));
}
}
for (auto const& identifier : rightResult.first.Valuation::getDoubleIdentifiers()) {
if (!leftResult.first.containsDoubleIdentifier(identifier)) {
rightResult.first.setDoubleValue(identifier, -rightResult.first.getDoubleValue(identifier));
}
}
rightResult.second = leftResult.second - rightResult.second;
return;
} else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
// If the expression is linear, either the left or the right side must not contain variables.
LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
if (leftResult.first.getNumberOfIdentifiers() == 0) {
for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
rightResult.first.setDoubleValue(identifier, leftResult.second * rightResult.first.getDoubleValue(identifier));
}
} else {
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
rightResult.first.addDoubleIdentifier(identifier, rightResult.second * leftResult.first.getDoubleValue(identifier));
}
}
rightResult.second *= leftResult.second;
return;
} else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
// If the expression is linear, either the left or the right side must not contain variables.
LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
if (leftResult.first.getNumberOfIdentifiers() == 0) {
for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
rightResult.first.setDoubleValue(identifier, leftResult.second / rightResult.first.getDoubleValue(identifier));
}
} else {
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
rightResult.first.addDoubleIdentifier(identifier, leftResult.first.getDoubleValue(identifier) / rightResult.second);
}
}
rightResult.second = leftResult.second / leftResult.second;
return;
} else {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
} }
void LinearCoefficientVisitor::visit(BinaryRelationExpression const* expression) { void LinearCoefficientVisitor::visit(BinaryRelationExpression const* expression) {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
} }
void LinearCoefficientVisitor::visit(VariableExpression const* expression) { void LinearCoefficientVisitor::visit(VariableExpression const* expression) {

19
src/storage/expressions/LinearityCheckVisitor.cpp

@ -77,24 +77,7 @@ namespace storm {
} }
void LinearityCheckVisitor::visit(BinaryRelationExpression const* expression) { void LinearityCheckVisitor::visit(BinaryRelationExpression const* expression) {
LinearityStatus leftResult;
LinearityStatus rightResult;
expression->getFirstOperand()->accept(this);
leftResult = resultStack.top();
if (leftResult == LinearityStatus::NonLinear) {
return;
} else {
resultStack.pop();
expression->getSecondOperand()->accept(this);
rightResult = resultStack.top();
if (rightResult == LinearityStatus::NonLinear) {
return;
}
resultStack.pop();
}
resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables);
resultStack.push(LinearityStatus::NonLinear);
} }
void LinearityCheckVisitor::visit(VariableExpression const* expression) { void LinearityCheckVisitor::visit(VariableExpression const* expression) {

Loading…
Cancel
Save