Browse Source

added first version of pre safety shield

tempestpy_adaptions
Stefan Pranger 4 years ago
parent
commit
caf855e1a4
  1. 2
      src/storm/shields/AbstractShield.cpp
  2. 12
      src/storm/shields/AbstractShield.h
  3. 35
      src/storm/shields/PreSafetyShield.cpp
  4. 2
      src/storm/shields/PreSafetyShield.h
  5. 5
      src/storm/shields/shield-handling.h

2
src/storm/shields/AbstractShield.cpp

@ -5,7 +5,7 @@ namespace tempest {
namespace shields { namespace shields {
template<typename ValueType, typename IndexType> template<typename ValueType, typename IndexType>
AbstractShield<ValueType, IndexType>::AbstractShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, boost::optional<storm::storage::BitVector> coalitionStates) : rowGroupIndices(rowGroupIndices), choiceValues(choiceValues), shieldingExpression(shieldingExpression), coalitionStates(coalitionStates) {
AbstractShield<ValueType, IndexType>::AbstractShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) : rowGroupIndices(rowGroupIndices), choiceValues(choiceValues), shieldingExpression(shieldingExpression), relevantStates(relevantStates), coalitionStates(coalitionStates) {
// Intentionally left empty. // Intentionally left empty.
} }

12
src/storm/shields/AbstractShield.h

@ -3,9 +3,12 @@
#include <boost/optional.hpp> #include <boost/optional.hpp>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <memory>
#include "storm/storage/Scheduler.h" #include "storm/storage/Scheduler.h"
#include "storm/storage/SchedulerChoice.h"
#include "storm/storage/BitVector.h" #include "storm/storage/BitVector.h"
#include "storm/storage/Distribution.h"
#include "storm/logic/ShieldExpression.h" #include "storm/logic/ShieldExpression.h"
@ -31,14 +34,21 @@ namespace tempest {
std::string getClassName() const; std::string getClassName() const;
protected: protected:
AbstractShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, boost::optional<storm::storage::BitVector> coalitionStates);
AbstractShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
std::vector<index_type> rowGroupIndices; std::vector<index_type> rowGroupIndices;
std::vector<value_type> choiceValues; std::vector<value_type> choiceValues;
std::shared_ptr<storm::logic::ShieldExpression const> shieldingExpression; std::shared_ptr<storm::logic::ShieldExpression const> shieldingExpression;
storm::storage::BitVector relevantStates;
boost::optional<storm::storage::BitVector> coalitionStates; boost::optional<storm::storage::BitVector> coalitionStates;
}; };
template<typename ValueType, typename IndexType>
bool allowedValue(ValueType const& max, ValueType const& v, std::shared_ptr<storm::logic::ShieldExpression const> const shieldExpression) {
return shieldExpression->isRelative() ? v >= shieldExpression->getValue() * max : v >= shieldExpression->getValue();
}
} }
} }

35
src/storm/shields/PreSafetyShield.cpp

@ -1,22 +1,45 @@
#include "storm/shields/PreSafetyShield.h" #include "storm/shields/PreSafetyShield.h"
#include <algorithm>
namespace tempest { namespace tempest {
namespace shields { namespace shields {
template<typename ValueType, typename IndexType> template<typename ValueType, typename IndexType>
PreSafetyShield<ValueType, IndexType>::PreSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, boost::optional<storm::storage::BitVector> coalitionStates) : AbstractShield<ValueType, IndexType>(rowGroupIndices, choiceValues, shieldingExpression, coalitionStates) {
PreSafetyShield<ValueType, IndexType>::PreSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) : AbstractShield<ValueType, IndexType>(rowGroupIndices, choiceValues, shieldingExpression, relevantStates, coalitionStates) {
// Intentionally left empty. // Intentionally left empty.
} }
template<typename ValueType, typename IndexType> template<typename ValueType, typename IndexType>
storm::storage::Scheduler<ValueType> PreSafetyShield<ValueType, IndexType>::construct() { storm::storage::Scheduler<ValueType> PreSafetyShield<ValueType, IndexType>::construct() {
for(auto const& x: this->rowGroupIndices) {
STORM_LOG_DEBUG(x << ", ");
storm::storage::Scheduler<ValueType> shield(this->rowGroupIndices.size());
STORM_LOG_DEBUG(this->rowGroupIndices.size());
STORM_LOG_DEBUG(this->relevantStates);
for(auto const& x : this->choiceValues) {
STORM_LOG_DEBUG(x << ",");
} }
for(auto const& x: this->choiceValues) {
STORM_LOG_DEBUG(x << ", ");
auto choice_it = this->choiceValues.begin();
for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) {
if(this->relevantStates.get(state)) {
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state];
STORM_LOG_DEBUG("rowGroupSize: " << rowGroupSize);
storm::storage::Distribution<ValueType, IndexType> actionDistribution;
ValueType maxProbability = *std::max_element(this->choiceValues.begin(), this->choiceValues.begin() + rowGroupSize);
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
if(allowedValue<ValueType, IndexType>(maxProbability, *choice_it, this->shieldingExpression)) {
actionDistribution.addProbability(choice, *choice_it);
STORM_LOG_DEBUG("Adding " << *choice_it << " to dist");
}
}
actionDistribution.normalize();
shield.setChoice(storm::storage::SchedulerChoice<ValueType>(actionDistribution), state);
STORM_LOG_DEBUG("SchedulerChoice: " << storm::storage::SchedulerChoice<ValueType>(actionDistribution));
} else {
shield.setChoice(storm::storage::Distribution<ValueType, IndexType>(), state);
}
} }
STORM_LOG_ASSERT(false, "construct NYI");
return shield;
} }
// Explicitly instantiate appropriate // Explicitly instantiate appropriate
template class PreSafetyShield<double, typename storm::storage::SparseMatrix<double>::index_type>; template class PreSafetyShield<double, typename storm::storage::SparseMatrix<double>::index_type>;

2
src/storm/shields/PreSafetyShield.h

@ -8,7 +8,7 @@ namespace tempest {
template<typename ValueType, typename IndexType> template<typename ValueType, typename IndexType>
class PreSafetyShield : public AbstractShield<ValueType, IndexType> { class PreSafetyShield : public AbstractShield<ValueType, IndexType> {
public: public:
PreSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, boost::optional<storm::storage::BitVector> coalitionStates);
PreSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
storm::storage::Scheduler<ValueType> construct(); storm::storage::Scheduler<ValueType> construct();
}; };
} }

5
src/storm/shields/shield-handling.h

@ -14,12 +14,13 @@
#include "storm/exceptions/InvalidArgumentException.h" #include "storm/exceptions/InvalidArgumentException.h"
namespace tempest { namespace tempest {
namespace shields { namespace shields {
template<typename ValueType, typename IndexType = storm::storage::sparse::state_type> template<typename ValueType, typename IndexType = storm::storage::sparse::state_type>
storm::storage::Scheduler<ValueType> createShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, boost::optional<storm::storage::BitVector> coalitionStates) {
storm::storage::Scheduler<ValueType> createShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) {
if(shieldingExpression->isPreSafetyShield()) { if(shieldingExpression->isPreSafetyShield()) {
PreSafetyShield<ValueType, IndexType> shield(rowGroupIndices, choiceValues, shieldingExpression, coalitionStates);
PreSafetyShield<ValueType, IndexType> shield(rowGroupIndices, choiceValues, shieldingExpression, relevantStates, coalitionStates);
return shield.construct(); return shield.construct();
} else if(shieldingExpression->isPostSafetyShield()) { } else if(shieldingExpression->isPostSafetyShield()) {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot create post safety shields yet"); STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Cannot create post safety shields yet");

Loading…
Cancel
Save