From 48011735054ec028d4d9213057cf851cf66487c2 Mon Sep 17 00:00:00 2001 From: Stefan Pranger Date: Mon, 15 Mar 2021 23:47:54 +0100 Subject: [PATCH] refactored preSafetyShield to use ChoiceFilter --- src/storm/shields/PreSafetyShield.cpp | 26 ++++++++++++++++++++++++-- src/storm/shields/PreSafetyShield.h | 5 ++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/storm/shields/PreSafetyShield.cpp b/src/storm/shields/PreSafetyShield.cpp index 9d0e24b43..071f6c53a 100644 --- a/src/storm/shields/PreSafetyShield.cpp +++ b/src/storm/shields/PreSafetyShield.cpp @@ -12,6 +12,27 @@ namespace tempest { template storm::storage::Scheduler PreSafetyShield::construct() { + STORM_LOG_DEBUG("PreSafetyShield::construct called"); + if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) { + if(this->shieldingExpression->isRelative()) { + return constructWithCompareType, true>>(); + } else { + return constructWithCompareType, false>>(); + } + } else { + if(this->shieldingExpression->isRelative()) { + return constructWithCompareType, true>>(); + } else { + return constructWithCompareType, false>>(); + } + } + } + + template + template + storm::storage::Scheduler PreSafetyShield::constructWithCompareType() { + STORM_LOG_DEBUG("PreSafetyShield::constructWithCompareType called"); + ChoiceFilter choiceFilter; storm::storage::Scheduler shield(this->rowGroupIndices.size() - 1); auto choice_it = this->choiceValues.begin(); if(this->coalitionStates.is_initialized()) { @@ -22,13 +43,13 @@ namespace tempest { uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; storm::storage::Distribution actionDistribution; ValueType maxProbability = *std::max_element(choice_it, choice_it + rowGroupSize); - if(!this->allowedValue(maxProbability, maxProbability, this->shieldingExpression)) { + if(!choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); shield.setChoice(storm::storage::Distribution(), state); continue; } for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { - if(this->allowedValue(maxProbability, *choice_it, this->shieldingExpression)) { + if(choiceFilter(*choice_it, maxProbability, this->shieldingExpression->getValue())) { actionDistribution.addProbability(choice, *choice_it); } } @@ -39,6 +60,7 @@ namespace tempest { shield.setChoice(storm::storage::Distribution(), state); } } + STORM_LOG_DEBUG("PreSafetyShield::constructWithCompareType done"); return shield; } // Explicitly instantiate appropriate diff --git a/src/storm/shields/PreSafetyShield.h b/src/storm/shields/PreSafetyShield.h index c962bb0d1..d7f3cecea 100644 --- a/src/storm/shields/PreSafetyShield.h +++ b/src/storm/shields/PreSafetyShield.h @@ -8,8 +8,11 @@ namespace tempest { template class PreSafetyShield : public AbstractShield { public: - PreSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + PreSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + storm::storage::Scheduler construct(); + template + storm::storage::Scheduler constructWithCompareType(); }; } }