1 | #pragma once |
2 | |
3 | // DO NOT DEFINE STATIC DATA IN THIS HEADER! |
4 | // See Note [Do not compile initializers with AVX] |
5 | |
6 | #include <ATen/cpu/vec/intrinsics.h> |
7 | #include <ATen/cpu/vec/vec_base.h> |
8 | #include <c10/util/irange.h> |
9 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
10 | #include <sleef.h> |
11 | #endif |
12 | |
13 | namespace at { |
14 | namespace vec { |
15 | // See Note [CPU_CAPABILITY namespace] |
16 | inline namespace CPU_CAPABILITY { |
17 | |
18 | |
19 | #if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) |
20 | |
21 | template <> class Vectorized<double> { |
22 | private: |
23 | __m256d values; |
24 | public: |
25 | using value_type = double; |
26 | using size_type = int; |
27 | static constexpr size_type size() { |
28 | return 4; |
29 | } |
30 | Vectorized() {} |
31 | Vectorized(__m256d v) : values(v) {} |
32 | Vectorized(double val) { |
33 | values = _mm256_set1_pd(val); |
34 | } |
35 | Vectorized(double val1, double val2, double val3, double val4) { |
36 | values = _mm256_setr_pd(val1, val2, val3, val4); |
37 | } |
38 | operator __m256d() const { |
39 | return values; |
40 | } |
41 | template <int64_t mask> |
42 | static Vectorized<double> blend(const Vectorized<double>& a, const Vectorized<double>& b) { |
43 | return _mm256_blend_pd(a.values, b.values, mask); |
44 | } |
45 | static Vectorized<double> blendv(const Vectorized<double>& a, const Vectorized<double>& b, |
46 | const Vectorized<double>& mask) { |
47 | return _mm256_blendv_pd(a.values, b.values, mask.values); |
48 | } |
49 | template<typename step_t> |
50 | static Vectorized<double> arange(double base = 0., step_t step = static_cast<step_t>(1)) { |
51 | return Vectorized<double>(base, base + step, base + 2 * step, base + 3 * step); |
52 | } |
53 | static Vectorized<double> set(const Vectorized<double>& a, const Vectorized<double>& b, |
54 | int64_t count = size()) { |
55 | switch (count) { |
56 | case 0: |
57 | return a; |
58 | case 1: |
59 | return blend<1>(a, b); |
60 | case 2: |
61 | return blend<3>(a, b); |
62 | case 3: |
63 | return blend<7>(a, b); |
64 | } |
65 | return b; |
66 | } |
67 | static Vectorized<double> loadu(const void* ptr, int64_t count = size()) { |
68 | if (count == size()) |
69 | return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr)); |
70 | |
71 | |
72 | __at_align__ double tmp_values[size()]; |
73 | // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502 |
74 | // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two |
75 | // instructions while a loop would be compiled to one instruction. |
76 | for (const auto i : c10::irange(size())) { |
77 | tmp_values[i] = 0.0; |
78 | } |
79 | std::memcpy( |
80 | tmp_values, |
81 | reinterpret_cast<const double*>(ptr), |
82 | count * sizeof(double)); |
83 | return _mm256_load_pd(tmp_values); |
84 | } |
85 | void store(void* ptr, int count = size()) const { |
86 | if (count == size()) { |
87 | _mm256_storeu_pd(reinterpret_cast<double*>(ptr), values); |
88 | } else if (count > 0) { |
89 | double tmp_values[size()]; |
90 | _mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values); |
91 | std::memcpy(ptr, tmp_values, count * sizeof(double)); |
92 | } |
93 | } |
94 | const double& operator[](int idx) const = delete; |
95 | double& operator[](int idx) = delete; |
96 | int zero_mask() const { |
97 | // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit |
98 | __m256d cmp = _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_EQ_OQ); |
99 | return _mm256_movemask_pd(cmp); |
100 | } |
101 | Vectorized<double> isnan() const { |
102 | return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q); |
103 | } |
104 | Vectorized<double> map(double (*const f)(double)) const { |
105 | __at_align__ double tmp[size()]; |
106 | store(tmp); |
107 | for (const auto i : c10::irange(size())) { |
108 | tmp[i] = f(tmp[i]); |
109 | } |
110 | return loadu(tmp); |
111 | } |
112 | Vectorized<double> abs() const { |
113 | auto mask = _mm256_set1_pd(-0.f); |
114 | return _mm256_andnot_pd(mask, values); |
115 | } |
116 | Vectorized<double> angle() const { |
117 | const auto zero_vec = _mm256_set1_pd(0.f); |
118 | const auto nan_vec = _mm256_set1_pd(NAN); |
119 | const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ); |
120 | const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ); |
121 | const auto pi = _mm256_set1_pd(c10::pi<double>); |
122 | |
123 | const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ); |
124 | auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask); |
125 | angle = _mm256_blendv_pd(angle, nan_vec, nan_mask); |
126 | return angle; |
127 | } |
128 | Vectorized<double> real() const { |
129 | return *this; |
130 | } |
131 | Vectorized<double> imag() const { |
132 | return _mm256_set1_pd(0); |
133 | } |
134 | Vectorized<double> conj() const { |
135 | return *this; |
136 | } |
137 | Vectorized<double> acos() const { |
138 | return Vectorized<double>(Sleef_acosd4_u10(values)); |
139 | } |
140 | Vectorized<double> asin() const { |
141 | return Vectorized<double>(Sleef_asind4_u10(values)); |
142 | } |
143 | Vectorized<double> atan() const { |
144 | return Vectorized<double>(Sleef_atand4_u10(values)); |
145 | } |
146 | Vectorized<double> atan2(const Vectorized<double> &b) const { |
147 | return Vectorized<double>(Sleef_atan2d4_u10(values, b)); |
148 | } |
149 | Vectorized<double> copysign(const Vectorized<double> &sign) const { |
150 | return Vectorized<double>(Sleef_copysignd4(values, sign)); |
151 | } |
152 | Vectorized<double> erf() const { |
153 | return Vectorized<double>(Sleef_erfd4_u10(values)); |
154 | } |
155 | Vectorized<double> erfc() const { |
156 | return Vectorized<double>(Sleef_erfcd4_u15(values)); |
157 | } |
158 | Vectorized<double> erfinv() const { |
159 | return map(calc_erfinv); |
160 | } |
161 | Vectorized<double> exp() const { |
162 | return Vectorized<double>(Sleef_expd4_u10(values)); |
163 | } |
164 | Vectorized<double> exp2() const { |
165 | return Vectorized<double>(Sleef_exp2d4_u10(values)); |
166 | } |
167 | Vectorized<double> expm1() const { |
168 | return Vectorized<double>(Sleef_expm1d4_u10(values)); |
169 | } |
170 | Vectorized<double> fmod(const Vectorized<double>& q) const { |
171 | return Vectorized<double>(Sleef_fmodd4(values, q)); |
172 | } |
173 | Vectorized<double> hypot(const Vectorized<double> &b) const { |
174 | return Vectorized<double>(Sleef_hypotd4_u05(values, b)); |
175 | } |
176 | Vectorized<double> i0() const { |
177 | return map(calc_i0); |
178 | } |
179 | Vectorized<double> i0e() const { |
180 | return map(calc_i0e); |
181 | } |
182 | Vectorized<double> igamma(const Vectorized<double> &x) const { |
183 | __at_align__ double tmp[size()]; |
184 | __at_align__ double tmp_x[size()]; |
185 | store(tmp); |
186 | x.store(tmp_x); |
187 | for (const auto i : c10::irange(size())) { |
188 | tmp[i] = calc_igamma(tmp[i], tmp_x[i]); |
189 | } |
190 | return loadu(tmp); |
191 | } |
192 | Vectorized<double> igammac(const Vectorized<double> &x) const { |
193 | __at_align__ double tmp[size()]; |
194 | __at_align__ double tmp_x[size()]; |
195 | store(tmp); |
196 | x.store(tmp_x); |
197 | for (const auto i : c10::irange(size())) { |
198 | tmp[i] = calc_igammac(tmp[i], tmp_x[i]); |
199 | } |
200 | return loadu(tmp); |
201 | } |
202 | Vectorized<double> log() const { |
203 | return Vectorized<double>(Sleef_logd4_u10(values)); |
204 | } |
205 | Vectorized<double> log2() const { |
206 | return Vectorized<double>(Sleef_log2d4_u10(values)); |
207 | } |
208 | Vectorized<double> log10() const { |
209 | return Vectorized<double>(Sleef_log10d4_u10(values)); |
210 | } |
211 | Vectorized<double> log1p() const { |
212 | return Vectorized<double>(Sleef_log1pd4_u10(values)); |
213 | } |
214 | Vectorized<double> sin() const { |
215 | return Vectorized<double>(Sleef_sind4_u10(values)); |
216 | } |
217 | Vectorized<double> sinh() const { |
218 | return Vectorized<double>(Sleef_sinhd4_u10(values)); |
219 | } |
220 | Vectorized<double> cos() const { |
221 | return Vectorized<double>(Sleef_cosd4_u10(values)); |
222 | } |
223 | Vectorized<double> cosh() const { |
224 | return Vectorized<double>(Sleef_coshd4_u10(values)); |
225 | } |
226 | Vectorized<double> ceil() const { |
227 | return _mm256_ceil_pd(values); |
228 | } |
229 | Vectorized<double> floor() const { |
230 | return _mm256_floor_pd(values); |
231 | } |
232 | Vectorized<double> frac() const; |
233 | Vectorized<double> neg() const { |
234 | return _mm256_xor_pd(_mm256_set1_pd(-0.), values); |
235 | } |
236 | Vectorized<double> nextafter(const Vectorized<double> &b) const { |
237 | return Vectorized<double>(Sleef_nextafterd4(values, b)); |
238 | } |
239 | Vectorized<double> round() const { |
240 | return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); |
241 | } |
242 | Vectorized<double> tan() const { |
243 | return Vectorized<double>(Sleef_tand4_u10(values)); |
244 | } |
245 | Vectorized<double> tanh() const { |
246 | return Vectorized<double>(Sleef_tanhd4_u10(values)); |
247 | } |
248 | Vectorized<double> trunc() const { |
249 | return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)); |
250 | } |
251 | Vectorized<double> lgamma() const { |
252 | return Vectorized<double>(Sleef_lgammad4_u10(values)); |
253 | } |
254 | Vectorized<double> sqrt() const { |
255 | return _mm256_sqrt_pd(values); |
256 | } |
257 | Vectorized<double> reciprocal() const { |
258 | return _mm256_div_pd(_mm256_set1_pd(1), values); |
259 | } |
260 | Vectorized<double> rsqrt() const { |
261 | return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values)); |
262 | } |
263 | Vectorized<double> pow(const Vectorized<double> &b) const { |
264 | return Vectorized<double>(Sleef_powd4_u10(values, b)); |
265 | } |
266 | // Comparison using the _CMP_**_OQ predicate. |
267 | // `O`: get false if an operand is NaN |
268 | // `Q`: do not raise if an operand is NaN |
269 | Vectorized<double> operator==(const Vectorized<double>& other) const { |
270 | return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ); |
271 | } |
272 | |
273 | Vectorized<double> operator!=(const Vectorized<double>& other) const { |
274 | return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ); |
275 | } |
276 | |
277 | Vectorized<double> operator<(const Vectorized<double>& other) const { |
278 | return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ); |
279 | } |
280 | |
281 | Vectorized<double> operator<=(const Vectorized<double>& other) const { |
282 | return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ); |
283 | } |
284 | |
285 | Vectorized<double> operator>(const Vectorized<double>& other) const { |
286 | return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ); |
287 | } |
288 | |
289 | Vectorized<double> operator>=(const Vectorized<double>& other) const { |
290 | return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ); |
291 | } |
292 | |
293 | Vectorized<double> eq(const Vectorized<double>& other) const; |
294 | Vectorized<double> ne(const Vectorized<double>& other) const; |
295 | Vectorized<double> lt(const Vectorized<double>& other) const; |
296 | Vectorized<double> le(const Vectorized<double>& other) const; |
297 | Vectorized<double> gt(const Vectorized<double>& other) const; |
298 | Vectorized<double> ge(const Vectorized<double>& other) const; |
299 | }; |
300 | |
301 | template <> |
302 | Vectorized<double> inline operator+(const Vectorized<double>& a, const Vectorized<double>& b) { |
303 | return _mm256_add_pd(a, b); |
304 | } |
305 | |
306 | template <> |
307 | Vectorized<double> inline operator-(const Vectorized<double>& a, const Vectorized<double>& b) { |
308 | return _mm256_sub_pd(a, b); |
309 | } |
310 | |
311 | template <> |
312 | Vectorized<double> inline operator*(const Vectorized<double>& a, const Vectorized<double>& b) { |
313 | return _mm256_mul_pd(a, b); |
314 | } |
315 | |
316 | template <> |
317 | Vectorized<double> inline operator/(const Vectorized<double>& a, const Vectorized<double>& b) { |
318 | return _mm256_div_pd(a, b); |
319 | } |
320 | |
321 | // frac. Implement this here so we can use subtraction. |
322 | inline Vectorized<double> Vectorized<double>::frac() const { |
323 | return *this - this->trunc(); |
324 | } |
325 | |
326 | // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if |
327 | // either input is a NaN. |
328 | template <> |
329 | Vectorized<double> inline maximum(const Vectorized<double>& a, const Vectorized<double>& b) { |
330 | Vectorized<double> max = _mm256_max_pd(a, b); |
331 | Vectorized<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q); |
332 | // Exploit the fact that all-ones is a NaN. |
333 | return _mm256_or_pd(max, isnan); |
334 | } |
335 | |
336 | // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if |
337 | // either input is a NaN. |
338 | template <> |
339 | Vectorized<double> inline minimum(const Vectorized<double>& a, const Vectorized<double>& b) { |
340 | Vectorized<double> min = _mm256_min_pd(a, b); |
341 | Vectorized<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q); |
342 | // Exploit the fact that all-ones is a NaN. |
343 | return _mm256_or_pd(min, isnan); |
344 | } |
345 | |
346 | template <> |
347 | Vectorized<double> inline clamp(const Vectorized<double>& a, const Vectorized<double>& min, const Vectorized<double>& max) { |
348 | return _mm256_min_pd(max, _mm256_max_pd(min, a)); |
349 | } |
350 | |
351 | template <> |
352 | Vectorized<double> inline clamp_min(const Vectorized<double>& a, const Vectorized<double>& min) { |
353 | return _mm256_max_pd(min, a); |
354 | } |
355 | |
356 | template <> |
357 | Vectorized<double> inline clamp_max(const Vectorized<double>& a, const Vectorized<double>& max) { |
358 | return _mm256_min_pd(max, a); |
359 | } |
360 | |
361 | template <> |
362 | Vectorized<double> inline operator&(const Vectorized<double>& a, const Vectorized<double>& b) { |
363 | return _mm256_and_pd(a, b); |
364 | } |
365 | |
366 | template <> |
367 | Vectorized<double> inline operator|(const Vectorized<double>& a, const Vectorized<double>& b) { |
368 | return _mm256_or_pd(a, b); |
369 | } |
370 | |
371 | template <> |
372 | Vectorized<double> inline operator^(const Vectorized<double>& a, const Vectorized<double>& b) { |
373 | return _mm256_xor_pd(a, b); |
374 | } |
375 | |
376 | inline Vectorized<double> Vectorized<double>::eq(const Vectorized<double>& other) const { |
377 | return (*this == other) & Vectorized<double>(1.0); |
378 | } |
379 | |
380 | inline Vectorized<double> Vectorized<double>::ne(const Vectorized<double>& other) const { |
381 | return (*this != other) & Vectorized<double>(1.0); |
382 | } |
383 | |
384 | inline Vectorized<double> Vectorized<double>::gt(const Vectorized<double>& other) const { |
385 | return (*this > other) & Vectorized<double>(1.0); |
386 | } |
387 | |
388 | inline Vectorized<double> Vectorized<double>::ge(const Vectorized<double>& other) const { |
389 | return (*this >= other) & Vectorized<double>(1.0); |
390 | } |
391 | |
392 | inline Vectorized<double> Vectorized<double>::lt(const Vectorized<double>& other) const { |
393 | return (*this < other) & Vectorized<double>(1.0); |
394 | } |
395 | |
396 | inline Vectorized<double> Vectorized<double>::le(const Vectorized<double>& other) const { |
397 | return (*this <= other) & Vectorized<double>(1.0); |
398 | } |
399 | |
400 | template <> |
401 | inline void convert(const double* src, double* dst, int64_t n) { |
402 | int64_t i; |
403 | #pragma unroll |
404 | for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) { |
405 | _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i)); |
406 | } |
407 | #pragma unroll |
408 | for (; i < n; i++) { |
409 | dst[i] = src[i]; |
410 | } |
411 | } |
412 | |
413 | #ifdef CPU_CAPABILITY_AVX2 |
414 | template <> |
415 | Vectorized<double> inline fmadd(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) { |
416 | return _mm256_fmadd_pd(a, b, c); |
417 | } |
418 | |
419 | template <> |
420 | Vectorized<double> inline fmsub(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) { |
421 | return _mm256_fmsub_pd(a, b, c); |
422 | } |
423 | #endif |
424 | |
425 | #endif |
426 | |
427 | }}} |
428 | |