Browse Source

restructured shield creation

construct is no longer a pure virtual function because of different
return types for different shields. Ctor is still protected so no issues
here.

Added computeRowGroupSizes method
tempestpy_adaptions
Stefan Pranger 4 years ago
parent
commit
4b2e7e020f
  1. 10
      src/storm/shields/AbstractShield.cpp
  2. 3
      src/storm/shields/AbstractShield.h
  3. 57
      src/storm/shields/PostSafetyShield.cpp
  4. 16
      src/storm/shields/PostSafetyShield.h
  5. 10
      src/storm/shields/PreSafetyShield.cpp

10
src/storm/shields/AbstractShield.cpp

@ -1,6 +1,7 @@
#include "storm/shields/AbstractShield.h"
#include <boost/core/typeinfo.hpp>
namespace tempest {
namespace shields {
@ -14,6 +15,15 @@ namespace tempest {
// Intentionally left empty.
}
template<typename ValueType, typename IndexType>
std::vector<IndexType> AbstractShield<ValueType, IndexType>::computeRowGroupSizes() {
std::vector<IndexType> rowGroupSizes(this->rowGroupIndices.size() - 1);
for(uint rowGroupStartIndex = 0; rowGroupStartIndex < rowGroupSizes.size(); rowGroupStartIndex++) {
rowGroupSizes.at(rowGroupStartIndex) = this->rowGroupIndices[rowGroupStartIndex + 1] - this->rowGroupIndices[rowGroupStartIndex];
}
return rowGroupSizes;
}
template<typename ValueType, typename IndexType>
std::string AbstractShield<ValueType, IndexType>::getClassName() const {
return std::string(boost::core::demangled_name(BOOST_CORE_TYPEID(*this)));

3
src/storm/shields/AbstractShield.h

@ -26,7 +26,8 @@ namespace tempest {
/*!
* TODO
*/
virtual storm::storage::Scheduler<ValueType> construct() = 0;
//virtual storm::storage::Scheduler<ValueType>* construct() = 0;
std::vector<IndexType> computeRowGroupSizes();
/*!
* TODO

57
src/storm/shields/PostSafetyShield.cpp

@ -0,0 +1,57 @@
#include "storm/shields/PostSafetyShield.h"
#include <algorithm>
namespace tempest {
namespace shields {
template<typename ValueType, typename IndexType>
PostSafetyShield<ValueType, IndexType>::PostSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) : AbstractShield<ValueType, IndexType>(rowGroupIndices, choiceValues, shieldingExpression, relevantStates, coalitionStates) {
// Intentionally left empty.
}
template<typename ValueType, typename IndexType>
storm::storage::PostScheduler<ValueType> PostSafetyShield<ValueType, IndexType>::construct() {
storm::storage::PostScheduler<ValueType> shield(this->rowGroupIndices.size() - 1, this->computeRowGroupSizes());
STORM_LOG_DEBUG(this->rowGroupIndices.size());
STORM_LOG_DEBUG(this->relevantStates);
STORM_LOG_DEBUG(this->coalitionStates.get());
for(auto const& x : this->choiceValues) {
STORM_LOG_DEBUG(x << ",");
}
auto choice_it = this->choiceValues.begin();
if(this->coalitionStates.is_initialized()) {
this->relevantStates &= this->coalitionStates.get();
}
STORM_LOG_DEBUG(this->relevantStates);
for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) {
if(this->relevantStates.get(state)) {
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state];
auto maxProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it;
ValueType maxProbability = *(choice_it + maxProbabilityIndex);
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
STORM_LOG_DEBUG("processing " << state << " with rowGroupSize of " << rowGroupSize << " choice index " << choice << " prob: " << *choice_it << " > " << maxProbability);
storm::storage::Distribution<ValueType, IndexType> actionDistribution;
if(allowedValue<ValueType, IndexType>(maxProbability, *choice_it, this->shieldingExpression)) {
actionDistribution.addProbability(choice, 1);
} else {
actionDistribution.addProbability(maxProbabilityIndex, 1);
}
actionDistribution.normalize();
STORM_LOG_DEBUG(" dist: " << actionDistribution);
shield.setChoice(choice, storm::storage::SchedulerChoice<ValueType>(actionDistribution), state);
}
} else {
shield.setChoice(0, storm::storage::Distribution<ValueType, IndexType>(), state);
}
}
return shield;
}
// Explicitly instantiate appropriate
template class PostSafetyShield<double, typename storm::storage::SparseMatrix<double>::index_type>;
#ifdef STORM_HAVE_CARL
template class PostSafetyShield<storm::RationalNumber, typename storm::storage::SparseMatrix<storm::RationalNumber>::index_type>;
#endif
}
}

16
src/storm/shields/PostSafetyShield.h

@ -0,0 +1,16 @@
#pragma once
#include "storm/shields/AbstractShield.h"
#include "storm/storage/PostScheduler.h"
namespace tempest {
namespace shields {
template<typename ValueType, typename IndexType>
class PostSafetyShield : public AbstractShield<ValueType, IndexType> {
public:
PostSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
storm::storage::PostScheduler<ValueType> construct();
};
}
}

10
src/storm/shields/PreSafetyShield.cpp

@ -13,17 +13,15 @@ namespace tempest {
template<typename ValueType, typename IndexType>
storm::storage::Scheduler<ValueType> PreSafetyShield<ValueType, IndexType>::construct() {
storm::storage::Scheduler<ValueType> shield(this->rowGroupIndices.size() - 1);
STORM_LOG_DEBUG(this->rowGroupIndices.size());
STORM_LOG_DEBUG(this->relevantStates);
for(auto const& x : this->choiceValues) {
STORM_LOG_DEBUG(x << ",");
}
auto choice_it = this->choiceValues.begin();
if(this->coalitionStates.is_initialized()) {
this->relevantStates &= this->coalitionStates.get();
}
for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) {
if(this->relevantStates.get(state)) {
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state];
storm::storage::Distribution<ValueType, IndexType> actionDistribution;
ValueType maxProbability = *std::max_element(this->choiceValues.begin(), this->choiceValues.begin() + rowGroupSize);
ValueType maxProbability = *std::max_element(choice_it, choice_it + rowGroupSize);
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
if(allowedValue<ValueType, IndexType>(maxProbability, *choice_it, this->shieldingExpression)) {
actionDistribution.addProbability(choice, *choice_it);

Loading…
Cancel
Save