Browse Source

first version for action mask callbacks in explicit generator

tempestpy_adaptions
Sebastian Junges 4 years ago
parent
commit
4a7ea35959
  1. 5
      src/storm/api/builder.h
  2. 34
      src/storm/generator/NextStateGenerator.cpp
  3. 32
      src/storm/generator/NextStateGenerator.h
  4. 14
      src/storm/generator/PrismNextStateGenerator.cpp
  5. 4
      src/storm/generator/PrismNextStateGenerator.h

5
src/storm/api/builder.h

@ -85,11 +85,12 @@ namespace storm {
* @return A builder
*/
template<typename ValueType>
storm::builder::ExplicitModelBuilder<ValueType> makeExplicitModelBuilder(storm::storage::SymbolicModelDescription const& model, storm::builder::BuilderOptions const& options) {
storm::builder::ExplicitModelBuilder<ValueType> makeExplicitModelBuilder(storm::storage::SymbolicModelDescription const& model, storm::builder::BuilderOptions const& options, std::shared_ptr<storm::generator::ActionMask<ValueType>> actionMask = nullptr) {
std::shared_ptr<storm::generator::NextStateGenerator<ValueType, uint32_t>> generator;
if (model.isPrismProgram()) {
generator = std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(model.asPrismProgram(), options);
generator = std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(model.asPrismProgram(), options, actionMask);
} else if (model.isJaniModel()) {
STORM_LOG_THROW(actionMask == nullptr, storm::exceptions::NotSupportedException, "Action masks for JANI are not yet supported");
generator = std::make_shared<storm::generator::JaniNextStateGenerator<ValueType, uint32_t>>(model.asJaniModel(), options);
} else {
STORM_LOG_THROW(false, storm::exceptions::NotSupportedException, "Cannot build sparse model from this symbolic model description.");

34
src/storm/generator/NextStateGenerator.cpp

@ -16,9 +16,23 @@
namespace storm {
namespace generator {
template<typename ValueType, typename StateType>
StateValuationFunctionMask<ValueType,StateType>::StateValuationFunctionMask(std::function<bool(storm::expressions::SimpleValuation const&, uint64_t)> const& f)
: func(f)
{
// Intentionally left empty
}
template<typename ValueType, typename StateType>
bool StateValuationFunctionMask<ValueType,StateType>::query(storm::generator::NextStateGenerator<ValueType,StateType> const& generator, uint64_t actionIndex) {
auto val = generator.currentStateToSimpleValuation();
bool res = func(val, actionIndex);
return res;
}
template<typename ValueType, typename StateType>
NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(variableInformation), evaluator(nullptr), state(nullptr) {
NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(variableInformation), evaluator(nullptr), state(nullptr), actionMask(mask) {
if(variableInformation.hasOutOfBoundsBit()) {
outOfBoundsState = createOutOfBoundsState(variableInformation);
}
@ -28,7 +42,7 @@ namespace storm {
}
template<typename ValueType, typename StateType>
NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(), evaluator(nullptr), state(nullptr) {
NextStateGenerator<ValueType, StateType>::NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask) : options(options), expressionManager(expressionManager.getSharedPointer()), variableInformation(), evaluator(nullptr), state(nullptr), actionMask(mask) {
if(variableInformation.hasOutOfBoundsBit()) {
outOfBoundsState = createOutOfBoundsState(variableInformation);
}
@ -270,6 +284,12 @@ namespace storm {
return result;
}
template<typename ValueType, typename StateType>
storm::expressions::SimpleValuation NextStateGenerator<ValueType, StateType>::currentStateToSimpleValuation() const {
return unpackStateIntoValuation(*state, variableInformation, *expressionManager);
}
template<typename ValueType, typename StateType>
void NextStateGenerator<ValueType, StateType>::extendStateInformation(storm::json<ValueType>&, bool) const {
// Intentionally left empty.
@ -301,7 +321,15 @@ namespace storm {
template class NextStateGenerator<double>;
template class ActionMask<double>;
template class StateValuationFunctionMask<double>;
#ifdef STORM_HAVE_CARL
template class ActionMask<storm::RationalNumber>;
template class StateValuationFunctionMask<storm::RationalNumber>;
template class ActionMask<storm::RationalFunction>;
template class StateValuationFunctionMask<storm::RationalFunction>;
template class NextStateGenerator<storm::RationalNumber>;
template class NextStateGenerator<storm::RationalFunction>;
#endif

32
src/storm/generator/NextStateGenerator.h

@ -11,6 +11,7 @@
#include "storm/storage/expressions/ExpressionEvaluator.h"
#include "storm/storage/sparse/ChoiceOrigins.h"
#include "storm/storage/sparse/StateValuations.h"
#include "storm/storage/expressions/SimpleValuation.h"
#include "storm/builder/BuilderOptions.h"
#include "storm/builder/RewardModelInformation.h"
@ -32,19 +33,41 @@ namespace storm {
MA,
POMDP
};
template<typename ValueType, typename StateType = uint32_t>
class NextStateGenerator;
template<typename ValueType, typename StateType = uint32_t>
class ActionMask {
public:
virtual ~ActionMask() = default;
virtual bool query(storm::generator::NextStateGenerator<ValueType, StateType> const &generator, uint64_t actionIndex) = 0;
};
template<typename ValueType, typename StateType = uint32_t>
class StateValuationFunctionMask : public ActionMask<ValueType,StateType> {
public:
StateValuationFunctionMask(std::function<bool (storm::expressions::SimpleValuation const&, uint64_t)> const& f);
virtual ~StateValuationFunctionMask() = default;
bool query(storm::generator::NextStateGenerator<ValueType,StateType> const& generator, uint64_t actionIndex) override;
private:
std::function<bool(storm::expressions::SimpleValuation, uint64_t)> func;
};
template<typename ValueType, typename StateType>
class NextStateGenerator {
public:
typedef std::function<StateType (CompressedState const&)> StateToIdCallback;
NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options);
NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, VariableInformation const& variableInformation, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& = nullptr);
/*!
* Creates a new next state generator. This version of the constructor default-constructs the variable information.
* Hence, the subclass is responsible for suitably initializing it in its constructor.
*/
NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options);
NextStateGenerator(storm::expressions::ExpressionManager const& expressionManager, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& = nullptr);
virtual ~NextStateGenerator() = default;
@ -73,6 +96,7 @@ namespace storm {
std::string stateToString(CompressedState const& state) const;
storm::json<ValueType> currentStateToJson(bool onlyObservable = false) const;
storm::expressions::SimpleValuation currentStateToSimpleValuation() const;
uint32_t observabilityClass(CompressedState const& state) const;
@ -147,6 +171,8 @@ namespace storm {
/// A map that stores the indices of states with overlapping guards.
boost::optional<std::vector<uint64_t>> overlappingGuardStates;
std::shared_ptr<ActionMask<ValueType,StateType>> actionMask;
};
}
}

14
src/storm/generator/PrismNextStateGenerator.cpp

@ -22,12 +22,12 @@ namespace storm {
namespace generator {
template<typename ValueType, typename StateType>
PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options) : PrismNextStateGenerator<ValueType, StateType>(program.substituteConstantsFormulas(), options, false) {
PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask) : PrismNextStateGenerator<ValueType, StateType>(program.substituteConstantsFormulas(), options, mask, false) {
// Intentionally left empty.
}
template<typename ValueType, typename StateType>
PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, bool) : NextStateGenerator<ValueType, StateType>(program.getManager(), options), program(program), rewardModels(), hasStateActionRewards(false) {
PrismNextStateGenerator<ValueType, StateType>::PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const& mask, bool) : NextStateGenerator<ValueType, StateType>(program.getManager(), options, mask), program(program), rewardModels(), hasStateActionRewards(false) {
STORM_LOG_TRACE("Creating next-state generator for PRISM program: " << program);
STORM_LOG_THROW(!this->program.specifiesSystemComposition(), storm::exceptions::WrongFormatException, "The explicit next-state generator currently does not support custom system compositions.");
@ -536,6 +536,11 @@ namespace storm {
continue;
}
}
if (this->actionMask != nullptr) {
if (!this->actionMask->query(*this, command.getActionIndex())) {
continue;
}
}
// Skip the command, if it is not enabled.
if (!this->evaluator->asBool(command.getGuardExpression())) {
@ -616,6 +621,11 @@ namespace storm {
void PrismNextStateGenerator<ValueType, StateType>::addLabeledChoices(std::vector<Choice<ValueType>>& choices, CompressedState const& state, StateToIdCallback stateToIdCallback, CommandFilter const& commandFilter) {
for (uint_fast64_t actionIndex : program.getSynchronizingActionIndices()) {
if (this->actionMask != nullptr) {
if (!this->actionMask->query(*this, actionIndex)) {
continue;
}
}
boost::optional<std::vector<std::vector<std::reference_wrapper<storm::prism::Command const>>>> optionalActiveCommandLists = getActiveCommandsByActionIndex(actionIndex, commandFilter);
// Only process this action label, if there is at least one feasible solution.

4
src/storm/generator/PrismNextStateGenerator.h

@ -23,7 +23,7 @@ namespace storm {
typedef storm::storage::FlatSet<uint_fast64_t> CommandSet;
enum class CommandFilter {All, Markovian, Probabilistic};
PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options = NextStateGeneratorOptions());
PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options = NextStateGeneratorOptions(), std::shared_ptr<ActionMask<ValueType,StateType>> const& = nullptr);
/*!
* A quick check to detect whether the given model is not supported.
@ -55,7 +55,7 @@ namespace storm {
* being called. The last argument is only present to distinguish the signature of this constructor from the
* public one.
*/
PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, bool flag);
PrismNextStateGenerator(storm::prism::Program const& program, NextStateGeneratorOptions const& options, std::shared_ptr<ActionMask<ValueType,StateType>> const&, bool flag);
/*!
* Applies an update to the state currently loaded into the evaluator and applies the resulting values to

Loading…
Cancel
Save