#include "storm/simulator/DiscreteTimeSparseModelSimulator.h" #include "storm/models/sparse/Model.h" namespace storm { namespace simulator { template<typename ValueType, typename RewardModelType> DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::DiscreteTimeSparseModelSimulator(storm::models::sparse::Model<ValueType, RewardModelType> const& model) : model(model), currentState(*model.getInitialStates().begin()), zeroRewards(model.getNumberOfRewardModels(), storm::utility::zero<ValueType>()) { STORM_LOG_WARN_COND(model.getInitialStates().getNumberOfSetBits()==1, "The model has multiple initial states. This simulator assumes it starts from the initial state with the lowest index."); lastRewards = zeroRewards; uint64_t i = 0; for (auto const& rewModPair : model.getRewardModels()) { if (rewModPair.second.hasStateRewards()) { lastRewards[i] += rewModPair.second.getStateReward(currentState); } ++i; } } template<typename ValueType, typename RewardModelType> void DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::setSeed(uint64_t seed) { generator = storm::utility::RandomProbabilityGenerator<ValueType>(seed); } template<typename ValueType, typename RewardModelType> bool DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::randomStep() { // TODO random_uint is slow if (model.getTransitionMatrix().getRowGroupSize(currentState) == 0) { return false; } return step(generator.random_uint(0, model.getTransitionMatrix().getRowGroupSize(currentState) - 1)); } template<typename ValueType, typename RewardModelType> bool DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::step(uint64_t action) { // TODO lots of optimization potential. // E.g., do not sample random numbers if there is only a single transition lastRewards = zeroRewards; ValueType probability = generator.random(); STORM_LOG_ASSERT(action < model.getTransitionMatrix().getRowGroupSize(currentState), "Action index higher than number of actions"); uint64_t row = model.getTransitionMatrix().getRowGroupIndices()[currentState] + action; uint64_t i = 0; for (auto const& rewModPair : model.getRewardModels()) { if (rewModPair.second.hasStateActionRewards()) { lastRewards[i] += rewModPair.second.getStateActionReward(row); } ++i; } ValueType sum = storm::utility::zero<ValueType>(); for (auto const& entry : model.getTransitionMatrix().getRow(row)) { sum += entry.getValue(); if (sum >= probability) { currentState = entry.getColumn(); i = 0; for (auto const& rewModPair : model.getRewardModels()) { if (rewModPair.second.hasStateRewards()) { lastRewards[i] += rewModPair.second.getStateReward(currentState); } ++i; } return true; } } // This position should never be reached return false; } template<typename ValueType, typename RewardModelType> uint64_t DiscreteTimeSparseModelSimulator<ValueType, RewardModelType>::getCurrentState() const { return currentState; } template<typename ValueType, typename RewardModelType> bool DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::resetToInitial() { currentState = *model.getInitialStates().begin(); lastRewards = zeroRewards; uint64_t i = 0; for (auto const& rewModPair : model.getRewardModels()) { if (rewModPair.second.hasStateRewards()) { lastRewards[i] += rewModPair.second.getStateReward(currentState); } ++i; } return true; } template<typename ValueType, typename RewardModelType> std::vector<ValueType> const& DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::getLastRewards() const { return lastRewards; } template class DiscreteTimeSparseModelSimulator<double>; template class DiscreteTimeSparseModelSimulator<storm::RationalNumber>; } }