From 42b7865e7ea2be16f04693701d5cbb20d0812ed6 Mon Sep 17 00:00:00 2001 From: Tim Quatmann Date: Thu, 1 Aug 2019 16:14:26 +0200 Subject: [PATCH] DirectEncodingParser: Added support for Action-based rewards. --- .../parser/DirectEncodingParser.cpp | 62 ++++++++++++++++--- .../storm/parser/DirectEncodingParserTest.cpp | 14 +++++ 2 files changed, 66 insertions(+), 10 deletions(-) diff --git a/src/storm-parsers/parser/DirectEncodingParser.cpp b/src/storm-parsers/parser/DirectEncodingParser.cpp index a34bbdb35..dea8a60fb 100644 --- a/src/storm-parsers/parser/DirectEncodingParser.cpp +++ b/src/storm-parsers/parser/DirectEncodingParser.cpp @@ -114,6 +114,7 @@ namespace storm { modelComponents->observabilityClasses = std::vector(); modelComponents->observabilityClasses->resize(stateSize); std::vector> stateRewards; + std::vector> actionRewards; if (continuousTime) { modelComponents->exitRates = std::vector(stateSize); if (type == storm::models::ModelType::MarkovAutomaton) { @@ -128,6 +129,7 @@ namespace storm { // Iterate over all lines std::string line; size_t row = 0; + size_t firstRowOfState = 0; size_t state = 0; bool firstState = true; bool firstActionForState = true; @@ -160,6 +162,7 @@ namespace storm { STORM_LOG_TRACE("new Row Group starts at " << row << "."); builder.newRowGroup(row); } + firstRowOfState = row; if (type == storm::models::ModelType::Ctmc || type == storm::models::ModelType::MarkovAutomaton) { // Parse exit rate for CTMC or MA @@ -190,14 +193,19 @@ namespace storm { std::vector rewards; boost::split(rewards, rewardsStr, boost::is_any_of(",")); if (stateRewards.size() < rewards.size()) { - stateRewards.resize(rewards.size(), std::vector(stateSize, storm::utility::zero())); + stateRewards.resize(rewards.size()); } auto stateRewardsIt = stateRewards.begin(); for (auto const& rew : rewards) { - (*stateRewardsIt)[state] = valueParser.parseValue(rew); + auto rewardValue = valueParser.parseValue(rew); + if (!storm::utility::isZero(rewardValue)) { + if (stateRewardsIt->empty()) { + stateRewardsIt->resize(stateSize, storm::utility::zero()); + } + (*stateRewardsIt)[row] = std::move(rewardValue); + } ++stateRewardsIt; } - line = line.substr(posEndReward+1); } @@ -258,15 +266,39 @@ namespace storm { } STORM_LOG_TRACE("New action: " << row); line = line.substr(8); //Remove "\taction " + std::string curString = line; + size_t posEnd = line.find(" "); + if (posEnd != std::string::npos) { + curString = line.substr(0, posEnd); + line = line.substr(posEnd+1); + } else { + line = ""; + } + size_t parsedId = parseNumber(curString); + STORM_LOG_ASSERT(row == firstRowOfState + parsedId, "Action ids do not correspond."); // Check for rewards if (boost::starts_with(line, "[")) { // Rewards found size_t posEndReward = line.find(']'); STORM_LOG_THROW(posEndReward != std::string::npos, storm::exceptions::WrongFormatException, "] missing."); - std::string rewards = line.substr(1, posEndReward-1); - STORM_LOG_TRACE("Transition rewards: " << rewards); - STORM_LOG_WARN("Transition rewards [" << rewards << "] not parsed."); - // TODO save rewards + std::string rewardsStr = line.substr(1, posEndReward-1); + STORM_LOG_TRACE("Action rewards: " << rewardsStr); + std::vector rewards; + boost::split(rewards, rewardsStr, boost::is_any_of(",")); + if (actionRewards.size() < rewards.size()) { + actionRewards.resize(rewards.size()); + } + auto actionRewardsIt = actionRewards.begin(); + for (auto const& rew : rewards) { + auto rewardValue = valueParser.parseValue(rew); + if (!storm::utility::isZero(rewardValue)) { + if (actionRewardsIt->size() <= row) { + actionRewardsIt->resize(row + 1, storm::utility::zero()); + } + (*actionRewardsIt)[row] = std::move(rewardValue); + } + ++actionRewardsIt; + } line = line.substr(posEndReward+1); } // TODO import choice labeling when the export works @@ -286,17 +318,27 @@ namespace storm { STORM_LOG_TRACE("Finished parsing"); modelComponents->transitionMatrix = builder.build(row + 1, stateSize, nonDeterministic ? stateSize : 0); + STORM_LOG_TRACE("Built matrix"); - for (uint64_t i = 0; i < stateRewards.size(); ++i) { + uint64_t numRewardModels = std::max(stateRewards.size(), actionRewards.size()); + for (uint64_t i = 0; i < numRewardModels; ++i) { std::string rewardModelName; if (rewardModelNames.size() <= i) { rewardModelName = "rew" + std::to_string(i); } else { rewardModelName = rewardModelNames[i]; } - modelComponents->rewardModels.emplace(rewardModelName, storm::models::sparse::StandardRewardModel(std::move(stateRewards[i]))); + boost::optional> stateRewardVector, actionRewardVector; + if (!stateRewards[i].empty()) { + stateRewardVector = std::move(stateRewards[i]); + } + if (!actionRewards[i].empty()) { + actionRewards[i].resize(row + 1, storm::utility::zero()); + actionRewardVector = std::move(actionRewards[i]); + } + modelComponents->rewardModels.emplace(rewardModelName, storm::models::sparse::StandardRewardModel(std::move(stateRewardVector), std::move(actionRewardVector))); } - STORM_LOG_TRACE("Built matrix"); + STORM_LOG_TRACE("Built reward models"); return modelComponents; } diff --git a/src/test/storm/parser/DirectEncodingParserTest.cpp b/src/test/storm/parser/DirectEncodingParserTest.cpp index c5716195b..235c4070c 100644 --- a/src/test/storm/parser/DirectEncodingParserTest.cpp +++ b/src/test/storm/parser/DirectEncodingParserTest.cpp @@ -35,6 +35,11 @@ TEST(DirectEncodingParserTest, MdpParsing) { ASSERT_EQ(5ul, modelPtr->getStates("six").getNumberOfSetBits()); ASSERT_TRUE(modelPtr->hasLabel("eleven")); ASSERT_EQ(2ul, modelPtr->getStates("eleven").getNumberOfSetBits()); + ASSERT_EQ(1ul, modelPtr->getNumberOfRewardModels()); + ASSERT_TRUE(modelPtr->hasRewardModel("coinflips")); + ASSERT_TRUE(!modelPtr->getRewardModel("coinflips").hasStateRewards()); + ASSERT_TRUE(modelPtr->getRewardModel("coinflips").hasStateActionRewards()); + ASSERT_TRUE(!modelPtr->getRewardModel("coinflips").isAllZero()); } TEST(DirectEncodingParserTest, CtmcParsing) { @@ -50,6 +55,11 @@ TEST(DirectEncodingParserTest, CtmcParsing) { ASSERT_EQ(64ul, modelPtr->getStates("premium").getNumberOfSetBits()); ASSERT_TRUE(modelPtr->hasLabel("minimum")); ASSERT_EQ(132ul, modelPtr->getStates("minimum").getNumberOfSetBits()); + ASSERT_EQ(1ul, modelPtr->getNumberOfRewardModels()); + ASSERT_TRUE(modelPtr->hasRewardModel("num_repairs")); + ASSERT_TRUE(!modelPtr->getRewardModel("num_repairs").hasStateRewards()); + ASSERT_TRUE(modelPtr->getRewardModel("num_repairs").hasStateActionRewards()); + ASSERT_TRUE(!modelPtr->getRewardModel("num_repairs").isAllZero()); } TEST(DirectEncodingParserTest, MarkovAutomatonParsing) { @@ -63,7 +73,11 @@ TEST(DirectEncodingParserTest, MarkovAutomatonParsing) { ASSERT_EQ(19ul, ma->getNumberOfChoices()); ASSERT_EQ(10ul, ma->getMarkovianStates().getNumberOfSetBits()); ASSERT_EQ(5, ma->getMaximalExitRate()); + ASSERT_EQ(1ul, ma->getNumberOfRewardModels()); ASSERT_TRUE(ma->hasRewardModel("avg_waiting_time")); + ASSERT_TRUE(ma->getRewardModel("avg_waiting_time").hasStateRewards()); + ASSERT_TRUE(!ma->getRewardModel("avg_waiting_time").hasStateActionRewards()); + ASSERT_TRUE(!ma->getRewardModel("avg_waiting_time").isAllZero()); ASSERT_TRUE(modelPtr->hasLabel("init")); ASSERT_EQ(1ul, modelPtr->getInitialStates().getNumberOfSetBits()); ASSERT_TRUE(modelPtr->hasLabel("one_job_finished"));