diff --git a/src/storm/shields/PostSafetyShield.cpp b/src/storm/shields/PostSafetyShield.cpp index c04499b28..c1e094c34 100644 --- a/src/storm/shields/PostSafetyShield.cpp +++ b/src/storm/shields/PostSafetyShield.cpp @@ -14,36 +14,37 @@ namespace tempest { storm::storage::PostScheduler PostSafetyShield::construct() { if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) { if(this->shieldingExpression->isRelative()) { - return constructWithCompareType, true>>(); + return constructWithCompareType, true>(); } else { - return constructWithCompareType, false>>(); + return constructWithCompareType, false>(); } } else { if(this->shieldingExpression->isRelative()) { - return constructWithCompareType, true>>(); + return constructWithCompareType, true>(); } else { - return constructWithCompareType, false>>(); + return constructWithCompareType, false>(); } } } template - template + template storm::storage::PostScheduler PostSafetyShield::constructWithCompareType() { - ChoiceFilter choiceFilter; + tempest::shields::utility::ChoiceFilter choiceFilter; storm::storage::PostScheduler shield(this->rowGroupIndices.size() - 1, this->computeRowGroupSizes()); auto choice_it = this->choiceValues.begin(); if(this->coalitionStates.is_initialized()) { this->relevantStates &= this->coalitionStates.get(); } for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) { + uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; if(this->relevantStates.get(state)) { - uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; auto maxProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it; ValueType maxProbability = *(choice_it + maxProbabilityIndex); - if(!choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { + if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); shield.setChoice(0, storm::storage::Distribution(), state); + choice_it += rowGroupSize; continue; } for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { @@ -53,11 +54,11 @@ namespace tempest { } else { actionDistribution.addProbability(maxProbabilityIndex, 1); } - actionDistribution.normalize(); shield.setChoice(choice, storm::storage::SchedulerChoice(actionDistribution), state); } } else { shield.setChoice(0, storm::storage::Distribution(), state); + choice_it += rowGroupSize; } } return shield; diff --git a/src/storm/shields/PostSafetyShield.h b/src/storm/shields/PostSafetyShield.h index 58974b9c4..9d633274d 100644 --- a/src/storm/shields/PostSafetyShield.h +++ b/src/storm/shields/PostSafetyShield.h @@ -12,7 +12,7 @@ namespace tempest { PostSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); storm::storage::PostScheduler construct(); - template + template storm::storage::PostScheduler constructWithCompareType(); }; } diff --git a/src/storm/shields/PreSafetyShield.cpp b/src/storm/shields/PreSafetyShield.cpp index 5a7f4a9c3..c33e3e8fe 100644 --- a/src/storm/shields/PreSafetyShield.cpp +++ b/src/storm/shields/PreSafetyShield.cpp @@ -14,36 +14,37 @@ namespace tempest { storm::storage::Scheduler PreSafetyShield::construct() { if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) { if(this->shieldingExpression->isRelative()) { - return constructWithCompareType, true>>(); + return constructWithCompareType, true>(); } else { - return constructWithCompareType, false>>(); + return constructWithCompareType, false>(); } } else { if(this->shieldingExpression->isRelative()) { - return constructWithCompareType, true>>(); + return constructWithCompareType, true>(); } else { - return constructWithCompareType, false>>(); + return constructWithCompareType, false>(); } } } template - template + template storm::storage::Scheduler PreSafetyShield::constructWithCompareType() { - ChoiceFilter choiceFilter; + tempest::shields::utility::ChoiceFilter choiceFilter; storm::storage::Scheduler shield(this->rowGroupIndices.size() - 1); auto choice_it = this->choiceValues.begin(); if(this->coalitionStates.is_initialized()) { this->relevantStates &= this->coalitionStates.get(); } for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) { + uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; if(this->relevantStates.get(state)) { - uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; storm::storage::Distribution actionDistribution; ValueType maxProbability = *std::max_element(choice_it, choice_it + rowGroupSize); - if(!choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { + if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); shield.setChoice(storm::storage::Distribution(), state); + choice_it += rowGroupSize; continue; } for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { @@ -51,11 +52,11 @@ namespace tempest { actionDistribution.addProbability(choice, *choice_it); } } - actionDistribution.normalize(); shield.setChoice(storm::storage::SchedulerChoice(actionDistribution), state); } else { shield.setChoice(storm::storage::Distribution(), state); + choice_it += rowGroupSize; } } return shield; diff --git a/src/storm/shields/PreSafetyShield.h b/src/storm/shields/PreSafetyShield.h index d7f3cecea..56033a98f 100644 --- a/src/storm/shields/PreSafetyShield.h +++ b/src/storm/shields/PreSafetyShield.h @@ -11,7 +11,7 @@ namespace tempest { PreSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional coalitionStates); storm::storage::Scheduler construct(); - template + template storm::storage::Scheduler constructWithCompareType(); }; }