|
|
@ -0,0 +1,196 @@ |
|
|
|
#include "storm-pomdp/transformer/MakePOMDPCanonic.h"
|
|
|
|
#include "storm/storage/sparse/ModelComponents.h"
|
|
|
|
#include "storm/exceptions/AmbiguousModelException.h"
|
|
|
|
|
|
|
|
#include "storm/exceptions/InvalidArgumentException.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace storm { |
|
|
|
namespace transformer { |
|
|
|
|
|
|
|
namespace detail { |
|
|
|
struct ActionIdentifier { |
|
|
|
uint64_t choiceLabelId; |
|
|
|
uint64_t choiceOriginId; |
|
|
|
|
|
|
|
bool compatibleWith(ActionIdentifier const& other) const { |
|
|
|
if (choiceLabelId != other.choiceLabelId) { |
|
|
|
return false; // different labels.
|
|
|
|
} |
|
|
|
if (choiceLabelId > 0) { |
|
|
|
// Notice that we assume that we already have ensured that names
|
|
|
|
// are not used more than once.
|
|
|
|
return true; // actions have a name, name coincides.
|
|
|
|
} else { |
|
|
|
// action is unnamed.
|
|
|
|
// We only call this method (at least we only should call this method)
|
|
|
|
// if there are multiple actions. Then two tau actions are only compatible
|
|
|
|
// if they are described by the same choice origin.
|
|
|
|
return choiceOriginId == other.choiceOriginId; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
friend bool operator<(ActionIdentifier const& lhs, ActionIdentifier const& rhs); |
|
|
|
}; |
|
|
|
|
|
|
|
template<typename iterator1, typename iterator2> |
|
|
|
bool compatibleWith(iterator1 start1, iterator1 end1, iterator2 start2, iterator2 end2) { |
|
|
|
iterator1 it1 = start1; |
|
|
|
iterator2 it2 = start2; |
|
|
|
while (it1 != end1 && it2 != end2) { |
|
|
|
if (!it1->compatibleWith(it2->first)) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
++it1; |
|
|
|
++it2; |
|
|
|
} |
|
|
|
return it1 == end1 && it2 == end2; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
bool operator<(ActionIdentifier const& lhs, ActionIdentifier const& rhs) { |
|
|
|
if (lhs.choiceLabelId == rhs.choiceLabelId) { |
|
|
|
return lhs.choiceOriginId < rhs.choiceOriginId; |
|
|
|
} |
|
|
|
return lhs.choiceLabelId < rhs.choiceLabelId; |
|
|
|
} |
|
|
|
|
|
|
|
class ChoiceLabelIdStorage { |
|
|
|
|
|
|
|
public: |
|
|
|
uint64_t registerLabel(std::string const& label) { |
|
|
|
auto it = std::find(storage.begin(), storage.end(), label); |
|
|
|
if (it == storage.end()) { |
|
|
|
storage.push_back(label); |
|
|
|
return storage.size() - 1; |
|
|
|
} else { |
|
|
|
return it - storage.begin(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
private: |
|
|
|
std::vector<std::string> storage = {""}; |
|
|
|
|
|
|
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> MakePOMDPCanonic<ValueType>::transform() const { |
|
|
|
STORM_LOG_THROW(pomdp.hasChoiceOrigins(), storm::exceptions::InvalidArgumentException, "Model must have been built with choice origins"); |
|
|
|
STORM_LOG_THROW(pomdp.hasChoiceLabeling(), storm::exceptions::InvalidArgumentException, "Model must have been built with choice labels"); |
|
|
|
std::vector<uint64_t> permutation = computeCanonicalPermutation(); |
|
|
|
return applyPermutationOnPomdp(permutation); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> MakePOMDPCanonic<ValueType>::applyPermutationOnPomdp(std::vector<uint64_t> permutation) const { |
|
|
|
|
|
|
|
auto rewardModels = pomdp.getRewardModels(); |
|
|
|
std::unordered_map<std::string, storm::models::sparse::StandardRewardModel<ValueType>> newRewardModels; |
|
|
|
for (auto& rewardModelNameAndModel : rewardModels) { |
|
|
|
newRewardModels.emplace(rewardModelNameAndModel.first, rewardModelNameAndModel.second.permuteActions(permutation)); |
|
|
|
} |
|
|
|
storm::storage::sparse::ModelComponents<ValueType> modelcomponents(pomdp.getTransitionMatrix().permuteRows(permutation), |
|
|
|
pomdp.getStateLabeling(), |
|
|
|
newRewardModels, |
|
|
|
false, boost::none); |
|
|
|
modelcomponents.observabilityClasses = pomdp.getObservations(); |
|
|
|
modelcomponents.choiceLabeling = pomdp.getChoiceLabeling(); |
|
|
|
return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(modelcomponents); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template<typename ValueType> |
|
|
|
std::vector<uint64_t> MakePOMDPCanonic<ValueType>::computeCanonicalPermutation() const { |
|
|
|
std::map<uint32_t, std::vector<detail::ActionIdentifier>> observationActionIdentifiers; |
|
|
|
std::map<uint32_t, uint64_t> actionIdentifierDefinition; |
|
|
|
auto const& choiceLabeling = pomdp.getChoiceLabeling(); |
|
|
|
detail::ChoiceLabelIdStorage labelStorage; |
|
|
|
|
|
|
|
std::vector<uint64_t> permutation; |
|
|
|
uint64_t nrObservations = pomdp.getNrObservations(); |
|
|
|
storm::storage::BitVector oneActionObservations(nrObservations); |
|
|
|
storm::storage::BitVector moreActionObservations(nrObservations); |
|
|
|
|
|
|
|
|
|
|
|
for (uint64_t state = 0; state < pomdp.getNumberOfStates(); ++state) { |
|
|
|
uint64_t rowIndexFrom = pomdp.getTransitionMatrix().getRowGroupIndices()[state]; |
|
|
|
uint64_t rowIndexTo = pomdp.getTransitionMatrix().getRowGroupIndices()[state+1]; |
|
|
|
|
|
|
|
uint64_t observation = pomdp.getObservation(state); |
|
|
|
if (rowIndexFrom + 1 == rowIndexTo) { |
|
|
|
permutation.push_back(rowIndexFrom); |
|
|
|
if (moreActionObservations.get(observation)) { |
|
|
|
// We have seen this observation previously with multiple actions. Error!
|
|
|
|
// TODO provide more diagnostic information
|
|
|
|
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Observation " << observation << " sometimes provides multiple action, but in state " << state << " provides one action."); |
|
|
|
} |
|
|
|
oneActionObservations.set(observation); |
|
|
|
|
|
|
|
// One action is ALWAYS fine.
|
|
|
|
continue; |
|
|
|
} else { |
|
|
|
if (oneActionObservations.get(observation)) { |
|
|
|
// We have seen this observation previously with one action. Error!
|
|
|
|
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Observation " << observation << " sometimes provides one action, but in state " << state << " provides multiple actions."); |
|
|
|
} |
|
|
|
moreActionObservations.set(observation); |
|
|
|
} |
|
|
|
|
|
|
|
std::map<detail::ActionIdentifier, uint64_t> actionIdentifiers; |
|
|
|
std::set<uint64_t> actionLabels; |
|
|
|
for (uint64_t actionIndex = rowIndexFrom; actionIndex < rowIndexTo; ++actionIndex) { |
|
|
|
// While this is in full generality a set of labels,
|
|
|
|
// for models modelled with prism, we actually know that these are singleton sets.
|
|
|
|
std::set<std::string> labels = choiceLabeling.getLabelsOfChoice(actionIndex); |
|
|
|
STORM_LOG_ASSERT(labels.size() <= 1, "We expect choice labels to be single sets"); |
|
|
|
// Generate action identifier
|
|
|
|
uint64_t labelId = -1; |
|
|
|
if (labels.size() == 1) { |
|
|
|
labelId = labelStorage.registerLabel(*labels.begin()); |
|
|
|
STORM_LOG_THROW(actionLabels.count(labelId) == 0, storm::exceptions::AmbiguousModelException, "Multiple actions with label '" << *labels.begin() << "' exist in state id " << state << "."); |
|
|
|
actionLabels.emplace(labelId); |
|
|
|
} else { |
|
|
|
labelId = labelStorage.registerLabel(""); |
|
|
|
} |
|
|
|
|
|
|
|
detail::ActionIdentifier ai; |
|
|
|
ai.choiceLabelId = labelId; |
|
|
|
ai.choiceOriginId = pomdp.getChoiceOrigins()->getIdentifier(actionIndex); |
|
|
|
actionIdentifiers.emplace(ai,actionIndex); |
|
|
|
} |
|
|
|
|
|
|
|
if (observationActionIdentifiers.count(observation) == 0) { |
|
|
|
// First state with this observation
|
|
|
|
// store the corresponding vector.
|
|
|
|
std::vector<detail::ActionIdentifier> ais; |
|
|
|
for (auto const& als : actionIdentifiers) { |
|
|
|
ais.push_back(als.first); |
|
|
|
} |
|
|
|
observationActionIdentifiers.emplace(observation, ais); |
|
|
|
actionIdentifierDefinition.emplace(observation, state); |
|
|
|
} else { |
|
|
|
auto referenceStart = observationActionIdentifiers[observation].begin(); |
|
|
|
auto referenceEnd = observationActionIdentifiers[observation].end(); |
|
|
|
if (!detail::compatibleWith(referenceStart, referenceEnd, actionIdentifiers.begin(), actionIdentifiers.end())) { |
|
|
|
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Actions identifiers do not align between states '" << state << "' and '" << actionIdentifierDefinition[observation] << "', both having observation " << observation << "."); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (auto const& al : actionIdentifiers) { |
|
|
|
permutation.push_back(al.second); |
|
|
|
} |
|
|
|
} |
|
|
|
return permutation; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
template class MakePOMDPCanonic<double>; |
|
|
|
template class MakePOMDPCanonic<storm::RationalNumber>; |
|
|
|
} |
|
|
|
} |
xxxxxxxxxx