Sebastian Junges
5 years ago
4 changed files with 235 additions and 0 deletions
-
3src/storm-pomdp-cli/storm-pomdp.cpp
-
196src/storm-pomdp/transformer/MakePOMDPCanonic.cpp
-
24src/storm-pomdp/transformer/MakePOMDPCanonic.h
-
12src/storm/exceptions/AmbiguousModelException.h
@ -0,0 +1,196 @@ |
|||
#include "storm-pomdp/transformer/MakePOMDPCanonic.h"
|
|||
#include "storm/storage/sparse/ModelComponents.h"
|
|||
#include "storm/exceptions/AmbiguousModelException.h"
|
|||
|
|||
#include "storm/exceptions/InvalidArgumentException.h"
|
|||
|
|||
|
|||
namespace storm { |
|||
namespace transformer { |
|||
|
|||
namespace detail { |
|||
struct ActionIdentifier { |
|||
uint64_t choiceLabelId; |
|||
uint64_t choiceOriginId; |
|||
|
|||
bool compatibleWith(ActionIdentifier const& other) const { |
|||
if (choiceLabelId != other.choiceLabelId) { |
|||
return false; // different labels.
|
|||
} |
|||
if (choiceLabelId > 0) { |
|||
// Notice that we assume that we already have ensured that names
|
|||
// are not used more than once.
|
|||
return true; // actions have a name, name coincides.
|
|||
} else { |
|||
// action is unnamed.
|
|||
// We only call this method (at least we only should call this method)
|
|||
// if there are multiple actions. Then two tau actions are only compatible
|
|||
// if they are described by the same choice origin.
|
|||
return choiceOriginId == other.choiceOriginId; |
|||
} |
|||
} |
|||
|
|||
friend bool operator<(ActionIdentifier const& lhs, ActionIdentifier const& rhs); |
|||
}; |
|||
|
|||
template<typename iterator1, typename iterator2> |
|||
bool compatibleWith(iterator1 start1, iterator1 end1, iterator2 start2, iterator2 end2) { |
|||
iterator1 it1 = start1; |
|||
iterator2 it2 = start2; |
|||
while (it1 != end1 && it2 != end2) { |
|||
if (!it1->compatibleWith(it2->first)) { |
|||
return false; |
|||
} |
|||
++it1; |
|||
++it2; |
|||
} |
|||
return it1 == end1 && it2 == end2; |
|||
}; |
|||
|
|||
|
|||
bool operator<(ActionIdentifier const& lhs, ActionIdentifier const& rhs) { |
|||
if (lhs.choiceLabelId == rhs.choiceLabelId) { |
|||
return lhs.choiceOriginId < rhs.choiceOriginId; |
|||
} |
|||
return lhs.choiceLabelId < rhs.choiceLabelId; |
|||
} |
|||
|
|||
class ChoiceLabelIdStorage { |
|||
|
|||
public: |
|||
uint64_t registerLabel(std::string const& label) { |
|||
auto it = std::find(storage.begin(), storage.end(), label); |
|||
if (it == storage.end()) { |
|||
storage.push_back(label); |
|||
return storage.size() - 1; |
|||
} else { |
|||
return it - storage.begin(); |
|||
} |
|||
} |
|||
|
|||
|
|||
private: |
|||
std::vector<std::string> storage = {""}; |
|||
|
|||
}; |
|||
} |
|||
|
|||
template<typename ValueType> |
|||
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> MakePOMDPCanonic<ValueType>::transform() const { |
|||
STORM_LOG_THROW(pomdp.hasChoiceOrigins(), storm::exceptions::InvalidArgumentException, "Model must have been built with choice origins"); |
|||
STORM_LOG_THROW(pomdp.hasChoiceLabeling(), storm::exceptions::InvalidArgumentException, "Model must have been built with choice labels"); |
|||
std::vector<uint64_t> permutation = computeCanonicalPermutation(); |
|||
return applyPermutationOnPomdp(permutation); |
|||
} |
|||
|
|||
|
|||
template<typename ValueType> |
|||
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> MakePOMDPCanonic<ValueType>::applyPermutationOnPomdp(std::vector<uint64_t> permutation) const { |
|||
|
|||
auto rewardModels = pomdp.getRewardModels(); |
|||
std::unordered_map<std::string, storm::models::sparse::StandardRewardModel<ValueType>> newRewardModels; |
|||
for (auto& rewardModelNameAndModel : rewardModels) { |
|||
newRewardModels.emplace(rewardModelNameAndModel.first, rewardModelNameAndModel.second.permuteActions(permutation)); |
|||
} |
|||
storm::storage::sparse::ModelComponents<ValueType> modelcomponents(pomdp.getTransitionMatrix().permuteRows(permutation), |
|||
pomdp.getStateLabeling(), |
|||
newRewardModels, |
|||
false, boost::none); |
|||
modelcomponents.observabilityClasses = pomdp.getObservations(); |
|||
modelcomponents.choiceLabeling = pomdp.getChoiceLabeling(); |
|||
return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(modelcomponents); |
|||
} |
|||
|
|||
|
|||
template<typename ValueType> |
|||
std::vector<uint64_t> MakePOMDPCanonic<ValueType>::computeCanonicalPermutation() const { |
|||
std::map<uint32_t, std::vector<detail::ActionIdentifier>> observationActionIdentifiers; |
|||
std::map<uint32_t, uint64_t> actionIdentifierDefinition; |
|||
auto const& choiceLabeling = pomdp.getChoiceLabeling(); |
|||
detail::ChoiceLabelIdStorage labelStorage; |
|||
|
|||
std::vector<uint64_t> permutation; |
|||
uint64_t nrObservations = pomdp.getNrObservations(); |
|||
storm::storage::BitVector oneActionObservations(nrObservations); |
|||
storm::storage::BitVector moreActionObservations(nrObservations); |
|||
|
|||
|
|||
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|||
uint64_t rowIndexFrom = pomdp.getTransitionMatrix().getRowGroupIndices()[state]; |
|||
uint64_t rowIndexTo = pomdp.getTransitionMatrix().getRowGroupIndices()[state+1]; |
|||
|
|||
uint64_t observation = pomdp.getObservation(state); |
|||
if (rowIndexFrom + 1 == rowIndexTo) { |
|||
permutation.push_back(rowIndexFrom); |
|||
if (moreActionObservations.get(observation)) { |
|||
// We have seen this observation previously with multiple actions. Error!
|
|||
// TODO provide more diagnostic information
|
|||
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Observation " << observation << " sometimes provides multiple action, but in state " << state << " provides one action."); |
|||
} |
|||
oneActionObservations.set(observation); |
|||
|
|||
// One action is ALWAYS fine.
|
|||
continue; |
|||
} else { |
|||
if (oneActionObservations.get(observation)) { |
|||
// We have seen this observation previously with one action. Error!
|
|||
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Observation " << observation << " sometimes provides one action, but in state " << state << " provides multiple actions."); |
|||
} |
|||
moreActionObservations.set(observation); |
|||
} |
|||
|
|||
std::map<detail::ActionIdentifier, uint64_t> actionIdentifiers; |
|||
std::set<uint64_t> actionLabels; |
|||
for (uint64_t actionIndex = rowIndexFrom; actionIndex < rowIndexTo; ++actionIndex) { |
|||
// While this is in full generality a set of labels,
|
|||
// for models modelled with prism, we actually know that these are singleton sets.
|
|||
std::set<std::string> labels = choiceLabeling.getLabelsOfChoice(actionIndex); |
|||
STORM_LOG_ASSERT(labels.size() <= 1, "We expect choice labels to be single sets"); |
|||
// Generate action identifier
|
|||
uint64_t labelId = -1; |
|||
if (labels.size() == 1) { |
|||
labelId = labelStorage.registerLabel(*labels.begin()); |
|||
STORM_LOG_THROW(actionLabels.count(labelId) == 0, storm::exceptions::AmbiguousModelException, "Multiple actions with label '" << *labels.begin() << "' exist in state id " << state << "."); |
|||
actionLabels.emplace(labelId); |
|||
} else { |
|||
labelId = labelStorage.registerLabel(""); |
|||
} |
|||
|
|||
detail::ActionIdentifier ai; |
|||
ai.choiceLabelId = labelId; |
|||
ai.choiceOriginId = pomdp.getChoiceOrigins()->getIdentifier(actionIndex); |
|||
actionIdentifiers.emplace(ai,actionIndex); |
|||
} |
|||
|
|||
if (observationActionIdentifiers.count(observation) == 0) { |
|||
// First state with this observation
|
|||
// store the corresponding vector.
|
|||
std::vector<detail::ActionIdentifier> ais; |
|||
for (auto const& als : actionIdentifiers) { |
|||
ais.push_back(als.first); |
|||
} |
|||
observationActionIdentifiers.emplace(observation, ais); |
|||
actionIdentifierDefinition.emplace(observation, state); |
|||
} else { |
|||
auto referenceStart = observationActionIdentifiers[observation].begin(); |
|||
auto referenceEnd = observationActionIdentifiers[observation].end(); |
|||
if (!detail::compatibleWith(referenceStart, referenceEnd, actionIdentifiers.begin(), actionIdentifiers.end())) { |
|||
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Actions identifiers do not align between states '" << state << "' and '" << actionIdentifierDefinition[observation] << "', both having observation " << observation << "."); |
|||
} |
|||
} |
|||
|
|||
for (auto const& al : actionIdentifiers) { |
|||
permutation.push_back(al.second); |
|||
} |
|||
} |
|||
return permutation; |
|||
|
|||
|
|||
|
|||
} |
|||
|
|||
|
|||
template class MakePOMDPCanonic<double>; |
|||
template class MakePOMDPCanonic<storm::RationalNumber>; |
|||
} |
|||
} |
@ -0,0 +1,24 @@ |
|||
#pragma once |
|||
|
|||
#include "storm/models/sparse/Pomdp.h" |
|||
|
|||
namespace storm { |
|||
namespace transformer { |
|||
|
|||
template<typename ValueType> |
|||
class MakePOMDPCanonic { |
|||
|
|||
public: |
|||
MakePOMDPCanonic(storm::models::sparse::Pomdp<ValueType> const& pomdp) : pomdp(pomdp) { |
|||
|
|||
} |
|||
|
|||
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> transform() const; |
|||
protected: |
|||
std::vector<uint64_t> computeCanonicalPermutation() const; |
|||
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> applyPermutationOnPomdp(std::vector<uint64_t> permutation) const; |
|||
|
|||
storm::models::sparse::Pomdp<ValueType> const& pomdp; |
|||
}; |
|||
} |
|||
} |
@ -0,0 +1,12 @@ |
|||
#pragma once |
|||
|
|||
#include "storm/exceptions/BaseException.h" |
|||
#include "storm/exceptions/ExceptionMacros.h" |
|||
|
|||
namespace storm { |
|||
namespace exceptions { |
|||
|
|||
STORM_NEW_EXCEPTION(AmbiguousModelException) |
|||
|
|||
} // namespace exceptions |
|||
} // namespace storm |
Write
Preview
Loading…
Cancel
Save
Reference in new issue