diff --git a/src/storm/models/ModelType.cpp b/src/storm/models/ModelType.cpp index 956c9baed..d556448bb 100644 --- a/src/storm/models/ModelType.cpp +++ b/src/storm/models/ModelType.cpp @@ -17,7 +17,7 @@ namespace storm { return ModelType::MarkovAutomaton; } else if (type == "S2PG") { return ModelType::S2pg; - } else if (type == "Pomdp") { + } else if (type == "POMDP") { return ModelType::Pomdp; } else { STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Type " << type << "not known."); diff --git a/src/storm/models/sparse/Pomdp.cpp b/src/storm/models/sparse/Pomdp.cpp index b49e72cb4..4baa24751 100644 --- a/src/storm/models/sparse/Pomdp.cpp +++ b/src/storm/models/sparse/Pomdp.cpp @@ -44,6 +44,16 @@ namespace storm { // In debug mode, ensure that every observability is used. } + template + uint32_t Pomdp::getObservation(uint64_t state) const { + return observations.at(state); + } + + template + uint64_t Pomdp::getNrObservations() const { + return nrObservations; + } + template class Pomdp; template class Pomdp; diff --git a/src/storm/models/sparse/Pomdp.h b/src/storm/models/sparse/Pomdp.h index 7be220743..4d4d54747 100644 --- a/src/storm/models/sparse/Pomdp.h +++ b/src/storm/models/sparse/Pomdp.h @@ -55,6 +55,11 @@ namespace storm { virtual void printModelInformationToStream(std::ostream& out) const override; + uint32_t getObservation(uint64_t state) const; + + uint64_t getNrObservations() const; + + protected: // TODO: consider a bitvector based presentation (depending on our needs). diff --git a/src/storm/parser/DirectEncodingParser.cpp b/src/storm/parser/DirectEncodingParser.cpp index c18d1965c..e6743e941 100644 --- a/src/storm/parser/DirectEncodingParser.cpp +++ b/src/storm/parser/DirectEncodingParser.cpp @@ -109,10 +109,11 @@ namespace storm { std::shared_ptr> DirectEncodingParser::parseStates(std::istream& file, storm::models::ModelType type, size_t stateSize, ValueParser const& valueParser) { // Initialize auto modelComponents = std::make_shared>(); - bool nonDeterministic = (type == storm::models::ModelType::Mdp || type == storm::models::ModelType::MarkovAutomaton); + bool nonDeterministic = (type == storm::models::ModelType::Mdp || type == storm::models::ModelType::MarkovAutomaton || type == storm::models::ModelType::Pomdp); storm::storage::SparseMatrixBuilder builder = storm::storage::SparseMatrixBuilder(0, 0, 0, false, nonDeterministic, 0); modelComponents->stateLabeling = storm::models::sparse::StateLabeling(stateSize); - + modelComponents->observabilityClasses = std::vector(); + modelComponents->observabilityClasses->resize(stateSize); // We parse rates for continuous time models. if (type == storm::models::ModelType::Ctmc) { modelComponents->rateTransitions = true; @@ -152,6 +153,19 @@ namespace storm { STORM_LOG_WARN("Rewards were not imported"); line = line.substr(posEndReward+1); } + + if (type == storm::models::ModelType::Pomdp) { + if (boost::starts_with(line, "{")) { + size_t posEndObservation = line.find("}"); + std::string observation = line.substr(1, posEndObservation-1); + STORM_LOG_TRACE("State observation " << observations); + modelComponents->observabilityClasses.get()[state] = std::stoi(observation); + line = line.substr(posEndObservation+1); + } else { + STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Expected an observation for state " << state << "."); + } + } + // Check for labels std::vector labels; boost::split(labels, line, boost::is_any_of(" ")); diff --git a/src/storm/storage/prism/Program.cpp b/src/storm/storage/prism/Program.cpp index fce0c6682..d8dece894 100644 --- a/src/storm/storage/prism/Program.cpp +++ b/src/storm/storage/prism/Program.cpp @@ -1684,6 +1684,7 @@ namespace storm { case Program::ModelType::MDP: out << "mdp"; break; case Program::ModelType::CTMDP: out << "ctmdp"; break; case Program::ModelType::MA: out << "ma"; break; + case Program::ModelType::POMDP: out << "pomdp"; break; } return out; } diff --git a/src/storm/utility/DirectEncodingExporter.cpp b/src/storm/utility/DirectEncodingExporter.cpp index 8c93b3b7f..ff5301045 100644 --- a/src/storm/utility/DirectEncodingExporter.cpp +++ b/src/storm/utility/DirectEncodingExporter.cpp @@ -8,6 +8,7 @@ #include "storm/models/sparse/Mdp.h" #include "storm/models/sparse/Ctmc.h" #include "storm/models/sparse/MarkovAutomaton.h" +#include "storm/models/sparse/Pomdp.h" #include "storm/models/sparse/StandardRewardModel.h" @@ -75,6 +76,10 @@ namespace storm { os << "]"; } + if (sparseModel->getType() == storm::models::ModelType::Pomdp) { + os << " {" << sparseModel->template as>()->getObservation(group) << "}"; + } + // Write labels for(auto const& label : sparseModel->getStateLabeling().getLabelsOfState(group)) { os << " " << label;