diff --git a/src/storage/expressions/BaseExpression.cpp b/src/storage/expressions/BaseExpression.cpp index c2f4a5ac3..38a0604e2 100644 --- a/src/storage/expressions/BaseExpression.cpp +++ b/src/storage/expressions/BaseExpression.cpp @@ -81,16 +81,6 @@ namespace storm { return this->shared_from_this(); } - std::ostream& operator<<(std::ostream& stream, ExpressionReturnType const& enumValue) { - switch (enumValue) { - case ExpressionReturnType::Undefined: stream << "undefined"; break; - case ExpressionReturnType::Bool: stream << "bool"; break; - case ExpressionReturnType::Int: stream << "int"; break; - case ExpressionReturnType::Double: stream << "double"; break; - } - return stream; - } - std::ostream& operator<<(std::ostream& stream, BaseExpression const& expression) { expression.printToStream(stream); return stream; diff --git a/src/storage/expressions/BaseExpression.h b/src/storage/expressions/BaseExpression.h index 6a02b35ab..9524fcdbb 100644 --- a/src/storage/expressions/BaseExpression.h +++ b/src/storage/expressions/BaseExpression.h @@ -6,6 +6,7 @@ #include #include +#include "src/storage/expressions/ExpressionReturnType.h" #include "src/storage/expressions/Valuation.h" #include "src/storage/expressions/ExpressionVisitor.h" #include "src/storage/expressions/OperatorType.h" @@ -13,14 +14,7 @@ #include "src/utility/OsDetection.h" namespace storm { - namespace expressions { - /*! - * Each node in an expression tree has a uniquely defined type from this enum. - */ - enum class ExpressionReturnType {Undefined, Bool, Int, Double}; - - std::ostream& operator<<(std::ostream& stream, ExpressionReturnType const& enumValue); - + namespace expressions { /*! * The base class of all expression classes. */ diff --git a/src/storage/expressions/Expression.cpp b/src/storage/expressions/Expression.cpp index 169578c34..762ed07bf 100644 --- a/src/storage/expressions/Expression.cpp +++ b/src/storage/expressions/Expression.cpp @@ -5,6 +5,7 @@ #include "src/storage/expressions/SubstitutionVisitor.h" #include "src/storage/expressions/IdentifierSubstitutionVisitor.h" #include "src/storage/expressions/TypeCheckVisitor.h" +#include "src/storage/expressions/LinearityCheckVisitor.h" #include "src/storage/expressions/Expressions.h" #include "src/exceptions/InvalidTypeException.h" #include "src/exceptions/ExceptionMacros.h" @@ -16,27 +17,27 @@ namespace storm { } Expression Expression::substitute(std::map const& identifierToExpressionMap) const { - return SubstitutionVisitor>(identifierToExpressionMap).substitute(this->getBaseExpressionPointer().get()); + return SubstitutionVisitor>(identifierToExpressionMap).substitute(this); } Expression Expression::substitute(std::unordered_map const& identifierToExpressionMap) const { - return SubstitutionVisitor>(identifierToExpressionMap).substitute(this->getBaseExpressionPointer().get()); + return SubstitutionVisitor>(identifierToExpressionMap).substitute(this); } Expression Expression::substitute(std::map const& identifierToIdentifierMap) const { - return IdentifierSubstitutionVisitor>(identifierToIdentifierMap).substitute(this->getBaseExpressionPointer().get()); + return IdentifierSubstitutionVisitor>(identifierToIdentifierMap).substitute(this); } Expression Expression::substitute(std::unordered_map const& identifierToIdentifierMap) const { - return IdentifierSubstitutionVisitor>(identifierToIdentifierMap).substitute(this->getBaseExpressionPointer().get()); + return IdentifierSubstitutionVisitor>(identifierToIdentifierMap).substitute(this); } void Expression::check(std::map const& identifierToTypeMap) const { - return TypeCheckVisitor>(identifierToTypeMap).check(this->getBaseExpressionPointer().get()); + return TypeCheckVisitor>(identifierToTypeMap).check(this); } void Expression::check(std::unordered_map const& identifierToTypeMap) const { - return TypeCheckVisitor>(identifierToTypeMap).check(this->getBaseExpressionPointer().get()); + return TypeCheckVisitor>(identifierToTypeMap).check(this); } bool Expression::evaluateAsBool(Valuation const* valuation) const { @@ -105,6 +106,10 @@ namespace storm { || this->getOperator() == OperatorType::Greater || this->getOperator() == OperatorType::GreaterOrEqual; } + bool Expression::isLinear() const { + return LinearityCheckVisitor().check(*this); + } + std::set Expression::getVariables() const { return this->getBaseExpression().getVariables(); } diff --git a/src/storage/expressions/Expression.h b/src/storage/expressions/Expression.h index 81ae5cf54..08c758f27 100644 --- a/src/storage/expressions/Expression.h +++ b/src/storage/expressions/Expression.h @@ -6,6 +6,7 @@ #include #include "src/storage/expressions/BaseExpression.h" +#include "src/storage/expressions/ExpressionVisitor.h" #include "src/utility/OsDetection.h" namespace storm { @@ -238,6 +239,13 @@ namespace storm { */ bool isRelationalExpression() const; + /*! + * Retrieves whether this expression is a linear expression. + * + * @return True iff the expression is linear. + */ + bool isLinear() const; + /*! * Retrieves the set of all variables that appear in the expression. * @@ -281,6 +289,13 @@ namespace storm { */ bool hasBooleanReturnType() const; + /*! + * Accepts the given visitor. + * + * @param visitor The visitor to accept. + */ + void accept(ExpressionVisitor* visitor) const; + friend std::ostream& operator<<(std::ostream& stream, Expression const& expression); private: diff --git a/src/storage/expressions/ExpressionReturnType.cpp b/src/storage/expressions/ExpressionReturnType.cpp new file mode 100644 index 000000000..c10810f6c --- /dev/null +++ b/src/storage/expressions/ExpressionReturnType.cpp @@ -0,0 +1,15 @@ +#include "src/storage/expressions/ExpressionReturnType.h" + +namespace storm { + namespace expressions { + std::ostream& operator<<(std::ostream& stream, ExpressionReturnType const& enumValue) { + switch (enumValue) { + case ExpressionReturnType::Undefined: stream << "undefined"; break; + case ExpressionReturnType::Bool: stream << "bool"; break; + case ExpressionReturnType::Int: stream << "int"; break; + case ExpressionReturnType::Double: stream << "double"; break; + } + return stream; + } + } +} \ No newline at end of file diff --git a/src/storage/expressions/ExpressionReturnType.h b/src/storage/expressions/ExpressionReturnType.h new file mode 100644 index 000000000..0cf928ed8 --- /dev/null +++ b/src/storage/expressions/ExpressionReturnType.h @@ -0,0 +1,17 @@ +#ifndef STORM_STORAGE_EXPRESSIONS_EXPRESSIONRETURNTYPE_H_ +#define STORM_STORAGE_EXPRESSIONS_EXPRESSIONRETURNTYPE_H_ + +#include + +namespace storm { + namespace expressions { + /*! + * Each node in an expression tree has a uniquely defined type from this enum. + */ + enum class ExpressionReturnType {Undefined, Bool, Int, Double}; + + std::ostream& operator<<(std::ostream& stream, ExpressionReturnType const& enumValue); + } +} + +#endif /* STORM_STORAGE_EXPRESSIONS_EXPRESSIONRETURNTYPE_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/IdentifierSubstitutionVisitor.cpp b/src/storage/expressions/IdentifierSubstitutionVisitor.cpp index a2153af6d..19724dea1 100644 --- a/src/storage/expressions/IdentifierSubstitutionVisitor.cpp +++ b/src/storage/expressions/IdentifierSubstitutionVisitor.cpp @@ -13,8 +13,8 @@ namespace storm { } template - Expression IdentifierSubstitutionVisitor::substitute(BaseExpression const* expression) { - expression->accept(this); + Expression IdentifierSubstitutionVisitor::substitute(Expression const& expression) { + expression.getBaseExpression().accept(this); return Expression(this->expressionStack.top()); } diff --git a/src/storage/expressions/IdentifierSubstitutionVisitor.h b/src/storage/expressions/IdentifierSubstitutionVisitor.h index 23dae79de..8e8723a70 100644 --- a/src/storage/expressions/IdentifierSubstitutionVisitor.h +++ b/src/storage/expressions/IdentifierSubstitutionVisitor.h @@ -26,7 +26,7 @@ namespace storm { * @return The expression in which all identifiers in the key set of the previously given mapping are * substituted with the mapped-to expressions. */ - Expression substitute(BaseExpression const* expression); + Expression substitute(Expression const& expression); virtual void visit(IfThenElseExpression const* expression) override; virtual void visit(BinaryBooleanFunctionExpression const* expression) override; diff --git a/src/storage/expressions/LinearCoefficientVisitor.cpp b/src/storage/expressions/LinearCoefficientVisitor.cpp new file mode 100644 index 000000000..58fd427d5 --- /dev/null +++ b/src/storage/expressions/LinearCoefficientVisitor.cpp @@ -0,0 +1,70 @@ +#include "src/storage/expressions/LinearCoefficientVisitor.h" + +#include "src/storage/expressions/Expressions.h" +#include "src/exceptions/ExceptionMacros.h" +#include "src/exceptions/InvalidArgumentException.h" + +namespace storm { + namespace expressions { + std::pair LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) { + expression.getBaseExpression().accept(this); + return resultStack.top(); + } + + void LinearCoefficientVisitor::visit(IfThenElseExpression const* expression) { + LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); + } + + void LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const* expression) { + + } + + void LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const* expression) { + + } + + void LinearCoefficientVisitor::visit(BinaryRelationExpression const* expression) { + + } + + void LinearCoefficientVisitor::visit(VariableExpression const* expression) { + SimpleValuation valuation; + switch (expression->getReturnType()) { + case ExpressionReturnType::Bool: LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); break; + case ExpressionReturnType::Int: + case ExpressionReturnType::Double: valuation.addDoubleIdentifier(expression->getVariableName(), 1); break; + case ExpressionReturnType::Undefined: LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Illegal expression return type."); break; + } + + resultStack.push(std::make_pair(valuation, 0)); + } + + void LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const* expression) { + LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); + } + + void LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const* expression) { + if (expression->getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) { + // Here, we need to negate all double identifiers. + std::pair& valuationConstantPair = resultStack.top(); + for (auto const& identifier : valuationConstantPair.first.getDoubleIdentifiers()) { + valuationConstantPair.first.setDoubleValue(identifier, -valuationConstantPair.first.getDoubleValue(identifier)); + } + } else { + LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); + } + } + + void LinearCoefficientVisitor::visit(BooleanLiteralExpression const* expression) { + LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); + } + + void LinearCoefficientVisitor::visit(IntegerLiteralExpression const* expression) { + resultStack.push(std::make_pair(SimpleValuation(), static_cast(expression->getValue()))); + } + + void LinearCoefficientVisitor::visit(DoubleLiteralExpression const* expression) { + resultStack.push(std::make_pair(SimpleValuation(), expression->getValue())); + } + } +} \ No newline at end of file diff --git a/src/storage/expressions/LinearCoefficientVisitor.h b/src/storage/expressions/LinearCoefficientVisitor.h new file mode 100644 index 000000000..263e752c8 --- /dev/null +++ b/src/storage/expressions/LinearCoefficientVisitor.h @@ -0,0 +1,46 @@ +#ifndef STORM_STORAGE_EXPRESSIONS_LINEARCOEFFICIENTVISITOR_H_ +#define STORM_STORAGE_EXPRESSIONS_LINEARCOEFFICIENTVISITOR_H_ + +#include + +#include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/ExpressionVisitor.h" +#include "src/storage/expressions/SimpleValuation.h" + +namespace storm { + namespace expressions { + class LinearCoefficientVisitor : public ExpressionVisitor { + public: + /*! + * Creates a linear coefficient visitor. + */ + LinearCoefficientVisitor() = default; + + /*! + * Computes the (double) coefficients of all identifiers appearing in the expression if the expression + * was rewritten as a sum of atoms.. If the expression is not linear, an exception is thrown. + * + * @param expression The expression for which to compute the coefficients. + * @return A pair consisting of a mapping from identifiers to their coefficients and the coefficient of + * the constant atom. + */ + std::pair getLinearCoefficients(Expression const& expression); + + virtual void visit(IfThenElseExpression const* expression) override; + virtual void visit(BinaryBooleanFunctionExpression const* expression) override; + virtual void visit(BinaryNumericalFunctionExpression const* expression) override; + virtual void visit(BinaryRelationExpression const* expression) override; + virtual void visit(VariableExpression const* expression) override; + virtual void visit(UnaryBooleanFunctionExpression const* expression) override; + virtual void visit(UnaryNumericalFunctionExpression const* expression) override; + virtual void visit(BooleanLiteralExpression const* expression) override; + virtual void visit(IntegerLiteralExpression const* expression) override; + virtual void visit(DoubleLiteralExpression const* expression) override; + + private: + std::stack> resultStack; + }; + } +} + +#endif /* STORM_STORAGE_EXPRESSIONS_LINEARCOEFFICIENTVISITOR_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/LinearityCheckVisitor.cpp b/src/storage/expressions/LinearityCheckVisitor.cpp index d4086abf1..5dd2f38ac 100644 --- a/src/storage/expressions/LinearityCheckVisitor.cpp +++ b/src/storage/expressions/LinearityCheckVisitor.cpp @@ -6,73 +6,124 @@ namespace storm { namespace expressions { - bool LinearityCheckVisitor::check(BaseExpression const* expression) { - expression->accept(this); - return resultStack.top(); + LinearityCheckVisitor::LinearityCheckVisitor() : resultStack() { + // Intentionally left empty. + } + + bool LinearityCheckVisitor::check(Expression const& expression) { + expression.getBaseExpression().accept(this); + return resultStack.top() == LinearityStatus::LinearWithoutVariables || resultStack.top() == LinearityStatus::LinearContainsVariables; } void LinearityCheckVisitor::visit(IfThenElseExpression const* expression) { // An if-then-else expression is never linear. - resultStack.push(false); + resultStack.push(LinearityStatus::NonLinear); } void LinearityCheckVisitor::visit(BinaryBooleanFunctionExpression const* expression) { // Boolean function applications are not allowed in linear expressions. - resultStack.push(false); + resultStack.push(LinearityStatus::NonLinear); } void LinearityCheckVisitor::visit(BinaryNumericalFunctionExpression const* expression) { - bool leftResult = true; - bool rightResult = true; + LinearityStatus leftResult; + LinearityStatus rightResult; switch (expression->getOperatorType()) { case BinaryNumericalFunctionExpression::OperatorType::Plus: case BinaryNumericalFunctionExpression::OperatorType::Minus: expression->getFirstOperand()->accept(this); leftResult = resultStack.top(); - if (!leftResult) { - + if (leftResult == LinearityStatus::NonLinear) { + return; } else { resultStack.pop(); - expression->getSecondOperand()->accept(this); + rightResult = resultStack.top(); + if (rightResult == LinearityStatus::NonLinear) { + return; + } + resultStack.pop(); } - + + resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables); + break; case BinaryNumericalFunctionExpression::OperatorType::Times: case BinaryNumericalFunctionExpression::OperatorType::Divide: - case BinaryNumericalFunctionExpression::OperatorType::Min: resultStack.push(false); break; - case BinaryNumericalFunctionExpression::OperatorType::Max: resultStack.push(false); break; + expression->getFirstOperand()->accept(this); + leftResult = resultStack.top(); + + if (leftResult == LinearityStatus::NonLinear) { + return; + } else { + resultStack.pop(); + expression->getSecondOperand()->accept(this); + rightResult = resultStack.top(); + if (rightResult == LinearityStatus::NonLinear) { + return; + } + resultStack.pop(); + } + + if (leftResult == LinearityStatus::LinearContainsVariables && rightResult == LinearityStatus::LinearContainsVariables) { + resultStack.push(LinearityStatus::NonLinear); + } + + resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables); + break; + case BinaryNumericalFunctionExpression::OperatorType::Min: resultStack.push(LinearityStatus::NonLinear); break; + case BinaryNumericalFunctionExpression::OperatorType::Max: resultStack.push(LinearityStatus::NonLinear); break; } } void LinearityCheckVisitor::visit(BinaryRelationExpression const* expression) { - resultStack.push(false); + LinearityStatus leftResult; + LinearityStatus rightResult; + expression->getFirstOperand()->accept(this); + leftResult = resultStack.top(); + + if (leftResult == LinearityStatus::NonLinear) { + return; + } else { + resultStack.pop(); + expression->getSecondOperand()->accept(this); + rightResult = resultStack.top(); + if (rightResult == LinearityStatus::NonLinear) { + return; + } + resultStack.pop(); + } + + resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables); } void LinearityCheckVisitor::visit(VariableExpression const* expression) { - resultStack.push(true); + resultStack.push(LinearityStatus::LinearContainsVariables); } void LinearityCheckVisitor::visit(UnaryBooleanFunctionExpression const* expression) { // Boolean function applications are not allowed in linear expressions. - resultStack.push(false); + resultStack.push(LinearityStatus::NonLinear); } void LinearityCheckVisitor::visit(UnaryNumericalFunctionExpression const* expression) { - // Intentionally left empty (just pass subresult one level further). + switch (expression->getOperatorType()) { + case UnaryNumericalFunctionExpression::OperatorType::Minus: break; + case UnaryNumericalFunctionExpression::OperatorType::Floor: + case UnaryNumericalFunctionExpression::OperatorType::Ceil: resultStack.pop(); resultStack.push(LinearityStatus::NonLinear); break; + } } void LinearityCheckVisitor::visit(BooleanLiteralExpression const* expression) { - // Boolean function applications are not allowed in linear expressions. - resultStack.push(false); + resultStack.push(LinearityStatus::NonLinear); } void LinearityCheckVisitor::visit(IntegerLiteralExpression const* expression) { - resultStack.push(true); + resultStack.push(LinearityStatus::LinearWithoutVariables); } void LinearityCheckVisitor::visit(DoubleLiteralExpression const* expression) { - resultStack.push(true); + resultStack.push(LinearityStatus::LinearWithoutVariables); } } } \ No newline at end of file diff --git a/src/storage/expressions/LinearityCheckVisitor.h b/src/storage/expressions/LinearityCheckVisitor.h index 2b1c3937c..d76b658c8 100644 --- a/src/storage/expressions/LinearityCheckVisitor.h +++ b/src/storage/expressions/LinearityCheckVisitor.h @@ -13,12 +13,14 @@ namespace storm { /*! * Creates a linearity check visitor. */ - LinearityCheckVisitor() = default; + LinearityCheckVisitor(); /*! * Checks that the given expression is linear. + * + * @param expression The expression to check for linearity. */ - bool check(BaseExpression const* expression); + bool check(Expression const& expression); virtual void visit(IfThenElseExpression const* expression) override; virtual void visit(BinaryBooleanFunctionExpression const* expression) override; @@ -32,8 +34,10 @@ namespace storm { virtual void visit(DoubleLiteralExpression const* expression) override; private: + enum class LinearityStatus { NonLinear, LinearContainsVariables, LinearWithoutVariables }; + // A stack for communicating the results of the subexpressions. - std::stack resultStack; + std::stack resultStack; }; } } diff --git a/src/storage/expressions/SimpleValuation.cpp b/src/storage/expressions/SimpleValuation.cpp index 4a5441e99..923bbe84b 100644 --- a/src/storage/expressions/SimpleValuation.cpp +++ b/src/storage/expressions/SimpleValuation.cpp @@ -1,5 +1,8 @@ -#include #include "src/storage/expressions/SimpleValuation.h" + +#include + +#include #include "src/exceptions/ExceptionMacros.h" #include "src/exceptions/InvalidArgumentException.h" #include "src/exceptions/InvalidAccessException.h" @@ -43,6 +46,18 @@ namespace storm { this->identifierToValueMap.erase(nameValuePair); } + ExpressionReturnType SimpleValuation::getIdentifierType(std::string const& name) const { + auto nameValuePair = this->identifierToValueMap.find(name); + LOG_THROW(nameValuePair != this->identifierToValueMap.end(), storm::exceptions::InvalidAccessException, "Access to unkown identifier '" << name << "'."); + if (nameValuePair->second.type() == typeid(bool)) { + return ExpressionReturnType::Bool; + } else if (nameValuePair->second.type() == typeid(int_fast64_t)) { + return ExpressionReturnType::Int; + } else { + return ExpressionReturnType::Double; + } + } + bool SimpleValuation::containsBooleanIdentifier(std::string const& name) const { auto nameValuePair = this->identifierToValueMap.find(name); if (nameValuePair == this->identifierToValueMap.end()) { @@ -85,6 +100,48 @@ namespace storm { return boost::get(nameValuePair->second); } + std::size_t SimpleValuation::getNumberOfIdentifiers() const { + return this->identifierToValueMap.size(); + } + + std::set SimpleValuation::getIdentifiers() const { + std::set result; + for (auto const& nameValuePair : this->identifierToValueMap) { + result.insert(nameValuePair.first); + } + return result; + } + + std::set SimpleValuation::getBooleanIdentifiers() const { + std::set result; + for (auto const& nameValuePair : this->identifierToValueMap) { + if (nameValuePair.second.type() == typeid(bool)) { + result.insert(nameValuePair.first); + } + } + return result; + } + + std::set SimpleValuation::getIntegerIdentifiers() const { + std::set result; + for (auto const& nameValuePair : this->identifierToValueMap) { + if (nameValuePair.second.type() == typeid(int_fast64_t)) { + result.insert(nameValuePair.first); + } + } + return result; + } + + std::set SimpleValuation::getDoubleIdentifiers() const { + std::set result; + for (auto const& nameValuePair : this->identifierToValueMap) { + if (nameValuePair.second.type() == typeid(double)) { + result.insert(nameValuePair.first); + } + } + return result; + } + std::ostream& operator<<(std::ostream& stream, SimpleValuation const& valuation) { stream << "{ "; uint_fast64_t elementIndex = 0; diff --git a/src/storage/expressions/SimpleValuation.h b/src/storage/expressions/SimpleValuation.h index 0ea05a511..fccfe2faa 100644 --- a/src/storage/expressions/SimpleValuation.h +++ b/src/storage/expressions/SimpleValuation.h @@ -6,6 +6,7 @@ #include #include "src/storage/expressions/Valuation.h" +#include "src/storage/expressions/ExpressionReturnType.h" #include "src/utility/OsDetection.h" namespace storm { @@ -84,11 +85,24 @@ namespace storm { * @param name The name of the identifier that is to be removed. */ void removeIdentifier(std::string const& name); + + /*! + * Retrieves the type of the identifier with the given name. + * + * @param name The name of the identifier whose type to retrieve. + * @return The type of the identifier with the given name. + */ + ExpressionReturnType getIdentifierType(std::string const& name) const; // Override base class methods. virtual bool containsBooleanIdentifier(std::string const& name) const override; virtual bool containsIntegerIdentifier(std::string const& name) const override; virtual bool containsDoubleIdentifier(std::string const& name) const override; + virtual std::size_t getNumberOfIdentifiers() const override; + virtual std::set getIdentifiers() const override; + virtual std::set getBooleanIdentifiers() const override; + virtual std::set getIntegerIdentifiers() const override; + virtual std::set getDoubleIdentifiers() const override; virtual bool getBooleanValue(std::string const& name) const override; virtual int_fast64_t getIntegerValue(std::string const& name) const override; virtual double getDoubleValue(std::string const& name) const override; diff --git a/src/storage/expressions/SubstitutionVisitor.cpp b/src/storage/expressions/SubstitutionVisitor.cpp index 2559b474b..43aa01ab3 100644 --- a/src/storage/expressions/SubstitutionVisitor.cpp +++ b/src/storage/expressions/SubstitutionVisitor.cpp @@ -13,8 +13,8 @@ namespace storm { } template - Expression SubstitutionVisitor::substitute(BaseExpression const* expression) { - expression->accept(this); + Expression SubstitutionVisitor::substitute(Expression const& expression) { + expression.getBaseExpression().accept(this); return Expression(this->expressionStack.top()); } diff --git a/src/storage/expressions/SubstitutionVisitor.h b/src/storage/expressions/SubstitutionVisitor.h index bc58148e3..0ebc0941e 100644 --- a/src/storage/expressions/SubstitutionVisitor.h +++ b/src/storage/expressions/SubstitutionVisitor.h @@ -26,7 +26,7 @@ namespace storm { * @return The expression in which all identifiers in the key set of the previously given mapping are * substituted with the mapped-to expressions. */ - Expression substitute(BaseExpression const* expression); + Expression substitute(Expression const& expression); virtual void visit(IfThenElseExpression const* expression) override; virtual void visit(BinaryBooleanFunctionExpression const* expression) override; diff --git a/src/storage/expressions/TypeCheckVisitor.cpp b/src/storage/expressions/TypeCheckVisitor.cpp index 044b85237..5ab80e141 100644 --- a/src/storage/expressions/TypeCheckVisitor.cpp +++ b/src/storage/expressions/TypeCheckVisitor.cpp @@ -12,8 +12,8 @@ namespace storm { } template - void TypeCheckVisitor::check(BaseExpression const* expression) { - expression->accept(this); + void TypeCheckVisitor::check(Expression const& expression) { + expression.getBaseExpression().accept(this); } template diff --git a/src/storage/expressions/TypeCheckVisitor.h b/src/storage/expressions/TypeCheckVisitor.h index 7772e0e7e..0cbf40f92 100644 --- a/src/storage/expressions/TypeCheckVisitor.h +++ b/src/storage/expressions/TypeCheckVisitor.h @@ -24,7 +24,7 @@ namespace storm { * * @param expression The expression in which to check the types. */ - void check(BaseExpression const* expression); + void check(Expression const& expression); virtual void visit(IfThenElseExpression const* expression) override; virtual void visit(BinaryBooleanFunctionExpression const* expression) override; diff --git a/src/storage/expressions/Valuation.h b/src/storage/expressions/Valuation.h index 19792b720..baf39462f 100644 --- a/src/storage/expressions/Valuation.h +++ b/src/storage/expressions/Valuation.h @@ -58,6 +58,42 @@ namespace storm { * @return True iff the identifier exists and is of boolean type. */ virtual bool containsDoubleIdentifier(std::string const& name) const = 0; + + /*! + * Retrieves the number of identifiers in this valuation. + * + * @return The number of identifiers in this valuation. + */ + virtual std::size_t getNumberOfIdentifiers() const = 0; + + /*! + * Retrieves the set of all identifiers contained in this valuation. + * + * @return The set of all identifiers contained in this valuation. + */ + virtual std::set getIdentifiers() const = 0; + + /*! + * Retrieves the set of boolean identifiers contained in this valuation. + * + * @return The set of boolean identifiers contained in this valuation. + */ + virtual std::set getBooleanIdentifiers() const = 0; + + /*! + * Retrieves the set of integer identifiers contained in this valuation. + * + * @return The set of integer identifiers contained in this valuation. + */ + virtual std::set getIntegerIdentifiers() const = 0; + + /*! + * Retrieves the set of double identifiers contained in this valuation. + * + * @return The set of double identifiers contained in this valuation. + */ + virtual std::set getDoubleIdentifiers() const = 0; + }; } diff --git a/test/functional/storage/ExpressionTest.cpp b/test/functional/storage/ExpressionTest.cpp index 26822190d..d3a920d1d 100644 --- a/test/functional/storage/ExpressionTest.cpp +++ b/test/functional/storage/ExpressionTest.cpp @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/LinearityCheckVisitor.h" #include "src/storage/expressions/SimpleValuation.h" #include "src/exceptions/InvalidTypeException.h" @@ -332,4 +333,20 @@ TEST(Expression, SimpleEvaluationTest) { ASSERT_THROW(tempExpression.evaluateAsDouble(&valuation), storm::exceptions::InvalidTypeException); ASSERT_THROW(tempExpression.evaluateAsInt(&valuation), storm::exceptions::InvalidTypeException); EXPECT_FALSE(tempExpression.evaluateAsBool(&valuation)); +} + +TEST(Expression, VisitorTest) { + storm::expressions::Expression threeExpression; + storm::expressions::Expression piExpression; + storm::expressions::Expression intVarExpression; + storm::expressions::Expression doubleVarExpression; + + ASSERT_NO_THROW(threeExpression = storm::expressions::Expression::createIntegerLiteral(3)); + ASSERT_NO_THROW(piExpression = storm::expressions::Expression::createDoubleLiteral(3.14)); + ASSERT_NO_THROW(intVarExpression = storm::expressions::Expression::createIntegerVariable("y")); + ASSERT_NO_THROW(doubleVarExpression = storm::expressions::Expression::createDoubleVariable("z")); + + storm::expressions::Expression tempExpression = intVarExpression + doubleVarExpression * threeExpression; + storm::expressions::LinearityCheckVisitor visitor; + EXPECT_TRUE(visitor.check(tempExpression)); } \ No newline at end of file