#include "storm/simulator/PrismProgramSimulator.h" #include "storm/exceptions/NotSupportedException.h" using namespace storm::generator; namespace storm { namespace simulator { template DiscreteTimePrismProgramSimulator::DiscreteTimePrismProgramSimulator(storm::prism::Program const& program, storm::generator::NextStateGeneratorOptions const& options) : program(program), currentState(), stateGenerator(std::make_shared>(program, options)), zeroRewards(stateGenerator->getNumberOfRewardModels(), storm::utility::zero()), lastActionRewards(zeroRewards) { // Current state needs to be overwritten to actual initial state. // But first, let us create a state generator. clearStateCaches(); resetToInitial(); } template void DiscreteTimePrismProgramSimulator::setSeed(uint64_t newSeed) { generator = storm::utility::RandomProbabilityGenerator(newSeed); } template bool DiscreteTimePrismProgramSimulator::step(uint64_t actionNumber) { uint32_t nextState = behavior.getChoices()[actionNumber].sampleFromDistribution(generator.random()); lastActionRewards = behavior.getChoices()[actionNumber].getRewards(); STORM_LOG_ASSERT(lastActionRewards.size() == stateGenerator->getNumberOfRewardModels(), "Reward vector should have as many rewards as model."); currentState = idToState[nextState]; // TODO we do not need to do this in every step! clearStateCaches(); explore(); return true; } template bool DiscreteTimePrismProgramSimulator::explore() { // Load the current state into the next state generator. stateGenerator->load(currentState); // TODO: This low-level code currently expands all actions, while this is not necessary. // However, using the next state generator ensures compatibliity with the model generator. behavior = stateGenerator->expand(stateToIdCallback); if (behavior.getStateRewards().size() > 0) { STORM_LOG_ASSERT(behavior.getStateRewards().size() == lastActionRewards.size(), "Reward vectors should have same length."); } for(uint64_t i = 0; i < behavior.getStateRewards().size(); i++) { lastActionRewards[i] += behavior.getStateRewards()[i]; } return true; } template bool DiscreteTimePrismProgramSimulator::isSinkState() const { if(behavior.empty()) { return true; } std::set successorIds; for (Choice const& choice : behavior.getChoices()) { for (auto it = choice.begin(); it != choice.end(); ++it) { successorIds.insert(it->first); if (successorIds.size() > 1) { return false; } } } if (idToState.at(*(successorIds.begin())) == currentState) { return true; } return false; } template std::vector> const& DiscreteTimePrismProgramSimulator::getChoices() const { return behavior.getChoices(); } template std::vector const& DiscreteTimePrismProgramSimulator::getLastRewards() const { return lastActionRewards; } template CompressedState const& DiscreteTimePrismProgramSimulator::getCurrentState() const { return currentState; } template expressions::SimpleValuation DiscreteTimePrismProgramSimulator::getCurrentStateAsValuation() const { return unpackStateIntoValuation(currentState, stateGenerator->getVariableInformation(), program.getManager()); } template std::string DiscreteTimePrismProgramSimulator::getCurrentStateString() const { return stateGenerator->stateToString(currentState); } template storm::json DiscreteTimePrismProgramSimulator::getStateAsJson() const { return stateGenerator->currentStateToJson(false); } template storm::json DiscreteTimePrismProgramSimulator::getObservationAsJson() const { return stateGenerator->currentStateToJson(true); } template bool DiscreteTimePrismProgramSimulator::resetToInitial() { auto indices = stateGenerator->getInitialStates(stateToIdCallback); STORM_LOG_THROW(indices.size() == 1, storm::exceptions::NotSupportedException, "Program must have a unique initial state"); currentState = idToState[indices[0]]; return explore(); } template bool DiscreteTimePrismProgramSimulator::resetToState(generator::CompressedState const& newState) { currentState = newState; return explore(); } template bool DiscreteTimePrismProgramSimulator::resetToState(expressions::SimpleValuation const& valuation) { currentState = generator::packStateFromValuation(valuation, stateGenerator->getVariableInformation(), true); return explore(); } template uint32_t DiscreteTimePrismProgramSimulator::getOrAddStateIndex(generator::CompressedState const& state) { uint32_t newIndex = static_cast(stateToId.size()); // Check, if the state was already registered. std::pair actualIndexBucketPair = stateToId.findOrAddAndGetBucket(state, newIndex); uint32_t actualIndex = actualIndexBucketPair.first; if (actualIndex == newIndex) { idToState[actualIndex] = state; } return actualIndex; } template void DiscreteTimePrismProgramSimulator::clearStateCaches() { idToState.clear(); stateToId = storm::storage::BitVectorHashMap(stateGenerator->getStateSize()); } template class DiscreteTimePrismProgramSimulator; } }