You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

95 lines
4.6 KiB

#include "storm/simulator/DiscreteTimeSparseModelSimulator.h"
#include "storm/models/sparse/Model.h"
namespace storm {
namespace simulator {
template<typename ValueType, typename RewardModelType>
DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::DiscreteTimeSparseModelSimulator(storm::models::sparse::Model<ValueType, RewardModelType> const& model) : model(model), currentState(*model.getInitialStates().begin()), zeroRewards(model.getNumberOfRewardModels(), storm::utility::zero<ValueType>()) {
STORM_LOG_WARN_COND(model.getInitialStates().getNumberOfSetBits()==1, "The model has multiple initial states. This simulator assumes it starts from the initial state with the lowest index.");
lastRewards = zeroRewards;
uint64_t i = 0;
for (auto const& rewModPair : model.getRewardModels()) {
if (rewModPair.second.hasStateRewards()) {
lastRewards[i] += rewModPair.second.getStateReward(currentState);
}
++i;
}
}
template<typename ValueType, typename RewardModelType>
void DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::setSeed(uint64_t seed) {
generator = storm::utility::RandomProbabilityGenerator<ValueType>(seed);
}
template<typename ValueType, typename RewardModelType>
bool DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::randomStep() {
// TODO random_uint is slow
if (model.getTransitionMatrix().getRowGroupSize(currentState) == 0) {
return false;
}
return step(generator.random_uint(0, model.getTransitionMatrix().getRowGroupSize(currentState) - 1));
}
template<typename ValueType, typename RewardModelType>
bool DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::step(uint64_t action) {
// TODO lots of optimization potential.
// E.g., do not sample random numbers if there is only a single transition
lastRewards = zeroRewards;
ValueType probability = generator.random();
STORM_LOG_ASSERT(action < model.getTransitionMatrix().getRowGroupSize(currentState), "Action index higher than number of actions");
uint64_t row = model.getTransitionMatrix().getRowGroupIndices()[currentState] + action;
uint64_t i = 0;
for (auto const& rewModPair : model.getRewardModels()) {
if (rewModPair.second.hasStateActionRewards()) {
lastRewards[i] += rewModPair.second.getStateActionReward(row);
}
++i;
}
ValueType sum = storm::utility::zero<ValueType>();
for (auto const& entry : model.getTransitionMatrix().getRow(row)) {
sum += entry.getValue();
if (sum >= probability) {
currentState = entry.getColumn();
i = 0;
for (auto const& rewModPair : model.getRewardModels()) {
if (rewModPair.second.hasStateRewards()) {
lastRewards[i] += rewModPair.second.getStateReward(currentState);
}
++i;
}
return true;
}
}
// This position should never be reached
return false;
}
template<typename ValueType, typename RewardModelType>
uint64_t DiscreteTimeSparseModelSimulator<ValueType, RewardModelType>::getCurrentState() const {
return currentState;
}
template<typename ValueType, typename RewardModelType>
bool DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::resetToInitial() {
currentState = *model.getInitialStates().begin();
lastRewards = zeroRewards;
uint64_t i = 0;
for (auto const& rewModPair : model.getRewardModels()) {
if (rewModPair.second.hasStateRewards()) {
lastRewards[i] += rewModPair.second.getStateReward(currentState);
}
++i;
}
return true;
}
template<typename ValueType, typename RewardModelType>
std::vector<ValueType> const& DiscreteTimeSparseModelSimulator<ValueType,RewardModelType>::getLastRewards() const {
return lastRewards;
}
template class DiscreteTimeSparseModelSimulator<double>;
template class DiscreteTimeSparseModelSimulator<storm::RationalNumber>;
}
}