diff --git a/src/storm/simulator/PrismProgramSimulator.cpp b/src/storm/simulator/PrismProgramSimulator.cpp index 879848d1e..40064c295 100644 --- a/src/storm/simulator/PrismProgramSimulator.cpp +++ b/src/storm/simulator/PrismProgramSimulator.cpp @@ -125,6 +125,15 @@ namespace storm { return explore(); } + template + std::vector DiscreteTimePrismProgramSimulator::getRewardNames() const { + std::vector names; + for (uint64_t i = 0; i < stateGenerator->getNumberOfRewardModels(); ++i) { + names.push_back(stateGenerator->getRewardModelInformation(i).getName()); + } + return names; + } + template uint32_t DiscreteTimePrismProgramSimulator::getOrAddStateIndex(generator::CompressedState const& state) { uint32_t newIndex = static_cast(stateToId.size()); diff --git a/src/storm/simulator/PrismProgramSimulator.h b/src/storm/simulator/PrismProgramSimulator.h index daec0a81e..8b2207a1c 100644 --- a/src/storm/simulator/PrismProgramSimulator.h +++ b/src/storm/simulator/PrismProgramSimulator.h @@ -75,6 +75,11 @@ namespace storm { bool resetToState(generator::CompressedState const& compressedState); bool resetToState(expressions::SimpleValuation const& valuationState); + + /** + * The names of the rewards that are returned. + */ + std::vector getRewardNames() const; protected: bool explore(); void clearStateCaches(); diff --git a/src/test/storm/simulator/PrismProgramSimulator.cpp b/src/test/storm/simulator/PrismProgramSimulator.cpp index 21a8439c6..56ce2d25f 100644 --- a/src/test/storm/simulator/PrismProgramSimulator.cpp +++ b/src/test/storm/simulator/PrismProgramSimulator.cpp @@ -10,6 +10,7 @@ TEST(PrismProgramSimulatorTest, KnuthYaoDieTest) { options.setBuildAllRewardModels(); storm::simulator::DiscreteTimePrismProgramSimulator sim(program, options); + EXPECT_EQ("coin_flips", sim.getRewardNames()[0]); auto rew = sim.getLastRewards(); rew = sim.getLastRewards(); EXPECT_EQ(1ul, rew.size());