diff --git a/src/storage/BitVector.cpp b/src/storage/BitVector.cpp index dff4b2f5f..02e40aa59 100644 --- a/src/storage/BitVector.cpp +++ b/src/storage/BitVector.cpp @@ -10,6 +10,12 @@ #include "src/utility/Hash.h" #include "src/utility/macros.h" +#include + +#ifndef NDEBUG +//#define ASSERT_BITVECTOR +#endif + namespace storm { namespace storage { @@ -413,6 +419,7 @@ namespace storm { } uint_fast64_t BitVector::getAsInt(uint_fast64_t bitIndex, uint_fast64_t numberOfBits) const { + assert(numberOfBits <= 64); uint64_t bucket = bitIndex >> 6; uint64_t bitIndexInBucket = bitIndex & mod64mask; @@ -662,6 +669,103 @@ namespace storm { } return endIndex; } + + BitVector BitVector::getAsBitVector(uint_fast64_t start, uint_fast64_t length) { + BitVector result(length); + + uint_fast64_t index = 0; + for ( ; index + 63 <= length; index += 63) { + result.setFromInt(index, 63, getAsInt(start + index, 63)); + } + // Insert remaining bits + result.setFromInt(index, length-index, getAsInt(start+index, length-index)); + +#ifdef ASSERT_BITVECTOR + // Check correctness of getter + for (uint_fast64_t i = 0; i < length; ++i) { + STORM_LOG_ASSERT(result.get(i) == get(start + i), "Getting of bits not correct for index " << i); + } +#endif + return result; + } + + void BitVector::setFromBitVector(uint_fast64_t start, BitVector const& other) { + uint_fast64_t index = 0; + for ( ; index + 63 <= other.bitCount; index += 63) { + setFromInt(start+index, 63, other.getAsInt(index, 63)); + } + // Insert remaining bits + setFromInt(start+index, other.bitCount-index, other.getAsInt(index, other.bitCount-index)); + +#ifdef ASSERT_BITVECTOR + // Check correctness of setter + for (uint_fast64_t i = 0; i < other.bitCount; ++i) { + STORM_LOG_ASSERT(other.get(i) == get(start + i), "Setting of bits not correct for index " << i); + } +#endif + } + + bool BitVector::compareAndSwap(uint_fast64_t start1, uint_fast64_t start2, uint_fast64_t length) { + if (length < 64) { + uint_fast64_t elem1 = getAsInt(start1, length); + uint_fast64_t elem2 = getAsInt(start2, length); + if (elem1 < elem2) { + // Swap elements + setFromInt(start1, length, elem2); + setFromInt(start2, length, elem1); + return true; + } + return false; + } else { + //TODO improve performance + BitVector elem1 = getAsBitVector(start1, length); + BitVector elem2 = getAsBitVector(start2, length); + + if (!(elem1 < elem2)) { + // Elements already sorted +#ifdef ASSERT_BITVECTOR + // Check that sorted + for (uint_fast64_t i = 0; i < length; ++i) { + if (get(start1 + i) > get(start2 + i)) { + break; + } + STORM_LOG_ASSERT(get(start1 + i) >= get(start2 + i), "Bit vector not sorted for indices " << start1+i << " and " << start2+i); + } +#endif + return false; + } + +#ifdef ASSERT_BITVECTOR + BitVector check(*this); +#endif + + // Swap elements + setFromBitVector(start1, elem2); + setFromBitVector(start2, elem1); + +#ifdef ASSERT_BITVECTOR + // Check correctness of swapping + bool tmp; + for (uint_fast64_t i = 0; i < length; ++i) { + tmp = check.get(i + start1); + check.set(i + start1, check.get(i + start2)); + check.set(i + start2, tmp); + } + STORM_LOG_ASSERT(*this == check, "Swapping not correct"); + + // Check that sorted + for (uint_fast64_t i = 0; i < length; ++i) { + if (get(start1 + i) > get(start2 + i)) { + break; + } + STORM_LOG_ASSERT(get(start1 + i) >= get(start2 + i), "Bit vector not sorted for indices " << start1+i << " and " << start2+i); + } +#endif + + return true; + } + } + void BitVector::truncateLastBucket() { if ((bitCount & mod64mask) != 0) { @@ -682,6 +786,26 @@ namespace storm { return out; } + + void BitVector::printBits(std::ostream& out) { + out << "bit vector(" << getNumberOfSetBits() << "/" << bitCount << ") "; + uint_fast64_t index = 0; + for ( ; index * 64 <= bitCount; ++index) { + std::bitset<64> tmp(bucketVector[index]); + out << tmp << "|"; + } + + --index; + // Print last bits + if (index * 64 < bitCount) { + assert(index == bucketVector.size() - 1); + std::bitset<64> tmp(bucketVector[index]); + for (size_t i = 0; i + index * 64 < bitCount; ++i) { + out << tmp[i]; + } + } + out << std::endl; + } std::size_t NonZeroBitVectorHash::operator()(storm::storage::BitVector const& bv) const { STORM_LOG_ASSERT(bv.size() > 0, "Cannot hash bit vector of zero size."); diff --git a/src/storage/BitVector.h b/src/storage/BitVector.h index 695bdcb82..e3f81ebc9 100644 --- a/src/storage/BitVector.h +++ b/src/storage/BitVector.h @@ -461,6 +461,17 @@ namespace storm { */ uint_fast64_t getNextSetIndex(uint_fast64_t startingIndex) const; + /* + * Compare two intervals [start1, start1+length] and [start2, start2+length] and swap them if the second + * one is larger than the first one. After the method the intervals are sorted in decreasing order. + * + * @param start1 Starting index of first interval. + * @param start2 Starting index of second interval. + * @param length Length of both intervals. + * @return True, if the intervals were swapped, false if nothing changed. + */ + bool compareAndSwap(uint_fast64_t start1, uint_fast64_t start2, uint_fast64_t length); + friend std::ostream& operator<<(std::ostream& out, BitVector const& bitVector); friend struct std::hash; friend struct NonZeroBitVectorHash; @@ -490,6 +501,31 @@ namespace storm { */ void truncateLastBucket(); + /*! Retrieves the content of the current bit vector at the given index for the given number of bits as a new + * bit vector. + * + * @param start The index of the first bit to get. + * @param length The number of bits to get. + * @return A new bit vector holding the selected bits. + */ + BitVector getAsBitVector(uint_fast64_t start, uint_fast64_t length); + + /*! + * Sets the exact bit pattern of the given bit vector starting at the given bit index. Note: the given bit + * vector must be shorter than the current one minus the given index. + * + * @param start The index of the first bit that is supposed to be set. + * @param other The bit vector whose pattern to set. + */ + void setFromBitVector(uint_fast64_t start, BitVector const& other); + + /*! + * Print bit vector and display all bits. + * + * @param out Stream to print to. + */ + void printBits(std::ostream& out); + /*! * Retrieves the number of buckets of the underlying storage. * diff --git a/src/storage/dft/DFTState.cpp b/src/storage/dft/DFTState.cpp index 7d3b25c8a..547fdd169 100644 --- a/src/storage/dft/DFTState.cpp +++ b/src/storage/dft/DFTState.cpp @@ -265,17 +265,12 @@ namespace storm { std::vector symmetryIndices = mStateGenerationInfo.getSymmetryIndices(pos); // Sort symmetry group in decreasing order by bubble sort // TODO use better algorithm? - size_t tmp, elem1, elem2; + size_t tmp; size_t n = symmetryIndices.size(); do { tmp = 0; for (size_t i = 1; i < n; ++i) { - elem1 = mStatus.getAsInt(symmetryIndices[i-1], length); - elem2 = mStatus.getAsInt(symmetryIndices[i], length); - if (elem1 < elem2) { - // Swap elements - mStatus.setFromInt(symmetryIndices[i-1], length, elem2); - mStatus.setFromInt(symmetryIndices[i], length, elem1); + if (mStatus.compareAndSwap(symmetryIndices[i-1], symmetryIndices[i], length)) { tmp = i; changed = true; }