From 1f281ff45a747a07dbedaaf44b822b4b60b1de9d Mon Sep 17 00:00:00 2001
From: Sebastian Junges <sebastian.junges@gmail.com>
Date: Fri, 11 Sep 2020 18:42:38 -0700
Subject: [PATCH] add predicate expressions for n-ary predicates

---
 .../storage/expressions/BaseExpression.cpp    |   8 ++
 .../storage/expressions/BaseExpression.h      |   4 +
 .../storage/expressions/ExpressionVisitor.cpp |  11 ++
 .../storage/expressions/ExpressionVisitor.h   |   2 +
 src/storm/storage/expressions/Expressions.h   |   1 +
 .../storage/expressions/OperatorType.cpp      |   3 +
 src/storm/storage/expressions/OperatorType.h  |   5 +-
 .../expressions/PredicateExpression.cpp       | 100 ++++++++++++++++++
 .../storage/expressions/PredicateExpression.h |  66 ++++++++++++
 9 files changed, 199 insertions(+), 1 deletion(-)
 create mode 100644 src/storm/storage/expressions/ExpressionVisitor.cpp
 create mode 100644 src/storm/storage/expressions/PredicateExpression.cpp
 create mode 100644 src/storm/storage/expressions/PredicateExpression.h

diff --git a/src/storm/storage/expressions/BaseExpression.cpp b/src/storm/storage/expressions/BaseExpression.cpp
index b6f652e3f..500dcdbc2 100644
--- a/src/storm/storage/expressions/BaseExpression.cpp
+++ b/src/storm/storage/expressions/BaseExpression.cpp
@@ -191,6 +191,14 @@ namespace storm {
         VariableExpression const& BaseExpression::asVariableExpression() const {
             return static_cast<VariableExpression const&>(*this);
         }
+
+        bool BaseExpression::isPredicateExpression() const {
+            return false;
+        }
+
+        PredicateExpression const& BaseExpression::asPredicateExpression() const {
+            return static_cast<PredicateExpression const&>(*this);
+        }
         
         std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression) {
             expression.printToStream(stream);
diff --git a/src/storm/storage/expressions/BaseExpression.h b/src/storm/storage/expressions/BaseExpression.h
index 666c4d40b..cd91a6c19 100644
--- a/src/storm/storage/expressions/BaseExpression.h
+++ b/src/storm/storage/expressions/BaseExpression.h
@@ -32,6 +32,7 @@ namespace storm {
         class UnaryBooleanFunctionExpression;
         class UnaryNumericalFunctionExpression;
         class VariableExpression;
+        class PredicateExpression;
         
         /*!
          * The base class of all expression classes.
@@ -286,6 +287,9 @@ namespace storm {
             
             virtual bool isVariableExpression() const;
             VariableExpression const& asVariableExpression() const;
+
+            virtual bool isPredicateExpression() const;
+            PredicateExpression const& asPredicateExpression() const;
             
         protected:
             /*!
diff --git a/src/storm/storage/expressions/ExpressionVisitor.cpp b/src/storm/storage/expressions/ExpressionVisitor.cpp
new file mode 100644
index 000000000..3e95f2c75
--- /dev/null
+++ b/src/storm/storage/expressions/ExpressionVisitor.cpp
@@ -0,0 +1,11 @@
+#include "storm/storage/expressions/ExpressionVisitor.h"
+#include "storm/utility/macros.h"
+#include "storm/exceptions/NotImplementedException.h"
+
+namespace storm {
+    namespace expressions {
+        boost::any ExpressionVisitor::visit(PredicateExpression const&, boost::any const&) {
+            STORM_LOG_THROW(false,storm::exceptions::NotImplementedException, "Predicate Expressions are not supported by this visitor");
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/storm/storage/expressions/ExpressionVisitor.h b/src/storm/storage/expressions/ExpressionVisitor.h
index 8ef1ea90d..367178426 100644
--- a/src/storm/storage/expressions/ExpressionVisitor.h
+++ b/src/storm/storage/expressions/ExpressionVisitor.h
@@ -17,6 +17,7 @@ namespace storm {
         class BooleanLiteralExpression;
         class IntegerLiteralExpression;
         class RationalLiteralExpression;
+        class PredicateExpression;
         
         class ExpressionVisitor {
         public:
@@ -32,6 +33,7 @@ namespace storm {
             virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) = 0;
             virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) = 0;
             virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) = 0;
+            virtual boost::any visit(PredicateExpression const& expression, boost::any const& data);
         };
     }
 }
diff --git a/src/storm/storage/expressions/Expressions.h b/src/storm/storage/expressions/Expressions.h
index 272810509..bcbf2b76a 100644
--- a/src/storm/storage/expressions/Expressions.h
+++ b/src/storm/storage/expressions/Expressions.h
@@ -8,4 +8,5 @@
 #include "storm/storage/expressions/UnaryBooleanFunctionExpression.h"
 #include "storm/storage/expressions/UnaryNumericalFunctionExpression.h"
 #include "storm/storage/expressions/VariableExpression.h"
+#include "storm/storage/expressions/PredicateExpression.h"
 #include "storm/storage/expressions/Expression.h"
diff --git a/src/storm/storage/expressions/OperatorType.cpp b/src/storm/storage/expressions/OperatorType.cpp
index bba61a653..a36982465 100644
--- a/src/storm/storage/expressions/OperatorType.cpp
+++ b/src/storm/storage/expressions/OperatorType.cpp
@@ -27,6 +27,9 @@ namespace storm {
                 case OperatorType::Floor: stream << "floor"; break;
                 case OperatorType::Ceil: stream << "ceil"; break;
                 case OperatorType::Ite: stream << "ite"; break;
+                case OperatorType::AtMostOneOf: stream << "atMostOneOf"; break;
+                case OperatorType::AtLeastOneOf: stream << "atLeastOneOf"; break;
+                case OperatorType::ExactlyOneOf: stream << "exactlyOneOf"; break;
             }
             return stream;
         }
diff --git a/src/storm/storage/expressions/OperatorType.h b/src/storm/storage/expressions/OperatorType.h
index e334b8104..196c900c7 100644
--- a/src/storm/storage/expressions/OperatorType.h
+++ b/src/storm/storage/expressions/OperatorType.h
@@ -29,7 +29,10 @@ namespace storm {
             Not,
             Floor,
             Ceil,
-            Ite
+            Ite,
+            AtLeastOneOf,
+            AtMostOneOf,
+            ExactlyOneOf
         };
         
         std::ostream& operator<<(std::ostream& stream, OperatorType const& operatorType);
diff --git a/src/storm/storage/expressions/PredicateExpression.cpp b/src/storm/storage/expressions/PredicateExpression.cpp
new file mode 100644
index 000000000..357a0cb17
--- /dev/null
+++ b/src/storm/storage/expressions/PredicateExpression.cpp
@@ -0,0 +1,100 @@
+
+#include "storm/storage/expressions/PredicateExpression.h"
+
+#include "storm/storage/expressions/ExpressionVisitor.h"
+#include "storm/utility/macros.h"
+#include "storm/storage/BitVector.h"
+#include "storm/exceptions/InvalidTypeException.h"
+
+namespace storm {
+    namespace expressions {
+        OperatorType toOperatorType(PredicateExpression::PredicateType tp) {
+            switch (tp) {
+                case PredicateExpression::PredicateType::AtMostOneOf: return OperatorType::AtMostOneOf;
+                case PredicateExpression::PredicateType::AtLeastOneOf: return OperatorType::AtLeastOneOf;
+                case PredicateExpression::PredicateType::ExactlyOneOf: return OperatorType::ExactlyOneOf;
+            }
+            STORM_LOG_ASSERT(false, "Predicate type not supported");
+        }
+
+        PredicateExpression::PredicateExpression(ExpressionManager const &manager, Type const& type,  std::vector <std::shared_ptr<BaseExpression const>> const &operands, PredicateType predicateType) : BaseExpression(manager, type), predicate(predicateType), operands(operands) {}
+
+        // Override base class methods.
+        storm::expressions::OperatorType PredicateExpression::getOperator() const {
+            return toOperatorType(predicate);
+        }
+
+        bool PredicateExpression::evaluateAsBool(Valuation const *valuation) const {
+            STORM_LOG_THROW(this->hasBooleanType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as boolean.");
+            storm::storage::BitVector results(operands.size());
+            uint64_t i = 0;
+            for(auto const& operand : operands) {
+                results.set(i, operand->evaluateAsBool(valuation));
+                ++i;
+            }
+            switch(predicate) {
+                case PredicateType::ExactlyOneOf: return results.getNumberOfSetBits() == 1;
+                case PredicateType::AtMostOneOf: return results.getNumberOfSetBits() <= 1;
+                case PredicateType::AtLeastOneOf: return results.getNumberOfSetBits() >= 1;
+            }
+            STORM_LOG_ASSERT(false, "Unknown predicate type");
+        }
+
+        std::shared_ptr<BaseExpression const> PredicateExpression::simplify() const {
+            std::vector<std::shared_ptr<BaseExpression const>> simplifiedOperands;
+            for (auto const& operand : operands) {
+                simplifiedOperands.push_back(operand->simplify());
+            }
+            return std::shared_ptr<BaseExpression>(new PredicateExpression(this->getManager(), this->getType(), simplifiedOperands, predicate));
+        }
+
+        boost::any PredicateExpression::accept(ExpressionVisitor &visitor, boost::any const &data) const {
+            return visitor.visit(*this, data);
+        }
+
+        bool PredicateExpression::isPredicateExpression() const {
+            return true;
+        }
+
+        bool PredicateExpression::isFunctionApplication() const {
+            return true;
+        }
+
+        bool PredicateExpression::containsVariables() const {
+            for(auto const& operand : operands) {
+                if(operand->containsVariables()) {
+                    return true;
+                }
+            }
+            return false;
+        }
+
+        uint_fast64_t PredicateExpression::getArity() const {
+            return operands.size();
+        }
+
+        std::shared_ptr<BaseExpression const> PredicateExpression::getOperand(uint_fast64_t operandIndex) const {
+            STORM_LOG_ASSERT(operandIndex < this->getArity(), "Invalid operand access");
+            return operands[operandIndex];
+        }
+
+        void PredicateExpression::gatherVariables(std::set<storm::expressions::Variable>& variables) const {
+            for(auto const& operand : operands) {
+                operand->gatherVariables(variables);
+            }
+        }
+
+        /*!
+         * Retrieves the relation associated with the expression.
+         *
+         * @return The relation associated with the expression.
+         */
+        PredicateExpression::PredicateType PredicateExpression::getPredicateType() const {
+            return predicate;
+        }
+
+        void PredicateExpression::printToStream(std::ostream& stream) const {
+
+        }
+    }
+}
\ No newline at end of file
diff --git a/src/storm/storage/expressions/PredicateExpression.h b/src/storm/storage/expressions/PredicateExpression.h
new file mode 100644
index 000000000..cfa8a5b6e
--- /dev/null
+++ b/src/storm/storage/expressions/PredicateExpression.h
@@ -0,0 +1,66 @@
+#pragma once
+
+#include "storm/storage/expressions/BaseExpression.h"
+
+namespace storm {
+    namespace expressions {
+        /*!
+         * The base class of all binary expressions.
+         */
+        class PredicateExpression : public BaseExpression {
+        public:
+            enum class PredicateType { AtLeastOneOf, AtMostOneOf, ExactlyOneOf };
+
+            PredicateExpression(ExpressionManager const &manager,Type const& type,
+                                std::vector <std::shared_ptr<BaseExpression const>> const &operands,
+                                PredicateType predicateType);
+
+            // Instantiate constructors and assignments with their default implementations.
+            PredicateExpression(PredicateExpression const &other) = default;
+
+            PredicateExpression &operator=(PredicateExpression const &other) = delete;
+
+            PredicateExpression(PredicateExpression &&) = default;
+
+            PredicateExpression &operator=(PredicateExpression &&) = delete;
+
+            virtual ~PredicateExpression() = default;
+
+            // Override base class methods.
+            virtual storm::expressions::OperatorType getOperator() const override;
+
+            virtual bool evaluateAsBool(Valuation const *valuation = nullptr) const override;
+
+            virtual std::shared_ptr<BaseExpression const> simplify() const override;
+
+            virtual boost::any accept(ExpressionVisitor &visitor, boost::any const &data) const override;
+
+            virtual bool isPredicateExpression() const override;
+
+            virtual bool isFunctionApplication() const override;
+
+            virtual bool containsVariables() const override;
+
+            virtual uint_fast64_t getArity() const override;
+
+            virtual std::shared_ptr<BaseExpression const> getOperand(uint_fast64_t operandIndex) const override;
+
+            virtual void gatherVariables(std::set<storm::expressions::Variable>& variables) const override;
+
+            /*!
+             * Retrieves the relation associated with the expression.
+             *
+             * @return The relation associated with the expression.
+             */
+            PredicateType getPredicateType() const;
+
+        protected:
+            // Override base class method.
+            virtual void printToStream(std::ostream& stream) const override;
+
+        private:
+            PredicateType predicate;
+            std::vector<std::shared_ptr<BaseExpression const>> operands;
+        };
+    }
+}