From ed45fa80e648b58988f78f23865cdf5e881a3721 Mon Sep 17 00:00:00 2001
From: TimQu <tim.quatmann@cs.rwth-aachen.de>
Date: Tue, 11 Sep 2018 18:00:51 +0200
Subject: [PATCH] debugging array elimination

---
 src/storm/storage/jani/ArrayEliminator.cpp    |  85 +++++++++------
 src/storm/storage/jani/TemplateEdge.cpp       |   4 +
 src/storm/storage/jani/TemplateEdge.h         |   1 +
 .../jani/expressions/JaniExpressionVisitor.h  |   4 +-
 .../jani/traverser/ArrayExpressionFinder.cpp  | 102 ++++++++++++++++++
 .../jani/traverser/ArrayExpressionFinder.h    |  18 ++++
 6 files changed, 181 insertions(+), 33 deletions(-)
 create mode 100644 src/storm/storage/jani/traverser/ArrayExpressionFinder.cpp
 create mode 100644 src/storm/storage/jani/traverser/ArrayExpressionFinder.h

diff --git a/src/storm/storage/jani/ArrayEliminator.cpp b/src/storm/storage/jani/ArrayEliminator.cpp
index fb50de59f..114e8ef3a 100644
--- a/src/storm/storage/jani/ArrayEliminator.cpp
+++ b/src/storm/storage/jani/ArrayEliminator.cpp
@@ -5,6 +5,8 @@
 #include "storm/storage/expressions/ExpressionVisitor.h"
 #include "storm/storage/jani/expressions/JaniExpressionVisitor.h"
 #include "storm/storage/jani/Variable.h"
+#include "storm/storage/jani/traverser/ArrayExpressionFinder.h"
+
 #include "storm/storage/expressions/Expressions.h"
 #include "storm/storage/jani/expressions/JaniExpressions.h"
 #include "storm/storage/expressions/ExpressionManager.h"
@@ -118,20 +120,10 @@ namespace storm {
                 ArrayExpressionEliminationVisitor(std::unordered_map<storm::expressions::Variable, std::vector<storm::jani::Variable const*>> const& replacements, std::unordered_map<storm::expressions::Variable, std::size_t> const& sizes) : replacements(replacements), arraySizes(sizes) {}
                 virtual ~ArrayExpressionEliminationVisitor() = default;
     
-                storm::expressions::Expression eliminate(storm::expressions::Expression const& expression, storm::expressions::Expression const& outOfBoundsExpression)  {
-                    auto res = eliminate(expression, false);
-                    if (outOfBoundsError) {
-                        return outOfBoundsExpression;
-                    } else {
-                        return res;
-                    }
-                }
-                
-                storm::expressions::Expression eliminate(storm::expressions::Expression const& expression, bool failIfOutOfBounds = true) {
-                    outOfBoundsError = false;
+                storm::expressions::Expression eliminate(storm::expressions::Expression const& expression) {
                     // here, data is the accessed index of the most recent array access expression. Initially, there is none.
                     auto res = storm::expressions::Expression(boost::any_cast<BaseExprPtr>(expression.accept(*this, boost::any())));
-                    STORM_LOG_THROW(!failIfOutOfBounds || !outOfBoundsError, storm::exceptions::UnexpectedException, "Out of bounds array access occured while eliminating expression " << expression);
+                    STORM_LOG_ASSERT(!containsArrayExpression(res), "Expression still contains array expressions. Before: " << std::endl << expression << std::endl << "After:" << std::endl << res);
                     return res.simplify();
                 }
                 
@@ -242,12 +234,8 @@ namespace storm {
                     STORM_LOG_THROW(!data.empty(), storm::exceptions::NotSupportedException, "Unable to translate ValueArrayExpression to element expression since it does not seem to be within an array access expression.");
                     uint64_t index = boost::any_cast<uint64_t>(data);
                     STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ").");
-                    if (index < static_cast<uint64_t>(expression.size()->evaluateAsInt())) {
-                        return expression.at(index);
-                    } else {
-                        outOfBoundsError = true;
-                        return expression.at(0);
-                    }
+                    STORM_LOG_THROW(index < static_cast<uint64_t>(expression.size()->evaluateAsInt()), storm::exceptions::UnexpectedException, "Out of bounds array access occured while accessing index " << index << " of expression " << expression);
+                    return expression.at(index);
                 }
                 
                 virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override {
@@ -255,8 +243,8 @@ namespace storm {
                     uint64_t index = boost::any_cast<uint64_t>(data);
                     if (expression.size()->containsVariables()) {
                         STORM_LOG_WARN("Ignoring length of constructorArrayExpression " << expression << " as it still contains variables.");
-                    } else if (index >= static_cast<uint64_t>(expression.size()->evaluateAsInt())) {
-                        outOfBoundsError = true;
+                    } else {
+                        STORM_LOG_THROW(index < static_cast<uint64_t>(expression.size()->evaluateAsInt()), storm::exceptions::UnexpectedException, "Out of bounds array access occured while accessing index " << index << " of expression " << expression);
                     }
                     return expression.at(index);
                 }
@@ -270,22 +258,22 @@ namespace storm {
                         storm::expressions::Expression result = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index))->toExpression();
                         while (index > 0) {
                             --index;
-                            result = storm::expressions::ite(
-                                    expression.getSecondOperand()->toExpression() == expression.getManager().integer(index),
+                            storm::expressions::Expression isCurrentIndex = boost::any_cast<BaseExprPtr>(expression.getSecondOperand()->accept(*this, boost::any()))->toExpression() == expression.getManager().integer(index);
+                            result = storm::expressions::ite(isCurrentIndex,
                                     boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index))->toExpression(),
                                     result);
                         }
                         return result.getBaseExpressionPointer();
                     } else {
                         uint64_t index = expression.getSecondOperand()->evaluateAsInt();
-                        return boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index));
+                        auto result = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index));
+                        return result;
                     }
                 }
                 
             private:
                 std::unordered_map<storm::expressions::Variable, std::vector<storm::jani::Variable const*>> const& replacements;
                 std::unordered_map<storm::expressions::Variable, std::size_t> const& arraySizes;
-                bool outOfBoundsError;
             };
             
             class MaxArraySizeDeterminer : public ConstJaniTraverser {
@@ -404,6 +392,28 @@ namespace storm {
                         automaton.setInitialStatesRestriction(arrayExprEliminator->eliminate(automaton.getInitialStatesRestriction()));
                     }
                 }
+         
+                void traverse(TemplateEdge& templateEdge, boost::any const& data) override {
+                    templateEdge.setGuard(arrayExprEliminator->eliminate(templateEdge.getGuard()));
+                    for (auto& dest : templateEdge.getDestinations()) {
+                        JaniTraverser::traverse(dest, data);
+                    }
+                    traverse(templateEdge.getAssignments(), data);
+                }
+
+                
+                void traverse(Edge& edge, boost::any const& data) override {
+                    if (edge.hasRate()) {
+                        edge.setRate(arrayExprEliminator->eliminate(edge.getRate()));
+                    }
+                    for (auto& dest : edge.getDestinations()) {
+                        JaniTraverser::traverse(dest, data);
+                    }
+                }
+                
+                void traverse(EdgeDestination& edgeDestination, boost::any const& data) override {
+                    edgeDestination.setProbability(arrayExprEliminator->eliminate(edgeDestination.getProbability()));
+                }
                 
                 virtual void traverse(OrderedAssignments& orderedAssignments, boost::any const& data) override {
                     auto const& replacements = boost::any_cast<ResultType*>(data)->replacements;
@@ -428,19 +438,31 @@ namespace storm {
                                     insertionRes.first->second.push_back(&assignment);
                                 }
                                 continue;
+                            } else {
+                                // Keeping array access LValue
+                                LValue newLValue(LValue(assignment.getLValue().getArray()), arrayExprEliminator->eliminate(assignment.getLValue().getArrayIndex()));
+                                newAssignments.emplace_back(newLValue, arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
                             }
                         } else if (assignment.getLValue().isVariable() && assignment.getVariable().isArrayVariable()) {
                             STORM_LOG_ASSERT(assignment.getAssignedExpression().getType().isArrayType(), "Assigning a non-array expression to an array variable...");
                             std::vector<storm::jani::Variable const*> const& arrayVariableReplacements = replacements.at(assignment.getExpressionVariable());
+                            // Get the maximum size of the array expression on the rhs
+                            uint64_t rhsSize = MaxArraySizeExpressionVisitor().getMaxSize(assignment.getAssignedExpression(), arraySizes);
+                            STORM_LOG_ASSERT(arrayVariableReplacements.size() >= rhsSize, "Array size too small.");
                             for (uint64_t index = 0; index < arrayVariableReplacements.size(); ++index) {
                                 auto const& replacement = *arrayVariableReplacements[index];
-                                auto arrayAccessExpression = std::make_shared<storm::expressions::ArrayAccessExpression>(expressionManager, assignment.getAssignedExpression().getType().getElementType(), assignment.getAssignedExpression().getBaseExpressionPointer(), expressionManager.integer(index).getBaseExpressionPointer())->toExpression();
-                                arrayAccessExpression = arrayExprEliminator->eliminate(arrayAccessExpression, getOutOfBoundsValue(replacement));
-                                newAssignments.emplace_back(LValue(replacement), arrayAccessExpression, level);
+                                storm::expressions::Expression newRhs;
+                                if (index < rhsSize) {
+                                    newRhs = std::make_shared<storm::expressions::ArrayAccessExpression>(expressionManager, assignment.getAssignedExpression().getType().getElementType(), assignment.getAssignedExpression().getBaseExpressionPointer(), expressionManager.integer(index).getBaseExpressionPointer())->toExpression();
+                                } else {
+                                    newRhs = getOutOfBoundsValue(replacement);
+                                }
+                                newRhs = arrayExprEliminator->eliminate(newRhs);
+                                newAssignments.emplace_back(LValue(replacement), newRhs, level);
                             }
-                            continue;
+                        } else {
+                            newAssignments.emplace_back(assignment.getLValue(), arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
                         }
-                        newAssignments.emplace_back(assignment.getLValue(), arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
                     }
                     for (auto const& arrayAssignments : collectedArrayAccessAssignments) {
                         insertLValueArrayAccessReplacements(arrayAssignments.second, replacements.at(arrayAssignments.first), level, newAssignments);
@@ -513,9 +535,9 @@ namespace storm {
                             storm::expressions::Expression assignedExpression = arrayVariableReplacements[index]->getExpressionVariable().getExpression();
                             auto indexExpression = expressionManager.integer(index);
                             for (auto const& aa : arrayAccesses) {
-                                assignedExpression = storm::expressions::ite(aa->getLValue().getArrayIndex() == indexExpression, arrayExprEliminator->eliminate(aa->getAssignedExpression()), assignedExpression);
-                                newAssignments.emplace_back(LValue(*arrayVariableReplacements[index]), assignedExpression, level);
+                                assignedExpression = storm::expressions::ite(arrayExprEliminator->eliminate(aa->getLValue().getArrayIndex()) == indexExpression, arrayExprEliminator->eliminate(aa->getAssignedExpression()), assignedExpression);
                             }
+                            newAssignments.emplace_back(LValue(*arrayVariableReplacements[index]), assignedExpression, level);
                         }
                     }
                 }
@@ -552,6 +574,7 @@ namespace storm {
             ArrayEliminatorData result = detail::ArrayVariableReplacer(model.getExpressionManager(), keepNonTrivialArrayAccess, sizes).replace(model);
             
             model.finalize();
+            STORM_LOG_ASSERT(!containsArrayExpression(model), "the model still contains array expressions.");
             return result;
         }
     }
diff --git a/src/storm/storage/jani/TemplateEdge.cpp b/src/storm/storage/jani/TemplateEdge.cpp
index 79644f8a6..a1ea3f7b5 100644
--- a/src/storm/storage/jani/TemplateEdge.cpp
+++ b/src/storm/storage/jani/TemplateEdge.cpp
@@ -56,6 +56,10 @@ namespace storm {
             return guard;
         }
         
+        void TemplateEdge::setGuard(storm::expressions::Expression const& newGuard) {
+            guard = newGuard;
+        }
+        
         std::size_t TemplateEdge::getNumberOfDestinations() const {
             return destinations.size();
         }
diff --git a/src/storm/storage/jani/TemplateEdge.h b/src/storm/storage/jani/TemplateEdge.h
index 4e0d2a027..0928c50e0 100644
--- a/src/storm/storage/jani/TemplateEdge.h
+++ b/src/storm/storage/jani/TemplateEdge.h
@@ -21,6 +21,7 @@ namespace storm {
             TemplateEdge(storm::expressions::Expression const& guard, OrderedAssignments const& assignments, std::vector<TemplateEdgeDestination> const& destinations);
 
             storm::expressions::Expression const& getGuard() const;
+            void setGuard(storm::expressions::Expression const& newGuard);
 
             void addDestination(TemplateEdgeDestination const& destination);
             
diff --git a/src/storm/storage/jani/expressions/JaniExpressionVisitor.h b/src/storm/storage/jani/expressions/JaniExpressionVisitor.h
index e0e76e95e..61808bb0d 100644
--- a/src/storm/storage/jani/expressions/JaniExpressionVisitor.h
+++ b/src/storm/storage/jani/expressions/JaniExpressionVisitor.h
@@ -1,12 +1,12 @@
 #pragma once
 
-#include "storm/storage/expressions/SubstitutionVisitor.h"
+#include "storm/storage/expressions/ExpressionVisitor.h"
 #include "storm/storage/jani/expressions/JaniExpressions.h"
 
 
 namespace storm {
     namespace expressions {
-        class JaniExpressionVisitor{
+        class JaniExpressionVisitor {
         public:
             virtual boost::any visit(ValueArrayExpression const& expression, boost::any const& data) = 0;
             virtual boost::any visit(ConstructorArrayExpression const& expression, boost::any const& data) = 0;
diff --git a/src/storm/storage/jani/traverser/ArrayExpressionFinder.cpp b/src/storm/storage/jani/traverser/ArrayExpressionFinder.cpp
new file mode 100644
index 000000000..f9c80a36a
--- /dev/null
+++ b/src/storm/storage/jani/traverser/ArrayExpressionFinder.cpp
@@ -0,0 +1,102 @@
+#include "storm/storage/jani/traverser/ArrayExpressionFinder.h"
+
+#include "storm/storage/jani/traverser/JaniTraverser.h"
+#include "storm/storage/jani/expressions/JaniExpressionVisitor.h"
+#include "storm/storage/jani/expressions/JaniExpressions.h"
+#include "storm/storage/jani/Model.h"
+
+namespace storm {
+    namespace jani {
+        
+        namespace detail {
+            class ArrayExpressionFinderExpressionVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor {
+            public:
+                virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override {
+                    return
+                        boost::any_cast<bool>(expression.getCondition()->accept(*this, data)) ||
+                        boost::any_cast<bool>(expression.getThenExpression()->accept(*this, data)) ||
+                        boost::any_cast<bool>(expression.getElseExpression()->accept(*this, data));
+                }
+                
+                virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override {
+                    return
+                        boost::any_cast<bool>(expression.getFirstOperand()->accept(*this, data)) ||
+                        boost::any_cast<bool>(expression.getSecondOperand()->accept(*this, data));
+                }
+                
+                virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override {
+                    return
+                        boost::any_cast<bool>(expression.getFirstOperand()->accept(*this, data)) ||
+                        boost::any_cast<bool>(expression.getSecondOperand()->accept(*this, data));
+                }
+                
+                virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override {
+                    return
+                        boost::any_cast<bool>(expression.getFirstOperand()->accept(*this, data)) ||
+                        boost::any_cast<bool>(expression.getSecondOperand()->accept(*this, data));
+                }
+                
+                virtual boost::any visit(storm::expressions::VariableExpression const& expression, boost::any const&) override {
+                    return false;
+                }
+                
+                virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override {
+                    return expression.getOperand()->accept(*this, data);
+                }
+                
+                virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override {
+                    return expression.getOperand()->accept(*this, data);
+                }
+                
+                virtual boost::any visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) override {
+                    return false;
+                }
+                
+                virtual boost::any visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) override {
+                    return false;
+                }
+                
+                virtual boost::any visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) override {
+                    return false;
+                }
+                
+                virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const& data) override {
+                    return true;
+                }
+                
+                virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override {
+                    return true;
+                }
+                
+                virtual boost::any visit(storm::expressions::ArrayAccessExpression const& expression, boost::any const& data) override {
+                    return true;
+                }
+            };
+            
+            class ArrayExpressionFinderTraverser : public ConstJaniTraverser {
+            public:
+                virtual void traverse(Model const& model, boost::any const& data) override {
+                    ConstJaniTraverser::traverse(model, data);
+                }
+                
+                virtual void traverse(storm::expressions::Expression const& expression, boost::any const& data) override {
+                    auto& res = *boost::any_cast<bool*>(data);
+                    res = res || containsArrayExpression(expression);
+                }
+            };
+        }
+        
+        
+        bool containsArrayExpression(Model const& model) {
+            bool result = false;
+            detail::ArrayExpressionFinderTraverser().traverse(model, &result);
+            return result;
+        }
+        
+        bool containsArrayExpression(storm::expressions::Expression const& expression) {
+            detail::ArrayExpressionFinderExpressionVisitor v;
+            return boost::any_cast<bool>(expression.accept(v, boost::none));
+        }
+    }
+}
+
diff --git a/src/storm/storage/jani/traverser/ArrayExpressionFinder.h b/src/storm/storage/jani/traverser/ArrayExpressionFinder.h
new file mode 100644
index 000000000..c924472cd
--- /dev/null
+++ b/src/storm/storage/jani/traverser/ArrayExpressionFinder.h
@@ -0,0 +1,18 @@
+#pragma once
+
+
+namespace storm {
+    
+    namespace expressions {
+        class Expression;
+    }
+    
+    namespace jani {
+        
+        class Model;
+        
+        bool containsArrayExpression(Model const& model);
+        bool containsArrayExpression(storm::expressions::Expression const& expr);
+    }
+}
+