1/*******************************************************************************
2 Copyright (c) The Taichi Authors (2016- ). All Rights Reserved.
3 The use of this software is governed by the LICENSE file.
4*******************************************************************************/
5#pragma once
6
7#include "taichi/common/core.h"
8
9namespace taichi {
10namespace bit {
11
12TI_FORCE_INLINE constexpr bool is_power_of_two(int32 x) {
13 return x != 0 && (x & (x - 1)) == 0;
14}
15
16TI_FORCE_INLINE constexpr bool is_power_of_two(uint32 x) {
17 return x != 0 && (x & (x - 1)) == 0;
18}
19
20TI_FORCE_INLINE constexpr bool is_power_of_two(int64 x) {
21 return x != 0 && (x & (x - 1)) == 0;
22}
23
24TI_FORCE_INLINE constexpr bool is_power_of_two(uint64 x) {
25 return x != 0 && (x & (x - 1)) == 0;
26}
27
28TI_FORCE_INLINE uint32 as_uint(const float32 x) {
29 return *(uint32 *)&x;
30}
31
32TI_FORCE_INLINE float32 as_float(const uint32 x) {
33 return *(float32 *)&x;
34}
35
36TI_FORCE_INLINE float32 half_to_float(const uint16 x) {
37 // Reference: https://stackoverflow.com/a/60047308
38 const uint32 e = (x & 0x7C00) >> 10; // exponent
39 const uint32 m = (x & 0x03FF) << 13; // mantissa
40 const uint32 v =
41 as_uint((float32)m) >>
42 23; // evil log2 bit hack to count leading zeros in denormalized format
43 return as_float(
44 (x & 0x8000) << 16 | (e != 0) * ((e + 112) << 23 | m) |
45 ((e == 0) & (m != 0)) *
46 ((v - 37) << 23 | ((m << (150 - v)) &
47 0x007FE000))); // sign : normalized : denormalized
48}
49
50template <int length>
51struct Bits {
52 static_assert(is_power_of_two(length), "length must be a power of two");
53 static_assert(length == 32 || length == 64, "length must be 32/64");
54
55 using T = std::conditional_t<length == 32, uint32, uint64>;
56
57 T data;
58
59 Bits() : data(0) {
60 }
61
62 // Uninitialized
63 explicit Bits(void *) {
64 }
65
66 template <int start, int bits = 1>
67 static constexpr T mask() {
68 return (((T)1 << bits) - 1) << start;
69 }
70
71 template <int start, int bits = 1>
72 TI_FORCE_INLINE T get() const {
73 return (data >> start) & (((T)1 << bits) - 1);
74 }
75
76 template <int start, int bits = 1>
77 TI_FORCE_INLINE void set(T val) {
78 data =
79 (data & ~mask<start, bits>()) | ((val << start) & mask<start, bits>());
80 }
81
82 TI_FORCE_INLINE T operator()(T) const {
83 return data;
84 }
85
86 TI_FORCE_INLINE T get() const {
87 return data;
88 }
89
90 TI_FORCE_INLINE void set(const T &data) {
91 this->data = data;
92 }
93};
94
95template <int length>
96using BitFlags = Bits<length>;
97
98template <typename T>
99constexpr int bit_length() {
100 return std::is_same<T, bool>() ? 1 : sizeof(T) * 8;
101}
102
103#define TI_BIT_FIELD(T, name, start) \
104 T get_##name() const { return (T)Base::get<start, bit::bit_length<T>()>(); } \
105 void set_##name(const T &val) { Base::set<start, bit::bit_length<T>()>(val); }
106
107template <typename T, int N>
108TI_FORCE_INLINE constexpr T product(const std::array<T, N> arr) {
109 T ret(1);
110 for (int i = 0; i < N; i++) {
111 ret *= arr[i];
112 }
113 return ret;
114}
115
116constexpr std::size_t least_pot_bound(std::size_t v) {
117 if (v > std::numeric_limits<std::size_t>::max() / 2 + 1) {
118 TI_ERROR("v({}) too large", v)
119 }
120 std::size_t ret = 1;
121 while (ret < v) {
122 ret *= 2;
123 }
124 return ret;
125}
126
127TI_FORCE_INLINE constexpr uint32 pot_mask(int x) {
128 return (1u << x) - 1;
129}
130
131TI_FORCE_INLINE constexpr uint32 log2int(uint64 value) {
132 int ret = 0;
133 value >>= 1;
134 while (value) {
135 value >>= 1;
136 ret += 1;
137 }
138 return ret;
139}
140
141TI_FORCE_INLINE constexpr uint32 ceil_log2int(uint64 value) {
142 // Returns ceil(log2(value)). When value == 0, it returns 0.
143 return log2int(value) + ((value & (value - 1)) != 0);
144}
145
146TI_FORCE_INLINE constexpr uint64 lowbit(uint64 x) {
147 return x & (-x);
148}
149
150template <typename G, typename T>
151constexpr TI_FORCE_INLINE copy_refcv_t<T, G> &&reinterpret_bits(T &&t) {
152 TI_STATIC_ASSERT(sizeof(G) == sizeof(T));
153 return std::forward<copy_refcv_t<T, G>>(*reinterpret_cast<G *>(&t));
154};
155
156TI_FORCE_INLINE constexpr float64 compress(float32 h, float32 l) {
157 uint64 data =
158 ((uint64)reinterpret_bits<uint32>(h) << 32) + reinterpret_bits<uint32>(l);
159 return reinterpret_bits<float64>(data);
160}
161
162TI_FORCE_INLINE constexpr std::tuple<float32, float32> extract(float64 x) {
163 auto data = reinterpret_bits<uint64>(x);
164 return std::make_tuple(reinterpret_bits<float32>((uint32)(data >> 32)),
165 reinterpret_bits<float32>((uint32)(data & (-1))));
166}
167
168class Bitset {
169 public:
170 using value_t = uint64;
171 static constexpr std::size_t kBits = sizeof(value_t) * 8;
172 // kBits should be a power of two. However, the function is_power_of_two is
173 // ambiguous and can't be called here.
174 static_assert((kBits & (kBits - 1)) == 0);
175 static constexpr std::size_t kLogBits = log2int(kBits);
176 static constexpr value_t kMask = ((value_t)-1);
177 class reference {
178 public:
179 reference(std::vector<value_t> &vec, int x);
180 explicit operator bool() const;
181 bool operator~() const;
182 reference &operator=(bool x);
183 reference &operator=(const reference &other);
184 reference &flip();
185
186 private:
187 value_t *pos_;
188 value_t digit_;
189 };
190
191 Bitset();
192 explicit Bitset(int n);
193 std::size_t size() const;
194 void reset();
195 void flip(int x);
196 bool any() const;
197 bool none() const;
198 reference operator[](int x);
199 Bitset &operator&=(const Bitset &other);
200 Bitset operator&(const Bitset &other) const;
201 Bitset &operator|=(const Bitset &other);
202 Bitset operator|(const Bitset &other) const;
203 Bitset &operator^=(const Bitset &other);
204 Bitset operator~() const;
205
206 // Find the place of the first "1", or return -1 if it doesn't exist.
207 int find_first_one() const;
208 // Find the place of the first "1" which is not before x, or return -1 if
209 // it doesn't exist.
210 int lower_bound(int x) const;
211
212 std::vector<int> or_eq_get_update_list(const Bitset &other);
213
214 // output from the lowest bit to the highest bit
215 friend std::ostream &operator<<(std::ostream &os, const Bitset &b);
216
217 private:
218 std::vector<value_t> vec_;
219};
220
221} // namespace bit
222} // namespace taichi
223