Browse Source

various diagnostic informations to explain why we reject a POMDP

tempestpy_adaptions
Sebastian Junges 5 years ago
parent
commit
1fe1bd4dea
  1. 70
      src/storm-pomdp/transformer/MakePOMDPCanonic.cpp
  2. 1
      src/storm-pomdp/transformer/MakePOMDPCanonic.h

70
src/storm-pomdp/transformer/MakePOMDPCanonic.cpp

@ -31,6 +31,8 @@ namespace storm {
}
friend bool operator<(ActionIdentifier const& lhs, ActionIdentifier const& rhs);
};
template<typename iterator1, typename iterator2>
@ -68,11 +70,49 @@ namespace storm {
}
}
std::string const& getLabel(uint64_t id) const {
STORM_LOG_ASSERT(id < storage.size(), "Id must be in storage");
return storage[id];
}
friend std::ostream& operator<<(std::ostream& os, ChoiceLabelIdStorage const& labelStorage);
private:
std::vector<std::string> storage = {""};
};
std::ostream& operator<<(std::ostream& os, ChoiceLabelIdStorage const& labelStorage) {
os << "LabelStorage: {";
uint64_t i = 0;
for (auto const& entry : labelStorage.storage) {
os << i << " -> " << entry << " ;";
++i;
}
os << "}";
return os;
}
void actionIdentifiersToStream(std::ostream& stream, std::vector<ActionIdentifier> const& actionIdentifiers, ChoiceLabelIdStorage const& labelStorage) {
stream << "actions: {";
for (auto ai : actionIdentifiers) {
stream << "[" << ai.choiceLabelId << " (" << labelStorage.getLabel(ai.choiceLabelId) << ")";
stream << ", " << ai.choiceOriginId << "]";
}
stream << "}";
}
template <typename IrrelevantType>
void actionIdentifiersToStream(std::ostream& stream, std::map<ActionIdentifier, IrrelevantType> const& actionIdentifiers, ChoiceLabelIdStorage const& labelStorage) {
stream << "actions: {";
for (auto ai : actionIdentifiers) {
stream << "[" << ai.first.choiceLabelId << "('" << labelStorage.getLabel(ai.first.choiceLabelId) << "')";
stream << ", " << ai.first.choiceOriginId << "]";
}
stream << "}";
}
}
template<typename ValueType>
@ -97,11 +137,20 @@ namespace storm {
newRewardModels,
false, boost::none);
modelcomponents.observabilityClasses = pomdp.getObservations();
modelcomponents.choiceLabeling = pomdp.getChoiceLabeling();
//modelcomponents.choiceLabeling = pomdp.getChoiceLabeling();
return std::make_shared<storm::models::sparse::Pomdp<ValueType>>(modelcomponents);
}
template<typename ValueType>
std::string MakePOMDPCanonic<ValueType>::getStateInformation(uint64_t state) const {
if(pomdp.hasStateValuations()) {
return std::to_string(state) + "[" + pomdp.getStateValuations().getStateInfo(state) + "]";
} else {
return std::to_string(state);
}
}
template<typename ValueType>
std::vector<uint64_t> MakePOMDPCanonic<ValueType>::computeCanonicalPermutation() const {
std::map<uint32_t, std::vector<detail::ActionIdentifier>> observationActionIdentifiers;
@ -125,7 +174,7 @@ namespace storm {
if (moreActionObservations.get(observation)) {
// We have seen this observation previously with multiple actions. Error!
// TODO provide more diagnostic information
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Observation " << observation << " sometimes provides multiple action, but in state " << state << " provides one action.");
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Observation " << observation << " sometimes provides multiple actions, but in state " << state << " provides one action.");
}
oneActionObservations.set(observation);
@ -159,8 +208,10 @@ namespace storm {
detail::ActionIdentifier ai;
ai.choiceLabelId = labelId;
ai.choiceOriginId = pomdp.getChoiceOrigins()->getIdentifier(actionIndex);
STORM_LOG_ASSERT(actionIdentifiers.count(ai) == 0, "Action with this identifier already exists for this state");
actionIdentifiers.emplace(ai,actionIndex);
}
STORM_LOG_ASSERT(actionIdentifiers.size() == rowIndexTo - rowIndexFrom, "Number of action identifiers should match number of actions");
if (observationActionIdentifiers.count(observation) == 0) {
// First state with this observation
@ -169,13 +220,26 @@ namespace storm {
for (auto const& als : actionIdentifiers) {
ais.push_back(als.first);
}
observationActionIdentifiers.emplace(observation, ais);
actionIdentifierDefinition.emplace(observation, state);
} else {
auto referenceStart = observationActionIdentifiers[observation].begin();
auto referenceEnd = observationActionIdentifiers[observation].end();
STORM_LOG_ASSERT(observationActionIdentifiers[observation].size() == pomdp.getNumberOfChoices(actionIdentifierDefinition[observation]), "Number of actions recorded for state does not coinide with number of actions.");
if (observationActionIdentifiers[observation].size() != actionIdentifiers.size()) {
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Number of actions in state '" << getStateInformation(state) << "' (nr actions:" << actionIdentifiers.size() << ") and state '" << getStateInformation(actionIdentifierDefinition[observation]) << "' (actions: "<< observationActionIdentifiers[observation].size() << " ), both having observation " << observation << " do not match." );
}
if (!detail::compatibleWith(referenceStart, referenceEnd, actionIdentifiers.begin(), actionIdentifiers.end())) {
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Actions identifiers do not align between states '" << state << "' and '" << actionIdentifierDefinition[observation] << "', both having observation " << observation << ".");
std::cout << "Observation " << observation << ": ";
detail::actionIdentifiersToStream(std::cout, observationActionIdentifiers[observation], labelStorage);
std::cout << " according to state " << actionIdentifierDefinition[observation] << "." << std::endl;
std::cout << "Observation " << observation << ": ";
detail::actionIdentifiersToStream(std::cout, actionIdentifiers, labelStorage);
std::cout << " according to state " << state << "." << std::endl;
STORM_LOG_THROW(false, storm::exceptions::AmbiguousModelException, "Actions identifiers do not align between states '" << getStateInformation(state) << "' and '" << getStateInformation(actionIdentifierDefinition[observation]) << "', both having observation " << observation << ". See output above for more information.");
}
}

1
src/storm-pomdp/transformer/MakePOMDPCanonic.h

@ -17,6 +17,7 @@ namespace storm {
protected:
std::vector<uint64_t> computeCanonicalPermutation() const;
std::shared_ptr<storm::models::sparse::Pomdp<ValueType>> applyPermutationOnPomdp(std::vector<uint64_t> permutation) const;
std::string getStateInformation(uint64_t state) const;
storm::models::sparse::Pomdp<ValueType> const& pomdp;
};

Loading…
Cancel
Save