You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

700 lines
49 KiB

  1. #include "storm/storage/jani/ArrayEliminator.h"
  2. #include <unordered_map>
  3. #include "storm/storage/expressions/ExpressionVisitor.h"
  4. #include "storm/storage/jani/expressions/JaniExpressionVisitor.h"
  5. #include "storm/storage/jani/Variable.h"
  6. #include "storm/storage/jani/Model.h"
  7. #include "storm/storage/jani/Property.h"
  8. #include "storm/storage/jani/traverser/JaniTraverser.h"
  9. #include "storm/storage/jani/traverser/ArrayExpressionFinder.h"
  10. #include "storm/storage/expressions/Expressions.h"
  11. #include "storm/storage/jani/expressions/JaniExpressions.h"
  12. #include "storm/storage/expressions/ExpressionManager.h"
  13. #include "storm/exceptions/NotSupportedException.h"
  14. #include "storm/exceptions/UnexpectedException.h"
  15. #include "storm/exceptions/OutOfRangeException.h"
  16. namespace storm {
  17. namespace jani {
  18. namespace detail {
  19. class MaxArraySizeExpressionVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor {
  20. public:
  21. using storm::expressions::ExpressionVisitor::visit;
  22. MaxArraySizeExpressionVisitor() = default;
  23. virtual ~MaxArraySizeExpressionVisitor() = default;
  24. std::size_t getMaxSize(storm::expressions::Expression const& expression, std::unordered_map<storm::expressions::Variable, std::size_t> const& arrayVariableSizeMap) {
  25. return boost::any_cast<std::size_t>(expression.accept(*this, &arrayVariableSizeMap));
  26. }
  27. virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override {
  28. if (expression.getCondition()->containsVariables()) {
  29. return std::max<std::size_t>(boost::any_cast<std::size_t>(expression.getThenExpression()->accept(*this, data)), boost::any_cast<std::size_t>(expression.getElseExpression()->accept(*this, data)));
  30. } else {
  31. if (expression.getCondition()->evaluateAsBool()) {
  32. return boost::any_cast<std::size_t>(expression.getThenExpression()->accept(*this, data));
  33. }
  34. return boost::any_cast<std::size_t>(expression.getElseExpression()->accept(*this, data));
  35. }
  36. }
  37. virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override {
  38. return std::max<std::size_t>(boost::any_cast<std::size_t>(expression.getFirstOperand()->accept(*this, data)), boost::any_cast<std::size_t>(expression.getSecondOperand()->accept(*this, data)));
  39. }
  40. virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override {
  41. return std::max<std::size_t>(boost::any_cast<std::size_t>(expression.getFirstOperand()->accept(*this, data)), boost::any_cast<std::size_t>(expression.getSecondOperand()->accept(*this, data)));
  42. }
  43. virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override {
  44. return std::max<std::size_t>(boost::any_cast<std::size_t>(expression.getFirstOperand()->accept(*this, data)), boost::any_cast<std::size_t>(expression.getSecondOperand()->accept(*this, data)));
  45. }
  46. virtual boost::any visit(storm::expressions::VariableExpression const& expression, boost::any const& data) override {
  47. std::unordered_map<storm::expressions::Variable, std::size_t> const* arrayVariableSizeMap = boost::any_cast<std::unordered_map<storm::expressions::Variable, std::size_t> const*>(data);
  48. if (expression.getType().isArrayType()) {
  49. auto varIt = arrayVariableSizeMap->find(expression.getVariable());
  50. if (varIt != arrayVariableSizeMap->end()) {
  51. return varIt->second;
  52. }
  53. }
  54. return static_cast<std::size_t>(0);
  55. }
  56. virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override {
  57. return boost::any_cast<std::size_t>(expression.getOperand()->accept(*this, data));
  58. }
  59. virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override {
  60. return boost::any_cast<std::size_t>(expression.getOperand()->accept(*this, data));
  61. }
  62. virtual boost::any visit(storm::expressions::BooleanLiteralExpression const&, boost::any const&) override {
  63. return 0;
  64. }
  65. virtual boost::any visit(storm::expressions::IntegerLiteralExpression const&, boost::any const&) override {
  66. return 0;
  67. }
  68. virtual boost::any visit(storm::expressions::RationalLiteralExpression const&, boost::any const&) override {
  69. return 0;
  70. }
  71. virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const&) override {
  72. STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ").");
  73. return static_cast<std::size_t>(expression.size()->evaluateAsInt());
  74. }
  75. virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const&) override {
  76. if (!expression.size()->containsVariables()) {
  77. return static_cast<std::size_t>(expression.size()->evaluateAsInt());
  78. } else {
  79. auto vars = expression.size()->toExpression().getVariables();
  80. std::string variables = "";
  81. for (auto const& v : vars) {
  82. if (variables != "") {
  83. variables += ", ";
  84. }
  85. variables += v.getName();
  86. }
  87. if (vars.size() == 1) {
  88. variables = "variable " + variables;
  89. } else {
  90. variables = "variables " + variables;
  91. }
  92. STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Unable to determine array size: Size of ConstructorArrayExpression '" << expression << "' still contains the " << variables << ".");
  93. }
  94. }
  95. virtual boost::any visit(storm::expressions::ArrayAccessExpression const&, boost::any const&) override {
  96. STORM_LOG_WARN("Found Array access expression within an array expression. This is not expected since nested arrays are currently not supported.");
  97. return 0;
  98. }
  99. virtual boost::any visit(storm::expressions::FunctionCallExpression const&, boost::any const&) override {
  100. STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "Found Function call expression within an array expression. This is not expected since functions are expected to be eliminated at this point.");
  101. return 0;
  102. }
  103. };
  104. class ArrayExpressionEliminationVisitor : public storm::expressions::ExpressionVisitor, public storm::expressions::JaniExpressionVisitor {
  105. public:
  106. using storm::expressions::ExpressionVisitor::visit;
  107. typedef std::shared_ptr<storm::expressions::BaseExpression const> BaseExprPtr;
  108. class ResultType {
  109. public:
  110. ResultType(ResultType const& other) = default;
  111. ResultType(BaseExprPtr expression) : expression(expression), arrayOutOfBoundsMessage("") {}
  112. ResultType(std::string arrayOutOfBoundsMessage) : expression(nullptr), arrayOutOfBoundsMessage(arrayOutOfBoundsMessage) {}
  113. BaseExprPtr& expr() {
  114. STORM_LOG_ASSERT(!isArrayOutOfBounds(), "Tried to get the result expression, but " << arrayOutOfBoundsMessage);
  115. return expression;
  116. };
  117. bool isArrayOutOfBounds() { return arrayOutOfBoundsMessage != ""; };
  118. std::string const& outOfBoundsMessage() const { return arrayOutOfBoundsMessage; }
  119. private:
  120. BaseExprPtr expression;
  121. std::string arrayOutOfBoundsMessage;
  122. };
  123. ArrayExpressionEliminationVisitor(std::unordered_map<storm::expressions::Variable, std::vector<storm::jani::Variable const*>> const& replacements, std::unordered_map<storm::expressions::Variable, std::size_t> const& sizes) : replacements(replacements), arraySizes(sizes) {}
  124. virtual ~ArrayExpressionEliminationVisitor() = default;
  125. storm::expressions::Expression eliminate(storm::expressions::Expression const& expression) {
  126. // here, data is the accessed index of the most recent array access expression. Initially, there is none.
  127. auto res = boost::any_cast<ResultType>(expression.accept(*this, boost::any()));
  128. STORM_LOG_THROW(!res.isArrayOutOfBounds(), storm::exceptions::OutOfRangeException, res.outOfBoundsMessage());
  129. STORM_LOG_ASSERT(!containsArrayExpression(res.expr()->toExpression()), "Expression still contains array expressions. Before: " << std::endl << expression << std::endl << "After:" << std::endl << res.expr()->toExpression());
  130. return res.expr()->simplify();
  131. }
  132. virtual boost::any visit(storm::expressions::IfThenElseExpression const& expression, boost::any const& data) override {
  133. // for the condition expression, outer array accesses should not matter.
  134. ResultType conditionResult = boost::any_cast<ResultType>(expression.getCondition()->accept(*this, boost::any()));
  135. if (conditionResult.isArrayOutOfBounds()) {
  136. return conditionResult;
  137. }
  138. // We need to handle expressions of the kind '42<size : A[42] : 0', where size is a variable and A[42] might be out of bounds.
  139. ResultType thenResult = boost::any_cast<ResultType>(expression.getThenExpression()->accept(*this, data));
  140. ResultType elseResult = boost::any_cast<ResultType>(expression.getElseExpression()->accept(*this, data));
  141. if (thenResult.isArrayOutOfBounds()) {
  142. if (elseResult.isArrayOutOfBounds()) {
  143. return ResultType(thenResult.outOfBoundsMessage() + " and " + elseResult.outOfBoundsMessage());
  144. } else {
  145. // Assume the else expression
  146. return elseResult;
  147. }
  148. } else if (elseResult.isArrayOutOfBounds()) {
  149. // Assume the then expression
  150. return thenResult;
  151. } else {
  152. // If the arguments did not change, we simply push the expression itself.
  153. if (conditionResult.expr().get() == expression.getCondition().get() && thenResult.expr().get() == expression.getThenExpression().get() && elseResult.expr().get() == expression.getElseExpression().get()) {
  154. return ResultType(expression.getSharedPointer());
  155. } else {
  156. return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::IfThenElseExpression(expression.getManager(), thenResult.expr()->getType(), conditionResult.expr(), thenResult.expr(), elseResult.expr()))));
  157. }
  158. }
  159. }
  160. virtual boost::any visit(storm::expressions::BinaryBooleanFunctionExpression const& expression, boost::any const& data) override {
  161. STORM_LOG_ASSERT(data.empty(), "BinaryBooleanFunctionExpressions should not be direct subexpressions of array access expressions. However, the expression " << expression << " is.");
  162. ResultType firstResult = boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, data));
  163. ResultType secondResult = boost::any_cast<ResultType>(expression.getSecondOperand()->accept(*this, data));
  164. if (firstResult.isArrayOutOfBounds()) {
  165. return firstResult;
  166. } else if (secondResult.isArrayOutOfBounds()) {
  167. return secondResult;
  168. }
  169. // If the arguments did not change, we simply push the expression itself.
  170. if (firstResult.expr().get() == expression.getFirstOperand().get() && secondResult.expr().get() == expression.getSecondOperand().get()) {
  171. return ResultType(expression.getSharedPointer());
  172. } else {
  173. return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryBooleanFunctionExpression(expression.getManager(), expression.getType(), firstResult.expr(), secondResult.expr(), expression.getOperatorType()))));
  174. }
  175. }
  176. virtual boost::any visit(storm::expressions::BinaryNumericalFunctionExpression const& expression, boost::any const& data) override {
  177. STORM_LOG_ASSERT(data.empty(), "BinaryNumericalFunctionExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is.");
  178. ResultType firstResult = boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, data));
  179. ResultType secondResult = boost::any_cast<ResultType>(expression.getSecondOperand()->accept(*this, data));
  180. if (firstResult.isArrayOutOfBounds()) {
  181. return firstResult;
  182. } else if (secondResult.isArrayOutOfBounds()) {
  183. return secondResult;
  184. }
  185. // If the arguments did not change, we simply push the expression itself.
  186. if (firstResult.expr().get() == expression.getFirstOperand().get() && secondResult.expr().get() == expression.getSecondOperand().get()) {
  187. return ResultType(expression.getSharedPointer());
  188. } else {
  189. return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryNumericalFunctionExpression(expression.getManager(), expression.getType(), firstResult.expr(), secondResult.expr(), expression.getOperatorType()))));
  190. }
  191. }
  192. virtual boost::any visit(storm::expressions::BinaryRelationExpression const& expression, boost::any const& data) override {
  193. STORM_LOG_ASSERT(data.empty(), "BinaryRelationExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is.");
  194. ResultType firstResult = boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, data));
  195. ResultType secondResult = boost::any_cast<ResultType>(expression.getSecondOperand()->accept(*this, data));
  196. if (firstResult.isArrayOutOfBounds()) {
  197. return firstResult;
  198. } else if (secondResult.isArrayOutOfBounds()) {
  199. return secondResult;
  200. }
  201. // If the arguments did not change, we simply push the expression itself.
  202. if (firstResult.expr().get() == expression.getFirstOperand().get() && secondResult.expr().get() == expression.getSecondOperand().get()) {
  203. return ResultType(expression.getSharedPointer());
  204. } else {
  205. return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::BinaryRelationExpression(expression.getManager(), expression.getType(), firstResult.expr(), secondResult.expr(), expression.getRelationType()))));
  206. }
  207. }
  208. virtual boost::any visit(storm::expressions::VariableExpression const& expression, boost::any const& data) override {
  209. if (expression.getType().isArrayType()) {
  210. STORM_LOG_THROW(!data.empty(), storm::exceptions::NotSupportedException, "Unable to translate array variable to basic variable, since it does not seem to be within an array access expression.");
  211. uint64_t index = boost::any_cast<uint64_t>(data);
  212. STORM_LOG_ASSERT(replacements.find(expression.getVariable()) != replacements.end(), "Unable to find array variable " << expression << " in array replacements.");
  213. auto const& arrayVarReplacements = replacements.at(expression.getVariable());
  214. if (index >= arrayVarReplacements.size()) {
  215. return ResultType("Array index " + std::to_string(index) + " for variable " + expression.getVariableName() + " is out of bounds.");
  216. }
  217. return ResultType(arrayVarReplacements[index]->getExpressionVariable().getExpression().getBaseExpressionPointer());
  218. } else {
  219. STORM_LOG_ASSERT(data.empty(), "VariableExpression of non-array variable should not be a subexpressions of array access expressions. However, the expression " << expression << " is.");
  220. return ResultType(expression.getSharedPointer());
  221. }
  222. }
  223. virtual boost::any visit(storm::expressions::UnaryBooleanFunctionExpression const& expression, boost::any const& data) override {
  224. STORM_LOG_ASSERT(data.empty(), "UnaryBooleanFunctionExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is.");
  225. ResultType operandResult = boost::any_cast<ResultType>(expression.getOperand()->accept(*this, data));
  226. if (operandResult.isArrayOutOfBounds()) {
  227. return operandResult;
  228. }
  229. // If the argument did not change, we simply push the expression itself.
  230. if (operandResult.expr().get() == expression.getOperand().get()) {
  231. return ResultType(expression.getSharedPointer());
  232. } else {
  233. return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::UnaryBooleanFunctionExpression(expression.getManager(), expression.getType(), operandResult.expr(), expression.getOperatorType()))));
  234. }
  235. }
  236. virtual boost::any visit(storm::expressions::UnaryNumericalFunctionExpression const& expression, boost::any const& data) override {
  237. STORM_LOG_ASSERT(data.empty(), "UnaryBooleanFunctionExpression should not be direct subexpressions of array access expressions. However, the expression " << expression << " is.");
  238. ResultType operandResult = boost::any_cast<ResultType>(expression.getOperand()->accept(*this, data));
  239. if (operandResult.isArrayOutOfBounds()) {
  240. return operandResult;
  241. }
  242. // If the argument did not change, we simply push the expression itself.
  243. if (operandResult.expr().get() == expression.getOperand().get()) {
  244. return ResultType(expression.getSharedPointer());
  245. } else {
  246. return ResultType(std::const_pointer_cast<storm::expressions::BaseExpression const>(std::shared_ptr<storm::expressions::BaseExpression>(new storm::expressions::UnaryNumericalFunctionExpression(expression.getManager(), expression.getType(), operandResult.expr(), expression.getOperatorType()))));
  247. }
  248. }
  249. virtual boost::any visit(storm::expressions::BooleanLiteralExpression const& expression, boost::any const&) override {
  250. return ResultType(expression.getSharedPointer());
  251. }
  252. virtual boost::any visit(storm::expressions::IntegerLiteralExpression const& expression, boost::any const&) override {
  253. return ResultType(expression.getSharedPointer());
  254. }
  255. virtual boost::any visit(storm::expressions::RationalLiteralExpression const& expression, boost::any const&) override {
  256. return ResultType(expression.getSharedPointer());
  257. }
  258. virtual boost::any visit(storm::expressions::ValueArrayExpression const& expression, boost::any const& data) override {
  259. STORM_LOG_THROW(!data.empty(), storm::exceptions::NotSupportedException, "Unable to translate ValueArrayExpression to element expression since it does not seem to be within an array access expression.");
  260. uint64_t index = boost::any_cast<uint64_t>(data);
  261. STORM_LOG_ASSERT(expression.size()->isIntegerLiteralExpression(), "unexpected kind of size expression of ValueArrayExpression (" << expression.size()->toExpression() << ").");
  262. if (index >= static_cast<uint64_t>(expression.size()->evaluateAsInt())) {
  263. return ResultType("Array index " + std::to_string(index) + " for ValueArrayExpression " + expression.toExpression().toString() + " is out of bounds.");
  264. }
  265. return ResultType(boost::any_cast<ResultType>(expression.at(index)->accept(*this, boost::any())));
  266. }
  267. virtual boost::any visit(storm::expressions::ConstructorArrayExpression const& expression, boost::any const& data) override {
  268. STORM_LOG_THROW(!data.empty(), storm::exceptions::NotSupportedException, "Unable to translate ValueArrayExpression to element expression since it does not seem to be within an array access expression.");
  269. uint64_t index = boost::any_cast<uint64_t>(data);
  270. if (expression.size()->containsVariables()) {
  271. STORM_LOG_WARN("Ignoring length of constructorArrayExpression " << expression << " as it still contains variables.");
  272. } else {
  273. if (index >= static_cast<uint64_t>(expression.size()->evaluateAsInt())) {
  274. return ResultType("Array index " + std::to_string(index) + " for ConstructorArrayExpression " + expression.toExpression().toString() + " is out of bounds.");
  275. }
  276. }
  277. return ResultType(boost::any_cast<ResultType>(expression.at(index)->accept(*this, boost::any())));
  278. }
  279. virtual boost::any visit(storm::expressions::ArrayAccessExpression const& expression, boost::any const&) override {
  280. if (expression.getSecondOperand()->containsVariables()) {
  281. //get the size of the array expression
  282. uint64_t size = MaxArraySizeExpressionVisitor().getMaxSize(expression.getFirstOperand()->toExpression(), arraySizes);
  283. STORM_LOG_THROW(size > 0, storm::exceptions::NotSupportedException, "Unable to get size of array expression for array access " << expression << ".");
  284. uint64_t index = size - 1;
  285. storm::expressions::Expression result = boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, index)).expr()->toExpression();
  286. while (index > 0) {
  287. --index;
  288. storm::expressions::Expression isCurrentIndex = boost::any_cast<ResultType>(expression.getSecondOperand()->accept(*this, boost::any())).expr()->toExpression() == expression.getManager().integer(index);
  289. result = storm::expressions::ite(isCurrentIndex,
  290. boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, index)).expr()->toExpression(),
  291. result);
  292. }
  293. return ResultType(result.getBaseExpressionPointer());
  294. } else {
  295. uint64_t index = expression.getSecondOperand()->evaluateAsInt();
  296. return boost::any_cast<ResultType>(expression.getFirstOperand()->accept(*this, index));
  297. }
  298. }
  299. virtual boost::any visit(storm::expressions::FunctionCallExpression const&, boost::any const&) override {
  300. STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "Found Function call expression while eliminating array expressions. This is not expected since functions are expected to be eliminated at this point.");
  301. return false;
  302. }
  303. private:
  304. std::unordered_map<storm::expressions::Variable, std::vector<storm::jani::Variable const*>> const& replacements;
  305. std::unordered_map<storm::expressions::Variable, std::size_t> const& arraySizes;
  306. };
  307. class MaxArraySizeDeterminer : public ConstJaniTraverser {
  308. public:
  309. typedef std::unordered_map<storm::expressions::Variable, std::size_t>* MapPtr;
  310. MaxArraySizeDeterminer() = default;
  311. virtual ~MaxArraySizeDeterminer() = default;
  312. std::unordered_map<storm::expressions::Variable, std::size_t> getMaxSizes(Model const& model) {
  313. // We repeatedly determine the max array sizes until convergence. This is to cover assignments of one array variable to another (A := B)
  314. std::unordered_map<storm::expressions::Variable, std::size_t> result, previousResult;
  315. do {
  316. previousResult = result;
  317. ConstJaniTraverser::traverse(model, &result);
  318. } while (previousResult != result);
  319. return result;
  320. }
  321. virtual void traverse(Assignment const& assignment, boost::any const& data) override {
  322. if (assignment.lValueIsVariable() && assignment.getExpressionVariable().getType().isArrayType()) {
  323. auto& map = *boost::any_cast<MapPtr>(data);
  324. std::size_t newSize = MaxArraySizeExpressionVisitor().getMaxSize(assignment.getAssignedExpression(), map);
  325. auto insertionRes = map.emplace(assignment.getExpressionVariable(), newSize);
  326. if (!insertionRes.second) {
  327. insertionRes.first->second = std::max(newSize, insertionRes.first->second);
  328. }
  329. }
  330. }
  331. virtual void traverse(ArrayVariable const& variable, boost::any const& data) override {
  332. if (variable.hasInitExpression()) {
  333. auto& map = *boost::any_cast<MapPtr>(data);
  334. std::size_t newSize = MaxArraySizeExpressionVisitor().getMaxSize(variable.getInitExpression(), map);
  335. auto insertionRes = map.emplace(variable.getExpressionVariable(), newSize);
  336. if (!insertionRes.second) {
  337. insertionRes.first->second = std::max(newSize, insertionRes.first->second);
  338. }
  339. }
  340. }
  341. };
  342. class ArrayVariableReplacer : public JaniTraverser {
  343. public:
  344. typedef ArrayEliminatorData ResultType;
  345. using JaniTraverser::traverse;
  346. ArrayVariableReplacer(storm::expressions::ExpressionManager& expressionManager, bool keepNonTrivialArrayAccess, std::unordered_map<storm::expressions::Variable, std::size_t> const& arrayVarToSizesMap) : expressionManager(expressionManager) , keepNonTrivialArrayAccess(keepNonTrivialArrayAccess), arraySizes(arrayVarToSizesMap) {}
  347. virtual ~ArrayVariableReplacer() = default;
  348. ResultType replace(Model& model) {
  349. ResultType result;
  350. arrayExprEliminator = std::make_unique<ArrayExpressionEliminationVisitor>(result.replacements, arraySizes);
  351. for (auto const& arraySize : arraySizes) {
  352. result.replacements.emplace(arraySize.first, std::vector<storm::jani::Variable const*>(arraySize.second, nullptr));
  353. }
  354. traverse(model, &result);
  355. return result;
  356. }
  357. virtual void traverse(Model& model, boost::any const& data) override {
  358. // Insert fresh basic variables for global array variables
  359. auto& replacements = boost::any_cast<ResultType*>(data)->replacements;
  360. for (storm::jani::ArrayVariable const& arrayVariable : model.getGlobalVariables().getArrayVariables()) {
  361. std::vector<storm::jani::Variable const*>& basicVars = replacements.at(arrayVariable.getExpressionVariable());
  362. for (uint64_t index = 0; index < basicVars.size(); ++index) {
  363. basicVars[index] = &model.addVariable(*getBasicVariable(arrayVariable, index));
  364. }
  365. }
  366. // drop all occuring array variables
  367. auto elVars = model.getGlobalVariables().dropAllArrayVariables();
  368. auto& eliminatedArrayVariables = boost::any_cast<ResultType*>(data)->eliminatedArrayVariables;
  369. eliminatedArrayVariables.insert(eliminatedArrayVariables.end(), elVars.begin(), elVars.end());
  370. // Make new variable replacements known to the expression eliminator
  371. arrayExprEliminator = std::make_unique<ArrayExpressionEliminationVisitor>(replacements, arraySizes);
  372. for (auto& aut : model.getAutomata()) {
  373. traverse(aut, data);
  374. }
  375. // traversal of remaining components
  376. if (model.hasInitialStatesRestriction()) {
  377. model.setInitialStatesRestriction(arrayExprEliminator->eliminate(model.getInitialStatesRestriction()));
  378. }
  379. for (auto& nonTrivRew : model.getNonTrivialRewardExpressions()) {
  380. nonTrivRew.second = arrayExprEliminator->eliminate(nonTrivRew.second);
  381. }
  382. }
  383. virtual void traverse(Automaton& automaton, boost::any const& data) override {
  384. // No need to traverse the init restriction.
  385. // Insert fresh basic variables for local array variables
  386. auto& replacements = boost::any_cast<ResultType*>(data)->replacements;
  387. for (storm::jani::ArrayVariable const& arrayVariable : automaton.getVariables().getArrayVariables()) {
  388. std::vector<storm::jani::Variable const*>& basicVars = replacements.at(arrayVariable.getExpressionVariable());
  389. for (uint64_t index = 0; index < basicVars.size(); ++index) {
  390. basicVars[index] = &automaton.addVariable(*getBasicVariable(arrayVariable, index));
  391. }
  392. }
  393. // drop all occuring array variables
  394. auto elVars = automaton.getVariables().dropAllArrayVariables();
  395. auto& eliminatedArrayVariables = boost::any_cast<ResultType*>(data)->eliminatedArrayVariables;
  396. eliminatedArrayVariables.insert(eliminatedArrayVariables.end(), elVars.begin(), elVars.end());
  397. // Make new variable replacements known to the expression eliminator
  398. arrayExprEliminator = std::make_unique<ArrayExpressionEliminationVisitor>(replacements, arraySizes);
  399. for (auto& loc : automaton.getLocations()) {
  400. traverse(loc, data);
  401. }
  402. traverse(automaton.getEdgeContainer(), data);
  403. if (automaton.hasInitialStatesRestriction()) {
  404. automaton.setInitialStatesRestriction(arrayExprEliminator->eliminate(automaton.getInitialStatesRestriction()));
  405. }
  406. }
  407. virtual void traverse(Location& location, boost::any const& data) override {
  408. traverse(location.getAssignments(), data);
  409. if (location.hasTimeProgressInvariant()) {
  410. location.setTimeProgressInvariant(arrayExprEliminator->eliminate(location.getTimeProgressInvariant()));
  411. traverse(location.getTimeProgressInvariant(), data);
  412. }
  413. }
  414. void traverse(TemplateEdge& templateEdge, boost::any const& data) override {
  415. templateEdge.setGuard(arrayExprEliminator->eliminate(templateEdge.getGuard()));
  416. for (auto& dest : templateEdge.getDestinations()) {
  417. traverse(dest, data);
  418. }
  419. traverse(templateEdge.getAssignments(), data);
  420. }
  421. void traverse(Edge& edge, boost::any const& data) override {
  422. if (edge.hasRate()) {
  423. edge.setRate(arrayExprEliminator->eliminate(edge.getRate()));
  424. }
  425. for (auto& dest : edge.getDestinations()) {
  426. traverse(dest, data);
  427. }
  428. }
  429. void traverse(EdgeDestination& edgeDestination, boost::any const&) override {
  430. edgeDestination.setProbability(arrayExprEliminator->eliminate(edgeDestination.getProbability()));
  431. }
  432. virtual void traverse(OrderedAssignments& orderedAssignments, boost::any const& data) override {
  433. auto const& replacements = boost::any_cast<ResultType*>(data)->replacements;
  434. // Replace array occurrences in LValues and assigned expressions.
  435. std::vector<Assignment> newAssignments;
  436. if (!orderedAssignments.empty()) {
  437. int64_t level = orderedAssignments.getLowestLevel();
  438. std::unordered_map<storm::expressions::Variable, std::vector<Assignment const*>> collectedArrayAccessAssignments;
  439. for (Assignment const& assignment : orderedAssignments) {
  440. if (assignment.getLevel() != level) {
  441. STORM_LOG_ASSERT(assignment.getLevel() > level, "Ordered Assignment does not have the expected order.");
  442. for (auto const& arrayAssignments : collectedArrayAccessAssignments) {
  443. insertLValueArrayAccessReplacements(arrayAssignments.second, replacements.at(arrayAssignments.first), level, newAssignments);
  444. }
  445. collectedArrayAccessAssignments.clear();
  446. level = assignment.getLevel();
  447. }
  448. if (assignment.getLValue().isArrayAccess()) {
  449. if (!keepNonTrivialArrayAccess || !assignment.getLValue().getArrayIndex().containsVariables()) {
  450. auto insertionRes = collectedArrayAccessAssignments.emplace(assignment.getLValue().getArray().getExpressionVariable(), std::vector<Assignment const*>({&assignment}));
  451. if (!insertionRes.second) {
  452. insertionRes.first->second.push_back(&assignment);
  453. }
  454. } else {
  455. // Keeping array access LValue
  456. LValue newLValue(LValue(assignment.getLValue().getArray()), arrayExprEliminator->eliminate(assignment.getLValue().getArrayIndex()));
  457. newAssignments.emplace_back(newLValue, arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
  458. }
  459. } else if (assignment.getLValue().isVariable() && assignment.getVariable().isArrayVariable()) {
  460. STORM_LOG_ASSERT(assignment.getAssignedExpression().getType().isArrayType(), "Assigning a non-array expression to an array variable...");
  461. std::vector<storm::jani::Variable const*> const& arrayVariableReplacements = replacements.at(assignment.getExpressionVariable());
  462. // Get the maximum size of the array expression on the rhs
  463. uint64_t rhsSize = MaxArraySizeExpressionVisitor().getMaxSize(assignment.getAssignedExpression(), arraySizes);
  464. STORM_LOG_ASSERT(arrayVariableReplacements.size() >= rhsSize, "Array size too small.");
  465. for (uint64_t index = 0; index < arrayVariableReplacements.size(); ++index) {
  466. auto const& replacement = *arrayVariableReplacements[index];
  467. storm::expressions::Expression newRhs;
  468. if (index < rhsSize) {
  469. newRhs = std::make_shared<storm::expressions::ArrayAccessExpression>(expressionManager, assignment.getAssignedExpression().getType().getElementType(), assignment.getAssignedExpression().getBaseExpressionPointer(), expressionManager.integer(index).getBaseExpressionPointer())->toExpression();
  470. } else {
  471. newRhs = getOutOfBoundsValue(replacement);
  472. }
  473. newRhs = arrayExprEliminator->eliminate(newRhs);
  474. newAssignments.emplace_back(LValue(replacement), newRhs, level);
  475. }
  476. } else {
  477. newAssignments.emplace_back(assignment.getLValue(), arrayExprEliminator->eliminate(assignment.getAssignedExpression()), assignment.getLevel());
  478. }
  479. }
  480. for (auto const& arrayAssignments : collectedArrayAccessAssignments) {
  481. insertLValueArrayAccessReplacements(arrayAssignments.second, replacements.at(arrayAssignments.first), level, newAssignments);
  482. }
  483. collectedArrayAccessAssignments.clear();
  484. orderedAssignments.clear();
  485. for (auto const& assignment : newAssignments) {
  486. orderedAssignments.add(assignment);
  487. }
  488. }
  489. }
  490. private:
  491. std::shared_ptr<Variable> getBasicVariable(ArrayVariable const& arrayVariable, uint64_t index) const {
  492. std::string name = arrayVariable.getExpressionVariable().getName() + "_at_" + std::to_string(index);
  493. storm::expressions::Expression initValue;
  494. if (arrayVariable.hasInitExpression()) {
  495. initValue = arrayExprEliminator->eliminate(std::make_shared<storm::expressions::ArrayAccessExpression>(expressionManager, arrayVariable.getExpressionVariable().getType().getElementType(), arrayVariable.getInitExpression().getBaseExpressionPointer(), expressionManager.integer(index).getBaseExpressionPointer())->toExpression());
  496. }
  497. if (arrayVariable.getElementType() == ArrayVariable::ElementType::Int) {
  498. storm::expressions::Variable exprVariable = expressionManager.declareIntegerVariable(name);
  499. if (arrayVariable.hasElementTypeBound()) {
  500. if (initValue.isInitialized()) {
  501. return std::make_shared<BoundedIntegerVariable>(name, exprVariable, initValue, arrayVariable.isTransient(), arrayVariable.getLowerElementTypeBound(), arrayVariable.getUpperElementTypeBound());
  502. } else {
  503. return std::make_shared<BoundedIntegerVariable>(name, exprVariable, arrayVariable.getLowerElementTypeBound(), arrayVariable.getUpperElementTypeBound());
  504. }
  505. } else {
  506. if (initValue.isInitialized()) {
  507. return std::make_shared<UnboundedIntegerVariable>(name, exprVariable, initValue, arrayVariable.isTransient());
  508. } else {
  509. return std::make_shared<UnboundedIntegerVariable>(name, exprVariable);
  510. }
  511. }
  512. } else if (arrayVariable.getElementType() == ArrayVariable::ElementType::Real) {
  513. storm::expressions::Variable exprVariable = expressionManager.declareRationalVariable(name);
  514. if (initValue.isInitialized()) {
  515. return std::make_shared<RealVariable>(name, exprVariable, initValue, arrayVariable.isTransient());
  516. } else {
  517. return std::make_shared<RealVariable>(name, exprVariable);
  518. }
  519. } else if (arrayVariable.getElementType() == ArrayVariable::ElementType::Bool) {
  520. storm::expressions::Variable exprVariable = expressionManager.declareBooleanVariable(name);
  521. if (initValue.isInitialized()) {
  522. return std::make_shared<BooleanVariable>(name, exprVariable, initValue, arrayVariable.isTransient());
  523. } else {
  524. return std::make_shared<BooleanVariable>(name, exprVariable);
  525. }
  526. }
  527. STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Unhandled array base type.");
  528. return nullptr;
  529. }
  530. void insertLValueArrayAccessReplacements(std::vector<Assignment const*> const& arrayAccesses, std::vector<storm::jani::Variable const*> const& arrayVariableReplacements, int64_t level, std::vector<Assignment>& newAssignments) const {
  531. bool onlyConstantIndices = true;
  532. for (auto const& aa : arrayAccesses) {
  533. if (aa->getLValue().getArrayIndex().containsVariables()) {
  534. onlyConstantIndices = false;
  535. break;
  536. }
  537. }
  538. if (onlyConstantIndices) {
  539. for (auto const& aa : arrayAccesses) {
  540. LValue lvalue(*arrayVariableReplacements.at(aa->getLValue().getArrayIndex().evaluateAsInt()));
  541. newAssignments.emplace_back(lvalue, arrayExprEliminator->eliminate(aa->getAssignedExpression()), level);
  542. }
  543. } else {
  544. for (uint64_t index = 0; index < arrayVariableReplacements.size(); ++index) {
  545. storm::expressions::Expression assignedExpression = arrayVariableReplacements[index]->getExpressionVariable().getExpression();
  546. auto indexExpression = expressionManager.integer(index);
  547. for (auto const& aa : arrayAccesses) {
  548. assignedExpression = storm::expressions::ite(arrayExprEliminator->eliminate(aa->getLValue().getArrayIndex()) == indexExpression, arrayExprEliminator->eliminate(aa->getAssignedExpression()), assignedExpression);
  549. }
  550. newAssignments.emplace_back(LValue(*arrayVariableReplacements[index]), assignedExpression, level);
  551. }
  552. }
  553. }
  554. storm::expressions::Expression getOutOfBoundsValue(Variable const& var) const {
  555. if (var.hasInitExpression()) {
  556. return var.getInitExpression();
  557. }
  558. if (var.isBooleanVariable()) {
  559. return expressionManager.boolean(false);
  560. }
  561. if (var.isBoundedIntegerVariable()) {
  562. return var.asBoundedIntegerVariable().getLowerBound();
  563. }
  564. if (var.isUnboundedIntegerVariable()) {
  565. return expressionManager.integer(0);
  566. }
  567. if (var.isRealVariable()) {
  568. return expressionManager.rational(0.0);
  569. }
  570. STORM_LOG_THROW(false, storm::exceptions::UnexpectedException, "unhandled variabe type");
  571. return storm::expressions::Expression();
  572. }
  573. std::unique_ptr<ArrayExpressionEliminationVisitor> arrayExprEliminator;
  574. storm::expressions::ExpressionManager& expressionManager;
  575. bool const keepNonTrivialArrayAccess;
  576. std::unordered_map<storm::expressions::Variable, std::size_t> const& arraySizes;
  577. };
  578. } // namespace detail
  579. storm::expressions::Expression ArrayEliminatorData::transformExpression(storm::expressions::Expression const& arrayExpression) const {
  580. std::unordered_map<storm::expressions::Variable, std::size_t> arraySizes;
  581. for (auto const& r : replacements) {
  582. arraySizes.emplace(r.first, r.second.size());
  583. }
  584. detail::ArrayExpressionEliminationVisitor eliminator(replacements, arraySizes);
  585. return eliminator.eliminate(arrayExpression);
  586. }
  587. void ArrayEliminatorData::transformProperty(storm::jani::Property& property) const {
  588. property = property.substitute([this](storm::expressions::Expression const& exp) {return transformExpression(exp);});
  589. }
  590. ArrayEliminatorData ArrayEliminator::eliminate(Model& model, bool keepNonTrivialArrayAccess) {
  591. ArrayEliminatorData result;
  592. // Only perform actions if there actually are arrays.
  593. if (model.getModelFeatures().hasArrays()) {
  594. auto sizes = detail::MaxArraySizeDeterminer().getMaxSizes(model);
  595. result = detail::ArrayVariableReplacer(model.getExpressionManager(), keepNonTrivialArrayAccess, sizes).replace(model);
  596. if (!keepNonTrivialArrayAccess) {
  597. model.getModelFeatures().remove(ModelFeature::Arrays);
  598. }
  599. model.finalize();
  600. }
  601. STORM_LOG_ASSERT(!containsArrayExpression(model), "the model still contains array expressions.");
  602. return result;
  603. }
  604. }
  605. }