1#define TORCH_ASSERT_NO_OPERATORS
2
3#include <ATen/native/ufunc/add.h>
4#include <ATen/native/DispatchStub.h>
5#include <ATen/TensorIterator.h>
6#include <ATen/native/cpu/Loops.h>
7#include <ATen/cpu/vec/vec.h>
8#include <ATen/Dispatch.h>
9#include <c10/core/Scalar.h>
10
11namespace at {
12namespace native {
13
14namespace {
15
16void add_kernel(TensorIteratorBase& iter, const at::Scalar & alpha) {
17 AT_DISPATCH_SWITCH(iter.common_dtype(), "add_stub",
18
19AT_DISPATCH_CASE(at::ScalarType::Bool,
20 [&]() {
21
22auto _s_alpha = alpha.to<scalar_t>();
23cpu_kernel(iter,
24 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }
25);
26
27 }
28)
29
30
31AT_DISPATCH_CASE(at::ScalarType::Byte,
32 [&]() {
33
34auto _s_alpha = alpha.to<scalar_t>();
35auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
36cpu_kernel_vec(iter,
37 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
38 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
39);
40
41 }
42)
43
44
45AT_DISPATCH_CASE(at::ScalarType::Char,
46 [&]() {
47
48auto _s_alpha = alpha.to<scalar_t>();
49auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
50cpu_kernel_vec(iter,
51 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
52 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
53);
54
55 }
56)
57
58
59AT_DISPATCH_CASE(at::ScalarType::Int,
60 [&]() {
61
62auto _s_alpha = alpha.to<scalar_t>();
63auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
64cpu_kernel_vec(iter,
65 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
66 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
67);
68
69 }
70)
71
72
73AT_DISPATCH_CASE(at::ScalarType::Long,
74 [&]() {
75
76auto _s_alpha = alpha.to<scalar_t>();
77auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
78cpu_kernel_vec(iter,
79 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
80 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
81);
82
83 }
84)
85
86
87AT_DISPATCH_CASE(at::ScalarType::Short,
88 [&]() {
89
90auto _s_alpha = alpha.to<scalar_t>();
91auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
92cpu_kernel_vec(iter,
93 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
94 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
95);
96
97 }
98)
99
100
101AT_DISPATCH_CASE(at::ScalarType::Float,
102 [&]() {
103
104auto _s_alpha = alpha.to<scalar_t>();
105auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
106cpu_kernel_vec(iter,
107 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
108 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
109);
110
111 }
112)
113
114
115AT_DISPATCH_CASE(at::ScalarType::Double,
116 [&]() {
117
118auto _s_alpha = alpha.to<scalar_t>();
119auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
120cpu_kernel_vec(iter,
121 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
122 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
123);
124
125 }
126)
127
128
129AT_DISPATCH_CASE(at::ScalarType::ComplexFloat,
130 [&]() {
131
132auto _s_alpha = alpha.to<scalar_t>();
133auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
134cpu_kernel_vec(iter,
135 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
136 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
137);
138
139 }
140)
141
142
143AT_DISPATCH_CASE(at::ScalarType::ComplexDouble,
144 [&]() {
145
146auto _s_alpha = alpha.to<scalar_t>();
147auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
148cpu_kernel_vec(iter,
149 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
150 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
151);
152
153 }
154)
155
156
157AT_DISPATCH_CASE(at::ScalarType::BFloat16,
158 [&]() {
159
160auto _s_alpha = alpha.to<scalar_t>();
161auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
162cpu_kernel_vec(iter,
163 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
164 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
165);
166
167 }
168)
169
170
171AT_DISPATCH_CASE(at::ScalarType::Half,
172 [&]() {
173
174auto _s_alpha = alpha.to<scalar_t>();
175auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
176cpu_kernel_vec(iter,
177 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
178 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
179);
180
181 }
182)
183
184
185AT_DISPATCH_CASE(at::ScalarType::ComplexHalf,
186 [&]() {
187
188auto _s_alpha = alpha.to<scalar_t>();
189auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha);
190cpu_kernel_vec(iter,
191 [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); },
192 [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); }
193);
194
195 }
196)
197
198 );
199}
200
201} // anonymous namespace
202
203using add_fn = void(*)(TensorIteratorBase&, const at::Scalar &);
204DECLARE_DISPATCH(add_fn, add_stub);
205REGISTER_DISPATCH(add_stub, &add_kernel);
206}} // namespace at::native
207