You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
157 lines
8.6 KiB
157 lines
8.6 KiB
#include "src/storage/expressions/LinearCoefficientVisitor.h"
|
|
|
|
#include "src/storage/expressions/Expressions.h"
|
|
#include "src/utility/macros.h"
|
|
#include "src/exceptions/InvalidArgumentException.h"
|
|
|
|
namespace storm {
|
|
namespace expressions {
|
|
LinearCoefficientVisitor::VariableCoefficients::VariableCoefficients(double constantPart) : variableToCoefficientMapping(), constantPart(constantPart) {
|
|
// Intentionally left empty.
|
|
}
|
|
|
|
LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator+=(VariableCoefficients&& other) {
|
|
for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) {
|
|
this->variableToCoefficientMapping[otherVariableCoefficientPair.first] += otherVariableCoefficientPair.second;
|
|
}
|
|
constantPart += other.constantPart;
|
|
return *this;
|
|
}
|
|
|
|
LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator-=(VariableCoefficients&& other) {
|
|
for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) {
|
|
this->variableToCoefficientMapping[otherVariableCoefficientPair.first] -= otherVariableCoefficientPair.second;
|
|
}
|
|
constantPart -= other.constantPart;
|
|
return *this;
|
|
}
|
|
|
|
LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator*=(VariableCoefficients&& other) {
|
|
STORM_LOG_THROW(variableToCoefficientMapping.size() == 0 || other.variableToCoefficientMapping.size() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
|
|
if (other.variableToCoefficientMapping.size() > 0) {
|
|
variableToCoefficientMapping = std::move(other.variableToCoefficientMapping);
|
|
std::swap(constantPart, other.constantPart);
|
|
}
|
|
for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) {
|
|
this->variableToCoefficientMapping[otherVariableCoefficientPair.first] *= other.constantPart;
|
|
}
|
|
constantPart *= other.constantPart;
|
|
return *this;
|
|
}
|
|
|
|
LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator/=(VariableCoefficients&& other) {
|
|
STORM_LOG_THROW(other.variableToCoefficientMapping.size() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
|
|
for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) {
|
|
this->variableToCoefficientMapping[otherVariableCoefficientPair.first] /= other.constantPart;
|
|
}
|
|
constantPart /= other.constantPart;
|
|
return *this;
|
|
}
|
|
|
|
void LinearCoefficientVisitor::VariableCoefficients::negate() {
|
|
for (auto& variableCoefficientPair : variableToCoefficientMapping) {
|
|
variableCoefficientPair.second = -variableCoefficientPair.second;
|
|
}
|
|
constantPart = -constantPart;
|
|
}
|
|
|
|
void LinearCoefficientVisitor::VariableCoefficients::setCoefficient(storm::expressions::Variable const& variable, double coefficient) {
|
|
variableToCoefficientMapping[variable] = coefficient;
|
|
}
|
|
|
|
double LinearCoefficientVisitor::VariableCoefficients::getCoefficient(storm::expressions::Variable const& variable) {
|
|
return variableToCoefficientMapping[variable];
|
|
}
|
|
|
|
double LinearCoefficientVisitor::VariableCoefficients::getConstantPart() const {
|
|
return this->constantPart;
|
|
}
|
|
|
|
void LinearCoefficientVisitor::VariableCoefficients::separateVariablesFromConstantPart(VariableCoefficients& rhs) {
|
|
for (auto const& rhsVariableCoefficientPair : rhs.variableToCoefficientMapping) {
|
|
this->variableToCoefficientMapping[rhsVariableCoefficientPair.first] -= rhsVariableCoefficientPair.second;
|
|
}
|
|
rhs.variableToCoefficientMapping.clear();
|
|
rhs.constantPart -= this->constantPart;
|
|
}
|
|
|
|
std::map<storm::expressions::Variable, double>::const_iterator LinearCoefficientVisitor::VariableCoefficients::begin() const {
|
|
return this->variableToCoefficientMapping.begin();
|
|
}
|
|
|
|
std::map<storm::expressions::Variable, double>::const_iterator LinearCoefficientVisitor::VariableCoefficients::end() const {
|
|
return this->variableToCoefficientMapping.end();
|
|
}
|
|
|
|
LinearCoefficientVisitor::VariableCoefficients LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) {
|
|
return boost::any_cast<VariableCoefficients>(expression.getBaseExpression().accept(*this));
|
|
}
|
|
|
|
boost::any LinearCoefficientVisitor::visit(IfThenElseExpression const& expression) {
|
|
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
|
|
}
|
|
|
|
boost::any LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const& expression) {
|
|
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
|
|
}
|
|
|
|
boost::any LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const& expression) {
|
|
VariableCoefficients leftResult = boost::any_cast<VariableCoefficients>(expression.getFirstOperand()->accept(*this));
|
|
VariableCoefficients rightResult = boost::any_cast<VariableCoefficients>(expression.getSecondOperand()->accept(*this));
|
|
|
|
if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) {
|
|
leftResult += std::move(rightResult);
|
|
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) {
|
|
leftResult -= std::move(rightResult);
|
|
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) {
|
|
leftResult *= std::move(rightResult);
|
|
} else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) {
|
|
leftResult /= std::move(rightResult);
|
|
} else {
|
|
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
|
|
}
|
|
return rightResult;
|
|
}
|
|
|
|
boost::any LinearCoefficientVisitor::visit(BinaryRelationExpression const& expression) {
|
|
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
|
|
}
|
|
|
|
boost::any LinearCoefficientVisitor::visit(VariableExpression const& expression) {
|
|
VariableCoefficients coefficients;
|
|
if (expression.getType().isNumericalType()) {
|
|
coefficients.setCoefficient(expression.getVariable(), 1);
|
|
} else {
|
|
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
|
|
}
|
|
return coefficients;
|
|
}
|
|
|
|
boost::any LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const& expression) {
|
|
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
|
|
}
|
|
|
|
boost::any LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const& expression) {
|
|
VariableCoefficients childResult = boost::any_cast<VariableCoefficients>(expression.getOperand()->accept(*this));
|
|
|
|
if (expression.getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) {
|
|
childResult.negate();
|
|
return childResult;
|
|
} else {
|
|
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
|
|
}
|
|
}
|
|
|
|
boost::any LinearCoefficientVisitor::visit(BooleanLiteralExpression const& expression) {
|
|
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear.");
|
|
}
|
|
|
|
boost::any LinearCoefficientVisitor::visit(IntegerLiteralExpression const& expression) {
|
|
return VariableCoefficients(static_cast<double>(expression.getValue()));
|
|
}
|
|
|
|
boost::any LinearCoefficientVisitor::visit(DoubleLiteralExpression const& expression) {
|
|
return VariableCoefficients(expression.getValue());
|
|
}
|
|
}
|
|
}
|