From 62c905fc583ef215d3dcd885181f0d99f816a5bc Mon Sep 17 00:00:00 2001 From: Alexander Bork Date: Thu, 2 Apr 2020 20:05:00 +0200 Subject: [PATCH] Added basis for rewards in dropUnreachableStates() --- src/storm-pomdp/builder/BeliefMdpExplorer.h | 28 +++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/storm-pomdp/builder/BeliefMdpExplorer.h b/src/storm-pomdp/builder/BeliefMdpExplorer.h index bb53c61c6..426eff188 100644 --- a/src/storm-pomdp/builder/BeliefMdpExplorer.h +++ b/src/storm-pomdp/builder/BeliefMdpExplorer.h @@ -298,28 +298,42 @@ namespace storm { void dropUnreachableStates() { STORM_LOG_ASSERT(status == Status::ModelFinished || status == Status::ModelChecked, "Method call is invalid in current status."); - storm::storage::BitVector reachableStates = storm::utility::graph::getReachableStates(getExploredMdp()->getTransitionMatrix(), - storm::storage::BitVector(getCurrentNumberOfMdpStates(), {initialMdpState}), - storm::storage::BitVector(getCurrentNumberOfMdpStates(), true), targetStates); + auto reachableStates = storm::utility::graph::getReachableStates(getExploredMdp()->getTransitionMatrix(), + storm::storage::BitVector(getCurrentNumberOfMdpStates(), std::vector{initialMdpState}), + storm::storage::BitVector(getCurrentNumberOfMdpStates(), true), + getExploredMdp()->getStateLabeling().getStates("target")); auto reachableTransitionMatrix = getExploredMdp()->getTransitionMatrix().getSubmatrix(true, reachableStates, reachableStates); auto reachableStateLabeling = getExploredMdp()->getStateLabeling().getSubLabeling(reachableStates); - // TODO reward model - storm::storage::sparse::ModelComponents modelComponents(std::move(reachableTransitionMatrix), std::move(reachableStateLabeling)); - exploredMdp = std::make_shared>(std::move(modelComponents)); - std::vector reachableMdpStateToBeliefIdMap(reachableStates.getNumberOfSetBits()); std::vector reachableLowerValueBounds(reachableStates.getNumberOfSetBits()); std::vector reachableUpperValueBounds(reachableStates.getNumberOfSetBits()); std::vector reachableValues(reachableStates.getNumberOfSetBits()); + std::vector reachableMdpActionRewards; for (uint64_t state = 0; state < reachableStates.size(); ++state) { if (reachableStates[state]) { reachableMdpStateToBeliefIdMap.push_back(mdpStateToBeliefIdMap[state]); reachableLowerValueBounds.push_back(lowerValueBounds[state]); reachableUpperValueBounds.push_back(upperValueBounds[state]); reachableValues.push_back(values[state]); + if (getExploredMdp()->hasRewardModel()) { + //TODO FIXME is there some mismatch with the indices here? + for (uint64_t i = 0; i < getExploredMdp()->getTransitionMatrix().getRowGroupSize(state); ++i) { + reachableMdpActionRewards.push_back(getExploredMdp()->getUniqueRewardModel().getStateActionRewardVector()[state + i]); + } + } } //TODO drop BeliefIds from exploredBeliefIDs? } + std::unordered_map> mdpRewardModels; + if (!reachableMdpActionRewards.empty()) { + //reachableMdpActionRewards.resize(getCurrentNumberOfMdpChoices(), storm::utility::zero()); + mdpRewardModels.emplace("default", + storm::models::sparse::StandardRewardModel(boost::optional>(), std::move(reachableMdpActionRewards))); + } + storm::storage::sparse::ModelComponents modelComponents(std::move(reachableTransitionMatrix), std::move(reachableStateLabeling), + std::move(mdpRewardModels)); + exploredMdp = std::make_shared>(std::move(modelComponents)); + std::map reachableBeliefIdToMdpStateMap; for (MdpStateType state = 0; state < reachableMdpStateToBeliefIdMap.size(); ++state) { reachableBeliefIdToMdpStateMap[reachableMdpStateToBeliefIdMap[state]] = state;