1 | #pragma once |
2 | |
3 | #include <ATen/TensorMeta.h> |
4 | #include <ATen/core/Dimname.h> |
5 | #include <ATen/core/Range.h> |
6 | #include <ATen/core/TensorBase.h> |
7 | #include <c10/core/DynamicCast.h> |
8 | #include <c10/util/FunctionRef.h> |
9 | #include <c10/util/MaybeOwned.h> |
10 | #include <c10/util/SmallVector.h> |
11 | #include <c10/util/TypeCast.h> |
12 | #include <c10/util/irange.h> |
13 | |
14 | #include <array> |
15 | #include <bitset> |
16 | |
17 | C10_CLANG_DIAGNOSTIC_PUSH() |
18 | #if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32") |
19 | C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32" ) |
20 | #endif |
21 | #if C10_CLANG_HAS_WARNING("-Wdeprecated-copy-dtor") |
22 | C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy-dtor" ) |
23 | #endif |
24 | |
25 | namespace at { |
26 | class Tensor; |
27 | class OptionalTensorRef; |
28 | using NameVector = SmallVector<Dimname, kDimVectorStaticSize>; |
29 | } // namespace at |
30 | |
31 | // TensorIterator is a helper class for element-wise operations, such as |
32 | // arithmetic, comparisons, and trigonometric functions. It handles |
33 | // broadcasting and type conversions of operands. |
34 | // |
35 | // This is inspired by NumPy's Array Iterator API (NpyIter). |
36 | // |
37 | // The files Loops.h and Loops.cuh provide functions to build kernels that |
38 | // use TensorIterator. |
39 | // |
40 | // Example: |
41 | // |
42 | // auto iter = TensorIteratorConfig() |
43 | // .add_output(output) |
44 | // .add_input(input) |
45 | // .build() |
46 | // |
47 | // [MyKernel.cpp / MyKernel.cu] |
48 | // cpu_kernel(iter, [](float a, float b) { |
49 | // return a + b; |
50 | // }); |
51 | // |
52 | // gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float { |
53 | // return a + b; |
54 | // }); |
55 | // |
56 | // Note [Order of Construction] |
57 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
58 | // When setting up the tensor iterator configuration, the output Tensors |
59 | // have to be added first via |
60 | // TensorIteratorConfig::add_owned_output(at::Tensor). After adding all outputs, |
61 | // the inputs can be added via |
62 | // TensorIteratorConfig::add_owned_input(at::Tensor). |
63 | // Adding another output after inputs have been added will rise an exception. |
64 | // |
65 | // Note [Common Dtype Computation] |
66 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
67 | // Some operations have a natural notion of a "common dtype" or |
68 | // "computation dtype" where all inputs are cast to one dtype, the |
69 | // operation is performed, and then the results are cast to all outputs. |
70 | // |
71 | // TensorIterator infers a common dtype if all inputs have the same dtype, |
72 | // and it computes one using type promotion rules on its inputs if |
73 | // promote_inputs_to_common_dtype_ is true. Attempting to query |
74 | // a common dtype otherwise will throw an exception. |
75 | // |
76 | // Note that the outputs are not considered when computing a common dtype. |
77 | |
78 | namespace at { |
79 | |
80 | namespace internal { |
81 | // This parameter is heuristically chosen to determine the minimum number of |
82 | // work that warrants parallelism. For example, when summing an array, it is |
83 | // deemed inefficient to parallelise over arrays shorter than 32768. Further, |
84 | // no parallel algorithm (such as parallel_reduce) should split work into |
85 | // smaller than GRAIN_SIZE chunks. |
86 | constexpr int64_t GRAIN_SIZE = 32768; |
87 | |
88 | // Storage for a non-owning Tensor, without needing to include Tensor.h |
89 | class TORCH_API OpaqueOptionalTensorRef { |
90 | alignas(alignof(TensorBase)) std::array<char, sizeof(TensorBase)> data_; |
91 | |
92 | public: |
93 | OpaqueOptionalTensorRef(); |
94 | ~OpaqueOptionalTensorRef(); |
95 | |
96 | OptionalTensorRef* get() { |
97 | return reinterpret_cast<OptionalTensorRef*>(data_.data()); |
98 | } |
99 | const OptionalTensorRef* get() const { |
100 | return reinterpret_cast<const OptionalTensorRef*>(data_.data()); |
101 | } |
102 | |
103 | OptionalTensorRef& operator*() { |
104 | return *get(); |
105 | } |
106 | const OptionalTensorRef& operator*() const { |
107 | return *get(); |
108 | } |
109 | OptionalTensorRef* operator->() { |
110 | return get(); |
111 | } |
112 | const OptionalTensorRef* operator->() const { |
113 | return get(); |
114 | } |
115 | |
116 | const Tensor& getTensor() const; |
117 | }; |
118 | } // namespace internal |
119 | |
120 | struct TORCH_API OperandInfo { |
121 | using StrideVector = SmallVector<int64_t, 6>; |
122 | OperandInfo() = default; |
123 | C10_ALWAYS_INLINE explicit OperandInfo(c10::MaybeOwned<TensorBase>&& t) { |
124 | if (t->defined()) { |
125 | device = t->device(); |
126 | target_dtype = t->scalar_type(); |
127 | current_dtype = target_dtype; |
128 | } |
129 | tensor(std::move(t)); |
130 | validate(); |
131 | } |
132 | |
133 | C10_ALWAYS_INLINE ~OperandInfo() = default; |
134 | |
135 | /// Stride after broadcasting. The stride is in bytes, not number of elements. |
136 | StrideVector stride_bytes; |
137 | |
138 | /// The desired device and type for the operand. For inputs, this specifies |
139 | /// that the input should be converted to this type if necessary. For outputs, |
140 | /// this specifies which type to allocate. target_dtype and device are |
141 | /// initialized with the dtype and device of the tensor but during type |
142 | /// promotion target_dtype value can become different from tensor's dtype |
143 | /// also, during type promotion target_dtype and device can be set for an |
144 | /// undefined tensor so that tensor can be properly constructed later. |
145 | c10::optional<Device> device = c10::nullopt; |
146 | ScalarType target_dtype = ScalarType::Undefined; |
147 | // Caches dtype of the tensor, because scalar_type is an expensive operation |
148 | // If dtype of the tensor is changed (e.g. as a result of type promotion or in |
149 | // allocate_outputs), this |
150 | // value should be changed too. |
151 | ScalarType current_dtype = ScalarType::Undefined; |
152 | |
153 | bool is_device_defined() const { |
154 | return device.has_value(); |
155 | } |
156 | bool is_type_defined() const { |
157 | return target_dtype != ScalarType::Undefined; |
158 | } |
159 | TensorOptions options() const { |
160 | return TensorOptions(target_dtype).device(device); |
161 | } |
162 | |
163 | /// The data pointer. This may be different from tensor->data_ptr() if the |
164 | /// iterator is split. |
165 | void* data = nullptr; |
166 | |
167 | bool is_output = false; |
168 | |
169 | bool will_resize = false; |
170 | |
171 | bool is_read_write = false; |
172 | |
173 | void validate() { |
174 | TORCH_CHECK( |
175 | !tensor_base_->defined() || tensor_base_->layout() == kStrided, |
176 | "unsupported tensor layout: " , |
177 | tensor_base_->layout()); |
178 | } |
179 | |
180 | /// The tensor operand. Note that the strides, data pointer, and |
181 | /// other attributes may differ due to dimension reordering and |
182 | /// coalescing. |
183 | const Tensor& tensor() const { |
184 | return tensor_storage_.getTensor(); |
185 | } |
186 | const TensorBase& tensor_base() const { |
187 | return *tensor_base_; |
188 | } |
189 | void tensor(c10::MaybeOwned<TensorBase>&& tensor); |
190 | |
191 | // Save the original tensor operand in cases when an output is modified |
192 | // (e.g. if dtype is changed) |
193 | const Tensor& original_tensor() const { |
194 | return original_tensor_storage_.getTensor(); |
195 | } |
196 | const TensorBase& original_tensor_base() const { |
197 | return *original_tensor_base_; |
198 | } |
199 | |
200 | // Set tensor to a new value, and store the old tensor value in |
201 | // original_tensor Should only ever be called once for the lifetime of an |
202 | // operand |
203 | void exchange_tensor(c10::MaybeOwned<TensorBase>&& new_tensor); |
204 | |
205 | // Move original_tensor back into tensor, exchange_tensor must have been |
206 | // called before |
207 | void restore_original_tensor(); |
208 | |
209 | private: |
210 | c10::MaybeOwned<TensorBase> tensor_base_; |
211 | c10::MaybeOwned<TensorBase> original_tensor_base_ = |
212 | c10::MaybeOwned<TensorBase>::owned(c10::in_place); |
213 | |
214 | // We store TensorBase visibly in the header to allow inline access. |
215 | // However, we sometimes need a genuine `const Tensor &` for the |
216 | // TensorIterator API. So, we also store a non-owning `Tensor` |
217 | // object in these `_storage_` variables. |
218 | internal::OpaqueOptionalTensorRef tensor_storage_; |
219 | internal::OpaqueOptionalTensorRef original_tensor_storage_; |
220 | }; |
221 | |
222 | struct SplitUntil32Bit; |
223 | |
224 | enum class FastSetupType : uint8_t { |
225 | NONE, |
226 | CONTIGUOUS, |
227 | CHANNELS_LAST, |
228 | NON_OVERLAPPING_DENSE |
229 | }; |
230 | |
231 | class TensorIteratorConfig; |
232 | struct TensorIterator; |
233 | |
234 | struct TORCH_API TensorIteratorBase : public impl::MetaBase { |
235 | using DimMask = std::bitset<64>; |
236 | using PtrVector = SmallVector<char*, 4>; |
237 | using StrideVector = SmallVector<int64_t, 6>; |
238 | |
239 | TensorIteratorBase(); |
240 | void build(TensorIteratorConfig&); |
241 | |
242 | // The inner-loop function operates on the fastest moving dimension. It |
243 | // implements element-wise operations in terms of 1-d strided tensors. |
244 | // |
245 | // Arguments: |
246 | // data: data pointers for each operand (length `ntensors`) |
247 | // strides: stride for each operand (length `ntensors`) |
248 | // size: size of inner loop |
249 | // |
250 | // The `size` often matches shape[0], but may be smaller due to |
251 | // parallelization of the inner loop. |
252 | using loop2d_t = c10::function_ref< |
253 | void(char** data, const int64_t* strides, int64_t size0, int64_t size1)>; |
254 | |
255 | using loop_subiter_t = c10::function_ref<void(TensorIteratorBase& subiter)>; |
256 | |
257 | void foreach_reduced_elt(loop_subiter_t loop, bool parallelize = true); |
258 | |
259 | int ndim() const { |
260 | return shape_.size(); |
261 | } |
262 | IntArrayRef shape() const { |
263 | return shape_; |
264 | } |
265 | int64_t numel() const; |
266 | int ntensors() const { |
267 | return operands_.size(); |
268 | } |
269 | int noutputs() const { |
270 | return num_outputs_; |
271 | } |
272 | int ninputs() const { |
273 | return ntensors() - noutputs(); |
274 | } |
275 | IntArrayRef view_offsets() const { |
276 | return view_offsets_; |
277 | } |
278 | |
279 | /// number of elements in the output operand. this is the same as numel() for |
280 | /// operations that are not reductions. |
281 | int64_t num_output_elements() const; |
282 | |
283 | /// number of reduced dimensions in a reduction operation |
284 | int num_reduce_dims() const; |
285 | |
286 | /// 1-dimensional iteration and no buffering or type conversion |
287 | bool is_trivial_1d() const; |
288 | /// Reducible to 1-dimensional and all operands are contiguous |
289 | bool is_contiguous() const; |
290 | bool is_dim_reduced(int dim) const; |
291 | |
292 | /// Accessors for each operand |
293 | IntArrayRef strides(int arg) const { |
294 | return operands_[arg].stride_bytes; |
295 | } |
296 | void* data_ptr(int arg) const; |
297 | ScalarType dtype(int arg = 0) const { |
298 | return operands_[arg].current_dtype; |
299 | } |
300 | ScalarType common_dtype() const { |
301 | TORCH_INTERNAL_ASSERT( |
302 | common_dtype_ != ScalarType::Undefined, |
303 | "Queried for invalid common dtype!" ); |
304 | return common_dtype_; |
305 | } |
306 | ScalarType input_dtype(int arg = 0) const { |
307 | return operands_[num_outputs_ + arg].current_dtype; |
308 | } |
309 | Device device(int arg = 0) const { |
310 | return operands_[arg].device.value(); |
311 | } |
312 | DeviceType device_type(int arg = 0) const { |
313 | return device(arg).type(); |
314 | } |
315 | int64_t element_size(int arg) const { |
316 | return elementSize(dtype(arg)); |
317 | } |
318 | bool is_scalar(int arg) const; |
319 | bool is_cpu_scalar(int arg) const; |
320 | |
321 | const TensorBase& tensor_base(int arg) const { |
322 | return operands_[arg].tensor_base(); |
323 | } |
324 | const Tensor& tensor(int arg) const { |
325 | return operands_[arg].tensor(); |
326 | } |
327 | |
328 | const TensorBase& output_base(int arg = 0) const { |
329 | AT_ASSERT(arg < num_outputs_); |
330 | return tensor_base(arg); |
331 | } |
332 | |
333 | const Tensor& output(int arg = 0) const { |
334 | AT_ASSERT(arg < num_outputs_); |
335 | return tensor(arg); |
336 | } |
337 | |
338 | const TensorBase& input_base(int arg = 0) const { |
339 | AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); |
340 | return tensor_base(num_outputs_ + arg); |
341 | } |
342 | const Tensor& input(int arg = 0) const { |
343 | AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); |
344 | return tensor(num_outputs_ + arg); |
345 | } |
346 | |
347 | // Copies from temporary outputs back to the original outputs |
348 | // NOTE: only used on CPU |
349 | void cast_outputs(); |
350 | |
351 | /// Removes an operand from this iterator |
352 | void remove_operand(int arg); |
353 | /// Shrinks an iterated dimension |
354 | void narrow(int dim, int64_t start, int64_t size); |
355 | /// Narrows every dim after and including `start_dim` to size one. |
356 | void select_all_keeping_dim(int start_dim, IntArrayRef starts); |
357 | /// Replaces the data pointer for the operand at index `arg`. |
358 | /// The new pointer should have the same sizes, strides and dtype as the |
359 | /// original |
360 | void unsafe_replace_operand(int arg, void* data); |
361 | |
362 | /// Splits this TensorIterator into two iterators. Together they iterate over |
363 | /// the entire operation. Used by `with_32bit_indexing()`. |
364 | std::unique_ptr<TensorIterator> split(int dim); |
365 | |
366 | /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim] |
367 | int get_dim_to_split() const; |
368 | |
369 | template <typename T> |
370 | T scalar_value(int arg) { |
371 | auto& op = operands_[arg]; |
372 | return c10::fetch_and_cast<T>(op.tensor_base().scalar_type(), op.data); |
373 | } |
374 | |
375 | private: |
376 | template <typename loop1d_t> |
377 | auto loop_2d_from_1d(const loop1d_t& loop) { |
378 | return |
379 | [loop, ntensor = ntensors()]( |
380 | char** base, const int64_t* strides, int64_t size0, int64_t size1) { |
381 | PtrVector data(base, base + ntensor); |
382 | const int64_t* outer_strides = &strides[ntensor]; |
383 | for (const auto i : c10::irange(size1)) { |
384 | if (i > 0) { |
385 | for (const auto arg : c10::irange(ntensor)) { |
386 | data[arg] += outer_strides[arg]; |
387 | } |
388 | } |
389 | loop(data.data(), strides, size0); |
390 | } |
391 | }; |
392 | } |
393 | |
394 | public: |
395 | template < |
396 | typename loop1d_t, |
397 | std::enable_if_t< |
398 | std::is_convertible< |
399 | loop1d_t, |
400 | c10::function_ref< |
401 | void(char**, const int64_t* strides, int64_t size)>>::value, |
402 | int> = 0> |
403 | void for_each(loop1d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE) { |
404 | for_each(loop_2d_from_1d(loop), grain_size); |
405 | } |
406 | |
407 | void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE); |
408 | |
409 | void parallel_reduce(loop2d_t loop); |
410 | |
411 | template < |
412 | typename loop1d_t, |
413 | std::enable_if_t< |
414 | std::is_convertible< |
415 | loop1d_t, |
416 | c10::function_ref< |
417 | void(char**, const int64_t* strides, int64_t size)>>::value, |
418 | int> = 0> |
419 | void serial_for_each(loop1d_t loop, Range range) { |
420 | serial_for_each(loop_2d_from_1d(loop), range); |
421 | } |
422 | |
423 | void serial_for_each(loop2d_t loop, Range range) const; |
424 | |
425 | /// Create a strides array for a Tensor with shape of this iterator. The |
426 | /// parameter `element_size` specifies the size of Tensor's data type in |
427 | /// bytes (e.g. `4` for `float`) |
428 | StrideVector compatible_stride(int element_size) const; |
429 | |
430 | /// Inverts the re-ordering done by reorder_dimensions. This can only be |
431 | /// called *before* coalesce_dimensions() is called. |
432 | DimVector invert_perm(IntArrayRef input) const; |
433 | |
434 | /// Reapply same re-ordering as it is done by reorder_dimensions. This can |
435 | /// only be called *before* coalesce_dimensions() is called. |
436 | DimVector apply_perm_and_mul(IntArrayRef input, int mul) const; |
437 | |
438 | /// Helper functions for CPU iteration |
439 | StrideVector get_dim_strides(int dim) const; |
440 | StrideVector get_strides() const; |
441 | StrideVector get_inner_strides() const { |
442 | return get_dim_strides(0); |
443 | } |
444 | PtrVector get_base_ptrs() const; |
445 | |
446 | // Helper functions for advanced stride manipulations (e.g. torch.flip) |
447 | void _unsafe_set_arg_strides(const int arg, IntArrayRef strides) { |
448 | operands_[arg].stride_bytes = std::move(strides); |
449 | } |
450 | void _unsafe_set_arg_data(const int arg, void* data) { |
451 | operands_[arg].data = data; |
452 | } |
453 | |
454 | /// true if the stride computation can use 32-bit arithmetic. Used by GPU |
455 | /// kernels |
456 | bool can_use_32bit_indexing() const; |
457 | |
458 | /// An "iteratable" object that recursively splits this iterator into |
459 | /// sub-iterators that can use 32-bit indexing. |
460 | SplitUntil32Bit with_32bit_indexing() const; |
461 | |
462 | /// If the kernel should accumulate into the output. Only relevant for CUDA |
463 | /// reductions. |
464 | bool should_accumulate() const { |
465 | return accumulate_; |
466 | } |
467 | |
468 | /// Whether this iterator produces the actual output, |
469 | /// as opposed to something that will be accumulated further. Only relevant |
470 | /// for CUDA reductions. |
471 | bool is_final_output() const { |
472 | return final_output_; |
473 | } |
474 | |
475 | bool has_contiguous_first_dim() const { |
476 | if (ndim() == 0) { |
477 | return true; |
478 | } |
479 | |
480 | int num_tensors = ntensors(); |
481 | for (const auto i : c10::irange(num_tensors)) { |
482 | if (strides(i)[0] != element_size(i)) { |
483 | return false; |
484 | } |
485 | } |
486 | return true; |
487 | } |
488 | |
489 | void set_output_raw_strided( |
490 | int64_t output_idx, |
491 | IntArrayRef sizes, |
492 | IntArrayRef strides, |
493 | TensorOptions options, |
494 | DimnameList names) override; |
495 | |
496 | #define TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, maybestatic) \ |
497 | maybestatic void methodname( \ |
498 | TensorBase&& out, const TensorBase& a, const TensorBase& b) = delete; \ |
499 | maybestatic void methodname( \ |
500 | const TensorBase& out, TensorBase&& a, const TensorBase& b) = delete; \ |
501 | maybestatic void methodname( \ |
502 | const TensorBase& out, const TensorBase& a, TensorBase&& b) = delete; \ |
503 | maybestatic void methodname( \ |
504 | TensorBase&& out, TensorBase&& a, const TensorBase& b) = delete; \ |
505 | maybestatic void methodname( \ |
506 | TensorBase&& out, const TensorBase& a, TensorBase&& b) = delete; \ |
507 | maybestatic void methodname( \ |
508 | const TensorBase& out, TensorBase&& a, TensorBase&& b) = delete; \ |
509 | maybestatic void methodname( \ |
510 | TensorBase&& out, TensorBase&& a, TensorBase&& b) = delete; |
511 | |
512 | #define TORCH_DISALLOW_TEMPORARIES(methodname) \ |
513 | TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, ) |
514 | |
515 | void build_binary_float_op( |
516 | const TensorBase& out, |
517 | const TensorBase& a, |
518 | const TensorBase& b); |
519 | void build_borrowing_binary_float_op( |
520 | const TensorBase& out, |
521 | const TensorBase& a, |
522 | const TensorBase& b); |
523 | TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_float_op) |
524 | void build_binary_op( |
525 | const TensorBase& out, |
526 | const TensorBase& a, |
527 | const TensorBase& b); |
528 | void build_borrowing_binary_op( |
529 | const TensorBase& out, |
530 | const TensorBase& a, |
531 | const TensorBase& b); |
532 | TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op) |
533 | void build_unary_float_op(const TensorBase& out, const TensorBase& a); |
534 | void build_borrowing_unary_float_op( |
535 | const TensorBase& out, |
536 | const TensorBase& a); |
537 | TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op) |
538 | void build_unary_op(const TensorBase& out, const TensorBase& a); |
539 | // Odd special case needed for pow. Has to borrow the output because |
540 | // it's a structured kernel, but the argument is potentially a copy. |
541 | void build_output_borrowing_argument_owning_unary_op( |
542 | const TensorBase& out, |
543 | const TensorBase& a); |
544 | void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a); |
545 | TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op) |
546 | void build_borrowing_unary_force_boolean_op( |
547 | const TensorBase& out, |
548 | const TensorBase& a); |
549 | TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op) |
550 | void build_comparison_op( |
551 | const TensorBase& out, |
552 | const TensorBase& a, |
553 | const TensorBase& b); |
554 | void build_borrowing_comparison_op( |
555 | const TensorBase& out, |
556 | const TensorBase& a, |
557 | const TensorBase& b); |
558 | TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op) |
559 | // Another special case: we need to own the second argument for comparison |
560 | // ops. |
561 | void build_borrowing_except_last_argument_comparison_op( |
562 | const TensorBase& out, |
563 | const TensorBase& a, |
564 | const TensorBase& b); |
565 | void build_ternary_op( |
566 | const TensorBase& out, |
567 | const TensorBase& a, |
568 | const TensorBase& b, |
569 | const TensorBase& c); |
570 | |
571 | #undef TORCH_DISALLOW_TEMPORARIES |
572 | protected: |
573 | // Mutable reference as it moves tensors out of TensorIteratorConfig |
574 | void populate_operands(TensorIteratorConfig&); |
575 | void mark_outputs(); |
576 | void mark_resize_outputs(const TensorIteratorConfig&); |
577 | void compute_mem_overlaps(const TensorIteratorConfig&); |
578 | void compute_shape(const TensorIteratorConfig&); |
579 | void compute_strides(const TensorIteratorConfig&); |
580 | void reorder_dimensions(); |
581 | void permute_dimensions(IntArrayRef perm); |
582 | void compute_types(const TensorIteratorConfig&); |
583 | ScalarType compute_common_dtype(); |
584 | void allocate_or_resize_outputs(); |
585 | bool fast_set_up(const TensorIteratorConfig&); |
586 | FastSetupType compute_fast_setup_type(const TensorIteratorConfig&); |
587 | void compute_names(const TensorIteratorConfig&); |
588 | void propagate_names_to_outputs(); |
589 | void coalesce_dimensions(); |
590 | |
591 | protected: |
592 | /// Records the "computation" shape of the output tensor. The computation |
593 | /// shape is different from the regular shape in a few ways: |
594 | /// |
595 | /// - The shape may be permuted (via permute_dimensions) so that we |
596 | /// process the dimensions in the most computationally efficient order |
597 | /// (rather than the logical order given to us by the users.) |
598 | /// - The shape may have adjacent dimensions collapsed (via |
599 | /// coalesce_dimensions) so that we minimize the number of |
600 | /// dimensions we have to explicitly iterate over. For example, |
601 | /// a pointwise operation on a contiguous tensor "computationally" |
602 | /// consists of only a single dimension. |
603 | /// |
604 | /// In other words, the computation shape is the output shape as it |
605 | /// actually matters for implementing the kernel, but not necessarily the |
606 | /// output shape that the user will see in the end. |
607 | /// |
608 | /// The lifecycle of mutations to shape_ in TensorIterator: |
609 | /// - declare_static_shape() sets an initial shape explicitly |
610 | /// provided by user, otherwise |
611 | /// - compute_shape() computes the true (non-computational) shape |
612 | /// specified by the user. |
613 | /// - reorder_dimensions() reorders dimensions to improve coalescing. |
614 | /// - coalesce_dimensions() then coalesces adjacent dimensions when |
615 | /// possible. |
616 | /// |
617 | /// The shape may also be further modified if we create sub-TensorIterators, |
618 | /// e.g., via narrow or select_all_keeping_dim. |
619 | DimVector shape_; |
620 | |
621 | /// Temporarily records the permutation computed by reorder_dimensions. |
622 | /// This permutation maps the computation output dimension (dim) to |
623 | /// the original true output dimension (perm_[dim]). It is used by |
624 | /// invert_perm to undo the permutation. After coalesce_dimensions is |
625 | /// called, the permutation is no longer valid (as, in general, there |
626 | /// is no permutation that will make computation dimensions to |
627 | /// output dimensions); methods that manipulate perm_ are obligated |
628 | /// to test that !has_coalesced_dimensions |
629 | DimVector perm_; |
630 | |
631 | /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build()) |
632 | /// been called? This is SOLELY used to check validity of perm_. |
633 | bool has_coalesced_dimensions_ = false; |
634 | |
635 | /// Whether iteration must be fixed. This disables dimension permuting and |
636 | /// also changes how for_each divides work among threads. |
637 | bool enforce_linear_iteration_ = false; |
638 | |
639 | /// The index offsets into the original tensors for each dimension. |
640 | /// This is only non-zero when you narrow() a TensorIterator (e.g., |
641 | /// when you make sub-TensorIterators). |
642 | DimVector view_offsets_; |
643 | |
644 | /// The computed names of the output tensor. Computed by compute_names() |
645 | NameVector names_; |
646 | |
647 | /// The operands of the TensorIterator: both the inputs and outputs. The |
648 | /// outputs MUST come first in the operands_ list. There is always an |
649 | /// operand for each output of the TensorIterator, even if TensorIterator |
650 | /// will ultimately be responsible for allocating the output; in those |
651 | /// cases, tensor is simply undefined (and will be populated later |
652 | /// during build()). |
653 | /// |
654 | /// This list is initially populated prior to build(), but build() mutates |
655 | /// OperandInfo to populate more information. |
656 | SmallVector<OperandInfo, 4> operands_; |
657 | |
658 | /// Number of outputs in operands_ (the length of the outputs prefix |
659 | /// in operands_). |
660 | int num_outputs_ = 0; |
661 | |
662 | /// Whether or not all operands have the same shape and are 1d+. Having all |
663 | /// the same shape affects whether or not the iterator is eligible for fast |
664 | /// setup. |
665 | bool all_ops_same_shape_ = false; |
666 | /// Whether or not all operands are 0d, this affects type promotion |
667 | bool all_ops_are_scalars_ = false; |
668 | |
669 | /// The "computation" dtype of TensorIterator, specifying what the dtype |
670 | /// we will do the internal computation in TensorIterator. Typically, |
671 | /// this matches the dtype of the output tensors, but not always! |
672 | ScalarType common_dtype_ = ScalarType::Undefined; |
673 | |
674 | /// This is currently defined as kCPU, or the device of the first non-CPU |
675 | /// tensor argument. See TensorIteratorBase::compute_types for details. |
676 | Device common_device_ = kCPU; |
677 | |
678 | /// Set by split(), see should_accumulate() and is_final_output() |
679 | bool accumulate_ = false; |
680 | bool final_output_ = true; |
681 | |
682 | // From TensorIteratorConfig |
683 | bool is_reduction_ = false; |
684 | |
685 | /// Set by populate_operands(), says if we're handling meta tensors |
686 | bool is_meta_ = false; |
687 | }; |
688 | |
689 | struct TORCH_API TensorIterator final : public TensorIteratorBase { |
690 | TensorIterator() : TensorIteratorBase() {} |
691 | // Slicing is OK, TensorIterator guaranteed NOT to have any fields |
692 | TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {} |
693 | |
694 | #define TORCH_DISALLOW_TEMPORARIES(methodname) \ |
695 | TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, static) |
696 | |
697 | static TensorIterator binary_float_op( |
698 | TensorBase& out, |
699 | const TensorBase& a, |
700 | const TensorBase& b); |
701 | static TensorIterator binary_op( |
702 | TensorBase& out, |
703 | const TensorBase& a, |
704 | const TensorBase& b); |
705 | static TensorIterator borrowing_binary_op( |
706 | const TensorBase& out, |
707 | const TensorBase& a, |
708 | const TensorBase& b); |
709 | TORCH_DISALLOW_TEMPORARIES(borrowing_binary_op) |
710 | static TensorIterator comparison_op( |
711 | TensorBase& out, |
712 | const TensorBase& a, |
713 | const TensorBase& b); |
714 | static TensorIterator unary_op(TensorBase& out, const TensorBase& a); |
715 | static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a); |
716 | static TensorIterator nullary_op(TensorBase& out); |
717 | static TensorIterator borrowing_nullary_op(const TensorBase& out); |
718 | static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete; |
719 | static TensorIterator reduce_op(TensorBase& out, const TensorBase& a); |
720 | static TensorIterator reduce_op( |
721 | TensorBase& out1, |
722 | TensorBase& out2, |
723 | const TensorBase& a); |
724 | #undef TORCH_DISALLOW_TEMPORARIES |
725 | #undef TORCH_DISALLOW_TEMPORARIES_IMPL |
726 | |
727 | const Tensor& maybe_get_output(int64_t output_idx) override; |
728 | void set_output_raw_strided( |
729 | int64_t output_idx, |
730 | IntArrayRef sizes, |
731 | IntArrayRef strides, |
732 | TensorOptions options, |
733 | DimnameList names) override; |
734 | }; |
735 | |
736 | class TORCH_API TensorIteratorConfig final { |
737 | public: |
738 | friend struct TensorIteratorBase; |
739 | friend struct TensorIterator; |
740 | |
741 | TensorIteratorConfig() = default; |
742 | |
743 | C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig); |
744 | |
745 | /// Construction |
746 | // Stores input/output Tensors without incrementing the reference count. |
747 | // Important: the outputs have to be added before the inputs. |
748 | TensorIteratorConfig& add_output(const TensorBase& output) { |
749 | return add_borrowed_output(output); |
750 | } |
751 | TensorIteratorConfig& add_input(const TensorBase& input) { |
752 | return add_borrowed_input(input); |
753 | } |
754 | |
755 | // Borrowing from temporaries is unlikely to go well. |
756 | TensorIteratorConfig& add_output(TensorBase&& output) = delete; |
757 | TensorIteratorConfig& add_input(TensorBase&& input) = delete; |
758 | |
759 | // Stores input/output Tensors while incrementing the reference count. |
760 | // Note that add_{in,out}put are nearly always what you |
761 | // want, and the exception (adding an unnamed temporary) won't |
762 | // compile. |
763 | TensorIteratorConfig& add_owned_output(const TensorBase& output); |
764 | TensorIteratorConfig& add_owned_input(const TensorBase& input); |
765 | |
766 | // Advanced API: stores input/output Tensors without incrementing |
767 | // the reference count. The caller must ensure that these Tensors |
768 | // live at least as long as this TensorIteratorConfig and any |
769 | // TensorIteratorBase built from this TensorIteratorConfig. |
770 | // Important: the outputs have to be added before the inputs. |
771 | TensorIteratorConfig& add_borrowed_output(const TensorBase& output); |
772 | TensorIteratorConfig& add_borrowed_input(const TensorBase& input); |
773 | |
774 | // Borrowing from temporaries is unlikely to go well. |
775 | TensorIteratorConfig& add_borrowed_output(TensorBase&& output) = delete; |
776 | TensorIteratorConfig& add_borrowed_input(TensorBase&& input) = delete; |
777 | |
778 | // Sets the check_mem_overlap_ flag, which is true by default. |
779 | // If true, inputs are checked for partial overlap with the outputs and |
780 | // outputs are checked for internal overlap (e.g. broadcasted views). An error |
781 | // is raised if unacceptable overlap is detected. |
782 | // If you're migrating an existing operator to using TensorIterator, please |
783 | // consider if the previous implementation checked memory overlap. If it did |
784 | // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then |
785 | // checking memory overlap is BC-breaking. Please don't check memory overlap |
786 | // in that case. |
787 | TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap) { |
788 | check_mem_overlap_ = check_mem_overlap; |
789 | return *this; |
790 | } |
791 | |
792 | // Sets the check_all_same_dtype_ flag, which is true by default |
793 | // If true, checks that all inputs and defined outputs have the same dtype |
794 | // Setting either of promote_inputs_to_common_dtype_ |
795 | // or cast_common_dtype_to_outputs_ to true will set |
796 | // check_all_same_dtype_ to false. |
797 | TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype) { |
798 | check_all_same_dtype_ = _check_all_same_dtype; |
799 | return *this; |
800 | } |
801 | |
802 | // Sets the check_all_same_device_ flag, which is true by default |
803 | // If true, all operands must be on the same device, with the possible |
804 | // exception of CPU scalars, which can be passed to some CUDA kernels |
805 | // as kernel arguments. |
806 | TensorIteratorConfig& check_all_same_device( |
807 | const bool _check_all_same_device) { |
808 | check_all_same_device_ = _check_all_same_device; |
809 | return *this; |
810 | } |
811 | |
812 | // Sets the enforce_safe_casting_to_output_ flag, which is false by default |
813 | // If true, the iterator's "common dtype" must be computable |
814 | // (see the [Common Dtype Computation] note) and |
815 | // canCast(common dtype, output dtype) must be true for all outputs. |
816 | TensorIteratorConfig& enforce_safe_casting_to_output( |
817 | const bool _enforce_safe_casting_to_output) { |
818 | enforce_safe_casting_to_output_ = _enforce_safe_casting_to_output; |
819 | return *this; |
820 | } |
821 | |
822 | // Sets the enforce_linear_iteration_ flag, which is false by default. |
823 | // If true, iteration goes in the same order as a C-contiguous tensor |
824 | // is layed out in memory. i.e. last dimension iterates fastest. |
825 | // |
826 | // This iteration order can be less efficient and may even prevent |
827 | // vectorization. So only use if the correctness of your kernel depends on it. |
828 | TensorIteratorConfig& enforce_linear_iteration( |
829 | const bool _enforce_linear_iteration = true) { |
830 | enforce_linear_iteration_ = _enforce_linear_iteration; |
831 | return *this; |
832 | } |
833 | |
834 | // Sets the promote_inputs_to_common_dtype_ flag, which is false by default |
835 | // If true, the iterator's "common dtype" is always computed (see the |
836 | // [Common Dtype Computation] note) and, on the CPU, temporary copies of |
837 | // the inputs in the common dtype are passed as the actual inputs to |
838 | // the operation. |
839 | // Setting this flag to true sets check_all_same_dtype_ to false. |
840 | TensorIteratorConfig& promote_inputs_to_common_dtype( |
841 | const bool _promote_inputs_to_common_dtype) { |
842 | promote_inputs_to_common_dtype_ = _promote_inputs_to_common_dtype; |
843 | if (_promote_inputs_to_common_dtype) { |
844 | check_all_same_dtype_ = false; |
845 | } |
846 | return *this; |
847 | } |
848 | |
849 | // Sets the promote_integer_inputs_to_float_ flag, which is false by default |
850 | // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be |
851 | // true. If true, if the iterator's "common dtype" is an integral type |
852 | // (including bool) |
853 | // then it is changed to the default float scalar type. |
854 | TensorIteratorConfig& promote_integer_inputs_to_float( |
855 | const bool _promote_integer_inputs_to_float) { |
856 | promote_integer_inputs_to_float_ = _promote_integer_inputs_to_float; |
857 | TORCH_INTERNAL_ASSERT( |
858 | !promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_); |
859 | return *this; |
860 | } |
861 | |
862 | TensorIteratorConfig& is_reduction(const bool _is_reduction) { |
863 | is_reduction_ = _is_reduction; |
864 | return *this; |
865 | } |
866 | |
867 | TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars) { |
868 | allow_cpu_scalars_ = _allow_cpu_scalars; |
869 | return *this; |
870 | } |
871 | |
872 | // Sets the cast_common_dtype_to_outputs_ flag, which is false by default |
873 | // If true, the iterator's "common dtype" must be computatable |
874 | // (see the [Common Dtype Computation] note) and, on the CPU, temporary |
875 | // copies of the outputs are passed as the actual output to the operation. |
876 | // These temporaries are then copied to the original outputs after |
877 | // the operation is performed (see cast_outputs()). |
878 | // Setting this flag to true sets check_all_same_dtype_ to false. |
879 | TensorIteratorConfig& cast_common_dtype_to_outputs( |
880 | const bool _cast_common_dtype_to_outputs) { |
881 | cast_common_dtype_to_outputs_ = _cast_common_dtype_to_outputs; |
882 | if (_cast_common_dtype_to_outputs) { |
883 | check_all_same_dtype_ = false; |
884 | } |
885 | return *this; |
886 | } |
887 | |
888 | TensorIteratorConfig& resize_outputs(bool resize_outputs) { |
889 | resize_outputs_ = resize_outputs; |
890 | return *this; |
891 | } |
892 | |
893 | // Bypass output dtype/device computation and fix the dtype/device as |
894 | // specified here. |
895 | TensorIteratorConfig& declare_static_dtype_and_device( |
896 | ScalarType dtype, |
897 | Device device); |
898 | TensorIteratorConfig& declare_static_dtype(ScalarType dtype); |
899 | TensorIteratorConfig& declare_static_device(Device device); |
900 | TensorIteratorConfig& declare_static_shape(IntArrayRef shape); |
901 | TensorIteratorConfig& declare_static_shape( |
902 | IntArrayRef shape, |
903 | IntArrayRef squash_dims); |
904 | |
905 | // It would be better if this was && qualified, but this would be at the cost |
906 | // of a lot of boilerplate above |
907 | TensorIterator build() { |
908 | TensorIterator iter; |
909 | iter.build(*this); |
910 | return iter; |
911 | } |
912 | |
913 | private: |
914 | SmallVector<c10::MaybeOwned<TensorBase>, 4> tensors_; |
915 | int num_outputs_ = 0; |
916 | int num_inputs_ = 0; |
917 | |
918 | c10::optional<DimVector> static_shape_ = c10::nullopt; |
919 | c10::optional<ScalarType> static_dtype_ = c10::nullopt; |
920 | c10::optional<Device> static_device_ = c10::nullopt; |
921 | bool check_mem_overlap_ = true; |
922 | bool allow_cpu_scalars_ = false; |
923 | bool is_reduction_ = false; |
924 | bool resize_outputs_ = true; |
925 | bool check_all_same_dtype_ = true; |
926 | bool check_all_same_device_ = true; |
927 | bool enforce_safe_casting_to_output_ = false; |
928 | bool enforce_linear_iteration_ = false; |
929 | bool promote_inputs_to_common_dtype_ = false; |
930 | bool promote_integer_inputs_to_float_ = false; |
931 | bool cast_common_dtype_to_outputs_ = false; |
932 | }; |
933 | |
934 | /// A container-like struct that acts as if it contains splits of a |
935 | /// TensorIterator that can use 32-bit indexing. Taken together the splits cover |
936 | /// the original TensorIterator. |
937 | struct TORCH_API SplitUntil32Bit { |
938 | struct TORCH_API iterator { |
939 | iterator() = default; |
940 | iterator(const TensorIteratorBase& iter); |
941 | iterator(iterator&&) = default; |
942 | |
943 | // Guaranteed to be a TensorIterator proper! |
944 | TensorIterator& operator*() const; |
945 | iterator& operator++(); |
946 | bool operator==(const iterator& other) const { |
947 | // two iterators are equal if they are the same object or they're both |
948 | // empty |
949 | return this == &other || (vec.empty() && other.vec.empty()); |
950 | } |
951 | // needed for C++11 range-based for loop |
952 | bool operator!=(const iterator& other) const { |
953 | return !(*this == other); |
954 | } |
955 | |
956 | /// stack of TensorIterators to be split |
957 | std::vector<std::unique_ptr<TensorIterator>> vec; |
958 | }; |
959 | |
960 | SplitUntil32Bit(const TensorIteratorBase& iter) : iter(iter) {} |
961 | |
962 | iterator begin() const; |
963 | iterator end() const; |
964 | |
965 | private: |
966 | const TensorIteratorBase& iter; |
967 | }; |
968 | |
969 | } // namespace at |
970 | |
971 | C10_CLANG_DIAGNOSTIC_POP() |
972 | |