From 454bffe03f76ba2bc5d054f0e589b5747ca7c988 Mon Sep 17 00:00:00 2001 From: Stefan Pranger Date: Wed, 3 Nov 2021 10:39:52 +0100 Subject: [PATCH] major changes to shield handling - Introduced OptimalPre and OptimalPost shields - Renamed *Safety to PreShield and PostShield - Introduced min case for shields - fixed coalition states in shield handling --- .../parser/FormulaParserGrammar.cpp | 20 ++++--- .../parser/FormulaParserGrammar.h | 2 +- src/storm/logic/ShieldExpression.cpp | 26 +++++---- src/storm/logic/ShieldExpression.h | 10 ++-- src/storm/shields/AbstractShield.h | 8 ++- src/storm/shields/OptimalShield.cpp | 54 ++++++++++++++++--- src/storm/shields/OptimalShield.h | 10 ++-- .../{PostSafetyShield.cpp => PostShield.cpp} | 27 +++++----- .../{PostSafetyShield.h => PostShield.h} | 4 +- .../{PreSafetyShield.cpp => PreShield.cpp} | 25 +++++---- .../{PreSafetyShield.h => PreShield.h} | 4 +- src/storm/shields/ShieldHandling.cpp | 19 ++++--- src/storm/shields/ShieldHandling.h | 6 +-- .../storm/parser/GameShieldingParserTest.cpp | 12 ++--- .../storm/parser/MdpShieldingParserTest.cpp | 12 ++--- 15 files changed, 153 insertions(+), 86 deletions(-) rename src/storm/shields/{PostSafetyShield.cpp => PostShield.cpp} (62%) rename src/storm/shields/{PostSafetyShield.h => PostShield.h} (53%) rename src/storm/shields/{PreSafetyShield.cpp => PreShield.cpp} (62%) rename src/storm/shields/{PreSafetyShield.h => PreShield.h} (53%) diff --git a/src/storm-parsers/parser/FormulaParserGrammar.cpp b/src/storm-parsers/parser/FormulaParserGrammar.cpp index ea96f31a2..2aae130bd 100644 --- a/src/storm-parsers/parser/FormulaParserGrammar.cpp +++ b/src/storm-parsers/parser/FormulaParserGrammar.cpp @@ -184,16 +184,21 @@ namespace storm { shieldExpression.name("shield expression"); - shieldingType = (qi::lit("PreSafety")[qi::_val = storm::logic::ShieldingType::PreSafety] | - qi::lit("PostSafety")[qi::_val = storm::logic::ShieldingType::PostSafety] | - qi::lit("Optimal")[qi::_val = storm::logic::ShieldingType::Optimal]) > -qi::lit("Shield"); + shieldingType = (qi::lit("PreSafety")[qi::_val = storm::logic::ShieldingType::PreSafety] | + qi::lit("PostSafety")[qi::_val = storm::logic::ShieldingType::PostSafety] | + qi::lit("OptimalPre")[qi::_val = storm::logic::ShieldingType::OptimalPre] | + qi::lit("OptimalPost")[qi::_val = storm::logic::ShieldingType::OptimalPost] | + qi::lit("Optimal")[qi::_val = storm::logic::ShieldingType::OptimalPost]) // backwards compatability, will be disabled in the future + > -qi::lit("Shield"); shieldingType.name("shielding type"); - probability = qi::double_[qi::_pass = (qi::_1 >= 0) & (qi::_1 <= 1.0), qi::_val = qi::_1 ]; - probability.name("double between 0 and 1"); + //probability = qi::double_[qi::_pass = (qi::_1 >= 0) & (qi::_1 <= 1.0), qi::_val = qi::_1 ]; + //probability.name("double between 0 and 1"); + comparisonValue = qi::double_[qi::_val = qi::_1 ]; + comparisonValue.name("double comparison value"); shieldComparison = ((qi::lit("lambda")[qi::_a = storm::logic::ShieldComparison::Relative] | - qi::lit("gamma")[qi::_a = storm::logic::ShieldComparison::Absolute]) > qi::lit("=") > probability)[qi::_val = phoenix::bind(&FormulaParserGrammar::createShieldComparisonStruct, phoenix::ref(*this), qi::_a, qi::_1)]; + qi::lit("gamma")[qi::_a = storm::logic::ShieldComparison::Absolute]) > qi::lit("=") > comparisonValue)[qi::_val = phoenix::bind(&FormulaParserGrammar::createShieldComparisonStruct, phoenix::ref(*this), qi::_a, qi::_1)]; shieldComparison.name("shield comparison type"); #pragma clang diagnostic push @@ -649,10 +654,9 @@ namespace storm { std::shared_ptr FormulaParserGrammar::createShieldExpression(storm::logic::ShieldingType type, std::string name, boost::optional> comparisonStruct) { if(comparisonStruct.is_initialized()) { - STORM_LOG_WARN_COND(type != storm::logic::ShieldingType::Optimal , "Comparison for optimal shield will be ignored."); return std::shared_ptr(new storm::logic::ShieldExpression(type, name, comparisonStruct.get().first, comparisonStruct.get().second)); } else { - STORM_LOG_THROW(type == storm::logic::ShieldingType::Optimal , storm::exceptions::WrongFormatException, "Construction of safety shield needs a comparison parameter (lambda or gamma)"); + STORM_LOG_INFO("Construction of shield without a comparison parameter (lambda or gamma) will default to 'lambda=0'"); return std::shared_ptr(new storm::logic::ShieldExpression(type, name)); } } diff --git a/src/storm-parsers/parser/FormulaParserGrammar.h b/src/storm-parsers/parser/FormulaParserGrammar.h index 9f0953577..972f19028 100644 --- a/src/storm-parsers/parser/FormulaParserGrammar.h +++ b/src/storm-parsers/parser/FormulaParserGrammar.h @@ -237,7 +237,7 @@ namespace storm { // Shielding properties qi::rule(), Skipper> shieldExpression; qi::rule shieldingType; - qi::rule probability; + qi::rule comparisonValue; qi::rule, qi::locals, Skipper> shieldComparison; // Start symbol diff --git a/src/storm/logic/ShieldExpression.cpp b/src/storm/logic/ShieldExpression.cpp index a9cbeb088..74a86e79f 100644 --- a/src/storm/logic/ShieldExpression.cpp +++ b/src/storm/logic/ShieldExpression.cpp @@ -26,8 +26,12 @@ namespace storm { return type == storm::logic::ShieldingType::PostSafety; } - bool ShieldExpression::isOptimalShield() const { - return type == storm::logic::ShieldingType::Optimal; + bool ShieldExpression::isOptimalPreShield() const { + return type == storm::logic::ShieldingType::OptimalPre; + } + + bool ShieldExpression::isOptimalPostShield() const { + return type == storm::logic::ShieldingType::OptimalPost; } double ShieldExpression::getValue() const { @@ -36,9 +40,10 @@ namespace storm { std::string ShieldExpression::typeToString() const { switch(type) { - case storm::logic::ShieldingType::PostSafety: return "PostSafety"; - case storm::logic::ShieldingType::PreSafety: return "PreSafety"; - case storm::logic::ShieldingType::Optimal: return "Optimal"; + case storm::logic::ShieldingType::PostSafety: return "Post"; + case storm::logic::ShieldingType::PreSafety: return "Pre"; + case storm::logic::ShieldingType::OptimalPre: return "OptimalPre"; + case storm::logic::ShieldingType::OptimalPost: return "OptimalPost"; } } @@ -57,14 +62,13 @@ namespace storm { std::string prettyString = ""; std::string comparisonType = isRelative() ? "relative" : "absolute"; switch(type) { - case storm::logic::ShieldingType::PostSafety: prettyString += "Post-Safety"; break; - case storm::logic::ShieldingType::PreSafety: prettyString += "Pre-Safety"; break; - case storm::logic::ShieldingType::Optimal: prettyString += "Optimal"; break; + case storm::logic::ShieldingType::PostSafety: prettyString += "Post-Safety"; break; + case storm::logic::ShieldingType::PreSafety: prettyString += "Pre-Safety"; break; + case storm::logic::ShieldingType::OptimalPre: prettyString += "Optimal-Pre"; break; + case storm::logic::ShieldingType::OptimalPost: prettyString += "Optimal-Post"; break; } prettyString += "-Shield "; - if(!(type == storm::logic::ShieldingType::Optimal)) { - prettyString += "with " + comparisonType + " comparison (" + comparisonToString() + " = " + std::to_string(value) + "):"; - } + prettyString += "with " + comparisonType + " comparison (" + comparisonToString() + " = " + std::to_string(value) + "):"; return prettyString; } diff --git a/src/storm/logic/ShieldExpression.h b/src/storm/logic/ShieldExpression.h index b214f3baa..f13575591 100644 --- a/src/storm/logic/ShieldExpression.h +++ b/src/storm/logic/ShieldExpression.h @@ -9,7 +9,8 @@ namespace storm { enum class ShieldingType { PostSafety, PreSafety, - Optimal + OptimalPre, + OptimalPost }; enum class ShieldComparison { Absolute, Relative }; @@ -23,7 +24,8 @@ namespace storm { bool isRelative() const; bool isPreSafetyShield() const; bool isPostSafetyShield() const; - bool isOptimalShield() const; + bool isOptimalPreShield() const; + bool isOptimalPostShield() const; double getValue() const; @@ -36,8 +38,8 @@ namespace storm { private: ShieldingType type; - ShieldComparison comparison; - double value; + ShieldComparison comparison = ShieldComparison::Relative; + double value = 0; std::string filename; }; diff --git a/src/storm/shields/AbstractShield.h b/src/storm/shields/AbstractShield.h index d93da4407..b479efe71 100644 --- a/src/storm/shields/AbstractShield.h +++ b/src/storm/shields/AbstractShield.h @@ -21,9 +21,13 @@ namespace tempest { namespace utility { template struct ChoiceFilter { - bool operator()(ValueType v, ValueType max, double shieldValue) { + bool operator()(ValueType v, ValueType opt, double shieldValue) { Compare compare; - if(relative) return compare(v, max * shieldValue); + if(relative && std::is_same>::value) { + return compare(v, opt + opt * shieldValue); + } else if(relative && std::is_same>::value) { + return compare(v, opt * shieldValue); + } else return compare(v, shieldValue); } }; diff --git a/src/storm/shields/OptimalShield.cpp b/src/storm/shields/OptimalShield.cpp index 0bd26908d..8be4b1fef 100644 --- a/src/storm/shields/OptimalShield.cpp +++ b/src/storm/shields/OptimalShield.cpp @@ -6,27 +6,65 @@ namespace tempest { namespace shields { template - OptimalShield::OptimalShield(std::vector const& rowGroupIndices, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) : AbstractShield(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), precomputedChoices(precomputedChoices) { + OptimalShield::OptimalShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) : AbstractShield(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), choiceValues(choiceValues) { // Intentionally left empty. } template - storm::storage::OptimalScheduler OptimalShield::construct() { - storm::storage::OptimalScheduler shield(this->rowGroupIndices.size() - 1); - // TODO Needs fixing as soon as we support MDPs + storm::storage::PostScheduler OptimalShield::construct() { + if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) { + if(this->shieldingExpression->isRelative()) { + return constructWithCompareType, true>(); + } else { + return constructWithCompareType, false>(); + } + } else { + if(this->shieldingExpression->isRelative()) { + return constructWithCompareType, true>(); + } else { + return constructWithCompareType, false>(); + } + } + } + + template + template + storm::storage::PostScheduler OptimalShield::constructWithCompareType() { + tempest::shields::utility::ChoiceFilter choiceFilter; + storm::storage::PostScheduler shield(this->rowGroupIndices.size() - 1, this->computeRowGroupSizes()); + auto choice_it = this->choiceValues.begin(); if(this->coalitionStates.is_initialized()) { - this->relevantStates = ~this->relevantStates; + this->relevantStates &= this->coalitionStates.get(); } for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) { + uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; if(this->relevantStates.get(state)) { - shield.setChoice(precomputedChoices[state], state); + auto maxProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it; + ValueType maxProbability = *(choice_it + maxProbabilityIndex); + if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { + STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); + shield.setChoice(storm::storage::PostSchedulerChoice(), state, 0); + choice_it += rowGroupSize; + continue; + } + storm::storage::PostSchedulerChoice choiceMapping; + for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { + if(choiceFilter(*choice_it, maxProbability, this->shieldingExpression->getValue())) { + choiceMapping.addChoice(choice, choice); + } else { + choiceMapping.addChoice(choice, maxProbabilityIndex); + } + } + shield.setChoice(choiceMapping, state, 0); } else { - shield.setChoice(storm::storage::Distribution(), state); + shield.setChoice(storm::storage::PostSchedulerChoice(), state, 0); + choice_it += rowGroupSize; } } return shield; } - // Explicitly instantiate appropriate + + // Explicitly instantiate appropriate classes template class OptimalShield::index_type>; #ifdef STORM_HAVE_CARL template class OptimalShield::index_type>; diff --git a/src/storm/shields/OptimalShield.h b/src/storm/shields/OptimalShield.h index 03bff3542..b7c55712e 100644 --- a/src/storm/shields/OptimalShield.h +++ b/src/storm/shields/OptimalShield.h @@ -1,7 +1,7 @@ #pragma once #include "storm/shields/AbstractShield.h" -#include "storm/storage/OptimalScheduler.h" +#include "storm/storage/PostScheduler.h" namespace tempest { namespace shields { @@ -9,11 +9,13 @@ namespace tempest { template class OptimalShield : public AbstractShield { public: - OptimalShield(std::vector const& rowGroupIndices, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + OptimalShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); - storm::storage::OptimalScheduler construct(); + storm::storage::PostScheduler construct(); + template + storm::storage::PostScheduler constructWithCompareType(); private: - std::vector precomputedChoices; + std::vector choiceValues; }; } } diff --git a/src/storm/shields/PostSafetyShield.cpp b/src/storm/shields/PostShield.cpp similarity index 62% rename from src/storm/shields/PostSafetyShield.cpp rename to src/storm/shields/PostShield.cpp index 311da2832..5ef487ad0 100644 --- a/src/storm/shields/PostSafetyShield.cpp +++ b/src/storm/shields/PostShield.cpp @@ -1,4 +1,4 @@ -#include "storm/shields/PostSafetyShield.h" +#include "storm/shields/PostShield.h" #include @@ -6,12 +6,12 @@ namespace tempest { namespace shields { template - PostSafetyShield::PostSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) : AbstractShield(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), choiceValues(choiceValues) { + PostShield::PostShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) : AbstractShield(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), choiceValues(choiceValues) { // Intentionally left empty. } template - storm::storage::PostScheduler PostSafetyShield::construct() { + storm::storage::PostScheduler PostShield::construct() { if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) { if(this->shieldingExpression->isRelative()) { return constructWithCompareType, true>(); @@ -29,19 +29,22 @@ namespace tempest { template template - storm::storage::PostScheduler PostSafetyShield::constructWithCompareType() { + storm::storage::PostScheduler PostShield::constructWithCompareType() { tempest::shields::utility::ChoiceFilter choiceFilter; storm::storage::PostScheduler shield(this->rowGroupIndices.size() - 1, this->computeRowGroupSizes()); auto choice_it = this->choiceValues.begin(); if(this->coalitionStates.is_initialized()) { - this->relevantStates &= this->coalitionStates.get(); + this->relevantStates &= ~this->coalitionStates.get(); } for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) { uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; if(this->relevantStates.get(state)) { - auto maxProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it; - ValueType maxProbability = *(choice_it + maxProbabilityIndex); - if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { + auto optProbabilityIndex = std::min_element(choice_it, choice_it + rowGroupSize) - choice_it; + if(std::is_same>::value) { + optProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it; + } + ValueType optProbability = *(choice_it + optProbabilityIndex); + if(!relative && !choiceFilter(optProbability, optProbability, this->shieldingExpression->getValue())) { STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); shield.setChoice(storm::storage::PostSchedulerChoice(), state, 0); choice_it += rowGroupSize; @@ -49,10 +52,10 @@ namespace tempest { } storm::storage::PostSchedulerChoice choiceMapping; for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { - if(choiceFilter(*choice_it, maxProbability, this->shieldingExpression->getValue())) { + if(choiceFilter(*choice_it, optProbability, this->shieldingExpression->getValue())) { choiceMapping.addChoice(choice, choice); } else { - choiceMapping.addChoice(choice, maxProbabilityIndex); + choiceMapping.addChoice(choice, optProbabilityIndex); } } shield.setChoice(choiceMapping, state, 0); @@ -65,9 +68,9 @@ namespace tempest { } // Explicitly instantiate appropriate classes - template class PostSafetyShield::index_type>; + template class PostShield::index_type>; #ifdef STORM_HAVE_CARL - template class PostSafetyShield::index_type>; + template class PostShield::index_type>; #endif } } diff --git a/src/storm/shields/PostSafetyShield.h b/src/storm/shields/PostShield.h similarity index 53% rename from src/storm/shields/PostSafetyShield.h rename to src/storm/shields/PostShield.h index 6f04e53c3..f2a43905f 100644 --- a/src/storm/shields/PostSafetyShield.h +++ b/src/storm/shields/PostShield.h @@ -7,9 +7,9 @@ namespace tempest { namespace shields { template - class PostSafetyShield : public AbstractShield { + class PostShield : public AbstractShield { public: - PostSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + PostShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); storm::storage::PostScheduler construct(); template diff --git a/src/storm/shields/PreSafetyShield.cpp b/src/storm/shields/PreShield.cpp similarity index 62% rename from src/storm/shields/PreSafetyShield.cpp rename to src/storm/shields/PreShield.cpp index f4374b929..a16b777c7 100644 --- a/src/storm/shields/PreSafetyShield.cpp +++ b/src/storm/shields/PreShield.cpp @@ -1,4 +1,4 @@ -#include "storm/shields/PreSafetyShield.h" +#include "storm/shields/PreShield.h" #include @@ -6,12 +6,12 @@ namespace tempest { namespace shields { template - PreSafetyShield::PreSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) : AbstractShield(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), choiceValues(choiceValues) { + PreShield::PreShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) : AbstractShield(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), choiceValues(choiceValues) { // Intentionally left empty. } template - storm::storage::PreScheduler PreSafetyShield::construct() { + storm::storage::PreScheduler PreShield::construct() { if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) { if(this->shieldingExpression->isRelative()) { return constructWithCompareType, true>(); @@ -29,26 +29,31 @@ namespace tempest { template template - storm::storage::PreScheduler PreSafetyShield::constructWithCompareType() { + storm::storage::PreScheduler PreShield::constructWithCompareType() { tempest::shields::utility::ChoiceFilter choiceFilter; storm::storage::PreScheduler shield(this->rowGroupIndices.size() - 1); auto choice_it = this->choiceValues.begin(); if(this->coalitionStates.is_initialized()) { - this->relevantStates &= this->coalitionStates.get(); + this->relevantStates &= ~this->coalitionStates.get(); } for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) { uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; if(this->relevantStates.get(state)) { storm::storage::PreSchedulerChoice enabledChoices; - ValueType maxProbability = *std::max_element(choice_it, choice_it + rowGroupSize); - if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { + ValueType optProbability; + if(std::is_same>::value) { + optProbability = *std::max_element(choice_it, choice_it + rowGroupSize); + } else { + optProbability = *std::min_element(choice_it, choice_it + rowGroupSize); + } + if(!relative && !choiceFilter(optProbability, optProbability, this->shieldingExpression->getValue())) { STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); shield.setChoice(storm::storage::PreSchedulerChoice(), state, 0); choice_it += rowGroupSize; continue; } for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { - if(choiceFilter(*choice_it, maxProbability, this->shieldingExpression->getValue())) { + if(choiceFilter(*choice_it, optProbability, this->shieldingExpression->getValue())) { enabledChoices.addChoice(choice, *choice_it); } } @@ -63,9 +68,9 @@ namespace tempest { return shield; } // Explicitly instantiate appropriate classes - template class PreSafetyShield::index_type>; + template class PreShield::index_type>; #ifdef STORM_HAVE_CARL - template class PreSafetyShield::index_type>; + template class PreShield::index_type>; #endif } } diff --git a/src/storm/shields/PreSafetyShield.h b/src/storm/shields/PreShield.h similarity index 53% rename from src/storm/shields/PreSafetyShield.h rename to src/storm/shields/PreShield.h index 8a4667bf2..6e98dd7e8 100644 --- a/src/storm/shields/PreSafetyShield.h +++ b/src/storm/shields/PreShield.h @@ -7,9 +7,9 @@ namespace tempest { namespace shields { template - class PreSafetyShield : public AbstractShield { + class PreShield : public AbstractShield { public: - PreSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + PreShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); storm::storage::PreScheduler construct(); template diff --git a/src/storm/shields/ShieldHandling.cpp b/src/storm/shields/ShieldHandling.cpp index 69240f6f1..007959d5d 100644 --- a/src/storm/shields/ShieldHandling.cpp +++ b/src/storm/shields/ShieldHandling.cpp @@ -10,11 +10,12 @@ namespace tempest { void createShield(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { std::ofstream stream; storm::utility::openFile(shieldFilename(shieldingExpression), stream); + if(coalitionStates.is_initialized()) coalitionStates.get().complement(); if(shieldingExpression->isPreSafetyShield()) { - PreSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + PreShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); shield.construct().printToStream(stream, shieldingExpression, model); } else if(shieldingExpression->isPostSafetyShield()) { - PostSafetyShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + PostShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); shield.construct().printToStream(stream, shieldingExpression, model); } else { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString()); @@ -24,11 +25,15 @@ namespace tempest { } template - void createQuantitativeShield(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { + void createQuantitativeShield(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { std::ofstream stream; storm::utility::openFile(shieldFilename(shieldingExpression), stream); - if(shieldingExpression->isOptimalShield()) { - OptimalShield shield(model->getTransitionMatrix().getRowGroupIndices(), precomputedChoices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + if(coalitionStates.is_initialized()) coalitionStates.get().complement(); // TODO CHECK THIS!!! + if(shieldingExpression->isOptimalPreShield()) { + PreShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); + shield.construct().printToStream(stream, shieldingExpression, model); + } else if(shieldingExpression->isOptimalPostShield()) { + PostShield shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates); shield.construct().printToStream(stream, shieldingExpression, model); } else { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString()); @@ -38,10 +43,10 @@ namespace tempest { } // Explicitly instantiate appropriate template void createShield::index_type>(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); - template void createQuantitativeShield::index_type>(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + template void createQuantitativeShield::index_type>(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); #ifdef STORM_HAVE_CARL template void createShield::index_type>(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); - template void createQuantitativeShield::index_type>(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + template void createQuantitativeShield::index_type>(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); #endif } } diff --git a/src/storm/shields/ShieldHandling.h b/src/storm/shields/ShieldHandling.h index 2b21a8522..768888a7d 100644 --- a/src/storm/shields/ShieldHandling.h +++ b/src/storm/shields/ShieldHandling.h @@ -10,8 +10,8 @@ #include "storm/logic/ShieldExpression.h" #include "storm/shields/AbstractShield.h" -#include "storm/shields/PreSafetyShield.h" -#include "storm/shields/PostSafetyShield.h" +#include "storm/shields/PreShield.h" +#include "storm/shields/PostShield.h" #include "storm/shields/OptimalShield.h" #include "storm/io/file.h" @@ -27,6 +27,6 @@ namespace tempest { void createShield(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); template - void createQuantitativeShield(std::shared_ptr> model, std::vector const& precomputedChoices, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + void createQuantitativeShield(std::shared_ptr> model, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); } } diff --git a/src/test/storm/parser/GameShieldingParserTest.cpp b/src/test/storm/parser/GameShieldingParserTest.cpp index 53ae456a7..6279efc42 100644 --- a/src/test/storm/parser/GameShieldingParserTest.cpp +++ b/src/test/storm/parser/GameShieldingParserTest.cpp @@ -20,15 +20,15 @@ TEST(GameShieldingParserTest, PreSafetyShieldTest) { std::shared_ptr shieldExpression(nullptr); ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression()); - EXPECT_TRUE(shieldExpression->isPreSafetyShield()); - EXPECT_FALSE(shieldExpression->isPostSafetyShield()); + EXPECT_TRUE(shieldExpression->isPreShield()); + EXPECT_FALSE(shieldExpression->isPostShield()); EXPECT_FALSE(shieldExpression->isOptimalShield()); EXPECT_TRUE(shieldExpression->isRelative()); EXPECT_EQ(std::stod(value), shieldExpression->getValue()); EXPECT_EQ(filename, shieldExpression->getFilename()); } -TEST(GameShieldingParserTest, PostSafetyShieldTest) { +TEST(GameShieldingParserTest, PostShieldTest) { storm::parser::FormulaParser formulaParser; std::string filename = "postSafetyShieldFileName"; @@ -46,7 +46,7 @@ TEST(GameShieldingParserTest, PostSafetyShieldTest) { std::shared_ptr shieldExpression(nullptr); ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression()); EXPECT_FALSE(shieldExpression->isPreSafetyShield()); - EXPECT_TRUE(shieldExpression->isPostSafetyShield()); + EXPECT_TRUE(shieldExpression->isPostShield()); EXPECT_FALSE(shieldExpression->isOptimalShield()); EXPECT_FALSE(shieldExpression->isRelative()); EXPECT_EQ(std::stod(value), shieldExpression->getValue()); @@ -74,8 +74,8 @@ TEST(GameShieldingParserTest, OptimalShieldTest) { std::shared_ptr shieldExpression(nullptr); ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression()); EXPECT_FALSE(shieldExpression->isPreSafetyShield()); - EXPECT_FALSE(shieldExpression->isPostSafetyShield()); + EXPECT_FALSE(shieldExpression->isPostShield()); EXPECT_TRUE(shieldExpression->isOptimalShield()); EXPECT_FALSE(shieldExpression->isRelative()); EXPECT_EQ(filename, shieldExpression->getFilename()); -} \ No newline at end of file +} diff --git a/src/test/storm/parser/MdpShieldingParserTest.cpp b/src/test/storm/parser/MdpShieldingParserTest.cpp index 2baf68f30..85ffe9883 100644 --- a/src/test/storm/parser/MdpShieldingParserTest.cpp +++ b/src/test/storm/parser/MdpShieldingParserTest.cpp @@ -18,19 +18,19 @@ TEST(MdpShieldingParserTest, PreSafetyShieldTest) { std::shared_ptr shieldExpression(nullptr); ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression()); EXPECT_TRUE(shieldExpression->isPreSafetyShield()); - EXPECT_FALSE(shieldExpression->isPostSafetyShield()); + EXPECT_FALSE(shieldExpression->isPostShield()); EXPECT_FALSE(shieldExpression->isOptimalShield()); EXPECT_FALSE(shieldExpression->isRelative()); EXPECT_EQ(std::stod(value), shieldExpression->getValue()); EXPECT_EQ(filename, shieldExpression->getFilename()); } -TEST(MdpShieldingParserTest, PostSafetyShieldTest) { +TEST(MdpShieldingParserTest, PostShieldTest) { storm::parser::FormulaParser formulaParser; std::string filename = "postSafetyShieldFileName"; std::string value = "0.95"; - std::string input = "<" + filename + ", PostSafety, lambda=" + value + "> Pmin=? [X !\"label\"]"; + std::string input = "<" + filename + ", Post, lambda=" + value + "> Pmin=? [X !\"label\"]"; std::shared_ptr formula(nullptr); std::vector property; @@ -40,7 +40,7 @@ TEST(MdpShieldingParserTest, PostSafetyShieldTest) { std::shared_ptr shieldExpression(nullptr); ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression()); EXPECT_FALSE(shieldExpression->isPreSafetyShield()); - EXPECT_TRUE(shieldExpression->isPostSafetyShield()); + EXPECT_TRUE(shieldExpression->isPostShield()); EXPECT_FALSE(shieldExpression->isOptimalShield()); EXPECT_TRUE(shieldExpression->isRelative()); EXPECT_EQ(std::stod(value), shieldExpression->getValue()); @@ -65,8 +65,8 @@ TEST(MdpShieldingParserTest, OptimalShieldTest) { std::shared_ptr shieldExpression(nullptr); ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression()); EXPECT_FALSE(shieldExpression->isPreSafetyShield()); - EXPECT_FALSE(shieldExpression->isPostSafetyShield()); + EXPECT_FALSE(shieldExpression->isPostShield()); EXPECT_TRUE(shieldExpression->isOptimalShield()); EXPECT_FALSE(shieldExpression->isRelative()); EXPECT_EQ(filename, shieldExpression->getFilename()); -} \ No newline at end of file +}