diff --git a/src/storm/shields/AbstractShield.cpp b/src/storm/shields/AbstractShield.cpp index 62944b7d9..309666aa9 100644 --- a/src/storm/shields/AbstractShield.cpp +++ b/src/storm/shields/AbstractShield.cpp @@ -5,7 +5,7 @@ namespace tempest { namespace shields { template - AbstractShield::AbstractShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, boost::optional coalitionStates) : rowGroupIndices(rowGroupIndices), choiceValues(choiceValues), shieldingExpression(shieldingExpression), coalitionStates(coalitionStates) { + AbstractShield::AbstractShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional coalitionStates) : rowGroupIndices(rowGroupIndices), choiceValues(choiceValues), shieldingExpression(shieldingExpression), relevantStates(relevantStates), coalitionStates(coalitionStates) { // Intentionally left empty. } diff --git a/src/storm/shields/AbstractShield.h b/src/storm/shields/AbstractShield.h index cc7fe9ca6..0ba34b56a 100644 --- a/src/storm/shields/AbstractShield.h +++ b/src/storm/shields/AbstractShield.h @@ -3,9 +3,12 @@ #include #include #include +#include #include "storm/storage/Scheduler.h" +#include "storm/storage/SchedulerChoice.h" #include "storm/storage/BitVector.h" +#include "storm/storage/Distribution.h" #include "storm/logic/ShieldExpression.h" @@ -31,14 +34,21 @@ namespace tempest { std::string getClassName() const; protected: - AbstractShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, boost::optional coalitionStates); + AbstractShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional coalitionStates); std::vector rowGroupIndices; std::vector choiceValues; std::shared_ptr shieldingExpression; + storm::storage::BitVector relevantStates; + boost::optional coalitionStates; }; + + template + bool allowedValue(ValueType const& max, ValueType const& v, std::shared_ptr const shieldExpression) { + return shieldExpression->isRelative() ? v >= shieldExpression->getValue() * max : v >= shieldExpression->getValue(); + } } } diff --git a/src/storm/shields/PreSafetyShield.cpp b/src/storm/shields/PreSafetyShield.cpp index 7cb944299..a87aa940b 100644 --- a/src/storm/shields/PreSafetyShield.cpp +++ b/src/storm/shields/PreSafetyShield.cpp @@ -1,22 +1,45 @@ #include "storm/shields/PreSafetyShield.h" +#include + namespace tempest { namespace shields { template - PreSafetyShield::PreSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, boost::optional coalitionStates) : AbstractShield(rowGroupIndices, choiceValues, shieldingExpression, coalitionStates) { + PreSafetyShield::PreSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional coalitionStates) : AbstractShield(rowGroupIndices, choiceValues, shieldingExpression, relevantStates, coalitionStates) { // Intentionally left empty. } template storm::storage::Scheduler PreSafetyShield::construct() { - for(auto const& x: this->rowGroupIndices) { - STORM_LOG_DEBUG(x << ", "); + storm::storage::Scheduler shield(this->rowGroupIndices.size()); + STORM_LOG_DEBUG(this->rowGroupIndices.size()); + STORM_LOG_DEBUG(this->relevantStates); + for(auto const& x : this->choiceValues) { + STORM_LOG_DEBUG(x << ","); } - for(auto const& x: this->choiceValues) { - STORM_LOG_DEBUG(x << ", "); + auto choice_it = this->choiceValues.begin(); + for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) { + if(this->relevantStates.get(state)) { + uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state]; + STORM_LOG_DEBUG("rowGroupSize: " << rowGroupSize); + storm::storage::Distribution actionDistribution; + ValueType maxProbability = *std::max_element(this->choiceValues.begin(), this->choiceValues.begin() + rowGroupSize); + for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) { + if(allowedValue(maxProbability, *choice_it, this->shieldingExpression)) { + actionDistribution.addProbability(choice, *choice_it); + STORM_LOG_DEBUG("Adding " << *choice_it << " to dist"); + } + } + actionDistribution.normalize(); + shield.setChoice(storm::storage::SchedulerChoice(actionDistribution), state); + STORM_LOG_DEBUG("SchedulerChoice: " << storm::storage::SchedulerChoice(actionDistribution)); + + } else { + shield.setChoice(storm::storage::Distribution(), state); + } } - STORM_LOG_ASSERT(false, "construct NYI"); + return shield; } // Explicitly instantiate appropriate template class PreSafetyShield::index_type>; diff --git a/src/storm/shields/PreSafetyShield.h b/src/storm/shields/PreSafetyShield.h index 0dbcb6906..c962bb0d1 100644 --- a/src/storm/shields/PreSafetyShield.h +++ b/src/storm/shields/PreSafetyShield.h @@ -8,7 +8,7 @@ namespace tempest { template class PreSafetyShield : public AbstractShield { public: - PreSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, boost::optional coalitionStates); + PreSafetyShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional coalitionStates); storm::storage::Scheduler construct(); }; } diff --git a/src/storm/shields/shield-handling.h b/src/storm/shields/shield-handling.h index a53c7e8f6..be3b1ce88 100644 --- a/src/storm/shields/shield-handling.h +++ b/src/storm/shields/shield-handling.h @@ -14,12 +14,13 @@ #include "storm/exceptions/InvalidArgumentException.h" + namespace tempest { namespace shields { template - storm::storage::Scheduler createShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, boost::optional coalitionStates) { + storm::storage::Scheduler createShield(std::vector const& rowGroupIndices, std::vector const& choiceValues, std::shared_ptr const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional coalitionStates) { if(shieldingExpression->isPreSafetyShield()) { - PreSafetyShield shield(rowGroupIndices, choiceValues, shieldingExpression, coalitionStates); + PreSafetyShield shield(rowGroupIndices, choiceValues, shieldingExpression, relevantStates, coalitionStates); return shield.construct(); } else if(shieldingExpression->isPostSafetyShield()) { STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot create post safety shields yet");