The source code and dockerfile for the GSW2024 AI Lab.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
This repo is archived. You can view files and clone it, but cannot push or open issues/pull-requests.

199 lines
9.3 KiB

2 months ago
  1. #include "Grid.h"
  2. #include <boost/algorithm/string/find.hpp>
  3. #include <algorithm>
  4. Grid::Grid(cells gridCells, cells background, const std::map<coordinates, float> &stateRewards, const float probIntended, const float faultyProbability)
  5. : allGridCells(gridCells), background(background), stateRewards(stateRewards), probIntended(probIntended), faultyProbability(faultyProbability)
  6. {
  7. cell max = allGridCells.at(allGridCells.size() - 1);
  8. maxBoundaries = std::make_pair(max.column - 1, max.row - 1);
  9. std::copy_if(gridCells.begin(), gridCells.end(), std::back_inserter(walls), [](cell c) { return c.type == Type::Wall; });
  10. std::copy_if(gridCells.begin(), gridCells.end(), std::back_inserter(lava), [](cell c) { return c.type == Type::Lava; });
  11. std::copy_if(gridCells.begin(), gridCells.end(), std::back_inserter(floor), [](cell c) { return c.type == Type::Floor; }); // TODO CHECK IF ALL AGENTS ARE ADDED TO FLOOR
  12. std::copy_if(background.begin(), background.end(), std::back_inserter(slipperyNorth), [](cell c) { return c.type == Type::SlipperyNorth; });
  13. std::copy_if(background.begin(), background.end(), std::back_inserter(slipperyEast), [](cell c) { return c.type == Type::SlipperyEast; });
  14. std::copy_if(background.begin(), background.end(), std::back_inserter(slipperySouth), [](cell c) { return c.type == Type::SlipperySouth; });
  15. std::copy_if(background.begin(), background.end(), std::back_inserter(slipperyWest), [](cell c) { return c.type == Type::SlipperyWest; });
  16. std::copy_if(background.begin(), background.end(), std::back_inserter(slipperyNorthWest), [](cell c) { return c.type == Type::SlipperyNorthWest; });
  17. std::copy_if(background.begin(), background.end(), std::back_inserter(slipperyNorthEast), [](cell c) { return c.type == Type::SlipperyNorthEast; });
  18. std::copy_if(background.begin(), background.end(), std::back_inserter(slipperySouthWest), [](cell c) { return c.type == Type::SlipperySouthWest; });
  19. std::copy_if(background.begin(), background.end(), std::back_inserter(slipperySouthEast), [](cell c) { return c.type == Type::SlipperySouthEast; });
  20. std::copy_if(gridCells.begin(), gridCells.end(), std::back_inserter(lockedDoors), [](cell c) { return c.type == Type::LockedDoor; });
  21. std::copy_if(gridCells.begin(), gridCells.end(), std::back_inserter(unlockedDoors), [](cell c) { return c.type == Type::Door; });
  22. std::copy_if(gridCells.begin(), gridCells.end(), std::back_inserter(goals), [](cell c) { return c.type == Type::Goal; });
  23. std::copy_if(gridCells.begin(), gridCells.end(), std::back_inserter(keys), [](cell c) { return c.type == Type::Key; });
  24. std::copy_if(gridCells.begin(), gridCells.end(), std::back_inserter(boxes), [](cell c) { return c.type == Type::Box; });
  25. std::copy_if(gridCells.begin(), gridCells.end(), std::back_inserter(balls), [](cell c) { return c.type == Type::Ball; });
  26. std::copy_if(gridCells.begin(), gridCells.end(), std::back_inserter(adversaries), [](cell c) { return c.type == Type::Adversary; });
  27. agent = *std::find_if(gridCells.begin(), gridCells.end(), [](cell c) { return c.type == Type::Agent; });
  28. floor.push_back(agent);
  29. agentNameAndPositionMap.insert({ "Agent", agent.getCoordinates() });
  30. for(auto const& adversary : adversaries) {
  31. std::string color = adversary.getColor();
  32. color.at(0) = std::toupper(color.at(0));
  33. try {
  34. auto success = agentNameAndPositionMap.insert({ color, adversary.getCoordinates() });
  35. floor.push_back(adversary);
  36. if(!success.second) {
  37. throw std::logic_error("Agent with " + color + " already present\n");
  38. }
  39. } catch(const std::logic_error& e) {
  40. std::cerr << "Expected agents colors to be different. Agent with color : '" << color << "' already present." << std::endl;
  41. throw;
  42. }
  43. }
  44. for(auto const& key : keys) {
  45. std::string color = key.getColor();
  46. try {
  47. auto success = keyNameAndPositionMap.insert({color, key.getCoordinates() });
  48. if (!success.second) {
  49. throw std::logic_error("Multiple keys with same color not supported " + color + "\n");
  50. }
  51. } catch(const std::logic_error& e) {
  52. std::cerr << "Expected key colors to be different. Key with color : '" << color << "' already present." << std::endl;
  53. throw;
  54. }
  55. }
  56. for(auto const& color : allColors) {
  57. cells cellsOfColor;
  58. std::copy_if(background.begin(), background.end(), std::back_inserter(cellsOfColor), [&color](cell c) {
  59. return c.type == Type::Floor && c.color == color;
  60. });
  61. if(cellsOfColor.size() > 0) {
  62. backgroundTiles.emplace(color, cellsOfColor);
  63. }
  64. }
  65. if (adversaries.empty()) {
  66. modelType = prism::ModelType::MDP;
  67. } else {
  68. modelType = prism::ModelType::SMG;
  69. }
  70. }
  71. std::ostream& operator<<(std::ostream& os, const Grid& grid) {
  72. int lastRow = 1;
  73. for(auto const& cell : grid.allGridCells) {
  74. if(lastRow != cell.row)
  75. os << std::endl;
  76. os << static_cast<char>(cell.type) << static_cast<char>(cell.color);
  77. lastRow = cell.row;
  78. }
  79. return os;
  80. }
  81. cells Grid::getGridCells() {
  82. return allGridCells;
  83. }
  84. bool Grid::isBlocked(coordinates p) {
  85. return isWall(p);
  86. }
  87. bool Grid::isWall(coordinates p) {
  88. return std::find_if(walls.begin(), walls.end(),
  89. [p](cell cell) {
  90. return cell.row == p.second && cell.column == p.first;
  91. }) != walls.end();
  92. }
  93. void Grid::applyOverwrites(std::string& str, std::vector<Configuration>& configuration) {
  94. for (auto& config : configuration) {
  95. if (!config.overwrite_) {
  96. continue;
  97. }
  98. for (auto& index : config.indexes_) {
  99. size_t start_pos;
  100. std::string search;
  101. if (config.type_ == ConfigType::Formula) {
  102. search = "formula " + config.identifier_;
  103. } else if (config.type_ == ConfigType::Label) {
  104. search = "label " + config.identifier_;
  105. } else if (config.type_ == ConfigType::Module) {
  106. search = config.identifier_;
  107. } else if (config.type_ == ConfigType::UpdateOnly) {
  108. search = config.identifier_;
  109. } else if (config.type_ == ConfigType::GuardOnly) {
  110. search = config.identifier_;
  111. }
  112. else if (config.type_ == ConfigType::Constant) {
  113. search = config.identifier_;
  114. }
  115. auto iter = boost::find_nth(str, search, index);
  116. auto end_identifier = config.end_identifier_;
  117. start_pos = std::distance(str.begin(), iter.begin());
  118. size_t end_pos = str.find(end_identifier, start_pos);
  119. if (config.type_ == ConfigType::GuardOnly || config.type_ == ConfigType::Module) {
  120. start_pos += search.length();
  121. } else if (config.type_ == ConfigType::UpdateOnly) {
  122. start_pos = str.find("->", start_pos) + 2;
  123. }
  124. if (end_pos != std::string::npos && end_pos != 0) {
  125. std::string expression = config.expression_;
  126. str.replace(start_pos, end_pos - start_pos , expression);
  127. }
  128. }
  129. }
  130. }
  131. void Grid::printToPrism(std::ostream& os, std::vector<Configuration>& configuration) {
  132. cells northRestriction, eastRestriction, southRestriction, westRestriction;
  133. cells walkable = floor;
  134. walkable.insert(walkable.end(), goals.begin(), goals.end());
  135. walkable.insert(walkable.end(), boxes.begin(), boxes.end());
  136. walkable.insert(walkable.end(), lava.begin(), lava.end());
  137. walkable.insert(walkable.end(), keys.begin(), keys.end());
  138. walkable.insert(walkable.end(), balls.begin(), balls.end());
  139. for(auto const& c : walkable) {
  140. if(isWall(c.getNorth())) northRestriction.push_back(c);
  141. if(isWall(c.getEast())) eastRestriction.push_back(c);
  142. if(isWall(c.getSouth())) southRestriction.push_back(c);
  143. if(isWall(c.getWest())) westRestriction.push_back(c);
  144. }
  145. std::map<std::string, cells> wallRestrictions = {{"North", northRestriction}, {"East", eastRestriction}, {"South", southRestriction}, {"West", westRestriction}};
  146. std::map<std::string, cells> slipperyTiles = {{"North", slipperyNorth}, {"East", slipperyEast}, {"South", slipperySouth}, {"West", slipperyWest}, {"NorthWest", slipperyNorthWest}, {"NorthEast", slipperyNorthEast},{"SouthWest", slipperySouthWest},{"SouthEast", slipperySouthEast}};
  147. std::vector<AgentName> agentNames;
  148. std::transform(agentNameAndPositionMap.begin(),
  149. agentNameAndPositionMap.end(),
  150. std::back_inserter(agentNames),
  151. [](const std::map<AgentNameAndPosition::first_type,AgentNameAndPosition::second_type>::value_type &pair){return pair.first;});
  152. std::string agentName = agentNames.at(0);
  153. prism::PrismFormulaPrinter formulas(os, wallRestrictions, walls, lockedDoors, unlockedDoors, keys, slipperyTiles, lava, goals, agentNameAndPositionMap, faultyProbability > 0.0);
  154. prism::PrismModulesPrinter modules(os, modelType, maxBoundaries, lockedDoors, unlockedDoors, keys, slipperyTiles, agentNameAndPositionMap, configuration, probIntended, faultyProbability, !lava.empty(), !goals.empty());
  155. modules.printModelType(modelType);
  156. for(const auto &agentName : agentNames) {
  157. formulas.print(agentName);
  158. }
  159. if(agentNameAndPositionMap.size() > 1) formulas.printCollisionFormula(agentName);
  160. formulas.printInitStruct();
  161. modules.print();
  162. //if(!stateRewards.empty()) {
  163. // modules.printRewards(os, agentName, stateRewards, lava, goals, backgroundTiles);
  164. //}
  165. //if (!configuration.empty()) {
  166. // modules.printConfiguration(os, configuration);
  167. //}
  168. }
  169. void Grid::setModelType(prism::ModelType type)
  170. {
  171. modelType = type;
  172. }