From 25074b50a9094e374d71dc941976fb4ee516ae29 Mon Sep 17 00:00:00 2001
From: TimQu <tim.quatmann@cs.rwth-aachen.de>
Date: Thu, 18 May 2017 10:54:38 +0200
Subject: [PATCH] Added function to get the next unset bit in a bitvector

---
 src/storm/storage/BitVector.cpp    | 90 +++++++++++++++++++++---------
 src/storm/storage/BitVector.h      | 16 +++++-
 src/storm/utility/vector.h         | 23 ++++----
 src/test/storage/BitVectorTest.cpp | 15 +++++
 4 files changed, 104 insertions(+), 40 deletions(-)

diff --git a/src/storm/storage/BitVector.cpp b/src/storm/storage/BitVector.cpp
index a0004263a..532e923e3 100644
--- a/src/storm/storage/BitVector.cpp
+++ b/src/storm/storage/BitVector.cpp
@@ -22,7 +22,7 @@ namespace storm {
         BitVector::const_iterator::const_iterator(uint64_t const* dataPtr, uint_fast64_t startIndex, uint_fast64_t endIndex, bool setOnFirstBit) : dataPtr(dataPtr), endIndex(endIndex) {
             if (setOnFirstBit) {
                 // Set the index of the first set bit in the vector.
-                currentIndex = getNextSetIndex(dataPtr, startIndex, endIndex);
+                currentIndex = getNextIndexWithValue(true, dataPtr, startIndex, endIndex);
             } else {
                 currentIndex = startIndex;
             }
@@ -43,13 +43,13 @@ namespace storm {
         }
 
         BitVector::const_iterator& BitVector::const_iterator::operator++() {
-            currentIndex = getNextSetIndex(dataPtr, ++currentIndex, endIndex);
+            currentIndex = getNextIndexWithValue(true, dataPtr, ++currentIndex, endIndex);
             return *this;
         }
 
         BitVector::const_iterator& BitVector::const_iterator::operator+=(size_t n) {
             for(size_t i = 0; i < n; ++i) {
-                currentIndex = getNextSetIndex(dataPtr, ++currentIndex, endIndex);
+                currentIndex = getNextIndexWithValue(true, dataPtr, ++currentIndex, endIndex);
             }
             return *this;
         }
@@ -628,10 +628,17 @@ namespace storm {
         }
 
         uint_fast64_t BitVector::getNextSetIndex(uint_fast64_t startingIndex) const {
-            return getNextSetIndex(buckets, startingIndex, bitCount);
+            return getNextIndexWithValue(true, buckets, startingIndex, bitCount);
         }
 
-        uint_fast64_t BitVector::getNextSetIndex(uint64_t const* dataPtr, uint_fast64_t startingIndex, uint_fast64_t endIndex) {
+       uint_fast64_t BitVector::getNextUnsetIndex(uint_fast64_t startingIndex) const {
+#ifdef ASSERT_BITVECTOR
+           STORM_LOG_ASSERT(getNextIndexWithValue(false, buckets, startingIndex, bitCount) == (~(*this)).getNextSetIndex(startingIndex), "The result is inconsistent with the next set index of the complement of this bitvector");
+#endif
+            return getNextIndexWithValue(false, buckets, startingIndex, bitCount);
+        }
+
+        uint_fast64_t BitVector::getNextIndexWithValue(bool value, uint64_t const* dataPtr, uint_fast64_t startingIndex, uint_fast64_t endIndex) {
             uint_fast8_t currentBitInByte = startingIndex & mod64mask;
             uint64_t const* bucketIt = dataPtr + (startingIndex >> 6);
             startingIndex = (startingIndex >> 6 << 6);
@@ -642,31 +649,62 @@ namespace storm {
             } else {
                 mask = (1ull << (64 - currentBitInByte)) - 1ull;
             }
-            while (startingIndex < endIndex) {
-                // Compute the remaining bucket content.
-                uint64_t remainingInBucket = *bucketIt & mask;
-
-                // Check if there is at least one bit in the remainder of the bucket that is set to true.
-                if (remainingInBucket != 0) {
-                    // As long as the current bit is not set, move the current bit.
-                    while ((remainingInBucket & (1ull << (63 - currentBitInByte))) == 0) {
-                        ++currentBitInByte;
+            
+            // For efficiency reasons, we branch on the desired value at this point
+            if (value) {
+                while (startingIndex < endIndex) {
+                    // Compute the remaining bucket content.
+                    uint64_t remainingInBucket = *bucketIt & mask;
+    
+                    // Check if there is at least one bit in the remainder of the bucket that is set to true.
+                    if (remainingInBucket != 0) {
+                        // As long as the current bit is not set, move the current bit.
+                        while ((remainingInBucket & (1ull << (63 - currentBitInByte))) == 0) {
+                            ++currentBitInByte;
+                        }
+    
+                        // Only return the index of the set bit if we are still in the valid range.
+                        if (startingIndex + currentBitInByte < endIndex) {
+                            return startingIndex + currentBitInByte;
+                        } else {
+                            return endIndex;
+                        }
                     }
-
-                    // Only return the index of the set bit if we are still in the valid range.
-                    if (startingIndex + currentBitInByte < endIndex) {
-                        return startingIndex + currentBitInByte;
-                    } else {
-                        return endIndex;
+    
+                    // Advance to the next bucket.
+                    startingIndex += 64;
+                    ++bucketIt;
+                    mask = -1ull;
+                    currentBitInByte = 0;
+                }
+            } else {
+                while (startingIndex < endIndex) {
+                    // Compute the remaining bucket content.
+                    uint64_t remainingInBucket = *bucketIt & mask;
+    
+                    // Check if there is at least one bit in the remainder of the bucket that is set to false.
+                    if (remainingInBucket != (-1ull & mask)) {
+                        // As long as the current bit is not false, move the current bit.
+                        while ((remainingInBucket & (1ull << (63 - currentBitInByte))) != 0) {
+                            ++currentBitInByte;
+                        }
+    
+                        // Only return the index of the set bit if we are still in the valid range.
+                        if (startingIndex + currentBitInByte < endIndex) {
+                            return startingIndex + currentBitInByte;
+                        } else {
+                            return endIndex;
+                        }
                     }
+    
+                    // Advance to the next bucket.
+                    startingIndex += 64;
+                    ++bucketIt;
+                    mask = -1ull;
+                    currentBitInByte = 0;
                 }
-
-                // Advance to the next bucket.
-                startingIndex += 64;
-                ++bucketIt;
-                mask = -1ull;
-                currentBitInByte = 0;
             }
+            
             return endIndex;
         }
         
diff --git a/src/storm/storage/BitVector.h b/src/storm/storage/BitVector.h
index ecdd03536..dccfa0830 100644
--- a/src/storm/storage/BitVector.h
+++ b/src/storm/storage/BitVector.h
@@ -477,6 +477,17 @@ namespace storm {
              */
             uint_fast64_t getNextSetIndex(uint_fast64_t startingIndex) const;
             
+            /*
+             * Retrieves the index of the bit that is the next bit set to false in the bit vector. If there is none,
+             * this function returns the number of bits this vector holds in total. Put differently, if the return
+             * value is equal to a call to size(), then there is no unset bit after the specified position.
+             *
+             * @param startingIndex The index at which to start the search for the next bit that is not set. The
+             * bit at this index itself is included in the search range.
+             * @return The index of the next bit that is set after the given index.
+             */
+            uint_fast64_t getNextUnsetIndex(uint_fast64_t startingIndex) const;
+            
             /*
              * Compare two intervals [start1, start1+length] and [start2, start2+length] and swap them if the second
              * one is larger than the first one. After the method the intervals are sorted in decreasing order.
@@ -502,15 +513,16 @@ namespace storm {
             BitVector(uint_fast64_t bucketCount, uint_fast64_t bitCount);
             
             /*!
-             * Retrieves the index of the next bit that is set to true after (and including) the given starting index.
+             * Retrieves the index of the next bit that is set to the given value after (and including) the given starting index.
              *
+             * @param value the value of the bit whose index is to be found.
              * @param dataPtr A pointer to the first bucket of the data storage.
              * @param startingIndex The index where to start the search.
              * @param endIndex The index at which to stop the search.
              * @return The index of the bit that is set after the given starting index, but before the given end index
              * in the given bit vector or endIndex in case the end index was reached.
              */
-            static uint_fast64_t getNextSetIndex(uint64_t const* dataPtr, uint_fast64_t startingIndex, uint_fast64_t endIndex);
+            static uint_fast64_t getNextIndexWithValue(bool value, uint64_t const* dataPtr, uint_fast64_t startingIndex, uint_fast64_t endIndex);
             
             /*!
              * Truncate the last bucket so that no bits are set starting from bitCount.
diff --git a/src/storm/utility/vector.h b/src/storm/utility/vector.h
index 6bf06a4f9..76fc58a5c 100644
--- a/src/storm/utility/vector.h
+++ b/src/storm/utility/vector.h
@@ -862,19 +862,18 @@ namespace storm {
             template<typename Type>
             void filterVectorInPlace(std::vector<Type>& v, storm::storage::BitVector const& filter) {
                 STORM_LOG_ASSERT(v.size() == filter.size(), "The filter size does not match the size of the input vector");
-                auto vIt = v.begin();
-                auto filterIt = filter.begin();
-                // get the first position where the filter has a 0.
-                // Note that we do not have to modify the entries of v on all positions before
-                for (uint_fast64_t i = 0; i == *filterIt && filterIt != filter.end(); ++i) {
-                    ++filterIt;
-                    ++vIt;
-                }
-                for (; filterIt != filter.end(); ++filterIt) {
-                    *vIt = std::move(v[*filterIt]);
-                    ++vIt;
+                uint_fast64_t size = v.size();
+                // we can start our work at the first index where the filter has value zero
+                uint_fast64_t firstUnsetIndex = filter.getNextUnsetIndex(0);
+                if (firstUnsetIndex < size) {
+                    auto vIt = v.begin() + firstUnsetIndex;
+                    for (uint_fast64_t index = filter.getNextSetIndex(firstUnsetIndex + 1); index != size; index = filter.getNextSetIndex(index + 1)) {
+                        *vIt = std::move(v[index]);
+                        ++vIt;
+                    }
+                    v.resize(vIt - v.begin());
+                    v.shrink_to_fit();
                 }
-                v.resize(vIt - v.begin());
                 STORM_LOG_ASSERT(v.size() == filter.getNumberOfSetBits(), "Result does not match.");
             }
             
diff --git a/src/test/storage/BitVectorTest.cpp b/src/test/storage/BitVectorTest.cpp
index 9d2fd60c5..0bb9847f6 100644
--- a/src/test/storage/BitVectorTest.cpp
+++ b/src/test/storage/BitVectorTest.cpp
@@ -482,6 +482,21 @@ TEST(BitVectorTest, NextSetIndex) {
     ASSERT_EQ(vector.size(), vector.getNextSetIndex(18));
 }
 
+TEST(BitVectorTest, NextUnsetIndex) {
+    storm::storage::BitVector vector(32);
+    
+    vector.set(14);
+    vector.set(17);
+    
+    vector.complement();
+    
+    ASSERT_EQ(14ul, vector.getNextUnsetIndex(14));
+    ASSERT_EQ(17ul, vector.getNextUnsetIndex(15));
+    ASSERT_EQ(17ul, vector.getNextUnsetIndex(16));
+    ASSERT_EQ(17ul, vector.getNextUnsetIndex(17));
+    ASSERT_EQ(vector.size(), vector.getNextUnsetIndex(18));
+}
+
 TEST(BitVectorTest, Iterator) {
     storm::storage::BitVector vector(32);