diff --git a/src/storm-pomdp/builder/BeliefMdpExplorer.h b/src/storm-pomdp/builder/BeliefMdpExplorer.h index e13e20cf3..bb53c61c6 100644 --- a/src/storm-pomdp/builder/BeliefMdpExplorer.h +++ b/src/storm-pomdp/builder/BeliefMdpExplorer.h @@ -287,20 +287,53 @@ namespace storm { std::unordered_map> mdpRewardModels; if (!mdpActionRewards.empty()) { mdpActionRewards.resize(getCurrentNumberOfMdpChoices(), storm::utility::zero()); - mdpRewardModels.emplace("default", storm::models::sparse::StandardRewardModel(boost::optional>(), std::move(mdpActionRewards))); + mdpRewardModels.emplace("default", + storm::models::sparse::StandardRewardModel(boost::optional>(), std::move(mdpActionRewards))); } - + storm::storage::sparse::ModelComponents modelComponents(std::move(mdpTransitionMatrix), std::move(mdpLabeling), std::move(mdpRewardModels)); exploredMdp = std::make_shared>(std::move(modelComponents)); status = Status::ModelFinished; } - + + void dropUnreachableStates() { + STORM_LOG_ASSERT(status == Status::ModelFinished || status == Status::ModelChecked, "Method call is invalid in current status."); + storm::storage::BitVector reachableStates = storm::utility::graph::getReachableStates(getExploredMdp()->getTransitionMatrix(), + storm::storage::BitVector(getCurrentNumberOfMdpStates(), {initialMdpState}), + storm::storage::BitVector(getCurrentNumberOfMdpStates(), true), targetStates); + auto reachableTransitionMatrix = getExploredMdp()->getTransitionMatrix().getSubmatrix(true, reachableStates, reachableStates); + auto reachableStateLabeling = getExploredMdp()->getStateLabeling().getSubLabeling(reachableStates); + // TODO reward model + storm::storage::sparse::ModelComponents modelComponents(std::move(reachableTransitionMatrix), std::move(reachableStateLabeling)); + exploredMdp = std::make_shared>(std::move(modelComponents)); + + std::vector reachableMdpStateToBeliefIdMap(reachableStates.getNumberOfSetBits()); + std::vector reachableLowerValueBounds(reachableStates.getNumberOfSetBits()); + std::vector reachableUpperValueBounds(reachableStates.getNumberOfSetBits()); + std::vector reachableValues(reachableStates.getNumberOfSetBits()); + for (uint64_t state = 0; state < reachableStates.size(); ++state) { + if (reachableStates[state]) { + reachableMdpStateToBeliefIdMap.push_back(mdpStateToBeliefIdMap[state]); + reachableLowerValueBounds.push_back(lowerValueBounds[state]); + reachableUpperValueBounds.push_back(upperValueBounds[state]); + reachableValues.push_back(values[state]); + } + //TODO drop BeliefIds from exploredBeliefIDs? + } + std::map reachableBeliefIdToMdpStateMap; + for (MdpStateType state = 0; state < reachableMdpStateToBeliefIdMap.size(); ++state) { + reachableBeliefIdToMdpStateMap[reachableMdpStateToBeliefIdMap[state]] = state; + } + mdpStateToBeliefIdMap = reachableMdpStateToBeliefIdMap; + beliefIdToMdpStateMap = reachableBeliefIdToMdpStateMap; + } + std::shared_ptr> getExploredMdp() const { STORM_LOG_ASSERT(status == Status::ModelFinished || status == Status::ModelChecked, "Method call is invalid in current status."); STORM_LOG_ASSERT(exploredMdp, "Tried to get the explored MDP but exploration was not finished yet."); return exploredMdp; } - + MdpStateType getCurrentNumberOfMdpStates() const { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); return mdpStateToBeliefIdMap.size();