Browse Source

extensions to pomdp stuff

tempestpy_adaptions
sjunges 7 years ago
parent
commit
3ac42caf7c
  1. 6
      src/storm-pomdp-cli/storm-pomdp.cpp
  2. 23
      src/storm-pomdp/analysis/UniqueObservationStates.cpp
  3. 5
      src/storm-pomdp/analysis/UniqueObservationStates.h
  4. 34
      src/storm/builder/ExplicitModelBuilder.cpp
  5. 22
      src/storm/parser/PrismParser.cpp

6
src/storm-pomdp-cli/storm-pomdp.cpp

@ -1,5 +1,6 @@
#include <storm-pomdp/analysis/UniqueObservationStates.h>
#include "storm/utility/initialize.h" #include "storm/utility/initialize.h"
#include "storm/settings/modules/GeneralSettings.h" #include "storm/settings/modules/GeneralSettings.h"
@ -97,17 +98,20 @@ int main(const int argc, const char** argv) {
storm::settings::modules::CoreSettings::Engine engine = coreSettings.getEngine(); storm::settings::modules::CoreSettings::Engine engine = coreSettings.getEngine();
storm::cli::SymbolicInput symbolicInput = storm::cli::parseAndPreprocessSymbolicInput(); storm::cli::SymbolicInput symbolicInput = storm::cli::parseAndPreprocessSymbolicInput();
// We should not export here if we are going to do some processing first.
auto model = storm::cli::buildPreprocessExportModelWithValueTypeAndDdlib<storm::dd::DdType::Sylvan, storm::RationalNumber>(symbolicInput, engine); auto model = storm::cli::buildPreprocessExportModelWithValueTypeAndDdlib<storm::dd::DdType::Sylvan, storm::RationalNumber>(symbolicInput, engine);
STORM_LOG_THROW(model && model->getType() == storm::models::ModelType::Pomdp, storm::exceptions::WrongFormatException, "Expected a POMDP."); STORM_LOG_THROW(model && model->getType() == storm::models::ModelType::Pomdp, storm::exceptions::WrongFormatException, "Expected a POMDP.");
// CHECK if prop maximizes, only apply in those situations // CHECK if prop maximizes, only apply in those situations
std::shared_ptr<storm::models::sparse::Pomdp<storm::RationalNumber>> pomdp = model->template as<storm::models::sparse::Pomdp<storm::RationalNumber>>(); std::shared_ptr<storm::models::sparse::Pomdp<storm::RationalNumber>> pomdp = model->template as<storm::models::sparse::Pomdp<storm::RationalNumber>>();
storm::transformer::GlobalPOMDPSelfLoopEliminator<storm::RationalNumber> selfLoopEliminator(*pomdp); storm::transformer::GlobalPOMDPSelfLoopEliminator<storm::RationalNumber> selfLoopEliminator(*pomdp);
pomdp = selfLoopEliminator.transform(); pomdp = selfLoopEliminator.transform();
storm::analysis::UniqueObservationStates<storm::RationalNumber> uniqueAnalysis(*pomdp);
std::cout << uniqueAnalysis.analyse() << std::endl;
storm::transformer::ApplyFiniteSchedulerToPomdp<storm::RationalNumber> toPMCTransformer(*pomdp);
if (pomdpSettings.isExportToParametricSet()) { if (pomdpSettings.isExportToParametricSet()) {
storm::transformer::ApplyFiniteSchedulerToPomdp<storm::RationalNumber> toPMCTransformer(*pomdp);
auto const &pmc = toPMCTransformer.transform(); auto const &pmc = toPMCTransformer.transform();
storm::analysis::ConstraintCollector<storm::RationalFunction> constraints(*pmc); storm::analysis::ConstraintCollector<storm::RationalFunction> constraints(*pmc);
pmc->printModelInformationToStream(std::cout); pmc->printModelInformationToStream(std::cout);

23
src/storm-pomdp/analysis/UniqueObservationStates.cpp

@ -3,12 +3,33 @@
namespace storm { namespace storm {
namespace analysis { namespace analysis {
typename <ValueType>
template <typename ValueType>
UniqueObservationStates<ValueType>::UniqueObservationStates(storm::models::sparse::Pomdp<ValueType> const &pomdp) : pomdp(pomdp) {
}
template <typename ValueType>
storm::storage::BitVector UniqueObservationStates<ValueType>::analyse() const { storm::storage::BitVector UniqueObservationStates<ValueType>::analyse() const {
storm::storage::BitVector seenOnce(pomdp.getNrObservations(), false);
storm::storage::BitVector seenMoreThanOnce(pomdp.getNrObservations(), false);
for (auto const& observation : pomdp.getObservations()) {
if (seenOnce.get(observation)) {
seenMoreThanOnce.set(observation);
}
seenOnce.set(observation);
} }
storm::storage::BitVector uniqueObservation(pomdp.getNumberOfStates(), false);
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) {
if (!seenMoreThanOnce.get(pomdp.getObservation(state))) {
uniqueObservation.set(state);
}
}
return uniqueObservation;
}
template class UniqueObservationStates<storm::RationalNumber>;
} }
} }

5
src/storm-pomdp/analysis/UniqueObservationStates.h

@ -1,12 +1,15 @@
#include "storm/models/sparse/Pomdp.h" #include "storm/models/sparse/Pomdp.h"
#include "storm/storage/BitVector.h"
namespace storm { namespace storm {
namespace analysis { namespace analysis {
template<typename ValueType> template<typename ValueType>
class UniqueObservationStates { class UniqueObservationStates {
public:
UniqueObservationStates(storm::models::sparse::Pomdp<ValueType> const& pomdp); UniqueObservationStates(storm::models::sparse::Pomdp<ValueType> const& pomdp);
storm::storage::BitVector analyse() const; storm::storage::BitVector analyse() const;
private:
storm::models::sparse::Pomdp<ValueType> const& pomdp; storm::models::sparse::Pomdp<ValueType> const& pomdp;
}; };
} }

34
src/storm/builder/ExplicitModelBuilder.cpp

@ -336,25 +336,37 @@ namespace storm {
for (auto const& bitVectorIndexPair : stateStorage.stateToId) { for (auto const& bitVectorIndexPair : stateStorage.stateToId) {
uint32_t varObservation = generator->observabilityClass(bitVectorIndexPair.first); uint32_t varObservation = generator->observabilityClass(bitVectorIndexPair.first);
uint32_t observation = -1; // Is replaced later on. uint32_t observation = -1; // Is replaced later on.
bool checkActionNames = false;
if (checkActionNames) {
bool foundActionSet = false; bool foundActionSet = false;
std::vector<std::string> actionNames; std::vector<std::string> actionNames;
bool addedAnonymousAction = false; bool addedAnonymousAction = false;
for (uint64 choice = modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second]; choice < modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second+1]; ++choice) {
for (uint64 choice = modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second];
choice < modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second +
1]; ++choice) {
if (modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).empty()) { if (modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).empty()) {
STORM_LOG_THROW(!addedAnonymousAction, storm::exceptions::WrongFormatException, "Cannot have multiple anonymous actions, as these cannot be mapped correctly.");
STORM_LOG_THROW(!addedAnonymousAction, storm::exceptions::WrongFormatException,
"Cannot have multiple anonymous actions, as these cannot be mapped correctly.");
actionNames.push_back(""); actionNames.push_back("");
addedAnonymousAction = true; addedAnonymousAction = true;
} else { } else {
STORM_LOG_ASSERT(modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size() == 1, "Expect choice labelling to contain exactly one label at this point, but found " << modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size());
actionNames.push_back(*modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).begin());
STORM_LOG_ASSERT(
modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size() == 1,
"Expect choice labelling to contain exactly one label at this point, but found "
<< modelComponents.choiceLabeling.get().getLabelsOfChoice(
choice).size());
actionNames.push_back(
*modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).begin());
} }
} }
STORM_LOG_TRACE("VarObservation: " << varObservation << " Action Names: " << storm::utility::vector::toString(actionNames));
STORM_LOG_TRACE("VarObservation: " << varObservation << " Action Names: "
<< storm::utility::vector::toString(actionNames));
auto it = observationActions.find(varObservation); auto it = observationActions.find(varObservation);
if (it == observationActions.end()) { if (it == observationActions.end()) {
observationActions.emplace(varObservation, std::vector<std::pair<std::vector<std::string>, uint32_t>>());
observationActions.emplace(varObservation,
std::vector<std::pair<std::vector<std::string>, uint32_t>>());
} else { } else {
for(auto const& entries : it->second) {
for (auto const &entries : it->second) {
STORM_LOG_TRACE(storm::utility::vector::toString(entries.first)); STORM_LOG_TRACE(storm::utility::vector::toString(entries.first));
if (entries.first == actionNames) { if (entries.first == actionNames) {
observation = entries.second; observation = entries.second;
@ -363,7 +375,10 @@ namespace storm {
} }
} }
STORM_LOG_THROW(generator->getOptions().isInferObservationsFromActionsSet() || foundActionSet, storm::exceptions::WrongFormatException, "Two states with the same observation have a different set of enabled actions, this is only allowed with a special option.");
STORM_LOG_THROW(
generator->getOptions().isInferObservationsFromActionsSet() || foundActionSet,
storm::exceptions::WrongFormatException,
"Two states with the same observation have a different set of enabled actions, this is only allowed with a special option.");
} }
if (!foundActionSet) { if (!foundActionSet) {
@ -373,6 +388,9 @@ namespace storm {
} }
classes[bitVectorIndexPair.second] = observation; classes[bitVectorIndexPair.second] = observation;
} else {
classes[bitVectorIndexPair.second] = varObservation;
}
} }
modelComponents.observabilityClasses = classes; modelComponents.observabilityClasses = classes;
} }

22
src/storm/parser/PrismParser.cpp

@ -263,6 +263,8 @@ namespace storm {
void PrismParser::moveToSecondRun() { void PrismParser::moveToSecondRun() {
// In the second run, we actually need to parse the commands instead of just skipping them, // In the second run, we actually need to parse the commands instead of just skipping them,
// so we adapt the rule for parsing commands. // so we adapt the rule for parsing commands.
STORM_LOG_THROW(observables.empty(), storm::exceptions::WrongFormatException, "Some variables marked as observable, but never declared");
commandDefinition = (((qi::lit("[") > -identifier > qi::lit("]")) commandDefinition = (((qi::lit("[") > -identifier > qi::lit("]"))
| |
(qi::lit("<") > -identifier > qi::lit(">")[qi::_a = true])) (qi::lit("<") > -identifier > qi::lit(">")[qi::_a = true]))
@ -587,12 +589,20 @@ namespace storm {
STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Boolean variable '" << variable.getName() << " was not renamed."); STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Boolean variable '" << variable.getName() << " was not renamed.");
storm::expressions::Variable renamedVariable = manager->declareBooleanVariable(renamingPair->second); storm::expressions::Variable renamedVariable = manager->declareBooleanVariable(renamingPair->second);
this->identifiers_.add(renamingPair->second, renamedVariable.getExpression()); this->identifiers_.add(renamingPair->second, renamedVariable.getExpression());
if(this->observables.count(renamingPair->second) > 0) {
this->observables.erase(renamingPair->second);
std::cout << renamingPair->second << " is observable." << std::endl;
}
} }
for (auto const& variable : moduleToRename.getIntegerVariables()) { for (auto const& variable : moduleToRename.getIntegerVariables()) {
auto const& renamingPair = renaming.find(variable.getName()); auto const& renamingPair = renaming.find(variable.getName());
STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Integer variable '" << variable.getName() << " was not renamed."); STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Integer variable '" << variable.getName() << " was not renamed.");
storm::expressions::Variable renamedVariable = manager->declareIntegerVariable(renamingPair->second); storm::expressions::Variable renamedVariable = manager->declareIntegerVariable(renamingPair->second);
this->identifiers_.add(renamingPair->second, renamedVariable.getExpression()); this->identifiers_.add(renamingPair->second, renamedVariable.getExpression());
if(this->observables.count(renamingPair->second) > 0) {
this->observables.erase(renamingPair->second);
std::cout << renamingPair->second << " is observable." << std::endl;
}
} }
for (auto const& command : moduleToRename.getCommands()) { for (auto const& command : moduleToRename.getCommands()) {
@ -631,7 +641,11 @@ namespace storm {
for (auto const& variable : moduleToRename.getBooleanVariables()) { for (auto const& variable : moduleToRename.getBooleanVariables()) {
auto const& renamingPair = renaming.find(variable.getName()); auto const& renamingPair = renaming.find(variable.getName());
STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Boolean variable '" << variable.getName() << " was not renamed."); STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Boolean variable '" << variable.getName() << " was not renamed.");
bool observable = variable.isObservable();
bool observable = this->observables.count(renamingPair->second) > 0;
if(observable) {
this->observables.erase(renamingPair->second);
std::cout << renamingPair->second << " is observable." << std::endl;
}
booleanVariables.push_back(storm::prism::BooleanVariable(manager->getVariable(renamingPair->second), variable.hasInitialValue() ? variable.getInitialValueExpression().substitute(expressionRenaming) : variable.getInitialValueExpression(), observable, this->getFilename(), get_line(qi::_1))); booleanVariables.push_back(storm::prism::BooleanVariable(manager->getVariable(renamingPair->second), variable.hasInitialValue() ? variable.getInitialValueExpression().substitute(expressionRenaming) : variable.getInitialValueExpression(), observable, this->getFilename(), get_line(qi::_1)));
} }
@ -640,7 +654,11 @@ namespace storm {
for (auto const& variable : moduleToRename.getIntegerVariables()) { for (auto const& variable : moduleToRename.getIntegerVariables()) {
auto const& renamingPair = renaming.find(variable.getName()); auto const& renamingPair = renaming.find(variable.getName());
STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Integer variable '" << variable.getName() << " was not renamed."); STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Integer variable '" << variable.getName() << " was not renamed.");
bool observable = variable.isObservable();
bool observable = this->observables.count(renamingPair->second) > 0;
if(observable) {
this->observables.erase(renamingPair->second);
std::cout << renamingPair->second << " is observable." << std::endl;
}
integerVariables.push_back(storm::prism::IntegerVariable(manager->getVariable(renamingPair->second), variable.getLowerBoundExpression().substitute(expressionRenaming), variable.getUpperBoundExpression().substitute(expressionRenaming), variable.hasInitialValue() ? variable.getInitialValueExpression().substitute(expressionRenaming) : variable.getInitialValueExpression(), observable, this->getFilename(), get_line(qi::_1))); integerVariables.push_back(storm::prism::IntegerVariable(manager->getVariable(renamingPair->second), variable.getLowerBoundExpression().substitute(expressionRenaming), variable.getUpperBoundExpression().substitute(expressionRenaming), variable.hasInitialValue() ? variable.getInitialValueExpression().substitute(expressionRenaming) : variable.getInitialValueExpression(), observable, this->getFilename(), get_line(qi::_1)));
} }

Loading…
Cancel
Save