You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
159 lines
7.2 KiB
159 lines
7.2 KiB
#include "storm/simulator/PrismProgramSimulator.h"
|
|
#include "storm/exceptions/NotSupportedException.h"
|
|
|
|
using namespace storm::generator;
|
|
|
|
namespace storm {
|
|
namespace simulator {
|
|
|
|
template<typename ValueType>
|
|
DiscreteTimePrismProgramSimulator<ValueType>::DiscreteTimePrismProgramSimulator(storm::prism::Program const& program, storm::generator::NextStateGeneratorOptions const& options)
|
|
: program(program), currentState(), stateGenerator(std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(program, options)), zeroRewards(stateGenerator->getNumberOfRewardModels(), storm::utility::zero<ValueType>()), lastActionRewards(zeroRewards)
|
|
{
|
|
// Current state needs to be overwritten to actual initial state.
|
|
// But first, let us create a state generator.
|
|
|
|
clearStateCaches();
|
|
resetToInitial();
|
|
}
|
|
|
|
template<typename ValueType>
|
|
void DiscreteTimePrismProgramSimulator<ValueType>::setSeed(uint64_t newSeed) {
|
|
generator = storm::utility::RandomProbabilityGenerator<ValueType>(newSeed);
|
|
}
|
|
|
|
template<typename ValueType>
|
|
bool DiscreteTimePrismProgramSimulator<ValueType>::step(uint64_t actionNumber) {
|
|
uint32_t nextState = behavior.getChoices()[actionNumber].sampleFromDistribution(generator.random());
|
|
lastActionRewards = behavior.getChoices()[actionNumber].getRewards();
|
|
STORM_LOG_ASSERT(lastActionRewards.size() == stateGenerator->getNumberOfRewardModels(), "Reward vector should have as many rewards as model.");
|
|
currentState = idToState[nextState];
|
|
// TODO we do not need to do this in every step!
|
|
clearStateCaches();
|
|
explore();
|
|
return true;
|
|
}
|
|
|
|
template<typename ValueType>
|
|
bool DiscreteTimePrismProgramSimulator<ValueType>::explore() {
|
|
// Load the current state into the next state generator.
|
|
stateGenerator->load(currentState);
|
|
// TODO: This low-level code currently expands all actions, while this is not necessary.
|
|
// However, using the next state generator ensures compatibliity with the model generator.
|
|
behavior = stateGenerator->expand(stateToIdCallback);
|
|
STORM_LOG_ASSERT(behavior.getStateRewards().size() == lastActionRewards.size(), "Reward vectors should have same length.");
|
|
for(uint64_t i = 0; i < behavior.getStateRewards().size(); i++) {
|
|
lastActionRewards[i] += behavior.getStateRewards()[i];
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template<typename ValueType>
|
|
bool DiscreteTimePrismProgramSimulator<ValueType>::isSinkState() const {
|
|
if(behavior.empty()) {
|
|
return true;
|
|
}
|
|
std::set<uint32_t> successorIds;
|
|
for (Choice<ValueType,uint32_t> const& choice : behavior.getChoices()) {
|
|
for (auto it = choice.begin(); it != choice.end(); ++it) {
|
|
successorIds.insert(it->first);
|
|
if (successorIds.size() > 1) {
|
|
return false;
|
|
}
|
|
}
|
|
}
|
|
if (idToState.at(*(successorIds.begin())) == currentState) {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template<typename ValueType>
|
|
std::vector<generator::Choice<ValueType, uint32_t>> const& DiscreteTimePrismProgramSimulator<ValueType>::getChoices() const {
|
|
return behavior.getChoices();
|
|
}
|
|
|
|
template<typename ValueType>
|
|
std::vector<ValueType> const& DiscreteTimePrismProgramSimulator<ValueType>::getLastRewards() const {
|
|
return lastActionRewards;
|
|
}
|
|
|
|
template<typename ValueType>
|
|
CompressedState const& DiscreteTimePrismProgramSimulator<ValueType>::getCurrentState() const {
|
|
return currentState;
|
|
}
|
|
|
|
template<typename ValueType>
|
|
expressions::SimpleValuation DiscreteTimePrismProgramSimulator<ValueType>::getCurrentStateAsValuation() const {
|
|
return unpackStateIntoValuation(currentState, stateGenerator->getVariableInformation(), program.getManager());
|
|
}
|
|
|
|
template<typename ValueType>
|
|
std::string DiscreteTimePrismProgramSimulator<ValueType>::getCurrentStateString() const {
|
|
return stateGenerator->stateToString(currentState);
|
|
}
|
|
|
|
template<typename ValueType>
|
|
storm::json<ValueType> DiscreteTimePrismProgramSimulator<ValueType>::getStateAsJson() const {
|
|
return stateGenerator->currentStateToJson(false);
|
|
}
|
|
|
|
template<typename ValueType>
|
|
storm::json<ValueType> DiscreteTimePrismProgramSimulator<ValueType>::getObservationAsJson() const {
|
|
return stateGenerator->currentStateToJson(true);
|
|
}
|
|
|
|
template<typename ValueType>
|
|
bool DiscreteTimePrismProgramSimulator<ValueType>::resetToInitial() {
|
|
lastActionRewards = zeroRewards;
|
|
auto indices = stateGenerator->getInitialStates(stateToIdCallback);
|
|
STORM_LOG_THROW(indices.size() == 1, storm::exceptions::NotSupportedException, "Program must have a unique initial state");
|
|
currentState = idToState[indices[0]];
|
|
return explore();
|
|
}
|
|
|
|
template<typename ValueType>
|
|
bool DiscreteTimePrismProgramSimulator<ValueType>::resetToState(generator::CompressedState const& newState) {
|
|
lastActionRewards = zeroRewards;
|
|
currentState = newState;
|
|
return explore();
|
|
}
|
|
|
|
template<typename ValueType>
|
|
bool DiscreteTimePrismProgramSimulator<ValueType>::resetToState(expressions::SimpleValuation const& valuation) {
|
|
currentState = generator::packStateFromValuation(valuation, stateGenerator->getVariableInformation(), true);
|
|
return explore();
|
|
}
|
|
|
|
template<typename ValueType>
|
|
std::vector<std::string> DiscreteTimePrismProgramSimulator<ValueType>::getRewardNames() const {
|
|
std::vector<std::string> names;
|
|
for (uint64_t i = 0; i < stateGenerator->getNumberOfRewardModels(); ++i) {
|
|
names.push_back(stateGenerator->getRewardModelInformation(i).getName());
|
|
}
|
|
return names;
|
|
}
|
|
|
|
template<typename ValueType>
|
|
uint32_t DiscreteTimePrismProgramSimulator<ValueType>::getOrAddStateIndex(generator::CompressedState const& state) {
|
|
uint32_t newIndex = static_cast<uint32_t>(stateToId.size());
|
|
|
|
// Check, if the state was already registered.
|
|
std::pair<uint32_t, std::size_t> actualIndexBucketPair = stateToId.findOrAddAndGetBucket(state, newIndex);
|
|
|
|
uint32_t actualIndex = actualIndexBucketPair.first;
|
|
if (actualIndex == newIndex) {
|
|
idToState[actualIndex] = state;
|
|
}
|
|
return actualIndex;
|
|
}
|
|
|
|
template<typename ValueType>
|
|
void DiscreteTimePrismProgramSimulator<ValueType>::clearStateCaches() {
|
|
idToState.clear();
|
|
stateToId = storm::storage::BitVectorHashMap<uint32_t>(stateGenerator->getStateSize());
|
|
}
|
|
|
|
template class DiscreteTimePrismProgramSimulator<double>;
|
|
}
|
|
}
|