|
@ -299,7 +299,6 @@ namespace storm { |
|
|
STORM_PRINT("Exploration threshold: " << explorationThreshold << std::endl) |
|
|
STORM_PRINT("Exploration threshold: " << explorationThreshold << std::endl) |
|
|
} |
|
|
} |
|
|
while (!beliefsToBeExpanded.empty()) { |
|
|
while (!beliefsToBeExpanded.empty()) { |
|
|
// TODO direct generation of transition matrix?
|
|
|
|
|
|
uint64_t currId = beliefsToBeExpanded.front(); |
|
|
uint64_t currId = beliefsToBeExpanded.front(); |
|
|
beliefsToBeExpanded.pop_front(); |
|
|
beliefsToBeExpanded.pop_front(); |
|
|
bool isTarget = beliefIsTarget[currId]; |
|
|
bool isTarget = beliefIsTarget[currId]; |
|
@ -637,7 +636,7 @@ namespace storm { |
|
|
// Belief ID -> ActionIndex
|
|
|
// Belief ID -> ActionIndex
|
|
|
std::map<uint64_t, std::vector<uint64_t>> chosenActions; |
|
|
std::map<uint64_t, std::vector<uint64_t>> chosenActions; |
|
|
|
|
|
|
|
|
// Belief ID -> Observation -> Probability
|
|
|
|
|
|
|
|
|
// Belief ID -> action -> Observation -> Probability
|
|
|
std::map<uint64_t, std::vector<std::map<uint32_t, ValueType>>> observationProbabilities; |
|
|
std::map<uint64_t, std::vector<std::map<uint32_t, ValueType>>> observationProbabilities; |
|
|
// current ID -> action -> next ID
|
|
|
// current ID -> action -> next ID
|
|
|
std::map<uint64_t, std::vector<std::map<uint32_t, uint64_t>>> nextBelieves; |
|
|
std::map<uint64_t, std::vector<std::map<uint32_t, uint64_t>>> nextBelieves; |
|
@ -1025,7 +1024,7 @@ namespace storm { |
|
|
storm::models::sparse::Pomdp<ValueType, RewardModelType> const &pomdp, |
|
|
storm::models::sparse::Pomdp<ValueType, RewardModelType> const &pomdp, |
|
|
std::set<uint32_t> const &target_observations, uint64_t gridResolution, |
|
|
std::set<uint32_t> const &target_observations, uint64_t gridResolution, |
|
|
std::vector<storm::pomdp::Belief<ValueType>> &beliefList, |
|
|
std::vector<storm::pomdp::Belief<ValueType>> &beliefList, |
|
|
std::vector<storm::pomdp::Belief<ValueType>> &grid, std::vector<bool> &beliefIsKnown, |
|
|
|
|
|
|
|
|
std::vector<storm::pomdp::Belief<ValueType>> &grid, std::vector<bool> &beliefIsTarget, |
|
|
uint64_t nextId) { |
|
|
uint64_t nextId) { |
|
|
bool isTarget; |
|
|
bool isTarget; |
|
|
uint64_t newId = nextId; |
|
|
uint64_t newId = nextId; |
|
@ -1045,7 +1044,7 @@ namespace storm { |
|
|
<< distribution << "]"); |
|
|
<< distribution << "]"); |
|
|
beliefList.push_back(belief); |
|
|
beliefList.push_back(belief); |
|
|
grid.push_back(belief); |
|
|
grid.push_back(belief); |
|
|
beliefIsKnown.push_back(isTarget); |
|
|
|
|
|
|
|
|
beliefIsTarget.push_back(isTarget); |
|
|
++newId; |
|
|
++newId; |
|
|
} else { |
|
|
} else { |
|
|
// Otherwise we have to enumerate all possible distributions with regards to the grid
|
|
|
// Otherwise we have to enumerate all possible distributions with regards to the grid
|
|
@ -1073,7 +1072,7 @@ namespace storm { |
|
|
STORM_LOG_TRACE("Add Belief " << std::to_string(newId) << " [(" << std::to_string(observation) << ")," << distribution << "]"); |
|
|
STORM_LOG_TRACE("Add Belief " << std::to_string(newId) << " [(" << std::to_string(observation) << ")," << distribution << "]"); |
|
|
beliefList.push_back(belief); |
|
|
beliefList.push_back(belief); |
|
|
grid.push_back(belief); |
|
|
grid.push_back(belief); |
|
|
beliefIsKnown.push_back(isTarget); |
|
|
|
|
|
|
|
|
beliefIsTarget.push_back(isTarget); |
|
|
if (helper[statesWithObservation.size() - 1] == |
|
|
if (helper[statesWithObservation.size() - 1] == |
|
|
storm::utility::convertNumber<ValueType>(gridResolution)) { |
|
|
storm::utility::convertNumber<ValueType>(gridResolution)) { |
|
|
// If the last entry of helper is the gridResolution, we have enumerated all necessary distributions
|
|
|
// If the last entry of helper is the gridResolution, we have enumerated all necessary distributions
|
|
@ -1175,7 +1174,16 @@ namespace storm { |
|
|
uint64_t actionIndex) { |
|
|
uint64_t actionIndex) { |
|
|
std::map<uint32_t, ValueType> res; |
|
|
std::map<uint32_t, ValueType> res; |
|
|
// the id is not important here as we immediately discard the belief (very hacky, I don't like it either)
|
|
|
// the id is not important here as we immediately discard the belief (very hacky, I don't like it either)
|
|
|
std::map<uint64_t, ValueType> postProbabilities = getBeliefAfterAction(pomdp, belief, actionIndex, 0).probabilities; |
|
|
|
|
|
|
|
|
std::map<uint64_t, ValueType> postProbabilities; |
|
|
|
|
|
for (auto const &probEntry : belief.probabilities) { |
|
|
|
|
|
uint64_t state = probEntry.first; |
|
|
|
|
|
auto row = pomdp.getTransitionMatrix().getRow(pomdp.getChoiceIndex(storm::storage::StateActionPair(state, actionIndex))); |
|
|
|
|
|
for (auto const &entry : row) { |
|
|
|
|
|
if (entry.getValue() > 0) { |
|
|
|
|
|
postProbabilities[entry.getColumn()] += belief.probabilities[state] * entry.getValue(); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
for (auto const &probEntry : postProbabilities) { |
|
|
for (auto const &probEntry : postProbabilities) { |
|
|
uint32_t observation = pomdp.getObservation(probEntry.first); |
|
|
uint32_t observation = pomdp.getObservation(probEntry.first); |
|
|
if (res.count(observation) == 0) { |
|
|
if (res.count(observation) == 0) { |
|
@ -1212,7 +1220,6 @@ namespace storm { |
|
|
storm::models::sparse::Pomdp<ValueType, RewardModelType> const &pomdp, std::vector<storm::pomdp::Belief<ValueType>> &beliefList, |
|
|
storm::models::sparse::Pomdp<ValueType, RewardModelType> const &pomdp, std::vector<storm::pomdp::Belief<ValueType>> &beliefList, |
|
|
std::vector<bool> &beliefIsTarget, std::set<uint32_t> const &targetObservations, storm::pomdp::Belief<ValueType> &belief, uint64_t actionIndex, |
|
|
std::vector<bool> &beliefIsTarget, std::set<uint32_t> const &targetObservations, storm::pomdp::Belief<ValueType> &belief, uint64_t actionIndex, |
|
|
uint32_t observation, uint64_t id) { |
|
|
uint32_t observation, uint64_t id) { |
|
|
storm::utility::Stopwatch distrWatch(true); |
|
|
|
|
|
std::map<uint64_t, ValueType> distributionAfter; |
|
|
std::map<uint64_t, ValueType> distributionAfter; |
|
|
for (auto const &probEntry : belief.probabilities) { |
|
|
for (auto const &probEntry : belief.probabilities) { |
|
|
uint64_t state = probEntry.first; |
|
|
uint64_t state = probEntry.first; |
|
@ -1223,9 +1230,7 @@ namespace storm { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
distrWatch.stop(); |
|
|
|
|
|
// We have to normalize the distribution
|
|
|
// We have to normalize the distribution
|
|
|
storm::utility::Stopwatch normalizationWatch(true); |
|
|
|
|
|
auto sum = storm::utility::zero<ValueType>(); |
|
|
auto sum = storm::utility::zero<ValueType>(); |
|
|
for (auto const &entry : distributionAfter) { |
|
|
for (auto const &entry : distributionAfter) { |
|
|
sum += entry.second; |
|
|
sum += entry.second; |
|
@ -1234,19 +1239,12 @@ namespace storm { |
|
|
for (auto const &entry : distributionAfter) { |
|
|
for (auto const &entry : distributionAfter) { |
|
|
distributionAfter[entry.first] /= sum; |
|
|
distributionAfter[entry.first] /= sum; |
|
|
} |
|
|
} |
|
|
normalizationWatch.stop(); |
|
|
|
|
|
if (getBeliefIdInVector(beliefList, observation, distributionAfter) != uint64_t(-1)) { |
|
|
if (getBeliefIdInVector(beliefList, observation, distributionAfter) != uint64_t(-1)) { |
|
|
storm::utility::Stopwatch getWatch(true); |
|
|
|
|
|
auto res = getBeliefIdInVector(beliefList, observation, distributionAfter); |
|
|
auto res = getBeliefIdInVector(beliefList, observation, distributionAfter); |
|
|
getWatch.stop(); |
|
|
|
|
|
//STORM_PRINT("Distribution: "<< distrWatch.getTimeInNanoseconds() << " / Normalization: " << normalizationWatch.getTimeInNanoseconds() << " / getId: " << getWatch.getTimeInNanoseconds() << std::endl)
|
|
|
|
|
|
return res; |
|
|
return res; |
|
|
} else { |
|
|
} else { |
|
|
storm::utility::Stopwatch pushWatch(true); |
|
|
|
|
|
beliefList.push_back(storm::pomdp::Belief<ValueType>{id, observation, distributionAfter}); |
|
|
beliefList.push_back(storm::pomdp::Belief<ValueType>{id, observation, distributionAfter}); |
|
|
beliefIsTarget.push_back(targetObservations.find(observation) != targetObservations.end()); |
|
|
beliefIsTarget.push_back(targetObservations.find(observation) != targetObservations.end()); |
|
|
pushWatch.stop(); |
|
|
|
|
|
//STORM_PRINT("Distribution: "<< distrWatch.getTimeInNanoseconds() << " / Normalization: " << normalizationWatch.getTimeInNanoseconds() << " / generateBelief: " << pushWatch.getTimeInNanoseconds() << std::endl)
|
|
|
|
|
|
return id; |
|
|
return id; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|