diff --git a/src/storm/storage/expressions/BaseExpression.cpp b/src/storm/storage/expressions/BaseExpression.cpp index b6f652e3f..500dcdbc2 100644 --- a/src/storm/storage/expressions/BaseExpression.cpp +++ b/src/storm/storage/expressions/BaseExpression.cpp @@ -191,6 +191,14 @@ namespace storm { VariableExpression const& BaseExpression::asVariableExpression() const { return static_cast(*this); } + + bool BaseExpression::isPredicateExpression() const { + return false; + } + + PredicateExpression const& BaseExpression::asPredicateExpression() const { + return static_cast(*this); + } std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression) { expression.printToStream(stream); diff --git a/src/storm/storage/expressions/BaseExpression.h b/src/storm/storage/expressions/BaseExpression.h index 666c4d40b..cd91a6c19 100644 --- a/src/storm/storage/expressions/BaseExpression.h +++ b/src/storm/storage/expressions/BaseExpression.h @@ -32,6 +32,7 @@ namespace storm { class UnaryBooleanFunctionExpression; class UnaryNumericalFunctionExpression; class VariableExpression; + class PredicateExpression; /*! * The base class of all expression classes. @@ -286,6 +287,9 @@ namespace storm { virtual bool isVariableExpression() const; VariableExpression const& asVariableExpression() const; + + virtual bool isPredicateExpression() const; + PredicateExpression const& asPredicateExpression() const; protected: /*! diff --git a/src/storm/storage/expressions/ExpressionVisitor.cpp b/src/storm/storage/expressions/ExpressionVisitor.cpp new file mode 100644 index 000000000..3e95f2c75 --- /dev/null +++ b/src/storm/storage/expressions/ExpressionVisitor.cpp @@ -0,0 +1,11 @@ +#include "storm/storage/expressions/ExpressionVisitor.h" +#include "storm/utility/macros.h" +#include "storm/exceptions/NotImplementedException.h" + +namespace storm { + namespace expressions { + boost::any ExpressionVisitor::visit(PredicateExpression const&, boost::any const&) { + STORM_LOG_THROW(false,storm::exceptions::NotImplementedException, "Predicate Expressions are not supported by this visitor"); + } + } +} \ No newline at end of file diff --git a/src/storm/storage/expressions/ExpressionVisitor.h b/src/storm/storage/expressions/ExpressionVisitor.h index 8ef1ea90d..367178426 100644 --- a/src/storm/storage/expressions/ExpressionVisitor.h +++ b/src/storm/storage/expressions/ExpressionVisitor.h @@ -17,6 +17,7 @@ namespace storm { class BooleanLiteralExpression; class IntegerLiteralExpression; class RationalLiteralExpression; + class PredicateExpression; class ExpressionVisitor { public: @@ -32,6 +33,7 @@ namespace storm { virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) = 0; virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) = 0; virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) = 0; + virtual boost::any visit(PredicateExpression const& expression, boost::any const& data); }; } } diff --git a/src/storm/storage/expressions/Expressions.h b/src/storm/storage/expressions/Expressions.h index 272810509..bcbf2b76a 100644 --- a/src/storm/storage/expressions/Expressions.h +++ b/src/storm/storage/expressions/Expressions.h @@ -8,4 +8,5 @@ #include "storm/storage/expressions/UnaryBooleanFunctionExpression.h" #include "storm/storage/expressions/UnaryNumericalFunctionExpression.h" #include "storm/storage/expressions/VariableExpression.h" +#include "storm/storage/expressions/PredicateExpression.h" #include "storm/storage/expressions/Expression.h" diff --git a/src/storm/storage/expressions/OperatorType.cpp b/src/storm/storage/expressions/OperatorType.cpp index bba61a653..a36982465 100644 --- a/src/storm/storage/expressions/OperatorType.cpp +++ b/src/storm/storage/expressions/OperatorType.cpp @@ -27,6 +27,9 @@ namespace storm { case OperatorType::Floor: stream << "floor"; break; case OperatorType::Ceil: stream << "ceil"; break; case OperatorType::Ite: stream << "ite"; break; + case OperatorType::AtMostOneOf: stream << "atMostOneOf"; break; + case OperatorType::AtLeastOneOf: stream << "atLeastOneOf"; break; + case OperatorType::ExactlyOneOf: stream << "exactlyOneOf"; break; } return stream; } diff --git a/src/storm/storage/expressions/OperatorType.h b/src/storm/storage/expressions/OperatorType.h index e334b8104..196c900c7 100644 --- a/src/storm/storage/expressions/OperatorType.h +++ b/src/storm/storage/expressions/OperatorType.h @@ -29,7 +29,10 @@ namespace storm { Not, Floor, Ceil, - Ite + Ite, + AtLeastOneOf, + AtMostOneOf, + ExactlyOneOf }; std::ostream& operator<<(std::ostream& stream, OperatorType const& operatorType); diff --git a/src/storm/storage/expressions/PredicateExpression.cpp b/src/storm/storage/expressions/PredicateExpression.cpp new file mode 100644 index 000000000..357a0cb17 --- /dev/null +++ b/src/storm/storage/expressions/PredicateExpression.cpp @@ -0,0 +1,100 @@ + +#include "storm/storage/expressions/PredicateExpression.h" + +#include "storm/storage/expressions/ExpressionVisitor.h" +#include "storm/utility/macros.h" +#include "storm/storage/BitVector.h" +#include "storm/exceptions/InvalidTypeException.h" + +namespace storm { + namespace expressions { + OperatorType toOperatorType(PredicateExpression::PredicateType tp) { + switch (tp) { + case PredicateExpression::PredicateType::AtMostOneOf: return OperatorType::AtMostOneOf; + case PredicateExpression::PredicateType::AtLeastOneOf: return OperatorType::AtLeastOneOf; + case PredicateExpression::PredicateType::ExactlyOneOf: return OperatorType::ExactlyOneOf; + } + STORM_LOG_ASSERT(false, "Predicate type not supported"); + } + + PredicateExpression::PredicateExpression(ExpressionManager const &manager, Type const& type, std::vector > const &operands, PredicateType predicateType) : BaseExpression(manager, type), predicate(predicateType), operands(operands) {} + + // Override base class methods. + storm::expressions::OperatorType PredicateExpression::getOperator() const { + return toOperatorType(predicate); + } + + bool PredicateExpression::evaluateAsBool(Valuation const *valuation) const { + STORM_LOG_THROW(this->hasBooleanType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as boolean."); + storm::storage::BitVector results(operands.size()); + uint64_t i = 0; + for(auto const& operand : operands) { + results.set(i, operand->evaluateAsBool(valuation)); + ++i; + } + switch(predicate) { + case PredicateType::ExactlyOneOf: return results.getNumberOfSetBits() == 1; + case PredicateType::AtMostOneOf: return results.getNumberOfSetBits() <= 1; + case PredicateType::AtLeastOneOf: return results.getNumberOfSetBits() >= 1; + } + STORM_LOG_ASSERT(false, "Unknown predicate type"); + } + + std::shared_ptr PredicateExpression::simplify() const { + std::vector> simplifiedOperands; + for (auto const& operand : operands) { + simplifiedOperands.push_back(operand->simplify()); + } + return std::shared_ptr(new PredicateExpression(this->getManager(), this->getType(), simplifiedOperands, predicate)); + } + + boost::any PredicateExpression::accept(ExpressionVisitor &visitor, boost::any const &data) const { + return visitor.visit(*this, data); + } + + bool PredicateExpression::isPredicateExpression() const { + return true; + } + + bool PredicateExpression::isFunctionApplication() const { + return true; + } + + bool PredicateExpression::containsVariables() const { + for(auto const& operand : operands) { + if(operand->containsVariables()) { + return true; + } + } + return false; + } + + uint_fast64_t PredicateExpression::getArity() const { + return operands.size(); + } + + std::shared_ptr PredicateExpression::getOperand(uint_fast64_t operandIndex) const { + STORM_LOG_ASSERT(operandIndex < this->getArity(), "Invalid operand access"); + return operands[operandIndex]; + } + + void PredicateExpression::gatherVariables(std::set& variables) const { + for(auto const& operand : operands) { + operand->gatherVariables(variables); + } + } + + /*! + * Retrieves the relation associated with the expression. + * + * @return The relation associated with the expression. + */ + PredicateExpression::PredicateType PredicateExpression::getPredicateType() const { + return predicate; + } + + void PredicateExpression::printToStream(std::ostream& stream) const { + + } + } +} \ No newline at end of file diff --git a/src/storm/storage/expressions/PredicateExpression.h b/src/storm/storage/expressions/PredicateExpression.h new file mode 100644 index 000000000..cfa8a5b6e --- /dev/null +++ b/src/storm/storage/expressions/PredicateExpression.h @@ -0,0 +1,66 @@ +#pragma once + +#include "storm/storage/expressions/BaseExpression.h" + +namespace storm { + namespace expressions { + /*! + * The base class of all binary expressions. + */ + class PredicateExpression : public BaseExpression { + public: + enum class PredicateType { AtLeastOneOf, AtMostOneOf, ExactlyOneOf }; + + PredicateExpression(ExpressionManager const &manager,Type const& type, + std::vector > const &operands, + PredicateType predicateType); + + // Instantiate constructors and assignments with their default implementations. + PredicateExpression(PredicateExpression const &other) = default; + + PredicateExpression &operator=(PredicateExpression const &other) = delete; + + PredicateExpression(PredicateExpression &&) = default; + + PredicateExpression &operator=(PredicateExpression &&) = delete; + + virtual ~PredicateExpression() = default; + + // Override base class methods. + virtual storm::expressions::OperatorType getOperator() const override; + + virtual bool evaluateAsBool(Valuation const *valuation = nullptr) const override; + + virtual std::shared_ptr simplify() const override; + + virtual boost::any accept(ExpressionVisitor &visitor, boost::any const &data) const override; + + virtual bool isPredicateExpression() const override; + + virtual bool isFunctionApplication() const override; + + virtual bool containsVariables() const override; + + virtual uint_fast64_t getArity() const override; + + virtual std::shared_ptr getOperand(uint_fast64_t operandIndex) const override; + + virtual void gatherVariables(std::set& variables) const override; + + /*! + * Retrieves the relation associated with the expression. + * + * @return The relation associated with the expression. + */ + PredicateType getPredicateType() const; + + protected: + // Override base class method. + virtual void printToStream(std::ostream& stream) const override; + + private: + PredicateType predicate; + std::vector> operands; + }; + } +}