1#pragma once
2
3#include <ATen/TensorMeta.h>
4#include <ATen/core/Dimname.h>
5#include <ATen/core/Range.h>
6#include <ATen/core/TensorBase.h>
7#include <c10/core/DynamicCast.h>
8#include <c10/util/FunctionRef.h>
9#include <c10/util/MaybeOwned.h>
10#include <c10/util/SmallVector.h>
11#include <c10/util/TypeCast.h>
12#include <c10/util/irange.h>
13
14#include <array>
15#include <bitset>
16
17C10_CLANG_DIAGNOSTIC_PUSH()
18#if C10_CLANG_HAS_WARNING("-Wshorten-64-to-32")
19C10_CLANG_DIAGNOSTIC_IGNORE("-Wshorten-64-to-32")
20#endif
21#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy-dtor")
22C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy-dtor")
23#endif
24
25namespace at {
26class Tensor;
27class OptionalTensorRef;
28using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
29} // namespace at
30
31// TensorIterator is a helper class for element-wise operations, such as
32// arithmetic, comparisons, and trigonometric functions. It handles
33// broadcasting and type conversions of operands.
34//
35// This is inspired by NumPy's Array Iterator API (NpyIter).
36//
37// The files Loops.h and Loops.cuh provide functions to build kernels that
38// use TensorIterator.
39//
40// Example:
41//
42// auto iter = TensorIteratorConfig()
43// .add_output(output)
44// .add_input(input)
45// .build()
46//
47// [MyKernel.cpp / MyKernel.cu]
48// cpu_kernel(iter, [](float a, float b) {
49// return a + b;
50// });
51//
52// gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float {
53// return a + b;
54// });
55//
56// Note [Order of Construction]
57// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
58// When setting up the tensor iterator configuration, the output Tensors
59// have to be added first via
60// TensorIteratorConfig::add_owned_output(at::Tensor). After adding all outputs,
61// the inputs can be added via
62// TensorIteratorConfig::add_owned_input(at::Tensor).
63// Adding another output after inputs have been added will rise an exception.
64//
65// Note [Common Dtype Computation]
66// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
67// Some operations have a natural notion of a "common dtype" or
68// "computation dtype" where all inputs are cast to one dtype, the
69// operation is performed, and then the results are cast to all outputs.
70//
71// TensorIterator infers a common dtype if all inputs have the same dtype,
72// and it computes one using type promotion rules on its inputs if
73// promote_inputs_to_common_dtype_ is true. Attempting to query
74// a common dtype otherwise will throw an exception.
75//
76// Note that the outputs are not considered when computing a common dtype.
77
78namespace at {
79
80namespace internal {
81// This parameter is heuristically chosen to determine the minimum number of
82// work that warrants parallelism. For example, when summing an array, it is
83// deemed inefficient to parallelise over arrays shorter than 32768. Further,
84// no parallel algorithm (such as parallel_reduce) should split work into
85// smaller than GRAIN_SIZE chunks.
86constexpr int64_t GRAIN_SIZE = 32768;
87
88// Storage for a non-owning Tensor, without needing to include Tensor.h
89class TORCH_API OpaqueOptionalTensorRef {
90 alignas(alignof(TensorBase)) std::array<char, sizeof(TensorBase)> data_;
91
92 public:
93 OpaqueOptionalTensorRef();
94 ~OpaqueOptionalTensorRef();
95
96 OptionalTensorRef* get() {
97 return reinterpret_cast<OptionalTensorRef*>(data_.data());
98 }
99 const OptionalTensorRef* get() const {
100 return reinterpret_cast<const OptionalTensorRef*>(data_.data());
101 }
102
103 OptionalTensorRef& operator*() {
104 return *get();
105 }
106 const OptionalTensorRef& operator*() const {
107 return *get();
108 }
109 OptionalTensorRef* operator->() {
110 return get();
111 }
112 const OptionalTensorRef* operator->() const {
113 return get();
114 }
115
116 const Tensor& getTensor() const;
117};
118} // namespace internal
119
120struct TORCH_API OperandInfo {
121 using StrideVector = SmallVector<int64_t, 6>;
122 OperandInfo() = default;
123 C10_ALWAYS_INLINE explicit OperandInfo(c10::MaybeOwned<TensorBase>&& t) {
124 if (t->defined()) {
125 device = t->device();
126 target_dtype = t->scalar_type();
127 current_dtype = target_dtype;
128 }
129 tensor(std::move(t));
130 validate();
131 }
132
133 C10_ALWAYS_INLINE ~OperandInfo() = default;
134
135 /// Stride after broadcasting. The stride is in bytes, not number of elements.
136 StrideVector stride_bytes;
137
138 /// The desired device and type for the operand. For inputs, this specifies
139 /// that the input should be converted to this type if necessary. For outputs,
140 /// this specifies which type to allocate. target_dtype and device are
141 /// initialized with the dtype and device of the tensor but during type
142 /// promotion target_dtype value can become different from tensor's dtype
143 /// also, during type promotion target_dtype and device can be set for an
144 /// undefined tensor so that tensor can be properly constructed later.
145 c10::optional<Device> device = c10::nullopt;
146 ScalarType target_dtype = ScalarType::Undefined;
147 // Caches dtype of the tensor, because scalar_type is an expensive operation
148 // If dtype of the tensor is changed (e.g. as a result of type promotion or in
149 // allocate_outputs), this
150 // value should be changed too.
151 ScalarType current_dtype = ScalarType::Undefined;
152
153 bool is_device_defined() const {
154 return device.has_value();
155 }
156 bool is_type_defined() const {
157 return target_dtype != ScalarType::Undefined;
158 }
159 TensorOptions options() const {
160 return TensorOptions(target_dtype).device(device);
161 }
162
163 /// The data pointer. This may be different from tensor->data_ptr() if the
164 /// iterator is split.
165 void* data = nullptr;
166
167 bool is_output = false;
168
169 bool will_resize = false;
170
171 bool is_read_write = false;
172
173 void validate() {
174 TORCH_CHECK(
175 !tensor_base_->defined() || tensor_base_->layout() == kStrided,
176 "unsupported tensor layout: ",
177 tensor_base_->layout());
178 }
179
180 /// The tensor operand. Note that the strides, data pointer, and
181 /// other attributes may differ due to dimension reordering and
182 /// coalescing.
183 const Tensor& tensor() const {
184 return tensor_storage_.getTensor();
185 }
186 const TensorBase& tensor_base() const {
187 return *tensor_base_;
188 }
189 void tensor(c10::MaybeOwned<TensorBase>&& tensor);
190
191 // Save the original tensor operand in cases when an output is modified
192 // (e.g. if dtype is changed)
193 const Tensor& original_tensor() const {
194 return original_tensor_storage_.getTensor();
195 }
196 const TensorBase& original_tensor_base() const {
197 return *original_tensor_base_;
198 }
199
200 // Set tensor to a new value, and store the old tensor value in
201 // original_tensor Should only ever be called once for the lifetime of an
202 // operand
203 void exchange_tensor(c10::MaybeOwned<TensorBase>&& new_tensor);
204
205 // Move original_tensor back into tensor, exchange_tensor must have been
206 // called before
207 void restore_original_tensor();
208
209 private:
210 c10::MaybeOwned<TensorBase> tensor_base_;
211 c10::MaybeOwned<TensorBase> original_tensor_base_ =
212 c10::MaybeOwned<TensorBase>::owned(c10::in_place);
213
214 // We store TensorBase visibly in the header to allow inline access.
215 // However, we sometimes need a genuine `const Tensor &` for the
216 // TensorIterator API. So, we also store a non-owning `Tensor`
217 // object in these `_storage_` variables.
218 internal::OpaqueOptionalTensorRef tensor_storage_;
219 internal::OpaqueOptionalTensorRef original_tensor_storage_;
220};
221
222struct SplitUntil32Bit;
223
224enum class FastSetupType : uint8_t {
225 NONE,
226 CONTIGUOUS,
227 CHANNELS_LAST,
228 NON_OVERLAPPING_DENSE
229};
230
231class TensorIteratorConfig;
232struct TensorIterator;
233
234struct TORCH_API TensorIteratorBase : public impl::MetaBase {
235 using DimMask = std::bitset<64>;
236 using PtrVector = SmallVector<char*, 4>;
237 using StrideVector = SmallVector<int64_t, 6>;
238
239 TensorIteratorBase();
240 void build(TensorIteratorConfig&);
241
242 // The inner-loop function operates on the fastest moving dimension. It
243 // implements element-wise operations in terms of 1-d strided tensors.
244 //
245 // Arguments:
246 // data: data pointers for each operand (length `ntensors`)
247 // strides: stride for each operand (length `ntensors`)
248 // size: size of inner loop
249 //
250 // The `size` often matches shape[0], but may be smaller due to
251 // parallelization of the inner loop.
252 using loop2d_t = c10::function_ref<
253 void(char** data, const int64_t* strides, int64_t size0, int64_t size1)>;
254
255 using loop_subiter_t = c10::function_ref<void(TensorIteratorBase& subiter)>;
256
257 void foreach_reduced_elt(loop_subiter_t loop, bool parallelize = true);
258
259 int ndim() const {
260 return shape_.size();
261 }
262 IntArrayRef shape() const {
263 return shape_;
264 }
265 int64_t numel() const;
266 int ntensors() const {
267 return operands_.size();
268 }
269 int noutputs() const {
270 return num_outputs_;
271 }
272 int ninputs() const {
273 return ntensors() - noutputs();
274 }
275 IntArrayRef view_offsets() const {
276 return view_offsets_;
277 }
278
279 /// number of elements in the output operand. this is the same as numel() for
280 /// operations that are not reductions.
281 int64_t num_output_elements() const;
282
283 /// number of reduced dimensions in a reduction operation
284 int num_reduce_dims() const;
285
286 /// 1-dimensional iteration and no buffering or type conversion
287 bool is_trivial_1d() const;
288 /// Reducible to 1-dimensional and all operands are contiguous
289 bool is_contiguous() const;
290 bool is_dim_reduced(int dim) const;
291
292 /// Accessors for each operand
293 IntArrayRef strides(int arg) const {
294 return operands_[arg].stride_bytes;
295 }
296 void* data_ptr(int arg) const;
297 ScalarType dtype(int arg = 0) const {
298 return operands_[arg].current_dtype;
299 }
300 ScalarType common_dtype() const {
301 TORCH_INTERNAL_ASSERT(
302 common_dtype_ != ScalarType::Undefined,
303 "Queried for invalid common dtype!");
304 return common_dtype_;
305 }
306 ScalarType input_dtype(int arg = 0) const {
307 return operands_[num_outputs_ + arg].current_dtype;
308 }
309 Device device(int arg = 0) const {
310 return operands_[arg].device.value();
311 }
312 DeviceType device_type(int arg = 0) const {
313 return device(arg).type();
314 }
315 int64_t element_size(int arg) const {
316 return elementSize(dtype(arg));
317 }
318 bool is_scalar(int arg) const;
319 bool is_cpu_scalar(int arg) const;
320
321 const TensorBase& tensor_base(int arg) const {
322 return operands_[arg].tensor_base();
323 }
324 const Tensor& tensor(int arg) const {
325 return operands_[arg].tensor();
326 }
327
328 const TensorBase& output_base(int arg = 0) const {
329 AT_ASSERT(arg < num_outputs_);
330 return tensor_base(arg);
331 }
332
333 const Tensor& output(int arg = 0) const {
334 AT_ASSERT(arg < num_outputs_);
335 return tensor(arg);
336 }
337
338 const TensorBase& input_base(int arg = 0) const {
339 AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
340 return tensor_base(num_outputs_ + arg);
341 }
342 const Tensor& input(int arg = 0) const {
343 AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
344 return tensor(num_outputs_ + arg);
345 }
346
347 // Copies from temporary outputs back to the original outputs
348 // NOTE: only used on CPU
349 void cast_outputs();
350
351 /// Removes an operand from this iterator
352 void remove_operand(int arg);
353 /// Shrinks an iterated dimension
354 void narrow(int dim, int64_t start, int64_t size);
355 /// Narrows every dim after and including `start_dim` to size one.
356 void select_all_keeping_dim(int start_dim, IntArrayRef starts);
357 /// Replaces the data pointer for the operand at index `arg`.
358 /// The new pointer should have the same sizes, strides and dtype as the
359 /// original
360 void unsafe_replace_operand(int arg, void* data);
361
362 /// Splits this TensorIterator into two iterators. Together they iterate over
363 /// the entire operation. Used by `with_32bit_indexing()`.
364 std::unique_ptr<TensorIterator> split(int dim);
365
366 /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim]
367 int get_dim_to_split() const;
368
369 template <typename T>
370 T scalar_value(int arg) {
371 auto& op = operands_[arg];
372 return c10::fetch_and_cast<T>(op.tensor_base().scalar_type(), op.data);
373 }
374
375 private:
376 template <typename loop1d_t>
377 auto loop_2d_from_1d(const loop1d_t& loop) {
378 return
379 [loop, ntensor = ntensors()](
380 char** base, const int64_t* strides, int64_t size0, int64_t size1) {
381 PtrVector data(base, base + ntensor);
382 const int64_t* outer_strides = &strides[ntensor];
383 for (const auto i : c10::irange(size1)) {
384 if (i > 0) {
385 for (const auto arg : c10::irange(ntensor)) {
386 data[arg] += outer_strides[arg];
387 }
388 }
389 loop(data.data(), strides, size0);
390 }
391 };
392 }
393
394 public:
395 template <
396 typename loop1d_t,
397 std::enable_if_t<
398 std::is_convertible<
399 loop1d_t,
400 c10::function_ref<
401 void(char**, const int64_t* strides, int64_t size)>>::value,
402 int> = 0>
403 void for_each(loop1d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE) {
404 for_each(loop_2d_from_1d(loop), grain_size);
405 }
406
407 void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE);
408
409 void parallel_reduce(loop2d_t loop);
410
411 template <
412 typename loop1d_t,
413 std::enable_if_t<
414 std::is_convertible<
415 loop1d_t,
416 c10::function_ref<
417 void(char**, const int64_t* strides, int64_t size)>>::value,
418 int> = 0>
419 void serial_for_each(loop1d_t loop, Range range) {
420 serial_for_each(loop_2d_from_1d(loop), range);
421 }
422
423 void serial_for_each(loop2d_t loop, Range range) const;
424
425 /// Create a strides array for a Tensor with shape of this iterator. The
426 /// parameter `element_size` specifies the size of Tensor's data type in
427 /// bytes (e.g. `4` for `float`)
428 StrideVector compatible_stride(int element_size) const;
429
430 /// Inverts the re-ordering done by reorder_dimensions. This can only be
431 /// called *before* coalesce_dimensions() is called.
432 DimVector invert_perm(IntArrayRef input) const;
433
434 /// Reapply same re-ordering as it is done by reorder_dimensions. This can
435 /// only be called *before* coalesce_dimensions() is called.
436 DimVector apply_perm_and_mul(IntArrayRef input, int mul) const;
437
438 /// Helper functions for CPU iteration
439 StrideVector get_dim_strides(int dim) const;
440 StrideVector get_strides() const;
441 StrideVector get_inner_strides() const {
442 return get_dim_strides(0);
443 }
444 PtrVector get_base_ptrs() const;
445
446 // Helper functions for advanced stride manipulations (e.g. torch.flip)
447 void _unsafe_set_arg_strides(const int arg, IntArrayRef strides) {
448 operands_[arg].stride_bytes = std::move(strides);
449 }
450 void _unsafe_set_arg_data(const int arg, void* data) {
451 operands_[arg].data = data;
452 }
453
454 /// true if the stride computation can use 32-bit arithmetic. Used by GPU
455 /// kernels
456 bool can_use_32bit_indexing() const;
457
458 /// An "iteratable" object that recursively splits this iterator into
459 /// sub-iterators that can use 32-bit indexing.
460 SplitUntil32Bit with_32bit_indexing() const;
461
462 /// If the kernel should accumulate into the output. Only relevant for CUDA
463 /// reductions.
464 bool should_accumulate() const {
465 return accumulate_;
466 }
467
468 /// Whether this iterator produces the actual output,
469 /// as opposed to something that will be accumulated further. Only relevant
470 /// for CUDA reductions.
471 bool is_final_output() const {
472 return final_output_;
473 }
474
475 bool has_contiguous_first_dim() const {
476 if (ndim() == 0) {
477 return true;
478 }
479
480 int num_tensors = ntensors();
481 for (const auto i : c10::irange(num_tensors)) {
482 if (strides(i)[0] != element_size(i)) {
483 return false;
484 }
485 }
486 return true;
487 }
488
489 void set_output_raw_strided(
490 int64_t output_idx,
491 IntArrayRef sizes,
492 IntArrayRef strides,
493 TensorOptions options,
494 DimnameList names) override;
495
496#define TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, maybestatic) \
497 maybestatic void methodname( \
498 TensorBase&& out, const TensorBase& a, const TensorBase& b) = delete; \
499 maybestatic void methodname( \
500 const TensorBase& out, TensorBase&& a, const TensorBase& b) = delete; \
501 maybestatic void methodname( \
502 const TensorBase& out, const TensorBase& a, TensorBase&& b) = delete; \
503 maybestatic void methodname( \
504 TensorBase&& out, TensorBase&& a, const TensorBase& b) = delete; \
505 maybestatic void methodname( \
506 TensorBase&& out, const TensorBase& a, TensorBase&& b) = delete; \
507 maybestatic void methodname( \
508 const TensorBase& out, TensorBase&& a, TensorBase&& b) = delete; \
509 maybestatic void methodname( \
510 TensorBase&& out, TensorBase&& a, TensorBase&& b) = delete;
511
512#define TORCH_DISALLOW_TEMPORARIES(methodname) \
513 TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, )
514
515 void build_binary_float_op(
516 const TensorBase& out,
517 const TensorBase& a,
518 const TensorBase& b);
519 void build_borrowing_binary_float_op(
520 const TensorBase& out,
521 const TensorBase& a,
522 const TensorBase& b);
523 TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_float_op)
524 void build_binary_op(
525 const TensorBase& out,
526 const TensorBase& a,
527 const TensorBase& b);
528 void build_borrowing_binary_op(
529 const TensorBase& out,
530 const TensorBase& a,
531 const TensorBase& b);
532 TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op)
533 void build_unary_float_op(const TensorBase& out, const TensorBase& a);
534 void build_borrowing_unary_float_op(
535 const TensorBase& out,
536 const TensorBase& a);
537 TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op)
538 void build_unary_op(const TensorBase& out, const TensorBase& a);
539 // Odd special case needed for pow. Has to borrow the output because
540 // it's a structured kernel, but the argument is potentially a copy.
541 void build_output_borrowing_argument_owning_unary_op(
542 const TensorBase& out,
543 const TensorBase& a);
544 void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a);
545 TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op)
546 void build_borrowing_unary_force_boolean_op(
547 const TensorBase& out,
548 const TensorBase& a);
549 TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op)
550 void build_comparison_op(
551 const TensorBase& out,
552 const TensorBase& a,
553 const TensorBase& b);
554 void build_borrowing_comparison_op(
555 const TensorBase& out,
556 const TensorBase& a,
557 const TensorBase& b);
558 TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op)
559 // Another special case: we need to own the second argument for comparison
560 // ops.
561 void build_borrowing_except_last_argument_comparison_op(
562 const TensorBase& out,
563 const TensorBase& a,
564 const TensorBase& b);
565 void build_ternary_op(
566 const TensorBase& out,
567 const TensorBase& a,
568 const TensorBase& b,
569 const TensorBase& c);
570
571#undef TORCH_DISALLOW_TEMPORARIES
572 protected:
573 // Mutable reference as it moves tensors out of TensorIteratorConfig
574 void populate_operands(TensorIteratorConfig&);
575 void mark_outputs();
576 void mark_resize_outputs(const TensorIteratorConfig&);
577 void compute_mem_overlaps(const TensorIteratorConfig&);
578 void compute_shape(const TensorIteratorConfig&);
579 void compute_strides(const TensorIteratorConfig&);
580 void reorder_dimensions();
581 void permute_dimensions(IntArrayRef perm);
582 void compute_types(const TensorIteratorConfig&);
583 ScalarType compute_common_dtype();
584 void allocate_or_resize_outputs();
585 bool fast_set_up(const TensorIteratorConfig&);
586 FastSetupType compute_fast_setup_type(const TensorIteratorConfig&);
587 void compute_names(const TensorIteratorConfig&);
588 void propagate_names_to_outputs();
589 void coalesce_dimensions();
590
591 protected:
592 /// Records the "computation" shape of the output tensor. The computation
593 /// shape is different from the regular shape in a few ways:
594 ///
595 /// - The shape may be permuted (via permute_dimensions) so that we
596 /// process the dimensions in the most computationally efficient order
597 /// (rather than the logical order given to us by the users.)
598 /// - The shape may have adjacent dimensions collapsed (via
599 /// coalesce_dimensions) so that we minimize the number of
600 /// dimensions we have to explicitly iterate over. For example,
601 /// a pointwise operation on a contiguous tensor "computationally"
602 /// consists of only a single dimension.
603 ///
604 /// In other words, the computation shape is the output shape as it
605 /// actually matters for implementing the kernel, but not necessarily the
606 /// output shape that the user will see in the end.
607 ///
608 /// The lifecycle of mutations to shape_ in TensorIterator:
609 /// - declare_static_shape() sets an initial shape explicitly
610 /// provided by user, otherwise
611 /// - compute_shape() computes the true (non-computational) shape
612 /// specified by the user.
613 /// - reorder_dimensions() reorders dimensions to improve coalescing.
614 /// - coalesce_dimensions() then coalesces adjacent dimensions when
615 /// possible.
616 ///
617 /// The shape may also be further modified if we create sub-TensorIterators,
618 /// e.g., via narrow or select_all_keeping_dim.
619 DimVector shape_;
620
621 /// Temporarily records the permutation computed by reorder_dimensions.
622 /// This permutation maps the computation output dimension (dim) to
623 /// the original true output dimension (perm_[dim]). It is used by
624 /// invert_perm to undo the permutation. After coalesce_dimensions is
625 /// called, the permutation is no longer valid (as, in general, there
626 /// is no permutation that will make computation dimensions to
627 /// output dimensions); methods that manipulate perm_ are obligated
628 /// to test that !has_coalesced_dimensions
629 DimVector perm_;
630
631 /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build())
632 /// been called? This is SOLELY used to check validity of perm_.
633 bool has_coalesced_dimensions_ = false;
634
635 /// Whether iteration must be fixed. This disables dimension permuting and
636 /// also changes how for_each divides work among threads.
637 bool enforce_linear_iteration_ = false;
638
639 /// The index offsets into the original tensors for each dimension.
640 /// This is only non-zero when you narrow() a TensorIterator (e.g.,
641 /// when you make sub-TensorIterators).
642 DimVector view_offsets_;
643
644 /// The computed names of the output tensor. Computed by compute_names()
645 NameVector names_;
646
647 /// The operands of the TensorIterator: both the inputs and outputs. The
648 /// outputs MUST come first in the operands_ list. There is always an
649 /// operand for each output of the TensorIterator, even if TensorIterator
650 /// will ultimately be responsible for allocating the output; in those
651 /// cases, tensor is simply undefined (and will be populated later
652 /// during build()).
653 ///
654 /// This list is initially populated prior to build(), but build() mutates
655 /// OperandInfo to populate more information.
656 SmallVector<OperandInfo, 4> operands_;
657
658 /// Number of outputs in operands_ (the length of the outputs prefix
659 /// in operands_).
660 int num_outputs_ = 0;
661
662 /// Whether or not all operands have the same shape and are 1d+. Having all
663 /// the same shape affects whether or not the iterator is eligible for fast
664 /// setup.
665 bool all_ops_same_shape_ = false;
666 /// Whether or not all operands are 0d, this affects type promotion
667 bool all_ops_are_scalars_ = false;
668
669 /// The "computation" dtype of TensorIterator, specifying what the dtype
670 /// we will do the internal computation in TensorIterator. Typically,
671 /// this matches the dtype of the output tensors, but not always!
672 ScalarType common_dtype_ = ScalarType::Undefined;
673
674 /// This is currently defined as kCPU, or the device of the first non-CPU
675 /// tensor argument. See TensorIteratorBase::compute_types for details.
676 Device common_device_ = kCPU;
677
678 /// Set by split(), see should_accumulate() and is_final_output()
679 bool accumulate_ = false;
680 bool final_output_ = true;
681
682 // From TensorIteratorConfig
683 bool is_reduction_ = false;
684
685 /// Set by populate_operands(), says if we're handling meta tensors
686 bool is_meta_ = false;
687};
688
689struct TORCH_API TensorIterator final : public TensorIteratorBase {
690 TensorIterator() : TensorIteratorBase() {}
691 // Slicing is OK, TensorIterator guaranteed NOT to have any fields
692 TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {}
693
694#define TORCH_DISALLOW_TEMPORARIES(methodname) \
695 TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, static)
696
697 static TensorIterator binary_float_op(
698 TensorBase& out,
699 const TensorBase& a,
700 const TensorBase& b);
701 static TensorIterator binary_op(
702 TensorBase& out,
703 const TensorBase& a,
704 const TensorBase& b);
705 static TensorIterator borrowing_binary_op(
706 const TensorBase& out,
707 const TensorBase& a,
708 const TensorBase& b);
709 TORCH_DISALLOW_TEMPORARIES(borrowing_binary_op)
710 static TensorIterator comparison_op(
711 TensorBase& out,
712 const TensorBase& a,
713 const TensorBase& b);
714 static TensorIterator unary_op(TensorBase& out, const TensorBase& a);
715 static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a);
716 static TensorIterator nullary_op(TensorBase& out);
717 static TensorIterator borrowing_nullary_op(const TensorBase& out);
718 static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete;
719 static TensorIterator reduce_op(TensorBase& out, const TensorBase& a);
720 static TensorIterator reduce_op(
721 TensorBase& out1,
722 TensorBase& out2,
723 const TensorBase& a);
724#undef TORCH_DISALLOW_TEMPORARIES
725#undef TORCH_DISALLOW_TEMPORARIES_IMPL
726
727 const Tensor& maybe_get_output(int64_t output_idx) override;
728 void set_output_raw_strided(
729 int64_t output_idx,
730 IntArrayRef sizes,
731 IntArrayRef strides,
732 TensorOptions options,
733 DimnameList names) override;
734};
735
736class TORCH_API TensorIteratorConfig final {
737 public:
738 friend struct TensorIteratorBase;
739 friend struct TensorIterator;
740
741 TensorIteratorConfig() = default;
742
743 C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig);
744
745 /// Construction
746 // Stores input/output Tensors without incrementing the reference count.
747 // Important: the outputs have to be added before the inputs.
748 TensorIteratorConfig& add_output(const TensorBase& output) {
749 return add_borrowed_output(output);
750 }
751 TensorIteratorConfig& add_input(const TensorBase& input) {
752 return add_borrowed_input(input);
753 }
754
755 // Borrowing from temporaries is unlikely to go well.
756 TensorIteratorConfig& add_output(TensorBase&& output) = delete;
757 TensorIteratorConfig& add_input(TensorBase&& input) = delete;
758
759 // Stores input/output Tensors while incrementing the reference count.
760 // Note that add_{in,out}put are nearly always what you
761 // want, and the exception (adding an unnamed temporary) won't
762 // compile.
763 TensorIteratorConfig& add_owned_output(const TensorBase& output);
764 TensorIteratorConfig& add_owned_input(const TensorBase& input);
765
766 // Advanced API: stores input/output Tensors without incrementing
767 // the reference count. The caller must ensure that these Tensors
768 // live at least as long as this TensorIteratorConfig and any
769 // TensorIteratorBase built from this TensorIteratorConfig.
770 // Important: the outputs have to be added before the inputs.
771 TensorIteratorConfig& add_borrowed_output(const TensorBase& output);
772 TensorIteratorConfig& add_borrowed_input(const TensorBase& input);
773
774 // Borrowing from temporaries is unlikely to go well.
775 TensorIteratorConfig& add_borrowed_output(TensorBase&& output) = delete;
776 TensorIteratorConfig& add_borrowed_input(TensorBase&& input) = delete;
777
778 // Sets the check_mem_overlap_ flag, which is true by default.
779 // If true, inputs are checked for partial overlap with the outputs and
780 // outputs are checked for internal overlap (e.g. broadcasted views). An error
781 // is raised if unacceptable overlap is detected.
782 // If you're migrating an existing operator to using TensorIterator, please
783 // consider if the previous implementation checked memory overlap. If it did
784 // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then
785 // checking memory overlap is BC-breaking. Please don't check memory overlap
786 // in that case.
787 TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap) {
788 check_mem_overlap_ = check_mem_overlap;
789 return *this;
790 }
791
792 // Sets the check_all_same_dtype_ flag, which is true by default
793 // If true, checks that all inputs and defined outputs have the same dtype
794 // Setting either of promote_inputs_to_common_dtype_
795 // or cast_common_dtype_to_outputs_ to true will set
796 // check_all_same_dtype_ to false.
797 TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype) {
798 check_all_same_dtype_ = _check_all_same_dtype;
799 return *this;
800 }
801
802 // Sets the check_all_same_device_ flag, which is true by default
803 // If true, all operands must be on the same device, with the possible
804 // exception of CPU scalars, which can be passed to some CUDA kernels
805 // as kernel arguments.
806 TensorIteratorConfig& check_all_same_device(
807 const bool _check_all_same_device) {
808 check_all_same_device_ = _check_all_same_device;
809 return *this;
810 }
811
812 // Sets the enforce_safe_casting_to_output_ flag, which is false by default
813 // If true, the iterator's "common dtype" must be computable
814 // (see the [Common Dtype Computation] note) and
815 // canCast(common dtype, output dtype) must be true for all outputs.
816 TensorIteratorConfig& enforce_safe_casting_to_output(
817 const bool _enforce_safe_casting_to_output) {
818 enforce_safe_casting_to_output_ = _enforce_safe_casting_to_output;
819 return *this;
820 }
821
822 // Sets the enforce_linear_iteration_ flag, which is false by default.
823 // If true, iteration goes in the same order as a C-contiguous tensor
824 // is layed out in memory. i.e. last dimension iterates fastest.
825 //
826 // This iteration order can be less efficient and may even prevent
827 // vectorization. So only use if the correctness of your kernel depends on it.
828 TensorIteratorConfig& enforce_linear_iteration(
829 const bool _enforce_linear_iteration = true) {
830 enforce_linear_iteration_ = _enforce_linear_iteration;
831 return *this;
832 }
833
834 // Sets the promote_inputs_to_common_dtype_ flag, which is false by default
835 // If true, the iterator's "common dtype" is always computed (see the
836 // [Common Dtype Computation] note) and, on the CPU, temporary copies of
837 // the inputs in the common dtype are passed as the actual inputs to
838 // the operation.
839 // Setting this flag to true sets check_all_same_dtype_ to false.
840 TensorIteratorConfig& promote_inputs_to_common_dtype(
841 const bool _promote_inputs_to_common_dtype) {
842 promote_inputs_to_common_dtype_ = _promote_inputs_to_common_dtype;
843 if (_promote_inputs_to_common_dtype) {
844 check_all_same_dtype_ = false;
845 }
846 return *this;
847 }
848
849 // Sets the promote_integer_inputs_to_float_ flag, which is false by default
850 // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be
851 // true. If true, if the iterator's "common dtype" is an integral type
852 // (including bool)
853 // then it is changed to the default float scalar type.
854 TensorIteratorConfig& promote_integer_inputs_to_float(
855 const bool _promote_integer_inputs_to_float) {
856 promote_integer_inputs_to_float_ = _promote_integer_inputs_to_float;
857 TORCH_INTERNAL_ASSERT(
858 !promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_);
859 return *this;
860 }
861
862 TensorIteratorConfig& is_reduction(const bool _is_reduction) {
863 is_reduction_ = _is_reduction;
864 return *this;
865 }
866
867 TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars) {
868 allow_cpu_scalars_ = _allow_cpu_scalars;
869 return *this;
870 }
871
872 // Sets the cast_common_dtype_to_outputs_ flag, which is false by default
873 // If true, the iterator's "common dtype" must be computatable
874 // (see the [Common Dtype Computation] note) and, on the CPU, temporary
875 // copies of the outputs are passed as the actual output to the operation.
876 // These temporaries are then copied to the original outputs after
877 // the operation is performed (see cast_outputs()).
878 // Setting this flag to true sets check_all_same_dtype_ to false.
879 TensorIteratorConfig& cast_common_dtype_to_outputs(
880 const bool _cast_common_dtype_to_outputs) {
881 cast_common_dtype_to_outputs_ = _cast_common_dtype_to_outputs;
882 if (_cast_common_dtype_to_outputs) {
883 check_all_same_dtype_ = false;
884 }
885 return *this;
886 }
887
888 TensorIteratorConfig& resize_outputs(bool resize_outputs) {
889 resize_outputs_ = resize_outputs;
890 return *this;
891 }
892
893 // Bypass output dtype/device computation and fix the dtype/device as
894 // specified here.
895 TensorIteratorConfig& declare_static_dtype_and_device(
896 ScalarType dtype,
897 Device device);
898 TensorIteratorConfig& declare_static_dtype(ScalarType dtype);
899 TensorIteratorConfig& declare_static_device(Device device);
900 TensorIteratorConfig& declare_static_shape(IntArrayRef shape);
901 TensorIteratorConfig& declare_static_shape(
902 IntArrayRef shape,
903 IntArrayRef squash_dims);
904
905 // It would be better if this was && qualified, but this would be at the cost
906 // of a lot of boilerplate above
907 TensorIterator build() {
908 TensorIterator iter;
909 iter.build(*this);
910 return iter;
911 }
912
913 private:
914 SmallVector<c10::MaybeOwned<TensorBase>, 4> tensors_;
915 int num_outputs_ = 0;
916 int num_inputs_ = 0;
917
918 c10::optional<DimVector> static_shape_ = c10::nullopt;
919 c10::optional<ScalarType> static_dtype_ = c10::nullopt;
920 c10::optional<Device> static_device_ = c10::nullopt;
921 bool check_mem_overlap_ = true;
922 bool allow_cpu_scalars_ = false;
923 bool is_reduction_ = false;
924 bool resize_outputs_ = true;
925 bool check_all_same_dtype_ = true;
926 bool check_all_same_device_ = true;
927 bool enforce_safe_casting_to_output_ = false;
928 bool enforce_linear_iteration_ = false;
929 bool promote_inputs_to_common_dtype_ = false;
930 bool promote_integer_inputs_to_float_ = false;
931 bool cast_common_dtype_to_outputs_ = false;
932};
933
934/// A container-like struct that acts as if it contains splits of a
935/// TensorIterator that can use 32-bit indexing. Taken together the splits cover
936/// the original TensorIterator.
937struct TORCH_API SplitUntil32Bit {
938 struct TORCH_API iterator {
939 iterator() = default;
940 iterator(const TensorIteratorBase& iter);
941 iterator(iterator&&) = default;
942
943 // Guaranteed to be a TensorIterator proper!
944 TensorIterator& operator*() const;
945 iterator& operator++();
946 bool operator==(const iterator& other) const {
947 // two iterators are equal if they are the same object or they're both
948 // empty
949 return this == &other || (vec.empty() && other.vec.empty());
950 }
951 // needed for C++11 range-based for loop
952 bool operator!=(const iterator& other) const {
953 return !(*this == other);
954 }
955
956 /// stack of TensorIterators to be split
957 std::vector<std::unique_ptr<TensorIterator>> vec;
958 };
959
960 SplitUntil32Bit(const TensorIteratorBase& iter) : iter(iter) {}
961
962 iterator begin() const;
963 iterator end() const;
964
965 private:
966 const TensorIteratorBase& iter;
967};
968
969} // namespace at
970
971C10_CLANG_DIAGNOSTIC_POP()
972