Browse Source
added function that reduces the nesting of expressions (e.g. when considering a big sum with many summands. This fixes stack overflows when translating expressions
main
added function that reduces the nesting of expressions (e.g. when considering a big sum with many summands. This fixes stack overflows when translating expressions
main
8 changed files with 255 additions and 11 deletions
-
8src/storm/storage/expressions/BaseExpression.cpp
-
7src/storm/storage/expressions/BaseExpression.h
-
4src/storm/storage/expressions/Expression.cpp
-
7src/storm/storage/expressions/Expression.h
-
180src/storm/storage/expressions/ReduceNestingVisitor.cpp
-
36src/storm/storage/expressions/ReduceNestingVisitor.h
-
7src/storm/storage/jani/JSONExporter.cpp
-
17src/storm/utility/ExpressionHelper.cpp
@ -0,0 +1,180 @@ |
|||||
|
#include <string>
|
||||
|
|
||||
|
#include "storm/storage/expressions/ReduceNestingVisitor.h"
|
||||
|
#include "storm/storage/expressions/Expressions.h"
|
||||
|
|
||||
|
namespace storm { |
||||
|
namespace expressions { |
||||
|
|
||||
|
ReduceNestingVisitor::ReduceNestingVisitor() { |
||||
|
// Intentionally left empty.
|
||||
|
} |
||||
|
|
||||
|
Expression ReduceNestingVisitor::reduceNesting(Expression const& expression) { |
||||
|
return Expression(boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getBaseExpression().accept(*this, boost::none))); |
||||
|
} |
||||
|
|
||||
|
boost::any ReduceNestingVisitor::visit(IfThenElseExpression const& expression, boost::any const& data) { |
||||
|
std::shared_ptr<BaseExpression const> conditionExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getCondition()->accept(*this, data)); |
||||
|
std::shared_ptr<BaseExpression const> thenExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getThenExpression()->accept(*this, data)); |
||||
|
std::shared_ptr<BaseExpression const> elseExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getElseExpression()->accept(*this, data)); |
||||
|
|
||||
|
// If the arguments did not change, we simply push the expression itself.
|
||||
|
if (conditionExpression.get() == expression.getCondition().get() && thenExpression.get() == expression.getThenExpression().get() && elseExpression.get() == expression.getElseExpression().get()) { |
||||
|
return expression.getSharedPointer(); |
||||
|
} else { |
||||
|
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new IfThenElseExpression(expression.getManager(), expression.getType(), conditionExpression, thenExpression, elseExpression))); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
template <typename BinaryFunc> |
||||
|
std::vector<std::shared_ptr<BaseExpression const>> getAllOperands(BinaryFunc const& binaryExpression) { |
||||
|
auto opType = binaryExpression.getOperatorType(); |
||||
|
std::vector<std::shared_ptr<BaseExpression const>> stack = {binaryExpression.getSharedPointer()}; |
||||
|
std::vector<std::shared_ptr<BaseExpression const>> res; |
||||
|
while (!stack.empty()) { |
||||
|
auto f = std::move(stack.back()); |
||||
|
stack.pop_back(); |
||||
|
|
||||
|
for (uint64_t opIndex = 0; opIndex < 2; ++opIndex) { |
||||
|
BinaryFunc const* subexp = dynamic_cast<BinaryFunc const*>(f->getOperand(opIndex).get()); |
||||
|
if (subexp != nullptr && subexp->getOperatorType() == opType) { |
||||
|
stack.push_back(f->getOperand(opIndex)); |
||||
|
} else { |
||||
|
res.push_back(f->getOperand(opIndex)); |
||||
|
} |
||||
|
} |
||||
|
} |
||||
|
return res; |
||||
|
} |
||||
|
|
||||
|
boost::any ReduceNestingVisitor::visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) { |
||||
|
|
||||
|
// Check if the operator is commutative and associative
|
||||
|
if (expression.getOperatorType() == BinaryBooleanFunctionExpression::OperatorType::Or || expression.getOperatorType() == BinaryBooleanFunctionExpression::OperatorType::And || expression.getOperatorType() == BinaryBooleanFunctionExpression::OperatorType::Iff) { |
||||
|
|
||||
|
std::vector<std::shared_ptr<BaseExpression const>> operands = getAllOperands<BinaryBooleanFunctionExpression>(expression); |
||||
|
|
||||
|
// Balance the syntax tree if there are enough operands
|
||||
|
if (operands.size() >= 4) { |
||||
|
|
||||
|
for (auto& operand : operands) { |
||||
|
operand = boost::any_cast<std::shared_ptr<BaseExpression const>>(operand->accept(*this, data)); |
||||
|
} |
||||
|
|
||||
|
auto opIt = operands.begin(); |
||||
|
while (operands.size() > 1) { |
||||
|
if (opIt == operands.end() || opIt == operands.end() - 1) { |
||||
|
opIt = operands.begin(); |
||||
|
} |
||||
|
*opIt = std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), *opIt, operands.back(), expression.getOperatorType()))); |
||||
|
operands.pop_back(); |
||||
|
++opIt; |
||||
|
} |
||||
|
return operands.front(); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getFirstOperand()->accept(*this, data)); |
||||
|
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getSecondOperand()->accept(*this, data)); |
||||
|
|
||||
|
// If the arguments did not change, we simply push the expression itself.
|
||||
|
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { |
||||
|
return expression.getSharedPointer(); |
||||
|
} else { |
||||
|
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType()))); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
boost::any ReduceNestingVisitor::visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) { |
||||
|
// Check if the operator is commutative and associative
|
||||
|
if (expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Plus || expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Times || expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Max || expression.getOperatorType() == BinaryNumericalFunctionExpression::OperatorType::Min) { |
||||
|
|
||||
|
std::vector<std::shared_ptr<BaseExpression const>> operands = getAllOperands<BinaryNumericalFunctionExpression>(expression); |
||||
|
|
||||
|
// Balance the syntax tree if there are enough operands
|
||||
|
if (operands.size() >= 4) { |
||||
|
|
||||
|
for (auto& operand : operands) { |
||||
|
operand = boost::any_cast<std::shared_ptr<BaseExpression const>>(operand->accept(*this, data)); |
||||
|
} |
||||
|
|
||||
|
auto opIt = operands.begin(); |
||||
|
while (operands.size() > 1) { |
||||
|
if (opIt == operands.end() || opIt == operands.end() - 1) { |
||||
|
opIt = operands.begin(); |
||||
|
} |
||||
|
*opIt = std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), *opIt, operands.back(), expression.getOperatorType()))); |
||||
|
operands.pop_back(); |
||||
|
++opIt; |
||||
|
} |
||||
|
return operands.front(); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
|
||||
|
|
||||
|
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getFirstOperand()->accept(*this, data)); |
||||
|
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getSecondOperand()->accept(*this, data)); |
||||
|
|
||||
|
// If the arguments did not change, we simply push the expression itself.
|
||||
|
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { |
||||
|
return expression.getSharedPointer(); |
||||
|
} else { |
||||
|
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType()))); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
boost::any ReduceNestingVisitor::visit(BinaryRelationExpression const& expression, boost::any const& data) { |
||||
|
|
||||
|
std::shared_ptr<BaseExpression const> firstExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getFirstOperand()->accept(*this, data)); |
||||
|
std::shared_ptr<BaseExpression const> secondExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getSecondOperand()->accept(*this, data)); |
||||
|
|
||||
|
// If the arguments did not change, we simply push the expression itself.
|
||||
|
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { |
||||
|
return expression.getSharedPointer(); |
||||
|
} else { |
||||
|
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new BinaryRelationExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getRelationType()))); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
boost::any ReduceNestingVisitor::visit(VariableExpression const& expression, boost::any const&) { |
||||
|
return expression.getSharedPointer(); |
||||
|
} |
||||
|
|
||||
|
boost::any ReduceNestingVisitor::visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) { |
||||
|
std::shared_ptr<BaseExpression const> operandExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getOperand()->accept(*this, data)); |
||||
|
|
||||
|
// If the argument did not change, we simply push the expression itself.
|
||||
|
if (operandExpression.get() == expression.getOperand().get()) { |
||||
|
return expression.getSharedPointer(); |
||||
|
} else { |
||||
|
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new UnaryBooleanFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType()))); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
boost::any ReduceNestingVisitor::visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) { |
||||
|
std::shared_ptr<BaseExpression const> operandExpression = boost::any_cast<std::shared_ptr<BaseExpression const>>(expression.getOperand()->accept(*this, data)); |
||||
|
|
||||
|
// If the argument did not change, we simply push the expression itself.
|
||||
|
if (operandExpression.get() == expression.getOperand().get()) { |
||||
|
return expression.getSharedPointer(); |
||||
|
} else { |
||||
|
return std::const_pointer_cast<BaseExpression const>(std::shared_ptr<BaseExpression>(new UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType()))); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
boost::any ReduceNestingVisitor::visit(BooleanLiteralExpression const& expression, boost::any const&) { |
||||
|
return expression.getSharedPointer(); |
||||
|
} |
||||
|
|
||||
|
boost::any ReduceNestingVisitor::visit(IntegerLiteralExpression const& expression, boost::any const&) { |
||||
|
return expression.getSharedPointer(); |
||||
|
} |
||||
|
|
||||
|
boost::any ReduceNestingVisitor::visit(RationalLiteralExpression const& expression, boost::any const&) { |
||||
|
return expression.getSharedPointer(); |
||||
|
} |
||||
|
|
||||
|
} |
||||
|
} |
@ -0,0 +1,36 @@ |
|||||
|
#pragma once |
||||
|
|
||||
|
#include "storm/storage/expressions/Expression.h" |
||||
|
#include "storm/storage/expressions/ExpressionVisitor.h" |
||||
|
|
||||
|
namespace storm { |
||||
|
namespace expressions { |
||||
|
class ReduceNestingVisitor : public ExpressionVisitor { |
||||
|
public: |
||||
|
/*! |
||||
|
* Creates a new reduce nesting visitor. |
||||
|
*/ |
||||
|
ReduceNestingVisitor(); |
||||
|
|
||||
|
/*! |
||||
|
* Reduces the nesting in the given expression |
||||
|
* |
||||
|
* @return A semantically equivalent expression with reduced nesting |
||||
|
*/ |
||||
|
Expression reduceNesting(Expression const& expression); |
||||
|
|
||||
|
virtual boost::any visit(IfThenElseExpression const& expression, boost::any const& data) override; |
||||
|
virtual boost::any visit(BinaryBooleanFunctionExpression const& expression, boost::any const& data) override; |
||||
|
virtual boost::any visit(BinaryNumericalFunctionExpression const& expression, boost::any const& data) override; |
||||
|
virtual boost::any visit(BinaryRelationExpression const& expression, boost::any const& data) override; |
||||
|
virtual boost::any visit(VariableExpression const& expression, boost::any const& data) override; |
||||
|
virtual boost::any visit(UnaryBooleanFunctionExpression const& expression, boost::any const& data) override; |
||||
|
virtual boost::any visit(UnaryNumericalFunctionExpression const& expression, boost::any const& data) override; |
||||
|
virtual boost::any visit(BooleanLiteralExpression const& expression, boost::any const& data) override; |
||||
|
virtual boost::any visit(IntegerLiteralExpression const& expression, boost::any const& data) override; |
||||
|
virtual boost::any visit(RationalLiteralExpression const& expression, boost::any const& data) override; |
||||
|
|
||||
|
private: |
||||
|
}; |
||||
|
} |
||||
|
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue