1 | #define TORCH_ASSERT_NO_OPERATORS |
2 | |
3 | #include <ATen/native/ufunc/add.h> |
4 | #include <ATen/Dispatch.h> |
5 | #include <ATen/native/DispatchStub.h> |
6 | #include <c10/core/Scalar.h> |
7 | #include <ATen/native/cuda/Loops.cuh> |
8 | |
9 | namespace at { |
10 | |
11 | // NB: this is explicitly copied here (via codegen) rather than |
12 | // included via NativeFunctions.h to avoid recompiling this file when |
13 | // NativeFunctions.h changes |
14 | namespace meta { |
15 | struct TORCH_API structured_add_Tensor : public TensorIteratorBase { |
16 | |
17 | |
18 | void meta(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); |
19 | }; |
20 | } |
21 | |
22 | namespace native { |
23 | struct TORCH_API structured_ufunc_add_CUDA : public at::meta::structured_add_Tensor { |
24 | void impl(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, const at::Tensor & out); |
25 | }; |
26 | |
27 | |
28 | template <typename scalar_t> |
29 | struct CUDAFunctorOnSelf_add { |
30 | using opmath_t = at::opmath_type<scalar_t>; |
31 | opmath_t other_; |
32 | opmath_t alpha_; |
33 | CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {} |
34 | __device__ scalar_t operator()(scalar_t self) const { |
35 | return ufunc::add(static_cast<opmath_t>(self), other_, alpha_); |
36 | } |
37 | }; |
38 | |
39 | |
40 | template <typename scalar_t> |
41 | struct CUDAFunctorOnOther_add { |
42 | using opmath_t = at::opmath_type<scalar_t>; |
43 | opmath_t self_; |
44 | opmath_t alpha_; |
45 | CUDAFunctorOnOther_add(opmath_t self, opmath_t alpha) : self_(self), alpha_(alpha) {} |
46 | __device__ scalar_t operator()(scalar_t other) const { |
47 | return ufunc::add(self_, static_cast<opmath_t>(other), alpha_); |
48 | } |
49 | }; |
50 | |
51 | |
52 | template <typename scalar_t> |
53 | struct CUDAFunctor_add { |
54 | using opmath_t = at::opmath_type<scalar_t>; |
55 | opmath_t alpha_; |
56 | CUDAFunctor_add(opmath_t alpha) : alpha_(alpha) {} |
57 | __device__ scalar_t operator()(scalar_t self, scalar_t other) const { |
58 | return ufunc::add(static_cast<opmath_t>(self), static_cast<opmath_t>(other), alpha_); |
59 | } |
60 | }; |
61 | |
62 | |
63 | using add_fn = void(*)(TensorIteratorBase&, const at::Scalar &); |
64 | DECLARE_DISPATCH(add_fn, add_stub); |
65 | |
66 | void add_kernel(TensorIteratorBase& iter, const at::Scalar & alpha) { |
67 | AT_DISPATCH_SWITCH(iter.common_dtype(), "ufunc_add_CUDA" , |
68 | |
69 | AT_DISPATCH_CASE(at::ScalarType::Bool, |
70 | [&]() { |
71 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
72 | else if (iter.is_cpu_scalar(1)) { |
73 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
74 | iter.remove_operand(1); |
75 | gpu_kernel(iter, ufunctor); |
76 | }else if (iter.is_cpu_scalar(2)) { |
77 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
78 | iter.remove_operand(2); |
79 | gpu_kernel(iter, ufunctor); |
80 | } |
81 | else { |
82 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
83 | } |
84 | |
85 | } |
86 | ) |
87 | |
88 | |
89 | AT_DISPATCH_CASE(at::ScalarType::Byte, |
90 | [&]() { |
91 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
92 | else if (iter.is_cpu_scalar(1)) { |
93 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
94 | iter.remove_operand(1); |
95 | gpu_kernel(iter, ufunctor); |
96 | }else if (iter.is_cpu_scalar(2)) { |
97 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
98 | iter.remove_operand(2); |
99 | gpu_kernel(iter, ufunctor); |
100 | } |
101 | else { |
102 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
103 | } |
104 | |
105 | } |
106 | ) |
107 | |
108 | |
109 | AT_DISPATCH_CASE(at::ScalarType::Char, |
110 | [&]() { |
111 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
112 | else if (iter.is_cpu_scalar(1)) { |
113 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
114 | iter.remove_operand(1); |
115 | gpu_kernel(iter, ufunctor); |
116 | }else if (iter.is_cpu_scalar(2)) { |
117 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
118 | iter.remove_operand(2); |
119 | gpu_kernel(iter, ufunctor); |
120 | } |
121 | else { |
122 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
123 | } |
124 | |
125 | } |
126 | ) |
127 | |
128 | |
129 | AT_DISPATCH_CASE(at::ScalarType::Int, |
130 | [&]() { |
131 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
132 | else if (iter.is_cpu_scalar(1)) { |
133 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
134 | iter.remove_operand(1); |
135 | gpu_kernel(iter, ufunctor); |
136 | }else if (iter.is_cpu_scalar(2)) { |
137 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
138 | iter.remove_operand(2); |
139 | gpu_kernel(iter, ufunctor); |
140 | } |
141 | else { |
142 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
143 | } |
144 | |
145 | } |
146 | ) |
147 | |
148 | |
149 | AT_DISPATCH_CASE(at::ScalarType::Long, |
150 | [&]() { |
151 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
152 | else if (iter.is_cpu_scalar(1)) { |
153 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
154 | iter.remove_operand(1); |
155 | gpu_kernel(iter, ufunctor); |
156 | }else if (iter.is_cpu_scalar(2)) { |
157 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
158 | iter.remove_operand(2); |
159 | gpu_kernel(iter, ufunctor); |
160 | } |
161 | else { |
162 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
163 | } |
164 | |
165 | } |
166 | ) |
167 | |
168 | |
169 | AT_DISPATCH_CASE(at::ScalarType::Short, |
170 | [&]() { |
171 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
172 | else if (iter.is_cpu_scalar(1)) { |
173 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
174 | iter.remove_operand(1); |
175 | gpu_kernel(iter, ufunctor); |
176 | }else if (iter.is_cpu_scalar(2)) { |
177 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
178 | iter.remove_operand(2); |
179 | gpu_kernel(iter, ufunctor); |
180 | } |
181 | else { |
182 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
183 | } |
184 | |
185 | } |
186 | ) |
187 | |
188 | |
189 | AT_DISPATCH_CASE(at::ScalarType::Float, |
190 | [&]() { |
191 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
192 | else if (iter.is_cpu_scalar(1)) { |
193 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
194 | iter.remove_operand(1); |
195 | gpu_kernel(iter, ufunctor); |
196 | }else if (iter.is_cpu_scalar(2)) { |
197 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
198 | iter.remove_operand(2); |
199 | gpu_kernel(iter, ufunctor); |
200 | } |
201 | else { |
202 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
203 | } |
204 | |
205 | } |
206 | ) |
207 | |
208 | |
209 | AT_DISPATCH_CASE(at::ScalarType::Double, |
210 | [&]() { |
211 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
212 | else if (iter.is_cpu_scalar(1)) { |
213 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
214 | iter.remove_operand(1); |
215 | gpu_kernel(iter, ufunctor); |
216 | }else if (iter.is_cpu_scalar(2)) { |
217 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
218 | iter.remove_operand(2); |
219 | gpu_kernel(iter, ufunctor); |
220 | } |
221 | else { |
222 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
223 | } |
224 | |
225 | } |
226 | ) |
227 | |
228 | |
229 | AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, |
230 | [&]() { |
231 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
232 | else if (iter.is_cpu_scalar(1)) { |
233 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
234 | iter.remove_operand(1); |
235 | gpu_kernel(iter, ufunctor); |
236 | }else if (iter.is_cpu_scalar(2)) { |
237 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
238 | iter.remove_operand(2); |
239 | gpu_kernel(iter, ufunctor); |
240 | } |
241 | else { |
242 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
243 | } |
244 | |
245 | } |
246 | ) |
247 | |
248 | |
249 | AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, |
250 | [&]() { |
251 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
252 | else if (iter.is_cpu_scalar(1)) { |
253 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
254 | iter.remove_operand(1); |
255 | gpu_kernel(iter, ufunctor); |
256 | }else if (iter.is_cpu_scalar(2)) { |
257 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
258 | iter.remove_operand(2); |
259 | gpu_kernel(iter, ufunctor); |
260 | } |
261 | else { |
262 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
263 | } |
264 | |
265 | } |
266 | ) |
267 | |
268 | |
269 | AT_DISPATCH_CASE(at::ScalarType::BFloat16, |
270 | [&]() { |
271 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
272 | else if (iter.is_cpu_scalar(1)) { |
273 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
274 | iter.remove_operand(1); |
275 | gpu_kernel(iter, ufunctor); |
276 | }else if (iter.is_cpu_scalar(2)) { |
277 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
278 | iter.remove_operand(2); |
279 | gpu_kernel(iter, ufunctor); |
280 | } |
281 | else { |
282 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
283 | } |
284 | |
285 | } |
286 | ) |
287 | |
288 | |
289 | AT_DISPATCH_CASE(at::ScalarType::Half, |
290 | [&]() { |
291 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
292 | else if (iter.is_cpu_scalar(1)) { |
293 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
294 | iter.remove_operand(1); |
295 | gpu_kernel(iter, ufunctor); |
296 | }else if (iter.is_cpu_scalar(2)) { |
297 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
298 | iter.remove_operand(2); |
299 | gpu_kernel(iter, ufunctor); |
300 | } |
301 | else { |
302 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
303 | } |
304 | |
305 | } |
306 | ) |
307 | |
308 | |
309 | AT_DISPATCH_CASE(at::ScalarType::ComplexHalf, |
310 | [&]() { |
311 | using opmath_t = at::opmath_type<scalar_t>;if (false) {} |
312 | else if (iter.is_cpu_scalar(1)) { |
313 | CUDAFunctorOnOther_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(1), (alpha).to<opmath_t>()); |
314 | iter.remove_operand(1); |
315 | gpu_kernel(iter, ufunctor); |
316 | }else if (iter.is_cpu_scalar(2)) { |
317 | CUDAFunctorOnSelf_add<scalar_t> ufunctor(iter.scalar_value<opmath_t>(2), (alpha).to<opmath_t>()); |
318 | iter.remove_operand(2); |
319 | gpu_kernel(iter, ufunctor); |
320 | } |
321 | else { |
322 | gpu_kernel(iter, CUDAFunctor_add<scalar_t>((alpha).to<opmath_t>())); |
323 | } |
324 | |
325 | } |
326 | ) |
327 | |
328 | ); |
329 | } |
330 | REGISTER_DISPATCH(add_stub, &add_kernel); |
331 | |
332 | TORCH_IMPL_FUNC(ufunc_add_CUDA)(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, const at::Tensor & out) { |
333 | add_kernel(*this, alpha); |
334 | } |
335 | }} // namespace at::native |
336 | |