diff --git a/src/storm/shields/shield-handling.h b/src/storm/shields/shield-handling.h index 7e74ebcf2..ba59e8054 100644 --- a/src/storm/shields/shield-handling.h +++ b/src/storm/shields/shield-handling.h @@ -12,6 +12,7 @@ #include "storm/shields/AbstractShield.h" #include "storm/shields/PreSafetyShield.h" #include "storm/shields/PostSafetyShield.h" +#include "storm/shields/OptimalShield.h" #include "storm/io/file.h" #include "storm/utility/macros.h" @@ -22,11 +23,7 @@ namespace tempest { namespace shields { std::string shieldFilename(std::shared_ptr const& shieldingExpression) { - std::stringstream filename; - filename << shieldingExpression->typeToString() << "_"; - filename << shieldingExpression->comparisonToString(); - filename << shieldingExpression->getValue() << ".shield"; - return filename.str(); + return shieldingExpression->getFilename() + ".shield"; } template @@ -35,13 +32,28 @@ namespace tempest { storm::utility::openFile(shieldFilename(shieldingExpression), stream); if(shieldingExpression->isPreSafetyShield()) { PreSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); - shield.construct().printToStream(stream, model); + shield.construct().printToStream(stream, shieldingExpression, model); } else if(shieldingExpression->isPostSafetyShield()) { PostSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); - shield.construct().printToStream(stream, model); - } else if(shieldingExpression->isOptimalShield()) { + shield.construct().printToStream(stream, shieldingExpression, model); + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot create " + shieldingExpression->typeToString() + " shields yet"); + storm::utility::closeFile(stream); + } + storm::utility::closeFile(stream); + } + + template + void createOptimalShield(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { + std::ofstream stream; + storm::utility::openFile(shieldFilename(shieldingExpression), stream); + if(shieldingExpression->isOptimalShield()) { + STORM_LOG_DEBUG("createOptimalShield"); + OptimalShield shield(model->getTransitionMatrix().getRowGroupIndices(), precomputedChoices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + shield.construct().printToStream(stream, shieldingExpression, model); + } else { + STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot create " + shieldingExpression->typeToString() + " shields yet"); storm::utility::closeFile(stream); - STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot create optimal shields yet"); } storm::utility::closeFile(stream); }