Browse Source

simulator presents rewards for efficiency

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
cf3bfc3d2d
  1. 40
      src/storm/simulator/DiscreteTimeSparseModelSimulator.cpp
  2. 5
      src/storm/simulator/DiscreteTimeSparseModelSimulator.h

40
src/storm/simulator/DiscreteTimeSparseModelSimulator.cpp

@ -4,9 +4,16 @@
namespace storm { namespace storm {
namespace simulator { namespace simulator {
template<typename ValueType, typename RewardModelType> template<typename ValueType, typename RewardModelType>
DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::DiscreteTimeSparseModelSimulator(storm::models::sparse::Model<ValueType, RewardModelType> const& model) : currentState(*model.getInitialStates().begin()), model(model) {
DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::DiscreteTimeSparseModelSimulator(storm::models::sparse::Model<ValueType, RewardModelType> const& model) : model(model), currentState(*model.getInitialStates().begin()), zeroRewards(model.getNumberOfRewardModels(), storm::utility::zero<ValueType>()) {
STORM_LOG_WARN_COND(model.getInitialStates().getNumberOfSetBits()==1, "The model has multiple initial states. This simulator assumes it starts from the initial state with the lowest index."); STORM_LOG_WARN_COND(model.getInitialStates().getNumberOfSetBits()==1, "The model has multiple initial states. This simulator assumes it starts from the initial state with the lowest index.");
lastRewards = zeroRewards;
uint64_t i = 0;
for (auto const& rewModPair : model.getRewardModels()) {
if (rewModPair.second.hasStateRewards()) {
lastRewards[i] += rewModPair.second.getStateReward(currentState);
}
++i;
}
} }
template<typename ValueType, typename RewardModelType> template<typename ValueType, typename RewardModelType>
@ -18,19 +25,34 @@ namespace storm {
bool DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::step(uint64_t action) { bool DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::step(uint64_t action) {
// TODO lots of optimization potential. // TODO lots of optimization potential.
// E.g., do not sample random numbers if there is only a single transition // E.g., do not sample random numbers if there is only a single transition
lastRewards = zeroRewards;
ValueType probability = generator.random(); ValueType probability = generator.random();
STORM_LOG_ASSERT(action < model.getTransitionMatrix().getRowGroupSize(currentState), "Action index higher than number of actions"); STORM_LOG_ASSERT(action < model.getTransitionMatrix().getRowGroupSize(currentState), "Action index higher than number of actions");
uint64_t row = model.getTransitionMatrix().getRowGroupIndices()[currentState] + action; uint64_t row = model.getTransitionMatrix().getRowGroupIndices()[currentState] + action;
uint64_t i = 0;
for (auto const& rewModPair : model.getRewardModels()) {
if (rewModPair.second.hasStateActionRewards()) {
lastRewards[i] += rewModPair.second.getStateActionReward(row);
}
++i;
}
ValueType sum = storm::utility::zero<ValueType>(); ValueType sum = storm::utility::zero<ValueType>();
for (auto const& entry : model.getTransitionMatrix().getRow(row)) { for (auto const& entry : model.getTransitionMatrix().getRow(row)) {
sum += entry.getValue(); sum += entry.getValue();
if (sum >= probability) { if (sum >= probability) {
currentState = entry.getColumn(); currentState = entry.getColumn();
i = 0;
for (auto const& rewModPair : model.getRewardModels()) {
if (rewModPair.second.hasStateRewards()) {
lastRewards[i] += rewModPair.second.getStateReward(currentState);
}
++i;
}
return true; return true;
} }
} }
// This position should never be reached
return false; return false;
STORM_LOG_ASSERT(false, "This position should never be reached");
} }
template<typename ValueType, typename RewardModelType> template<typename ValueType, typename RewardModelType>
@ -41,9 +63,21 @@ namespace storm {
template<typename ValueType, typename RewardModelType> template<typename ValueType, typename RewardModelType>
bool DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::resetToInitial() { bool DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::resetToInitial() {
currentState = *model.getInitialStates().begin(); currentState = *model.getInitialStates().begin();
lastRewards = zeroRewards;
uint64_t i = 0;
for (auto const& rewModPair : model.getRewardModels()) {
if (rewModPair.second.hasStateRewards()) {
lastRewards[i] += rewModPair.second.getStateReward(currentState);
}
++i;
}
return true; return true;
} }
template<typename ValueType, typename RewardModelType>
std::vector<ValueType> const& DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::getLastRewards() const {
return lastRewards;
}
template class DiscreteTimeSparseModelSimulator<double>; template class DiscreteTimeSparseModelSimulator<double>;
} }

5
src/storm/simulator/DiscreteTimeSparseModelSimulator.h

@ -19,11 +19,14 @@ namespace storm {
DiscreteTimeSparseModelSimulator(storm::models::sparse::Model<ValueType, RewardModelType> const& model); DiscreteTimeSparseModelSimulator(storm::models::sparse::Model<ValueType, RewardModelType> const& model);
void setSeed(uint64_t); void setSeed(uint64_t);
bool step(uint64_t action); bool step(uint64_t action);
std::vector<ValueType> const& getLastRewards() const;
uint64_t getCurrentState() const; uint64_t getCurrentState() const;
bool resetToInitial(); bool resetToInitial();
protected: protected:
uint64_t currentState;
storm::models::sparse::Model<ValueType, RewardModelType> const& model; storm::models::sparse::Model<ValueType, RewardModelType> const& model;
uint64_t currentState;
std::vector<ValueType> lastRewards;
std::vector<ValueType> zeroRewards;
storm::utility::RandomProbabilityGenerator<ValueType> generator; storm::utility::RandomProbabilityGenerator<ValueType> generator;
}; };
} }
Loading…
Cancel
Save