From cf3bfc3d2ddc90c80735f030be69d65e0d496887 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Thu, 18 Jun 2020 12:26:51 -0700 Subject: [PATCH] simulator presents rewards for efficiency --- .../DiscreteTimeSparseModelSimulator.cpp | 40 +++++++++++++++++-- .../DiscreteTimeSparseModelSimulator.h | 5 ++- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/src/storm/simulator/DiscreteTimeSparseModelSimulator.cpp b/src/storm/simulator/DiscreteTimeSparseModelSimulator.cpp index 2e9161b8a..d6669516d 100644 --- a/src/storm/simulator/DiscreteTimeSparseModelSimulator.cpp +++ b/src/storm/simulator/DiscreteTimeSparseModelSimulator.cpp @@ -4,9 +4,16 @@ namespace storm { namespace simulator { template - DiscreteTimeSparseModelSimulator::DiscreteTimeSparseModelSimulator(storm::models::sparse::Model const& model) : currentState(*model.getInitialStates().begin()), model(model) { + DiscreteTimeSparseModelSimulator::DiscreteTimeSparseModelSimulator(storm::models::sparse::Model const& model) : model(model), currentState(*model.getInitialStates().begin()), zeroRewards(model.getNumberOfRewardModels(), storm::utility::zero()) { STORM_LOG_WARN_COND(model.getInitialStates().getNumberOfSetBits()==1, "The model has multiple initial states. This simulator assumes it starts from the initial state with the lowest index."); - + lastRewards = zeroRewards; + uint64_t i = 0; + for (auto const& rewModPair : model.getRewardModels()) { + if (rewModPair.second.hasStateRewards()) { + lastRewards[i] += rewModPair.second.getStateReward(currentState); + } + ++i; + } } template @@ -18,19 +25,34 @@ namespace storm { bool DiscreteTimeSparseModelSimulator::step(uint64_t action) { // TODO lots of optimization potential. // E.g., do not sample random numbers if there is only a single transition + lastRewards = zeroRewards; ValueType probability = generator.random(); STORM_LOG_ASSERT(action < model.getTransitionMatrix().getRowGroupSize(currentState), "Action index higher than number of actions"); uint64_t row = model.getTransitionMatrix().getRowGroupIndices()[currentState] + action; + uint64_t i = 0; + for (auto const& rewModPair : model.getRewardModels()) { + if (rewModPair.second.hasStateActionRewards()) { + lastRewards[i] += rewModPair.second.getStateActionReward(row); + } + ++i; + } ValueType sum = storm::utility::zero(); for (auto const& entry : model.getTransitionMatrix().getRow(row)) { sum += entry.getValue(); if (sum >= probability) { currentState = entry.getColumn(); + i = 0; + for (auto const& rewModPair : model.getRewardModels()) { + if (rewModPair.second.hasStateRewards()) { + lastRewards[i] += rewModPair.second.getStateReward(currentState); + } + ++i; + } return true; } } + // This position should never be reached return false; - STORM_LOG_ASSERT(false, "This position should never be reached"); } template @@ -41,9 +63,21 @@ namespace storm { template bool DiscreteTimeSparseModelSimulator::resetToInitial() { currentState = *model.getInitialStates().begin(); + lastRewards = zeroRewards; + uint64_t i = 0; + for (auto const& rewModPair : model.getRewardModels()) { + if (rewModPair.second.hasStateRewards()) { + lastRewards[i] += rewModPair.second.getStateReward(currentState); + } + ++i; + } return true; } + template + std::vector const& DiscreteTimeSparseModelSimulator::getLastRewards() const { + return lastRewards; + } template class DiscreteTimeSparseModelSimulator; } diff --git a/src/storm/simulator/DiscreteTimeSparseModelSimulator.h b/src/storm/simulator/DiscreteTimeSparseModelSimulator.h index 1637bffec..a0dcf4e3a 100644 --- a/src/storm/simulator/DiscreteTimeSparseModelSimulator.h +++ b/src/storm/simulator/DiscreteTimeSparseModelSimulator.h @@ -19,11 +19,14 @@ namespace storm { DiscreteTimeSparseModelSimulator(storm::models::sparse::Model const& model); void setSeed(uint64_t); bool step(uint64_t action); + std::vector const& getLastRewards() const; uint64_t getCurrentState() const; bool resetToInitial(); protected: - uint64_t currentState; storm::models::sparse::Model const& model; + uint64_t currentState; + std::vector lastRewards; + std::vector zeroRewards; storm::utility::RandomProbabilityGenerator generator; }; }