#pragma once #include <boost/optional.hpp> #include <iostream> #include <string> #include <memory> #include "storm/storage/Scheduler.h" #include "storm/storage/SchedulerChoice.h" #include "storm/storage/BitVector.h" #include "storm/storage/Distribution.h" #include "storm/utility/constants.h" #include "storm/solver/OptimizationDirection.h" #include "storm/logic/ShieldExpression.h" #include "storm/exceptions/NotSupportedException.h" namespace tempest { namespace shields { template<typename ValueType, typename IndexType> class PreShield; template<typename ValueType, typename IndexType> class PostShield; template<typename ValueType, typename IndexType> class OptimalShield; namespace utility { template<typename ValueType, typename Compare, bool relative> struct ChoiceFilter { bool operator()(ValueType v, ValueType opt, double shieldValue) { if constexpr (std::is_same_v<ValueType, storm::RationalNumber> || std::is_same_v<ValueType, double>) { Compare compare; if(relative && std::is_same<Compare, storm::utility::ElementLessEqual<ValueType>>::value) { return compare(v, opt + opt * shieldValue); } else if(relative && std::is_same<Compare, storm::utility::ElementGreaterEqual<ValueType>>::value) { return compare(v, opt * shieldValue); } else return compare(v, shieldValue); } else { STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Cannot create shields for parametric models"); } } }; } template<typename ValueType, typename IndexType> class AbstractShield { public: typedef IndexType index_type; typedef ValueType value_type; virtual ~AbstractShield() = 0; /*! * Computes the sizes of the row groups based on the indices given. */ std::vector<IndexType> computeRowGroupSizes(); storm::OptimizationDirection getOptimizationDirection(); void setShieldingExpression(std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression); std::string getClassName() const; virtual bool isPreShield() const; virtual bool isPostShield() const; virtual bool isOptimalShield() const; PreShield<ValueType, IndexType>& asPreShield(); PreShield<ValueType, IndexType> const& asPreShield() const; PostShield<ValueType, IndexType>& asPostShield(); PostShield<ValueType, IndexType> const& asPostShield() const; OptimalShield<ValueType, IndexType>& asOptimalShield(); OptimalShield<ValueType, IndexType> const& asOptimalShield() const; virtual void printToStream(std::ostream& out, std::shared_ptr<storm::models::sparse::Model<ValueType>> const& model) = 0; virtual void printJsonToStream(std::ostream& out, std::shared_ptr<storm::models::sparse::Model<ValueType>> const& model) = 0; protected: AbstractShield(std::vector<IndexType> const& rowGroupIndices, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates); std::vector<index_type> rowGroupIndices; //std::vector<value_type> choiceValues; std::shared_ptr<storm::logic::ShieldExpression const> shieldingExpression; storm::OptimizationDirection optimizationDirection; storm::storage::BitVector relevantStates; boost::optional<storm::storage::BitVector> coalitionStates; }; } }