Browse Source

BeliefMdpExplorer: Added a few asserts so that methods can only be called in the corresponding exploration phase

tempestpy_adaptions
Tim Quatmann 5 years ago
parent
commit
5388ed98e3
  1. 38
      src/storm-pomdp/builder/BeliefMdpExplorer.h

38
src/storm-pomdp/builder/BeliefMdpExplorer.h

@ -29,12 +29,20 @@ namespace storm {
typedef typename BeliefManagerType::BeliefId BeliefId;
typedef uint64_t MdpStateType;
BeliefMdpExplorer(std::shared_ptr<BeliefManagerType> beliefManager, std::vector<ValueType> const& pomdpLowerValueBounds, std::vector<ValueType> const& pomdpUpperValueBounds) : beliefManager(beliefManager), pomdpLowerValueBounds(pomdpLowerValueBounds), pomdpUpperValueBounds(pomdpUpperValueBounds) {
enum class Status {
Uninitialized,
Exploring,
ModelFinished,
ModelChecked
};
BeliefMdpExplorer(std::shared_ptr<BeliefManagerType> beliefManager, std::vector<ValueType> const& pomdpLowerValueBounds, std::vector<ValueType> const& pomdpUpperValueBounds) : beliefManager(beliefManager), pomdpLowerValueBounds(pomdpLowerValueBounds), pomdpUpperValueBounds(pomdpUpperValueBounds), status(Status::Uninitialized) {
// Intentionally left empty
}
BeliefMdpExplorer(BeliefMdpExplorer&& other) = default;
void startNewExploration(boost::optional<ValueType> extraTargetStateValue = boost::none, boost::optional<ValueType> extraBottomStateValue = boost::none) {
status = Status::Exploring;
// Reset data from potential previous explorations
mdpStateToBeliefIdMap.clear();
beliefIdToMdpStateMap.clear();
@ -83,10 +91,12 @@ namespace storm {
}
bool hasUnexploredState() const {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return !beliefIdsToExplore.empty();
}
BeliefId exploreNextState() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// Set up the matrix builder
finishCurrentRow();
startOfCurrentRowGroup = currentRowCount;
@ -100,6 +110,7 @@ namespace storm {
}
void addTransitionsToExtraStates(uint64_t const& localActionIndex, ValueType const& targetStateValue = storm::utility::zero<ValueType>(), ValueType const& bottomStateValue = storm::utility::zero<ValueType>()) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// We first insert the entries of the current row in a separate map.
// This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
@ -115,6 +126,7 @@ namespace storm {
}
void addSelfloopTransition(uint64_t const& localActionIndex = 0, ValueType const& value = storm::utility::one<ValueType>()) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
uint64_t row = startOfCurrentRowGroup + localActionIndex;
internalAddTransition(row, getCurrentMdpState(), value);
}
@ -128,6 +140,7 @@ namespace storm {
* @return true iff a transition was actually inserted. False can only happen if ignoreNewBeliefs is true.
*/
bool addTransitionToBelief(uint64_t const& localActionIndex, BeliefId const& transitionTarget, ValueType const& value, bool ignoreNewBeliefs) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// We first insert the entries of the current row in a separate map.
// This is to ensure that entries are sorted in the right way (as required for the transition matrix builder)
MdpStateType column;
@ -145,6 +158,7 @@ namespace storm {
}
void computeRewardAtCurrentState(uint64 const& localActionIndex, ValueType extraReward = storm::utility::zero<ValueType>()) {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
if (currentRowCount >= mdpActionRewards.size()) {
mdpActionRewards.resize(currentRowCount, storm::utility::zero<ValueType>());
}
@ -153,16 +167,19 @@ namespace storm {
}
void setCurrentStateIsTarget() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
targetStates.grow(getCurrentNumberOfMdpStates(), false);
targetStates.set(getCurrentMdpState(), true);
}
void setCurrentStateIsTruncated() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
truncatedStates.grow(getCurrentNumberOfMdpStates(), false);
truncatedStates.set(getCurrentMdpState(), true);
}
void finishExploration() {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
// Create the tranistion matrix
finishCurrentRow();
auto mdpTransitionMatrix = mdpTransitionsBuilder.build(getCurrentNumberOfMdpChoices(), getCurrentNumberOfMdpStates(), getCurrentNumberOfMdpStates());
@ -185,26 +202,32 @@ namespace storm {
storm::storage::sparse::ModelComponents<ValueType> modelComponents(std::move(mdpTransitionMatrix), std::move(mdpLabeling), std::move(mdpRewardModels));
exploredMdp = std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(modelComponents));
status = Status::ModelFinished;
}
std::shared_ptr<storm::models::sparse::Mdp<ValueType>> getExploredMdp() const {
STORM_LOG_ASSERT(status == Status::ModelFinished || status == Status::ModelChecked, "Method call is invalid in current status.");
STORM_LOG_ASSERT(exploredMdp, "Tried to get the explored MDP but exploration was not finished yet.");
return exploredMdp;
}
MdpStateType getCurrentNumberOfMdpStates() const {
STORM_LOG_ASSERT(status != Status::Uninitialized, "Method call is invalid in current status.");
return mdpStateToBeliefIdMap.size();
}
MdpStateType getCurrentNumberOfMdpChoices() const {
STORM_LOG_ASSERT(status != Status::Uninitialized, "Method call is invalid in current status.");
return currentRowCount;
}
ValueType getLowerValueBoundAtCurrentState() const {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return lowerValueBounds[getCurrentMdpState()];
}
ValueType getUpperValueBoundAtCurrentState() const {
STORM_LOG_ASSERT(status == Status::Exploring, "Method call is invalid in current status.");
return upperValueBounds[getCurrentMdpState()];
}
@ -216,7 +239,8 @@ namespace storm {
return beliefManager->getWeightedSum(beliefId, pomdpUpperValueBounds);
}
std::vector<ValueType> const& computeValuesOfExploredMdp(storm::solver::OptimizationDirection const& dir) {
void computeValuesOfExploredMdp(storm::solver::OptimizationDirection const& dir) {
STORM_LOG_ASSERT(status == Status::ModelFinished, "Method call is invalid in current status.");
STORM_LOG_ASSERT(exploredMdp, "Tried to compute values but the MDP is not explored");
auto property = createStandardProperty(dir, exploredMdp->hasRewardModel());
auto task = createStandardCheckTask(property);
@ -228,12 +252,18 @@ namespace storm {
STORM_LOG_ASSERT(storm::utility::resources::isTerminate(), "Empty check result!");
STORM_LOG_ERROR("No result obtained while checking.");
}
status = Status::ModelChecked;
}
std::vector<ValueType> const& getValuesOfExploredMdp() const {
STORM_LOG_ASSERT(status == Status::ModelChecked, "Method call is invalid in current status.");
return values;
}
ValueType const& getComputedValueAtInitialState() const {
STORM_LOG_ASSERT(status == Status::ModelChecked, "Method call is invalid in current status.");
STORM_LOG_ASSERT(exploredMdp, "Tried to get a value but no MDP was explored.");
return values[exploredMdp->getInitialStates().getNextSetIndex(0)];
return getValuesOfExploredMdp()[exploredMdp->getInitialStates().getNextSetIndex(0)];
}
private:
@ -355,6 +385,8 @@ namespace storm {
std::vector<ValueType> upperValueBounds;
std::vector<ValueType> values; // Contains an estimate during building and the actual result after a check has performed
// The current status of this explorer
Status status;
};
}
}
Loading…
Cancel
Save