From 1c3e669efab5bd0ee3d67c1ec92587b3a68769c1 Mon Sep 17 00:00:00 2001 From: Stefan Pranger Date: Tue, 16 Mar 2021 13:29:48 +0100 Subject: [PATCH] introduced choiceFilter for PostSafetyShields --- src/storm/shields/PostSafetyShield.cpp | 23 +++++++++++++++++++++-- src/storm/shields/PostSafetyShield.h | 3 +++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/storm/shields/PostSafetyShield.cpp b/src/storm/shields/PostSafetyShield.cpp index cf76550d5..c04499b28 100644 --- a/src/storm/shields/PostSafetyShield.cpp +++ b/src/storm/shields/PostSafetyShield.cpp @@ -12,6 +12,25 @@ namespace tempest { template storm::storage::PostScheduler PostSafetyShield::construct() { + if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) { + if(this->shieldingExpression->isRelative()) { + return constructWithCompareType, true>>(); + } else { + return constructWithCompareType, false>>(); + } + } else { + if(this->shieldingExpression->isRelative()) { + return constructWithCompareType, true>>(); + } else { + return constructWithCompareType, false>>(); + } + } + } + + template + template + storm::storage::PostScheduler PostSafetyShield::constructWithCompareType() { + ChoiceFilter choiceFilter; storm::storage::PostScheduler shield(this->rowGroupIndices.size() - 1, this->computeRowGroupSizes()); auto choice_it = this->choiceValues.begin(); if(this->coalitionStates.is_initialized()) { @@ -22,14 +41,14 @@ namespace tempest { uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; auto maxProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it; ValueType maxProbability = *(choice_it + maxProbabilityIndex); - if(!this->allowedValue(maxProbability, maxProbability, this->shieldingExpression)) { + if(!choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); shield.setChoice(0, storm::storage::Distribution(), state); continue; } for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { storm::storage::Distribution actionDistribution; - if(this->allowedValue(maxProbability, *choice_it, this->shieldingExpression)) { + if(choiceFilter(*choice_it, maxProbability, this->shieldingExpression->getValue())) { actionDistribution.addProbability(choice, 1); } else { actionDistribution.addProbability(maxProbabilityIndex, 1); diff --git a/src/storm/shields/PostSafetyShield.h b/src/storm/shields/PostSafetyShield.h index ea6eab497..58974b9c4 100644 --- a/src/storm/shields/PostSafetyShield.h +++ b/src/storm/shields/PostSafetyShield.h @@ -10,7 +10,10 @@ namespace tempest { class PostSafetyShield : public AbstractShield { public: PostSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + storm::storage::PostScheduler construct(); + template + storm::storage::PostScheduler constructWithCompareType(); }; } }