1#pragma once
2
3// DO NOT DEFINE STATIC DATA IN THIS HEADER!
4// See Note [Do not compile initializers with AVX]
5
6#include <ATen/cpu/vec/intrinsics.h>
7#include <ATen/cpu/vec/vec_base.h>
8#include <c10/util/irange.h>
9
10#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
11#include <sleef.h>
12#endif
13
14#pragma GCC diagnostic push
15#pragma GCC diagnostic ignored "-Wignored-qualifiers"
16
17namespace at {
18namespace vec {
19// See Note [CPU_CAPABILITY namespace]
20inline namespace CPU_CAPABILITY {
21
22#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
23
24static inline void cvtbf16_fp32(const __m128i& a, __m256& o) {
25 o = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a), 16));
26}
27
28static inline void cvtbf16_fp32(const __m256i& a, __m256& o1, __m256& o2) {
29 __m128i lo = _mm256_extractf128_si256(a, 0);
30 __m128i hi = _mm256_extractf128_si256(a, 1);
31 cvtbf16_fp32(lo, o1);
32 cvtbf16_fp32(hi, o2);
33}
34static inline __m256i cvtfp32_bf16(const __m256& a, const __m256& b) {
35 __m256i lo = _mm256_castps_si256(a);
36 __m256i hi = _mm256_castps_si256(b);
37 __m256i nan = _mm256_set1_epi32(0xffff);
38 __m256i mask_lo = _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q));
39 __m256i mask_hi = _mm256_castps_si256(_mm256_cmp_ps(b, b, _CMP_ORD_Q));
40 __m256i ones = _mm256_set1_epi32(0x1);
41 __m256i vec_bias = _mm256_set1_epi32(0x7fff);
42 // uint32_t lsb = (input >> 16) & 1;
43 auto t_lo = _mm256_and_si256(_mm256_srli_epi32(lo, 16), ones);
44 auto t_hi = _mm256_and_si256(_mm256_srli_epi32(hi, 16), ones);
45 // uint32_t rounding_bias = 0x7fff + lsb;
46 t_lo = _mm256_add_epi32(t_lo, vec_bias);
47 t_hi = _mm256_add_epi32(t_hi, vec_bias);
48 // input += rounding_bias;
49 t_lo = _mm256_add_epi32(t_lo, lo);
50 t_hi = _mm256_add_epi32(t_hi, hi);
51 // input = input >> 16;
52 t_lo = _mm256_srli_epi32(t_lo, 16);
53 t_hi = _mm256_srli_epi32(t_hi, 16);
54 // Check NaN before converting back to bf16
55 t_lo = _mm256_blendv_epi8(nan, t_lo, mask_lo);
56 t_hi = _mm256_blendv_epi8(nan, t_hi, mask_hi);
57
58 t_lo = _mm256_packus_epi32(t_lo, t_hi); // t_hi[4-7] t_lo[4-7] t_hi[0-4] t_lo[0-4]
59 return _mm256_permute4x64_epi64(t_lo, 0xd8); // 11 01 10 00
60}
61
62static inline __m256i merge_compare_result(const __m256& a, const __m256& b) {
63 __m256i lo = _mm256_castps_si256(a);
64 __m256i hi = _mm256_castps_si256(b);
65 lo = _mm256_srli_epi32(lo, 16);
66 hi = _mm256_srli_epi32(hi, 16);
67 auto out = _mm256_packus_epi32(lo, hi);
68 return _mm256_permute4x64_epi64(out, 0xd8);
69}
70
71template <> class Vectorized<BFloat16> {
72private:
73 __m256i values;
74public:
75 using value_type = uint16_t;
76 using size_type = int;
77 static constexpr size_type size() {
78 return 16;
79 }
80 Vectorized() {}
81 Vectorized(__m256i v) : values(v) {}
82 Vectorized(BFloat16 val) {
83 value_type uw = val.x;
84 values = _mm256_set1_epi16(uw);
85 }
86 Vectorized(BFloat16 val1, BFloat16 val2, BFloat16 val3, BFloat16 val4,
87 BFloat16 val5, BFloat16 val6, BFloat16 val7, BFloat16 val8,
88 BFloat16 val9, BFloat16 val10, BFloat16 val11, BFloat16 val12,
89 BFloat16 val13, BFloat16 val14, BFloat16 val15, BFloat16 val16) {
90 values = _mm256_setr_epi16(
91 val1.x, val2.x, val3.x, val4.x, val5.x, val6.x, val7.x, val8.x,
92 val9.x, val10.x, val11.x, val12.x, val13.x, val14.x, val15.x, val16.x);
93 }
94 operator __m256i() const {
95 return values;
96 }
97 BFloat16& operator[](int idx) = delete;
98 const BFloat16& operator[](int idx) const = delete;
99 int zero_mask() const {
100 // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
101 __m256i cmp = _mm256_cmpeq_epi16(values, _mm256_set1_epi16(0));
102 return _mm256_movemask_epi8(cmp);
103 }
104 static Vectorized<BFloat16> loadu(const void* ptr) {
105 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
106 }
107 static Vectorized<BFloat16> loadu(const void* ptr, int16_t count) {
108 __at_align__ int16_t tmp_values[size()];
109 std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
110 return loadu(tmp_values);
111 }
112 void store(void* ptr, int count = size()) const {
113 if (count == size()) {
114 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
115 } else if (count > 0) {
116 __at_align__ int16_t tmp_values[size()];
117 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
118 std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
119 }
120 }
121 template <int64_t mask>
122 static Vectorized<BFloat16> blend(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
123 __at_align__ int16_t tmp_values[size()];
124 a.store(tmp_values);
125 if (mask & 0x01)
126 tmp_values[0] = _mm256_extract_epi16(b.values, 0);
127 if (mask & 0x02)
128 tmp_values[1] = _mm256_extract_epi16(b.values, 1);
129 if (mask & 0x04)
130 tmp_values[2] = _mm256_extract_epi16(b.values, 2);
131 if (mask & 0x08)
132 tmp_values[3] = _mm256_extract_epi16(b.values, 3);
133 if (mask & 0x10)
134 tmp_values[4] = _mm256_extract_epi16(b.values, 4);
135 if (mask & 0x20)
136 tmp_values[5] = _mm256_extract_epi16(b.values, 5);
137 if (mask & 0x40)
138 tmp_values[6] = _mm256_extract_epi16(b.values, 6);
139 if (mask & 0x80)
140 tmp_values[7] = _mm256_extract_epi16(b.values, 7);
141 if (mask & 0x100)
142 tmp_values[8] = _mm256_extract_epi16(b.values, 8);
143 if (mask & 0x200)
144 tmp_values[9] = _mm256_extract_epi16(b.values, 9);
145 if (mask & 0x400)
146 tmp_values[10] = _mm256_extract_epi16(b.values, 10);
147 if (mask & 0x800)
148 tmp_values[11] = _mm256_extract_epi16(b.values, 11);
149 if (mask & 0x1000)
150 tmp_values[12] = _mm256_extract_epi16(b.values, 12);
151 if (mask & 0x2000)
152 tmp_values[13] = _mm256_extract_epi16(b.values, 13);
153 if (mask & 0x4000)
154 tmp_values[14] = _mm256_extract_epi16(b.values, 14);
155 if (mask & 0x8000)
156 tmp_values[15] = _mm256_extract_epi16(b.values, 15);
157 return loadu(tmp_values);
158 }
159 static Vectorized<BFloat16> blendv(const Vectorized<BFloat16>& a,
160 const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& mask) {
161 return _mm256_blendv_epi8(a.values, b.values, mask.values);
162 }
163 template<typename step_t>
164 static Vectorized<BFloat16> arange(BFloat16 base = 0.f, step_t step = static_cast<step_t>(1)) {
165 return Vectorized<BFloat16>(
166 base, base + step, base + 2 * step, base + 3 * step,
167 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
168 base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
169 base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
170 }
171 static Vectorized<BFloat16> set(const Vectorized<BFloat16>& a,
172 const Vectorized<BFloat16>& b, int64_t count = size()) {
173 switch (count) {
174 case 0:
175 return a;
176 case 1:
177 return blend<1>(a, b);
178 case 2:
179 return blend<3>(a, b);
180 case 3:
181 return blend<7>(a, b);
182 case 4:
183 return blend<15>(a, b);
184 case 5:
185 return blend<31>(a, b);
186 case 6:
187 return blend<63>(a, b);
188 case 7:
189 return blend<127>(a, b);
190 case 8:
191 return blend<255>(a, b);
192 case 9:
193 return blend<511>(a, b);
194 case 10:
195 return blend<1023>(a, b);
196 case 11:
197 return blend<2047>(a, b);
198 case 12:
199 return blend<4095>(a, b);
200 case 13:
201 return blend<8191>(a, b);
202 case 14:
203 return blend<16383>(a, b);
204 case 15:
205 return blend<32767>(a, b);
206 }
207 return b;
208 }
209 Vectorized<BFloat16> map(const __m256 (*const vop)(__m256)) const {
210 __m256 lo, hi;
211 cvtbf16_fp32(values, lo, hi);
212 const auto o1 = vop(lo);
213 const auto o2 = vop(hi);
214 return cvtfp32_bf16(o1, o2);
215 }
216 Vectorized<BFloat16> abs() const {
217 __m256 lo, hi;
218 cvtbf16_fp32(values, lo, hi);
219 const auto mask = _mm256_set1_ps(-0.f);
220 const auto o1 = _mm256_andnot_ps(mask, lo);
221 const auto o2 = _mm256_andnot_ps(mask, hi);
222 return cvtfp32_bf16(o1, o2);
223 }
224 Vectorized<BFloat16> angle() const {
225 __m256 lo, hi;
226 cvtbf16_fp32(values, lo, hi);
227 auto angle_lambda = [](__m256 values) {
228 const auto zero_vec = _mm256_set1_ps(0.f);
229 const auto nan_vec = _mm256_set1_ps(NAN);
230 const auto not_nan_mask = _mm256_cmp_ps(values, values, _CMP_EQ_OQ);
231 const auto nan_mask = _mm256_cmp_ps(not_nan_mask, zero_vec, _CMP_EQ_OQ);
232 const auto pi = _mm256_set1_ps(c10::pi<float>);
233
234 const auto neg_mask = _mm256_cmp_ps(values, zero_vec, _CMP_LT_OQ);
235 auto angle = _mm256_blendv_ps(zero_vec, pi, neg_mask);
236 angle = _mm256_blendv_ps(angle, nan_vec, nan_mask);
237 return angle;
238 };
239 auto o1 = angle_lambda(lo);
240 auto o2 = angle_lambda(hi);
241 return cvtfp32_bf16(o1, o2);
242 }
243 Vectorized<BFloat16> real() const {
244 return *this;
245 }
246 Vectorized<BFloat16> imag() const {
247 return _mm256_set1_epi16(0);
248 }
249 Vectorized<BFloat16> conj() const {
250 return *this;
251 }
252 Vectorized<BFloat16> acos() const {
253 return map(Sleef_acosf8_u10);
254 }
255 Vectorized<BFloat16> asin() const {
256 return map(Sleef_asinf8_u10);
257 }
258 Vectorized<BFloat16> atan() const {
259 return map(Sleef_atanf8_u10);
260 }
261 Vectorized<BFloat16> atan2(const Vectorized<BFloat16> &b) const {
262 __m256 lo, hi;
263 __m256 b1, b2;
264 cvtbf16_fp32(values, lo, hi);
265 cvtbf16_fp32(b.values, b1, b2);
266 auto o1 = Sleef_atan2f8_u10(lo, b1);
267 auto o2 = Sleef_atan2f8_u10(hi, b2);
268 return cvtfp32_bf16(o1, o2);
269 }
270 Vectorized<BFloat16> copysign(const Vectorized<BFloat16> &sign) const {
271 // copy sign bit (0x8000) from sign and remaining bits from values
272 __m256i mask_value = _mm256_set1_epi32(~0x80008000);
273 __m256i mask_signbit = _mm256_set1_epi32(0x80008000);
274 return Vectorized<BFloat16>(
275 _mm256_or_si256(
276 _mm256_and_si256(values, mask_value),
277 _mm256_and_si256(sign, mask_signbit)));
278 }
279 Vectorized<BFloat16> erf() const {
280 return map(Sleef_erff8_u10);
281 }
282 Vectorized<BFloat16> erfc() const {
283 return map(Sleef_erfcf8_u15);
284 }
285 Vectorized<BFloat16> erfinv() const {
286 __m256 lo, hi;
287 cvtbf16_fp32(values, lo, hi);
288 __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
289 _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
290 _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
291 for (int64_t i = 0; i < size() / 2; i++) {
292 tmp1[i] = calc_erfinv(tmp1[i]);
293 tmp2[i] = calc_erfinv(tmp2[i]);
294 }
295 auto o1 = _mm256_loadu_ps(tmp1);
296 auto o2 = _mm256_loadu_ps(tmp2);
297 return cvtfp32_bf16(o1, o2);
298 }
299 Vectorized<BFloat16> exp() const {
300 return map(Sleef_expf8_u10);
301 }
302 Vectorized<BFloat16> exp2() const {
303 return map(Sleef_exp2f8_u10);
304 }
305 Vectorized<BFloat16> expm1() const {
306 return map(Sleef_expm1f8_u10);
307 }
308 Vectorized<BFloat16> fmod(const Vectorized<BFloat16> & q) const {
309 __m256 x_lo, x_hi;
310 cvtbf16_fp32(values, x_lo, x_hi);
311 __m256 q_lo, q_hi;
312 cvtbf16_fp32(q.values, q_lo, q_hi);
313 auto o1 = Sleef_fmodf8(x_lo, q_lo);
314 auto o2 = Sleef_fmodf8(x_hi, q_hi);
315 return cvtfp32_bf16(o1, o2);
316 }
317 Vectorized<BFloat16> hypot(const Vectorized<BFloat16> &b) const {
318 __m256 lo, hi;
319 __m256 b1, b2;
320 cvtbf16_fp32(values, lo, hi);
321 cvtbf16_fp32(b.values, b1, b2);
322 auto o1 = Sleef_hypotf8_u05(lo, b1);
323 auto o2 = Sleef_hypotf8_u05(hi, b2);
324 return cvtfp32_bf16(o1, o2);
325 }
326 Vectorized<BFloat16> i0() const {
327 __m256 lo, hi;
328 cvtbf16_fp32(values, lo, hi);
329 __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
330 _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
331 _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
332 for (int64_t i = 0; i < size() / 2; i++) {
333 tmp1[i] = calc_i0(tmp1[i]);
334 tmp2[i] = calc_i0(tmp2[i]);
335 }
336 auto o1 = _mm256_loadu_ps(tmp1);
337 auto o2 = _mm256_loadu_ps(tmp2);
338 return cvtfp32_bf16(o1, o2);
339 }
340 Vectorized<BFloat16> i0e() const {
341 __m256 lo, hi;
342 cvtbf16_fp32(values, lo, hi);
343 constexpr auto sz = size();
344 __at_align__ float tmp1[sz / 2], tmp2[sz / 2];
345 _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
346 _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
347
348 for (auto i = decltype(sz){0}; i < sz / 2; i++) {
349 tmp1[i] = calc_i0e(tmp1[i]);
350 tmp2[i] = calc_i0e(tmp2[i]);
351 }
352 const auto o1 = _mm256_loadu_ps(tmp1);
353 const auto o2 = _mm256_loadu_ps(tmp2);
354 return cvtfp32_bf16(o1, o2);
355 }
356 Vectorized<BFloat16> igamma(const Vectorized<BFloat16> &x) const {
357 __m256 lo, hi;
358 __m256 xlo, xhi;
359 cvtbf16_fp32(values, lo, hi);
360 cvtbf16_fp32(x.values, xlo, xhi);
361 __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
362 _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
363 _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
364 __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
365 _mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
366 _mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
367 for (int64_t i = 0; i < size() / 2; ++i) {
368 tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
369 tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
370 }
371 auto o1 = _mm256_loadu_ps(tmp1);
372 auto o2 = _mm256_loadu_ps(tmp2);
373 return cvtfp32_bf16(o1, o2);
374 }
375
376 Vectorized<BFloat16> igammac(const Vectorized<BFloat16> &x) const {
377 __m256 lo, hi;
378 __m256 xlo, xhi;
379 cvtbf16_fp32(values, lo, hi);
380 cvtbf16_fp32(x.values, xlo, xhi);
381 __at_align__ float tmp1[size() / 2], tmp2[size() / 2];
382 _mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
383 _mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
384 __at_align__ float tmpx1[size() / 2], tmpx2[size() / 2];
385 _mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
386 _mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
387 for (int64_t i = 0; i < size() / 2; ++i) {
388 tmp1[i] = calc_igammac(tmp1[i], tmpx1[i]);
389 tmp2[i] = calc_igammac(tmp2[i], tmpx2[i]);
390 }
391 auto o1 = _mm256_loadu_ps(tmp1);
392 auto o2 = _mm256_loadu_ps(tmp2);
393 return cvtfp32_bf16(o1, o2);
394 }
395 Vectorized<BFloat16> log() const {
396 return map(Sleef_logf8_u10);
397 }
398 Vectorized<BFloat16> log2() const {
399 return map(Sleef_log2f8_u10);
400 }
401 Vectorized<BFloat16> log10() const {
402 return map(Sleef_log10f8_u10);
403 }
404 Vectorized<BFloat16> log1p() const {
405 return map(Sleef_log1pf8_u10);
406 }
407 Vectorized<BFloat16> frac() const;
408 Vectorized<BFloat16> sin() const {
409 return map(Sleef_sinf8_u10);
410 }
411 Vectorized<BFloat16> sinh() const {
412 return map(Sleef_sinhf8_u10);
413 }
414 Vectorized<BFloat16> cos() const {
415 return map(Sleef_cosf8_u10);
416 }
417 Vectorized<BFloat16> cosh() const {
418 return map(Sleef_coshf8_u10);
419 }
420 Vectorized<BFloat16> ceil() const {
421 __m256 lo, hi;
422 cvtbf16_fp32(values, lo, hi);
423 auto o1 = _mm256_ceil_ps(lo);
424 auto o2 = _mm256_ceil_ps(hi);
425 return cvtfp32_bf16(o1, o2);
426 }
427 Vectorized<BFloat16> floor() const {
428 __m256 lo, hi;
429 cvtbf16_fp32(values, lo, hi);
430 auto o1 = _mm256_floor_ps(lo);
431 auto o2 = _mm256_floor_ps(hi);
432 return cvtfp32_bf16(o1, o2);
433 }
434 Vectorized<BFloat16> neg() const {
435 __m256 lo, hi;
436 cvtbf16_fp32(values, lo, hi);
437 auto mask = _mm256_set1_ps(-0.f);
438 auto o1 = _mm256_xor_ps(mask, lo);
439 auto o2 = _mm256_xor_ps(mask, hi);
440 return cvtfp32_bf16(o1, o2);
441 }
442 Vectorized<BFloat16> round() const {
443 __m256 lo, hi;
444 cvtbf16_fp32(values, lo, hi);
445 auto o1 = _mm256_round_ps(lo, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
446 auto o2 = _mm256_round_ps(hi, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
447 return cvtfp32_bf16(o1, o2);
448 }
449 Vectorized<BFloat16> tan() const {
450 return map(Sleef_tanf8_u10);
451 }
452 Vectorized<BFloat16> tanh() const {
453 return map(Sleef_tanhf8_u10);
454 }
455 Vectorized<BFloat16> trunc() const {
456 __m256 lo, hi;
457 cvtbf16_fp32(values, lo, hi);
458 auto o1 = _mm256_round_ps(lo, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
459 auto o2 = _mm256_round_ps(hi, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
460 return cvtfp32_bf16(o1, o2);
461 }
462 Vectorized<BFloat16> lgamma() const {
463 return map(Sleef_lgammaf8_u10);
464 }
465 Vectorized<BFloat16> sqrt() const {
466 __m256 lo, hi;
467 cvtbf16_fp32(values, lo, hi);
468 auto o1 = _mm256_sqrt_ps(lo);
469 auto o2 = _mm256_sqrt_ps(hi);
470 return cvtfp32_bf16(o1, o2);
471 }
472 Vectorized<BFloat16> reciprocal() const {
473 __m256 lo, hi;
474 cvtbf16_fp32(values, lo, hi);
475 auto ones = _mm256_set1_ps(1);
476 auto o1 = _mm256_div_ps(ones, lo);
477 auto o2 = _mm256_div_ps(ones, hi);
478 return cvtfp32_bf16(o1, o2);
479 }
480 Vectorized<BFloat16> rsqrt() const {
481 __m256 lo, hi;
482 cvtbf16_fp32(values, lo, hi);
483 auto ones = _mm256_set1_ps(1);
484 auto o1 = _mm256_div_ps(ones, _mm256_sqrt_ps(lo));
485 auto o2 = _mm256_div_ps(ones, _mm256_sqrt_ps(hi));
486 return cvtfp32_bf16(o1, o2);
487 }
488 Vectorized<BFloat16> pow(const Vectorized<BFloat16> &b) const {
489 __m256 lo, hi;
490 __m256 b1, b2;
491 cvtbf16_fp32(values, lo, hi);
492 cvtbf16_fp32(b.values, b1, b2);
493 auto o1 = Sleef_powf8_u10(lo, b1);
494 auto o2 = Sleef_powf8_u10(hi, b2);
495 return cvtfp32_bf16(o1, o2);
496 }
497
498 Vectorized<BFloat16> inline operator>(const Vectorized<BFloat16>& other) const;
499 Vectorized<BFloat16> inline operator<(const Vectorized<BFloat16>& other) const;
500 Vectorized<BFloat16> inline operator>=(const Vectorized<BFloat16>& other) const;
501 Vectorized<BFloat16> inline operator<=(const Vectorized<BFloat16>& other) const;
502 Vectorized<BFloat16> inline operator==(const Vectorized<BFloat16>& other) const;
503 Vectorized<BFloat16> inline operator!=(const Vectorized<BFloat16>& other) const;
504
505 Vectorized<BFloat16> eq(const Vectorized<BFloat16>& other) const;
506 Vectorized<BFloat16> ne(const Vectorized<BFloat16>& other) const;
507 Vectorized<BFloat16> gt(const Vectorized<BFloat16>& other) const;
508 Vectorized<BFloat16> ge(const Vectorized<BFloat16>& other) const;
509 Vectorized<BFloat16> lt(const Vectorized<BFloat16>& other) const;
510 Vectorized<BFloat16> le(const Vectorized<BFloat16>& other) const;
511};
512
513template<typename Op>
514Vectorized<BFloat16> static inline bfloat16_binary_op_as_fp32(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b, Op op) {
515 __m256 a_lo, a_hi;
516 __m256 b_lo, b_hi;
517 cvtbf16_fp32(__m256i(a), a_lo, a_hi);
518 cvtbf16_fp32(__m256i(b), b_lo, b_hi);
519 auto o1 = op(a_lo, b_lo);
520 auto o2 = op(a_hi, b_hi);
521 return cvtfp32_bf16(o1, o2);
522}
523
524template<typename Op>
525Vectorized<BFloat16> static inline bfloat16_compare_as_fp32(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b, Op op) {
526 __m256 a_lo, a_hi;
527 __m256 b_lo, b_hi;
528 cvtbf16_fp32(__m256i(a), a_lo, a_hi);
529 cvtbf16_fp32(__m256i(b), b_lo, b_hi);
530 auto o1 = op(a_lo, b_lo);
531 auto o2 = op(a_hi, b_hi);
532 return merge_compare_result(o1, o2);
533}
534
535Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>(const Vectorized<BFloat16>& other) const {
536 return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) {
537 return _mm256_cmp_ps(x, y, _CMP_GT_OQ);
538 });
539}
540Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<(const Vectorized<BFloat16>& other) const {
541 return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) {
542 return _mm256_cmp_ps(x, y, _CMP_LT_OQ);
543 });
544}
545Vectorized<BFloat16> inline Vectorized<BFloat16>::operator>=(const Vectorized<BFloat16>& other) const {
546 return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) {
547 return _mm256_cmp_ps(x, y, _CMP_GE_OQ);
548 });
549}
550Vectorized<BFloat16> inline Vectorized<BFloat16>::operator<=(const Vectorized<BFloat16>& other) const {
551 return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) {
552 return _mm256_cmp_ps(x, y, _CMP_LE_OQ);
553 });
554}
555Vectorized<BFloat16> inline Vectorized<BFloat16>::operator==(const Vectorized<BFloat16>& other) const {
556 return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) {
557 return _mm256_cmp_ps(x, y, _CMP_EQ_OQ);
558 });
559}
560Vectorized<BFloat16> inline Vectorized<BFloat16>::operator!=(const Vectorized<BFloat16>& other) const {
561 return bfloat16_compare_as_fp32(*this, other, [](__m256 x, __m256 y) {
562 return _mm256_cmp_ps(x, y, _CMP_NEQ_UQ);
563 });
564}
565
566Vectorized<BFloat16> inline operator+(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
567 return bfloat16_binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_add_ps(x, y); });
568}
569Vectorized<BFloat16> inline operator-(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
570 return bfloat16_binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_sub_ps(x, y); });
571}
572Vectorized<BFloat16> inline operator*(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
573 return bfloat16_binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_mul_ps(x, y); });
574}
575Vectorized<BFloat16> inline operator/(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
576 return bfloat16_binary_op_as_fp32(a, b, [](const __m256& x, const __m256& y) { return _mm256_div_ps(x, y); });
577}
578
579Vectorized<BFloat16> inline operator&(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
580 return _mm256_and_si256(a, b);
581}
582Vectorized<BFloat16> inline operator|(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
583 return _mm256_or_si256(a, b);
584}
585Vectorized<BFloat16> inline operator^(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
586 return _mm256_xor_si256(a, b);
587}
588
589inline Vectorized<BFloat16> Vectorized<BFloat16>::eq(const Vectorized<BFloat16>& other) const {
590 return (*this == other) & Vectorized<BFloat16>(1.0f);
591}
592
593inline Vectorized<BFloat16> Vectorized<BFloat16>::ne(const Vectorized<BFloat16>& other) const {
594 return (*this != other) & Vectorized<BFloat16>(1.0f);
595}
596
597inline Vectorized<BFloat16> Vectorized<BFloat16>::gt(const Vectorized<BFloat16>& other) const {
598 return (*this > other) & Vectorized<BFloat16>(1.0f);
599}
600
601inline Vectorized<BFloat16> Vectorized<BFloat16>::ge(const Vectorized<BFloat16>& other) const {
602 return (*this >= other) & Vectorized<BFloat16>(1.0f);
603}
604
605inline Vectorized<BFloat16> Vectorized<BFloat16>::lt(const Vectorized<BFloat16>& other) const {
606 return (*this < other) & Vectorized<BFloat16>(1.0f);
607}
608
609inline Vectorized<BFloat16> Vectorized<BFloat16>::le(const Vectorized<BFloat16>& other) const {
610 return (*this <= other) & Vectorized<BFloat16>(1.0f);
611}
612
613// frac. Implement this here so we can use subtraction
614inline Vectorized<BFloat16> Vectorized<BFloat16>::frac() const {
615 return *this - this->trunc();
616}
617
618// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
619// either input is a NaN.
620template <>
621Vectorized<BFloat16> inline maximum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
622 __m256 a_lo, a_hi;
623 __m256 b_lo, b_hi;
624 cvtbf16_fp32(__m256i(a), a_lo, a_hi);
625 cvtbf16_fp32(__m256i(b), b_lo, b_hi);
626 auto max_lo = _mm256_max_ps(a_lo, b_lo);
627 auto max_hi = _mm256_max_ps(a_hi, b_hi);
628 auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q);
629 auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q);
630 // Exploit the fact that all-ones is a NaN.
631 auto o1 = _mm256_or_ps(max_lo, nan_lo);
632 auto o2 = _mm256_or_ps(max_hi, nan_hi);
633 return cvtfp32_bf16(o1, o2);
634}
635
636// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
637// either input is a NaN.
638template <>
639Vectorized<BFloat16> inline minimum(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& b) {
640 __m256 a_lo, a_hi;
641 __m256 b_lo, b_hi;
642 cvtbf16_fp32(__m256i(a), a_lo, a_hi);
643 cvtbf16_fp32(__m256i(b), b_lo, b_hi);
644 auto min_lo = _mm256_min_ps(a_lo, b_lo);
645 auto min_hi = _mm256_min_ps(a_hi, b_hi);
646 auto nan_lo = _mm256_cmp_ps(a_lo, b_lo, _CMP_UNORD_Q);
647 auto nan_hi = _mm256_cmp_ps(a_hi, b_hi, _CMP_UNORD_Q);
648 // Exploit the fact that all-ones is a NaN.
649 auto o1 = _mm256_or_ps(min_lo, nan_lo);
650 auto o2 = _mm256_or_ps(min_hi, nan_hi);
651 return cvtfp32_bf16(o1, o2);
652}
653
654template <>
655Vectorized<BFloat16> inline clamp(const Vectorized<BFloat16>& a,
656 const Vectorized<BFloat16>& min, const Vectorized<BFloat16>& max) {
657 __m256 a_lo, a_hi;
658 __m256 min_lo, min_hi;
659 __m256 max_lo, max_hi;
660 cvtbf16_fp32(__m256i(a), a_lo, a_hi);
661 cvtbf16_fp32(__m256i(min), min_lo, min_hi);
662 cvtbf16_fp32(__m256i(max), max_lo, max_hi);
663 auto o1 = _mm256_min_ps(max_lo, _mm256_max_ps(min_lo, a_lo));
664 auto o2 = _mm256_min_ps(max_hi, _mm256_max_ps(min_hi, a_hi));
665 return cvtfp32_bf16(o1, o2);
666}
667
668template <>
669Vectorized<BFloat16> inline clamp_max(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& max) {
670 __m256 a_lo, a_hi;
671 __m256 max_lo, max_hi;
672 cvtbf16_fp32(__m256i(a), a_lo, a_hi);
673 cvtbf16_fp32(__m256i(max), max_lo, max_hi);
674 auto o1 = _mm256_min_ps(max_lo, a_lo);
675 auto o2 = _mm256_min_ps(max_hi, a_hi);
676 return cvtfp32_bf16(o1, o2);
677}
678
679template <>
680Vectorized<BFloat16> inline clamp_min(const Vectorized<BFloat16>& a, const Vectorized<BFloat16>& min) {
681 __m256 a_lo, a_hi;
682 __m256 min_lo, min_hi;
683 cvtbf16_fp32(__m256i(a), a_lo, a_hi);
684 cvtbf16_fp32(__m256i(min), min_lo, min_hi);
685 auto o1 = _mm256_max_ps(min_lo, a_lo);
686 auto o2 = _mm256_max_ps(min_hi, a_hi);
687 return cvtfp32_bf16(o1, o2);
688}
689
690template <>
691inline void convert(const BFloat16* src, BFloat16* dst, int64_t n) {
692 int64_t i;
693#pragma unroll
694 for (i = 0; i <= (n - Vectorized<BFloat16>::size()); i += Vectorized<BFloat16>::size()) {
695 auto vsrc = _mm256_loadu_si256(reinterpret_cast<__m256i*>((void*)(src + i)));
696 _mm256_storeu_si256(reinterpret_cast<__m256i*>((void*)(dst + i)), vsrc);
697 }
698#pragma unroll
699 for (; i < n; i++) {
700 dst[i] = src[i];
701 }
702}
703
704template <>
705inline void convert(const float* src, BFloat16* dst, int64_t n) {
706 int64_t i;
707 for (i = 0; i + Vectorized<BFloat16>::size() <= n; i += Vectorized<BFloat16>::size()) {
708 __m256 a = _mm256_loadu_ps(&src[i]);
709 __m256 b = _mm256_loadu_ps(&src[i + 8]);
710
711 __m256i bf = cvtfp32_bf16(a, b);
712 _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf);
713 }
714 for (; i < n; i++) {
715 dst[i] = c10::convert<BFloat16>(src[i]);
716 }
717}
718
719template <>
720inline void convert(const double* src, BFloat16* dst, int64_t n) {
721 auto load_float = [](const double *src) -> __m256 {
722 // Load one float vector from an array of doubles
723 __m128 a = _mm256_cvtpd_ps(_mm256_loadu_pd(src));
724 __m128 b = _mm256_cvtpd_ps(_mm256_loadu_pd(src + 4));
725 return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1);
726 };
727
728 int64_t i;
729 for (i = 0; i + Vectorized<BFloat16>::size() <= n; i += Vectorized<BFloat16>::size()) {
730 __m256 a = load_float(&src[i]);
731 __m256 b = load_float(&src[i + 8]);
732
733 __m256i bf = cvtfp32_bf16(a, b);
734 _mm256_storeu_si256(reinterpret_cast<__m256i*>(&dst[i]), bf);
735 }
736 for (; i < n; i++) {
737 dst[i] = c10::convert<BFloat16>(src[i]);
738 }
739}
740
741template <>
742Vectorized<BFloat16> inline fmadd(const Vectorized<BFloat16>& a,
743 const Vectorized<BFloat16>& b, const Vectorized<BFloat16>& c) {
744 __m256 a_lo, a_hi;
745 __m256 b_lo, b_hi;
746 __m256 c_lo, c_hi;
747 cvtbf16_fp32(__m256i(a), a_lo, a_hi);
748 cvtbf16_fp32(__m256i(b), b_lo, b_hi);
749 cvtbf16_fp32(__m256i(c), c_lo, c_hi);
750 auto o1 = _mm256_fmadd_ps(a_lo, b_lo, c_lo);
751 auto o2 = _mm256_fmadd_ps(a_hi, b_hi, c_hi);
752 return cvtfp32_bf16(o1, o2);
753}
754
755inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) {
756 __m256 o1, o2;
757 cvtbf16_fp32(__m256i(a), o1, o2);
758 return std::make_tuple(o1, o2);
759}
760
761inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) {
762 return cvtfp32_bf16(__m256(a), __m256(b));
763}
764
765
766#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
767
768inline std::tuple<Vectorized<float>, Vectorized<float>> convert_bfloat16_float(const Vectorized<BFloat16>& a) {
769 constexpr int64_t K = Vectorized<BFloat16>::size();
770 __at_align__ float arr[K];
771 __at_align__ BFloat16 arr2[K];
772 a.store(arr2);
773 convert(arr2, arr, K);
774 return std::make_tuple(
775 Vectorized<float>::loadu(arr),
776 Vectorized<float>::loadu(arr + Vectorized<float>::size()));
777}
778
779inline Vectorized<BFloat16> convert_float_bfloat16(const Vectorized<float>& a, const Vectorized<float>& b) {
780 constexpr int64_t K = Vectorized<BFloat16>::size();
781 __at_align__ float arr[K];
782 __at_align__ BFloat16 arr2[K];
783 a.store(arr);
784 b.store(arr + Vectorized<float>::size());
785 convert(arr, arr2, K);
786 return Vectorized<BFloat16>::loadu(arr2);
787}
788
789#endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
790
791#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
792inline void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
793 auto values = _mm_loadu_si128(reinterpret_cast<const __m128i*>(data));
794 __m256 out_values;
795 cvtbf16_fp32(values, out_values);
796 out = out_values;
797}
798
799inline void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out1, Vectorized<float>& out2) {
800 auto vec = Vectorized<c10::BFloat16>::loadu(data);
801 __m256 out1_values, out2_values;
802 cvtbf16_fp32(vec, out1_values, out2_values);
803 out1 = out1_values;
804 out2 = out2_values;
805}
806#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
807inline void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out) {
808 __at_align__ float values[Vectorized<float>::size()];
809 for (const auto k : c10::irange(Vectorized<float>::size())) {
810 values[k] = data[k];
811 }
812 out = Vectorized<float>::loadu(values);
813}
814
815inline void load_fp32_from_bf16(const c10::BFloat16 *data, Vectorized<float>& out1, Vectorized<float>& out2) {
816 load_fp32_from_bf16(data, out1);
817 data += Vectorized<float>::size();
818 load_fp32_from_bf16(data, out2);
819}
820#endif
821
822}}}
823
824#pragma GCC diagnostic pop
825