#include <cmath> #include <boost/variant.hpp> #include "storm/adapters/RationalNumberAdapter.h" #include "storm/storage/expressions/UnaryNumericalFunctionExpression.h" #include "storm/storage/expressions/IntegerLiteralExpression.h" #include "storm/storage/expressions/RationalLiteralExpression.h" #include "ExpressionVisitor.h" #include "storm/utility/macros.h" #include "storm/utility/constants.h" #include "storm/exceptions/InvalidTypeException.h" #include "storm/exceptions/InvalidOperationException.h" namespace storm { namespace expressions { UnaryNumericalFunctionExpression::UnaryNumericalFunctionExpression(ExpressionManager const& manager, Type const& type, std::shared_ptr<BaseExpression const> const& operand, OperatorType operatorType) : UnaryExpression(manager, type, operand), operatorType(operatorType) { // Intentionally left empty. } UnaryNumericalFunctionExpression::OperatorType UnaryNumericalFunctionExpression::getOperatorType() const { return this->operatorType; } storm::expressions::OperatorType UnaryNumericalFunctionExpression::getOperator() const { storm::expressions::OperatorType result = storm::expressions::OperatorType::Minus; switch (this->getOperatorType()) { case OperatorType::Minus: result = storm::expressions::OperatorType::Minus; break; case OperatorType::Floor: result = storm::expressions::OperatorType::Floor; break; case OperatorType::Ceil: result = storm::expressions::OperatorType::Ceil; break; } return result; } int_fast64_t UnaryNumericalFunctionExpression::evaluateAsInt(Valuation const* valuation) const { STORM_LOG_THROW(this->hasIntegerType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as integer."); if (this->getOperatorType() == OperatorType::Minus) { STORM_LOG_THROW(this->getOperand()->hasIntegerType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as integer."); int_fast64_t result = this->getOperand()->evaluateAsInt(valuation); return -result; } else { // TODO: this should evaluate the operand as a rational. double result = this->getOperand()->evaluateAsDouble(valuation); switch (this->getOperatorType()) { case OperatorType::Floor: return static_cast<int_fast64_t>(std::floor(result)); break; case OperatorType::Ceil: return static_cast<int_fast64_t>(std::ceil(result)); break; default: STORM_LOG_ASSERT(false, "All other operator types should have been handled before."); return 0;// Warning suppression. } } } double UnaryNumericalFunctionExpression::evaluateAsDouble(Valuation const* valuation) const { STORM_LOG_THROW(this->hasNumericalType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as double."); double result = this->getOperand()->evaluateAsDouble(valuation); switch (this->getOperatorType()) { case OperatorType::Minus: result = -result; break; case OperatorType::Floor: result = std::floor(result); break; case OperatorType::Ceil: result = std::ceil(result); break; } return result; } std::shared_ptr<BaseExpression const> UnaryNumericalFunctionExpression::simplify() const { std::shared_ptr<BaseExpression const> operandSimplified = this->getOperand()->simplify(); if (operandSimplified->isLiteral()) { if (operandSimplified->hasIntegerType()) { int_fast64_t value = operandSimplified->evaluateAsInt(); switch (this->getOperatorType()) { case OperatorType::Minus: value = -value; break; // Nothing to be done for the other cases: case OperatorType::Floor: case OperatorType::Ceil: break; } return std::shared_ptr<BaseExpression>(new IntegerLiteralExpression(this->getManager(), value)); } else { storm::RationalNumber value = operandSimplified->evaluateAsRational(); bool convertToInteger = false; switch (this->getOperatorType()) { case OperatorType::Minus: value = -value; break; case OperatorType::Floor: value = storm::utility::floor(value); convertToInteger = true; break; case OperatorType::Ceil: value = storm::utility::ceil(value); convertToInteger = true; break; } if (convertToInteger) { return std::shared_ptr<BaseExpression>(new IntegerLiteralExpression(this->getManager(), storm::utility::convertNumber<int64_t>(value))); } else { return std::shared_ptr<BaseExpression>(new RationalLiteralExpression(this->getManager(), value)); } } } if (operandSimplified.get() == this->getOperand().get()) { return this->shared_from_this(); } else { return std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(this->getManager(), this->getType(), operandSimplified, this->getOperatorType())); } } boost::any UnaryNumericalFunctionExpression::accept(ExpressionVisitor& visitor, boost::any const& data) const { return visitor.visit(*this, data); } bool UnaryNumericalFunctionExpression::isUnaryNumericalFunctionExpression() const { return true; } void UnaryNumericalFunctionExpression::printToStream(std::ostream& stream) const { switch (this->getOperatorType()) { case OperatorType::Minus: stream << "-("; break; case OperatorType::Floor: stream << "floor("; break; case OperatorType::Ceil: stream << "ceil("; break; } stream << *this->getOperand() << ")"; } } }