diff --git a/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp b/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp index b193b9f09..1668bc440 100644 --- a/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp +++ b/src/storm/modelchecker/rpatl/SparseSmgRpatlModelChecker.cpp @@ -24,7 +24,7 @@ #include "storm/models/sparse/StandardRewardModel.h" -#include "storm/shields/shield-handling.h" +#include "storm/shields/ShieldHandling.h" #include "storm/settings/modules/GeneralSettings.h" @@ -212,7 +212,7 @@ namespace storm { std::unique_ptr result(new ExplicitQuantitativeCheckResult(std::move(values))); if(checkTask.isShieldingTask()) { - tempest::shields::createOptimalShield(std::make_shared>(this->getModel()), helper.getProducedOptimalChoices(), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), statesOfCoalition, statesOfCoalition); + tempest::shields::createQuantitativeShield(std::make_shared>(this->getModel()), helper.getProducedOptimalChoices(), checkTask.getShieldingExpression(), checkTask.getOptimizationDirection(), statesOfCoalition, statesOfCoalition); } else if (checkTask.isProduceSchedulersSet()) { result->asExplicitQuantitativeCheckResult().setScheduler(std::make_unique>(helper.extractScheduler())); } diff --git a/src/storm/shields/ShieldHandling.cpp b/src/storm/shields/ShieldHandling.cpp new file mode 100644 index 000000000..69240f6f1 --- /dev/null +++ b/src/storm/shields/ShieldHandling.cpp @@ -0,0 +1,47 @@ +#include "ShieldHandling.h" + +namespace tempest { + namespace shields { + std::string shieldFilename(std::shared_ptr const& shieldingExpression) { + return shieldingExpression->getFilename() + ".shield"; + } + + template + void createShield(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { + std::ofstream stream; + storm::utility::openFile(shieldFilename(shieldingExpression), stream); + if(shieldingExpression->isPreSafetyShield()) { + PreSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + shield.construct().printToStream(stream, shieldingExpression, model); + } else if(shieldingExpression->isPostSafetyShield()) { + PostSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + shield.construct().printToStream(stream, shieldingExpression, model); + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString()); + storm::utility::closeFile(stream); + } + storm::utility::closeFile(stream); + } + + template + void createQuantitativeShield(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { + std::ofstream stream; + storm::utility::openFile(shieldFilename(shieldingExpression), stream); + if(shieldingExpression->isOptimalShield()) { + OptimalShield shield(model->getTransitionMatrix().getRowGroupIndices(), precomputedChoices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + shield.construct().printToStream(stream, shieldingExpression, model); + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString()); + storm::utility::closeFile(stream); + } + storm::utility::closeFile(stream); + } + // Explicitly instantiate appropriate + template void createShield::index_type>(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + template void createQuantitativeShield::index_type>(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); +#ifdef STORM_HAVE_CARL + template void createShield::index_type>(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + template void createQuantitativeShield::index_type>(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); +#endif + } +} diff --git a/src/storm/shields/ShieldHandling.h b/src/storm/shields/ShieldHandling.h new file mode 100644 index 000000000..2b21a8522 --- /dev/null +++ b/src/storm/shields/ShieldHandling.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include + +#include "storm/storage/Scheduler.h" +#include "storm/storage/BitVector.h" + +#include "storm/logic/ShieldExpression.h" + +#include "storm/shields/AbstractShield.h" +#include "storm/shields/PreSafetyShield.h" +#include "storm/shields/PostSafetyShield.h" +#include "storm/shields/OptimalShield.h" + +#include "storm/io/file.h" +#include "storm/utility/macros.h" + +#include "storm/exceptions/InvalidArgumentException.h" + +namespace tempest { + namespace shields { + std::string shieldFilename(std::shared_ptr const& shieldingExpression); + + template + void createShield(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + + template + void createQuantitativeShield(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + } +} diff --git a/src/storm/shields/shield-handling.h b/src/storm/shields/shield-handling.h deleted file mode 100644 index ef4686334..000000000 --- a/src/storm/shields/shield-handling.h +++ /dev/null @@ -1,60 +0,0 @@ -#pragma once - -#include -#include -#include - -#include "storm/storage/Scheduler.h" -#include "storm/storage/BitVector.h" - -#include "storm/logic/ShieldExpression.h" - -#include "storm/shields/AbstractShield.h" -#include "storm/shields/PreSafetyShield.h" -#include "storm/shields/PostSafetyShield.h" -#include "storm/shields/OptimalShield.h" - -#include "storm/io/file.h" -#include "storm/utility/macros.h" - -#include "storm/exceptions/InvalidArgumentException.h" - - -namespace tempest { - namespace shields { - std::string shieldFilename(std::shared_ptr const& shieldingExpression) { - return shieldingExpression->getFilename() + ".shield"; - } - - template - void createShield(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { - std::ofstream stream; - storm::utility::openFile(shieldFilename(shieldingExpression), stream); - if(shieldingExpression->isPreSafetyShield()) { - PreSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); - shield.construct().printToStream(stream, shieldingExpression, model); - } else if(shieldingExpression->isPostSafetyShield()) { - PostSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); - shield.construct().printToStream(stream, shieldingExpression, model); - } else { - STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString()); - storm::utility::closeFile(stream); - } - storm::utility::closeFile(stream); - } - - template - void createOptimalShield(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { - std::ofstream stream; - storm::utility::openFile(shieldFilename(shieldingExpression), stream); - if(shieldingExpression->isOptimalShield()) { - OptimalShield shield(model->getTransitionMatrix().getRowGroupIndices(), precomputedChoices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); - shield.construct().printToStream(stream, shieldingExpression, model); - } else { - STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString()); - storm::utility::closeFile(stream); - } - storm::utility::closeFile(stream); - } - } -}