From 9314d99354a1e9ff445e5eb59345774ab6043bab Mon Sep 17 00:00:00 2001 From: sjunges <sebastian.junges@gmail.com> Date: Thu, 24 Aug 2017 17:31:05 +0200 Subject: [PATCH] allow inference from actions, but only with a set option --- src/storm/builder/BuilderOptions.cpp | 6 +++- src/storm/builder/BuilderOptions.h | 11 +++++- src/storm/builder/ExplicitModelBuilder.cpp | 42 +++++++++++++++++++++- src/storm/generator/NextStateGenerator.cpp | 2 -- 4 files changed, 56 insertions(+), 5 deletions(-) diff --git a/src/storm/builder/BuilderOptions.cpp b/src/storm/builder/BuilderOptions.cpp index b0d2641bd..4774227fa 100644 --- a/src/storm/builder/BuilderOptions.cpp +++ b/src/storm/builder/BuilderOptions.cpp @@ -35,7 +35,7 @@ namespace storm { return boost::get<storm::expressions::Expression>(labelOrExpression); } - BuilderOptions::BuilderOptions(bool buildAllRewardModels, bool buildAllLabels) : buildAllRewardModels(buildAllRewardModels), buildAllLabels(buildAllLabels), buildChoiceLabels(false), buildStateValuations(false), buildChoiceOrigins(false), explorationChecks(false), explorationShowProgress(false), explorationShowProgressDelay(0) { + BuilderOptions::BuilderOptions(bool buildAllRewardModels, bool buildAllLabels) : buildAllRewardModels(buildAllRewardModels), buildAllLabels(buildAllLabels), buildChoiceLabels(false), buildStateValuations(false), buildChoiceOrigins(false), explorationChecks(false), explorationShowProgress(false), inferObservationsFromActions(false), explorationShowProgressDelay(0) { // Intentionally left empty. } @@ -156,6 +156,10 @@ namespace storm { bool BuilderOptions::isBuildAllLabelsSet() const { return buildAllLabels; } + + bool BuilderOptions::isInferObservationsFromActionsSet() const { + return inferObservationsFromActions; + } BuilderOptions& BuilderOptions::setBuildAllRewardModels(bool newValue) { buildAllRewardModels = newValue; diff --git a/src/storm/builder/BuilderOptions.h b/src/storm/builder/BuilderOptions.h index f3a6edfbf..0aeb0f674 100644 --- a/src/storm/builder/BuilderOptions.h +++ b/src/storm/builder/BuilderOptions.h @@ -82,6 +82,7 @@ namespace storm { */ void setTerminalStatesFromFormula(storm::logic::Formula const& formula); + /*! * Which reward models are built * @return @@ -107,6 +108,7 @@ namespace storm { bool isBuildAllLabelsSet() const; bool isExplorationChecksSet() const; bool isExplorationShowProgressSet() const; + bool isInferObservationsFromActionsSet() const; uint64_t getExplorationShowProgressDelay() const; /** @@ -155,7 +157,11 @@ namespace storm { * @return this */ BuilderOptions& setExplorationChecks(bool newValue = true); - + + + BuilderOptions& setInferObservationsFromActions(bool newValue = true); + + private: /// A flag that indicates whether all reward models are to be built. In this case, the reward model names are /// to be ignored. @@ -191,6 +197,9 @@ namespace storm { /// A flag that stores whether the progress of exploration is to be printed. bool explorationShowProgress; + + /// For POMDPs, should we allow inference of observation classes from different enabled actions. + bool inferObservationsFromActions; /// The delay for printing progress information. uint64_t explorationShowProgressDelay; diff --git a/src/storm/builder/ExplicitModelBuilder.cpp b/src/storm/builder/ExplicitModelBuilder.cpp index 79d6fc15d..1f2a53d54 100644 --- a/src/storm/builder/ExplicitModelBuilder.cpp +++ b/src/storm/builder/ExplicitModelBuilder.cpp @@ -330,9 +330,49 @@ namespace storm { } if (generator->isPartiallyObservable()) { std::vector<uint32_t> classes; + uint32_t newObservation = 0; classes.resize(stateStorage.getNumberOfStates()); + std::unordered_map<uint32_t, std::vector<std::pair<std::vector<std::string>, uint32_t>>> observationActions; for (auto const& bitVectorIndexPair : stateStorage.stateToId) { - classes[bitVectorIndexPair.second] = generator->observabilityClass(bitVectorIndexPair.first); + uint32_t varObservation = generator->observabilityClass(bitVectorIndexPair.first); + uint32_t observation = -1; // Is replaced later on. + bool foundActionSet = false; + std::vector<std::string> actionNames; + bool addedAnonymousAction = false; + for (uint64 choice = modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second]; choice < modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second+1]; ++choice) { + if (modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).empty()) { + STORM_LOG_THROW(!addedAnonymousAction, storm::exceptions::WrongFormatException, "Cannot have multiple anonymous actions, as these cannot be mapped correctly."); + actionNames.push_back(""); + addedAnonymousAction = true; + } else { + STORM_LOG_ASSERT(modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size() == 1, "Expect choice labelling to contain exactly one label at this point, but found " << modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size()); + actionNames.push_back(*modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).begin()); + } + } + STORM_LOG_TRACE("VarObservation: " << varObservation << " Action Names: " << storm::utility::vector::toString(actionNames)); + auto it = observationActions.find(varObservation); + if (it == observationActions.end()) { + observationActions.emplace(varObservation, std::vector<std::pair<std::vector<std::string>, uint32_t>>()); + } else { + for(auto const& entries : it->second) { + STORM_LOG_TRACE(storm::utility::vector::toString(entries.first)); + if (entries.first == actionNames) { + observation = entries.second; + foundActionSet = true; + break; + } + } + + STORM_LOG_THROW(generator->getOptions().isInferObservationsFromActionsSet() || foundActionSet, storm::exceptions::WrongFormatException, "Two states with the same observation have a different set of enabled actions, this is only allowed with a special option."); + + } + if (!foundActionSet) { + observation = newObservation; + observationActions.find(varObservation)->second.emplace_back(actionNames, newObservation); + ++newObservation; + } + + classes[bitVectorIndexPair.second] = observation; } modelComponents.observabilityClasses = classes; } diff --git a/src/storm/generator/NextStateGenerator.cpp b/src/storm/generator/NextStateGenerator.cpp index faaadf1ed..8d35f9cf2 100644 --- a/src/storm/generator/NextStateGenerator.cpp +++ b/src/storm/generator/NextStateGenerator.cpp @@ -158,9 +158,7 @@ namespace storm { uint32_t NextStateGenerator<ValueType, StateType>::observabilityClass(CompressedState const &state) const { if (this->mask.size() == 0) { this->mask = computeObservabilityMask(variableInformation); - std::cout << mask.size() << std::endl; } - std::cout << state.size() << std::endl; return unpackStateToObservabilityClass(state, observabilityMap, mask); }