Browse Source

DirectEncodingParser: Added support for Action-based rewards.

tempestpy_adaptions
Tim Quatmann 6 years ago
parent
commit
42b7865e7e
  1. 62
      src/storm-parsers/parser/DirectEncodingParser.cpp
  2. 14
      src/test/storm/parser/DirectEncodingParserTest.cpp

62
src/storm-parsers/parser/DirectEncodingParser.cpp

@ -114,6 +114,7 @@ namespace storm {
modelComponents->observabilityClasses = std::vector<uint32_t>(); modelComponents->observabilityClasses = std::vector<uint32_t>();
modelComponents->observabilityClasses->resize(stateSize); modelComponents->observabilityClasses->resize(stateSize);
std::vector<std::vector<ValueType>> stateRewards; std::vector<std::vector<ValueType>> stateRewards;
std::vector<std::vector<ValueType>> actionRewards;
if (continuousTime) { if (continuousTime) {
modelComponents->exitRates = std::vector<ValueType>(stateSize); modelComponents->exitRates = std::vector<ValueType>(stateSize);
if (type == storm::models::ModelType::MarkovAutomaton) { if (type == storm::models::ModelType::MarkovAutomaton) {
@ -128,6 +129,7 @@ namespace storm {
// Iterate over all lines // Iterate over all lines
std::string line; std::string line;
size_t row = 0; size_t row = 0;
size_t firstRowOfState = 0;
size_t state = 0; size_t state = 0;
bool firstState = true; bool firstState = true;
bool firstActionForState = true; bool firstActionForState = true;
@ -160,6 +162,7 @@ namespace storm {
STORM_LOG_TRACE("new Row Group starts at " << row << "."); STORM_LOG_TRACE("new Row Group starts at " << row << ".");
builder.newRowGroup(row); builder.newRowGroup(row);
} }
firstRowOfState = row;
if (type == storm::models::ModelType::Ctmc || type == storm::models::ModelType::MarkovAutomaton) { if (type == storm::models::ModelType::Ctmc || type == storm::models::ModelType::MarkovAutomaton) {
// Parse exit rate for CTMC or MA // Parse exit rate for CTMC or MA
@ -190,14 +193,19 @@ namespace storm {
std::vector<std::string> rewards; std::vector<std::string> rewards;
boost::split(rewards, rewardsStr, boost::is_any_of(",")); boost::split(rewards, rewardsStr, boost::is_any_of(","));
if (stateRewards.size() < rewards.size()) { if (stateRewards.size() < rewards.size()) {
stateRewards.resize(rewards.size(), std::vector<ValueType>(stateSize, storm::utility::zero<ValueType>()));
stateRewards.resize(rewards.size());
} }
auto stateRewardsIt = stateRewards.begin(); auto stateRewardsIt = stateRewards.begin();
for (auto const& rew : rewards) { for (auto const& rew : rewards) {
(*stateRewardsIt)[state] = valueParser.parseValue(rew);
auto rewardValue = valueParser.parseValue(rew);
if (!storm::utility::isZero(rewardValue)) {
if (stateRewardsIt->empty()) {
stateRewardsIt->resize(stateSize, storm::utility::zero<ValueType>());
}
(*stateRewardsIt)[row] = std::move(rewardValue);
}
++stateRewardsIt; ++stateRewardsIt;
} }
line = line.substr(posEndReward+1); line = line.substr(posEndReward+1);
} }
@ -258,15 +266,39 @@ namespace storm {
} }
STORM_LOG_TRACE("New action: " << row); STORM_LOG_TRACE("New action: " << row);
line = line.substr(8); //Remove "\taction " line = line.substr(8); //Remove "\taction "
std::string curString = line;
size_t posEnd = line.find(" ");
if (posEnd != std::string::npos) {
curString = line.substr(0, posEnd);
line = line.substr(posEnd+1);
} else {
line = "";
}
size_t parsedId = parseNumber<size_t>(curString);
STORM_LOG_ASSERT(row == firstRowOfState + parsedId, "Action ids do not correspond.");
// Check for rewards // Check for rewards
if (boost::starts_with(line, "[")) { if (boost::starts_with(line, "[")) {
// Rewards found // Rewards found
size_t posEndReward = line.find(']'); size_t posEndReward = line.find(']');
STORM_LOG_THROW(posEndReward != std::string::npos, storm::exceptions::WrongFormatException, "] missing."); STORM_LOG_THROW(posEndReward != std::string::npos, storm::exceptions::WrongFormatException, "] missing.");
std::string rewards = line.substr(1, posEndReward-1);
STORM_LOG_TRACE("Transition rewards: " << rewards);
STORM_LOG_WARN("Transition rewards [" << rewards << "] not parsed.");
// TODO save rewards
std::string rewardsStr = line.substr(1, posEndReward-1);
STORM_LOG_TRACE("Action rewards: " << rewardsStr);
std::vector<std::string> rewards;
boost::split(rewards, rewardsStr, boost::is_any_of(","));
if (actionRewards.size() < rewards.size()) {
actionRewards.resize(rewards.size());
}
auto actionRewardsIt = actionRewards.begin();
for (auto const& rew : rewards) {
auto rewardValue = valueParser.parseValue(rew);
if (!storm::utility::isZero(rewardValue)) {
if (actionRewardsIt->size() <= row) {
actionRewardsIt->resize(row + 1, storm::utility::zero<ValueType>());
}
(*actionRewardsIt)[row] = std::move(rewardValue);
}
++actionRewardsIt;
}
line = line.substr(posEndReward+1); line = line.substr(posEndReward+1);
} }
// TODO import choice labeling when the export works // TODO import choice labeling when the export works
@ -286,17 +318,27 @@ namespace storm {
STORM_LOG_TRACE("Finished parsing"); STORM_LOG_TRACE("Finished parsing");
modelComponents->transitionMatrix = builder.build(row + 1, stateSize, nonDeterministic ? stateSize : 0); modelComponents->transitionMatrix = builder.build(row + 1, stateSize, nonDeterministic ? stateSize : 0);
STORM_LOG_TRACE("Built matrix");
for (uint64_t i = 0; i < stateRewards.size(); ++i) {
uint64_t numRewardModels = std::max(stateRewards.size(), actionRewards.size());
for (uint64_t i = 0; i < numRewardModels; ++i) {
std::string rewardModelName; std::string rewardModelName;
if (rewardModelNames.size() <= i) { if (rewardModelNames.size() <= i) {
rewardModelName = "rew" + std::to_string(i); rewardModelName = "rew" + std::to_string(i);
} else { } else {
rewardModelName = rewardModelNames[i]; rewardModelName = rewardModelNames[i];
} }
modelComponents->rewardModels.emplace(rewardModelName, storm::models::sparse::StandardRewardModel<ValueType>(std::move(stateRewards[i])));
boost::optional<std::vector<ValueType>> stateRewardVector, actionRewardVector;
if (!stateRewards[i].empty()) {
stateRewardVector = std::move(stateRewards[i]);
}
if (!actionRewards[i].empty()) {
actionRewards[i].resize(row + 1, storm::utility::zero<ValueType>());
actionRewardVector = std::move(actionRewards[i]);
}
modelComponents->rewardModels.emplace(rewardModelName, storm::models::sparse::StandardRewardModel<ValueType>(std::move(stateRewardVector), std::move(actionRewardVector)));
} }
STORM_LOG_TRACE("Built matrix");
STORM_LOG_TRACE("Built reward models");
return modelComponents; return modelComponents;
} }

14
src/test/storm/parser/DirectEncodingParserTest.cpp

@ -35,6 +35,11 @@ TEST(DirectEncodingParserTest, MdpParsing) {
ASSERT_EQ(5ul, modelPtr->getStates("six").getNumberOfSetBits()); ASSERT_EQ(5ul, modelPtr->getStates("six").getNumberOfSetBits());
ASSERT_TRUE(modelPtr->hasLabel("eleven")); ASSERT_TRUE(modelPtr->hasLabel("eleven"));
ASSERT_EQ(2ul, modelPtr->getStates("eleven").getNumberOfSetBits()); ASSERT_EQ(2ul, modelPtr->getStates("eleven").getNumberOfSetBits());
ASSERT_EQ(1ul, modelPtr->getNumberOfRewardModels());
ASSERT_TRUE(modelPtr->hasRewardModel("coinflips"));
ASSERT_TRUE(!modelPtr->getRewardModel("coinflips").hasStateRewards());
ASSERT_TRUE(modelPtr->getRewardModel("coinflips").hasStateActionRewards());
ASSERT_TRUE(!modelPtr->getRewardModel("coinflips").isAllZero());
} }
TEST(DirectEncodingParserTest, CtmcParsing) { TEST(DirectEncodingParserTest, CtmcParsing) {
@ -50,6 +55,11 @@ TEST(DirectEncodingParserTest, CtmcParsing) {
ASSERT_EQ(64ul, modelPtr->getStates("premium").getNumberOfSetBits()); ASSERT_EQ(64ul, modelPtr->getStates("premium").getNumberOfSetBits());
ASSERT_TRUE(modelPtr->hasLabel("minimum")); ASSERT_TRUE(modelPtr->hasLabel("minimum"));
ASSERT_EQ(132ul, modelPtr->getStates("minimum").getNumberOfSetBits()); ASSERT_EQ(132ul, modelPtr->getStates("minimum").getNumberOfSetBits());
ASSERT_EQ(1ul, modelPtr->getNumberOfRewardModels());
ASSERT_TRUE(modelPtr->hasRewardModel("num_repairs"));
ASSERT_TRUE(!modelPtr->getRewardModel("num_repairs").hasStateRewards());
ASSERT_TRUE(modelPtr->getRewardModel("num_repairs").hasStateActionRewards());
ASSERT_TRUE(!modelPtr->getRewardModel("num_repairs").isAllZero());
} }
TEST(DirectEncodingParserTest, MarkovAutomatonParsing) { TEST(DirectEncodingParserTest, MarkovAutomatonParsing) {
@ -63,7 +73,11 @@ TEST(DirectEncodingParserTest, MarkovAutomatonParsing) {
ASSERT_EQ(19ul, ma->getNumberOfChoices()); ASSERT_EQ(19ul, ma->getNumberOfChoices());
ASSERT_EQ(10ul, ma->getMarkovianStates().getNumberOfSetBits()); ASSERT_EQ(10ul, ma->getMarkovianStates().getNumberOfSetBits());
ASSERT_EQ(5, ma->getMaximalExitRate()); ASSERT_EQ(5, ma->getMaximalExitRate());
ASSERT_EQ(1ul, ma->getNumberOfRewardModels());
ASSERT_TRUE(ma->hasRewardModel("avg_waiting_time")); ASSERT_TRUE(ma->hasRewardModel("avg_waiting_time"));
ASSERT_TRUE(ma->getRewardModel("avg_waiting_time").hasStateRewards());
ASSERT_TRUE(!ma->getRewardModel("avg_waiting_time").hasStateActionRewards());
ASSERT_TRUE(!ma->getRewardModel("avg_waiting_time").isAllZero());
ASSERT_TRUE(modelPtr->hasLabel("init")); ASSERT_TRUE(modelPtr->hasLabel("init"));
ASSERT_EQ(1ul, modelPtr->getInitialStates().getNumberOfSetBits()); ASSERT_EQ(1ul, modelPtr->getInitialStates().getNumberOfSetBits());
ASSERT_TRUE(modelPtr->hasLabel("one_job_finished")); ASSERT_TRUE(modelPtr->hasLabel("one_job_finished"));

Loading…
Cancel
Save