You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
340 lines
18 KiB
340 lines
18 KiB
#ifndef STORM_MODELCHECKER_REACHABILITY_SPARSEMDPLEARNINGMODELCHECKER_H_
|
|
#define STORM_MODELCHECKER_REACHABILITY_SPARSEMDPLEARNINGMODELCHECKER_H_
|
|
|
|
#include <random>
|
|
|
|
#include "src/modelchecker/AbstractModelChecker.h"
|
|
|
|
#include "src/storage/prism/Program.h"
|
|
#include "src/storage/sparse/StateStorage.h"
|
|
|
|
#include "src/generator/PrismNextStateGenerator.h"
|
|
#include "src/generator/CompressedState.h"
|
|
#include "src/generator/VariableInformation.h"
|
|
|
|
#include "src/utility/ConstantsComparator.h"
|
|
#include "src/utility/constants.h"
|
|
|
|
namespace storm {
|
|
namespace storage {
|
|
namespace sparse {
|
|
template<typename StateType>
|
|
class StateStorage;
|
|
}
|
|
}
|
|
|
|
namespace generator {
|
|
template<typename ValueType, typename StateType>
|
|
class PrismNextStateGenerator;
|
|
}
|
|
|
|
namespace modelchecker {
|
|
|
|
template<typename ValueType>
|
|
class SparseMdpLearningModelChecker : public AbstractModelChecker {
|
|
public:
|
|
typedef uint32_t StateType;
|
|
typedef uint32_t ActionType;
|
|
typedef boost::container::flat_set<StateType> StateSet;
|
|
typedef boost::container::flat_set<ActionType> ActionSet;
|
|
typedef std::shared_ptr<ActionSet> ActionSetPointer;
|
|
typedef std::vector<std::pair<StateType, ActionType>> StateActionStack;
|
|
|
|
SparseMdpLearningModelChecker(storm::prism::Program const& program, boost::optional<std::map<storm::expressions::Variable, storm::expressions::Expression>> const& constantDefinitions);
|
|
|
|
virtual bool canHandle(CheckTask<storm::logic::Formula> const& checkTask) const override;
|
|
|
|
virtual std::unique_ptr<CheckResult> computeReachabilityProbabilities(CheckTask<storm::logic::EventuallyFormula> const& checkTask) override;
|
|
|
|
private:
|
|
// A struct that keeps track of certain statistics during the computation.
|
|
struct Statistics {
|
|
Statistics() : iterations(0), maxPathLength(0), numberOfTargetStates(0), numberOfExploredStates(0), pathLengthUntilEndComponentDetection(27) {
|
|
// Intentionally left empty.
|
|
}
|
|
|
|
std::size_t iterations;
|
|
std::size_t maxPathLength;
|
|
std::size_t numberOfTargetStates;
|
|
std::size_t numberOfExploredStates;
|
|
std::size_t pathLengthUntilEndComponentDetection;
|
|
};
|
|
|
|
// A struct containing the data required for state exploration.
|
|
struct StateGeneration {
|
|
StateGeneration(storm::prism::Program const& program, storm::generator::VariableInformation const& variableInformation, storm::expressions::Expression const& targetStateExpression) : generator(program, variableInformation, false), targetStateExpression(targetStateExpression) {
|
|
// Intentionally left empty.
|
|
}
|
|
|
|
std::vector<StateType> getInitialStates() {
|
|
return generator.getInitialStates(stateToIdCallback);
|
|
}
|
|
|
|
storm::generator::StateBehavior<ValueType, StateType> expand() {
|
|
return generator.expand(stateToIdCallback);
|
|
}
|
|
|
|
bool isTargetState() const {
|
|
return generator.satisfies(targetStateExpression);
|
|
}
|
|
|
|
storm::generator::PrismNextStateGenerator<ValueType, StateType> generator;
|
|
std::function<StateType (storm::generator::CompressedState const&)> stateToIdCallback;
|
|
storm::expressions::Expression targetStateExpression;
|
|
};
|
|
|
|
// A structure containing the data assembled during exploration.
|
|
struct ExplorationInformation {
|
|
ExplorationInformation(uint_fast64_t bitsPerBucket, ActionType const& unexploredMarker = std::numeric_limits<ActionType>::max()) : stateStorage(bitsPerBucket), unexploredMarker(unexploredMarker) {
|
|
// Intentionally left empty.
|
|
}
|
|
|
|
storm::storage::sparse::StateStorage<StateType> stateStorage;
|
|
|
|
std::vector<std::vector<storm::storage::MatrixEntry<StateType, ValueType>>> matrix;
|
|
std::vector<StateType> rowGroupIndices;
|
|
|
|
std::vector<StateType> stateToRowGroupMapping;
|
|
StateType unexploredMarker;
|
|
std::unordered_map<StateType, storm::generator::CompressedState> unexploredStates;
|
|
|
|
storm::OptimizationDirection optimizationDirection;
|
|
StateSet terminalStates;
|
|
|
|
std::unordered_map<StateType, ActionSetPointer> stateToLeavingChoicesOfEndComponent;
|
|
|
|
void setInitialStates(std::vector<StateType> const& initialStates) {
|
|
stateStorage.initialStateIndices = initialStates;
|
|
}
|
|
|
|
StateType getFirstInitialState() const {
|
|
return stateStorage.initialStateIndices.front();
|
|
}
|
|
|
|
std::size_t getNumberOfInitialStates() const {
|
|
return stateStorage.initialStateIndices.size();
|
|
}
|
|
|
|
void addUnexploredState(storm::generator::CompressedState const& compressedState) {
|
|
stateToRowGroupMapping.push_back(unexploredMarker);
|
|
unexploredStates[stateStorage.numberOfStates] = compressedState;
|
|
++stateStorage.numberOfStates;
|
|
}
|
|
|
|
void assignStateToRowGroup(StateType const& state, ActionType const& rowGroup) {
|
|
stateToRowGroupMapping[state] = rowGroup;
|
|
}
|
|
|
|
StateType assignStateToNextRowGroup(StateType const& state) {
|
|
stateToRowGroupMapping[state] = rowGroupIndices.size() - 1;
|
|
return stateToRowGroupMapping[state];
|
|
}
|
|
|
|
void newRowGroup(ActionType const& action) {
|
|
rowGroupIndices.push_back(action);
|
|
}
|
|
|
|
void newRowGroup() {
|
|
newRowGroup(matrix.size());
|
|
}
|
|
|
|
std::size_t getNumberOfUnexploredStates() const {
|
|
return unexploredStates.size();
|
|
}
|
|
|
|
std::size_t getNumberOfDiscoveredStates() const {
|
|
return stateStorage.numberOfStates;
|
|
}
|
|
|
|
StateType const& getRowGroup(StateType const& state) const {
|
|
return stateToRowGroupMapping[state];
|
|
}
|
|
|
|
StateType const& getUnexploredMarker() const {
|
|
return unexploredMarker;
|
|
}
|
|
|
|
bool isUnexplored(StateType const& state) const {
|
|
return stateToRowGroupMapping[state] == unexploredMarker;
|
|
}
|
|
|
|
bool isTerminal(StateType const& state) const {
|
|
return terminalStates.find(state) != terminalStates.end();
|
|
}
|
|
|
|
ActionType const& getStartRowOfGroup(StateType const& group) const {
|
|
return rowGroupIndices[group];
|
|
}
|
|
|
|
std::size_t getRowGroupSize(StateType const& group) const {
|
|
return rowGroupIndices[group + 1] - rowGroupIndices[group];
|
|
}
|
|
|
|
void addTerminalState(StateType const& state) {
|
|
terminalStates.insert(state);
|
|
}
|
|
|
|
std::vector<storm::storage::MatrixEntry<StateType, ValueType>>& getRowOfMatrix(ActionType const& row) {
|
|
return matrix[row];
|
|
}
|
|
|
|
std::vector<storm::storage::MatrixEntry<StateType, ValueType>> const& getRowOfMatrix(ActionType const& row) const {
|
|
return matrix[row];
|
|
}
|
|
|
|
void addRowsToMatrix(std::size_t const& count) {
|
|
matrix.resize(matrix.size() + count);
|
|
}
|
|
};
|
|
|
|
// A struct containg the lower and upper bounds per state and action.
|
|
struct BoundValues {
|
|
std::vector<ValueType> lowerBoundsPerState;
|
|
std::vector<ValueType> upperBoundsPerState;
|
|
std::vector<ValueType> lowerBoundsPerAction;
|
|
std::vector<ValueType> upperBoundsPerAction;
|
|
|
|
std::pair<ValueType, ValueType> getBoundsForState(StateType const& state, ExplorationInformation const& explorationInformation) const {
|
|
ActionType index = explorationInformation.getRowGroup(state);
|
|
if (index == explorationInformation.getUnexploredMarker()) {
|
|
return std::make_pair(storm::utility::zero<ValueType>(), storm::utility::one<ValueType>());
|
|
} else {
|
|
return std::make_pair(lowerBoundsPerState[index], upperBoundsPerState[index]);
|
|
}
|
|
}
|
|
|
|
ValueType getLowerBoundForState(StateType const& state, ExplorationInformation const& explorationInformation) const {
|
|
ActionType index = explorationInformation.getRowGroup(state);
|
|
if (index == explorationInformation.getUnexploredMarker()) {
|
|
return storm::utility::zero<ValueType>();
|
|
} else {
|
|
return getLowerBoundForRowGroup(index, explorationInformation);
|
|
}
|
|
}
|
|
|
|
ValueType getLowerBoundForRowGroup(StateType const& rowGroup, ExplorationInformation const& explorationInformation) const {
|
|
return lowerBoundsPerState[rowGroup];
|
|
}
|
|
|
|
ValueType getUpperBoundForState(StateType const& state, ExplorationInformation const& explorationInformation) const {
|
|
ActionType index = explorationInformation.getRowGroup(state);
|
|
if (index == explorationInformation.getUnexploredMarker()) {
|
|
return storm::utility::one<ValueType>();
|
|
} else {
|
|
return getUpperBoundForRowGroup(index);
|
|
}
|
|
}
|
|
|
|
ValueType const& getUpperBoundForRowGroup(StateType const& rowGroup) const {
|
|
return upperBoundsPerState[rowGroup];
|
|
}
|
|
|
|
std::pair<ValueType, ValueType> getBoundsForAction(ActionType const& action) const {
|
|
return std::make_pair(lowerBoundsPerAction[action], upperBoundsPerAction[action]);
|
|
}
|
|
|
|
ValueType const& getLowerBoundForAction(ActionType const& action) const {
|
|
return lowerBoundsPerAction[action];
|
|
}
|
|
|
|
ValueType const& getUpperBoundForAction(ActionType const& action) const {
|
|
return upperBoundsPerAction[action];
|
|
}
|
|
|
|
ValueType getDifferenceOfStateBounds(StateType const& state, ExplorationInformation const& explorationInformation) {
|
|
std::pair<ValueType, ValueType> bounds = getBoundsForState(state, explorationInformation);
|
|
return bounds.second - bounds.first;
|
|
}
|
|
|
|
void initializeBoundsForNextState(std::pair<ValueType, ValueType> const& vals = std::pair<ValueType, ValueType>(storm::utility::zero<ValueType>(), storm::utility::one<ValueType>())) {
|
|
lowerBoundsPerState.push_back(vals.first);
|
|
upperBoundsPerState.push_back(vals.second);
|
|
}
|
|
|
|
void initializeBoundsForNextAction(std::pair<ValueType, ValueType> const& vals = std::pair<ValueType, ValueType>(storm::utility::zero<ValueType>(), storm::utility::one<ValueType>())) {
|
|
lowerBoundsPerAction.push_back(vals.first);
|
|
upperBoundsPerAction.push_back(vals.second);
|
|
}
|
|
|
|
void setLowerBoundForState(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& value) {
|
|
lowerBoundsPerState[explorationInformation.getRowGroup(state)] = value;
|
|
}
|
|
|
|
void setUpperBoundForState(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& value) {
|
|
upperBoundsPerState[explorationInformation.getRowGroup(state)] = value;
|
|
}
|
|
|
|
void setBoundsForAction(ActionType const& action, std::pair<ValueType, ValueType> const& values) {
|
|
lowerBoundsPerAction[action] = values.first;
|
|
upperBoundsPerAction[action] = values.second;
|
|
}
|
|
|
|
void setBoundsForState(StateType const& state, ExplorationInformation const& explorationInformation, std::pair<ValueType, ValueType> const& values) {
|
|
StateType const& rowGroup = explorationInformation.getRowGroup(state);
|
|
lowerBoundsPerState[rowGroup] = values.first;
|
|
upperBoundsPerState[rowGroup] = values.second;
|
|
}
|
|
|
|
bool setNewLowerBoundOfStateIfGreaterThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newLowerValue) {
|
|
StateType const& rowGroup = explorationInformation.getRowGroup(state);
|
|
if (lowerBoundsPerState[rowGroup] < newLowerValue) {
|
|
lowerBoundsPerState[rowGroup] = newLowerValue;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool setNewUpperBoundOfStateIfLessThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newUpperValue) {
|
|
StateType const& rowGroup = explorationInformation.getRowGroup(state);
|
|
if (newUpperValue < upperBoundsPerState[rowGroup]) {
|
|
upperBoundsPerState[rowGroup] = newUpperValue;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
};
|
|
|
|
storm::expressions::Expression getTargetStateExpression(storm::logic::Formula const& subformula) const;
|
|
|
|
std::function<StateType (storm::generator::CompressedState const&)> createStateToIdCallback(ExplorationInformation& explorationInformation) const;
|
|
|
|
std::tuple<StateType, ValueType, ValueType> performLearningProcedure(StateGeneration& stateGeneration, ExplorationInformation& explorationInformation) const;
|
|
|
|
bool samplePathFromState(StateGeneration& stateGeneration, ExplorationInformation& explorationInformation, StateActionStack& stack, BoundValues& bounds, Statistics& stats) const;
|
|
|
|
bool exploreState(StateGeneration& stateGeneration, StateType const& currentStateId, storm::generator::CompressedState const& currentState, ExplorationInformation& explorationInformation, BoundValues& bounds, Statistics& stats) const;
|
|
|
|
uint32_t sampleMaxAction(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues& bounds) const;
|
|
|
|
StateType sampleSuccessorFromAction(ActionType const& chosenAction, ExplorationInformation const& explorationInformation) const;
|
|
|
|
void detectEndComponents(StateActionStack const& stack, ExplorationInformation& explorationInformation, BoundValues& bounds) const;
|
|
|
|
void updateProbabilityBoundsAlongSampledPath(StateActionStack& stack, ExplorationInformation const& explorationInformation, BoundValues& bounds) const;
|
|
|
|
void updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const;
|
|
|
|
std::pair<ValueType, ValueType> computeBoundsOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
|
|
ValueType computeUpperBoundOverAllOtherActions(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
|
|
std::pair<ValueType, ValueType> computeBoundsOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
|
|
ValueType computeLowerBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
|
|
ValueType computeUpperBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
|
|
ValueType computeLowerBoundOfState(StateType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
|
|
ValueType computeUpperBoundOfState(StateType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const;
|
|
|
|
// The program that defines the model to check.
|
|
storm::prism::Program program;
|
|
|
|
// The variable information.
|
|
storm::generator::VariableInformation variableInformation;
|
|
|
|
// The random number generator.
|
|
mutable std::default_random_engine randomGenerator;
|
|
|
|
// A comparator used to determine whether values are equal.
|
|
storm::utility::ConstantsComparator<ValueType> comparator;
|
|
};
|
|
}
|
|
}
|
|
|
|
#endif /* STORM_MODELCHECKER_REACHABILITY_SPARSEMDPLEARNINGMODELCHECKER_H_ */
|