diff --git a/src/storm/shields/AbstractShield.cpp b/src/storm/shields/AbstractShield.cpp index 309666aa9..2722b1916 100644 --- a/src/storm/shields/AbstractShield.cpp +++ b/src/storm/shields/AbstractShield.cpp @@ -1,6 +1,7 @@ #include "storm/shields/AbstractShield.h" #include + namespace tempest { namespace shields { @@ -14,6 +15,15 @@ namespace tempest { // Intentionally left empty. } + template + std::vector AbstractShield::computeRowGroupSizes() { + std::vector rowGroupSizes(this->rowGroupIndices.size() - 1); + for(uint rowGroupStartIndex = 0; rowGroupStartIndex < rowGroupSizes.size(); rowGroupStartIndex++) { + rowGroupSizes.at(rowGroupStartIndex) = this->rowGroupIndices[rowGroupStartIndex + 1] - this->rowGroupIndices[rowGroupStartIndex]; + } + return rowGroupSizes; + } + template std::string AbstractShield::getClassName() const { return std::string(boost::core::demangled_name(BOOST_CORE_TYPEID(*this))); diff --git a/src/storm/shields/AbstractShield.h b/src/storm/shields/AbstractShield.h index 0ba34b56a..3dfcc2b1b 100644 --- a/src/storm/shields/AbstractShield.h +++ b/src/storm/shields/AbstractShield.h @@ -26,7 +26,8 @@ namespace tempest { /*! * TODO */ - virtual storm::storage::Scheduler construct() = 0; + //virtual storm::storage::Scheduler* construct() = 0; + std::vector computeRowGroupSizes(); /*! * TODO diff --git a/src/storm/shields/PostSafetyShield.cpp b/src/storm/shields/PostSafetyShield.cpp new file mode 100644 index 000000000..33b86a815 --- /dev/null +++ b/src/storm/shields/PostSafetyShield.cpp @@ -0,0 +1,57 @@ +#include "storm/shields/PostSafetyShield.h" + +#include + +namespace tempest { + namespace shields { + + template + PostSafetyShield::PostSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional coalitionStates) : AbstractShield(rowGroupIndices, choiceValues, shieldingExpression, relevantStates, coalitionStates) { + // Intentionally left empty. + } + + template + storm::storage::PostScheduler PostSafetyShield::construct() { + storm::storage::PostScheduler shield(this->rowGroupIndices.size() - 1, this->computeRowGroupSizes()); + STORM_LOG_DEBUG(this->rowGroupIndices.size()); + STORM_LOG_DEBUG(this->relevantStates); + STORM_LOG_DEBUG(this->coalitionStates.get()); + for(auto const& x : this->choiceValues) { + STORM_LOG_DEBUG(x << ","); + } + auto choice_it = this->choiceValues.begin(); + if(this->coalitionStates.is_initialized()) { + this->relevantStates &= this->coalitionStates.get(); + } + STORM_LOG_DEBUG(this->relevantStates); + for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) { + if(this->relevantStates.get(state)) { + uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; + auto maxProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it; + ValueType maxProbability = *(choice_it + maxProbabilityIndex); + for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { + STORM_LOG_DEBUG("processing " << state << " with rowGroupSize of " << rowGroupSize << " choice index " << choice << " prob: " << *choice_it << " > " << maxProbability); + storm::storage::Distribution actionDistribution; + if(allowedValue(maxProbability, *choice_it, this->shieldingExpression)) { + actionDistribution.addProbability(choice, 1); + } else { + actionDistribution.addProbability(maxProbabilityIndex, 1); + } + actionDistribution.normalize(); + STORM_LOG_DEBUG(" dist: " << actionDistribution); + shield.setChoice(choice, storm::storage::SchedulerChoice(actionDistribution), state); + } + } else { + shield.setChoice(0, storm::storage::Distribution(), state); + } + } + return shield; + } + + // Explicitly instantiate appropriate + template class PostSafetyShield::index_type>; +#ifdef STORM_HAVE_CARL + template class PostSafetyShield::index_type>; +#endif + } +} diff --git a/src/storm/shields/PostSafetyShield.h b/src/storm/shields/PostSafetyShield.h new file mode 100644 index 000000000..7e077a940 --- /dev/null +++ b/src/storm/shields/PostSafetyShield.h @@ -0,0 +1,16 @@ +#pragma once + +#include "storm/shields/AbstractShield.h" +#include "storm/storage/PostScheduler.h" + +namespace tempest { + namespace shields { + + template + class PostSafetyShield : public AbstractShield { + public: + PostSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional coalitionStates); + storm::storage::PostScheduler construct(); + }; + } +} diff --git a/src/storm/shields/PreSafetyShield.cpp b/src/storm/shields/PreSafetyShield.cpp index 98d32ded7..d1c769b2b 100644 --- a/src/storm/shields/PreSafetyShield.cpp +++ b/src/storm/shields/PreSafetyShield.cpp @@ -13,17 +13,15 @@ namespace tempest { template storm::storage::Scheduler PreSafetyShield::construct() { storm::storage::Scheduler shield(this->rowGroupIndices.size() - 1); - STORM_LOG_DEBUG(this->rowGroupIndices.size()); - STORM_LOG_DEBUG(this->relevantStates); - for(auto const& x : this->choiceValues) { - STORM_LOG_DEBUG(x << ","); - } auto choice_it = this->choiceValues.begin(); + if(this->coalitionStates.is_initialized()) { + this->relevantStates &= this->coalitionStates.get(); + } for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) { if(this->relevantStates.get(state)) { uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; storm::storage::Distribution actionDistribution; - ValueType maxProbability = *std::max_element(this->choiceValues.begin(), this->choiceValues.begin() + rowGroupSize); + ValueType maxProbability = *std::max_element(choice_it, choice_it + rowGroupSize); for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { if(allowedValue(maxProbability, *choice_it, this->shieldingExpression)) { actionDistribution.addProbability(choice, *choice_it);