diff --git a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp index 9a88e5d8e..91971def3 100644 --- a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp +++ b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp @@ -552,10 +552,21 @@ namespace storm { boost::optional> overApproximationMap, boost::optional> underApproximationMap, uint64_t maxUaModelSize) { + bool initialBoundMapsSet = overApproximationMap && underApproximationMap; + std::map initialOverMap; + std::map initialUnderMap; + if (initialBoundMapsSet) { + initialOverMap = overApproximationMap.value(); + initialUnderMap = underApproximationMap.value(); + } // Note that a persistent cache is not support by the current data structure. The resolution for the given belief also has to be stored somewhere to cache effectively std::map>> subSimplexCache; std::map> lambdaCache; + // Map to save the weighted values resulting from the initial preprocessing for newly added beliefs / indices in beliefSpace + std::map weightedSumOverMap; + std::map weightedSumUnderMap; + uint64_t nextBeliefId = refinementComponents->beliefList.size(); uint64_t nextStateId = refinementComponents->overApproxModelPtr->getNumberOfStates(); std::set relevantStates; @@ -622,18 +633,16 @@ namespace storm { refinementComponents->beliefGrid.push_back(gridBelief); refinementComponents->beliefIsTarget.push_back(targetObservations.find(observation) != targetObservations.end()); // compute overapproximate value using MDP result map - //TODO do this - /* - if (boundMapsSet) { + if (initialBoundMapsSet) { auto tempWeightedSumOver = storm::utility::zero(); auto tempWeightedSumUnder = storm::utility::zero(); for (uint64_t i = 0; i < subSimplex[j].size(); ++i) { - tempWeightedSumOver += subSimplex[j][i] * storm::utility::convertNumber(overMap[i]); - tempWeightedSumUnder += subSimplex[j][i] * storm::utility::convertNumber(underMap[i]); + tempWeightedSumOver += subSimplex[j][i] * storm::utility::convertNumber(initialOverMap[i]); + tempWeightedSumUnder += subSimplex[j][i] * storm::utility::convertNumber(initialUnderMap[i]); } - weightedSumOverMap[nextId] = tempWeightedSumOver; - weightedSumUnderMap[nextId] = tempWeightedSumUnder; - } */ + weightedSumOverMap[nextBeliefId] = tempWeightedSumOver; + weightedSumUnderMap[nextBeliefId] = tempWeightedSumUnder; + } beliefsToBeExpanded.push_back(nextBeliefId); refinementComponents->overApproxBeliefStateMap.insert(bsmap_type::value_type(nextBeliefId, nextStateId)); transitionInActionBelief[nextStateId] = iter->second * lambdas[j]; @@ -654,17 +663,25 @@ namespace storm { transitionsStateActionPair[stateActionPair] = transitionInActionBelief; } } + + std::set stoppedExplorationStateSet; + // Expand newly added beliefs while (!beliefsToBeExpanded.empty()) { uint64_t currId = beliefsToBeExpanded.front(); beliefsToBeExpanded.pop_front(); bool isTarget = refinementComponents->beliefIsTarget[currId]; - /* TODO - if (boundMapsSet && cc.isLess(weightedSumOverMap[currId] - weightedSumUnderMap[currId], storm::utility::convertNumber(options.explorationThreshold))) { - mdpTransitions.push_back({{{1, weightedSumOverMap[currId]}, {0, storm::utility::one() - weightedSumOverMap[currId]}}}); + if (initialBoundMapsSet && + cc.isLess(weightedSumOverMap[currId] - weightedSumUnderMap[currId], storm::utility::convertNumber(options.explorationThreshold))) { + STORM_PRINT("Stop Exploration in State " << refinementComponents->overApproxBeliefStateMap.left.at(currId) << " with Value " << weightedSumOverMap[currId] + << std::endl) + transitionsStateActionPair[std::make_pair(refinementComponents->overApproxBeliefStateMap.left.at(currId), 0)] = {{1, weightedSumOverMap[currId]}, + {0, storm::utility::one() - + weightedSumOverMap[currId]}}; + stoppedExplorationStateSet.insert(refinementComponents->overApproxBeliefStateMap.left.at(currId)); continue; - }*/ + } if (isTarget) { // Depending on whether we compute rewards, we select the right initial result @@ -690,21 +707,21 @@ namespace storm { //Triangulate here and put the possibly resulting belief in the grid std::vector> subSimplex; std::vector lambdas; - /* TODO Caching + if (options.cacheSubsimplices && subSimplexCache.count(idNextBelief) > 0) { subSimplex = subSimplexCache[idNextBelief]; lambdas = lambdaCache[idNextBelief]; - } else { */ - auto temp = computeSubSimplexAndLambdas(refinementComponents->beliefList[idNextBelief].probabilities, - observationResolutionVector[refinementComponents->beliefList[idNextBelief].observation], - pomdp.getNumberOfStates()); - subSimplex = temp.first; - lambdas = temp.second; - /*if (options.cacheSubsimplices) { - subSimplexCache[idNextBelief] = subSimplex; - lambdaCache[idNextBelief] = lambdas; + } else { + auto temp = computeSubSimplexAndLambdas(refinementComponents->beliefList[idNextBelief].probabilities, + observationResolutionVector[refinementComponents->beliefList[idNextBelief].observation], + pomdp.getNumberOfStates()); + subSimplex = temp.first; + lambdas = temp.second; + if (options.cacheSubsimplices) { + subSimplexCache[idNextBelief] = subSimplex; + lambdaCache[idNextBelief] = lambdas; + } } - }*/ for (size_t j = 0; j < lambdas.size(); ++j) { if (!cc.isEqual(lambdas[j], storm::utility::zero())) { @@ -716,17 +733,16 @@ namespace storm { refinementComponents->beliefGrid.push_back(gridBelief); refinementComponents->beliefIsTarget.push_back(targetObservations.find(observation) != targetObservations.end()); // compute overapproximate value using MDP result map - /* - if (boundMapsSet) { + if (initialBoundMapsSet) { auto tempWeightedSumOver = storm::utility::zero(); auto tempWeightedSumUnder = storm::utility::zero(); for (uint64_t i = 0; i < subSimplex[j].size(); ++i) { - tempWeightedSumOver += subSimplex[j][i] * storm::utility::convertNumber(overMap[i]); - tempWeightedSumUnder += subSimplex[j][i] * storm::utility::convertNumber(underMap[i]); + tempWeightedSumOver += subSimplex[j][i] * storm::utility::convertNumber(initialOverMap[i]); + tempWeightedSumUnder += subSimplex[j][i] * storm::utility::convertNumber(initialUnderMap[i]); } - weightedSumOverMap[nextId] = tempWeightedSumOver; - weightedSumUnderMap[nextId] = tempWeightedSumUnder; - } */ + weightedSumOverMap[nextBeliefId] = tempWeightedSumOver; + weightedSumUnderMap[nextBeliefId] = tempWeightedSumUnder; + } beliefsToBeExpanded.push_back(nextBeliefId); refinementComponents->overApproxBeliefStateMap.insert(bsmap_type::value_type(nextBeliefId, nextStateId)); transitionInActionBelief[nextStateId] = iter->second * lambdas[j]; @@ -750,15 +766,7 @@ namespace storm { /* if (computeRewards) { beliefActionRewards.emplace(std::make_pair(currId, actionRewardsInState)); - } - - - if (transitionsInBelief.empty()) { - std::map transitionInActionBelief; - transitionInActionBelief[beliefStateMap.left.at(currId)] = storm::utility::one(); - transitionsInBelief.push_back(transitionInActionBelief); - } - mdpTransitions.push_back(transitionsInBelief);*/ + }*/ } } @@ -766,7 +774,7 @@ namespace storm { mdpLabeling.addLabel("init"); mdpLabeling.addLabel("target"); mdpLabeling.addLabelToState("init", refinementComponents->overApproxBeliefStateMap.left.at(refinementComponents->initialBeliefId)); - + mdpLabeling.addLabelToState("target", 1); uint_fast64_t currentRow = 0; uint_fast64_t currentRowGroup = 0; storm::storage::SparseMatrixBuilder smb(0, nextStateId, 0, false, true); @@ -801,6 +809,9 @@ namespace storm { mdpLabeling.addLabelToState("target", state); break; } + if (stoppedExplorationStateSet.find(state) != stoppedExplorationStateSet.end()) { + break; + } } ++currentRowGroup; }