1//===- llvm/ADT/SparseBitVector.h - Efficient Sparse BitVector --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SparseBitVector class. See the doxygen comment for
10// SparseBitVector for more details on the algorithm used.
11//
12//===----------------------------------------------------------------------===//
13
14#pragma once
15#include <c10/macros/Macros.h>
16#include <c10/util/llvmMathExtras.h>
17#include <cassert>
18#include <climits>
19#include <cstring>
20#include <iterator>
21#include <list>
22
23C10_CLANG_DIAGNOSTIC_PUSH()
24#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
25C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
26#endif
27
28namespace c10 {
29
30/// SparseBitVector is an implementation of a bitvector that is sparse by only
31/// storing the elements that have non-zero bits set. In order to make this
32/// fast for the most common cases, SparseBitVector is implemented as a linked
33/// list of SparseBitVectorElements. We maintain a pointer to the last
34/// SparseBitVectorElement accessed (in the form of a list iterator), in order
35/// to make multiple in-order test/set constant time after the first one is
36/// executed. Note that using vectors to store SparseBitVectorElement's does
37/// not work out very well because it causes insertion in the middle to take
38/// enormous amounts of time with a large amount of bits. Other structures that
39/// have better worst cases for insertion in the middle (various balanced trees,
40/// etc) do not perform as well in practice as a linked list with this iterator
41/// kept up to date. They are also significantly more memory intensive.
42
43template <unsigned ElementSize = 128>
44struct SparseBitVectorElement {
45 public:
46 using BitWord = unsigned long;
47 using size_type = unsigned;
48 enum {
49 BITWORD_SIZE = sizeof(BitWord) * CHAR_BIT,
50 BITWORDS_PER_ELEMENT = (ElementSize + BITWORD_SIZE - 1) / BITWORD_SIZE,
51 BITS_PER_ELEMENT = ElementSize
52 };
53
54 private:
55 // Index of Element in terms of where first bit starts.
56 unsigned ElementIndex;
57 BitWord Bits[BITWORDS_PER_ELEMENT];
58
59 SparseBitVectorElement() {
60 ElementIndex = ~0U;
61 memset(&Bits[0], 0, sizeof(BitWord) * BITWORDS_PER_ELEMENT);
62 }
63
64 public:
65 explicit SparseBitVectorElement(unsigned Idx) {
66 ElementIndex = Idx;
67 memset(&Bits[0], 0, sizeof(BitWord) * BITWORDS_PER_ELEMENT);
68 }
69
70 // Comparison.
71 bool operator==(const SparseBitVectorElement& RHS) const {
72 if (ElementIndex != RHS.ElementIndex)
73 return false;
74 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
75 if (Bits[i] != RHS.Bits[i])
76 return false;
77 return true;
78 }
79
80 bool operator!=(const SparseBitVectorElement& RHS) const {
81 return !(*this == RHS);
82 }
83
84 // Return the bits that make up word Idx in our element.
85 BitWord word(unsigned Idx) const {
86 assert(Idx < BITWORDS_PER_ELEMENT);
87 return Bits[Idx];
88 }
89
90 unsigned index() const {
91 return ElementIndex;
92 }
93
94 bool empty() const {
95 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
96 if (Bits[i])
97 return false;
98 return true;
99 }
100
101 void set(unsigned Idx) {
102 Bits[Idx / BITWORD_SIZE] |= 1L << (Idx % BITWORD_SIZE);
103 }
104
105 bool test_and_set(unsigned Idx) {
106 bool old = test(Idx);
107 if (!old) {
108 set(Idx);
109 return true;
110 }
111 return false;
112 }
113
114 void reset(unsigned Idx) {
115 Bits[Idx / BITWORD_SIZE] &= ~(1L << (Idx % BITWORD_SIZE));
116 }
117
118 bool test(unsigned Idx) const {
119 return Bits[Idx / BITWORD_SIZE] & (1L << (Idx % BITWORD_SIZE));
120 }
121
122 size_type count() const {
123 unsigned NumBits = 0;
124 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
125 NumBits += llvm::countPopulation(Bits[i]);
126 return NumBits;
127 }
128
129 /// find_first - Returns the index of the first set bit.
130 int find_first() const {
131 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i)
132 if (Bits[i] != 0)
133 return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
134 throw std::runtime_error("Illegal empty element");
135 }
136
137 /// find_last - Returns the index of the last set bit.
138 int find_last() const {
139 for (unsigned I = 0; I < BITWORDS_PER_ELEMENT; ++I) {
140 unsigned Idx = BITWORDS_PER_ELEMENT - I - 1;
141 if (Bits[Idx] != 0)
142 return Idx * BITWORD_SIZE + BITWORD_SIZE -
143 llvm::countLeadingZeros(Bits[Idx]);
144 }
145 throw std::runtime_error("Illegal empty element");
146 }
147
148 /// find_next - Returns the index of the next set bit starting from the
149 /// "Curr" bit. Returns -1 if the next set bit is not found.
150 int find_next(unsigned Curr) const {
151 if (Curr >= BITS_PER_ELEMENT)
152 return -1;
153
154 unsigned WordPos = Curr / BITWORD_SIZE;
155 unsigned BitPos = Curr % BITWORD_SIZE;
156 BitWord Copy = Bits[WordPos];
157 assert(
158 WordPos <= BITWORDS_PER_ELEMENT && "Word Position outside of element");
159
160 // Mask off previous bits.
161 Copy &= ~0UL << BitPos;
162
163 if (Copy != 0)
164 return WordPos * BITWORD_SIZE + llvm::countTrailingZeros(Copy);
165
166 // Check subsequent words.
167 for (unsigned i = WordPos + 1; i < BITWORDS_PER_ELEMENT; ++i)
168 if (Bits[i] != 0)
169 return i * BITWORD_SIZE + llvm::countTrailingZeros(Bits[i]);
170 return -1;
171 }
172
173 // Union this element with RHS and return true if this one changed.
174 bool unionWith(const SparseBitVectorElement& RHS) {
175 bool changed = false;
176 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
177 BitWord old = changed ? 0 : Bits[i];
178
179 Bits[i] |= RHS.Bits[i];
180 if (!changed && old != Bits[i])
181 changed = true;
182 }
183 return changed;
184 }
185
186 // Return true if we have any bits in common with RHS
187 bool intersects(const SparseBitVectorElement& RHS) const {
188 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
189 if (RHS.Bits[i] & Bits[i])
190 return true;
191 }
192 return false;
193 }
194
195 // Intersect this Element with RHS and return true if this one changed.
196 // BecameZero is set to true if this element became all-zero bits.
197 bool intersectWith(const SparseBitVectorElement& RHS, bool& BecameZero) {
198 bool changed = false;
199 bool allzero = true;
200
201 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
202 BitWord old = changed ? 0 : Bits[i];
203
204 Bits[i] &= RHS.Bits[i];
205 if (Bits[i] != 0)
206 allzero = false;
207
208 if (!changed && old != Bits[i])
209 changed = true;
210 }
211 BecameZero = allzero;
212 return changed;
213 }
214
215 // Intersect this Element with the complement of RHS and return true if this
216 // one changed. BecameZero is set to true if this element became all-zero
217 // bits.
218 bool intersectWithComplement(
219 const SparseBitVectorElement& RHS,
220 bool& BecameZero) {
221 bool changed = false;
222 bool allzero = true;
223
224 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
225 BitWord old = changed ? 0 : Bits[i];
226
227 Bits[i] &= ~RHS.Bits[i];
228 if (Bits[i] != 0)
229 allzero = false;
230
231 if (!changed && old != Bits[i])
232 changed = true;
233 }
234 BecameZero = allzero;
235 return changed;
236 }
237
238 // Three argument version of intersectWithComplement that intersects
239 // RHS1 & ~RHS2 into this element
240 void intersectWithComplement(
241 const SparseBitVectorElement& RHS1,
242 const SparseBitVectorElement& RHS2,
243 bool& BecameZero) {
244 bool allzero = true;
245
246 for (unsigned i = 0; i < BITWORDS_PER_ELEMENT; ++i) {
247 Bits[i] = RHS1.Bits[i] & ~RHS2.Bits[i];
248 if (Bits[i] != 0)
249 allzero = false;
250 }
251 BecameZero = allzero;
252 }
253};
254
255template <unsigned ElementSize = 128>
256class SparseBitVector {
257 using ElementList = std::list<SparseBitVectorElement<ElementSize>>;
258 using ElementListIter = typename ElementList::iterator;
259 using ElementListConstIter = typename ElementList::const_iterator;
260 enum { BITWORD_SIZE = SparseBitVectorElement<ElementSize>::BITWORD_SIZE };
261
262 ElementList Elements;
263 // Pointer to our current Element. This has no visible effect on the external
264 // state of a SparseBitVector, it's just used to improve performance in the
265 // common case of testing/modifying bits with similar indices.
266 mutable ElementListIter CurrElementIter;
267
268 // This is like std::lower_bound, except we do linear searching from the
269 // current position.
270 ElementListIter FindLowerBoundImpl(unsigned ElementIndex) const {
271 // We cache a non-const iterator so we're forced to resort to const_cast to
272 // get the begin/end in the case where 'this' is const. To avoid duplication
273 // of code with the only difference being whether the const cast is present
274 // 'this' is always const in this particular function and we sort out the
275 // difference in FindLowerBound and FindLowerBoundConst.
276 ElementListIter Begin =
277 const_cast<SparseBitVector<ElementSize>*>(this)->Elements.begin();
278 ElementListIter End =
279 const_cast<SparseBitVector<ElementSize>*>(this)->Elements.end();
280
281 if (Elements.empty()) {
282 CurrElementIter = Begin;
283 return CurrElementIter;
284 }
285
286 // Make sure our current iterator is valid.
287 if (CurrElementIter == End)
288 --CurrElementIter;
289
290 // Search from our current iterator, either backwards or forwards,
291 // depending on what element we are looking for.
292 ElementListIter ElementIter = CurrElementIter;
293 if (CurrElementIter->index() == ElementIndex) {
294 return ElementIter;
295 } else if (CurrElementIter->index() > ElementIndex) {
296 while (ElementIter != Begin && ElementIter->index() > ElementIndex)
297 --ElementIter;
298 } else {
299 while (ElementIter != End && ElementIter->index() < ElementIndex)
300 ++ElementIter;
301 }
302 CurrElementIter = ElementIter;
303 return ElementIter;
304 }
305 ElementListConstIter FindLowerBoundConst(unsigned ElementIndex) const {
306 return FindLowerBoundImpl(ElementIndex);
307 }
308 ElementListIter FindLowerBound(unsigned ElementIndex) {
309 return FindLowerBoundImpl(ElementIndex);
310 }
311
312 // Iterator to walk set bits in the bitmap. This iterator is a lot uglier
313 // than it would be, in order to be efficient.
314 class SparseBitVectorIterator {
315 private:
316 bool AtEnd;
317
318 const SparseBitVector<ElementSize>* BitVector = nullptr;
319
320 // Current element inside of bitmap.
321 ElementListConstIter Iter;
322
323 // Current bit number inside of our bitmap.
324 unsigned BitNumber;
325
326 // Current word number inside of our element.
327 unsigned WordNumber;
328
329 // Current bits from the element.
330 typename SparseBitVectorElement<ElementSize>::BitWord Bits;
331
332 // Move our iterator to the first non-zero bit in the bitmap.
333 void AdvanceToFirstNonZero() {
334 if (AtEnd)
335 return;
336 if (BitVector->Elements.empty()) {
337 AtEnd = true;
338 return;
339 }
340 Iter = BitVector->Elements.begin();
341 BitNumber = Iter->index() * ElementSize;
342 unsigned BitPos = Iter->find_first();
343 BitNumber += BitPos;
344 WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
345 Bits = Iter->word(WordNumber);
346 Bits >>= BitPos % BITWORD_SIZE;
347 }
348
349 // Move our iterator to the next non-zero bit.
350 void AdvanceToNextNonZero() {
351 if (AtEnd)
352 return;
353
354 while (Bits && !(Bits & 1)) {
355 Bits >>= 1;
356 BitNumber += 1;
357 }
358
359 // See if we ran out of Bits in this word.
360 if (!Bits) {
361 int NextSetBitNumber = Iter->find_next(BitNumber % ElementSize);
362 // If we ran out of set bits in this element, move to next element.
363 if (NextSetBitNumber == -1 || (BitNumber % ElementSize == 0)) {
364 ++Iter;
365 WordNumber = 0;
366
367 // We may run out of elements in the bitmap.
368 if (Iter == BitVector->Elements.end()) {
369 AtEnd = true;
370 return;
371 }
372 // Set up for next non-zero word in bitmap.
373 BitNumber = Iter->index() * ElementSize;
374 NextSetBitNumber = Iter->find_first();
375 BitNumber += NextSetBitNumber;
376 WordNumber = (BitNumber % ElementSize) / BITWORD_SIZE;
377 Bits = Iter->word(WordNumber);
378 Bits >>= NextSetBitNumber % BITWORD_SIZE;
379 } else {
380 WordNumber = (NextSetBitNumber % ElementSize) / BITWORD_SIZE;
381 Bits = Iter->word(WordNumber);
382 Bits >>= NextSetBitNumber % BITWORD_SIZE;
383 BitNumber = Iter->index() * ElementSize;
384 BitNumber += NextSetBitNumber;
385 }
386 }
387 }
388
389 public:
390 SparseBitVectorIterator() = default;
391
392 SparseBitVectorIterator(
393 const SparseBitVector<ElementSize>* RHS,
394 bool end = false)
395 : BitVector(RHS) {
396 Iter = BitVector->Elements.begin();
397 BitNumber = 0;
398 Bits = 0;
399 WordNumber = ~0;
400 AtEnd = end;
401 AdvanceToFirstNonZero();
402 }
403
404 // Preincrement.
405 inline SparseBitVectorIterator& operator++() {
406 ++BitNumber;
407 Bits >>= 1;
408 AdvanceToNextNonZero();
409 return *this;
410 }
411
412 // Postincrement.
413 inline SparseBitVectorIterator operator++(int) {
414 SparseBitVectorIterator tmp = *this;
415 ++*this;
416 return tmp;
417 }
418
419 // Return the current set bit number.
420 unsigned operator*() const {
421 return BitNumber;
422 }
423
424 bool operator==(const SparseBitVectorIterator& RHS) const {
425 // If they are both at the end, ignore the rest of the fields.
426 if (AtEnd && RHS.AtEnd)
427 return true;
428 // Otherwise they are the same if they have the same bit number and
429 // bitmap.
430 return AtEnd == RHS.AtEnd && RHS.BitNumber == BitNumber;
431 }
432
433 bool operator!=(const SparseBitVectorIterator& RHS) const {
434 return !(*this == RHS);
435 }
436 };
437
438 public:
439 using iterator = SparseBitVectorIterator;
440
441 SparseBitVector() : Elements(), CurrElementIter(Elements.begin()) {}
442
443 SparseBitVector(const SparseBitVector& RHS)
444 : Elements(RHS.Elements), CurrElementIter(Elements.begin()) {}
445 SparseBitVector(SparseBitVector&& RHS)
446 : Elements(std::move(RHS.Elements)), CurrElementIter(Elements.begin()) {}
447
448 // Clear.
449 void clear() {
450 Elements.clear();
451 }
452
453 // Assignment
454 SparseBitVector& operator=(const SparseBitVector& RHS) {
455 if (this == &RHS)
456 return *this;
457
458 Elements = RHS.Elements;
459 CurrElementIter = Elements.begin();
460 return *this;
461 }
462 SparseBitVector& operator=(SparseBitVector&& RHS) {
463 Elements = std::move(RHS.Elements);
464 CurrElementIter = Elements.begin();
465 return *this;
466 }
467
468 // Test, Reset, and Set a bit in the bitmap.
469 bool test(unsigned Idx) const {
470 if (Elements.empty())
471 return false;
472
473 unsigned ElementIndex = Idx / ElementSize;
474 ElementListConstIter ElementIter = FindLowerBoundConst(ElementIndex);
475
476 // If we can't find an element that is supposed to contain this bit, there
477 // is nothing more to do.
478 if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex)
479 return false;
480 return ElementIter->test(Idx % ElementSize);
481 }
482
483 void reset(unsigned Idx) {
484 if (Elements.empty())
485 return;
486
487 unsigned ElementIndex = Idx / ElementSize;
488 ElementListIter ElementIter = FindLowerBound(ElementIndex);
489
490 // If we can't find an element that is supposed to contain this bit, there
491 // is nothing more to do.
492 if (ElementIter == Elements.end() || ElementIter->index() != ElementIndex)
493 return;
494 ElementIter->reset(Idx % ElementSize);
495
496 // When the element is zeroed out, delete it.
497 if (ElementIter->empty()) {
498 ++CurrElementIter;
499 Elements.erase(ElementIter);
500 }
501 }
502
503 void set(unsigned Idx) {
504 unsigned ElementIndex = Idx / ElementSize;
505 ElementListIter ElementIter;
506 if (Elements.empty()) {
507 ElementIter = Elements.emplace(Elements.end(), ElementIndex);
508 } else {
509 ElementIter = FindLowerBound(ElementIndex);
510
511 if (ElementIter == Elements.end() ||
512 ElementIter->index() != ElementIndex) {
513 // We may have hit the beginning of our SparseBitVector, in which case,
514 // we may need to insert right after this element, which requires moving
515 // the current iterator forward one, because insert does insert before.
516 if (ElementIter != Elements.end() &&
517 ElementIter->index() < ElementIndex)
518 ++ElementIter;
519 ElementIter = Elements.emplace(ElementIter, ElementIndex);
520 }
521 }
522 CurrElementIter = ElementIter;
523
524 ElementIter->set(Idx % ElementSize);
525 }
526
527 bool test_and_set(unsigned Idx) {
528 bool old = test(Idx);
529 if (!old) {
530 set(Idx);
531 return true;
532 }
533 return false;
534 }
535
536 bool operator!=(const SparseBitVector& RHS) const {
537 return !(*this == RHS);
538 }
539
540 bool operator==(const SparseBitVector& RHS) const {
541 ElementListConstIter Iter1 = Elements.begin();
542 ElementListConstIter Iter2 = RHS.Elements.begin();
543
544 for (; Iter1 != Elements.end() && Iter2 != RHS.Elements.end();
545 ++Iter1, ++Iter2) {
546 if (*Iter1 != *Iter2)
547 return false;
548 }
549 return Iter1 == Elements.end() && Iter2 == RHS.Elements.end();
550 }
551
552 // Union our bitmap with the RHS and return true if we changed.
553 bool operator|=(const SparseBitVector& RHS) {
554 if (this == &RHS)
555 return false;
556
557 if (empty()) {
558 *this = RHS;
559 return true;
560 }
561
562 bool changed = false;
563 ElementListIter Iter1 = Elements.begin();
564 ElementListConstIter Iter2 = RHS.Elements.begin();
565
566 // If RHS is empty, we are done
567 if (RHS.Elements.empty())
568 return false;
569
570 while (Iter2 != RHS.Elements.end()) {
571 if (Iter1 == Elements.end() || Iter1->index() > Iter2->index()) {
572 Elements.insert(Iter1, *Iter2);
573 ++Iter2;
574 changed = true;
575 } else if (Iter1->index() == Iter2->index()) {
576 changed |= Iter1->unionWith(*Iter2);
577 ++Iter1;
578 ++Iter2;
579 } else {
580 ++Iter1;
581 }
582 }
583 CurrElementIter = Elements.begin();
584 return changed;
585 }
586
587 // Intersect our bitmap with the RHS and return true if ours changed.
588 bool operator-=(const SparseBitVector& RHS) {
589 return intersectWithComplement(RHS);
590 }
591
592 // Intersect our bitmap with the RHS and return true if ours changed.
593 bool operator&=(const SparseBitVector& RHS) {
594 if (this == &RHS)
595 return false;
596
597 bool changed = false;
598 ElementListIter Iter1 = Elements.begin();
599 ElementListConstIter Iter2 = RHS.Elements.begin();
600
601 // Check if both bitmaps are empty.
602 if (Elements.empty() && RHS.Elements.empty())
603 return false;
604
605 // Loop through, intersecting as we go, erasing elements when necessary.
606 while (Iter2 != RHS.Elements.end()) {
607 if (Iter1 == Elements.end()) {
608 CurrElementIter = Elements.begin();
609 return changed;
610 }
611
612 if (Iter1->index() > Iter2->index()) {
613 ++Iter2;
614 } else if (Iter1->index() == Iter2->index()) {
615 bool BecameZero;
616 changed |= Iter1->intersectWith(*Iter2, BecameZero);
617 if (BecameZero) {
618 ElementListIter IterTmp = Iter1;
619 ++Iter1;
620 Elements.erase(IterTmp);
621 } else {
622 ++Iter1;
623 }
624 ++Iter2;
625 } else {
626 ElementListIter IterTmp = Iter1;
627 ++Iter1;
628 Elements.erase(IterTmp);
629 changed = true;
630 }
631 }
632 if (Iter1 != Elements.end()) {
633 Elements.erase(Iter1, Elements.end());
634 changed = true;
635 }
636 CurrElementIter = Elements.begin();
637 return changed;
638 }
639
640 // Intersect our bitmap with the complement of the RHS and return true
641 // if ours changed.
642 bool intersectWithComplement(const SparseBitVector& RHS) {
643 if (this == &RHS) {
644 if (!empty()) {
645 clear();
646 return true;
647 }
648 return false;
649 }
650
651 bool changed = false;
652 ElementListIter Iter1 = Elements.begin();
653 ElementListConstIter Iter2 = RHS.Elements.begin();
654
655 // If either our bitmap or RHS is empty, we are done
656 if (Elements.empty() || RHS.Elements.empty())
657 return false;
658
659 // Loop through, intersecting as we go, erasing elements when necessary.
660 while (Iter2 != RHS.Elements.end()) {
661 if (Iter1 == Elements.end()) {
662 CurrElementIter = Elements.begin();
663 return changed;
664 }
665
666 if (Iter1->index() > Iter2->index()) {
667 ++Iter2;
668 } else if (Iter1->index() == Iter2->index()) {
669 bool BecameZero;
670 changed |= Iter1->intersectWithComplement(*Iter2, BecameZero);
671 if (BecameZero) {
672 ElementListIter IterTmp = Iter1;
673 ++Iter1;
674 Elements.erase(IterTmp);
675 } else {
676 ++Iter1;
677 }
678 ++Iter2;
679 } else {
680 ++Iter1;
681 }
682 }
683 CurrElementIter = Elements.begin();
684 return changed;
685 }
686
687 bool intersectWithComplement(const SparseBitVector<ElementSize>* RHS) const {
688 return intersectWithComplement(*RHS);
689 }
690
691 // Three argument version of intersectWithComplement.
692 // Result of RHS1 & ~RHS2 is stored into this bitmap.
693 void intersectWithComplement(
694 const SparseBitVector<ElementSize>& RHS1,
695 const SparseBitVector<ElementSize>& RHS2) {
696 if (this == &RHS1) {
697 intersectWithComplement(RHS2);
698 return;
699 } else if (this == &RHS2) {
700 SparseBitVector RHS2Copy(RHS2);
701 intersectWithComplement(RHS1, RHS2Copy);
702 return;
703 }
704
705 Elements.clear();
706 CurrElementIter = Elements.begin();
707 ElementListConstIter Iter1 = RHS1.Elements.begin();
708 ElementListConstIter Iter2 = RHS2.Elements.begin();
709
710 // If RHS1 is empty, we are done
711 // If RHS2 is empty, we still have to copy RHS1
712 if (RHS1.Elements.empty())
713 return;
714
715 // Loop through, intersecting as we go, erasing elements when necessary.
716 while (Iter2 != RHS2.Elements.end()) {
717 if (Iter1 == RHS1.Elements.end())
718 return;
719
720 if (Iter1->index() > Iter2->index()) {
721 ++Iter2;
722 } else if (Iter1->index() == Iter2->index()) {
723 bool BecameZero = false;
724 Elements.emplace_back(Iter1->index());
725 Elements.back().intersectWithComplement(*Iter1, *Iter2, BecameZero);
726 if (BecameZero)
727 Elements.pop_back();
728 ++Iter1;
729 ++Iter2;
730 } else {
731 Elements.push_back(*Iter1++);
732 }
733 }
734
735 // copy the remaining elements
736 std::copy(Iter1, RHS1.Elements.end(), std::back_inserter(Elements));
737 }
738
739 void intersectWithComplement(
740 const SparseBitVector<ElementSize>* RHS1,
741 const SparseBitVector<ElementSize>* RHS2) {
742 intersectWithComplement(*RHS1, *RHS2);
743 }
744
745 bool intersects(const SparseBitVector<ElementSize>* RHS) const {
746 return intersects(*RHS);
747 }
748
749 // Return true if we share any bits in common with RHS
750 bool intersects(const SparseBitVector<ElementSize>& RHS) const {
751 ElementListConstIter Iter1 = Elements.begin();
752 ElementListConstIter Iter2 = RHS.Elements.begin();
753
754 // Check if both bitmaps are empty.
755 if (Elements.empty() && RHS.Elements.empty())
756 return false;
757
758 // Loop through, intersecting stopping when we hit bits in common.
759 while (Iter2 != RHS.Elements.end()) {
760 if (Iter1 == Elements.end())
761 return false;
762
763 if (Iter1->index() > Iter2->index()) {
764 ++Iter2;
765 } else if (Iter1->index() == Iter2->index()) {
766 if (Iter1->intersects(*Iter2))
767 return true;
768 ++Iter1;
769 ++Iter2;
770 } else {
771 ++Iter1;
772 }
773 }
774 return false;
775 }
776
777 // Return true iff all bits set in this SparseBitVector are
778 // also set in RHS.
779 bool contains(const SparseBitVector<ElementSize>& RHS) const {
780 SparseBitVector<ElementSize> Result(*this);
781 Result &= RHS;
782 return (Result == RHS);
783 }
784
785 // Return the first set bit in the bitmap. Return -1 if no bits are set.
786 int find_first() const {
787 if (Elements.empty())
788 return -1;
789 const SparseBitVectorElement<ElementSize>& First = *(Elements.begin());
790 return (First.index() * ElementSize) + First.find_first();
791 }
792
793 // Return the last set bit in the bitmap. Return -1 if no bits are set.
794 int find_last() const {
795 if (Elements.empty())
796 return -1;
797 const SparseBitVectorElement<ElementSize>& Last = *(Elements.rbegin());
798 return (Last.index() * ElementSize) + Last.find_last();
799 }
800
801 // Return true if the SparseBitVector is empty
802 bool empty() const {
803 return Elements.empty();
804 }
805
806 unsigned count() const {
807 unsigned BitCount = 0;
808 for (ElementListConstIter Iter = Elements.begin(); Iter != Elements.end();
809 ++Iter)
810 BitCount += Iter->count();
811
812 return BitCount;
813 }
814
815 iterator begin() const {
816 return iterator(this);
817 }
818
819 iterator end() const {
820 return iterator(this, true);
821 }
822};
823
824// Convenience functions to allow Or and And without dereferencing in the user
825// code.
826
827template <unsigned ElementSize>
828inline bool operator|=(
829 SparseBitVector<ElementSize>& LHS,
830 const SparseBitVector<ElementSize>* RHS) {
831 return LHS |= *RHS;
832}
833
834template <unsigned ElementSize>
835inline bool operator|=(
836 SparseBitVector<ElementSize>* LHS,
837 const SparseBitVector<ElementSize>& RHS) {
838 return LHS->operator|=(RHS);
839}
840
841template <unsigned ElementSize>
842inline bool operator&=(
843 SparseBitVector<ElementSize>* LHS,
844 const SparseBitVector<ElementSize>& RHS) {
845 return LHS->operator&=(RHS);
846}
847
848template <unsigned ElementSize>
849inline bool operator&=(
850 SparseBitVector<ElementSize>& LHS,
851 const SparseBitVector<ElementSize>* RHS) {
852 return LHS &= *RHS;
853}
854
855// Convenience functions for infix union, intersection, difference operators.
856
857template <unsigned ElementSize>
858inline SparseBitVector<ElementSize> operator|(
859 const SparseBitVector<ElementSize>& LHS,
860 const SparseBitVector<ElementSize>& RHS) {
861 SparseBitVector<ElementSize> Result(LHS);
862 Result |= RHS;
863 return Result;
864}
865
866template <unsigned ElementSize>
867inline SparseBitVector<ElementSize> operator&(
868 const SparseBitVector<ElementSize>& LHS,
869 const SparseBitVector<ElementSize>& RHS) {
870 SparseBitVector<ElementSize> Result(LHS);
871 Result &= RHS;
872 return Result;
873}
874
875template <unsigned ElementSize>
876inline SparseBitVector<ElementSize> operator-(
877 const SparseBitVector<ElementSize>& LHS,
878 const SparseBitVector<ElementSize>& RHS) {
879 SparseBitVector<ElementSize> Result;
880 Result.intersectWithComplement(LHS, RHS);
881 return Result;
882}
883
884template <unsigned ElementSize>
885std::ostream& operator<<(
886 std::ostream& stream,
887 const SparseBitVector<ElementSize>& vec) {
888 bool first = true;
889 stream << "{";
890 for (auto el : vec) {
891 if (first) {
892 first = false;
893 } else {
894 stream << ", ";
895 }
896 stream << el;
897 }
898 stream << "}";
899 return stream;
900}
901
902} // end namespace c10
903
904C10_CLANG_DIAGNOSTIC_POP()
905