From 92d550be1259263de8fca2ab730851926287feaf Mon Sep 17 00:00:00 2001 From: dehnert Date: Sun, 4 Jan 2015 19:18:00 +0100 Subject: [PATCH] More and more refactoring. Former-commit-id: b2f5b25c9259e80d2e1fd48e937f34ea29f17cf6 --- src/storage/expressions/Expression.cpp | 86 ++++------- src/storage/expressions/Expression.h | 36 ++--- src/storage/expressions/ExpressionManager.cpp | 34 +++-- src/storage/expressions/ExpressionManager.h | 16 +- .../expressions/LinearCoefficientVisitor.cpp | 141 ++++++++++-------- .../expressions/LinearCoefficientVisitor.h | 29 +++- src/storage/expressions/Type.cpp | 31 ++-- src/storage/expressions/Type.h | 20 ++- src/storage/expressions/Variable.cpp | 24 ++- src/storage/expressions/Variable.h | 46 +++--- .../expressions/VariableExpression.cpp | 18 +-- src/storage/prism/Assignment.cpp | 2 +- src/storage/prism/Assignment.h | 5 +- src/storage/prism/BooleanVariable.cpp | 10 +- src/storage/prism/BooleanVariable.h | 17 +-- src/storage/prism/Command.cpp | 2 +- src/storage/prism/Command.h | 2 +- src/storage/prism/Constant.cpp | 26 ++-- src/storage/prism/Constant.h | 33 ++-- src/storage/prism/Formula.cpp | 6 +- src/storage/prism/Formula.h | 7 +- src/storage/prism/InitialConstruct.cpp | 2 +- src/storage/prism/InitialConstruct.h | 3 +- src/storage/prism/IntegerVariable.cpp | 10 +- src/storage/prism/IntegerVariable.h | 19 +-- src/storage/prism/Label.cpp | 2 +- src/storage/prism/Label.h | 3 +- src/storage/prism/Program.cpp | 50 ++++--- src/storage/prism/Program.h | 24 ++- src/storage/prism/StateReward.cpp | 2 +- src/storage/prism/StateReward.h | 3 +- src/storage/prism/TransitionReward.cpp | 2 +- src/storage/prism/TransitionReward.h | 3 +- src/storage/prism/Update.cpp | 2 +- src/storage/prism/Update.h | 2 +- src/storage/prism/Variable.cpp | 16 +- src/storage/prism/Variable.h | 30 +++- 37 files changed, 422 insertions(+), 342 deletions(-) diff --git a/src/storage/expressions/Expression.cpp b/src/storage/expressions/Expression.cpp index 2adbb4235..db11f28ac 100644 --- a/src/storage/expressions/Expression.cpp +++ b/src/storage/expressions/Expression.cpp @@ -2,8 +2,8 @@ #include #include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/ExpressionManager.h" #include "src/storage/expressions/SubstitutionVisitor.h" -#include "src/storage/expressions/IdentifierSubstitutionVisitor.h" #include "src/storage/expressions/LinearityCheckVisitor.h" #include "src/storage/expressions/Expressions.h" #include "src/exceptions/InvalidTypeException.h" @@ -15,24 +15,16 @@ namespace storm { // Intentionally left empty. } - Expression::Expression(Variable const& variable) : expressionPtr(new VariableExpression(variable)) { + Expression::Expression(Variable const& variable) : expressionPtr(std::shared_ptr(new VariableExpression(variable))) { // Intentionally left empty. } - Expression Expression::substitute(std::map const& identifierToExpressionMap) const { - return SubstitutionVisitor>(identifierToExpressionMap).substitute(*this); + Expression Expression::substitute(std::map const& identifierToExpressionMap) const { + return SubstitutionVisitor>(identifierToExpressionMap).substitute(*this); } - Expression Expression::substitute(std::unordered_map const& identifierToExpressionMap) const { - return SubstitutionVisitor>(identifierToExpressionMap).substitute(*this); - } - - Expression Expression::substitute(std::map const& identifierToIdentifierMap) const { - return IdentifierSubstitutionVisitor>(identifierToIdentifierMap).substitute(*this); - } - - Expression Expression::substitute(std::unordered_map const& identifierToIdentifierMap) const { - return IdentifierSubstitutionVisitor>(identifierToIdentifierMap).substitute(*this); + Expression Expression::substitute(std::unordered_map const& identifierToExpressionMap) const { + return SubstitutionVisitor>(identifierToExpressionMap).substitute(*this); } bool Expression::evaluateAsBool(Valuation const* valuation) const { @@ -117,7 +109,7 @@ namespace storm { return this->expressionPtr; } - Type Expression::getType() const { + Type const& Expression::getType() const { return this->getBaseExpression().getType(); } @@ -140,130 +132,110 @@ namespace storm { Expression Expression::operator-(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '-' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Minus))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().plusMinusTimes(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Minus))); } Expression Expression::operator-() const { - STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '-' requires numerical operand."); - return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Minus))); + return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Minus))); } Expression Expression::operator*(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '*' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Times))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().plusMinusTimes(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Times))); } Expression Expression::operator/(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '/' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Divide))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().divide(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Divide))); } Expression Expression::operator^(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '^' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getReturnType() == ExpressionReturnType::Int && other.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Power))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().power(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Power))); } Expression Expression::operator&&(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '&&' requires boolean operands."); - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::And))); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::And))); } Expression Expression::operator||(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '||' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Or))); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Or))); } Expression Expression::operator!() const { - STORM_LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '!' requires boolean operand."); - return Expression(std::shared_ptr(new UnaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), UnaryBooleanFunctionExpression::OperatorType::Not))); + return Expression(std::shared_ptr(new UnaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(), this->getBaseExpressionPointer(), UnaryBooleanFunctionExpression::OperatorType::Not))); } Expression Expression::operator==(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '==' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Equal))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Equal))); } Expression Expression::operator!=(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW((this->hasNumericalReturnType() && other.hasNumericalReturnType()) || (this->hasBooleanReturnType() && other.hasBooleanReturnType()), storm::exceptions::InvalidTypeException, "Operator '!=' requires operands of equal type."); if (this->hasNumericalReturnType() && other.hasNumericalReturnType()) { - return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::NotEqual))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::NotEqual))); } else { - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Xor))); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Xor))); } } Expression Expression::operator>(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '>' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Greater))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Greater))); } Expression Expression::operator>=(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '>=' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::GreaterOrEqual))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::GreaterOrEqual))); } Expression Expression::operator<(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '<' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Less))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::Less))); } Expression Expression::operator<=(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasNumericalReturnType() && other.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator '<=' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::LessOrEqual))); + return Expression(std::shared_ptr(new BinaryRelationExpression(this->getBaseExpression().getManager(), this->getType().numericalComparison(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryRelationExpression::RelationType::LessOrEqual))); } Expression Expression::minimum(Expression const& lhs, Expression const& rhs) { assertSameManager(lhs.getBaseExpression(), rhs.getBaseExpression()); - STORM_LOG_THROW(lhs.hasNumericalReturnType() && rhs.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'min' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(lhs.getBaseExpression().getManager(), lhs.getReturnType() == ExpressionReturnType::Int && rhs.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Min))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(lhs.getBaseExpression().getManager(), lhs.getType().minimumMaximum(rhs.getType()), lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Min))); } Expression Expression::maximum(Expression const& lhs, Expression const& rhs) { assertSameManager(lhs.getBaseExpression(), rhs.getBaseExpression()); - STORM_LOG_THROW(lhs.hasNumericalReturnType() && rhs.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'max' requires numerical operands."); - return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(lhs.getBaseExpression().getManager(), lhs.getReturnType() == ExpressionReturnType::Int && rhs.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double, lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Max))); + return Expression(std::shared_ptr(new BinaryNumericalFunctionExpression(lhs.getBaseExpression().getManager(), lhs.getType().minimumMaximum(rhs.getType()), lhs.getBaseExpressionPointer(), rhs.getBaseExpressionPointer(), BinaryNumericalFunctionExpression::OperatorType::Max))); } Expression Expression::ite(Expression const& thenExpression, Expression const& elseExpression) { assertSameManager(this->getBaseExpression(), thenExpression.getBaseExpression()); assertSameManager(thenExpression.getBaseExpression(), elseExpression.getBaseExpression()); - STORM_LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Condition of if-then-else operator must be of boolean type."); - STORM_LOG_THROW(thenExpression.hasBooleanReturnType() && elseExpression.hasBooleanReturnType() || thenExpression.hasNumericalReturnType() && elseExpression.hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "'then' and 'else' expression of if-then-else operator must have equal return type."); - return Expression(std::shared_ptr(new IfThenElseExpression(this->getBaseExpression().getManager(), thenExpression.hasBooleanReturnType() && elseExpression.hasBooleanReturnType() ? ExpressionReturnType::Bool : (thenExpression.getReturnType() == ExpressionReturnType::Int && elseExpression.getReturnType() == ExpressionReturnType::Int ? ExpressionReturnType::Int : ExpressionReturnType::Double), this->getBaseExpressionPointer(), thenExpression.getBaseExpressionPointer(), elseExpression.getBaseExpressionPointer()))); + return Expression(std::shared_ptr(new IfThenElseExpression(this->getBaseExpression().getManager(), this->getType().ite(thenExpression.getType(), elseExpression.getType()), this->getBaseExpressionPointer(), thenExpression.getBaseExpressionPointer(), elseExpression.getBaseExpressionPointer()))); } Expression Expression::implies(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '&&' requires boolean operands."); - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Implies))); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Implies))); } Expression Expression::iff(Expression const& other) const { assertSameManager(this->getBaseExpression(), other.getBaseExpression()); - STORM_LOG_THROW(this->hasBooleanReturnType() && other.hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Operator '&&' requires boolean operands."); - return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Bool, this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Iff))); + return Expression(std::shared_ptr(new BinaryBooleanFunctionExpression(this->getBaseExpression().getManager(), this->getType().logicalConnective(other.getType()), this->getBaseExpressionPointer(), other.getBaseExpressionPointer(), BinaryBooleanFunctionExpression::OperatorType::Iff))); } Expression Expression::floor() const { STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'floor' requires numerical operand."); - return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Floor))); + return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().floorCeil(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Floor))); } Expression Expression::ceil() const { STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Operator 'ceil' requires numerical operand."); - return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Ceil))); + return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(this->getBaseExpression().getManager(), this->getType().floorCeil(), this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Ceil))); } boost::any Expression::accept(ExpressionVisitor& visitor) const { diff --git a/src/storage/expressions/Expression.h b/src/storage/expressions/Expression.h index 8334f415f..4afb90e2d 100644 --- a/src/storage/expressions/Expression.h +++ b/src/storage/expressions/Expression.h @@ -59,44 +59,26 @@ namespace storm { static Expression maximum(Expression const& lhs, Expression const& rhs); /*! - * Substitutes all occurrences of identifiers according to the given map. Note that this substitution is - * done simultaneously, i.e., identifiers appearing in the expressions that were "plugged in" are not + * Substitutes all occurrences of the variables according to the given map. Note that this substitution is + * done simultaneously, i.e., variables appearing in the expressions that were "plugged in" are not * substituted. * - * @param identifierToExpressionMap A mapping from identifiers to the expression they are substituted with. + * @param variableToExpressionMap A mapping from variables to the expression they are substituted with. * @return An expression in which all identifiers in the key set of the mapping are replaced by the * expression they are mapped to. */ - Expression substitute(std::map const& identifierToExpressionMap) const; + Expression substitute(std::map const& variableToExpressionMap) const; /*! - * Substitutes all occurrences of identifiers according to the given map. Note that this substitution is - * done simultaneously, i.e., identifiers appearing in the expressions that were "plugged in" are not + * Substitutes all occurrences of the variables according to the given map. Note that this substitution is + * done simultaneously, i.e., variables appearing in the expressions that were "plugged in" are not * substituted. * - * @param identifierToExpressionMap A mapping from identifiers to the expression they are substituted with. + * @param variableToExpressionMap A mapping from variables to the expression they are substituted with. * @return An expression in which all identifiers in the key set of the mapping are replaced by the * expression they are mapped to. */ - Expression substitute(std::unordered_map const& identifierToExpressionMap) const; - - /*! - * Substitutes all occurrences of identifiers with different names given by a mapping. - * - * @param identifierToIdentifierMap A mapping from identifiers to identifiers they are substituted with. - * @return An expression in which all identifiers in the key set of the mapping are replaced by the - * identifiers they are mapped to. - */ - Expression substitute(std::map const& identifierToIdentifierMap) const; - - /*! - * Substitutes all occurrences of identifiers with different names given by a mapping. - * - * @param identifierToIdentifierMap A mapping from identifiers to identifiers they are substituted with. - * @return An expression in which all identifiers in the key set of the mapping are replaced by the - * identifiers they are mapped to. - */ - Expression substitute(std::unordered_map const& identifierToIdentifierMap) const; + Expression substitute(std::unordered_map const& variableToExpressionMap) const; /*! * Evaluates the expression under the valuation of variables given by the valuation and returns the @@ -247,7 +229,7 @@ namespace storm { * * @return The type of the expression. */ - Type getType() const; + Type const& getType() const; /*! * Retrieves whether the expression has a numerical return type, i.e., integer or double. diff --git a/src/storage/expressions/ExpressionManager.cpp b/src/storage/expressions/ExpressionManager.cpp index 42e555431..3e07bcfae 100644 --- a/src/storage/expressions/ExpressionManager.cpp +++ b/src/storage/expressions/ExpressionManager.cpp @@ -47,7 +47,7 @@ namespace storm { } if (nameIndexIterator != nameIndexIteratorEnd) { - currentElement = std::make_pair(Variable(manager, nameIndexIterator->second), manager.getVariableType(nameIndexIterator->second)); + currentElement = std::make_pair(Variable(manager.getSharedPointer(), nameIndexIterator->second), manager.getVariableType(nameIndexIterator->second)); } } @@ -56,15 +56,15 @@ namespace storm { } Expression ExpressionManager::boolean(bool value) const { - return Expression(std::shared_ptr(new BooleanLiteralExpression(*this, value))); + return Expression(std::shared_ptr(new BooleanLiteralExpression(*this, value))); } Expression ExpressionManager::integer(int_fast64_t value) const { - return Expression(std::shared_ptr(new IntegerLiteralExpression(*this, value))); + return Expression(std::shared_ptr(new IntegerLiteralExpression(*this, value))); } Expression ExpressionManager::rational(double value) const { - return Expression(std::shared_ptr(new DoubleLiteralExpression(*this, value))); + return Expression(std::shared_ptr(new DoubleLiteralExpression(*this, value))); } bool ExpressionManager::operator==(ExpressionManager const& other) const { @@ -72,19 +72,19 @@ namespace storm { } Type ExpressionManager::getBooleanType() const { - return Type(std::shared_ptr(new BooleanType())); + return Type(this->getSharedPointer(), std::shared_ptr(new BooleanType())); } Type ExpressionManager::getIntegerType() const { - return Type(std::shared_ptr(new IntegerType())); + return Type(this->getSharedPointer(), std::shared_ptr(new IntegerType())); } Type ExpressionManager::getBoundedIntegerType(std::size_t width) const { - return Type(std::shared_ptr(new BoundedIntegerType(width))); + return Type(this->getSharedPointer(), std::shared_ptr(new BoundedIntegerType(width))); } Type ExpressionManager::getRationalType() const { - return Type(std::shared_ptr(new RationalType())); + return Type(this->getSharedPointer(), std::shared_ptr(new RationalType())); } bool ExpressionManager::isValidVariableName(std::string const& name) { @@ -110,7 +110,7 @@ namespace storm { STORM_LOG_THROW(isValidVariableName(name), storm::exceptions::InvalidArgumentException, "Invalid variable name '" << name << "'."); auto nameIndexPair = nameToIndexMapping.find(name); if (nameIndexPair != nameToIndexMapping.end()) { - return Variable(*this, nameIndexPair->second); + return Variable(this->getSharedPointer(), nameIndexPair->second); } else { std::unordered_map::iterator typeCountPair = variableTypeToCountMapping.find(variableType); uint_fast64_t& oldCount = variableTypeToCountMapping[variableType]; @@ -122,7 +122,7 @@ namespace storm { nameToIndexMapping[name] = newIndex; indexToNameMapping[newIndex] = name; indexToTypeMapping[newIndex] = variableType; - return Variable(*this, newIndex); + return Variable(this->getSharedPointer(), newIndex); } } @@ -130,7 +130,7 @@ namespace storm { STORM_LOG_THROW(isValidVariableName(name), storm::exceptions::InvalidArgumentException, "Invalid variable name '" << name << "'."); auto nameIndexPair = nameToIndexMapping.find(name); if (nameIndexPair != nameToIndexMapping.end()) { - return Variable(*this, nameIndexPair->second); + return Variable(this->getSharedPointer(), nameIndexPair->second); } else { std::unordered_map::iterator typeCountPair = auxiliaryVariableTypeToCountMapping.find(variableType); uint_fast64_t& oldCount = auxiliaryVariableTypeToCountMapping[variableType]; @@ -142,14 +142,14 @@ namespace storm { nameToIndexMapping[name] = newIndex; indexToNameMapping[newIndex] = name; indexToTypeMapping[newIndex] = variableType; - return Variable(*this, newIndex); + return Variable(this->getSharedPointer(), newIndex); } } Variable ExpressionManager::getVariable(std::string const& name) const { auto nameIndexPair = nameToIndexMapping.find(name); STORM_LOG_THROW(nameIndexPair != nameToIndexMapping.end(), storm::exceptions::InvalidArgumentException, "Unknown variable '" << name << "'."); - return Variable(*this, nameIndexPair->second); + return Variable(this->getSharedPointer(), nameIndexPair->second); } Expression ExpressionManager::getVariableExpression(std::string const& name) const { @@ -240,5 +240,13 @@ namespace storm { return ExpressionManager::const_iterator(*this, this->nameToIndexMapping.end(), this->nameToIndexMapping.end(), const_iterator::VariableSelection::OnlyRegularVariables); } + std::shared_ptr ExpressionManager::getSharedPointer() { + return this->shared_from_this(); + } + + std::shared_ptr ExpressionManager::getSharedPointer() const { + return this->shared_from_this(); + } + } // namespace expressions } // namespace storm \ No newline at end of file diff --git a/src/storage/expressions/ExpressionManager.h b/src/storage/expressions/ExpressionManager.h index 3cc84b5bc..9b4acfc8c 100644 --- a/src/storage/expressions/ExpressionManager.h +++ b/src/storage/expressions/ExpressionManager.h @@ -58,7 +58,7 @@ namespace storm { /*! * This class is responsible for managing a set of typed variables and all expressions using these variables. */ - class ExpressionManager { + class ExpressionManager : public std::enable_shared_from_this { public: friend class VariableIterator; @@ -315,6 +315,20 @@ namespace storm { const_iterator end() const; private: + /*! + * Retrieves a shared pointer to the expression manager. + * + * @return A shared pointer to the expression manager. + */ + std::shared_ptr getSharedPointer(); + + /*! + * Retrieves a shared pointer to the expression manager. + * + * @return A shared pointer to the expression manager. + */ + std::shared_ptr getSharedPointer() const; + /*! * Checks whether the given variable name is valid. * diff --git a/src/storage/expressions/LinearCoefficientVisitor.cpp b/src/storage/expressions/LinearCoefficientVisitor.cpp index 29adef811..aa644f588 100644 --- a/src/storage/expressions/LinearCoefficientVisitor.cpp +++ b/src/storage/expressions/LinearCoefficientVisitor.cpp @@ -6,8 +6,65 @@ namespace storm { namespace expressions { - std::pair LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) { - return boost::any_cast>(expression.getBaseExpression().accept(*this)); + LinearCoefficientVisitor::VariableCoefficients::VariableCoefficients(double constantPart) : variableToCoefficientMapping(), constantPart(constantPart) { + // Intentionally left empty. + } + + LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator+=(VariableCoefficients&& other) { + for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) { + this->variableToCoefficientMapping[otherVariableCoefficientPair.first] += otherVariableCoefficientPair.second; + } + constantPart += other.constantPart; + return *this; + } + + LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator-=(VariableCoefficients&& other) { + for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) { + this->variableToCoefficientMapping[otherVariableCoefficientPair.first] -= otherVariableCoefficientPair.second; + } + constantPart -= other.constantPart; + return *this; + } + + LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator*=(VariableCoefficients&& other) { + STORM_LOG_THROW(variableToCoefficientMapping.size() == 0 || other.variableToCoefficientMapping.size() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); + if (other.variableToCoefficientMapping.size() > 0) { + variableToCoefficientMapping = std::move(other.variableToCoefficientMapping); + std::swap(constantPart, other.constantPart); + } + for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) { + this->variableToCoefficientMapping[otherVariableCoefficientPair.first] *= other.constantPart; + } + constantPart *= other.constantPart; + return *this; + } + + LinearCoefficientVisitor::VariableCoefficients& LinearCoefficientVisitor::VariableCoefficients::operator/=(VariableCoefficients&& other) { + STORM_LOG_THROW(other.variableToCoefficientMapping.size() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); + for (auto const& otherVariableCoefficientPair : other.variableToCoefficientMapping) { + this->variableToCoefficientMapping[otherVariableCoefficientPair.first] /= other.constantPart; + } + constantPart /= other.constantPart; + return *this; + } + + void LinearCoefficientVisitor::VariableCoefficients::negate() { + for (auto& variableCoefficientPair : variableToCoefficientMapping) { + variableCoefficientPair.second = -variableCoefficientPair.second; + } + constantPart = -constantPart; + } + + void LinearCoefficientVisitor::VariableCoefficients::setCoefficient(storm::expressions::Variable const& variable, double coefficient) { + variableToCoefficientMapping[variable] = coefficient; + } + + double LinearCoefficientVisitor::VariableCoefficients::getCoefficient(storm::expressions::Variable const& variable) { + return variableToCoefficientMapping[variable]; + } + + LinearCoefficientVisitor::VariableCoefficients LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) { + return boost::any_cast(expression.getBaseExpression().accept(*this)); } boost::any LinearCoefficientVisitor::visit(IfThenElseExpression const& expression) { @@ -19,60 +76,17 @@ namespace storm { } boost::any LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const& expression) { - std::pair leftResult = boost::any_cast>(expression.getFirstOperand()->accept(*this)); - std::pair rightResult = boost::any_cast>(expression.getSecondOperand()->accept(*this)); + VariableCoefficients leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); + VariableCoefficients rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) { - // Now add the left result to the right result. - for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) { - if (rightResult.first.containsDoubleIdentifier(identifier)) { - rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) + rightResult.first.getDoubleValue(identifier)); - } else { - rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier)); - } - } - rightResult.second += leftResult.second; + leftResult += std::move(rightResult); } else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) { - // Now subtract the right result from the left result. - for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) { - if (rightResult.first.containsDoubleIdentifier(identifier)) { - rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier) - rightResult.first.getDoubleValue(identifier)); - } else { - rightResult.first.setDoubleValue(identifier, leftResult.first.getDoubleValue(identifier)); - } - } - for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) { - if (!leftResult.first.containsDoubleIdentifier(identifier)) { - rightResult.first.setDoubleValue(identifier, -rightResult.first.getDoubleValue(identifier)); - } - } - rightResult.second = leftResult.second - rightResult.second; + leftResult -= std::move(rightResult); } else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) { - // If the expression is linear, either the left or the right side must not contain variables. - STORM_LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); - if (leftResult.first.getNumberOfIdentifiers() == 0) { - for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) { - rightResult.first.setDoubleValue(identifier, leftResult.second * rightResult.first.getDoubleValue(identifier)); - } - } else { - for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) { - rightResult.first.addDoubleIdentifier(identifier, rightResult.second * leftResult.first.getDoubleValue(identifier)); - } - } - rightResult.second *= leftResult.second; + leftResult *= std::move(rightResult); } else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) { - // If the expression is linear, either the left or the right side must not contain variables. - STORM_LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); - if (leftResult.first.getNumberOfIdentifiers() == 0) { - for (auto const& identifier : rightResult.first.getDoubleIdentifiers()) { - rightResult.first.setDoubleValue(identifier, leftResult.second / rightResult.first.getDoubleValue(identifier)); - } - } else { - for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) { - rightResult.first.addDoubleIdentifier(identifier, leftResult.first.getDoubleValue(identifier) / rightResult.second); - } - } - rightResult.second = leftResult.second / leftResult.second; + leftResult /= std::move(rightResult); } else { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } @@ -84,15 +98,13 @@ namespace storm { } boost::any LinearCoefficientVisitor::visit(VariableExpression const& expression) { - SimpleValuation valuation; - switch (expression.getReturnType()) { - case ExpressionReturnType::Bool: STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); break; - case ExpressionReturnType::Int: - case ExpressionReturnType::Double: valuation.addDoubleIdentifier(expression.getVariableName(), 1); break; - case ExpressionReturnType::Undefined: STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Illegal expression return type."); break; + VariableCoefficients coefficients; + if (expression.getType().isNumericalType()) { + coefficients.setCoefficient(expression.getVariable(), 1); + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } - - return std::make_pair(valuation, static_cast(0)); + return coefficients; } boost::any LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const& expression) { @@ -100,13 +112,10 @@ namespace storm { } boost::any LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const& expression) { - std::pair childResult = boost::any_cast>(expression.getOperand()->accept(*this)); + VariableCoefficients childResult = boost::any_cast(expression.getOperand()->accept(*this)); if (expression.getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) { - // Here, we need to negate all double identifiers. - for (auto const& identifier : childResult.first.getDoubleIdentifiers()) { - childResult.first.setDoubleValue(identifier, -childResult.first.getDoubleValue(identifier)); - } + childResult.negate(); return childResult; } else { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); @@ -118,11 +127,11 @@ namespace storm { } boost::any LinearCoefficientVisitor::visit(IntegerLiteralExpression const& expression) { - return std::make_pair(SimpleValuation(), static_cast(expression.getValue())); + return VariableCoefficients(static_cast(expression.getValue())); } boost::any LinearCoefficientVisitor::visit(DoubleLiteralExpression const& expression) { - return std::make_pair(SimpleValuation(), expression.getValue()); + return VariableCoefficients(expression.getValue()); } } } \ No newline at end of file diff --git a/src/storage/expressions/LinearCoefficientVisitor.h b/src/storage/expressions/LinearCoefficientVisitor.h index faf02414b..38bcd88c7 100644 --- a/src/storage/expressions/LinearCoefficientVisitor.h +++ b/src/storage/expressions/LinearCoefficientVisitor.h @@ -4,6 +4,7 @@ #include #include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/Variable.h" #include "src/storage/expressions/ExpressionVisitor.h" #include "src/storage/expressions/SimpleValuation.h" @@ -11,6 +12,29 @@ namespace storm { namespace expressions { class LinearCoefficientVisitor : public ExpressionVisitor { public: + struct VariableCoefficients { + public: + VariableCoefficients(double constantPart = 0); + + VariableCoefficients(VariableCoefficients const& other) = default; + VariableCoefficients& operator=(VariableCoefficients const& other) = default; + VariableCoefficients(VariableCoefficients&& other) = default; + VariableCoefficients& operator=(VariableCoefficients&& other) = default; + + VariableCoefficients& operator+=(VariableCoefficients&& other); + VariableCoefficients& operator-=(VariableCoefficients&& other); + VariableCoefficients& operator*=(VariableCoefficients&& other); + VariableCoefficients& operator/=(VariableCoefficients&& other); + + void negate(); + void setCoefficient(storm::expressions::Variable const& variable, double coefficient); + double getCoefficient(storm::expressions::Variable const& variable); + + private: + std::map variableToCoefficientMapping; + double constantPart; + }; + /*! * Creates a linear coefficient visitor. */ @@ -21,10 +45,9 @@ namespace storm { * was rewritten as a sum of atoms.. If the expression is not linear, an exception is thrown. * * @param expression The expression for which to compute the coefficients. - * @return A pair consisting of a mapping from identifiers to their coefficients and the coefficient of - * the constant atom. + * @return A structure representing the coefficients of the variables and the constant part. */ - std::pair getLinearCoefficients(Expression const& expression); + VariableCoefficients getLinearCoefficients(Expression const& expression); virtual boost::any visit(IfThenElseExpression const& expression) override; virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override; diff --git a/src/storage/expressions/Type.cpp b/src/storage/expressions/Type.cpp index bfc2adaff..f9d0e4713 100644 --- a/src/storage/expressions/Type.cpp +++ b/src/storage/expressions/Type.cpp @@ -50,7 +50,7 @@ namespace storm { return "rational"; } - Type::Type(ExpressionManager const& manager, std::shared_ptr innerType) : manager(manager), innerType(innerType) { + Type::Type(std::shared_ptr const& manager, std::shared_ptr innerType) : manager(manager), innerType(innerType) { // Intentionally left empty. } @@ -94,12 +94,16 @@ namespace storm { return typeid(*this->innerType) == typeid(RationalType); } + storm::expressions::ExpressionManager const& Type::getManager() const { + return *manager; + } + Type Type::plusMinusTimes(Type const& other) const { STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands."); if (this->isRationalType() || other.isRationalType()) { - return manager.getRationalType(); + return this->getManager().getRationalType(); } - return manager.getIntegerType(); + return getManager().getIntegerType(); } Type Type::minus() const { @@ -109,7 +113,15 @@ namespace storm { Type Type::divide(Type const& other) const { STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands."); - return manager.getRationalType(); + return this->getManager().getRationalType(); + } + + Type Type::power(Type const& other) const { + STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands."); + if (this->isRationalType() || other.isRationalType()) { + return getManager().getRationalType(); + } + return this->getManager().getIntegerType(); } Type Type::logicalConnective(Type const& other) const { @@ -125,27 +137,28 @@ namespace storm { Type Type::numericalComparison(Type const& other) const { STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands."); if (this->isRationalType() || other.isRationalType()) { - return manager.getRationalType(); + return this->getManager().getRationalType(); } - return manager.getIntegerType(); + return this->getManager().getIntegerType(); } Type Type::ite(Type const& thenType, Type const& elseType) const { + STORM_LOG_ASSERT(this->isBooleanType(), "Operator requires boolean condition."); STORM_LOG_ASSERT(thenType == elseType, "Operator requires equal types."); return thenType; } Type Type::floorCeil() const { STORM_LOG_ASSERT(this->isRationalType(), "Operator requires rational operand."); - return manager.getIntegerType(); + return this->getManager().getIntegerType(); } Type Type::minimumMaximum(Type const& other) const { STORM_LOG_ASSERT(this->isNumericalType() && other.isNumericalType(), "Operator requires numerical operands."); if (this->isRationalType() || other.isRationalType()) { - return manager.getRationalType(); + return this->getManager().getRationalType(); } - return manager.getIntegerType(); + return this->getManager().getIntegerType(); } std::ostream& operator<<(std::ostream& stream, Type const& type) { diff --git a/src/storage/expressions/Type.h b/src/storage/expressions/Type.h index b8a3f51de..a4a8c0307 100644 --- a/src/storage/expressions/Type.h +++ b/src/storage/expressions/Type.h @@ -106,7 +106,15 @@ namespace storm { class Type { public: - Type(ExpressionManager const& manager, std::shared_ptr innerType); + Type() = default; + + /*! + * Constructs a new type of the given manager with the given encapsulated type. + * + * @param manager The manager responsible for this type. + * @param innerType The encapsulated type. + */ + Type(std::shared_ptr const& manager, std::shared_ptr innerType); /*! * Checks whether two types are the same. @@ -179,10 +187,18 @@ namespace storm { */ bool isRationalType() const; + /*! + * Retrieves the manager of the type. + * + * @return The manager of the type. + */ + storm::expressions::ExpressionManager const& getManager() const; + // Functions that, given the input types, produce the output type of the corresponding function application. Type plusMinusTimes(Type const& other) const; Type minus() const; Type divide(Type const& other) const; + Type power(Type const& other) const; Type logicalConnective(Type const& other) const; Type logicalConnective() const; Type numericalComparison(Type const& other) const; @@ -192,7 +208,7 @@ namespace storm { private: // The manager responsible for the type. - ExpressionManager const& manager; + std::shared_ptr manager; // The encapsulated type. std::shared_ptr innerType; diff --git a/src/storage/expressions/Variable.cpp b/src/storage/expressions/Variable.cpp index 2c6b5780e..b0acaba22 100644 --- a/src/storage/expressions/Variable.cpp +++ b/src/storage/expressions/Variable.cpp @@ -3,7 +3,7 @@ namespace storm { namespace expressions { - Variable::Variable(ExpressionManager const& manager, uint_fast64_t index) : manager(manager), index(index) { + Variable::Variable(std::shared_ptr const& manager, uint_fast64_t index) : manager(manager), index(index) { // Intentionally left empty. } @@ -20,21 +20,21 @@ namespace storm { } uint_fast64_t Variable::getOffset() const { - return manager.getOffset(index); + return this->getManager().getOffset(index); } std::string const& Variable::getName() const { - return manager.getVariableName(index); + return this->getManager().getVariableName(index); } Type const& Variable::getType() const { - return manager.getVariableType(index); + return this->getManager().getVariableType(index); } ExpressionManager const& Variable::getManager() const { - return manager; + return *manager; } - + bool Variable::hasBooleanType() const { return this->getType().isBooleanType(); } @@ -51,4 +51,14 @@ namespace storm { return this->getType().isNumericalType(); } } -} \ No newline at end of file +} + +namespace std { + std::size_t hash::operator()(storm::expressions::Variable const& variable) const { + return std::hash()(variable.getIndex()); + } + + std::size_t less::operator()(storm::expressions::Variable const& variable1, storm::expressions::Variable const& variable2) const { + return variable1.getIndex() < variable2.getIndex(); + } +} diff --git a/src/storage/expressions/Variable.h b/src/storage/expressions/Variable.h index 51e3c13b0..e99fd70c4 100644 --- a/src/storage/expressions/Variable.h +++ b/src/storage/expressions/Variable.h @@ -8,6 +8,26 @@ #include "src/storage/expressions/Type.h" #include "src/storage/expressions/Expression.h" +namespace storm { + namespace expressions { + class Variable; + } +} + +namespace std { + // Provide a hashing operator, so we can put variables in unordered collections. + template <> + struct hash { + std::size_t operator()(storm::expressions::Variable const& variable) const; + }; + + // Provide a less operator, so we can put variables in ordered collections. + template <> + struct less { + std::size_t operator()(storm::expressions::Variable const& variable1, storm::expressions::Variable const& variable2) const; + }; +} + namespace storm { namespace expressions { class ExpressionManager; @@ -15,13 +35,15 @@ namespace storm { // This class captures a simple variable. class Variable { public: + Variable() = default; + /*! * Constructs a variable with the given index and type. * * @param manager The manager that is responsible for this variable. * @param index The (unique) index of the variable. */ - Variable(ExpressionManager const& manager, uint_fast64_t index); + Variable(std::shared_ptr const& manager, uint_fast64_t index); // Default-instantiate some copy/move construction/assignment. Variable(Variable const& other) = default; @@ -62,6 +84,8 @@ namespace storm { /*! * Retrieves the manager responsible for this variable. + * + * @return The manager responsible for this variable. */ ExpressionManager const& getManager() const; @@ -109,7 +133,7 @@ namespace storm { private: // The manager that is responsible for this variable. - ExpressionManager const& manager; + std::shared_ptr manager; // The index of the variable. uint_fast64_t index; @@ -117,22 +141,4 @@ namespace storm { } } -namespace std { - // Provide a hashing operator, so we can put variables in unordered collections. - template <> - struct hash { - std::size_t operator()(storm::expressions::Variable const& variable) const { - return std::hash()(variable.getIndex()); - } - }; - - // Provide a less operator, so we can put variables in ordered collections. - template <> - struct less { - std::size_t operator()(storm::expressions::Variable const& variable1, storm::expressions::Variable const& variable2) const { - return variable1.getIndex() < variable2.getIndex(); - } - }; -} - #endif /* STORM_STORAGE_EXPRESSIONS_VARIABLE_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/VariableExpression.cpp b/src/storage/expressions/VariableExpression.cpp index 0737b94d9..19cdb8cd3 100644 --- a/src/storage/expressions/VariableExpression.cpp +++ b/src/storage/expressions/VariableExpression.cpp @@ -18,31 +18,27 @@ namespace storm { bool VariableExpression::evaluateAsBool(Valuation const* valuation) const { STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation."); - STORM_LOG_THROW(this->hasBooleanReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as boolean: return type is not a boolean."); + STORM_LOG_THROW(this->hasBooleanType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as boolean: return type is not a boolean."); return valuation->getBooleanValue(this->getVariable()); } int_fast64_t VariableExpression::evaluateAsInt(Valuation const* valuation) const { STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation."); - STORM_LOG_THROW(this->hasIntegralReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as integer: return type is not an integer."); + STORM_LOG_THROW(this->hasIntegralType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as integer: return type is not an integer."); return valuation->getIntegerValue(this->getVariable()); } double VariableExpression::evaluateAsDouble(Valuation const* valuation) const { STORM_LOG_ASSERT(valuation != nullptr, "Evaluating expressions with unknowns without valuation."); - STORM_LOG_THROW(this->hasNumericalReturnType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as double: return type is not a double."); + STORM_LOG_THROW(this->hasNumericalType(), storm::exceptions::InvalidTypeException, "Cannot evaluate expression as double: return type is not a double."); - switch (this->getReturnType()) { - case ExpressionReturnType::Int: return static_cast(valuation->getIntegerValue(this->getVariable())); break; - case ExpressionReturnType::Double: valuation->getRationalValue(this->getVariable()); break; - default: break; + if (this->getType().isIntegralType()) { + return static_cast(valuation->getIntegerValue(this->getVariable())); + } else { + return valuation->getRationalValue(this->getVariable()); } - STORM_LOG_ASSERT(false, "Type of variable is required to be numeric."); - - // Silence warning. This point can never be reached. - return 0; } std::string const& VariableExpression::getIdentifier() const { diff --git a/src/storage/prism/Assignment.cpp b/src/storage/prism/Assignment.cpp index 3b63703ed..4d1211959 100644 --- a/src/storage/prism/Assignment.cpp +++ b/src/storage/prism/Assignment.cpp @@ -14,7 +14,7 @@ namespace storm { return this->expression; } - Assignment Assignment::substitute(std::map const& substitution) const { + Assignment Assignment::substitute(std::map const& substitution) const { return Assignment(this->getVariableName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } diff --git a/src/storage/prism/Assignment.h b/src/storage/prism/Assignment.h index c6604a03e..a22154650 100644 --- a/src/storage/prism/Assignment.h +++ b/src/storage/prism/Assignment.h @@ -5,6 +5,7 @@ #include "src/storage/prism/LocatedInformation.h" #include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/Variable.h" #include "src/utility/OsDetection.h" namespace storm { @@ -45,12 +46,12 @@ namespace storm { storm::expressions::Expression const& getExpression() const; /*! - * Substitutes all identifiers in the assignment according to the given map. + * Substitutes all variables in the assignment according to the given map. * * @param substitution The substitution to perform. * @return The resulting assignment. */ - Assignment substitute(std::map const& substitution) const; + Assignment substitute(std::map const& substitution) const; friend std::ostream& operator<<(std::ostream& stream, Assignment const& assignment); diff --git a/src/storage/prism/BooleanVariable.cpp b/src/storage/prism/BooleanVariable.cpp index ef88cefc4..7f32fd7ea 100644 --- a/src/storage/prism/BooleanVariable.cpp +++ b/src/storage/prism/BooleanVariable.cpp @@ -2,16 +2,12 @@ namespace storm { namespace prism { - BooleanVariable::BooleanVariable(std::string const& variableName, std::string const& filename, uint_fast64_t lineNumber) : Variable(variableName, storm::expressions::Expression::createFalse(), true, filename, lineNumber) { - // Nothing to do here. - } - - BooleanVariable::BooleanVariable(std::string const& variableName, storm::expressions::Expression const& initialValueExpression, std::string const& filename, uint_fast64_t lineNumber) : Variable(variableName, initialValueExpression, false, filename, lineNumber) { + BooleanVariable::BooleanVariable(storm::expressions::Variable const& variable, storm::expressions::Expression const& initialValueExpression, std::string const& filename, uint_fast64_t lineNumber) : Variable(variable, initialValueExpression, false, filename, lineNumber) { // Nothing to do here. } - BooleanVariable BooleanVariable::substitute(std::map const& substitution) const { - return BooleanVariable(this->getName(), this->getInitialValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); + BooleanVariable BooleanVariable::substitute(std::map const& substitution) const { + return BooleanVariable(this->getExpressionVariable(), this->getInitialValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } std::ostream& operator<<(std::ostream& stream, BooleanVariable const& variable) { diff --git a/src/storage/prism/BooleanVariable.h b/src/storage/prism/BooleanVariable.h index f32035010..5b43e191c 100644 --- a/src/storage/prism/BooleanVariable.h +++ b/src/storage/prism/BooleanVariable.h @@ -20,23 +20,14 @@ namespace storm { #endif /*! - * Creates a boolean variable with the given name and the default initial value expression. + * Creates a boolean variable with the given constant initial value expression. * - * @param variableName The name of the variable. - * @param filename The filename in which the variable is defined. - * @param lineNumber The line number in which the variable is defined. - */ - BooleanVariable(std::string const& variableName, std::string const& filename = "", uint_fast64_t lineNumber = 0); - - /*! - * Creates a boolean variable with the given name and the given constant initial value expression. - * - * @param variableName The name of the variable. + * @param variable The expression variable associated with this variable. * @param initialValueExpression The constant expression that defines the initial value of the variable. * @param filename The filename in which the variable is defined. * @param lineNumber The line number in which the variable is defined. */ - BooleanVariable(std::string const& variableName, storm::expressions::Expression const& initialValueExpression, std::string const& filename = "", uint_fast64_t lineNumber = 0); + BooleanVariable(storm::expressions::Variable const& variable, storm::expressions::Expression const& initialValueExpression, std::string const& filename = "", uint_fast64_t lineNumber = 0); /*! * Substitutes all identifiers in the boolean variable according to the given map. @@ -44,7 +35,7 @@ namespace storm { * @param substitution The substitution to perform. * @return The resulting boolean variable. */ - BooleanVariable substitute(std::map const& substitution) const; + BooleanVariable substitute(std::map const& substitution) const; friend std::ostream& operator<<(std::ostream& stream, BooleanVariable const& variable); }; diff --git a/src/storage/prism/Command.cpp b/src/storage/prism/Command.cpp index 4c1619732..aa13da0e8 100644 --- a/src/storage/prism/Command.cpp +++ b/src/storage/prism/Command.cpp @@ -30,7 +30,7 @@ namespace storm { return this->globalIndex; } - Command Command::substitute(std::map const& substitution) const { + Command Command::substitute(std::map const& substitution) const { std::vector newUpdates; newUpdates.reserve(this->getNumberOfUpdates()); for (auto const& update : this->getUpdates()) { diff --git a/src/storage/prism/Command.h b/src/storage/prism/Command.h index 6e32eba5e..60095d13c 100644 --- a/src/storage/prism/Command.h +++ b/src/storage/prism/Command.h @@ -82,7 +82,7 @@ namespace storm { * @param substitution The substitution to perform. * @return The resulting command. */ - Command substitute(std::map const& substitution) const; + Command substitute(std::map const& substitution) const; friend std::ostream& operator<<(std::ostream& stream, Command const& command); diff --git a/src/storage/prism/Constant.cpp b/src/storage/prism/Constant.cpp index d327461a7..1e3d8d1d8 100644 --- a/src/storage/prism/Constant.cpp +++ b/src/storage/prism/Constant.cpp @@ -4,20 +4,24 @@ namespace storm { namespace prism { - Constant::Constant(storm::expressions::ExpressionReturnType type, std::string const& name, storm::expressions::Expression const& expression, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), type(type), name(name), defined(true), expression(expression) { + Constant::Constant(storm::expressions::Variable const& variable, storm::expressions::Expression const& expression, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), variable(variable), defined(true), expression(expression) { // Intentionally left empty. } - Constant::Constant(storm::expressions::ExpressionReturnType type, std::string const& name, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), type(type), name(name), defined(false), expression() { + Constant::Constant(storm::expressions::Variable const& variable, std::string const& filename, uint_fast64_t lineNumber) : LocatedInformation(filename, lineNumber), variable(variable), defined(false), expression() { // Intentionally left empty. } std::string const& Constant::getName() const { - return this->name; + return this->variable.getName(); } - storm::expressions::ExpressionReturnType Constant::getType() const { - return this->type; + storm::expressions::Type const& Constant::getType() const { + return this->getExpressionVariable().getType(); + } + + storm::expressions::Variable const& Constant::getExpressionVariable() const { + return this->variable; } bool Constant::isDefined() const { @@ -29,18 +33,12 @@ namespace storm { return this->expression; } - Constant Constant::substitute(std::map const& substitution) const { - return Constant(this->getType(), this->getName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); + Constant Constant::substitute(std::map const& substitution) const { + return Constant(variable, this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } std::ostream& operator<<(std::ostream& stream, Constant const& constant) { - stream << "const "; - switch (constant.getType()) { - case storm::expressions::ExpressionReturnType::Undefined: stream << "undefined "; break; - case storm::expressions::ExpressionReturnType::Bool: stream << "bool "; break; - case storm::expressions::ExpressionReturnType::Int: stream << "int "; break; - case storm::expressions::ExpressionReturnType::Double: stream << "double "; break; - } + stream << "const " << constant.getExpressionVariable().getType(); stream << constant.getName(); if (constant.isDefined()) { stream << " = " << constant.getExpression(); diff --git a/src/storage/prism/Constant.h b/src/storage/prism/Constant.h index 2e1c27590..1138395a7 100644 --- a/src/storage/prism/Constant.h +++ b/src/storage/prism/Constant.h @@ -5,6 +5,7 @@ #include "src/storage/prism/LocatedInformation.h" #include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/Variable.h" #include "src/utility/OsDetection.h" namespace storm { @@ -12,25 +13,23 @@ namespace storm { class Constant : public LocatedInformation { public: /*! - * Creates a constant with the given type, name and defining expression. + * Creates a defined constant. * - * @param type The type of the constant. - * @param name The name of the constant. + * @param variable The expression variable associated with the constant. * @param expression The expression that defines the constant. * @param filename The filename in which the transition reward is defined. * @param lineNumber The line number in which the transition reward is defined. */ - Constant(storm::expressions::ExpressionReturnType type, std::string const& name, storm::expressions::Expression const& expression, std::string const& filename = "", uint_fast64_t lineNumber = 0); + Constant(storm::expressions::Variable const& variable, storm::expressions::Expression const& expression, std::string const& filename = "", uint_fast64_t lineNumber = 0); /*! - * Creates an undefined constant with the given type and name. + * Creates an undefined constant. * - * @param constantType The type of the constant. - * @param constantName The name of the constant. + * @param variable The expression variable associated with the constant. * @param filename The filename in which the transition reward is defined. * @param lineNumber The line number in which the transition reward is defined. */ - Constant(storm::expressions::ExpressionReturnType constantType, std::string const& constantName, std::string const& filename = "", uint_fast64_t lineNumber = 0); + Constant(storm::expressions::Variable const& variable, std::string const& filename = "", uint_fast64_t lineNumber = 0); // Create default implementations of constructors/assignment. Constant() = default; @@ -53,7 +52,14 @@ namespace storm { * * @return The type of the constant; */ - storm::expressions::ExpressionReturnType getType() const; + storm::expressions::Type const& getType() const; + + /*! + * Retrieves the expression variable associated with this constant. + * + * @return The expression variable associated with this constant. + */ + storm::expressions::Variable const& getExpressionVariable() const; /*! * Retrieves whether the constant is defined, i.e., whether there is an expression defining its value. @@ -76,16 +82,13 @@ namespace storm { * @param substitution The substitution to perform. * @return The resulting constant. */ - Constant substitute(std::map const& substitution) const; + Constant substitute(std::map const& substitution) const; friend std::ostream& operator<<(std::ostream& stream, Constant const& constant); private: - // The type of the constant. - storm::expressions::ExpressionReturnType type; - - // The name of the constant. - std::string name; + // The expression variable associated with the constant. + storm::expressions::Variable variable; // A flag that stores whether or not the constant is defined. bool defined; diff --git a/src/storage/prism/Formula.cpp b/src/storage/prism/Formula.cpp index a120078b8..78e1bfaf1 100644 --- a/src/storage/prism/Formula.cpp +++ b/src/storage/prism/Formula.cpp @@ -14,11 +14,11 @@ namespace storm { return this->expression; } - storm::expressions::ExpressionReturnType Formula::getType() const { - return this->getExpression().getReturnType(); + storm::expressions::Type const& Formula::getType() const { + return this->getExpression().getType(); } - Formula Formula::substitute(std::map const& substitution) const { + Formula Formula::substitute(std::map const& substitution) const { return Formula(this->getName(), this->getExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } diff --git a/src/storage/prism/Formula.h b/src/storage/prism/Formula.h index 2df8e6d32..0e9ad7bd3 100644 --- a/src/storage/prism/Formula.h +++ b/src/storage/prism/Formula.h @@ -5,6 +5,7 @@ #include "src/storage/prism/LocatedInformation.h" #include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/Variable.h" #include "src/utility/OsDetection.h" namespace storm { @@ -49,15 +50,15 @@ namespace storm { * * @return The return type of the formula. */ - storm::expressions::ExpressionReturnType getType() const; + storm::expressions::Type const& getType() const; /*! - * Substitutes all identifiers in the expression of the formula according to the given map. + * Substitutes all variables in the expression of the formula according to the given map. * * @param substitution The substitution to perform. * @return The resulting formula. */ - Formula substitute(std::map const& substitution) const; + Formula substitute(std::map const& substitution) const; friend std::ostream& operator<<(std::ostream& stream, Formula const& formula); diff --git a/src/storage/prism/InitialConstruct.cpp b/src/storage/prism/InitialConstruct.cpp index 61c662843..9fff50a97 100644 --- a/src/storage/prism/InitialConstruct.cpp +++ b/src/storage/prism/InitialConstruct.cpp @@ -10,7 +10,7 @@ namespace storm { return this->initialStatesExpression; } - InitialConstruct InitialConstruct::substitute(std::map const& substitution) const { + InitialConstruct InitialConstruct::substitute(std::map const& substitution) const { return InitialConstruct(this->getInitialStatesExpression().substitute(substitution)); } diff --git a/src/storage/prism/InitialConstruct.h b/src/storage/prism/InitialConstruct.h index 6f7f6adb9..1bfd54d5f 100644 --- a/src/storage/prism/InitialConstruct.h +++ b/src/storage/prism/InitialConstruct.h @@ -5,6 +5,7 @@ #include "src/storage/prism/LocatedInformation.h" #include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/Variable.h" #include "src/utility/OsDetection.h" namespace storm { @@ -42,7 +43,7 @@ namespace storm { * @param substitution The substitution to perform. * @return The resulting initial construct. */ - InitialConstruct substitute(std::map const& substitution) const; + InitialConstruct substitute(std::map const& substitution) const; friend std::ostream& operator<<(std::ostream& stream, InitialConstruct const& initialConstruct); diff --git a/src/storage/prism/IntegerVariable.cpp b/src/storage/prism/IntegerVariable.cpp index 1968a07ca..8ab2f9678 100644 --- a/src/storage/prism/IntegerVariable.cpp +++ b/src/storage/prism/IntegerVariable.cpp @@ -2,11 +2,7 @@ namespace storm { namespace prism { - IntegerVariable::IntegerVariable(std::string const& name, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, std::string const& filename, uint_fast64_t lineNumber) : Variable(name, lowerBoundExpression, true, filename, lineNumber), lowerBoundExpression(lowerBoundExpression), upperBoundExpression(upperBoundExpression) { - // Intentionally left empty. - } - - IntegerVariable::IntegerVariable(std::string const& name, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, storm::expressions::Expression const& initialValueExpression, std::string const& filename, uint_fast64_t lineNumber) : Variable(name, initialValueExpression, false, filename, lineNumber), lowerBoundExpression(lowerBoundExpression), upperBoundExpression(upperBoundExpression) { + IntegerVariable::IntegerVariable(storm::expressions::Variable const& variable, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, storm::expressions::Expression const& initialValueExpression, std::string const& filename, uint_fast64_t lineNumber) : Variable(variable, initialValueExpression, false, filename, lineNumber), lowerBoundExpression(lowerBoundExpression), upperBoundExpression(upperBoundExpression) { // Intentionally left empty. } @@ -18,8 +14,8 @@ namespace storm { return this->upperBoundExpression; } - IntegerVariable IntegerVariable::substitute(std::map const& substitution) const { - return IntegerVariable(this->getName(), this->getLowerBoundExpression().substitute(substitution), this->getUpperBoundExpression().substitute(substitution), this->getInitialValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); + IntegerVariable IntegerVariable::substitute(std::map const& substitution) const { + return IntegerVariable(this->getExpressionVariable(), this->getLowerBoundExpression().substitute(substitution), this->getUpperBoundExpression().substitute(substitution), this->getInitialValueExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } std::ostream& operator<<(std::ostream& stream, IntegerVariable const& variable) { diff --git a/src/storage/prism/IntegerVariable.h b/src/storage/prism/IntegerVariable.h index a7c069c53..024e1d214 100644 --- a/src/storage/prism/IntegerVariable.h +++ b/src/storage/prism/IntegerVariable.h @@ -20,27 +20,16 @@ namespace storm { #endif /*! - * Creates an integer variable with the given name and a default initial value. + * Creates an integer variable with the given initial value expression. * - * @param name The name of the variable. - * @param lowerBoundExpression A constant expression defining the lower bound of the domain of the variable. - * @param upperBoundExpression A constant expression defining the upper bound of the domain of the variable. - * @param filename The filename in which the variable is defined. - * @param lineNumber The line number in which the variable is defined. - */ - IntegerVariable(std::string const& name, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, std::string const& filename = "", uint_fast64_t lineNumber = 0); - - /*! - * Creates an integer variable with the given name and the given initial value expression. - * - * @param name The name of the variable. + * @param variable The expression variable associated with this variable. * @param lowerBoundExpression A constant expression defining the lower bound of the domain of the variable. * @param upperBoundExpression A constant expression defining the upper bound of the domain of the variable. * @param initialValueExpression A constant expression that defines the initial value of the variable. * @param filename The filename in which the variable is defined. * @param lineNumber The line number in which the variable is defined. */ - IntegerVariable(std::string const& name, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, storm::expressions::Expression const& initialValueExpression, std::string const& filename = "", uint_fast64_t lineNumber = 0); + IntegerVariable(storm::expressions::Variable const& variable, storm::expressions::Expression const& lowerBoundExpression, storm::expressions::Expression const& upperBoundExpression, storm::expressions::Expression const& initialValueExpression, std::string const& filename = "", uint_fast64_t lineNumber = 0); /*! * Retrieves an expression defining the lower bound for this integer variable. @@ -62,7 +51,7 @@ namespace storm { * @param substitution The substitution to perform. * @return The resulting boolean variable. */ - IntegerVariable substitute(std::map const& substitution) const; + IntegerVariable substitute(std::map const& substitution) const; friend std::ostream& operator<<(std::ostream& stream, IntegerVariable const& variable); diff --git a/src/storage/prism/Label.cpp b/src/storage/prism/Label.cpp index d11091cf0..dbc9d7f6f 100644 --- a/src/storage/prism/Label.cpp +++ b/src/storage/prism/Label.cpp @@ -14,7 +14,7 @@ namespace storm { return this->statePredicateExpression; } - Label Label::substitute(std::map const& substitution) const { + Label Label::substitute(std::map const& substitution) const { return Label(this->getName(), this->getStatePredicateExpression().substitute(substitution), this->getFilename(), this->getLineNumber()); } diff --git a/src/storage/prism/Label.h b/src/storage/prism/Label.h index 96934adf2..c0202cc04 100644 --- a/src/storage/prism/Label.h +++ b/src/storage/prism/Label.h @@ -5,6 +5,7 @@ #include "src/storage/prism/LocatedInformation.h" #include "src/storage/expressions/Expression.h" +#include "src/storage/expressions/Variable.h" #include "src/utility/OsDetection.h" namespace storm { @@ -51,7 +52,7 @@ namespace storm { * @param substitution The substitution to perform. * @return The resulting label. */ - Label substitute(std::map const& substitution) const; + Label substitute(std::map const& substitution) const; friend std::ostream& operator<<(std::ostream& stream, Label const& label); diff --git a/src/storage/prism/Program.cpp b/src/storage/prism/Program.cpp index bd7a11137..203e25e75 100644 --- a/src/storage/prism/Program.cpp +++ b/src/storage/prism/Program.cpp @@ -2,34 +2,35 @@ #include +#include "src/storage/expressions/ExpressionManager.h" #include "src/utility/macros.h" -#include "exceptions/InvalidArgumentException.h" +#include "src/exceptions/InvalidArgumentException.h" #include "src/exceptions/OutOfRangeException.h" #include "src/exceptions/WrongFormatException.h" #include "src/exceptions/InvalidTypeException.h" namespace storm { namespace prism { - Program::Program(ModelType modelType, std::vector const& constants, std::vector const& globalBooleanVariables, std::vector const& globalIntegerVariables, std::vector const& formulas, std::vector const& modules, std::vector const& rewardModels, bool fixInitialConstruct, storm::prism::InitialConstruct const& initialConstruct, std::vector