From 75d792e9876913d1b000aa9fd43647897baaf97b Mon Sep 17 00:00:00 2001 From: Tim Quatmann Date: Wed, 22 Apr 2020 13:51:18 +0200 Subject: [PATCH] Implemented refinement heuristic. --- src/storm-pomdp/builder/BeliefMdpExplorer.h | 87 ++++++- .../ApproximatePOMDPModelchecker.cpp | 212 ++++++++++++------ .../ApproximatePOMDPModelchecker.h | 11 +- 3 files changed, 234 insertions(+), 76 deletions(-) diff --git a/src/storm-pomdp/builder/BeliefMdpExplorer.h b/src/storm-pomdp/builder/BeliefMdpExplorer.h index 2e1b0f4cc..0fc6c9ba4 100644 --- a/src/storm-pomdp/builder/BeliefMdpExplorer.h +++ b/src/storm-pomdp/builder/BeliefMdpExplorer.h @@ -61,7 +61,10 @@ namespace storm { exploredMdpTransitions.clear(); exploredChoiceIndices.clear(); mdpActionRewards.clear(); - optimalMdpChoices = boost::none; + targetStates.clear(); + truncatedStates.clear(); + delayedExplorationChoices.clear(); + optimalChoices = boost::none; optimalChoicesReachableMdpStates = boost::none; exploredMdp = nullptr; internalAddRowGroupIndex(); // Mark the start of the first row group @@ -120,6 +123,7 @@ namespace storm { } targetStates = storm::storage::BitVector(getCurrentNumberOfMdpStates(), false); truncatedStates = storm::storage::BitVector(getCurrentNumberOfMdpStates(), false); + delayedExplorationChoices.clear(); mdpStatesToExplore.clear(); // The extra states are not changed @@ -226,17 +230,51 @@ namespace storm { truncatedStates.set(getCurrentMdpState(), true); } - bool currentStateHasOldBehavior() { + void setCurrentChoiceIsDelayed(uint64_t const& localActionIndex) { + STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); + delayedExplorationChoices.grow(getCurrentNumberOfMdpChoices(), false); + delayedExplorationChoices.set(getStartOfCurrentRowGroup() + localActionIndex, true); + } + + bool currentStateHasOldBehavior() const { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(getCurrentMdpState() != noState(), "Method 'currentStateHasOldBehavior' called but there is no current state."); return exploredMdp && getCurrentMdpState() < exploredMdp->getNumberOfStates(); } + bool getCurrentStateWasTruncated() const { + STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); + STORM_LOG_ASSERT(getCurrentMdpState() != noState(), "Method 'actionAtCurrentStateWasOptimal' called but there is no current state."); + STORM_LOG_ASSERT(currentStateHasOldBehavior(), "Method 'actionAtCurrentStateWasOptimal' called but current state has no old behavior"); + STORM_LOG_ASSERT(exploredMdp, "No 'old' mdp available"); + return exploredMdp->getStateLabeling().getStateHasLabel("truncated", getCurrentMdpState()); + } + + /*! + * Retrieves whether the current state can be reached under an optimal scheduler + * This requires a previous call of computeOptimalChoicesAndReachableMdpStates. + */ + bool stateIsOptimalSchedulerReachable(MdpStateType mdpState) const { + STORM_LOG_ASSERT(status == Status::ModelChecked, "Method call is invalid in current status."); + STORM_LOG_ASSERT(optimalChoicesReachableMdpStates.is_initialized(), "Method 'stateIsOptimalSchedulerReachable' called but 'computeOptimalChoicesAndReachableMdpStates' was not called before."); + return optimalChoicesReachableMdpStates->get(mdpState); + } + + /*! + * Retrieves whether the given action at the current state was optimal in the most recent check. + * This requires a previous call of computeOptimalChoicesAndReachableMdpStates. + */ + bool actionIsOptimal(uint64_t const& globalActionIndex) const { + STORM_LOG_ASSERT(status == Status::ModelChecked, "Method call is invalid in current status."); + STORM_LOG_ASSERT(optimalChoices.is_initialized(), "Method 'actionIsOptimal' called but 'computeOptimalChoicesAndReachableMdpStates' was not called before."); + return optimalChoices->get(globalActionIndex); + } + /*! * Retrieves whether the current state can be reached under a scheduler that was optimal in the most recent check. * This requires (i) a previous call of computeOptimalChoicesAndReachableMdpStates and (ii) that the current state has old behavior. */ - bool currentStateIsOptimalSchedulerReachable() { + bool currentStateIsOptimalSchedulerReachable() const { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(getCurrentMdpState() != noState(), "Method 'currentStateIsOptimalSchedulerReachable' called but there is no current state."); STORM_LOG_ASSERT(currentStateHasOldBehavior(), "Method 'currentStateIsOptimalSchedulerReachable' called but current state has no old behavior"); @@ -248,13 +286,22 @@ namespace storm { * Retrieves whether the given action at the current state was optimal in the most recent check. * This requires (i) a previous call of computeOptimalChoicesAndReachableMdpStates and (ii) that the current state has old behavior. */ - bool actionAtCurrentStateWasOptimal(uint64_t const& localActionIndex) { + bool actionAtCurrentStateWasOptimal(uint64_t const& localActionIndex) const { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(getCurrentMdpState() != noState(), "Method 'actionAtCurrentStateWasOptimal' called but there is no current state."); STORM_LOG_ASSERT(currentStateHasOldBehavior(), "Method 'actionAtCurrentStateWasOptimal' called but current state has no old behavior"); STORM_LOG_ASSERT(optimalChoices.is_initialized(), "Method 'currentStateIsOptimalSchedulerReachable' called but 'computeOptimalChoicesAndReachableMdpStates' was not called before."); - uint64_t row = getStartOfCurrentRowGroup() + localActionIndex; - return optimalChoices->get(row); + uint64_t choice = getStartOfCurrentRowGroup() + localActionIndex; + return optimalChoices->get(choice); + } + + bool getCurrentStateActionExplorationWasDelayed(uint64_t const& localActionIndex) const { + STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); + STORM_LOG_ASSERT(getCurrentMdpState() != noState(), "Method 'actionAtCurrentStateWasOptimal' called but there is no current state."); + STORM_LOG_ASSERT(currentStateHasOldBehavior(), "Method 'actionAtCurrentStateWasOptimal' called but current state has no old behavior"); + STORM_LOG_ASSERT(exploredMdp, "No 'old' mdp available"); + uint64_t choice = exploredMdp->getNondeterministicChoiceIndices()[getCurrentMdpState()] + localActionIndex; + return exploredMdp->hasChoiceLabeling() && exploredMdp->getChoiceLabeling().getChoiceHasLabel("delayed", choice); } /*! @@ -351,7 +398,17 @@ namespace storm { storm::models::sparse::StandardRewardModel(boost::optional>(), std::move(mdpActionRewards))); } + // Create model components storm::storage::sparse::ModelComponents modelComponents(std::move(mdpTransitionMatrix), std::move(mdpLabeling), std::move(mdpRewardModels)); + + // Potentially create a choice labeling + if (!delayedExplorationChoices.empty()) { + modelComponents.choiceLabeling = storm::models::sparse::ChoiceLabeling(getCurrentNumberOfMdpChoices()); + delayedExplorationChoices.resize(getCurrentNumberOfMdpChoices(), false); + modelComponents.choiceLabeling->addLabel("delayed", std::move(delayedExplorationChoices)); + } + + // Create the final model. exploredMdp = std::make_shared>(std::move(modelComponents)); status = Status::ModelFinished; STORM_LOG_DEBUG("Explored Mdp with " << exploredMdp->getNumberOfStates() << " states (" << truncatedStates.getNumberOfSetBits() << " of which were flagged as truncated)."); @@ -579,6 +636,19 @@ namespace storm { } } + bool currentStateHasSuccessorObservationInObservationSet(uint64_t localActionIndex, storm::storage::BitVector const& observationSet) { + STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); + STORM_LOG_ASSERT(currentStateHasOldBehavior(), "Method call is invalid since the current state has no old behavior"); + uint64_t mdpChoice = getStartOfCurrentRowGroup() + localActionIndex; + for (auto const& entry : exploredMdp->getTransitionMatrix().getRow(mdpChoice)) { + auto const& beliefId = getBeliefId(entry.getColumn()); + if (observationSet.get(beliefManager->getBeliefObservation(beliefId))) { + return true; + } + } + return false; + } + void takeCurrentValuesAsUpperBounds() { STORM_LOG_ASSERT(status == Status::ModelChecked, "Method call is invalid in current status."); upperValueBounds = values; @@ -735,12 +805,13 @@ namespace storm { std::vector mdpActionRewards; uint64_t currentMdpState; - // Special states during exploration + // Special states and choices during exploration boost::optional extraTargetState; boost::optional extraBottomState; storm::storage::BitVector targetStates; storm::storage::BitVector truncatedStates; MdpStateType initialMdpState; + storm::storage::BitVector delayedExplorationChoices; // Final Mdp std::shared_ptr> exploredMdp; @@ -751,7 +822,7 @@ namespace storm { std::vector lowerValueBounds; std::vector upperValueBounds; std::vector values; // Contains an estimate during building and the actual result after a check has performed - boost::optional optimalMdpChoices; + boost::optional optimalChoices; boost::optional optimalChoicesReachableMdpStates; // The current status of this explorer diff --git a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp index 8aa83452a..09e0f8ce8 100644 --- a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp +++ b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp @@ -175,7 +175,13 @@ namespace storm { manager->setRewardModel(rewardModelName); } auto approx = std::make_shared(manager, lowerPomdpValueBounds, upperPomdpValueBounds); - buildOverApproximation(targetObservations, min, rewardModelName.is_initialized(), false, nullptr, observationResolutionVector, manager, approx); + HeuristicParameters heuristicParameters; + heuristicParameters.gapThreshold = storm::utility::convertNumber(options.explorationThreshold); + heuristicParameters.observationThreshold = storm::utility::zero(); // Not relevant without refinement + heuristicParameters.sizeThreshold = std::numeric_limits::max(); + heuristicParameters.optimalChoiceValueEpsilon = storm::utility::convertNumber(1e-4); + + buildOverApproximation(targetObservations, min, rewardModelName.is_initialized(), false, heuristicParameters, observationResolutionVector, manager, approx); if (approx->hasComputedValues()) { STORM_PRINT_AND_LOG("Explored and checked Over-Approximation MDP:\n"); approx->getExploredMdp()->printModelInformationToStream(std::cout); @@ -220,7 +226,12 @@ namespace storm { // OverApproximaion auto overApproximation = std::make_shared(overApproxBeliefManager, lowerPomdpValueBounds, upperPomdpValueBounds); - buildOverApproximation(targetObservations, min, rewardModelName.is_initialized(), false, nullptr, observationResolutionVector, overApproxBeliefManager, overApproximation); + HeuristicParameters heuristicParameters; + heuristicParameters.gapThreshold = storm::utility::convertNumber(options.explorationThreshold); + heuristicParameters.observationThreshold = storm::utility::zero(); // Will be set to lowest observation score automatically + heuristicParameters.sizeThreshold = std::numeric_limits::max(); + heuristicParameters.optimalChoiceValueEpsilon = storm::utility::convertNumber(1e-4); + buildOverApproximation(targetObservations, min, rewardModelName.is_initialized(), false, heuristicParameters, observationResolutionVector, overApproxBeliefManager, overApproximation); if (!overApproximation->hasComputedValues()) { return; } @@ -245,7 +256,6 @@ namespace storm { // ValueType lastMinScore = storm::utility::infinity(); // Start refinement statistics.refinementSteps = 0; - ValueType refinementAggressiveness = storm::utility::convertNumber(0.0); while (result.diff() > options.refinementPrecision) { if (storm::utility::resources::isTerminate()) { break; @@ -254,13 +264,15 @@ namespace storm { STORM_LOG_INFO("Starting refinement step " << statistics.refinementSteps.get() << ". Current difference between lower and upper bound is " << result.diff() << "."); // Refine over-approximation - STORM_LOG_DEBUG("Refining over-approximation with aggressiveness " << refinementAggressiveness << "."); if (min) { overApproximation->takeCurrentValuesAsLowerBounds(); } else { overApproximation->takeCurrentValuesAsUpperBounds(); } - buildOverApproximation(targetObservations, min, rewardModelName.is_initialized(), true, &refinementAggressiveness, observationResolutionVector, overApproxBeliefManager, overApproximation); + heuristicParameters.gapThreshold /= storm::utility::convertNumber(4); + heuristicParameters.sizeThreshold = overApproximation->getExploredMdp()->getNumberOfStates() * 4; + heuristicParameters.observationThreshold += storm::utility::convertNumber(0.1) * (storm::utility::one() - heuristicParameters.observationThreshold); + buildOverApproximation(targetObservations, min, rewardModelName.is_initialized(), true, heuristicParameters, observationResolutionVector, overApproxBeliefManager, overApproximation); if (overApproximation->hasComputedValues()) { overApproxValue = overApproximation->getComputedValueAtInitialState(); } else { @@ -269,7 +281,7 @@ namespace storm { if (result.diff() > options.refinementPrecision) { // Refine under-approximation - underApproxSizeThreshold = storm::utility::convertNumber(storm::utility::convertNumber(underApproxSizeThreshold) * (storm::utility::one() + refinementAggressiveness)); + underApproxSizeThreshold *= 4; underApproxSizeThreshold = std::max(underApproxSizeThreshold, overApproximation->getExploredMdp()->getNumberOfStates()); STORM_LOG_DEBUG("Refining under-approximation with size threshold " << underApproxSizeThreshold << "."); buildUnderApproximation(targetObservations, min, rewardModelName.is_initialized(), underApproxSizeThreshold, underApproxBeliefManager, underApproximation); @@ -309,30 +321,38 @@ namespace storm { template std::vector::ValueType> ApproximatePOMDPModelchecker::getObservationRatings(std::shared_ptr const& overApproximation, std::vector const& observationResolutionVector, uint64_t const& maxResolution) { - uint64_t numMdpChoices = overApproximation->getExploredMdp()->getNumberOfChoices(); + uint64_t numMdpStates = overApproximation->getExploredMdp()->getNumberOfStates(); + auto const& choiceIndices = overApproximation->getExploredMdp()->getNondeterministicChoiceIndices(); std::vector resultingRatings(pomdp.getNrObservations(), storm::utility::one()); std::map gatheredSuccessorObservations; // Declare here to avoid reallocations - for (uint64_t mdpChoice = 0; mdpChoice < numMdpChoices; ++mdpChoice) { - gatheredSuccessorObservations.clear(); - overApproximation->gatherSuccessorObservationInformationAtMdpChoice(mdpChoice, gatheredSuccessorObservations); - for (auto const& obsInfo : gatheredSuccessorObservations) { - auto const& obs = obsInfo.first; - ValueType obsChoiceRating = rateObservation(obsInfo.second, observationResolutionVector[obs], maxResolution); - - // The rating of the observation will be the minimum over all choice-based observation ratings - resultingRatings[obs] = std::min(resultingRatings[obs], obsChoiceRating); + for (uint64_t mdpState = 0; mdpState < numMdpStates; ++mdpState) { + // Check whether this state is reached under an optimal scheduler. + // The heuristic assumes that the remaining states are not relevant for the observation score. + if (overApproximation->stateIsOptimalSchedulerReachable(mdpState)) { + for (uint64_t mdpChoice = choiceIndices[mdpState]; mdpChoice < choiceIndices[mdpState + 1]; ++mdpChoice) { + // Similarly, only optimal actions are relevant + if (overApproximation->actionIsOptimal(mdpChoice)) { + // score the observations for this choice + gatheredSuccessorObservations.clear(); + overApproximation->gatherSuccessorObservationInformationAtMdpChoice(mdpChoice, gatheredSuccessorObservations); + for (auto const& obsInfo : gatheredSuccessorObservations) { + auto const& obs = obsInfo.first; + ValueType obsChoiceRating = rateObservation(obsInfo.second, observationResolutionVector[obs], maxResolution); + + // The rating of the observation will be the minimum over all choice-based observation ratings + resultingRatings[obs] = std::min(resultingRatings[obs], obsChoiceRating); + } + } + } } } return resultingRatings; } template - void ApproximatePOMDPModelchecker::buildOverApproximation(std::set const &targetObservations, bool min, bool computeRewards, bool refine, ValueType* refinementAggressiveness, std::vector& observationResolutionVector, std::shared_ptr& beliefManager, std::shared_ptr& overApproximation) { - STORM_LOG_ASSERT(!refine || refinementAggressiveness != nullptr, "Refinement enabled but no aggressiveness given"); - STORM_LOG_ASSERT(!refine || *refinementAggressiveness >= storm::utility::zero(), "Can not refine with negative aggressiveness."); - STORM_LOG_ASSERT(!refine || *refinementAggressiveness <= storm::utility::one(), "Refinement with aggressiveness > 1 is invalid."); + void ApproximatePOMDPModelchecker::buildOverApproximation(std::set const &targetObservations, bool min, bool computeRewards, bool refine, HeuristicParameters& heuristicParameters, std::vector& observationResolutionVector, std::shared_ptr& beliefManager, std::shared_ptr& overApproximation) { // current maximal resolution (needed for refinement heuristic) uint64_t oldMaxResolution = *std::max_element(observationResolutionVector.begin(), observationResolutionVector.end()); @@ -347,17 +367,18 @@ namespace storm { overApproximation->startNewExploration(storm::utility::one(), storm::utility::zero()); } } else { - // If we refine the existing overApproximation, we need to find out which observation resolutions need refinement. + // If we refine the existing overApproximation, our heuristic also wants to know which states are reachable under an optimal policy + overApproximation->computeOptimalChoicesAndReachableMdpStates(heuristicParameters.optimalChoiceValueEpsilon, true); + // We also need to find out which observation resolutions needs refinement. auto obsRatings = getObservationRatings(overApproximation, observationResolutionVector, oldMaxResolution); ValueType minRating = *std::min_element(obsRatings.begin(), obsRatings.end()); - // Potentially increase the aggressiveness so that at least one observation actually gets refinement. - *refinementAggressiveness = std::max(minRating, *refinementAggressiveness); - refinedObservations = storm::utility::vector::filter(obsRatings, [&refinementAggressiveness](ValueType const& r) { return r <= *refinementAggressiveness;}); + // Potentially increase the observationThreshold so that at least one observation actually gets refinement. + heuristicParameters.observationThreshold = std::max(minRating, heuristicParameters.observationThreshold); + refinedObservations = storm::utility::vector::filter(obsRatings, [&heuristicParameters](ValueType const& r) { return r <= heuristicParameters.observationThreshold;}); STORM_LOG_DEBUG("Refining the resolution of " << refinedObservations.getNumberOfSetBits() << "/" << refinedObservations.size() << " observations."); for (auto const& obs : refinedObservations) { - // Heuristically increment the resolution at the refined observations (also based on the refinementAggressiveness) - ValueType incrementValue = storm::utility::one() + (*refinementAggressiveness) * storm::utility::convertNumber(observationResolutionVector[obs]); - observationResolutionVector[obs] += storm::utility::convertNumber(storm::utility::ceil(incrementValue)); + // Increment the resolution at the refined observations + observationResolutionVector[obs] *= 2; } overApproximation->restartExploration(); } @@ -365,6 +386,7 @@ namespace storm { // Start exploration std::map gatheredSuccessorObservations; // Declare here to avoid reallocations + uint64_t numRewiredOrExploredStates = 0; while (overApproximation->hasUnexploredState()) { uint64_t currId = overApproximation->exploreNextState(); @@ -373,66 +395,124 @@ namespace storm { overApproximation->setCurrentStateIsTarget(); overApproximation->addSelfloopTransition(); } else { - bool stopExploration = false; - if (storm::utility::abs(overApproximation->getUpperValueBoundAtCurrentState() - overApproximation->getLowerValueBoundAtCurrentState()) < options.explorationThreshold) { - stopExploration = true; - overApproximation->setCurrentStateIsTruncated(); + // We need to decide how to treat this state (and each individual enabled action). There are the following cases: + // 1 The state has no old behavior and + // 1.1 we explore all actions or + // 1.2 we truncate all actions + // 2 The state has old behavior and was truncated in the last iteration and + // 2.1 we explore all actions or + // 2.2 we truncate all actions (essentially restoring old behavior, but we do the truncation step again to benefit from updated bounds) + // 3 The state has old behavior and was not truncated in the last iteration and the current action + // 3.1 should be rewired or + // 3.2 should get the old behavior but either + // 3.2.1 none of the successor observation has been refined since the last rewiring or exploration of this action + // 3.2.2 rewiring is only delayed as it could still have an effect in a later refinement step + + // Find out in which case we are + bool exploreAllActions = false; + bool truncateAllActions = false; + bool restoreAllActions = false; + bool checkRewireForAllActions = false; + ValueType gap = storm::utility::abs(overApproximation->getUpperValueBoundAtCurrentState() - overApproximation->getLowerValueBoundAtCurrentState()); + if (!refine || !overApproximation->currentStateHasOldBehavior()) { + // Case 1 + // If we explore this state and if it has no old behavior, it is clear that an "old" optimal scheduler can be extended to a scheduler that reaches this state + if (gap > heuristicParameters.gapThreshold && numRewiredOrExploredStates < heuristicParameters.sizeThreshold) { + exploreAllActions = true; // Case 1.1 + } else { + truncateAllActions = true; // Case 1.2 + overApproximation->setCurrentStateIsTruncated(); + } + } else { + if (overApproximation->getCurrentStateWasTruncated()) { + // Case 2 + if (overApproximation->currentStateIsOptimalSchedulerReachable() && gap > heuristicParameters.gapThreshold && numRewiredOrExploredStates < heuristicParameters.sizeThreshold) { + exploreAllActions = true; // Case 2.1 + } else { + truncateAllActions = true; // Case 2.2 + overApproximation->setCurrentStateIsTruncated(); + } + } else { + // Case 3 + // The decision for rewiring also depends on the corresponding action, but we have some criteria that lead to case 3.2 (independent of the action) + if (overApproximation->currentStateIsOptimalSchedulerReachable() && gap > heuristicParameters.gapThreshold && numRewiredOrExploredStates < heuristicParameters.sizeThreshold) { + checkRewireForAllActions = true; // Case 3.1 or Case 3.2 + } else { + restoreAllActions = true; // Definitely Case 3.2 + // We still need to check for each action whether rewiring makes sense later + checkRewireForAllActions = true; + } + } } + bool expandedAtLeastOneAction = false; for (uint64 action = 0, numActions = beliefManager->getBeliefNumberOfChoices(currId); action < numActions; ++action) { - // Check whether we expand this state/action pair - // We always expand if we are not doing refinement of if the state was not available in the "old" MDP. - // Otherwise, a heuristic decides. - bool expandStateAction = true; - if (refine && overApproximation->currentStateHasOldBehavior()) { - // Compute a rating of the current state/action pair - ValueType stateActionRating = storm::utility::one(); - gatheredSuccessorObservations.clear(); - overApproximation->gatherSuccessorObservationInformationAtCurrentState(action, gatheredSuccessorObservations); - for (auto const& obsInfo : gatheredSuccessorObservations) { - if (refinedObservations.get(obsInfo.first)) { - ValueType obsRating = rateObservation(obsInfo.second, observationResolutionVector[obsInfo.first], oldMaxResolution); - stateActionRating = std::min(stateActionRating, obsRating); + bool expandCurrentAction = exploreAllActions || truncateAllActions; + if (checkRewireForAllActions) { + assert(refine); + // In this case, we still need to check whether this action needs to be expanded + assert(!expandCurrentAction); + // Check the action dependent conditions for rewiring + // First, check whether this action has been rewired since the last refinement of one of the successor observations (i.e. whether rewiring would actually change the successor states) + assert(overApproximation->currentStateHasOldBehavior()); + if (overApproximation->getCurrentStateActionExplorationWasDelayed(action) || overApproximation->currentStateHasSuccessorObservationInObservationSet(action, refinedObservations)) { + // Then, check whether the other criteria for rewiring are satisfied + if (!restoreAllActions && overApproximation->actionAtCurrentStateWasOptimal(action)) { + // Do the rewiring now! (Case 3.1) + expandCurrentAction = true; + } else { + // Delay the rewiring (Case 3.2.2) + overApproximation->setCurrentChoiceIsDelayed(action); } - } - // Only refine if this rating is below the doubled refinementAggressiveness - expandStateAction = stateActionRating < storm::utility::convertNumber(2.0) * (*refinementAggressiveness); + } // else { Case 3.2.1 } } - if (expandStateAction) { - ValueType truncationProbability = storm::utility::zero(); - ValueType truncationValueBound = storm::utility::zero(); - auto successorGridPoints = beliefManager->expandAndTriangulate(currId, action, observationResolutionVector); - for (auto const& successor : successorGridPoints) { - bool added = overApproximation->addTransitionToBelief(action, successor.first, successor.second, stopExploration); - if (!added) { - STORM_LOG_ASSERT(stopExploration, "Didn't add a transition although exploration shouldn't be stopped."); - // We did not explore this successor state. Get a bound on the "missing" value - truncationProbability += successor.second; - truncationValueBound += successor.second * (min ? overApproximation->computeLowerValueBoundAtBelief(successor.first) : overApproximation->computeUpperValueBoundAtBelief(successor.first)); + + if (expandCurrentAction) { + expandedAtLeastOneAction = true; + if (!truncateAllActions) { + // Cases 1.1, 2.1, or 3.1 + auto successorGridPoints = beliefManager->expandAndTriangulate(currId, action, observationResolutionVector); + for (auto const& successor : successorGridPoints) { + overApproximation->addTransitionToBelief(action, successor.first, successor.second, false); } - } - if (stopExploration) { if (computeRewards) { + overApproximation->computeRewardAtCurrentState(action); + } + } else { + // Cases 1.2 or 2.2 + ValueType truncationProbability = storm::utility::zero(); + ValueType truncationValueBound = storm::utility::zero(); + auto successorGridPoints = beliefManager->expandAndTriangulate(currId, action, observationResolutionVector); + for (auto const& successor : successorGridPoints) { + bool added = overApproximation->addTransitionToBelief(action, successor.first, successor.second, true); + if (!added) { + // We did not explore this successor state. Get a bound on the "missing" value + truncationProbability += successor.second; + truncationValueBound += successor.second * (min ? overApproximation->computeLowerValueBoundAtBelief(successor.first) : overApproximation->computeUpperValueBoundAtBelief(successor.first)); + } + } + if (computeRewards) { + // The truncationValueBound will be added on top of the reward introduced by the current belief state. overApproximation->addTransitionsToExtraStates(action, truncationProbability); + overApproximation->computeRewardAtCurrentState(action, truncationValueBound); } else { overApproximation->addTransitionsToExtraStates(action, truncationValueBound, truncationProbability - truncationValueBound); } } - if (computeRewards) { - // The truncationValueBound will be added on top of the reward introduced by the current belief state. - overApproximation->computeRewardAtCurrentState(action, truncationValueBound); - } } else { - // Do not refine here + // Case 3.2 overApproximation->restoreOldBehaviorAtCurrentState(action); } } + if (expandedAtLeastOneAction) { + ++numRewiredOrExploredStates; + } } + if (storm::utility::resources::isTerminate()) { statistics.overApproximationBuildAborted = true; break; } } - // TODO: Drop unreachable states (sometimes?) statistics.overApproximationStates = overApproximation->getCurrentNumberOfMdpStates(); if (storm::utility::resources::isTerminate()) { statistics.overApproximationBuildTime.stop(); diff --git a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h index 823eebf60..8b892e1f1 100644 --- a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h +++ b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.h @@ -74,11 +74,18 @@ namespace storm { * @return A struct containing the final overapproximation (overApproxValue) and underapproximation (underApproxValue) values */ void refineReachability(std::set const &targetObservations, bool min, boost::optional rewardModelName, std::vector const& lowerPomdpValueBounds, std::vector const& upperPomdpValueBounds, Result& result); - + + struct HeuristicParameters { + ValueType gapThreshold; + ValueType observationThreshold; + uint64_t sizeThreshold; + ValueType optimalChoiceValueEpsilon; + }; + /** * Builds and checks an MDP that over-approximates the POMDP behavior, i.e. provides an upper bound for maximizing and a lower bound for minimizing properties */ - void buildOverApproximation(std::set const &targetObservations, bool min, bool computeRewards, bool refine, ValueType* refinementAggressiveness, std::vector& observationResolutionVector, std::shared_ptr& beliefManager, std::shared_ptr& overApproximation); + void buildOverApproximation(std::set const &targetObservations, bool min, bool computeRewards, bool refine, HeuristicParameters& heuristicParameters, std::vector& observationResolutionVector, std::shared_ptr& beliefManager, std::shared_ptr& overApproximation); /** * Builds and checks an MDP that under-approximates the POMDP behavior, i.e. provides a lower bound for maximizing and an upper bound for minimizing properties