1 | #pragma once |
2 | |
3 | // DO NOT DEFINE STATIC DATA IN THIS HEADER! |
4 | // See Note [Do not compile initializers with AVX] |
5 | |
6 | #include <ATen/cpu/vec/intrinsics.h> |
7 | #include <ATen/cpu/vec/vec_base.h> |
8 | #include <c10/util/irange.h> |
9 | |
10 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
11 | #include <sleef.h> |
12 | #endif |
13 | |
14 | #pragma GCC diagnostic push |
15 | #pragma GCC diagnostic ignored "-Wignored-qualifiers" |
16 | |
17 | namespace at { |
18 | namespace vec { |
19 | // See Note [CPU_CAPABILITY namespace] |
20 | inline namespace CPU_CAPABILITY { |
21 | |
22 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
23 | |
24 | static inline void cvtbf16_fp32(const __m128i& a, __m256& o) { |
25 | o = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16)); |
26 | } |
27 | |
28 | static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) { |
29 | __m128i lo = _mm256_extractf128_si256(a, 0); |
30 | __m128i hi = _mm256_extractf128_si256(a, 1); |
31 | cvtbf16_fp32(lo, o1); |
32 | cvtbf16_fp32(hi, o2); |
33 | } |
34 | static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) { |
35 | __m256i lo = _mm256_castps_si256(a); |
36 | __m256i hi = _mm256_castps_si256(b); |
37 | __m256i nan = _mm256_set1_epi32(0xffff); |
38 | __m256i mask_lo = _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q)); |
39 | __m256i mask_hi = _mm256_castps_si256(_mm256_cmp_ps(b, b, _CMP_ORD_Q)); |
40 | __m256i ones = _mm256_set1_epi32(0x1); |
41 | __m256i vec_bias = _mm256_set1_epi32(0x7fff); |
42 | // uint32_t lsb = (input >> 16) & 1; |
43 | auto t_lo = _mm256_and_si256(_mm256_srli_epi32(lo, 16), ones); |
44 | auto t_hi = _mm256_and_si256(_mm256_srli_epi32(hi, 16), ones); |
45 | // uint32_t rounding_bias = 0x7fff + lsb; |
46 | t_lo = _mm256_add_epi32(t_lo, vec_bias); |
47 | t_hi = _mm256_add_epi32(t_hi, vec_bias); |
48 | // input += rounding_bias; |
49 | t_lo = _mm256_add_epi32(t_lo, lo); |
50 | t_hi = _mm256_add_epi32(t_hi, hi); |
51 | // input = input >> 16; |
52 | t_lo = _mm256_srli_epi32(t_lo, 16); |
53 | t_hi = _mm256_srli_epi32(t_hi, 16); |
54 | // Check NaN before converting back to bf16 |
55 | t_lo = _mm256_blendv_epi8(nan, t_lo, mask_lo); |
56 | t_hi = _mm256_blendv_epi8(nan, t_hi, mask_hi); |
57 | |
58 | t_lo = _mm256_packus_epi32(t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4] |
59 | return _mm256_permute4x64_epi64(t_lo, 0xd8); // 11 01 10 00 |
60 | } |
61 | |
62 | static inline __m256i merge_compare_result(const __m256& a, const __m256& b) { |
63 | __m256i lo = _mm256_castps_si256(a); |
64 | __m256i hi = _mm256_castps_si256(b); |
65 | lo = _mm256_srli_epi32(lo, 16); |
66 | hi = _mm256_srli_epi32(hi, 16); |
67 | auto out = _mm256_packus_epi32(lo, hi); |
68 | return _mm256_permute4x64_epi64(out, 0xd8); |
69 | } |
70 | |
71 | template <> class Vectorized<BFloat16> { |
72 | private: |
73 | __m256i values; |
74 | public: |
75 | using value_type = uint16_t; |
76 | using size_type = int; |
77 | static constexpr size_type size() { |
78 | return 16; |
79 | } |
80 | Vectorized() {} |
81 | Vectorized(__m256i v) : values(v) {} |
82 | Vectorized(BFloat16 val) { |
83 | value_type uw = val.x; |
84 | values = _mm256_set1_epi16(uw); |
85 | } |
86 | Vectorized(BFloat16 val1, BFloat16 val2, BFloat16 val3, BFloat16 val4, |
87 | BFloat16 val5, BFloat16 val6, BFloat16 val7, BFloat16 val8, |
88 | BFloat16 val9, BFloat16 val10, BFloat16 val11, BFloat16 val12, |
89 | BFloat16 val13, BFloat16 val14, BFloat16 val15, BFloat16 val16) { |
90 | values = _mm256_setr_epi16( |
91 | val1.x, val2.x, val3.x, val4.x, val5.x, val6.x, val7.x, val8.x, |
92 | val9.x, val10.x, val11.x, val12.x, val13.x, val14.x, val15.x, val16.x); |
93 | } |
94 | operator __m256i() const { |
95 | return values; |
96 | } |
97 | BFloat16& operator[](int idx) = delete; |
98 | const BFloat16& operator[](int idx) const = delete; |
99 | int zero_mask() const { |
100 | // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit |
101 | __m256i cmp = _mm256_cmpeq_epi16(values, _mm256_set1_epi16(0)); |
102 | return _mm256_movemask_epi8(cmp); |
103 | } |
104 | static Vectorized<BFloat16> loadu(const void* ptr) { |
105 | return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr)); |
106 | } |
107 | static Vectorized<BFloat16> loadu(const void* ptr, int16_t count) { |
108 | __at_align__ int16_t tmp_values[size()]; |
109 | std::memcpy(tmp_values, ptr, count * sizeof(int16_t)); |
110 | return loadu(tmp_values); |
111 | } |
112 | void store(void* ptr, int count = size()) const { |
113 | if (count == size()) { |
114 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values); |
115 | } else if (count > 0) { |
116 | __at_align__ int16_t tmp_values[size()]; |
117 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values); |
118 | std::memcpy(ptr, tmp_values, count * sizeof(int16_t)); |
119 | } |
120 | } |
121 | template <int64_t mask> |
122 | static Vectorized<BFloat16> blend(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { |
123 | __at_align__ int16_t tmp_values[size()]; |
124 | a.store(tmp_values); |
125 | if (mask & 0x01) |
126 | tmp_values[0] = _mm256_extract_epi16(b.values, 0); |
127 | if (mask & 0x02) |
128 | tmp_values[1] = _mm256_extract_epi16(b.values, 1); |
129 | if (mask & 0x04) |
130 | tmp_values[2] = _mm256_extract_epi16(b.values, 2); |
131 | if (mask & 0x08) |
132 | tmp_values[3] = _mm256_extract_epi16(b.values, 3); |
133 | if (mask & 0x10) |
134 | tmp_values[4] = _mm256_extract_epi16(b.values, 4); |
135 | if (mask & 0x20) |
136 | tmp_values[5] = _mm256_extract_epi16(b.values, 5); |
137 | if (mask & 0x40) |
138 | tmp_values[6] = _mm256_extract_epi16(b.values, 6); |
139 | if (mask & 0x80) |
140 | tmp_values[7] = _mm256_extract_epi16(b.values, 7); |
141 | if (mask & 0x100) |
142 | tmp_values[8] = _mm256_extract_epi16(b.values, 8); |
143 | if (mask & 0x200) |
144 | tmp_values[9] = _mm256_extract_epi16(b.values, 9); |
145 | if (mask & 0x400) |
146 | tmp_values[10] = _mm256_extract_epi16(b.values, 10); |
147 | if (mask & 0x800) |
148 | tmp_values[11] = _mm256_extract_epi16(b.values, 11); |
149 | if (mask & 0x1000) |
150 | tmp_values[12] = _mm256_extract_epi16(b.values, 12); |
151 | if (mask & 0x2000) |
152 | tmp_values[13] = _mm256_extract_epi16(b.values, 13); |
153 | if (mask & 0x4000) |
154 | tmp_values[14] = _mm256_extract_epi16(b.values, 14); |
155 | if (mask & 0x8000) |
156 | tmp_values[15] = _mm256_extract_epi16(b.values, 15); |
157 | return loadu(tmp_values); |
158 | } |
159 | static Vectorized<BFloat16> blendv(const Vectorized<BFloat16>& a, |
160 | const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& mask) { |
161 | return _mm256_blendv_epi8(a.values, b.values, mask.values); |
162 | } |
163 | template<typename step_t> |
164 | static Vectorized<BFloat16> arange(BFloat16 base = 0.f, step_t step = static_cast<step_t>(1)) { |
165 | return Vectorized<BFloat16>( |
166 | base, base + step, base + 2 * step, base + 3 * step, |
167 | base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step, |
168 | base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step, |
169 | base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step); |
170 | } |
171 | static Vectorized<BFloat16> set(const Vectorized<BFloat16>& a, |
172 | const Vectorized<BFloat16>& b, int64_t count = size()) { |
173 | switch (count) { |
174 | case 0: |
175 | return a; |
176 | case 1: |
177 | return blend<1>(a, b); |
178 | case 2: |
179 | return blend<3>(a, b); |
180 | case 3: |
181 | return blend<7>(a, b); |
182 | case 4: |
183 | return blend<15>(a, b); |
184 | case 5: |
185 | return blend<31>(a, b); |
186 | case 6: |
187 | return blend<63>(a, b); |
188 | case 7: |
189 | return blend<127>(a, b); |
190 | case 8: |
191 | return blend<255>(a, b); |
192 | case 9: |
193 | return blend<511>(a, b); |
194 | case 10: |
195 | return blend<1023>(a, b); |
196 | case 11: |
197 | return blend<2047>(a, b); |
198 | case 12: |
199 | return blend<4095>(a, b); |
200 | case 13: |
201 | return blend<8191>(a, b); |
202 | case 14: |
203 | return blend<16383>(a, b); |
204 | case 15: |
205 | return blend<32767>(a, b); |
206 | } |
207 | return b; |
208 | } |
209 | Vectorized<BFloat16> map(const __m256 (*const vop)(__m256)) const { |
210 | __m256 lo, hi; |
211 | cvtbf16_fp32(values, lo, hi); |
212 | const auto o1 = vop(lo); |
213 | const auto o2 = vop(hi); |
214 | return cvtfp32_bf16(o1, o2); |
215 | } |
216 | Vectorized<BFloat16> abs() const { |
217 | __m256 lo, hi; |
218 | cvtbf16_fp32(values, lo, hi); |
219 | const auto mask = _mm256_set1_ps(-0.f); |
220 | const auto o1 = _mm256_andnot_ps(mask, lo); |
221 | const auto o2 = _mm256_andnot_ps(mask, hi); |
222 | return cvtfp32_bf16(o1, o2); |
223 | } |
224 | Vectorized<BFloat16> angle() const { |
225 | __m256 lo, hi; |
226 | cvtbf16_fp32(values, lo, hi); |
227 | auto angle_lambda = [](__m256 values) { |
228 | const auto zero_vec = _mm256_set1_ps(0.f); |
229 | const auto nan_vec = _mm256_set1_ps(NAN); |
230 | const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ); |
231 | const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ); |
232 | const auto pi = _mm256_set1_ps(c10::pi<float>); |
233 | |
234 | const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ); |
235 | auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask); |
236 | angle = _mm256_blendv_ps(angle, nan_vec, nan_mask); |
237 | return angle; |
238 | }; |
239 | auto o1 = angle_lambda(lo); |
240 | auto o2 = angle_lambda(hi); |
241 | return cvtfp32_bf16(o1, o2); |
242 | } |
243 | Vectorized<BFloat16> real() const { |
244 | return *this; |
245 | } |
246 | Vectorized<BFloat16> imag() const { |
247 | return _mm256_set1_epi16(0); |
248 | } |
249 | Vectorized<BFloat16> conj() const { |
250 | return *this; |
251 | } |
252 | Vectorized<BFloat16> acos() const { |
253 | return map(Sleef_acosf8_u10); |
254 | } |
255 | Vectorized<BFloat16> asin() const { |
256 | return map(Sleef_asinf8_u10); |
257 | } |
258 | Vectorized<BFloat16> atan() const { |
259 | return map(Sleef_atanf8_u10); |
260 | } |
261 | Vectorized<BFloat16> atan2(const Vectorized<BFloat16> &b) const { |
262 | __m256 lo, hi; |
263 | __m256 b1, b2; |
264 | cvtbf16_fp32(values, lo, hi); |
265 | cvtbf16_fp32(b.values, b1, b2); |
266 | auto o1 = Sleef_atan2f8_u10(lo, b1); |
267 | auto o2 = Sleef_atan2f8_u10(hi, b2); |
268 | return cvtfp32_bf16(o1, o2); |
269 | } |
270 | Vectorized<BFloat16> copysign(const Vectorized<BFloat16> &sign) const { |
271 | // copy sign bit (0x8000) from sign and remaining bits from values |
272 | __m256i mask_value = _mm256_set1_epi32(~0x80008000); |
273 | __m256i mask_signbit = _mm256_set1_epi32(0x80008000); |
274 | return Vectorized<BFloat16>( |
275 | _mm256_or_si256( |
276 | _mm256_and_si256(values, mask_value), |
277 | _mm256_and_si256(sign, mask_signbit))); |
278 | } |
279 | Vectorized<BFloat16> erf() const { |
280 | return map(Sleef_erff8_u10); |
281 | } |
282 | Vectorized<BFloat16> erfc() const { |
283 | return map(Sleef_erfcf8_u15); |
284 | } |
285 | Vectorized<BFloat16> erfinv() const { |
286 | __m256 lo, hi; |
287 | cvtbf16_fp32(values, lo, hi); |
288 | __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; |
289 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo); |
290 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi); |
291 | for (int64_t i = 0; i < size() / 2; i++) { |
292 | tmp1[i] = calc_erfinv(tmp1[i]); |
293 | tmp2[i] = calc_erfinv(tmp2[i]); |
294 | } |
295 | auto o1 = _mm256_loadu_ps(tmp1); |
296 | auto o2 = _mm256_loadu_ps(tmp2); |
297 | return cvtfp32_bf16(o1, o2); |
298 | } |
299 | Vectorized<BFloat16> exp() const { |
300 | return map(Sleef_expf8_u10); |
301 | } |
302 | Vectorized<BFloat16> exp2() const { |
303 | return map(Sleef_exp2f8_u10); |
304 | } |
305 | Vectorized<BFloat16> expm1() const { |
306 | return map(Sleef_expm1f8_u10); |
307 | } |
308 | Vectorized<BFloat16> fmod(const Vectorized<BFloat16> & q) const { |
309 | __m256 x_lo, x_hi; |
310 | cvtbf16_fp32(values, x_lo, x_hi); |
311 | __m256 q_lo, q_hi; |
312 | cvtbf16_fp32(q.values, q_lo, q_hi); |
313 | auto o1 = Sleef_fmodf8(x_lo, q_lo); |
314 | auto o2 = Sleef_fmodf8(x_hi, q_hi); |
315 | return cvtfp32_bf16(o1, o2); |
316 | } |
317 | Vectorized<BFloat16> hypot(const Vectorized<BFloat16> &b) const { |
318 | __m256 lo, hi; |
319 | __m256 b1, b2; |
320 | cvtbf16_fp32(values, lo, hi); |
321 | cvtbf16_fp32(b.values, b1, b2); |
322 | auto o1 = Sleef_hypotf8_u05(lo, b1); |
323 | auto o2 = Sleef_hypotf8_u05(hi, b2); |
324 | return cvtfp32_bf16(o1, o2); |
325 | } |
326 | Vectorized<BFloat16> i0() const { |
327 | __m256 lo, hi; |
328 | cvtbf16_fp32(values, lo, hi); |
329 | __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; |
330 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo); |
331 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi); |
332 | for (int64_t i = 0; i < size() / 2; i++) { |
333 | tmp1[i] = calc_i0(tmp1[i]); |
334 | tmp2[i] = calc_i0(tmp2[i]); |
335 | } |
336 | auto o1 = _mm256_loadu_ps(tmp1); |
337 | auto o2 = _mm256_loadu_ps(tmp2); |
338 | return cvtfp32_bf16(o1, o2); |
339 | } |
340 | Vectorized<BFloat16> i0e() const { |
341 | __m256 lo, hi; |
342 | cvtbf16_fp32(values, lo, hi); |
343 | constexpr auto sz = size(); |
344 | __at_align__ float tmp1[sz / 2], tmp2[sz / 2]; |
345 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo); |
346 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi); |
347 | |
348 | for (auto i = decltype(sz){0}; i < sz / 2; i++) { |
349 | tmp1[i] = calc_i0e(tmp1[i]); |
350 | tmp2[i] = calc_i0e(tmp2[i]); |
351 | } |
352 | const auto o1 = _mm256_loadu_ps(tmp1); |
353 | const auto o2 = _mm256_loadu_ps(tmp2); |
354 | return cvtfp32_bf16(o1, o2); |
355 | } |
356 | Vectorized<BFloat16> igamma(const Vectorized<BFloat16> &x) const { |
357 | __m256 lo, hi; |
358 | __m256 xlo, xhi; |
359 | cvtbf16_fp32(values, lo, hi); |
360 | cvtbf16_fp32(x.values, xlo, xhi); |
361 | __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; |
362 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo); |
363 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi); |
364 | __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; |
365 | _mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo); |
366 | _mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi); |
367 | for (int64_t i = 0; i < size() / 2; ++i) { |
368 | tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]); |
369 | tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]); |
370 | } |
371 | auto o1 = _mm256_loadu_ps(tmp1); |
372 | auto o2 = _mm256_loadu_ps(tmp2); |
373 | return cvtfp32_bf16(o1, o2); |
374 | } |
375 | |
376 | Vectorized<BFloat16> igammac(const Vectorized<BFloat16> &x) const { |
377 | __m256 lo, hi; |
378 | __m256 xlo, xhi; |
379 | cvtbf16_fp32(values, lo, hi); |
380 | cvtbf16_fp32(x.values, xlo, xhi); |
381 | __at_align__ float tmp1[size() / 2], tmp2[size() / 2]; |
382 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo); |
383 | _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi); |
384 | __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2]; |
385 | _mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo); |
386 | _mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi); |
387 | for (int64_t i = 0; i < size() / 2; ++i) { |
388 | tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]); |
389 | tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]); |
390 | } |
391 | auto o1 = _mm256_loadu_ps(tmp1); |
392 | auto o2 = _mm256_loadu_ps(tmp2); |
393 | return cvtfp32_bf16(o1, o2); |
394 | } |
395 | Vectorized<BFloat16> log() const { |
396 | return map(Sleef_logf8_u10); |
397 | } |
398 | Vectorized<BFloat16> log2() const { |
399 | return map(Sleef_log2f8_u10); |
400 | } |
401 | Vectorized<BFloat16> log10() const { |
402 | return map(Sleef_log10f8_u10); |
403 | } |
404 | Vectorized<BFloat16> log1p() const { |
405 | return map(Sleef_log1pf8_u10); |
406 | } |
407 | Vectorized<BFloat16> frac() const; |
408 | Vectorized<BFloat16> sin() const { |
409 | return map(Sleef_sinf8_u10); |
410 | } |
411 | Vectorized<BFloat16> sinh() const { |
412 | return map(Sleef_sinhf8_u10); |
413 | } |
414 | Vectorized<BFloat16> cos() const { |
415 | return map(Sleef_cosf8_u10); |
416 | } |
417 | Vectorized<BFloat16> cosh() const { |
418 | return map(Sleef_coshf8_u10); |
419 | } |
420 | Vectorized<BFloat16> ceil() const { |
421 | __m256 lo, hi; |
422 | cvtbf16_fp32(values, lo, hi); |
423 | auto o1 = _mm256_ceil_ps(lo); |
424 | auto o2 = _mm256_ceil_ps(hi); |
425 | return cvtfp32_bf16(o1, o2); |
426 | } |
427 | Vectorized<BFloat16> floor() const { |
428 | __m256 lo, hi; |
429 | cvtbf16_fp32(values, lo, hi); |
430 | auto o1 = _mm256_floor_ps(lo); |
431 | auto o2 = _mm256_floor_ps(hi); |
432 | return cvtfp32_bf16(o1, o2); |
433 | } |
434 | Vectorized<BFloat16> neg() const { |
435 | __m256 lo, hi; |
436 | cvtbf16_fp32(values, lo, hi); |
437 | auto mask = _mm256_set1_ps(-0.f); |
438 | auto o1 = _mm256_xor_ps(mask, lo); |
439 | auto o2 = _mm256_xor_ps(mask, hi); |
440 | return cvtfp32_bf16(o1, o2); |
441 | } |
442 | Vectorized<BFloat16> round() const { |
443 | __m256 lo, hi; |
444 | cvtbf16_fp32(values, lo, hi); |
445 | auto o1 = _mm256_round_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
446 | auto o2 = _mm256_round_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
447 | return cvtfp32_bf16(o1, o2); |
448 | } |
449 | Vectorized<BFloat16> tan() const { |
450 | return map(Sleef_tanf8_u10); |
451 | } |
452 | Vectorized<BFloat16> tanh() const { |
453 | return map(Sleef_tanhf8_u10); |
454 | } |
455 | Vectorized<BFloat16> trunc() const { |
456 | __m256 lo, hi; |
457 | cvtbf16_fp32(values, lo, hi); |
458 | auto o1 = _mm256_round_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); |
459 | auto o2 = _mm256_round_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); |
460 | return cvtfp32_bf16(o1, o2); |
461 | } |
462 | Vectorized<BFloat16> lgamma() const { |
463 | return map(Sleef_lgammaf8_u10); |
464 | } |
465 | Vectorized<BFloat16> sqrt() const { |
466 | __m256 lo, hi; |
467 | cvtbf16_fp32(values, lo, hi); |
468 | auto o1 = _mm256_sqrt_ps(lo); |
469 | auto o2 = _mm256_sqrt_ps(hi); |
470 | return cvtfp32_bf16(o1, o2); |
471 | } |
472 | Vectorized<BFloat16> reciprocal() const { |
473 | __m256 lo, hi; |
474 | cvtbf16_fp32(values, lo, hi); |
475 | auto ones = _mm256_set1_ps(1); |
476 | auto o1 = _mm256_div_ps(ones, lo); |
477 | auto o2 = _mm256_div_ps(ones, hi); |
478 | return cvtfp32_bf16(o1, o2); |
479 | } |
480 | Vectorized<BFloat16> rsqrt() const { |
481 | __m256 lo, hi; |
482 | cvtbf16_fp32(values, lo, hi); |
483 | auto ones = _mm256_set1_ps(1); |
484 | auto o1 = _mm256_div_ps(ones, _mm256_sqrt_ps(lo)); |
485 | auto o2 = _mm256_div_ps(ones, _mm256_sqrt_ps(hi)); |
486 | return cvtfp32_bf16(o1, o2); |
487 | } |
488 | Vectorized<BFloat16> pow(const Vectorized<BFloat16> &b) const { |
489 | __m256 lo, hi; |
490 | __m256 b1, b2; |
491 | cvtbf16_fp32(values, lo, hi); |
492 | cvtbf16_fp32(b.values, b1, b2); |
493 | auto o1 = Sleef_powf8_u10(lo, b1); |
494 | auto o2 = Sleef_powf8_u10(hi, b2); |
495 | return cvtfp32_bf16(o1, o2); |
496 | } |
497 | |
498 | Vectorized<BFloat16> inline operator>(const Vectorized<BFloat16>& other) const; |
499 | Vectorized<BFloat16> inline operator<(const Vectorized<BFloat16>& other) const; |
500 | Vectorized<BFloat16> inline operator>=(const Vectorized<BFloat16>& other) const; |
501 | Vectorized<BFloat16> inline operator<=(const Vectorized<BFloat16>& other) const; |
502 | Vectorized<BFloat16> inline operator==(const Vectorized<BFloat16>& other) const; |
503 | Vectorized<BFloat16> inline operator!=(const Vectorized<BFloat16>& other) const; |
504 | |
505 | Vectorized<BFloat16> eq(const Vectorized<BFloat16>& other) const; |
506 | Vectorized<BFloat16> ne(const Vectorized<BFloat16>& other) const; |
507 | Vectorized<BFloat16> gt(const Vectorized<BFloat16>& other) const; |
508 | Vectorized<BFloat16> ge(const Vectorized<BFloat16>& other) const; |
509 | Vectorized<BFloat16> lt(const Vectorized<BFloat16>& other) const; |
510 | Vectorized<BFloat16> le(const Vectorized<BFloat16>& other) const; |
511 | }; |
512 | |
513 | template<typename Op> |
514 | Vectorized<BFloat16> static inline bfloat16_binary_op_as_fp32(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b, Op op) { |
515 | __m256 a_lo, a_hi; |
516 | __m256 b_lo, b_hi; |
517 | cvtbf16_fp32(__m256i(a), a_lo, a_hi); |
518 | cvtbf16_fp32(__m256i(b), b_lo, b_hi); |
519 | auto o1 = op(a_lo, b_lo); |
520 | auto o2 = op(a_hi, b_hi); |
521 | return cvtfp32_bf16(o1, o2); |
522 | } |
523 | |
524 | template<typename Op> |
525 | Vectorized<BFloat16> static inline bfloat16_compare_as_fp32(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b, Op op) { |
526 | __m256 a_lo, a_hi; |
527 | __m256 b_lo, b_hi; |
528 | cvtbf16_fp32(__m256i(a), a_lo, a_hi); |
529 | cvtbf16_fp32(__m256i(b), b_lo, b_hi); |
530 | auto o1 = op(a_lo, b_lo); |
531 | auto o2 = op(a_hi, b_hi); |
532 | return merge_compare_result(o1, o2); |
533 | } |
534 | |
535 | Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>(const Vectorized<BFloat16>& other) const { |
536 | return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) { |
537 | return _mm256_cmp_ps(x, y, _CMP_GT_OQ); |
538 | }); |
539 | } |
540 | Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<(const Vectorized<BFloat16>& other) const { |
541 | return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) { |
542 | return _mm256_cmp_ps(x, y, _CMP_LT_OQ); |
543 | }); |
544 | } |
545 | Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>=(const Vectorized<BFloat16>& other) const { |
546 | return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) { |
547 | return _mm256_cmp_ps(x, y, _CMP_GE_OQ); |
548 | }); |
549 | } |
550 | Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<=(const Vectorized<BFloat16>& other) const { |
551 | return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) { |
552 | return _mm256_cmp_ps(x, y, _CMP_LE_OQ); |
553 | }); |
554 | } |
555 | Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(const Vectorized<BFloat16>& other) const { |
556 | return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) { |
557 | return _mm256_cmp_ps(x, y, _CMP_EQ_OQ); |
558 | }); |
559 | } |
560 | Vectorized<BFloat16> inline Vectorized<BFloat16>::operator!=(const Vectorized<BFloat16>& other) const { |
561 | return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) { |
562 | return _mm256_cmp_ps(x, y, _CMP_NEQ_UQ); |
563 | }); |
564 | } |
565 | |
566 | Vectorized<BFloat16> inline operator+(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { |
567 | return bfloat16_binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_add_ps(x, y); }); |
568 | } |
569 | Vectorized<BFloat16> inline operator-(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { |
570 | return bfloat16_binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_sub_ps(x, y); }); |
571 | } |
572 | Vectorized<BFloat16> inline operator*(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { |
573 | return bfloat16_binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_mul_ps(x, y); }); |
574 | } |
575 | Vectorized<BFloat16> inline operator/(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { |
576 | return bfloat16_binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_div_ps(x, y); }); |
577 | } |
578 | |
579 | Vectorized<BFloat16> inline operator&(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { |
580 | return _mm256_and_si256(a, b); |
581 | } |
582 | Vectorized<BFloat16> inline operator|(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { |
583 | return _mm256_or_si256(a, b); |
584 | } |
585 | Vectorized<BFloat16> inline operator^(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { |
586 | return _mm256_xor_si256(a, b); |
587 | } |
588 | |
589 | inline Vectorized<BFloat16> Vectorized<BFloat16>::eq(const Vectorized<BFloat16>& other) const { |
590 | return (*this == other) & Vectorized<BFloat16>(1.0f); |
591 | } |
592 | |
593 | inline Vectorized<BFloat16> Vectorized<BFloat16>::ne(const Vectorized<BFloat16>& other) const { |
594 | return (*this != other) & Vectorized<BFloat16>(1.0f); |
595 | } |
596 | |
597 | inline Vectorized<BFloat16> Vectorized<BFloat16>::gt(const Vectorized<BFloat16>& other) const { |
598 | return (*this > other) & Vectorized<BFloat16>(1.0f); |
599 | } |
600 | |
601 | inline Vectorized<BFloat16> Vectorized<BFloat16>::ge(const Vectorized<BFloat16>& other) const { |
602 | return (*this >= other) & Vectorized<BFloat16>(1.0f); |
603 | } |
604 | |
605 | inline Vectorized<BFloat16> Vectorized<BFloat16>::lt(const Vectorized<BFloat16>& other) const { |
606 | return (*this < other) & Vectorized<BFloat16>(1.0f); |
607 | } |
608 | |
609 | inline Vectorized<BFloat16> Vectorized<BFloat16>::le(const Vectorized<BFloat16>& other) const { |
610 | return (*this <= other) & Vectorized<BFloat16>(1.0f); |
611 | } |
612 | |
613 | // frac. Implement this here so we can use subtraction |
614 | inline Vectorized<BFloat16> Vectorized<BFloat16>::frac() const { |
615 | return *this - this->trunc(); |
616 | } |
617 | |
618 | // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if |
619 | // either input is a NaN. |
620 | template <> |
621 | Vectorized<BFloat16> inline maximum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { |
622 | __m256 a_lo, a_hi; |
623 | __m256 b_lo, b_hi; |
624 | cvtbf16_fp32(__m256i(a), a_lo, a_hi); |
625 | cvtbf16_fp32(__m256i(b), b_lo, b_hi); |
626 | auto max_lo = _mm256_max_ps(a_lo, b_lo); |
627 | auto max_hi = _mm256_max_ps(a_hi, b_hi); |
628 | auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); |
629 | auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); |
630 | // Exploit the fact that all-ones is a NaN. |
631 | auto o1 = _mm256_or_ps(max_lo, nan_lo); |
632 | auto o2 = _mm256_or_ps(max_hi, nan_hi); |
633 | return cvtfp32_bf16(o1, o2); |
634 | } |
635 | |
636 | // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if |
637 | // either input is a NaN. |
638 | template <> |
639 | Vectorized<BFloat16> inline minimum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) { |
640 | __m256 a_lo, a_hi; |
641 | __m256 b_lo, b_hi; |
642 | cvtbf16_fp32(__m256i(a), a_lo, a_hi); |
643 | cvtbf16_fp32(__m256i(b), b_lo, b_hi); |
644 | auto min_lo = _mm256_min_ps(a_lo, b_lo); |
645 | auto min_hi = _mm256_min_ps(a_hi, b_hi); |
646 | auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q); |
647 | auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q); |
648 | // Exploit the fact that all-ones is a NaN. |
649 | auto o1 = _mm256_or_ps(min_lo, nan_lo); |
650 | auto o2 = _mm256_or_ps(min_hi, nan_hi); |
651 | return cvtfp32_bf16(o1, o2); |
652 | } |
653 | |
654 | template <> |
655 | Vectorized<BFloat16> inline clamp(const Vectorized<BFloat16>& a, |
656 | const Vectorized<BFloat16>& min, const Vectorized<BFloat16>& max) { |
657 | __m256 a_lo, a_hi; |
658 | __m256 min_lo, min_hi; |
659 | __m256 max_lo, max_hi; |
660 | cvtbf16_fp32(__m256i(a), a_lo, a_hi); |
661 | cvtbf16_fp32(__m256i(min), min_lo, min_hi); |
662 | cvtbf16_fp32(__m256i(max), max_lo, max_hi); |
663 | auto o1 = _mm256_min_ps(max_lo, _mm256_max_ps(min_lo, a_lo)); |
664 | auto o2 = _mm256_min_ps(max_hi, _mm256_max_ps(min_hi, a_hi)); |
665 | return cvtfp32_bf16(o1, o2); |
666 | } |
667 | |
668 | template <> |
669 | Vectorized<BFloat16> inline clamp_max(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& max) { |
670 | __m256 a_lo, a_hi; |
671 | __m256 max_lo, max_hi; |
672 | cvtbf16_fp32(__m256i(a), a_lo, a_hi); |
673 | cvtbf16_fp32(__m256i(max), max_lo, max_hi); |
674 | auto o1 = _mm256_min_ps(max_lo, a_lo); |
675 | auto o2 = _mm256_min_ps(max_hi, a_hi); |
676 | return cvtfp32_bf16(o1, o2); |
677 | } |
678 | |
679 | template <> |
680 | Vectorized<BFloat16> inline clamp_min(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& min) { |
681 | __m256 a_lo, a_hi; |
682 | __m256 min_lo, min_hi; |
683 | cvtbf16_fp32(__m256i(a), a_lo, a_hi); |
684 | cvtbf16_fp32(__m256i(min), min_lo, min_hi); |
685 | auto o1 = _mm256_max_ps(min_lo, a_lo); |
686 | auto o2 = _mm256_max_ps(min_hi, a_hi); |
687 | return cvtfp32_bf16(o1, o2); |
688 | } |
689 | |
690 | template <> |
691 | inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) { |
692 | int64_t i; |
693 | #pragma unroll |
694 | for (i = 0; i <= (n - Vectorized<BFloat16>::size()); i += Vectorized<BFloat16>::size()) { |
695 | auto vsrc = _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i))); |
696 | _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc); |
697 | } |
698 | #pragma unroll |
699 | for (; i < n; i++) { |
700 | dst[i] = src[i]; |
701 | } |
702 | } |
703 | |
704 | template <> |
705 | inline void convert(const float* src, BFloat16* dst, int64_t n) { |
706 | int64_t i; |
707 | for (i = 0; i + Vectorized<BFloat16>::size() <= n; i += Vectorized<BFloat16>::size()) { |
708 | __m256 a = _mm256_loadu_ps(&src[i]); |
709 | __m256 b = _mm256_loadu_ps(&src[i + 8]); |
710 | |
711 | __m256i bf = cvtfp32_bf16(a, b); |
712 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf); |
713 | } |
714 | for (; i < n; i++) { |
715 | dst[i] = c10::convert<BFloat16>(src[i]); |
716 | } |
717 | } |
718 | |
719 | template <> |
720 | inline void convert(const double* src, BFloat16* dst, int64_t n) { |
721 | auto load_float = [](const double *src) -> __m256 { |
722 | // Load one float vector from an array of doubles |
723 | __m128 a = _mm256_cvtpd_ps(_mm256_loadu_pd(src)); |
724 | __m128 b = _mm256_cvtpd_ps(_mm256_loadu_pd(src + 4)); |
725 | return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); |
726 | }; |
727 | |
728 | int64_t i; |
729 | for (i = 0; i + Vectorized<BFloat16>::size() <= n; i += Vectorized<BFloat16>::size()) { |
730 | __m256 a = load_float(&src[i]); |
731 | __m256 b = load_float(&src[i + 8]); |
732 | |
733 | __m256i bf = cvtfp32_bf16(a, b); |
734 | _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf); |
735 | } |
736 | for (; i < n; i++) { |
737 | dst[i] = c10::convert<BFloat16>(src[i]); |
738 | } |
739 | } |
740 | |
741 | template <> |
742 | Vectorized<BFloat16> inline fmadd(const Vectorized<BFloat16>& a, |
743 | const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& c) { |
744 | __m256 a_lo, a_hi; |
745 | __m256 b_lo, b_hi; |
746 | __m256 c_lo, c_hi; |
747 | cvtbf16_fp32(__m256i(a), a_lo, a_hi); |
748 | cvtbf16_fp32(__m256i(b), b_lo, b_hi); |
749 | cvtbf16_fp32(__m256i(c), c_lo, c_hi); |
750 | auto o1 = _mm256_fmadd_ps(a_lo, b_lo, c_lo); |
751 | auto o2 = _mm256_fmadd_ps(a_hi, b_hi, c_hi); |
752 | return cvtfp32_bf16(o1, o2); |
753 | } |
754 | |
755 | inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) { |
756 | __m256 o1, o2; |
757 | cvtbf16_fp32(__m256i(a), o1, o2); |
758 | return std::make_tuple(o1, o2); |
759 | } |
760 | |
761 | inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) { |
762 | return cvtfp32_bf16(__m256(a), __m256(b)); |
763 | } |
764 | |
765 | |
766 | #else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
767 | |
768 | inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) { |
769 | constexpr int64_t K = Vectorized<BFloat16>::size(); |
770 | __at_align__ float arr[K]; |
771 | __at_align__ BFloat16 arr2[K]; |
772 | a.store(arr2); |
773 | convert(arr2, arr, K); |
774 | return std::make_tuple( |
775 | Vectorized<float>::loadu(arr), |
776 | Vectorized<float>::loadu(arr + Vectorized<float>::size())); |
777 | } |
778 | |
779 | inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) { |
780 | constexpr int64_t K = Vectorized<BFloat16>::size(); |
781 | __at_align__ float arr[K]; |
782 | __at_align__ BFloat16 arr2[K]; |
783 | a.store(arr); |
784 | b.store(arr + Vectorized<float>::size()); |
785 | convert(arr, arr2, K); |
786 | return Vectorized<BFloat16>::loadu(arr2); |
787 | } |
788 | |
789 | #endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
790 | |
791 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
792 | inline void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) { |
793 | auto values = _mm_loadu_si128(reinterpret_cast<const __m128i*>(data)); |
794 | __m256 out_values; |
795 | cvtbf16_fp32(values, out_values); |
796 | out = out_values; |
797 | } |
798 | |
799 | inline void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out1, Vectorized<float>& out2) { |
800 | auto vec = Vectorized<c10::BFloat16>::loadu(data); |
801 | __m256 out1_values, out2_values; |
802 | cvtbf16_fp32(vec, out1_values, out2_values); |
803 | out1 = out1_values; |
804 | out2 = out2_values; |
805 | } |
806 | #else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
807 | inline void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) { |
808 | __at_align__ float values[Vectorized<float>::size()]; |
809 | for (const auto k : c10::irange(Vectorized<float>::size())) { |
810 | values[k] = data[k]; |
811 | } |
812 | out = Vectorized<float>::loadu(values); |
813 | } |
814 | |
815 | inline void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out1, Vectorized<float>& out2) { |
816 | load_fp32_from_bf16(data, out1); |
817 | data += Vectorized<float>::size(); |
818 | load_fp32_from_bf16(data, out2); |
819 | } |
820 | #endif |
821 | |
822 | }}} |
823 | |
824 | #pragma GCC diagnostic pop |
825 | |