Sebastian Junges
5 years ago
4 changed files with 235 additions and 0 deletions
-
3src/storm-pomdp-cli/storm-pomdp.cpp
-
196src/storm-pomdp/transformer/MakePOMDPCanonic.cpp
-
24src/storm-pomdp/transformer/MakePOMDPCanonic.h
-
12src/storm/exceptions/AmbiguousModelException.h
@ -0,0 +1,196 @@ |
|||||
|
#include "storm-pomdp/transformer/MakePOMDPCanonic.h"
|
||||
|
#include "storm/storage/sparse/ModelComponents.h"
|
||||
|
#include "storm/exceptions/AmbiguousModelException.h"
|
||||
|
|
||||
|
#include "storm/exceptions/InvalidArgumentException.h"
|
||||
|
|
||||
|
|
||||
|
namespace storm { |
||||
|
namespace transformer { |
||||
|
|
||||
|
namespace detail { |
||||
|
struct ActionIdentifier { |
||||
|
uint64_t choiceLabelId; |
||||
|
uint64_t choiceOriginId; |
||||
|
|
||||
|
bool compatibleWith(ActionIdentifier const& other) const { |
||||
|
if (choiceLabelId != other.choiceLabelId) { |
||||
|
return false; // different labels.
|
||||
|
} |
||||
|
if (choiceLabelId > 0) { |
||||
|
// Notice that we assume that we already have ensured that names
|
||||
|
// are not used more than once.
|
||||
|
return true; // actions have a name, name coincides.
|
||||
|
} else { |
||||
|
// action is unnamed.
|
||||
|
// We only call this method (at least we only should call this method)
|
||||
|
// if there are multiple actions. Then two tau actions are only compatible
|
||||
|
// if they are described by the same choice origin.
|
||||
|
return choiceOriginId == other.choiceOriginId; |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
friend bool operator<(ActionIdentifier const& lhs, ActionIdentifier const& rhs); |
||||
|
}; |
||||
|
|
||||
|
template<typename iterator1, typename iterator2> |
||||
|
bool compatibleWith(iterator1 start1, iterator1 end1, iterator2 start2, iterator2 end2) { |
||||
|
iterator1 it1 = start1; |
||||
|
iterator2 it2 = start2; |
||||
|
while (it1 != end1 && it2 != end2) { |
||||
|
if (!it1->compatibleWith(it2->first)) { |
||||
|
return false; |
||||
|
} |
||||
|
++it1; |
||||
|
++it2; |
||||
|
} |
||||
|
return it1 == end1 && it2 == end2; |
||||
|
}; |
||||
|
|
||||
|
|
||||
|
bool operator<(ActionIdentifier const& lhs, ActionIdentifier const& rhs) { |
||||
|
if (lhs.choiceLabelId == rhs.choiceLabelId) { |
||||
|
return lhs.choiceOriginId < rhs.choiceOriginId; |
||||
|
} |
||||
|
return lhs.choiceLabelId < rhs.choiceLabelId; |
||||
|
} |
||||
|
|
||||
|
class ChoiceLabelIdStorage { |
||||
|
|
||||
|
public: |
||||
|
uint64_t registerLabel(std::string const& label) { |
||||
|
auto it = std::find(storage.begin(), storage.end(), label); |
||||
|
if (it == storage.end()) { |
||||
|
storage.push_back(label); |
||||
|
return storage.size() - 1; |
||||
|
} else { |
||||
|
return it - storage.begin(); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
|
||||
|
private: |
||||
|
std::vector<std::string> storage = {""}; |
||||
|
|
||||
|
}; |
||||
|
} |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> MakePOMDPCanonic<ValueType>::transform() const { |
||||
|
STORM_LOG_THROW(pomdp.hasChoiceOrigins(), storm::exceptions::InvalidArgumentException, "Model must have been built with choice origins"); |
||||
|
STORM_LOG_THROW(pomdp.hasChoiceLabeling(), storm::exceptions::InvalidArgumentException, "Model must have been built with choice labels"); |
||||
|
std::vector<uint64_t> permutation = computeCanonicalPermutation(); |
||||
|
return applyPermutationOnPomdp(permutation); |
||||
|
} |
||||
|
|
||||
|
|
||||
|
template<typename ValueType> |
||||
|
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> MakePOMDPCanonic<ValueType>::applyPermutationOnPomdp(std::vector<uint64_t> permutation) const { |
||||
|
|
||||
|
auto rewardModels = pomdp.getRewardModels(); |
||||
|
std::unordered_map<std::string, storm::models::sparse::StandardRewardModel<ValueType>> newRewardModels; |
||||
|
for (auto& rewardModelNameAndModel : rewardModels) { |
||||
|
newRewardModels.emplace(rewardModelNameAndModel.first, rewardModelNameAndModel.second.permuteActions(permutation)); |
||||
|
} |
||||
|
storm::storage::sparse::ModelComponents<ValueType> modelcomponents(pomdp.getTransitionMatrix().permuteRows(permutation), |
||||
|
pomdp.getStateLabeling(), |
||||
|
newRewardModels, |
||||
|
false, boost::none); |
||||
|
modelcomponents.observabilityClasses = pomdp.getObservations(); |
||||
|
modelcomponents.choiceLabeling = pomdp.getChoiceLabeling(); |
||||
|
return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(modelcomponents); |
||||
|
} |
||||
|
|
||||
|
|
||||
|
template<typename ValueType> |
||||
|
std::vector<uint64_t> MakePOMDPCanonic<ValueType>::computeCanonicalPermutation() const { |
||||
|
std::map<uint32_t, std::vector<detail::ActionIdentifier>> observationActionIdentifiers; |
||||
|
std::map<uint32_t, uint64_t> actionIdentifierDefinition; |
||||
|
auto const& choiceLabeling = pomdp.getChoiceLabeling(); |
||||
|
detail::ChoiceLabelIdStorage labelStorage; |
||||
|
|
||||
|
std::vector<uint64_t> permutation; |
||||
|
uint64_t nrObservations = pomdp.getNrObservations(); |
||||
|
storm::storage::BitVector oneActionObservations(nrObservations); |
||||
|
storm::storage::BitVector moreActionObservations(nrObservations); |
||||
|
|
||||
|
|
||||
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
||||
|
uint64_t rowIndexFrom = pomdp.getTransitionMatrix().getRowGroupIndices()[state]; |
||||
|
uint64_t rowIndexTo = pomdp.getTransitionMatrix().getRowGroupIndices()[state+1]; |
||||
|
|
||||
|
uint64_t observation = pomdp.getObservation(state); |
||||
|
if (rowIndexFrom + 1 == rowIndexTo) { |
||||
|
permutation.push_back(rowIndexFrom); |
||||
|
if (moreActionObservations.get(observation)) { |
||||
|
// We have seen this observation previously with multiple actions. Error!
|
||||
|
// TODO provide more diagnostic information
|
||||
|
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Observation " << observation << " sometimes provides multiple action, but in state " << state << " provides one action."); |
||||
|
} |
||||
|
oneActionObservations.set(observation); |
||||
|
|
||||
|
// One action is ALWAYS fine.
|
||||
|
continue; |
||||
|
} else { |
||||
|
if (oneActionObservations.get(observation)) { |
||||
|
// We have seen this observation previously with one action. Error!
|
||||
|
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Observation " << observation << " sometimes provides one action, but in state " << state << " provides multiple actions."); |
||||
|
} |
||||
|
moreActionObservations.set(observation); |
||||
|
} |
||||
|
|
||||
|
std::map<detail::ActionIdentifier, uint64_t> actionIdentifiers; |
||||
|
std::set<uint64_t> actionLabels; |
||||
|
for (uint64_t actionIndex = rowIndexFrom; actionIndex < rowIndexTo; ++actionIndex) { |
||||
|
// While this is in full generality a set of labels,
|
||||
|
// for models modelled with prism, we actually know that these are singleton sets.
|
||||
|
std::set<std::string> labels = choiceLabeling.getLabelsOfChoice(actionIndex); |
||||
|
STORM_LOG_ASSERT(labels.size() <= 1, "We expect choice labels to be single sets"); |
||||
|
// Generate action identifier
|
||||
|
uint64_t labelId = -1; |
||||
|
if (labels.size() == 1) { |
||||
|
labelId = labelStorage.registerLabel(*labels.begin()); |
||||
|
STORM_LOG_THROW(actionLabels.count(labelId) == 0, storm::exceptions::AmbiguousModelException, "Multiple actions with label '" << *labels.begin() << "' exist in state id " << state << "."); |
||||
|
actionLabels.emplace(labelId); |
||||
|
} else { |
||||
|
labelId = labelStorage.registerLabel(""); |
||||
|
} |
||||
|
|
||||
|
detail::ActionIdentifier ai; |
||||
|
ai.choiceLabelId = labelId; |
||||
|
ai.choiceOriginId = pomdp.getChoiceOrigins()->getIdentifier(actionIndex); |
||||
|
actionIdentifiers.emplace(ai,actionIndex); |
||||
|
} |
||||
|
|
||||
|
if (observationActionIdentifiers.count(observation) == 0) { |
||||
|
// First state with this observation
|
||||
|
// store the corresponding vector.
|
||||
|
std::vector<detail::ActionIdentifier> ais; |
||||
|
for (auto const& als : actionIdentifiers) { |
||||
|
ais.push_back(als.first); |
||||
|
} |
||||
|
observationActionIdentifiers.emplace(observation, ais); |
||||
|
actionIdentifierDefinition.emplace(observation, state); |
||||
|
} else { |
||||
|
auto referenceStart = observationActionIdentifiers[observation].begin(); |
||||
|
auto referenceEnd = observationActionIdentifiers[observation].end(); |
||||
|
if (!detail::compatibleWith(referenceStart, referenceEnd, actionIdentifiers.begin(), actionIdentifiers.end())) { |
||||
|
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Actions identifiers do not align between states '" << state << "' and '" << actionIdentifierDefinition[observation] << "', both having observation " << observation << "."); |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
for (auto const& al : actionIdentifiers) { |
||||
|
permutation.push_back(al.second); |
||||
|
} |
||||
|
} |
||||
|
return permutation; |
||||
|
|
||||
|
|
||||
|
|
||||
|
} |
||||
|
|
||||
|
|
||||
|
template class MakePOMDPCanonic<double>; |
||||
|
template class MakePOMDPCanonic<storm::RationalNumber>; |
||||
|
} |
||||
|
} |
@ -0,0 +1,24 @@ |
|||||
|
#pragma once |
||||
|
|
||||
|
#include "storm/models/sparse/Pomdp.h" |
||||
|
|
||||
|
namespace storm { |
||||
|
namespace transformer { |
||||
|
|
||||
|
template<typename ValueType> |
||||
|
class MakePOMDPCanonic { |
||||
|
|
||||
|
public: |
||||
|
MakePOMDPCanonic(storm::models::sparse::Pomdp<ValueType> const& pomdp) : pomdp(pomdp) { |
||||
|
|
||||
|
} |
||||
|
|
||||
|
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> transform() const; |
||||
|
protected: |
||||
|
std::vector<uint64_t> computeCanonicalPermutation() const; |
||||
|
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> applyPermutationOnPomdp(std::vector<uint64_t> permutation) const; |
||||
|
|
||||
|
storm::models::sparse::Pomdp<ValueType> const& pomdp; |
||||
|
}; |
||||
|
} |
||||
|
} |
@ -0,0 +1,12 @@ |
|||||
|
#pragma once |
||||
|
|
||||
|
#include "storm/exceptions/BaseException.h" |
||||
|
#include "storm/exceptions/ExceptionMacros.h" |
||||
|
|
||||
|
namespace storm { |
||||
|
namespace exceptions { |
||||
|
|
||||
|
STORM_NEW_EXCEPTION(AmbiguousModelException) |
||||
|
|
||||
|
} // namespace exceptions |
||||
|
} // namespace storm |
Write
Preview
Loading…
Cancel
Save
Reference in new issue