1 | #pragma once |
2 | |
3 | #include <ATen/CollapseDims.h> |
4 | #include <ATen/Parallel.h> |
5 | #include <ATen/TensorUtils.h> |
6 | #include <c10/util/irange.h> |
7 | #include <cstring> |
8 | #include <limits> |
9 | #include <utility> |
10 | |
11 | namespace at { |
12 | |
13 | /* |
14 | * The basic strategy for apply is as follows: |
15 | * |
16 | * 1. Starting with the outermost index, loop until we reach a dimension where |
17 | * the data is no longer contiguous, i.e. the stride at that dimension is not |
18 | * equal to the size of the tensor defined by the outer dimensions. Let's call |
19 | * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then |
20 | * A is equal to the entire Tensor. Let's call the inner tensor B. |
21 | * |
22 | * 2. We loop through the indices in B, starting at its outermost dimension. For |
23 | * example, if B is a 2x2 matrix, then we do: |
24 | * |
25 | * B[0][0] |
26 | * B[0][1] |
27 | * B[1][0] |
28 | * B[1][1] |
29 | * |
30 | * We set the offset into the underlying storage as (storageOffset + stride_B * |
31 | * index_B), i.e. basically we compute the offset into the storage as we would |
32 | * normally for a Tensor. But because we are guaranteed the subsequent data is |
33 | * contiguous in memory, we can simply loop for sizeof(A) iterations and perform |
34 | * the operation, without having to follow the order described by the strides of |
35 | * A. |
36 | * |
37 | * 3. As an optimization, we merge dimensions of A that are contiguous in |
38 | * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor, |
39 | * then the first two dimensions can be merged for the purposes of APPLY, |
40 | * reducing the number of nested loops. |
41 | */ |
42 | |
43 | inline Tensor sort_strides(Tensor& tensor_) { |
44 | IntArrayRef strides = tensor_.strides(); |
45 | std::vector<int64_t> indices; |
46 | indices.reserve(tensor_.ndimension()); |
47 | for (const auto i : c10::irange(tensor_.ndimension())) { |
48 | indices.push_back(i); |
49 | } |
50 | std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) { |
51 | return strides[i1] > strides[i2]; |
52 | }); |
53 | Tensor tensor = tensor_.permute(indices); |
54 | return tensor; |
55 | } |
56 | |
57 | template <typename T, int N> |
58 | struct strided_tensor_iter_fixed { |
59 | public: |
60 | T* data_ = NULL; |
61 | int64_t dim_ = 0; |
62 | |
63 | int64_t counter_[N] = {0}; |
64 | int64_t sizes_[N] = {0}; |
65 | int64_t strides_[N] = {0}; |
66 | |
67 | strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete; |
68 | void operator=(strided_tensor_iter_fixed const& x) = delete; |
69 | strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default; |
70 | strided_tensor_iter_fixed(Tensor& tensor, bool sort_strides = false) |
71 | : data_(tensor.data_ptr<T>()) { |
72 | (void)sort_strides; // Suppress unused variable warning |
73 | std::memset(counter_, 0, sizeof(int64_t) * N); |
74 | if (tensor.dim() > 0) { |
75 | std::memcpy( |
76 | sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t)); |
77 | std::memcpy( |
78 | strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t)); |
79 | } |
80 | dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension())); |
81 | } |
82 | }; |
83 | |
84 | template <typename T> |
85 | struct strided_tensor_iter { |
86 | private: |
87 | public: |
88 | T* data_ = NULL; |
89 | int64_t dim_; |
90 | |
91 | std::vector<int64_t> counter_; |
92 | std::vector<int64_t> sizes_; |
93 | std::vector<int64_t> strides_; |
94 | |
95 | strided_tensor_iter(strided_tensor_iter const&) = delete; |
96 | void operator=(strided_tensor_iter const& x) = delete; |
97 | strided_tensor_iter(strided_tensor_iter&&) = default; |
98 | strided_tensor_iter(Tensor& tensor) |
99 | : data_(tensor.data_ptr<T>()), |
100 | dim_(tensor.ndimension()), |
101 | counter_(dim_, 0), |
102 | sizes_(tensor.sizes().vec()), |
103 | strides_(tensor.strides().vec()) { |
104 | dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_)); |
105 | } |
106 | }; |
107 | |
108 | inline bool _all_equal_numel(at::ArrayRef<Tensor> tensors) { |
109 | if (tensors.empty()) |
110 | return true; |
111 | int64_t all_numel = tensors[0].numel(); |
112 | for (const auto i : c10::irange(1, tensors.size())) { |
113 | if (tensors[i].numel() != all_numel) |
114 | return false; |
115 | } |
116 | return true; |
117 | } |
118 | |
119 | inline std::string _all_equal_numel_error(at::ArrayRef<Tensor> tensors) { |
120 | std::ostringstream oss; |
121 | oss << "inconsistent tensor size, expected " ; |
122 | for (size_t i = 0; i < tensors.size() - 1; i++) { |
123 | oss << tensors[i].sizes() << ", " ; |
124 | } |
125 | oss << "and " << tensors[tensors.size() - 1].sizes() |
126 | << " to have the same number of elements, but got " ; |
127 | for (size_t i = 0; i < tensors.size() - 1; i++) { |
128 | oss << tensors[i].numel() << ", " ; |
129 | } |
130 | oss << "and " << tensors[tensors.size() - 1].numel() |
131 | << " elements respectively" ; |
132 | return oss.str(); |
133 | } |
134 | |
135 | inline bool _apply_preamble(ArrayRef<Tensor> tensors) { |
136 | checkDeviceType("CPU_tensor_apply" , tensors, kCPU); |
137 | checkLayout("CPU_tensor_apply" , tensors, kStrided); |
138 | if (!_all_equal_numel(tensors)) |
139 | AT_ERROR(_all_equal_numel_error(tensors)); |
140 | // An empty tensor has no elements |
141 | for (auto& t : tensors) |
142 | if (t.numel() == 0) |
143 | return false; |
144 | return true; |
145 | } |
146 | |
147 | inline int64_t _max_dim_tensors(ArrayRef<Tensor> tensors) { |
148 | int64_t dim = 0; |
149 | for (auto& t : tensors) |
150 | dim = std::max(dim, t.ndimension()); |
151 | return dim; |
152 | } |
153 | |
154 | inline void iterate(int64_t /*size*/){}; |
155 | |
156 | template <typename Arg, typename... Args> |
157 | inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) { |
158 | iter.counter_[iter.dim_ - 1] += size; |
159 | iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1]; |
160 | iterate(size, iter_tail...); |
161 | } |
162 | |
163 | inline bool iterate_continue() { |
164 | return true; |
165 | }; |
166 | |
167 | template <typename Arg, typename... Args> |
168 | inline bool iterate_continue(Arg& iter, Args&... iter_tail) { |
169 | return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] && |
170 | iterate_continue(iter_tail...); |
171 | } |
172 | |
173 | inline int64_t max_iterate_size() { |
174 | return std::numeric_limits<int64_t>::max(); |
175 | }; |
176 | |
177 | template <typename Arg, typename... Args> |
178 | inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) { |
179 | return std::min( |
180 | (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]), |
181 | max_iterate_size(iter_tail...)); |
182 | } |
183 | |
184 | inline void iterate_overflow(){}; |
185 | |
186 | template <typename Arg, typename... Args> |
187 | inline void iterate_overflow(Arg& iter, Args&... iter_tail) { |
188 | if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) { |
189 | for (int64_t i = iter.dim_ - 1; i > 0; i--) { |
190 | if (iter.counter_[i] == iter.sizes_[i]) { |
191 | iter.counter_[i] = 0; |
192 | iter.counter_[i - 1]++; |
193 | iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) + |
194 | iter.strides_[i - 1]; |
195 | } |
196 | } |
197 | } |
198 | iterate_overflow(iter_tail...); |
199 | } |
200 | |
201 | inline void forward(int64_t /*offset*/){}; |
202 | |
203 | template <typename Arg, typename... Args> |
204 | inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) { |
205 | int64_t multi = offset; |
206 | for (int64_t i = iter.dim_ - 1; i >= 0; i--) { |
207 | int64_t inc = multi % iter.sizes_[i]; |
208 | multi = multi / iter.sizes_[i]; |
209 | iter.data_ = iter.data_ + inc * iter.strides_[i]; |
210 | iter.counter_[i] += inc; |
211 | } |
212 | forward(offset, iter_tail...); |
213 | } |
214 | |
215 | inline int64_t max_dim() { |
216 | return 0; |
217 | } |
218 | |
219 | template <typename Arg, typename... Args> |
220 | inline int64_t max_dim(Arg& iter, Args&... iter_tail) { |
221 | return std::max(iter.dim_, max_dim(iter_tail...)); |
222 | } |
223 | |
224 | inline void apply_op(){}; |
225 | |
226 | template <typename Op, typename... Args> |
227 | inline void apply_op( |
228 | int64_t numel, |
229 | int64_t offset, |
230 | const Op& op, |
231 | Args... iters) { |
232 | // For 0-dim tensors |
233 | if (numel == 1 && max_dim(iters...) == 0) { |
234 | op(*iters.data_...); |
235 | return; |
236 | } |
237 | if (offset > 0) |
238 | forward(offset, iters...); |
239 | // Splitting this into chunks helps the compiler create faster assembly |
240 | for (int64_t i = 0; i < numel;) { |
241 | for (; iterate_continue(iters...) && i < numel;) { |
242 | op(*iters.data_...); |
243 | iterate(1, iters...); |
244 | i++; |
245 | } |
246 | iterate_overflow(iters...); |
247 | } |
248 | } |
249 | |
250 | /* |
251 | Apply a pointwise operator to sequence of tensors |
252 | |
253 | The calling convention for op is a function/functor that takes the same |
254 | number of pointers of type scalar as the number of given tensors. For example, |
255 | to compute a = b * c, op would be of the form: |
256 | [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] = |
257 | b_val[0] * c_val[0]; }; |
258 | */ |
259 | |
260 | template <typename scalar1, typename scalar2, typename Op> |
261 | inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) { |
262 | if (!_apply_preamble({tensor1, tensor2})) |
263 | return; |
264 | if (_max_dim_tensors({tensor1, tensor2}) <= 8) { |
265 | apply_op( |
266 | tensor1.numel(), |
267 | 0, |
268 | op, |
269 | strided_tensor_iter_fixed<scalar1, 8>(tensor1), |
270 | strided_tensor_iter_fixed<scalar2, 8>(tensor2)); |
271 | } else { |
272 | apply_op( |
273 | tensor1.numel(), |
274 | 0, |
275 | op, |
276 | strided_tensor_iter<scalar1>(tensor1), |
277 | strided_tensor_iter<scalar2>(tensor2)); |
278 | } |
279 | } |
280 | |
281 | template <typename scalar1, typename scalar2, typename scalar3, typename Op> |
282 | inline void CPU_tensor_apply3( |
283 | Tensor tensor1, |
284 | Tensor tensor2, |
285 | Tensor tensor3, |
286 | const Op op) { |
287 | if (!_apply_preamble({tensor1, tensor2, tensor3})) |
288 | return; |
289 | if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) { |
290 | apply_op( |
291 | tensor1.numel(), |
292 | 0, |
293 | op, |
294 | strided_tensor_iter_fixed<scalar1, 8>(tensor1), |
295 | strided_tensor_iter_fixed<scalar2, 8>(tensor2), |
296 | strided_tensor_iter_fixed<scalar3, 8>(tensor3)); |
297 | } else { |
298 | apply_op( |
299 | tensor1.numel(), |
300 | 0, |
301 | op, |
302 | strided_tensor_iter<scalar1>(tensor1), |
303 | strided_tensor_iter<scalar2>(tensor2), |
304 | strided_tensor_iter<scalar3>(tensor3)); |
305 | } |
306 | } |
307 | |
308 | template < |
309 | typename scalar1, |
310 | typename scalar2, |
311 | typename scalar3, |
312 | typename scalar4, |
313 | typename Op> |
314 | inline void CPU_tensor_apply4( |
315 | Tensor tensor1, |
316 | Tensor tensor2, |
317 | Tensor tensor3, |
318 | Tensor tensor4, |
319 | const Op op) { |
320 | if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4})) |
321 | return; |
322 | if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) { |
323 | apply_op( |
324 | tensor1.numel(), |
325 | 0, |
326 | op, |
327 | strided_tensor_iter_fixed<scalar1, 8>(tensor1), |
328 | strided_tensor_iter_fixed<scalar2, 8>(tensor2), |
329 | strided_tensor_iter_fixed<scalar3, 8>(tensor3), |
330 | strided_tensor_iter_fixed<scalar4, 8>(tensor4)); |
331 | } else { |
332 | apply_op( |
333 | tensor1.numel(), |
334 | 0, |
335 | op, |
336 | strided_tensor_iter<scalar1>(tensor1), |
337 | strided_tensor_iter<scalar2>(tensor2), |
338 | strided_tensor_iter<scalar3>(tensor3), |
339 | strided_tensor_iter<scalar4>(tensor4)); |
340 | } |
341 | } |
342 | |
343 | } // namespace at |
344 | |