Browse Source

debugging array elimination

tempestpy_adaptions
TimQu 6 years ago
parent
commit
ed45fa80e6
  1. 85
      src/storm/storage/jani/ArrayEliminator.cpp
  2. 4
      src/storm/storage/jani/TemplateEdge.cpp
  3. 1
      src/storm/storage/jani/TemplateEdge.h
  4. 4
      src/storm/storage/jani/expressions/JaniExpressionVisitor.h
  5. 102
      src/storm/storage/jani/traverser/ArrayExpressionFinder.cpp
  6. 18
      src/storm/storage/jani/traverser/ArrayExpressionFinder.h

85
src/storm/storage/jani/ArrayEliminator.cpp

@ -5,6 +5,8 @@
#include "storm/storage/expressions/ExpressionVisitor.h"
#include "storm/storage/jani/expressions/JaniExpressionVisitor.h"
#include "storm/storage/jani/Variable.h"
#include "storm/storage/jani/traverser/ArrayExpressionFinder.h"
#include "storm/storage/expressions/Expressions.h"
#include "storm/storage/jani/expressions/JaniExpressions.h"
#include "storm/storage/expressions/ExpressionManager.h"
@ -118,20 +120,10 @@ namespace storm {
ArrayExpressionEliminationVisitor(std::unordered_map<storm::expressions::Variable, std::vector<storm::jani::Variable const*>> const& replacements, std::unordered_map<storm::expressions::Variable, std::size_t> const& sizes) : replacements(replacements), arraySizes(sizes) {}
virtual ~ArrayExpressionEliminationVisitor() = default;
storm::expressions::Expression eliminate(storm::expressions::Expression const& expression, storm::expressions::Expression const& outOfBoundsExpression) {
auto res = eliminate(expression, false);
if (outOfBoundsError) {
return outOfBoundsExpression;
} else {
return res;
}
}
storm::expressions::Expression eliminate(storm::expressions::Expression const& expression, bool failIfOutOfBounds = true) {
outOfBoundsError = false;
storm::expressions::Expression eliminate(storm::expressions::Expression const& expression) {
// here, data is the accessed index of the most recent array access expression. Initially, there is none.
auto res = storm::expressions::Expression(boost::any_cast<BaseExprPtr>(expression.accept(*this, boost::any())));
STORM_LOG_THROW(!failIfOutOfBounds || !outOfBoundsError, storm::exceptions::UnexpectedException, "Out of bounds array access occured while eliminating expression " << expression);
STORM_LOG_ASSERT(!containsArrayExpression(res), "Expression still contains array expressions. Before: " << std::endl << expression << std::endl << "After:" << std::endl << res);
return res.simplify();
}
@ -242,12 +234,8 @@ namespace storm {
STORM_LOG_THROW(!data.empty(), storm::exceptions::NotSupportedException, "Unable to translate ValueArrayExpression to element expression since it does not seem to be within an array access expression.");
uint64_t index = boost::any_cast<uint64_t>(data);
STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ").");
if (index < static_cast<uint64_t>(expression.size()->evaluateAsInt())) {
return expression.at(index);
} else {
outOfBoundsError = true;
return expression.at(0);
}
STORM_LOG_THROW(index < static_cast<uint64_t>(expression.size()->evaluateAsInt()), storm::exceptions::UnexpectedException, "Out of bounds array access occured while accessing index " << index << " of expression " << expression);
return expression.at(index);
}
virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override {
@ -255,8 +243,8 @@ namespace storm {
uint64_t index = boost::any_cast<uint64_t>(data);
if (expression.size()->containsVariables()) {
STORM_LOG_WARN("Ignoring length of constructorArrayExpression " << expression << " as it still contains variables.");
} else if (index >= static_cast<uint64_t>(expression.size()->evaluateAsInt())) {
outOfBoundsError = true;
} else {
STORM_LOG_THROW(index < static_cast<uint64_t>(expression.size()->evaluateAsInt()), storm::exceptions::UnexpectedException, "Out of bounds array access occured while accessing index " << index << " of expression " << expression);
}
return expression.at(index);
}
@ -270,22 +258,22 @@ namespace storm {
storm::expressions::Expression result = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index))->toExpression();
while (index > 0) {
--index;
result = storm::expressions::ite(
expression.getSecondOperand()->toExpression() == expression.getManager().integer(index),
storm::expressions::Expression isCurrentIndex = boost::any_cast<BaseExprPtr>(expression.getSecondOperand()->accept(*this, boost::any()))->toExpression() == expression.getManager().integer(index);
result = storm::expressions::ite(isCurrentIndex,
boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index))->toExpression(),
result);
}
return result.getBaseExpressionPointer();
} else {
uint64_t index = expression.getSecondOperand()->evaluateAsInt();
return boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index));
auto result = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index));
return result;
}
}
private:
std::unordered_map<storm::expressions::Variable, std::vector<storm::jani::Variable const*>> const& replacements;
std::unordered_map<storm::expressions::Variable, std::size_t> const& arraySizes;
bool outOfBoundsError;
};
class MaxArraySizeDeterminer : public ConstJaniTraverser {
@ -404,6 +392,28 @@ namespace storm {
automaton.setInitialStatesRestriction(arrayExprEliminator->eliminate(automaton.getInitialStatesRestriction()));
}
}
void traverse(TemplateEdge& templateEdge, boost::any const& data) override {
templateEdge.setGuard(arrayExprEliminator->eliminate(templateEdge.getGuard()));
for (auto& dest : templateEdge.getDestinations()) {
JaniTraverser::traverse(dest, data);
}
traverse(templateEdge.getAssignments(), data);
}
void traverse(Edge& edge, boost::any const& data) override {
if (edge.hasRate()) {
edge.setRate(arrayExprEliminator->eliminate(edge.getRate()));
}
for (auto& dest : edge.getDestinations()) {
JaniTraverser::traverse(dest, data);
}
}
void traverse(EdgeDestination& edgeDestination, boost::any const& data) override {
edgeDestination.setProbability(arrayExprEliminator->eliminate(edgeDestination.getProbability()));
}
virtual void traverse(OrderedAssignments& orderedAssignments, boost::any const& data) override {
auto const& replacements = boost::any_cast<ResultType*>(data)->replacements;
@ -428,19 +438,31 @@ namespace storm {
insertionRes.first->second.push_back(&assignment);
}
continue;
} else {
// Keeping array access LValue
LValue newLValue(LValue(assignment.getLValue().getArray()), arrayExprEliminator->eliminate(assignment.getLValue().getArrayIndex()));
newAssignments.emplace_back(newLValue, arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
}
} else if (assignment.getLValue().isVariable() && assignment.getVariable().isArrayVariable()) {
STORM_LOG_ASSERT(assignment.getAssignedExpression().getType().isArrayType(), "Assigning a non-array expression to an array variable...");
std::vector<storm::jani::Variable const*> const& arrayVariableReplacements = replacements.at(assignment.getExpressionVariable());
// Get the maximum size of the array expression on the rhs
uint64_t rhsSize = MaxArraySizeExpressionVisitor().getMaxSize(assignment.getAssignedExpression(), arraySizes);
STORM_LOG_ASSERT(arrayVariableReplacements.size() >= rhsSize, "Array size too small.");
for (uint64_t index = 0; index < arrayVariableReplacements.size(); ++index) {
auto const& replacement = *arrayVariableReplacements[index];
auto arrayAccessExpression = std::make_shared<storm::expressions::ArrayAccessExpression>(expressionManager, assignment.getAssignedExpression().getType().getElementType(), assignment.getAssignedExpression().getBaseExpressionPointer(), expressionManager.integer(index).getBaseExpressionPointer())->toExpression();
arrayAccessExpression = arrayExprEliminator->eliminate(arrayAccessExpression, getOutOfBoundsValue(replacement));
newAssignments.emplace_back(LValue(replacement), arrayAccessExpression, level);
storm::expressions::Expression newRhs;
if (index < rhsSize) {
newRhs = std::make_shared<storm::expressions::ArrayAccessExpression>(expressionManager, assignment.getAssignedExpression().getType().getElementType(), assignment.getAssignedExpression().getBaseExpressionPointer(), expressionManager.integer(index).getBaseExpressionPointer())->toExpression();
} else {
newRhs = getOutOfBoundsValue(replacement);
}
newRhs = arrayExprEliminator->eliminate(newRhs);
newAssignments.emplace_back(LValue(replacement), newRhs, level);
}
continue;
} else {
newAssignments.emplace_back(assignment.getLValue(), arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
}
newAssignments.emplace_back(assignment.getLValue(), arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
}
for (auto const& arrayAssignments : collectedArrayAccessAssignments) {
insertLValueArrayAccessReplacements(arrayAssignments.second, replacements.at(arrayAssignments.first), level, newAssignments);
@ -513,9 +535,9 @@ namespace storm {
storm::expressions::Expression assignedExpression = arrayVariableReplacements[index]->getExpressionVariable().getExpression();
auto indexExpression = expressionManager.integer(index);
for (auto const& aa : arrayAccesses) {
assignedExpression = storm::expressions::ite(aa->getLValue().getArrayIndex() == indexExpression, arrayExprEliminator->eliminate(aa->getAssignedExpression()), assignedExpression);
newAssignments.emplace_back(LValue(*arrayVariableReplacements[index]), assignedExpression, level);
assignedExpression = storm::expressions::ite(arrayExprEliminator->eliminate(aa->getLValue().getArrayIndex()) == indexExpression, arrayExprEliminator->eliminate(aa->getAssignedExpression()), assignedExpression);
}
newAssignments.emplace_back(LValue(*arrayVariableReplacements[index]), assignedExpression, level);
}
}
}
@ -552,6 +574,7 @@ namespace storm {
ArrayEliminatorData result = detail::ArrayVariableReplacer(model.getExpressionManager(), keepNonTrivialArrayAccess, sizes).replace(model);
model.finalize();
STORM_LOG_ASSERT(!containsArrayExpression(model), "the model still contains array expressions.");
return result;
}
}

4
src/storm/storage/jani/TemplateEdge.cpp

@ -56,6 +56,10 @@ namespace storm {
return guard;
}
void TemplateEdge::setGuard(storm::expressions::Expression const& newGuard) {
guard = newGuard;
}
std::size_t TemplateEdge::getNumberOfDestinations() const {
return destinations.size();
}

1
src/storm/storage/jani/TemplateEdge.h

@ -21,6 +21,7 @@ namespace storm {
TemplateEdge(storm::expressions::Expression const& guard, OrderedAssignments const& assignments, std::vector<TemplateEdgeDestination> const& destinations);
storm::expressions::Expression const& getGuard() const;
void setGuard(storm::expressions::Expression const& newGuard);
void addDestination(TemplateEdgeDestination const& destination);

4
src/storm/storage/jani/expressions/JaniExpressionVisitor.h

@ -1,12 +1,12 @@
#pragma once
#include "storm/storage/expressions/SubstitutionVisitor.h"
#include "storm/storage/expressions/ExpressionVisitor.h"
#include "storm/storage/jani/expressions/JaniExpressions.h"
namespace storm {
namespace expressions {
class JaniExpressionVisitor{
class JaniExpressionVisitor {
public:
virtual boost::any visit(ValueArrayExpression const& expression, boost::any const& data) = 0;
virtual boost::any visit(ConstructorArrayExpression const& expression, boost::any const& data) = 0;

102
src/storm/storage/jani/traverser/ArrayExpressionFinder.cpp

@ -0,0 +1,102 @@
#include "storm/storage/jani/traverser/ArrayExpressionFinder.h"
#include "storm/storage/jani/traverser/JaniTraverser.h"
#include "storm/storage/jani/expressions/JaniExpressionVisitor.h"
#include "storm/storage/jani/expressions/JaniExpressions.h"
#include "storm/storage/jani/Model.h"
namespace storm {
namespace jani {
namespace detail {
class ArrayExpressionFinderExpressionVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor {
public:
virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override {
return
boost::any_cast<bool>(expression.getCondition()->accept(*this, data)) ||
boost::any_cast<bool>(expression.getThenExpression()->accept(*this, data)) ||
boost::any_cast<bool>(expression.getElseExpression()->accept(*this, data));
}
virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override {
return
boost::any_cast<bool>(expression.getFirstOperand()->accept(*this, data)) ||
boost::any_cast<bool>(expression.getSecondOperand()->accept(*this, data));
}
virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override {
return
boost::any_cast<bool>(expression.getFirstOperand()->accept(*this, data)) ||
boost::any_cast<bool>(expression.getSecondOperand()->accept(*this, data));
}
virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override {
return
boost::any_cast<bool>(expression.getFirstOperand()->accept(*this, data)) ||
boost::any_cast<bool>(expression.getSecondOperand()->accept(*this, data));
}
virtual boost::any visit(storm::expressions::VariableExpression const& expression, boost::any const&) override {
return false;
}
virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override {
return expression.getOperand()->accept(*this, data);
}
virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override {
return expression.getOperand()->accept(*this, data);
}
virtual boost::any visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) override {
return false;
}
virtual boost::any visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) override {
return false;
}
virtual boost::any visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) override {
return false;
}
virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const& data) override {
return true;
}
virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override {
return true;
}
virtual boost::any visit(storm::expressions::ArrayAccessExpression const& expression, boost::any const& data) override {
return true;
}
};
class ArrayExpressionFinderTraverser : public ConstJaniTraverser {
public:
virtual void traverse(Model const& model, boost::any const& data) override {
ConstJaniTraverser::traverse(model, data);
}
virtual void traverse(storm::expressions::Expression const& expression, boost::any const& data) override {
auto& res = *boost::any_cast<bool*>(data);
res = res || containsArrayExpression(expression);
}
};
}
bool containsArrayExpression(Model const& model) {
bool result = false;
detail::ArrayExpressionFinderTraverser().traverse(model, &result);
return result;
}
bool containsArrayExpression(storm::expressions::Expression const& expression) {
detail::ArrayExpressionFinderExpressionVisitor v;
return boost::any_cast<bool>(expression.accept(v, boost::none));
}
}
}

18
src/storm/storage/jani/traverser/ArrayExpressionFinder.h

@ -0,0 +1,18 @@
#pragma once
namespace storm {
namespace expressions {
class Expression;
}
namespace jani {
class Model;
bool containsArrayExpression(Model const& model);
bool containsArrayExpression(storm::expressions::Expression const& expr);
}
}
Loading…
Cancel
Save