Sebastian Junges
4 years ago
4 changed files with 261 additions and 1 deletions
-
104src/storm/simulator/PrismProgramSimulator.cpp
-
100src/storm/simulator/PrismProgramSimulator.h
-
2src/test/storm/CMakeLists.txt
-
56src/test/storm/simulator/PrismProgramSimulator.cpp
@ -0,0 +1,104 @@ |
|||
#include "storm/simulator/PrismProgramSimulator.h"
|
|||
#include "storm/exceptions/NotSupportedException.h"
|
|||
|
|||
using namespace storm::generator; |
|||
|
|||
namespace storm { |
|||
namespace simulator { |
|||
|
|||
template<typename ValueType> |
|||
DiscreteTimePrismProgramSimulator<ValueType>::DiscreteTimePrismProgramSimulator(storm::prism::Program const& program, storm::generator::NextStateGeneratorOptions const& options) |
|||
: program(program), currentState(), stateGenerator(std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(program, options)), zeroRewards(stateGenerator->getNumberOfRewardModels(), storm::utility::zero<ValueType>()), lastActionRewards(zeroRewards) |
|||
{ |
|||
// Current state needs to be overwritten to actual initial state.
|
|||
// But first, let us create a state generator.
|
|||
|
|||
clearStateCaches(); |
|||
resetToInitial(); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
void DiscreteTimePrismProgramSimulator<ValueType>::setSeed(uint64_t newSeed) { |
|||
generator = storm::utility::RandomProbabilityGenerator<ValueType>(newSeed); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
bool DiscreteTimePrismProgramSimulator<ValueType>::step(uint64_t actionNumber) { |
|||
uint32_t nextState = behavior.getChoices()[actionNumber].sampleFromDistribution(generator.random()); |
|||
lastActionRewards = behavior.getChoices()[actionNumber].getRewards(); |
|||
STORM_LOG_ASSERT(lastActionRewards.size() == stateGenerator->getNumberOfRewardModels(), "Reward vector should have as many rewards as model."); |
|||
currentState = idToState[nextState]; |
|||
// TODO we do not need to do this in every step!
|
|||
clearStateCaches(); |
|||
explore(); |
|||
return true; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
bool DiscreteTimePrismProgramSimulator<ValueType>::explore() { |
|||
// Load the current state into the next state generator.
|
|||
stateGenerator->load(currentState); |
|||
// TODO: This low-level code currently expands all actions, while this may not be necessary.
|
|||
behavior = stateGenerator->expand(stateToIdCallback); |
|||
if (behavior.getStateRewards().size() > 0) { |
|||
STORM_LOG_ASSERT(behavior.getStateRewards().size() == lastActionRewards.size(), "Reward vectors should have same length."); |
|||
} |
|||
return true; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
std::vector<generator::Choice<ValueType, uint32_t>> const& DiscreteTimePrismProgramSimulator<ValueType>::getChoices() { |
|||
return behavior.getChoices(); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
std::vector<ValueType> const& DiscreteTimePrismProgramSimulator<ValueType>::getLastRewards() const { |
|||
return lastActionRewards; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
CompressedState const& DiscreteTimePrismProgramSimulator<ValueType>::getCurrentState() const { |
|||
return currentState; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
expressions::SimpleValuation DiscreteTimePrismProgramSimulator<ValueType>::getCurrentStateAsValuation() const { |
|||
return unpackStateIntoValuation(currentState, stateGenerator->getVariableInformation(), program.getManager()); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
std::string DiscreteTimePrismProgramSimulator<ValueType>::getCurrentStateString() const { |
|||
return stateGenerator->stateToString(currentState); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
bool DiscreteTimePrismProgramSimulator<ValueType>::resetToInitial() { |
|||
auto indices = stateGenerator->getInitialStates(stateToIdCallback); |
|||
STORM_LOG_THROW(indices.size() == 1, storm::exceptions::NotSupportedException, "Program must have a unique initial state"); |
|||
currentState = idToState[indices[0]]; |
|||
return explore(); |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
uint32_t DiscreteTimePrismProgramSimulator<ValueType>::getOrAddStateIndex(generator::CompressedState const& state) { |
|||
uint32_t newIndex = static_cast<uint32_t>(stateToId.size()); |
|||
|
|||
// Check, if the state was already registered.
|
|||
std::pair<uint32_t, std::size_t> actualIndexBucketPair = stateToId.findOrAddAndGetBucket(state, newIndex); |
|||
|
|||
uint32_t actualIndex = actualIndexBucketPair.first; |
|||
if (actualIndex == newIndex) { |
|||
idToState[actualIndex] = state; |
|||
} |
|||
return actualIndex; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
void DiscreteTimePrismProgramSimulator<ValueType>::clearStateCaches() { |
|||
idToState.clear(); |
|||
stateToId = storm::storage::BitVectorHashMap<uint32_t>(stateGenerator->getStateSize()); |
|||
} |
|||
|
|||
template class DiscreteTimePrismProgramSimulator<double>; |
|||
} |
|||
} |
@ -0,0 +1,100 @@ |
|||
#pragma once |
|||
|
|||
#include "storm/storage/prism/Program.h" |
|||
#include "storm/storage/expressions/SimpleValuation.h" |
|||
#include "storm/generator/PrismNextStateGenerator.h" |
|||
#include "storm/utility/random.h" |
|||
|
|||
|
|||
namespace storm { |
|||
namespace simulator { |
|||
|
|||
/** |
|||
* This class provides a simulator interface on the prism program, |
|||
* and uses the next state generator. While the next state generator has been tuned, |
|||
* it is not targeted for simulation purposes. In particular, we (as of now) |
|||
* always extend all actions as soon as we arrive in a state. |
|||
* This may cause significant overhead, especially with a larger branching factor. |
|||
* |
|||
* On the other hand, this simulator is convenient for stepping through the model |
|||
* as it potentially allows considering the next states. |
|||
* Thus, while a performant alternative would be great, this simulator has its own merits. |
|||
* |
|||
* @tparam ValueType |
|||
*/ |
|||
template<typename ValueType> |
|||
class DiscreteTimePrismProgramSimulator { |
|||
public: |
|||
/** |
|||
* Initialize the simulator for a given prism program. |
|||
* |
|||
* @param program The prism program. Should have a unique initial state. |
|||
* @param options The generator options that are used to generate successor states. |
|||
*/ |
|||
DiscreteTimePrismProgramSimulator(storm::prism::Program const& program, |
|||
storm::generator::NextStateGeneratorOptions const& options); |
|||
/** |
|||
* Set the simulation seed. |
|||
*/ |
|||
void setSeed(uint64_t); |
|||
/** |
|||
* |
|||
* @return A list of choices that encode the possibilities in the current state. |
|||
*/ |
|||
std::vector<generator::Choice<ValueType, uint32_t>> const& getChoices(); |
|||
/** |
|||
* Make a step and randomly select the successor. The action is given as an argument, the index reflects the index of the getChoices vector that can be accessed. |
|||
* |
|||
* @param actionNumber The action to select. |
|||
* @return true, if this action can be taken. |
|||
*/ |
|||
bool step(uint64_t actionNumber); |
|||
/** |
|||
* Accessor for the last state action reward and the current state reward, added together. |
|||
* @return A vector with te number of rewards. |
|||
*/ |
|||
std::vector<ValueType> const& getLastRewards() const; |
|||
generator::CompressedState const& getCurrentState() const; |
|||
expressions::SimpleValuation getCurrentStateAsValuation() const; |
|||
|
|||
std::string getCurrentStateString() const; |
|||
/** |
|||
* Reset to the (unique) initial state. |
|||
* |
|||
* @return |
|||
*/ |
|||
bool resetToInitial(); |
|||
protected: |
|||
bool explore(); |
|||
void clearStateCaches(); |
|||
/** |
|||
* Helper function for (temp) storing states. |
|||
*/ |
|||
uint32_t getOrAddStateIndex(generator::CompressedState const&); |
|||
|
|||
/// The program that we are simulating. |
|||
storm::prism::Program const& program; |
|||
/// The current state in the program, in its compressed form. |
|||
generator::CompressedState currentState; |
|||
/// Generator for the next states |
|||
std::shared_ptr<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>> stateGenerator; |
|||
bool explored = false; |
|||
generator::StateBehavior<ValueType> behavior; |
|||
/// Helper for last action reward construction |
|||
std::vector<ValueType> zeroRewards; |
|||
/// Stores the action rewards from the last action. |
|||
std::vector<ValueType> lastActionRewards; |
|||
/// Random number generator |
|||
storm::utility::RandomProbabilityGenerator<ValueType> generator; |
|||
/// Data structure to temp store states. |
|||
storm::storage::BitVectorHashMap<uint32_t> stateToId; |
|||
|
|||
std::unordered_map<uint32_t, generator::CompressedState> idToState; |
|||
|
|||
private: |
|||
// Create a callback for the next-state generator to enable it to request the index of states. |
|||
std::function<uint32_t (generator::CompressedState const&)> stateToIdCallback = std::bind(&DiscreteTimePrismProgramSimulator<ValueType>::getOrAddStateIndex, this, std::placeholders::_1); |
|||
|
|||
}; |
|||
} |
|||
} |
@ -0,0 +1,56 @@ |
|||
#include "test/storm_gtest.h"
|
|||
#include "storm/simulator/PrismProgramSimulator.h"
|
|||
#include "storm-parsers/parser/PrismParser.h"
|
|||
#include "storm/environment/Environment.h"
|
|||
|
|||
TEST(PrismProgramSimulatorTest, KnuthYaoDieTest) { |
|||
storm::Environment env; |
|||
storm::prism::Program program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/mdp/die_c1.nm"); |
|||
storm::builder::BuilderOptions options; |
|||
options.setBuildAllRewardModels(); |
|||
|
|||
storm::simulator::DiscreteTimePrismProgramSimulator<double> sim(program, options); |
|||
auto rew = sim.getLastRewards(); |
|||
rew = sim.getLastRewards(); |
|||
EXPECT_EQ(1ul, rew.size()); |
|||
EXPECT_EQ(0.0, rew[0]); |
|||
#if 0
|
|||
std::cout << "reward: "; |
|||
for (auto const& r : rew) { |
|||
std::cout << r << " "; |
|||
} |
|||
std::cout << std::endl; |
|||
std::cout << sim.getCurrentStateAsValuation() << std::endl; |
|||
for (auto const& c : sim.getChoices()) { |
|||
std::cout << "Choice "; |
|||
std::cout << "action index: " << program.getActionName(c.getActionIndex()) << std::endl; |
|||
} |
|||
#endif
|
|||
EXPECT_EQ(2ul, sim.getChoices().size()); |
|||
sim.step(0); |
|||
rew = sim.getLastRewards(); |
|||
EXPECT_EQ(1ul, rew.size()); |
|||
EXPECT_EQ(0.0, rew[0]); |
|||
#if 0
|
|||
std::cout << "reward: "; |
|||
for (auto const& r : rew) { |
|||
std::cout << r << " "; |
|||
} |
|||
std::cout << std::endl; |
|||
std::cout << sim.getCurrentStateString() << std::endl; |
|||
#endif
|
|||
EXPECT_EQ(1ul, sim.getChoices().size()); |
|||
sim.step(0); |
|||
rew = sim.getLastRewards(); |
|||
EXPECT_EQ(1ul, rew.size()); |
|||
EXPECT_EQ(1.0, rew[0]); |
|||
#if 0
|
|||
std::cout << "reward: "; |
|||
for (auto const& r : rew) { |
|||
std::cout << r << " "; |
|||
} |
|||
std::cout << std::endl; |
|||
std::cout << sim.getCurrentStateString() << std::endl; |
|||
#endif
|
|||
|
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue