Browse Source

allow inference from actions, but only with a set option

tempestpy_adaptions
sjunges 8 years ago
parent
commit
9314d99354
  1. 6
      src/storm/builder/BuilderOptions.cpp
  2. 11
      src/storm/builder/BuilderOptions.h
  3. 42
      src/storm/builder/ExplicitModelBuilder.cpp
  4. 2
      src/storm/generator/NextStateGenerator.cpp

6
src/storm/builder/BuilderOptions.cpp

@ -35,7 +35,7 @@ namespace storm {
return boost::get<storm::expressions::Expression>(labelOrExpression); return boost::get<storm::expressions::Expression>(labelOrExpression);
} }
BuilderOptions::BuilderOptions(bool buildAllRewardModels, bool buildAllLabels) : buildAllRewardModels(buildAllRewardModels), buildAllLabels(buildAllLabels), buildChoiceLabels(false), buildStateValuations(false), buildChoiceOrigins(false), explorationChecks(false), explorationShowProgress(false), explorationShowProgressDelay(0) {
BuilderOptions::BuilderOptions(bool buildAllRewardModels, bool buildAllLabels) : buildAllRewardModels(buildAllRewardModels), buildAllLabels(buildAllLabels), buildChoiceLabels(false), buildStateValuations(false), buildChoiceOrigins(false), explorationChecks(false), explorationShowProgress(false), inferObservationsFromActions(false), explorationShowProgressDelay(0) {
// Intentionally left empty. // Intentionally left empty.
} }
@ -156,6 +156,10 @@ namespace storm {
bool BuilderOptions::isBuildAllLabelsSet() const { bool BuilderOptions::isBuildAllLabelsSet() const {
return buildAllLabels; return buildAllLabels;
} }
bool BuilderOptions::isInferObservationsFromActionsSet() const {
return inferObservationsFromActions;
}
BuilderOptions& BuilderOptions::setBuildAllRewardModels(bool newValue) { BuilderOptions& BuilderOptions::setBuildAllRewardModels(bool newValue) {
buildAllRewardModels = newValue; buildAllRewardModels = newValue;

11
src/storm/builder/BuilderOptions.h

@ -82,6 +82,7 @@ namespace storm {
*/ */
void setTerminalStatesFromFormula(storm::logic::Formula const& formula); void setTerminalStatesFromFormula(storm::logic::Formula const& formula);
/*! /*!
* Which reward models are built * Which reward models are built
* @return * @return
@ -107,6 +108,7 @@ namespace storm {
bool isBuildAllLabelsSet() const; bool isBuildAllLabelsSet() const;
bool isExplorationChecksSet() const; bool isExplorationChecksSet() const;
bool isExplorationShowProgressSet() const; bool isExplorationShowProgressSet() const;
bool isInferObservationsFromActionsSet() const;
uint64_t getExplorationShowProgressDelay() const; uint64_t getExplorationShowProgressDelay() const;
/** /**
@ -155,7 +157,11 @@ namespace storm {
* @return this * @return this
*/ */
BuilderOptions& setExplorationChecks(bool newValue = true); BuilderOptions& setExplorationChecks(bool newValue = true);
BuilderOptions& setInferObservationsFromActions(bool newValue = true);
private: private:
/// A flag that indicates whether all reward models are to be built. In this case, the reward model names are /// A flag that indicates whether all reward models are to be built. In this case, the reward model names are
/// to be ignored. /// to be ignored.
@ -191,6 +197,9 @@ namespace storm {
/// A flag that stores whether the progress of exploration is to be printed. /// A flag that stores whether the progress of exploration is to be printed.
bool explorationShowProgress; bool explorationShowProgress;
/// For POMDPs, should we allow inference of observation classes from different enabled actions.
bool inferObservationsFromActions;
/// The delay for printing progress information. /// The delay for printing progress information.
uint64_t explorationShowProgressDelay; uint64_t explorationShowProgressDelay;

42
src/storm/builder/ExplicitModelBuilder.cpp

@ -330,9 +330,49 @@ namespace storm {
} }
if (generator->isPartiallyObservable()) { if (generator->isPartiallyObservable()) {
std::vector<uint32_t> classes; std::vector<uint32_t> classes;
uint32_t newObservation = 0;
classes.resize(stateStorage.getNumberOfStates()); classes.resize(stateStorage.getNumberOfStates());
std::unordered_map<uint32_t, std::vector<std::pair<std::vector<std::string>, uint32_t>>> observationActions;
for (auto const& bitVectorIndexPair : stateStorage.stateToId) { for (auto const& bitVectorIndexPair : stateStorage.stateToId) {
classes[bitVectorIndexPair.second] = generator->observabilityClass(bitVectorIndexPair.first);
uint32_t varObservation = generator->observabilityClass(bitVectorIndexPair.first);
uint32_t observation = -1; // Is replaced later on.
bool foundActionSet = false;
std::vector<std::string> actionNames;
bool addedAnonymousAction = false;
for (uint64 choice = modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second]; choice < modelComponents.transitionMatrix.getRowGroupIndices()[bitVectorIndexPair.second+1]; ++choice) {
if (modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).empty()) {
STORM_LOG_THROW(!addedAnonymousAction, storm::exceptions::WrongFormatException, "Cannot have multiple anonymous actions, as these cannot be mapped correctly.");
actionNames.push_back("");
addedAnonymousAction = true;
} else {
STORM_LOG_ASSERT(modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size() == 1, "Expect choice labelling to contain exactly one label at this point, but found " << modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).size());
actionNames.push_back(*modelComponents.choiceLabeling.get().getLabelsOfChoice(choice).begin());
}
}
STORM_LOG_TRACE("VarObservation: " << varObservation << " Action Names: " << storm::utility::vector::toString(actionNames));
auto it = observationActions.find(varObservation);
if (it == observationActions.end()) {
observationActions.emplace(varObservation, std::vector<std::pair<std::vector<std::string>, uint32_t>>());
} else {
for(auto const& entries : it->second) {
STORM_LOG_TRACE(storm::utility::vector::toString(entries.first));
if (entries.first == actionNames) {
observation = entries.second;
foundActionSet = true;
break;
}
}
STORM_LOG_THROW(generator->getOptions().isInferObservationsFromActionsSet() || foundActionSet, storm::exceptions::WrongFormatException, "Two states with the same observation have a different set of enabled actions, this is only allowed with a special option.");
}
if (!foundActionSet) {
observation = newObservation;
observationActions.find(varObservation)->second.emplace_back(actionNames, newObservation);
++newObservation;
}
classes[bitVectorIndexPair.second] = observation;
} }
modelComponents.observabilityClasses = classes; modelComponents.observabilityClasses = classes;
} }

2
src/storm/generator/NextStateGenerator.cpp

@ -158,9 +158,7 @@ namespace storm {
uint32_t NextStateGenerator<ValueType, StateType>::observabilityClass(CompressedState const &state) const { uint32_t NextStateGenerator<ValueType, StateType>::observabilityClass(CompressedState const &state) const {
if (this->mask.size() == 0) { if (this->mask.size() == 0) {
this->mask = computeObservabilityMask(variableInformation); this->mask = computeObservabilityMask(variableInformation);
std::cout << mask.size() << std::endl;
} }
std::cout << state.size() << std::endl;
return unpackStateToObservabilityClass(state, observabilityMap, mask); return unpackStateToObservabilityClass(state, observabilityMap, mask);
} }

Loading…
Cancel
Save