diff --git a/src/storm-pomdp/analysis/WinningRegion.cpp b/src/storm-pomdp/analysis/WinningRegion.cpp index c2408f20e..1b6c37f73 100644 --- a/src/storm-pomdp/analysis/WinningRegion.cpp +++ b/src/storm-pomdp/analysis/WinningRegion.cpp @@ -1,6 +1,7 @@ #include <iostream> #include <boost/algorithm/string.hpp> #include "storm/utility/file.h" +#include "storm/utility/constants.h" #include "storm/storage/expressions/Expression.h" #include "storm/storage/expressions/ExpressionManager.h" #include "storm-pomdp/analysis/WinningRegion.h" @@ -164,6 +165,146 @@ namespace pomdp { return winningRegion[observation]; } + storm::RationalNumber WinningRegion::beliefSupportStates() const { + storm::RationalNumber total = 0; + storm::RationalNumber two = storm::utility::convertNumber<storm::RationalNumber>(2); + for (auto const& size : observationSizes) { + total += carl::pow(two,size) - 1; + } + return total; + } + + std::pair<storm::RationalNumber, storm::RationalNumber> count(std::vector<storm::storage::BitVector> const& origSets, std::vector<storm::storage::BitVector> const& intersects, + std::vector<storm::storage::BitVector> const& intersectsInfo, + storm::RationalNumber val, + bool plus, uint64_t remdepth) { + assert(intersects.size() == intersectsInfo.size()); + storm::RationalNumber newVal = val; + storm::RationalNumber two = storm::utility::convertNumber<storm::RationalNumber>(2); + for (uint64_t i = 0; i < intersects.size(); ++i) { + if(plus) { + newVal += carl::pow(two, intersects[i].getNumberOfSetBits()); + } else { + newVal -= carl::pow(two, intersects[i].getNumberOfSetBits()); + } + } + + storm::RationalNumber diff = val-newVal; + storm::RationalNumber max = storm::utility::max(val,newVal); + + diff = storm::utility::abs(diff); + if (remdepth == 0 || 20 * diff < max) { + if (plus) { + return std::make_pair(val, newVal); + } else { + return std::make_pair(newVal, val); + } + } else { + storm::RationalNumber skipped = 0; + uint64_t upperBoundElements = origSets.size() * intersects.size(); + STORM_LOG_DEBUG("Upper bound on number of elements to be considered " << upperBoundElements); + STORM_LOG_DEBUG("Value " << val << " newVal " << newVal); + uint64_t oom = 0; + uint64_t critoom = 0; + storm::RationalNumber n = 1; + while (n < max) { + oom += 1; + n *= 2; + } + STORM_LOG_DEBUG("Order of magnitude = " << oom); + + critoom = oom - floor(log2(upperBoundElements)); + + + STORM_LOG_DEBUG("Crit Order of magnitude = " << critoom); + + + uint64_t intersectSetSkip = critoom - floor(log2(origSets.size())); + + + std::vector<storm::storage::BitVector> useIntersects; + std::vector<storm::storage::BitVector> useInfo; + for(uint64_t i = 0; i < intersects.size(); ++i) { + if (upperBoundElements > 1000 && intersects[i].getNumberOfSetBits() < intersectSetSkip - 2) { + skipped += (carl::pow(two, intersects[i].getNumberOfSetBits()) * origSets.size()); + STORM_LOG_DEBUG("Skipped " << skipped); + } else { + useIntersects.push_back(intersects[i]); + useInfo.push_back(intersectsInfo[i]); + } + } + + uint64_t origSetSkip = critoom - floor(log2(useIntersects.size())); + STORM_LOG_DEBUG("OrigSkip= " << origSetSkip); + + + std::vector<storm::storage::BitVector> newIntersects; + std::vector<storm::storage::BitVector> newInfo; + + for (uint64_t i = 0; i < origSets.size(); ++i) { + if (upperBoundElements > 1000 && origSets[i].getNumberOfSetBits() < origSetSkip - 2) { + skipped += (carl::pow(two, origSets[i].getNumberOfSetBits()) * useIntersects.size()); + STORM_LOG_DEBUG("Skipped " << skipped); + continue; + } + for (uint64_t j = 0; j < useIntersects.size(); ++j ) { + if (useInfo[j].get(i)) { + continue; + } + storm::storage::BitVector newinf = useInfo[j]; + newinf.set(i); + if (newinf == useInfo[j]) { + continue; + } + bool exists = false; + for( auto const& entry : newInfo) { + if (entry == newinf) { + exists = true; + break; + } + } + if(!exists) { + newInfo.push_back(newinf); + newIntersects.push_back(origSets[i] & useIntersects[j]); + } + } + } + + auto res = count(origSets, newIntersects, newInfo, newVal, !plus, remdepth - 1); + if (plus) { + return std::make_pair(res.first - skipped, res.second); + } else { + return std::make_pair(res.first, res.second + skipped); + } + } + } + + std::pair<storm::RationalNumber,storm::RationalNumber> WinningRegion::computeNrWinningBeliefs() const { + storm::RationalNumber upper = 0; + storm::RationalNumber lower = 0; + storm::RationalNumber two = storm::utility::convertNumber<storm::RationalNumber>(2); + for (auto const& winningSets : winningRegion) { + storm::RationalNumber totalForObs = 0; + storm::RationalNumber two = storm::utility::convertNumber<storm::RationalNumber>(2); + + std::vector<storm::storage::BitVector> info; // which intersections are part of this + for (uint64_t i = 0; i < winningSets.size(); ++i) { + storm::storage::BitVector entry(winningSets.size()); + entry.set(i); + info.push_back(entry); + } + auto res = count(winningSets, winningSets, info, totalForObs, true, 6); + lower += res.first; + upper += res.second; + } + if (lower > 0) { + lower -= 1; + upper -= 1; + } + return std::make_pair(lower,upper); + } + + uint64_t WinningRegion::getStorageSize() const { uint64_t result = 0; for (uint64_t i = 0; i < getNumberOfObservations(); ++i) { diff --git a/src/storm-pomdp/analysis/WinningRegion.h b/src/storm-pomdp/analysis/WinningRegion.h index bc81cf907..88ef79cf7 100644 --- a/src/storm-pomdp/analysis/WinningRegion.h +++ b/src/storm-pomdp/analysis/WinningRegion.h @@ -12,6 +12,8 @@ namespace storm { bool update(uint64_t observation, storm::storage::BitVector const& winning); bool query(uint64_t observation, storm::storage::BitVector const& currently) const; bool isWinning(uint64_t observation, uint64_t offset) const { + assert(observation < observationSizes.size()); + assert(offset < observationSizes[observation]); storm::storage::BitVector currently(observationSizes[observation]); currently.set(offset); return query(observation,currently); @@ -27,6 +29,9 @@ namespace storm { uint64_t getStorageSize() const; + storm::RationalNumber beliefSupportStates() const; + std::pair<storm::RationalNumber,storm::RationalNumber> computeNrWinningBeliefs() const; + uint64_t getNumberOfObservations() const; bool empty() const; void print() const;