diff --git a/src/storm-pomdp/builder/BeliefMdpExplorer.h b/src/storm-pomdp/builder/BeliefMdpExplorer.h index 2a97c5e05..d59b770f0 100644 --- a/src/storm-pomdp/builder/BeliefMdpExplorer.h +++ b/src/storm-pomdp/builder/BeliefMdpExplorer.h @@ -62,33 +62,36 @@ namespace storm { exploredChoiceIndices.clear(); mdpActionRewards.clear(); exploredMdp = nullptr; - currentMdpState = noState(); - + internalAddRowGroupIndex(); // Mark the start of the first row group + // Add some states with special treatment (if requested) if (extraBottomStateValue) { - extraBottomState = getCurrentNumberOfMdpStates(); + currentMdpState = getCurrentNumberOfMdpStates(); + extraBottomState = currentMdpState; mdpStateToBeliefIdMap.push_back(beliefManager->noId()); insertValueHints(extraBottomStateValue.get(), extraBottomStateValue.get()); - internalAddRowGroupIndex(); internalAddTransition(getStartOfCurrentRowGroup(), extraBottomState.get(), storm::utility::one()); + internalAddRowGroupIndex(); } else { extraBottomState = boost::none; } if (extraTargetStateValue) { - extraTargetState = getCurrentNumberOfMdpStates(); + currentMdpState = getCurrentNumberOfMdpStates(); + extraTargetState = currentMdpState; mdpStateToBeliefIdMap.push_back(beliefManager->noId()); insertValueHints(extraTargetStateValue.get(), extraTargetStateValue.get()); - internalAddRowGroupIndex(); internalAddTransition(getStartOfCurrentRowGroup(), extraTargetState.get(), storm::utility::one()); + internalAddRowGroupIndex(); targetStates.grow(getCurrentNumberOfMdpStates(), false); targetStates.set(extraTargetState.get(), true); } else { extraTargetState = boost::none; } - + currentMdpState = noState(); + // Set up the initial state. initialMdpState = getOrAddMdpState(beliefManager->getInitialBelief()); } @@ -101,6 +104,7 @@ namespace storm { */ void restartExploration() { STORM_LOG_ASSERT(status == Status::ModelChecked || status == Status::ModelFinished, "Method call is invalid in current status."); + status = Status::Exploring; // We will not erase old states during the exploration phase, so most state-based data (like mappings between MDP and Belief states) remain valid. exploredBeliefIds.clear(); exploredBeliefIds.grow(beliefManager->getNumberOfBeliefIds(), false); @@ -124,6 +128,7 @@ namespace storm { if (extraTargetState) { currentMdpState = extraTargetState.get(); restoreOldBehaviorAtCurrentState(0); + targetStates.set(extraTargetState.get(), true); } currentMdpState = noState(); @@ -138,23 +143,22 @@ namespace storm { BeliefId exploreNextState() { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); + // Mark the end of the previously explored row group. + if (currentMdpState != noState() && !currentStateHasOldBehavior()) { + internalAddRowGroupIndex(); + } // Pop from the queue. currentMdpState = mdpStatesToExplore.front(); mdpStatesToExplore.pop_front(); - if (!currentStateHasOldBehavior()) { - internalAddRowGroupIndex(); - } return mdpStateToBeliefIdMap[currentMdpState]; } void addTransitionsToExtraStates(uint64_t const& localActionIndex, ValueType const& targetStateValue = storm::utility::zero(), ValueType const& bottomStateValue = storm::utility::zero()) { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); - // We first insert the entries of the current row in a separate map. - // This is to ensure that entries are sorted in the right way (as required for the transition matrix builder) - + STORM_LOG_ASSERT(!currentStateHasOldBehavior() || localActionIndex < exploredChoiceIndices[currentMdpState + 1] - exploredChoiceIndices[currentMdpState], "Action index " << localActionIndex << " was not valid at state " << currentMdpState << " of the previously explored MDP."); uint64_t row = getStartOfCurrentRowGroup() + localActionIndex; if (!storm::utility::isZero(bottomStateValue)) { STORM_LOG_ASSERT(extraBottomState.is_initialized(), "Requested a transition to the extra bottom state but there is none."); @@ -168,6 +172,7 @@ namespace storm { void addSelfloopTransition(uint64_t const& localActionIndex = 0, ValueType const& value = storm::utility::one()) { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); + STORM_LOG_ASSERT(!currentStateHasOldBehavior() || localActionIndex < exploredChoiceIndices[currentMdpState + 1] - exploredChoiceIndices[currentMdpState], "Action index " << localActionIndex << " was not valid at state " << currentMdpState << " of the previously explored MDP."); uint64_t row = getStartOfCurrentRowGroup() + localActionIndex; internalAddTransition(row, getCurrentMdpState(), value); } @@ -182,8 +187,8 @@ namespace storm { */ bool addTransitionToBelief(uint64_t const& localActionIndex, BeliefId const& transitionTarget, ValueType const& value, bool ignoreNewBeliefs) { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); - // We first insert the entries of the current row in a separate map. - // This is to ensure that entries are sorted in the right way (as required for the transition matrix builder) + STORM_LOG_ASSERT(!currentStateHasOldBehavior() || localActionIndex < exploredChoiceIndices[currentMdpState + 1] - exploredChoiceIndices[currentMdpState], "Action index " << localActionIndex << " was not valid at state " << currentMdpState << " of the previously explored MDP."); + MdpStateType column; if (ignoreNewBeliefs) { column = getExploredMdpState(transitionTarget); @@ -221,6 +226,7 @@ namespace storm { bool currentStateHasOldBehavior() { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); + STORM_LOG_ASSERT(getCurrentMdpState() != noState(), "Method 'currentStateHasOldBehavior' called but there is no current state."); return exploredMdp && getCurrentMdpState() < exploredMdp->getNumberOfStates(); } @@ -232,6 +238,8 @@ namespace storm { */ void restoreOldBehaviorAtCurrentState(uint64_t const& localActionIndex) { STORM_LOG_ASSERT(currentStateHasOldBehavior(), "Cannot restore old behavior as the current state does not have any."); + STORM_LOG_ASSERT(localActionIndex < exploredChoiceIndices[currentMdpState + 1] - exploredChoiceIndices[currentMdpState], "Action index " << localActionIndex << " was not valid at state " << currentMdpState << " of the previously explored MDP."); + uint64_t choiceIndex = exploredChoiceIndices[getCurrentMdpState()] + localActionIndex; STORM_LOG_ASSERT(choiceIndex < exploredChoiceIndices[getCurrentMdpState() + 1], "Invalid local action index."); @@ -255,10 +263,27 @@ namespace storm { void finishExploration() { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(!hasUnexploredState(), "Finishing exploration not possible if there are still unexplored states."); + + // Complete the exploration // Finish the last row grouping in case the last explored state was new if (!currentStateHasOldBehavior()) { internalAddRowGroupIndex(); } + // Resize state- and choice based vectors to the correct size + targetStates.resize(getCurrentNumberOfMdpStates(), false); + truncatedStates.resize(getCurrentNumberOfMdpStates(), false); + if (!mdpActionRewards.empty()) { + mdpActionRewards.resize(getCurrentNumberOfMdpChoices(), storm::utility::zero()); + } + + // We are not exploring anymore + currentMdpState = noState(); + + // If this was a restarted exploration, we might still have unexplored states (which were only reachable and explored in a previous build). + // We get rid of these before rebuilding the model + if (exploredMdp) { + dropUnexploredStates(); + } // Create the tranistion matrix uint64_t entryCount = 0; @@ -300,50 +325,101 @@ namespace storm { status = Status::ModelFinished; } - void dropUnreachableStates() { - STORM_LOG_ASSERT(status == Status::ModelFinished || status == Status::ModelChecked, "Method call is invalid in current status."); - auto reachableStates = storm::utility::graph::getReachableStates(getExploredMdp()->getTransitionMatrix(), - storm::storage::BitVector(getCurrentNumberOfMdpStates(), std::vector{initialMdpState}), - storm::storage::BitVector(getCurrentNumberOfMdpStates(), true), - getExploredMdp()->getStateLabeling().getStates("target")); - auto reachableTransitionMatrix = getExploredMdp()->getTransitionMatrix().getSubmatrix(true, reachableStates, reachableStates); - auto reachableStateLabeling = getExploredMdp()->getStateLabeling().getSubLabeling(reachableStates); - std::vector reachableMdpStateToBeliefIdMap(reachableStates.getNumberOfSetBits()); - std::vector reachableLowerValueBounds(reachableStates.getNumberOfSetBits()); - std::vector reachableUpperValueBounds(reachableStates.getNumberOfSetBits()); - std::vector reachableValues(reachableStates.getNumberOfSetBits()); - std::vector reachableMdpActionRewards; - for (uint64_t state = 0; state < reachableStates.size(); ++state) { - if (reachableStates[state]) { - reachableMdpStateToBeliefIdMap.push_back(mdpStateToBeliefIdMap[state]); - reachableLowerValueBounds.push_back(lowerValueBounds[state]); - reachableUpperValueBounds.push_back(upperValueBounds[state]); - reachableValues.push_back(values[state]); - if (getExploredMdp()->hasRewardModel()) { - //TODO FIXME is there some mismatch with the indices here? - for (uint64_t i = 0; i < getExploredMdp()->getTransitionMatrix().getRowGroupSize(state); ++i) { - reachableMdpActionRewards.push_back(getExploredMdp()->getUniqueRewardModel().getStateActionRewardVector()[state + i]); - } + void dropUnexploredStates() { + STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); + STORM_LOG_ASSERT(!hasUnexploredState(), "Finishing exploration not possible if there are still unexplored states."); + + STORM_LOG_ASSERT(exploredMdp, "Method called although no 'old' MDP is available."); + // Find the states (and corresponding choices) that were not explored. + // These correspond to "empty" MDP transitions + storm::storage::BitVector relevantMdpStates(getCurrentNumberOfMdpStates(), true), relevantMdpChoices(getCurrentNumberOfMdpChoices(), true); + std::vector toRelevantStateIndexMap(getCurrentNumberOfMdpStates(), noState()); + MdpStateType nextRelevantIndex = 0; + for (uint64_t groupIndex = 0; groupIndex < exploredChoiceIndices.size() - 1; ++groupIndex) { + uint64_t rowIndex = exploredChoiceIndices[groupIndex]; + // Check first row in group + if (exploredMdpTransitions[rowIndex].empty()) { + relevantMdpChoices.set(rowIndex, false); + relevantMdpStates.set(groupIndex, false); + } else { + toRelevantStateIndexMap[groupIndex] = nextRelevantIndex; + ++nextRelevantIndex; + } + uint64_t groupEnd = exploredChoiceIndices[groupIndex + 1]; + // process remaining rows in group + for (++rowIndex; rowIndex < groupEnd; ++rowIndex) { + // Assert that all actions at the current state were consistently explored or unexplored. + STORM_LOG_ASSERT(exploredMdpTransitions[rowIndex].empty() != relevantMdpStates.get(groupIndex), "Actions at 'old' MDP state " << groupIndex << " were only partly explored."); + if (exploredMdpTransitions[rowIndex].empty()) { + relevantMdpChoices.set(rowIndex, false); } } - //TODO drop BeliefIds from exploredBeliefIDs? } - std::unordered_map> mdpRewardModels; - if (!reachableMdpActionRewards.empty()) { - //reachableMdpActionRewards.resize(getCurrentNumberOfMdpChoices(), storm::utility::zero()); - mdpRewardModels.emplace("default", - storm::models::sparse::StandardRewardModel(boost::optional>(), std::move(reachableMdpActionRewards))); + + if (relevantMdpStates.full()) { + // All states are relevant so nothing to do + return; } - storm::storage::sparse::ModelComponents modelComponents(std::move(reachableTransitionMatrix), std::move(reachableStateLabeling), - std::move(mdpRewardModels)); - exploredMdp = std::make_shared>(std::move(modelComponents)); - - std::map reachableBeliefIdToMdpStateMap; - for (MdpStateType state = 0; state < reachableMdpStateToBeliefIdMap.size(); ++state) { - reachableBeliefIdToMdpStateMap[reachableMdpStateToBeliefIdMap[state]] = state; + + // Translate various components to the "new" MDP state set + storm::utility::vector::filterVectorInPlace(mdpStateToBeliefIdMap, relevantMdpStates); + { // beliefIdToMdpStateMap + for (auto belIdToMdpStateIt = beliefIdToMdpStateMap.begin(); belIdToMdpStateIt != beliefIdToMdpStateMap.end();) { + if (relevantMdpStates.get(belIdToMdpStateIt->second)) { + // Keep current entry and move on to the next one. + ++belIdToMdpStateIt; + } else { + STORM_LOG_ASSERT(!exploredBeliefIds.get(belIdToMdpStateIt->first), "Inconsistent exploration information: Unexplored MDPState corresponds to explored beliefId"); + // Delete current entry and move on to the next one. + // This works because std::map::erase does not invalidate other iterators within the map! + beliefIdToMdpStateMap.erase(belIdToMdpStateIt++); + } + } + } + { // exploredMdpTransitions + storm::utility::vector::filterVectorInPlace(exploredMdpTransitions, relevantMdpChoices); + // Adjust column indices. Unfortunately, the fastest way seems to be to "rebuild" the map + // It might payoff to do this when building the matrix. + for (auto& transitions : exploredMdpTransitions) { + std::map newTransitions; + for (auto const& entry : transitions) { + STORM_LOG_ASSERT(relevantMdpStates.get(entry.first), "Relevant state has transition to irrelevant state."); + newTransitions.emplace_hint(newTransitions.end(), toRelevantStateIndexMap[entry.first], entry.second); + } + transitions = std::move(newTransitions); + } } - mdpStateToBeliefIdMap = reachableMdpStateToBeliefIdMap; - beliefIdToMdpStateMap = reachableBeliefIdToMdpStateMap; + { // exploredChoiceIndices + MdpStateType newState = 0; + assert(exploredChoiceIndices[0] == 0u); + // Loop invariant: all indices up to exploredChoiceIndices[newState] consider the new row indices and all other entries are not touched. + for (auto const& oldState : relevantMdpStates) { + if (oldState != newState) { + assert(oldState > newState); + uint64_t groupSize = exploredChoiceIndices[oldState + 1] - exploredChoiceIndices[oldState]; + exploredChoiceIndices[newState + 1] = exploredChoiceIndices[newState] + groupSize; + } + ++newState; + } + exploredChoiceIndices.resize(newState + 1); + } + if (!mdpActionRewards.empty()) { + storm::utility::vector::filterVectorInPlace(mdpActionRewards, relevantMdpChoices); + } + if (extraBottomState) { + extraBottomState = toRelevantStateIndexMap[extraBottomState.get()]; + } + if (extraTargetState) { + extraTargetState = toRelevantStateIndexMap[extraTargetState.get()]; + } + targetStates = targetStates % relevantMdpStates; + truncatedStates = truncatedStates % relevantMdpStates; + initialMdpState = toRelevantStateIndexMap[initialMdpState]; + + storm::utility::vector::filterVectorInPlace(lowerValueBounds, relevantMdpStates); + storm::utility::vector::filterVectorInPlace(upperValueBounds, relevantMdpStates); + storm::utility::vector::filterVectorInPlace(values, relevantMdpStates); + } std::shared_ptr> getExploredMdp() const { @@ -364,7 +440,7 @@ namespace storm { MdpStateType getStartOfCurrentRowGroup() const { STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); - return exploredChoiceIndices.back(); + return exploredChoiceIndices[getCurrentMdpState()]; } ValueType getLowerValueBoundAtCurrentState() const {