Browse Source

added function that reduces the nesting of expressions (e.g. when considering a big sum with many summands. This fixes stack overflows when translating expressions

main
TimQu 7 years ago
parent
commit
ed2de09ce3
  1. 8
      src/storm/storage/expressions/BaseExpression.cpp
  2. 7
      src/storm/storage/expressions/BaseExpression.h
  3. 4
      src/storm/storage/expressions/Expression.cpp
  4. 7
      src/storm/storage/expressions/Expression.h
  5. 180
      src/storm/storage/expressions/ReduceNestingVisitor.cpp
  6. 36
      src/storm/storage/expressions/ReduceNestingVisitor.h
  7. 7
      src/storm/storage/jani/JSONExporter.cpp
  8. 17
      src/storm/utility/ExpressionHelper.cpp

8
src/storm/storage/expressions/BaseExpression.cpp

@ -6,6 +6,7 @@
#include "storm/storage/expressions/Expressions.h"
#include "storm/storage/expressions/ToRationalNumberVisitor.h"
#include "storm/storage/expressions/ReduceNestingVisitor.h"
namespace storm {
namespace expressions {
@ -63,7 +64,7 @@ namespace storm {
}
std::shared_ptr<BaseExpression const> BaseExpression::getOperand(uint_fast64_t operandIndex) const {
STORM_LOG_THROW(false, storm::exceptions::InvalidAccessException, "Unable to access operand " << operandIndex << " in expression of arity 0.");
STORM_LOG_THROW(false, storm::exceptions::InvalidAccessException, "Unable to access operand " << operandIndex << " in expression '" << *this << "' of arity 0.");
}
std::string const& BaseExpression::getIdentifier() const {
@ -74,6 +75,11 @@ namespace storm {
STORM_LOG_THROW(false, storm::exceptions::InvalidAccessException, "Unable to access operator of non-function application expression.");
}
std::shared_ptr<BaseExpression const> BaseExpression::reduceNesting() const {
ReduceNestingVisitor v;
return v.reduceNesting(this->toExpression()).getBaseExpressionPointer();
}
bool BaseExpression::containsVariables() const {
return false;
}

7
src/storm/storage/expressions/BaseExpression.h

@ -185,6 +185,13 @@ namespace storm {
*/
virtual std::shared_ptr<BaseExpression const> simplify() const = 0;
/*!
* Tries to flatten the syntax tree of the expression, e.g., 1 + (2 + (3 + 4)) becomes (1 + 2) + (3 + 4)
*
* @return A semantically equivalent expression with reduced nesting
*/
std::shared_ptr<BaseExpression const> reduceNesting() const;
/*!
* Accepts the given visitor by calling its visit method.
*

4
src/storm/storage/expressions/Expression.cpp

@ -73,6 +73,10 @@ namespace storm {
return Expression(this->getBaseExpression().simplify());
}
Expression Expression::reduceNesting() const {
return Expression(this->getBaseExpression().reduceNesting());
}
OperatorType Expression::getOperator() const {
return this->getBaseExpression().getOperator();
}

7
src/storm/storage/expressions/Expression.h

@ -152,6 +152,13 @@ namespace storm {
*/
Expression simplify() const;
/*!
* Tries to flatten the syntax tree of the expression, e.g., 1 + (2 + (3 + 4)) becomes (1 + 2) + (3 + 4)
*
* @return A semantically equivalent expression with reduced nesting
*/
Expression reduceNesting() const;
/*!
* Retrieves the operator of a function application. This is only legal to call if the expression is
* function application.

180
src/storm/storage/expressions/ReduceNestingVisitor.cpp

@ -0,0 +1,180 @@
#include <string>
#include "storm/storage/expressions/ReduceNestingVisitor.h"
#include "storm/storage/expressions/Expressions.h"
namespace storm {
namespace expressions {
ReduceNestingVisitor::ReduceNestingVisitor() {
// Intentionally left empty.
}
Expression ReduceNestingVisitor::reduceNesting(Expression const& expression) {
return Expression(boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getBaseExpression().accept(*this, boost::none)));
}
boost::any ReduceNestingVisitor::visit(IfThenElseExpression const& expression, boost::any const& data) {
std::shared_ptr<BaseExpression const> conditionExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getCondition()->accept(*this, data));
std::shared_ptr<BaseExpression const> thenExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getThenExpression()->accept(*this, data));
std::shared_ptr<BaseExpression const> elseExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getElseExpression()->accept(*this, data));
// If the arguments did not change, we simply push the expression itself.
if (conditionExpression.get() == expression.getCondition().get() && thenExpression.get() == expression.getThenExpression().get() && elseExpression.get() == expression.getElseExpression().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new IfThenElseExpression(expression.getManager(), expression.getType(), conditionExpression, thenExpression, elseExpression)));
}
}
template <typename BinaryFunc>
std::vector<std::shared_ptr<BaseExpression const>> getAllOperands(BinaryFunc const& binaryExpression) {
auto opType = binaryExpression.getOperatorType();
std::vector<std::shared_ptr<BaseExpression const>> stack = {binaryExpression.getSharedPointer()};
std::vector<std::shared_ptr<BaseExpression const>> res;
while (!stack.empty()) {
auto f = std::move(stack.back());
stack.pop_back();
for (uint64_t opIndex = 0; opIndex < 2; ++opIndex) {
BinaryFunc const* subexp = dynamic_cast<BinaryFunc const*>(f->getOperand(opIndex).get());
if (subexp != nullptr && subexp->getOperatorType() == opType) {
stack.push_back(f->getOperand(opIndex));
} else {
res.push_back(f->getOperand(opIndex));
}
}
}
return res;
}
boost::any ReduceNestingVisitor::visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) {
// Check if the operator is commutative and associative
if (expression.getOperatorType() == BinaryBooleanFunctionExpression::OperatorType::Or || expression.getOperatorType() == BinaryBooleanFunctionExpression::OperatorType::And || expression.getOperatorType() == BinaryBooleanFunctionExpression::OperatorType::Iff) {
std::vector<std::shared_ptr<BaseExpression const>> operands = getAllOperands<BinaryBooleanFunctionExpression>(expression);
// Balance the syntax tree if there are enough operands
if (operands.size() >= 4) {
for (auto& operand : operands) {
operand = boost::any_cast<std::shared_ptr<BaseExpression const>>(operand->accept(*this, data));
}
auto opIt = operands.begin();
while (operands.size() > 1) {
if (opIt == operands.end() || opIt == operands.end() - 1) {
opIt = operands.begin();
}
*opIt = std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), *opIt, operands.back(), expression.getOperatorType())));
operands.pop_back();
++opIt;
}
return operands.front();
}
}
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getFirstOperand()->accept(*this, data));
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getSecondOperand()->accept(*this, data));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType())));
}
}
boost::any ReduceNestingVisitor::visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) {
// Check if the operator is commutative and associative
if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus || expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times || expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Max || expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Min) {
std::vector<std::shared_ptr<BaseExpression const>> operands = getAllOperands<BinaryNumericalFunctionExpression>(expression);
// Balance the syntax tree if there are enough operands
if (operands.size() >= 4) {
for (auto& operand : operands) {
operand = boost::any_cast<std::shared_ptr<BaseExpression const>>(operand->accept(*this, data));
}
auto opIt = operands.begin();
while (operands.size() > 1) {
if (opIt == operands.end() || opIt == operands.end() - 1) {
opIt = operands.begin();
}
*opIt = std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), *opIt, operands.back(), expression.getOperatorType())));
operands.pop_back();
++opIt;
}
return operands.front();
}
}
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getFirstOperand()->accept(*this, data));
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getSecondOperand()->accept(*this, data));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType())));
}
}
boost::any ReduceNestingVisitor::visit(BinaryRelationExpression const& expression, boost::any const& data) {
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getFirstOperand()->accept(*this, data));
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getSecondOperand()->accept(*this, data));
// If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getRelationType())));
}
}
boost::any ReduceNestingVisitor::visit(VariableExpression const& expression, boost::any const&) {
return expression.getSharedPointer();
}
boost::any ReduceNestingVisitor::visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) {
std::shared_ptr<BaseExpression const> operandExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getOperand()->accept(*this, data));
// If the argument did not change, we simply push the expression itself.
if (operandExpression.get() == expression.getOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new UnaryBooleanFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType())));
}
}
boost::any ReduceNestingVisitor::visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) {
std::shared_ptr<BaseExpression const> operandExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getOperand()->accept(*this, data));
// If the argument did not change, we simply push the expression itself.
if (operandExpression.get() == expression.getOperand().get()) {
return expression.getSharedPointer();
} else {
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType())));
}
}
boost::any ReduceNestingVisitor::visit(BooleanLiteralExpression const& expression, boost::any const&) {
return expression.getSharedPointer();
}
boost::any ReduceNestingVisitor::visit(IntegerLiteralExpression const& expression, boost::any const&) {
return expression.getSharedPointer();
}
boost::any ReduceNestingVisitor::visit(RationalLiteralExpression const& expression, boost::any const&) {
return expression.getSharedPointer();
}
}
}

36
src/storm/storage/expressions/ReduceNestingVisitor.h

@ -0,0 +1,36 @@
#pragma once
#include "storm/storage/expressions/Expression.h"
#include "storm/storage/expressions/ExpressionVisitor.h"
namespace storm {
namespace expressions {
class ReduceNestingVisitor : public ExpressionVisitor {
public:
/*!
* Creates a new reduce nesting visitor.
*/
ReduceNestingVisitor();
/*!
* Reduces the nesting in the given expression
*
* @return A semantically equivalent expression with reduced nesting
*/
Expression reduceNesting(Expression const& expression);
virtual boost::any visit(IfThenElseExpression const& expression, boost::any const& data) override;
virtual boost::any visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) override;
virtual boost::any visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) override;
virtual boost::any visit(BinaryRelationExpression const& expression, boost::any const& data) override;
virtual boost::any visit(VariableExpression const& expression, boost::any const& data) override;
virtual boost::any visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) override;
virtual boost::any visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) override;
virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) override;
virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) override;
virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) override;
private:
};
}
}

7
src/storm/storage/jani/JSONExporter.cpp

@ -548,8 +548,13 @@ namespace storm {
}
modernjson::json ExpressionToJson::translate(storm::expressions::Expression const& expr, std::vector<storm::jani::Constant> const& constants, VariableSet const& globalVariables, VariableSet const& localVariables) {
// Simplify the expression first and reduce the nesting
auto simplifiedExpr = expr.simplify().reduceNesting();
ExpressionToJson visitor(constants, globalVariables, localVariables);
return boost::any_cast<modernjson::json>(expr.accept(visitor, boost::none));
return boost::any_cast<modernjson::json>(simplifiedExpr.accept(visitor, boost::none));
}

17
src/storm/utility/ExpressionHelper.cpp

@ -12,17 +12,16 @@ namespace storm {
if (summands.empty()) {
return manager->rational(storm::utility::zero<storm::RationalNumber>());
}
// As the sum can potentially have many summands, we want to make sure that the formula tree is (roughly balanced)
auto it = summands.begin();
while (summands.size() > 1) {
if (it == summands.end() || it == summands.end() - 1) {
it = summands.begin();
storm::expressions::Expression res = summands.front();
bool first = true;
for (auto& s : summands) {
if (first) {
first = false;
} else {
res = res + s;
}
*it = *it + summands.back();
summands.pop_back();
++it;
}
return summands.front();
return res.simplify().reduceNesting();
}
}
Loading…
Cancel
Save