From d5c2f9248fb183ee4864b8177a24082d81b3a34b Mon Sep 17 00:00:00 2001 From: dehnert Date: Sat, 10 May 2014 23:30:02 +0200 Subject: [PATCH] Finished linear coefficient visitor and adapted glpk solver to new expression-based LP solver interface. Former-commit-id: ba1d3a912f8bebe6d388f59198870d5b060391c5 --- src/solver/GlpkLpSolver.cpp | 42 ++++++--- .../expressions/LinearCoefficientVisitor.cpp | 87 ++++++++++++++++++- .../expressions/LinearityCheckVisitor.cpp | 19 +--- 3 files changed, 117 insertions(+), 31 deletions(-) diff --git a/src/solver/GlpkLpSolver.cpp b/src/solver/GlpkLpSolver.cpp index fc2c5f261..1cffa1a18 100644 --- a/src/solver/GlpkLpSolver.cpp +++ b/src/solver/GlpkLpSolver.cpp @@ -4,6 +4,8 @@ #include +#include "src/storage/expressions/LinearCoefficientVisitor.h" + #include "src/settings/Settings.h" #include "src/exceptions/ExceptionMacros.h" #include "src/exceptions/InvalidAccessException.h" @@ -160,26 +162,46 @@ namespace storm { LOG_THROW(constraint.isRelationalExpression(), storm::exceptions::InvalidArgumentException, "Illegal constraint is not a relational expression."); LOG_THROW(constraint.getOperator() != storm::expressions::OperatorType::NotEqual, storm::exceptions::InvalidArgumentException, "Illegal constraint uses inequality operator."); - // TODO: get variable/coefficients vector from constraint. + std::pair leftCoefficients = storm::expressions::LinearCoefficientVisitor().getLinearCoefficients(constraint.getOperand(0)); + std::pair rightCoefficients = storm::expressions::LinearCoefficientVisitor().getLinearCoefficients(constraint.getOperand(1)); + for (auto const& identifier : rightCoefficients.first.getDoubleIdentifiers()) { + if (leftCoefficients.first.containsDoubleIdentifier(identifier)) { + leftCoefficients.first.setDoubleValue(identifier, leftCoefficients.first.getDoubleValue(identifier) - rightCoefficients.first.getDoubleValue(identifier)); + } else { + leftCoefficients.first.addDoubleIdentifier(identifier, -rightCoefficients.first.getDoubleValue(identifier)); + } + } + rightCoefficients.second -= leftCoefficients.second; + // Now we need to transform the coefficients to the vector representation. + std::vector variables; + std::vector coefficients; + for (auto const& identifier : leftCoefficients.first.getDoubleIdentifiers()) { + auto identifierIndexPair = this->variableNameToIndexMap.find(identifier); + LOG_THROW(identifierIndexPair != this->variableNameToIndexMap.end(), storm::exceptions::InvalidArgumentException, "Constraint contains illegal identifier '" << identifier << "'."); + variables.push_back(identifierIndexPair->second); + coefficients.push_back(leftCoefficients.first.getDoubleValue(identifier)); + } // Determine the type of the constraint and add it properly. switch (constraint.getOperator()) { case storm::expressions::OperatorType::Less: - glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_UP, 0, rightHandSideValue - storm::settings::Settings::getInstance()->getOptionByLongName("glpkinttol").getArgument(0).getValueAsDouble()); + glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_UP, 0, rightCoefficients.second - storm::settings::Settings::getInstance()->getOptionByLongName("glpkinttol").getArgument(0).getValueAsDouble()); break; - case LESS_EQUAL: - glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_UP, 0, rightHandSideValue); + case storm::expressions::OperatorType::LessOrEqual: + glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_UP, 0, rightCoefficients.second); break; - case GREATER: - glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_LO, rightHandSideValue + storm::settings::Settings::getInstance()->getOptionByLongName("glpkinttol").getArgument(0).getValueAsDouble(), 0); + case storm::expressions::OperatorType::Greater: + glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_LO, rightCoefficients.second + storm::settings::Settings::getInstance()->getOptionByLongName("glpkinttol").getArgument(0).getValueAsDouble(), 0); break; - case GREATER_EQUAL: - glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_LO, rightHandSideValue, 0); + case storm::expressions::OperatorType::GreaterOrEqual: + glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_LO, rightCoefficients.second, 0); break; - case EQUAL: - glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_FX, rightHandSideValue, rightHandSideValue); + case storm::expressions::OperatorType::Equal: + glp_set_row_bnds(this->lp, nextConstraintIndex, GLP_FX, rightCoefficients.second, rightCoefficients.second); break; + default: + LOG_ASSERT(false, "Illegal operator in LP solver constraint."); } // Record the variables and coefficients in the coefficient matrix. diff --git a/src/storage/expressions/LinearCoefficientVisitor.cpp b/src/storage/expressions/LinearCoefficientVisitor.cpp index 58fd427d5..f59804927 100644 --- a/src/storage/expressions/LinearCoefficientVisitor.cpp +++ b/src/storage/expressions/LinearCoefficientVisitor.cpp @@ -16,15 +16,96 @@ namespace storm { } void LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const* expression) { - + LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } void LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const* expression) { - + if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) { + expression->getFirstOperand()->accept(this); + std::pair leftResult = resultStack.top(); + resultStack.pop(); + expression->getSecondOperand()->accept(this); + std::pair& rightResult = resultStack.top(); + + // Now add the left result to the right result. + for (auto const& identifier : leftResult.first.Valuation::getDoubleIdentifiers()) { + if (rightResult.first.containsDoubleIdentifier(identifier)) { + rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) + rightResult.first.getDoubleValue(identifier)); + } else { + rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier)); + } + } + rightResult.second += leftResult.second; + return; + } else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) { + expression->getFirstOperand()->accept(this); + std::pair leftResult = resultStack.top(); + resultStack.pop(); + expression->getSecondOperand()->accept(this); + std::pair& rightResult = resultStack.top(); + + // Now subtract the right result from the left result. + for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) { + if (rightResult.first.containsDoubleIdentifier(identifier)) { + rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) - rightResult.first.getDoubleValue(identifier)); + } else { + rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier)); + } + } + for (auto const& identifier : rightResult.first.Valuation::getDoubleIdentifiers()) { + if (!leftResult.first.containsDoubleIdentifier(identifier)) { + rightResult.first.setDoubleValue(identifier, -rightResult.first.getDoubleValue(identifier)); + } + } + rightResult.second = leftResult.second - rightResult.second; + return; + } else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) { + expression->getFirstOperand()->accept(this); + std::pair leftResult = resultStack.top(); + resultStack.pop(); + expression->getSecondOperand()->accept(this); + std::pair& rightResult = resultStack.top(); + + // If the expression is linear, either the left or the right side must not contain variables. + LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); + if (leftResult.first.getNumberOfIdentifiers() == 0) { + for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) { + rightResult.first.setDoubleValue(identifier, leftResult.second * rightResult.first.getDoubleValue(identifier)); + } + } else { + for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) { + rightResult.first.addDoubleIdentifier(identifier, rightResult.second * leftResult.first.getDoubleValue(identifier)); + } + } + rightResult.second *= leftResult.second; + return; + } else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) { + expression->getFirstOperand()->accept(this); + std::pair leftResult = resultStack.top(); + resultStack.pop(); + expression->getSecondOperand()->accept(this); + std::pair& rightResult = resultStack.top(); + + // If the expression is linear, either the left or the right side must not contain variables. + LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); + if (leftResult.first.getNumberOfIdentifiers() == 0) { + for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) { + rightResult.first.setDoubleValue(identifier, leftResult.second / rightResult.first.getDoubleValue(identifier)); + } + } else { + for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) { + rightResult.first.addDoubleIdentifier(identifier, leftResult.first.getDoubleValue(identifier) / rightResult.second); + } + } + rightResult.second = leftResult.second / leftResult.second; + return; + } else { + LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); + } } void LinearCoefficientVisitor::visit(BinaryRelationExpression const* expression) { - + LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } void LinearCoefficientVisitor::visit(VariableExpression const* expression) { diff --git a/src/storage/expressions/LinearityCheckVisitor.cpp b/src/storage/expressions/LinearityCheckVisitor.cpp index 5dd2f38ac..9b382ed22 100644 --- a/src/storage/expressions/LinearityCheckVisitor.cpp +++ b/src/storage/expressions/LinearityCheckVisitor.cpp @@ -77,24 +77,7 @@ namespace storm { } void LinearityCheckVisitor::visit(BinaryRelationExpression const* expression) { - LinearityStatus leftResult; - LinearityStatus rightResult; - expression->getFirstOperand()->accept(this); - leftResult = resultStack.top(); - - if (leftResult == LinearityStatus::NonLinear) { - return; - } else { - resultStack.pop(); - expression->getSecondOperand()->accept(this); - rightResult = resultStack.top(); - if (rightResult == LinearityStatus::NonLinear) { - return; - } - resultStack.pop(); - } - - resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables); + resultStack.push(LinearityStatus::NonLinear); } void LinearityCheckVisitor::visit(VariableExpression const* expression) {