From 0d60b468e132621c20114b6bf156683afa8cf299 Mon Sep 17 00:00:00 2001 From: sp Date: Tue, 5 Nov 2024 11:24:30 +0100 Subject: [PATCH] refactored shield export --- src/storm-cli-utilities/model-handling.h | 20 ++++++- src/storm/api/verification.h | 2 +- .../rpatl/SparseSmgRpatlModelChecker.cpp | 17 +++--- src/storm/shields/AbstractShield.cpp | 56 ++++++++++++++++++- src/storm/shields/AbstractShield.h | 43 +++++++++++--- src/storm/shields/OptimalShield.cpp | 27 +++++++-- src/storm/shields/OptimalShield.h | 7 ++- src/storm/shields/PostShield.cpp | 26 +++++++-- src/storm/shields/PostShield.h | 6 +- src/storm/shields/PreShield.cpp | 29 +++++++--- src/storm/shields/PreShield.h | 9 ++- src/storm/storage/PostScheduler.cpp | 9 +-- src/storm/storage/PreScheduler.cpp | 7 ++- src/storm/storage/PreScheduler.h | 1 + 14 files changed, 213 insertions(+), 46 deletions(-) diff --git a/src/storm-cli-utilities/model-handling.h b/src/storm-cli-utilities/model-handling.h index c90852df5..cb915a005 100644 --- a/src/storm-cli-utilities/model-handling.h +++ b/src/storm-cli-utilities/model-handling.h @@ -49,6 +49,11 @@ #include "storm/settings/modules/HintSettings.h" #include "storm/storage/Qvbs.h" +#include "storm/shields/AbstractShield.h" +#include "storm/shields/PreShield.h" +#include "storm/shields/PostShield.h" +#include "storm/shields/OptimalShield.h" + #include "storm/utility/Stopwatch.h" namespace storm { @@ -1047,13 +1052,22 @@ namespace storm { if (result->isExplicitQuantitativeCheckResult()) { if (result-> template asExplicitQuantitativeCheckResult().hasShield()) { auto shield = result->template asExplicitQuantitativeCheckResult().getShield(); + if(shield->isPreShield()) { + shield->asPreShield().construct(); + } else if(shield->isPostShield()) { + shield->asPostShield().construct(); + } else if(shield->isOptimalShield()) { + shield->asOptimalShield().construct(); + } + STORM_PRINT_AND_LOG("Exporting shield ... "); - - storm::api::exportShield(sparseModel, shield, ioSettings.getExportShieldFilename()); + + STORM_LOG_WARN_COND(exportCount == 0, "Prepending " << exportCount << " to file name for this property because there are multiple properties."); + storm::api::exportShield(sparseModel, shield, (exportCount == 0 ? std::string("") : std::to_string(exportCount)) + ioSettings.getExportShieldFilename()); } } } - + if (ioSettings.isExportCheckResultSet()) { STORM_LOG_WARN_COND(sparseModel->hasStateValuations(), "No information of state valuations available. The result output will use internal state ids. You might be interested in building the model with state valuations using --buildstateval."); diff --git a/src/storm/api/verification.h b/src/storm/api/verification.h index 868c00be3..3abb50803 100644 --- a/src/storm/api/verification.h +++ b/src/storm/api/verification.h @@ -458,7 +458,7 @@ namespace storm { template typename std::enable_if::value, std::unique_ptr>::type verifyWithDdEngine(storm::Environment const&, std::shared_ptr> const&, storm::modelchecker::CheckTask const&) { - STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Dd engine cannot verify MDPs with this data type."); + STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, ""); } template diff --git a/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp b/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp index 53562ef04..79cc6fc87 100644 --- a/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp +++ b/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp @@ -141,13 +141,14 @@ namespace storm { ExplicitQualitativeCheckResult const& leftResult = leftResultPointer->asExplicitQualitativeCheckResult(); ExplicitQualitativeCheckResult const& rightResult = rightResultPointer->asExplicitQualitativeCheckResult(); + auto ret = storm::modelchecker::helper::SparseSmgRpatlHelper::computeUntilProbabilities(env, storm::solver::SolveGoal(this->getModel(), checkTask), this->getModel().getTransitionMatrix(), this->getModel().getBackwardTransitions(), leftResult.getTruthValuesVector(), rightResult.getTruthValuesVector(), checkTask.isQualitativeSet(), statesOfCoalition, checkTask.isProduceSchedulersSet(), checkTask.getHint()); std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(ret.values))); if(checkTask.isShieldingTask()) { storm::storage::BitVector allStatesBv = storm::storage::BitVector(this->getModel().getTransitionMatrix().getRowGroupCount(), true); auto shield = tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), allStatesBv, ~statesOfCoalition); - result->asExplicitQuantitativeCheckResult().setShield(std::move(shield)); - } + result->asExplicitQuantitativeCheckResult().setShield(std::move(shield)); + } if (checkTask.isProduceSchedulersSet() && ret.scheduler) { result->asExplicitQuantitativeCheckResult().setScheduler(std::move(ret.scheduler)); } @@ -165,8 +166,8 @@ namespace storm { if(checkTask.isShieldingTask()) { storm::storage::BitVector allStatesBv = storm::storage::BitVector(this->getModel().getTransitionMatrix().getRowGroupCount(), true); auto shield = tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), allStatesBv, ~statesOfCoalition); - result->asExplicitQuantitativeCheckResult().setShield(std::move(shield)); - } + result->asExplicitQuantitativeCheckResult().setShield(std::move(shield)); + } if (checkTask.isProduceSchedulersSet() && ret.scheduler) { result->asExplicitQuantitativeCheckResult().setScheduler(std::move(ret.scheduler)); } @@ -202,7 +203,7 @@ namespace storm { if(checkTask.isShieldingTask()) { storm::storage::BitVector allStatesBv = storm::storage::BitVector(this->getModel().getTransitionMatrix().getRowGroupCount(), true); auto shield = tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), allStatesBv, ~statesOfCoalition); - result->asExplicitQuantitativeCheckResult().setShield(std::move(shield)); + result->asExplicitQuantitativeCheckResult().setShield(std::move(shield)); } return result; } @@ -225,7 +226,7 @@ namespace storm { std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(ret.values))); if(checkTask.isShieldingTask()) { auto shield = tempest::shields::createShield(std::make_shared>(this->getModel()), std::move(ret.choiceValues), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), std::move(ret.relevantStates), ~statesOfCoalition); - result->asExplicitQuantitativeCheckResult().setShield(std::move(shield)); + result->asExplicitQuantitativeCheckResult().setShield(std::move(shield)); } return result; } @@ -246,8 +247,8 @@ namespace storm { if(checkTask.isShieldingTask()) { storm::storage::BitVector allStatesBv = storm::storage::BitVector(this->getModel().getTransitionMatrix().getRowGroupCount(), true); auto shield = tempest::shields::createQuantitativeShield(std::make_shared>(this->getModel()), helper.getChoiceValues(), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), allStatesBv, statesOfCoalition); - result->asExplicitQuantitativeCheckResult().setShield(std::move(shield)); - } + result->asExplicitQuantitativeCheckResult().setShield(std::move(shield)); + } if (checkTask.isProduceSchedulersSet()) { result->asExplicitQuantitativeCheckResult().setScheduler(std::make_unique>(helper.extractScheduler())); } diff --git a/src/storm/shields/AbstractShield.cpp b/src/storm/shields/AbstractShield.cpp index 82f27bdc9..ca446726c 100644 --- a/src/storm/shields/AbstractShield.cpp +++ b/src/storm/shields/AbstractShield.cpp @@ -1,4 +1,7 @@ #include "storm/shields/AbstractShield.h" +#include "storm/shields/PreShield.h" +#include "storm/shields/PostShield.h" +#include "storm/shields/OptimalShield.h" #include @@ -29,11 +32,62 @@ namespace tempest { return optimizationDirection; } + template + void AbstractShield::setShieldingExpression(std::shared_ptr const& shieldingExpression) { + this->shieldingExpression = shieldingExpression; + } + + template + bool AbstractShield::isPreShield() const { + return false; + } + + template + bool AbstractShield::isPostShield() const { + return false; + } + + template + bool AbstractShield::isOptimalShield() const { + return false; + } + + template + PreShield& AbstractShield::asPreShield() { + return dynamic_cast&>(*this); + } + + template + PreShield const& AbstractShield::asPreShield() const { + return dynamic_cast const&>(*this); + } + + template + PostShield& AbstractShield::asPostShield() { + return dynamic_cast&>(*this); + } + + template + PostShield const& AbstractShield::asPostShield() const { + return dynamic_cast const&>(*this); + } + + template + OptimalShield& AbstractShield::asOptimalShield() { + return dynamic_cast&>(*this); + } + + template + OptimalShield const& AbstractShield::asOptimalShield() const { + return dynamic_cast const&>(*this); + } + + template std::string AbstractShield::getClassName() const { return std::string(boost::core::demangled_name(BOOST_CORE_TYPEID(*this))); } - + // Explicitly instantiate appropriate template class AbstractShield::index_type>; #ifdef STORM_HAVE_CARL diff --git a/src/storm/shields/AbstractShield.h b/src/storm/shields/AbstractShield.h index 9ef54775e..04401cd6b 100644 --- a/src/storm/shields/AbstractShield.h +++ b/src/storm/shields/AbstractShield.h @@ -16,19 +16,33 @@ #include "storm/logic/ShieldExpression.h" +#include "storm/exceptions/NotSupportedException.h" + + namespace tempest { namespace shields { + template + class PreShield; + template + class PostShield; + template + class OptimalShield; + namespace utility { template struct ChoiceFilter { bool operator()(ValueType v, ValueType opt, double shieldValue) { - Compare compare; - if(relative && std::is_same>::value) { - return compare(v, opt + opt * shieldValue); - } else if(relative && std::is_same>::value) { - return compare(v, opt * shieldValue); + if constexpr (std::is_same_v || std::is_same_v) { + Compare compare; + if(relative && std::is_same>::value) { + return compare(v, opt + opt * shieldValue); + } else if(relative && std::is_same>::value) { + return compare(v, opt * shieldValue); + } + else return compare(v, shieldValue); + } else { + STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Cannot create shields for parametric models"); } - else return compare(v, shieldValue); } }; } @@ -47,9 +61,24 @@ namespace tempest { std::vector computeRowGroupSizes(); storm::OptimizationDirection getOptimizationDirection(); + void setShieldingExpression(std::shared_ptr const& shieldingExpression); std::string getClassName() const; - + + virtual bool isPreShield() const; + virtual bool isPostShield() const; + virtual bool isOptimalShield() const; + + PreShield& asPreShield(); + PreShield const& asPreShield() const; + + PostShield& asPostShield(); + PostShield const& asPostShield() const; + + OptimalShield& asOptimalShield(); + OptimalShield const& asOptimalShield() const; + + virtual void printToStream(std::ostream& out, std::shared_ptr> const& model) = 0; virtual void printJsonToStream(std::ostream& out, std::shared_ptr> const& model) = 0; diff --git a/src/storm/shields/OptimalShield.cpp b/src/storm/shields/OptimalShield.cpp index 6d03e77da..0d8fc0c63 100644 --- a/src/storm/shields/OptimalShield.cpp +++ b/src/storm/shields/OptimalShield.cpp @@ -64,21 +64,40 @@ namespace tempest { return shield; } + template - void OptimalShield::printToStream(std::ostream& out, std::shared_ptr> const& model) { - this->construct().printToStream(out, this->shieldingExpression, model); + std::shared_ptr> OptimalShield::getScheduler() const { + return optimalScheduler; + } + + template + bool OptimalShield::isOptimalShield() const { + return true; } template void OptimalShield::printJsonToStream(std::ostream& out, std::shared_ptr> const& model) { - this->construct().printJsonToStream(out, model); + optimalScheduler->printJsonToStream(out, model); + } + + template + void OptimalShield::printToStream(std::ostream& out, std::shared_ptr> const& model) { + optimalScheduler->printToStream(out, this->shieldingExpression, model); } + //template + //template + //std::enable_if_t::value, storm::storage::PostScheduler> OptimalShield::construct() { + // STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "todo"); + //} + + + // Explicitly instantiate appropriate classes template class OptimalShield::index_type>; #ifdef STORM_HAVE_CARL template class OptimalShield::index_type>; - // template class OptimalShield::index_type>; + template class OptimalShield::index_type>; #endif } diff --git a/src/storm/shields/OptimalShield.h b/src/storm/shields/OptimalShield.h index 0dac50587..3c2d9a2ff 100644 --- a/src/storm/shields/OptimalShield.h +++ b/src/storm/shields/OptimalShield.h @@ -14,11 +14,16 @@ namespace tempest { storm::storage::PostScheduler construct(); template storm::storage::PostScheduler constructWithCompareType(); + std::shared_ptr> getScheduler() const; + + virtual bool isOptimalShield() const override; + virtual void printToStream(std::ostream& out, std::shared_ptr> const& model) override; virtual void printJsonToStream(std::ostream& out, std::shared_ptr> const& model) override; - private: std::vector choiceValues; + + std::shared_ptr> optimalScheduler; }; } } diff --git a/src/storm/shields/PostShield.cpp b/src/storm/shields/PostShield.cpp index 6dffd5344..a3e86b7f2 100644 --- a/src/storm/shields/PostShield.cpp +++ b/src/storm/shields/PostShield.cpp @@ -45,7 +45,7 @@ namespace tempest { } ValueType optProbability = *(choice_it + optProbabilityIndex); if(!relative && !choiceFilter(optProbability, optProbability, this->shieldingExpression->getValue())) { - STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); + //STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); shield.setChoice(storm::storage::PostSchedulerChoice(), state, 0); choice_it += rowGroupSize; continue; @@ -69,21 +69,37 @@ namespace tempest { template - void PostShield::printToStream(std::ostream& out, std::shared_ptr> const& model) { - this->construct().printToStream(out, this->shieldingExpression, model); + std::shared_ptr> PostShield::getScheduler() const { + return postScheduler; + } + + template + bool PostShield::isPostShield() const { + return true; } template void PostShield::printJsonToStream(std::ostream& out, std::shared_ptr> const& model) { - this->construct().printJsonToStream(out, model); + postScheduler->printJsonToStream(out, model); } + template + void PostShield::printToStream(std::ostream& out, std::shared_ptr> const& model) { + postScheduler->printToStream(out, this->shieldingExpression, model); + } + + //template + //template + //std::enable_if_t::value, storm::storage::PostScheduler> PostShield::construct() { + // STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "todo"); + //} + // Explicitly instantiate appropriate classes template class PostShield::index_type>; #ifdef STORM_HAVE_CARL template class PostShield::index_type>; - // template class PostShield::index_type>; + template class PostShield::index_type>; #endif } diff --git a/src/storm/shields/PostShield.h b/src/storm/shields/PostShield.h index b857a9216..2116cf7bb 100644 --- a/src/storm/shields/PostShield.h +++ b/src/storm/shields/PostShield.h @@ -14,12 +14,16 @@ namespace tempest { storm::storage::PostScheduler construct(); template storm::storage::PostScheduler constructWithCompareType(); + std::shared_ptr> getScheduler() const; + + virtual bool isPostShield() const override; virtual void printToStream(std::ostream& out, std::shared_ptr> const& model) override; virtual void printJsonToStream(std::ostream& out, std::shared_ptr> const& model) override; - private: std::vector choiceValues; + + std::shared_ptr> postScheduler; }; } } diff --git a/src/storm/shields/PreShield.cpp b/src/storm/shields/PreShield.cpp index 716d4f8d1..7c62d5788 100644 --- a/src/storm/shields/PreShield.cpp +++ b/src/storm/shields/PreShield.cpp @@ -10,8 +10,12 @@ namespace tempest { // Intentionally left empty. } + template storm::storage::PreScheduler PreShield::construct() { + if constexpr (std::is_same_v) { + STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "todo"); + } if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) { if(this->shieldingExpression->isRelative()) { return constructWithCompareType, true>(); @@ -38,7 +42,7 @@ namespace tempest { } for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) { uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; - if(this->relevantStates.get(state)) { + if(true){ //if(this->relevantStates.get(state)) { storm::storage::PreSchedulerChoice enabledChoices; ValueType optProbability; if(std::is_same>::value) { @@ -47,7 +51,7 @@ namespace tempest { optProbability = *std::min_element(choice_it, choice_it + rowGroupSize); } if(!relative && !choiceFilter(optProbability, optProbability, this->shieldingExpression->getValue())) { - STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); + //STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); shield.setChoice(storm::storage::PreSchedulerChoice(), state, 0); choice_it += rowGroupSize; continue; @@ -65,27 +69,38 @@ namespace tempest { } } + preScheduler = std::make_shared>(shield); return shield; } template - void PreShield::printToStream(std::ostream& out, std::shared_ptr> const& model) { - this->construct().printToStream(out, this->shieldingExpression, model); + std::shared_ptr> PreShield::getScheduler() const { + return preScheduler; + } + + template + bool PreShield::isPreShield() const { + return true; } template void PreShield::printJsonToStream(std::ostream& out, std::shared_ptr> const& model) { - this->construct().printJsonToStream(out, model); + preScheduler->printJsonToStream(out, model); } + template + void PreShield::printToStream(std::ostream& out, std::shared_ptr> const& model) { + preScheduler->printToStream(out, this->shieldingExpression, model); + } + + // Explicitly instantiate appropriate classes template class PreShield::index_type>; #ifdef STORM_HAVE_CARL template class PreShield::index_type>; - //template class PreShield::index_type>; - + template class PreShield::index_type>; #endif } } diff --git a/src/storm/shields/PreShield.h b/src/storm/shields/PreShield.h index 1b5f099b8..3bf80017c 100644 --- a/src/storm/shields/PreShield.h +++ b/src/storm/shields/PreShield.h @@ -11,15 +11,22 @@ namespace tempest { public: PreShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + storm::storage::PreScheduler construct(); + template storm::storage::PreScheduler constructWithCompareType(); + void setShieldingExpression(std::shared_ptr const& shieldingExpression); + std::shared_ptr> getScheduler() const; + + virtual bool isPreShield() const override; virtual void printToStream(std::ostream& out, std::shared_ptr> const& model) override; virtual void printJsonToStream(std::ostream& out, std::shared_ptr> const& model) override; - private: std::vector choiceValues; + + std::shared_ptr> preScheduler; }; } } diff --git a/src/storm/storage/PostScheduler.cpp b/src/storm/storage/PostScheduler.cpp index dd8fb3d1f..d8f5a7993 100644 --- a/src/storm/storage/PostScheduler.cpp +++ b/src/storm/storage/PostScheduler.cpp @@ -161,11 +161,11 @@ namespace storm { std::string choiceActionLabel = choiceOriginJson["action-label"]; std::string choiceCorrectionActionLabel = choiceOriginCorrectionJson["action-label"]; choiceOriginJson["action-label"] = choiceActionLabel.append(": ").append(choiceCorrectionActionLabel).append("\n"); - choiceJson["origin"] = choiceOriginJson; + choiceJson["origin"] = choiceOriginJson; } if (model && model->hasChoiceLabeling()) { auto choiceLabels = model->getChoiceLabeling().getLabelsOfChoice(globalChoiceIndex); - + choiceJson["labels"] = std::vector(choiceLabels.begin(), choiceLabels.end()); } @@ -179,10 +179,10 @@ namespace storm { } else { choicesJson = "undefined"; } - + stateChoicesJson["c"] = std::move(choicesJson); output.push_back(std::move(stateChoicesJson)); - + } out << output.dump(4); @@ -192,6 +192,7 @@ namespace storm { template class PostScheduler; #ifdef STORM_HAVE_CARL template class PostScheduler; + template class PostScheduler; #endif } } diff --git a/src/storm/storage/PreScheduler.cpp b/src/storm/storage/PreScheduler.cpp index 41508064b..f6f78c41d 100644 --- a/src/storm/storage/PreScheduler.cpp +++ b/src/storm/storage/PreScheduler.cpp @@ -138,7 +138,7 @@ namespace storm { } - + template void PreScheduler::printJsonToStream(std::ostream& out, std::shared_ptr> model, bool skipUniqueChoices) const { @@ -194,7 +194,7 @@ namespace storm { storm::json updateJson; // next model state if (model && model->hasStateValuations()) { - updateJson["s'"] = model->getStateValuations().template toJson(entryIt->getColumn()); + updateJson["s'"] = model->getStateValuations().template toJson(entryIt->getColumn()); } else { updateJson["s'"] = entryIt->getColumn(); } @@ -209,7 +209,7 @@ namespace storm { } else { choicesJson = "undefined"; } - + stateChoicesJson["c"] = std::move(choicesJson); output.push_back(std::move(stateChoicesJson)); } @@ -223,6 +223,7 @@ namespace storm { template class PreScheduler; #ifdef STORM_HAVE_CARL template class PreScheduler; + template class PreScheduler; #endif } } diff --git a/src/storm/storage/PreScheduler.h b/src/storm/storage/PreScheduler.h index e695e0732..294e7ecc1 100644 --- a/src/storm/storage/PreScheduler.h +++ b/src/storm/storage/PreScheduler.h @@ -13,6 +13,7 @@ namespace storm { /* * TODO needs obvious changes in all comment blocks */ + // TODO inherit from Scheduler? template class PreScheduler { public: