From c3847d05afd0bce9aae4d80047a1a1543846f263 Mon Sep 17 00:00:00 2001 From: Tim Quatmann Date: Tue, 7 Apr 2020 06:37:01 +0200 Subject: [PATCH] Scaling the rating of an observation with the current resolution. --- .../ApproximatePOMDPModelchecker.cpp | 18 +++++++++++------- .../ApproximatePOMDPModelchecker.h | 4 ++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp index 4526607d4..2936d9b40 100644 --- a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp +++ b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp @@ -264,21 +264,23 @@ namespace storm { * Here, 0 means a bad approximation and 1 means a good approximation. */ template - typename ApproximatePOMDPModelchecker::ValueType ApproximatePOMDPModelchecker::rateObservation(typename ExplorerType::SuccessorObservationInformation const& info) { + typename ApproximatePOMDPModelchecker::ValueType ApproximatePOMDPModelchecker::rateObservation(typename ExplorerType::SuccessorObservationInformation const& info, uint64_t const& observationResolution, uint64_t const& maxResolution) { auto n = storm::utility::convertNumber(info.successorWithObsCount); auto one = storm::utility::one(); - // Create the actual rating for this observation at this choice from the given info + // Create the rating for this observation at this choice from the given info ValueType obsChoiceRating = info.maxProbabilityToSuccessorWithObs / info.observationProbability; // At this point, obsRating is the largest triangulation weight (which ranges from 1/n to 1 // Normalize the rating so that it ranges from 0 to 1, where // 0 means that the actual belief lies in the middle of the triangulating simplex (i.e. a "bad" approximation) and 1 means that the belief is precisely approximated. obsChoiceRating = (obsChoiceRating * n - one) / (n - one); + // Scale the ratings with the resolutions, so that low resolutions get a lower rating (and are thus more likely to be refined) + obsChoiceRating *= storm::utility::convertNumber(observationResolution) / storm::utility::convertNumber(maxResolution); return obsChoiceRating; } template - std::vector::ValueType> ApproximatePOMDPModelchecker::getObservationRatings(std::shared_ptr const& overApproximation) { + std::vector::ValueType> ApproximatePOMDPModelchecker::getObservationRatings(std::shared_ptr const& overApproximation, std::vector const& observationResolutionVector, uint64_t const& maxResolution) { uint64_t numMdpChoices = overApproximation->getExploredMdp()->getNumberOfChoices(); std::vector resultingRatings(pomdp.getNrObservations(), storm::utility::one()); @@ -289,7 +291,7 @@ namespace storm { overApproximation->gatherSuccessorObservationInformationAtMdpChoice(mdpChoice, gatheredSuccessorObservations); for (auto const& obsInfo : gatheredSuccessorObservations) { auto const& obs = obsInfo.first; - ValueType obsChoiceRating = rateObservation(obsInfo.second); + ValueType obsChoiceRating = rateObservation(obsInfo.second, observationResolutionVector[obs], maxResolution); // The rating of the observation will be the minimum over all choice-based observation ratings resultingRatings[obs] = std::min(resultingRatings[obs], obsChoiceRating); @@ -303,7 +305,9 @@ namespace storm { STORM_LOG_ASSERT(!refine || refinementAggressiveness != nullptr, "Refinement enabled but no aggressiveness given"); STORM_LOG_ASSERT(!refine || *refinementAggressiveness >= storm::utility::zero(), "Can not refine with negative aggressiveness."); STORM_LOG_ASSERT(!refine || *refinementAggressiveness <= storm::utility::one(), "Refinement with aggressiveness > 1 is invalid."); - + uint64_t maxResolution = *std::max_element(observationResolutionVector.begin(), observationResolutionVector.end()); + STORM_LOG_INFO("Refining with maximal resolution " << maxResolution << "."); + statistics.overApproximationBuildTime.start(); storm::storage::BitVector refinedObservations; if (!refine) { @@ -315,7 +319,7 @@ namespace storm { } } else { // If we refine the existing overApproximation, we need to find out which observation resolutions need refinement. - auto obsRatings = getObservationRatings(overApproximation); + auto obsRatings = getObservationRatings(overApproximation, observationResolutionVector, maxResolution); ValueType minRating = *std::min_element(obsRatings.begin(), obsRatings.end()); // Potentially increase the aggressiveness so that at least one observation actually gets refinement. *refinementAggressiveness = std::max(minRating, *refinementAggressiveness); @@ -356,7 +360,7 @@ namespace storm { overApproximation->gatherSuccessorObservationInformationAtCurrentState(action, gatheredSuccessorObservations); for (auto const& obsInfo : gatheredSuccessorObservations) { if (refinedObservations.get(obsInfo.first)) { - ValueType obsRating = rateObservation(obsInfo.second); + ValueType obsRating = rateObservation(obsInfo.second, observationResolutionVector[obsInfo.first], maxResolution); stateActionRating = std::min(stateActionRating, obsRating); } } diff --git a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h index 7fbd2ab5e..f895a3138 100644 --- a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h +++ b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h @@ -84,9 +84,9 @@ namespace storm { */ void buildUnderApproximation(std::set const &targetObservations, bool min, bool computeRewards, uint64_t maxStateCount, std::shared_ptr& beliefManager, std::shared_ptr& underApproximation); - ValueType rateObservation(typename ExplorerType::SuccessorObservationInformation const& info); + ValueType rateObservation(typename ExplorerType::SuccessorObservationInformation const& info, uint64_t const& observationResolution, uint64_t const& maxResolution); - std::vector getObservationRatings(std::shared_ptr const& overApproximation); + std::vector getObservationRatings(std::shared_ptr const& overApproximation, std::vector const& observationResolutionVector, uint64_t const& maxResolution); struct Statistics { Statistics();