diff --git a/src/storm/simulator/PrismProgramSimulator.cpp b/src/storm/simulator/PrismProgramSimulator.cpp new file mode 100644 index 000000000..6d9428b6a --- /dev/null +++ b/src/storm/simulator/PrismProgramSimulator.cpp @@ -0,0 +1,104 @@ +#include "storm/simulator/PrismProgramSimulator.h" +#include "storm/exceptions/NotSupportedException.h" + +using namespace storm::generator; + +namespace storm { + namespace simulator { + + template + DiscreteTimePrismProgramSimulator::DiscreteTimePrismProgramSimulator(storm::prism::Program const& program, storm::generator::NextStateGeneratorOptions const& options) + : program(program), currentState(), stateGenerator(std::make_shared>(program, options)), zeroRewards(stateGenerator->getNumberOfRewardModels(), storm::utility::zero()), lastActionRewards(zeroRewards) + { + // Current state needs to be overwritten to actual initial state. + // But first, let us create a state generator. + + clearStateCaches(); + resetToInitial(); + } + + template + void DiscreteTimePrismProgramSimulator::setSeed(uint64_t newSeed) { + generator = storm::utility::RandomProbabilityGenerator(newSeed); + } + + template + bool DiscreteTimePrismProgramSimulator::step(uint64_t actionNumber) { + uint32_t nextState = behavior.getChoices()[actionNumber].sampleFromDistribution(generator.random()); + lastActionRewards = behavior.getChoices()[actionNumber].getRewards(); + STORM_LOG_ASSERT(lastActionRewards.size() == stateGenerator->getNumberOfRewardModels(), "Reward vector should have as many rewards as model."); + currentState = idToState[nextState]; + // TODO we do not need to do this in every step! + clearStateCaches(); + explore(); + return true; + } + + template + bool DiscreteTimePrismProgramSimulator::explore() { + // Load the current state into the next state generator. + stateGenerator->load(currentState); + // TODO: This low-level code currently expands all actions, while this may not be necessary. + behavior = stateGenerator->expand(stateToIdCallback); + if (behavior.getStateRewards().size() > 0) { + STORM_LOG_ASSERT(behavior.getStateRewards().size() == lastActionRewards.size(), "Reward vectors should have same length."); + } + return true; + } + + template + std::vector> const& DiscreteTimePrismProgramSimulator::getChoices() { + return behavior.getChoices(); + } + + template + std::vector const& DiscreteTimePrismProgramSimulator::getLastRewards() const { + return lastActionRewards; + } + + template + CompressedState const& DiscreteTimePrismProgramSimulator::getCurrentState() const { + return currentState; + } + + template + expressions::SimpleValuation DiscreteTimePrismProgramSimulator::getCurrentStateAsValuation() const { + return unpackStateIntoValuation(currentState, stateGenerator->getVariableInformation(), program.getManager()); + } + + template + std::string DiscreteTimePrismProgramSimulator::getCurrentStateString() const { + return stateGenerator->stateToString(currentState); + } + + template + bool DiscreteTimePrismProgramSimulator::resetToInitial() { + auto indices = stateGenerator->getInitialStates(stateToIdCallback); + STORM_LOG_THROW(indices.size() == 1, storm::exceptions::NotSupportedException, "Program must have a unique initial state"); + currentState = idToState[indices[0]]; + return explore(); + } + + template + uint32_t DiscreteTimePrismProgramSimulator::getOrAddStateIndex(generator::CompressedState const& state) { + uint32_t newIndex = static_cast(stateToId.size()); + + // Check, if the state was already registered. + std::pair actualIndexBucketPair = stateToId.findOrAddAndGetBucket(state, newIndex); + + uint32_t actualIndex = actualIndexBucketPair.first; + if (actualIndex == newIndex) { + idToState[actualIndex] = state; + } + return actualIndex; + } + + template + void DiscreteTimePrismProgramSimulator::clearStateCaches() { + idToState.clear(); + stateToId = storm::storage::BitVectorHashMap(stateGenerator->getStateSize()); + } + + template class DiscreteTimePrismProgramSimulator; + } +} diff --git a/src/storm/simulator/PrismProgramSimulator.h b/src/storm/simulator/PrismProgramSimulator.h new file mode 100644 index 000000000..116388b05 --- /dev/null +++ b/src/storm/simulator/PrismProgramSimulator.h @@ -0,0 +1,100 @@ +#pragma once + +#include "storm/storage/prism/Program.h" +#include "storm/storage/expressions/SimpleValuation.h" +#include "storm/generator/PrismNextStateGenerator.h" +#include "storm/utility/random.h" + + +namespace storm { + namespace simulator { + + /** + * This class provides a simulator interface on the prism program, + * and uses the next state generator. While the next state generator has been tuned, + * it is not targeted for simulation purposes. In particular, we (as of now) + * always extend all actions as soon as we arrive in a state. + * This may cause significant overhead, especially with a larger branching factor. + * + * On the other hand, this simulator is convenient for stepping through the model + * as it potentially allows considering the next states. + * Thus, while a performant alternative would be great, this simulator has its own merits. + * + * @tparam ValueType + */ + template + class DiscreteTimePrismProgramSimulator { + public: + /** + * Initialize the simulator for a given prism program. + * + * @param program The prism program. Should have a unique initial state. + * @param options The generator options that are used to generate successor states. + */ + DiscreteTimePrismProgramSimulator(storm::prism::Program const& program, + storm::generator::NextStateGeneratorOptions const& options); + /** + * Set the simulation seed. + */ + void setSeed(uint64_t); + /** + * + * @return A list of choices that encode the possibilities in the current state. + */ + std::vector> const& getChoices(); + /** + * Make a step and randomly select the successor. The action is given as an argument, the index reflects the index of the getChoices vector that can be accessed. + * + * @param actionNumber The action to select. + * @return true, if this action can be taken. + */ + bool step(uint64_t actionNumber); + /** + * Accessor for the last state action reward and the current state reward, added together. + * @return A vector with te number of rewards. + */ + std::vector const& getLastRewards() const; + generator::CompressedState const& getCurrentState() const; + expressions::SimpleValuation getCurrentStateAsValuation() const; + + std::string getCurrentStateString() const; + /** + * Reset to the (unique) initial state. + * + * @return + */ + bool resetToInitial(); + protected: + bool explore(); + void clearStateCaches(); + /** + * Helper function for (temp) storing states. + */ + uint32_t getOrAddStateIndex(generator::CompressedState const&); + + /// The program that we are simulating. + storm::prism::Program const& program; + /// The current state in the program, in its compressed form. + generator::CompressedState currentState; + /// Generator for the next states + std::shared_ptr> stateGenerator; + bool explored = false; + generator::StateBehavior behavior; + /// Helper for last action reward construction + std::vector zeroRewards; + /// Stores the action rewards from the last action. + std::vector lastActionRewards; + /// Random number generator + storm::utility::RandomProbabilityGenerator generator; + /// Data structure to temp store states. + storm::storage::BitVectorHashMap stateToId; + + std::unordered_map idToState; + + private: + // Create a callback for the next-state generator to enable it to request the index of states. + std::function stateToIdCallback = std::bind(&DiscreteTimePrismProgramSimulator::getOrAddStateIndex, this, std::placeholders::_1); + + }; + } +} diff --git a/src/test/storm/CMakeLists.txt b/src/test/storm/CMakeLists.txt index 20e27e95b..d99806231 100755 --- a/src/test/storm/CMakeLists.txt +++ b/src/test/storm/CMakeLists.txt @@ -10,7 +10,7 @@ register_source_groups_from_filestructure("${ALL_FILES}" test) include_directories(${GTEST_INCLUDE_DIR}) # Set split and non-split test directories -set(NON_SPLIT_TESTS abstraction adapter builder logic model parser permissiveschedulers solver storage transformer utility) +set(NON_SPLIT_TESTS abstraction adapter builder logic model parser permissiveschedulers simulator solver storage transformer utility) set(MODELCHECKER_TEST_SPLITS abstraction csl exploration multiobjective reachability) set(MODELCHECKER_PRCTL_TEST_SPLITS dtmc mdp) diff --git a/src/test/storm/simulator/PrismProgramSimulator.cpp b/src/test/storm/simulator/PrismProgramSimulator.cpp new file mode 100644 index 000000000..90b3873eb --- /dev/null +++ b/src/test/storm/simulator/PrismProgramSimulator.cpp @@ -0,0 +1,56 @@ +#include "test/storm_gtest.h" +#include "storm/simulator/PrismProgramSimulator.h" +#include "storm-parsers/parser/PrismParser.h" +#include "storm/environment/Environment.h" + +TEST(PrismProgramSimulatorTest, KnuthYaoDieTest) { + storm::Environment env; + storm::prism::Program program = storm::parser::PrismParser::parse(STORM_TEST_RESOURCES_DIR "/mdp/die_c1.nm"); + storm::builder::BuilderOptions options; + options.setBuildAllRewardModels(); + + storm::simulator::DiscreteTimePrismProgramSimulator sim(program, options); + auto rew = sim.getLastRewards(); + rew = sim.getLastRewards(); + EXPECT_EQ(1ul, rew.size()); + EXPECT_EQ(0.0, rew[0]); +#if 0 + std::cout << "reward: "; + for (auto const& r : rew) { + std::cout << r << " "; + } + std::cout << std::endl; + std::cout << sim.getCurrentStateAsValuation() << std::endl; + for (auto const& c : sim.getChoices()) { + std::cout << "Choice "; + std::cout << "action index: " << program.getActionName(c.getActionIndex()) << std::endl; + } +#endif + EXPECT_EQ(2ul, sim.getChoices().size()); + sim.step(0); + rew = sim.getLastRewards(); + EXPECT_EQ(1ul, rew.size()); + EXPECT_EQ(0.0, rew[0]); +#if 0 + std::cout << "reward: "; + for (auto const& r : rew) { + std::cout << r << " "; + } + std::cout << std::endl; + std::cout << sim.getCurrentStateString() << std::endl; +#endif + EXPECT_EQ(1ul, sim.getChoices().size()); + sim.step(0); + rew = sim.getLastRewards(); + EXPECT_EQ(1ul, rew.size()); + EXPECT_EQ(1.0, rew[0]); +#if 0 + std::cout << "reward: "; + for (auto const& r : rew) { + std::cout << r << " "; + } + std::cout << std::endl; + std::cout << sim.getCurrentStateString() << std::endl; +#endif + +}