From 3ac42caf7cd2aa8b939af54ab384473603ebaf59 Mon Sep 17 00:00:00 2001 From: sjunges Date: Sat, 26 Aug 2017 14:29:15 +0200 Subject: [PATCH] extensions to pomdp stuff --- src/storm-pomdp-cli/storm-pomdp.cpp | 6 +- .../analysis/UniqueObservationStates.cpp | 25 +++++- .../analysis/UniqueObservationStates.h | 5 +- src/storm/builder/ExplicitModelBuilder.cpp | 82 +++++++++++-------- src/storm/parser/PrismParser.cpp | 22 ++++- 5 files changed, 102 insertions(+), 38 deletions(-) diff --git a/src/storm-pomdp-cli/storm-pomdp.cpp b/src/storm-pomdp-cli/storm-pomdp.cpp index 6328ae90e..c63f989a4 100644 --- a/src/storm-pomdp-cli/storm-pomdp.cpp +++ b/src/storm-pomdp-cli/storm-pomdp.cpp @@ -1,5 +1,6 @@ +#include #include "storm/utility/initialize.h" #include "storm/settings/modules/GeneralSettings.h" @@ -97,17 +98,20 @@ int main(const int argc, const char** argv) { storm::settings::modules::CoreSettings::Engine engine = coreSettings.getEngine(); storm::cli::SymbolicInput symbolicInput = storm::cli::parseAndPreprocessSymbolicInput(); + // We should not export here if we are going to do some processing first. auto model = storm::cli::buildPreprocessExportModelWithValueTypeAndDdlib(symbolicInput, engine); STORM_LOG_THROW(model && model->getType() == storm::models::ModelType::Pomdp, storm::exceptions::WrongFormatException, "Expected a POMDP."); // CHECK if prop maximizes, only apply in those situations std::shared_ptr> pomdp = model->template as>(); storm::transformer::GlobalPOMDPSelfLoopEliminator selfLoopEliminator(*pomdp); pomdp = selfLoopEliminator.transform(); + storm::analysis::UniqueObservationStates uniqueAnalysis(*pomdp); + std::cout << uniqueAnalysis.analyse() << std::endl; - storm::transformer::ApplyFiniteSchedulerToPomdp toPMCTransformer(*pomdp); if (pomdpSettings.isExportToParametricSet()) { + storm::transformer::ApplyFiniteSchedulerToPomdp toPMCTransformer(*pomdp); auto const &pmc = toPMCTransformer.transform(); storm::analysis::ConstraintCollector constraints(*pmc); pmc->printModelInformationToStream(std::cout); diff --git a/src/storm-pomdp/analysis/UniqueObservationStates.cpp b/src/storm-pomdp/analysis/UniqueObservationStates.cpp index 7c0ff28af..5fc8e1426 100644 --- a/src/storm-pomdp/analysis/UniqueObservationStates.cpp +++ b/src/storm-pomdp/analysis/UniqueObservationStates.cpp @@ -3,12 +3,33 @@ namespace storm { namespace analysis { - typename - storm::storage::BitVector UniqueObservationStates::analyse() const { + template + UniqueObservationStates::UniqueObservationStates(storm::models::sparse::Pomdp const &pomdp) : pomdp(pomdp) { } + template + storm::storage::BitVector UniqueObservationStates::analyse() const { + storm::storage::BitVector seenOnce(pomdp.getNrObservations(), false); + storm::storage::BitVector seenMoreThanOnce(pomdp.getNrObservations(), false); + + for (auto const& observation : pomdp.getObservations()) { + if (seenOnce.get(observation)) { + seenMoreThanOnce.set(observation); + } + seenOnce.set(observation); + } + + storm::storage::BitVector uniqueObservation(pomdp.getNumberOfStates(), false); + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + if (!seenMoreThanOnce.get(pomdp.getObservation(state))) { + uniqueObservation.set(state); + } + } + return uniqueObservation; + } + template class UniqueObservationStates; } } \ No newline at end of file diff --git a/src/storm-pomdp/analysis/UniqueObservationStates.h b/src/storm-pomdp/analysis/UniqueObservationStates.h index b16190fb4..a0c13c406 100644 --- a/src/storm-pomdp/analysis/UniqueObservationStates.h +++ b/src/storm-pomdp/analysis/UniqueObservationStates.h @@ -1,12 +1,15 @@ #include "storm/models/sparse/Pomdp.h" +#include "storm/storage/BitVector.h" namespace storm { namespace analysis { template class UniqueObservationStates { + public: + UniqueObservationStates(storm::models::sparse::Pomdp const& pomdp); storm::storage::BitVector analyse() const; - + private: storm::models::sparse::Pomdp const& pomdp; }; } diff --git a/src/storm/builder/ExplicitModelBuilder.cpp b/src/storm/builder/ExplicitModelBuilder.cpp index 1f2a53d54..c4c5b6440 100644 --- a/src/storm/builder/ExplicitModelBuilder.cpp +++ b/src/storm/builder/ExplicitModelBuilder.cpp @@ -336,43 +336,61 @@ namespace storm { for (auto const& bitVectorIndexPair : stateStorage.stateToId) { uint32_t varObservation = generator->observabilityClass(bitVectorIndexPair.first); uint32_t observation = -1; // Is replaced later on. - bool foundActionSet = false; - std::vector actionNames; - bool addedAnonymousAction = false; - for (uint64 choice = modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second]; choice < modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second+1]; ++choice) { - if (modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).empty()) { - STORM_LOG_THROW(!addedAnonymousAction, storm::exceptions::WrongFormatException, "Cannot have multiple anonymous actions, as these cannot be mapped correctly."); - actionNames.push_back(""); - addedAnonymousAction = true; - } else { - STORM_LOG_ASSERT(modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size() == 1, "Expect choice labelling to contain exactly one label at this point, but found " << modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size()); - actionNames.push_back(*modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).begin()); - } - } - STORM_LOG_TRACE("VarObservation: " << varObservation << " Action Names: " << storm::utility::vector::toString(actionNames)); - auto it = observationActions.find(varObservation); - if (it == observationActions.end()) { - observationActions.emplace(varObservation, std::vector, uint32_t>>()); - } else { - for(auto const& entries : it->second) { - STORM_LOG_TRACE(storm::utility::vector::toString(entries.first)); - if (entries.first == actionNames) { - observation = entries.second; - foundActionSet = true; - break; + bool checkActionNames = false; + if (checkActionNames) { + bool foundActionSet = false; + std::vector actionNames; + bool addedAnonymousAction = false; + for (uint64 choice = modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second]; + choice < modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second + + 1]; ++choice) { + if (modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).empty()) { + STORM_LOG_THROW(!addedAnonymousAction, storm::exceptions::WrongFormatException, + "Cannot have multiple anonymous actions, as these cannot be mapped correctly."); + actionNames.push_back(""); + addedAnonymousAction = true; + } else { + STORM_LOG_ASSERT( + modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size() == 1, + "Expect choice labelling to contain exactly one label at this point, but found " + << modelComponents.choiceLabeling.get().getLabelsOfChoice( + choice).size()); + actionNames.push_back( + *modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).begin()); } } + STORM_LOG_TRACE("VarObservation: " << varObservation << " Action Names: " + << storm::utility::vector::toString(actionNames)); + auto it = observationActions.find(varObservation); + if (it == observationActions.end()) { + observationActions.emplace(varObservation, + std::vector, uint32_t>>()); + } else { + for (auto const &entries : it->second) { + STORM_LOG_TRACE(storm::utility::vector::toString(entries.first)); + if (entries.first == actionNames) { + observation = entries.second; + foundActionSet = true; + break; + } + } - STORM_LOG_THROW(generator->getOptions().isInferObservationsFromActionsSet() || foundActionSet, storm::exceptions::WrongFormatException, "Two states with the same observation have a different set of enabled actions, this is only allowed with a special option."); + STORM_LOG_THROW( + generator->getOptions().isInferObservationsFromActionsSet() || foundActionSet, + storm::exceptions::WrongFormatException, + "Two states with the same observation have a different set of enabled actions, this is only allowed with a special option."); - } - if (!foundActionSet) { - observation = newObservation; - observationActions.find(varObservation)->second.emplace_back(actionNames, newObservation); - ++newObservation; - } + } + if (!foundActionSet) { + observation = newObservation; + observationActions.find(varObservation)->second.emplace_back(actionNames, newObservation); + ++newObservation; + } - classes[bitVectorIndexPair.second] = observation; + classes[bitVectorIndexPair.second] = observation; + } else { + classes[bitVectorIndexPair.second] = varObservation; + } } modelComponents.observabilityClasses = classes; } diff --git a/src/storm/parser/PrismParser.cpp b/src/storm/parser/PrismParser.cpp index 759b3b140..5310fc38e 100644 --- a/src/storm/parser/PrismParser.cpp +++ b/src/storm/parser/PrismParser.cpp @@ -263,6 +263,8 @@ namespace storm { void PrismParser::moveToSecondRun() { // In the second run, we actually need to parse the commands instead of just skipping them, // so we adapt the rule for parsing commands. + STORM_LOG_THROW(observables.empty(), storm::exceptions::WrongFormatException, "Some variables marked as observable, but never declared"); + commandDefinition = (((qi::lit("[") > -identifier > qi::lit("]")) | (qi::lit("<") > -identifier > qi::lit(">")[qi::_a = true])) @@ -587,12 +589,20 @@ namespace storm { STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Boolean variable '" << variable.getName() << " was not renamed."); storm::expressions::Variable renamedVariable = manager->declareBooleanVariable(renamingPair->second); this->identifiers_.add(renamingPair->second, renamedVariable.getExpression()); + if(this->observables.count(renamingPair->second) > 0) { + this->observables.erase(renamingPair->second); + std::cout << renamingPair->second << " is observable." << std::endl; + } } for (auto const& variable : moduleToRename.getIntegerVariables()) { auto const& renamingPair = renaming.find(variable.getName()); STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Integer variable '" << variable.getName() << " was not renamed."); storm::expressions::Variable renamedVariable = manager->declareIntegerVariable(renamingPair->second); this->identifiers_.add(renamingPair->second, renamedVariable.getExpression()); + if(this->observables.count(renamingPair->second) > 0) { + this->observables.erase(renamingPair->second); + std::cout << renamingPair->second << " is observable." << std::endl; + } } for (auto const& command : moduleToRename.getCommands()) { @@ -631,7 +641,11 @@ namespace storm { for (auto const& variable : moduleToRename.getBooleanVariables()) { auto const& renamingPair = renaming.find(variable.getName()); STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Boolean variable '" << variable.getName() << " was not renamed."); - bool observable = variable.isObservable(); + bool observable = this->observables.count(renamingPair->second) > 0; + if(observable) { + this->observables.erase(renamingPair->second); + std::cout << renamingPair->second << " is observable." << std::endl; + } booleanVariables.push_back(storm::prism::BooleanVariable(manager->getVariable(renamingPair->second), variable.hasInitialValue() ? variable.getInitialValueExpression().substitute(expressionRenaming) : variable.getInitialValueExpression(), observable, this->getFilename(), get_line(qi::_1))); } @@ -640,7 +654,11 @@ namespace storm { for (auto const& variable : moduleToRename.getIntegerVariables()) { auto const& renamingPair = renaming.find(variable.getName()); STORM_LOG_THROW(renamingPair != renaming.end(), storm::exceptions::WrongFormatException, "Parsing error in " << this->getFilename() << ", line " << get_line(qi::_3) << ": Integer variable '" << variable.getName() << " was not renamed."); - bool observable = variable.isObservable(); + bool observable = this->observables.count(renamingPair->second) > 0; + if(observable) { + this->observables.erase(renamingPair->second); + std::cout << renamingPair->second << " is observable." << std::endl; + } integerVariables.push_back(storm::prism::IntegerVariable(manager->getVariable(renamingPair->second), variable.getLowerBoundExpression().substitute(expressionRenaming), variable.getUpperBoundExpression().substitute(expressionRenaming), variable.hasInitialValue() ? variable.getInitialValueExpression().substitute(expressionRenaming) : variable.getInitialValueExpression(), observable, this->getFilename(), get_line(qi::_1))); }