1#pragma once
2
3#include <ATen/CollapseDims.h>
4#include <ATen/Parallel.h>
5#include <ATen/TensorUtils.h>
6#include <c10/util/irange.h>
7#include <cstring>
8#include <limits>
9#include <utility>
10
11namespace at {
12
13/*
14 * The basic strategy for apply is as follows:
15 *
16 * 1. Starting with the outermost index, loop until we reach a dimension where
17 * the data is no longer contiguous, i.e. the stride at that dimension is not
18 * equal to the size of the tensor defined by the outer dimensions. Let's call
19 * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then
20 * A is equal to the entire Tensor. Let's call the inner tensor B.
21 *
22 * 2. We loop through the indices in B, starting at its outermost dimension. For
23 * example, if B is a 2x2 matrix, then we do:
24 *
25 * B[0][0]
26 * B[0][1]
27 * B[1][0]
28 * B[1][1]
29 *
30 * We set the offset into the underlying storage as (storageOffset + stride_B *
31 * index_B), i.e. basically we compute the offset into the storage as we would
32 * normally for a Tensor. But because we are guaranteed the subsequent data is
33 * contiguous in memory, we can simply loop for sizeof(A) iterations and perform
34 * the operation, without having to follow the order described by the strides of
35 * A.
36 *
37 * 3. As an optimization, we merge dimensions of A that are contiguous in
38 * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor,
39 * then the first two dimensions can be merged for the purposes of APPLY,
40 * reducing the number of nested loops.
41 */
42
43inline Tensor sort_strides(Tensor& tensor_) {
44 IntArrayRef strides = tensor_.strides();
45 std::vector<int64_t> indices;
46 indices.reserve(tensor_.ndimension());
47 for (const auto i : c10::irange(tensor_.ndimension())) {
48 indices.push_back(i);
49 }
50 std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) {
51 return strides[i1] > strides[i2];
52 });
53 Tensor tensor = tensor_.permute(indices);
54 return tensor;
55}
56
57template <typename T, int N>
58struct strided_tensor_iter_fixed {
59 public:
60 T* data_ = NULL;
61 int64_t dim_ = 0;
62
63 int64_t counter_[N] = {0};
64 int64_t sizes_[N] = {0};
65 int64_t strides_[N] = {0};
66
67 strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete;
68 void operator=(strided_tensor_iter_fixed const& x) = delete;
69 strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default;
70 strided_tensor_iter_fixed(Tensor& tensor, bool sort_strides = false)
71 : data_(tensor.data_ptr<T>()) {
72 (void)sort_strides; // Suppress unused variable warning
73 std::memset(counter_, 0, sizeof(int64_t) * N);
74 if (tensor.dim() > 0) {
75 std::memcpy(
76 sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t));
77 std::memcpy(
78 strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t));
79 }
80 dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension()));
81 }
82};
83
84template <typename T>
85struct strided_tensor_iter {
86 private:
87 public:
88 T* data_ = NULL;
89 int64_t dim_;
90
91 std::vector<int64_t> counter_;
92 std::vector<int64_t> sizes_;
93 std::vector<int64_t> strides_;
94
95 strided_tensor_iter(strided_tensor_iter const&) = delete;
96 void operator=(strided_tensor_iter const& x) = delete;
97 strided_tensor_iter(strided_tensor_iter&&) = default;
98 strided_tensor_iter(Tensor& tensor)
99 : data_(tensor.data_ptr<T>()),
100 dim_(tensor.ndimension()),
101 counter_(dim_, 0),
102 sizes_(tensor.sizes().vec()),
103 strides_(tensor.strides().vec()) {
104 dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_));
105 }
106};
107
108inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) {
109 if (tensors.empty())
110 return true;
111 int64_t all_numel = tensors[0].numel();
112 for (const auto i : c10::irange(1, tensors.size())) {
113 if (tensors[i].numel() != all_numel)
114 return false;
115 }
116 return true;
117}
118
119inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) {
120 std::ostringstream oss;
121 oss << "inconsistent tensor size, expected ";
122 for (size_t i = 0; i < tensors.size() - 1; i++) {
123 oss << tensors[i].sizes() << ", ";
124 }
125 oss << "and " << tensors[tensors.size() - 1].sizes()
126 << " to have the same number of elements, but got ";
127 for (size_t i = 0; i < tensors.size() - 1; i++) {
128 oss << tensors[i].numel() << ", ";
129 }
130 oss << "and " << tensors[tensors.size() - 1].numel()
131 << " elements respectively";
132 return oss.str();
133}
134
135inline bool _apply_preamble(ArrayRef<Tensor> tensors) {
136 checkDeviceType("CPU_tensor_apply", tensors, kCPU);
137 checkLayout("CPU_tensor_apply", tensors, kStrided);
138 if (!_all_equal_numel(tensors))
139 AT_ERROR(_all_equal_numel_error(tensors));
140 // An empty tensor has no elements
141 for (auto& t : tensors)
142 if (t.numel() == 0)
143 return false;
144 return true;
145}
146
147inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) {
148 int64_t dim = 0;
149 for (auto& t : tensors)
150 dim = std::max(dim, t.ndimension());
151 return dim;
152}
153
154inline void iterate(int64_t /*size*/){};
155
156template <typename Arg, typename... Args>
157inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) {
158 iter.counter_[iter.dim_ - 1] += size;
159 iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1];
160 iterate(size, iter_tail...);
161}
162
163inline bool iterate_continue() {
164 return true;
165};
166
167template <typename Arg, typename... Args>
168inline bool iterate_continue(Arg& iter, Args&... iter_tail) {
169 return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] &&
170 iterate_continue(iter_tail...);
171}
172
173inline int64_t max_iterate_size() {
174 return std::numeric_limits<int64_t>::max();
175};
176
177template <typename Arg, typename... Args>
178inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) {
179 return std::min(
180 (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]),
181 max_iterate_size(iter_tail...));
182}
183
184inline void iterate_overflow(){};
185
186template <typename Arg, typename... Args>
187inline void iterate_overflow(Arg& iter, Args&... iter_tail) {
188 if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) {
189 for (int64_t i = iter.dim_ - 1; i > 0; i--) {
190 if (iter.counter_[i] == iter.sizes_[i]) {
191 iter.counter_[i] = 0;
192 iter.counter_[i - 1]++;
193 iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) +
194 iter.strides_[i - 1];
195 }
196 }
197 }
198 iterate_overflow(iter_tail...);
199}
200
201inline void forward(int64_t /*offset*/){};
202
203template <typename Arg, typename... Args>
204inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) {
205 int64_t multi = offset;
206 for (int64_t i = iter.dim_ - 1; i >= 0; i--) {
207 int64_t inc = multi % iter.sizes_[i];
208 multi = multi / iter.sizes_[i];
209 iter.data_ = iter.data_ + inc * iter.strides_[i];
210 iter.counter_[i] += inc;
211 }
212 forward(offset, iter_tail...);
213}
214
215inline int64_t max_dim() {
216 return 0;
217}
218
219template <typename Arg, typename... Args>
220inline int64_t max_dim(Arg& iter, Args&... iter_tail) {
221 return std::max(iter.dim_, max_dim(iter_tail...));
222}
223
224inline void apply_op(){};
225
226template <typename Op, typename... Args>
227inline void apply_op(
228 int64_t numel,
229 int64_t offset,
230 const Op& op,
231 Args... iters) {
232 // For 0-dim tensors
233 if (numel == 1 && max_dim(iters...) == 0) {
234 op(*iters.data_...);
235 return;
236 }
237 if (offset > 0)
238 forward(offset, iters...);
239 // Splitting this into chunks helps the compiler create faster assembly
240 for (int64_t i = 0; i < numel;) {
241 for (; iterate_continue(iters...) && i < numel;) {
242 op(*iters.data_...);
243 iterate(1, iters...);
244 i++;
245 }
246 iterate_overflow(iters...);
247 }
248}
249
250/*
251 Apply a pointwise operator to sequence of tensors
252
253 The calling convention for op is a function/functor that takes the same
254 number of pointers of type scalar as the number of given tensors. For example,
255 to compute a = b * c, op would be of the form:
256 [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] =
257 b_val[0] * c_val[0]; };
258*/
259
260template <typename scalar1, typename scalar2, typename Op>
261inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) {
262 if (!_apply_preamble({tensor1, tensor2}))
263 return;
264 if (_max_dim_tensors({tensor1, tensor2}) <= 8) {
265 apply_op(
266 tensor1.numel(),
267 0,
268 op,
269 strided_tensor_iter_fixed<scalar1, 8>(tensor1),
270 strided_tensor_iter_fixed<scalar2, 8>(tensor2));
271 } else {
272 apply_op(
273 tensor1.numel(),
274 0,
275 op,
276 strided_tensor_iter<scalar1>(tensor1),
277 strided_tensor_iter<scalar2>(tensor2));
278 }
279}
280
281template <typename scalar1, typename scalar2, typename scalar3, typename Op>
282inline void CPU_tensor_apply3(
283 Tensor tensor1,
284 Tensor tensor2,
285 Tensor tensor3,
286 const Op op) {
287 if (!_apply_preamble({tensor1, tensor2, tensor3}))
288 return;
289 if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) {
290 apply_op(
291 tensor1.numel(),
292 0,
293 op,
294 strided_tensor_iter_fixed<scalar1, 8>(tensor1),
295 strided_tensor_iter_fixed<scalar2, 8>(tensor2),
296 strided_tensor_iter_fixed<scalar3, 8>(tensor3));
297 } else {
298 apply_op(
299 tensor1.numel(),
300 0,
301 op,
302 strided_tensor_iter<scalar1>(tensor1),
303 strided_tensor_iter<scalar2>(tensor2),
304 strided_tensor_iter<scalar3>(tensor3));
305 }
306}
307
308template <
309 typename scalar1,
310 typename scalar2,
311 typename scalar3,
312 typename scalar4,
313 typename Op>
314inline void CPU_tensor_apply4(
315 Tensor tensor1,
316 Tensor tensor2,
317 Tensor tensor3,
318 Tensor tensor4,
319 const Op op) {
320 if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4}))
321 return;
322 if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) {
323 apply_op(
324 tensor1.numel(),
325 0,
326 op,
327 strided_tensor_iter_fixed<scalar1, 8>(tensor1),
328 strided_tensor_iter_fixed<scalar2, 8>(tensor2),
329 strided_tensor_iter_fixed<scalar3, 8>(tensor3),
330 strided_tensor_iter_fixed<scalar4, 8>(tensor4));
331 } else {
332 apply_op(
333 tensor1.numel(),
334 0,
335 op,
336 strided_tensor_iter<scalar1>(tensor1),
337 strided_tensor_iter<scalar2>(tensor2),
338 strided_tensor_iter<scalar3>(tensor3),
339 strided_tensor_iter<scalar4>(tensor4));
340 }
341}
342
343} // namespace at
344