Browse Source

add predicate expressions for n-ary predicates

tempestpy_adaptions
Sebastian Junges 4 years ago
parent
commit
1f281ff45a
  1. 8
      src/storm/storage/expressions/BaseExpression.cpp
  2. 4
      src/storm/storage/expressions/BaseExpression.h
  3. 11
      src/storm/storage/expressions/ExpressionVisitor.cpp
  4. 2
      src/storm/storage/expressions/ExpressionVisitor.h
  5. 1
      src/storm/storage/expressions/Expressions.h
  6. 3
      src/storm/storage/expressions/OperatorType.cpp
  7. 5
      src/storm/storage/expressions/OperatorType.h
  8. 100
      src/storm/storage/expressions/PredicateExpression.cpp
  9. 66
      src/storm/storage/expressions/PredicateExpression.h

8
src/storm/storage/expressions/BaseExpression.cpp

@ -192,6 +192,14 @@ namespace storm {
return static_cast<VariableExpression const&>(*this); return static_cast<VariableExpression const&>(*this);
} }
bool BaseExpression::isPredicateExpression() const {
return false;
}
PredicateExpression const& BaseExpression::asPredicateExpression() const {
return static_cast<PredicateExpression const&>(*this);
}
std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression) { std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression) {
expression.printToStream(stream); expression.printToStream(stream);
return stream; return stream;

4
src/storm/storage/expressions/BaseExpression.h

@ -32,6 +32,7 @@ namespace storm {
class UnaryBooleanFunctionExpression; class UnaryBooleanFunctionExpression;
class UnaryNumericalFunctionExpression; class UnaryNumericalFunctionExpression;
class VariableExpression; class VariableExpression;
class PredicateExpression;
/*! /*!
* The base class of all expression classes. * The base class of all expression classes.
@ -287,6 +288,9 @@ namespace storm {
virtual bool isVariableExpression() const; virtual bool isVariableExpression() const;
VariableExpression const& asVariableExpression() const; VariableExpression const& asVariableExpression() const;
virtual bool isPredicateExpression() const;
PredicateExpression const& asPredicateExpression() const;
protected: protected:
/*! /*!
* Prints the expression to the given stream. * Prints the expression to the given stream.

11
src/storm/storage/expressions/ExpressionVisitor.cpp

@ -0,0 +1,11 @@
#include "storm/storage/expressions/ExpressionVisitor.h"
#include "storm/utility/macros.h"
#include "storm/exceptions/NotImplementedException.h"
namespace storm {
namespace expressions {
boost::any ExpressionVisitor::visit(PredicateExpression const&, boost::any const&) {
STORM_LOG_THROW(false,storm::exceptions::NotImplementedException, "Predicate Expressions are not supported by this visitor");
}
}
}

2
src/storm/storage/expressions/ExpressionVisitor.h

@ -17,6 +17,7 @@ namespace storm {
class BooleanLiteralExpression; class BooleanLiteralExpression;
class IntegerLiteralExpression; class IntegerLiteralExpression;
class RationalLiteralExpression; class RationalLiteralExpression;
class PredicateExpression;
class ExpressionVisitor { class ExpressionVisitor {
public: public:
@ -32,6 +33,7 @@ namespace storm {
virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) = 0; virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) = 0;
virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) = 0; virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) = 0;
virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) = 0; virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) = 0;
virtual boost::any visit(PredicateExpression const& expression, boost::any const& data);
}; };
} }
} }

1
src/storm/storage/expressions/Expressions.h

@ -8,4 +8,5 @@
#include "storm/storage/expressions/UnaryBooleanFunctionExpression.h" #include "storm/storage/expressions/UnaryBooleanFunctionExpression.h"
#include "storm/storage/expressions/UnaryNumericalFunctionExpression.h" #include "storm/storage/expressions/UnaryNumericalFunctionExpression.h"
#include "storm/storage/expressions/VariableExpression.h" #include "storm/storage/expressions/VariableExpression.h"
#include "storm/storage/expressions/PredicateExpression.h"
#include "storm/storage/expressions/Expression.h" #include "storm/storage/expressions/Expression.h"

3
src/storm/storage/expressions/OperatorType.cpp

@ -27,6 +27,9 @@ namespace storm {
case OperatorType::Floor: stream << "floor"; break; case OperatorType::Floor: stream << "floor"; break;
case OperatorType::Ceil: stream << "ceil"; break; case OperatorType::Ceil: stream << "ceil"; break;
case OperatorType::Ite: stream << "ite"; break; case OperatorType::Ite: stream << "ite"; break;
case OperatorType::AtMostOneOf: stream << "atMostOneOf"; break;
case OperatorType::AtLeastOneOf: stream << "atLeastOneOf"; break;
case OperatorType::ExactlyOneOf: stream << "exactlyOneOf"; break;
} }
return stream; return stream;
} }

5
src/storm/storage/expressions/OperatorType.h

@ -29,7 +29,10 @@ namespace storm {
Not, Not,
Floor, Floor,
Ceil, Ceil,
Ite
Ite,
AtLeastOneOf,
AtMostOneOf,
ExactlyOneOf
}; };
std::ostream& operator<<(std::ostream& stream, OperatorType const& operatorType); std::ostream& operator<<(std::ostream& stream, OperatorType const& operatorType);

100
src/storm/storage/expressions/PredicateExpression.cpp

@ -0,0 +1,100 @@
#include "storm/storage/expressions/PredicateExpression.h"
#include "storm/storage/expressions/ExpressionVisitor.h"
#include "storm/utility/macros.h"
#include "storm/storage/BitVector.h"
#include "storm/exceptions/InvalidTypeException.h"
namespace storm {
namespace expressions {
OperatorType toOperatorType(PredicateExpression::PredicateType tp) {
switch (tp) {
case PredicateExpression::PredicateType::AtMostOneOf: return OperatorType::AtMostOneOf;
case PredicateExpression::PredicateType::AtLeastOneOf: return OperatorType::AtLeastOneOf;
case PredicateExpression::PredicateType::ExactlyOneOf: return OperatorType::ExactlyOneOf;
}
STORM_LOG_ASSERT(false, "Predicate type not supported");
}
PredicateExpression::PredicateExpression(ExpressionManager const &manager, Type const& type, std::vector <std::shared_ptr<BaseExpression const>> const &operands, PredicateType predicateType) : BaseExpression(manager, type), predicate(predicateType), operands(operands) {}
// Override base class methods.
storm::expressions::OperatorType PredicateExpression::getOperator() const {
return toOperatorType(predicate);
}
bool PredicateExpression::evaluateAsBool(Valuation const *valuation) const {
STORM_LOG_THROW(this->hasBooleanType(), storm::exceptions::InvalidTypeException, "Unable to evaluate expression as boolean.");
storm::storage::BitVector results(operands.size());
uint64_t i = 0;
for(auto const& operand : operands) {
results.set(i, operand->evaluateAsBool(valuation));
++i;
}
switch(predicate) {
case PredicateType::ExactlyOneOf: return results.getNumberOfSetBits() == 1;
case PredicateType::AtMostOneOf: return results.getNumberOfSetBits() <= 1;
case PredicateType::AtLeastOneOf: return results.getNumberOfSetBits() >= 1;
}
STORM_LOG_ASSERT(false, "Unknown predicate type");
}
std::shared_ptr<BaseExpression const> PredicateExpression::simplify() const {
std::vector<std::shared_ptr<BaseExpression const>> simplifiedOperands;
for (auto const& operand : operands) {
simplifiedOperands.push_back(operand->simplify());
}
return std::shared_ptr<BaseExpression>(new PredicateExpression(this->getManager(), this->getType(), simplifiedOperands, predicate));
}
boost::any PredicateExpression::accept(ExpressionVisitor &visitor, boost::any const &data) const {
return visitor.visit(*this, data);
}
bool PredicateExpression::isPredicateExpression() const {
return true;
}
bool PredicateExpression::isFunctionApplication() const {
return true;
}
bool PredicateExpression::containsVariables() const {
for(auto const& operand : operands) {
if(operand->containsVariables()) {
return true;
}
}
return false;
}
uint_fast64_t PredicateExpression::getArity() const {
return operands.size();
}
std::shared_ptr<BaseExpression const> PredicateExpression::getOperand(uint_fast64_t operandIndex) const {
STORM_LOG_ASSERT(operandIndex < this->getArity(), "Invalid operand access");
return operands[operandIndex];
}
void PredicateExpression::gatherVariables(std::set<storm::expressions::Variable>& variables) const {
for(auto const& operand : operands) {
operand->gatherVariables(variables);
}
}
/*!
* Retrieves the relation associated with the expression.
*
* @return The relation associated with the expression.
*/
PredicateExpression::PredicateType PredicateExpression::getPredicateType() const {
return predicate;
}
void PredicateExpression::printToStream(std::ostream& stream) const {
}
}
}

66
src/storm/storage/expressions/PredicateExpression.h

@ -0,0 +1,66 @@
#pragma once
#include "storm/storage/expressions/BaseExpression.h"
namespace storm {
namespace expressions {
/*!
* The base class of all binary expressions.
*/
class PredicateExpression : public BaseExpression {
public:
enum class PredicateType { AtLeastOneOf, AtMostOneOf, ExactlyOneOf };
PredicateExpression(ExpressionManager const &manager,Type const& type,
std::vector <std::shared_ptr<BaseExpression const>> const &operands,
PredicateType predicateType);
// Instantiate constructors and assignments with their default implementations.
PredicateExpression(PredicateExpression const &other) = default;
PredicateExpression &operator=(PredicateExpression const &other) = delete;
PredicateExpression(PredicateExpression &&) = default;
PredicateExpression &operator=(PredicateExpression &&) = delete;
virtual ~PredicateExpression() = default;
// Override base class methods.
virtual storm::expressions::OperatorType getOperator() const override;
virtual bool evaluateAsBool(Valuation const *valuation = nullptr) const override;
virtual std::shared_ptr<BaseExpression const> simplify() const override;
virtual boost::any accept(ExpressionVisitor &visitor, boost::any const &data) const override;
virtual bool isPredicateExpression() const override;
virtual bool isFunctionApplication() const override;
virtual bool containsVariables() const override;
virtual uint_fast64_t getArity() const override;
virtual std::shared_ptr<BaseExpression const> getOperand(uint_fast64_t operandIndex) const override;
virtual void gatherVariables(std::set<storm::expressions::Variable>& variables) const override;
/*!
* Retrieves the relation associated with the expression.
*
* @return The relation associated with the expression.
*/
PredicateType getPredicateType() const;
protected:
// Override base class method.
virtual void printToStream(std::ostream& stream) const override;
private:
PredicateType predicate;
std::vector<std::shared_ptr<BaseExpression const>> operands;
};
}
}
Loading…
Cancel
Save