From 0d8b5bc0c2d8fd2c9765206ef267522bb639cbfe Mon Sep 17 00:00:00 2001
From: Thomas Knoll <thomas.knolł@student.tugraz.at>
Date: Fri, 1 Sep 2023 14:38:37 +0200
Subject: [PATCH] added module config added overwrite some fixes

---
 main.cpp                     |  7 ++++-
 util/ConfigYaml.cpp          | 58 ++++++++++++++++++++++++++----------
 util/ConfigYaml.h            | 30 ++++++++++++-------
 util/Grid.cpp                | 26 +++++++++++++++-
 util/Grid.h                  |  1 +
 util/PrismModulesPrinter.cpp | 36 ++++++++++++++--------
 util/PrismModulesPrinter.h   |  4 ++-
 7 files changed, 121 insertions(+), 41 deletions(-)

diff --git a/main.cpp b/main.cpp
index 1eaa0c7..31e1b32 100644
--- a/main.cpp
+++ b/main.cpp
@@ -169,7 +169,12 @@ int main(int argc, char* argv[]) {
     if(ok) {
       Grid grid(contentCells, backgroundCells, gridOptions, stateRewards);
       //grid.printToPrism(std::cout, prism::ModelType::MDP);
-      grid.printToPrism(file, configurations ,prism::ModelType::MDP);
+      std::stringstream ss;
+      // grid.printToPrism(file, configurations ,prism::ModelType::MDP);
+      grid.printToPrism(ss, configurations ,prism::ModelType::MDP);
+      std::string str = ss.str();
+      grid.applyOverwrites(str, configurations);
+      file << str;
     }
   } catch(qi::expectation_failure<pos_iterator_t> const& e) {
     std::cout << "expected: "; print_info(e.what_);
diff --git a/util/ConfigYaml.cpp b/util/ConfigYaml.cpp
index 4e5d6ef..014099e 100644
--- a/util/ConfigYaml.cpp
+++ b/util/ConfigYaml.cpp
@@ -24,6 +24,30 @@ std::ostream& operator << (std::ostream& os, const Module& module) {
     return os;
 }
 
+std::string Label::createExpression() const {
+    if (overwrite_) {
+        return "label \"" + label_ + "\" = " + text_ + "; // Overwrite";
+    } 
+
+    return "label \"" + label_ + "\" = " + text_ + ";";
+}
+
+std::string Formula::createExpression() const {
+    if (overwrite_) {
+        return "formula " + formula_ + " = " + content_ + "; // Overwrite";
+    }
+
+    return "formula " + formula_ + " = " + content_ + ";";
+}
+
+std::string Action::createExpression() const {
+    if (overwrite_) {
+        return action_  + "\t" + guard_ + "-> " + update_ + "; // Overwrite";
+    }
+    
+    return "\t" + action_  + "\t" + guard_ + "-> " + update_ + ";";
+}
+
 YAML::Node YAML::convert<Module>::encode(const Module& rhs) {
     YAML::Node node;
     
@@ -122,28 +146,30 @@ bool YAML::convert<Formula>::decode(const YAML::Node& node, Formula& rhs) {
         std::vector<Configuration> configuration;
 
         try {
-        YAML::Node config = YAML::LoadFile(file_);  
+            YAML::Node config = YAML::LoadFile(file_);  
 
-        const std::vector<Label> labels = config["labels"].as<std::vector<Label>>();
-        const std::vector<Formula> formulas = config["formulas"].as<std::vector<Formula>>();
-        const std::vector<Module> modules = config["modules"].as<std::vector<Module>>();
+            const std::vector<Label> labels = config["labels"].as<std::vector<Label>>();
+            const std::vector<Formula> formulas = config["formulas"].as<std::vector<Formula>>();
+            const std::vector<Module> modules = config["modules"].as<std::vector<Module>>();
 
-        for (auto& label : labels) {
-            configuration.push_back({label.text_, label.label_, ConfigType::Label, label.overwrite_});
-        }
-        for (auto& formula : formulas) {
-            configuration.push_back({formula.content_, formula.formula_ , ConfigType::Formula, formula.overwrite_});
-        }
-        for (auto& module : modules) {
-            std::cout << module << std::endl;
-        }
+            for (auto& label : labels) {
+                configuration.push_back({label.createExpression(), label.label_ , ConfigType::Label, label.overwrite_});
+            }
+            for (auto& formula : formulas) {
+                configuration.push_back({formula.createExpression(), formula.formula_ ,ConfigType::Formula, formula.overwrite_});
+            }
+            for (auto& module : modules) {
+                for (auto& action : module.actions_) {
+                    configuration.push_back({action.createExpression(), action.action_, ConfigType::Module, action.overwrite_, module.module_});
+                }
+            }
 
 
         }
         catch(const std::exception& e) {
-        std::cout << "Exception '" << typeid(e).name() << "' caught:" << std::endl;
-        std::cout << "\t" << e.what() << std::endl;
-        std::cout << "while parsing configuration " << file_ << std::endl;
+            std::cout << "Exception '" << typeid(e).name() << "' caught:" << std::endl;
+            std::cout << "\t" << e.what() << std::endl;
+            std::cout << "while parsing configuration " << file_ << std::endl;
         }
 
         return configuration;
diff --git a/util/ConfigYaml.h b/util/ConfigYaml.h
index 1066ba0..4ca26f5 100644
--- a/util/ConfigYaml.h
+++ b/util/ConfigYaml.h
@@ -5,7 +5,6 @@
 
 #include "yaml-cpp/yaml.h"
 
-typedef std::string expressions;
 
 enum class ConfigType : char {
   Label = 'L',
@@ -15,20 +14,25 @@ enum class ConfigType : char {
 
 struct Configuration
 {
-  expressions expressions_;
-  std::string derivation_;
+  std::string module_ {};
+  std::string expression_{};
+  std::string identifier_{};
   ConfigType type_ {ConfigType::Label};
-  bool overwrite_;
+  bool overwrite_ {false};
 
   Configuration() = default;
-  Configuration(std::string expression, std::string derivation, ConfigType type, bool overwrite = false) : expressions_(expression), derivation_(derivation), type_(type), overwrite_(overwrite) {}
+  Configuration(std::string expression
+                , std::string identifier
+                , ConfigType type
+                , bool overwrite = false
+                , std::string module = "") : expression_(expression), identifier_(identifier), type_(type), overwrite_(overwrite), module_{module} {}
+  
   ~Configuration() = default;
   Configuration(const Configuration&) = default;
 
   friend std::ostream& operator << (std::ostream& os, const Configuration& config) {
     os << "Configuration with Type: " << static_cast<char>(config.type_) << std::endl; 
-    os << "\tExpression=" << config.expressions_ << std::endl;
-    return os << "\tDerviation=" << config.derivation_;
+    return os << "\tExpression=" << config.expression_ << std::endl;
   }
 };
 
@@ -40,7 +44,9 @@ struct Label {
   public:
   std::string text_;
   std::string label_;
-  bool overwrite_;
+  bool overwrite_{false};
+
+  std::string createExpression() const;
 
   friend std::ostream& operator <<(std::ostream &os, const Label& label);
 };
@@ -51,7 +57,9 @@ struct Formula {
   public:
   std::string formula_;
   std::string content_;
-  bool overwrite_;
+  bool overwrite_ {false};
+
+  std::string createExpression() const;
 
   friend std::ostream& operator << (std::ostream &os, const Formula& formula);
 };
@@ -61,7 +69,9 @@ struct Action {
   std::string action_;
   std::string guard_;
   std::string update_;
-  bool overwrite_;
+  bool overwrite_ {false};
+
+  std::string createExpression() const;
 
   friend std::ostream& operator << (std::ostream& os, const Action& action);
 };
diff --git a/util/Grid.cpp b/util/Grid.cpp
index b45266c..2ef4871 100644
--- a/util/Grid.cpp
+++ b/util/Grid.cpp
@@ -129,6 +129,30 @@ bool Grid::isBox(coordinates p) {
       }) != boxes.end();
 }
 
+void Grid::applyOverwrites(std::string& str, std::vector<Configuration>& configuration) {
+  for (auto& config : configuration) {
+    if (!config.overwrite_) {
+      continue;
+    }
+      std::cout << "Searching for " << config.identifier_ << std::endl;
+      size_t start_pos;
+      
+      if (config.type_ == ConfigType::Formula) {
+        start_pos = str.find("formula " + config.identifier_);
+      } else if (config.type_ == ConfigType::Label) {
+        start_pos = str.find("label " + config.identifier_);
+      } else if (config.type_ == ConfigType::Module) {
+        start_pos = str.find(config.identifier_);
+      }
+
+      size_t end_pos = str.find(';', start_pos) + 1;
+
+      std::string expression = config.expression_;
+    
+      str.replace(start_pos, end_pos - start_pos , expression);
+  }
+}
+
 void Grid::printToPrism(std::ostream& os, std::vector<Configuration>& configuration ,const prism::ModelType& modelType) {
   cells northRestriction;
   cells eastRestriction;
@@ -151,7 +175,7 @@ void Grid::printToPrism(std::ostream& os, std::vector<Configuration>& configurat
     if(isBlocked(c.getWest()))   westRestriction.push_back(c);
   }
 
-  prism::PrismModulesPrinter printer(modelType, agentNameAndPositionMap.size(), gridOptions.enforceOneWays);
+  prism::PrismModulesPrinter printer(modelType, agentNameAndPositionMap.size(), configuration, gridOptions.enforceOneWays);
   printer.printModel(os, modelType);
   if(modelType == prism::ModelType::SMG) {
     printer.printGlobalMoveVariable(os, agentNameAndPositionMap.size());
diff --git a/util/Grid.h b/util/Grid.h
index 18cfd95..a7c5667 100644
--- a/util/Grid.h
+++ b/util/Grid.h
@@ -30,6 +30,7 @@ class Grid {
     bool isKey(coordinates p);
     bool isBox(coordinates p);
     void printToPrism(std::ostream &os, std::vector<Configuration>& configuration, const prism::ModelType& modelType);
+    void applyOverwrites(std::string& str, std::vector<Configuration>& configuration);
 
     std::array<bool, 8> getWalkableDirOf8Neighborhood(cell c);
 
diff --git a/util/PrismModulesPrinter.cpp b/util/PrismModulesPrinter.cpp
index 9d50860..7e804cd 100644
--- a/util/PrismModulesPrinter.cpp
+++ b/util/PrismModulesPrinter.cpp
@@ -5,8 +5,8 @@
 
 namespace prism {
 
-  PrismModulesPrinter::PrismModulesPrinter(const ModelType &modelType, const size_t &numberOfPlayer, const bool enforceOneWays)
-    : modelType(modelType), numberOfPlayer(numberOfPlayer), enforceOneWays(enforceOneWays) {
+  PrismModulesPrinter::PrismModulesPrinter(const ModelType &modelType, const size_t &numberOfPlayer, std::vector<Configuration> config, const bool enforceOneWays)
+    : modelType(modelType), numberOfPlayer(numberOfPlayer), enforceOneWays(enforceOneWays), configuration(config) {
   }
 
   std::ostream& PrismModulesPrinter::printModel(std::ostream &os, const ModelType &modelType) {
@@ -222,18 +222,12 @@ namespace prism {
     os << "\n// Configuration\n";
     
     for (auto& configuration : configurations) {
-      if (configuration.type_ == ConfigType::Label) {
-        os << "label \"" << configuration.derivation_ << "\" = ";
-      }
-      else if (configuration.type_ == ConfigType::Formula) {
-        os << "formula " << configuration.derivation_ << " = ";
-      }
-
-      for (auto& expr : configuration.expressions_) {
-        os << expr;
+      std::cout << configuration.overwrite_ << std::endl;
+      if (configuration.overwrite_ || configuration.type_ == ConfigType::Module) {
+        continue;
       }
       
-      os << ";\n";
+      os << configuration.expression_ << std::endl;
     }
 
     return os;
@@ -362,7 +356,25 @@ namespace prism {
       printMovementActions(os, agentName, agentIndex, agentWithView, probability);
     }
     printDoneActions(os, agentName, agentIndex);
+
+    printConfiguredActions(os, agentName);
+
+    os << "\n";
+    return os;
+  }
+
+  std::ostream& PrismModulesPrinter::printConfiguredActions(std::ostream &os, const AgentName &agentName) {
+    os << "\t//Configuration \n";
+
+
+    for (auto& config : configuration) {
+      if (config.type_ == ConfigType::Module && !config.overwrite_ && agentName == config.module_) {
+        os << config.expression_ ;
+      }
+    }
+
     os << "\n";
+
     return os;
   }
 
diff --git a/util/PrismModulesPrinter.h b/util/PrismModulesPrinter.h
index 9e25d44..85483ab 100644
--- a/util/PrismModulesPrinter.h
+++ b/util/PrismModulesPrinter.h
@@ -9,7 +9,7 @@
 namespace prism {
   class PrismModulesPrinter {
     public:
-      PrismModulesPrinter(const ModelType &modelType, const size_t &numberOfPlayer, const bool enforceOneWays = false);
+      PrismModulesPrinter(const ModelType &modelType, const size_t &numberOfPlayer, std::vector<Configuration> config ,const bool enforceOneWays = false);
 
       std::ostream& printRestrictionFormula(std::ostream& os, const AgentName &agentName, const std::string &direction, const cells &cells);
       std::ostream& printIsOnSlipperyFormula(std::ostream& os, const AgentName &agentName, const std::vector<std::reference_wrapper<cells>> &slipperyCollection, const cells &slipperyNorth, const cells &slipperyEast, const cells &slipperySouth, const cells &slipperyWest);
@@ -78,6 +78,7 @@ namespace prism {
       std::ostream& printRewards(std::ostream &os, const AgentName &agentName, const std::map<coordinates, float> &stateRewards, const cells &lava, const cells &goals, const std::map<Color, cells> &backgroundTiles);
 
       std::ostream& printConfiguration(std::ostream &os, const std::vector<Configuration>& configurations);
+      std::ostream& printConfiguredActions(std::ostream &os, const AgentName &agentName);
 
       std::string moveGuard(const size_t &agentIndex);
       std::string moveUpdate(const size_t &agentIndex);
@@ -90,5 +91,6 @@ namespace prism {
       ModelType const& modelType;
       const size_t numberOfPlayer;
       bool enforceOneWays;
+      std::vector<Configuration> configuration;
   };
 }