From ed45fa80e648b58988f78f23865cdf5e881a3721 Mon Sep 17 00:00:00 2001 From: TimQu Date: Tue, 11 Sep 2018 18:00:51 +0200 Subject: [PATCH] debugging array elimination --- src/storm/storage/jani/ArrayEliminator.cpp | 85 +++++++++------ src/storm/storage/jani/TemplateEdge.cpp | 4 + src/storm/storage/jani/TemplateEdge.h | 1 + .../jani/expressions/JaniExpressionVisitor.h | 4 +- .../jani/traverser/ArrayExpressionFinder.cpp | 102 ++++++++++++++++++ .../jani/traverser/ArrayExpressionFinder.h | 18 ++++ 6 files changed, 181 insertions(+), 33 deletions(-) create mode 100644 src/storm/storage/jani/traverser/ArrayExpressionFinder.cpp create mode 100644 src/storm/storage/jani/traverser/ArrayExpressionFinder.h diff --git a/src/storm/storage/jani/ArrayEliminator.cpp b/src/storm/storage/jani/ArrayEliminator.cpp index fb50de59f..114e8ef3a 100644 --- a/src/storm/storage/jani/ArrayEliminator.cpp +++ b/src/storm/storage/jani/ArrayEliminator.cpp @@ -5,6 +5,8 @@ #include "storm/storage/expressions/ExpressionVisitor.h" #include "storm/storage/jani/expressions/JaniExpressionVisitor.h" #include "storm/storage/jani/Variable.h" +#include "storm/storage/jani/traverser/ArrayExpressionFinder.h" + #include "storm/storage/expressions/Expressions.h" #include "storm/storage/jani/expressions/JaniExpressions.h" #include "storm/storage/expressions/ExpressionManager.h" @@ -118,20 +120,10 @@ namespace storm { ArrayExpressionEliminationVisitor(std::unordered_map> const& replacements, std::unordered_map const& sizes) : replacements(replacements), arraySizes(sizes) {} virtual ~ArrayExpressionEliminationVisitor() = default; - storm::expressions::Expression eliminate(storm::expressions::Expression const& expression, storm::expressions::Expression const& outOfBoundsExpression) { - auto res = eliminate(expression, false); - if (outOfBoundsError) { - return outOfBoundsExpression; - } else { - return res; - } - } - - storm::expressions::Expression eliminate(storm::expressions::Expression const& expression, bool failIfOutOfBounds = true) { - outOfBoundsError = false; + storm::expressions::Expression eliminate(storm::expressions::Expression const& expression) { // here, data is the accessed index of the most recent array access expression. Initially, there is none. auto res = storm::expressions::Expression(boost::any_cast(expression.accept(*this, boost::any()))); - STORM_LOG_THROW(!failIfOutOfBounds || !outOfBoundsError, storm::exceptions::UnexpectedException, "Out of bounds array access occured while eliminating expression " << expression); + STORM_LOG_ASSERT(!containsArrayExpression(res), "Expression still contains array expressions. Before: " << std::endl << expression << std::endl << "After:" << std::endl << res); return res.simplify(); } @@ -242,12 +234,8 @@ namespace storm { STORM_LOG_THROW(!data.empty(), storm::exceptions::NotSupportedException, "Unable to translate ValueArrayExpression to element expression since it does not seem to be within an array access expression."); uint64_t index = boost::any_cast(data); STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ")."); - if (index < static_cast(expression.size()->evaluateAsInt())) { - return expression.at(index); - } else { - outOfBoundsError = true; - return expression.at(0); - } + STORM_LOG_THROW(index < static_cast(expression.size()->evaluateAsInt()), storm::exceptions::UnexpectedException, "Out of bounds array access occured while accessing index " << index << " of expression " << expression); + return expression.at(index); } virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override { @@ -255,8 +243,8 @@ namespace storm { uint64_t index = boost::any_cast(data); if (expression.size()->containsVariables()) { STORM_LOG_WARN("Ignoring length of constructorArrayExpression " << expression << " as it still contains variables."); - } else if (index >= static_cast(expression.size()->evaluateAsInt())) { - outOfBoundsError = true; + } else { + STORM_LOG_THROW(index < static_cast(expression.size()->evaluateAsInt()), storm::exceptions::UnexpectedException, "Out of bounds array access occured while accessing index " << index << " of expression " << expression); } return expression.at(index); } @@ -270,22 +258,22 @@ namespace storm { storm::expressions::Expression result = boost::any_cast(expression.getFirstOperand()->accept(*this, index))->toExpression(); while (index > 0) { --index; - result = storm::expressions::ite( - expression.getSecondOperand()->toExpression() == expression.getManager().integer(index), + storm::expressions::Expression isCurrentIndex = boost::any_cast(expression.getSecondOperand()->accept(*this, boost::any()))->toExpression() == expression.getManager().integer(index); + result = storm::expressions::ite(isCurrentIndex, boost::any_cast(expression.getFirstOperand()->accept(*this, index))->toExpression(), result); } return result.getBaseExpressionPointer(); } else { uint64_t index = expression.getSecondOperand()->evaluateAsInt(); - return boost::any_cast(expression.getFirstOperand()->accept(*this, index)); + auto result = boost::any_cast(expression.getFirstOperand()->accept(*this, index)); + return result; } } private: std::unordered_map> const& replacements; std::unordered_map const& arraySizes; - bool outOfBoundsError; }; class MaxArraySizeDeterminer : public ConstJaniTraverser { @@ -404,6 +392,28 @@ namespace storm { automaton.setInitialStatesRestriction(arrayExprEliminator->eliminate(automaton.getInitialStatesRestriction())); } } + + void traverse(TemplateEdge& templateEdge, boost::any const& data) override { + templateEdge.setGuard(arrayExprEliminator->eliminate(templateEdge.getGuard())); + for (auto& dest : templateEdge.getDestinations()) { + JaniTraverser::traverse(dest, data); + } + traverse(templateEdge.getAssignments(), data); + } + + + void traverse(Edge& edge, boost::any const& data) override { + if (edge.hasRate()) { + edge.setRate(arrayExprEliminator->eliminate(edge.getRate())); + } + for (auto& dest : edge.getDestinations()) { + JaniTraverser::traverse(dest, data); + } + } + + void traverse(EdgeDestination& edgeDestination, boost::any const& data) override { + edgeDestination.setProbability(arrayExprEliminator->eliminate(edgeDestination.getProbability())); + } virtual void traverse(OrderedAssignments& orderedAssignments, boost::any const& data) override { auto const& replacements = boost::any_cast(data)->replacements; @@ -428,19 +438,31 @@ namespace storm { insertionRes.first->second.push_back(&assignment); } continue; + } else { + // Keeping array access LValue + LValue newLValue(LValue(assignment.getLValue().getArray()), arrayExprEliminator->eliminate(assignment.getLValue().getArrayIndex())); + newAssignments.emplace_back(newLValue, arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel()); } } else if (assignment.getLValue().isVariable() && assignment.getVariable().isArrayVariable()) { STORM_LOG_ASSERT(assignment.getAssignedExpression().getType().isArrayType(), "Assigning a non-array expression to an array variable..."); std::vector const& arrayVariableReplacements = replacements.at(assignment.getExpressionVariable()); + // Get the maximum size of the array expression on the rhs + uint64_t rhsSize = MaxArraySizeExpressionVisitor().getMaxSize(assignment.getAssignedExpression(), arraySizes); + STORM_LOG_ASSERT(arrayVariableReplacements.size() >= rhsSize, "Array size too small."); for (uint64_t index = 0; index < arrayVariableReplacements.size(); ++index) { auto const& replacement = *arrayVariableReplacements[index]; - auto arrayAccessExpression = std::make_shared(expressionManager, assignment.getAssignedExpression().getType().getElementType(), assignment.getAssignedExpression().getBaseExpressionPointer(), expressionManager.integer(index).getBaseExpressionPointer())->toExpression(); - arrayAccessExpression = arrayExprEliminator->eliminate(arrayAccessExpression, getOutOfBoundsValue(replacement)); - newAssignments.emplace_back(LValue(replacement), arrayAccessExpression, level); + storm::expressions::Expression newRhs; + if (index < rhsSize) { + newRhs = std::make_shared(expressionManager, assignment.getAssignedExpression().getType().getElementType(), assignment.getAssignedExpression().getBaseExpressionPointer(), expressionManager.integer(index).getBaseExpressionPointer())->toExpression(); + } else { + newRhs = getOutOfBoundsValue(replacement); + } + newRhs = arrayExprEliminator->eliminate(newRhs); + newAssignments.emplace_back(LValue(replacement), newRhs, level); } - continue; + } else { + newAssignments.emplace_back(assignment.getLValue(), arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel()); } - newAssignments.emplace_back(assignment.getLValue(), arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel()); } for (auto const& arrayAssignments : collectedArrayAccessAssignments) { insertLValueArrayAccessReplacements(arrayAssignments.second, replacements.at(arrayAssignments.first), level, newAssignments); @@ -513,9 +535,9 @@ namespace storm { storm::expressions::Expression assignedExpression = arrayVariableReplacements[index]->getExpressionVariable().getExpression(); auto indexExpression = expressionManager.integer(index); for (auto const& aa : arrayAccesses) { - assignedExpression = storm::expressions::ite(aa->getLValue().getArrayIndex() == indexExpression, arrayExprEliminator->eliminate(aa->getAssignedExpression()), assignedExpression); - newAssignments.emplace_back(LValue(*arrayVariableReplacements[index]), assignedExpression, level); + assignedExpression = storm::expressions::ite(arrayExprEliminator->eliminate(aa->getLValue().getArrayIndex()) == indexExpression, arrayExprEliminator->eliminate(aa->getAssignedExpression()), assignedExpression); } + newAssignments.emplace_back(LValue(*arrayVariableReplacements[index]), assignedExpression, level); } } } @@ -552,6 +574,7 @@ namespace storm { ArrayEliminatorData result = detail::ArrayVariableReplacer(model.getExpressionManager(), keepNonTrivialArrayAccess, sizes).replace(model); model.finalize(); + STORM_LOG_ASSERT(!containsArrayExpression(model), "the model still contains array expressions."); return result; } } diff --git a/src/storm/storage/jani/TemplateEdge.cpp b/src/storm/storage/jani/TemplateEdge.cpp index 79644f8a6..a1ea3f7b5 100644 --- a/src/storm/storage/jani/TemplateEdge.cpp +++ b/src/storm/storage/jani/TemplateEdge.cpp @@ -56,6 +56,10 @@ namespace storm { return guard; } + void TemplateEdge::setGuard(storm::expressions::Expression const& newGuard) { + guard = newGuard; + } + std::size_t TemplateEdge::getNumberOfDestinations() const { return destinations.size(); } diff --git a/src/storm/storage/jani/TemplateEdge.h b/src/storm/storage/jani/TemplateEdge.h index 4e0d2a027..0928c50e0 100644 --- a/src/storm/storage/jani/TemplateEdge.h +++ b/src/storm/storage/jani/TemplateEdge.h @@ -21,6 +21,7 @@ namespace storm { TemplateEdge(storm::expressions::Expression const& guard, OrderedAssignments const& assignments, std::vector const& destinations); storm::expressions::Expression const& getGuard() const; + void setGuard(storm::expressions::Expression const& newGuard); void addDestination(TemplateEdgeDestination const& destination); diff --git a/src/storm/storage/jani/expressions/JaniExpressionVisitor.h b/src/storm/storage/jani/expressions/JaniExpressionVisitor.h index e0e76e95e..61808bb0d 100644 --- a/src/storm/storage/jani/expressions/JaniExpressionVisitor.h +++ b/src/storm/storage/jani/expressions/JaniExpressionVisitor.h @@ -1,12 +1,12 @@ #pragma once -#include "storm/storage/expressions/SubstitutionVisitor.h" +#include "storm/storage/expressions/ExpressionVisitor.h" #include "storm/storage/jani/expressions/JaniExpressions.h" namespace storm { namespace expressions { - class JaniExpressionVisitor{ + class JaniExpressionVisitor { public: virtual boost::any visit(ValueArrayExpression const& expression, boost::any const& data) = 0; virtual boost::any visit(ConstructorArrayExpression const& expression, boost::any const& data) = 0; diff --git a/src/storm/storage/jani/traverser/ArrayExpressionFinder.cpp b/src/storm/storage/jani/traverser/ArrayExpressionFinder.cpp new file mode 100644 index 000000000..f9c80a36a --- /dev/null +++ b/src/storm/storage/jani/traverser/ArrayExpressionFinder.cpp @@ -0,0 +1,102 @@ +#include "storm/storage/jani/traverser/ArrayExpressionFinder.h" + +#include "storm/storage/jani/traverser/JaniTraverser.h" +#include "storm/storage/jani/expressions/JaniExpressionVisitor.h" +#include "storm/storage/jani/expressions/JaniExpressions.h" +#include "storm/storage/jani/Model.h" + +namespace storm { + namespace jani { + + namespace detail { + class ArrayExpressionFinderExpressionVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor { + public: + virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override { + return + boost::any_cast(expression.getCondition()->accept(*this, data)) || + boost::any_cast(expression.getThenExpression()->accept(*this, data)) || + boost::any_cast(expression.getElseExpression()->accept(*this, data)); + } + + virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override { + return + boost::any_cast(expression.getFirstOperand()->accept(*this, data)) || + boost::any_cast(expression.getSecondOperand()->accept(*this, data)); + } + + virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override { + return + boost::any_cast(expression.getFirstOperand()->accept(*this, data)) || + boost::any_cast(expression.getSecondOperand()->accept(*this, data)); + } + + virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override { + return + boost::any_cast(expression.getFirstOperand()->accept(*this, data)) || + boost::any_cast(expression.getSecondOperand()->accept(*this, data)); + } + + virtual boost::any visit(storm::expressions::VariableExpression const& expression, boost::any const&) override { + return false; + } + + virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override { + return expression.getOperand()->accept(*this, data); + } + + virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override { + return expression.getOperand()->accept(*this, data); + } + + virtual boost::any visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) override { + return false; + } + + virtual boost::any visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) override { + return false; + } + + virtual boost::any visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) override { + return false; + } + + virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const& data) override { + return true; + } + + virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override { + return true; + } + + virtual boost::any visit(storm::expressions::ArrayAccessExpression const& expression, boost::any const& data) override { + return true; + } + }; + + class ArrayExpressionFinderTraverser : public ConstJaniTraverser { + public: + virtual void traverse(Model const& model, boost::any const& data) override { + ConstJaniTraverser::traverse(model, data); + } + + virtual void traverse(storm::expressions::Expression const& expression, boost::any const& data) override { + auto& res = *boost::any_cast(data); + res = res || containsArrayExpression(expression); + } + }; + } + + + bool containsArrayExpression(Model const& model) { + bool result = false; + detail::ArrayExpressionFinderTraverser().traverse(model, &result); + return result; + } + + bool containsArrayExpression(storm::expressions::Expression const& expression) { + detail::ArrayExpressionFinderExpressionVisitor v; + return boost::any_cast(expression.accept(v, boost::none)); + } + } +} + diff --git a/src/storm/storage/jani/traverser/ArrayExpressionFinder.h b/src/storm/storage/jani/traverser/ArrayExpressionFinder.h new file mode 100644 index 000000000..c924472cd --- /dev/null +++ b/src/storm/storage/jani/traverser/ArrayExpressionFinder.h @@ -0,0 +1,18 @@ +#pragma once + + +namespace storm { + + namespace expressions { + class Expression; + } + + namespace jani { + + class Model; + + bool containsArrayExpression(Model const& model); + bool containsArrayExpression(storm::expressions::Expression const& expr); + } +} +