1#pragma once
2
3#include <c10/macros/Macros.h>
4#include <c10/util/C++17.h>
5#include <c10/util/Optional.h>
6#if defined(_MSC_VER)
7#include <intrin.h>
8#endif
9
10namespace c10 {
11namespace utils {
12
13/**
14 * This is a simple bitset class with sizeof(long long int) bits.
15 * You can set bits, unset bits, query bits by index,
16 * and query for the first set bit.
17 * Before using this class, please also take a look at std::bitset,
18 * which has more functionality and is more generic. It is probably
19 * a better fit for your use case. The sole reason for c10::utils::bitset
20 * to exist is that std::bitset misses a find_first_set() method.
21 */
22struct bitset final {
23 private:
24#if defined(_MSC_VER)
25 // MSVCs _BitScanForward64 expects int64_t
26 using bitset_type = int64_t;
27#else
28 // POSIX ffsll expects long long int
29 using bitset_type = long long int;
30#endif
31 public:
32 static constexpr size_t NUM_BITS() {
33 return 8 * sizeof(bitset_type);
34 }
35
36 constexpr bitset() noexcept = default;
37 constexpr bitset(const bitset&) noexcept = default;
38 constexpr bitset(bitset&&) noexcept = default;
39 // there is an issure for gcc 5.3.0 when define default function as constexpr
40 // see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=68754.
41 bitset& operator=(const bitset&) noexcept = default;
42 bitset& operator=(bitset&&) noexcept = default;
43
44 constexpr void set(size_t index) noexcept {
45 bitset_ |= (static_cast<long long int>(1) << index);
46 }
47
48 constexpr void unset(size_t index) noexcept {
49 bitset_ &= ~(static_cast<long long int>(1) << index);
50 }
51
52 constexpr bool get(size_t index) const noexcept {
53 return bitset_ & (static_cast<long long int>(1) << index);
54 }
55
56 constexpr bool is_entirely_unset() const noexcept {
57 return 0 == bitset_;
58 }
59
60 // Call the given functor with the index of each bit that is set
61 template <class Func>
62 void for_each_set_bit(Func&& func) const {
63 bitset cur = *this;
64 size_t index = cur.find_first_set();
65 while (0 != index) {
66 // -1 because find_first_set() is not one-indexed.
67 index -= 1;
68 func(index);
69 cur.unset(index);
70 index = cur.find_first_set();
71 }
72 }
73
74 private:
75 // Return the index of the first set bit. The returned index is one-indexed
76 // (i.e. if the very first bit is set, this function returns '1'), and a
77 // return of '0' means that there was no bit set.
78 size_t find_first_set() const {
79#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64))
80 unsigned long result;
81 bool has_bits_set = (0 != _BitScanForward64(&result, bitset_));
82 if (!has_bits_set) {
83 return 0;
84 }
85 return result + 1;
86#elif defined(_MSC_VER) && defined(_M_IX86)
87 unsigned long result;
88 if (static_cast<uint32_t>(bitset_) != 0) {
89 bool has_bits_set =
90 (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_)));
91 if (!has_bits_set) {
92 return 0;
93 }
94 return result + 1;
95 } else {
96 bool has_bits_set =
97 (0 != _BitScanForward(&result, static_cast<uint32_t>(bitset_ >> 32)));
98 if (!has_bits_set) {
99 return 32;
100 }
101 return result + 33;
102 }
103#else
104 return __builtin_ffsll(bitset_);
105#endif
106 }
107
108 friend bool operator==(bitset lhs, bitset rhs) noexcept {
109 return lhs.bitset_ == rhs.bitset_;
110 }
111
112 bitset_type bitset_{0};
113};
114
115inline bool operator!=(bitset lhs, bitset rhs) noexcept {
116 return !(lhs == rhs);
117}
118
119} // namespace utils
120} // namespace c10
121