Browse Source

BeliefMdpExplorer: Various bugfixes for exploration restarts. Unexplored (= unreachable) states are now dropped before building the MDP since we do not get a valid MDP otherwise.

tempestpy_adaptions
Tim Quatmann 5 years ago
parent
commit
961baa4386
  1. 184
      src/storm-pomdp/builder/BeliefMdpExplorer.h

184
src/storm-pomdp/builder/BeliefMdpExplorer.h

@ -62,32 +62,35 @@ namespace storm {
exploredChoiceIndices.clear(); exploredChoiceIndices.clear();
mdpActionRewards.clear(); mdpActionRewards.clear();
exploredMdp = nullptr; exploredMdp = nullptr;
currentMdpState = noState();
internalAddRowGroupIndex(); // Mark the start of the first row group
// Add some states with special treatment (if requested) // Add some states with special treatment (if requested)
if (extraBottomStateValue) { if (extraBottomStateValue) {
extraBottomState = getCurrentNumberOfMdpStates();
currentMdpState = getCurrentNumberOfMdpStates();
extraBottomState = currentMdpState;
mdpStateToBeliefIdMap.push_back(beliefManager->noId()); mdpStateToBeliefIdMap.push_back(beliefManager->noId());
insertValueHints(extraBottomStateValue.get(), extraBottomStateValue.get()); insertValueHints(extraBottomStateValue.get(), extraBottomStateValue.get());
internalAddRowGroupIndex();
internalAddTransition(getStartOfCurrentRowGroup(), extraBottomState.get(), storm::utility::one<ValueType>()); internalAddTransition(getStartOfCurrentRowGroup(), extraBottomState.get(), storm::utility::one<ValueType>());
internalAddRowGroupIndex();
} else { } else {
extraBottomState = boost::none; extraBottomState = boost::none;
} }
if (extraTargetStateValue) { if (extraTargetStateValue) {
extraTargetState = getCurrentNumberOfMdpStates();
currentMdpState = getCurrentNumberOfMdpStates();
extraTargetState = currentMdpState;
mdpStateToBeliefIdMap.push_back(beliefManager->noId()); mdpStateToBeliefIdMap.push_back(beliefManager->noId());
insertValueHints(extraTargetStateValue.get(), extraTargetStateValue.get()); insertValueHints(extraTargetStateValue.get(), extraTargetStateValue.get());
internalAddRowGroupIndex();
internalAddTransition(getStartOfCurrentRowGroup(), extraTargetState.get(), storm::utility::one<ValueType>()); internalAddTransition(getStartOfCurrentRowGroup(), extraTargetState.get(), storm::utility::one<ValueType>());
internalAddRowGroupIndex();
targetStates.grow(getCurrentNumberOfMdpStates(), false); targetStates.grow(getCurrentNumberOfMdpStates(), false);
targetStates.set(extraTargetState.get(), true); targetStates.set(extraTargetState.get(), true);
} else { } else {
extraTargetState = boost::none; extraTargetState = boost::none;
} }
currentMdpState = noState();
// Set up the initial state. // Set up the initial state.
initialMdpState = getOrAddMdpState(beliefManager->getInitialBelief()); initialMdpState = getOrAddMdpState(beliefManager->getInitialBelief());
@ -101,6 +104,7 @@ namespace storm {
*/ */
void restartExploration() { void restartExploration() {
STORM_LOG_ASSERT(status == Status::ModelChecked || status == Status::ModelFinished, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::ModelChecked || status == Status::ModelFinished, "Method call is invalid in current status.");
status = Status::Exploring;
// We will not erase old states during the exploration phase, so most state-based data (like mappings between MDP and Belief states) remain valid. // We will not erase old states during the exploration phase, so most state-based data (like mappings between MDP and Belief states) remain valid.
exploredBeliefIds.clear(); exploredBeliefIds.clear();
exploredBeliefIds.grow(beliefManager->getNumberOfBeliefIds(), false); exploredBeliefIds.grow(beliefManager->getNumberOfBeliefIds(), false);
@ -124,6 +128,7 @@ namespace storm {
if (extraTargetState) { if (extraTargetState) {
currentMdpState = extraTargetState.get(); currentMdpState = extraTargetState.get();
restoreOldBehaviorAtCurrentState(0); restoreOldBehaviorAtCurrentState(0);
targetStates.set(extraTargetState.get(), true);
} }
currentMdpState = noState(); currentMdpState = noState();
@ -138,23 +143,22 @@ namespace storm {
BeliefId exploreNextState() { BeliefId exploreNextState() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// Mark the end of the previously explored row group.
if (currentMdpState != noState() && !currentStateHasOldBehavior()) {
internalAddRowGroupIndex();
}
// Pop from the queue. // Pop from the queue.
currentMdpState = mdpStatesToExplore.front(); currentMdpState = mdpStatesToExplore.front();
mdpStatesToExplore.pop_front(); mdpStatesToExplore.pop_front();
if (!currentStateHasOldBehavior()) {
internalAddRowGroupIndex();
}
return mdpStateToBeliefIdMap[currentMdpState]; return mdpStateToBeliefIdMap[currentMdpState];
} }
void addTransitionsToExtraStates(uint64_t const& localActionIndex, ValueType const& targetStateValue = storm::utility::zero<ValueType>(), ValueType const& bottomStateValue = storm::utility::zero<ValueType>()) { void addTransitionsToExtraStates(uint64_t const& localActionIndex, ValueType const& targetStateValue = storm::utility::zero<ValueType>(), ValueType const& bottomStateValue = storm::utility::zero<ValueType>()) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// We first insert the entries of the current row in a separate map.
// This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
STORM_LOG_ASSERT(!currentStateHasOldBehavior() || localActionIndex < exploredChoiceIndices[currentMdpState + 1] - exploredChoiceIndices[currentMdpState], "Action index " << localActionIndex << " was not valid at state " << currentMdpState << " of the previously explored MDP.");
uint64_t row = getStartOfCurrentRowGroup() + localActionIndex; uint64_t row = getStartOfCurrentRowGroup() + localActionIndex;
if (!storm::utility::isZero(bottomStateValue)) { if (!storm::utility::isZero(bottomStateValue)) {
STORM_LOG_ASSERT(extraBottomState.is_initialized(), "Requested a transition to the extra bottom state but there is none."); STORM_LOG_ASSERT(extraBottomState.is_initialized(), "Requested a transition to the extra bottom state but there is none.");
@ -168,6 +172,7 @@ namespace storm {
void addSelfloopTransition(uint64_t const& localActionIndex = 0, ValueType const& value = storm::utility::one<ValueType>()) { void addSelfloopTransition(uint64_t const& localActionIndex = 0, ValueType const& value = storm::utility::one<ValueType>()) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
STORM_LOG_ASSERT(!currentStateHasOldBehavior() || localActionIndex < exploredChoiceIndices[currentMdpState + 1] - exploredChoiceIndices[currentMdpState], "Action index " << localActionIndex << " was not valid at state " << currentMdpState << " of the previously explored MDP.");
uint64_t row = getStartOfCurrentRowGroup() + localActionIndex; uint64_t row = getStartOfCurrentRowGroup() + localActionIndex;
internalAddTransition(row, getCurrentMdpState(), value); internalAddTransition(row, getCurrentMdpState(), value);
} }
@ -182,8 +187,8 @@ namespace storm {
*/ */
bool addTransitionToBelief(uint64_t const& localActionIndex, BeliefId const& transitionTarget, ValueType const& value, bool ignoreNewBeliefs) { bool addTransitionToBelief(uint64_t const& localActionIndex, BeliefId const& transitionTarget, ValueType const& value, bool ignoreNewBeliefs) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// We first insert the entries of the current row in a separate map.
// This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
STORM_LOG_ASSERT(!currentStateHasOldBehavior() || localActionIndex < exploredChoiceIndices[currentMdpState + 1] - exploredChoiceIndices[currentMdpState], "Action index " << localActionIndex << " was not valid at state " << currentMdpState << " of the previously explored MDP.");
MdpStateType column; MdpStateType column;
if (ignoreNewBeliefs) { if (ignoreNewBeliefs) {
column = getExploredMdpState(transitionTarget); column = getExploredMdpState(transitionTarget);
@ -221,6 +226,7 @@ namespace storm {
bool currentStateHasOldBehavior() { bool currentStateHasOldBehavior() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
STORM_LOG_ASSERT(getCurrentMdpState() != noState(), "Method 'currentStateHasOldBehavior' called but there is no current state.");
return exploredMdp && getCurrentMdpState() < exploredMdp->getNumberOfStates(); return exploredMdp && getCurrentMdpState() < exploredMdp->getNumberOfStates();
} }
@ -232,6 +238,8 @@ namespace storm {
*/ */
void restoreOldBehaviorAtCurrentState(uint64_t const& localActionIndex) { void restoreOldBehaviorAtCurrentState(uint64_t const& localActionIndex) {
STORM_LOG_ASSERT(currentStateHasOldBehavior(), "Cannot restore old behavior as the current state does not have any."); STORM_LOG_ASSERT(currentStateHasOldBehavior(), "Cannot restore old behavior as the current state does not have any.");
STORM_LOG_ASSERT(localActionIndex < exploredChoiceIndices[currentMdpState + 1] - exploredChoiceIndices[currentMdpState], "Action index " << localActionIndex << " was not valid at state " << currentMdpState << " of the previously explored MDP.");
uint64_t choiceIndex = exploredChoiceIndices[getCurrentMdpState()] + localActionIndex; uint64_t choiceIndex = exploredChoiceIndices[getCurrentMdpState()] + localActionIndex;
STORM_LOG_ASSERT(choiceIndex < exploredChoiceIndices[getCurrentMdpState() + 1], "Invalid local action index."); STORM_LOG_ASSERT(choiceIndex < exploredChoiceIndices[getCurrentMdpState() + 1], "Invalid local action index.");
@ -255,10 +263,27 @@ namespace storm {
void finishExploration() { void finishExploration() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
STORM_LOG_ASSERT(!hasUnexploredState(), "Finishing exploration not possible if there are still unexplored states."); STORM_LOG_ASSERT(!hasUnexploredState(), "Finishing exploration not possible if there are still unexplored states.");
// Complete the exploration
// Finish the last row grouping in case the last explored state was new // Finish the last row grouping in case the last explored state was new
if (!currentStateHasOldBehavior()) { if (!currentStateHasOldBehavior()) {
internalAddRowGroupIndex(); internalAddRowGroupIndex();
} }
// Resize state- and choice based vectors to the correct size
targetStates.resize(getCurrentNumberOfMdpStates(), false);
truncatedStates.resize(getCurrentNumberOfMdpStates(), false);
if (!mdpActionRewards.empty()) {
mdpActionRewards.resize(getCurrentNumberOfMdpChoices(), storm::utility::zero<ValueType>());
}
// We are not exploring anymore
currentMdpState = noState();
// If this was a restarted exploration, we might still have unexplored states (which were only reachable and explored in a previous build).
// We get rid of these before rebuilding the model
if (exploredMdp) {
dropUnexploredStates();
}
// Create the tranistion matrix // Create the tranistion matrix
uint64_t entryCount = 0; uint64_t entryCount = 0;
@ -300,50 +325,101 @@ namespace storm {
status = Status::ModelFinished; status = Status::ModelFinished;
} }
void dropUnreachableStates() {
STORM_LOG_ASSERT(status == Status::ModelFinished || status == Status::ModelChecked, "Method call is invalid in current status.");
auto reachableStates = storm::utility::graph::getReachableStates(getExploredMdp()->getTransitionMatrix(),
storm::storage::BitVector(getCurrentNumberOfMdpStates(), std::vector<uint64_t>{initialMdpState}),
storm::storage::BitVector(getCurrentNumberOfMdpStates(), true),
getExploredMdp()->getStateLabeling().getStates("target"));
auto reachableTransitionMatrix = getExploredMdp()->getTransitionMatrix().getSubmatrix(true, reachableStates, reachableStates);
auto reachableStateLabeling = getExploredMdp()->getStateLabeling().getSubLabeling(reachableStates);
std::vector<BeliefId> reachableMdpStateToBeliefIdMap(reachableStates.getNumberOfSetBits());
std::vector<ValueType> reachableLowerValueBounds(reachableStates.getNumberOfSetBits());
std::vector<ValueType> reachableUpperValueBounds(reachableStates.getNumberOfSetBits());
std::vector<ValueType> reachableValues(reachableStates.getNumberOfSetBits());
std::vector<ValueType> reachableMdpActionRewards;
for (uint64_t state = 0; state < reachableStates.size(); ++state) {
if (reachableStates[state]) {
reachableMdpStateToBeliefIdMap.push_back(mdpStateToBeliefIdMap[state]);
reachableLowerValueBounds.push_back(lowerValueBounds[state]);
reachableUpperValueBounds.push_back(upperValueBounds[state]);
reachableValues.push_back(values[state]);
if (getExploredMdp()->hasRewardModel()) {
//TODO FIXME is there some mismatch with the indices here?
for (uint64_t i = 0; i < getExploredMdp()->getTransitionMatrix().getRowGroupSize(state); ++i) {
reachableMdpActionRewards.push_back(getExploredMdp()->getUniqueRewardModel().getStateActionRewardVector()[state + i]);
}
}
}
//TODO drop BeliefIds from exploredBeliefIDs?
void dropUnexploredStates() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
STORM_LOG_ASSERT(!hasUnexploredState(), "Finishing exploration not possible if there are still unexplored states.");
STORM_LOG_ASSERT(exploredMdp, "Method called although no 'old' MDP is available.");
// Find the states (and corresponding choices) that were not explored.
// These correspond to "empty" MDP transitions
storm::storage::BitVector relevantMdpStates(getCurrentNumberOfMdpStates(), true), relevantMdpChoices(getCurrentNumberOfMdpChoices(), true);
std::vector<MdpStateType> toRelevantStateIndexMap(getCurrentNumberOfMdpStates(), noState());
MdpStateType nextRelevantIndex = 0;
for (uint64_t groupIndex = 0; groupIndex < exploredChoiceIndices.size() - 1; ++groupIndex) {
uint64_t rowIndex = exploredChoiceIndices[groupIndex];
// Check first row in group
if (exploredMdpTransitions[rowIndex].empty()) {
relevantMdpChoices.set(rowIndex, false);
relevantMdpStates.set(groupIndex, false);
} else {
toRelevantStateIndexMap[groupIndex] = nextRelevantIndex;
++nextRelevantIndex;
} }
std::unordered_map<std::string, storm::models::sparse::StandardRewardModel<ValueType>> mdpRewardModels;
if (!reachableMdpActionRewards.empty()) {
//reachableMdpActionRewards.resize(getCurrentNumberOfMdpChoices(), storm::utility::zero<ValueType>());
mdpRewardModels.emplace("default",
storm::models::sparse::StandardRewardModel<ValueType>(boost::optional<std::vector<ValueType>>(), std::move(reachableMdpActionRewards)));
uint64_t groupEnd = exploredChoiceIndices[groupIndex + 1];
// process remaining rows in group
for (++rowIndex; rowIndex < groupEnd; ++rowIndex) {
// Assert that all actions at the current state were consistently explored or unexplored.
STORM_LOG_ASSERT(exploredMdpTransitions[rowIndex].empty() != relevantMdpStates.get(groupIndex), "Actions at 'old' MDP state " << groupIndex << " were only partly explored.");
if (exploredMdpTransitions[rowIndex].empty()) {
relevantMdpChoices.set(rowIndex, false);
}
}
}
if (relevantMdpStates.full()) {
// All states are relevant so nothing to do
return;
} }
storm::storage::sparse::ModelComponents<ValueType> modelComponents(std::move(reachableTransitionMatrix), std::move(reachableStateLabeling),
std::move(mdpRewardModels));
exploredMdp = std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(modelComponents));
std::map<BeliefId, MdpStateType> reachableBeliefIdToMdpStateMap;
for (MdpStateType state = 0; state < reachableMdpStateToBeliefIdMap.size(); ++state) {
reachableBeliefIdToMdpStateMap[reachableMdpStateToBeliefIdMap[state]] = state;
// Translate various components to the "new" MDP state set
storm::utility::vector::filterVectorInPlace(mdpStateToBeliefIdMap, relevantMdpStates);
{ // beliefIdToMdpStateMap
for (auto belIdToMdpStateIt = beliefIdToMdpStateMap.begin(); belIdToMdpStateIt != beliefIdToMdpStateMap.end();) {
if (relevantMdpStates.get(belIdToMdpStateIt->second)) {
// Keep current entry and move on to the next one.
++belIdToMdpStateIt;
} else {
STORM_LOG_ASSERT(!exploredBeliefIds.get(belIdToMdpStateIt->first), "Inconsistent exploration information: Unexplored MDPState corresponds to explored beliefId");
// Delete current entry and move on to the next one.
// This works because std::map::erase does not invalidate other iterators within the map!
beliefIdToMdpStateMap.erase(belIdToMdpStateIt++);
}
}
} }
mdpStateToBeliefIdMap = reachableMdpStateToBeliefIdMap;
beliefIdToMdpStateMap = reachableBeliefIdToMdpStateMap;
{ // exploredMdpTransitions
storm::utility::vector::filterVectorInPlace(exploredMdpTransitions, relevantMdpChoices);
// Adjust column indices. Unfortunately, the fastest way seems to be to "rebuild" the map
// It might payoff to do this when building the matrix.
for (auto& transitions : exploredMdpTransitions) {
std::map<MdpStateType, ValueType> newTransitions;
for (auto const& entry : transitions) {
STORM_LOG_ASSERT(relevantMdpStates.get(entry.first), "Relevant state has transition to irrelevant state.");
newTransitions.emplace_hint(newTransitions.end(), toRelevantStateIndexMap[entry.first], entry.second);
}
transitions = std::move(newTransitions);
}
}
{ // exploredChoiceIndices
MdpStateType newState = 0;
assert(exploredChoiceIndices[0] == 0u);
// Loop invariant: all indices up to exploredChoiceIndices[newState] consider the new row indices and all other entries are not touched.
for (auto const& oldState : relevantMdpStates) {
if (oldState != newState) {
assert(oldState > newState);
uint64_t groupSize = exploredChoiceIndices[oldState + 1] - exploredChoiceIndices[oldState];
exploredChoiceIndices[newState + 1] = exploredChoiceIndices[newState] + groupSize;
}
++newState;
}
exploredChoiceIndices.resize(newState + 1);
}
if (!mdpActionRewards.empty()) {
storm::utility::vector::filterVectorInPlace(mdpActionRewards, relevantMdpChoices);
}
if (extraBottomState) {
extraBottomState = toRelevantStateIndexMap[extraBottomState.get()];
}
if (extraTargetState) {
extraTargetState = toRelevantStateIndexMap[extraTargetState.get()];
}
targetStates = targetStates % relevantMdpStates;
truncatedStates = truncatedStates % relevantMdpStates;
initialMdpState = toRelevantStateIndexMap[initialMdpState];
storm::utility::vector::filterVectorInPlace(lowerValueBounds, relevantMdpStates);
storm::utility::vector::filterVectorInPlace(upperValueBounds, relevantMdpStates);
storm::utility::vector::filterVectorInPlace(values, relevantMdpStates);
} }
std::shared_ptr<storm::models::sparse::Mdp<ValueType>> getExploredMdp() const { std::shared_ptr<storm::models::sparse::Mdp<ValueType>> getExploredMdp() const {
@ -364,7 +440,7 @@ namespace storm {
MdpStateType getStartOfCurrentRowGroup() const { MdpStateType getStartOfCurrentRowGroup() const {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return exploredChoiceIndices.back();
return exploredChoiceIndices[getCurrentMdpState()];
} }
ValueType getLowerValueBoundAtCurrentState() const { ValueType getLowerValueBoundAtCurrentState() const {

Loading…
Cancel
Save