1#pragma once
2
3// This file provides two functions to help write elementwise kernels:
4//
5// cpu_kernel(TensorIterator iter, <lambda>)
6// cpu_kernel_vec(TensorIterator iter, <lambda>, <vec_lambda>)
7//
8// Both functions may generate vectorized code. The cpu_kernel implementation
9// relies on the compiler's auto-vectorization. The cpu_kernel_vec
10// implementation uses x86 SIMD intrinsics when available. These functions
11// are only intended to be used in the ATen/native/cpu subdirectory, since files
12// in other directories are not compiled with AVX/AVX2 enabled. See README.md
13// for more details.
14//
15// For example, to write a multiplication kernel for float:
16//
17// cpu_kernel(iter, [](float a, float b) { return a * b; });
18//
19// Or you may write:
20//
21// cpu_kernel_vec(iter,
22// [](float a, float b) { return a * b; },
23// [](Vectorized<float> a, Vectorized<float> b) { return a * b; });
24//
25// See BinaryOpsKernel.cpp for the complete implementation
26//
27//
28
29#include <stdint.h>
30#include <c10/util/C++17.h>
31#include <c10/util/Load.h>
32#include <c10/util/irange.h>
33#include <ATen/detail/FunctionTraits.h>
34#include <ATen/native/cpu/IsContiguous.h>
35#include <ATen/native/TensorIterator.h>
36#include <ATen/native/TensorIteratorDynamicCasting.h>
37#include <ATen/cpu/vec/vec.h>
38
39#include <utility>
40
41namespace at { namespace native { inline namespace CPU_CAPABILITY {
42
43using namespace vec;
44
45template <typename traits, std::size_t... INDEX>
46typename traits::ArgsTuple
47dereference_impl(char* C10_RESTRICT data[], const int64_t* strides, int64_t i,
48 std::index_sequence<INDEX...>) {
49 return std::make_tuple(
50 c10::load<typename traits::template arg<INDEX>::type>(
51 data[INDEX] + i * strides[INDEX])...);
52}
53
54template <typename traits>
55typename traits::ArgsTuple
56dereference(char* C10_RESTRICT data[], const int64_t* strides, int64_t i) {
57 using Indices = std::make_index_sequence<traits::arity>;
58 return dereference_impl<traits>(data, strides, i, Indices{});
59}
60
61template <typename traits, std::size_t... INDEX>
62typename traits::ArgsTuple
63dereference_vec_impl(char* C10_RESTRICT data[],
64 const typename traits::result_type& opt_scalar,
65 size_t S,
66 int64_t i,
67 std::index_sequence<INDEX...>) {
68 using Vec = typename traits::result_type;
69 using scalar_t = typename Vec::value_type;
70 return std::make_tuple(
71 S == INDEX + 1 ?
72 opt_scalar :
73 Vec::loadu(data[INDEX] + i * sizeof(scalar_t))...);
74}
75
76template <typename traits>
77typename traits::ArgsTuple
78dereference_vec(char* C10_RESTRICT data[], const typename traits::result_type& opt_scalar, size_t S, int64_t i) {
79 using Indices = std::make_index_sequence<traits::arity>;
80 return dereference_vec_impl<traits>(data, opt_scalar, S, i, Indices{});
81}
82
83template <typename func_t,
84 typename std::enable_if<!std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
85static inline void
86execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
87 using traits = function_traits<func_t>;
88 using result_type = typename traits::result_type;
89 for (; i < n; i++) {
90 result_type* out_ptr = (result_type*)(data[0] + i * strides[0]);
91 *out_ptr = c10::guts::apply(std::forward<func_t>(op), dereference<traits>(
92 &data[1],
93 &strides[1],
94 i));
95 }
96}
97
98template <typename func_t,
99 typename std::enable_if<std::is_void<typename function_traits<func_t>::result_type>::value>::type* = nullptr>
100static inline void
101execute_op(char* C10_RESTRICT data[], const int64_t* strides, int64_t i, int64_t n, func_t&& op) {
102 using traits = function_traits<func_t>;
103 for (; i < n; i++) {
104 c10::guts::apply(std::forward<func_t>(op), dereference<traits>(
105 &data[0],
106 &strides[0],
107 i));
108 }
109}
110
111// Basic loop operation (one output, N inputs). May be auto-vectorized
112// by the compiler. Supports inputs and outputs of different types.
113template <typename func_t>
114static inline void
115basic_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
116 using traits = function_traits<func_t>;
117 constexpr int ntensors = traits::arity + 1;
118
119 // Copying strides to temporary array helps auto vectorization in older GCC
120 // versions.
121 int64_t strides[ntensors];
122 for (const auto arg : c10::irange(ntensors)) {
123 strides[arg] = strides_[arg];
124 }
125
126 execute_op(data, strides, i, n, std::forward<func_t>(op));
127}
128
129// the recursive variadic template for iterating over the returned tuple
130template<class T, size_t N>
131struct TupleOutput {
132 static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
133 const T &tuple) {
134 TupleOutput<T, N - 1>::handle(data, strides, i, tuple);
135
136 auto output = std::get<N - 1>(tuple);
137 using output_type = decltype(output);
138 output_type * out_ptr = (output_type *)(data[N - 1] + i * strides[N - 1]);
139 *out_ptr = output;
140 }
141};
142
143// Base case for the above recursive template
144template<class T>
145struct TupleOutput<T, 1> {
146 static void handle(char *C10_RESTRICT data[], const int64_t *strides, int64_t i,
147 const T &tuple) {
148 auto output = std::get<0>(tuple);
149 using output_type = decltype(output);
150 output_type* out_ptr = (output_type *)(data[0] + i * strides[0]);
151 *out_ptr = output;
152 }
153};
154
155template<class... Args>
156void handle_tuple_outputs(char* C10_RESTRICT data[],
157 const int64_t* strides,
158 int64_t i,
159 const std::tuple<Args...> &tuple) {
160 TupleOutput<decltype(tuple), sizeof...(Args)>::handle(data, strides, i, tuple);
161}
162
163// Loop operation for `cpu_kernel_multiple_outputs`.
164// 1. Use `c10::guts::apply` to make dynamic method invocation
165// for the lambda passed in `cpu_kernel_multiple_outputs`.
166// 2. Iterate over the members of the returned tuple, set the corresponding
167// output tensor by the tuple member in `handle_tuple_outputs` function.
168template <typename func_t>
169static inline void
170multiple_outputs_loop(char* C10_RESTRICT data[], const int64_t* strides_, int64_t i, int64_t n, func_t&& op) {
171 using traits = function_traits<func_t>;
172
173 using result_type = typename traits::result_type;
174 constexpr int num_outputs = std::tuple_size<result_type>::value;
175 constexpr int ntensors = traits::arity + num_outputs;
176
177 // Copying strides to temporary array helps auto vectorization in older GCC
178 // versions.
179 int64_t strides[ntensors];
180 for (const auto arg : c10::irange(ntensors)) {
181 strides[arg] = strides_[arg];
182 }
183
184 for (; i < n; i++) {
185 auto output = c10::guts::apply(op, dereference<traits>(
186 &data[num_outputs],
187 &strides[num_outputs],
188 i));
189 handle_tuple_outputs(data, strides, i, output);
190 }
191}
192
193// Explicitly vectorized loop implementation. All inputs and outputs must be
194// the same type and contiguous with one exception: a single input may be
195// a scalar (stride 0). It's position is indicated by the argument `S`. If `S`
196// is 0, then there are no scalar inputs.
197template <typename func_t, typename vec_func_t>
198static inline void
199vectorized_loop(char** C10_RESTRICT data_, int64_t n, int64_t S, func_t&& op, vec_func_t&& vop) {
200 using traits = function_traits<vec_func_t>;
201 using scalar_t = typename function_traits<func_t>::result_type;
202 using Vec = Vectorized<scalar_t>;
203 constexpr int ntensors = traits::arity + 1;
204
205 char* C10_RESTRICT data[ntensors];
206 for (const auto arg : c10::irange(ntensors)) {
207 data[arg] = data_[arg];
208 }
209
210 Vec opt_scalar = Vec(S > 0 ? *(scalar_t*)data[S] : scalar_t(0));
211 int64_t i = 0;
212 for (; i <= n - 2 * Vec::size(); i += 2 * Vec::size()) {
213 auto args1 = dereference_vec<traits>(&data[1], opt_scalar, S, i);
214 auto args2 = dereference_vec<traits>(&data[1], opt_scalar, S, i + Vec::size());
215 auto out1 = c10::guts::apply(std::forward<vec_func_t>(vop), std::move(args1));
216 auto out2 = c10::guts::apply(std::forward<vec_func_t>(vop), std::move(args2));
217 out1.store(data[0] + i * sizeof(scalar_t));
218 out2.store(data[0] + (i + Vec::size()) * sizeof(scalar_t));
219 }
220 if (i < n) {
221 int64_t strides[ntensors];
222 for (const auto arg : c10::irange(ntensors)) {
223 strides[arg] = (S > 0 && arg == S) ? 0 : sizeof(scalar_t);
224 }
225 basic_loop(data, strides, i, n, std::forward<func_t>(op));
226 }
227}
228
229
230template <typename traits, typename cb_t>
231static inline void unroll_contiguous_scalar_checks(
232 const int64_t* /*strides*/,
233 std::index_sequence<>,
234 cb_t&& cb) {
235 cb(0);
236}
237
238template <typename traits, typename cb_t, size_t INDEX0, size_t ...INDEX>
239static inline void unroll_contiguous_scalar_checks(
240 const int64_t* strides,
241 std::index_sequence<INDEX0, INDEX...>,
242 cb_t&& cb) {
243 if (is_contiguous_scalar<traits, INDEX0 + 1>(strides)) {
244 cb(INDEX0 + 1);
245 } else {
246 unroll_contiguous_scalar_checks<traits>(strides, std::index_sequence<INDEX...>{}, std::forward<cb_t>(cb));
247 }
248}
249
250template <typename op_t, typename vop_t>
251struct VectorizedLoop2d {
252 op_t op;
253 vop_t vop;
254
255 using traits = function_traits<op_t>;
256 static constexpr int ntensors = traits::arity + 1;
257 using data_t = std::array<char*, ntensors>;
258
259 VectorizedLoop2d(const op_t &op, vop_t vop):
260 op(op), vop(std::move(vop)) {}
261
262 static void advance(data_t &data, const int64_t *outer_strides) {
263 for (const auto arg : c10::irange(data.size())) {
264 data[arg] += outer_strides[arg];
265 }
266 }
267
268 void operator()(char** base, const int64_t *strides, int64_t size0, int64_t size1) {
269 data_t data;
270 std::copy_n(base, ntensors, data.data());
271 const int64_t *outer_strides = &strides[ntensors];
272
273 if (is_contiguous<traits>(strides)) {
274 for (const auto i C10_UNUSED : c10::irange(size1)) {
275 vectorized_loop(data.data(), size0, 0, op, vop);
276 advance(data, outer_strides);
277 }
278 } else {
279 using Indices = std::make_index_sequence<traits::arity>;
280 unroll_contiguous_scalar_checks<traits>(strides, Indices{}, [&](size_t idx) {
281 if (idx) {
282 for (const auto i C10_UNUSED : c10::irange(size1)) {
283 vectorized_loop(data.data(), size0, idx, op, vop);
284 advance(data, outer_strides);
285 }
286 } else {
287 for (const auto i C10_UNUSED : c10::irange(size1)) {
288 basic_loop(data.data(), strides, 0, size0, op);
289 advance(data, outer_strides);
290 }
291 }
292 });
293 }
294 }
295};
296
297template <typename op_t, typename vop_t>
298VectorizedLoop2d<op_t, vop_t> make_vectorized_loop2d(
299 const op_t &op, const vop_t &vop) {
300 return VectorizedLoop2d<op_t, vop_t>(op, vop);
301}
302
303template <typename func_t>
304void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
305 using traits = function_traits<func_t>;
306 // this could be extended to work with void return types
307 TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
308 TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
309 // dynamic casting not currently supported on CPU
310 TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
311
312 iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
313 // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that
314 // iter.for_each is ever sending to the loop lambda
315 basic_loop(data, strides, 0, n, std::forward<func_t>(op));
316 }, grain_size);
317 iter.cast_outputs();
318}
319
320// This function helps write elementwise kernels that requires multiple outputs.
321// It follows the similar structure of cpu_kernel.
322// Instead of `basic_loop` function, a new `multiple_outputs_loop` function is
323// manipulated to handle multiple return values.
324// For now `needs_dynamic_casting` check is not added as the passed lambda (`func_t`)
325// of `multiple_outputs_loop` returns `std::tuple` instead of `scalar_t`.
326// The `gpu_kernel_multiple_outputs` is also implemented without this check,
327// We could extend `needs_dynamic_casting` to support both `std::tuple` and
328// `thrust::tuple` in the future.
329template <typename func_t>
330void cpu_kernel_multiple_outputs(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) {
331 using traits = function_traits<func_t>;
332 TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
333
334 iter.for_each([&](char** data, const int64_t* strides, int64_t n) {
335 multiple_outputs_loop(data, strides, 0, n, std::forward<func_t>(op));
336 }, grain_size);
337 iter.cast_outputs();
338}
339
340template <bool check_dynamic_cast=true, typename func_t, typename vec_func_t>
341void cpu_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, int64_t grain_size = at::internal::GRAIN_SIZE) {
342 using traits = function_traits<func_t>;
343 // this could be extended to work with void return types
344 TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
345 TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
346 // dynamic casting not currently supported on CPU, but some kernels (like Fill)
347 // explicitly dynamic_cast, so we give the opt-out of checking.
348 c10::guts::if_constexpr<check_dynamic_cast>([&] {
349 TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
350 });
351
352 iter.for_each(make_vectorized_loop2d(op, vop), grain_size);
353 iter.cast_outputs();
354}
355
356template <typename func_t>
357void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op, const Range& range) {
358 using traits = function_traits<func_t>;
359 constexpr bool result_void = std::is_void<typename traits::result_type>::value;
360 TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity &&
361 ((result_void && iter.noutputs() == 0) || (!result_void && iter.noutputs() == 1)));
362 // dynamic casting not currently supported on CPU
363 TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
364
365 iter.serial_for_each([&](char** data, const int64_t* strides, int64_t n) {
366 basic_loop(data, strides, 0, n, std::forward<func_t>(op));
367 }, range);
368 iter.cast_outputs();
369}
370
371template <typename func_t>
372void cpu_serial_kernel(TensorIteratorBase& iter, func_t&& op) {
373 cpu_serial_kernel(iter, op, {0, iter.numel()});
374}
375
376template <typename func_t, typename vec_func_t>
377void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop, const Range& range) {
378 using traits = function_traits<func_t>;
379 // this could be extended to work with void return types
380 TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
381 TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
382 // dynamic casting not currently supported on CPU
383 TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));
384
385 iter.serial_for_each(make_vectorized_loop2d(op, vop), range);
386 iter.cast_outputs();
387}
388
389template <typename func_t, typename vec_func_t>
390void cpu_serial_kernel_vec(TensorIteratorBase& iter, func_t&& op, vec_func_t&& vop) {
391 cpu_serial_kernel_vec(iter, op, vop, {0, iter.numel()});
392}
393
394}}} // namespace at::native::<anonymous>
395