Browse Source

Started to make the BeliefMdpExplorer more flexible, allowing to restart the exploration

tempestpy_adaptions
Tim Quatmann 5 years ago
parent
commit
79641ef131
  1. 224
      src/storm-pomdp/builder/BeliefMdpExplorer.h
  2. 4
      src/storm-pomdp/storage/BeliefManager.h

224
src/storm-pomdp/builder/BeliefMdpExplorer.h

@ -11,6 +11,7 @@
#include "storm/api/verification.h" #include "storm/api/verification.h"
#include "storm/storage/BitVector.h" #include "storm/storage/BitVector.h"
#include "storm/storage/SparseMatrix.h"
#include "storm/utility/macros.h" #include "storm/utility/macros.h"
#include "storm-pomdp/storage/BeliefManager.h" #include "storm-pomdp/storage/BeliefManager.h"
#include "storm/utility/SignalHandler.h" #include "storm/utility/SignalHandler.h"
@ -19,6 +20,7 @@
#include "storm/modelchecker/results/ExplicitQuantitativeCheckResult.h" #include "storm/modelchecker/results/ExplicitQuantitativeCheckResult.h"
#include "storm/modelchecker/hints/ExplicitModelCheckerHint.cpp" #include "storm/modelchecker/hints/ExplicitModelCheckerHint.cpp"
namespace storm { namespace storm {
namespace builder { namespace builder {
template<typename PomdpType, typename BeliefValueType = typename PomdpType::ValueType> template<typename PomdpType, typename BeliefValueType = typename PomdpType::ValueType>
@ -46,16 +48,17 @@ namespace storm {
// Reset data from potential previous explorations // Reset data from potential previous explorations
mdpStateToBeliefIdMap.clear(); mdpStateToBeliefIdMap.clear();
beliefIdToMdpStateMap.clear(); beliefIdToMdpStateMap.clear();
beliefIdsWithMdpState.clear();
beliefIdsWithMdpState.grow(beliefManager->getNumberOfBeliefIds(), false);
exploredBeliefIds.clear();
exploredBeliefIds.grow(beliefManager->getNumberOfBeliefIds(), false);
mdpStatesToExplore.clear();
lowerValueBounds.clear(); lowerValueBounds.clear();
upperValueBounds.clear(); upperValueBounds.clear();
values.clear(); values.clear();
mdpTransitionsBuilder = storm::storage::SparseMatrixBuilder<ValueType>(0, 0, 0, true, true);
currentRowCount = 0;
startOfCurrentRowGroup = 0;
exploredMdpTransitions.clear();
exploredChoiceIndices.clear();
mdpActionRewards.clear(); mdpActionRewards.clear();
exploredMdp = nullptr; exploredMdp = nullptr;
currentMdpState = noState();
// Add some states with special treatment (if requested) // Add some states with special treatment (if requested)
if (extraBottomStateValue) { if (extraBottomStateValue) {
@ -63,10 +66,8 @@ namespace storm {
mdpStateToBeliefIdMap.push_back(beliefManager->noId()); mdpStateToBeliefIdMap.push_back(beliefManager->noId());
insertValueHints(extraBottomStateValue.get(), extraBottomStateValue.get()); insertValueHints(extraBottomStateValue.get(), extraBottomStateValue.get());
startOfCurrentRowGroup = currentRowCount;
mdpTransitionsBuilder.newRowGroup(startOfCurrentRowGroup);
mdpTransitionsBuilder.addNextValue(currentRowCount, extraBottomState.get(), storm::utility::one<ValueType>());
++currentRowCount;
internalAddRowGroupIndex();
internalAddTransition(getStartOfCurrentRowGroup(), extraBottomState.get(), storm::utility::one<ValueType>());
} else { } else {
extraBottomState = boost::none; extraBottomState = boost::none;
} }
@ -75,10 +76,8 @@ namespace storm {
mdpStateToBeliefIdMap.push_back(beliefManager->noId()); mdpStateToBeliefIdMap.push_back(beliefManager->noId());
insertValueHints(extraTargetStateValue.get(), extraTargetStateValue.get()); insertValueHints(extraTargetStateValue.get(), extraTargetStateValue.get());
startOfCurrentRowGroup = currentRowCount;
mdpTransitionsBuilder.newRowGroup(startOfCurrentRowGroup);
mdpTransitionsBuilder.addNextValue(currentRowCount, extraTargetState.get(), storm::utility::one<ValueType>());
++currentRowCount;
internalAddRowGroupIndex();
internalAddTransition(getStartOfCurrentRowGroup(), extraTargetState.get(), storm::utility::one<ValueType>());
targetStates.grow(getCurrentNumberOfMdpStates(), false); targetStates.grow(getCurrentNumberOfMdpStates(), false);
targetStates.set(extraTargetState.get(), true); targetStates.set(extraTargetState.get(), true);
@ -89,24 +88,62 @@ namespace storm {
// Set up the initial state. // Set up the initial state.
initialMdpState = getOrAddMdpState(beliefManager->getInitialBelief()); initialMdpState = getOrAddMdpState(beliefManager->getInitialBelief());
} }
/*!
* Restarts the exploration to allow re-exploring each state.
* After calling this, the "currently explored" MDP has the same number of states and choices as the "old" one, but the choices are still empty
* This method inserts the initial state of the MDP in the exploration queue.
* While re-exploring, the reference to the old MDP remains valid.
*/
void restartExploration() {
STORM_LOG_ASSERT(status == Status::ModelChecked || status == Status::ModelFinished, "Method call is invalid in current status.");
// We will not erase old states during the exploration phase, so most state-based data (like mappings between MDP and Belief states) remain valid.
exploredBeliefIds.clear();
exploredBeliefIds.grow(beliefManager->getNumberOfBeliefIds(), false);
exploredMdpTransitions.clear();
exploredMdpTransitions.resize(exploredMdp->getNumberOfChoices);
exploredChoiceIndices = exploredMdp->getNondeterministicChoiceIndices();
mdpActionRewards.clear();
if (exploredMdp->hasRewardModel()) {
// Can be overwritten during exploration
mdpActionRewards = exploredMdp->getUniqueRewardModel().getStateActionRewardVector();
}
targetStates = storm::storage::BitVector(getCurrentNumberOfMdpStates(), false);
truncatedStates = storm::storage::BitVector(getCurrentNumberOfMdpStates(), false);
mdpStatesToExplore.clear();
// The extra states are not changed
if (extraBottomState) {
currentMdpState = extraBottomState.get();
restoreOldBehaviorAtCurrentState(0);
}
if (extraTargetState) {
currentMdpState = extraTargetState.get();
restoreOldBehaviorAtCurrentState(0);
}
currentMdpState = noState();
// Set up the initial state.
initialMdpState = getOrAddMdpState(beliefManager->getInitialBelief());
}
bool hasUnexploredState() const { bool hasUnexploredState() const {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return !beliefIdsToExplore.empty();
return !mdpStatesToExplore.empty();
} }
BeliefId exploreNextState() { BeliefId exploreNextState() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// Set up the matrix builder
finishCurrentRow();
startOfCurrentRowGroup = currentRowCount;
mdpTransitionsBuilder.newRowGroup(startOfCurrentRowGroup);
++currentRowCount;
// Pop from the queue. // Pop from the queue.
auto result = beliefIdsToExplore.front();
beliefIdsToExplore.pop_front();
return result;
currentMdpState = mdpStatesToExplore.front();
mdpStatesToExplore.pop_front();
if (!currentStateHasOldBehavior()) {
internalAddRowGroupIndex();
}
return mdpStateToBeliefIdMap[currentMdpState];
} }
void addTransitionsToExtraStates(uint64_t const& localActionIndex, ValueType const& targetStateValue = storm::utility::zero<ValueType>(), ValueType const& bottomStateValue = storm::utility::zero<ValueType>()) { void addTransitionsToExtraStates(uint64_t const& localActionIndex, ValueType const& targetStateValue = storm::utility::zero<ValueType>(), ValueType const& bottomStateValue = storm::utility::zero<ValueType>()) {
@ -114,7 +151,7 @@ namespace storm {
// We first insert the entries of the current row in a separate map. // We first insert the entries of the current row in a separate map.
// This is to ensure that entries are sorted in the right way (as required for the transition matrix builder) // This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
uint64_t row = startOfCurrentRowGroup + localActionIndex;
uint64_t row = getStartOfCurrentRowGroup() + localActionIndex;
if (!storm::utility::isZero(bottomStateValue)) { if (!storm::utility::isZero(bottomStateValue)) {
STORM_LOG_ASSERT(extraBottomState.is_initialized(), "Requested a transition to the extra bottom state but there is none."); STORM_LOG_ASSERT(extraBottomState.is_initialized(), "Requested a transition to the extra bottom state but there is none.");
internalAddTransition(row, extraBottomState.get(), bottomStateValue); internalAddTransition(row, extraBottomState.get(), bottomStateValue);
@ -127,7 +164,7 @@ namespace storm {
void addSelfloopTransition(uint64_t const& localActionIndex = 0, ValueType const& value = storm::utility::one<ValueType>()) { void addSelfloopTransition(uint64_t const& localActionIndex = 0, ValueType const& value = storm::utility::one<ValueType>()) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
uint64_t row = startOfCurrentRowGroup + localActionIndex;
uint64_t row = getStartOfCurrentRowGroup() + localActionIndex;
internalAddTransition(row, getCurrentMdpState(), value); internalAddTransition(row, getCurrentMdpState(), value);
} }
@ -145,24 +182,24 @@ namespace storm {
// This is to ensure that entries are sorted in the right way (as required for the transition matrix builder) // This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
MdpStateType column; MdpStateType column;
if (ignoreNewBeliefs) { if (ignoreNewBeliefs) {
column = getMdpState(transitionTarget);
column = getExploredMdpState(transitionTarget);
if (column == noState()) { if (column == noState()) {
return false; return false;
} }
} else { } else {
column = getOrAddMdpState(transitionTarget); column = getOrAddMdpState(transitionTarget);
} }
uint64_t row = startOfCurrentRowGroup + localActionIndex;
uint64_t row = getStartOfCurrentRowGroup() + localActionIndex;
internalAddTransition(row, column, value); internalAddTransition(row, column, value);
return true; return true;
} }
void computeRewardAtCurrentState(uint64 const& localActionIndex, ValueType extraReward = storm::utility::zero<ValueType>()) { void computeRewardAtCurrentState(uint64 const& localActionIndex, ValueType extraReward = storm::utility::zero<ValueType>()) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
if (currentRowCount >= mdpActionRewards.size()) {
mdpActionRewards.resize(currentRowCount, storm::utility::zero<ValueType>());
if (getCurrentNumberOfMdpChoices() > mdpActionRewards.size()) {
mdpActionRewards.resize(getCurrentNumberOfMdpChoices(), storm::utility::zero<ValueType>());
} }
uint64_t row = startOfCurrentRowGroup + localActionIndex;
uint64_t row = getStartOfCurrentRowGroup() + localActionIndex;
mdpActionRewards[row] = beliefManager->getBeliefActionReward(getCurrentBeliefId(), localActionIndex) + extraReward; mdpActionRewards[row] = beliefManager->getBeliefActionReward(getCurrentBeliefId(), localActionIndex) + extraReward;
} }
@ -178,11 +215,64 @@ namespace storm {
truncatedStates.set(getCurrentMdpState(), true); truncatedStates.set(getCurrentMdpState(), true);
} }
bool currentStateHasOldBehavior() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return exploredMdp && getCurrentMdpState() < exploredMdp->getNumberOfStates();
}
/*!
* Inserts transitions and rewards at the given action as in the MDP of the previous exploration.
* Does NOT set whether the state is truncated and/or target.
* Will add "old" states that have not been considered before into the exploration queue
* @param localActionIndex
*/
void restoreOldBehaviorAtCurrentState(uint64_t const& localActionIndex) {
STORM_LOG_ASSERT(currentStateHasOldBehavior(), "Cannot restore old behavior as the current state does not have any.");
uint64_t choiceIndex = exploredChoiceIndices[getCurrentMdpState()] + localActionIndex;
STORM_LOG_ASSERT(choiceIndex < exploredChoiceIndices[getCurrentMdpState() + 1], "Invalid local action index.");
// Insert the transitions
for (auto const& transition : exploredMdp->getTransitionMatrix().getRow(choiceIndex)) {
internalAddTransition(choiceIndex, transition.getColumn(), transition.getValue());
// Check whether exploration is needed
auto beliefId = mdpStateToBeliefIdMap[transition.getColumn()];
if (beliefId != beliefManager->noId()) { // Not the extra target or bottom state
if (!exploredBeliefIds.get(beliefId)) {
// This belief needs exploration
exploredBeliefIds.set(beliefId, true);
mdpStatesToExplore.push_back(transition.getColumn());
}
}
}
// Actually, nothing needs to be done for rewards since we already initialize the vector with the "old" values
}
void finishExploration() { void finishExploration() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
STORM_LOG_ASSERT(!hasUnexploredState(), "Finishing exploration not possible if there are still unexplored states.");
// Finish the last row grouping in case the last explored state was new
if (!currentStateHasOldBehavior()) {
internalAddRowGroupIndex();
}
// Create the tranistion matrix // Create the tranistion matrix
finishCurrentRow();
auto mdpTransitionMatrix = mdpTransitionsBuilder.build(getCurrentNumberOfMdpChoices(), getCurrentNumberOfMdpStates(), getCurrentNumberOfMdpStates());
uint64_t entryCount = 0;
for (auto const& row : exploredMdpTransitions) {
entryCount += row.size();
}
storm::storage::SparseMatrixBuilder<ValueType> builder(getCurrentNumberOfMdpChoices(), getCurrentNumberOfMdpStates(), entryCount, true, true, getCurrentNumberOfMdpStates());
for (uint64_t groupIndex = 0; groupIndex < exploredChoiceIndices.size() - 1; ++groupIndex) {
uint64_t rowIndex = exploredChoiceIndices[groupIndex];
uint64_t groupEnd = exploredChoiceIndices[groupIndex + 1];
builder.newRowGroup(rowIndex);
for (; rowIndex < groupEnd; ++rowIndex) {
for (auto const& entry : exploredMdpTransitions[rowIndex]) {
builder.addNextValue(rowIndex, entry.first, entry.second);
}
}
}
auto mdpTransitionMatrix = builder.build();
// Create a standard labeling // Create a standard labeling
storm::models::sparse::StateLabeling mdpLabeling(getCurrentNumberOfMdpStates()); storm::models::sparse::StateLabeling mdpLabeling(getCurrentNumberOfMdpStates());
@ -212,13 +302,18 @@ namespace storm {
} }
MdpStateType getCurrentNumberOfMdpStates() const { MdpStateType getCurrentNumberOfMdpStates() const {
STORM_LOG_ASSERT(status != Status::Uninitialized, "Method call is invalid in current status.");
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return mdpStateToBeliefIdMap.size(); return mdpStateToBeliefIdMap.size();
} }
MdpStateType getCurrentNumberOfMdpChoices() const { MdpStateType getCurrentNumberOfMdpChoices() const {
STORM_LOG_ASSERT(status != Status::Uninitialized, "Method call is invalid in current status.");
return currentRowCount;
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return exploredMdpTransitions.size();
}
MdpStateType getStartOfCurrentRowGroup() const {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return exploredChoiceIndices.back();
} }
ValueType getLowerValueBoundAtCurrentState() const { ValueType getLowerValueBoundAtCurrentState() const {
@ -291,7 +386,8 @@ namespace storm {
} }
MdpStateType getCurrentMdpState() const { MdpStateType getCurrentMdpState() const {
return mdpTransitionsBuilder.getCurrentRowGroupCount() - 1;
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return currentMdpState;
} }
MdpStateType getCurrentBeliefId() const { MdpStateType getCurrentBeliefId() const {
@ -299,27 +395,20 @@ namespace storm {
} }
void internalAddTransition(uint64_t const& row, MdpStateType const& column, ValueType const& value) { void internalAddTransition(uint64_t const& row, MdpStateType const& column, ValueType const& value) {
// We first insert the entries of the current row in a separate map.
// This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
STORM_LOG_ASSERT(row >= currentRowCount - 1, "Trying to insert in an already completed row.");
if (row >= currentRowCount) {
// We are going to start a new row, so insert the entries of the old one
finishCurrentRow();
currentRowCount = row + 1;
STORM_LOG_ASSERT(row <= exploredMdpTransitions.size(), "Skipped at least one row.");
if (row == exploredMdpTransitions.size()) {
exploredMdpTransitions.emplace_back();
} }
STORM_LOG_ASSERT(mdpTransitionsBuilderCurrentRowEntries.count(column) == 0, "Trying to insert multiple transitions to the same state.");
mdpTransitionsBuilderCurrentRowEntries[column] = value;
STORM_LOG_ASSERT(exploredMdpTransitions[row].count(column) == 0, "Trying to insert multiple transitions to the same state.");
exploredMdpTransitions[row][column] = value;
} }
void finishCurrentRow() {
for (auto const& entry : mdpTransitionsBuilderCurrentRowEntries) {
mdpTransitionsBuilder.addNextValue(currentRowCount - 1, entry.first, entry.second);
}
mdpTransitionsBuilderCurrentRowEntries.clear();
void internalAddRowGroupIndex() {
exploredChoiceIndices.push_back(getCurrentNumberOfMdpChoices());
} }
MdpStateType getMdpState(BeliefId const& beliefId) const {
if (beliefId < beliefIdsWithMdpState.size() && beliefIdsWithMdpState.get(beliefId)) {
MdpStateType getExploredMdpState(BeliefId const& beliefId) const {
if (beliefId < exploredBeliefIds.size() && exploredBeliefIds.get(beliefId)) {
return beliefIdToMdpStateMap.at(beliefId); return beliefIdToMdpStateMap.at(beliefId);
} else { } else {
return noState(); return noState();
@ -336,20 +425,28 @@ namespace storm {
} }
MdpStateType getOrAddMdpState(BeliefId const& beliefId) { MdpStateType getOrAddMdpState(BeliefId const& beliefId) {
beliefIdsWithMdpState.grow(beliefId + 1, false);
if (beliefIdsWithMdpState.get(beliefId)) {
exploredBeliefIds.grow(beliefId + 1, false);
if (exploredBeliefIds.get(beliefId)) {
return beliefIdToMdpStateMap[beliefId]; return beliefIdToMdpStateMap[beliefId];
} else { } else {
// Add a new MDP state
beliefIdsWithMdpState.set(beliefId, true);
// This state needs exploration
exploredBeliefIds.set(beliefId, true);
// If this is a restart of the exploration, we still might have an MDP state for the belief
if (exploredMdp) {
auto findRes = beliefIdToMdpStateMap.find(beliefId);
if (findRes != beliefIdToMdpStateMap.end()) {
mdpStatesToExplore.push_back(findRes->second);
return findRes->second;
}
}
// At this poind we need to add a new MDP state
MdpStateType result = getCurrentNumberOfMdpStates(); MdpStateType result = getCurrentNumberOfMdpStates();
assert(getCurrentNumberOfMdpStates() == mdpStateToBeliefIdMap.size()); assert(getCurrentNumberOfMdpStates() == mdpStateToBeliefIdMap.size());
mdpStateToBeliefIdMap.push_back(beliefId); mdpStateToBeliefIdMap.push_back(beliefId);
beliefIdToMdpStateMap[beliefId] = result; beliefIdToMdpStateMap[beliefId] = result;
// This new belief needs exploration
beliefIdsToExplore.push_back(beliefId);
insertValueHints(computeLowerValueBoundAtBelief(beliefId), computeUpperValueBoundAtBelief(beliefId)); insertValueHints(computeLowerValueBoundAtBelief(beliefId), computeUpperValueBoundAtBelief(beliefId));
mdpStatesToExplore.push_back(result);
return result; return result;
} }
} }
@ -358,15 +455,14 @@ namespace storm {
std::shared_ptr<BeliefManagerType> beliefManager; std::shared_ptr<BeliefManagerType> beliefManager;
std::vector<BeliefId> mdpStateToBeliefIdMap; std::vector<BeliefId> mdpStateToBeliefIdMap;
std::map<BeliefId, MdpStateType> beliefIdToMdpStateMap; std::map<BeliefId, MdpStateType> beliefIdToMdpStateMap;
storm::storage::BitVector beliefIdsWithMdpState;
storm::storage::BitVector exploredBeliefIds;
// Exploration information // Exploration information
std::deque<uint64_t> beliefIdsToExplore;
storm::storage::SparseMatrixBuilder<ValueType> mdpTransitionsBuilder;
std::map<MdpStateType, ValueType> mdpTransitionsBuilderCurrentRowEntries;
std::deque<uint64_t> mdpStatesToExplore;
std::vector<std::map<MdpStateType, ValueType>> exploredMdpTransitions;
std::vector<MdpStateType> exploredChoiceIndices;
std::vector<ValueType> mdpActionRewards; std::vector<ValueType> mdpActionRewards;
uint64_t startOfCurrentRowGroup;
uint64_t currentRowCount;
uint64_t currentMdpState;
// Special states during exploration // Special states during exploration
boost::optional<MdpStateType> extraTargetState; boost::optional<MdpStateType> extraTargetState;

4
src/storm-pomdp/storage/BeliefManager.h

@ -324,8 +324,8 @@ namespace storm {
} }
std::map<BeliefId, ValueType> expandInternal(BeliefId const& beliefId, uint64_t actionIndex, boost::optional<std::vector<uint64_t>> const& observationTriangulationResolutions = boost::none) { std::map<BeliefId, ValueType> expandInternal(BeliefId const& beliefId, uint64_t actionIndex, boost::optional<std::vector<uint64_t>> const& observationTriangulationResolutions = boost::none) {
std::map<BeliefId, ValueType> destinations; // The belief ids should be ordered
// TODO: Does this make sense? It could be better to order them afterwards because now we rely on the fact that MDP states have the same order than their associated BeliefIds
std::map<BeliefId, ValueType> destinations;
// TODO: Output as vector?
BeliefType belief = getBelief(beliefId); BeliefType belief = getBelief(beliefId);

Loading…
Cancel
Save