1 | /******************************************************************************* |
2 | Copyright (c) The Taichi Authors (2016- ). All Rights Reserved. |
3 | The use of this software is governed by the LICENSE file. |
4 | *******************************************************************************/ |
5 | #pragma once |
6 | |
7 | #include "taichi/common/core.h" |
8 | |
9 | namespace taichi { |
10 | namespace bit { |
11 | |
12 | TI_FORCE_INLINE constexpr bool is_power_of_two(int32 x) { |
13 | return x != 0 && (x & (x - 1)) == 0; |
14 | } |
15 | |
16 | TI_FORCE_INLINE constexpr bool is_power_of_two(uint32 x) { |
17 | return x != 0 && (x & (x - 1)) == 0; |
18 | } |
19 | |
20 | TI_FORCE_INLINE constexpr bool is_power_of_two(int64 x) { |
21 | return x != 0 && (x & (x - 1)) == 0; |
22 | } |
23 | |
24 | TI_FORCE_INLINE constexpr bool is_power_of_two(uint64 x) { |
25 | return x != 0 && (x & (x - 1)) == 0; |
26 | } |
27 | |
28 | TI_FORCE_INLINE uint32 as_uint(const float32 x) { |
29 | return *(uint32 *)&x; |
30 | } |
31 | |
32 | TI_FORCE_INLINE float32 as_float(const uint32 x) { |
33 | return *(float32 *)&x; |
34 | } |
35 | |
36 | TI_FORCE_INLINE float32 half_to_float(const uint16 x) { |
37 | // Reference: https://stackoverflow.com/a/60047308 |
38 | const uint32 e = (x & 0x7C00) >> 10; // exponent |
39 | const uint32 m = (x & 0x03FF) << 13; // mantissa |
40 | const uint32 v = |
41 | as_uint((float32)m) >> |
42 | 23; // evil log2 bit hack to count leading zeros in denormalized format |
43 | return as_float( |
44 | (x & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) | |
45 | ((e == 0) & (m != 0)) * |
46 | ((v - 37) << 23 | ((m << (150 - v)) & |
47 | 0x007FE000))); // sign : normalized : denormalized |
48 | } |
49 | |
50 | template <int length> |
51 | struct Bits { |
52 | static_assert(is_power_of_two(length), "length must be a power of two" ); |
53 | static_assert(length == 32 || length == 64, "length must be 32/64" ); |
54 | |
55 | using T = std::conditional_t<length == 32, uint32, uint64>; |
56 | |
57 | T data; |
58 | |
59 | Bits() : data(0) { |
60 | } |
61 | |
62 | // Uninitialized |
63 | explicit Bits(void *) { |
64 | } |
65 | |
66 | template <int start, int bits = 1> |
67 | static constexpr T mask() { |
68 | return (((T)1 << bits) - 1) << start; |
69 | } |
70 | |
71 | template <int start, int bits = 1> |
72 | TI_FORCE_INLINE T get() const { |
73 | return (data >> start) & (((T)1 << bits) - 1); |
74 | } |
75 | |
76 | template <int start, int bits = 1> |
77 | TI_FORCE_INLINE void set(T val) { |
78 | data = |
79 | (data & ~mask<start, bits>()) | ((val << start) & mask<start, bits>()); |
80 | } |
81 | |
82 | TI_FORCE_INLINE T operator()(T) const { |
83 | return data; |
84 | } |
85 | |
86 | TI_FORCE_INLINE T get() const { |
87 | return data; |
88 | } |
89 | |
90 | TI_FORCE_INLINE void set(const T &data) { |
91 | this->data = data; |
92 | } |
93 | }; |
94 | |
95 | template <int length> |
96 | using BitFlags = Bits<length>; |
97 | |
98 | template <typename T> |
99 | constexpr int bit_length() { |
100 | return std::is_same<T, bool>() ? 1 : sizeof(T) * 8; |
101 | } |
102 | |
103 | #define TI_BIT_FIELD(T, name, start) \ |
104 | T get_##name() const { return (T)Base::get<start, bit::bit_length<T>()>(); } \ |
105 | void set_##name(const T &val) { Base::set<start, bit::bit_length<T>()>(val); } |
106 | |
107 | template <typename T, int N> |
108 | TI_FORCE_INLINE constexpr T product(const std::array<T, N> arr) { |
109 | T ret(1); |
110 | for (int i = 0; i < N; i++) { |
111 | ret *= arr[i]; |
112 | } |
113 | return ret; |
114 | } |
115 | |
116 | constexpr std::size_t least_pot_bound(std::size_t v) { |
117 | if (v > std::numeric_limits<std::size_t>::max() / 2 + 1) { |
118 | TI_ERROR("v({}) too large" , v) |
119 | } |
120 | std::size_t ret = 1; |
121 | while (ret < v) { |
122 | ret *= 2; |
123 | } |
124 | return ret; |
125 | } |
126 | |
127 | TI_FORCE_INLINE constexpr uint32 pot_mask(int x) { |
128 | return (1u << x) - 1; |
129 | } |
130 | |
131 | TI_FORCE_INLINE constexpr uint32 log2int(uint64 value) { |
132 | int ret = 0; |
133 | value >>= 1; |
134 | while (value) { |
135 | value >>= 1; |
136 | ret += 1; |
137 | } |
138 | return ret; |
139 | } |
140 | |
141 | TI_FORCE_INLINE constexpr uint32 ceil_log2int(uint64 value) { |
142 | // Returns ceil(log2(value)). When value == 0, it returns 0. |
143 | return log2int(value) + ((value & (value - 1)) != 0); |
144 | } |
145 | |
146 | TI_FORCE_INLINE constexpr uint64 lowbit(uint64 x) { |
147 | return x & (-x); |
148 | } |
149 | |
150 | template <typename G, typename T> |
151 | constexpr TI_FORCE_INLINE copy_refcv_t<T, G> &&reinterpret_bits(T &&t) { |
152 | TI_STATIC_ASSERT(sizeof(G) == sizeof(T)); |
153 | return std::forward<copy_refcv_t<T, G>>(*reinterpret_cast<G *>(&t)); |
154 | }; |
155 | |
156 | TI_FORCE_INLINE constexpr float64 compress(float32 h, float32 l) { |
157 | uint64 data = |
158 | ((uint64)reinterpret_bits<uint32>(h) << 32) + reinterpret_bits<uint32>(l); |
159 | return reinterpret_bits<float64>(data); |
160 | } |
161 | |
162 | TI_FORCE_INLINE constexpr std::tuple<float32, float32> (float64 x) { |
163 | auto data = reinterpret_bits<uint64>(x); |
164 | return std::make_tuple(reinterpret_bits<float32>((uint32)(data >> 32)), |
165 | reinterpret_bits<float32>((uint32)(data & (-1)))); |
166 | } |
167 | |
168 | class Bitset { |
169 | public: |
170 | using value_t = uint64; |
171 | static constexpr std::size_t kBits = sizeof(value_t) * 8; |
172 | // kBits should be a power of two. However, the function is_power_of_two is |
173 | // ambiguous and can't be called here. |
174 | static_assert((kBits & (kBits - 1)) == 0); |
175 | static constexpr std::size_t kLogBits = log2int(kBits); |
176 | static constexpr value_t kMask = ((value_t)-1); |
177 | class reference { |
178 | public: |
179 | reference(std::vector<value_t> &vec, int x); |
180 | explicit operator bool() const; |
181 | bool operator~() const; |
182 | reference &operator=(bool x); |
183 | reference &operator=(const reference &other); |
184 | reference &flip(); |
185 | |
186 | private: |
187 | value_t *pos_; |
188 | value_t digit_; |
189 | }; |
190 | |
191 | Bitset(); |
192 | explicit Bitset(int n); |
193 | std::size_t size() const; |
194 | void reset(); |
195 | void flip(int x); |
196 | bool any() const; |
197 | bool none() const; |
198 | reference operator[](int x); |
199 | Bitset &operator&=(const Bitset &other); |
200 | Bitset operator&(const Bitset &other) const; |
201 | Bitset &operator|=(const Bitset &other); |
202 | Bitset operator|(const Bitset &other) const; |
203 | Bitset &operator^=(const Bitset &other); |
204 | Bitset operator~() const; |
205 | |
206 | // Find the place of the first "1", or return -1 if it doesn't exist. |
207 | int find_first_one() const; |
208 | // Find the place of the first "1" which is not before x, or return -1 if |
209 | // it doesn't exist. |
210 | int lower_bound(int x) const; |
211 | |
212 | std::vector<int> or_eq_get_update_list(const Bitset &other); |
213 | |
214 | // output from the lowest bit to the highest bit |
215 | friend std::ostream &operator<<(std::ostream &os, const Bitset &b); |
216 | |
217 | private: |
218 | std::vector<value_t> vec_; |
219 | }; |
220 | |
221 | } // namespace bit |
222 | } // namespace taichi |
223 | |