diff --git a/src/adapters/MathsatExpressionAdapter.h b/src/adapters/MathsatExpressionAdapter.h index 1aee8f53f..f15276b80 100644 --- a/src/adapters/MathsatExpressionAdapter.h +++ b/src/adapters/MathsatExpressionAdapter.h @@ -36,219 +36,171 @@ namespace storm { * expressions and are not yet known to the adapter. * @param variableToDeclarationMap A mapping from variable names to their corresponding MathSAT declarations (if already existing). */ - MathsatExpressionAdapter(msat_env& env, bool createVariables = true, std::map const& variableToDeclarationMap = std::map()) : env(env), stack(), variableToDeclarationMap(variableToDeclarationMap), createVariables(createVariables) { + MathsatExpressionAdapter(msat_env& env, bool createVariables = true, std::map const& variableToDeclarationMap = std::map()) : env(env), variableToDeclarationMap(variableToDeclarationMap), createVariables(createVariables) { // Intentionally left empty. } /*! - * Translates the given expression to an equivalent term for MathSAT. - * - * @param expression The expression to be translated. - * @return An equivalent term for MathSAT. - */ + * Translates the given expression to an equivalent term for MathSAT. + * + * @param expression The expression to be translated. + * @return An equivalent term for MathSAT. + */ msat_term translateExpression(storm::expressions::Expression const& expression) { - expression.getBaseExpression().accept(this); - msat_term result = stack.top(); - stack.pop(); + msat_term result = boost::any_cast(expression.getBaseExpression().accept(*this)); STORM_LOG_THROW(!MSAT_ERROR_TERM(result), storm::exceptions::ExpressionEvaluationException, "Could not translate expression to MathSAT's format."); return result; } - virtual void visit(expressions::BinaryBooleanFunctionExpression const* expression) override { - expression->getFirstOperand()->accept(this); - expression->getSecondOperand()->accept(this); + virtual boost::any visit(expressions::BinaryBooleanFunctionExpression const& expression) override { + msat_term leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); + msat_term rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); - msat_term rightResult = stack.top(); - stack.pop(); - msat_term leftResult = stack.top(); - stack.pop(); - - switch (expression->getOperatorType()) { + switch (expression.getOperatorType()) { case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And: - stack.push(msat_make_and(env, leftResult, rightResult)); - break; + return msat_make_and(env, leftResult, rightResult); case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Or: - stack.push(msat_make_or(env, leftResult, rightResult)); - break; + return msat_make_or(env, leftResult, rightResult); case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Iff: - stack.push(msat_make_iff(env, leftResult, rightResult)); - break; + return msat_make_iff(env, leftResult, rightResult); case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies: - stack.push(msat_make_or(env, msat_make_not(env, leftResult), rightResult)); - break; + return msat_make_or(env, msat_make_not(env, leftResult), rightResult); default: - STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } - } - virtual void visit(expressions::BinaryNumericalFunctionExpression const* expression) override { - expression->getFirstOperand()->accept(this); - expression->getSecondOperand()->accept(this); - - msat_term rightResult = stack.top(); - stack.pop(); - msat_term leftResult = stack.top(); - stack.pop(); + virtual boost::any visit(expressions::BinaryNumericalFunctionExpression const& expression) override { + msat_term leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); + msat_term rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); - switch (expression->getOperatorType()) { + switch (expression.getOperatorType()) { case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Plus: - stack.push(msat_make_plus(env, leftResult, rightResult)); - break; + return msat_make_plus(env, leftResult, rightResult); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Minus: - stack.push(msat_make_plus(env, leftResult, msat_make_times(env, msat_make_number(env, "-1"), rightResult))); - break; + return msat_make_plus(env, leftResult, msat_make_times(env, msat_make_number(env, "-1"), rightResult)); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Times: - stack.push(msat_make_times(env, leftResult, rightResult)); - break; + return msat_make_times(env, leftResult, rightResult); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Divide: - throw storm::exceptions::ExpressionEvaluationException() << "Cannot evaluate expression: " - << "Unsupported numerical binary operator: '/' (division) in expression " << expression << "."; - break; + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unsupported numerical binary operator: '/' (division) in expression."); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Min: - stack.push(msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), leftResult, rightResult)); - break; + return msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), leftResult, rightResult); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max: - stack.push(msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), rightResult, leftResult)); - break; + return msat_make_term_ite(env, msat_make_leq(env, leftResult, rightResult), rightResult, leftResult); default: - STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } } - virtual void visit(expressions::BinaryRelationExpression const* expression) override { - expression->getFirstOperand()->accept(this); - expression->getSecondOperand()->accept(this); - - msat_term rightResult = stack.top(); - stack.pop(); - msat_term leftResult = stack.top(); - stack.pop(); + virtual boost::any visit(expressions::BinaryRelationExpression const& expression) override { + msat_term leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); + msat_term rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); - switch (expression->getRelationType()) { + switch (expression.getRelationType()) { case storm::expressions::BinaryRelationExpression::RelationType::Equal: - if (expression->getFirstOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool && expression->getSecondOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool) { - stack.push(msat_make_iff(env, leftResult, rightResult)); + if (expression.getFirstOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool && expression.getSecondOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool) { + return msat_make_iff(env, leftResult, rightResult); } else { - stack.push(msat_make_equal(env, leftResult, rightResult)); + return msat_make_equal(env, leftResult, rightResult); } - break; case storm::expressions::BinaryRelationExpression::RelationType::NotEqual: - if (expression->getFirstOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool && expression->getSecondOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool) { - stack.push(msat_make_not(env, msat_make_iff(env, leftResult, rightResult))); + if (expression.getFirstOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool && expression.getSecondOperand()->getReturnType() == storm::expressions::ExpressionReturnType::Bool) { + return msat_make_not(env, msat_make_iff(env, leftResult, rightResult)); } else { - stack.push(msat_make_not(env, msat_make_equal(env, leftResult, rightResult))); + return msat_make_not(env, msat_make_equal(env, leftResult, rightResult)); } - break; case storm::expressions::BinaryRelationExpression::RelationType::Less: - stack.push(msat_make_and(env, msat_make_not(env, msat_make_equal(env, leftResult, rightResult)), msat_make_leq(env, leftResult, rightResult))); - break; + return msat_make_and(env, msat_make_not(env, msat_make_equal(env, leftResult, rightResult)), msat_make_leq(env, leftResult, rightResult)); case storm::expressions::BinaryRelationExpression::RelationType::LessOrEqual: - stack.push(msat_make_leq(env, leftResult, rightResult)); - break; + return msat_make_leq(env, leftResult, rightResult); case storm::expressions::BinaryRelationExpression::RelationType::Greater: - stack.push(msat_make_not(env, msat_make_leq(env, leftResult, rightResult))); - break; + return msat_make_not(env, msat_make_leq(env, leftResult, rightResult)); case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual: - stack.push(msat_make_or(env, msat_make_equal(env, leftResult, rightResult), msat_make_not(env, msat_make_leq(env, leftResult, rightResult)))); - break; + return msat_make_or(env, msat_make_equal(env, leftResult, rightResult), msat_make_not(env, msat_make_leq(env, leftResult, rightResult))); default: - STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression->getRelationType()) << "' in expression " << expression << "."); + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression.getRelationType()) << "' in expression " << expression << "."); } } - virtual void visit(storm::expressions::IfThenElseExpression const* expression) override { - expression->getCondition()->accept(this); - expression->getThenExpression()->accept(this); - expression->getElseExpression()->accept(this); - - msat_term conditionResult = stack.top(); - stack.pop(); - msat_term thenResult = stack.top(); - stack.pop(); - msat_term elseResult = stack.top(); - stack.pop(); - - stack.push(msat_make_term_ite(env, conditionResult, thenResult, elseResult)); + virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression) override { + msat_term conditionResult = boost::any_cast(expression.getCondition()->accept(*this)); + msat_term thenResult = boost::any_cast(expression.getThenExpression()->accept(*this)); + msat_term elseResult = boost::any_cast(expression.getElseExpression()->accept(*this)); + return msat_make_term_ite(env, conditionResult, thenResult, elseResult); } - virtual void visit(expressions::BooleanLiteralExpression const* expression) override { - stack.push(expression->evaluateAsBool(nullptr) ? msat_make_true(env) : msat_make_false(env)); + virtual boost::any visit(expressions::BooleanLiteralExpression const& expression) override { + return expression.getValue() ? msat_make_true(env) : msat_make_false(env); } - virtual void visit(expressions::DoubleLiteralExpression const* expression) override { - stack.push(msat_make_number(env, std::to_string(expression->evaluateAsDouble(nullptr)).c_str())); + virtual boost::any visit(expressions::DoubleLiteralExpression const& expression) override { + return msat_make_number(env, std::to_string(expression.getValue()).c_str()); } - virtual void visit(expressions::IntegerLiteralExpression const* expression) override { - stack.push(msat_make_number(env, std::to_string(static_cast(expression->evaluateAsInt(nullptr))).c_str())); + virtual boost::any visit(expressions::IntegerLiteralExpression const& expression) override { + return msat_make_number(env, std::to_string(static_cast(expression.getValue())).c_str()); } - virtual void visit(expressions::UnaryBooleanFunctionExpression const* expression) override { - expression->getOperand()->accept(this); - - msat_term childResult = stack.top(); - stack.pop(); + virtual boost::any visit(expressions::UnaryBooleanFunctionExpression const& expression) override { + msat_term childResult = boost::any_cast(expression.getOperand()->accept(*this)); - switch (expression->getOperatorType()) { + switch (expression.getOperatorType()) { case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not: - stack.push(msat_make_not(env, childResult)); + return msat_make_not(env, childResult); break; default: - STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean unary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean unary operator: '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } } - virtual void visit(expressions::UnaryNumericalFunctionExpression const* expression) override { - expression->getOperand()->accept(this); + virtual boost::any visit(expressions::UnaryNumericalFunctionExpression const& expression) override { + msat_term childResult = boost::any_cast(expression.getOperand()->accept(*this)); - msat_term childResult = stack.top(); - stack.pop(); - switch (expression->getOperatorType()) { + switch (expression.getOperatorType()) { case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus: - stack.push(msat_make_times(env, msat_make_number(env, "-1"), childResult)); + return msat_make_times(env, msat_make_number(env, "-1"), childResult); break; case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor: - stack.push(msat_make_floor(env, childResult)); + return msat_make_floor(env, childResult); break; case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil: - stack.push(msat_make_plus(env, msat_make_floor(env, childResult), msat_make_number(env, "1"))); + return msat_make_plus(env, msat_make_floor(env, childResult), msat_make_number(env, "1")); break; default: - STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator: '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator: '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } } - virtual void visit(expressions::VariableExpression const* expression) override { - std::map::iterator stringVariablePair = variableToDeclarationMap.find(expression->getVariableName()); + virtual boost::any visit(expressions::VariableExpression const& expression) override { + std::map::iterator stringVariablePair = variableToDeclarationMap.find(expression.getVariableName()); msat_decl result; if (stringVariablePair == variableToDeclarationMap.end() && createVariables) { std::pair::iterator, bool> iteratorAndFlag; - switch (expression->getReturnType()) { + switch (expression.getReturnType()) { case storm::expressions::ExpressionReturnType::Bool: - iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_bool_type(env)))); + iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression.getVariableName(), msat_declare_function(env, expression.getVariableName().c_str(), msat_get_bool_type(env)))); result = iteratorAndFlag.first->second; break; case storm::expressions::ExpressionReturnType::Int: - iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_integer_type(env)))); + iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression.getVariableName(), msat_declare_function(env, expression.getVariableName().c_str(), msat_get_integer_type(env)))); result = iteratorAndFlag.first->second; break; case storm::expressions::ExpressionReturnType::Double: - iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression->getVariableName(), msat_declare_function(env, expression->getVariableName().c_str(), msat_get_rational_type(env)))); + iteratorAndFlag = this->variableToDeclarationMap.insert(std::make_pair(expression.getVariableName(), msat_declare_function(env, expression.getVariableName().c_str(), msat_get_rational_type(env)))); result = iteratorAndFlag.first->second; break; default: - STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression->getVariableName() << "' with unknown type while trying to create solver variables."); + STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression.getVariableName() << "' with unknown type while trying to create solver variables."); } } else { - STORM_LOG_THROW(stringVariablePair != variableToDeclarationMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression->getVariableName() << "'."); + STORM_LOG_THROW(stringVariablePair != variableToDeclarationMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression.getVariableName() << "'."); result = stringVariablePair->second; } STORM_LOG_THROW(!MSAT_ERROR_DECL(result), storm::exceptions::ExpressionEvaluationException, "Unable to translate expression to MathSAT format, because a variable could not be translated."); - stack.push(msat_make_constant(env, result)); + return msat_make_constant(env, result); } storm::expressions::Expression translateExpression(msat_term const& term) { @@ -309,9 +261,6 @@ namespace storm { // The MathSAT environment used. msat_env& env; - // A stack used for communicating results between different functions. - std::stack stack; - // A mapping of variable names to their declaration in the MathSAT environment. std::map variableToDeclarationMap; diff --git a/src/adapters/Z3ExpressionAdapter.h b/src/adapters/Z3ExpressionAdapter.h index 09c7d56f3..b28744ae7 100644 --- a/src/adapters/Z3ExpressionAdapter.h +++ b/src/adapters/Z3ExpressionAdapter.h @@ -35,7 +35,7 @@ namespace storm { * expressions and are not yet known to the adapter. * @param variableToExpressionMap A mapping from variable names to their corresponding Z3 expressions (if already existing). */ - Z3ExpressionAdapter(z3::context& context, bool createVariables = true, std::map const& variableToExpressionMap = std::map()) : context(context) , stack() , additionalAssertions() , additionalVariableCounter(0), variableToExpressionMap(variableToExpressionMap), createVariables(createVariables) { + Z3ExpressionAdapter(z3::context& context, bool createVariables = true, std::map const& variableToExpressionMap = std::map()) : context(context), additionalAssertions(), additionalVariableCounter(0), variableToExpressionMap(variableToExpressionMap), createVariables(createVariables) { // Intentionally left empty. } @@ -44,20 +44,18 @@ namespace storm { * * @warning The adapter internally creates helper variables prefixed with `__z3adapter_`. As a consequence, * having variables with this prefix in the variableToExpressionMap might lead to unexpected results and is - * strictly to be avoided. + * strictly to be aboost::anyed. * * @param expression The expression to translate. * @return An equivalent expression for Z3. */ z3::expr translateExpression(storm::expressions::Expression const& expression) { - expression.getBaseExpression().accept(this); - z3::expr result = stack.top(); - stack.pop(); + z3::expr result = boost::any_cast(expression.getBaseExpression().accept(*this)); - while (!additionalAssertions.empty()) { - result = result && additionalAssertions.top(); - additionalAssertions.pop(); - } + for (z3::expr const& assertion : additionalAssertions) { + result = result && assertion; + } + additionalAssertions.clear(); return result; } @@ -167,211 +165,159 @@ namespace storm { } } - virtual void visit(storm::expressions::BinaryBooleanFunctionExpression const* expression) override { - expression->getFirstOperand()->accept(this); - expression->getSecondOperand()->accept(this); - - const z3::expr rightResult = stack.top(); - stack.pop(); - const z3::expr leftResult = stack.top(); - stack.pop(); + virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression) override { + z3::expr leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); + z3::expr rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); - switch(expression->getOperatorType()) { + switch(expression.getOperatorType()) { case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::And: - stack.push(leftResult && rightResult); - break; + return leftResult && rightResult; case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Or: - stack.push(leftResult || rightResult); - break; + return leftResult || rightResult; case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Xor: - stack.push(z3::expr(context, Z3_mk_xor(context, leftResult, rightResult))); - break; + return z3::expr(context, Z3_mk_xor(context, leftResult, rightResult)); case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Implies: - stack.push(z3::expr(context, Z3_mk_implies(context, leftResult, rightResult))); - break; + return z3::expr(context, Z3_mk_implies(context, leftResult, rightResult)); case storm::expressions::BinaryBooleanFunctionExpression::OperatorType::Iff: - stack.push(z3::expr(context, Z3_mk_iff(context, leftResult, rightResult))); - break; + return z3::expr(context, Z3_mk_iff(context, leftResult, rightResult)); default: - STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } } - virtual void visit(storm::expressions::BinaryNumericalFunctionExpression const* expression) override { - expression->getFirstOperand()->accept(this); - expression->getSecondOperand()->accept(this); - - z3::expr rightResult = stack.top(); - stack.pop(); - z3::expr leftResult = stack.top(); - stack.pop(); - - switch(expression->getOperatorType()) { + virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression) override { + z3::expr leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); + z3::expr rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); + + switch(expression.getOperatorType()) { case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Plus: - stack.push(leftResult + rightResult); - break; + return leftResult + rightResult; case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Minus: - stack.push(leftResult - rightResult); - break; + return leftResult - rightResult; case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Times: - stack.push(leftResult * rightResult); - break; + return leftResult * rightResult; case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Divide: - stack.push(leftResult / rightResult); - break; + return leftResult / rightResult; case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Min: - stack.push(ite(leftResult <= rightResult, leftResult, rightResult)); - break; + return ite(leftResult <= rightResult, leftResult, rightResult); case storm::expressions::BinaryNumericalFunctionExpression::OperatorType::Max: - stack.push(ite(leftResult >= rightResult, leftResult, rightResult)); - break; - default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); + return ite(leftResult >= rightResult, leftResult, rightResult); + default: + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical binary operator '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } } - virtual void visit(storm::expressions::BinaryRelationExpression const* expression) override { - expression->getFirstOperand()->accept(this); - expression->getSecondOperand()->accept(this); - - z3::expr rightResult = stack.top(); - stack.pop(); - z3::expr leftResult = stack.top(); - stack.pop(); - - switch(expression->getRelationType()) { + virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression) override { + z3::expr leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); + z3::expr rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); + + switch(expression.getRelationType()) { case storm::expressions::BinaryRelationExpression::RelationType::Equal: - stack.push(leftResult == rightResult); - break; + return leftResult == rightResult; case storm::expressions::BinaryRelationExpression::RelationType::NotEqual: - stack.push(leftResult != rightResult); - break; + return leftResult != rightResult; case storm::expressions::BinaryRelationExpression::RelationType::Less: - stack.push(leftResult < rightResult); - break; + return leftResult < rightResult; case storm::expressions::BinaryRelationExpression::RelationType::LessOrEqual: - stack.push(leftResult <= rightResult); - break; + return leftResult <= rightResult; case storm::expressions::BinaryRelationExpression::RelationType::Greater: - stack.push(leftResult > rightResult); - break; + return leftResult > rightResult; case storm::expressions::BinaryRelationExpression::RelationType::GreaterOrEqual: - stack.push(leftResult >= rightResult); - break; + return leftResult >= rightResult; default: - STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression->getRelationType()) << "' in expression " << expression << "."); + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression.getRelationType()) << "' in expression " << expression << "."); } } - virtual void visit(storm::expressions::BooleanLiteralExpression const* expression) override { - stack.push(context.bool_val(expression->evaluateAsBool())); + virtual boost::any visit(storm::expressions::BooleanLiteralExpression const& expression) override { + return context.bool_val(expression.getValue()); } - virtual void visit(storm::expressions::DoubleLiteralExpression const* expression) override { + virtual boost::any visit(storm::expressions::DoubleLiteralExpression const& expression) override { std::stringstream fractionStream; - fractionStream << expression->evaluateAsDouble(); - stack.push(context.real_val(fractionStream.str().c_str())); + fractionStream << expression.getValue(); + return context.real_val(fractionStream.str().c_str()); } - virtual void visit(storm::expressions::IntegerLiteralExpression const* expression) override { - stack.push(context.int_val(static_cast(expression->evaluateAsInt()))); + virtual boost::any visit(storm::expressions::IntegerLiteralExpression const& expression) override { + return context.int_val(static_cast(expression.getValue())); } - virtual void visit(storm::expressions::UnaryBooleanFunctionExpression const* expression) override { - expression->getOperand()->accept(this); - - z3::expr childResult = stack.top(); - stack.pop(); + virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression) override { + z3::expr childResult = boost::any_cast(expression.getOperand()->accept(*this)); - switch (expression->getOperatorType()) { + switch (expression.getOperatorType()) { case storm::expressions::UnaryBooleanFunctionExpression::OperatorType::Not: - stack.push(!childResult); - break; + return !childResult; default: - STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression->getOperatorType()) << "' in expression " << expression << "."); + STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown boolean binary operator '" << static_cast(expression.getOperatorType()) << "' in expression " << expression << "."); } } - virtual void visit(storm::expressions::UnaryNumericalFunctionExpression const* expression) override { - expression->getOperand()->accept(this); + virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression) override { + z3::expr childResult = boost::any_cast(expression.getOperand()->accept(*this)); - z3::expr childResult = stack.top(); - stack.pop(); - - switch(expression->getOperatorType()) { + switch(expression.getOperatorType()) { case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Minus: - stack.push(0 - childResult); - break; + return 0 - childResult; case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Floor: { z3::expr floorVariable = context.int_const(("__z3adapter_floor_" + std::to_string(additionalVariableCounter++)).c_str()); - additionalAssertions.push(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= childResult && childResult < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1)); - stack.push(floorVariable); - break; + additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, floorVariable)) <= childResult && childResult < (z3::expr(context, Z3_mk_int2real(context, floorVariable)) + 1)); + return floorVariable; } case storm::expressions::UnaryNumericalFunctionExpression::OperatorType::Ceil:{ z3::expr ceilVariable = context.int_const(("__z3adapter_ceil_" + std::to_string(additionalVariableCounter++)).c_str()); - additionalAssertions.push(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= childResult && childResult < z3::expr(context, Z3_mk_int2real(context, ceilVariable))); - stack.push(ceilVariable); - break; + additionalAssertions.push_back(z3::expr(context, Z3_mk_int2real(context, ceilVariable)) - 1 <= childResult && childResult < z3::expr(context, Z3_mk_int2real(context, ceilVariable))); + return ceilVariable; } - default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator '" << static_cast(expression->getOperatorType()) << "'."); + default: STORM_LOG_THROW(false, storm::exceptions::ExpressionEvaluationException, "Cannot evaluate expression: unknown numerical unary operator '" << static_cast(expression.getOperatorType()) << "'."); } } - virtual void visit(storm::expressions::IfThenElseExpression const* expression) override { - expression->getCondition()->accept(this); - expression->getThenExpression()->accept(this); - expression->getElseExpression()->accept(this); - - z3::expr conditionResult = stack.top(); - stack.pop(); - z3::expr thenResult = stack.top(); - stack.pop(); - z3::expr elseResult = stack.top(); - stack.pop(); - - stack.push(z3::expr(context, Z3_mk_ite(context, conditionResult, thenResult, elseResult))); + virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression) override { + z3::expr conditionResult = boost::any_cast(expression.getCondition()->accept(*this)); + z3::expr thenResult = boost::any_cast(expression.getThenExpression()->accept(*this)); + z3::expr elseResult = boost::any_cast(expression.getElseExpression()->accept(*this)); + return z3::expr(context, Z3_mk_ite(context, conditionResult, thenResult, elseResult)); } - virtual void visit(storm::expressions::VariableExpression const* expression) override { - std::map::iterator stringVariablePair = variableToExpressionMap.find(expression->getVariableName()); + virtual boost::any visit(storm::expressions::VariableExpression const& expression) override { + std::map::iterator stringVariablePair = variableToExpressionMap.find(expression.getVariableName()); z3::expr result(context); if (stringVariablePair == variableToExpressionMap.end() && createVariables) { std::pair::iterator, bool> iteratorAndFlag; - switch (expression->getReturnType()) { + switch (expression.getReturnType()) { case storm::expressions::ExpressionReturnType::Bool: - iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression->getVariableName(), context.bool_const(expression->getVariableName().c_str()))); + iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression.getVariableName(), context.bool_const(expression.getVariableName().c_str()))); result = iteratorAndFlag.first->second; break; case storm::expressions::ExpressionReturnType::Int: - iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression->getVariableName(), context.int_const(expression->getVariableName().c_str()))); + iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression.getVariableName(), context.int_const(expression.getVariableName().c_str()))); result = iteratorAndFlag.first->second; break; case storm::expressions::ExpressionReturnType::Double: - iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression->getVariableName(), context.real_const(expression->getVariableName().c_str()))); + iteratorAndFlag = this->variableToExpressionMap.insert(std::make_pair(expression.getVariableName(), context.real_const(expression.getVariableName().c_str()))); result = iteratorAndFlag.first->second; break; default: - STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression->getVariableName() << "' with unknown type while trying to create solver variables."); + STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Encountered variable '" << expression.getVariableName() << "' with unknown type while trying to create solver variables."); } } else { - STORM_LOG_THROW(stringVariablePair != variableToExpressionMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression->getVariableName() << "'."); + STORM_LOG_THROW(stringVariablePair != variableToExpressionMap.end(), storm::exceptions::InvalidArgumentException, "Expression refers to unknown variable '" << expression.getVariableName() << "'."); result = stringVariablePair->second; } - stack.push(result); + return result; } private: // The context that is used to translate the expressions. z3::context& context; - // A stack that is used to communicate the translation results between method calls. - std::stack stack; - // A stack of assertions that need to be kept separate, because they were only impliclty part of an assertion that was added. - std::stack additionalAssertions; + std::vector additionalAssertions; // A counter for the variables that were created to identify the additional assertions. uint_fast64_t additionalVariableCounter; diff --git a/src/storage/expressions/BaseExpression.h b/src/storage/expressions/BaseExpression.h index 8afe11b91..86f6f28a4 100644 --- a/src/storage/expressions/BaseExpression.h +++ b/src/storage/expressions/BaseExpression.h @@ -168,7 +168,7 @@ namespace storm { * * @param visitor The visitor that is to be accepted. */ - virtual void accept(ExpressionVisitor* visitor) const = 0; + virtual boost::any accept(ExpressionVisitor& visitor) const = 0; /*! * Retrieves whether the expression has a numerical return type, i.e., integer or double. diff --git a/src/storage/expressions/BinaryBooleanFunctionExpression.cpp b/src/storage/expressions/BinaryBooleanFunctionExpression.cpp index 9a9d81fe5..953ff9f92 100644 --- a/src/storage/expressions/BinaryBooleanFunctionExpression.cpp +++ b/src/storage/expressions/BinaryBooleanFunctionExpression.cpp @@ -91,8 +91,8 @@ namespace storm { } } - void BinaryBooleanFunctionExpression::accept(ExpressionVisitor* visitor) const { - visitor->visit(this); + boost::any BinaryBooleanFunctionExpression::accept(ExpressionVisitor& visitor) const { + return visitor.visit(*this); } void BinaryBooleanFunctionExpression::printToStream(std::ostream& stream) const { diff --git a/src/storage/expressions/BinaryBooleanFunctionExpression.h b/src/storage/expressions/BinaryBooleanFunctionExpression.h index 2a1d2417d..eb32b7914 100644 --- a/src/storage/expressions/BinaryBooleanFunctionExpression.h +++ b/src/storage/expressions/BinaryBooleanFunctionExpression.h @@ -36,7 +36,7 @@ namespace storm { virtual storm::expressions::OperatorType getOperator() const override; virtual bool evaluateAsBool(Valuation const* valuation = nullptr) const override; virtual std::shared_ptr simplify() const override; - virtual void accept(ExpressionVisitor* visitor) const override; + virtual boost::any accept(ExpressionVisitor& visitor) const override; /*! * Retrieves the operator associated with the expression. diff --git a/src/storage/expressions/BinaryNumericalFunctionExpression.cpp b/src/storage/expressions/BinaryNumericalFunctionExpression.cpp index 18b1d3a71..b6c85ebcb 100644 --- a/src/storage/expressions/BinaryNumericalFunctionExpression.cpp +++ b/src/storage/expressions/BinaryNumericalFunctionExpression.cpp @@ -70,8 +70,8 @@ namespace storm { } } - void BinaryNumericalFunctionExpression::accept(ExpressionVisitor* visitor) const { - visitor->visit(this); + boost::any BinaryNumericalFunctionExpression::accept(ExpressionVisitor& visitor) const { + return visitor.visit(*this); } void BinaryNumericalFunctionExpression::printToStream(std::ostream& stream) const { diff --git a/src/storage/expressions/BinaryNumericalFunctionExpression.h b/src/storage/expressions/BinaryNumericalFunctionExpression.h index 77b8021a4..8e129a21d 100644 --- a/src/storage/expressions/BinaryNumericalFunctionExpression.h +++ b/src/storage/expressions/BinaryNumericalFunctionExpression.h @@ -37,7 +37,7 @@ namespace storm { virtual int_fast64_t evaluateAsInt(Valuation const* valuation = nullptr) const override; virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override; virtual std::shared_ptr simplify() const override; - virtual void accept(ExpressionVisitor* visitor) const override; + virtual boost::any accept(ExpressionVisitor& visitor) const override; /*! * Retrieves the operator associated with the expression. diff --git a/src/storage/expressions/BinaryRelationExpression.cpp b/src/storage/expressions/BinaryRelationExpression.cpp index 4c0a2e1ae..bb0545e8e 100644 --- a/src/storage/expressions/BinaryRelationExpression.cpp +++ b/src/storage/expressions/BinaryRelationExpression.cpp @@ -46,8 +46,8 @@ namespace storm { } } - void BinaryRelationExpression::accept(ExpressionVisitor* visitor) const { - visitor->visit(this); + boost::any BinaryRelationExpression::accept(ExpressionVisitor& visitor) const { + return visitor.visit(*this); } BinaryRelationExpression::RelationType BinaryRelationExpression::getRelationType() const { diff --git a/src/storage/expressions/BinaryRelationExpression.h b/src/storage/expressions/BinaryRelationExpression.h index 4e03a598e..898cd650f 100644 --- a/src/storage/expressions/BinaryRelationExpression.h +++ b/src/storage/expressions/BinaryRelationExpression.h @@ -36,7 +36,7 @@ namespace storm { virtual storm::expressions::OperatorType getOperator() const override; virtual bool evaluateAsBool(Valuation const* valuation = nullptr) const override; virtual std::shared_ptr simplify() const override; - virtual void accept(ExpressionVisitor* visitor) const override; + virtual boost::any accept(ExpressionVisitor& visitor) const override; /*! * Retrieves the relation associated with the expression. diff --git a/src/storage/expressions/BooleanLiteralExpression.cpp b/src/storage/expressions/BooleanLiteralExpression.cpp index d510c6f45..b1970b87c 100644 --- a/src/storage/expressions/BooleanLiteralExpression.cpp +++ b/src/storage/expressions/BooleanLiteralExpression.cpp @@ -34,8 +34,8 @@ namespace storm { return this->shared_from_this(); } - void BooleanLiteralExpression::accept(ExpressionVisitor* visitor) const { - visitor->visit(this); + boost::any BooleanLiteralExpression::accept(ExpressionVisitor& visitor) const { + return visitor.visit(*this); } bool BooleanLiteralExpression::getValue() const { diff --git a/src/storage/expressions/BooleanLiteralExpression.h b/src/storage/expressions/BooleanLiteralExpression.h index 57c19677c..61a5056c7 100644 --- a/src/storage/expressions/BooleanLiteralExpression.h +++ b/src/storage/expressions/BooleanLiteralExpression.h @@ -32,7 +32,7 @@ namespace storm { virtual std::set getVariables() const override; virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; - virtual void accept(ExpressionVisitor* visitor) const override; + virtual boost::any accept(ExpressionVisitor& visitor) const override; /*! * Retrieves the value of the boolean literal. diff --git a/src/storage/expressions/DoubleLiteralExpression.cpp b/src/storage/expressions/DoubleLiteralExpression.cpp index 00471b533..772053b67 100644 --- a/src/storage/expressions/DoubleLiteralExpression.cpp +++ b/src/storage/expressions/DoubleLiteralExpression.cpp @@ -26,8 +26,8 @@ namespace storm { return this->shared_from_this(); } - void DoubleLiteralExpression::accept(ExpressionVisitor* visitor) const { - visitor->visit(this); + boost::any DoubleLiteralExpression::accept(ExpressionVisitor& visitor) const { + return visitor.visit(*this); } double DoubleLiteralExpression::getValue() const { diff --git a/src/storage/expressions/DoubleLiteralExpression.h b/src/storage/expressions/DoubleLiteralExpression.h index 515326b8b..531662b7a 100644 --- a/src/storage/expressions/DoubleLiteralExpression.h +++ b/src/storage/expressions/DoubleLiteralExpression.h @@ -30,7 +30,7 @@ namespace storm { virtual std::set getVariables() const override; virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; - virtual void accept(ExpressionVisitor* visitor) const override; + virtual boost::any accept(ExpressionVisitor& visitor) const override; /*! * Retrieves the value of the double literal. diff --git a/src/storage/expressions/Expression.cpp b/src/storage/expressions/Expression.cpp index c0785bade..8e7d41902 100644 --- a/src/storage/expressions/Expression.cpp +++ b/src/storage/expressions/Expression.cpp @@ -4,7 +4,6 @@ #include "src/storage/expressions/Expression.h" #include "src/storage/expressions/SubstitutionVisitor.h" #include "src/storage/expressions/IdentifierSubstitutionVisitor.h" -#include "src/storage/expressions/TypeCheckVisitor.h" #include "src/storage/expressions/LinearityCheckVisitor.h" #include "src/storage/expressions/Expressions.h" #include "src/exceptions/InvalidTypeException.h" @@ -31,14 +30,6 @@ namespace storm { Expression Expression::substitute(std::unordered_map const& identifierToIdentifierMap) const { return IdentifierSubstitutionVisitor>(identifierToIdentifierMap).substitute(*this); } - - void Expression::check(std::map const& identifierToTypeMap) const { - return TypeCheckVisitor>(identifierToTypeMap).check(*this); - } - - void Expression::check(std::unordered_map const& identifierToTypeMap) const { - return TypeCheckVisitor>(identifierToTypeMap).check(*this); - } bool Expression::evaluateAsBool(Valuation const* valuation) const { return this->getBaseExpression().evaluateAsBool(valuation); @@ -99,17 +90,6 @@ namespace storm { std::set Expression::getVariables() const { return this->getBaseExpression().getVariables(); } - - std::map Expression::getVariablesAndTypes(bool validate) const { - if (validate) { - std::map result = this->getBaseExpression().getVariablesAndTypes(); - this->check(result); - return result; - } - else { - return this->getBaseExpression().getVariablesAndTypes(); - } - } bool Expression::isRelationalExpression() const { if (!this->isFunctionApplication()) { @@ -300,6 +280,10 @@ namespace storm { return Expression(std::shared_ptr(new UnaryNumericalFunctionExpression(ExpressionReturnType::Int, this->getBaseExpressionPointer(), UnaryNumericalFunctionExpression::OperatorType::Ceil))); } + boost::any Expression::accept(ExpressionVisitor& visitor) const { + return this->getBaseExpression().accept(visitor); + } + std::ostream& operator<<(std::ostream& stream, Expression const& expression) { stream << expression.getBaseExpression(); return stream; diff --git a/src/storage/expressions/Expression.h b/src/storage/expressions/Expression.h index 750641ddf..e456f295b 100644 --- a/src/storage/expressions/Expression.h +++ b/src/storage/expressions/Expression.h @@ -108,22 +108,6 @@ namespace storm { */ Expression substitute(std::unordered_map const& identifierToIdentifierMap) const; - /*! - * Checks that all identifiers appearing in the expression have the types given by the map. An exception - * is thrown in case a violation is found. - * - * @param identifierToTypeMap A mapping from identifiers to the types that are supposed to have. - */ - void check(std::map const& identifierToTypeMap) const; - - /*! - * Checks that all identifiers appearing in the expression have the types given by the map. An exception - * is thrown in case a violation is found. - * - * @param identifierToTypeMap A mapping from identifiers to the types that are supposed to have. - */ - void check(std::unordered_map const& identifierToTypeMap) const; - /*! * Evaluates the expression under the valuation of variables given by the valuation and returns the * resulting boolean value. If the return type of the expression is not a boolean an exception is thrown. @@ -314,7 +298,7 @@ namespace storm { * * @param visitor The visitor to accept. */ - void accept(ExpressionVisitor* visitor) const; + boost::any accept(ExpressionVisitor& visitor) const; friend std::ostream& operator<<(std::ostream& stream, Expression const& expression); diff --git a/src/storage/expressions/ExpressionVisitor.h b/src/storage/expressions/ExpressionVisitor.h index 1b417f6ed..5fdb486c2 100644 --- a/src/storage/expressions/ExpressionVisitor.h +++ b/src/storage/expressions/ExpressionVisitor.h @@ -1,6 +1,8 @@ #ifndef STORM_STORAGE_EXPRESSIONS_EXPRESSIONVISITOR_H_ #define STORM_STORAGE_EXPRESSIONS_EXPRESSIONVISITOR_H_ +#include + namespace storm { namespace expressions { // Forward-declare all expression classes. @@ -17,16 +19,16 @@ namespace storm { class ExpressionVisitor { public: - virtual void visit(IfThenElseExpression const* expression) = 0; - virtual void visit(BinaryBooleanFunctionExpression const* expression) = 0; - virtual void visit(BinaryNumericalFunctionExpression const* expression) = 0; - virtual void visit(BinaryRelationExpression const* expression) = 0; - virtual void visit(VariableExpression const* expression) = 0; - virtual void visit(UnaryBooleanFunctionExpression const* expression) = 0; - virtual void visit(UnaryNumericalFunctionExpression const* expression) = 0; - virtual void visit(BooleanLiteralExpression const* expression) = 0; - virtual void visit(IntegerLiteralExpression const* expression) = 0; - virtual void visit(DoubleLiteralExpression const* expression) = 0; + virtual boost::any visit(IfThenElseExpression const& expression) = 0; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) = 0; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) = 0; + virtual boost::any visit(BinaryRelationExpression const& expression) = 0; + virtual boost::any visit(VariableExpression const& expression) = 0; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) = 0; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) = 0; + virtual boost::any visit(BooleanLiteralExpression const& expression) = 0; + virtual boost::any visit(IntegerLiteralExpression const& expression) = 0; + virtual boost::any visit(DoubleLiteralExpression const& expression) = 0; }; } } diff --git a/src/storage/expressions/IdentifierSubstitutionVisitor.cpp b/src/storage/expressions/IdentifierSubstitutionVisitor.cpp index 19724dea1..30b83b9c7 100644 --- a/src/storage/expressions/IdentifierSubstitutionVisitor.cpp +++ b/src/storage/expressions/IdentifierSubstitutionVisitor.cpp @@ -14,138 +14,110 @@ namespace storm { template Expression IdentifierSubstitutionVisitor::substitute(Expression const& expression) { - expression.getBaseExpression().accept(this); - return Expression(this->expressionStack.top()); + return Expression(boost::any_cast>(expression.getBaseExpression().accept(*this))); } template - void IdentifierSubstitutionVisitor::visit(IfThenElseExpression const* expression) { - expression->getCondition()->accept(this); - std::shared_ptr conditionExpression = expressionStack.top(); - expressionStack.pop(); - - expression->getThenExpression()->accept(this); - std::shared_ptr thenExpression = expressionStack.top(); - expressionStack.pop(); - - expression->getElseExpression()->accept(this); - std::shared_ptr elseExpression = expressionStack.top(); - expressionStack.pop(); + boost::any IdentifierSubstitutionVisitor::visit(IfThenElseExpression const& expression) { + std::shared_ptr conditionExpression = boost::any_cast>(expression.getCondition()->accept(*this)); + std::shared_ptr thenExpression = boost::any_cast>(expression.getThenExpression()->accept(*this)); + std::shared_ptr elseExpression = boost::any_cast>(expression.getElseExpression()->accept(*this)); // If the arguments did not change, we simply push the expression itself. - if (conditionExpression.get() == expression->getCondition().get() && thenExpression.get() == expression->getThenExpression().get() && elseExpression.get() == expression->getElseExpression().get()) { - this->expressionStack.push(expression->getSharedPointer()); + if (conditionExpression.get() == expression.getCondition().get() && thenExpression.get() == expression.getThenExpression().get() && elseExpression.get() == expression.getElseExpression().get()) { + return expression.getSharedPointer(); } else { - this->expressionStack.push(std::shared_ptr(new IfThenElseExpression(expression->getReturnType(), conditionExpression, thenExpression, elseExpression))); + return std::shared_ptr(new IfThenElseExpression(expression.getReturnType(), conditionExpression, thenExpression, elseExpression)); } } template - void IdentifierSubstitutionVisitor::visit(BinaryBooleanFunctionExpression const* expression) { - expression->getFirstOperand()->accept(this); - std::shared_ptr firstExpression = expressionStack.top(); - expressionStack.pop(); - - expression->getSecondOperand()->accept(this); - std::shared_ptr secondExpression = expressionStack.top(); - expressionStack.pop(); + boost::any IdentifierSubstitutionVisitor::visit(BinaryBooleanFunctionExpression const& expression) { + std::shared_ptr firstExpression = boost::any_cast>(expression.getFirstOperand()->accept(*this)); + std::shared_ptr secondExpression = boost::any_cast>(expression.getSecondOperand()->accept(*this)); // If the arguments did not change, we simply push the expression itself. - if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) { - this->expressionStack.push(expression->getSharedPointer()); + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); } else { - this->expressionStack.push(std::shared_ptr(new BinaryBooleanFunctionExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getOperatorType()))); + return std::shared_ptr(new BinaryBooleanFunctionExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getOperatorType())); } } template - void IdentifierSubstitutionVisitor::visit(BinaryNumericalFunctionExpression const* expression) { - expression->getFirstOperand()->accept(this); - std::shared_ptr firstExpression = expressionStack.top(); - expressionStack.pop(); - - expression->getSecondOperand()->accept(this); - std::shared_ptr secondExpression = expressionStack.top(); - expressionStack.pop(); + boost::any IdentifierSubstitutionVisitor::visit(BinaryNumericalFunctionExpression const& expression) { + std::shared_ptr firstExpression = boost::any_cast>(expression.getFirstOperand()->accept(*this)); + std::shared_ptr secondExpression = boost::any_cast>(expression.getSecondOperand()->accept(*this)); // If the arguments did not change, we simply push the expression itself. - if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) { - this->expressionStack.push(expression->getSharedPointer()); + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); } else { - this->expressionStack.push(std::shared_ptr(new BinaryNumericalFunctionExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getOperatorType()))); + return std::shared_ptr(new BinaryNumericalFunctionExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getOperatorType())); } } template - void IdentifierSubstitutionVisitor::visit(BinaryRelationExpression const* expression) { - expression->getFirstOperand()->accept(this); - std::shared_ptr firstExpression = expressionStack.top(); - expressionStack.pop(); - - expression->getSecondOperand()->accept(this); - std::shared_ptr secondExpression = expressionStack.top(); - expressionStack.pop(); + boost::any IdentifierSubstitutionVisitor::visit(BinaryRelationExpression const& expression) { + std::shared_ptr firstExpression = boost::any_cast>(expression.getFirstOperand()->accept(*this)); + std::shared_ptr secondExpression = boost::any_cast>(expression.getSecondOperand()->accept(*this)); // If the arguments did not change, we simply push the expression itself. - if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) { - this->expressionStack.push(expression->getSharedPointer()); + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); } else { - this->expressionStack.push(std::shared_ptr(new BinaryRelationExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getRelationType()))); + return std::shared_ptr(new BinaryRelationExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getRelationType())); } } template - void IdentifierSubstitutionVisitor::visit(VariableExpression const* expression) { + boost::any IdentifierSubstitutionVisitor::visit(VariableExpression const& expression) { // If the variable is in the key set of the substitution, we need to replace it. - auto const& namePair = this->identifierToIdentifierMap.find(expression->getVariableName()); + auto const& namePair = this->identifierToIdentifierMap.find(expression.getVariableName()); if (namePair != this->identifierToIdentifierMap.end()) { - this->expressionStack.push(std::shared_ptr(new VariableExpression(expression->getReturnType(), namePair->second))); + return std::shared_ptr(new VariableExpression(expression.getReturnType(), namePair->second)); } else { - this->expressionStack.push(expression->getSharedPointer()); + return expression.getSharedPointer(); } } template - void IdentifierSubstitutionVisitor::visit(UnaryBooleanFunctionExpression const* expression) { - expression->getOperand()->accept(this); - std::shared_ptr operandExpression = expressionStack.top(); - expressionStack.pop(); + boost::any IdentifierSubstitutionVisitor::visit(UnaryBooleanFunctionExpression const& expression) { + std::shared_ptr operandExpression = boost::any_cast>(expression.getOperand()->accept(*this)); // If the argument did not change, we simply push the expression itself. - if (operandExpression.get() == expression->getOperand().get()) { - expressionStack.push(expression->getSharedPointer()); + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); } else { - expressionStack.push(std::shared_ptr(new UnaryBooleanFunctionExpression(expression->getReturnType(), operandExpression, expression->getOperatorType()))); + return std::shared_ptr(new UnaryBooleanFunctionExpression(expression.getReturnType(), operandExpression, expression.getOperatorType())); } } template - void IdentifierSubstitutionVisitor::visit(UnaryNumericalFunctionExpression const* expression) { - expression->getOperand()->accept(this); - std::shared_ptr operandExpression = expressionStack.top(); - expressionStack.pop(); + boost::any IdentifierSubstitutionVisitor::visit(UnaryNumericalFunctionExpression const& expression) { + std::shared_ptr operandExpression = boost::any_cast>(expression.getOperand()->accept(*this)); // If the argument did not change, we simply push the expression itself. - if (operandExpression.get() == expression->getOperand().get()) { - expressionStack.push(expression->getSharedPointer()); + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); } else { - expressionStack.push(std::shared_ptr(new UnaryNumericalFunctionExpression(expression->getReturnType(), operandExpression, expression->getOperatorType()))); + return std::shared_ptr(new UnaryNumericalFunctionExpression(expression.getReturnType(), operandExpression, expression.getOperatorType())); } } template - void IdentifierSubstitutionVisitor::visit(BooleanLiteralExpression const* expression) { - this->expressionStack.push(expression->getSharedPointer()); + boost::any IdentifierSubstitutionVisitor::visit(BooleanLiteralExpression const& expression) { + return expression.getSharedPointer(); } template - void IdentifierSubstitutionVisitor::visit(IntegerLiteralExpression const* expression) { - this->expressionStack.push(expression->getSharedPointer()); + boost::any IdentifierSubstitutionVisitor::visit(IntegerLiteralExpression const& expression) { + return expression.getSharedPointer(); } template - void IdentifierSubstitutionVisitor::visit(DoubleLiteralExpression const* expression) { - this->expressionStack.push(expression->getSharedPointer()); + boost::any IdentifierSubstitutionVisitor::visit(DoubleLiteralExpression const& expression) { + return expression.getSharedPointer(); } // Explicitly instantiate the class with map and unordered_map. diff --git a/src/storage/expressions/IdentifierSubstitutionVisitor.h b/src/storage/expressions/IdentifierSubstitutionVisitor.h index 8e8723a70..a2044c9db 100644 --- a/src/storage/expressions/IdentifierSubstitutionVisitor.h +++ b/src/storage/expressions/IdentifierSubstitutionVisitor.h @@ -28,21 +28,18 @@ namespace storm { */ Expression substitute(Expression const& expression); - virtual void visit(IfThenElseExpression const* expression) override; - virtual void visit(BinaryBooleanFunctionExpression const* expression) override; - virtual void visit(BinaryNumericalFunctionExpression const* expression) override; - virtual void visit(BinaryRelationExpression const* expression) override; - virtual void visit(VariableExpression const* expression) override; - virtual void visit(UnaryBooleanFunctionExpression const* expression) override; - virtual void visit(UnaryNumericalFunctionExpression const* expression) override; - virtual void visit(BooleanLiteralExpression const* expression) override; - virtual void visit(IntegerLiteralExpression const* expression) override; - virtual void visit(DoubleLiteralExpression const* expression) override; + virtual boost::any visit(IfThenElseExpression const& expression) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) override; + virtual boost::any visit(BinaryRelationExpression const& expression) override; + virtual boost::any visit(VariableExpression const& expression) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) override; + virtual boost::any visit(BooleanLiteralExpression const& expression) override; + virtual boost::any visit(IntegerLiteralExpression const& expression) override; + virtual boost::any visit(DoubleLiteralExpression const& expression) override; private: - // A stack of expression used to pass the results to the higher levels. - std::stack> expressionStack; - // A mapping of identifier names to expressions with which they shall be replaced. MapType const& identifierToIdentifierMap; }; diff --git a/src/storage/expressions/IfThenElseExpression.cpp b/src/storage/expressions/IfThenElseExpression.cpp index fb17fe142..229b63b31 100644 --- a/src/storage/expressions/IfThenElseExpression.cpp +++ b/src/storage/expressions/IfThenElseExpression.cpp @@ -99,8 +99,8 @@ namespace storm { } } - void IfThenElseExpression::accept(ExpressionVisitor* visitor) const { - visitor->visit(this); + boost::any IfThenElseExpression::accept(ExpressionVisitor& visitor) const { + return visitor.visit(*this); } std::shared_ptr IfThenElseExpression::getCondition() const { diff --git a/src/storage/expressions/IfThenElseExpression.h b/src/storage/expressions/IfThenElseExpression.h index 63619d852..47d34e3f3 100644 --- a/src/storage/expressions/IfThenElseExpression.h +++ b/src/storage/expressions/IfThenElseExpression.h @@ -38,7 +38,7 @@ namespace storm { virtual std::set getVariables() const override; virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; - virtual void accept(ExpressionVisitor* visitor) const override; + virtual boost::any accept(ExpressionVisitor& visitor) const override; /*! * Retrieves the condition expression of the if-then-else expression. diff --git a/src/storage/expressions/IntegerLiteralExpression.cpp b/src/storage/expressions/IntegerLiteralExpression.cpp index b05ab5a67..9f4a22d6c 100644 --- a/src/storage/expressions/IntegerLiteralExpression.cpp +++ b/src/storage/expressions/IntegerLiteralExpression.cpp @@ -30,8 +30,8 @@ namespace storm { return this->shared_from_this(); } - void IntegerLiteralExpression::accept(ExpressionVisitor* visitor) const { - visitor->visit(this); + boost::any IntegerLiteralExpression::accept(ExpressionVisitor& visitor) const { + return visitor.visit(*this); } int_fast64_t IntegerLiteralExpression::getValue() const { diff --git a/src/storage/expressions/IntegerLiteralExpression.h b/src/storage/expressions/IntegerLiteralExpression.h index 5d6c731a5..e764f5bbc 100644 --- a/src/storage/expressions/IntegerLiteralExpression.h +++ b/src/storage/expressions/IntegerLiteralExpression.h @@ -31,7 +31,7 @@ namespace storm { virtual std::set getVariables() const override; virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; - virtual void accept(ExpressionVisitor* visitor) const override; + virtual boost::any accept(ExpressionVisitor& visitor) const override; /*! * Retrieves the value of the integer literal. diff --git a/src/storage/expressions/LinearCoefficientVisitor.cpp b/src/storage/expressions/LinearCoefficientVisitor.cpp index dce0b4ece..29adef811 100644 --- a/src/storage/expressions/LinearCoefficientVisitor.cpp +++ b/src/storage/expressions/LinearCoefficientVisitor.cpp @@ -7,26 +7,22 @@ namespace storm { namespace expressions { std::pair LinearCoefficientVisitor::getLinearCoefficients(Expression const& expression) { - expression.getBaseExpression().accept(this); - return resultStack.top(); + return boost::any_cast>(expression.getBaseExpression().accept(*this)); } - void LinearCoefficientVisitor::visit(IfThenElseExpression const* expression) { + boost::any LinearCoefficientVisitor::visit(IfThenElseExpression const& expression) { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } - void LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const* expression) { + boost::any LinearCoefficientVisitor::visit(BinaryBooleanFunctionExpression const& expression) { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } - void LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const* expression) { - if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) { - expression->getFirstOperand()->accept(this); - std::pair leftResult = resultStack.top(); - resultStack.pop(); - expression->getSecondOperand()->accept(this); - std::pair& rightResult = resultStack.top(); - + boost::any LinearCoefficientVisitor::visit(BinaryNumericalFunctionExpression const& expression) { + std::pair leftResult = boost::any_cast>(expression.getFirstOperand()->accept(*this)); + std::pair rightResult = boost::any_cast>(expression.getSecondOperand()->accept(*this)); + + if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus) { // Now add the left result to the right result. for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) { if (rightResult.first.containsDoubleIdentifier(identifier)) { @@ -36,14 +32,7 @@ namespace storm { } } rightResult.second += leftResult.second; - return; - } else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) { - expression->getFirstOperand()->accept(this); - std::pair leftResult = resultStack.top(); - resultStack.pop(); - expression->getSecondOperand()->accept(this); - std::pair& rightResult = resultStack.top(); - + } else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Minus) { // Now subtract the right result from the left result. for (auto const& identifier : leftResult.first.getDoubleIdentifiers()) { if (rightResult.first.containsDoubleIdentifier(identifier)) { @@ -58,14 +47,7 @@ namespace storm { } } rightResult.second = leftResult.second - rightResult.second; - return; - } else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) { - expression->getFirstOperand()->accept(this); - std::pair leftResult = resultStack.top(); - resultStack.pop(); - expression->getSecondOperand()->accept(this); - std::pair& rightResult = resultStack.top(); - + } else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times) { // If the expression is linear, either the left or the right side must not contain variables. STORM_LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); if (leftResult.first.getNumberOfIdentifiers() == 0) { @@ -78,14 +60,7 @@ namespace storm { } } rightResult.second *= leftResult.second; - return; - } else if (expression->getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) { - expression->getFirstOperand()->accept(this); - std::pair leftResult = resultStack.top(); - resultStack.pop(); - expression->getSecondOperand()->accept(this); - std::pair& rightResult = resultStack.top(); - + } else if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Divide) { // If the expression is linear, either the left or the right side must not contain variables. STORM_LOG_THROW(leftResult.first.getNumberOfIdentifiers() == 0 || rightResult.first.getNumberOfIdentifiers() == 0, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); if (leftResult.first.getNumberOfIdentifiers() == 0) { @@ -98,54 +73,56 @@ namespace storm { } } rightResult.second = leftResult.second / leftResult.second; - return; } else { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } + return rightResult; } - void LinearCoefficientVisitor::visit(BinaryRelationExpression const* expression) { + boost::any LinearCoefficientVisitor::visit(BinaryRelationExpression const& expression) { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } - void LinearCoefficientVisitor::visit(VariableExpression const* expression) { + boost::any LinearCoefficientVisitor::visit(VariableExpression const& expression) { SimpleValuation valuation; - switch (expression->getReturnType()) { + switch (expression.getReturnType()) { case ExpressionReturnType::Bool: STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); break; case ExpressionReturnType::Int: - case ExpressionReturnType::Double: valuation.addDoubleIdentifier(expression->getVariableName(), 1); break; + case ExpressionReturnType::Double: valuation.addDoubleIdentifier(expression.getVariableName(), 1); break; case ExpressionReturnType::Undefined: STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Illegal expression return type."); break; } - resultStack.push(std::make_pair(valuation, 0)); + return std::make_pair(valuation, static_cast(0)); } - void LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const* expression) { + boost::any LinearCoefficientVisitor::visit(UnaryBooleanFunctionExpression const& expression) { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } - void LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const* expression) { - if (expression->getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) { + boost::any LinearCoefficientVisitor::visit(UnaryNumericalFunctionExpression const& expression) { + std::pair childResult = boost::any_cast>(expression.getOperand()->accept(*this)); + + if (expression.getOperatorType() == UnaryNumericalFunctionExpression::OperatorType::Minus) { // Here, we need to negate all double identifiers. - std::pair& valuationConstantPair = resultStack.top(); - for (auto const& identifier : valuationConstantPair.first.getDoubleIdentifiers()) { - valuationConstantPair.first.setDoubleValue(identifier, -valuationConstantPair.first.getDoubleValue(identifier)); + for (auto const& identifier : childResult.first.getDoubleIdentifiers()) { + childResult.first.setDoubleValue(identifier, -childResult.first.getDoubleValue(identifier)); } + return childResult; } else { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } } - void LinearCoefficientVisitor::visit(BooleanLiteralExpression const* expression) { + boost::any LinearCoefficientVisitor::visit(BooleanLiteralExpression const& expression) { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Expression is non-linear."); } - void LinearCoefficientVisitor::visit(IntegerLiteralExpression const* expression) { - resultStack.push(std::make_pair(SimpleValuation(), static_cast(expression->getValue()))); + boost::any LinearCoefficientVisitor::visit(IntegerLiteralExpression const& expression) { + return std::make_pair(SimpleValuation(), static_cast(expression.getValue())); } - void LinearCoefficientVisitor::visit(DoubleLiteralExpression const* expression) { - resultStack.push(std::make_pair(SimpleValuation(), expression->getValue())); + boost::any LinearCoefficientVisitor::visit(DoubleLiteralExpression const& expression) { + return std::make_pair(SimpleValuation(), expression.getValue()); } } } \ No newline at end of file diff --git a/src/storage/expressions/LinearCoefficientVisitor.h b/src/storage/expressions/LinearCoefficientVisitor.h index 263e752c8..faf02414b 100644 --- a/src/storage/expressions/LinearCoefficientVisitor.h +++ b/src/storage/expressions/LinearCoefficientVisitor.h @@ -26,19 +26,16 @@ namespace storm { */ std::pair getLinearCoefficients(Expression const& expression); - virtual void visit(IfThenElseExpression const* expression) override; - virtual void visit(BinaryBooleanFunctionExpression const* expression) override; - virtual void visit(BinaryNumericalFunctionExpression const* expression) override; - virtual void visit(BinaryRelationExpression const* expression) override; - virtual void visit(VariableExpression const* expression) override; - virtual void visit(UnaryBooleanFunctionExpression const* expression) override; - virtual void visit(UnaryNumericalFunctionExpression const* expression) override; - virtual void visit(BooleanLiteralExpression const* expression) override; - virtual void visit(IntegerLiteralExpression const* expression) override; - virtual void visit(DoubleLiteralExpression const* expression) override; - - private: - std::stack> resultStack; + virtual boost::any visit(IfThenElseExpression const& expression) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) override; + virtual boost::any visit(BinaryRelationExpression const& expression) override; + virtual boost::any visit(VariableExpression const& expression) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) override; + virtual boost::any visit(BooleanLiteralExpression const& expression) override; + virtual boost::any visit(IntegerLiteralExpression const& expression) override; + virtual boost::any visit(DoubleLiteralExpression const& expression) override; }; } } diff --git a/src/storage/expressions/LinearityCheckVisitor.cpp b/src/storage/expressions/LinearityCheckVisitor.cpp index 78b6cf1f4..5749d3a35 100644 --- a/src/storage/expressions/LinearityCheckVisitor.cpp +++ b/src/storage/expressions/LinearityCheckVisitor.cpp @@ -6,108 +6,84 @@ namespace storm { namespace expressions { - LinearityCheckVisitor::LinearityCheckVisitor() : resultStack() { + LinearityCheckVisitor::LinearityCheckVisitor() { // Intentionally left empty. } bool LinearityCheckVisitor::check(Expression const& expression) { - expression.getBaseExpression().accept(this); - return resultStack.top() == LinearityStatus::LinearWithoutVariables || resultStack.top() == LinearityStatus::LinearContainsVariables; + LinearityStatus result = boost::any_cast(expression.getBaseExpression().accept(*this)); + return result == LinearityStatus::LinearWithoutVariables || result == LinearityStatus::LinearContainsVariables; } - void LinearityCheckVisitor::visit(IfThenElseExpression const* expression) { + boost::any LinearityCheckVisitor::visit(IfThenElseExpression const& expression) { // An if-then-else expression is never linear. - resultStack.push(LinearityStatus::NonLinear); + return LinearityStatus::NonLinear; } - void LinearityCheckVisitor::visit(BinaryBooleanFunctionExpression const* expression) { + boost::any LinearityCheckVisitor::visit(BinaryBooleanFunctionExpression const& expression) { // Boolean function applications are not allowed in linear expressions. - resultStack.push(LinearityStatus::NonLinear); + return LinearityStatus::NonLinear; } - void LinearityCheckVisitor::visit(BinaryNumericalFunctionExpression const* expression) { - LinearityStatus leftResult; - LinearityStatus rightResult; - switch (expression->getOperatorType()) { + boost::any LinearityCheckVisitor::visit(BinaryNumericalFunctionExpression const& expression) { + LinearityStatus leftResult = boost::any_cast(expression.getFirstOperand()->accept(*this)); + if (leftResult == LinearityStatus::NonLinear) { + return LinearityStatus::NonLinear; + } + + LinearityStatus rightResult = boost::any_cast(expression.getSecondOperand()->accept(*this)); + if (rightResult == LinearityStatus::NonLinear) { + return LinearityStatus::NonLinear; + } + + switch (expression.getOperatorType()) { case BinaryNumericalFunctionExpression::OperatorType::Plus: case BinaryNumericalFunctionExpression::OperatorType::Minus: - expression->getFirstOperand()->accept(this); - leftResult = resultStack.top(); - - if (leftResult == LinearityStatus::NonLinear) { - return; - } else { - resultStack.pop(); - expression->getSecondOperand()->accept(this); - rightResult = resultStack.top(); - if (rightResult == LinearityStatus::NonLinear) { - return; - } - resultStack.pop(); - } - - resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables); - break; + return (leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables); case BinaryNumericalFunctionExpression::OperatorType::Times: case BinaryNumericalFunctionExpression::OperatorType::Divide: - expression->getFirstOperand()->accept(this); - leftResult = resultStack.top(); - - if (leftResult == LinearityStatus::NonLinear) { - return; - } else { - resultStack.pop(); - expression->getSecondOperand()->accept(this); - rightResult = resultStack.top(); - if (rightResult == LinearityStatus::NonLinear) { - return; - } - resultStack.pop(); - } - if (leftResult == LinearityStatus::LinearContainsVariables && rightResult == LinearityStatus::LinearContainsVariables) { - resultStack.push(LinearityStatus::NonLinear); + return LinearityStatus::NonLinear; } - resultStack.push(leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables); - break; - case BinaryNumericalFunctionExpression::OperatorType::Min: resultStack.push(LinearityStatus::NonLinear); break; - case BinaryNumericalFunctionExpression::OperatorType::Max: resultStack.push(LinearityStatus::NonLinear); break; - case BinaryNumericalFunctionExpression::OperatorType::Power: resultStack.push(LinearityStatus::NonLinear); break; + return (leftResult == LinearityStatus::LinearContainsVariables || rightResult == LinearityStatus::LinearContainsVariables ? LinearityStatus::LinearContainsVariables : LinearityStatus::LinearWithoutVariables); + case BinaryNumericalFunctionExpression::OperatorType::Min: return LinearityStatus::NonLinear; break; + case BinaryNumericalFunctionExpression::OperatorType::Max: return LinearityStatus::NonLinear; break; + case BinaryNumericalFunctionExpression::OperatorType::Power: return LinearityStatus::NonLinear; break; } } - void LinearityCheckVisitor::visit(BinaryRelationExpression const* expression) { - resultStack.push(LinearityStatus::NonLinear); + boost::any LinearityCheckVisitor::visit(BinaryRelationExpression const& expression) { + return LinearityStatus::NonLinear; } - void LinearityCheckVisitor::visit(VariableExpression const* expression) { - resultStack.push(LinearityStatus::LinearContainsVariables); + boost::any LinearityCheckVisitor::visit(VariableExpression const& expression) { + return LinearityStatus::LinearContainsVariables; } - void LinearityCheckVisitor::visit(UnaryBooleanFunctionExpression const* expression) { + boost::any LinearityCheckVisitor::visit(UnaryBooleanFunctionExpression const& expression) { // Boolean function applications are not allowed in linear expressions. - resultStack.push(LinearityStatus::NonLinear); + return LinearityStatus::NonLinear; } - void LinearityCheckVisitor::visit(UnaryNumericalFunctionExpression const* expression) { - switch (expression->getOperatorType()) { - case UnaryNumericalFunctionExpression::OperatorType::Minus: break; + boost::any LinearityCheckVisitor::visit(UnaryNumericalFunctionExpression const& expression) { + switch (expression.getOperatorType()) { + case UnaryNumericalFunctionExpression::OperatorType::Minus: return expression.getOperand()->accept(*this); case UnaryNumericalFunctionExpression::OperatorType::Floor: - case UnaryNumericalFunctionExpression::OperatorType::Ceil: resultStack.pop(); resultStack.push(LinearityStatus::NonLinear); break; + case UnaryNumericalFunctionExpression::OperatorType::Ceil: return LinearityStatus::NonLinear; } } - void LinearityCheckVisitor::visit(BooleanLiteralExpression const* expression) { - resultStack.push(LinearityStatus::NonLinear); + boost::any LinearityCheckVisitor::visit(BooleanLiteralExpression const& expression) { + return LinearityStatus::NonLinear; } - void LinearityCheckVisitor::visit(IntegerLiteralExpression const* expression) { - resultStack.push(LinearityStatus::LinearWithoutVariables); + boost::any LinearityCheckVisitor::visit(IntegerLiteralExpression const& expression) { + return LinearityStatus::LinearWithoutVariables; } - void LinearityCheckVisitor::visit(DoubleLiteralExpression const* expression) { - resultStack.push(LinearityStatus::LinearWithoutVariables); + boost::any LinearityCheckVisitor::visit(DoubleLiteralExpression const& expression) { + return LinearityStatus::LinearWithoutVariables; } } } \ No newline at end of file diff --git a/src/storage/expressions/LinearityCheckVisitor.h b/src/storage/expressions/LinearityCheckVisitor.h index d76b658c8..b9e1bf15c 100644 --- a/src/storage/expressions/LinearityCheckVisitor.h +++ b/src/storage/expressions/LinearityCheckVisitor.h @@ -22,22 +22,19 @@ namespace storm { */ bool check(Expression const& expression); - virtual void visit(IfThenElseExpression const* expression) override; - virtual void visit(BinaryBooleanFunctionExpression const* expression) override; - virtual void visit(BinaryNumericalFunctionExpression const* expression) override; - virtual void visit(BinaryRelationExpression const* expression) override; - virtual void visit(VariableExpression const* expression) override; - virtual void visit(UnaryBooleanFunctionExpression const* expression) override; - virtual void visit(UnaryNumericalFunctionExpression const* expression) override; - virtual void visit(BooleanLiteralExpression const* expression) override; - virtual void visit(IntegerLiteralExpression const* expression) override; - virtual void visit(DoubleLiteralExpression const* expression) override; + virtual boost::any visit(IfThenElseExpression const& expression) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) override; + virtual boost::any visit(BinaryRelationExpression const& expression) override; + virtual boost::any visit(VariableExpression const& expression) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) override; + virtual boost::any visit(BooleanLiteralExpression const& expression) override; + virtual boost::any visit(IntegerLiteralExpression const& expression) override; + virtual boost::any visit(DoubleLiteralExpression const& expression) override; private: enum class LinearityStatus { NonLinear, LinearContainsVariables, LinearWithoutVariables }; - - // A stack for communicating the results of the subexpressions. - std::stack resultStack; }; } } diff --git a/src/storage/expressions/SubstitutionVisitor.cpp b/src/storage/expressions/SubstitutionVisitor.cpp index 43aa01ab3..0924d9eb8 100644 --- a/src/storage/expressions/SubstitutionVisitor.cpp +++ b/src/storage/expressions/SubstitutionVisitor.cpp @@ -14,138 +14,110 @@ namespace storm { template Expression SubstitutionVisitor::substitute(Expression const& expression) { - expression.getBaseExpression().accept(this); - return Expression(this->expressionStack.top()); + return Expression(boost::any_cast>(expression.getBaseExpression().accept(*this))); } template - void SubstitutionVisitor::visit(IfThenElseExpression const* expression) { - expression->getCondition()->accept(this); - std::shared_ptr conditionExpression = expressionStack.top(); - expressionStack.pop(); - - expression->getThenExpression()->accept(this); - std::shared_ptr thenExpression = expressionStack.top(); - expressionStack.pop(); - - expression->getElseExpression()->accept(this); - std::shared_ptr elseExpression = expressionStack.top(); - expressionStack.pop(); + boost::any SubstitutionVisitor::visit(IfThenElseExpression const& expression) { + std::shared_ptr conditionExpression = boost::any_cast>(expression.getCondition()->accept(*this)); + std::shared_ptr thenExpression = boost::any_cast>(expression.getThenExpression()->accept(*this)); + std::shared_ptr elseExpression = boost::any_cast>(expression.getElseExpression()->accept(*this)); // If the arguments did not change, we simply push the expression itself. - if (conditionExpression.get() == expression->getCondition().get() && thenExpression.get() == expression->getThenExpression().get() && elseExpression.get() == expression->getElseExpression().get()) { - this->expressionStack.push(expression->getSharedPointer()); + if (conditionExpression.get() == expression.getCondition().get() && thenExpression.get() == expression.getThenExpression().get() && elseExpression.get() == expression.getElseExpression().get()) { + return expression.getSharedPointer(); } else { - this->expressionStack.push(std::shared_ptr(new IfThenElseExpression(expression->getReturnType(), conditionExpression, thenExpression, elseExpression))); + return static_cast>(std::shared_ptr(new IfThenElseExpression(expression.getReturnType(), conditionExpression, thenExpression, elseExpression))); } } template - void SubstitutionVisitor::visit(BinaryBooleanFunctionExpression const* expression) { - expression->getFirstOperand()->accept(this); - std::shared_ptr firstExpression = expressionStack.top(); - expressionStack.pop(); - - expression->getSecondOperand()->accept(this); - std::shared_ptr secondExpression = expressionStack.top(); - expressionStack.pop(); + boost::any SubstitutionVisitor::visit(BinaryBooleanFunctionExpression const& expression) { + std::shared_ptr firstExpression = boost::any_cast>(expression.getFirstOperand()->accept(*this)); + std::shared_ptr secondExpression = boost::any_cast>(expression.getSecondOperand()->accept(*this)); // If the arguments did not change, we simply push the expression itself. - if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) { - this->expressionStack.push(expression->getSharedPointer()); + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); } else { - this->expressionStack.push(std::shared_ptr(new BinaryBooleanFunctionExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getOperatorType()))); + return static_cast>(std::shared_ptr(new BinaryBooleanFunctionExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getOperatorType()))); } } template - void SubstitutionVisitor::visit(BinaryNumericalFunctionExpression const* expression) { - expression->getFirstOperand()->accept(this); - std::shared_ptr firstExpression = expressionStack.top(); - expressionStack.pop(); - - expression->getSecondOperand()->accept(this); - std::shared_ptr secondExpression = expressionStack.top(); - expressionStack.pop(); + boost::any SubstitutionVisitor::visit(BinaryNumericalFunctionExpression const& expression) { + std::shared_ptr firstExpression = boost::any_cast>(expression.getFirstOperand()->accept(*this)); + std::shared_ptr secondExpression = boost::any_cast>(expression.getSecondOperand()->accept(*this)); // If the arguments did not change, we simply push the expression itself. - if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) { - this->expressionStack.push(expression->getSharedPointer()); + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); } else { - this->expressionStack.push(std::shared_ptr(new BinaryNumericalFunctionExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getOperatorType()))); + return static_cast>(std::shared_ptr(new BinaryNumericalFunctionExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getOperatorType()))); } } template - void SubstitutionVisitor::visit(BinaryRelationExpression const* expression) { - expression->getFirstOperand()->accept(this); - std::shared_ptr firstExpression = expressionStack.top(); - expressionStack.pop(); - - expression->getSecondOperand()->accept(this); - std::shared_ptr secondExpression = expressionStack.top(); - expressionStack.pop(); + boost::any SubstitutionVisitor::visit(BinaryRelationExpression const& expression) { + std::shared_ptr firstExpression = boost::any_cast>(expression.getFirstOperand()->accept(*this)); + std::shared_ptr secondExpression = boost::any_cast>(expression.getSecondOperand()->accept(*this)); // If the arguments did not change, we simply push the expression itself. - if (firstExpression.get() == expression->getFirstOperand().get() && secondExpression.get() == expression->getSecondOperand().get()) { - this->expressionStack.push(expression->getSharedPointer()); + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); } else { - this->expressionStack.push(std::shared_ptr(new BinaryRelationExpression(expression->getReturnType(), firstExpression, secondExpression, expression->getRelationType()))); + return static_cast>(std::shared_ptr(new BinaryRelationExpression(expression.getReturnType(), firstExpression, secondExpression, expression.getRelationType()))); } } template - void SubstitutionVisitor::visit(VariableExpression const* expression) { + boost::any SubstitutionVisitor::visit(VariableExpression const& expression) { // If the variable is in the key set of the substitution, we need to replace it. - auto const& nameExpressionPair = this->identifierToExpressionMap.find(expression->getVariableName()); + auto const& nameExpressionPair = this->identifierToExpressionMap.find(expression.getVariableName()); if (nameExpressionPair != this->identifierToExpressionMap.end()) { - this->expressionStack.push(nameExpressionPair->second.getBaseExpressionPointer()); + return nameExpressionPair->second.getBaseExpressionPointer(); } else { - this->expressionStack.push(expression->getSharedPointer()); + return expression.getSharedPointer(); } } template - void SubstitutionVisitor::visit(UnaryBooleanFunctionExpression const* expression) { - expression->getOperand()->accept(this); - std::shared_ptr operandExpression = expressionStack.top(); - expressionStack.pop(); + boost::any SubstitutionVisitor::visit(UnaryBooleanFunctionExpression const& expression) { + std::shared_ptr operandExpression = boost::any_cast>(expression.getOperand()->accept(*this)); // If the argument did not change, we simply push the expression itself. - if (operandExpression.get() == expression->getOperand().get()) { - expressionStack.push(expression->getSharedPointer()); + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); } else { - expressionStack.push(std::shared_ptr(new UnaryBooleanFunctionExpression(expression->getReturnType(), operandExpression, expression->getOperatorType()))); + return static_cast>(std::shared_ptr(new UnaryBooleanFunctionExpression(expression.getReturnType(), operandExpression, expression.getOperatorType()))); } } template - void SubstitutionVisitor::visit(UnaryNumericalFunctionExpression const* expression) { - expression->getOperand()->accept(this); - std::shared_ptr operandExpression = expressionStack.top(); - expressionStack.pop(); + boost::any SubstitutionVisitor::visit(UnaryNumericalFunctionExpression const& expression) { + std::shared_ptr operandExpression = boost::any_cast>(expression.getOperand()->accept(*this)); // If the argument did not change, we simply push the expression itself. - if (operandExpression.get() == expression->getOperand().get()) { - expressionStack.push(expression->getSharedPointer()); + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); } else { - expressionStack.push(std::shared_ptr(new UnaryNumericalFunctionExpression(expression->getReturnType(), operandExpression, expression->getOperatorType()))); + return static_cast>(std::shared_ptr(new UnaryNumericalFunctionExpression(expression.getReturnType(), operandExpression, expression.getOperatorType()))); } } template - void SubstitutionVisitor::visit(BooleanLiteralExpression const* expression) { - this->expressionStack.push(expression->getSharedPointer()); + boost::any SubstitutionVisitor::visit(BooleanLiteralExpression const& expression) { + return expression.getSharedPointer(); } template - void SubstitutionVisitor::visit(IntegerLiteralExpression const* expression) { - this->expressionStack.push(expression->getSharedPointer()); + boost::any SubstitutionVisitor::visit(IntegerLiteralExpression const& expression) { + return expression.getSharedPointer(); } template - void SubstitutionVisitor::visit(DoubleLiteralExpression const* expression) { - this->expressionStack.push(expression->getSharedPointer()); + boost::any SubstitutionVisitor::visit(DoubleLiteralExpression const& expression) { + return expression.getSharedPointer(); } // Explicitly instantiate the class with map and unordered_map. diff --git a/src/storage/expressions/SubstitutionVisitor.h b/src/storage/expressions/SubstitutionVisitor.h index 0ebc0941e..e2dec9e5c 100644 --- a/src/storage/expressions/SubstitutionVisitor.h +++ b/src/storage/expressions/SubstitutionVisitor.h @@ -28,21 +28,18 @@ namespace storm { */ Expression substitute(Expression const& expression); - virtual void visit(IfThenElseExpression const* expression) override; - virtual void visit(BinaryBooleanFunctionExpression const* expression) override; - virtual void visit(BinaryNumericalFunctionExpression const* expression) override; - virtual void visit(BinaryRelationExpression const* expression) override; - virtual void visit(VariableExpression const* expression) override; - virtual void visit(UnaryBooleanFunctionExpression const* expression) override; - virtual void visit(UnaryNumericalFunctionExpression const* expression) override; - virtual void visit(BooleanLiteralExpression const* expression) override; - virtual void visit(IntegerLiteralExpression const* expression) override; - virtual void visit(DoubleLiteralExpression const* expression) override; + virtual boost::any visit(IfThenElseExpression const& expression) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression) override; + virtual boost::any visit(BinaryRelationExpression const& expression) override; + virtual boost::any visit(VariableExpression const& expression) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression) override; + virtual boost::any visit(BooleanLiteralExpression const& expression) override; + virtual boost::any visit(IntegerLiteralExpression const& expression) override; + virtual boost::any visit(DoubleLiteralExpression const& expression) override; private: - // A stack of expression used to pass the results to the higher levels. - std::stack> expressionStack; - // A mapping of identifier names to expressions with which they shall be replaced. MapType const& identifierToExpressionMap; }; diff --git a/src/storage/expressions/TypeCheckVisitor.cpp b/src/storage/expressions/TypeCheckVisitor.cpp deleted file mode 100644 index 11643abce..000000000 --- a/src/storage/expressions/TypeCheckVisitor.cpp +++ /dev/null @@ -1,80 +0,0 @@ -#include "src/storage/expressions/TypeCheckVisitor.h" -#include "src/storage/expressions/Expressions.h" - -#include "src/utility/macros.h" -#include "src/exceptions/InvalidTypeException.h" - -namespace storm { - namespace expressions { - template - TypeCheckVisitor::TypeCheckVisitor(MapType const& identifierToTypeMap) : identifierToTypeMap(identifierToTypeMap) { - // Intentionally left empty. - } - - template - void TypeCheckVisitor::check(Expression const& expression) { - expression.getBaseExpression().accept(this); - } - - template - void TypeCheckVisitor::visit(IfThenElseExpression const* expression) { - expression->getCondition()->accept(this); - expression->getThenExpression()->accept(this); - expression->getElseExpression()->accept(this); - } - - template - void TypeCheckVisitor::visit(BinaryBooleanFunctionExpression const* expression) { - expression->getFirstOperand()->accept(this); - expression->getSecondOperand()->accept(this); - } - - template - void TypeCheckVisitor::visit(BinaryNumericalFunctionExpression const* expression) { - expression->getFirstOperand()->accept(this); - expression->getSecondOperand()->accept(this); - } - - template - void TypeCheckVisitor::visit(BinaryRelationExpression const* expression) { - expression->getFirstOperand()->accept(this); - expression->getSecondOperand()->accept(this); - } - - template - void TypeCheckVisitor::visit(VariableExpression const* expression) { - auto identifierTypePair = this->identifierToTypeMap.find(expression->getVariableName()); - STORM_LOG_THROW(identifierTypePair != this->identifierToTypeMap.end(), storm::exceptions::InvalidArgumentException, "No type available for identifier '" << expression->getVariableName() << "'."); - STORM_LOG_THROW(identifierTypePair->second == expression->getReturnType(), storm::exceptions::InvalidTypeException, "Type mismatch for variable '" << expression->getVariableName() << "': expected '" << identifierTypePair->first << "', but found '" << expression->getReturnType() << "'."); - } - - template - void TypeCheckVisitor::visit(UnaryBooleanFunctionExpression const* expression) { - expression->getOperand()->accept(this); - } - - template - void TypeCheckVisitor::visit(UnaryNumericalFunctionExpression const* expression) { - expression->getOperand()->accept(this); - } - - template - void TypeCheckVisitor::visit(BooleanLiteralExpression const* expression) { - // Intentionally left empty. - } - - template - void TypeCheckVisitor::visit(IntegerLiteralExpression const* expression) { - // Intentionally left empty. - } - - template - void TypeCheckVisitor::visit(DoubleLiteralExpression const* expression) { - // Intentionally left empty. - } - - // Explicitly instantiate the class with map and unordered_map. - template class TypeCheckVisitor>; - template class TypeCheckVisitor>; - } -} \ No newline at end of file diff --git a/src/storage/expressions/TypeCheckVisitor.h b/src/storage/expressions/TypeCheckVisitor.h deleted file mode 100644 index 0cbf40f92..000000000 --- a/src/storage/expressions/TypeCheckVisitor.h +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef STORM_STORAGE_EXPRESSIONS_TYPECHECKVISITOR_H_ -#define STORM_STORAGE_EXPRESSIONS_TYPECHECKVISITOR_H_ - -#include - -#include "src/storage/expressions/Expression.h" -#include "src/storage/expressions/ExpressionVisitor.h" - -namespace storm { - namespace expressions { - template - class TypeCheckVisitor : public ExpressionVisitor { - public: - /*! - * Creates a new type check visitor that uses the given map to check the types of variables and constants. - * - * @param identifierToTypeMap A mapping from identifiers to expressions. - */ - TypeCheckVisitor(MapType const& identifierToTypeMap); - - /*! - * Checks that the types of the identifiers in the given expression match the ones in the previously given - * map. - * - * @param expression The expression in which to check the types. - */ - void check(Expression const& expression); - - virtual void visit(IfThenElseExpression const* expression) override; - virtual void visit(BinaryBooleanFunctionExpression const* expression) override; - virtual void visit(BinaryNumericalFunctionExpression const* expression) override; - virtual void visit(BinaryRelationExpression const* expression) override; - virtual void visit(VariableExpression const* expression) override; - virtual void visit(UnaryBooleanFunctionExpression const* expression) override; - virtual void visit(UnaryNumericalFunctionExpression const* expression) override; - virtual void visit(BooleanLiteralExpression const* expression) override; - virtual void visit(IntegerLiteralExpression const* expression) override; - virtual void visit(DoubleLiteralExpression const* expression) override; - - private: - // A mapping of identifier names to expressions with which they shall be replaced. - MapType const& identifierToTypeMap; - }; - } -} - -#endif /* STORM_STORAGE_EXPRESSIONS_TYPECHECKVISITOR_H_ */ \ No newline at end of file diff --git a/src/storage/expressions/UnaryBooleanFunctionExpression.cpp b/src/storage/expressions/UnaryBooleanFunctionExpression.cpp index d696ef22d..d9f7839ea 100644 --- a/src/storage/expressions/UnaryBooleanFunctionExpression.cpp +++ b/src/storage/expressions/UnaryBooleanFunctionExpression.cpp @@ -45,8 +45,8 @@ namespace storm { } } - void UnaryBooleanFunctionExpression::accept(ExpressionVisitor* visitor) const { - visitor->visit(this); + boost::any UnaryBooleanFunctionExpression::accept(ExpressionVisitor& visitor) const { + return visitor.visit(*this); } void UnaryBooleanFunctionExpression::printToStream(std::ostream& stream) const { diff --git a/src/storage/expressions/UnaryBooleanFunctionExpression.h b/src/storage/expressions/UnaryBooleanFunctionExpression.h index c69c6f886..02b2fab2f 100644 --- a/src/storage/expressions/UnaryBooleanFunctionExpression.h +++ b/src/storage/expressions/UnaryBooleanFunctionExpression.h @@ -35,7 +35,7 @@ namespace storm { virtual storm::expressions::OperatorType getOperator() const override; virtual bool evaluateAsBool(Valuation const* valuation = nullptr) const override; virtual std::shared_ptr simplify() const override; - virtual void accept(ExpressionVisitor* visitor) const override; + virtual boost::any accept(ExpressionVisitor& visitor) const override; /*! * Retrieves the operator associated with this expression. diff --git a/src/storage/expressions/UnaryNumericalFunctionExpression.cpp b/src/storage/expressions/UnaryNumericalFunctionExpression.cpp index d90e460a1..75a2c0186 100644 --- a/src/storage/expressions/UnaryNumericalFunctionExpression.cpp +++ b/src/storage/expressions/UnaryNumericalFunctionExpression.cpp @@ -54,8 +54,8 @@ namespace storm { } } - void UnaryNumericalFunctionExpression::accept(ExpressionVisitor* visitor) const { - visitor->visit(this); + boost::any UnaryNumericalFunctionExpression::accept(ExpressionVisitor& visitor) const { + return visitor.visit(*this); } void UnaryNumericalFunctionExpression::printToStream(std::ostream& stream) const { diff --git a/src/storage/expressions/UnaryNumericalFunctionExpression.h b/src/storage/expressions/UnaryNumericalFunctionExpression.h index 10b85e63e..7852c73c7 100644 --- a/src/storage/expressions/UnaryNumericalFunctionExpression.h +++ b/src/storage/expressions/UnaryNumericalFunctionExpression.h @@ -36,7 +36,7 @@ namespace storm { virtual int_fast64_t evaluateAsInt(Valuation const* valuation = nullptr) const override; virtual double evaluateAsDouble(Valuation const* valuation = nullptr) const override; virtual std::shared_ptr simplify() const override; - virtual void accept(ExpressionVisitor* visitor) const override; + virtual boost::any accept(ExpressionVisitor& visitor) const override; /*! * Retrieves the operator associated with this expression. diff --git a/src/storage/expressions/VariableExpression.cpp b/src/storage/expressions/VariableExpression.cpp index c65e9a47a..685ce1d4d 100644 --- a/src/storage/expressions/VariableExpression.cpp +++ b/src/storage/expressions/VariableExpression.cpp @@ -65,8 +65,8 @@ namespace storm { return this->shared_from_this(); } - void VariableExpression::accept(ExpressionVisitor* visitor) const { - visitor->visit(this); + boost::any VariableExpression::accept(ExpressionVisitor& visitor) const { + return visitor.visit(*this); } void VariableExpression::printToStream(std::ostream& stream) const { diff --git a/src/storage/expressions/VariableExpression.h b/src/storage/expressions/VariableExpression.h index dda150495..d84a12c55 100644 --- a/src/storage/expressions/VariableExpression.h +++ b/src/storage/expressions/VariableExpression.h @@ -35,7 +35,7 @@ namespace storm { virtual std::set getVariables() const override; virtual std::map getVariablesAndTypes() const override; virtual std::shared_ptr simplify() const override; - virtual void accept(ExpressionVisitor* visitor) const override; + virtual boost::any accept(ExpressionVisitor& visitor) const override; /*! * Retrieves the name of the variable associated with this expression. diff --git a/src/storage/prism/Program.cpp b/src/storage/prism/Program.cpp index 00a2fca09..bd7a11137 100644 --- a/src/storage/prism/Program.cpp +++ b/src/storage/prism/Program.cpp @@ -360,13 +360,6 @@ namespace storm { std::set containedIdentifiers = constant.getExpression().getVariables(); bool isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << constant.getFilename() << ", line " << constant.getLineNumber() << ": defining expression refers to unknown identifiers."); - - // Now check that the constants appear with the right types. - try { - constant.getExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << constant.getFilename() << ", line " << constant.getLineNumber() << ": " << e.what()); - } } // Finally, register the type of the constant for later type checks. @@ -388,11 +381,6 @@ namespace storm { std::set containedIdentifiers = variable.getInitialValueExpression().getVariables(); bool isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": initial value expression refers to unknown constants."); - try { - variable.getInitialValueExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what()); - } // Register the type of the constant for later type checks. identifierToTypeMap.emplace(variable.getName(), storm::expressions::ExpressionReturnType::Bool); @@ -410,30 +398,15 @@ namespace storm { std::set containedIdentifiers = variable.getLowerBoundExpression().getVariables(); bool isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": lower bound expression refers to unknown constants."); - try { - variable.getLowerBoundExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what()); - } containedIdentifiers = variable.getLowerBoundExpression().getVariables(); isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": upper bound expression refers to unknown constants."); - try { - variable.getUpperBoundExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what()); - } // Check the initial value of the variable. containedIdentifiers = variable.getInitialValueExpression().getVariables(); isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": initial value expression refers to unknown constants."); - try { - variable.getInitialValueExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what()); - } // Register the type of the constant for later type checks. identifierToTypeMap.emplace(variable.getName(), storm::expressions::ExpressionReturnType::Int); @@ -454,11 +427,6 @@ namespace storm { std::set containedIdentifiers = variable.getInitialValueExpression().getVariables(); bool isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": initial value expression refers to unknown constants."); - try { - variable.getInitialValueExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what()); - } // Register the type of the constant for later type checks. identifierToTypeMap.emplace(variable.getName(), storm::expressions::ExpressionReturnType::Bool); @@ -478,30 +446,15 @@ namespace storm { std::set containedIdentifiers = variable.getLowerBoundExpression().getVariables(); bool isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": lower bound expression refers to unknown constants."); - try { - variable.getLowerBoundExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what()); - } containedIdentifiers = variable.getLowerBoundExpression().getVariables(); isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": upper bound expression refers to unknown constants."); - try { - variable.getUpperBoundExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what()); - } // Check the initial value of the variable. containedIdentifiers = variable.getInitialValueExpression().getVariables(); isValid = std::includes(constantNames.begin(), constantNames.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": initial value expression refers to unknown constants."); - try { - variable.getInitialValueExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << variable.getFilename() << ", line " << variable.getLineNumber() << ": " << e.what()); - } // Record the new identifier for future checks. variableNames.insert(variable.getName()); @@ -528,11 +481,6 @@ namespace storm { std::set containedIdentifiers = command.getGuardExpression().getVariables(); bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": guard refers to unknown identifiers."); - try { - command.getGuardExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": " << e.what()); - } STORM_LOG_THROW(command.getGuardExpression().hasBooleanReturnType(), storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": expression for guard must evaluate to type 'bool'."); // Check all updates. @@ -540,11 +488,6 @@ namespace storm { containedIdentifiers = update.getLikelihoodExpression().getVariables(); isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": likelihood expression refers to unknown identifiers."); - try { - update.getLikelihoodExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": " << e.what()); - } // Check all assignments. std::set alreadyAssignedIdentifiers; @@ -563,11 +506,6 @@ namespace storm { containedIdentifiers = assignment.getExpression().getVariables(); isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": likelihood expression refers to unknown identifiers."); - try { - assignment.getExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << command.getFilename() << ", line " << command.getLineNumber() << ": " << e.what()); - } // Add the current variable to the set of assigned variables (of this update). alreadyAssignedIdentifiers.insert(assignment.getVariableName()); @@ -582,21 +520,11 @@ namespace storm { std::set containedIdentifiers = stateReward.getStatePredicateExpression().getVariables(); bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": state reward expression refers to unknown identifiers."); - try { - stateReward.getStatePredicateExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": " << e.what()); - } STORM_LOG_THROW(stateReward.getStatePredicateExpression().hasBooleanReturnType(), storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": state predicate must evaluate to type 'bool'."); containedIdentifiers = stateReward.getRewardValueExpression().getVariables(); isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": state reward value expression refers to unknown identifiers."); - try { - stateReward.getRewardValueExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": " << e.what()); - } STORM_LOG_THROW(stateReward.getRewardValueExpression().hasNumericalReturnType(), storm::exceptions::WrongFormatException, "Error in " << stateReward.getFilename() << ", line " << stateReward.getLineNumber() << ": reward value expression must evaluate to numerical type."); } @@ -604,21 +532,11 @@ namespace storm { std::set containedIdentifiers = transitionReward.getStatePredicateExpression().getVariables(); bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": state reward expression refers to unknown identifiers."); - try { - transitionReward.getStatePredicateExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": " << e.what()); - } STORM_LOG_THROW(transitionReward.getStatePredicateExpression().hasBooleanReturnType(), storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": state predicate must evaluate to type 'bool'."); containedIdentifiers = transitionReward.getRewardValueExpression().getVariables(); isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": state reward value expression refers to unknown identifiers."); - try { - transitionReward.getRewardValueExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": " << e.what()); - } STORM_LOG_THROW(transitionReward.getRewardValueExpression().hasNumericalReturnType(), storm::exceptions::WrongFormatException, "Error in " << transitionReward.getFilename() << ", line " << transitionReward.getLineNumber() << ": reward value expression must evaluate to numerical type."); } } @@ -627,11 +545,6 @@ namespace storm { std::set containedIdentifiers = this->getInitialConstruct().getInitialStatesExpression().getVariables(); bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << this->getInitialConstruct().getFilename() << ", line " << this->getInitialConstruct().getLineNumber() << ": initial expression refers to unknown identifiers."); - try { - this->getInitialConstruct().getInitialStatesExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << this->getInitialConstruct().getFilename() << ", line " << this->getInitialConstruct().getLineNumber() << ": " << e.what()); - } // Check the labels. for (auto const& label : this->getLabels()) { @@ -641,12 +554,6 @@ namespace storm { std::set containedIdentifiers = label.getStatePredicateExpression().getVariables(); bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << label.getFilename() << ", line " << label.getLineNumber() << ": label expression refers to unknown identifiers."); - try { - label.getStatePredicateExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << label.getFilename() << ", line " << label.getLineNumber() << ": " << e.what()); - } - STORM_LOG_THROW(label.getStatePredicateExpression().hasBooleanReturnType(), storm::exceptions::WrongFormatException, "Error in " << label.getFilename() << ", line " << label.getLineNumber() << ": label predicate must evaluate to type 'bool'."); } @@ -658,11 +565,6 @@ namespace storm { std::set containedIdentifiers = formula.getExpression().getVariables(); bool isValid = std::includes(variablesAndConstants.begin(), variablesAndConstants.end(), containedIdentifiers.begin(), containedIdentifiers.end()); STORM_LOG_THROW(isValid, storm::exceptions::WrongFormatException, "Error in " << formula.getFilename() << ", line " << formula.getLineNumber() << ": formula expression refers to unknown identifiers."); - try { - formula.getExpression().check(identifierToTypeMap); - } catch (storm::exceptions::InvalidTypeException const& e) { - STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Error in " << formula.getFilename() << ", line " << formula.getLineNumber() << ": " << e.what()); - } // Record the new identifier for future checks. allIdentifiers.insert(formula.getName()); diff --git a/test/functional/solver/GlpkLpSolverTest.cpp b/test/functional/solver/GlpkLpSolverTest.cpp index e72aaff67..31c134172 100644 --- a/test/functional/solver/GlpkLpSolverTest.cpp +++ b/test/functional/solver/GlpkLpSolverTest.cpp @@ -14,6 +14,7 @@ TEST(GlpkLpSolver, LPOptimizeMax) { ASSERT_NO_THROW(solver.addLowerBoundedContinuousVariable("z", 0, 1)); ASSERT_NO_THROW(solver.update()); + solver.addConstraint("", storm::expressions::Expression::createDoubleVariable("x") + storm::expressions::Expression::createDoubleVariable("y") + storm::expressions::Expression::createDoubleVariable("z") <= storm::expressions::Expression::createDoubleLiteral(12)); ASSERT_NO_THROW(solver.addConstraint("", storm::expressions::Expression::createDoubleVariable("x") + storm::expressions::Expression::createDoubleVariable("y") + storm::expressions::Expression::createDoubleVariable("z") <= storm::expressions::Expression::createDoubleLiteral(12))); ASSERT_NO_THROW(solver.addConstraint("", storm::expressions::Expression::createDoubleLiteral(0.5) * storm::expressions::Expression::createDoubleVariable("y") + storm::expressions::Expression::createDoubleVariable("z") - storm::expressions::Expression::createDoubleVariable("x") == storm::expressions::Expression::createDoubleLiteral(5))); ASSERT_NO_THROW(solver.addConstraint("", storm::expressions::Expression::createDoubleVariable("y") - storm::expressions::Expression::createDoubleVariable("x") <= storm::expressions::Expression::createDoubleLiteral(5.5)));