99 lines
3.9 KiB

#pragma once
#include <boost/optional.hpp>
#include <iostream>
#include <string>
#include <memory>
#include "storm/storage/Scheduler.h"
#include "storm/storage/SchedulerChoice.h"
#include "storm/storage/BitVector.h"
#include "storm/storage/Distribution.h"
#include "storm/utility/constants.h"
#include "storm/solver/OptimizationDirection.h"
#include "storm/logic/ShieldExpression.h"
#include "storm/exceptions/NotSupportedException.h"
namespace tempest {
namespace shields {
template<typename ValueType, typename IndexType>
class PreShield;
template<typename ValueType, typename IndexType>
class PostShield;
template<typename ValueType, typename IndexType>
class OptimalShield;
namespace utility {
template<typename ValueType, typename Compare, bool relative>
struct ChoiceFilter {
bool operator()(ValueType v, ValueType opt, double shieldValue) {
if constexpr (std::is_same_v<ValueType, storm::RationalNumber> || std::is_same_v<ValueType, double>) {
Compare compare;
if(relative && std::is_same<Compare, storm::utility::ElementLessEqual<ValueType>>::value) {
return compare(v, opt + opt * shieldValue);
} else if(relative && std::is_same<Compare, storm::utility::ElementGreaterEqual<ValueType>>::value) {
return compare(v, opt * shieldValue);
}
else return compare(v, shieldValue);
} else {
STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Cannot create shields for parametric models");
}
}
};
}
template<typename ValueType, typename IndexType>
class AbstractShield {
public:
typedef IndexType index_type;
typedef ValueType value_type;
virtual ~AbstractShield() = 0;
/*!
* Computes the sizes of the row groups based on the indices given.
*/
std::vector<IndexType> computeRowGroupSizes();
storm::OptimizationDirection getOptimizationDirection();
void setShieldingExpression(std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression);
std::string getClassName() const;
virtual bool isPreShield() const;
virtual bool isPostShield() const;
virtual bool isOptimalShield() const;
PreShield<ValueType, IndexType>& asPreShield();
PreShield<ValueType, IndexType> const& asPreShield() const;
PostShield<ValueType, IndexType>& asPostShield();
PostShield<ValueType, IndexType> const& asPostShield() const;
OptimalShield<ValueType, IndexType>& asOptimalShield();
OptimalShield<ValueType, IndexType> const& asOptimalShield() const;
virtual void printToStream(std::ostream& out, std::shared_ptr<storm::models::sparse::Model<ValueType>> const& model) = 0;
virtual void printJsonToStream(std::ostream& out, std::shared_ptr<storm::models::sparse::Model<ValueType>> const& model) = 0;
protected:
AbstractShield(std::vector<IndexType> const& rowGroupIndices, std::shared_ptr<storm::logic::ShieldExpression const> const& shieldingExpression, storm::OptimizationDirection optimizationDirection, storm::storage::BitVector relevantStates, boost::optional<storm::storage::BitVector> coalitionStates);
std::vector<index_type> rowGroupIndices;
//std::vector<value_type> choiceValues;
std::shared_ptr<storm::logic::ShieldExpression const> shieldingExpression;
storm::OptimizationDirection optimizationDirection;
storm::storage::BitVector relevantStates;
boost::optional<storm::storage::BitVector> coalitionStates;
};
}
}