Browse Source

major changes in safety shield creation

- changed templating
 - buxfixes with choice value to state mapping
tempestpy_adaptions
Stefan Pranger 4 years ago
parent
commit
5421b98aa5
  1. 19
      src/storm/shields/PostSafetyShield.cpp
  2. 2
      src/storm/shields/PostSafetyShield.h
  3. 19
      src/storm/shields/PreSafetyShield.cpp
  4. 2
      src/storm/shields/PreSafetyShield.h

19
src/storm/shields/PostSafetyShield.cpp

@ -14,36 +14,37 @@ namespace tempest {
storm::storage::PostScheduler<ValueType> PostSafetyShield<ValueType, IndexType>::construct() {
if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) {
if(this->shieldingExpression->isRelative()) {
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementLessEqual<ValueType>, true>>();
return constructWithCompareType<storm::utility::ElementLessEqual<ValueType>, true>();
} else {
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementLessEqual<ValueType>, false>>();
return constructWithCompareType<storm::utility::ElementLessEqual<ValueType>, false>();
}
} else {
if(this->shieldingExpression->isRelative()) {
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementGreaterEqual<ValueType>, true>>();
return constructWithCompareType<storm::utility::ElementGreaterEqual<ValueType>, true>();
} else {
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementGreaterEqual<ValueType>, false>>();
return constructWithCompareType<storm::utility::ElementGreaterEqual<ValueType>, false>();
}
}
}
template<typename ValueType, typename IndexType>
template<typename ChoiceFilter>
template<typename Compare, bool relative>
storm::storage::PostScheduler<ValueType> PostSafetyShield<ValueType, IndexType>::constructWithCompareType() {
ChoiceFilter choiceFilter;
tempest::shields::utility::ChoiceFilter<ValueType, Compare, relative> choiceFilter;
storm::storage::PostScheduler<ValueType> shield(this->rowGroupIndices.size() - 1, this->computeRowGroupSizes());
auto choice_it = this->choiceValues.begin();
if(this->coalitionStates.is_initialized()) {
this->relevantStates &= this->coalitionStates.get();
}
for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) {
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state];
if(this->relevantStates.get(state)) {
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state];
auto maxProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it;
ValueType maxProbability = *(choice_it + maxProbabilityIndex);
if(!choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) {
if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) {
STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state);
shield.setChoice(0, storm::storage::Distribution<ValueType, IndexType>(), state);
choice_it += rowGroupSize;
continue;
}
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
@ -53,11 +54,11 @@ namespace tempest {
} else {
actionDistribution.addProbability(maxProbabilityIndex, 1);
}
actionDistribution.normalize();
shield.setChoice(choice, storm::storage::SchedulerChoice<ValueType>(actionDistribution), state);
}
} else {
shield.setChoice(0, storm::storage::Distribution<ValueType, IndexType>(), state);
choice_it += rowGroupSize;
}
}
return shield;

2
src/storm/shields/PostSafetyShield.h

@ -12,7 +12,7 @@ namespace tempest {
PostSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
storm::storage::PostScheduler<ValueType> construct();
template<typename ChoiceFilter>
template<typename Compare, bool relative>
storm::storage::PostScheduler<ValueType> constructWithCompareType();
};
}

19
src/storm/shields/PreSafetyShield.cpp

@ -14,36 +14,37 @@ namespace tempest {
storm::storage::Scheduler<ValueType> PreSafetyShield<ValueType, IndexType>::construct() {
if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) {
if(this->shieldingExpression->isRelative()) {
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementLessEqual<ValueType>, true>>();
return constructWithCompareType<storm::utility::ElementLessEqual<ValueType>, true>();
} else {
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementLessEqual<ValueType>, false>>();
return constructWithCompareType<storm::utility::ElementLessEqual<ValueType>, false>();
}
} else {
if(this->shieldingExpression->isRelative()) {
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementGreaterEqual<ValueType>, true>>();
return constructWithCompareType<storm::utility::ElementGreaterEqual<ValueType>, true>();
} else {
return constructWithCompareType<tempest::shields::utility::ChoiceFilter<ValueType, storm::utility::ElementGreaterEqual<ValueType>, false>>();
return constructWithCompareType<storm::utility::ElementGreaterEqual<ValueType>, false>();
}
}
}
template<typename ValueType, typename IndexType>
template<typename ChoiceFilter>
template<typename Compare, bool relative>
storm::storage::Scheduler<ValueType> PreSafetyShield<ValueType, IndexType>::constructWithCompareType() {
ChoiceFilter choiceFilter;
tempest::shields::utility::ChoiceFilter<ValueType, Compare, relative> choiceFilter;
storm::storage::Scheduler<ValueType> shield(this->rowGroupIndices.size() - 1);
auto choice_it = this->choiceValues.begin();
if(this->coalitionStates.is_initialized()) {
this->relevantStates &= this->coalitionStates.get();
}
for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) {
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state];
if(this->relevantStates.get(state)) {
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state];
storm::storage::Distribution<ValueType, IndexType> actionDistribution;
ValueType maxProbability = *std::max_element(choice_it, choice_it + rowGroupSize);
if(!choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) {
if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) {
STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state);
shield.setChoice(storm::storage::Distribution<ValueType, IndexType>(), state);
choice_it += rowGroupSize;
continue;
}
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
@ -51,11 +52,11 @@ namespace tempest {
actionDistribution.addProbability(choice, *choice_it);
}
}
actionDistribution.normalize();
shield.setChoice(storm::storage::SchedulerChoice<ValueType>(actionDistribution), state);
} else {
shield.setChoice(storm::storage::Distribution<ValueType, IndexType>(), state);
choice_it += rowGroupSize;
}
}
return shield;

2
src/storm/shields/PreSafetyShield.h

@ -11,7 +11,7 @@ namespace tempest {
PreSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
storm::storage::Scheduler<ValueType> construct();
template<typename ChoiceFilter>
template<typename Compare, bool relative>
storm::storage::Scheduler<ValueType> constructWithCompareType();
};
}

Loading…
Cancel
Save