#ifndef STORM_MODELCHECKER_REACHABILITY_SPARSEMDPLEARNINGMODELCHECKER_H_ #define STORM_MODELCHECKER_REACHABILITY_SPARSEMDPLEARNINGMODELCHECKER_H_ #include <random> #include "src/modelchecker/AbstractModelChecker.h" #include "src/storage/prism/Program.h" #include "src/storage/sparse/StateStorage.h" #include "src/generator/PrismNextStateGenerator.h" #include "src/generator/CompressedState.h" #include "src/generator/VariableInformation.h" #include "src/utility/ConstantsComparator.h" #include "src/utility/constants.h" namespace storm { namespace storage { namespace sparse { template<typename StateType> class StateStorage; } } namespace generator { template<typename ValueType, typename StateType> class PrismNextStateGenerator; } namespace modelchecker { template<typename ValueType> class SparseMdpLearningModelChecker : public AbstractModelChecker { public: typedef uint32_t StateType; typedef uint32_t ActionType; typedef boost::container::flat_set<StateType> StateSet; typedef boost::container::flat_set<ActionType> ActionSet; typedef std::shared_ptr<ActionSet> ActionSetPointer; typedef std::vector<std::pair<StateType, ActionType>> StateActionStack; SparseMdpLearningModelChecker(storm::prism::Program const& program, boost::optional<std::map<storm::expressions::Variable, storm::expressions::Expression>> const& constantDefinitions); virtual bool canHandle(CheckTask<storm::logic::Formula> const& checkTask) const override; virtual std::unique_ptr<CheckResult> computeReachabilityProbabilities(CheckTask<storm::logic::EventuallyFormula> const& checkTask) override; private: // A struct that keeps track of certain statistics during the computation. struct Statistics { Statistics() : iterations(0), maxPathLength(0), numberOfTargetStates(0), numberOfExploredStates(0), pathLengthUntilEndComponentDetection(27) { // Intentionally left empty. } std::size_t iterations; std::size_t maxPathLength; std::size_t numberOfTargetStates; std::size_t numberOfExploredStates; std::size_t pathLengthUntilEndComponentDetection; }; // A struct containing the data required for state exploration. struct StateGeneration { StateGeneration(storm::prism::Program const& program, storm::generator::VariableInformation const& variableInformation, storm::expressions::Expression const& targetStateExpression) : generator(program, variableInformation, false), targetStateExpression(targetStateExpression) { // Intentionally left empty. } std::vector<StateType> getInitialStates() { return generator.getInitialStates(stateToIdCallback); } storm::generator::StateBehavior<ValueType, StateType> expand() { return generator.expand(stateToIdCallback); } bool isTargetState() const { return generator.satisfies(targetStateExpression); } storm::generator::PrismNextStateGenerator<ValueType, StateType> generator; std::function<StateType (storm::generator::CompressedState const&)> stateToIdCallback; storm::expressions::Expression targetStateExpression; }; // A structure containing the data assembled during exploration. struct ExplorationInformation { ExplorationInformation(uint_fast64_t bitsPerBucket, ActionType const& unexploredMarker = std::numeric_limits<ActionType>::max()) : stateStorage(bitsPerBucket), unexploredMarker(unexploredMarker) { // Intentionally left empty. } storm::storage::sparse::StateStorage<StateType> stateStorage; std::vector<std::vector<storm::storage::MatrixEntry<StateType, ValueType>>> matrix; std::vector<StateType> rowGroupIndices; std::vector<StateType> stateToRowGroupMapping; StateType unexploredMarker; std::unordered_map<StateType, storm::generator::CompressedState> unexploredStates; storm::OptimizationDirection optimizationDirection; StateSet terminalStates; std::unordered_map<StateType, ActionSetPointer> stateToLeavingChoicesOfEndComponent; void setInitialStates(std::vector<StateType> const& initialStates) { stateStorage.initialStateIndices = initialStates; } StateType getFirstInitialState() const { return stateStorage.initialStateIndices.front(); } std::size_t getNumberOfInitialStates() const { return stateStorage.initialStateIndices.size(); } void addUnexploredState(storm::generator::CompressedState const& compressedState) { stateToRowGroupMapping.push_back(unexploredMarker); unexploredStates[stateStorage.numberOfStates] = compressedState; ++stateStorage.numberOfStates; } void assignStateToRowGroup(StateType const& state, ActionType const& rowGroup) { stateToRowGroupMapping[state] = rowGroup; } StateType assignStateToNextRowGroup(StateType const& state) { stateToRowGroupMapping[state] = rowGroupIndices.size() - 1; return stateToRowGroupMapping[state]; } void newRowGroup(ActionType const& action) { rowGroupIndices.push_back(action); } void newRowGroup() { newRowGroup(matrix.size()); } std::size_t getNumberOfUnexploredStates() const { return unexploredStates.size(); } std::size_t getNumberOfDiscoveredStates() const { return stateStorage.numberOfStates; } StateType const& getRowGroup(StateType const& state) const { return stateToRowGroupMapping[state]; } StateType const& getUnexploredMarker() const { return unexploredMarker; } bool isUnexplored(StateType const& state) const { return stateToRowGroupMapping[state] == unexploredMarker; } bool isTerminal(StateType const& state) const { return terminalStates.find(state) != terminalStates.end(); } ActionType const& getStartRowOfGroup(StateType const& group) const { return rowGroupIndices[group]; } std::size_t getRowGroupSize(StateType const& group) const { return rowGroupIndices[group + 1] - rowGroupIndices[group]; } void addTerminalState(StateType const& state) { terminalStates.insert(state); } std::vector<storm::storage::MatrixEntry<StateType, ValueType>>& getRowOfMatrix(ActionType const& row) { return matrix[row]; } std::vector<storm::storage::MatrixEntry<StateType, ValueType>> const& getRowOfMatrix(ActionType const& row) const { return matrix[row]; } void addRowsToMatrix(std::size_t const& count) { matrix.resize(matrix.size() + count); } }; // A struct containg the lower and upper bounds per state and action. struct BoundValues { std::vector<ValueType> lowerBoundsPerState; std::vector<ValueType> upperBoundsPerState; std::vector<ValueType> lowerBoundsPerAction; std::vector<ValueType> upperBoundsPerAction; std::pair<ValueType, ValueType> getBoundsForState(StateType const& state, ExplorationInformation const& explorationInformation) const { ActionType index = explorationInformation.getRowGroup(state); if (index == explorationInformation.getUnexploredMarker()) { return std::make_pair(storm::utility::zero<ValueType>(), storm::utility::one<ValueType>()); } else { return std::make_pair(lowerBoundsPerState[index], upperBoundsPerState[index]); } } ValueType getLowerBoundForState(StateType const& state, ExplorationInformation const& explorationInformation) const { ActionType index = explorationInformation.getRowGroup(state); if (index == explorationInformation.getUnexploredMarker()) { return storm::utility::zero<ValueType>(); } else { return getLowerBoundForRowGroup(index, explorationInformation); } } ValueType getLowerBoundForRowGroup(StateType const& rowGroup, ExplorationInformation const& explorationInformation) const { return lowerBoundsPerState[rowGroup]; } ValueType getUpperBoundForState(StateType const& state, ExplorationInformation const& explorationInformation) const { ActionType index = explorationInformation.getRowGroup(state); if (index == explorationInformation.getUnexploredMarker()) { return storm::utility::one<ValueType>(); } else { return getUpperBoundForRowGroup(index); } } ValueType const& getUpperBoundForRowGroup(StateType const& rowGroup) const { return upperBoundsPerState[rowGroup]; } std::pair<ValueType, ValueType> getBoundsForAction(ActionType const& action) const { return std::make_pair(lowerBoundsPerAction[action], upperBoundsPerAction[action]); } ValueType const& getLowerBoundForAction(ActionType const& action) const { return lowerBoundsPerAction[action]; } ValueType const& getUpperBoundForAction(ActionType const& action) const { return upperBoundsPerAction[action]; } ValueType getDifferenceOfStateBounds(StateType const& state, ExplorationInformation const& explorationInformation) { std::pair<ValueType, ValueType> bounds = getBoundsForState(state, explorationInformation); return bounds.second - bounds.first; } void initializeBoundsForNextState(std::pair<ValueType, ValueType> const& vals = std::pair<ValueType, ValueType>(storm::utility::zero<ValueType>(), storm::utility::one<ValueType>())) { lowerBoundsPerState.push_back(vals.first); upperBoundsPerState.push_back(vals.second); } void initializeBoundsForNextAction(std::pair<ValueType, ValueType> const& vals = std::pair<ValueType, ValueType>(storm::utility::zero<ValueType>(), storm::utility::one<ValueType>())) { lowerBoundsPerAction.push_back(vals.first); upperBoundsPerAction.push_back(vals.second); } void setLowerBoundForState(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& value) { lowerBoundsPerState[explorationInformation.getRowGroup(state)] = value; } void setUpperBoundForState(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& value) { upperBoundsPerState[explorationInformation.getRowGroup(state)] = value; } void setBoundsForAction(ActionType const& action, std::pair<ValueType, ValueType> const& values) { lowerBoundsPerAction[action] = values.first; upperBoundsPerAction[action] = values.second; } void setBoundsForState(StateType const& state, ExplorationInformation const& explorationInformation, std::pair<ValueType, ValueType> const& values) { StateType const& rowGroup = explorationInformation.getRowGroup(state); lowerBoundsPerState[rowGroup] = values.first; upperBoundsPerState[rowGroup] = values.second; } bool setNewLowerBoundOfStateIfGreaterThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newLowerValue) { StateType const& rowGroup = explorationInformation.getRowGroup(state); if (lowerBoundsPerState[rowGroup] < newLowerValue) { lowerBoundsPerState[rowGroup] = newLowerValue; return true; } return false; } bool setNewUpperBoundOfStateIfLessThanOld(StateType const& state, ExplorationInformation const& explorationInformation, ValueType const& newUpperValue) { StateType const& rowGroup = explorationInformation.getRowGroup(state); if (newUpperValue < upperBoundsPerState[rowGroup]) { upperBoundsPerState[rowGroup] = newUpperValue; return true; } return false; } }; storm::expressions::Expression getTargetStateExpression(storm::logic::Formula const& subformula) const; std::function<StateType (storm::generator::CompressedState const&)> createStateToIdCallback(ExplorationInformation& explorationInformation) const; std::tuple<StateType, ValueType, ValueType> performLearningProcedure(StateGeneration& stateGeneration, ExplorationInformation& explorationInformation) const; bool samplePathFromState(StateGeneration& stateGeneration, ExplorationInformation& explorationInformation, StateActionStack& stack, BoundValues& bounds, Statistics& stats) const; bool exploreState(StateGeneration& stateGeneration, StateType const& currentStateId, storm::generator::CompressedState const& currentState, ExplorationInformation& explorationInformation, BoundValues& bounds, Statistics& stats) const; uint32_t sampleMaxAction(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; StateType sampleSuccessorFromAction(ActionType const& chosenAction, ExplorationInformation const& explorationInformation) const; void detectEndComponents(StateActionStack const& stack, ExplorationInformation& explorationInformation, BoundValues& bounds) const; void updateProbabilityBoundsAlongSampledPath(StateActionStack& stack, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; void updateProbabilityOfAction(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues& bounds) const; std::pair<ValueType, ValueType> computeBoundsOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeUpperBoundOverAllOtherActions(StateType const& state, ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; std::pair<ValueType, ValueType> computeBoundsOfState(StateType const& currentStateId, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeLowerBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeUpperBoundOfAction(ActionType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeLowerBoundOfState(StateType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; ValueType computeUpperBoundOfState(StateType const& action, ExplorationInformation const& explorationInformation, BoundValues const& bounds) const; // The program that defines the model to check. storm::prism::Program program; // The variable information. storm::generator::VariableInformation variableInformation; // The random number generator. mutable std::default_random_engine randomGenerator; // A comparator used to determine whether values are equal. storm::utility::ConstantsComparator<ValueType> comparator; }; } } #endif /* STORM_MODELCHECKER_REACHABILITY_SPARSEMDPLEARNINGMODELCHECKER_H_ */