Browse Source

Added basis for rewards in dropUnreachableStates()

tempestpy_adaptions
Alexander Bork 5 years ago
parent
commit
62c905fc58
  1. 28
      src/storm-pomdp/builder/BeliefMdpExplorer.h

28
src/storm-pomdp/builder/BeliefMdpExplorer.h

@ -298,28 +298,42 @@ namespace storm {
void dropUnreachableStates() { void dropUnreachableStates() {
STORM_LOG_ASSERT(status == Status::ModelFinished || status == Status::ModelChecked, "Method call is invalid in current status."); STORM_LOG_ASSERT(status == Status::ModelFinished || status == Status::ModelChecked, "Method call is invalid in current status.");
storm::storage::BitVector reachableStates = storm::utility::graph::getReachableStates(getExploredMdp()->getTransitionMatrix(),
storm::storage::BitVector(getCurrentNumberOfMdpStates(), {initialMdpState}),
storm::storage::BitVector(getCurrentNumberOfMdpStates(), true), targetStates);
auto reachableStates = storm::utility::graph::getReachableStates(getExploredMdp()->getTransitionMatrix(),
storm::storage::BitVector(getCurrentNumberOfMdpStates(), std::vector<uint64_t>{initialMdpState}),
storm::storage::BitVector(getCurrentNumberOfMdpStates(), true),
getExploredMdp()->getStateLabeling().getStates("target"));
auto reachableTransitionMatrix = getExploredMdp()->getTransitionMatrix().getSubmatrix(true, reachableStates, reachableStates); auto reachableTransitionMatrix = getExploredMdp()->getTransitionMatrix().getSubmatrix(true, reachableStates, reachableStates);
auto reachableStateLabeling = getExploredMdp()->getStateLabeling().getSubLabeling(reachableStates); auto reachableStateLabeling = getExploredMdp()->getStateLabeling().getSubLabeling(reachableStates);
// TODO reward model
storm::storage::sparse::ModelComponents<ValueType> modelComponents(std::move(reachableTransitionMatrix), std::move(reachableStateLabeling));
exploredMdp = std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(modelComponents));
std::vector<BeliefId> reachableMdpStateToBeliefIdMap(reachableStates.getNumberOfSetBits()); std::vector<BeliefId> reachableMdpStateToBeliefIdMap(reachableStates.getNumberOfSetBits());
std::vector<ValueType> reachableLowerValueBounds(reachableStates.getNumberOfSetBits()); std::vector<ValueType> reachableLowerValueBounds(reachableStates.getNumberOfSetBits());
std::vector<ValueType> reachableUpperValueBounds(reachableStates.getNumberOfSetBits()); std::vector<ValueType> reachableUpperValueBounds(reachableStates.getNumberOfSetBits());
std::vector<ValueType> reachableValues(reachableStates.getNumberOfSetBits()); std::vector<ValueType> reachableValues(reachableStates.getNumberOfSetBits());
std::vector<ValueType> reachableMdpActionRewards;
for (uint64_t state = 0; state < reachableStates.size(); ++state) { for (uint64_t state = 0; state < reachableStates.size(); ++state) {
if (reachableStates[state]) { if (reachableStates[state]) {
reachableMdpStateToBeliefIdMap.push_back(mdpStateToBeliefIdMap[state]); reachableMdpStateToBeliefIdMap.push_back(mdpStateToBeliefIdMap[state]);
reachableLowerValueBounds.push_back(lowerValueBounds[state]); reachableLowerValueBounds.push_back(lowerValueBounds[state]);
reachableUpperValueBounds.push_back(upperValueBounds[state]); reachableUpperValueBounds.push_back(upperValueBounds[state]);
reachableValues.push_back(values[state]); reachableValues.push_back(values[state]);
if (getExploredMdp()->hasRewardModel()) {
//TODO FIXME is there some mismatch with the indices here?
for (uint64_t i = 0; i < getExploredMdp()->getTransitionMatrix().getRowGroupSize(state); ++i) {
reachableMdpActionRewards.push_back(getExploredMdp()->getUniqueRewardModel().getStateActionRewardVector()[state + i]);
}
}
} }
//TODO drop BeliefIds from exploredBeliefIDs? //TODO drop BeliefIds from exploredBeliefIDs?
} }
std::unordered_map<std::string, storm::models::sparse::StandardRewardModel<ValueType>> mdpRewardModels;
if (!reachableMdpActionRewards.empty()) {
//reachableMdpActionRewards.resize(getCurrentNumberOfMdpChoices(), storm::utility::zero<ValueType>());
mdpRewardModels.emplace("default",
storm::models::sparse::StandardRewardModel<ValueType>(boost::optional<std::vector<ValueType>>(), std::move(reachableMdpActionRewards)));
}
storm::storage::sparse::ModelComponents<ValueType> modelComponents(std::move(reachableTransitionMatrix), std::move(reachableStateLabeling),
std::move(mdpRewardModels));
exploredMdp = std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(modelComponents));
std::map<BeliefId, MdpStateType> reachableBeliefIdToMdpStateMap; std::map<BeliefId, MdpStateType> reachableBeliefIdToMdpStateMap;
for (MdpStateType state = 0; state < reachableMdpStateToBeliefIdMap.size(); ++state) { for (MdpStateType state = 0; state < reachableMdpStateToBeliefIdMap.size(); ++state) {
reachableBeliefIdToMdpStateMap[reachableMdpStateToBeliefIdMap[state]] = state; reachableBeliefIdToMdpStateMap[reachableMdpStateToBeliefIdMap[state]] = state;

Loading…
Cancel
Save