70 lines
3.7 KiB
70 lines
3.7 KiB
#include "storm-pomdp/transformer/MakeStateSetObservationClosed.h"
|
|
|
|
|
|
namespace storm {
|
|
namespace transformer {
|
|
|
|
template<typename ValueType>
|
|
MakeStateSetObservationClosed<ValueType>::MakeStateSetObservationClosed(std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> pomdp) : pomdp(pomdp) { }
|
|
|
|
template<typename ValueType>
|
|
std::pair<std::shared_ptr<storm::models::sparse::Pomdp<ValueType>>, std::set<uint32_t>> MakeStateSetObservationClosed<ValueType>::transform(storm::storage::BitVector const& stateSet) const {
|
|
// Collect observations of target states
|
|
std::set<uint32_t> oldObservations;
|
|
for (auto const& state : stateSet) {
|
|
oldObservations.insert(pomdp->getObservation(state));
|
|
}
|
|
|
|
// Collect observations that belong to both, target states and non-target states.
|
|
// Add a fresh observation for each of them
|
|
std::map<uint32_t, uint32_t> oldToNewObservationMap;
|
|
uint32_t freshObs = pomdp->getNrObservations();
|
|
for (uint64_t state = stateSet.getNextUnsetIndex(0); state < stateSet.size(); state = stateSet.getNextUnsetIndex(state + 1)) {
|
|
uint32_t obs = pomdp->getObservation(state);
|
|
if (oldObservations.count(obs) > 0) {
|
|
// this observation belongs to both, target and non-target states.
|
|
if (oldToNewObservationMap.emplace(obs, freshObs).second) {
|
|
// We actually inserted something, i.e., we have not seen this observation before.
|
|
// For the next observation (that is different from obs) we want to insert another fresh observation index
|
|
// This is to preserve the assumption that states with the same observation have the same enabled actions.
|
|
++freshObs;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check whether the state set already is observation closed.
|
|
if (oldToNewObservationMap.empty()) {
|
|
return {pomdp, std::move(oldObservations)};
|
|
} else {
|
|
// Create new observations
|
|
auto newObservationVector = pomdp->getObservations();
|
|
for (auto const& state : stateSet) {
|
|
auto findRes = oldToNewObservationMap.find(pomdp->getObservation(state));
|
|
if (findRes != oldToNewObservationMap.end()) {
|
|
newObservationVector[state] = findRes->second;
|
|
}
|
|
}
|
|
// Create a copy of the pomdp and change observations accordingly.
|
|
// This transformation preserves canonicity.
|
|
auto transformed = std::make_shared<storm::models::sparse::Pomdp<ValueType>>(*pomdp);
|
|
transformed->updateObservations(std::move(newObservationVector), true);
|
|
|
|
// Finally, get the new set of target observations
|
|
std::set<uint32_t> newObservations;
|
|
for (auto const& obs : oldObservations) {
|
|
auto findRes = oldToNewObservationMap.find(obs);
|
|
if (findRes == oldToNewObservationMap.end()) {
|
|
newObservations.insert(obs);
|
|
} else {
|
|
newObservations.insert(findRes->second);
|
|
}
|
|
}
|
|
|
|
return {transformed, std::move(newObservations)};
|
|
}
|
|
}
|
|
|
|
template class MakeStateSetObservationClosed<double>;
|
|
template class MakeStateSetObservationClosed<storm::RationalNumber>;
|
|
}
|
|
}
|