You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

159 lines
7.2 KiB

#include "storm/simulator/PrismProgramSimulator.h"
#include "storm/exceptions/NotSupportedException.h"
using namespace storm::generator;
namespace storm {
namespace simulator {
template<typename ValueType>
DiscreteTimePrismProgramSimulator<ValueType>::DiscreteTimePrismProgramSimulator(storm::prism::Program const& program, storm::generator::NextStateGeneratorOptions const& options)
: program(program), currentState(), stateGenerator(std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(program, options)), zeroRewards(stateGenerator->getNumberOfRewardModels(), storm::utility::zero<ValueType>()), lastActionRewards(zeroRewards)
{
// Current state needs to be overwritten to actual initial state.
// But first, let us create a state generator.
clearStateCaches();
resetToInitial();
}
template<typename ValueType>
void DiscreteTimePrismProgramSimulator<ValueType>::setSeed(uint64_t newSeed) {
generator = storm::utility::RandomProbabilityGenerator<ValueType>(newSeed);
}
template<typename ValueType>
bool DiscreteTimePrismProgramSimulator<ValueType>::step(uint64_t actionNumber) {
uint32_t nextState = behavior.getChoices()[actionNumber].sampleFromDistribution(generator.random());
lastActionRewards = behavior.getChoices()[actionNumber].getRewards();
STORM_LOG_ASSERT(lastActionRewards.size() == stateGenerator->getNumberOfRewardModels(), "Reward vector should have as many rewards as model.");
currentState = idToState[nextState];
// TODO we do not need to do this in every step!
clearStateCaches();
explore();
return true;
}
template<typename ValueType>
bool DiscreteTimePrismProgramSimulator<ValueType>::explore() {
// Load the current state into the next state generator.
stateGenerator->load(currentState);
// TODO: This low-level code currently expands all actions, while this is not necessary.
// However, using the next state generator ensures compatibliity with the model generator.
behavior = stateGenerator->expand(stateToIdCallback);
STORM_LOG_ASSERT(behavior.getStateRewards().size() == lastActionRewards.size(), "Reward vectors should have same length.");
for(uint64_t i = 0; i < behavior.getStateRewards().size(); i++) {
lastActionRewards[i] += behavior.getStateRewards()[i];
}
return true;
}
template<typename ValueType>
bool DiscreteTimePrismProgramSimulator<ValueType>::isSinkState() const {
if(behavior.empty()) {
return true;
}
std::set<uint32_t> successorIds;
for (Choice<ValueType,uint32_t> const& choice : behavior.getChoices()) {
for (auto it = choice.begin(); it != choice.end(); ++it) {
successorIds.insert(it->first);
if (successorIds.size() > 1) {
return false;
}
}
}
if (idToState.at(*(successorIds.begin())) == currentState) {
return true;
}
return false;
}
template<typename ValueType>
std::vector<generator::Choice<ValueType, uint32_t>> const& DiscreteTimePrismProgramSimulator<ValueType>::getChoices() const {
return behavior.getChoices();
}
template<typename ValueType>
std::vector<ValueType> const& DiscreteTimePrismProgramSimulator<ValueType>::getLastRewards() const {
return lastActionRewards;
}
template<typename ValueType>
CompressedState const& DiscreteTimePrismProgramSimulator<ValueType>::getCurrentState() const {
return currentState;
}
template<typename ValueType>
expressions::SimpleValuation DiscreteTimePrismProgramSimulator<ValueType>::getCurrentStateAsValuation() const {
return unpackStateIntoValuation(currentState, stateGenerator->getVariableInformation(), program.getManager());
}
template<typename ValueType>
std::string DiscreteTimePrismProgramSimulator<ValueType>::getCurrentStateString() const {
return stateGenerator->stateToString(currentState);
}
template<typename ValueType>
storm::json<ValueType> DiscreteTimePrismProgramSimulator<ValueType>::getStateAsJson() const {
return stateGenerator->currentStateToJson(false);
}
template<typename ValueType>
storm::json<ValueType> DiscreteTimePrismProgramSimulator<ValueType>::getObservationAsJson() const {
return stateGenerator->currentStateToJson(true);
}
template<typename ValueType>
bool DiscreteTimePrismProgramSimulator<ValueType>::resetToInitial() {
lastActionRewards = zeroRewards;
auto indices = stateGenerator->getInitialStates(stateToIdCallback);
STORM_LOG_THROW(indices.size() == 1, storm::exceptions::NotSupportedException, "Program must have a unique initial state");
currentState = idToState[indices[0]];
return explore();
}
template<typename ValueType>
bool DiscreteTimePrismProgramSimulator<ValueType>::resetToState(generator::CompressedState const& newState) {
lastActionRewards = zeroRewards;
currentState = newState;
return explore();
}
template<typename ValueType>
bool DiscreteTimePrismProgramSimulator<ValueType>::resetToState(expressions::SimpleValuation const& valuation) {
currentState = generator::packStateFromValuation(valuation, stateGenerator->getVariableInformation(), true);
return explore();
}
template<typename ValueType>
std::vector<std::string> DiscreteTimePrismProgramSimulator<ValueType>::getRewardNames() const {
std::vector<std::string> names;
for (uint64_t i = 0; i < stateGenerator->getNumberOfRewardModels(); ++i) {
names.push_back(stateGenerator->getRewardModelInformation(i).getName());
}
return names;
}
template<typename ValueType>
uint32_t DiscreteTimePrismProgramSimulator<ValueType>::getOrAddStateIndex(generator::CompressedState const& state) {
uint32_t newIndex = static_cast<uint32_t>(stateToId.size());
// Check, if the state was already registered.
std::pair<uint32_t, std::size_t> actualIndexBucketPair = stateToId.findOrAddAndGetBucket(state, newIndex);
uint32_t actualIndex = actualIndexBucketPair.first;
if (actualIndex == newIndex) {
idToState[actualIndex] = state;
}
return actualIndex;
}
template<typename ValueType>
void DiscreteTimePrismProgramSimulator<ValueType>::clearStateCaches() {
idToState.clear();
stateToId = storm::storage::BitVectorHashMap<uint32_t>(stateGenerator->getStateSize());
}
template class DiscreteTimePrismProgramSimulator<double>;
}
}