|
|
@ -14,36 +14,37 @@ namespace tempest { |
|
|
|
storm::storage::PostScheduler<ValueType> PostSafetyShield<ValueType, IndexType>::construct() { |
|
|
|
if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) { |
|
|
|
if(this->shieldingExpression->isRelative()) { |
|
|
|
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementLessEqual<ValueType>, true>>(); |
|
|
|
return constructWithCompareType<storm::utility::ElementLessEqual<ValueType>, true>(); |
|
|
|
} else { |
|
|
|
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementLessEqual<ValueType>, false>>(); |
|
|
|
return constructWithCompareType<storm::utility::ElementLessEqual<ValueType>, false>(); |
|
|
|
} |
|
|
|
} else { |
|
|
|
if(this->shieldingExpression->isRelative()) { |
|
|
|
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementGreaterEqual<ValueType>, true>>(); |
|
|
|
return constructWithCompareType<storm::utility::ElementGreaterEqual<ValueType>, true>(); |
|
|
|
} else { |
|
|
|
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementGreaterEqual<ValueType>, false>>(); |
|
|
|
return constructWithCompareType<storm::utility::ElementGreaterEqual<ValueType>, false>(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template<typename ValueType, typename IndexType> |
|
|
|
template<typename ChoiceFilter> |
|
|
|
template<typename Compare, bool relative> |
|
|
|
storm::storage::PostScheduler<ValueType> PostSafetyShield<ValueType, IndexType>::constructWithCompareType() { |
|
|
|
ChoiceFilter choiceFilter; |
|
|
|
tempest::shields::utility::ChoiceFilter<ValueType, Compare, relative> choiceFilter; |
|
|
|
storm::storage::PostScheduler<ValueType> shield(this->rowGroupIndices.size() - 1, this->computeRowGroupSizes()); |
|
|
|
auto choice_it = this->choiceValues.begin(); |
|
|
|
if(this->coalitionStates.is_initialized()) { |
|
|
|
this->relevantStates &= this->coalitionStates.get(); |
|
|
|
} |
|
|
|
for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) { |
|
|
|
if(this->relevantStates.get(state)) { |
|
|
|
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; |
|
|
|
if(this->relevantStates.get(state)) { |
|
|
|
auto maxProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it; |
|
|
|
ValueType maxProbability = *(choice_it + maxProbabilityIndex); |
|
|
|
if(!choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { |
|
|
|
if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) { |
|
|
|
STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state); |
|
|
|
shield.setChoice(0, storm::storage::Distribution<ValueType, IndexType>(), state); |
|
|
|
choice_it += rowGroupSize; |
|
|
|
continue; |
|
|
|
} |
|
|
|
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { |
|
|
@ -53,11 +54,11 @@ namespace tempest { |
|
|
|
} else { |
|
|
|
actionDistribution.addProbability(maxProbabilityIndex, 1); |
|
|
|
} |
|
|
|
actionDistribution.normalize(); |
|
|
|
shield.setChoice(choice, storm::storage::SchedulerChoice<ValueType>(actionDistribution), state); |
|
|
|
} |
|
|
|
} else { |
|
|
|
shield.setChoice(0, storm::storage::Distribution<ValueType, IndexType>(), state); |
|
|
|
choice_it += rowGroupSize; |
|
|
|
} |
|
|
|
} |
|
|
|
return shield; |
|
|
|