4 changed files with 261 additions and 1 deletions
			
			
		- 
					104src/storm/simulator/PrismProgramSimulator.cpp
- 
					100src/storm/simulator/PrismProgramSimulator.h
- 
					2src/test/storm/CMakeLists.txt
- 
					56src/test/storm/simulator/PrismProgramSimulator.cpp
| @ -0,0 +1,104 @@ | |||||
|  | #include "storm/simulator/PrismProgramSimulator.h"
 | ||||
|  | #include "storm/exceptions/NotSupportedException.h"
 | ||||
|  | 
 | ||||
|  | using namespace storm::generator; | ||||
|  | 
 | ||||
|  | namespace storm { | ||||
|  |     namespace simulator { | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         DiscreteTimePrismProgramSimulator<ValueType>::DiscreteTimePrismProgramSimulator(storm::prism::Program const& program, storm::generator::NextStateGeneratorOptions const& options) | ||||
|  |         : program(program), currentState(), stateGenerator(std::make_shared<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>>(program, options)), zeroRewards(stateGenerator->getNumberOfRewardModels(), storm::utility::zero<ValueType>()), lastActionRewards(zeroRewards) | ||||
|  |         { | ||||
|  |             // Current state needs to be overwritten to actual initial state.
 | ||||
|  |             // But first, let us create a state generator.
 | ||||
|  | 
 | ||||
|  |             clearStateCaches(); | ||||
|  |             resetToInitial(); | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         void DiscreteTimePrismProgramSimulator<ValueType>::setSeed(uint64_t newSeed) { | ||||
|  |             generator = storm::utility::RandomProbabilityGenerator<ValueType>(newSeed); | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         bool DiscreteTimePrismProgramSimulator<ValueType>::step(uint64_t actionNumber) { | ||||
|  |             uint32_t nextState = behavior.getChoices()[actionNumber].sampleFromDistribution(generator.random()); | ||||
|  |             lastActionRewards = behavior.getChoices()[actionNumber].getRewards(); | ||||
|  |             STORM_LOG_ASSERT(lastActionRewards.size() == stateGenerator->getNumberOfRewardModels(), "Reward vector should have as many rewards as model."); | ||||
|  |             currentState = idToState[nextState]; | ||||
|  |             // TODO we do not need to do this in every step!
 | ||||
|  |             clearStateCaches(); | ||||
|  |             explore(); | ||||
|  |             return true; | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         bool DiscreteTimePrismProgramSimulator<ValueType>::explore() { | ||||
|  |             // Load the current state into the next state generator.
 | ||||
|  |             stateGenerator->load(currentState); | ||||
|  |             // TODO: This low-level code currently expands all actions, while this may not be necessary.
 | ||||
|  |             behavior = stateGenerator->expand(stateToIdCallback); | ||||
|  |             if (behavior.getStateRewards().size() > 0) { | ||||
|  |                 STORM_LOG_ASSERT(behavior.getStateRewards().size() == lastActionRewards.size(), "Reward vectors should have same length."); | ||||
|  |             } | ||||
|  |             return true; | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         std::vector<generator::Choice<ValueType, uint32_t>> const& DiscreteTimePrismProgramSimulator<ValueType>::getChoices() { | ||||
|  |             return behavior.getChoices(); | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         std::vector<ValueType> const& DiscreteTimePrismProgramSimulator<ValueType>::getLastRewards() const { | ||||
|  |             return lastActionRewards; | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         CompressedState const& DiscreteTimePrismProgramSimulator<ValueType>::getCurrentState() const { | ||||
|  |             return currentState; | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         expressions::SimpleValuation DiscreteTimePrismProgramSimulator<ValueType>::getCurrentStateAsValuation() const { | ||||
|  |             return unpackStateIntoValuation(currentState, stateGenerator->getVariableInformation(), program.getManager()); | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         std::string DiscreteTimePrismProgramSimulator<ValueType>::getCurrentStateString() const { | ||||
|  |             return stateGenerator->stateToString(currentState); | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         bool DiscreteTimePrismProgramSimulator<ValueType>::resetToInitial() { | ||||
|  |             auto indices = stateGenerator->getInitialStates(stateToIdCallback); | ||||
|  |             STORM_LOG_THROW(indices.size() == 1, storm::exceptions::NotSupportedException, "Program must have a unique initial state"); | ||||
|  |             currentState = idToState[indices[0]]; | ||||
|  |             return explore(); | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         uint32_t DiscreteTimePrismProgramSimulator<ValueType>::getOrAddStateIndex(generator::CompressedState const& state) { | ||||
|  |             uint32_t newIndex = static_cast<uint32_t>(stateToId.size()); | ||||
|  | 
 | ||||
|  |             // Check, if the state was already registered.
 | ||||
|  |             std::pair<uint32_t, std::size_t> actualIndexBucketPair = stateToId.findOrAddAndGetBucket(state, newIndex); | ||||
|  | 
 | ||||
|  |             uint32_t actualIndex = actualIndexBucketPair.first; | ||||
|  |             if (actualIndex == newIndex) { | ||||
|  |                 idToState[actualIndex] = state; | ||||
|  |             } | ||||
|  |             return actualIndex; | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template<typename ValueType> | ||||
|  |         void DiscreteTimePrismProgramSimulator<ValueType>::clearStateCaches() { | ||||
|  |             idToState.clear(); | ||||
|  |             stateToId = storm::storage::BitVectorHashMap<uint32_t>(stateGenerator->getStateSize()); | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         template class DiscreteTimePrismProgramSimulator<double>; | ||||
|  |     } | ||||
|  | } | ||||
| @ -0,0 +1,100 @@ | |||||
|  | #pragma once | ||||
|  | 
 | ||||
|  | #include "storm/storage/prism/Program.h" | ||||
|  | #include "storm/storage/expressions/SimpleValuation.h" | ||||
|  | #include "storm/generator/PrismNextStateGenerator.h" | ||||
|  | #include "storm/utility/random.h" | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | namespace storm { | ||||
|  |     namespace simulator { | ||||
|  | 
 | ||||
|  |         /** | ||||
|  |          * This class provides a simulator interface on the prism program, | ||||
|  |          * and uses the next state generator. While the next state generator has been tuned, | ||||
|  |          * it is not targeted for simulation purposes. In particular, we (as of now) | ||||
|  |          * always extend all actions as soon as we arrive in a state. | ||||
|  |          * This may cause significant overhead, especially with a larger branching factor. | ||||
|  |          * | ||||
|  |          * On the other hand, this simulator is convenient for stepping through the model | ||||
|  |          * as it potentially allows considering the next states. | ||||
|  |          * Thus, while a performant alternative would be great, this simulator has its own merits. | ||||
|  |          * | ||||
|  |          * @tparam ValueType | ||||
|  |          */ | ||||
|  |         template<typename ValueType> | ||||
|  |         class DiscreteTimePrismProgramSimulator { | ||||
|  |         public: | ||||
|  |             /** | ||||
|  |              * Initialize the simulator for a given prism program. | ||||
|  |              * | ||||
|  |              * @param program The prism program. Should have a unique initial state. | ||||
|  |              * @param options The generator options that are used to generate successor states. | ||||
|  |              */ | ||||
|  |             DiscreteTimePrismProgramSimulator(storm::prism::Program const& program, | ||||
|  |                                               storm::generator::NextStateGeneratorOptions const& options); | ||||
|  |             /** | ||||
|  |              * Set the simulation seed. | ||||
|  |              */ | ||||
|  |             void setSeed(uint64_t); | ||||
|  |             /** | ||||
|  |              * | ||||
|  |              * @return A list of choices that encode the possibilities in the current state. | ||||
|  |              */ | ||||
|  |             std::vector<generator::Choice<ValueType, uint32_t>> const& getChoices(); | ||||
|  |             /** | ||||
|  |              * Make a step and randomly select the successor. The action is given as an argument, the index reflects the index of the getChoices vector that can be accessed. | ||||
|  |              * | ||||
|  |              * @param actionNumber The action to select. | ||||
|  |              * @return true, if this action can be taken. | ||||
|  |              */ | ||||
|  |             bool step(uint64_t actionNumber); | ||||
|  |             /** | ||||
|  |              * Accessor for the last state action reward and the current state reward, added together. | ||||
|  |              * @return A vector with te number of rewards. | ||||
|  |              */ | ||||
|  |             std::vector<ValueType> const& getLastRewards() const; | ||||
|  |             generator::CompressedState const& getCurrentState() const; | ||||
|  |             expressions::SimpleValuation getCurrentStateAsValuation() const; | ||||
|  | 
 | ||||
|  |             std::string getCurrentStateString() const; | ||||
|  |             /** | ||||
|  |              * Reset to the (unique) initial state. | ||||
|  |              * | ||||
|  |              * @return | ||||
|  |              */ | ||||
|  |             bool resetToInitial(); | ||||
|  |         protected: | ||||
|  |             bool explore(); | ||||
|  |             void clearStateCaches(); | ||||
|  |             /** | ||||
|  |              * Helper function for (temp) storing states. | ||||
|  |              */ | ||||
|  |             uint32_t getOrAddStateIndex(generator::CompressedState const&); | ||||
|  | 
 | ||||
|  |             /// The program that we are simulating. | ||||
|  |             storm::prism::Program const& program; | ||||
|  |             /// The current state in the program, in its compressed form. | ||||
|  |             generator::CompressedState currentState; | ||||
|  |             /// Generator for the next states | ||||
|  |             std::shared_ptr<storm::generator::PrismNextStateGenerator<ValueType, uint32_t>> stateGenerator; | ||||
|  |             bool explored = false; | ||||
|  |             generator::StateBehavior<ValueType> behavior; | ||||
|  |             /// Helper for last action reward construction | ||||
|  |             std::vector<ValueType> zeroRewards; | ||||
|  |             /// Stores the action rewards from the last action. | ||||
|  |             std::vector<ValueType> lastActionRewards; | ||||
|  |             /// Random number generator | ||||
|  |             storm::utility::RandomProbabilityGenerator<ValueType> generator; | ||||
|  |             /// Data structure to temp store states. | ||||
|  |             storm::storage::BitVectorHashMap<uint32_t> stateToId; | ||||
|  | 
 | ||||
|  |             std::unordered_map<uint32_t, generator::CompressedState> idToState; | ||||
|  | 
 | ||||
|  |         private: | ||||
|  |             // Create a callback for the next-state generator to enable it to request the index of states. | ||||
|  |             std::function<uint32_t (generator::CompressedState const&)> stateToIdCallback = std::bind(&DiscreteTimePrismProgramSimulator<ValueType>::getOrAddStateIndex, this, std::placeholders::_1); | ||||
|  | 
 | ||||
|  |         }; | ||||
|  |     } | ||||
|  | } | ||||
| @ -0,0 +1,56 @@ | |||||
|  | #include "test/storm_gtest.h"
 | ||||
|  | #include "storm/simulator/PrismProgramSimulator.h"
 | ||||
|  | #include "storm-parsers/parser/PrismParser.h"
 | ||||
|  | #include "storm/environment/Environment.h"
 | ||||
|  | 
 | ||||
|  | TEST(PrismProgramSimulatorTest, KnuthYaoDieTest) { | ||||
|  |     storm::Environment env; | ||||
|  |     storm::prism::Program program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/mdp/die_c1.nm"); | ||||
|  |     storm::builder::BuilderOptions options; | ||||
|  |     options.setBuildAllRewardModels(); | ||||
|  | 
 | ||||
|  |     storm::simulator::DiscreteTimePrismProgramSimulator<double> sim(program, options); | ||||
|  |     auto rew = sim.getLastRewards(); | ||||
|  |     rew = sim.getLastRewards(); | ||||
|  |     EXPECT_EQ(1ul, rew.size()); | ||||
|  |     EXPECT_EQ(0.0, rew[0]); | ||||
|  | #if 0
 | ||||
|  |     std::cout << "reward: "; | ||||
|  |     for (auto const& r : rew) { | ||||
|  |         std::cout << r << " "; | ||||
|  |     } | ||||
|  |     std::cout << std::endl; | ||||
|  |     std::cout << sim.getCurrentStateAsValuation() << std::endl; | ||||
|  |     for (auto const& c : sim.getChoices()) { | ||||
|  |         std::cout << "Choice "; | ||||
|  |         std::cout << "action index: " << program.getActionName(c.getActionIndex()) << std::endl; | ||||
|  |     } | ||||
|  | #endif
 | ||||
|  |     EXPECT_EQ(2ul, sim.getChoices().size()); | ||||
|  |     sim.step(0); | ||||
|  |     rew = sim.getLastRewards(); | ||||
|  |     EXPECT_EQ(1ul, rew.size()); | ||||
|  |     EXPECT_EQ(0.0, rew[0]); | ||||
|  | #if 0
 | ||||
|  |     std::cout << "reward: "; | ||||
|  |     for (auto const& r : rew) { | ||||
|  |         std::cout << r << " "; | ||||
|  |     } | ||||
|  |     std::cout << std::endl; | ||||
|  |     std::cout << sim.getCurrentStateString() << std::endl; | ||||
|  | #endif
 | ||||
|  |     EXPECT_EQ(1ul, sim.getChoices().size()); | ||||
|  |     sim.step(0); | ||||
|  |     rew = sim.getLastRewards(); | ||||
|  |     EXPECT_EQ(1ul, rew.size()); | ||||
|  |     EXPECT_EQ(1.0, rew[0]); | ||||
|  | #if 0
 | ||||
|  |     std::cout << "reward: "; | ||||
|  |     for (auto const& r : rew) { | ||||
|  |         std::cout << r << " "; | ||||
|  |     } | ||||
|  |     std::cout << std::endl; | ||||
|  |     std::cout << sim.getCurrentStateString() << std::endl; | ||||
|  | #endif
 | ||||
|  | 
 | ||||
|  | } | ||||
						Write
						Preview
					
					
					Loading…
					
					Cancel
						Save
					
		Reference in new issue