Browse Source

fixes in parsing, support for POMDPs in DRN

tempestpy_adaptions
Sebastian Junges 7 years ago
parent
commit
fb8dd88314
  1. 2
      src/storm/models/ModelType.cpp
  2. 10
      src/storm/models/sparse/Pomdp.cpp
  3. 5
      src/storm/models/sparse/Pomdp.h
  4. 18
      src/storm/parser/DirectEncodingParser.cpp
  5. 1
      src/storm/storage/prism/Program.cpp
  6. 5
      src/storm/utility/DirectEncodingExporter.cpp

2
src/storm/models/ModelType.cpp

@ -17,7 +17,7 @@ namespace storm {
return ModelType::MarkovAutomaton; return ModelType::MarkovAutomaton;
} else if (type == "S2PG") { } else if (type == "S2PG") {
return ModelType::S2pg; return ModelType::S2pg;
} else if (type == "Pomdp") {
} else if (type == "POMDP") {
return ModelType::Pomdp; return ModelType::Pomdp;
} else { } else {
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Type " << type << "not known."); STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Type " << type << "not known.");

10
src/storm/models/sparse/Pomdp.cpp

@ -44,6 +44,16 @@ namespace storm {
// In debug mode, ensure that every observability is used. // In debug mode, ensure that every observability is used.
} }
template<typename ValueType, typename RewardModelType>
uint32_t Pomdp<ValueType, RewardModelType>::getObservation(uint64_t state) const {
return observations.at(state);
}
template<typename ValueType, typename RewardModelType>
uint64_t Pomdp<ValueType, RewardModelType>::getNrObservations() const {
return nrObservations;
}
template class Pomdp<double>; template class Pomdp<double>;
template class Pomdp<storm::RationalNumber>; template class Pomdp<storm::RationalNumber>;

5
src/storm/models/sparse/Pomdp.h

@ -55,6 +55,11 @@ namespace storm {
virtual void printModelInformationToStream(std::ostream& out) const override; virtual void printModelInformationToStream(std::ostream& out) const override;
uint32_t getObservation(uint64_t state) const;
uint64_t getNrObservations() const;
protected: protected:
// TODO: consider a bitvector based presentation (depending on our needs). // TODO: consider a bitvector based presentation (depending on our needs).

18
src/storm/parser/DirectEncodingParser.cpp

@ -109,10 +109,11 @@ namespace storm {
std::shared_ptr<storm::storage::sparse::ModelComponents<ValueType, RewardModelType>> DirectEncodingParser<ValueType, RewardModelType>::parseStates(std::istream& file, storm::models::ModelType type, size_t stateSize, ValueParser<ValueType> const& valueParser) { std::shared_ptr<storm::storage::sparse::ModelComponents<ValueType, RewardModelType>> DirectEncodingParser<ValueType, RewardModelType>::parseStates(std::istream& file, storm::models::ModelType type, size_t stateSize, ValueParser<ValueType> const& valueParser) {
// Initialize // Initialize
auto modelComponents = std::make_shared<storm::storage::sparse::ModelComponents<ValueType, RewardModelType>>(); auto modelComponents = std::make_shared<storm::storage::sparse::ModelComponents<ValueType, RewardModelType>>();
bool nonDeterministic = (type == storm::models::ModelType::Mdp || type == storm::models::ModelType::MarkovAutomaton);
bool nonDeterministic = (type == storm::models::ModelType::Mdp || type == storm::models::ModelType::MarkovAutomaton || type == storm::models::ModelType::Pomdp);
storm::storage::SparseMatrixBuilder<ValueType> builder = storm::storage::SparseMatrixBuilder<ValueType>(0, 0, 0, false, nonDeterministic, 0); storm::storage::SparseMatrixBuilder<ValueType> builder = storm::storage::SparseMatrixBuilder<ValueType>(0, 0, 0, false, nonDeterministic, 0);
modelComponents->stateLabeling = storm::models::sparse::StateLabeling(stateSize); modelComponents->stateLabeling = storm::models::sparse::StateLabeling(stateSize);
modelComponents->observabilityClasses = std::vector<uint32_t>();
modelComponents->observabilityClasses->resize(stateSize);
// We parse rates for continuous time models. // We parse rates for continuous time models.
if (type == storm::models::ModelType::Ctmc) { if (type == storm::models::ModelType::Ctmc) {
modelComponents->rateTransitions = true; modelComponents->rateTransitions = true;
@ -152,6 +153,19 @@ namespace storm {
STORM_LOG_WARN("Rewards were not imported"); STORM_LOG_WARN("Rewards were not imported");
line = line.substr(posEndReward+1); line = line.substr(posEndReward+1);
} }
if (type == storm::models::ModelType::Pomdp) {
if (boost::starts_with(line, "{")) {
size_t posEndObservation = line.find("}");
std::string observation = line.substr(1, posEndObservation-1);
STORM_LOG_TRACE("State observation " << observations);
modelComponents->observabilityClasses.get()[state] = std::stoi(observation);
line = line.substr(posEndObservation+1);
} else {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Expected an observation for state " << state << ".");
}
}
// Check for labels // Check for labels
std::vector<std::string> labels; std::vector<std::string> labels;
boost::split(labels, line, boost::is_any_of(" ")); boost::split(labels, line, boost::is_any_of(" "));

1
src/storm/storage/prism/Program.cpp

@ -1684,6 +1684,7 @@ namespace storm {
case Program::ModelType::MDP: out << "mdp"; break; case Program::ModelType::MDP: out << "mdp"; break;
case Program::ModelType::CTMDP: out << "ctmdp"; break; case Program::ModelType::CTMDP: out << "ctmdp"; break;
case Program::ModelType::MA: out << "ma"; break; case Program::ModelType::MA: out << "ma"; break;
case Program::ModelType::POMDP: out << "pomdp"; break;
} }
return out; return out;
} }

5
src/storm/utility/DirectEncodingExporter.cpp

@ -8,6 +8,7 @@
#include "storm/models/sparse/Mdp.h" #include "storm/models/sparse/Mdp.h"
#include "storm/models/sparse/Ctmc.h" #include "storm/models/sparse/Ctmc.h"
#include "storm/models/sparse/MarkovAutomaton.h" #include "storm/models/sparse/MarkovAutomaton.h"
#include "storm/models/sparse/Pomdp.h"
#include "storm/models/sparse/StandardRewardModel.h" #include "storm/models/sparse/StandardRewardModel.h"
@ -75,6 +76,10 @@ namespace storm {
os << "]"; os << "]";
} }
if (sparseModel->getType() == storm::models::ModelType::Pomdp) {
os << " {" << sparseModel->template as<storm::models::sparse::Pomdp<ValueType>>()->getObservation(group) << "}";
}
// Write labels // Write labels
for(auto const& label : sparseModel->getStateLabeling().getLabelsOfState(group)) { for(auto const& label : sparseModel->getStateLabeling().getLabelsOfState(group)) {
os << " " << label; os << " " << label;

Loading…
Cancel
Save