diff --git a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp index c18c59961..fb5264ca8 100644 --- a/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp +++ b/src/storm-pomdp/modelchecker/ApproximatePOMDPModelchecker.cpp @@ -387,6 +387,7 @@ namespace storm { hintVector[extraTargetState] = storm::utility::one(); } std::vector targetStates = {extraTargetState}; + storm::storage::BitVector fullyExpandedStates; // Map to save the weighted values resulting from the preprocessing for the beliefs / indices in beliefSpace std::map weightedSumOverMap; @@ -441,9 +442,8 @@ namespace storm { beliefsToBeExpanded.pop_front(); uint64_t currMdpState = beliefStateMap.left.at(currId); - auto const& currBelief = beliefGrid.getGridPoint(currId); - uint32_t currObservation = beliefGrid.getBeliefObservation(currBelief); - + uint32_t currObservation = beliefGrid.getBeliefObservation(currId); + mdpTransitionsBuilder.newRowGroup(mdpMatrixRow); if (targetObservations.count(currObservation) != 0) { @@ -457,8 +457,9 @@ namespace storm { mdpTransitionsBuilder.addNextValue(mdpMatrixRow, extraBottomState, storm::utility::one() - weightedSumOverMap[currId]); ++mdpMatrixRow; } else { - auto const& currBelief = beliefGrid.getGridPoint(currId); - uint64_t someState = currBelief.begin()->first; + fullyExpandedStates.grow(nextMdpStateId, false); + fullyExpandedStates.set(currMdpState, true); + uint64_t someState = beliefGrid.getGridPoint(currId).begin()->first; uint64_t numChoices = pomdp.getNumberOfChoices(someState); for (uint64_t action = 0; action < numChoices; ++action) { @@ -507,6 +508,7 @@ namespace storm { statistics.overApproximationBuildTime.stop(); return nullptr; } + fullyExpandedStates.resize(nextMdpStateId, false); storm::models::sparse::StateLabeling mdpLabeling(nextMdpStateId); mdpLabeling.addLabel("init"); @@ -520,13 +522,15 @@ namespace storm { if (computeRewards) { storm::models::sparse::StandardRewardModel mdpRewardModel(boost::none, std::vector(mdpMatrixRow)); for (auto const &iter : beliefStateMap.left) { - auto currentBelief = beliefGrid.getGridPoint(iter.first); - auto representativeState = currentBelief.begin()->first; - for (uint64_t action = 0; action < overApproxMdp->getNumberOfChoices(representativeState); ++action) { - // Add the reward - uint64_t mdpChoice = overApproxMdp->getChoiceIndex(storm::storage::StateActionPair(iter.second, action)); - uint64_t pomdpChoice = pomdp.getChoiceIndex(storm::storage::StateActionPair(representativeState, action)); - mdpRewardModel.setStateActionReward(mdpChoice, getRewardAfterAction(pomdpChoice, currentBelief)); + if (fullyExpandedStates.get(iter.second)) { + auto currentBelief = beliefGrid.getGridPoint(iter.first); + auto representativeState = currentBelief.begin()->first; + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(representativeState); ++action) { + // Add the reward + uint64_t mdpChoice = overApproxMdp->getChoiceIndex(storm::storage::StateActionPair(iter.second, action)); + uint64_t pomdpChoice = pomdp.getChoiceIndex(storm::storage::StateActionPair(representativeState, action)); + mdpRewardModel.setStateActionReward(mdpChoice, getRewardAfterAction(pomdpChoice, currentBelief)); + } } } overApproxMdp->addRewardModel("default", mdpRewardModel); @@ -1076,7 +1080,8 @@ namespace storm { ++mdpMatrixRow; } std::vector targetStates = {extraTargetState}; - + storm::storage::BitVector fullyExpandedStates; + bsmap_type beliefStateMap; std::deque beliefsToBeExpanded; @@ -1106,11 +1111,11 @@ namespace storm { mdpTransitionsBuilder.addNextValue(mdpMatrixRow, currMdpState, storm::utility::one()); ++mdpMatrixRow; } else if (currMdpState > maxModelSize) { - // In other cases, this could be helpflull as well. if (min) { // Get an upper bound here if (computeRewards) { // TODO: With minimizing rewards we need an upper bound! + // In other cases, this could be helpflull as well. // For now, add a selfloop to "generate" infinite reward mdpTransitionsBuilder.addNextValue(mdpMatrixRow, currMdpState, storm::utility::one()); } else { @@ -1121,6 +1126,8 @@ namespace storm { } ++mdpMatrixRow; } else { + fullyExpandedStates.grow(nextMdpStateId, false); + fullyExpandedStates.set(currMdpState, true); // Iterate over all actions and add the corresponding transitions uint64_t someState = currBelief.begin()->first; uint64_t numChoices = pomdp.getNumberOfChoices(someState); @@ -1153,7 +1160,7 @@ namespace storm { statistics.underApproximationBuildTime.stop(); return nullptr; } - + fullyExpandedStates.resize(nextMdpStateId, false); storm::models::sparse::StateLabeling mdpLabeling(nextMdpStateId); mdpLabeling.addLabel("init"); mdpLabeling.addLabel("target"); @@ -1167,13 +1174,15 @@ namespace storm { if (computeRewards) { storm::models::sparse::StandardRewardModel mdpRewardModel(boost::none, std::vector(mdpMatrixRow)); for (auto const &iter : beliefStateMap.left) { - auto currentBelief = beliefGrid.getGridPoint(iter.first); - auto representativeState = currentBelief.begin()->first; - for (uint64_t action = 0; action < model->getNumberOfChoices(representativeState); ++action) { - // Add the reward - uint64_t mdpChoice = model->getChoiceIndex(storm::storage::StateActionPair(iter.second, action)); - uint64_t pomdpChoice = pomdp.getChoiceIndex(storm::storage::StateActionPair(representativeState, action)); - mdpRewardModel.setStateActionReward(mdpChoice, getRewardAfterAction(pomdpChoice, currentBelief)); + if (fullyExpandedStates.get(iter.second)) { + auto currentBelief = beliefGrid.getGridPoint(iter.first); + auto representativeState = currentBelief.begin()->first; + for (uint64_t action = 0; action < pomdp.getNumberOfChoices(representativeState); ++action) { + // Add the reward + uint64_t mdpChoice = model->getChoiceIndex(storm::storage::StateActionPair(iter.second, action)); + uint64_t pomdpChoice = pomdp.getChoiceIndex(storm::storage::StateActionPair(representativeState, action)); + mdpRewardModel.setStateActionReward(mdpChoice, getRewardAfterAction(pomdpChoice, currentBelief)); + } } } model->addRewardModel("default", mdpRewardModel);