From e001d8dbcb281609c66f4760b32cf83da1d0b89d Mon Sep 17 00:00:00 2001
From: Thomas Knoll <thomas.knolł@student.tugraz.at>
Date: Thu, 10 Aug 2023 12:51:02 +0200
Subject: [PATCH] changed json export in pre scheduler

---
 src/storm/api/export.h              | 11 +++--
 src/storm/shields/OptimalShield.cpp |  2 +-
 src/storm/shields/PostShield.cpp    |  2 +-
 src/storm/storage/PostScheduler.cpp | 73 +++++++++++++++++++++++++++++
 src/storm/storage/PostScheduler.h   |  8 ++++
 src/storm/storage/PreScheduler.cpp  |  8 ++--
 6 files changed, 95 insertions(+), 9 deletions(-)

diff --git a/src/storm/api/export.h b/src/storm/api/export.h
index 078c2fc46..452e957d2 100644
--- a/src/storm/api/export.h
+++ b/src/storm/api/export.h
@@ -65,10 +65,15 @@ namespace storm {
         }
 
         template <typename ValueType, typename IndexType>
-        void exportShield(std::shared_ptr<storm::models::sparse::Model<ValueType>> const& model, std::shared_ptr<tempest::shields::AbstractShield<ValueType, IndexType>> const& shield) {
+        void exportShield(std::shared_ptr<storm::models::sparse::Model<ValueType>> const& model, std::shared_ptr<tempest::shields::AbstractShield<ValueType, IndexType>> const& shield, std::string const& filename) {
             std::ofstream stream;
-            storm::utility::openFile(shield->getShieldFileName(), stream);
-            shield->printToStream(stream, model);
+            storm::utility::openFile(filename, stream);
+            std::string jsonFileExtension = ".json";
+            if (filename.size() > 4 && std::equal(jsonFileExtension.rbegin(), jsonFileExtension.rend(), filename.rbegin())) {
+                shield->printJsonToStream(stream, model);
+            } else {
+                shield->printToStream(stream, model);
+            }
             storm::utility::closeFile(stream);
         }
         
diff --git a/src/storm/shields/OptimalShield.cpp b/src/storm/shields/OptimalShield.cpp
index ff12e1c12..6d03e77da 100644
--- a/src/storm/shields/OptimalShield.cpp
+++ b/src/storm/shields/OptimalShield.cpp
@@ -71,7 +71,7 @@ namespace tempest {
 
         template<typename ValueType, typename IndexType>
         void OptimalShield<ValueType, IndexType>::printJsonToStream(std::ostream& out, std::shared_ptr<storm::models::sparse::Model<ValueType>> const& model) {
-            STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "not supported yet");
+            this->construct().printJsonToStream(out, model);
         }
 
         // Explicitly instantiate appropriate classes
diff --git a/src/storm/shields/PostShield.cpp b/src/storm/shields/PostShield.cpp
index b9165b45e..6dffd5344 100644
--- a/src/storm/shields/PostShield.cpp
+++ b/src/storm/shields/PostShield.cpp
@@ -75,7 +75,7 @@ namespace tempest {
 
         template<typename ValueType, typename IndexType>
         void PostShield<ValueType, IndexType>::printJsonToStream(std::ostream& out, std::shared_ptr<storm::models::sparse::Model<ValueType>> const& model) {
-            STORM_LOG_THROW(false, storm::exceptions::InvalidOperationException, "not supported yet");
+            this->construct().printJsonToStream(out, model);
         }
 
 
diff --git a/src/storm/storage/PostScheduler.cpp b/src/storm/storage/PostScheduler.cpp
index 3131e21c9..35873ea31 100644
--- a/src/storm/storage/PostScheduler.cpp
+++ b/src/storm/storage/PostScheduler.cpp
@@ -128,6 +128,79 @@ namespace storm {
             out << "___________________________________________________________________" << std::endl;
         }
 
+        template <typename ValueType>
+        void PostScheduler<ValueType>::printJsonToStream(std::ostream& out, std::shared_ptr<storm::models::sparse::Model<ValueType>> model, bool skipUniqueChoices) const {
+            // STORM_LOG_THROW(model == nullptr || model->getNumberOfStates() == schedulerChoices.front().size(), storm::exceptions::InvalidOperationException, "The given model is not compatible with this scheduler.");
+            // STORM_LOG_WARN_COND(!(skipUniqueChoices && model == nullptr), "Can not skip unique choices if the model is not given.");
+            // storm::json<storm::RationalNumber> output;
+            // for (uint64_t state = 0; state < schedulerChoices.front().size(); ++state) {
+            //     // Check whether the state is skipped
+            //     if (skipUniqueChoices && model != nullptr && model->getTransitionMatrix().getRowGroupSize(state) == 1) {
+            //         continue;
+            //     }
+
+            //     for (uint_fast64_t memoryState = 0; memoryState < getNumberOfMemoryStates(); ++memoryState) {
+            //         storm::json<storm::RationalNumber> stateChoicesJson;
+            //         if (model && model->hasStateValuations()) {
+            //             stateChoicesJson["s"] = model->getStateValuations().template toJson<storm::RationalNumber>(state);
+            //         } else {
+            //             stateChoicesJson["s"] = state;
+            //         }
+
+            //         if (!isMemorylessScheduler()) {
+            //             stateChoicesJson["m"] = memoryState;
+            //         }
+
+            //         auto const &choice = schedulerChoices[memoryState][state];
+            //         storm::json<storm::RationalNumber> choicesJson;
+
+            //         for (auto const &choiceProbPair : choice.getChoiceMap()) {
+            //             uint64_t globalChoiceIndex = model->getTransitionMatrix().getRowGroupIndices()[state] ;//+ choiceProbPair.first;
+            //             storm::json<storm::RationalNumber> choiceJson;
+            //             if (model && model->hasChoiceOrigins() &&
+            //                 model->getChoiceOrigins()->getIdentifier(globalChoiceIndex) !=
+            //                 model->getChoiceOrigins()->getIdentifierForChoicesWithNoOrigin()) {
+            //                 choiceJson["origin"] = model->getChoiceOrigins()->getChoiceAsJson(globalChoiceIndex);
+            //             }
+            //             if (model && model->hasChoiceLabeling()) {
+            //                 auto choiceLabels = model->getChoiceLabeling().getLabelsOfChoice(globalChoiceIndex);
+            //                 choiceJson["labels"] = std::vector<std::string>(choiceLabels.begin(),
+            //                                                                 choiceLabels.end());
+            //             }
+            //             choiceJson["index"] = globalChoiceIndex;
+            //             choiceJson["prob"] = storm::utility::convertNumber<storm::RationalNumber>(
+            //                     std::get<1>(choiceProbPair));
+
+            //             // Memory updates
+            //             if(!isMemorylessScheduler()) {
+            //                 choiceJson["memory-updates"] = std::vector<storm::json<storm::RationalNumber>>();
+            //                 uint64_t row = model->getTransitionMatrix().getRowGroupIndices()[state]; //+ std::get<0>(choiceProbPair);
+            //                 for (auto entryIt = model->getTransitionMatrix().getRow(row).begin(); entryIt < model->getTransitionMatrix().getRow(row).end(); ++entryIt) {
+            //                     storm::json<storm::RationalNumber> updateJson;
+            //                     // next model state
+            //                     if (model && model->hasStateValuations()) {
+            //                         updateJson["s'"] = model->getStateValuations().template toJson<storm::RationalNumber>(entryIt->getColumn());
+            //                     } else {
+            //                         updateJson["s'"] = entryIt->getColumn();
+            //                     }
+            //                     // next memory state
+            //                     updateJson["m'"] = this->memoryStructure->getSuccessorMemoryState(memoryState, entryIt - model->getTransitionMatrix().begin());
+            //                     choiceJson["memory-updates"].push_back(std::move(updateJson));
+            //                 }
+            //             }
+
+            //             choicesJson.push_back(std::move(choiceJson));
+            //         }
+            //         if (!choicesJson.is_null()) {
+            //             stateChoicesJson["c"] = std::move(choicesJson);
+            //             output.push_back(std::move(stateChoicesJson));
+            //         }
+            //     }
+            // }
+            // out << output.dump(4);
+        }
+
+
         template class PostScheduler<double>;
 #ifdef STORM_HAVE_CARL
         template class PostScheduler<storm::RationalNumber>;
diff --git a/src/storm/storage/PostScheduler.h b/src/storm/storage/PostScheduler.h
index 44fcb6abd..c2eaebde2 100644
--- a/src/storm/storage/PostScheduler.h
+++ b/src/storm/storage/PostScheduler.h
@@ -92,6 +92,14 @@ namespace storm {
              *                          Requires a model to be given.
              */
             void printToStream(std::ostream& out, std::shared_ptr<storm::logic::ShieldExpression const> shieldingExpression, std::shared_ptr<storm::models::sparse::Model<ValueType>> model = nullptr, bool skipUniqueChoices = false) const;
+
+             /*!
+             * Prints the pre scheduler in json format to the given output stream.
+             */
+            void printJsonToStream(std::ostream& out, std::shared_ptr<storm::models::sparse::Model<ValueType>> model = nullptr, bool skipUniqueChoices = false) const;
+
+
+
         private:
             boost::optional<storm::storage::MemoryStructure> memoryStructure;
             std::vector<std::vector<PostSchedulerChoice<ValueType>>> schedulerChoiceMapping;
diff --git a/src/storm/storage/PreScheduler.cpp b/src/storm/storage/PreScheduler.cpp
index 92069070b..373b8986e 100644
--- a/src/storm/storage/PreScheduler.cpp
+++ b/src/storm/storage/PreScheduler.cpp
@@ -167,7 +167,7 @@ namespace storm {
                     storm::json<storm::RationalNumber> choicesJson;
 
                     for (auto const &choiceProbPair : choice.getChoiceMap()) {
-                        uint64_t globalChoiceIndex = model->getTransitionMatrix().getRowGroupIndices()[state] ;//+ choiceProbPair.first;
+                        uint64_t globalChoiceIndex = model->getTransitionMatrix().getRowGroupIndices()[state] + std::get<uint_fast64_t>(choiceProbPair);
                         storm::json<storm::RationalNumber> choiceJson;
                         if (model && model->hasChoiceOrigins() &&
                             model->getChoiceOrigins()->getIdentifier(globalChoiceIndex) !=
@@ -181,17 +181,17 @@ namespace storm {
                         }
                         choiceJson["index"] = globalChoiceIndex;
                         choiceJson["prob"] = storm::utility::convertNumber<storm::RationalNumber>(
-                                std::get<1>(choiceProbPair));
+                                std::get<ValueType>(choiceProbPair));
 
                         // Memory updates
                         if(!isMemorylessScheduler()) {
                             choiceJson["memory-updates"] = std::vector<storm::json<storm::RationalNumber>>();
-                            uint64_t row = model->getTransitionMatrix().getRowGroupIndices()[state]; //+ std::get<0>(choiceProbPair);
+                            uint64_t row = model->getTransitionMatrix().getRowGroupIndices()[state] + std::get<uint_fast64_t>(choiceProbPair);
                             for (auto entryIt = model->getTransitionMatrix().getRow(row).begin(); entryIt < model->getTransitionMatrix().getRow(row).end(); ++entryIt) {
                                 storm::json<storm::RationalNumber> updateJson;
                                 // next model state
                                 if (model && model->hasStateValuations()) {
-                                    updateJson["s'"] = model->getStateValuations().template toJson<storm::RationalNumber>(entryIt->getColumn());
+                                    updateJson["s'"] = model->getStateValuations().template toJson<storm::RationalNumber>(entryIt->getColumn());    
                                 } else {
                                     updateJson["s'"] = entryIt->getColumn();
                                 }