Browse Source

BeliefMdpExplorer: Added a few asserts so that methods can only be called in the corresponding exploration phase

main
Tim Quatmann 5 years ago
parent
commit
5388ed98e3
  1. 38
      src/storm-pomdp/builder/BeliefMdpExplorer.h

38
src/storm-pomdp/builder/BeliefMdpExplorer.h

@ -29,12 +29,20 @@ namespace storm {
typedef typename BeliefManagerType::BeliefId BeliefId; typedef typename BeliefManagerType::BeliefId BeliefId;
typedef uint64_t MdpStateType; typedef uint64_t MdpStateType;
BeliefMdpExplorer(std::shared_ptr<BeliefManagerType> beliefManager, std::vector<ValueType> const& pomdpLowerValueBounds, std::vector<ValueType> const& pomdpUpperValueBounds) : beliefManager(beliefManager), pomdpLowerValueBounds(pomdpLowerValueBounds), pomdpUpperValueBounds(pomdpUpperValueBounds) {
enum class Status {
Uninitialized,
Exploring,
ModelFinished,
ModelChecked
};
BeliefMdpExplorer(std::shared_ptr<BeliefManagerType> beliefManager, std::vector<ValueType> const& pomdpLowerValueBounds, std::vector<ValueType> const& pomdpUpperValueBounds) : beliefManager(beliefManager), pomdpLowerValueBounds(pomdpLowerValueBounds), pomdpUpperValueBounds(pomdpUpperValueBounds), status(Status::Uninitialized) {
// Intentionally left empty // Intentionally left empty
} }
BeliefMdpExplorer(BeliefMdpExplorer&& other) = default; BeliefMdpExplorer(BeliefMdpExplorer&& other) = default;
void startNewExploration(boost::optional<ValueType> extraTargetStateValue = boost::none, boost::optional<ValueType> extraBottomStateValue = boost::none) { void startNewExploration(boost::optional<ValueType> extraTargetStateValue = boost::none, boost::optional<ValueType> extraBottomStateValue = boost::none) {
status = Status::Exploring;
// Reset data from potential previous explorations // Reset data from potential previous explorations
mdpStateToBeliefIdMap.clear(); mdpStateToBeliefIdMap.clear();
beliefIdToMdpStateMap.clear(); beliefIdToMdpStateMap.clear();
@ -83,10 +91,12 @@ namespace storm {
} }
bool hasUnexploredState() const { bool hasUnexploredState() const {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return !beliefIdsToExplore.empty(); return !beliefIdsToExplore.empty();
} }
BeliefId exploreNextState() { BeliefId exploreNextState() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// Set up the matrix builder // Set up the matrix builder
finishCurrentRow(); finishCurrentRow();
startOfCurrentRowGroup = currentRowCount; startOfCurrentRowGroup = currentRowCount;
@ -100,6 +110,7 @@ namespace storm {
} }
void addTransitionsToExtraStates(uint64_t const& localActionIndex, ValueType const& targetStateValue = storm::utility::zero<ValueType>(), ValueType const& bottomStateValue = storm::utility::zero<ValueType>()) { void addTransitionsToExtraStates(uint64_t const& localActionIndex, ValueType const& targetStateValue = storm::utility::zero<ValueType>(), ValueType const& bottomStateValue = storm::utility::zero<ValueType>()) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// We first insert the entries of the current row in a separate map. // We first insert the entries of the current row in a separate map.
// This is to ensure that entries are sorted in the right way (as required for the transition matrix builder) // This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
@ -115,6 +126,7 @@ namespace storm {
} }
void addSelfloopTransition(uint64_t const& localActionIndex = 0, ValueType const& value = storm::utility::one<ValueType>()) { void addSelfloopTransition(uint64_t const& localActionIndex = 0, ValueType const& value = storm::utility::one<ValueType>()) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
uint64_t row = startOfCurrentRowGroup + localActionIndex; uint64_t row = startOfCurrentRowGroup + localActionIndex;
internalAddTransition(row, getCurrentMdpState(), value); internalAddTransition(row, getCurrentMdpState(), value);
} }
@ -128,6 +140,7 @@ namespace storm {
* @return true iff a transition was actually inserted. False can only happen if ignoreNewBeliefs is true. * @return true iff a transition was actually inserted. False can only happen if ignoreNewBeliefs is true.
*/ */
bool addTransitionToBelief(uint64_t const& localActionIndex, BeliefId const& transitionTarget, ValueType const& value, bool ignoreNewBeliefs) { bool addTransitionToBelief(uint64_t const& localActionIndex, BeliefId const& transitionTarget, ValueType const& value, bool ignoreNewBeliefs) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// We first insert the entries of the current row in a separate map. // We first insert the entries of the current row in a separate map.
// This is to ensure that entries are sorted in the right way (as required for the transition matrix builder) // This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
MdpStateType column; MdpStateType column;
@ -145,6 +158,7 @@ namespace storm {
} }
void computeRewardAtCurrentState(uint64 const& localActionIndex, ValueType extraReward = storm::utility::zero<ValueType>()) { void computeRewardAtCurrentState(uint64 const& localActionIndex, ValueType extraReward = storm::utility::zero<ValueType>()) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
if (currentRowCount >= mdpActionRewards.size()) { if (currentRowCount >= mdpActionRewards.size()) {
mdpActionRewards.resize(currentRowCount, storm::utility::zero<ValueType>()); mdpActionRewards.resize(currentRowCount, storm::utility::zero<ValueType>());
} }
@ -153,16 +167,19 @@ namespace storm {
} }
void setCurrentStateIsTarget() { void setCurrentStateIsTarget() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
targetStates.grow(getCurrentNumberOfMdpStates(), false); targetStates.grow(getCurrentNumberOfMdpStates(), false);
targetStates.set(getCurrentMdpState(), true); targetStates.set(getCurrentMdpState(), true);
} }
void setCurrentStateIsTruncated() { void setCurrentStateIsTruncated() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
truncatedStates.grow(getCurrentNumberOfMdpStates(), false); truncatedStates.grow(getCurrentNumberOfMdpStates(), false);
truncatedStates.set(getCurrentMdpState(), true); truncatedStates.set(getCurrentMdpState(), true);
} }
void finishExploration() { void finishExploration() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// Create the tranistion matrix // Create the tranistion matrix
finishCurrentRow(); finishCurrentRow();
auto mdpTransitionMatrix = mdpTransitionsBuilder.build(getCurrentNumberOfMdpChoices(), getCurrentNumberOfMdpStates(), getCurrentNumberOfMdpStates()); auto mdpTransitionMatrix = mdpTransitionsBuilder.build(getCurrentNumberOfMdpChoices(), getCurrentNumberOfMdpStates(), getCurrentNumberOfMdpStates());
@ -185,26 +202,32 @@ namespace storm {
storm::storage::sparse::ModelComponents<ValueType> modelComponents(std::move(mdpTransitionMatrix), std::move(mdpLabeling), std::move(mdpRewardModels)); storm::storage::sparse::ModelComponents<ValueType> modelComponents(std::move(mdpTransitionMatrix), std::move(mdpLabeling), std::move(mdpRewardModels));
exploredMdp = std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(modelComponents)); exploredMdp = std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(modelComponents));
status = Status::ModelFinished;
} }
std::shared_ptr<storm::models::sparse::Mdp<ValueType>> getExploredMdp() const { std::shared_ptr<storm::models::sparse::Mdp<ValueType>> getExploredMdp() const {
STORM_LOG_ASSERT(status == Status::ModelFinished || status == Status::ModelChecked, "Method call is invalid in current status.");
STORM_LOG_ASSERT(exploredMdp, "Tried to get the explored MDP but exploration was not finished yet."); STORM_LOG_ASSERT(exploredMdp, "Tried to get the explored MDP but exploration was not finished yet.");
return exploredMdp; return exploredMdp;
} }
MdpStateType getCurrentNumberOfMdpStates() const { MdpStateType getCurrentNumberOfMdpStates() const {
STORM_LOG_ASSERT(status != Status::Uninitialized, "Method call is invalid in current status.");
return mdpStateToBeliefIdMap.size(); return mdpStateToBeliefIdMap.size();
} }
MdpStateType getCurrentNumberOfMdpChoices() const { MdpStateType getCurrentNumberOfMdpChoices() const {
STORM_LOG_ASSERT(status != Status::Uninitialized, "Method call is invalid in current status.");
return currentRowCount; return currentRowCount;
} }
ValueType getLowerValueBoundAtCurrentState() const { ValueType getLowerValueBoundAtCurrentState() const {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return lowerValueBounds[getCurrentMdpState()]; return lowerValueBounds[getCurrentMdpState()];
} }
ValueType getUpperValueBoundAtCurrentState() const { ValueType getUpperValueBoundAtCurrentState() const {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return upperValueBounds[getCurrentMdpState()]; return upperValueBounds[getCurrentMdpState()];
} }
@ -216,7 +239,8 @@ namespace storm {
return beliefManager->getWeightedSum(beliefId, pomdpUpperValueBounds); return beliefManager->getWeightedSum(beliefId, pomdpUpperValueBounds);
} }
std::vector<ValueType> const& computeValuesOfExploredMdp(storm::solver::OptimizationDirection const& dir) {
void computeValuesOfExploredMdp(storm::solver::OptimizationDirection const& dir) {
STORM_LOG_ASSERT(status == Status::ModelFinished, "Method call is invalid in current status.");
STORM_LOG_ASSERT(exploredMdp, "Tried to compute values but the MDP is not explored"); STORM_LOG_ASSERT(exploredMdp, "Tried to compute values but the MDP is not explored");
auto property = createStandardProperty(dir, exploredMdp->hasRewardModel()); auto property = createStandardProperty(dir, exploredMdp->hasRewardModel());
auto task = createStandardCheckTask(property); auto task = createStandardCheckTask(property);
@ -228,12 +252,18 @@ namespace storm {
STORM_LOG_ASSERT(storm::utility::resources::isTerminate(), "Empty check result!"); STORM_LOG_ASSERT(storm::utility::resources::isTerminate(), "Empty check result!");
STORM_LOG_ERROR("No result obtained while checking."); STORM_LOG_ERROR("No result obtained while checking.");
} }
status = Status::ModelChecked;
}
std::vector<ValueType> const& getValuesOfExploredMdp() const {
STORM_LOG_ASSERT(status == Status::ModelChecked, "Method call is invalid in current status.");
return values; return values;
} }
ValueType const& getComputedValueAtInitialState() const { ValueType const& getComputedValueAtInitialState() const {
STORM_LOG_ASSERT(status == Status::ModelChecked, "Method call is invalid in current status.");
STORM_LOG_ASSERT(exploredMdp, "Tried to get a value but no MDP was explored."); STORM_LOG_ASSERT(exploredMdp, "Tried to get a value but no MDP was explored.");
return values[exploredMdp->getInitialStates().getNextSetIndex(0)];
return getValuesOfExploredMdp()[exploredMdp->getInitialStates().getNextSetIndex(0)];
} }
private: private:
@ -355,6 +385,8 @@ namespace storm {
std::vector<ValueType> upperValueBounds; std::vector<ValueType> upperValueBounds;
std::vector<ValueType> values; // Contains an estimate during building and the actual result after a check has performed std::vector<ValueType> values; // Contains an estimate during building and the actual result after a check has performed
// The current status of this explorer
Status status;
}; };
} }
} }
Loading…
Cancel
Save