1 | #define TORCH_ASSERT_NO_OPERATORS |
2 | |
3 | #include <ATen/native/ufunc/add.h> |
4 | #include <ATen/native/DispatchStub.h> |
5 | #include <ATen/TensorIterator.h> |
6 | #include <ATen/native/cpu/Loops.h> |
7 | #include <ATen/cpu/vec/vec.h> |
8 | #include <ATen/Dispatch.h> |
9 | #include <c10/core/Scalar.h> |
10 | |
11 | namespace at { |
12 | namespace native { |
13 | |
14 | namespace { |
15 | |
16 | void add_kernel(TensorIteratorBase& iter, const at::Scalar & alpha) { |
17 | AT_DISPATCH_SWITCH(iter.common_dtype(), "add_stub" , |
18 | |
19 | AT_DISPATCH_CASE(at::ScalarType::Bool, |
20 | [&]() { |
21 | |
22 | auto _s_alpha = alpha.to<scalar_t>(); |
23 | cpu_kernel(iter, |
24 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); } |
25 | ); |
26 | |
27 | } |
28 | ) |
29 | |
30 | |
31 | AT_DISPATCH_CASE(at::ScalarType::Byte, |
32 | [&]() { |
33 | |
34 | auto _s_alpha = alpha.to<scalar_t>(); |
35 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
36 | cpu_kernel_vec(iter, |
37 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
38 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
39 | ); |
40 | |
41 | } |
42 | ) |
43 | |
44 | |
45 | AT_DISPATCH_CASE(at::ScalarType::Char, |
46 | [&]() { |
47 | |
48 | auto _s_alpha = alpha.to<scalar_t>(); |
49 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
50 | cpu_kernel_vec(iter, |
51 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
52 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
53 | ); |
54 | |
55 | } |
56 | ) |
57 | |
58 | |
59 | AT_DISPATCH_CASE(at::ScalarType::Int, |
60 | [&]() { |
61 | |
62 | auto _s_alpha = alpha.to<scalar_t>(); |
63 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
64 | cpu_kernel_vec(iter, |
65 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
66 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
67 | ); |
68 | |
69 | } |
70 | ) |
71 | |
72 | |
73 | AT_DISPATCH_CASE(at::ScalarType::Long, |
74 | [&]() { |
75 | |
76 | auto _s_alpha = alpha.to<scalar_t>(); |
77 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
78 | cpu_kernel_vec(iter, |
79 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
80 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
81 | ); |
82 | |
83 | } |
84 | ) |
85 | |
86 | |
87 | AT_DISPATCH_CASE(at::ScalarType::Short, |
88 | [&]() { |
89 | |
90 | auto _s_alpha = alpha.to<scalar_t>(); |
91 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
92 | cpu_kernel_vec(iter, |
93 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
94 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
95 | ); |
96 | |
97 | } |
98 | ) |
99 | |
100 | |
101 | AT_DISPATCH_CASE(at::ScalarType::Float, |
102 | [&]() { |
103 | |
104 | auto _s_alpha = alpha.to<scalar_t>(); |
105 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
106 | cpu_kernel_vec(iter, |
107 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
108 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
109 | ); |
110 | |
111 | } |
112 | ) |
113 | |
114 | |
115 | AT_DISPATCH_CASE(at::ScalarType::Double, |
116 | [&]() { |
117 | |
118 | auto _s_alpha = alpha.to<scalar_t>(); |
119 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
120 | cpu_kernel_vec(iter, |
121 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
122 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
123 | ); |
124 | |
125 | } |
126 | ) |
127 | |
128 | |
129 | AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, |
130 | [&]() { |
131 | |
132 | auto _s_alpha = alpha.to<scalar_t>(); |
133 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
134 | cpu_kernel_vec(iter, |
135 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
136 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
137 | ); |
138 | |
139 | } |
140 | ) |
141 | |
142 | |
143 | AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, |
144 | [&]() { |
145 | |
146 | auto _s_alpha = alpha.to<scalar_t>(); |
147 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
148 | cpu_kernel_vec(iter, |
149 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
150 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
151 | ); |
152 | |
153 | } |
154 | ) |
155 | |
156 | |
157 | AT_DISPATCH_CASE(at::ScalarType::BFloat16, |
158 | [&]() { |
159 | |
160 | auto _s_alpha = alpha.to<scalar_t>(); |
161 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
162 | cpu_kernel_vec(iter, |
163 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
164 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
165 | ); |
166 | |
167 | } |
168 | ) |
169 | |
170 | |
171 | AT_DISPATCH_CASE(at::ScalarType::Half, |
172 | [&]() { |
173 | |
174 | auto _s_alpha = alpha.to<scalar_t>(); |
175 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
176 | cpu_kernel_vec(iter, |
177 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
178 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
179 | ); |
180 | |
181 | } |
182 | ) |
183 | |
184 | |
185 | AT_DISPATCH_CASE(at::ScalarType::ComplexHalf, |
186 | [&]() { |
187 | |
188 | auto _s_alpha = alpha.to<scalar_t>(); |
189 | auto _v_alpha = at::vec::Vectorized<scalar_t>(_s_alpha); |
190 | cpu_kernel_vec(iter, |
191 | [=](scalar_t self, scalar_t other) { return ufunc::add(self, other, _s_alpha); }, |
192 | [=](at::vec::Vectorized<scalar_t> self, at::vec::Vectorized<scalar_t> other) { return ufunc::add(self, other, _v_alpha); } |
193 | ); |
194 | |
195 | } |
196 | ) |
197 | |
198 | ); |
199 | } |
200 | |
201 | } // anonymous namespace |
202 | |
203 | using add_fn = void(*)(TensorIteratorBase&, const at::Scalar &); |
204 | DECLARE_DISPATCH(add_fn, add_stub); |
205 | REGISTER_DISPATCH(add_stub, &add_kernel); |
206 | }} // namespace at::native |
207 | |