Browse Source

fixes in parsing, support for POMDPs in DRN

tempestpy_adaptions
Sebastian Junges 7 years ago
parent
commit
fb8dd88314
  1. 2
      src/storm/models/ModelType.cpp
  2. 10
      src/storm/models/sparse/Pomdp.cpp
  3. 5
      src/storm/models/sparse/Pomdp.h
  4. 18
      src/storm/parser/DirectEncodingParser.cpp
  5. 1
      src/storm/storage/prism/Program.cpp
  6. 5
      src/storm/utility/DirectEncodingExporter.cpp

2
src/storm/models/ModelType.cpp

@ -17,7 +17,7 @@ namespace storm {
return ModelType::MarkovAutomaton;
} else if (type == "S2PG") {
return ModelType::S2pg;
} else if (type == "Pomdp") {
} else if (type == "POMDP") {
return ModelType::Pomdp;
} else {
STORM_LOG_THROW(false, storm::exceptions::InvalidTypeException, "Type " << type << "not known.");

10
src/storm/models/sparse/Pomdp.cpp

@ -44,6 +44,16 @@ namespace storm {
// In debug mode, ensure that every observability is used.
}
template<typename ValueType, typename RewardModelType>
uint32_t Pomdp<ValueType, RewardModelType>::getObservation(uint64_t state) const {
return observations.at(state);
}
template<typename ValueType, typename RewardModelType>
uint64_t Pomdp<ValueType, RewardModelType>::getNrObservations() const {
return nrObservations;
}
template class Pomdp<double>;
template class Pomdp<storm::RationalNumber>;

5
src/storm/models/sparse/Pomdp.h

@ -55,6 +55,11 @@ namespace storm {
virtual void printModelInformationToStream(std::ostream& out) const override;
uint32_t getObservation(uint64_t state) const;
uint64_t getNrObservations() const;
protected:
// TODO: consider a bitvector based presentation (depending on our needs).

18
src/storm/parser/DirectEncodingParser.cpp

@ -109,10 +109,11 @@ namespace storm {
std::shared_ptr<storm::storage::sparse::ModelComponents<ValueType, RewardModelType>> DirectEncodingParser<ValueType, RewardModelType>::parseStates(std::istream& file, storm::models::ModelType type, size_t stateSize, ValueParser<ValueType> const& valueParser) {
// Initialize
auto modelComponents = std::make_shared<storm::storage::sparse::ModelComponents<ValueType, RewardModelType>>();
bool nonDeterministic = (type == storm::models::ModelType::Mdp || type == storm::models::ModelType::MarkovAutomaton);
bool nonDeterministic = (type == storm::models::ModelType::Mdp || type == storm::models::ModelType::MarkovAutomaton || type == storm::models::ModelType::Pomdp);
storm::storage::SparseMatrixBuilder<ValueType> builder = storm::storage::SparseMatrixBuilder<ValueType>(0, 0, 0, false, nonDeterministic, 0);
modelComponents->stateLabeling = storm::models::sparse::StateLabeling(stateSize);
modelComponents->observabilityClasses = std::vector<uint32_t>();
modelComponents->observabilityClasses->resize(stateSize);
// We parse rates for continuous time models.
if (type == storm::models::ModelType::Ctmc) {
modelComponents->rateTransitions = true;
@ -152,6 +153,19 @@ namespace storm {
STORM_LOG_WARN("Rewards were not imported");
line = line.substr(posEndReward+1);
}
if (type == storm::models::ModelType::Pomdp) {
if (boost::starts_with(line, "{")) {
size_t posEndObservation = line.find("}");
std::string observation = line.substr(1, posEndObservation-1);
STORM_LOG_TRACE("State observation " << observations);
modelComponents->observabilityClasses.get()[state] = std::stoi(observation);
line = line.substr(posEndObservation+1);
} else {
STORM_LOG_THROW(false, storm::exceptions::WrongFormatException, "Expected an observation for state " << state << ".");
}
}
// Check for labels
std::vector<std::string> labels;
boost::split(labels, line, boost::is_any_of(" "));

1
src/storm/storage/prism/Program.cpp

@ -1684,6 +1684,7 @@ namespace storm {
case Program::ModelType::MDP: out << "mdp"; break;
case Program::ModelType::CTMDP: out << "ctmdp"; break;
case Program::ModelType::MA: out << "ma"; break;
case Program::ModelType::POMDP: out << "pomdp"; break;
}
return out;
}

5
src/storm/utility/DirectEncodingExporter.cpp

@ -8,6 +8,7 @@
#include "storm/models/sparse/Mdp.h"
#include "storm/models/sparse/Ctmc.h"
#include "storm/models/sparse/MarkovAutomaton.h"
#include "storm/models/sparse/Pomdp.h"
#include "storm/models/sparse/StandardRewardModel.h"
@ -75,6 +76,10 @@ namespace storm {
os << "]";
}
if (sparseModel->getType() == storm::models::ModelType::Pomdp) {
os << " {" << sparseModel->template as<storm::models::sparse::Pomdp<ValueType>>()->getObservation(group) << "}";
}
// Write labels
for(auto const& label : sparseModel->getStateLabeling().getLabelsOfState(group)) {
os << " " << label;

Loading…
Cancel
Save