Browse Source

better treatment of cases where array out of bounds accesses occurr

main
TimQu 7 years ago
parent
commit
bdeacf8669
  1. 280
      src/storm/storage/jani/ArrayEliminator.cpp

280
src/storm/storage/jani/ArrayEliminator.cpp

@ -128,94 +128,119 @@ namespace storm {
class ArrayExpressionEliminationVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor { class ArrayExpressionEliminationVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor {
public: public:
typedef std::shared_ptr<storm::expressions::BaseExpression const> BaseExprPtr; typedef std::shared_ptr<storm::expressions::BaseExpression const> BaseExprPtr;
class ResultType {
public:
ResultType(ResultType const& other) = default;
ResultType(BaseExprPtr expression) : expression(expression), arrayOutOfBoundsMessage("") {}
ResultType(std::string arrayOutOfBoundsMessage) : expression(nullptr), arrayOutOfBoundsMessage(arrayOutOfBoundsMessage) {}
BaseExprPtr& expr() {
STORM_LOG_ASSERT(!isArrayOutOfBounds(), "Tried to get the result expression, but " << arrayOutOfBoundsMessage);
return expression;
};
bool isArrayOutOfBounds() { return arrayOutOfBoundsMessage != ""; };
std::string const& outOfBoundsMessage() const { return arrayOutOfBoundsMessage; }
private:
BaseExprPtr expression;
std::string arrayOutOfBoundsMessage;
};
ArrayExpressionEliminationVisitor(std::unordered_map<storm::expressions::Variable, std::vector<storm::jani::Variable const*>> const& replacements, std::unordered_map<storm::expressions::Variable, std::size_t> const& sizes) : replacements(replacements), arraySizes(sizes) {} ArrayExpressionEliminationVisitor(std::unordered_map<storm::expressions::Variable, std::vector<storm::jani::Variable const*>> const& replacements, std::unordered_map<storm::expressions::Variable, std::size_t> const& sizes) : replacements(replacements), arraySizes(sizes) {}
virtual ~ArrayExpressionEliminationVisitor() = default; virtual ~ArrayExpressionEliminationVisitor() = default;
storm::expressions::Expression eliminate(storm::expressions::Expression const& expression) { storm::expressions::Expression eliminate(storm::expressions::Expression const& expression) {
// here, data is the accessed index of the most recent array access expression. Initially, there is none. // here, data is the accessed index of the most recent array access expression. Initially, there is none.
auto res = storm::expressions::Expression(boost::any_cast<BaseExprPtr>(expression.accept(*this, boost::any()))); auto res = boost::any_cast<ResultType>(expression.accept(*this, boost::any()));
STORM_LOG_ASSERT(!containsArrayExpression(res), "Expression still contains array expressions. Before: " << std::endl << expression << std::endl << "After:" << std::endl << res); STORM_LOG_THROW(!res.isArrayOutOfBounds(), storm::exceptions::OutOfRangeException, res.outOfBoundsMessage());
return res.simplify(); STORM_LOG_ASSERT(!containsArrayExpression(res.expr()->toExpression()), "Expression still contains array expressions. Before: " << std::endl << expression << std::endl << "After:" << std::endl << res.expr()->toExpression());
return res.expr()->simplify();
} }
virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override { virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override {
BaseExprPtr thenExpression, elseExpression; // for the condition expression, outer array accesses should not matter.
ResultType conditionResult = boost::any_cast<ResultType>(expression.getCondition()->accept(*this, boost::any()));
if (conditionResult.isArrayOutOfBounds()) {
return conditionResult;
}
// We need to handle expressions of the kind '42<size : A[42] : 0', where size is a variable. // We need to handle expressions of the kind '42<size : A[42] : 0', where size is a variable and A[42] might be out of bounds.
// If an out of range exception occurrs in the 'then' or the 'else' branch, we assume that the this expression represents the other branch. ResultType thenResult = boost::any_cast<ResultType>(expression.getThenExpression()->accept(*this, data));
// TODO: Make this more reliable ResultType elseResult = boost::any_cast<ResultType>(expression.getElseExpression()->accept(*this, data));
bool thenOutOfRange(false), elseOutOfRange(false); if (thenResult.isArrayOutOfBounds()) {
try { if (elseResult.isArrayOutOfBounds()) {
thenExpression = boost::any_cast<BaseExprPtr>(expression.getThenExpression()->accept(*this, data)); return ResultType(thenResult.outOfBoundsMessage() + " and " + elseResult.outOfBoundsMessage());
} catch (storm::exceptions::OutOfRangeException const&) {
thenOutOfRange = true;
}
try {
elseExpression = boost::any_cast<BaseExprPtr>(expression.getElseExpression()->accept(*this, data));
} catch (storm::exceptions::OutOfRangeException const& e) {
if (thenOutOfRange) {
throw e;
} else { } else {
elseOutOfRange = true; // Assume the else expression
return elseResult;
} }
} } else if (elseResult.isArrayOutOfBounds()) {
// Assume the then expression
if (thenOutOfRange) { return thenResult;
assert(!elseOutOfRange);
return elseExpression;
} else if (elseOutOfRange) {
return thenExpression;
}
// for the condition expression, outer array accesses should not matter.
BaseExprPtr conditionExpression = boost::any_cast<BaseExprPtr>(expression.getCondition()->accept(*this, boost::any()));
// If the arguments did not change, we simply push the expression itself.
if (conditionExpression.get() == expression.getCondition().get() && thenExpression.get() == expression.getThenExpression().get() && elseExpression.get() == expression.getElseExpression().get()) {
return expression.getSharedPointer();
} else { } else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::IfThenElseExpression(expression.getManager(), thenExpression->getType(), conditionExpression, thenExpression, elseExpression))); // If the arguments did not change, we simply push the expression itself.
if (conditionResult.expr().get() == expression.getCondition().get() && thenResult.expr().get() == expression.getThenExpression().get() && elseResult.expr().get() == expression.getElseExpression().get()) {
return ResultType(expression.getSharedPointer());
} else {
return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::IfThenElseExpression(expression.getManager(), thenResult.expr()->getType(), conditionResult.expr(), thenResult.expr(), elseResult.expr()))));
}
} }
} }
virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override { virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override {
STORM_LOG_ASSERT(data.empty(), "BinaryBooleanFunctionExpressions should not be direct subexpressions of array access expressions. However, the expression " << expression << " is."); STORM_LOG_ASSERT(data.empty(), "BinaryBooleanFunctionExpressions should not be direct subexpressions of array access expressions. However, the expression " << expression << " is.");
BaseExprPtr firstExpression = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, data)); ResultType firstResult = boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, data));
BaseExprPtr secondExpression = boost::any_cast<BaseExprPtr>(expression.getSecondOperand()->accept(*this, data)); ResultType secondResult = boost::any_cast<ResultType>(expression.getSecondOperand()->accept(*this, data));
if (firstResult.isArrayOutOfBounds()) {
return firstResult;
} else if (secondResult.isArrayOutOfBounds()) {
return secondResult;
}
// If the arguments did not change, we simply push the expression itself. // If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { if (firstResult.expr().get() == expression.getFirstOperand().get() && secondResult.expr().get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer(); return ResultType(expression.getSharedPointer());
} else { } else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType()))); return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), firstResult.expr(), secondResult.expr(), expression.getOperatorType()))));
} }
} }
virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override { virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override {
STORM_LOG_ASSERT(data.empty(), "BinaryNumericalFunctionExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is."); STORM_LOG_ASSERT(data.empty(), "BinaryNumericalFunctionExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is.");
BaseExprPtr firstExpression = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, data)); ResultType firstResult = boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, data));
BaseExprPtr secondExpression = boost::any_cast<BaseExprPtr>(expression.getSecondOperand()->accept(*this, data)); ResultType secondResult = boost::any_cast<ResultType>(expression.getSecondOperand()->accept(*this, data));
if (firstResult.isArrayOutOfBounds()) {
return firstResult;
} else if (secondResult.isArrayOutOfBounds()) {
return secondResult;
}
// If the arguments did not change, we simply push the expression itself. // If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { if (firstResult.expr().get() == expression.getFirstOperand().get() && secondResult.expr().get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer(); return ResultType(expression.getSharedPointer());
} else { } else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getOperatorType()))); return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), firstResult.expr(), secondResult.expr(), expression.getOperatorType()))));
} }
} }
virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override { virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override {
STORM_LOG_ASSERT(data.empty(), "BinaryRelationExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is."); STORM_LOG_ASSERT(data.empty(), "BinaryRelationExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is.");
BaseExprPtr firstExpression = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, data)); ResultType firstResult = boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, data));
BaseExprPtr secondExpression = boost::any_cast<BaseExprPtr>(expression.getSecondOperand()->accept(*this, data)); ResultType secondResult = boost::any_cast<ResultType>(expression.getSecondOperand()->accept(*this, data));
if (firstResult.isArrayOutOfBounds()) {
return firstResult;
} else if (secondResult.isArrayOutOfBounds()) {
return secondResult;
}
// If the arguments did not change, we simply push the expression itself. // If the arguments did not change, we simply push the expression itself.
if (firstExpression.get() == expression.getFirstOperand().get() && secondExpression.get() == expression.getSecondOperand().get()) { if (firstResult.expr().get() == expression.getFirstOperand().get() && secondResult.expr().get() == expression.getSecondOperand().get()) {
return expression.getSharedPointer(); return ResultType(expression.getSharedPointer());
} else { } else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryRelationExpression(expression.getManager(), expression.getType(), firstExpression, secondExpression, expression.getRelationType()))); return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryRelationExpression(expression.getManager(), expression.getType(), firstResult.expr(), secondResult.expr(), expression.getRelationType()))));
} }
} }
@ -225,59 +250,69 @@ namespace storm {
uint64_t index = boost::any_cast<uint64_t>(data); uint64_t index = boost::any_cast<uint64_t>(data);
STORM_LOG_ASSERT(replacements.find(expression.getVariable()) != replacements.end(), "Unable to find array variable " << expression << " in array replacements."); STORM_LOG_ASSERT(replacements.find(expression.getVariable()) != replacements.end(), "Unable to find array variable " << expression << " in array replacements.");
auto const& arrayVarReplacements = replacements.at(expression.getVariable()); auto const& arrayVarReplacements = replacements.at(expression.getVariable());
if (index >= arrayVarReplacements.size()) throw storm::exceptions::OutOfRangeException(); if (index >= arrayVarReplacements.size()) {
//STORM_LOG_THROW(index < arrayVarReplacements.size(), storm::exceptions::OutOfRangeException, "Array index " << index << " for variable " << expression << " is out of bounds."); return ResultType("Array index " + std::to_string(index) + " for variable " + expression.getVariableName() + " is out of bounds.");
return arrayVarReplacements[index]->getExpressionVariable().getExpression().getBaseExpressionPointer(); }
return ResultType(arrayVarReplacements[index]->getExpressionVariable().getExpression().getBaseExpressionPointer());
} else { } else {
STORM_LOG_ASSERT(data.empty(), "VariableExpression of non-array variable should not be a subexpressions of array access expressions. However, the expression " << expression << " is."); STORM_LOG_ASSERT(data.empty(), "VariableExpression of non-array variable should not be a subexpressions of array access expressions. However, the expression " << expression << " is.");
return expression.getSharedPointer(); return ResultType(expression.getSharedPointer());
} }
} }
virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override { virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override {
STORM_LOG_ASSERT(data.empty(), "UnaryBooleanFunctionExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is."); STORM_LOG_ASSERT(data.empty(), "UnaryBooleanFunctionExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is.");
BaseExprPtr operandExpression = boost::any_cast<BaseExprPtr>(expression.getOperand()->accept(*this, data)); ResultType operandResult = boost::any_cast<ResultType>(expression.getOperand()->accept(*this, data));
if (operandResult.isArrayOutOfBounds()) {
return operandResult;
}
// If the argument did not change, we simply push the expression itself. // If the argument did not change, we simply push the expression itself.
if (operandExpression.get() == expression.getOperand().get()) { if (operandResult.expr().get() == expression.getOperand().get()) {
return expression.getSharedPointer(); return ResultType(expression.getSharedPointer());
} else { } else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::UnaryBooleanFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType()))); return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::UnaryBooleanFunctionExpression(expression.getManager(), expression.getType(), operandResult.expr(), expression.getOperatorType()))));
} }
} }
virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override { virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override {
STORM_LOG_ASSERT(data.empty(), "UnaryBooleanFunctionExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is."); STORM_LOG_ASSERT(data.empty(), "UnaryBooleanFunctionExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is.");
BaseExprPtr operandExpression = boost::any_cast<BaseExprPtr>(expression.getOperand()->accept(*this, data)); ResultType operandResult = boost::any_cast<ResultType>(expression.getOperand()->accept(*this, data));
if (operandResult.isArrayOutOfBounds()) {
return operandResult;
}
// If the argument did not change, we simply push the expression itself. // If the argument did not change, we simply push the expression itself.
if (operandExpression.get() == expression.getOperand().get()) { if (operandResult.expr().get() == expression.getOperand().get()) {
return expression.getSharedPointer(); return ResultType(expression.getSharedPointer());
} else { } else {
return std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), operandExpression, expression.getOperatorType()))); return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), operandResult.expr(), expression.getOperatorType()))));
} }
} }
virtual boost::any visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) override { virtual boost::any visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) override {
return expression.getSharedPointer(); return ResultType(expression.getSharedPointer());
} }
virtual boost::any visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) override { virtual boost::any visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) override {
return expression.getSharedPointer(); return ResultType(expression.getSharedPointer());
} }
virtual boost::any visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) override { virtual boost::any visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) override {
return expression.getSharedPointer(); return ResultType(expression.getSharedPointer());
} }
virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const& data) override { virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const& data) override {
STORM_LOG_THROW(!data.empty(), storm::exceptions::NotSupportedException, "Unable to translate ValueArrayExpression to element expression since it does not seem to be within an array access expression."); STORM_LOG_THROW(!data.empty(), storm::exceptions::NotSupportedException, "Unable to translate ValueArrayExpression to element expression since it does not seem to be within an array access expression.");
uint64_t index = boost::any_cast<uint64_t>(data); uint64_t index = boost::any_cast<uint64_t>(data);
STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ")."); STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ").");
if (index >= static_cast<uint64_t>(expression.size()->evaluateAsInt())) throw storm::exceptions::OutOfRangeException(); if (index >= static_cast<uint64_t>(expression.size()->evaluateAsInt())) {
// STORM_LOG_THROW(index < static_cast<uint64_t>(expression.size()->evaluateAsInt()), storm::exceptions::OutOfRangeException, "Out of bounds array access occured while accessing index " << index << " of expression " << expression); return ResultType("Array index " + std::to_string(index) + " for ValueArrayExpression " + expression.toExpression().toString() + " is out of bounds.");
return boost::any_cast<BaseExprPtr>(expression.at(index)->accept(*this, boost::any())); }
return ResultType(boost::any_cast<ResultType>(expression.at(index)->accept(*this, boost::any())));
} }
virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override { virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override {
@ -286,10 +321,11 @@ namespace storm {
if (expression.size()->containsVariables()) { if (expression.size()->containsVariables()) {
STORM_LOG_WARN("Ignoring length of constructorArrayExpression " << expression << " as it still contains variables."); STORM_LOG_WARN("Ignoring length of constructorArrayExpression " << expression << " as it still contains variables.");
} else { } else {
if (index >= static_cast<uint64_t>(expression.size()->evaluateAsInt())) throw storm::exceptions::OutOfRangeException(); if (index >= static_cast<uint64_t>(expression.size()->evaluateAsInt())) {
// STORM_LOG_THROW(index < static_cast<uint64_t>(expression.size()->evaluateAsInt()), storm::exceptions::OutOfRangeException, "Out of bounds array access occured while accessing index " << index << " of expression " << expression); return ResultType("Array index " + std::to_string(index) + " for ConstructorArrayExpression " + expression.toExpression().toString() + " is out of bounds.");
}
} }
return boost::any_cast<BaseExprPtr>(expression.at(index)->accept(*this, boost::any())); return ResultType(boost::any_cast<ResultType>(expression.at(index)->accept(*this, boost::any())));
} }
virtual boost::any visit(storm::expressions::ArrayAccessExpression const& expression, boost::any const& data) override { virtual boost::any visit(storm::expressions::ArrayAccessExpression const& expression, boost::any const& data) override {
@ -298,19 +334,18 @@ namespace storm {
uint64_t size = MaxArraySizeExpressionVisitor().getMaxSize(expression.getFirstOperand()->toExpression(), arraySizes); uint64_t size = MaxArraySizeExpressionVisitor().getMaxSize(expression.getFirstOperand()->toExpression(), arraySizes);
STORM_LOG_THROW(size > 0, storm::exceptions::NotSupportedException, "Unable to get size of array expression for array access " << expression << "."); STORM_LOG_THROW(size > 0, storm::exceptions::NotSupportedException, "Unable to get size of array expression for array access " << expression << ".");
uint64_t index = size - 1; uint64_t index = size - 1;
storm::expressions::Expression result = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index))->toExpression(); storm::expressions::Expression result = boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, index)).expr()->toExpression();
while (index > 0) { while (index > 0) {
--index; --index;
storm::expressions::Expression isCurrentIndex = boost::any_cast<BaseExprPtr>(expression.getSecondOperand()->accept(*this, boost::any()))->toExpression() == expression.getManager().integer(index); storm::expressions::Expression isCurrentIndex = boost::any_cast<ResultType>(expression.getSecondOperand()->accept(*this, boost::any())).expr()->toExpression() == expression.getManager().integer(index);
result = storm::expressions::ite(isCurrentIndex, result = storm::expressions::ite(isCurrentIndex,
boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index))->toExpression(), boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, index)).expr()->toExpression(),
result); result);
} }
return result.getBaseExpressionPointer(); return ResultType(result.getBaseExpressionPointer());
} else { } else {
uint64_t index = expression.getSecondOperand()->evaluateAsInt(); uint64_t index = expression.getSecondOperand()->evaluateAsInt();
auto result = boost::any_cast<BaseExprPtr>(expression.getFirstOperand()->accept(*this, index)); return boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, index));
return result;
} }
} }
@ -468,56 +503,59 @@ namespace storm {
// Replace array occurrences in LValues and assigned expressions. // Replace array occurrences in LValues and assigned expressions.
std::vector<Assignment> newAssignments; std::vector<Assignment> newAssignments;
int64_t level = orderedAssignments.getLowestLevel(); if (!orderedAssignments.empty()) {
std::unordered_map<storm::expressions::Variable, std::vector<Assignment const*>> collectedArrayAccessAssignments; int64_t level = orderedAssignments.getLowestLevel();
for (Assignment const& assignment : orderedAssignments) { std::unordered_map<storm::expressions::Variable, std::vector<Assignment const*>> collectedArrayAccessAssignments;
if (assignment.getLevel() != level) { for (Assignment const& assignment : orderedAssignments) {
STORM_LOG_ASSERT(assignment.getLevel() > level, "Ordered Assignment does not have the expected order."); if (assignment.getLevel() != level) {
for (auto const& arrayAssignments : collectedArrayAccessAssignments) { STORM_LOG_ASSERT(assignment.getLevel() > level, "Ordered Assignment does not have the expected order.");
insertLValueArrayAccessReplacements(arrayAssignments.second, replacements.at(arrayAssignments.first), level, newAssignments); for (auto const& arrayAssignments : collectedArrayAccessAssignments) {
} insertLValueArrayAccessReplacements(arrayAssignments.second, replacements.at(arrayAssignments.first), level, newAssignments);
collectedArrayAccessAssignments.clear();
level = assignment.getLevel();
}
if (assignment.getLValue().isArrayAccess()) {
if (!keepNonTrivialArrayAccess || !assignment.getLValue().getArrayIndex().containsVariables()) {
auto insertionRes = collectedArrayAccessAssignments.emplace(assignment.getLValue().getArray().getExpressionVariable(), std::vector<Assignment const*>({&assignment}));
if (!insertionRes.second) {
insertionRes.first->second.push_back(&assignment);
} }
} else { collectedArrayAccessAssignments.clear();
// Keeping array access LValue level = assignment.getLevel();
LValue newLValue(LValue(assignment.getLValue().getArray()), arrayExprEliminator->eliminate(assignment.getLValue().getArrayIndex()));
newAssignments.emplace_back(newLValue, arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
} }
} else if (assignment.getLValue().isVariable() && assignment.getVariable().isArrayVariable()) { if (assignment.getLValue().isArrayAccess()) {
STORM_LOG_ASSERT(assignment.getAssignedExpression().getType().isArrayType(), "Assigning a non-array expression to an array variable..."); if (!keepNonTrivialArrayAccess || !assignment.getLValue().getArrayIndex().containsVariables()) {
std::vector<storm::jani::Variable const*> const& arrayVariableReplacements = replacements.at(assignment.getExpressionVariable()); auto insertionRes = collectedArrayAccessAssignments.emplace(assignment.getLValue().getArray().getExpressionVariable(), std::vector<Assignment const*>({&assignment}));
// Get the maximum size of the array expression on the rhs if (!insertionRes.second) {
uint64_t rhsSize = MaxArraySizeExpressionVisitor().getMaxSize(assignment.getAssignedExpression(), arraySizes); insertionRes.first->second.push_back(&assignment);
STORM_LOG_ASSERT(arrayVariableReplacements.size() >= rhsSize, "Array size too small."); }
for (uint64_t index = 0; index < arrayVariableReplacements.size(); ++index) {
auto const& replacement = *arrayVariableReplacements[index];
storm::expressions::Expression newRhs;
if (index < rhsSize) {
newRhs = std::make_shared<storm::expressions::ArrayAccessExpression>(expressionManager, assignment.getAssignedExpression().getType().getElementType(), assignment.getAssignedExpression().getBaseExpressionPointer(), expressionManager.integer(index).getBaseExpressionPointer())->toExpression();
} else { } else {
newRhs = getOutOfBoundsValue(replacement); // Keeping array access LValue
LValue newLValue(LValue(assignment.getLValue().getArray()), arrayExprEliminator->eliminate(assignment.getLValue().getArrayIndex()));
newAssignments.emplace_back(newLValue, arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
}
} else if (assignment.getLValue().isVariable() && assignment.getVariable().isArrayVariable()) {
STORM_LOG_ASSERT(assignment.getAssignedExpression().getType().isArrayType(), "Assigning a non-array expression to an array variable...");
std::vector<storm::jani::Variable const*> const& arrayVariableReplacements = replacements.at(assignment.getExpressionVariable());
// Get the maximum size of the array expression on the rhs
uint64_t rhsSize = MaxArraySizeExpressionVisitor().getMaxSize(assignment.getAssignedExpression(), arraySizes);
STORM_LOG_ASSERT(arrayVariableReplacements.size() >= rhsSize, "Array size too small.");
for (uint64_t index = 0; index < arrayVariableReplacements.size(); ++index) {
auto const& replacement = *arrayVariableReplacements[index];
storm::expressions::Expression newRhs;
if (index < rhsSize) {
newRhs = std::make_shared<storm::expressions::ArrayAccessExpression>(expressionManager, assignment.getAssignedExpression().getType().getElementType(), assignment.getAssignedExpression().getBaseExpressionPointer(), expressionManager.integer(index).getBaseExpressionPointer())->toExpression();
} else {
newRhs = getOutOfBoundsValue(replacement);
}
newRhs = arrayExprEliminator->eliminate(newRhs);
newAssignments.emplace_back(LValue(replacement), newRhs, level);
} }
newRhs = arrayExprEliminator->eliminate(newRhs); } else {
newAssignments.emplace_back(LValue(replacement), newRhs, level); newAssignments.emplace_back(assignment.getLValue(), arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
} }
} else {
newAssignments.emplace_back(assignment.getLValue(), arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
} }
} for (auto const& arrayAssignments : collectedArrayAccessAssignments) {
for (auto const& arrayAssignments : collectedArrayAccessAssignments) { insertLValueArrayAccessReplacements(arrayAssignments.second, replacements.at(arrayAssignments.first), level, newAssignments);
insertLValueArrayAccessReplacements(arrayAssignments.second, replacements.at(arrayAssignments.first), level, newAssignments); }
} collectedArrayAccessAssignments.clear();
collectedArrayAccessAssignments.clear(); orderedAssignments.clear();
orderedAssignments.clear(); for (auto const& assignment : newAssignments) {
for (auto const& assignment : newAssignments) { orderedAssignments.add(assignment);
orderedAssignments.add(assignment); }
} }
} }

|||||||
100:0
Loading…
Cancel
Save