1#pragma once
2
3// DO NOT DEFINE STATIC DATA IN THIS HEADER!
4// See Note [Do not compile initializers with AVX]
5
6#include <ATen/cpu/vec/intrinsics.h>
7#include <ATen/cpu/vec/vec_base.h>
8#include <c10/util/irange.h>
9#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
10#include <sleef.h>
11#endif
12
13namespace at {
14namespace vec {
15// See Note [CPU_CAPABILITY namespace]
16inline namespace CPU_CAPABILITY {
17
18
19#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
20
21template <> class Vectorized<double> {
22private:
23 __m256d values;
24public:
25 using value_type = double;
26 using size_type = int;
27 static constexpr size_type size() {
28 return 4;
29 }
30 Vectorized() {}
31 Vectorized(__m256d v) : values(v) {}
32 Vectorized(double val) {
33 values = _mm256_set1_pd(val);
34 }
35 Vectorized(double val1, double val2, double val3, double val4) {
36 values = _mm256_setr_pd(val1, val2, val3, val4);
37 }
38 operator __m256d() const {
39 return values;
40 }
41 template <int64_t mask>
42 static Vectorized<double> blend(const Vectorized<double>& a, const Vectorized<double>& b) {
43 return _mm256_blend_pd(a.values, b.values, mask);
44 }
45 static Vectorized<double> blendv(const Vectorized<double>& a, const Vectorized<double>& b,
46 const Vectorized<double>& mask) {
47 return _mm256_blendv_pd(a.values, b.values, mask.values);
48 }
49 template<typename step_t>
50 static Vectorized<double> arange(double base = 0., step_t step = static_cast<step_t>(1)) {
51 return Vectorized<double>(base, base + step, base + 2 * step, base + 3 * step);
52 }
53 static Vectorized<double> set(const Vectorized<double>& a, const Vectorized<double>& b,
54 int64_t count = size()) {
55 switch (count) {
56 case 0:
57 return a;
58 case 1:
59 return blend<1>(a, b);
60 case 2:
61 return blend<3>(a, b);
62 case 3:
63 return blend<7>(a, b);
64 }
65 return b;
66 }
67 static Vectorized<double> loadu(const void* ptr, int64_t count = size()) {
68 if (count == size())
69 return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
70
71
72 __at_align__ double tmp_values[size()];
73 // Ensure uninitialized memory does not change the output value See https://github.com/pytorch/pytorch/issues/32502
74 // for more details. We do not initialize arrays to zero using "={0}" because gcc would compile it to two
75 // instructions while a loop would be compiled to one instruction.
76 for (const auto i : c10::irange(size())) {
77 tmp_values[i] = 0.0;
78 }
79 std::memcpy(
80 tmp_values,
81 reinterpret_cast<const double*>(ptr),
82 count * sizeof(double));
83 return _mm256_load_pd(tmp_values);
84 }
85 void store(void* ptr, int count = size()) const {
86 if (count == size()) {
87 _mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
88 } else if (count > 0) {
89 double tmp_values[size()];
90 _mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
91 std::memcpy(ptr, tmp_values, count * sizeof(double));
92 }
93 }
94 const double& operator[](int idx) const = delete;
95 double& operator[](int idx) = delete;
96 int zero_mask() const {
97 // returns an integer mask where all zero elements are translated to 1-bit and others are translated to 0-bit
98 __m256d cmp = _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_EQ_OQ);
99 return _mm256_movemask_pd(cmp);
100 }
101 Vectorized<double> isnan() const {
102 return _mm256_cmp_pd(values, _mm256_set1_pd(0.0), _CMP_UNORD_Q);
103 }
104 Vectorized<double> map(double (*const f)(double)) const {
105 __at_align__ double tmp[size()];
106 store(tmp);
107 for (const auto i : c10::irange(size())) {
108 tmp[i] = f(tmp[i]);
109 }
110 return loadu(tmp);
111 }
112 Vectorized<double> abs() const {
113 auto mask = _mm256_set1_pd(-0.f);
114 return _mm256_andnot_pd(mask, values);
115 }
116 Vectorized<double> angle() const {
117 const auto zero_vec = _mm256_set1_pd(0.f);
118 const auto nan_vec = _mm256_set1_pd(NAN);
119 const auto not_nan_mask = _mm256_cmp_pd(values, values, _CMP_EQ_OQ);
120 const auto nan_mask = _mm256_cmp_pd(not_nan_mask, zero_vec, _CMP_EQ_OQ);
121 const auto pi = _mm256_set1_pd(c10::pi<double>);
122
123 const auto neg_mask = _mm256_cmp_pd(values, zero_vec, _CMP_LT_OQ);
124 auto angle = _mm256_blendv_pd(zero_vec, pi, neg_mask);
125 angle = _mm256_blendv_pd(angle, nan_vec, nan_mask);
126 return angle;
127 }
128 Vectorized<double> real() const {
129 return *this;
130 }
131 Vectorized<double> imag() const {
132 return _mm256_set1_pd(0);
133 }
134 Vectorized<double> conj() const {
135 return *this;
136 }
137 Vectorized<double> acos() const {
138 return Vectorized<double>(Sleef_acosd4_u10(values));
139 }
140 Vectorized<double> asin() const {
141 return Vectorized<double>(Sleef_asind4_u10(values));
142 }
143 Vectorized<double> atan() const {
144 return Vectorized<double>(Sleef_atand4_u10(values));
145 }
146 Vectorized<double> atan2(const Vectorized<double> &b) const {
147 return Vectorized<double>(Sleef_atan2d4_u10(values, b));
148 }
149 Vectorized<double> copysign(const Vectorized<double> &sign) const {
150 return Vectorized<double>(Sleef_copysignd4(values, sign));
151 }
152 Vectorized<double> erf() const {
153 return Vectorized<double>(Sleef_erfd4_u10(values));
154 }
155 Vectorized<double> erfc() const {
156 return Vectorized<double>(Sleef_erfcd4_u15(values));
157 }
158 Vectorized<double> erfinv() const {
159 return map(calc_erfinv);
160 }
161 Vectorized<double> exp() const {
162 return Vectorized<double>(Sleef_expd4_u10(values));
163 }
164 Vectorized<double> exp2() const {
165 return Vectorized<double>(Sleef_exp2d4_u10(values));
166 }
167 Vectorized<double> expm1() const {
168 return Vectorized<double>(Sleef_expm1d4_u10(values));
169 }
170 Vectorized<double> fmod(const Vectorized<double>& q) const {
171 return Vectorized<double>(Sleef_fmodd4(values, q));
172 }
173 Vectorized<double> hypot(const Vectorized<double> &b) const {
174 return Vectorized<double>(Sleef_hypotd4_u05(values, b));
175 }
176 Vectorized<double> i0() const {
177 return map(calc_i0);
178 }
179 Vectorized<double> i0e() const {
180 return map(calc_i0e);
181 }
182 Vectorized<double> igamma(const Vectorized<double> &x) const {
183 __at_align__ double tmp[size()];
184 __at_align__ double tmp_x[size()];
185 store(tmp);
186 x.store(tmp_x);
187 for (const auto i : c10::irange(size())) {
188 tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
189 }
190 return loadu(tmp);
191 }
192 Vectorized<double> igammac(const Vectorized<double> &x) const {
193 __at_align__ double tmp[size()];
194 __at_align__ double tmp_x[size()];
195 store(tmp);
196 x.store(tmp_x);
197 for (const auto i : c10::irange(size())) {
198 tmp[i] = calc_igammac(tmp[i], tmp_x[i]);
199 }
200 return loadu(tmp);
201 }
202 Vectorized<double> log() const {
203 return Vectorized<double>(Sleef_logd4_u10(values));
204 }
205 Vectorized<double> log2() const {
206 return Vectorized<double>(Sleef_log2d4_u10(values));
207 }
208 Vectorized<double> log10() const {
209 return Vectorized<double>(Sleef_log10d4_u10(values));
210 }
211 Vectorized<double> log1p() const {
212 return Vectorized<double>(Sleef_log1pd4_u10(values));
213 }
214 Vectorized<double> sin() const {
215 return Vectorized<double>(Sleef_sind4_u10(values));
216 }
217 Vectorized<double> sinh() const {
218 return Vectorized<double>(Sleef_sinhd4_u10(values));
219 }
220 Vectorized<double> cos() const {
221 return Vectorized<double>(Sleef_cosd4_u10(values));
222 }
223 Vectorized<double> cosh() const {
224 return Vectorized<double>(Sleef_coshd4_u10(values));
225 }
226 Vectorized<double> ceil() const {
227 return _mm256_ceil_pd(values);
228 }
229 Vectorized<double> floor() const {
230 return _mm256_floor_pd(values);
231 }
232 Vectorized<double> frac() const;
233 Vectorized<double> neg() const {
234 return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
235 }
236 Vectorized<double> nextafter(const Vectorized<double> &b) const {
237 return Vectorized<double>(Sleef_nextafterd4(values, b));
238 }
239 Vectorized<double> round() const {
240 return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
241 }
242 Vectorized<double> tan() const {
243 return Vectorized<double>(Sleef_tand4_u10(values));
244 }
245 Vectorized<double> tanh() const {
246 return Vectorized<double>(Sleef_tanhd4_u10(values));
247 }
248 Vectorized<double> trunc() const {
249 return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
250 }
251 Vectorized<double> lgamma() const {
252 return Vectorized<double>(Sleef_lgammad4_u10(values));
253 }
254 Vectorized<double> sqrt() const {
255 return _mm256_sqrt_pd(values);
256 }
257 Vectorized<double> reciprocal() const {
258 return _mm256_div_pd(_mm256_set1_pd(1), values);
259 }
260 Vectorized<double> rsqrt() const {
261 return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values));
262 }
263 Vectorized<double> pow(const Vectorized<double> &b) const {
264 return Vectorized<double>(Sleef_powd4_u10(values, b));
265 }
266 // Comparison using the _CMP_**_OQ predicate.
267 // `O`: get false if an operand is NaN
268 // `Q`: do not raise if an operand is NaN
269 Vectorized<double> operator==(const Vectorized<double>& other) const {
270 return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
271 }
272
273 Vectorized<double> operator!=(const Vectorized<double>& other) const {
274 return _mm256_cmp_pd(values, other.values, _CMP_NEQ_UQ);
275 }
276
277 Vectorized<double> operator<(const Vectorized<double>& other) const {
278 return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ);
279 }
280
281 Vectorized<double> operator<=(const Vectorized<double>& other) const {
282 return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ);
283 }
284
285 Vectorized<double> operator>(const Vectorized<double>& other) const {
286 return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ);
287 }
288
289 Vectorized<double> operator>=(const Vectorized<double>& other) const {
290 return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ);
291 }
292
293 Vectorized<double> eq(const Vectorized<double>& other) const;
294 Vectorized<double> ne(const Vectorized<double>& other) const;
295 Vectorized<double> lt(const Vectorized<double>& other) const;
296 Vectorized<double> le(const Vectorized<double>& other) const;
297 Vectorized<double> gt(const Vectorized<double>& other) const;
298 Vectorized<double> ge(const Vectorized<double>& other) const;
299};
300
301template <>
302Vectorized<double> inline operator+(const Vectorized<double>& a, const Vectorized<double>& b) {
303 return _mm256_add_pd(a, b);
304}
305
306template <>
307Vectorized<double> inline operator-(const Vectorized<double>& a, const Vectorized<double>& b) {
308 return _mm256_sub_pd(a, b);
309}
310
311template <>
312Vectorized<double> inline operator*(const Vectorized<double>& a, const Vectorized<double>& b) {
313 return _mm256_mul_pd(a, b);
314}
315
316template <>
317Vectorized<double> inline operator/(const Vectorized<double>& a, const Vectorized<double>& b) {
318 return _mm256_div_pd(a, b);
319}
320
321// frac. Implement this here so we can use subtraction.
322inline Vectorized<double> Vectorized<double>::frac() const {
323 return *this - this->trunc();
324}
325
326// Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
327// either input is a NaN.
328template <>
329Vectorized<double> inline maximum(const Vectorized<double>& a, const Vectorized<double>& b) {
330 Vectorized<double> max = _mm256_max_pd(a, b);
331 Vectorized<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
332 // Exploit the fact that all-ones is a NaN.
333 return _mm256_or_pd(max, isnan);
334}
335
336// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
337// either input is a NaN.
338template <>
339Vectorized<double> inline minimum(const Vectorized<double>& a, const Vectorized<double>& b) {
340 Vectorized<double> min = _mm256_min_pd(a, b);
341 Vectorized<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
342 // Exploit the fact that all-ones is a NaN.
343 return _mm256_or_pd(min, isnan);
344}
345
346template <>
347Vectorized<double> inline clamp(const Vectorized<double>& a, const Vectorized<double>& min, const Vectorized<double>& max) {
348 return _mm256_min_pd(max, _mm256_max_pd(min, a));
349}
350
351template <>
352Vectorized<double> inline clamp_min(const Vectorized<double>& a, const Vectorized<double>& min) {
353 return _mm256_max_pd(min, a);
354}
355
356template <>
357Vectorized<double> inline clamp_max(const Vectorized<double>& a, const Vectorized<double>& max) {
358 return _mm256_min_pd(max, a);
359}
360
361template <>
362Vectorized<double> inline operator&(const Vectorized<double>& a, const Vectorized<double>& b) {
363 return _mm256_and_pd(a, b);
364}
365
366template <>
367Vectorized<double> inline operator|(const Vectorized<double>& a, const Vectorized<double>& b) {
368 return _mm256_or_pd(a, b);
369}
370
371template <>
372Vectorized<double> inline operator^(const Vectorized<double>& a, const Vectorized<double>& b) {
373 return _mm256_xor_pd(a, b);
374}
375
376inline Vectorized<double> Vectorized<double>::eq(const Vectorized<double>& other) const {
377 return (*this == other) & Vectorized<double>(1.0);
378}
379
380inline Vectorized<double> Vectorized<double>::ne(const Vectorized<double>& other) const {
381 return (*this != other) & Vectorized<double>(1.0);
382}
383
384inline Vectorized<double> Vectorized<double>::gt(const Vectorized<double>& other) const {
385 return (*this > other) & Vectorized<double>(1.0);
386}
387
388inline Vectorized<double> Vectorized<double>::ge(const Vectorized<double>& other) const {
389 return (*this >= other) & Vectorized<double>(1.0);
390}
391
392inline Vectorized<double> Vectorized<double>::lt(const Vectorized<double>& other) const {
393 return (*this < other) & Vectorized<double>(1.0);
394}
395
396inline Vectorized<double> Vectorized<double>::le(const Vectorized<double>& other) const {
397 return (*this <= other) & Vectorized<double>(1.0);
398}
399
400template <>
401inline void convert(const double* src, double* dst, int64_t n) {
402 int64_t i;
403#pragma unroll
404 for (i = 0; i <= (n - Vectorized<double>::size()); i += Vectorized<double>::size()) {
405 _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
406 }
407#pragma unroll
408 for (; i < n; i++) {
409 dst[i] = src[i];
410 }
411}
412
413#ifdef CPU_CAPABILITY_AVX2
414template <>
415Vectorized<double> inline fmadd(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
416 return _mm256_fmadd_pd(a, b, c);
417}
418
419template <>
420Vectorized<double> inline fmsub(const Vectorized<double>& a, const Vectorized<double>& b, const Vectorized<double>& c) {
421 return _mm256_fmsub_pd(a, b, c);
422}
423#endif
424
425#endif
426
427}}}
428