From 09cf1902e0e236b11cb949273a065799b052bc85 Mon Sep 17 00:00:00 2001 From: Sebastian Junges Date: Sat, 25 Jan 2020 13:43:14 +0100 Subject: [PATCH] added transformer to make pomdp canonic --- src/storm-pomdp-cli/storm-pomdp.cpp | 3 + .../transformer/MakePOMDPCanonic.cpp | 196 ++++++++++++++++++ .../transformer/MakePOMDPCanonic.h | 24 +++ .../exceptions/AmbiguousModelException.h | 12 ++ 4 files changed, 235 insertions(+) create mode 100644 src/storm-pomdp/transformer/MakePOMDPCanonic.cpp create mode 100644 src/storm-pomdp/transformer/MakePOMDPCanonic.h create mode 100644 src/storm/exceptions/AmbiguousModelException.h diff --git a/src/storm-pomdp-cli/storm-pomdp.cpp b/src/storm-pomdp-cli/storm-pomdp.cpp index 56de8b404..6d6725435 100644 --- a/src/storm-pomdp-cli/storm-pomdp.cpp +++ b/src/storm-pomdp-cli/storm-pomdp.cpp @@ -36,6 +36,7 @@ #include "storm-pomdp/transformer/GlobalPomdpMecChoiceEliminator.h" #include "storm-pomdp/transformer/PomdpMemoryUnfolder.h" #include "storm-pomdp/transformer/BinaryPomdpTransformer.h" +#include "storm-pomdp/transformer/MakePOMDPCanonic.h" #include "storm-pomdp/analysis/UniqueObservationStates.h" #include "storm-pomdp/analysis/QualitativeAnalysis.h" #include "storm/api/storm.h" @@ -100,6 +101,8 @@ int main(const int argc, const char** argv) { auto model = storm::cli::buildPreprocessExportModelWithValueTypeAndDdlib(symbolicInput, engine); STORM_LOG_THROW(model && model->getType() == storm::models::ModelType::Pomdp, storm::exceptions::WrongFormatException, "Expected a POMDP."); std::shared_ptr> pomdp = model->template as>(); + storm::transformer::MakePOMDPCanonic makeCanonic(*pomdp); + pomdp = makeCanonic.transform(); std::shared_ptr formula; if (!symbolicInput.properties.empty()) { diff --git a/src/storm-pomdp/transformer/MakePOMDPCanonic.cpp b/src/storm-pomdp/transformer/MakePOMDPCanonic.cpp new file mode 100644 index 000000000..2477b122c --- /dev/null +++ b/src/storm-pomdp/transformer/MakePOMDPCanonic.cpp @@ -0,0 +1,196 @@ +#include "storm-pomdp/transformer/MakePOMDPCanonic.h" +#include "storm/storage/sparse/ModelComponents.h" +#include "storm/exceptions/AmbiguousModelException.h" + +#include "storm/exceptions/InvalidArgumentException.h" + + +namespace storm { + namespace transformer { + + namespace detail { + struct ActionIdentifier { + uint64_t choiceLabelId; + uint64_t choiceOriginId; + + bool compatibleWith(ActionIdentifier const& other) const { + if (choiceLabelId != other.choiceLabelId) { + return false; // different labels. + } + if (choiceLabelId > 0) { + // Notice that we assume that we already have ensured that names + // are not used more than once. + return true; // actions have a name, name coincides. + } else { + // action is unnamed. + // We only call this method (at least we only should call this method) + // if there are multiple actions. Then two tau actions are only compatible + // if they are described by the same choice origin. + return choiceOriginId == other.choiceOriginId; + } + } + + friend bool operator<(ActionIdentifier const& lhs, ActionIdentifier const& rhs); + }; + + template + bool compatibleWith(iterator1 start1, iterator1 end1, iterator2 start2, iterator2 end2) { + iterator1 it1 = start1; + iterator2 it2 = start2; + while (it1 != end1 && it2 != end2) { + if (!it1->compatibleWith(it2->first)) { + return false; + } + ++it1; + ++it2; + } + return it1 == end1 && it2 == end2; + }; + + + bool operator<(ActionIdentifier const& lhs, ActionIdentifier const& rhs) { + if (lhs.choiceLabelId == rhs.choiceLabelId) { + return lhs.choiceOriginId < rhs.choiceOriginId; + } + return lhs.choiceLabelId < rhs.choiceLabelId; + } + + class ChoiceLabelIdStorage { + + public: + uint64_t registerLabel(std::string const& label) { + auto it = std::find(storage.begin(), storage.end(), label); + if (it == storage.end()) { + storage.push_back(label); + return storage.size() - 1; + } else { + return it - storage.begin(); + } + } + + + private: + std::vector storage = {""}; + + }; + } + + template + std::shared_ptr> MakePOMDPCanonic::transform() const { + STORM_LOG_THROW(pomdp.hasChoiceOrigins(), storm::exceptions::InvalidArgumentException, "Model must have been built with choice origins"); + STORM_LOG_THROW(pomdp.hasChoiceLabeling(), storm::exceptions::InvalidArgumentException, "Model must have been built with choice labels"); + std::vector permutation = computeCanonicalPermutation(); + return applyPermutationOnPomdp(permutation); + } + + + template + std::shared_ptr> MakePOMDPCanonic::applyPermutationOnPomdp(std::vector permutation) const { + + auto rewardModels = pomdp.getRewardModels(); + std::unordered_map> newRewardModels; + for (auto& rewardModelNameAndModel : rewardModels) { + newRewardModels.emplace(rewardModelNameAndModel.first, rewardModelNameAndModel.second.permuteActions(permutation)); + } + storm::storage::sparse::ModelComponents modelcomponents(pomdp.getTransitionMatrix().permuteRows(permutation), + pomdp.getStateLabeling(), + newRewardModels, + false, boost::none); + modelcomponents.observabilityClasses = pomdp.getObservations(); + modelcomponents.choiceLabeling = pomdp.getChoiceLabeling(); + return std::make_shared>(modelcomponents); + } + + + template + std::vector MakePOMDPCanonic::computeCanonicalPermutation() const { + std::map> observationActionIdentifiers; + std::map actionIdentifierDefinition; + auto const& choiceLabeling = pomdp.getChoiceLabeling(); + detail::ChoiceLabelIdStorage labelStorage; + + std::vector permutation; + uint64_t nrObservations = pomdp.getNrObservations(); + storm::storage::BitVector oneActionObservations(nrObservations); + storm::storage::BitVector moreActionObservations(nrObservations); + + + for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { + uint64_t rowIndexFrom = pomdp.getTransitionMatrix().getRowGroupIndices()[state]; + uint64_t rowIndexTo = pomdp.getTransitionMatrix().getRowGroupIndices()[state+1]; + + uint64_t observation = pomdp.getObservation(state); + if (rowIndexFrom + 1 == rowIndexTo) { + permutation.push_back(rowIndexFrom); + if (moreActionObservations.get(observation)) { + // We have seen this observation previously with multiple actions. Error! + // TODO provide more diagnostic information + STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Observation " << observation << " sometimes provides multiple action, but in state " << state << " provides one action."); + } + oneActionObservations.set(observation); + + // One action is ALWAYS fine. + continue; + } else { + if (oneActionObservations.get(observation)) { + // We have seen this observation previously with one action. Error! + STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Observation " << observation << " sometimes provides one action, but in state " << state << " provides multiple actions."); + } + moreActionObservations.set(observation); + } + + std::map actionIdentifiers; + std::set actionLabels; + for (uint64_t actionIndex = rowIndexFrom; actionIndex < rowIndexTo; ++actionIndex) { + // While this is in full generality a set of labels, + // for models modelled with prism, we actually know that these are singleton sets. + std::set labels = choiceLabeling.getLabelsOfChoice(actionIndex); + STORM_LOG_ASSERT(labels.size() <= 1, "We expect choice labels to be single sets"); + // Generate action identifier + uint64_t labelId = -1; + if (labels.size() == 1) { + labelId = labelStorage.registerLabel(*labels.begin()); + STORM_LOG_THROW(actionLabels.count(labelId) == 0, storm::exceptions::AmbiguousModelException, "Multiple actions with label '" << *labels.begin() << "' exist in state id " << state << "."); + actionLabels.emplace(labelId); + } else { + labelId = labelStorage.registerLabel(""); + } + + detail::ActionIdentifier ai; + ai.choiceLabelId = labelId; + ai.choiceOriginId = pomdp.getChoiceOrigins()->getIdentifier(actionIndex); + actionIdentifiers.emplace(ai,actionIndex); + } + + if (observationActionIdentifiers.count(observation) == 0) { + // First state with this observation + // store the corresponding vector. + std::vector ais; + for (auto const& als : actionIdentifiers) { + ais.push_back(als.first); + } + observationActionIdentifiers.emplace(observation, ais); + actionIdentifierDefinition.emplace(observation, state); + } else { + auto referenceStart = observationActionIdentifiers[observation].begin(); + auto referenceEnd = observationActionIdentifiers[observation].end(); + if (!detail::compatibleWith(referenceStart, referenceEnd, actionIdentifiers.begin(), actionIdentifiers.end())) { + STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Actions identifiers do not align between states '" << state << "' and '" << actionIdentifierDefinition[observation] << "', both having observation " << observation << "."); + } + } + + for (auto const& al : actionIdentifiers) { + permutation.push_back(al.second); + } + } + return permutation; + + + + } + + + template class MakePOMDPCanonic; + template class MakePOMDPCanonic; + } +} \ No newline at end of file diff --git a/src/storm-pomdp/transformer/MakePOMDPCanonic.h b/src/storm-pomdp/transformer/MakePOMDPCanonic.h new file mode 100644 index 000000000..689e89023 --- /dev/null +++ b/src/storm-pomdp/transformer/MakePOMDPCanonic.h @@ -0,0 +1,24 @@ +#pragma once + +#include "storm/models/sparse/Pomdp.h" + +namespace storm { + namespace transformer { + + template + class MakePOMDPCanonic { + + public: + MakePOMDPCanonic(storm::models::sparse::Pomdp const& pomdp) : pomdp(pomdp) { + + } + + std::shared_ptr> transform() const; + protected: + std::vector computeCanonicalPermutation() const; + std::shared_ptr> applyPermutationOnPomdp(std::vector permutation) const; + + storm::models::sparse::Pomdp const& pomdp; + }; + } +} \ No newline at end of file diff --git a/src/storm/exceptions/AmbiguousModelException.h b/src/storm/exceptions/AmbiguousModelException.h new file mode 100644 index 000000000..98d8fcc20 --- /dev/null +++ b/src/storm/exceptions/AmbiguousModelException.h @@ -0,0 +1,12 @@ +#pragma once + +#include "storm/exceptions/BaseException.h" +#include "storm/exceptions/ExceptionMacros.h" + +namespace storm { + namespace exceptions { + + STORM_NEW_EXCEPTION(AmbiguousModelException) + + } // namespace exceptions +} // namespace storm