Browse Source

compute size of winning region

main
Sebastian Junges 5 years ago
parent
commit
34fce002cb
  1. 141
      src/storm-pomdp/analysis/WinningRegion.cpp
  2. 5
      src/storm-pomdp/analysis/WinningRegion.h

141
src/storm-pomdp/analysis/WinningRegion.cpp

@ -1,6 +1,7 @@
#include <iostream>
#include <boost/algorithm/string.hpp>
#include "storm/utility/file.h"
#include "storm/utility/constants.h"
#include "storm/storage/expressions/Expression.h"
#include "storm/storage/expressions/ExpressionManager.h"
#include "storm-pomdp/analysis/WinningRegion.h"
@ -164,6 +165,146 @@ namespace pomdp {
return winningRegion[observation];
}
storm::RationalNumber WinningRegion::beliefSupportStates() const {
storm::RationalNumber total = 0;
storm::RationalNumber two = storm::utility::convertNumber<storm::RationalNumber>(2);
for (auto const& size : observationSizes) {
total += carl::pow(two,size) - 1;
}
return total;
}
std::pair<storm::RationalNumber, storm::RationalNumber> count(std::vector<storm::storage::BitVector> const& origSets, std::vector<storm::storage::BitVector> const& intersects,
std::vector<storm::storage::BitVector> const& intersectsInfo,
storm::RationalNumber val,
bool plus, uint64_t remdepth) {
assert(intersects.size() == intersectsInfo.size());
storm::RationalNumber newVal = val;
storm::RationalNumber two = storm::utility::convertNumber<storm::RationalNumber>(2);
for (uint64_t i = 0; i < intersects.size(); ++i) {
if(plus) {
newVal += carl::pow(two, intersects[i].getNumberOfSetBits());
} else {
newVal -= carl::pow(two, intersects[i].getNumberOfSetBits());
}
}
storm::RationalNumber diff = val-newVal;
storm::RationalNumber max = storm::utility::max(val,newVal);
diff = storm::utility::abs(diff);
if (remdepth == 0 || 20 * diff < max) {
if (plus) {
return std::make_pair(val, newVal);
} else {
return std::make_pair(newVal, val);
}
} else {
storm::RationalNumber skipped = 0;
uint64_t upperBoundElements = origSets.size() * intersects.size();
STORM_LOG_DEBUG("Upper bound on number of elements to be considered " << upperBoundElements);
STORM_LOG_DEBUG("Value " << val << " newVal " << newVal);
uint64_t oom = 0;
uint64_t critoom = 0;
storm::RationalNumber n = 1;
while (n < max) {
oom += 1;
n *= 2;
}
STORM_LOG_DEBUG("Order of magnitude = " << oom);
critoom = oom - floor(log2(upperBoundElements));
STORM_LOG_DEBUG("Crit Order of magnitude = " << critoom);
uint64_t intersectSetSkip = critoom - floor(log2(origSets.size()));
std::vector<storm::storage::BitVector> useIntersects;
std::vector<storm::storage::BitVector> useInfo;
for(uint64_t i = 0; i < intersects.size(); ++i) {
if (upperBoundElements > 1000 && intersects[i].getNumberOfSetBits() < intersectSetSkip - 2) {
skipped += (carl::pow(two, intersects[i].getNumberOfSetBits()) * origSets.size());
STORM_LOG_DEBUG("Skipped " << skipped);
} else {
useIntersects.push_back(intersects[i]);
useInfo.push_back(intersectsInfo[i]);
}
}
uint64_t origSetSkip = critoom - floor(log2(useIntersects.size()));
STORM_LOG_DEBUG("OrigSkip= " << origSetSkip);
std::vector<storm::storage::BitVector> newIntersects;
std::vector<storm::storage::BitVector> newInfo;
for (uint64_t i = 0; i < origSets.size(); ++i) {
if (upperBoundElements > 1000 && origSets[i].getNumberOfSetBits() < origSetSkip - 2) {
skipped += (carl::pow(two, origSets[i].getNumberOfSetBits()) * useIntersects.size());
STORM_LOG_DEBUG("Skipped " << skipped);
continue;
}
for (uint64_t j = 0; j < useIntersects.size(); ++j ) {
if (useInfo[j].get(i)) {
continue;
}
storm::storage::BitVector newinf = useInfo[j];
newinf.set(i);
if (newinf == useInfo[j]) {
continue;
}
bool exists = false;
for( auto const& entry : newInfo) {
if (entry == newinf) {
exists = true;
break;
}
}
if(!exists) {
newInfo.push_back(newinf);
newIntersects.push_back(origSets[i] & useIntersects[j]);
}
}
}
auto res = count(origSets, newIntersects, newInfo, newVal, !plus, remdepth - 1);
if (plus) {
return std::make_pair(res.first - skipped, res.second);
} else {
return std::make_pair(res.first, res.second + skipped);
}
}
}
std::pair<storm::RationalNumber,storm::RationalNumber> WinningRegion::computeNrWinningBeliefs() const {
storm::RationalNumber upper = 0;
storm::RationalNumber lower = 0;
storm::RationalNumber two = storm::utility::convertNumber<storm::RationalNumber>(2);
for (auto const& winningSets : winningRegion) {
storm::RationalNumber totalForObs = 0;
storm::RationalNumber two = storm::utility::convertNumber<storm::RationalNumber>(2);
std::vector<storm::storage::BitVector> info; // which intersections are part of this
for (uint64_t i = 0; i < winningSets.size(); ++i) {
storm::storage::BitVector entry(winningSets.size());
entry.set(i);
info.push_back(entry);
}
auto res = count(winningSets, winningSets, info, totalForObs, true, 6);
lower += res.first;
upper += res.second;
}
if (lower > 0) {
lower -= 1;
upper -= 1;
}
return std::make_pair(lower,upper);
}
uint64_t WinningRegion::getStorageSize() const {
uint64_t result = 0;
for (uint64_t i = 0; i < getNumberOfObservations(); ++i) {

5
src/storm-pomdp/analysis/WinningRegion.h

@ -12,6 +12,8 @@ namespace storm {
bool update(uint64_t observation, storm::storage::BitVector const& winning);
bool query(uint64_t observation, storm::storage::BitVector const& currently) const;
bool isWinning(uint64_t observation, uint64_t offset) const {
assert(observation < observationSizes.size());
assert(offset < observationSizes[observation]);
storm::storage::BitVector currently(observationSizes[observation]);
currently.set(offset);
return query(observation,currently);
@ -27,6 +29,9 @@ namespace storm {
uint64_t getStorageSize() const;
storm::RationalNumber beliefSupportStates() const;
std::pair<storm::RationalNumber,storm::RationalNumber> computeNrWinningBeliefs() const;
uint64_t getNumberOfObservations() const;
bool empty() const;
void print() const;

|||||||
100:0
Loading…
Cancel
Save