#include "storm/simulator/DiscreteTimeSparseModelSimulator.h" #include "storm/models/sparse/Model.h" namespace storm { namespace simulator { template DiscreteTimeSparseModelSimulator::DiscreteTimeSparseModelSimulator(storm::models::sparse::Model const& model) : model(model), currentState(*model.getInitialStates().begin()), zeroRewards(model.getNumberOfRewardModels(), storm::utility::zero()) { STORM_LOG_WARN_COND(model.getInitialStates().getNumberOfSetBits()==1, "The model has multiple initial states. This simulator assumes it starts from the initial state with the lowest index."); lastRewards = zeroRewards; uint64_t i = 0; for (auto const& rewModPair : model.getRewardModels()) { if (rewModPair.second.hasStateRewards()) { lastRewards[i] += rewModPair.second.getStateReward(currentState); } ++i; } } template void DiscreteTimeSparseModelSimulator::setSeed(uint64_t seed) { generator = storm::utility::RandomProbabilityGenerator(seed); } template bool DiscreteTimeSparseModelSimulator::randomStep() { // TODO random_uint is slow if (model.getTransitionMatrix().getRowGroupSize(currentState) == 0) { return false; } return step(generator.random_uint(0, model.getTransitionMatrix().getRowGroupSize(currentState) - 1)); } template bool DiscreteTimeSparseModelSimulator::step(uint64_t action) { // TODO lots of optimization potential. // E.g., do not sample random numbers if there is only a single transition lastRewards = zeroRewards; ValueType probability = generator.random(); STORM_LOG_ASSERT(action < model.getTransitionMatrix().getRowGroupSize(currentState), "Action index higher than number of actions"); uint64_t row = model.getTransitionMatrix().getRowGroupIndices()[currentState] + action; uint64_t i = 0; for (auto const& rewModPair : model.getRewardModels()) { if (rewModPair.second.hasStateActionRewards()) { lastRewards[i] += rewModPair.second.getStateActionReward(row); } ++i; } ValueType sum = storm::utility::zero(); for (auto const& entry : model.getTransitionMatrix().getRow(row)) { sum += entry.getValue(); if (sum >= probability) { currentState = entry.getColumn(); i = 0; for (auto const& rewModPair : model.getRewardModels()) { if (rewModPair.second.hasStateRewards()) { lastRewards[i] += rewModPair.second.getStateReward(currentState); } ++i; } return true; } } // This position should never be reached return false; } template uint64_t DiscreteTimeSparseModelSimulator::getCurrentState() const { return currentState; } template bool DiscreteTimeSparseModelSimulator::resetToInitial() { currentState = *model.getInitialStates().begin(); lastRewards = zeroRewards; uint64_t i = 0; for (auto const& rewModPair : model.getRewardModels()) { if (rewModPair.second.hasStateRewards()) { lastRewards[i] += rewModPair.second.getStateReward(currentState); } ++i; } return true; } template std::vector const& DiscreteTimeSparseModelSimulator::getLastRewards() const { return lastRewards; } template class DiscreteTimeSparseModelSimulator; template class DiscreteTimeSparseModelSimulator; } }