diff --git a/src/storage/prism.cpp b/src/storage/prism.cpp index 0439b81..e2eb9af 100644 --- a/src/storage/prism.cpp +++ b/src/storage/prism.cpp @@ -1,12 +1,14 @@ #include "prism.h" #include #include +#include #include "src/helpers.h" #include #include #include #include #include +#include #include #include "storm/exceptions/NotSupportedException.h" #include @@ -262,22 +264,98 @@ class StateGenerator { return generator.satisfies(expression); } - choice_list_type expand() { + storm::generator::StateBehavior expandBehavior() { if (!hasComputedInitialStates) { STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "Initial state not initialized"); } + return generator.expand(stateToIdCallback); + } + + choice_list_type expand() { + auto behavior = expandBehavior(); choice_list_type choices_result; - auto behavior = generator.expand(stateToIdCallback); - for (auto& choice : behavior.getChoices()) { + for (auto choice : behavior.getChoices()) { choices_result.push_back(GeneratorChoice(choice)); } return choices_result; } + bool isTerminal() { + if (!hasComputedInitialStates) { + STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, + "Initial state not initialized"); + } + choice_list_type choices_result; + auto behavior = generator.expand(stateToIdCallback); + return behavior.getChoices().empty(); + } + }; +std::map> simulate(storm::prism::Program const& program, uint64_t totalSamples, uint64_t maxSteps) { + using StateType = uint32_t; + using ValueType = double; + + StateGenerator generator(program); + std::map> result; + std::set visited; + auto goal = program.getLabelExpression("goal"); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0.0, 1.0); + + const auto addSample = [&result](StateType state, bool isGoal) { + auto hitsVisits = result.count(state) != 0 ? result[state] : std::make_pair(0, 0); + if (isGoal) hitsVisits.first++; + hitsVisits.second++; + result[state] = hitsVisits; + }; + + const auto sampleBehavior = [&gen, &dis](std::vector> const& choices) -> StateType { + ValueType rnd = dis(gen); + STORM_LOG_THROW(choices.size() == 1, storm::exceptions::InvalidStateException, "nondeterminism"); + auto choice = choices[0]; + for (auto entry : choice) { + if (rnd <= entry.second) { + return entry.first; + } + rnd -= entry.second; + } + STORM_LOG_THROW(false, storm::exceptions::InvalidStateException, "unreachable"); + }; + + for (unsigned int i = 0; i <= totalSamples; i++) { + StateType state = generator.loadInitialState(); + uint64_t steps = 0; + bool hitGoal = false; + visited.clear(); + + while (steps <= maxSteps) { + steps++; + visited.insert(state); + + if (generator.satisfies(goal)) { + addSample(state, true); + break; + } + + auto behavior = generator.expandBehavior(); + auto choices = behavior.getChoices(); + if (choices.empty()) { + addSample(state, false); + break; + } + + state = sampleBehavior(choices); + } + } + + return result; +} + template void define_stateGeneration(py::module& m) { py::class_> valuation_mapping(m, "ValuationMapping", "A valuation mapping for a state consists of a mapping from variable to value for each of the three types."); @@ -333,4 +411,6 @@ void define_stateGeneration(py::module& m) { :rtype: [GeneratorChoice] )doc"); + + m.def("simulate", &simulate); }