Browse Source

major changes to shield handling

- Introduced OptimalPre and OptimalPost shields
 - Renamed *Safety to PreShield and PostShield
 - Introduced min case for shields
 - fixed coalition states in shield handling
tempestpy_adaptions
Stefan Pranger 3 years ago
parent
commit
454bffe03f
  1. 20
      src/storm-parsers/parser/FormulaParserGrammar.cpp
  2. 2
      src/storm-parsers/parser/FormulaParserGrammar.h
  3. 26
      src/storm/logic/ShieldExpression.cpp
  4. 10
      src/storm/logic/ShieldExpression.h
  5. 8
      src/storm/shields/AbstractShield.h
  6. 54
      src/storm/shields/OptimalShield.cpp
  7. 10
      src/storm/shields/OptimalShield.h
  8. 27
      src/storm/shields/PostShield.cpp
  9. 4
      src/storm/shields/PostShield.h
  10. 25
      src/storm/shields/PreShield.cpp
  11. 4
      src/storm/shields/PreShield.h
  12. 19
      src/storm/shields/ShieldHandling.cpp
  13. 6
      src/storm/shields/ShieldHandling.h
  14. 12
      src/test/storm/parser/GameShieldingParserTest.cpp
  15. 12
      src/test/storm/parser/MdpShieldingParserTest.cpp

20
src/storm-parsers/parser/FormulaParserGrammar.cpp

@ -184,16 +184,21 @@ namespace storm {
shieldExpression.name("shield expression");
shieldingType = (qi::lit("PreSafety")[qi::_val = storm::logic::ShieldingType::PreSafety] |
qi::lit("PostSafety")[qi::_val = storm::logic::ShieldingType::PostSafety] |
qi::lit("Optimal")[qi::_val = storm::logic::ShieldingType::Optimal]) > -qi::lit("Shield");
shieldingType = (qi::lit("PreSafety")[qi::_val = storm::logic::ShieldingType::PreSafety] |
qi::lit("PostSafety")[qi::_val = storm::logic::ShieldingType::PostSafety] |
qi::lit("OptimalPre")[qi::_val = storm::logic::ShieldingType::OptimalPre] |
qi::lit("OptimalPost")[qi::_val = storm::logic::ShieldingType::OptimalPost] |
qi::lit("Optimal")[qi::_val = storm::logic::ShieldingType::OptimalPost]) // backwards compatability, will be disabled in the future
> -qi::lit("Shield");
shieldingType.name("shielding type");
probability = qi::double_[qi::_pass = (qi::_1 >= 0) & (qi::_1 <= 1.0), qi::_val = qi::_1 ];
probability.name("double between 0 and 1");
//probability = qi::double_[qi::_pass = (qi::_1 >= 0) & (qi::_1 <= 1.0), qi::_val = qi::_1 ];
//probability.name("double between 0 and 1");
comparisonValue = qi::double_[qi::_val = qi::_1 ];
comparisonValue.name("double comparison value");
shieldComparison = ((qi::lit("lambda")[qi::_a = storm::logic::ShieldComparison::Relative] |
qi::lit("gamma")[qi::_a = storm::logic::ShieldComparison::Absolute]) > qi::lit("=") > probability)[qi::_val = phoenix::bind(&FormulaParserGrammar::createShieldComparisonStruct, phoenix::ref(*this), qi::_a, qi::_1)];
qi::lit("gamma")[qi::_a = storm::logic::ShieldComparison::Absolute]) > qi::lit("=") > comparisonValue)[qi::_val = phoenix::bind(&FormulaParserGrammar::createShieldComparisonStruct, phoenix::ref(*this), qi::_a, qi::_1)];
shieldComparison.name("shield comparison type");
#pragma clang diagnostic push
@ -649,10 +654,9 @@ namespace storm {
std::shared_ptr<storm::logic::ShieldExpression const> FormulaParserGrammar::createShieldExpression(storm::logic::ShieldingType type, std::string name, boost::optional<std::pair<storm::logic::ShieldComparison, double>> comparisonStruct) {
if(comparisonStruct.is_initialized()) {
STORM_LOG_WARN_COND(type != storm::logic::ShieldingType::Optimal , "Comparison for optimal shield will be ignored.");
return std::shared_ptr<storm::logic::ShieldExpression>(new storm::logic::ShieldExpression(type, name, comparisonStruct.get().first, comparisonStruct.get().second));
} else {
STORM_LOG_THROW(type == storm::logic::ShieldingType::Optimal , storm::exceptions::WrongFormatException, "Construction of safety shield needs a comparison parameter (lambda or gamma)");
STORM_LOG_INFO("Construction of shield without a comparison parameter (lambda or gamma) will default to 'lambda=0'");
return std::shared_ptr<storm::logic::ShieldExpression>(new storm::logic::ShieldExpression(type, name));
}
}

2
src/storm-parsers/parser/FormulaParserGrammar.h

@ -237,7 +237,7 @@ namespace storm {
// Shielding properties
qi::rule<Iterator, std::shared_ptr<storm::logic::ShieldExpression const>(), Skipper> shieldExpression;
qi::rule<Iterator, storm::logic::ShieldingType, Skipper> shieldingType;
qi::rule<Iterator, double, Skipper> probability;
qi::rule<Iterator, double, Skipper> comparisonValue;
qi::rule<Iterator, std::pair<storm::logic::ShieldComparison, double>, qi::locals<storm::logic::ShieldComparison>, Skipper> shieldComparison;
// Start symbol

26
src/storm/logic/ShieldExpression.cpp

@ -26,8 +26,12 @@ namespace storm {
return type == storm::logic::ShieldingType::PostSafety;
}
bool ShieldExpression::isOptimalShield() const {
return type == storm::logic::ShieldingType::Optimal;
bool ShieldExpression::isOptimalPreShield() const {
return type == storm::logic::ShieldingType::OptimalPre;
}
bool ShieldExpression::isOptimalPostShield() const {
return type == storm::logic::ShieldingType::OptimalPost;
}
double ShieldExpression::getValue() const {
@ -36,9 +40,10 @@ namespace storm {
std::string ShieldExpression::typeToString() const {
switch(type) {
case storm::logic::ShieldingType::PostSafety: return "PostSafety";
case storm::logic::ShieldingType::PreSafety: return "PreSafety";
case storm::logic::ShieldingType::Optimal: return "Optimal";
case storm::logic::ShieldingType::PostSafety: return "Post";
case storm::logic::ShieldingType::PreSafety: return "Pre";
case storm::logic::ShieldingType::OptimalPre: return "OptimalPre";
case storm::logic::ShieldingType::OptimalPost: return "OptimalPost";
}
}
@ -57,14 +62,13 @@ namespace storm {
std::string prettyString = "";
std::string comparisonType = isRelative() ? "relative" : "absolute";
switch(type) {
case storm::logic::ShieldingType::PostSafety: prettyString += "Post-Safety"; break;
case storm::logic::ShieldingType::PreSafety: prettyString += "Pre-Safety"; break;
case storm::logic::ShieldingType::Optimal: prettyString += "Optimal"; break;
case storm::logic::ShieldingType::PostSafety: prettyString += "Post-Safety"; break;
case storm::logic::ShieldingType::PreSafety: prettyString += "Pre-Safety"; break;
case storm::logic::ShieldingType::OptimalPre: prettyString += "Optimal-Pre"; break;
case storm::logic::ShieldingType::OptimalPost: prettyString += "Optimal-Post"; break;
}
prettyString += "-Shield ";
if(!(type == storm::logic::ShieldingType::Optimal)) {
prettyString += "with " + comparisonType + " comparison (" + comparisonToString() + " = " + std::to_string(value) + "):";
}
prettyString += "with " + comparisonType + " comparison (" + comparisonToString() + " = " + std::to_string(value) + "):";
return prettyString;
}

10
src/storm/logic/ShieldExpression.h

@ -9,7 +9,8 @@ namespace storm {
enum class ShieldingType {
PostSafety,
PreSafety,
Optimal
OptimalPre,
OptimalPost
};
enum class ShieldComparison { Absolute, Relative };
@ -23,7 +24,8 @@ namespace storm {
bool isRelative() const;
bool isPreSafetyShield() const;
bool isPostSafetyShield() const;
bool isOptimalShield() const;
bool isOptimalPreShield() const;
bool isOptimalPostShield() const;
double getValue() const;
@ -36,8 +38,8 @@ namespace storm {
private:
ShieldingType type;
ShieldComparison comparison;
double value;
ShieldComparison comparison = ShieldComparison::Relative;
double value = 0;
std::string filename;
};

8
src/storm/shields/AbstractShield.h

@ -21,9 +21,13 @@ namespace tempest {
namespace utility {
template<typename ValueType, typename Compare, bool relative>
struct ChoiceFilter {
bool operator()(ValueType v, ValueType max, double shieldValue) {
bool operator()(ValueType v, ValueType opt, double shieldValue) {
Compare compare;
if(relative) return compare(v, max * shieldValue);
if(relative && std::is_same<Compare, storm::utility::ElementLessEqual<ValueType>>::value) {
return compare(v, opt + opt * shieldValue);
} else if(relative && std::is_same<Compare, storm::utility::ElementGreaterEqual<ValueType>>::value) {
return compare(v, opt * shieldValue);
}
else return compare(v, shieldValue);
}
};

54
src/storm/shields/OptimalShield.cpp

@ -6,27 +6,65 @@ namespace tempest {
namespace shields {
template<typename ValueType, typename IndexType>
OptimalShield<ValueType, IndexType>::OptimalShield(std::vector<IndexType> const& rowGroupIndices, std::vector<uint64_t> const& precomputedChoices, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) : AbstractShield<ValueType, IndexType>(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), precomputedChoices(precomputedChoices) {
OptimalShield<ValueType, IndexType>::OptimalShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) : AbstractShield<ValueType, IndexType>(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), choiceValues(choiceValues) {
// Intentionally left empty.
}
template<typename ValueType, typename IndexType>
storm::storage::OptimalScheduler<ValueType> OptimalShield<ValueType, IndexType>::construct() {
storm::storage::OptimalScheduler<ValueType> shield(this->rowGroupIndices.size() - 1);
// TODO Needs fixing as soon as we support MDPs
storm::storage::PostScheduler<ValueType> OptimalShield<ValueType, IndexType>::construct() {
if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) {
if(this->shieldingExpression->isRelative()) {
return constructWithCompareType<storm::utility::ElementLessEqual<ValueType>, true>();
} else {
return constructWithCompareType<storm::utility::ElementLessEqual<ValueType>, false>();
}
} else {
if(this->shieldingExpression->isRelative()) {
return constructWithCompareType<storm::utility::ElementGreaterEqual<ValueType>, true>();
} else {
return constructWithCompareType<storm::utility::ElementGreaterEqual<ValueType>, false>();
}
}
}
template<typename ValueType, typename IndexType>
template<typename Compare, bool relative>
storm::storage::PostScheduler<ValueType> OptimalShield<ValueType, IndexType>::constructWithCompareType() {
tempest::shields::utility::ChoiceFilter<ValueType, Compare, relative> choiceFilter;
storm::storage::PostScheduler<ValueType> shield(this->rowGroupIndices.size() - 1, this->computeRowGroupSizes());
auto choice_it = this->choiceValues.begin();
if(this->coalitionStates.is_initialized()) {
this->relevantStates = ~this->relevantStates;
this->relevantStates &= this->coalitionStates.get();
}
for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) {
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state];
if(this->relevantStates.get(state)) {
shield.setChoice(precomputedChoices[state], state);
auto maxProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it;
ValueType maxProbability = *(choice_it + maxProbabilityIndex);
if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) {
STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state);
shield.setChoice(storm::storage::PostSchedulerChoice<ValueType>(), state, 0);
choice_it += rowGroupSize;
continue;
}
storm::storage::PostSchedulerChoice<ValueType> choiceMapping;
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
if(choiceFilter(*choice_it, maxProbability, this->shieldingExpression->getValue())) {
choiceMapping.addChoice(choice, choice);
} else {
choiceMapping.addChoice(choice, maxProbabilityIndex);
}
}
shield.setChoice(choiceMapping, state, 0);
} else {
shield.setChoice(storm::storage::Distribution<ValueType, IndexType>(), state);
shield.setChoice(storm::storage::PostSchedulerChoice<ValueType>(), state, 0);
choice_it += rowGroupSize;
}
}
return shield;
}
// Explicitly instantiate appropriate
// Explicitly instantiate appropriate classes
template class OptimalShield<double, typename storm::storage::SparseMatrix<double>::index_type>;
#ifdef STORM_HAVE_CARL
template class OptimalShield<storm::RationalNumber, typename storm::storage::SparseMatrix<storm::RationalNumber>::index_type>;

10
src/storm/shields/OptimalShield.h

@ -1,7 +1,7 @@
#pragma once
#include "storm/shields/AbstractShield.h"
#include "storm/storage/OptimalScheduler.h"
#include "storm/storage/PostScheduler.h"
namespace tempest {
namespace shields {
@ -9,11 +9,13 @@ namespace tempest {
template<typename ValueType, typename IndexType>
class OptimalShield : public AbstractShield<ValueType, IndexType> {
public:
OptimalShield(std::vector<IndexType> const& rowGroupIndices, std::vector<uint64_t> const& precomputedChoices, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
OptimalShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
storm::storage::OptimalScheduler<ValueType> construct();
storm::storage::PostScheduler<ValueType> construct();
template<typename Compare, bool relative>
storm::storage::PostScheduler<ValueType> constructWithCompareType();
private:
std::vector<uint64_t> precomputedChoices;
std::vector<ValueType> choiceValues;
};
}
}

27
src/storm/shields/PostSafetyShield.cpp → src/storm/shields/PostShield.cpp

@ -1,4 +1,4 @@
#include "storm/shields/PostSafetyShield.h"
#include "storm/shields/PostShield.h"
#include <algorithm>
@ -6,12 +6,12 @@ namespace tempest {
namespace shields {
template<typename ValueType, typename IndexType>
PostSafetyShield<ValueType, IndexType>::PostSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) : AbstractShield<ValueType, IndexType>(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), choiceValues(choiceValues) {
PostShield<ValueType, IndexType>::PostShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) : AbstractShield<ValueType, IndexType>(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), choiceValues(choiceValues) {
// Intentionally left empty.
}
template<typename ValueType, typename IndexType>
storm::storage::PostScheduler<ValueType> PostSafetyShield<ValueType, IndexType>::construct() {
storm::storage::PostScheduler<ValueType> PostShield<ValueType, IndexType>::construct() {
if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) {
if(this->shieldingExpression->isRelative()) {
return constructWithCompareType<storm::utility::ElementLessEqual<ValueType>, true>();
@ -29,19 +29,22 @@ namespace tempest {
template<typename ValueType, typename IndexType>
template<typename Compare, bool relative>
storm::storage::PostScheduler<ValueType> PostSafetyShield<ValueType, IndexType>::constructWithCompareType() {
storm::storage::PostScheduler<ValueType> PostShield<ValueType, IndexType>::constructWithCompareType() {
tempest::shields::utility::ChoiceFilter<ValueType, Compare, relative> choiceFilter;
storm::storage::PostScheduler<ValueType> shield(this->rowGroupIndices.size() - 1, this->computeRowGroupSizes());
auto choice_it = this->choiceValues.begin();
if(this->coalitionStates.is_initialized()) {
this->relevantStates &= this->coalitionStates.get();
this->relevantStates &= ~this->coalitionStates.get();
}
for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) {
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state];
if(this->relevantStates.get(state)) {
auto maxProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it;
ValueType maxProbability = *(choice_it + maxProbabilityIndex);
if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) {
auto optProbabilityIndex = std::min_element(choice_it, choice_it + rowGroupSize) - choice_it;
if(std::is_same<Compare, storm::utility::ElementGreaterEqual<ValueType>>::value) {
optProbabilityIndex = std::max_element(choice_it, choice_it + rowGroupSize) - choice_it;
}
ValueType optProbability = *(choice_it + optProbabilityIndex);
if(!relative && !choiceFilter(optProbability, optProbability, this->shieldingExpression->getValue())) {
STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state);
shield.setChoice(storm::storage::PostSchedulerChoice<ValueType>(), state, 0);
choice_it += rowGroupSize;
@ -49,10 +52,10 @@ namespace tempest {
}
storm::storage::PostSchedulerChoice<ValueType> choiceMapping;
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
if(choiceFilter(*choice_it, maxProbability, this->shieldingExpression->getValue())) {
if(choiceFilter(*choice_it, optProbability, this->shieldingExpression->getValue())) {
choiceMapping.addChoice(choice, choice);
} else {
choiceMapping.addChoice(choice, maxProbabilityIndex);
choiceMapping.addChoice(choice, optProbabilityIndex);
}
}
shield.setChoice(choiceMapping, state, 0);
@ -65,9 +68,9 @@ namespace tempest {
}
// Explicitly instantiate appropriate classes
template class PostSafetyShield<double, typename storm::storage::SparseMatrix<double>::index_type>;
template class PostShield<double, typename storm::storage::SparseMatrix<double>::index_type>;
#ifdef STORM_HAVE_CARL
template class PostSafetyShield<storm::RationalNumber, typename storm::storage::SparseMatrix<storm::RationalNumber>::index_type>;
template class PostShield<storm::RationalNumber, typename storm::storage::SparseMatrix<storm::RationalNumber>::index_type>;
#endif
}
}

4
src/storm/shields/PostSafetyShield.h → src/storm/shields/PostShield.h

@ -7,9 +7,9 @@ namespace tempest {
namespace shields {
template<typename ValueType, typename IndexType>
class PostSafetyShield : public AbstractShield<ValueType, IndexType> {
class PostShield : public AbstractShield<ValueType, IndexType> {
public:
PostSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
PostShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
storm::storage::PostScheduler<ValueType> construct();
template<typename Compare, bool relative>

25
src/storm/shields/PreSafetyShield.cpp → src/storm/shields/PreShield.cpp

@ -1,4 +1,4 @@
#include "storm/shields/PreSafetyShield.h"
#include "storm/shields/PreShield.h"
#include <algorithm>
@ -6,12 +6,12 @@ namespace tempest {
namespace shields {
template<typename ValueType, typename IndexType>
PreSafetyShield<ValueType, IndexType>::PreSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) : AbstractShield<ValueType, IndexType>(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), choiceValues(choiceValues) {
PreShield<ValueType, IndexType>::PreShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) : AbstractShield<ValueType, IndexType>(rowGroupIndices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates), choiceValues(choiceValues) {
// Intentionally left empty.
}
template<typename ValueType, typename IndexType>
storm::storage::PreScheduler<ValueType> PreSafetyShield<ValueType, IndexType>::construct() {
storm::storage::PreScheduler<ValueType> PreShield<ValueType, IndexType>::construct() {
if (this->getOptimizationDirection() == storm::OptimizationDirection::Minimize) {
if(this->shieldingExpression->isRelative()) {
return constructWithCompareType<storm::utility::ElementLessEqual<ValueType>, true>();
@ -29,26 +29,31 @@ namespace tempest {
template<typename ValueType, typename IndexType>
template<typename Compare, bool relative>
storm::storage::PreScheduler<ValueType> PreSafetyShield<ValueType, IndexType>::constructWithCompareType() {
storm::storage::PreScheduler<ValueType> PreShield<ValueType, IndexType>::constructWithCompareType() {
tempest::shields::utility::ChoiceFilter<ValueType, Compare, relative> choiceFilter;
storm::storage::PreScheduler<ValueType> shield(this->rowGroupIndices.size() - 1);
auto choice_it = this->choiceValues.begin();
if(this->coalitionStates.is_initialized()) {
this->relevantStates &= this->coalitionStates.get();
this->relevantStates &= ~this->coalitionStates.get();
}
for(uint state = 0; state < this->rowGroupIndices.size() - 1; state++) {
uint rowGroupSize = this->rowGroupIndices[state + 1] - this->rowGroupIndices[state];
if(this->relevantStates.get(state)) {
storm::storage::PreSchedulerChoice<ValueType> enabledChoices;
ValueType maxProbability = *std::max_element(choice_it, choice_it + rowGroupSize);
if(!relative && !choiceFilter(maxProbability, maxProbability, this->shieldingExpression->getValue())) {
ValueType optProbability;
if(std::is_same<Compare, storm::utility::ElementGreaterEqual<ValueType>>::value) {
optProbability = *std::max_element(choice_it, choice_it + rowGroupSize);
} else {
optProbability = *std::min_element(choice_it, choice_it + rowGroupSize);
}
if(!relative && !choiceFilter(optProbability, optProbability, this->shieldingExpression->getValue())) {
STORM_LOG_WARN("No shielding action possible with absolute comparison for state with index " << state);
shield.setChoice(storm::storage::PreSchedulerChoice<ValueType>(), state, 0);
choice_it += rowGroupSize;
continue;
}
for(uint choice = 0; choice < rowGroupSize; choice++, choice_it++) {
if(choiceFilter(*choice_it, maxProbability, this->shieldingExpression->getValue())) {
if(choiceFilter(*choice_it, optProbability, this->shieldingExpression->getValue())) {
enabledChoices.addChoice(choice, *choice_it);
}
}
@ -63,9 +68,9 @@ namespace tempest {
return shield;
}
// Explicitly instantiate appropriate classes
template class PreSafetyShield<double, typename storm::storage::SparseMatrix<double>::index_type>;
template class PreShield<double, typename storm::storage::SparseMatrix<double>::index_type>;
#ifdef STORM_HAVE_CARL
template class PreSafetyShield<storm::RationalNumber, typename storm::storage::SparseMatrix<storm::RationalNumber>::index_type>;
template class PreShield<storm::RationalNumber, typename storm::storage::SparseMatrix<storm::RationalNumber>::index_type>;
#endif
}
}

4
src/storm/shields/PreSafetyShield.h → src/storm/shields/PreShield.h

@ -7,9 +7,9 @@ namespace tempest {
namespace shields {
template<typename ValueType, typename IndexType>
class PreSafetyShield : public AbstractShield<ValueType, IndexType> {
class PreShield : public AbstractShield<ValueType, IndexType> {
public:
PreSafetyShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
PreShield(std::vector<IndexType> const& rowGroupIndices, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
storm::storage::PreScheduler<ValueType> construct();
template<typename Compare, bool relative>

19
src/storm/shields/ShieldHandling.cpp

@ -10,11 +10,12 @@ namespace tempest {
void createShield(std::shared_ptr<storm::models::sparse::Model<ValueType>> model, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) {
std::ofstream stream;
storm::utility::openFile(shieldFilename(shieldingExpression), stream);
if(coalitionStates.is_initialized()) coalitionStates.get().complement();
if(shieldingExpression->isPreSafetyShield()) {
PreSafetyShield<ValueType, IndexType> shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates);
PreShield<ValueType, IndexType> shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates);
shield.construct().printToStream(stream, shieldingExpression, model);
} else if(shieldingExpression->isPostSafetyShield()) {
PostSafetyShield<ValueType, IndexType> shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates);
PostShield<ValueType, IndexType> shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates);
shield.construct().printToStream(stream, shieldingExpression, model);
} else {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString());
@ -24,11 +25,15 @@ namespace tempest {
}
template<typename ValueType, typename IndexType>
void createQuantitativeShield(std::shared_ptr<storm::models::sparse::Model<ValueType>> model, std::vector<uint64_t> const& precomputedChoices, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) {
void createQuantitativeShield(std::shared_ptr<storm::models::sparse::Model<ValueType>> model, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates) {
std::ofstream stream;
storm::utility::openFile(shieldFilename(shieldingExpression), stream);
if(shieldingExpression->isOptimalShield()) {
OptimalShield<ValueType, IndexType> shield(model->getTransitionMatrix().getRowGroupIndices(), precomputedChoices, shieldingExpression, optimizationDirection, relevantStates, coalitionStates);
if(coalitionStates.is_initialized()) coalitionStates.get().complement(); // TODO CHECK THIS!!!
if(shieldingExpression->isOptimalPreShield()) {
PreShield<ValueType, IndexType> shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates);
shield.construct().printToStream(stream, shieldingExpression, model);
} else if(shieldingExpression->isOptimalPostShield()) {
PostShield<ValueType, IndexType> shield(model->getTransitionMatrix().getRowGroupIndices(), choiceValues, shieldingExpression, optimizationDirection, relevantStates, coalitionStates);
shield.construct().printToStream(stream, shieldingExpression, model);
} else {
STORM_LOG_THROW(false, storm::exceptions::InvalidArgumentException, "Unknown Shielding Type: " + shieldingExpression->typeToString());
@ -38,10 +43,10 @@ namespace tempest {
}
// Explicitly instantiate appropriate
template void createShield<double, typename storm::storage::SparseMatrix<double>::index_type>(std::shared_ptr<storm::models::sparse::Model<double>> model, std::vector<double> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
template void createQuantitativeShield<double, typename storm::storage::SparseMatrix<double>::index_type>(std::shared_ptr<storm::models::sparse::Model<double>> model, std::vector<uint64_t> const& precomputedChoices, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
template void createQuantitativeShield<double, typename storm::storage::SparseMatrix<double>::index_type>(std::shared_ptr<storm::models::sparse::Model<double>> model, std::vector<double> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
#ifdef STORM_HAVE_CARL
template void createShield<storm::RationalNumber, typename storm::storage::SparseMatrix<storm::RationalNumber>::index_type>(std::shared_ptr<storm::models::sparse::Model<storm::RationalNumber>> model, std::vector<storm::RationalNumber> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
template void createQuantitativeShield<storm::RationalNumber, typename storm::storage::SparseMatrix<storm::RationalNumber>::index_type>(std::shared_ptr<storm::models::sparse::Model<storm::RationalNumber>> model, std::vector<uint64_t> const& precomputedChoices, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
template void createQuantitativeShield<storm::RationalNumber, typename storm::storage::SparseMatrix<storm::RationalNumber>::index_type>(std::shared_ptr<storm::models::sparse::Model<storm::RationalNumber>> model, std::vector<storm::RationalNumber> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
#endif
}
}

6
src/storm/shields/ShieldHandling.h

@ -10,8 +10,8 @@
#include "storm/logic/ShieldExpression.h"
#include "storm/shields/AbstractShield.h"
#include "storm/shields/PreSafetyShield.h"
#include "storm/shields/PostSafetyShield.h"
#include "storm/shields/PreShield.h"
#include "storm/shields/PostShield.h"
#include "storm/shields/OptimalShield.h"
#include "storm/io/file.h"
@ -27,6 +27,6 @@ namespace tempest {
void createShield(std::shared_ptr<storm::models::sparse::Model<ValueType>> model, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
template<typename ValueType, typename IndexType = storm::storage::sparse::state_type>
void createQuantitativeShield(std::shared_ptr<storm::models::sparse::Model<ValueType>> model, std::vector<uint64_t> const& precomputedChoices, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
void createQuantitativeShield(std::shared_ptr<storm::models::sparse::Model<ValueType>> model, std::vector<ValueType> const& choiceValues, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
}
}

12
src/test/storm/parser/GameShieldingParserTest.cpp

@ -20,15 +20,15 @@ TEST(GameShieldingParserTest, PreSafetyShieldTest) {
std::shared_ptr<storm::logic::ShieldExpression const> shieldExpression(nullptr);
ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression());
EXPECT_TRUE(shieldExpression->isPreSafetyShield());
EXPECT_FALSE(shieldExpression->isPostSafetyShield());
EXPECT_TRUE(shieldExpression->isPreShield());
EXPECT_FALSE(shieldExpression->isPostShield());
EXPECT_FALSE(shieldExpression->isOptimalShield());
EXPECT_TRUE(shieldExpression->isRelative());
EXPECT_EQ(std::stod(value), shieldExpression->getValue());
EXPECT_EQ(filename, shieldExpression->getFilename());
}
TEST(GameShieldingParserTest, PostSafetyShieldTest) {
TEST(GameShieldingParserTest, PostShieldTest) {
storm::parser::FormulaParser formulaParser;
std::string filename = "postSafetyShieldFileName";
@ -46,7 +46,7 @@ TEST(GameShieldingParserTest, PostSafetyShieldTest) {
std::shared_ptr<storm::logic::ShieldExpression const> shieldExpression(nullptr);
ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression());
EXPECT_FALSE(shieldExpression->isPreSafetyShield());
EXPECT_TRUE(shieldExpression->isPostSafetyShield());
EXPECT_TRUE(shieldExpression->isPostShield());
EXPECT_FALSE(shieldExpression->isOptimalShield());
EXPECT_FALSE(shieldExpression->isRelative());
EXPECT_EQ(std::stod(value), shieldExpression->getValue());
@ -74,8 +74,8 @@ TEST(GameShieldingParserTest, OptimalShieldTest) {
std::shared_ptr<storm::logic::ShieldExpression const> shieldExpression(nullptr);
ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression());
EXPECT_FALSE(shieldExpression->isPreSafetyShield());
EXPECT_FALSE(shieldExpression->isPostSafetyShield());
EXPECT_FALSE(shieldExpression->isPostShield());
EXPECT_TRUE(shieldExpression->isOptimalShield());
EXPECT_FALSE(shieldExpression->isRelative());
EXPECT_EQ(filename, shieldExpression->getFilename());
}
}

12
src/test/storm/parser/MdpShieldingParserTest.cpp

@ -18,19 +18,19 @@ TEST(MdpShieldingParserTest, PreSafetyShieldTest) {
std::shared_ptr<storm::logic::ShieldExpression const> shieldExpression(nullptr);
ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression());
EXPECT_TRUE(shieldExpression->isPreSafetyShield());
EXPECT_FALSE(shieldExpression->isPostSafetyShield());
EXPECT_FALSE(shieldExpression->isPostShield());
EXPECT_FALSE(shieldExpression->isOptimalShield());
EXPECT_FALSE(shieldExpression->isRelative());
EXPECT_EQ(std::stod(value), shieldExpression->getValue());
EXPECT_EQ(filename, shieldExpression->getFilename());
}
TEST(MdpShieldingParserTest, PostSafetyShieldTest) {
TEST(MdpShieldingParserTest, PostShieldTest) {
storm::parser::FormulaParser formulaParser;
std::string filename = "postSafetyShieldFileName";
std::string value = "0.95";
std::string input = "<" + filename + ", PostSafety, lambda=" + value + "> Pmin=? [X !\"label\"]";
std::string input = "<" + filename + ", Post, lambda=" + value + "> Pmin=? [X !\"label\"]";
std::shared_ptr<storm::logic::Formula const> formula(nullptr);
std::vector<storm::jani::Property> property;
@ -40,7 +40,7 @@ TEST(MdpShieldingParserTest, PostSafetyShieldTest) {
std::shared_ptr<storm::logic::ShieldExpression const> shieldExpression(nullptr);
ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression());
EXPECT_FALSE(shieldExpression->isPreSafetyShield());
EXPECT_TRUE(shieldExpression->isPostSafetyShield());
EXPECT_TRUE(shieldExpression->isPostShield());
EXPECT_FALSE(shieldExpression->isOptimalShield());
EXPECT_TRUE(shieldExpression->isRelative());
EXPECT_EQ(std::stod(value), shieldExpression->getValue());
@ -65,8 +65,8 @@ TEST(MdpShieldingParserTest, OptimalShieldTest) {
std::shared_ptr<storm::logic::ShieldExpression const> shieldExpression(nullptr);
ASSERT_NO_THROW(shieldExpression = property.at(0).getShieldingExpression());
EXPECT_FALSE(shieldExpression->isPreSafetyShield());
EXPECT_FALSE(shieldExpression->isPostSafetyShield());
EXPECT_FALSE(shieldExpression->isPostShield());
EXPECT_TRUE(shieldExpression->isOptimalShield());
EXPECT_FALSE(shieldExpression->isRelative());
EXPECT_EQ(filename, shieldExpression->getFilename());
}
}
Loading…
Cancel
Save