You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

179 lines
12 KiB

#include <algorithm>
#include <cmath>
#include "storm/adapters/RationalNumberAdapter.h"
#include "storm/storage/expressions/BinaryNumericalFunctionExpression.h"
#include "storm/storage/expressions/IntegerLiteralExpression.h"
#include "storm/storage/expressions/RationalLiteralExpression.h"
#include "storm/storage/expressions/ExpressionVisitor.h"
#include "storm/utility/macros.h"
#include "storm/utility/constants.h"
#include "storm/utility/NumberTraits.h"
#include "storm/exceptions/InvalidTypeException.h"
#include "storm/exceptions/InvalidStateException.h"
namespace storm {
namespace expressions {
BinaryNumericalFunctionExpression::BinaryNumericalFunctionExpression(ExpressionManager const& manager, Type const& type, std::shared_ptr<BaseExpression const> const& firstOperand, std::shared_ptr<BaseExpression const> const& secondOperand, OperatorType operatorType) : BinaryExpression(manager, type, firstOperand, secondOperand), operatorType(operatorType) {
// Intentionally left empty.
}
BinaryNumericalFunctionExpression::OperatorType BinaryNumericalFunctionExpression::getOperatorType() const {
return this->operatorType;
}
storm::expressions::OperatorType BinaryNumericalFunctionExpression::getOperator() const {
storm::expressions::OperatorType result = storm::expressions::OperatorType::Plus;
switch (this->getOperatorType()) {
case OperatorType::Plus: result = storm::expressions::OperatorType::Plus; break;
case OperatorType::Minus: result = storm::expressions::OperatorType::Minus; break;
case OperatorType::Times: result = storm::expressions::OperatorType::Times; break;
case OperatorType::Divide: result = storm::expressions::OperatorType::Divide; break;
case OperatorType::Min: result = storm::expressions::OperatorType::Min; break;
case OperatorType::Max: result = storm::expressions::OperatorType::Max; break;
case OperatorType::Power: result = storm::expressions::OperatorType::Power; break;
case OperatorType::Modulo: result = storm::expressions::OperatorType::Modulo; break;
}
return result;
}
int_fast64_t BinaryNumericalFunctionExpression::evaluateAsInt(Valuation const* valuation) const {
STORM_LOG_THROW(this->hasIntegerType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as integer.");
int_fast64_t firstOperandEvaluation = this->getFirstOperand()->evaluateAsInt(valuation);
int_fast64_t secondOperandEvaluation = this->getSecondOperand()->evaluateAsInt(valuation);
int_fast64_t result = 0;
switch (this->getOperatorType()) {
case OperatorType::Plus: result = firstOperandEvaluation + secondOperandEvaluation; break;
case OperatorType::Minus: result = firstOperandEvaluation - secondOperandEvaluation; break;
case OperatorType::Times: result = firstOperandEvaluation * secondOperandEvaluation; break;
case OperatorType::Divide: result = firstOperandEvaluation / secondOperandEvaluation; break;
case OperatorType::Min: result = std::min(firstOperandEvaluation, secondOperandEvaluation); break;
case OperatorType::Max: result = std::max(firstOperandEvaluation, secondOperandEvaluation); break;
case OperatorType::Power: result = static_cast<int_fast64_t>(std::pow(firstOperandEvaluation, secondOperandEvaluation)); break;
case OperatorType::Modulo: result = firstOperandEvaluation % secondOperandEvaluation; break;
}
return result;
}
double BinaryNumericalFunctionExpression::evaluateAsDouble(Valuation const* valuation) const {
STORM_LOG_THROW(this->hasNumericalType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as double.");
double firstOperandEvaluation = this->getFirstOperand()->evaluateAsDouble(valuation);
double secondOperandEvaluation = this->getSecondOperand()->evaluateAsDouble(valuation);
double result = 0;
switch (this->getOperatorType()) {
case OperatorType::Plus: result = firstOperandEvaluation + secondOperandEvaluation; break;
case OperatorType::Minus: result = firstOperandEvaluation - secondOperandEvaluation; break;
case OperatorType::Times: result = firstOperandEvaluation * secondOperandEvaluation; break;
case OperatorType::Divide: result = firstOperandEvaluation / secondOperandEvaluation; break;
case OperatorType::Min: result = std::min(firstOperandEvaluation, secondOperandEvaluation); break;
case OperatorType::Max: result = std::max(firstOperandEvaluation, secondOperandEvaluation); break;
case OperatorType::Power: result = std::pow(firstOperandEvaluation, secondOperandEvaluation); break;
case OperatorType::Modulo: result = std::fmod(firstOperandEvaluation, secondOperandEvaluation); break;
}
return result;
}
std::shared_ptr<BaseExpression const> BinaryNumericalFunctionExpression::simplify() const {
std::shared_ptr<BaseExpression const> firstOperandSimplified = this->getFirstOperand()->simplify();
std::shared_ptr<BaseExpression const> secondOperandSimplified = this->getSecondOperand()->simplify();
if (firstOperandSimplified->isLiteral() && secondOperandSimplified->isLiteral()) {
if (this->hasIntegerType()) {
int_fast64_t firstOperandEvaluation = firstOperandSimplified->evaluateAsInt();
int_fast64_t secondOperandEvaluation = secondOperandSimplified->evaluateAsInt();
boost::optional<int_fast64_t> newValue;
switch (this->getOperatorType()) {
case OperatorType::Plus: newValue = firstOperandEvaluation + secondOperandEvaluation; break;
case OperatorType::Minus: newValue = firstOperandEvaluation - secondOperandEvaluation; break;
case OperatorType::Times: newValue = firstOperandEvaluation * secondOperandEvaluation; break;
case OperatorType::Min: newValue = std::min(firstOperandEvaluation, secondOperandEvaluation); break;
case OperatorType::Max: newValue = std::max(firstOperandEvaluation, secondOperandEvaluation); break;
case OperatorType::Power:
if (secondOperandEvaluation >= 0) {
// Only simplify if this evaluates to an integer.
// Otherwise, we note that the type of this expression could change due to simplifications (which might or might not be expected)
newValue = static_cast<int_fast64_t>(std::pow(firstOperandEvaluation, secondOperandEvaluation));
}
break;
case OperatorType::Modulo: newValue = firstOperandEvaluation % secondOperandEvaluation; break;
case OperatorType::Divide:
if (firstOperandEvaluation % secondOperandEvaluation == 0) {
// Only simplify if there is no remainder, because otherwise it is not clear whether we want integer division or floating point division.
// Otherwise, we note that the type of this expression could change due to simplifications (which might or might not be expected)
newValue = firstOperandEvaluation / secondOperandEvaluation;
}
break;
}
if (newValue) {
return std::shared_ptr<BaseExpression>(new IntegerLiteralExpression(this->getManager(), newValue.get()));
}
} else if (this->hasRationalType()) {
storm::RationalNumber firstOperandEvaluation = firstOperandSimplified->evaluateAsRational();
storm::RationalNumber secondOperandEvaluation = secondOperandSimplified->evaluateAsRational();
boost::optional<storm::RationalNumber> newValue;
switch (this->getOperatorType()) {
case OperatorType::Plus: newValue = firstOperandEvaluation + secondOperandEvaluation; break;
case OperatorType::Minus: newValue = firstOperandEvaluation - secondOperandEvaluation; break;
case OperatorType::Times: newValue = firstOperandEvaluation * secondOperandEvaluation; break;
case OperatorType::Min: newValue = std::min(firstOperandEvaluation, secondOperandEvaluation); break;
case OperatorType::Max: newValue = std::max(firstOperandEvaluation, secondOperandEvaluation); break;
case OperatorType::Divide: newValue = firstOperandEvaluation / secondOperandEvaluation; break;
case OperatorType::Power: {
if (carl::isInteger(secondOperandEvaluation)) {
auto exponent = carl::toInt<carl::sint>(secondOperandEvaluation);
if (exponent >= 0) {
newValue = carl::pow(firstOperandEvaluation, exponent);
} else {
storm::RationalNumber power = carl::pow(firstOperandEvaluation, -exponent);
newValue = storm::utility::one<storm::RationalNumber>() / power;
}
}
break;
}
case OperatorType::Modulo: {
if (carl::isInteger(firstOperandEvaluation) && carl::isInteger(secondOperandEvaluation)) {
newValue = storm::utility::mod(storm::utility::numerator(firstOperandEvaluation), storm::utility::numerator(secondOperandEvaluation));
}
break;
}
}
if (newValue) {
return std::shared_ptr<BaseExpression>(new RationalLiteralExpression(this->getManager(), newValue.get()));
}
}
}
if (firstOperandSimplified.get() == this->getFirstOperand().get() && secondOperandSimplified.get() == this->getSecondOperand().get()) {
return this->shared_from_this();
} else {
return std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(this->getManager(), this->getType(), firstOperandSimplified, secondOperandSimplified, this->getOperatorType()));
}
}
boost::any BinaryNumericalFunctionExpression::accept(ExpressionVisitor& visitor, boost::any const& data) const {
return visitor.visit(*this, data);
}
bool BinaryNumericalFunctionExpression::isBinaryNumericalFunctionExpression() const {
return true;
}
void BinaryNumericalFunctionExpression::printToStream(std::ostream& stream) const {
stream << "(";
switch (this->getOperatorType()) {
case OperatorType::Plus: stream << *this->getFirstOperand() << " + " << *this->getSecondOperand(); break;
case OperatorType::Minus: stream << *this->getFirstOperand() << " - " << *this->getSecondOperand(); break;
case OperatorType::Times: stream << *this->getFirstOperand() << " * " << *this->getSecondOperand(); break;
case OperatorType::Divide: stream << *this->getFirstOperand() << " / " << *this->getSecondOperand(); break;
case OperatorType::Min: stream << "min(" << *this->getFirstOperand() << ", " << *this->getSecondOperand() << ")"; break;
case OperatorType::Max: stream << "max(" << *this->getFirstOperand() << ", " << *this->getSecondOperand() << ")"; break;
case OperatorType::Power: stream << *this->getFirstOperand() << " ^ " << *this->getSecondOperand(); break;
case OperatorType::Modulo: stream << *this->getFirstOperand() << " % " << *this->getSecondOperand(); break;
}
stream << ")";
}
}
}