diff --git a/src/storm-pomdp/transformer/MakeStateSetObservationClosed.cpp b/src/storm-pomdp/transformer/MakeStateSetObservationClosed.cpp new file mode 100644 index 000000000..ce23358f6 --- /dev/null +++ b/src/storm-pomdp/transformer/MakeStateSetObservationClosed.cpp @@ -0,0 +1,70 @@ +#include "storm-pomdp/transformer/MakeStateSetObservationClosed.h" + + +namespace storm { + namespace transformer { + + template + MakeStateSetObservationClosed::MakeStateSetObservationClosed(std::shared_ptr> pomdp) : pomdp(pomdp) { } + + template + std::pair>, std::set> MakeStateSetObservationClosed::transform(storm::storage::BitVector const& stateSet) const { + // Collect observations of target states + std::set oldObservations; + for (auto const& state : stateSet) { + oldObservations.insert(pomdp->getObservation(state)); + } + + // Collect observations that belong to both, target states and non-target states. + // Add a fresh observation for each of them + std::map oldToNewObservationMap; + uint32_t freshObs = pomdp->getNrObservations(); + for (uint64_t state = stateSet.getNextUnsetIndex(0); state < stateSet.size(); state = stateSet.getNextUnsetIndex(state + 1)) { + uint32_t obs = pomdp->getObservation(state); + if (oldObservations.count(obs) > 0) { + // this observation belongs to both, target and non-target states. + if (oldToNewObservationMap.emplace(obs, freshObs).second) { + // We actually inserted something, i.e., we have not seen this observation before. + // For the next observation (that is different from obs) we want to insert another fresh observation index + // This is to preserve the assumption that states with the same observation have the same enabled actions. + ++freshObs; + } + } + } + + // Check whether the state set already is observation closed. + if (oldToNewObservationMap.empty()) { + return {pomdp, std::move(oldObservations)}; + } else { + // Create new observations + auto newObservationVector = pomdp->getObservations(); + for (auto const& state : stateSet) { + auto findRes = oldToNewObservationMap.find(pomdp->getObservation(state)); + if (findRes != oldToNewObservationMap.end()) { + newObservationVector[state] = findRes->second; + } + } + // Create a copy of the pomdp and change observations accordingly. + // This transformation preserves canonicity. + auto transformed = std::make_shared>(*pomdp); + transformed->updateObservations(std::move(newObservationVector), true); + + // Finally, get the new set of target observations + std::set newObservations; + for (auto const& obs : oldObservations) { + auto findRes = oldToNewObservationMap.find(obs); + if (findRes == oldToNewObservationMap.end()) { + newObservations.insert(obs); + } else { + newObservations.insert(findRes->second); + } + } + + return {transformed, std::move(newObservations)}; + } + } + + template class MakeStateSetObservationClosed; + template class MakeStateSetObservationClosed; + } +} \ No newline at end of file diff --git a/src/storm-pomdp/transformer/MakeStateSetObservationClosed.h b/src/storm-pomdp/transformer/MakeStateSetObservationClosed.h new file mode 100644 index 000000000..4a7f20c92 --- /dev/null +++ b/src/storm-pomdp/transformer/MakeStateSetObservationClosed.h @@ -0,0 +1,26 @@ +#pragma once +#include + +#include "storm/models/sparse/Pomdp.h" +namespace storm { + namespace transformer { + + template + class MakeStateSetObservationClosed { + + public: + MakeStateSetObservationClosed(std::shared_ptr> pomdp); + + /*! + * Ensures that the given set of states is observation closed, potentially, adding new observation(s) + * A set of states S is observation closed, iff there is a set of observations Z such that `o(s) in Z iff s in S` + * + * @return the model where the given set of states is observation closed as well as the set of observations that uniquely describe the given state set. + * If the state set is already observation close, we return the original POMDP, i.e., the pomdp is not copied. + */ + std::pair>, std::set> transform(storm::storage::BitVector const& stateSet) const; + protected: + std::shared_ptr> pomdp; + }; + } +} \ No newline at end of file diff --git a/src/storm/models/sparse/Pomdp.cpp b/src/storm/models/sparse/Pomdp.cpp index b38b3dd2f..dc5461120 100644 --- a/src/storm/models/sparse/Pomdp.cpp +++ b/src/storm/models/sparse/Pomdp.cpp @@ -75,6 +75,12 @@ namespace storm { return observations; } + template + void Pomdp::updateObservations(std::vector&& newObservations, bool preservesCanonicity) { + observations = std::move(newObservations); + computeNrObservations(); + setIsCanonic(isCanonic() && preservesCanonicity); + } template std::string Pomdp::additionalDotStateInfo(uint64_t state) const { diff --git a/src/storm/models/sparse/Pomdp.h b/src/storm/models/sparse/Pomdp.h index 408b2dec7..cacee0953 100644 --- a/src/storm/models/sparse/Pomdp.h +++ b/src/storm/models/sparse/Pomdp.h @@ -65,6 +65,16 @@ namespace storm { uint64_t getMaxNrStatesWithSameObservation() const; std::vector const& getObservations() const; + + /*! + * Changes the observations to the given ones and updates redundant informations (like the number of observations) + * After calling this method, isCanonic() returns true iff (i) isCanonic() returned true before calling this method and (ii) preservesCanonicity was set to true. + * + * @param newObservations The new observations + * @param preservesCanonicity specifies whether the pomdp is still canonic (assuming that it was canonic before) + * + */ + void updateObservations(std::vector&& newObservations, bool preservesCanonicity); std::vector getStatesWithObservation(uint32_t observation) const;