diff --git a/src/exceptions/ExceptionMacros.h b/src/exceptions/ExceptionMacros.h index dabfe25cb..09e7058ec 100644 --- a/src/exceptions/ExceptionMacros.h +++ b/src/exceptions/ExceptionMacros.h @@ -16,16 +16,16 @@ extern log4cplus::Logger logger; assert(cond); \ } \ } while (false) -#define LOG_THROW(cond, exception, message) \ -{ \ - if (!(cond)) { \ - LOG4CPLUS_ERROR(logger, message); \ - throw exception() << message; \ - } \ -} while (false) #else #define LOG_ASSERT(cond, message) /* empty */ -#define LOG_THROW(cond, exception, message) /* empty */ #endif +#define LOG_THROW(cond, exception, message) \ +{ \ +if (!(cond)) { \ +LOG4CPLUS_ERROR(logger, message); \ +throw exception() << message; \ +} \ +} while (false) + #endif /* STORM_EXCEPTIONS_EXCEPTIONMACROS_H_ */ \ No newline at end of file diff --git a/src/exceptions/InvalidOperationException.h b/src/exceptions/InvalidOperationException.h new file mode 100644 index 000000000..728f7ff1c --- /dev/null +++ b/src/exceptions/InvalidOperationException.h @@ -0,0 +1,17 @@ +#ifndef STORM_EXCEPTIONS_INVALIDOPERATIONEXCEPTION_H_ +#define STORM_EXCEPTIONS_INVALIDOPERATIONEXCEPTION_H_ + +#include "src/exceptions/BaseException.h" + +namespace storm { + + namespace exceptions { + + /*! + * @brief This exception is thrown when an operation is invalid in this context + */ + STORM_EXCEPTION_DEFINE_NEW(InvalidOperationException) + + } // namespace exceptions +} // namespace storm +#endif // STORM_EXCEPTIONS_INVALIDOPERATIONEXCEPTION_H_ diff --git a/src/exceptions/InvalidTypeException.h b/src/exceptions/InvalidTypeException.h new file mode 100644 index 000000000..de2874218 --- /dev/null +++ b/src/exceptions/InvalidTypeException.h @@ -0,0 +1,18 @@ +#ifndef STORM_EXCEPTIONS_INVALIDTYPEEXCEPTION_H_ +#define STORM_EXCEPTIONS_INVALIDTYPEEXCEPTION_H_ + +#include "src/exceptions/BaseException.h" + +namespace storm { + + namespace exceptions { + + /*! + * @brief This exception is thrown when a type is invalid in this context + */ + STORM_EXCEPTION_DEFINE_NEW(InvalidTypeException) + + } // namespace exceptions + +} // namespace storm +#endif // STORM_EXCEPTIONS_INVALIDTYPEEXCEPTION_H_ diff --git a/src/storage/expressions/BaseExpression.cpp b/src/storage/expressions/BaseExpression.cpp index ca2db09f1..200b8ac0f 100644 --- a/src/storage/expressions/BaseExpression.cpp +++ b/src/storage/expressions/BaseExpression.cpp @@ -1,5 +1,6 @@ #include "src/storage/expressions/BaseExpression.h" #include "src/exceptions/ExceptionMacros.h" +#include "src/exceptions/InvalidTypeException.h" namespace storm { namespace expressions { @@ -11,6 +12,10 @@ namespace storm { return this->returnType; } + bool BaseExpression::hasIntegralReturnType() const { + return this->getReturnType() == ExpressionReturnType::Int; + } + bool BaseExpression::hasNumericalReturnType() const { return this->getReturnType() == ExpressionReturnType::Double || this->getReturnType() == ExpressionReturnType::Int; } @@ -20,15 +25,15 @@ namespace storm { } int_fast64_t BaseExpression::evaluateAsInt(Valuation const& evaluation) const { - LOG_ASSERT(false, "Unable to evaluate expression as integer."); + LOG_THROW(false, storm::exceptions::InvalidTypeException, "Unable to evaluate expression as integer."); } bool BaseExpression::evaluateAsBool(Valuation const& evaluation) const { - LOG_ASSERT(false, "Unable to evaluate expression as boolean."); + LOG_THROW(false, storm::exceptions::InvalidTypeException, "Unable to evaluate expression as boolean."); } double BaseExpression::evaluateAsDouble(Valuation const& evaluation) const { - LOG_ASSERT(false, "Unable to evaluate expression as double."); + LOG_THROW(false, storm::exceptions::InvalidTypeException, "Unable to evaluate expression as double."); } bool BaseExpression::isConstant() const { @@ -42,5 +47,14 @@ namespace storm { bool BaseExpression::isFalse() const { return false; } + + std::shared_ptr BaseExpression::getSharedPointer() const { + return this->shared_from_this(); + } + + std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression) { + expression.printToStream(stream); + return stream; + } } } \ No newline at end of file diff --git a/src/storage/expressions/BaseExpression.h b/src/storage/expressions/BaseExpression.h index 0698278da..7ca9e63ad 100644 --- a/src/storage/expressions/BaseExpression.h +++ b/src/storage/expressions/BaseExpression.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "src/storage/expressions/Valuation.h" #include "src/storage/expressions/ExpressionVisitor.h" @@ -19,7 +20,7 @@ namespace storm { /*! * The base class of all expression classes. */ - class BaseExpression { + class BaseExpression : public std::enable_shared_from_this { public: /*! * Constructs a base expression with the given return type. @@ -107,7 +108,7 @@ namespace storm { * * @return A pointer to the simplified expression. */ - virtual std::unique_ptr simplify() const = 0; + virtual std::shared_ptr simplify() const = 0; /*! * Accepts the given visitor by calling its visit method. @@ -117,18 +118,18 @@ namespace storm { virtual void accept(ExpressionVisitor* visitor) const = 0; /*! - * Performs a deep-copy of the expression. + * Retrieves whether the expression has a numerical return type, i.e., integer or double. * - * @return A pointer to a deep-copy of the expression. + * @return True iff the expression has a numerical return type. */ - virtual std::unique_ptr clone() const = 0; + bool hasNumericalReturnType() const; /*! - * Retrieves whether the expression has a numerical return type, i.e., integer or double. + * Retrieves whether the expression has an integral return type, i.e., integer. * - * @return True iff the expression has a numerical return type. + * @return True iff the expression has an integral return type. */ - bool hasNumericalReturnType() const; + bool hasIntegralReturnType() const; /*! * Retrieves whether the expression has a boolean return type. @@ -137,6 +138,13 @@ namespace storm { */ bool hasBooleanReturnType() const; + /*! + * Retrieves a shared pointer to this expression. + * + * @return A shared pointer to this expression. + */ + std::shared_ptr getSharedPointer() const; + /*! * Retrieves the return type of the expression. * @@ -144,6 +152,15 @@ namespace storm { */ ExpressionReturnType getReturnType() const; + friend std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression); + protected: + /*! + * Prints the expression to the given stream. + * + * @param stream The stream to which to write the expression. + */ + virtual void printToStream(std::ostream& stream) const = 0; + private: // The return type of this expression. ExpressionReturnType returnType; diff --git a/src/storage/expressions/BinaryBooleanFunctionExpression.cpp b/src/storage/expressions/BinaryBooleanFunctionExpression.cpp index e4034794e..8f5abf57c 100644 --- a/src/storage/expressions/BinaryBooleanFunctionExpression.cpp +++ b/src/storage/expressions/BinaryBooleanFunctionExpression.cpp @@ -1,8 +1,11 @@ #include "src/storage/expressions/BinaryBooleanFunctionExpression.h" +#include "src/exceptions/ExceptionMacros.h" +#include "src/exceptions/InvalidTypeException.h" + namespace storm { namespace expressions { - BinaryBooleanFunctionExpression::BinaryBooleanFunctionExpression(ExpressionReturnType returnType, std::unique_ptr&& firstOperand, std::unique_ptr&& secondOperand, OperatorType operatorType) : BinaryExpression(returnType, std::move(firstOperand), std::move(secondOperand)), operatorType(operatorType) { + BinaryBooleanFunctionExpression::BinaryBooleanFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType) : BinaryExpression(returnType, firstOperand, secondOperand), operatorType(operatorType) { // Intentionally left empty. } @@ -11,6 +14,8 @@ namespace storm { } bool BinaryBooleanFunctionExpression::evaluateAsBool(Valuation const& valuation) const { + LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as boolean."); + bool firstOperandEvaluation = this->getFirstOperand()->evaluateAsBool(valuation); bool secondOperandEvaluation = this->getSecondOperand()->evaluateAsBool(valuation); @@ -23,9 +28,9 @@ namespace storm { return result; } - std::unique_ptr BinaryBooleanFunctionExpression::simplify() const { - std::unique_ptr firstOperandSimplified = this->getFirstOperand()->simplify(); - std::unique_ptr secondOperandSimplified = this->getSecondOperand()->simplify(); + std::shared_ptr BinaryBooleanFunctionExpression::simplify() const { + std::shared_ptr firstOperandSimplified = this->getFirstOperand()->simplify(); + std::shared_ptr secondOperandSimplified = this->getSecondOperand()->simplify(); switch (this->getOperatorType()) { case OperatorType::And: if (firstOperandSimplified->isTrue()) { @@ -49,15 +54,25 @@ namespace storm { } } - return std::unique_ptr(new BinaryBooleanFunctionExpression(this->getReturnType(), std::move(firstOperandSimplified), std::move(secondOperandSimplified), this->getOperatorType())); + // If the two successors remain unchanged, we can return a shared_ptr to this very object. + if (firstOperandSimplified.get() == this->getFirstOperand().get() && secondOperandSimplified.get() == this->getSecondOperand().get()) { + return this->shared_from_this(); + } else { + return std::shared_ptr(new BinaryBooleanFunctionExpression(this->getReturnType(), firstOperandSimplified, secondOperandSimplified, this->getOperatorType())); + } } void BinaryBooleanFunctionExpression::accept(ExpressionVisitor* visitor) const { visitor->visit(this); } - std::unique_ptr BinaryBooleanFunctionExpression::clone() const { - return std::unique_ptr(new BinaryBooleanFunctionExpression(*this)); + void BinaryBooleanFunctionExpression::printToStream(std::ostream& stream) const { + stream << "(" << *this->getFirstOperand(); + switch (this->getOperatorType()) { + case OperatorType::And: stream << " && "; break; + case OperatorType::Or: stream << " || "; break; + } + stream << *this->getSecondOperand() << ")"; } } } \ No newline at end of file diff --git a/src/storage/expressions/BinaryBooleanFunctionExpression.h b/src/storage/expressions/BinaryBooleanFunctionExpression.h index eaabd1d56..41ae92b7b 100644 --- a/src/storage/expressions/BinaryBooleanFunctionExpression.h +++ b/src/storage/expressions/BinaryBooleanFunctionExpression.h @@ -20,7 +20,7 @@ namespace storm { * @param secondOperand The second operand of the expression. * @param functionType The operator of the expression. */ - BinaryBooleanFunctionExpression(ExpressionReturnType returnType, std::unique_ptr&& fistOperand, std::unique_ptr&& secondOperand, OperatorType operatorType); + BinaryBooleanFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType); // Instantiate constructors and assignments with their default implementations. BinaryBooleanFunctionExpression(BinaryBooleanFunctionExpression const& other) = default; @@ -31,10 +31,9 @@ namespace storm { // Override base class methods. virtual bool evaluateAsBool(Valuation const& valuation) const override; - virtual std::unique_ptr simplify() const override; + virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; - virtual std::unique_ptr clone() const override; - + /*! * Retrieves the operator associated with the expression. * @@ -42,6 +41,10 @@ namespace storm { */ OperatorType getOperatorType() const; + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + private: // The operator of the expression. OperatorType operatorType; diff --git a/src/storage/expressions/BinaryExpression.cpp b/src/storage/expressions/BinaryExpression.cpp index 9038aca41..cfd18447c 100644 --- a/src/storage/expressions/BinaryExpression.cpp +++ b/src/storage/expressions/BinaryExpression.cpp @@ -2,22 +2,10 @@ namespace storm { namespace expressions { - BinaryExpression::BinaryExpression(ExpressionReturnType returnType, std::unique_ptr&& firstOperand, std::unique_ptr&& secondOperand) : BaseExpression(returnType), firstOperand(std::move(firstOperand)), secondOperand(std::move(secondOperand)) { + BinaryExpression::BinaryExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand) : BaseExpression(returnType), firstOperand(firstOperand), secondOperand(secondOperand) { // Intentionally left empty. } - - BinaryExpression::BinaryExpression(BinaryExpression const& other) : BaseExpression(other.getReturnType()), firstOperand(other.getFirstOperand()->clone()), secondOperand(other.getSecondOperand()->clone()) { - // Intentionally left empty. - } - - BinaryExpression& BinaryExpression::operator=(BinaryExpression const& other) { - if (this != &other) { - this->firstOperand = other.getFirstOperand()->clone(); - this->secondOperand = other.getSecondOperand()->clone(); - } - return *this; - } - + bool BinaryExpression::isConstant() const { return this->getFirstOperand()->isConstant() && this->getSecondOperand()->isConstant(); } @@ -36,11 +24,11 @@ namespace storm { return firstConstantSet; } - std::unique_ptr const& BinaryExpression::getFirstOperand() const { + std::shared_ptr const& BinaryExpression::getFirstOperand() const { return this->firstOperand; } - std::unique_ptr const& BinaryExpression::getSecondOperand() const { + std::shared_ptr const& BinaryExpression::getSecondOperand() const { return this->secondOperand; } } diff --git a/src/storage/expressions/BinaryExpression.h b/src/storage/expressions/BinaryExpression.h index a94513cb5..fd1b1903a 100644 --- a/src/storage/expressions/BinaryExpression.h +++ b/src/storage/expressions/BinaryExpression.h @@ -17,13 +17,11 @@ namespace storm { * @param firstOperand The first operand of the expression. * @param secondOperand The second operand of the expression. */ - BinaryExpression(ExpressionReturnType returnType, std::unique_ptr&& firstOperand, std::unique_ptr&& secondOperand); + BinaryExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand); - // Provide custom versions of copy construction and assignment. - BinaryExpression(BinaryExpression const& other); - BinaryExpression& operator=(BinaryExpression const& other); - - // Create default variants of move construction/assignment and virtual destructor. + // Instantiate constructors and assignments with their default implementations. + BinaryExpression(BinaryExpression const& other) = default; + BinaryExpression& operator=(BinaryExpression const& other) = default; BinaryExpression(BinaryExpression&&) = default; BinaryExpression& operator=(BinaryExpression&&) = default; virtual ~BinaryExpression() = default; @@ -38,21 +36,21 @@ namespace storm { * * @return The first operand of the expression. */ - std::unique_ptr const& getFirstOperand() const; + std::shared_ptr const& getFirstOperand() const; /*! * Retrieves the second operand of the expression. * * @return The second operand of the expression. */ - std::unique_ptr const& getSecondOperand() const; + std::shared_ptr const& getSecondOperand() const; private: // The first operand of the expression. - std::unique_ptr firstOperand; + std::shared_ptr firstOperand; // The second operand of the expression. - std::unique_ptr secondOperand; + std::shared_ptr secondOperand; }; } } diff --git a/src/storage/expressions/BinaryNumericalFunctionExpression.cpp b/src/storage/expressions/BinaryNumericalFunctionExpression.cpp index 7139db695..521e554bb 100644 --- a/src/storage/expressions/BinaryNumericalFunctionExpression.cpp +++ b/src/storage/expressions/BinaryNumericalFunctionExpression.cpp @@ -2,15 +2,21 @@ #include "src/storage/expressions/BinaryNumericalFunctionExpression.h" #include "src/exceptions/ExceptionMacros.h" +#include "src/exceptions/InvalidTypeException.h" namespace storm { namespace expressions { - BinaryNumericalFunctionExpression::BinaryNumericalFunctionExpression(ExpressionReturnType returnType, std::unique_ptr&& firstOperand, std::unique_ptr&& secondOperand, OperatorType operatorType) : BinaryExpression(returnType, std::move(firstOperand), std::move(secondOperand)), operatorType(operatorType) { + BinaryNumericalFunctionExpression::BinaryNumericalFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType) : BinaryExpression(returnType, firstOperand, secondOperand), operatorType(operatorType) { // Intentionally left empty. } + BinaryNumericalFunctionExpression::OperatorType BinaryNumericalFunctionExpression::getOperatorType() const { + return this->operatorType; + } + int_fast64_t BinaryNumericalFunctionExpression::evaluateAsInt(Valuation const& valuation) const { - LOG_ASSERT(this->getReturnType() == ExpressionReturnType::Int, "Unable to evaluate expression as integer."); + LOG_THROW(this->hasIntegralReturnType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as integer."); + int_fast64_t firstOperandEvaluation = this->getFirstOperand()->evaluateAsInt(valuation); int_fast64_t secondOperandEvaluation = this->getSecondOperand()->evaluateAsInt(valuation); switch (this->getOperatorType()) { @@ -24,7 +30,8 @@ namespace storm { } double BinaryNumericalFunctionExpression::evaluateAsDouble(Valuation const& valuation) const { - LOG_ASSERT(this->getReturnType() == ExpressionReturnType::Double, "Unable to evaluate expression as double."); + LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as double."); + double firstOperandEvaluation = this->getFirstOperand()->evaluateAsDouble(valuation); double secondOperandEvaluation = this->getSecondOperand()->evaluateAsDouble(valuation); switch (this->getOperatorType()) { @@ -37,16 +44,32 @@ namespace storm { } } - std::unique_ptr BinaryNumericalFunctionExpression::simplify() const { - return std::unique_ptr(new BinaryNumericalFunctionExpression(this->getReturnType(), this->getFirstOperand()->simplify(), this->getSecondOperand()->simplify(), this->getOperatorType())); + std::shared_ptr BinaryNumericalFunctionExpression::simplify() const { + std::shared_ptr firstOperandSimplified = this->getFirstOperand()->simplify(); + std::shared_ptr secondOperandSimplified = this->getSecondOperand()->simplify(); + + if (firstOperandSimplified.get() == this->getFirstOperand().get() && secondOperandSimplified.get() == this->getSecondOperand().get()) { + return this->shared_from_this(); + } else { + return std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType(), firstOperandSimplified, secondOperandSimplified, this->getOperatorType())); + } } void BinaryNumericalFunctionExpression::accept(ExpressionVisitor* visitor) const { visitor->visit(this); } - std::unique_ptr BinaryNumericalFunctionExpression::clone() const { - return std::unique_ptr(new BinaryNumericalFunctionExpression(*this)); + void BinaryNumericalFunctionExpression::printToStream(std::ostream& stream) const { + stream << "("; + switch (this->getOperatorType()) { + case OperatorType::Plus: stream << *this->getFirstOperand() << " + " << *this->getSecondOperand(); break; + case OperatorType::Minus: stream << *this->getFirstOperand() << " - " << *this->getSecondOperand(); break; + case OperatorType::Times: stream << *this->getFirstOperand() << " * " << *this->getSecondOperand(); break; + case OperatorType::Divide: stream << *this->getFirstOperand() << " / " << *this->getSecondOperand(); break; + case OperatorType::Min: stream << "min(" << *this->getFirstOperand() << ", " << *this->getSecondOperand() << ")"; break; + case OperatorType::Max: stream << "max(" << *this->getFirstOperand() << ", " << *this->getSecondOperand() << ")"; break; + } + stream << ")"; } } } \ No newline at end of file diff --git a/src/storage/expressions/BinaryNumericalFunctionExpression.h b/src/storage/expressions/BinaryNumericalFunctionExpression.h index 92d4bc0d2..8d49e541a 100644 --- a/src/storage/expressions/BinaryNumericalFunctionExpression.h +++ b/src/storage/expressions/BinaryNumericalFunctionExpression.h @@ -20,7 +20,7 @@ namespace storm { * @param secondOperand The second operand of the expression. * @param functionType The operator of the expression. */ - BinaryNumericalFunctionExpression(ExpressionReturnType returnType, std::unique_ptr&& firstOperand, std::unique_ptr&& secondOperand, OperatorType operatorType); + BinaryNumericalFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, OperatorType operatorType); // Instantiate constructors and assignments with their default implementations. BinaryNumericalFunctionExpression(BinaryNumericalFunctionExpression const& other) = default; @@ -32,10 +32,9 @@ namespace storm { // Override base class methods. virtual int_fast64_t evaluateAsInt(Valuation const& valuation) const override; virtual double evaluateAsDouble(Valuation const& valuation) const override; - virtual std::unique_ptr simplify() const override; + virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; - virtual std::unique_ptr clone() const override; - + /*! * Retrieves the operator associated with the expression. * @@ -43,6 +42,10 @@ namespace storm { */ OperatorType getOperatorType() const; + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + private: // The operator of the expression. OperatorType operatorType; diff --git a/src/storage/expressions/BinaryRelationExpression.cpp b/src/storage/expressions/BinaryRelationExpression.cpp index fb4123fca..e23df77c4 100644 --- a/src/storage/expressions/BinaryRelationExpression.cpp +++ b/src/storage/expressions/BinaryRelationExpression.cpp @@ -1,12 +1,17 @@ #include "src/storage/expressions/BinaryRelationExpression.h" +#include "src/exceptions/ExceptionMacros.h" +#include "src/exceptions/InvalidTypeException.h" + namespace storm { namespace expressions { - BinaryRelationExpression::BinaryRelationExpression(ExpressionReturnType returnType, std::unique_ptr&& firstOperand, std::unique_ptr&& secondOperand, RelationType relationType) : BinaryExpression(returnType, std::move(firstOperand), std::move(secondOperand)), relationType(relationType) { + BinaryRelationExpression::BinaryRelationExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, RelationType relationType) : BinaryExpression(returnType, firstOperand, secondOperand), relationType(relationType) { // Intentionally left empty. } bool BinaryRelationExpression::evaluateAsBool(Valuation const& valuation) const { + LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as boolean."); + double firstOperandEvaluated = this->getFirstOperand()->evaluateAsDouble(valuation); double secondOperandEvaluated = this->getSecondOperand()->evaluateAsDouble(valuation); switch (this->getRelationType()) { @@ -19,20 +24,36 @@ namespace storm { } } - std::unique_ptr BinaryRelationExpression::simplify() const { - return std::unique_ptr(new BinaryRelationExpression(this->getReturnType(), this->getFirstOperand()->simplify(), this->getSecondOperand()->simplify(), this->getRelationType())); + std::shared_ptr BinaryRelationExpression::simplify() const { + std::shared_ptr firstOperandSimplified = this->getFirstOperand()->simplify(); + std::shared_ptr secondOperandSimplified = this->getSecondOperand()->simplify(); + + if (firstOperandSimplified.get() == this->getFirstOperand().get() && secondOperandSimplified.get() == this->getSecondOperand().get()) { + return this->shared_from_this(); + } else { + return std::shared_ptr(new BinaryRelationExpression(this->getReturnType(), firstOperandSimplified, secondOperandSimplified, this->getRelationType())); + } } void BinaryRelationExpression::accept(ExpressionVisitor* visitor) const { visitor->visit(this); } - std::unique_ptr BinaryRelationExpression::clone() const { - return std::unique_ptr(new BinaryRelationExpression(*this)); - } - BinaryRelationExpression::RelationType BinaryRelationExpression::getRelationType() const { return this->relationType; } + + void BinaryRelationExpression::printToStream(std::ostream& stream) const { + stream << "(" << *this->getFirstOperand(); + switch (this->getRelationType()) { + case RelationType::Equal: stream << " == "; break; + case RelationType::NotEqual: stream << " != "; break; + case RelationType::Greater: stream << " > "; break; + case RelationType::GreaterOrEqual: stream << " >= "; break; + case RelationType::Less: stream << " < "; break; + case RelationType::LessOrEqual: stream << " <= "; break; + } + stream << *this->getSecondOperand() << ")"; + } } } \ No newline at end of file diff --git a/src/storage/expressions/BinaryRelationExpression.h b/src/storage/expressions/BinaryRelationExpression.h index 31145ec30..70201313c 100644 --- a/src/storage/expressions/BinaryRelationExpression.h +++ b/src/storage/expressions/BinaryRelationExpression.h @@ -20,7 +20,7 @@ namespace storm { * @param secondOperand The second operand of the expression. * @param relationType The operator of the expression. */ - BinaryRelationExpression(ExpressionReturnType returnType, std::unique_ptr&& firstOperand, std::unique_ptr&& secondOperand, RelationType relationType); + BinaryRelationExpression(ExpressionReturnType returnType, std::shared_ptr const& firstOperand, std::shared_ptr const& secondOperand, RelationType relationType); // Instantiate constructors and assignments with their default implementations. BinaryRelationExpression(BinaryRelationExpression const& other) = default; @@ -31,10 +31,9 @@ namespace storm { // Override base class methods. virtual bool evaluateAsBool(Valuation const& valuation) const override; - virtual std::unique_ptr simplify() const override; + virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; - virtual std::unique_ptr clone() const override; - + /*! * Retrieves the relation associated with the expression. * @@ -42,6 +41,10 @@ namespace storm { */ RelationType getRelationType() const; + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + private: // The relation type of the expression. RelationType relationType; diff --git a/src/storage/expressions/BooleanConstantExpression.cpp b/src/storage/expressions/BooleanConstantExpression.cpp index c20d8a8b7..b3b453097 100644 --- a/src/storage/expressions/BooleanConstantExpression.cpp +++ b/src/storage/expressions/BooleanConstantExpression.cpp @@ -10,12 +10,12 @@ namespace storm { return valuation.getBooleanValue(this->getConstantName()); } - std::unique_ptr BooleanConstantExpression::clone() const { - return std::unique_ptr(new BooleanConstantExpression(*this)); - } - void BooleanConstantExpression::accept(ExpressionVisitor* visitor) const { visitor->visit(this); } + + std::shared_ptr BooleanConstantExpression::simplify() const { + return this->shared_from_this(); + } } } \ No newline at end of file diff --git a/src/storage/expressions/BooleanConstantExpression.h b/src/storage/expressions/BooleanConstantExpression.h index 7b72fdc4e..d239f146d 100644 --- a/src/storage/expressions/BooleanConstantExpression.h +++ b/src/storage/expressions/BooleanConstantExpression.h @@ -22,9 +22,9 @@ namespace storm { virtual ~BooleanConstantExpression() = default; // Override base class methods. - virtual bool evaluateAsBool(Valuation const& valuation) const; - virtual std::unique_ptr clone() const; - virtual void accept(ExpressionVisitor* visitor) const; + virtual bool evaluateAsBool(Valuation const& valuation) const override; + virtual void accept(ExpressionVisitor* visitor) const override; + virtual std::shared_ptr simplify() const override; }; } } diff --git a/src/storage/expressions/BooleanLiteralExpression.cpp b/src/storage/expressions/BooleanLiteralExpression.cpp index 371e9f7e9..7fb01cec4 100644 --- a/src/storage/expressions/BooleanLiteralExpression.cpp +++ b/src/storage/expressions/BooleanLiteralExpression.cpp @@ -2,7 +2,7 @@ namespace storm { namespace expressions { - BooleanLiteralExpression::BooleanLiteralExpression(bool value) : value(value) { + BooleanLiteralExpression::BooleanLiteralExpression(bool value) : BaseExpression(ExpressionReturnType::Bool), value(value) { // Intentionally left empty. } @@ -23,27 +23,27 @@ namespace storm { } std::set BooleanLiteralExpression::getVariables() const { - return {}; + return std::set(); } std::set BooleanLiteralExpression::getConstants() const { - return {}; + return std::set(); } - std::unique_ptr BooleanLiteralExpression::simplify() const { - return this->clone(); + std::shared_ptr BooleanLiteralExpression::simplify() const { + return this->shared_from_this(); } void BooleanLiteralExpression::accept(ExpressionVisitor* visitor) const { visitor->visit(this); } - std::unique_ptr BooleanLiteralExpression::clone() const { - return std::unique_ptr(new BooleanLiteralExpression(*this)); - } - bool BooleanLiteralExpression::getValue() const { return this->value; } + + void BooleanLiteralExpression::printToStream(std::ostream& stream) const { + stream << (this->getValue() ? "true" : "false"); + } } } \ No newline at end of file diff --git a/src/storage/expressions/BooleanLiteralExpression.h b/src/storage/expressions/BooleanLiteralExpression.h index 837190c95..1c7baed09 100644 --- a/src/storage/expressions/BooleanLiteralExpression.h +++ b/src/storage/expressions/BooleanLiteralExpression.h @@ -5,7 +5,7 @@ namespace storm { namespace expressions { - class BooleanLiteralExpression : BaseExpression { + class BooleanLiteralExpression : public BaseExpression { public: /*! * Creates a boolean literal expression with the given value. @@ -28,9 +28,8 @@ namespace storm { virtual bool isFalse() const override; virtual std::set getVariables() const override; virtual std::set getConstants() const override; - virtual std::unique_ptr simplify() const override; + virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; - virtual std::unique_ptr clone() const override; /*! * Retrieves the value of the boolean literal. @@ -39,6 +38,10 @@ namespace storm { */ bool getValue() const; + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + private: // The value of the boolean literal. bool value; diff --git a/src/storage/expressions/ConstantExpression.cpp b/src/storage/expressions/ConstantExpression.cpp index 08bd26a95..d09a75573 100644 --- a/src/storage/expressions/ConstantExpression.cpp +++ b/src/storage/expressions/ConstantExpression.cpp @@ -14,12 +14,12 @@ namespace storm { return {this->getConstantName()}; } - std::unique_ptr ConstantExpression::simplify() const { - return this->clone(); - } - std::string const& ConstantExpression::getConstantName() const { return this->constantName; } + + void ConstantExpression::printToStream(std::ostream& stream) const { + stream << this->getConstantName(); + } } } \ No newline at end of file diff --git a/src/storage/expressions/ConstantExpression.h b/src/storage/expressions/ConstantExpression.h index b8b985ac2..70ccf6b35 100644 --- a/src/storage/expressions/ConstantExpression.h +++ b/src/storage/expressions/ConstantExpression.h @@ -25,7 +25,6 @@ namespace storm { // Override base class methods. virtual std::set getVariables() const override; virtual std::set getConstants() const override; - virtual std::unique_ptr simplify() const override; /*! * Retrieves the name of the constant. @@ -33,6 +32,10 @@ namespace storm { * @return The name of the constant. */ std::string const& getConstantName() const; + + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; private: // The name of the constant. diff --git a/src/storage/expressions/DoubleConstantExpression.cpp b/src/storage/expressions/DoubleConstantExpression.cpp index c0c7925dc..d0f5a2799 100644 --- a/src/storage/expressions/DoubleConstantExpression.cpp +++ b/src/storage/expressions/DoubleConstantExpression.cpp @@ -10,8 +10,8 @@ namespace storm { return valuation.getDoubleValue(this->getConstantName()); } - std::unique_ptr DoubleConstantExpression::clone() const { - return std::unique_ptr(new DoubleConstantExpression(*this)); + std::shared_ptr DoubleConstantExpression::simplify() const { + return this->shared_from_this(); } void DoubleConstantExpression::accept(ExpressionVisitor* visitor) const { diff --git a/src/storage/expressions/DoubleConstantExpression.h b/src/storage/expressions/DoubleConstantExpression.h index 28455190f..349de5216 100644 --- a/src/storage/expressions/DoubleConstantExpression.h +++ b/src/storage/expressions/DoubleConstantExpression.h @@ -22,9 +22,9 @@ namespace storm { virtual ~DoubleConstantExpression() = default; // Override base class methods. - virtual double evaluateAsDouble(Valuation const& valuation) const; - virtual std::unique_ptr clone() const; - virtual void accept(ExpressionVisitor* visitor) const; + virtual double evaluateAsDouble(Valuation const& valuation) const override; + virtual void accept(ExpressionVisitor* visitor) const override; + virtual std::shared_ptr simplify() const override; }; } } diff --git a/src/storage/expressions/DoubleLiteralExpression.cpp b/src/storage/expressions/DoubleLiteralExpression.cpp index a96bdcdd9..176823803 100644 --- a/src/storage/expressions/DoubleLiteralExpression.cpp +++ b/src/storage/expressions/DoubleLiteralExpression.cpp @@ -2,7 +2,7 @@ namespace storm { namespace expressions { - DoubleLiteralExpression::DoubleLiteralExpression(double value) : value(value) { + DoubleLiteralExpression::DoubleLiteralExpression(double value) : BaseExpression(ExpressionReturnType::Double), value(value) { // Intentionally left empty. } @@ -15,27 +15,27 @@ namespace storm { } std::set DoubleLiteralExpression::getVariables() const { - return {}; + return std::set(); } std::set DoubleLiteralExpression::getConstants() const { - return {}; + return std::set(); } - std::unique_ptr DoubleLiteralExpression::simplify() const { - return this->clone(); + std::shared_ptr DoubleLiteralExpression::simplify() const { + return this->shared_from_this(); } void DoubleLiteralExpression::accept(ExpressionVisitor* visitor) const { visitor->visit(this); } - std::unique_ptr DoubleLiteralExpression::clone() const { - return std::unique_ptr(new DoubleLiteralExpression(*this)); - } - double DoubleLiteralExpression::getValue() const { return this->value; } + + void DoubleLiteralExpression::printToStream(std::ostream& stream) const { + stream << this->getValue(); + } } } \ No newline at end of file diff --git a/src/storage/expressions/DoubleLiteralExpression.h b/src/storage/expressions/DoubleLiteralExpression.h index b15e5e013..5d97e6b59 100644 --- a/src/storage/expressions/DoubleLiteralExpression.h +++ b/src/storage/expressions/DoubleLiteralExpression.h @@ -5,7 +5,7 @@ namespace storm { namespace expressions { - class DoubleLiteralExpression : BaseExpression { + class DoubleLiteralExpression : public BaseExpression { public: /*! * Creates an double literal expression with the given value. @@ -26,10 +26,9 @@ namespace storm { virtual bool isConstant() const override; virtual std::set getVariables() const override; virtual std::set getConstants() const override; - virtual std::unique_ptr simplify() const override; + virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; - virtual std::unique_ptr clone() const override; - + /*! * Retrieves the value of the double literal. * @@ -37,6 +36,10 @@ namespace storm { */ double getValue() const; + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + private: // The value of the double literal. double value; diff --git a/src/storage/expressions/Expression.cpp b/src/storage/expressions/Expression.cpp index c73108336..bd2e3cbc9 100644 --- a/src/storage/expressions/Expression.cpp +++ b/src/storage/expressions/Expression.cpp @@ -3,32 +3,229 @@ #include "src/storage/expressions/Expression.h" #include "src/storage/expressions/SubstitutionVisitor.h" +#include "src/exceptions/InvalidTypeException.h" +#include "src/exceptions/ExceptionMacros.h" + +#include "src/storage/expressions/BinaryBooleanFunctionExpression.h" +#include "src/storage/expressions/BinaryNumericalFunctionExpression.h" +#include "src/storage/expressions/BinaryRelationExpression.h" +#include "src/storage/expressions/BooleanConstantExpression.h" +#include "src/storage/expressions/IntegerConstantExpression.h" +#include "src/storage/expressions/DoubleConstantExpression.h" +#include "src/storage/expressions/BooleanLiteralExpression.h" +#include "src/storage/expressions/IntegerLiteralExpression.h" +#include "src/storage/expressions/DoubleLiteralExpression.h" +#include "src/storage/expressions/VariableExpression.h" +#include "src/storage/expressions/UnaryBooleanFunctionExpression.h" +#include "src/storage/expressions/UnaryNumericalFunctionExpression.h" namespace storm { namespace expressions { - Expression::Expression(std::unique_ptr&& expressionPtr) : expressionPtr(std::move(expressionPtr)) { + Expression::Expression(std::shared_ptr const& expressionPtr) : expressionPtr(expressionPtr) { // Intentionally left empty. } template class MapType> Expression Expression::substitute(MapType const& identifierToExpressionMap) const { - SubstitutionVisitor visitor; - return visitor.substitute(this->getBaseExpressionPointer(), identifierToExpressionMap); + return SubstitutionVisitor(identifierToExpressionMap).substitute(this->getBaseExpressionPointer().get()); } - - Expression Expression::operator+(Expression const& other) { - return Expression(this->getBaseExpression() + other.getBaseExpression()); + + bool Expression::evaluateAsBool(Valuation const& valuation) const { + return this->getBaseExpression().evaluateAsBool(valuation); + } + + int_fast64_t Expression::evaluateAsInt(Valuation const& valuation) const { + return this->getBaseExpression().evaluateAsInt(valuation); + } + + double Expression::evaluateAsDouble(Valuation const& valuation) const { + return this->getBaseExpression().evaluateAsDouble(valuation); + } + + Expression Expression::simplify() { + return Expression(this->getBaseExpression().simplify()); + } + + bool Expression::isConstant() const { + return this->getBaseExpression().isConstant(); + } + + bool Expression::isTrue() const { + return this->getBaseExpression().isTrue(); + } + + bool Expression::isFalse() const { + return this->getBaseExpression().isFalse(); + } + + std::set Expression::getVariables() const { + return this->getBaseExpression().getVariables(); + } + + std::set Expression::getConstants() const { + return this->getBaseExpression().getConstants(); } BaseExpression const& Expression::getBaseExpression() const { return *this->expressionPtr; } - BaseExpression const* Expression::getBaseExpressionPointer() const { - return this->expressionPtr.get(); + std::shared_ptr const& Expression::getBaseExpressionPointer() const { + return this->expressionPtr; + } + + ExpressionReturnType Expression::getReturnType() const { + return this->getBaseExpression().getReturnType(); + } + + bool Expression::hasNumericalReturnType() const { + return this->getReturnType() == ExpressionReturnType::Int || this->getReturnType() == ExpressionReturnType::Double; + } + + bool Expression::hasBooleanReturnType() const { + return this->getReturnType() == ExpressionReturnType::Bool; + } + + Expression Expression::createBooleanLiteral(bool value) { + return Expression(std::shared_ptr(new BooleanLiteralExpression(value))); + } + + Expression Expression::createTrue() { + return createBooleanLiteral(true); + } + + Expression Expression::createFalse() { + return createBooleanLiteral(false); + } + + Expression Expression::createIntegerLiteral(int_fast64_t value) { + return Expression(std::shared_ptr(new IntegerLiteralExpression(value))); + } + + Expression Expression::createDoubleLiteral(double value) { + return Expression(std::shared_ptr(new DoubleLiteralExpression(value))); + } + + Expression Expression::createBooleanVariable(std::string const& variableName) { + return Expression(std::shared_ptr(new VariableExpression(ExpressionReturnType::Bool, variableName))); + } + + Expression Expression::createIntegerVariable(std::string const& variableName) { + return Expression(std::shared_ptr(new VariableExpression(ExpressionReturnType::Int, variableName))); + } + + Expression Expression::createDoubleVariable(std::string const& variableName) { + return Expression(std::shared_ptr(new VariableExpression(ExpressionReturnType::Double, variableName))); + } + + Expression Expression::createBooleanConstant(std::string const& constantName) { + return Expression(std::shared_ptr(new BooleanConstantExpression(constantName))); + } + + Expression Expression::createIntegerConstant(std::string const& constantName) { + return Expression(std::shared_ptr(new IntegerConstantExpression(constantName))); + } + + Expression Expression::createDoubleConstant(std::string const& constantName) { + return Expression(std::shared_ptr(new DoubleConstantExpression(constantName))); + } + + Expression Expression::operator+(Expression const& other) const { + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '+' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Plus))); + } + + Expression Expression::operator-(Expression const& other) const { + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '-' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Minus))); + } + + Expression Expression::operator-() const { + LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '-' requires numerical operand."); + return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getReturnType(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Minus))); + } + + Expression Expression::operator*(Expression const& other) const { + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '*' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Times))); + } + + Expression Expression::operator/(Expression const& other) const { + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '/' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Divide))); + } + + Expression Expression::operator&&(Expression const& other) const { + LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '&&' requires boolean operands."); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::And))); + } + + Expression Expression::operator||(Expression const& other) const { + LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '||' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Or))); + } + + Expression Expression::operator!() const { + LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '!' requires boolean operand."); + return Expression(std::shared_ptr(new UnaryBooleanFunctionExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), UnaryBooleanFunctionExpression::OperatorType::Not))); + } + + Expression Expression::operator==(Expression const& other) const { + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '==' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Equal))); + } + + Expression Expression::operator!=(Expression const& other) const { + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '!=' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::NotEqual))); + } + + Expression Expression::operator>(Expression const& other) const { + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '>' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Greater))); + } + + Expression Expression::operator>=(Expression const& other) const { + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '>=' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::GreaterOrEqual))); + } + + Expression Expression::operator<(Expression const& other) const { + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '<' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Less))); + } + + Expression Expression::operator<=(Expression const& other) const { + LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '<=' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryRelationExpression(ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::LessOrEqual))); + } + + Expression Expression::minimum(Expression const& lhs, Expression const& rhs) { + LOG_THROW(lhs.hasNumericalReturnType() && rhs.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'min' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(lhs.getReturnType() == ExpressionReturnType::Int && rhs.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Min))); + } + + Expression Expression::maximum(Expression const& lhs, Expression const& rhs) { + LOG_THROW(lhs.hasNumericalReturnType() && rhs.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'max' requires numerical operands."); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(lhs.getReturnType() == ExpressionReturnType::Int && rhs.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Max))); + } + + Expression Expression::floor() const { + LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'floor' requires numerical operand."); + return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Floor))); + } + + Expression Expression::ceil() const { + LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'ceil' requires numerical operand."); + return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Ceil))); } template Expression Expression::substitute(std::map const&) const; template Expression Expression::substitute(std::unordered_map const&) const; + + std::ostream& operator<<(std::ostream& stream, Expression const& expression) { + stream << expression.getBaseExpression(); + return stream; + } } } diff --git a/src/storage/expressions/Expression.h b/src/storage/expressions/Expression.h index c866d64cb..52e656a15 100644 --- a/src/storage/expressions/Expression.h +++ b/src/storage/expressions/Expression.h @@ -11,29 +11,53 @@ namespace storm { public: Expression() = default; + /*! + * Creates an expression with the given underlying base expression. + * + * @param expressionPtr A pointer to the underlying base expression. + */ + Expression(std::shared_ptr const& expressionPtr); + + // Instantiate constructors and assignments with their default implementations. + Expression(Expression const& other) = default; + Expression& operator=(Expression const& other) = default; + Expression(Expression&&) = default; + Expression& operator=(Expression&&) = default; + // Static factory methods to create atomic expression parts. + static Expression createBooleanLiteral(bool value); + static Expression createTrue(); + static Expression createFalse(); + static Expression createIntegerLiteral(int_fast64_t value); + static Expression createDoubleLiteral(double value); + static Expression createBooleanVariable(std::string const& variableName); + static Expression createIntegerVariable(std::string const& variableName); + static Expression createDoubleVariable(std::string const& variableName); + static Expression createBooleanConstant(std::string const& constantName); + static Expression createIntegerConstant(std::string const& constantName); + static Expression createDoubleConstant(std::string const& constantName); - // Virtual operator overloading. + // Provide operator overloads to conveniently construct new expressions from other expressions. Expression operator+(Expression const& other) const; Expression operator-(Expression const& other) const; Expression operator-() const; Expression operator*(Expression const& other) const; Expression operator/(Expression const& other) const; - Expression operator&(Expression const& other) const; - Expression operator|(Expression const& other) const; - Expression operator~() const; - - Expression equals(Expression const& other) const; - Expression notEquals(Expression const& other) const; - Expression greater(Expression const& other) const; - Expression greaterOrEqual(Expression const& other) const; - Expression less(Expression const& other) const; - Expression lessOrEqual(Expression const& other) const; - Expression minimum(Expression const& other) const; - Expression maximum(Expression const& other) const; - Expression mod(Expression const& other) const; + Expression operator&&(Expression const& other) const; + Expression operator||(Expression const& other) const; + Expression operator!() const; + Expression operator==(Expression const& other) const; + Expression operator!=(Expression const& other) const; + Expression operator>(Expression const& other) const; + Expression operator>=(Expression const& other) const; + Expression operator<(Expression const& other) const; + Expression operator<=(Expression const& other) const; + Expression floor() const; Expression ceil() const; + + static Expression minimum(Expression const& lhs, Expression const& rhs); + static Expression maximum(Expression const& lhs, Expression const& rhs); /*! * Substitutes all occurrences of identifiers according to the given map. Note that this substitution is @@ -48,19 +72,76 @@ namespace storm { Expression substitute(MapType const& identifierToExpressionMap) const; /*! - * Retrieves the return type of the expression. + * Evaluates the expression under the valuation of unknowns (variables and constants) given by the + * valuation and returns the resulting boolean value. If the return type of the expression is not a boolean + * an exception is thrown. * - * @return The return type of the expression. + * @param valuation The valuation of unknowns under which to evaluate the expression. + * @return The boolean value of the expression under the given valuation. */ - ExpressionReturnType getReturnType() const; + bool evaluateAsBool(Valuation const& valuation) const; - private: /*! - * Creates an expression with the given underlying base expression. + * Evaluates the expression under the valuation of unknowns (variables and constants) given by the + * valuation and returns the resulting integer value. If the return type of the expression is not an integer + * an exception is thrown. * - * @param expressionPtr A pointer to the underlying base expression. + * @param valuation The valuation of unknowns under which to evaluate the expression. + * @return The integer value of the expression under the given valuation. + */ + int_fast64_t evaluateAsInt(Valuation const& valuation) const; + + /*! + * Evaluates the expression under the valuation of unknowns (variables and constants) given by the + * valuation and returns the resulting double value. If the return type of the expression is not a double + * an exception is thrown. + * + * @param valuation The valuation of unknowns under which to evaluate the expression. + * @return The double value of the expression under the given valuation. + */ + double evaluateAsDouble(Valuation const& valuation) const; + + /*! + * Simplifies the expression according to some basic rules. + * + * @return The simplified expression. + */ + Expression simplify(); + + /*! + * Retrieves whether the expression is constant, i.e., contains no variables or undefined constants. + * + * @return True iff the expression is constant. + */ + bool isConstant() const; + + /*! + * Checks if the expression is equal to the boolean literal true. + * + * @return True iff the expression is equal to the boolean literal true. + */ + bool isTrue() const; + + /*! + * Checks if the expression is equal to the boolean literal false. + * + * @return True iff the expression is equal to the boolean literal false. */ - Expression(std::unique_ptr&& expressionPtr); + bool isFalse() const; + + /*! + * Retrieves the set of all variables that appear in the expression. + * + * @return The set of all variables that appear in the expression. + */ + std::set getVariables() const; + + /*! + * Retrieves the set of all constants that appear in the expression. + * + * @return The set of all constants that appear in the expression. + */ + std::set getConstants() const; /*! * Retrieves the base expression underlying this expression object. Note that prior to calling this, the @@ -75,10 +156,34 @@ namespace storm { * * @return A pointer to the underlying base expression. */ - BaseExpression const* getBaseExpressionPointer() const; + std::shared_ptr const& getBaseExpressionPointer() const; + + /*! + * Retrieves the return type of the expression. + * + * @return The return type of the expression. + */ + ExpressionReturnType getReturnType() const; + /*! + * Retrieves whether the expression has a numerical return type, i.e., integer or double. + * + * @return True iff the expression has a numerical return type. + */ + bool hasNumericalReturnType() const; + + /*! + * Retrieves whether the expression has a boolean return type. + * + * @return True iff the expression has a boolean return type. + */ + bool hasBooleanReturnType() const; + + friend std::ostream& operator<<(std::ostream& stream, Expression const& expression); + + private: // A pointer to the underlying base expression. - std::unique_ptr expressionPtr; + std::shared_ptr expressionPtr; }; } } diff --git a/src/storage/expressions/ExpressionVisitor.h b/src/storage/expressions/ExpressionVisitor.h index 4bd2e8d29..8e6dca24e 100644 --- a/src/storage/expressions/ExpressionVisitor.h +++ b/src/storage/expressions/ExpressionVisitor.h @@ -19,6 +19,7 @@ namespace storm { class DoubleLiteralExpression; class ExpressionVisitor { + public: virtual void visit(BinaryBooleanFunctionExpression const* expression) = 0; virtual void visit(BinaryNumericalFunctionExpression const* expression) = 0; virtual void visit(BinaryRelationExpression const* expression) = 0; diff --git a/src/storage/expressions/IntegerConstantExpression.cpp b/src/storage/expressions/IntegerConstantExpression.cpp index 87e109423..92bedebc9 100644 --- a/src/storage/expressions/IntegerConstantExpression.cpp +++ b/src/storage/expressions/IntegerConstantExpression.cpp @@ -6,7 +6,7 @@ namespace storm { // Intentionally left empty. } - int_fast64_t IntegerConstantExpression::evaluateAsInteger(Valuation const& valuation) const { + int_fast64_t IntegerConstantExpression::evaluateAsInt(Valuation const& valuation) const { return valuation.getIntegerValue(this->getConstantName()); } @@ -14,8 +14,8 @@ namespace storm { return static_cast(valuation.getIntegerValue(this->getConstantName())); } - std::unique_ptr IntegerConstantExpression::clone() const { - return std::unique_ptr(new IntegerConstantExpression(*this)); + std::shared_ptr IntegerConstantExpression::simplify() const { + return this->shared_from_this(); } void IntegerConstantExpression::accept(ExpressionVisitor* visitor) const { diff --git a/src/storage/expressions/IntegerConstantExpression.h b/src/storage/expressions/IntegerConstantExpression.h index 402679163..05deaf49c 100644 --- a/src/storage/expressions/IntegerConstantExpression.h +++ b/src/storage/expressions/IntegerConstantExpression.h @@ -22,9 +22,9 @@ namespace storm { virtual ~IntegerConstantExpression() = default; // Override base class methods. - virtual int_fast64_t evaluateAsInteger(Valuation const& valuation) const; - virtual double evaluateAsDouble(Valuation const& valuation) const; - virtual std::unique_ptr clone() const; + virtual int_fast64_t evaluateAsInt(Valuation const& valuation) const override; + virtual double evaluateAsDouble(Valuation const& valuation) const override; + virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const; }; } diff --git a/src/storage/expressions/IntegerLiteralExpression.cpp b/src/storage/expressions/IntegerLiteralExpression.cpp index ad74e58bd..16adeadc6 100644 --- a/src/storage/expressions/IntegerLiteralExpression.cpp +++ b/src/storage/expressions/IntegerLiteralExpression.cpp @@ -2,7 +2,7 @@ namespace storm { namespace expressions { - IntegerLiteralExpression::IntegerLiteralExpression(int_fast64_t value) : value(value) { + IntegerLiteralExpression::IntegerLiteralExpression(int_fast64_t value) : BaseExpression(ExpressionReturnType::Int), value(value) { // Intentionally left empty. } @@ -19,27 +19,27 @@ namespace storm { } std::set IntegerLiteralExpression::getVariables() const { - return {}; + return std::set(); } std::set IntegerLiteralExpression::getConstants() const { - return {}; + return std::set(); } - std::unique_ptr IntegerLiteralExpression::simplify() const { - return this->clone(); + std::shared_ptr IntegerLiteralExpression::simplify() const { + return this->shared_from_this(); } void IntegerLiteralExpression::accept(ExpressionVisitor* visitor) const { visitor->visit(this); } - std::unique_ptr IntegerLiteralExpression::clone() const { - return std::unique_ptr(new IntegerLiteralExpression(*this)); - } - int_fast64_t IntegerLiteralExpression::getValue() const { return this->value; } + + void IntegerLiteralExpression::printToStream(std::ostream& stream) const { + stream << this->getValue(); + } } } \ No newline at end of file diff --git a/src/storage/expressions/IntegerLiteralExpression.h b/src/storage/expressions/IntegerLiteralExpression.h index 1fc1b03f7..1b71ea306 100644 --- a/src/storage/expressions/IntegerLiteralExpression.h +++ b/src/storage/expressions/IntegerLiteralExpression.h @@ -5,7 +5,7 @@ namespace storm { namespace expressions { - class IntegerLiteralExpression : BaseExpression { + class IntegerLiteralExpression : public BaseExpression { public: /*! * Creates an integer literal expression with the given value. @@ -27,9 +27,8 @@ namespace storm { virtual bool isConstant() const override; virtual std::set getVariables() const override; virtual std::set getConstants() const override; - virtual std::unique_ptr simplify() const override; + virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; - virtual std::unique_ptr clone() const override; /*! * Retrieves the value of the integer literal. @@ -38,6 +37,10 @@ namespace storm { */ int_fast64_t getValue() const; + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + private: // The value of the integer literal. int_fast64_t value; diff --git a/src/storage/expressions/SimpleValuation.cpp b/src/storage/expressions/SimpleValuation.cpp index ea829557e..e2e230b64 100644 --- a/src/storage/expressions/SimpleValuation.cpp +++ b/src/storage/expressions/SimpleValuation.cpp @@ -2,7 +2,7 @@ namespace storm { namespace expressions { - SimpleValuation::SimpleValuation(std::size_t booleanVariableCount, std::size_t integerVariableCount, std::size_t doubleVariableCount) : identifierToIndexMap(), booleanValues(booleanVariableCount), integerValues(integerVariableCount), doubleValues(doubleVariableCount) { + SimpleValuation::SimpleValuation(std::size_t booleanVariableCount, std::size_t integerVariableCount, std::size_t doubleVariableCount) : identifierToIndexMap(new std::unordered_map), booleanValues(booleanVariableCount), integerValues(integerVariableCount), doubleValues(doubleVariableCount) { // Intentionally left empty. } @@ -15,15 +15,15 @@ namespace storm { } void SimpleValuation::setBooleanValue(std::string const& name, bool value) { - this->booleanValues[(*this->identifierToIndexMap)[name]] = value; + this->booleanValues[this->identifierToIndexMap->at(name)] = value; } void SimpleValuation::setIntegerValue(std::string const& name, int_fast64_t value) { - this->integerValues[(*this->identifierToIndexMap)[name]] = value; + this->integerValues[this->identifierToIndexMap->at(name)] = value; } void SimpleValuation::setDoubleValue(std::string const& name, double value) { - this->doubleValues[(*this->identifierToIndexMap)[name]] = value; + this->doubleValues[this->identifierToIndexMap->at(name)] = value; } bool SimpleValuation::getBooleanValue(std::string const& name) const { @@ -40,5 +40,23 @@ namespace storm { auto const& nameIndexPair = this->identifierToIndexMap->find(name); return this->doubleValues[nameIndexPair->second]; } + + std::ostream& operator<<(std::ostream& stream, SimpleValuation const& valuation) { + stream << "valuation { bool["; + for (uint_fast64_t i = 0; i < valuation.booleanValues.size() - 1; ++i) { + stream << valuation.booleanValues[i] << ", "; + } + stream << valuation.booleanValues.back() << "] ints["; + for (uint_fast64_t i = 0; i < valuation.integerValues.size() - 1; ++i) { + stream << valuation.integerValues[i] << ", "; + } + stream << valuation.integerValues.back() << "] double["; + for (uint_fast64_t i = 0; i < valuation.doubleValues.size() - 1; ++i) { + stream << valuation.doubleValues[i] << ", "; + } + stream << valuation.doubleValues.back() << "] }"; + + return stream; + } } } \ No newline at end of file diff --git a/src/storage/expressions/SimpleValuation.h b/src/storage/expressions/SimpleValuation.h index f4808face..8f0e9d161 100644 --- a/src/storage/expressions/SimpleValuation.h +++ b/src/storage/expressions/SimpleValuation.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "src/storage/expressions/Valuation.h" @@ -76,6 +77,8 @@ namespace storm { virtual int_fast64_t getIntegerValue(std::string const& name) const override; virtual double getDoubleValue(std::string const& name) const override; + friend std::ostream& operator<<(std::ostream& stream, SimpleValuation const& valuation); + private: // A mapping of identifiers to their local indices in the value containers. std::shared_ptr> identifierToIndexMap; diff --git a/src/storage/expressions/SubstitutionVisitor.cpp b/src/storage/expressions/SubstitutionVisitor.cpp index 8c478d682..878d06036 100644 --- a/src/storage/expressions/SubstitutionVisitor.cpp +++ b/src/storage/expressions/SubstitutionVisitor.cpp @@ -3,15 +3,175 @@ #include "src/storage/expressions/SubstitutionVisitor.h" +#include "src/storage/expressions/BinaryBooleanFunctionExpression.h" +#include "src/storage/expressions/BinaryNumericalFunctionExpression.h" +#include "src/storage/expressions/BinaryRelationExpression.h" +#include "src/storage/expressions/BooleanConstantExpression.h" +#include "src/storage/expressions/IntegerConstantExpression.h" +#include "src/storage/expressions/DoubleConstantExpression.h" +#include "src/storage/expressions/BooleanLiteralExpression.h" +#include "src/storage/expressions/IntegerLiteralExpression.h" +#include "src/storage/expressions/DoubleLiteralExpression.h" +#include "src/storage/expressions/VariableExpression.h" +#include "src/storage/expressions/UnaryBooleanFunctionExpression.h" +#include "src/storage/expressions/UnaryNumericalFunctionExpression.h" + namespace storm { namespace expressions { template class MapType> - Expression SubstitutionVisitor::substitute(BaseExpression const* expression, MapType const& identifierToExpressionMap) { - return Expression(); + SubstitutionVisitor::SubstitutionVisitor(MapType const& identifierToExpressionMap) : identifierToExpressionMap(identifierToExpressionMap) { + // Intentionally left empty. + } + + template class MapType> + Expression SubstitutionVisitor::substitute(BaseExpression const* expression) { + expression->accept(this); + return Expression(this->expressionStack.top()); + } + + template class MapType> + void SubstitutionVisitor::visit(BinaryBooleanFunctionExpression const* expression) { + expression->getFirstOperand()->accept(this); + std::shared_ptr firstExpression = expressionStack.top(); + expressionStack.pop(); + + expression->getSecondOperand()->accept(this); + std::shared_ptr secondExpression = expressionStack.top(); + expressionStack.pop(); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) { + this->expressionStack.push(expression->getSharedPointer()); + } else { + this->expressionStack.push(std::shared_ptr(new BinaryBooleanFunctionExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getOperatorType()))); + } + } + + template class MapType> + void SubstitutionVisitor::visit(BinaryNumericalFunctionExpression const* expression) { + expression->getFirstOperand()->accept(this); + std::shared_ptr firstExpression = expressionStack.top(); + expressionStack.pop(); + + expression->getSecondOperand()->accept(this); + std::shared_ptr secondExpression = expressionStack.top(); + expressionStack.pop(); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) { + this->expressionStack.push(expression->getSharedPointer()); + } else { + this->expressionStack.push(std::shared_ptr(new BinaryNumericalFunctionExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getOperatorType()))); + } + } + + template class MapType> + void SubstitutionVisitor::visit(BinaryRelationExpression const* expression) { + expression->getFirstOperand()->accept(this); + std::shared_ptr firstExpression = expressionStack.top(); + expressionStack.pop(); + + expression->getSecondOperand()->accept(this); + std::shared_ptr secondExpression = expressionStack.top(); + expressionStack.pop(); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) { + this->expressionStack.push(expression->getSharedPointer()); + } else { + this->expressionStack.push(std::shared_ptr(new BinaryRelationExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getRelationType()))); + } + } + + template class MapType> + void SubstitutionVisitor::visit(BooleanConstantExpression const* expression) { + // If the boolean constant is in the key set of the substitution, we need to replace it. + auto const& nameExpressionPair = this->identifierToExpressionMap.find(expression->getConstantName()); + if (nameExpressionPair != this->identifierToExpressionMap.end()) { + this->expressionStack.push(nameExpressionPair->second.getBaseExpressionPointer()); + } else { + this->expressionStack.push(expression->getSharedPointer()); + } + } + + template class MapType> + void SubstitutionVisitor::visit(DoubleConstantExpression const* expression) { + // If the double constant is in the key set of the substitution, we need to replace it. + auto const& nameExpressionPair = this->identifierToExpressionMap.find(expression->getConstantName()); + if (nameExpressionPair != this->identifierToExpressionMap.end()) { + this->expressionStack.push(nameExpressionPair->second.getBaseExpressionPointer()); + } else { + this->expressionStack.push(expression->getSharedPointer()); + } + } + + template class MapType> + void SubstitutionVisitor::visit(IntegerConstantExpression const* expression) { + // If the integer constant is in the key set of the substitution, we need to replace it. + auto const& nameExpressionPair = this->identifierToExpressionMap.find(expression->getConstantName()); + if (nameExpressionPair != this->identifierToExpressionMap.end()) { + this->expressionStack.push(nameExpressionPair->second.getBaseExpressionPointer()); + } else { + this->expressionStack.push(expression->getSharedPointer()); + } + } + + template class MapType> + void SubstitutionVisitor::visit(VariableExpression const* expression) { + // If the variable is in the key set of the substitution, we need to replace it. + auto const& nameExpressionPair = this->identifierToExpressionMap.find(expression->getVariableName()); + if (nameExpressionPair != this->identifierToExpressionMap.end()) { + this->expressionStack.push(nameExpressionPair->second.getBaseExpressionPointer()); + } else { + this->expressionStack.push(expression->getSharedPointer()); + } + } + + template class MapType> + void SubstitutionVisitor::visit(UnaryBooleanFunctionExpression const* expression) { + expression->getOperand()->accept(this); + std::shared_ptr operandExpression = expressionStack.top(); + expressionStack.pop(); + + // If the argument did not change, we simply push the expression itself. + if (operandExpression.get() == expression->getOperand().get()) { + expressionStack.push(expression->getSharedPointer()); + } else { + expressionStack.push(std::shared_ptr(new UnaryBooleanFunctionExpression(expression->getReturnType(), operandExpression, expression->getOperatorType()))); + } + } + + template class MapType> + void SubstitutionVisitor::visit(UnaryNumericalFunctionExpression const* expression) { + expression->getOperand()->accept(this); + std::shared_ptr operandExpression = expressionStack.top(); + expressionStack.pop(); + + // If the argument did not change, we simply push the expression itself. + if (operandExpression.get() == expression->getOperand().get()) { + expressionStack.push(expression->getSharedPointer()); + } else { + expressionStack.push(std::shared_ptr(new UnaryNumericalFunctionExpression(expression->getReturnType(), operandExpression, expression->getOperatorType()))); + } + } + + template class MapType> + void SubstitutionVisitor::visit(BooleanLiteralExpression const* expression) { + this->expressionStack.push(expression->getSharedPointer()); + } + + template class MapType> + void SubstitutionVisitor::visit(IntegerLiteralExpression const* expression) { + this->expressionStack.push(expression->getSharedPointer()); + } + + template class MapType> + void SubstitutionVisitor::visit(DoubleLiteralExpression const* expression) { + this->expressionStack.push(expression->getSharedPointer()); } - // Explicitly instantiate substitute with map and unordered_map. - template Expression SubstitutionVisitor::substitute(BaseExpression const* expression, std::map const& identifierToExpressionMap); - template Expression SubstitutionVisitor::substitute(BaseExpression const* expression, std::unordered_map const& identifierToExpressionMap); + // Explicitly instantiate the class with map and unordered_map. + template class SubstitutionVisitor; + template class SubstitutionVisitor; } } diff --git a/src/storage/expressions/SubstitutionVisitor.h b/src/storage/expressions/SubstitutionVisitor.h index db2947e2e..4c3abc1e7 100644 --- a/src/storage/expressions/SubstitutionVisitor.h +++ b/src/storage/expressions/SubstitutionVisitor.h @@ -1,15 +1,52 @@ #ifndef STORM_STORAGE_EXPRESSIONS_SUBSTITUTIONVISITOR_H_ #define STORM_STORAGE_EXPRESSIONS_SUBSTITUTIONVISITOR_H_ +#include + #include "src/storage/expressions/Expression.h" #include "src/storage/expressions/ExpressionVisitor.h" namespace storm { namespace expressions { + template class MapType> class SubstitutionVisitor : public ExpressionVisitor { public: - template class MapType> - Expression substitute(BaseExpression const* expression, MapType const& identifierToExpressionMap); + /*! + * Creates a new substitution visitor that uses the given map to replace identifiers. + * + * @param identifierToExpressionMap A mapping from identifiers to expressions. + */ + SubstitutionVisitor(MapType const& identifierToExpressionMap); + + /*! + * Substitutes the identifiers in the given expression according to the previously given map and returns the + * resulting expression. + * + * @param expression The expression in which to substitute the identifiers. + * @return The expression in which all identifiers in the key set of the previously given mapping are + * substituted with the mapped-to expressions. + */ + Expression substitute(BaseExpression const* expression); + + virtual void visit(BinaryBooleanFunctionExpression const* expression) override; + virtual void visit(BinaryNumericalFunctionExpression const* expression) override; + virtual void visit(BinaryRelationExpression const* expression) override; + virtual void visit(BooleanConstantExpression const* expression) override; + virtual void visit(DoubleConstantExpression const* expression) override; + virtual void visit(IntegerConstantExpression const* expression) override; + virtual void visit(VariableExpression const* expression) override; + virtual void visit(UnaryBooleanFunctionExpression const* expression) override; + virtual void visit(UnaryNumericalFunctionExpression const* expression) override; + virtual void visit(BooleanLiteralExpression const* expression) override; + virtual void visit(IntegerLiteralExpression const* expression) override; + virtual void visit(DoubleLiteralExpression const* expression) override; + + private: + // A stack of expression used to pass the results to the higher levels. + std::stack> expressionStack; + + // A mapping of identifier names to expressions with which they shall be replaced. + MapType const& identifierToExpressionMap; }; } } diff --git a/src/storage/expressions/UnaryBooleanFunctionExpression.cpp b/src/storage/expressions/UnaryBooleanFunctionExpression.cpp index d25460c3c..0b12b5bf6 100644 --- a/src/storage/expressions/UnaryBooleanFunctionExpression.cpp +++ b/src/storage/expressions/UnaryBooleanFunctionExpression.cpp @@ -1,42 +1,50 @@ #include "src/storage/expressions/UnaryBooleanFunctionExpression.h" #include "src/storage/expressions/BooleanLiteralExpression.h" +#include "src/exceptions/ExceptionMacros.h" +#include "src/exceptions/InvalidTypeException.h" namespace storm { namespace expressions { - UnaryBooleanFunctionExpression::UnaryBooleanFunctionExpression(ExpressionReturnType returnType, std::unique_ptr&& operand, OperatorType operatorType) : UnaryExpression(returnType, std::move(operand)), operatorType(operatorType) { + UnaryBooleanFunctionExpression::UnaryBooleanFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType) : UnaryExpression(returnType, operand), operatorType(operatorType) { // Intentionally left empty. } - OperatorType UnaryBooleanFunctionExpression::getOperatorType() const { + UnaryBooleanFunctionExpression::OperatorType UnaryBooleanFunctionExpression::getOperatorType() const { return this->operatorType; } bool UnaryBooleanFunctionExpression::evaluateAsBool(Valuation const& valuation) const { + LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as boolean."); + bool operandEvaluated = this->getOperand()->evaluateAsBool(valuation); switch (this->getOperatorType()) { case OperatorType::Not: return !operandEvaluated; break; } } - std::unique_ptr UnaryBooleanFunctionExpression::simplify() const { - std::unique_ptr operandSimplified = this->getOperand()->simplify(); + std::shared_ptr UnaryBooleanFunctionExpression::simplify() const { + std::shared_ptr operandSimplified = this->getOperand()->simplify(); switch (this->getOperatorType()) { case OperatorType::Not: if (operandSimplified->isTrue()) { - return std::unique_ptr(new BooleanLiteralExpression(false)); + return std::shared_ptr(new BooleanLiteralExpression(false)); } else { - return std::unique_ptr(new BooleanLiteralExpression(true)); + return std::shared_ptr(new BooleanLiteralExpression(true)); } } - return UnaryBooleanFunctionExpression(this->getReturnType(), std::move(operandSimplified), this->getOperatorType()); + if (operandSimplified.get() == this->getOperand().get()) { + return this->shared_from_this(); + } else { + return std::shared_ptr(new UnaryBooleanFunctionExpression(this->getReturnType(), operandSimplified, this->getOperatorType())); + } } void UnaryBooleanFunctionExpression::accept(ExpressionVisitor* visitor) const { visitor->visit(this); } - std::unique_ptr UnaryBooleanFunctionExpression::clone() const { - return std::unique_ptr(new UnaryBooleanFunctionExpression(*this)); + void UnaryBooleanFunctionExpression::printToStream(std::ostream& stream) const { + stream << "!(" << *this->getOperand() << ")"; } } } \ No newline at end of file diff --git a/src/storage/expressions/UnaryBooleanFunctionExpression.h b/src/storage/expressions/UnaryBooleanFunctionExpression.h index 8453f99e0..fefbded86 100644 --- a/src/storage/expressions/UnaryBooleanFunctionExpression.h +++ b/src/storage/expressions/UnaryBooleanFunctionExpression.h @@ -6,10 +6,11 @@ namespace storm { namespace expressions { class UnaryBooleanFunctionExpression : public UnaryExpression { + public: /*! * An enum type specifying the different functions applicable. */ - enum class OperatorType {Not}; + enum class OperatorType { Not }; /*! * Creates a unary boolean function expression with the given return type, operand and operator. @@ -18,7 +19,7 @@ namespace storm { * @param operand The operand of the expression. * @param operatorType The operator of the expression. */ - UnaryBooleanFunctionExpression(ExpressionReturnType returnType, std::unique_ptr&& operand, OperatorType operatorType); + UnaryBooleanFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType); // Instantiate constructors and assignments with their default implementations. UnaryBooleanFunctionExpression(UnaryBooleanFunctionExpression const& other) = default; @@ -29,9 +30,8 @@ namespace storm { // Override base class methods. virtual bool evaluateAsBool(Valuation const& valuation) const override; - virtual std::unique_ptr simplify() const override; + virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; - virtual std::unique_ptr clone() const override; /*! * Retrieves the operator associated with this expression. @@ -40,6 +40,10 @@ namespace storm { */ OperatorType getOperatorType() const; + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + private: // The operator of this expression. OperatorType operatorType; diff --git a/src/storage/expressions/UnaryExpression.cpp b/src/storage/expressions/UnaryExpression.cpp index e0573d813..1e2d9135e 100644 --- a/src/storage/expressions/UnaryExpression.cpp +++ b/src/storage/expressions/UnaryExpression.cpp @@ -2,21 +2,9 @@ namespace storm { namespace expressions { - UnaryExpression::UnaryExpression(ExpressionReturnType returnType, std::unique_ptr&& operand) : BaseExpression(returnType), operand(std::move(operand)) { + UnaryExpression::UnaryExpression(ExpressionReturnType returnType, std::shared_ptr const& operand) : BaseExpression(returnType), operand(operand) { // Intentionally left empty. } - - UnaryExpression::UnaryExpression(UnaryExpression const& other) : BaseExpression(other), operand(other.getOperand()->clone()) { - // Intentionally left empty. - } - - UnaryExpression& UnaryExpression::operator=(UnaryExpression const& other) { - if (this != &other) { - BaseExpression::operator=(other); - this->operand = other.getOperand()->clone(); - } - return *this; - } bool UnaryExpression::isConstant() const { return this->getOperand()->isConstant(); @@ -29,5 +17,9 @@ namespace storm { std::set UnaryExpression::getConstants() const { return this->getOperand()->getVariables(); } + + std::shared_ptr const& UnaryExpression::getOperand() const { + return this->operand; + } } } \ No newline at end of file diff --git a/src/storage/expressions/UnaryExpression.h b/src/storage/expressions/UnaryExpression.h index c0d4acf33..686e97c4e 100644 --- a/src/storage/expressions/UnaryExpression.h +++ b/src/storage/expressions/UnaryExpression.h @@ -13,13 +13,11 @@ namespace storm { * @param returnType The return type of the expression. * @param operand The operand of the unary expression. */ - UnaryExpression(ExpressionReturnType returnType, std::unique_ptr&& operand); + UnaryExpression(ExpressionReturnType returnType, std::shared_ptr const& operand); - // Provide custom versions of copy construction and assignment. + // Instantiate constructors and assignments with their default implementations. UnaryExpression(UnaryExpression const& other); UnaryExpression& operator=(UnaryExpression const& other); - - // Create default variants of move construction/assignment and virtual destructor. UnaryExpression(UnaryExpression&&) = default; UnaryExpression& operator=(UnaryExpression&&) = default; virtual ~UnaryExpression() = default; @@ -34,11 +32,11 @@ namespace storm { * * @return The operand of the unary expression. */ - std::unique_ptr const& getOperand() const; + std::shared_ptr const& getOperand() const; private: // The operand of the unary expression. - std::unique_ptr operand; + std::shared_ptr operand; }; } } diff --git a/src/storage/expressions/UnaryNumericalFunctionExpression.cpp b/src/storage/expressions/UnaryNumericalFunctionExpression.cpp index fc8ef625b..31f9aeb1e 100644 --- a/src/storage/expressions/UnaryNumericalFunctionExpression.cpp +++ b/src/storage/expressions/UnaryNumericalFunctionExpression.cpp @@ -1,14 +1,22 @@ #include #include "src/storage/expressions/UnaryNumericalFunctionExpression.h" +#include "src/exceptions/ExceptionMacros.h" +#include "src/exceptions/InvalidTypeException.h" namespace storm { namespace expressions { - UnaryNumericalFunctionExpression::UnaryNumericalFunctionExpression(ExpressionReturnType returnType, std::unique_ptr&& operand, OperatorType operatorType) : UnaryExpression(returnType, std::move(operand)), operatorType(operatorType) { + UnaryNumericalFunctionExpression::UnaryNumericalFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType) : UnaryExpression(returnType, operand), operatorType(operatorType) { // Intentionally left empty. } + UnaryNumericalFunctionExpression::OperatorType UnaryNumericalFunctionExpression::getOperatorType() const { + return this->operatorType; + } + int_fast64_t UnaryNumericalFunctionExpression::evaluateAsInt(Valuation const& valuation) const { + LOG_THROW(this->hasIntegralReturnType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as integer."); + int_fast64_t operandEvaluated = this->getOperand()->evaluateAsInt(valuation); switch (this->getOperatorType()) { case OperatorType::Minus: return -operandEvaluated; break; @@ -18,6 +26,8 @@ namespace storm { } double UnaryNumericalFunctionExpression::evaluateAsDouble(Valuation const& valuation) const { + LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as double."); + double operandEvaluated = this->getOperand()->evaluateAsDouble(valuation); switch (this->getOperatorType()) { case OperatorType::Minus: return -operandEvaluated; break; @@ -26,16 +36,27 @@ namespace storm { } } - std::unique_ptr UnaryNumericalFunctionExpression::simplify() const { - return std::unique_ptr(new UnaryNumericalFunctionExpression(this->getReturnType(), this->getOperand()->simplify(), this->getOperatorType())); + std::shared_ptr UnaryNumericalFunctionExpression::simplify() const { + std::shared_ptr operandSimplified = this->getOperand()->simplify(); + + if (operandSimplified.get() == this->getOperand().get()) { + return this->shared_from_this(); + } else { + return std::shared_ptr(new UnaryNumericalFunctionExpression(this->getReturnType(), operandSimplified, this->getOperatorType())); + } } void UnaryNumericalFunctionExpression::accept(ExpressionVisitor* visitor) const { visitor->visit(this); } - std::unique_ptr UnaryNumericalFunctionExpression::clone() const { - return std::unique_ptr(new UnaryNumericalFunctionExpression(*this)); + void UnaryNumericalFunctionExpression::printToStream(std::ostream& stream) const { + switch (this->getOperatorType()) { + case OperatorType::Minus: stream << "-("; break; + case OperatorType::Floor: stream << "floor("; break; + case OperatorType::Ceil: stream << "ceil("; break; + } + stream << *this->getOperand() << ")"; } } } \ No newline at end of file diff --git a/src/storage/expressions/UnaryNumericalFunctionExpression.h b/src/storage/expressions/UnaryNumericalFunctionExpression.h index 7bbba69ca..2f50549ef 100644 --- a/src/storage/expressions/UnaryNumericalFunctionExpression.h +++ b/src/storage/expressions/UnaryNumericalFunctionExpression.h @@ -6,6 +6,7 @@ namespace storm { namespace expressions { class UnaryNumericalFunctionExpression : public UnaryExpression { + public: /*! * An enum type specifying the different functions applicable. */ @@ -18,7 +19,7 @@ namespace storm { * @param operand The operand of the expression. * @param operatorType The operator of the expression. */ - UnaryNumericalFunctionExpression(ExpressionReturnType returnType, std::unique_ptr&& operand, OperatorType operatorType); + UnaryNumericalFunctionExpression(ExpressionReturnType returnType, std::shared_ptr const& operand, OperatorType operatorType); // Instantiate constructors and assignments with their default implementations. UnaryNumericalFunctionExpression(UnaryNumericalFunctionExpression const& other) = default; @@ -30,10 +31,9 @@ namespace storm { // Override base class methods. virtual int_fast64_t evaluateAsInt(Valuation const& valuation) const override; virtual double evaluateAsDouble(Valuation const& valuation) const override; - virtual std::unique_ptr simplify() const override; + virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; - virtual std::unique_ptr clone() const override; - + /*! * Retrieves the operator associated with this expression. * @@ -41,6 +41,10 @@ namespace storm { */ OperatorType getOperatorType() const; + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + private: // The operator of this expression. OperatorType operatorType; diff --git a/src/storage/expressions/VariableExpression.cpp b/src/storage/expressions/VariableExpression.cpp index 6cc96f799..b7a6ce6d2 100644 --- a/src/storage/expressions/VariableExpression.cpp +++ b/src/storage/expressions/VariableExpression.cpp @@ -1,5 +1,6 @@ #include "src/storage/expressions/VariableExpression.h" #include "src/exceptions/ExceptionMacros.h" +#include "src/exceptions/InvalidTypeException.h" namespace storm { namespace expressions { @@ -11,19 +12,27 @@ namespace storm { return this->variableName; } - int_fast64_t VariableExpression::evaluateAsInt(Valuation const& evaluation) const { - LOG_ASSERT((this->getReturnType() == ExpressionReturnType::Int), "Cannot evaluate expression as integer: return type is not an integer."); - return evaluation.getIntegerValue(this->getVariableName()); - } - bool VariableExpression::evaluateAsBool(Valuation const& evaluation) const { - LOG_ASSERT((this->getReturnType() == ExpressionReturnType::Bool), "Cannot evaluate expression as integer: return type is not a boolean."); + LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as boolean: return type is not a boolean."); + return evaluation.getBooleanValue(this->getVariableName()); } + + int_fast64_t VariableExpression::evaluateAsInt(Valuation const& evaluation) const { + LOG_THROW(this->hasIntegralReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as integer: return type is not an integer."); + + return evaluation.getIntegerValue(this->getVariableName()); + } double VariableExpression::evaluateAsDouble(Valuation const& evaluation) const { - LOG_ASSERT((this->getReturnType() == ExpressionReturnType::Double), "Cannot evaluate expression as integer: return type is not a double."); - return evaluation.getDoubleValue(this->getVariableName()); + LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as double: return type is not a double."); + + switch (this->getReturnType()) { + case ExpressionReturnType::Int: return static_cast(evaluation.getIntegerValue(this->getVariableName())); break; + case ExpressionReturnType::Double: evaluation.getDoubleValue(this->getVariableName()); break; + default: break; + } + LOG_THROW(false, storm::exceptions::InvalidTypeException, "Type of variable is required to be numeric."); } std::set VariableExpression::getVariables() const { @@ -34,16 +43,16 @@ namespace storm { return std::set(); } - std::unique_ptr VariableExpression::simplify() const { - return this->clone(); + std::shared_ptr VariableExpression::simplify() const { + return this->shared_from_this(); } void VariableExpression::accept(ExpressionVisitor* visitor) const { visitor->visit(this); } - std::unique_ptr VariableExpression::clone() const { - return std::unique_ptr(new VariableExpression(*this)); + void VariableExpression::printToStream(std::ostream& stream) const { + stream << this->getVariableName(); } } } \ No newline at end of file diff --git a/src/storage/expressions/VariableExpression.h b/src/storage/expressions/VariableExpression.h index 5aa89e4ce..7e4732eda 100644 --- a/src/storage/expressions/VariableExpression.h +++ b/src/storage/expressions/VariableExpression.h @@ -6,6 +6,13 @@ namespace storm { namespace expressions { class VariableExpression : public BaseExpression { + public: + /*! + * Creates a variable expression with the given return type and variable name. + * + * @param returnType The return type of the variable expression. + * @param variableName The name of the variable associated with this expression. + */ VariableExpression(ExpressionReturnType returnType, std::string const& variableName); // Instantiate constructors and assignments with their default implementations. @@ -21,9 +28,8 @@ namespace storm { virtual double evaluateAsDouble(Valuation const& valuation) const override; virtual std::set getVariables() const override; virtual std::set getConstants() const override; - virtual std::unique_ptr simplify() const override; + virtual std::shared_ptr simplify() const override; virtual void accept(ExpressionVisitor* visitor) const override; - virtual std::unique_ptr clone() const override; /*! * Retrieves the name of the variable associated with this expression. @@ -32,6 +38,10 @@ namespace storm { */ std::string const& getVariableName() const; + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + private: // The variable name associated with this expression. std::string variableName; diff --git a/test/functional/storage/ExpressionTest.cpp b/test/functional/storage/ExpressionTest.cpp index 9d9307bea..a7f789769 100644 --- a/test/functional/storage/ExpressionTest.cpp +++ b/test/functional/storage/ExpressionTest.cpp @@ -1,9 +1,358 @@ -#include "gtest/gtest.h" - #include +#include "gtest/gtest.h" +#include "src/storage/expressions/Expression.h" #include "src/storage/expressions/SimpleValuation.h" +#include "src/exceptions/InvalidTypeException.h" + +TEST(Expression, FactoryMethodTest) { + EXPECT_NO_THROW(storm::expressions::Expression::createBooleanLiteral(true)); + EXPECT_NO_THROW(storm::expressions::Expression::createTrue()); + EXPECT_NO_THROW(storm::expressions::Expression::createFalse()); + EXPECT_NO_THROW(storm::expressions::Expression::createIntegerLiteral(3)); + EXPECT_NO_THROW(storm::expressions::Expression::createDoubleLiteral(3.14)); + EXPECT_NO_THROW(storm::expressions::Expression::createBooleanVariable("x")); + EXPECT_NO_THROW(storm::expressions::Expression::createIntegerVariable("y")); + EXPECT_NO_THROW(storm::expressions::Expression::createDoubleVariable("z")); + EXPECT_NO_THROW(storm::expressions::Expression::createBooleanConstant("a")); + EXPECT_NO_THROW(storm::expressions::Expression::createIntegerConstant("b")); + EXPECT_NO_THROW(storm::expressions::Expression::createDoubleConstant("c")); +} + +TEST(Expression, AccessorTest) { + storm::expressions::Expression trueExpression; + storm::expressions::Expression falseExpression; + storm::expressions::Expression threeExpression; + storm::expressions::Expression piExpression; + storm::expressions::Expression boolVarExpression; + storm::expressions::Expression intVarExpression; + storm::expressions::Expression doubleVarExpression; + storm::expressions::Expression boolConstExpression; + storm::expressions::Expression intConstExpression; + storm::expressions::Expression doubleConstExpression; + + ASSERT_NO_THROW(trueExpression = storm::expressions::Expression::createTrue()); + ASSERT_NO_THROW(falseExpression = storm::expressions::Expression::createFalse()); + ASSERT_NO_THROW(threeExpression = storm::expressions::Expression::createIntegerLiteral(3)); + ASSERT_NO_THROW(piExpression = storm::expressions::Expression::createDoubleLiteral(3.14)); + ASSERT_NO_THROW(boolVarExpression = storm::expressions::Expression::createBooleanVariable("x")); + ASSERT_NO_THROW(intVarExpression = storm::expressions::Expression::createIntegerVariable("y")); + ASSERT_NO_THROW(doubleVarExpression = storm::expressions::Expression::createDoubleVariable("z")); + ASSERT_NO_THROW(boolConstExpression = storm::expressions::Expression::createBooleanConstant("a")); + ASSERT_NO_THROW(intConstExpression = storm::expressions::Expression::createIntegerConstant("b")); + ASSERT_NO_THROW(doubleConstExpression = storm::expressions::Expression::createDoubleConstant("c")); + + EXPECT_TRUE(trueExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + EXPECT_TRUE(trueExpression.isConstant()); + EXPECT_TRUE(trueExpression.isTrue()); + EXPECT_FALSE(trueExpression.isFalse()); + EXPECT_TRUE(trueExpression.getVariables() == std::set()); + EXPECT_TRUE(trueExpression.getConstants() == std::set()); + + EXPECT_TRUE(falseExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + EXPECT_TRUE(falseExpression.isConstant()); + EXPECT_FALSE(falseExpression.isTrue()); + EXPECT_TRUE(falseExpression.isFalse()); + EXPECT_TRUE(falseExpression.getVariables() == std::set()); + EXPECT_TRUE(falseExpression.getConstants() == std::set()); + + EXPECT_TRUE(threeExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + EXPECT_TRUE(threeExpression.isConstant()); + EXPECT_FALSE(threeExpression.isTrue()); + EXPECT_FALSE(threeExpression.isFalse()); + EXPECT_TRUE(threeExpression.getVariables() == std::set()); + EXPECT_TRUE(threeExpression.getConstants() == std::set()); + + EXPECT_TRUE(piExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + EXPECT_TRUE(piExpression.isConstant()); + EXPECT_FALSE(piExpression.isTrue()); + EXPECT_FALSE(piExpression.isFalse()); + EXPECT_TRUE(piExpression.getVariables() == std::set()); + EXPECT_TRUE(piExpression.getConstants() == std::set()); + + EXPECT_TRUE(boolVarExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + EXPECT_FALSE(boolVarExpression.isConstant()); + EXPECT_FALSE(boolVarExpression.isTrue()); + EXPECT_FALSE(boolVarExpression.isFalse()); + EXPECT_TRUE(boolVarExpression.getVariables() == std::set({"x"})); + EXPECT_TRUE(boolVarExpression.getConstants() == std::set()); + + EXPECT_TRUE(intVarExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + EXPECT_FALSE(intVarExpression.isConstant()); + EXPECT_FALSE(intVarExpression.isTrue()); + EXPECT_FALSE(intVarExpression.isFalse()); + EXPECT_TRUE(intVarExpression.getVariables() == std::set({"y"})); + EXPECT_TRUE(intVarExpression.getConstants() == std::set()); + + EXPECT_TRUE(doubleVarExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + EXPECT_FALSE(doubleVarExpression.isConstant()); + EXPECT_FALSE(doubleVarExpression.isTrue()); + EXPECT_FALSE(doubleVarExpression.isFalse()); + EXPECT_TRUE(doubleVarExpression.getVariables() == std::set({"z"})); + EXPECT_TRUE(doubleVarExpression.getConstants() == std::set()); + + EXPECT_TRUE(boolConstExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + EXPECT_FALSE(boolConstExpression.isConstant()); + EXPECT_FALSE(boolConstExpression.isTrue()); + EXPECT_FALSE(boolConstExpression.isFalse()); + EXPECT_TRUE(boolConstExpression.getVariables() == std::set()); + EXPECT_TRUE(boolConstExpression.getConstants() == std::set({"a"})); + + EXPECT_TRUE(intConstExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + EXPECT_FALSE(intConstExpression.isConstant()); + EXPECT_FALSE(intConstExpression.isTrue()); + EXPECT_FALSE(intConstExpression.isFalse()); + EXPECT_TRUE(intConstExpression.getVariables() == std::set()); + EXPECT_TRUE(intConstExpression.getConstants() == std::set({"b"})); + + EXPECT_TRUE(doubleConstExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + EXPECT_FALSE(doubleConstExpression.isConstant()); + EXPECT_FALSE(doubleConstExpression.isTrue()); + EXPECT_FALSE(doubleConstExpression.isFalse()); + EXPECT_TRUE(doubleConstExpression.getVariables() == std::set()); + EXPECT_TRUE(doubleConstExpression.getConstants() == std::set({"c"})); +} + +TEST(Expression, OperatorTest) { + storm::expressions::Expression trueExpression; + storm::expressions::Expression falseExpression; + storm::expressions::Expression threeExpression; + storm::expressions::Expression piExpression; + storm::expressions::Expression boolVarExpression; + storm::expressions::Expression intVarExpression; + storm::expressions::Expression doubleVarExpression; + storm::expressions::Expression boolConstExpression; + storm::expressions::Expression intConstExpression; + storm::expressions::Expression doubleConstExpression; + + ASSERT_NO_THROW(trueExpression = storm::expressions::Expression::createTrue()); + ASSERT_NO_THROW(falseExpression = storm::expressions::Expression::createFalse()); + ASSERT_NO_THROW(threeExpression = storm::expressions::Expression::createIntegerLiteral(3)); + ASSERT_NO_THROW(piExpression = storm::expressions::Expression::createDoubleLiteral(3.14)); + ASSERT_NO_THROW(boolVarExpression = storm::expressions::Expression::createBooleanVariable("x")); + ASSERT_NO_THROW(intVarExpression = storm::expressions::Expression::createIntegerVariable("y")); + ASSERT_NO_THROW(doubleVarExpression = storm::expressions::Expression::createDoubleVariable("z")); + ASSERT_NO_THROW(boolConstExpression = storm::expressions::Expression::createBooleanConstant("a")); + ASSERT_NO_THROW(intConstExpression = storm::expressions::Expression::createIntegerConstant("b")); + ASSERT_NO_THROW(doubleConstExpression = storm::expressions::Expression::createDoubleConstant("c")); + + storm::expressions::Expression tempExpression; + + ASSERT_THROW(tempExpression = trueExpression + piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression + threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + ASSERT_NO_THROW(tempExpression = threeExpression + piExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + ASSERT_NO_THROW(tempExpression = doubleVarExpression + doubleConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + + ASSERT_THROW(tempExpression = trueExpression - piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression - threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + ASSERT_NO_THROW(tempExpression = threeExpression - piExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + ASSERT_NO_THROW(tempExpression = doubleVarExpression - doubleConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + + ASSERT_THROW(tempExpression = -trueExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = -threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + ASSERT_NO_THROW(tempExpression = -piExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + ASSERT_NO_THROW(tempExpression = -doubleVarExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + + ASSERT_THROW(tempExpression = trueExpression * piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression * threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + ASSERT_NO_THROW(tempExpression = threeExpression * piExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + ASSERT_NO_THROW(tempExpression = intVarExpression * intConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + + ASSERT_THROW(tempExpression = trueExpression / piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression / threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + ASSERT_NO_THROW(tempExpression = threeExpression / piExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + ASSERT_NO_THROW(tempExpression = doubleVarExpression / intConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + + ASSERT_THROW(tempExpression = trueExpression && piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = trueExpression && falseExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + ASSERT_NO_THROW(tempExpression = boolVarExpression && boolConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + + ASSERT_THROW(tempExpression = trueExpression || piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = trueExpression || falseExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + ASSERT_NO_THROW(tempExpression = boolVarExpression || boolConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + + ASSERT_THROW(tempExpression = !threeExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = !trueExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + ASSERT_NO_THROW(tempExpression = !boolVarExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + + ASSERT_THROW(tempExpression = trueExpression == piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression == threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + ASSERT_NO_THROW(tempExpression = intVarExpression == doubleConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + + ASSERT_THROW(tempExpression = trueExpression != piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression != threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + ASSERT_NO_THROW(tempExpression = intVarExpression != doubleConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + + ASSERT_THROW(tempExpression = trueExpression > piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression > threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + ASSERT_NO_THROW(tempExpression = intVarExpression > doubleConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + + ASSERT_THROW(tempExpression = trueExpression >= piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression >= threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + ASSERT_NO_THROW(tempExpression = intVarExpression >= doubleConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + + ASSERT_THROW(tempExpression = trueExpression < piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression < threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + ASSERT_NO_THROW(tempExpression = intVarExpression < doubleConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + + ASSERT_THROW(tempExpression = trueExpression <= piExpression, storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression <= threeExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + ASSERT_NO_THROW(tempExpression = intVarExpression <= doubleConstExpression); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Bool); + + ASSERT_THROW(tempExpression = storm::expressions::Expression::minimum(trueExpression, piExpression), storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = storm::expressions::Expression::minimum(threeExpression, threeExpression)); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + ASSERT_NO_THROW(tempExpression = storm::expressions::Expression::minimum(intVarExpression, doubleConstExpression)); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + + ASSERT_THROW(tempExpression = storm::expressions::Expression::maximum(trueExpression, piExpression), storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = storm::expressions::Expression::maximum(threeExpression, threeExpression)); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + ASSERT_NO_THROW(tempExpression = storm::expressions::Expression::maximum(intVarExpression, doubleConstExpression)); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Double); + + ASSERT_THROW(tempExpression = trueExpression.floor(), storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression.floor()); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + ASSERT_NO_THROW(tempExpression = doubleConstExpression.floor()); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + + ASSERT_THROW(tempExpression = trueExpression.ceil(), storm::exceptions::InvalidTypeException); + ASSERT_NO_THROW(tempExpression = threeExpression.ceil()); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); + ASSERT_NO_THROW(tempExpression = doubleConstExpression.ceil()); + EXPECT_TRUE(tempExpression.getReturnType() == storm::expressions::ExpressionReturnType::Int); +} + +TEST(Expression, SubstitutionTest) { + storm::expressions::Expression trueExpression; + storm::expressions::Expression falseExpression; + storm::expressions::Expression threeExpression; + storm::expressions::Expression piExpression; + storm::expressions::Expression boolVarExpression; + storm::expressions::Expression intVarExpression; + storm::expressions::Expression doubleVarExpression; + storm::expressions::Expression boolConstExpression; + storm::expressions::Expression intConstExpression; + storm::expressions::Expression doubleConstExpression; + + ASSERT_NO_THROW(trueExpression = storm::expressions::Expression::createTrue()); + ASSERT_NO_THROW(falseExpression = storm::expressions::Expression::createFalse()); + ASSERT_NO_THROW(threeExpression = storm::expressions::Expression::createIntegerLiteral(3)); + ASSERT_NO_THROW(piExpression = storm::expressions::Expression::createDoubleLiteral(3.14)); + ASSERT_NO_THROW(boolVarExpression = storm::expressions::Expression::createBooleanVariable("x")); + ASSERT_NO_THROW(intVarExpression = storm::expressions::Expression::createIntegerVariable("y")); + ASSERT_NO_THROW(doubleVarExpression = storm::expressions::Expression::createDoubleVariable("z")); + ASSERT_NO_THROW(boolConstExpression = storm::expressions::Expression::createBooleanConstant("a")); + ASSERT_NO_THROW(intConstExpression = storm::expressions::Expression::createIntegerConstant("b")); + ASSERT_NO_THROW(doubleConstExpression = storm::expressions::Expression::createDoubleConstant("c")); + + storm::expressions::Expression tempExpression; + ASSERT_NO_THROW(tempExpression = (intVarExpression < threeExpression || boolVarExpression) && boolConstExpression); + + std::map substution = { std::make_pair("y", doubleConstExpression), std::make_pair("x", storm::expressions::Expression::createTrue()), std::make_pair("a", storm::expressions::Expression::createTrue()) }; + storm::expressions::Expression substitutedExpression; + ASSERT_NO_THROW(substitutedExpression = tempExpression.substitute(substution)); + EXPECT_TRUE(substitutedExpression.simplify().isTrue()); +} + +TEST(Expression, SimplificationTest) { + storm::expressions::Expression trueExpression; + storm::expressions::Expression falseExpression; + storm::expressions::Expression threeExpression; + storm::expressions::Expression intVarExpression; + + ASSERT_NO_THROW(trueExpression = storm::expressions::Expression::createTrue()); + ASSERT_NO_THROW(falseExpression = storm::expressions::Expression::createFalse()); + ASSERT_NO_THROW(threeExpression = storm::expressions::Expression::createIntegerLiteral(3)); + ASSERT_NO_THROW(intVarExpression = storm::expressions::Expression::createIntegerVariable("y")); + + storm::expressions::Expression tempExpression; + storm::expressions::Expression simplifiedExpression; + + ASSERT_NO_THROW(tempExpression = trueExpression || intVarExpression > threeExpression); + ASSERT_NO_THROW(simplifiedExpression = tempExpression.simplify()); + EXPECT_TRUE(simplifiedExpression.isTrue()); + + ASSERT_NO_THROW(tempExpression = falseExpression && intVarExpression > threeExpression); + ASSERT_NO_THROW(simplifiedExpression = tempExpression.simplify()); + EXPECT_TRUE(simplifiedExpression.isFalse()); +} -TEST(Expression, SimpleValuationTest) { - ASSERT_NO_THROW(storm::expressions::SimpleValuation evaluation(1, 1, 1)); +TEST(Expression, SimpleEvaluationTest) { + storm::expressions::Expression trueExpression; + storm::expressions::Expression falseExpression; + storm::expressions::Expression threeExpression; + storm::expressions::Expression piExpression; + storm::expressions::Expression boolVarExpression; + storm::expressions::Expression intVarExpression; + storm::expressions::Expression doubleVarExpression; + storm::expressions::Expression boolConstExpression; + storm::expressions::Expression intConstExpression; + storm::expressions::Expression doubleConstExpression; + + ASSERT_NO_THROW(trueExpression = storm::expressions::Expression::createTrue()); + ASSERT_NO_THROW(falseExpression = storm::expressions::Expression::createFalse()); + ASSERT_NO_THROW(threeExpression = storm::expressions::Expression::createIntegerLiteral(3)); + ASSERT_NO_THROW(piExpression = storm::expressions::Expression::createDoubleLiteral(3.14)); + ASSERT_NO_THROW(boolVarExpression = storm::expressions::Expression::createBooleanVariable("x")); + ASSERT_NO_THROW(intVarExpression = storm::expressions::Expression::createIntegerVariable("y")); + ASSERT_NO_THROW(doubleVarExpression = storm::expressions::Expression::createDoubleVariable("z")); + ASSERT_NO_THROW(boolConstExpression = storm::expressions::Expression::createBooleanConstant("a")); + ASSERT_NO_THROW(intConstExpression = storm::expressions::Expression::createIntegerConstant("b")); + ASSERT_NO_THROW(doubleConstExpression = storm::expressions::Expression::createDoubleConstant("c")); + + storm::expressions::Expression tempExpression; + + ASSERT_NO_THROW(tempExpression = (intVarExpression < threeExpression || boolVarExpression) && boolConstExpression); + + ASSERT_NO_THROW(storm::expressions::SimpleValuation valuation(2, 2, 2)); + storm::expressions::SimpleValuation valuation(2, 2, 2); + ASSERT_NO_THROW(valuation.setIdentifierIndex("x", 0)); + ASSERT_NO_THROW(valuation.setIdentifierIndex("a", 1)); + ASSERT_NO_THROW(valuation.setIdentifierIndex("y", 0)); + ASSERT_NO_THROW(valuation.setIdentifierIndex("b", 1)); + ASSERT_NO_THROW(valuation.setIdentifierIndex("z", 0)); + ASSERT_NO_THROW(valuation.setIdentifierIndex("c", 1)); + + ASSERT_THROW(tempExpression.evaluateAsDouble(valuation), storm::exceptions::InvalidTypeException); + ASSERT_THROW(tempExpression.evaluateAsInt(valuation), storm::exceptions::InvalidTypeException); + EXPECT_FALSE(tempExpression.evaluateAsBool(valuation)); + ASSERT_NO_THROW(valuation.setBooleanValue("a", true)); + EXPECT_TRUE(tempExpression.evaluateAsBool(valuation)); + ASSERT_NO_THROW(valuation.setIntegerValue("y", 3)); + EXPECT_FALSE(tempExpression.evaluateAsBool(valuation)); } \ No newline at end of file