diff --git a/src/storm/storage/expressions/BaseExpression.cpp b/src/storm/storage/expressions/BaseExpression.cpp index 4e07fd089..b6f652e3f 100644 --- a/src/storm/storage/expressions/BaseExpression.cpp +++ b/src/storm/storage/expressions/BaseExpression.cpp @@ -6,6 +6,7 @@ #include "storm/storage/expressions/Expressions.h" #include "storm/storage/expressions/ToRationalNumberVisitor.h" +#include "storm/storage/expressions/ReduceNestingVisitor.h" namespace storm { namespace expressions { @@ -63,7 +64,7 @@ namespace storm { } std::shared_ptr BaseExpression::getOperand(uint_fast64_t operandIndex) const { - STORM_LOG_THROW(false, storm::exceptions::InvalidAccessException, "Unable to access operand " << operandIndex << " in expression of arity 0."); + STORM_LOG_THROW(false, storm::exceptions::InvalidAccessException, "Unable to access operand " << operandIndex << " in expression '" << *this << "' of arity 0."); } std::string const& BaseExpression::getIdentifier() const { @@ -74,6 +75,11 @@ namespace storm { STORM_LOG_THROW(false, storm::exceptions::InvalidAccessException, "Unable to access operator of non-function application expression."); } + std::shared_ptr BaseExpression::reduceNesting() const { + ReduceNestingVisitor v; + return v.reduceNesting(this->toExpression()).getBaseExpressionPointer(); + } + bool BaseExpression::containsVariables() const { return false; } diff --git a/src/storm/storage/expressions/BaseExpression.h b/src/storm/storage/expressions/BaseExpression.h index 93091e536..666c4d40b 100644 --- a/src/storm/storage/expressions/BaseExpression.h +++ b/src/storm/storage/expressions/BaseExpression.h @@ -185,6 +185,13 @@ namespace storm { */ virtual std::shared_ptr simplify() const = 0; + /*! + * Tries to flatten the syntax tree of the expression, e.g., 1 + (2 + (3 + 4)) becomes (1 + 2) + (3 + 4) + * + * @return A semantically equivalent expression with reduced nesting + */ + std::shared_ptr reduceNesting() const; + /*! * Accepts the given visitor by calling its visit method. * diff --git a/src/storm/storage/expressions/Expression.cpp b/src/storm/storage/expressions/Expression.cpp index 2ec1604a7..1d31f4c6f 100644 --- a/src/storm/storage/expressions/Expression.cpp +++ b/src/storm/storage/expressions/Expression.cpp @@ -73,6 +73,10 @@ namespace storm { return Expression(this->getBaseExpression().simplify()); } + Expression Expression::reduceNesting() const { + return Expression(this->getBaseExpression().reduceNesting()); + } + OperatorType Expression::getOperator() const { return this->getBaseExpression().getOperator(); } diff --git a/src/storm/storage/expressions/Expression.h b/src/storm/storage/expressions/Expression.h index 99d409e2e..5778f7214 100644 --- a/src/storm/storage/expressions/Expression.h +++ b/src/storm/storage/expressions/Expression.h @@ -152,6 +152,13 @@ namespace storm { */ Expression simplify() const; + /*! + * Tries to flatten the syntax tree of the expression, e.g., 1 + (2 + (3 + 4)) becomes (1 + 2) + (3 + 4) + * + * @return A semantically equivalent expression with reduced nesting + */ + Expression reduceNesting() const; + /*! * Retrieves the operator of a function application. This is only legal to call if the expression is * function application. diff --git a/src/storm/storage/expressions/ReduceNestingVisitor.cpp b/src/storm/storage/expressions/ReduceNestingVisitor.cpp new file mode 100644 index 000000000..5f77ca10c --- /dev/null +++ b/src/storm/storage/expressions/ReduceNestingVisitor.cpp @@ -0,0 +1,180 @@ +#include + +#include "storm/storage/expressions/ReduceNestingVisitor.h" +#include "storm/storage/expressions/Expressions.h" + +namespace storm { + namespace expressions { + + ReduceNestingVisitor::ReduceNestingVisitor() { + // Intentionally left empty. + } + + Expression ReduceNestingVisitor::reduceNesting(Expression const& expression) { + return Expression(boost::any_cast>(expression.getBaseExpression().accept(*this, boost::none))); + } + + boost::any ReduceNestingVisitor::visit(IfThenElseExpression const& expression, boost::any const& data) { + std::shared_ptr conditionExpression = boost::any_cast>(expression.getCondition()->accept(*this, data)); + std::shared_ptr thenExpression = boost::any_cast>(expression.getThenExpression()->accept(*this, data)); + std::shared_ptr elseExpression = boost::any_cast>(expression.getElseExpression()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (conditionExpression.get() == expression.getCondition().get() && thenExpression.get() == expression.getThenExpression().get() && elseExpression.get() == expression.getElseExpression().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new IfThenElseExpression(expression.getManager(), expression.getType(), conditionExpression, thenExpression, elseExpression))); + } + } + + template + std::vector> getAllOperands(BinaryFunc const& binaryExpression) { + auto opType = binaryExpression.getOperatorType(); + std::vector> stack = {binaryExpression.getSharedPointer()}; + std::vector> res; + while (!stack.empty()) { + auto f = std::move(stack.back()); + stack.pop_back(); + + for (uint64_t opIndex = 0; opIndex < 2; ++opIndex) { + BinaryFunc const* subexp = dynamic_cast(f->getOperand(opIndex).get()); + if (subexp != nullptr && subexp->getOperatorType() == opType) { + stack.push_back(f->getOperand(opIndex)); + } else { + res.push_back(f->getOperand(opIndex)); + } + } + } + return res; + } + + boost::any ReduceNestingVisitor::visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) { + + // Check if the operator is commutative and associative + if (expression.getOperatorType() == BinaryBooleanFunctionExpression::OperatorType::Or || expression.getOperatorType() == BinaryBooleanFunctionExpression::OperatorType::And || expression.getOperatorType() == BinaryBooleanFunctionExpression::OperatorType::Iff) { + + std::vector> operands = getAllOperands(expression); + + // Balance the syntax tree if there are enough operands + if (operands.size() >= 4) { + + for (auto& operand : operands) { + operand = boost::any_cast>(operand->accept(*this, data)); + } + + auto opIt = operands.begin(); + while (operands.size() > 1) { + if (opIt == operands.end() || opIt == operands.end() - 1) { + opIt = operands.begin(); + } + *opIt = std::const_pointer_cast(std::shared_ptr(new BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), *opIt, operands.back(), expression.getOperatorType()))); + operands.pop_back(); + ++opIt; + } + return operands.front(); + } + } + + std::shared_ptr firstExpression = boost::any_cast>(expression.getFirstOperand()->accept(*this, data)); + std::shared_ptr secondExpression = boost::any_cast>(expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType()))); + } + } + + boost::any ReduceNestingVisitor::visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) { + // Check if the operator is commutative and associative + if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus || expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times || expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Max || expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Min) { + + std::vector> operands = getAllOperands(expression); + + // Balance the syntax tree if there are enough operands + if (operands.size() >= 4) { + + for (auto& operand : operands) { + operand = boost::any_cast>(operand->accept(*this, data)); + } + + auto opIt = operands.begin(); + while (operands.size() > 1) { + if (opIt == operands.end() || opIt == operands.end() - 1) { + opIt = operands.begin(); + } + *opIt = std::const_pointer_cast(std::shared_ptr(new BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), *opIt, operands.back(), expression.getOperatorType()))); + operands.pop_back(); + ++opIt; + } + return operands.front(); + } + } + + + + std::shared_ptr firstExpression = boost::any_cast>(expression.getFirstOperand()->accept(*this, data)); + std::shared_ptr secondExpression = boost::any_cast>(expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType()))); + } + } + + boost::any ReduceNestingVisitor::visit(BinaryRelationExpression const& expression, boost::any const& data) { + + std::shared_ptr firstExpression = boost::any_cast>(expression.getFirstOperand()->accept(*this, data)); + std::shared_ptr secondExpression = boost::any_cast>(expression.getSecondOperand()->accept(*this, data)); + + // If the arguments did not change, we simply push the expression itself. + if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new BinaryRelationExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getRelationType()))); + } + } + + boost::any ReduceNestingVisitor::visit(VariableExpression const& expression, boost::any const&) { + return expression.getSharedPointer(); + } + + boost::any ReduceNestingVisitor::visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) { + std::shared_ptr operandExpression = boost::any_cast>(expression.getOperand()->accept(*this, data)); + + // If the argument did not change, we simply push the expression itself. + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new UnaryBooleanFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType()))); + } + } + + boost::any ReduceNestingVisitor::visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) { + std::shared_ptr operandExpression = boost::any_cast>(expression.getOperand()->accept(*this, data)); + + // If the argument did not change, we simply push the expression itself. + if (operandExpression.get() == expression.getOperand().get()) { + return expression.getSharedPointer(); + } else { + return std::const_pointer_cast(std::shared_ptr(new UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType()))); + } + } + + boost::any ReduceNestingVisitor::visit(BooleanLiteralExpression const& expression, boost::any const&) { + return expression.getSharedPointer(); + } + + boost::any ReduceNestingVisitor::visit(IntegerLiteralExpression const& expression, boost::any const&) { + return expression.getSharedPointer(); + } + + boost::any ReduceNestingVisitor::visit(RationalLiteralExpression const& expression, boost::any const&) { + return expression.getSharedPointer(); + } + + } +} diff --git a/src/storm/storage/expressions/ReduceNestingVisitor.h b/src/storm/storage/expressions/ReduceNestingVisitor.h new file mode 100644 index 000000000..f7c9c1566 --- /dev/null +++ b/src/storm/storage/expressions/ReduceNestingVisitor.h @@ -0,0 +1,36 @@ +#pragma once + +#include "storm/storage/expressions/Expression.h" +#include "storm/storage/expressions/ExpressionVisitor.h" + +namespace storm { + namespace expressions { + class ReduceNestingVisitor : public ExpressionVisitor { + public: + /*! + * Creates a new reduce nesting visitor. + */ + ReduceNestingVisitor(); + + /*! + * Reduces the nesting in the given expression + * + * @return A semantically equivalent expression with reduced nesting + */ + Expression reduceNesting(Expression const& expression); + + virtual boost::any visit(IfThenElseExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BinaryRelationExpression const& expression, boost::any const& data) override; + virtual boost::any visit(VariableExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) override; + virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) override; + virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) override; + + private: + }; + } +} diff --git a/src/storm/storage/jani/JSONExporter.cpp b/src/storm/storage/jani/JSONExporter.cpp index befb793c5..c48927a78 100644 --- a/src/storm/storage/jani/JSONExporter.cpp +++ b/src/storm/storage/jani/JSONExporter.cpp @@ -548,8 +548,13 @@ namespace storm { } modernjson::json ExpressionToJson::translate(storm::expressions::Expression const& expr, std::vector const& constants, VariableSet const& globalVariables, VariableSet const& localVariables) { + + // Simplify the expression first and reduce the nesting + auto simplifiedExpr = expr.simplify().reduceNesting(); + + ExpressionToJson visitor(constants, globalVariables, localVariables); - return boost::any_cast(expr.accept(visitor, boost::none)); + return boost::any_cast(simplifiedExpr.accept(visitor, boost::none)); } diff --git a/src/storm/utility/ExpressionHelper.cpp b/src/storm/utility/ExpressionHelper.cpp index 219d6e39c..fd99855f8 100644 --- a/src/storm/utility/ExpressionHelper.cpp +++ b/src/storm/utility/ExpressionHelper.cpp @@ -12,17 +12,16 @@ namespace storm { if (summands.empty()) { return manager->rational(storm::utility::zero()); } - // As the sum can potentially have many summands, we want to make sure that the formula tree is (roughly balanced) - auto it = summands.begin(); - while (summands.size() > 1) { - if (it == summands.end() || it == summands.end() - 1) { - it = summands.begin(); + storm::expressions::Expression res = summands.front(); + bool first = true; + for (auto& s : summands) { + if (first) { + first = false; + } else { + res = res + s; } - *it = *it + summands.back(); - summands.pop_back(); - ++it; } - return summands.front(); + return res.simplify().reduceNesting(); } }