1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#define TORCH_ASSERT_NO_OPERATORS
3#include <ATen/TensorIterator.h>
4#undef TORCH_ASSERT_NO_OPERATORS
5
6#include <ATen/core/Tensor.h>
7
8#include <ATen/ExpandUtils.h>
9#include <ATen/Parallel.h>
10#include <ATen/native/TypeProperties.h>
11#include <ATen/MemoryOverlap.h>
12#include <ATen/native/Resize.h>
13#include <ATen/NamedTensorUtils.h>
14#include <ATen/TensorOperators.h>
15#include <ATen/TensorIteratorInternal.h>
16
17#ifndef AT_PER_OPERATOR_HEADERS
18#include <ATen/Functions.h>
19#else
20#include <ATen/ops/empty.h>
21#include <ATen/ops/empty_strided.h>
22#endif
23
24#include <c10/util/irange.h>
25#include <c10/util/SmallBuffer.h>
26
27#include <array>
28#include <algorithm>
29#include <cmath>
30
31namespace at {
32
33using DimMask = TensorIteratorBase::DimMask;
34using PtrVector = TensorIteratorBase::PtrVector;
35using loop2d_t = TensorIteratorBase::loop2d_t;
36using StrideVector = TensorIteratorBase::StrideVector;
37
38namespace {
39
40inline void get_base_ptrs(char** ptrs, ArrayRef<OperandInfo> operands) {
41 std::transform(operands.begin(), operands.end(), ptrs, [](const OperandInfo& op) {
42 return static_cast<char*>(op.data);
43 });
44}
45
46inline void get_strides(int64_t* strides, ArrayRef<OperandInfo> operands, int64_t ndim) {
47 for (const auto dim : c10::irange(ndim)) {
48 for (const auto arg : c10::irange(operands.size())) {
49 *strides++ = operands[arg].stride_bytes[dim];
50 }
51 }
52 // Always at least 2d strides to support 2d for_each loops
53 if (ndim < 2) {
54 const int64_t ntensors = operands.size();
55 std::fill_n(strides, (2 - ndim) * ntensors, 0);
56 }
57}
58
59static OptionalTensorRef make_otr(const TensorBase &tensor) {
60 if (tensor.defined()) {
61 return OptionalTensorRef(tensor);
62 } else {
63 return OptionalTensorRef();
64 }
65}
66
67}
68
69namespace internal {
70
71OpaqueOptionalTensorRef::OpaqueOptionalTensorRef() {
72 static_assert(alignof(OptionalTensorRef) == alignof(TensorBase));
73 static_assert(sizeof(OptionalTensorRef) == sizeof(TensorBase));
74 new (data_.data()) OptionalTensorRef();
75}
76
77OpaqueOptionalTensorRef::~OpaqueOptionalTensorRef() {
78 get()->~OptionalTensorRef();
79}
80
81const Tensor& OpaqueOptionalTensorRef::getTensor() const {
82 return get()->getTensorRef();
83}
84
85}
86
87void OperandInfo::tensor(c10::MaybeOwned<TensorBase> &&tensor) {
88 tensor_base_ = std::move(tensor);
89 *tensor_storage_ = make_otr(*tensor_base_);
90}
91
92void OperandInfo::exchange_tensor(c10::MaybeOwned<TensorBase> &&new_tensor) {
93 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!original_tensor_base_->defined());
94 original_tensor_base_ = std::exchange(tensor_base_, new_tensor);
95 *original_tensor_storage_ = std::exchange(*tensor_storage_, make_otr(*tensor_base_));
96}
97
98void OperandInfo::restore_original_tensor() {
99 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(original_tensor_base_->defined());
100 tensor_base_ = std::move(original_tensor_base_);
101 *tensor_storage_ = std::exchange(*original_tensor_storage_, OptionalTensorRef{});
102}
103
104/// Construction
105TensorIteratorConfig& TensorIteratorConfig::add_owned_output(const TensorBase& output) {
106 TORCH_INTERNAL_ASSERT(
107 num_inputs_ == 0,
108 "Keep in mind that you have to add all outputs first before adding any input. "
109 "For more details, see https://github.com/pytorch/pytorch/wiki/How-to-use-TensorIterator.");
110 tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(c10::in_place, output));
111 num_outputs_++;
112 return *this;
113}
114
115TensorIteratorConfig& TensorIteratorConfig::add_owned_input(const TensorBase& input) {
116 tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(c10::in_place, input));
117 num_inputs_++;
118 return *this;
119}
120
121TensorIteratorConfig& TensorIteratorConfig::add_borrowed_output(const TensorBase& output) {
122 TORCH_INTERNAL_ASSERT(
123 num_inputs_ == 0,
124 "Keep in mind that you have to add all outputs first before adding any input. "
125 "For more details, see https://github.com/pytorch/pytorch/wiki/How-to-use-TensorIterator.");
126 tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(output));
127 num_outputs_++;
128 return *this;
129}
130
131TensorIteratorConfig& TensorIteratorConfig::add_borrowed_input(const TensorBase& input) {
132 tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(input));
133 num_inputs_++;
134 return *this;
135}
136
137TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype_and_device(ScalarType dtype, Device device) {
138 TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
139 static_dtype_ = dtype;
140 static_device_ = device;
141 return *this;
142}
143
144TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype(ScalarType dtype) {
145 TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)");
146 static_dtype_ = dtype;
147 return *this;
148}
149
150TensorIteratorConfig& TensorIteratorConfig::declare_static_device(Device device) {
151 static_device_ = device;
152 return *this;
153}
154
155TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape) {
156 // WARNING:
157 // This will bypass all shape checking in the TensorIterator. Kernels which call this method
158 // are expected to check shapes before calling `add_owned_input` or `add_owned_output`.
159 TORCH_CHECK(!resize_outputs_, "resize_outputs() must be called before declare_static_shape(...)")
160 static_shape_ = c10::make_optional(DimVector(shape));
161 return *this;
162}
163
164TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape, IntArrayRef squash_dims) {
165 declare_static_shape(shape);
166 if (static_shape_->empty()) return *this;
167 for (const auto& squash_dim : squash_dims) {
168 TORCH_CHECK(squash_dim >= 0 && squash_dim < static_cast<int64_t>(static_shape_->size()),
169 "squash_dim ", squash_dim, " must be in [0, ", static_shape_->size(), ").");
170 (*static_shape_)[squash_dim] = 1;
171 }
172 return *this;
173}
174
175// NOTE: [Computing output strides]
176// We use the following algorithm to compute output strides
177// If correctly sized output is provided, we respect its stides and don't change them
178// Otherwise, if provided output is of incorrect size or no output is provided,
179// we try to recover permutation that was applied to the inputs
180// by sorting the strides of the inputs. Precedence is given to the inputs in the order they were added,
181// and to permutations involving non-broadcasted dimensions
182// 1. we loop over inputs starting from the first
183// 2. for all inputs strides of broadcasted dimensions are set to 0, and 0 compares equal to anything. If one
184// of the dimensions being compared has a stride of 0, we move on to the next tensor to determine if
185// these dimensions need to be swapped.
186// 3. strides of dimensions equal to 1 participate in sorting
187// 4. if 2 strides are equal and neither is 0, we try to break the tie by looking at the corresponding dimensions
188// of the tensor. Dimensions were permuted if, when iterating from the end, dimensions corresponding to the
189// same strides are increasing. If dimensions are non-increasing, we move on to the next input to break the tie.
190//
191// Instead of applying rule 4 for tie breaking, we could move on to the next tensor directly. This would result in possibly
192// losing the correct permuation of the first tensor if there are permuted trivial dimensions, but could potentially
193// improve traversal order of the second tensor. We chose the former option to better propagate channels last layout
194// for example for a tensor with the sizes N1H1
195// These rules result in the intuitive behavior that in most cases recovers permutation of either the first argument (if all
196// arguments are of the same size) or the argument that is not broadcasted, regardless of its position.
197// As a bonus, it also result in reasonably well-behaved traversal order of the inputs and outputs - in the kernels
198// output is traversed linearly, and since it closely follows input layouts, inputs are traversed linearly as well
199//
200// Examples:
201// full size tensor + broadcasted tensor with 0 or 1 non-trivial dimensions => strides of output are same
202// as strides of full size input regardless of the order
203// 2 tensors of same size but different strides => output strides are the same as first argument
204//
205// We also have fast path for memory-dense inputs with the same strides (or, trivially, single memory-dense input)
206// that outputs a tensor with the same strides as inputs. The only difference in result with the algorithm described
207// above is for strides for trivial (1) dimensions, where in ambiguous cases for performance reasons we default to
208// contiguous strides.
209// Example: tensor with sizes NC11 and strides C1CC will produce output with strides C111 (note differences are only
210// in the strides of trivial dimensions, so physical layout is unaffected but permutation information is lost)
211// We might change this behavior in future once performance considerations are resolved
212
213void TensorIteratorBase::reorder_dimensions() {
214 // Sort the dimensions based on strides in ascending order with reduced dims
215 // at the front. NOTE: that this inverts the order of C-contiguous tensors.
216 // strides[0] is the fastest moving dimension instead of strides[ndim - 1].
217 // See NOTE: [Computing output strides] and inline comments for more detailed description
218
219 perm_.resize(ndim());
220 if (ndim() == 1) {
221 perm_[0] = 0;
222 return;
223 }
224
225 // initialize perm with n-1, n-2, ..., 1, 0
226 std::iota(perm_.rbegin(), perm_.rend(), 0);
227
228 // Reordering dimensions changes iteraton order
229 if (enforce_linear_iteration_) {
230 permute_dimensions(perm_);
231 return;
232 }
233
234 // returns 1 if the dim0 should come after dim1, -1 if dim0 should come
235 // before dim1, and 0 if the comparison is ambiguous.
236 auto should_swap = [&](size_t dim0, size_t dim1) {
237 for (const auto arg : c10::irange(ntensors())) {
238 // ignore undefined or incorrectly sized tensors
239 if (operands_[arg].stride_bytes.empty() || operands_[arg].will_resize) {
240 continue;
241 }
242 int64_t stride0 = operands_[arg].stride_bytes[dim0];
243 int64_t stride1 = operands_[arg].stride_bytes[dim1];
244 if (is_reduction_ && operands_[arg].is_output) {
245 // move reduced dimensions to the front
246 // strides of reduced dimensions are always set to 0 by review_reduce_result
247 if ((stride0 == 0) != (stride1 == 0)) {
248 return stride1 == 0 ? 1 : -1;
249 }
250 }
251 //move on to the next input if one of the dimensions is broadcasted
252 if (stride0 == 0 || stride1 == 0) {
253 continue;
254 // it is important to return here only with strict comparisons, for equal strides we try to break the tie later
255 // by comparing corresponding dimensions or if that does not work, moving on to the next tensor
256 } else if (stride0 < stride1) {
257 return -1;
258 } else if (stride0 > stride1) {
259 return 1;
260 } else { //equal strides, use dimensions themselves as the tie-breaker.
261 //at this point, with zero strides out of the way, we are guaranteed that operand dimensions are equal to shape_
262 auto t_dim0 = shape_[dim0];
263 auto t_dim1 = shape_[dim1];
264 //return only if dimensions should be swapped, otherwise move on to the next tensor
265 if (t_dim0 > t_dim1) {
266 return 1;
267 }
268 }
269 }
270 return 0;
271 };
272
273 // insertion sort with support for ambiguous comparisons
274 for (const auto i : c10::irange(1, ndim())) {
275 int dim1 = i;
276 for (int dim0 = i - 1; dim0 >= 0; dim0--) {
277 int comparison = should_swap(perm_[dim0], perm_[dim1]);
278 if (comparison > 0) {
279 std::swap(perm_[dim0], perm_[dim1]);
280 dim1 = dim0;
281 } else if (comparison < 0) {
282 break;
283 }
284 }
285 }
286
287 // perform re-ordering of shape and strides
288 permute_dimensions(perm_);
289}
290
291// Computes a common dtype using type promotion
292// See the [Common Dtype Computation] note
293ScalarType TensorIteratorBase::compute_common_dtype() {
294 at::native::ResultTypeState state = {};
295 for (const auto& op : operands_) {
296 if (op.is_output) {
297 continue;
298 }
299
300 state = at::native::update_result_type_state(op.tensor(), state);
301 }
302
303 common_dtype_ = at::native::result_type(state);
304 TORCH_INTERNAL_ASSERT(common_dtype_ != ScalarType::Undefined);
305
306 return common_dtype_;
307}
308
309TensorOptions original_options(const OperandInfo& op) {
310 if (op.original_tensor_base().defined()) {
311 return op.original_tensor_base().options();
312 } else {
313 return op.options();
314 }
315}
316
317// Implements the the behavior of the following flags:
318// - check_all_same_dtype_
319// - check_all_same_device_
320// - enforce_safe_casting_to_output_
321// - promote_inputs_to_common_dtype_
322// - cast_common_dtype_to_outputs_
323//
324// See their descriptions in TensorIterator.h for details.
325// NOTE: Checks for more specific behaviors (e.g. the first and second
326// inputs must share a dtype, but the third must have the long dtype)
327// should be implemented directly and outside of TensorIterator.
328void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) {
329 // Reviews operands (1/2)
330 // - validates that all input tensors are defined
331 // - computes common device
332 // - determines if there are undefined outputs
333 // - determines if there are different dtypes and attempts
334 // to quickly acquire a common dtype
335 Device common_device = kCPU;
336 common_dtype_ = ScalarType::Undefined;
337 // NB: despite output_dtype's generic sounding name, it only is
338 // used in a nontrivial way if check_all_same_dtype is true
339 ScalarType output_dtype = ScalarType::Undefined;
340 bool has_different_input_dtypes = false;
341 bool has_different_output_dtypes = false;
342 bool has_undefined_outputs = false;
343
344 for (auto& op : operands_) {
345 // Validates that all inputs have type information, and that
346 // if an output is missing type information that we can infer
347 // the device it should be allocated on.
348 if (!op.is_type_defined()) {
349 TORCH_INTERNAL_ASSERT(op.is_output, "Found type undefined input tensor!");
350
351 if (config.static_dtype_.has_value()) {
352 op.target_dtype = config.static_dtype_.value();
353 } else {
354 has_undefined_outputs = true;
355 }
356
357 if (config.static_device_.has_value()) {
358 op.device = config.static_device_.value();
359 } else {
360 TORCH_INTERNAL_ASSERT(config.check_all_same_device_);
361 }
362
363 if (has_undefined_outputs || !op.device.has_value()) {
364 continue;
365 }
366 }
367
368 // Validates input tensors are defined
369 if (!op.tensor_base().defined()) {
370 TORCH_INTERNAL_ASSERT(op.is_output, "Found undefined input tensor!");
371 continue;
372 }
373
374 TORCH_INTERNAL_ASSERT(op.target_dtype == op.current_dtype)
375
376 // Acquires the first non-CPU device (if any) as the common device
377 if (common_device == kCPU && !op.tensor_base().is_cpu()) {
378 common_device = op.tensor_base().device();
379 }
380
381 if (!op.is_output) {
382 // Determines if there are varying input dtypes
383 // NOTE: the common dtype is set to the first defined input dtype observed
384 if (op.target_dtype != common_dtype_) {
385 if (common_dtype_ == ScalarType::Undefined) {
386 common_dtype_ = op.target_dtype;
387 } else {
388 has_different_input_dtypes = true;
389 }
390 }
391 } else { // op.is_output
392 // Determines if there are varying output dtypes
393 // NOTE: the output dtype is set to the first defined output dtype observed
394 if (op.target_dtype != output_dtype) {
395 if (output_dtype == ScalarType::Undefined) {
396 output_dtype = op.target_dtype;
397 } else {
398 has_different_output_dtypes = true;
399 }
400 }
401 }
402 }
403
404 // Checks that either the computation type is computable or unneeded
405 TORCH_INTERNAL_ASSERT(!(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ &&
406 (has_undefined_outputs || config.enforce_safe_casting_to_output_ ||
407 config.cast_common_dtype_to_outputs_)));
408
409 // Checks that all inputs and defined outputs are the same dtype, if requested
410 if (config.check_all_same_dtype_ &&
411 (has_different_input_dtypes || has_different_output_dtypes ||
412 (common_dtype_ != output_dtype && output_dtype != ScalarType::Undefined))) {
413 // Throws an informative error message
414 for (auto& op : operands_) {
415 if (!op.tensor_base().defined()) {
416 continue;
417 }
418
419 TORCH_CHECK(op.target_dtype == common_dtype_,
420 "Found dtype ", op.target_dtype, " but expected ", common_dtype_);
421 }
422 }
423
424 // Short-circuits if no additional work required
425 if (!has_undefined_outputs && !config.check_all_same_device_ &&
426 !config.promote_inputs_to_common_dtype_ && !config.cast_common_dtype_to_outputs_ &&
427 !config.enforce_safe_casting_to_output_) {
428 // Invalidates common_dtype_ if it could not be inferred
429 common_dtype_ = has_different_input_dtypes ? ScalarType::Undefined : common_dtype_;
430 return;
431 }
432
433 // Computes a common dtype, if needed
434 if ((has_different_input_dtypes || all_ops_are_scalars_) && config.promote_inputs_to_common_dtype_) {
435 common_dtype_ = compute_common_dtype();
436 }
437
438 // Promotes common dtype to the default float scalar type, if needed
439 if (config.promote_integer_inputs_to_float_ &&
440 c10::isIntegralType(common_dtype_, /*includeBool=*/true)) {
441 common_dtype_ = c10::typeMetaToScalarType(c10::get_default_dtype());
442 }
443
444 // Reviews operands (2/2)
445 // - sets metadata for undefined outputs
446 // - checks that all tensors are on the same device, if requested
447 // - checks that the common dtype can safely cast to each output, if requested
448 // - creates temporaries for CPU operations, if needed and requested
449 common_device_ = common_device;
450 int max_cpu_scalars_on_non_cpu = config.allow_cpu_scalars_ ? 1 : 0;
451 int current_cpu_scalars_on_non_cpu = 0;
452 for (auto& op : operands_) {
453 bool is_type_defined = op.is_type_defined();
454 bool is_device_defined = op.is_device_defined();
455
456 if (!is_type_defined) {
457 op.target_dtype = common_dtype_;
458 }
459 if (!is_device_defined) {
460 op.device = common_device;
461 }
462
463 if (!is_type_defined && !is_device_defined) {
464 continue;
465 }
466
467 // Skips undefined tensors
468 if (!op.tensor_base().defined()) {
469 continue;
470 }
471
472 // Checks all tensors are on the same device, if requested
473 if (config.check_all_same_device_) {
474 // Handles CPU scalars on CUDA kernels that support them
475 if (!common_device.is_cpu() &&
476 config.allow_cpu_scalars_ && !op.is_output && op.tensor_base().dim() == 0 &&
477 op.tensor_base().is_cpu()) {
478 TORCH_CHECK(current_cpu_scalars_on_non_cpu < max_cpu_scalars_on_non_cpu,
479 "Trying to pass too many CPU scalars to non-CPU kernel!");
480 ++current_cpu_scalars_on_non_cpu;
481 } else if (op.device.value() != common_device) {
482 TORCH_CHECK(false,
483 "Expected all tensors to be on the same device, but "
484 "found at least two devices, ", common_device, " and ", op.device.value(), "!");
485 }
486 }
487
488 // Checks safe casting, if requested
489 if (config.enforce_safe_casting_to_output_ && op.is_output && op.current_dtype != common_dtype_) {
490 TORCH_CHECK(canCast(common_dtype_, op.current_dtype),
491 "result type ", common_dtype_, " can't be cast to the "
492 "desired output type ", op.current_dtype);
493 }
494
495 // Creates temporaries for CPU operations, if needed and requested
496 // TODO: reuse temporaries when possible (e.g. for inplace operations)
497 if (common_device == kCPU) {
498 // Casts to outputs by creating temporaries of the correct dtype (if needed)
499 // NB: we skip this on is_meta_, because the temporary allocation here is
500 // unnecessary if we aren't going to actually do the compute
501 if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) {
502 TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
503 // Marker [Output original_tensor is set]
504 // NB: do NOT use set_output here, as the temporary is NOT a true output;
505 // op.tensor is the true output and it was pre-provided for us.
506 // TODO: The logic for cast_outputs will need to be handled by the
507 // structured kernels implementation. What probably should happen
508 // is that we pass in the inferred dtype into the out kernel, and
509 // then after calling the out kernel, do the conversion (which
510 // is cast_outputs here), but integrating this with existing
511 // TensorIterator will take a little doing
512 op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(
513 at::empty_like(op.tensor(),
514 op.tensor_base().options().dtype(common_dtype_),
515 LEGACY_CONTIGUOUS_MEMORY_FORMAT)));
516 if (!names_.empty()) {
517 namedinference::propagate_names(op.tensor_base(), names_);
518 }
519 op.current_dtype = common_dtype_;
520 op.target_dtype = common_dtype_;
521 }
522
523 // Promotes inputs by creating temporaries of the correct dtype
524 if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) {
525 op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(op.tensor().to(common_dtype_)));
526 op.current_dtype = common_dtype_;
527 op.target_dtype = common_dtype_;
528 }
529 }
530 }
531}
532
533StrideVector TensorIteratorBase::compatible_stride(int element_size) const {
534 auto stride = StrideVector();
535 int64_t next_stride = element_size;
536 for (const auto dim : c10::irange(ndim())) {
537 stride.push_back(next_stride);
538 next_stride *= shape_[dim];
539 }
540 return stride;
541}
542
543DimVector TensorIteratorBase::invert_perm(IntArrayRef input) const {
544 // Invert the permutation caused by reorder_dimensions. This is not valid
545 // after coalesce_dimensions is called.
546 TORCH_INTERNAL_ASSERT(!has_coalesced_dimensions_);
547 TORCH_INTERNAL_ASSERT(input.size()==perm_.size());
548 auto res = DimVector(input.size()); //no initialization needed, every value in res should be written to.
549 for (const auto dim : c10::irange(ndim())) {
550 res[perm_[dim]] = input[dim];
551 }
552 return res;
553}
554
555void TensorIteratorBase::allocate_or_resize_outputs() {
556 for (const auto i : c10::irange(num_outputs_)) {
557 auto& op = operands_[i];
558 if (!op.tensor_base().defined() || op.will_resize) {
559 TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
560 int element_size = elementSize(op.target_dtype);
561 op.stride_bytes = compatible_stride(element_size);
562 // check if permutation is just an inverted order
563 bool inverted = true;
564 for (const auto j : c10::irange(ndim())) {
565 if (perm_[j] != ndim() - j - 1) {
566 inverted = false;
567 break;
568 }
569 }
570 auto tensor_shape = invert_perm(shape_);
571 if (inverted) {
572 // can just return contiguous output
573 // it is faster because it avoids allocating 0 size tensor and
574 // resizing and restriding it
575 set_output_raw_strided(i, tensor_shape, {}, original_options(op), names_);
576 } else {
577 auto tensor_stride = invert_perm(op.stride_bytes);
578 for (const auto dim : c10::irange(ndim())) {
579 tensor_stride[dim] /= element_size;
580 }
581 set_output_raw_strided(i, tensor_shape, tensor_stride, original_options(op), names_);
582 }
583 op.current_dtype = op.target_dtype;
584 } else if (op.tensor_base().defined()) {
585 // Even if we don't resize, we still need to tell set_output about
586 // the output, so that we properly set guard and propagate names
587 set_output_raw_strided(i, op.tensor_base().sizes(), {}, original_options(op), names_);
588 }
589 }
590}
591
592void TensorIteratorBase::compute_names(const TensorIteratorConfig& config) {
593 bool should_infer_names = std::any_of(
594 operands_.begin(),
595 operands_.end(),
596 [](const OperandInfo& op) {
597 return op.tensor_base().defined() && op.tensor_base().has_names();
598 });
599 if (!should_infer_names) {
600 return;
601 }
602
603 for (auto& op : operands_) {
604 if (!op.tensor_base().defined()) continue;
605 // Don't include output tensors if we are resizing, since we will
606 // clobber their names in any case. (If the output tensor was
607 // also an input tensor, we'll pick it up when it shows up again
608 // in operands).
609 if (config.resize_outputs_ && op.is_output) continue;
610 // perform name inference
611 if (names_.empty()) {
612 names_ = op.tensor_base().names();
613 } else {
614 names_ = NameVector(unify_from_right(names_, op.tensor_base().names()));
615 }
616 }
617}
618
619void TensorIteratorBase::coalesce_dimensions() {
620 if (ndim() <= 1) {
621 return;
622 }
623
624 // We can coalesce two adjacent dimensions if either dim has size 1 or if:
625 // shape[n] * stride[n] == shape[n + 1].
626 auto can_coalesce = [&](int dim0, int dim1) {
627 auto shape0 = shape_[dim0];
628 auto shape1 = shape_[dim1];
629 if (shape0 == 1 || shape1 == 1) {
630 return true;
631 }
632 for (const auto i : c10::irange(ntensors())) {
633 auto& stride = operands_[i].stride_bytes;
634 if (shape0 * stride[dim0] != stride[dim1]) {
635 return false;
636 }
637 }
638 return true;
639 };
640
641 // replace each operands stride at dim0 with its stride at dim1
642 auto replace_stride = [&](int dim0, int dim1) {
643 for (const auto i : c10::irange(ntensors())) {
644 auto& stride = operands_[i].stride_bytes;
645 stride[dim0] = stride[dim1];
646 }
647 };
648
649 int prev_dim = 0;
650 for (const auto dim : c10::irange(1, ndim())) {
651 if (can_coalesce(prev_dim, dim)) {
652 if (shape_[prev_dim] == 1) {
653 replace_stride(prev_dim, dim);
654 }
655 shape_[prev_dim] *= shape_[dim];
656 } else {
657 prev_dim++;
658 if (prev_dim != dim) {
659 replace_stride(prev_dim, dim);
660 shape_[prev_dim] = shape_[dim];
661 }
662 }
663 }
664
665 shape_.resize(prev_dim + 1);
666 for (const auto i : c10::irange(ntensors())) {
667 operands_[i].stride_bytes.resize(ndim());
668 }
669 has_coalesced_dimensions_ = true;
670}
671
672int64_t TensorIteratorBase::numel() const {
673 int64_t numel = 1;
674 for (int64_t size : shape_) {
675 numel *= size;
676 }
677 return numel;
678}
679
680StrideVector TensorIteratorBase::get_dim_strides(int dim) const {
681 auto dims = ndim();
682 auto inner_strides = StrideVector();
683 for (auto& op : operands_) {
684 inner_strides.push_back(dims == 0 ? 0 : op.stride_bytes[dim]);
685 }
686 return inner_strides;
687}
688
689SmallVector<char*, 4> TensorIteratorBase::get_base_ptrs() const {
690 auto ptrs = SmallVector<char*, 4>(ntensors());
691 at::get_base_ptrs(ptrs.data(), operands_);
692 return ptrs;
693}
694
695bool TensorIteratorBase::is_dim_reduced(int dim) const {
696 for (auto& op : operands_) {
697 if (op.is_output && op.stride_bytes[dim] == 0 && shape_[dim] > 1) {
698 return true;
699 }
700 }
701 return false;
702}
703
704void TensorIteratorBase::permute_dimensions(IntArrayRef perm) {
705 TORCH_INTERNAL_ASSERT(perm.size() == static_cast<unsigned>(ndim()));
706
707 auto reorder = [perm](IntArrayRef data) {
708 auto res = DimVector(data.size(), 0);
709 for (const auto i : c10::irange(perm.size())) {
710 res[i] = data[perm[i]];
711 }
712 return res;
713 };
714
715 // Update shape and strides
716 shape_ = reorder(shape_);
717 for (auto& op : operands_) {
718 if (!op.stride_bytes.empty()) {
719 op.stride_bytes = reorder(op.stride_bytes);
720 }
721 }
722}
723
724int64_t TensorIteratorBase::num_output_elements() const {
725 int64_t elem = 1;
726 for (const auto dim : c10::irange(ndim())) {
727 if (operands_[0].stride_bytes[dim] != 0 || shape_[dim] == 0) {
728 elem *= shape_[dim];
729 }
730 }
731 return elem;
732}
733
734int TensorIteratorBase::num_reduce_dims() const {
735 int count = 0;
736 for (const auto dim : c10::irange(ndim())) {
737 if (operands_[0].stride_bytes[dim] == 0) {
738 count++;
739 }
740 }
741 return count;
742}
743
744void TensorIteratorBase::for_each(loop2d_t loop, int64_t grain_size) {
745 int64_t numel = this->numel();
746 if (numel == 0) {
747 return;
748 } else if (numel < grain_size || at::get_num_threads() == 1) {
749 return serial_for_each(loop, {0, numel});
750 } else {
751 at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) {
752 serial_for_each(loop, {begin, end});
753 });
754 }
755}
756
757StrideVector TensorIteratorBase::get_strides() const {
758 const auto dim = ndim();
759 StrideVector strides(std::max(dim, 2) * ntensors());
760 at::get_strides(strides.data(), operands_, dim);
761 return strides;
762}
763
764void TensorIteratorBase::serial_for_each(loop2d_t loop, Range range) const {
765 if (range.size() == 0) {
766 return;
767 }
768
769 const auto ntensors = this->ntensors();
770 const auto ndim = this->ndim();
771
772 c10::SmallBuffer<char*, 4> ptrs(ntensors);
773 c10::SmallBuffer<int64_t, 8> strides(ntensors * std::max(ndim, 2));
774
775 at::get_base_ptrs(ptrs.data(), operands_);
776 at::get_strides(strides.data(), operands_, ndim);
777 at::internal::serial_for_each(
778 shape_, strides, ptrs.data(), ptrs.size(), loop, range);
779}
780
781bool TensorIteratorBase::is_trivial_1d() const {
782 // TODO: check for casting once it's supported
783 return ndim() == 1;
784}
785
786bool TensorIteratorBase::is_contiguous() const {
787 if (numel() == 1) {
788 return true;
789 }
790 if (ndim() != 1) {
791 return false;
792 }
793 return has_contiguous_first_dim();
794}
795
796
797bool TensorIteratorBase::is_scalar(int arg) const {
798 const auto& stride = operands_[arg].stride_bytes;
799 for (const auto i : c10::irange(ndim())) {
800 if (stride[i] != 0 && shape_[i] != 1) {
801 return false;
802 }
803 }
804 return true;
805}
806
807bool TensorIteratorBase::is_cpu_scalar(int arg) const {
808 return is_scalar(arg) && device(arg).is_cpu();
809}
810
811void TensorIteratorBase::cast_outputs() {
812 for (auto& op : operands_) {
813 if (op.is_output && op.original_tensor_base().defined() &&
814 op.original_tensor_base().scalar_type() != op.current_dtype) {
815 // TODO: Now that set_output resizes both the original_tensor
816 // and tensor, this condition should no longer ever be true
817 const auto &original_tensor = op.original_tensor();
818 const auto &tensor = op.tensor();
819 if (original_tensor.sizes() != tensor.sizes()){
820 original_tensor.resize_as_(tensor).as_strided_(tensor.sizes(), tensor.strides());
821 }
822 original_tensor.copy_(tensor);
823 op.restore_original_tensor();
824 }
825 }
826}
827
828void* TensorIteratorBase::data_ptr(int arg) const {
829 return operands_[arg].data;
830}
831
832void TensorIteratorBase::remove_operand(int arg) {
833 operands_.erase(operands_.begin() + arg);
834}
835
836void TensorIteratorBase::unsafe_replace_operand(int arg, void* data) {
837 operands_[arg].data = data;
838}
839
840void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) {
841 TORCH_INTERNAL_ASSERT(dim < ndim() && size >= 1);
842 shape_[dim] = size;
843 view_offsets_[dim] += start;
844 for (auto& op : operands_) {
845 op.data = ((char*)op.data) + op.stride_bytes[dim] * start;
846 }
847 if (size == 1 && !is_reduction_) {
848 coalesce_dimensions();
849 }
850}
851
852void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indices) {
853 TORCH_INTERNAL_ASSERT(start_dim <= ndim());
854 for (const auto i : c10::irange(start_dim, ndim())) {
855 for (auto& op : operands_) {
856 op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim];
857 }
858 shape_[i] = 1;
859 }
860}
861
862#define BINARY_FLOAT_OP_CONFIG() \
863 TensorIteratorConfig() \
864 .set_check_mem_overlap(true) \
865 .allow_cpu_scalars(true) \
866 .promote_inputs_to_common_dtype(true) \
867 .cast_common_dtype_to_outputs(true) \
868 .enforce_safe_casting_to_output(true) \
869 .promote_integer_inputs_to_float(true)
870
871// Helper to construct a binary op that promotes integer inputs to float.
872void TensorIteratorBase::build_binary_float_op(
873 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
874 build(BINARY_FLOAT_OP_CONFIG()
875 .add_owned_output(out)
876 .add_owned_input(a)
877 .add_owned_input(b));
878}
879
880void TensorIteratorBase::build_borrowing_binary_float_op(
881 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
882 build(BINARY_FLOAT_OP_CONFIG()
883 .add_output(out)
884 .add_input(a)
885 .add_input(b));
886}
887
888static void set_up_comparison_op_config(TensorIteratorConfig& config, const TensorBase& out) {
889 config.set_check_mem_overlap(true);
890 config.allow_cpu_scalars(true);
891 config.promote_inputs_to_common_dtype(true);
892
893 // When 'out' isn't defined (e.g. for the functional operator 'a == b'), we
894 // want the output to be bool. Otherwise (e.g. 'torch.eq(a, b, out=c)') we
895 // don't coerce the output.
896 if (!out.defined()) {
897 config.declare_static_dtype(kBool);
898 }
899
900 // Note [special-case bool outputs]
901 // We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor
902 // has `bool` dtype. This is a performance optimization: the functional
903 // version of all comparison/logical ops uses a bool output tensor, and we'd like to
904 // avoid creating a temporary copy of the output.
905 // However, note that all kernels using this TensorIterator will need to special-case when
906 // the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool).
907 if (out.defined() && out.scalar_type() != kBool) {
908 config.cast_common_dtype_to_outputs(true);
909 }
910}
911
912void TensorIteratorBase::build_comparison_op(
913 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
914 TensorIteratorConfig config;
915 set_up_comparison_op_config(config, out);
916
917 config.add_owned_output(out);
918 config.add_owned_input(a);
919 config.add_owned_input(b);
920 build(config);
921}
922
923void TensorIteratorBase::build_borrowing_comparison_op(
924 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
925 TensorIteratorConfig config;
926 set_up_comparison_op_config(config, out);
927
928 config.add_borrowed_output(out);
929 config.add_borrowed_input(a);
930 config.add_borrowed_input(b);
931 build(config);
932}
933
934void TensorIteratorBase::build_borrowing_except_last_argument_comparison_op(
935 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
936 TensorIteratorConfig config;
937 set_up_comparison_op_config(config, out);
938
939 config.add_borrowed_output(out);
940 config.add_borrowed_input(a);
941 config.add_owned_input(b);
942 build(config);
943}
944
945void TensorIteratorBase::build_ternary_op(
946 const TensorBase& out, const TensorBase& a,
947 const TensorBase& b, const TensorBase& c) {
948 build(TensorIteratorConfig()
949 .promote_inputs_to_common_dtype(true)
950 .enforce_safe_casting_to_output(true)
951 .add_owned_output(out)
952 .add_owned_input(a)
953 .add_owned_input(b)
954 .add_owned_input(c));
955}
956
957// This cannot be a function because TensorIteratorConfig is not
958// copyable or movable, so it can't be returned from the function.
959#define BINARY_OP_CONFIG() \
960 TensorIteratorConfig() \
961 .set_check_mem_overlap(true) \
962 .allow_cpu_scalars(true) \
963 .promote_inputs_to_common_dtype(true) \
964 .cast_common_dtype_to_outputs(true) \
965 .enforce_safe_casting_to_output(true) \
966
967void TensorIteratorBase::build_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b) {
968 build(BINARY_OP_CONFIG()
969 .add_owned_output(out)
970 .add_owned_input(a)
971 .add_owned_input(b));
972}
973
974void TensorIteratorBase::build_borrowing_binary_op(
975 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
976 build(BINARY_OP_CONFIG()
977 .add_output(out)
978 .add_input(a)
979 .add_input(b));
980}
981
982// This cannot be a function because TensorIteratorConfig is not
983// copyable or movable, so it can't be returned from the function.
984#define UNARY_FLOAT_OP_CONFIG() \
985 TensorIteratorConfig() \
986 .set_check_mem_overlap(true) \
987 .promote_inputs_to_common_dtype(true) \
988 .cast_common_dtype_to_outputs(true) \
989 .enforce_safe_casting_to_output(true) \
990 .promote_integer_inputs_to_float(true)
991
992void TensorIteratorBase::build_unary_float_op(const TensorBase& out, const TensorBase& a) {
993 build(UNARY_FLOAT_OP_CONFIG()
994 .add_owned_output(out)
995 .add_owned_input(a));
996}
997
998void TensorIteratorBase::build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a) {
999 build(UNARY_FLOAT_OP_CONFIG()
1000 .add_output(out)
1001 .add_input(a));
1002}
1003
1004// This cannot be a function because TensorIteratorConfig is not
1005// copyable or movable, so it can't be returned from the function.
1006#define UNARY_OP_CONFIG() \
1007 TensorIteratorConfig() \
1008 .set_check_mem_overlap(true) \
1009 .cast_common_dtype_to_outputs(false) \
1010 .enforce_safe_casting_to_output(false) \
1011 .check_all_same_dtype(true)
1012
1013void TensorIteratorBase::build_unary_op(const TensorBase& out, const TensorBase& a) {
1014 build(UNARY_OP_CONFIG()
1015 .add_owned_output(out)
1016 .add_owned_input(a));
1017}
1018
1019void TensorIteratorBase::build_borrowing_unary_op(const TensorBase& out, const TensorBase& a) {
1020 build(UNARY_OP_CONFIG()
1021 .add_output(out)
1022 .add_input(a));
1023}
1024
1025void TensorIteratorBase::build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a) {
1026 build(UNARY_OP_CONFIG()
1027 .add_output(out)
1028 .add_owned_input(a));
1029}
1030
1031// Helper to construct a unary op that forcibly promotes output to boolean.
1032// Only be used when the output tensor must have boolean type.
1033void TensorIteratorBase::build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) {
1034 build(TensorIteratorConfig()
1035 .set_check_mem_overlap(true)
1036 .check_all_same_dtype(false)
1037 .declare_static_dtype(at::kBool)
1038 .declare_static_device(a.device())
1039 .add_output(out)
1040 .add_input(a));
1041}
1042
1043TensorIterator TensorIterator::binary_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
1044 TensorIterator iter;
1045 iter.build_binary_op(out, a, b);
1046 return iter;
1047}
1048
1049TensorIterator TensorIterator::borrowing_binary_op(
1050 const TensorBase& out, const TensorBase& a, const TensorBase& b) {
1051 TensorIterator iter;
1052 iter.build_borrowing_binary_op(out, a, b);
1053 return iter;
1054}
1055
1056TensorIterator TensorIterator::binary_float_op(TensorBase& out, const TensorBase& a, const TensorBase& b) {
1057 TensorIterator iter;
1058 iter.build_binary_float_op(out, a, b);
1059 return iter;
1060}
1061
1062TensorIterator TensorIterator::comparison_op(TensorBase& out, const TensorBase& a,
1063 const TensorBase& b) {
1064 TensorIterator iter;
1065 iter.build_comparison_op(out, a, b);
1066 return iter;
1067}
1068
1069TensorIterator TensorIterator::unary_op(TensorBase& out, const TensorBase& a) {
1070 TensorIterator iter;
1071 iter.build_unary_op(out, a);
1072 return iter;
1073}
1074
1075TensorIterator TensorIterator::unary_float_op(TensorBase& out, const TensorBase& a) {
1076 TensorIterator iter;
1077 iter.build_unary_float_op(out, a);
1078 return iter;
1079}
1080
1081#define NULLARY_OP_CONFIG() \
1082 TensorIteratorConfig() \
1083 .set_check_mem_overlap(true) \
1084 .check_all_same_dtype(false) \
1085 /* FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342 */ \
1086 .resize_outputs(false)
1087
1088TensorIterator TensorIterator::nullary_op(TensorBase& out) {
1089 return NULLARY_OP_CONFIG()
1090 .add_owned_output(out)
1091 .build();
1092}
1093
1094TensorIterator TensorIterator::borrowing_nullary_op(const TensorBase& out) {
1095 return NULLARY_OP_CONFIG()
1096 .add_output(out)
1097 .build();
1098}
1099
1100TensorIterator TensorIterator::reduce_op(TensorBase& out, const TensorBase& a) {
1101 TORCH_INTERNAL_ASSERT(out.defined());
1102 return TensorIteratorConfig()
1103 .set_check_mem_overlap(false)
1104 .add_owned_output(out)
1105 .add_owned_input(a)
1106 .resize_outputs(false)
1107 .is_reduction(true)
1108 // TODO: not supporting casting to outputs is only really necessary for arg{min,max}
1109 .promote_inputs_to_common_dtype(true)
1110 .build();
1111}
1112
1113TensorIterator TensorIterator::reduce_op(TensorBase& out1, TensorBase& out2, const TensorBase& a) {
1114 TORCH_INTERNAL_ASSERT(out1.defined());
1115 TORCH_INTERNAL_ASSERT(out2.defined());
1116 TORCH_CHECK(a.device() == out1.device() && out1.device() == out2.device(),
1117 "reduce_op(): expected input and both outputs to be on same device, but input is on ", a.device(),
1118 ", output1 is on ", out1.device(), " and output2 is on", out2.device());
1119 TORCH_CHECK(out1.dim() == out2.dim(), "reduce_op(): expected both outputs to have same number of dims, but output1 has ", out1.dim(),
1120 " and output2 has ", out2.dim());
1121 TORCH_CHECK(out1.sizes() == out2.sizes(), "reduce_op(): expected both outputs to have same sizes, but output1 has ", out1.sizes(),
1122 " and output2 has ", out2.sizes());
1123 TORCH_CHECK(out1.strides() == out2.strides(), "reduce_op(): expected both outputs to have same strides, but output1 has ", out1.strides(),
1124 " and output2 has ", out2.strides());
1125 return TensorIteratorConfig()
1126 .set_check_mem_overlap(false)
1127 .add_owned_output(out1)
1128 .add_owned_output(out2)
1129 .add_owned_input(a)
1130 .resize_outputs(false)
1131 .is_reduction(true)
1132 .check_all_same_dtype(false)
1133 .build();
1134}
1135
1136void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) {
1137 for (auto& tensor: config.tensors_) {
1138 // If *any* of the arguments is a meta tensor, the overall
1139 // computation is a meta computation (don't do any work,
1140 // just compute output information). This aligns with
1141 // our multiple dispatch semantics.
1142 if (tensor->is_meta()) {
1143 is_meta_ = true;
1144 }
1145 operands_.emplace_back(std::move(tensor));
1146 }
1147 num_outputs_ = config.num_outputs_;
1148}
1149
1150void TensorIteratorBase::mark_outputs() {
1151 // TODO: merge this into populate_operands
1152 for (const auto i : c10::irange(num_outputs_)) {
1153 operands_[i].is_output = true;
1154 const auto& output = tensor(i);
1155 if (!output.defined()) continue;
1156
1157 // check if output is also an input
1158 for (const auto arg : c10::irange(num_outputs_, ntensors())) {
1159 const auto& input = tensor(arg);
1160 if (output.is_same(input)) {
1161 operands_[i].is_read_write = true;
1162 }
1163 }
1164 }
1165}
1166
1167void TensorIteratorBase::mark_resize_outputs(const TensorIteratorConfig& config) {
1168 // Outputs cannot be broadcasted. Check that the shape of the outputs matches
1169 // the inferred shape. There's an exception for write-only tensors to support
1170 // our legacy behavior that functions with `out=` arguments resize their
1171 // outputs.
1172 if (config.static_shape_.has_value()) {
1173 return;
1174 }
1175 for (const auto i : c10::irange(num_outputs_)) {
1176 const auto& output = tensor(i);
1177 if (output.defined() && !output.sizes().equals(shape_)) {
1178 if (config.resize_outputs_ && !operands_[i].is_read_write) {
1179 operands_[i].will_resize = true;
1180 continue;
1181 }
1182 // for reduction, output size does not match shape_, as output is reduced size, and shape_ is size of the input
1183 TORCH_CHECK(is_reduction_, "output with shape ", output.sizes(), " doesn't match the broadcast shape ",
1184 shape_);
1185 }
1186 }
1187}
1188
1189void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config) {
1190 if (!config.check_mem_overlap_) {
1191 return;
1192 }
1193 for (const auto i : c10::irange(num_outputs_)) {
1194 const auto& output = tensor_base(i);
1195 if (!output.defined()) continue;
1196 assert_no_internal_overlap(output);
1197 for (const auto j : c10::irange(num_outputs_, ntensors())) {
1198 const auto& input = tensor_base(j);
1199 if (!input.is_same(output)) {
1200 assert_no_partial_overlap(output, input);
1201 }
1202 }
1203 }
1204}
1205
1206void TensorIteratorBase::compute_shape(const TensorIteratorConfig& config) {
1207 if (config.static_shape_.has_value()) {
1208 shape_ = *config.static_shape_;
1209 return;
1210 }
1211
1212 all_ops_same_shape_ = true;
1213 bool has_scalars = false;
1214 bool has_tensors = false;
1215 for (auto& op : operands_) {
1216 if (!op.tensor_base().defined()) continue;
1217
1218 // For now, don't include output tensors when we're resizing outputs.
1219 // These shapes don't participate in shape computation.
1220 // This preserves the legacy behavior where torch.add(..., out=dst) resizes
1221 // the destination tensor. If the output tensor is also an input, we'll
1222 // pick it up later in the operands.
1223 if (config.resize_outputs_ && op.is_output) continue;
1224 TORCH_CHECK(!op.tensor_base().unsafeGetTensorImpl()->has_symbolic_sizes_strides(),
1225 "TensorIterator does not support symbolic shapes; please implement this operator in torch/_refs "
1226 "using the elementwise or reduction helpers (look at backtrace to find out what operator this is)");
1227 auto shape = op.tensor_base().sizes();
1228 if (shape.empty()) {
1229 has_scalars = true;
1230 } else {
1231 has_tensors = true;
1232 }
1233 if (has_scalars && has_tensors) {
1234 all_ops_same_shape_ = false;
1235 }
1236 if (shape_.empty()) {
1237 shape_ = shape;
1238 } else if (!shape.equals(shape_)) {
1239 all_ops_same_shape_ = false;
1240 shape_ = infer_size_dimvector(shape_, shape);
1241 }
1242 }
1243 all_ops_are_scalars_ = !has_tensors;
1244}
1245
1246void TensorIteratorBase::compute_strides(const TensorIteratorConfig& config) {
1247 for (auto& op : operands_) {
1248 if (op.tensor_base().defined() && !op.will_resize) {
1249 IntArrayRef original_shape = config.static_shape_ ? shape_ : op.tensor_base().sizes();
1250 auto original_stride = op.tensor_base().strides();
1251 auto element_size_in_bytes = op.tensor_base().element_size();
1252 auto offset = ndim() - original_shape.size();
1253 if (offset > 0)
1254 op.stride_bytes.resize(ndim(), 0);
1255 else
1256 op.stride_bytes.resize(ndim());
1257 for (const auto i : c10::irange(original_shape.size())) {
1258 // see NOTE: [Computing output strides]
1259 if (original_shape[i] == 1 && shape_[offset + i] !=1) {
1260 op.stride_bytes[offset + i] = 0;
1261 } else {
1262 op.stride_bytes[offset + i] = original_stride[i] * element_size_in_bytes;
1263 }
1264 }
1265 }
1266 }
1267}
1268
1269bool TensorIteratorBase::can_use_32bit_indexing() const {
1270 int64_t max_value = std::numeric_limits<int32_t>::max();
1271 if (numel() > max_value) {
1272 return false;
1273 }
1274 for (auto& op : operands_) {
1275 int64_t max_offset = 1;
1276 for (const auto dim : c10::irange(ndim())) {
1277 max_offset += (shape_[dim] - 1) * op.stride_bytes[dim];
1278 }
1279 if (max_offset > max_value) {
1280 return false;
1281 }
1282 }
1283 return true;
1284}
1285
1286std::unique_ptr<TensorIterator> TensorIteratorBase::split(int dim) {
1287 TORCH_INTERNAL_ASSERT(dim >= 0 && dim < ndim() && shape()[dim] >= 2);
1288 std::unique_ptr<TensorIterator> copy(new TensorIterator(*this));
1289
1290 bool overlaps = is_dim_reduced(dim);
1291 auto copy_size = shape_[dim] / 2;
1292 auto this_size = shape_[dim] - copy_size;
1293 copy->narrow(dim, 0, copy_size);
1294 copy->final_output_ &= !overlaps;
1295 this->narrow(dim, copy_size, this_size);
1296 this->accumulate_ |= overlaps;
1297
1298 return copy;
1299}
1300
1301
1302int TensorIteratorBase::get_dim_to_split() const {
1303 TORCH_INTERNAL_ASSERT(ndim() >= 1);
1304 int64_t max_extent = -1;
1305 int dim_to_split = -1;
1306 for (int dim = ndim() - 1; dim >= 0; dim--) {
1307 const int64_t size = shape_[dim];
1308 if (size == 0) {
1309 continue;
1310 }
1311 for (auto& op : operands_) {
1312 // std::abs is necessary to handle some special cases where we support negative strides
1313 // see the CUDA backend of at::flip
1314 const int64_t extent = (size - 1) * std::abs(op.stride_bytes[dim]);
1315 if (extent > max_extent) {
1316 max_extent = extent;
1317 dim_to_split = dim;
1318 }
1319 }
1320 }
1321 TORCH_INTERNAL_ASSERT(max_extent >= 0);
1322 return dim_to_split;
1323}
1324
1325bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) {
1326 // This function tries to do a fast setup to avoid needless reordering of dimensions and tracking output strides
1327 // Return true if it can do fast setup or false otherwise
1328 // TODO enable fast handling for reductions
1329 FastSetupType setup_type = compute_fast_setup_type(config);
1330 if (setup_type == FastSetupType::NONE) {
1331 return false;
1332 }
1333
1334 // allocate memory for output, memory format depends on setup_type
1335 switch (setup_type) {
1336 case FastSetupType::CONTIGUOUS:
1337 {
1338 for (const auto i : c10::irange(num_outputs_)) {
1339 auto& op = operands_[i];
1340 if (!op.tensor_base().defined()) {
1341 TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1342 }
1343 set_output_raw_strided(i, shape_, {}, original_options(op).memory_format(MemoryFormat::Contiguous), names_);
1344 }
1345 break;
1346 }
1347 case FastSetupType::CHANNELS_LAST:
1348 {
1349 for (const auto i : c10::irange(num_outputs_)) {
1350 auto& op = operands_[i];
1351 if (!op.tensor_base().defined()) {
1352 TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1353 }
1354 set_output_raw_strided(i, shape_, {}, original_options(op).memory_format(MemoryFormat::ChannelsLast), names_);
1355 }
1356 break;
1357 }
1358 case FastSetupType::NON_OVERLAPPING_DENSE:
1359 {
1360 // find the index of a defined tensor in operands_ start from input tensor
1361 int i_defined; // NOLINT(cppcoreguidelines-init-variables)
1362 for (i_defined = ntensors() - 1; i_defined >= 0; --i_defined) {
1363 if (tensor(i_defined).defined()) break;
1364 }
1365 TORCH_CHECK(i_defined >= 0, "Can not find a defined tensor when fast allocating memory to outputs");
1366 for (const auto i : c10::irange(num_outputs_)) {
1367 auto& op = operands_[i];
1368 if (!op.tensor_base().defined()) {
1369 TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand", i);
1370 }
1371 set_output_raw_strided(i, shape_, tensor_base(i_defined).strides(), original_options(op), names_);
1372 }
1373 break;
1374 }
1375 default:
1376 TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type", c10::to_string((int)setup_type));
1377 }
1378 //coalescing dimensions consists of collapsing dimensions to 1 (we are limited to contiguous no-broadcast cases here)
1379 if (ndim() > 1){
1380 has_coalesced_dimensions_ = true;
1381 }
1382 if (ndim() >= 1) {
1383 shape_[0] = numel();
1384 shape_.resize(1);
1385 }
1386 for (auto& op : operands_ ) {
1387 auto element_size_in_bytes = op.tensor_base().element_size();
1388 op.stride_bytes.resize(ndim());
1389 if (ndim()>0) {
1390 op.stride_bytes[0] = element_size_in_bytes;
1391 }
1392 }
1393 return true;
1394}
1395
1396FastSetupType TensorIteratorBase::compute_fast_setup_type(const TensorIteratorConfig& config) {
1397 if (is_reduction_ || !all_ops_same_shape_) {
1398 return FastSetupType::NONE;
1399 }
1400
1401 // For linear iteration, only contiguous tensors can be coalesced
1402 // Fast setup of any other format requires changing iteration order
1403 if (enforce_linear_iteration_) {
1404 for (const auto& op : operands_) {
1405 if (op.tensor_base().defined() && !op.will_resize) {
1406 auto is_contiguous = op.tensor_base().is_contiguous(at::MemoryFormat::Contiguous);
1407 if (!is_contiguous) {
1408 return FastSetupType::NONE;
1409 }
1410 }
1411 }
1412 return FastSetupType::CONTIGUOUS;
1413 }
1414
1415 bool is_contiguous = true;
1416 bool is_channels_last = true;
1417 bool is_non_overlapping_and_dense = true;
1418 for (const auto& op : operands_) {
1419 if (op.tensor_base().defined() && !op.will_resize) {
1420 is_contiguous &= op.tensor_base().is_contiguous(at::MemoryFormat::Contiguous);
1421 is_channels_last &= op.tensor_base().is_contiguous(at::MemoryFormat::ChannelsLast);
1422 is_non_overlapping_and_dense &= op.tensor_base().is_non_overlapping_and_dense();
1423 }
1424 }
1425 // TODO this leads to ambiguous cases (NC11) to be always treated as contiguous
1426 if (is_contiguous) {
1427 return FastSetupType::CONTIGUOUS;
1428 }
1429 if (is_channels_last) {
1430 return FastSetupType::CHANNELS_LAST;
1431 }
1432 if (is_non_overlapping_and_dense) {
1433 int64_t prev = -1;
1434 // Fast setup is allowed only when all the defined tensors have the same shape and strides,
1435 // Iterate from back to check input tensors' strides first, then output tensors'.
1436 for (int64_t i = ntensors() - 1; i >= 0; --i) {
1437 const auto& op = operands_[i];
1438 if (op.tensor_base().defined() && !op.will_resize) {
1439 if (prev < 0) {
1440 prev = i;
1441 continue;
1442 }
1443 if (!tensor_base(prev).strides().equals(op.tensor_base().strides())) {
1444 // [Note: stride check for non contiguous tensors in fast setup]
1445 // We prevent 3 cases doing fast setup here:
1446 // 1. input tensors have different strides.
1447 // 2. output tensors won't be resized and have different strides.
1448 // 3. input tensors have the same strides, but output tensors have different strides with input tensors.
1449 // We don't allow re-stride output tensors in this case since it is not compatible with
1450 // numpy. The behavior in numpy is that if the output tensor has same shape as the input
1451 // tensor but different strides, the strides of output tensor will be preserved, so we do
1452 // the same in tensor iterator.
1453 return FastSetupType::NONE;
1454 }
1455 }
1456 }
1457 return FastSetupType::NON_OVERLAPPING_DENSE;
1458 }
1459 return FastSetupType::NONE;
1460}
1461
1462TensorIteratorBase::TensorIteratorBase() = default;
1463
1464void TensorIteratorBase::build(TensorIteratorConfig& config) {
1465 // populate some persistent configuration fields
1466 is_reduction_ = config.is_reduction_;
1467 enforce_linear_iteration_ = config.enforce_linear_iteration_;
1468
1469 // fill in operands_ based on configuration
1470 populate_operands(config);
1471 // set is_output and is_read_write flags on appropriate tensors
1472 mark_outputs();
1473 // Check that the outputs have no internal overlap
1474 // and do not share memory with inputs.
1475 compute_mem_overlaps(config);
1476 // Check that input dimensions are aligned correctly & compute outnames.
1477 compute_names(config);
1478 // compute the broadcasted shape
1479 compute_shape(config);
1480 // mark outputs for resizing if necessary
1481 mark_resize_outputs(config);
1482 // compute the result dtype and device
1483 compute_types(config);
1484 // try fast setup output tensor, if failed, fallback to normal setup
1485 if (!fast_set_up(config)) {
1486 // compute each tensor's stride after broadcasting
1487 compute_strides(config);
1488 // re-order dimensions to improve coalescing
1489 reorder_dimensions();
1490 // allocate the output tensor if it's not provided
1491 allocate_or_resize_outputs();
1492 // coalesce adjacent dimensions when possible
1493 if (!is_meta_) coalesce_dimensions();
1494 }
1495
1496 if (is_meta_) return;
1497
1498 auto has_storage = true;
1499 for (auto& op : operands_) {
1500 has_storage &= op.tensor_base().has_storage();
1501 }
1502 auto privateuse1_without_storage =
1503 common_device_.type() == DeviceType::PrivateUse1 &&
1504 !has_storage;
1505
1506 // XLA and lazy tensors don't have storage, so they don't have an underlying data pointer.
1507 // Nothing beyond this point is important for meta functions, so it's fine to exit early here.
1508 // Extend the condition to ORT tesnors as ORT tensors also don't have storage.
1509 if (privateuse1_without_storage ||
1510 common_device_.type() == DeviceType::XLA ||
1511 common_device_.type() == DeviceType::IPU ||
1512 common_device_.type() == DeviceType::Lazy ||
1513 common_device_.type() == DeviceType::ORT ||
1514 common_device_.type() == DeviceType::HPU) return;
1515
1516 for (auto& op : operands_) {
1517 TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
1518 op.data = op.tensor_base().data_ptr();
1519 }
1520
1521 // zero out offsets
1522 // If the tensor is a scalar, we leave room for it
1523 // So index translations in reduction can access
1524 // a valid value for the offset
1525 int64_t ndim_offsets = (ndim() ? ndim() : 1);
1526 view_offsets_ = DimVector(ndim_offsets, 0);
1527}
1528
1529// This is the structured kernels' implementation of set_output. It is
1530// NEVER actually called directly; instead, a subclass of TensorIteratorBase
1531// will override set_output to actually do the operation, and then call
1532// set_output on the TensorIteratorBase to setup TI's metadata.
1533// The precondition for this function is that maybe_get_output() now
1534// unconditionally returns a real Tensor (prior to output setting,
1535// this function may return an undefined tensor.)
1536void TensorIteratorBase::set_output_raw_strided(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
1537 auto& op = operands_[output_idx];
1538 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
1539 const auto& t = maybe_get_output(output_idx);
1540 TORCH_INTERNAL_ASSERT(t.defined());
1541 if (!op.tensor_base().defined()) {
1542 op.tensor(c10::MaybeOwned<TensorBase>::borrowed(t));
1543 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.target_dtype == t.scalar_type());
1544 } else if (op.will_resize) {
1545 if (op.original_tensor_base().defined()) {
1546 // OK, so this is pretty weird. To understand how we can end up in
1547 // this situation, first look at Marker [Output original_tensor is set].
1548 // That is the sole site where original_tensor may be set on an
1549 // output operand. Essentially, when we are given an explicit output
1550 // tensor whose dtype doesn't match the computed common dtype from
1551 // the input operands, we do a switcheroo: we replace the (incorrectly
1552 // typed) output tensor with a correctly typed, *temporary* tensor,
1553 // and remember the original tensor in original_tensor (which will
1554 // then get written back to when we cast_outputs).
1555 //
1556 // Now, what if the given output tensor also happened to be zero
1557 // size (meaning that we will_resize it)? Well, at the call site
1558 // above, we don't necessarily(*) know what the correct shape should
1559 // be, so we give the temporary tensor the same shape as the original.
1560 // At the time of set_output is when we DO know what the correct size
1561 // is, and the subclass's implementation of set_output in structured class
1562 // responsible for resizing original_tensor. But we still have this
1563 // incorrectly sized temporary output which the structured subclass
1564 // knows nothing about, so we are obligated to also resize it here.
1565 //
1566 // This is a slight memory pessimization, because previously
1567 // original_tensor only got resized at the end of the computation, rather
1568 // than at the beginning (as happens here). However, the peak memory
1569 // usage is the same, since you need to materialize both original tensor
1570 // and temporary tensor to do the copy.
1571 //
1572 // (*) Actually, technically, we probably do know what the shape
1573 // should be, since we do shape computation before dtype computation.
1574 // So hypothetically we could figure out what the correct shape is
1575 // at that point in time and directly allocate the temporary at
1576 // the right size.
1577 //
1578 // But a better solution is to delay allocation of temporaries until
1579 // after TensorIterator builder, waiting until we actually want
1580 // to do the computation. That would also remove the necessity
1581 // for the is_meta_ test.
1582 TORCH_INTERNAL_ASSERT(op.original_tensor_base().is_same(t));
1583 TORCH_INTERNAL_ASSERT(!op.tensor_base().is_same(t));
1584 OptionalTensorRef tensor(op.tensor());
1585 at::native::resize_output(*tensor, sizes);
1586 if (!strides.empty()) {
1587 TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
1588 tensor->as_strided_(sizes, strides);
1589 } else if (options.memory_format_opt().has_value()) {
1590 tensor->unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
1591 }
1592 }
1593 }
1594 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
1595 op.tensor_base().is_same(t) || op.current_dtype == op.tensor_base().scalar_type());
1596// For simplicity, just always update the cached current_type.
1597 op.current_dtype = op.tensor_base().scalar_type();
1598}
1599
1600// This is the "traditional" implementation of set_output. On TensorIterator
1601// instances, it is invoked directly from various call sites in this file. No
1602// funny business.
1603void TensorIterator::set_output_raw_strided(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) {
1604 // NB: intentionally no superclass call
1605 auto& op = operands_[output_idx];
1606 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_);
1607 if (!op.tensor_base().defined()) {
1608 if (strides.empty()) {
1609 op.tensor(c10::MaybeOwned<TensorBase>::owned(at::empty(sizes, options)));
1610 } else {
1611 op.tensor(c10::MaybeOwned<TensorBase>::owned(at::empty_strided(sizes, strides, options)));
1612 }
1613 op.current_dtype = op.target_dtype;
1614 } else if (op.will_resize) {
1615 at::native::resize_output(op.tensor(), sizes);
1616 if (!strides.empty()) {
1617 TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
1618 op.tensor().as_strided_(sizes, strides);
1619 } else if (options.memory_format_opt().has_value()) {
1620 op.tensor_base().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
1621 }
1622 }
1623 if (!names.empty()) {
1624 TORCH_INTERNAL_ASSERT(op.tensor_base().defined());
1625 namedinference::propagate_names(op.tensor_base(), names);
1626 }
1627}
1628
1629// Not actually used by anything (TensorIterator subclass calls
1630// its own implementation of set_output which knows exactly where
1631// all the outputs are), but we have to provide all pure virtual methods
1632// for MetaBase
1633const Tensor& TensorIterator::maybe_get_output(int64_t output_idx) {
1634 return output(output_idx);
1635}
1636
1637SplitUntil32Bit TensorIteratorBase::with_32bit_indexing() const {
1638 return SplitUntil32Bit(*this);
1639}
1640
1641/// SplitUntil32Bit. Recursively splits an iterator into sub-iterators that
1642/// can use 32-bit indexing.
1643
1644SplitUntil32Bit::iterator::iterator(const TensorIteratorBase& iter) {
1645 vec.emplace_back(new TensorIterator(iter));
1646 vec.emplace_back(nullptr); // ++ first pops the last element
1647 ++(*this);
1648}
1649
1650SplitUntil32Bit::iterator& SplitUntil32Bit::iterator::operator++() {
1651 vec.pop_back();
1652 while (!vec.empty() && !vec.back()->can_use_32bit_indexing()) {
1653 auto& iter = *vec.back();
1654 int64_t split_dim = iter.get_dim_to_split();
1655 vec.emplace_back(iter.split(split_dim));
1656 }
1657 return *this;
1658}
1659
1660TensorIterator& SplitUntil32Bit::iterator::operator*() const {
1661 return *vec.back();
1662}
1663
1664SplitUntil32Bit::iterator SplitUntil32Bit::begin() const {
1665 return SplitUntil32Bit::iterator(iter);
1666}
1667
1668SplitUntil32Bit::iterator SplitUntil32Bit::end() const {
1669 return SplitUntil32Bit::iterator();
1670}
1671
1672DimCounter::DimCounter(IntArrayRef shape, Range range)
1673 : shape(shape)
1674 , range(range)
1675 , values(shape.size())
1676 , offset(range.begin) {
1677 std::fill(values.begin(), values.end(), 0);
1678 if (range.begin == 0) {
1679 return;
1680 }
1681
1682 int64_t linear_offset = range.begin;
1683 int64_t ndim = values.size();
1684 for (const auto dim : c10::irange(ndim)) {
1685 int64_t size = shape[dim];
1686 if (size > 0) {
1687 values[dim] = linear_offset % size;
1688 linear_offset /= size;
1689 }
1690 }
1691 TORCH_INTERNAL_ASSERT(linear_offset == 0);
1692}
1693
1694bool DimCounter::is_done() const {
1695 return offset >= range.end;
1696}
1697
1698void DimCounter::increment(const std::array<int64_t, 2>& step) {
1699 offset += step[0] * step[1];
1700 int64_t ndim = values.size();
1701 int64_t overflow = step[0];
1702 int i = 0;
1703 if (step[1] != 1) {
1704 TORCH_INTERNAL_ASSERT(step[0] == shape[0] && values[0] == 0);
1705 i = 1;
1706 overflow = step[1];
1707 }
1708 for (; i < ndim && overflow > 0; i++) {
1709 auto size = shape[i];
1710 auto prev = values[i];
1711 auto value = prev + overflow;
1712 if (value >= size) {
1713 overflow = 1;
1714 value -= size;
1715 TORCH_INTERNAL_ASSERT(value < size);
1716 } else {
1717 overflow = 0;
1718 }
1719 values[i] = value;
1720 }
1721 TORCH_INTERNAL_ASSERT(overflow == 0 || overflow == 1);
1722}
1723
1724std::array<int64_t, 2> DimCounter::max_2d_step() const {
1725 int64_t step0 = std::min(shape[0] - values[0], range.end - offset);
1726 int64_t step1 = 1;
1727 if (step0 == shape[0] && !shape.empty()) {
1728 step1 = std::min(shape[1] - values[1], (range.end - offset) / shape[0]);
1729 }
1730 return {step0, step1};
1731}
1732
1733} // namespace at
1734