151 lines
9.1 KiB

#include "src/storage/expressions/LinearCoefficientVisitor.h"
#include "src/storage/expressions/Expressions.h"
#include "src/exceptions/ExceptionMacros.h"
#include "src/exceptions/InvalidArgumentException.h"
namespace storm {
namespace expressions {
std::pair<SimpleValuation, double> LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) {
expression.getBaseExpression().accept(this);
return resultStack.top();
}
void LinearCoefficientVisitor::visit(IfThenElseExpression const* expression) {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const* expression) {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const* expression) {
if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
// Now add the left result to the right result.
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
if (rightResult.first.containsDoubleIdentifier(identifier)) {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) + rightResult.first.getDoubleValue(identifier));
} else {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier));
}
}
rightResult.second += leftResult.second;
return;
} else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
// Now subtract the right result from the left result.
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
if (rightResult.first.containsDoubleIdentifier(identifier)) {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) - rightResult.first.getDoubleValue(identifier));
} else {
rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier));
}
}
for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
if (!leftResult.first.containsDoubleIdentifier(identifier)) {
rightResult.first.setDoubleValue(identifier, -rightResult.first.getDoubleValue(identifier));
}
}
rightResult.second = leftResult.second - rightResult.second;
return;
} else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
// If the expression is linear, either the left or the right side must not contain variables.
LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
if (leftResult.first.getNumberOfIdentifiers() == 0) {
for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
rightResult.first.setDoubleValue(identifier, leftResult.second * rightResult.first.getDoubleValue(identifier));
}
} else {
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
rightResult.first.addDoubleIdentifier(identifier, rightResult.second * leftResult.first.getDoubleValue(identifier));
}
}
rightResult.second *= leftResult.second;
return;
} else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) {
expression->getFirstOperand()->accept(this);
std::pair<SimpleValuation, double> leftResult = resultStack.top();
resultStack.pop();
expression->getSecondOperand()->accept(this);
std::pair<SimpleValuation, double>& rightResult = resultStack.top();
// If the expression is linear, either the left or the right side must not contain variables.
LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
if (leftResult.first.getNumberOfIdentifiers() == 0) {
for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) {
rightResult.first.setDoubleValue(identifier, leftResult.second / rightResult.first.getDoubleValue(identifier));
}
} else {
for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) {
rightResult.first.addDoubleIdentifier(identifier, leftResult.first.getDoubleValue(identifier) / rightResult.second);
}
}
rightResult.second = leftResult.second / leftResult.second;
return;
} else {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
}
void LinearCoefficientVisitor::visit(BinaryRelationExpression const* expression) {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(VariableExpression const* expression) {
SimpleValuation valuation;
switch (expression->getReturnType()) {
case ExpressionReturnType::Bool: LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); break;
case ExpressionReturnType::Int:
case ExpressionReturnType::Double: valuation.addDoubleIdentifier(expression->getVariableName(), 1); break;
case ExpressionReturnType::Undefined: LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Illegal expression return type."); break;
}
resultStack.push(std::make_pair(valuation, 0));
}
void LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const* expression) {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const* expression) {
if (expression->getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) {
// Here, we need to negate all double identifiers.
std::pair<SimpleValuation, double>& valuationConstantPair = resultStack.top();
for (auto const& identifier : valuationConstantPair.first.getDoubleIdentifiers()) {
valuationConstantPair.first.setDoubleValue(identifier, -valuationConstantPair.first.getDoubleValue(identifier));
}
} else {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
}
void LinearCoefficientVisitor::visit(BooleanLiteralExpression const* expression) {
LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
}
void LinearCoefficientVisitor::visit(IntegerLiteralExpression const* expression) {
resultStack.push(std::make_pair(SimpleValuation(), static_cast<double>(expression->getValue())));
}
void LinearCoefficientVisitor::visit(DoubleLiteralExpression const* expression) {
resultStack.push(std::make_pair(SimpleValuation(), expression->getValue()));
}
}
}