1 | #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
2 | #define TORCH_ASSERT_NO_OPERATORS |
3 | #include <ATen/TensorIterator.h> |
4 | #undef TORCH_ASSERT_NO_OPERATORS |
5 | |
6 | #include <ATen/core/Tensor.h> |
7 | |
8 | #include <ATen/ExpandUtils.h> |
9 | #include <ATen/Parallel.h> |
10 | #include <ATen/native/TypeProperties.h> |
11 | #include <ATen/MemoryOverlap.h> |
12 | #include <ATen/native/Resize.h> |
13 | #include <ATen/NamedTensorUtils.h> |
14 | #include <ATen/TensorOperators.h> |
15 | #include <ATen/TensorIteratorInternal.h> |
16 | |
17 | #ifndef AT_PER_OPERATOR_HEADERS |
18 | #include <ATen/Functions.h> |
19 | #else |
20 | #include <ATen/ops/empty.h> |
21 | #include <ATen/ops/empty_strided.h> |
22 | #endif |
23 | |
24 | #include <c10/util/irange.h> |
25 | #include <c10/util/SmallBuffer.h> |
26 | |
27 | #include <array> |
28 | #include <algorithm> |
29 | #include <cmath> |
30 | |
31 | namespace at { |
32 | |
33 | using DimMask = TensorIteratorBase::DimMask; |
34 | using PtrVector = TensorIteratorBase::PtrVector; |
35 | using loop2d_t = TensorIteratorBase::loop2d_t; |
36 | using StrideVector = TensorIteratorBase::StrideVector; |
37 | |
38 | namespace { |
39 | |
40 | inline void get_base_ptrs(char** ptrs, ArrayRef<OperandInfo> operands) { |
41 | std::transform(operands.begin(), operands.end(), ptrs, [](const OperandInfo& op) { |
42 | return static_cast<char*>(op.data); |
43 | }); |
44 | } |
45 | |
46 | inline void get_strides(int64_t* strides, ArrayRef<OperandInfo> operands, int64_t ndim) { |
47 | for (const auto dim : c10::irange(ndim)) { |
48 | for (const auto arg : c10::irange(operands.size())) { |
49 | *strides++ = operands[arg].stride_bytes[dim]; |
50 | } |
51 | } |
52 | // Always at least 2d strides to support 2d for_each loops |
53 | if (ndim < 2) { |
54 | const int64_t ntensors = operands.size(); |
55 | std::fill_n(strides, (2 - ndim) * ntensors, 0); |
56 | } |
57 | } |
58 | |
59 | static OptionalTensorRef make_otr(const TensorBase &tensor) { |
60 | if (tensor.defined()) { |
61 | return OptionalTensorRef(tensor); |
62 | } else { |
63 | return OptionalTensorRef(); |
64 | } |
65 | } |
66 | |
67 | } |
68 | |
69 | namespace internal { |
70 | |
71 | OpaqueOptionalTensorRef::OpaqueOptionalTensorRef() { |
72 | static_assert(alignof(OptionalTensorRef) == alignof(TensorBase)); |
73 | static_assert(sizeof(OptionalTensorRef) == sizeof(TensorBase)); |
74 | new (data_.data()) OptionalTensorRef(); |
75 | } |
76 | |
77 | OpaqueOptionalTensorRef::~OpaqueOptionalTensorRef() { |
78 | get()->~OptionalTensorRef(); |
79 | } |
80 | |
81 | const Tensor& OpaqueOptionalTensorRef::getTensor() const { |
82 | return get()->getTensorRef(); |
83 | } |
84 | |
85 | } |
86 | |
87 | void OperandInfo::tensor(c10::MaybeOwned<TensorBase> &&tensor) { |
88 | tensor_base_ = std::move(tensor); |
89 | *tensor_storage_ = make_otr(*tensor_base_); |
90 | } |
91 | |
92 | void OperandInfo::exchange_tensor(c10::MaybeOwned<TensorBase> &&new_tensor) { |
93 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!original_tensor_base_->defined()); |
94 | original_tensor_base_ = std::exchange(tensor_base_, new_tensor); |
95 | *original_tensor_storage_ = std::exchange(*tensor_storage_, make_otr(*tensor_base_)); |
96 | } |
97 | |
98 | void OperandInfo::restore_original_tensor() { |
99 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(original_tensor_base_->defined()); |
100 | tensor_base_ = std::move(original_tensor_base_); |
101 | *tensor_storage_ = std::exchange(*original_tensor_storage_, OptionalTensorRef{}); |
102 | } |
103 | |
104 | /// Construction |
105 | TensorIteratorConfig& TensorIteratorConfig::add_owned_output(const TensorBase& output) { |
106 | TORCH_INTERNAL_ASSERT( |
107 | num_inputs_ == 0, |
108 | "Keep in mind that you have to add all outputs first before adding any input. " |
109 | "For more details, see https://github.com/pytorch/pytorch/wiki/How-to-use-TensorIterator." ); |
110 | tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(c10::in_place, output)); |
111 | num_outputs_++; |
112 | return *this; |
113 | } |
114 | |
115 | TensorIteratorConfig& TensorIteratorConfig::add_owned_input(const TensorBase& input) { |
116 | tensors_.push_back(c10::MaybeOwned<TensorBase>::owned(c10::in_place, input)); |
117 | num_inputs_++; |
118 | return *this; |
119 | } |
120 | |
121 | TensorIteratorConfig& TensorIteratorConfig::add_borrowed_output(const TensorBase& output) { |
122 | TORCH_INTERNAL_ASSERT( |
123 | num_inputs_ == 0, |
124 | "Keep in mind that you have to add all outputs first before adding any input. " |
125 | "For more details, see https://github.com/pytorch/pytorch/wiki/How-to-use-TensorIterator." ); |
126 | tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(output)); |
127 | num_outputs_++; |
128 | return *this; |
129 | } |
130 | |
131 | TensorIteratorConfig& TensorIteratorConfig::add_borrowed_input(const TensorBase& input) { |
132 | tensors_.push_back(c10::MaybeOwned<TensorBase>::borrowed(input)); |
133 | num_inputs_++; |
134 | return *this; |
135 | } |
136 | |
137 | TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype_and_device(ScalarType dtype, Device device) { |
138 | TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)" ); |
139 | static_dtype_ = dtype; |
140 | static_device_ = device; |
141 | return *this; |
142 | } |
143 | |
144 | TensorIteratorConfig& TensorIteratorConfig::declare_static_dtype(ScalarType dtype) { |
145 | TORCH_CHECK(!check_all_same_dtype_, "check_all_same_dtype(false) must be called before declare_static_dtype(...)" ); |
146 | static_dtype_ = dtype; |
147 | return *this; |
148 | } |
149 | |
150 | TensorIteratorConfig& TensorIteratorConfig::declare_static_device(Device device) { |
151 | static_device_ = device; |
152 | return *this; |
153 | } |
154 | |
155 | TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape) { |
156 | // WARNING: |
157 | // This will bypass all shape checking in the TensorIterator. Kernels which call this method |
158 | // are expected to check shapes before calling `add_owned_input` or `add_owned_output`. |
159 | TORCH_CHECK(!resize_outputs_, "resize_outputs() must be called before declare_static_shape(...)" ) |
160 | static_shape_ = c10::make_optional(DimVector(shape)); |
161 | return *this; |
162 | } |
163 | |
164 | TensorIteratorConfig& TensorIteratorConfig::declare_static_shape(IntArrayRef shape, IntArrayRef squash_dims) { |
165 | declare_static_shape(shape); |
166 | if (static_shape_->empty()) return *this; |
167 | for (const auto& squash_dim : squash_dims) { |
168 | TORCH_CHECK(squash_dim >= 0 && squash_dim < static_cast<int64_t>(static_shape_->size()), |
169 | "squash_dim " , squash_dim, " must be in [0, " , static_shape_->size(), ")." ); |
170 | (*static_shape_)[squash_dim] = 1; |
171 | } |
172 | return *this; |
173 | } |
174 | |
175 | // NOTE: [Computing output strides] |
176 | // We use the following algorithm to compute output strides |
177 | // If correctly sized output is provided, we respect its stides and don't change them |
178 | // Otherwise, if provided output is of incorrect size or no output is provided, |
179 | // we try to recover permutation that was applied to the inputs |
180 | // by sorting the strides of the inputs. Precedence is given to the inputs in the order they were added, |
181 | // and to permutations involving non-broadcasted dimensions |
182 | // 1. we loop over inputs starting from the first |
183 | // 2. for all inputs strides of broadcasted dimensions are set to 0, and 0 compares equal to anything. If one |
184 | // of the dimensions being compared has a stride of 0, we move on to the next tensor to determine if |
185 | // these dimensions need to be swapped. |
186 | // 3. strides of dimensions equal to 1 participate in sorting |
187 | // 4. if 2 strides are equal and neither is 0, we try to break the tie by looking at the corresponding dimensions |
188 | // of the tensor. Dimensions were permuted if, when iterating from the end, dimensions corresponding to the |
189 | // same strides are increasing. If dimensions are non-increasing, we move on to the next input to break the tie. |
190 | // |
191 | // Instead of applying rule 4 for tie breaking, we could move on to the next tensor directly. This would result in possibly |
192 | // losing the correct permuation of the first tensor if there are permuted trivial dimensions, but could potentially |
193 | // improve traversal order of the second tensor. We chose the former option to better propagate channels last layout |
194 | // for example for a tensor with the sizes N1H1 |
195 | // These rules result in the intuitive behavior that in most cases recovers permutation of either the first argument (if all |
196 | // arguments are of the same size) or the argument that is not broadcasted, regardless of its position. |
197 | // As a bonus, it also result in reasonably well-behaved traversal order of the inputs and outputs - in the kernels |
198 | // output is traversed linearly, and since it closely follows input layouts, inputs are traversed linearly as well |
199 | // |
200 | // Examples: |
201 | // full size tensor + broadcasted tensor with 0 or 1 non-trivial dimensions => strides of output are same |
202 | // as strides of full size input regardless of the order |
203 | // 2 tensors of same size but different strides => output strides are the same as first argument |
204 | // |
205 | // We also have fast path for memory-dense inputs with the same strides (or, trivially, single memory-dense input) |
206 | // that outputs a tensor with the same strides as inputs. The only difference in result with the algorithm described |
207 | // above is for strides for trivial (1) dimensions, where in ambiguous cases for performance reasons we default to |
208 | // contiguous strides. |
209 | // Example: tensor with sizes NC11 and strides C1CC will produce output with strides C111 (note differences are only |
210 | // in the strides of trivial dimensions, so physical layout is unaffected but permutation information is lost) |
211 | // We might change this behavior in future once performance considerations are resolved |
212 | |
213 | void TensorIteratorBase::reorder_dimensions() { |
214 | // Sort the dimensions based on strides in ascending order with reduced dims |
215 | // at the front. NOTE: that this inverts the order of C-contiguous tensors. |
216 | // strides[0] is the fastest moving dimension instead of strides[ndim - 1]. |
217 | // See NOTE: [Computing output strides] and inline comments for more detailed description |
218 | |
219 | perm_.resize(ndim()); |
220 | if (ndim() == 1) { |
221 | perm_[0] = 0; |
222 | return; |
223 | } |
224 | |
225 | // initialize perm with n-1, n-2, ..., 1, 0 |
226 | std::iota(perm_.rbegin(), perm_.rend(), 0); |
227 | |
228 | // Reordering dimensions changes iteraton order |
229 | if (enforce_linear_iteration_) { |
230 | permute_dimensions(perm_); |
231 | return; |
232 | } |
233 | |
234 | // returns 1 if the dim0 should come after dim1, -1 if dim0 should come |
235 | // before dim1, and 0 if the comparison is ambiguous. |
236 | auto should_swap = [&](size_t dim0, size_t dim1) { |
237 | for (const auto arg : c10::irange(ntensors())) { |
238 | // ignore undefined or incorrectly sized tensors |
239 | if (operands_[arg].stride_bytes.empty() || operands_[arg].will_resize) { |
240 | continue; |
241 | } |
242 | int64_t stride0 = operands_[arg].stride_bytes[dim0]; |
243 | int64_t stride1 = operands_[arg].stride_bytes[dim1]; |
244 | if (is_reduction_ && operands_[arg].is_output) { |
245 | // move reduced dimensions to the front |
246 | // strides of reduced dimensions are always set to 0 by review_reduce_result |
247 | if ((stride0 == 0) != (stride1 == 0)) { |
248 | return stride1 == 0 ? 1 : -1; |
249 | } |
250 | } |
251 | //move on to the next input if one of the dimensions is broadcasted |
252 | if (stride0 == 0 || stride1 == 0) { |
253 | continue; |
254 | // it is important to return here only with strict comparisons, for equal strides we try to break the tie later |
255 | // by comparing corresponding dimensions or if that does not work, moving on to the next tensor |
256 | } else if (stride0 < stride1) { |
257 | return -1; |
258 | } else if (stride0 > stride1) { |
259 | return 1; |
260 | } else { //equal strides, use dimensions themselves as the tie-breaker. |
261 | //at this point, with zero strides out of the way, we are guaranteed that operand dimensions are equal to shape_ |
262 | auto t_dim0 = shape_[dim0]; |
263 | auto t_dim1 = shape_[dim1]; |
264 | //return only if dimensions should be swapped, otherwise move on to the next tensor |
265 | if (t_dim0 > t_dim1) { |
266 | return 1; |
267 | } |
268 | } |
269 | } |
270 | return 0; |
271 | }; |
272 | |
273 | // insertion sort with support for ambiguous comparisons |
274 | for (const auto i : c10::irange(1, ndim())) { |
275 | int dim1 = i; |
276 | for (int dim0 = i - 1; dim0 >= 0; dim0--) { |
277 | int comparison = should_swap(perm_[dim0], perm_[dim1]); |
278 | if (comparison > 0) { |
279 | std::swap(perm_[dim0], perm_[dim1]); |
280 | dim1 = dim0; |
281 | } else if (comparison < 0) { |
282 | break; |
283 | } |
284 | } |
285 | } |
286 | |
287 | // perform re-ordering of shape and strides |
288 | permute_dimensions(perm_); |
289 | } |
290 | |
291 | // Computes a common dtype using type promotion |
292 | // See the [Common Dtype Computation] note |
293 | ScalarType TensorIteratorBase::compute_common_dtype() { |
294 | at::native::ResultTypeState state = {}; |
295 | for (const auto& op : operands_) { |
296 | if (op.is_output) { |
297 | continue; |
298 | } |
299 | |
300 | state = at::native::update_result_type_state(op.tensor(), state); |
301 | } |
302 | |
303 | common_dtype_ = at::native::result_type(state); |
304 | TORCH_INTERNAL_ASSERT(common_dtype_ != ScalarType::Undefined); |
305 | |
306 | return common_dtype_; |
307 | } |
308 | |
309 | TensorOptions original_options(const OperandInfo& op) { |
310 | if (op.original_tensor_base().defined()) { |
311 | return op.original_tensor_base().options(); |
312 | } else { |
313 | return op.options(); |
314 | } |
315 | } |
316 | |
317 | // Implements the the behavior of the following flags: |
318 | // - check_all_same_dtype_ |
319 | // - check_all_same_device_ |
320 | // - enforce_safe_casting_to_output_ |
321 | // - promote_inputs_to_common_dtype_ |
322 | // - cast_common_dtype_to_outputs_ |
323 | // |
324 | // See their descriptions in TensorIterator.h for details. |
325 | // NOTE: Checks for more specific behaviors (e.g. the first and second |
326 | // inputs must share a dtype, but the third must have the long dtype) |
327 | // should be implemented directly and outside of TensorIterator. |
328 | void TensorIteratorBase::compute_types(const TensorIteratorConfig& config) { |
329 | // Reviews operands (1/2) |
330 | // - validates that all input tensors are defined |
331 | // - computes common device |
332 | // - determines if there are undefined outputs |
333 | // - determines if there are different dtypes and attempts |
334 | // to quickly acquire a common dtype |
335 | Device common_device = kCPU; |
336 | common_dtype_ = ScalarType::Undefined; |
337 | // NB: despite output_dtype's generic sounding name, it only is |
338 | // used in a nontrivial way if check_all_same_dtype is true |
339 | ScalarType output_dtype = ScalarType::Undefined; |
340 | bool has_different_input_dtypes = false; |
341 | bool has_different_output_dtypes = false; |
342 | bool has_undefined_outputs = false; |
343 | |
344 | for (auto& op : operands_) { |
345 | // Validates that all inputs have type information, and that |
346 | // if an output is missing type information that we can infer |
347 | // the device it should be allocated on. |
348 | if (!op.is_type_defined()) { |
349 | TORCH_INTERNAL_ASSERT(op.is_output, "Found type undefined input tensor!" ); |
350 | |
351 | if (config.static_dtype_.has_value()) { |
352 | op.target_dtype = config.static_dtype_.value(); |
353 | } else { |
354 | has_undefined_outputs = true; |
355 | } |
356 | |
357 | if (config.static_device_.has_value()) { |
358 | op.device = config.static_device_.value(); |
359 | } else { |
360 | TORCH_INTERNAL_ASSERT(config.check_all_same_device_); |
361 | } |
362 | |
363 | if (has_undefined_outputs || !op.device.has_value()) { |
364 | continue; |
365 | } |
366 | } |
367 | |
368 | // Validates input tensors are defined |
369 | if (!op.tensor_base().defined()) { |
370 | TORCH_INTERNAL_ASSERT(op.is_output, "Found undefined input tensor!" ); |
371 | continue; |
372 | } |
373 | |
374 | TORCH_INTERNAL_ASSERT(op.target_dtype == op.current_dtype) |
375 | |
376 | // Acquires the first non-CPU device (if any) as the common device |
377 | if (common_device == kCPU && !op.tensor_base().is_cpu()) { |
378 | common_device = op.tensor_base().device(); |
379 | } |
380 | |
381 | if (!op.is_output) { |
382 | // Determines if there are varying input dtypes |
383 | // NOTE: the common dtype is set to the first defined input dtype observed |
384 | if (op.target_dtype != common_dtype_) { |
385 | if (common_dtype_ == ScalarType::Undefined) { |
386 | common_dtype_ = op.target_dtype; |
387 | } else { |
388 | has_different_input_dtypes = true; |
389 | } |
390 | } |
391 | } else { // op.is_output |
392 | // Determines if there are varying output dtypes |
393 | // NOTE: the output dtype is set to the first defined output dtype observed |
394 | if (op.target_dtype != output_dtype) { |
395 | if (output_dtype == ScalarType::Undefined) { |
396 | output_dtype = op.target_dtype; |
397 | } else { |
398 | has_different_output_dtypes = true; |
399 | } |
400 | } |
401 | } |
402 | } |
403 | |
404 | // Checks that either the computation type is computable or unneeded |
405 | TORCH_INTERNAL_ASSERT(!(has_different_input_dtypes && !config.promote_inputs_to_common_dtype_ && |
406 | (has_undefined_outputs || config.enforce_safe_casting_to_output_ || |
407 | config.cast_common_dtype_to_outputs_))); |
408 | |
409 | // Checks that all inputs and defined outputs are the same dtype, if requested |
410 | if (config.check_all_same_dtype_ && |
411 | (has_different_input_dtypes || has_different_output_dtypes || |
412 | (common_dtype_ != output_dtype && output_dtype != ScalarType::Undefined))) { |
413 | // Throws an informative error message |
414 | for (auto& op : operands_) { |
415 | if (!op.tensor_base().defined()) { |
416 | continue; |
417 | } |
418 | |
419 | TORCH_CHECK(op.target_dtype == common_dtype_, |
420 | "Found dtype " , op.target_dtype, " but expected " , common_dtype_); |
421 | } |
422 | } |
423 | |
424 | // Short-circuits if no additional work required |
425 | if (!has_undefined_outputs && !config.check_all_same_device_ && |
426 | !config.promote_inputs_to_common_dtype_ && !config.cast_common_dtype_to_outputs_ && |
427 | !config.enforce_safe_casting_to_output_) { |
428 | // Invalidates common_dtype_ if it could not be inferred |
429 | common_dtype_ = has_different_input_dtypes ? ScalarType::Undefined : common_dtype_; |
430 | return; |
431 | } |
432 | |
433 | // Computes a common dtype, if needed |
434 | if ((has_different_input_dtypes || all_ops_are_scalars_) && config.promote_inputs_to_common_dtype_) { |
435 | common_dtype_ = compute_common_dtype(); |
436 | } |
437 | |
438 | // Promotes common dtype to the default float scalar type, if needed |
439 | if (config.promote_integer_inputs_to_float_ && |
440 | c10::isIntegralType(common_dtype_, /*includeBool=*/true)) { |
441 | common_dtype_ = c10::typeMetaToScalarType(c10::get_default_dtype()); |
442 | } |
443 | |
444 | // Reviews operands (2/2) |
445 | // - sets metadata for undefined outputs |
446 | // - checks that all tensors are on the same device, if requested |
447 | // - checks that the common dtype can safely cast to each output, if requested |
448 | // - creates temporaries for CPU operations, if needed and requested |
449 | common_device_ = common_device; |
450 | int max_cpu_scalars_on_non_cpu = config.allow_cpu_scalars_ ? 1 : 0; |
451 | int current_cpu_scalars_on_non_cpu = 0; |
452 | for (auto& op : operands_) { |
453 | bool is_type_defined = op.is_type_defined(); |
454 | bool is_device_defined = op.is_device_defined(); |
455 | |
456 | if (!is_type_defined) { |
457 | op.target_dtype = common_dtype_; |
458 | } |
459 | if (!is_device_defined) { |
460 | op.device = common_device; |
461 | } |
462 | |
463 | if (!is_type_defined && !is_device_defined) { |
464 | continue; |
465 | } |
466 | |
467 | // Skips undefined tensors |
468 | if (!op.tensor_base().defined()) { |
469 | continue; |
470 | } |
471 | |
472 | // Checks all tensors are on the same device, if requested |
473 | if (config.check_all_same_device_) { |
474 | // Handles CPU scalars on CUDA kernels that support them |
475 | if (!common_device.is_cpu() && |
476 | config.allow_cpu_scalars_ && !op.is_output && op.tensor_base().dim() == 0 && |
477 | op.tensor_base().is_cpu()) { |
478 | TORCH_CHECK(current_cpu_scalars_on_non_cpu < max_cpu_scalars_on_non_cpu, |
479 | "Trying to pass too many CPU scalars to non-CPU kernel!" ); |
480 | ++current_cpu_scalars_on_non_cpu; |
481 | } else if (op.device.value() != common_device) { |
482 | TORCH_CHECK(false, |
483 | "Expected all tensors to be on the same device, but " |
484 | "found at least two devices, " , common_device, " and " , op.device.value(), "!" ); |
485 | } |
486 | } |
487 | |
488 | // Checks safe casting, if requested |
489 | if (config.enforce_safe_casting_to_output_ && op.is_output && op.current_dtype != common_dtype_) { |
490 | TORCH_CHECK(canCast(common_dtype_, op.current_dtype), |
491 | "result type " , common_dtype_, " can't be cast to the " |
492 | "desired output type " , op.current_dtype); |
493 | } |
494 | |
495 | // Creates temporaries for CPU operations, if needed and requested |
496 | // TODO: reuse temporaries when possible (e.g. for inplace operations) |
497 | if (common_device == kCPU) { |
498 | // Casts to outputs by creating temporaries of the correct dtype (if needed) |
499 | // NB: we skip this on is_meta_, because the temporary allocation here is |
500 | // unnecessary if we aren't going to actually do the compute |
501 | if (config.cast_common_dtype_to_outputs_ && op.is_output && op.current_dtype != common_dtype_ && !is_meta_) { |
502 | TORCH_INTERNAL_ASSERT(op.tensor_base().defined()); |
503 | // Marker [Output original_tensor is set] |
504 | // NB: do NOT use set_output here, as the temporary is NOT a true output; |
505 | // op.tensor is the true output and it was pre-provided for us. |
506 | // TODO: The logic for cast_outputs will need to be handled by the |
507 | // structured kernels implementation. What probably should happen |
508 | // is that we pass in the inferred dtype into the out kernel, and |
509 | // then after calling the out kernel, do the conversion (which |
510 | // is cast_outputs here), but integrating this with existing |
511 | // TensorIterator will take a little doing |
512 | op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned( |
513 | at::empty_like(op.tensor(), |
514 | op.tensor_base().options().dtype(common_dtype_), |
515 | LEGACY_CONTIGUOUS_MEMORY_FORMAT))); |
516 | if (!names_.empty()) { |
517 | namedinference::propagate_names(op.tensor_base(), names_); |
518 | } |
519 | op.current_dtype = common_dtype_; |
520 | op.target_dtype = common_dtype_; |
521 | } |
522 | |
523 | // Promotes inputs by creating temporaries of the correct dtype |
524 | if (config.promote_inputs_to_common_dtype_ && !op.is_output && op.current_dtype != common_dtype_) { |
525 | op.exchange_tensor(c10::MaybeOwned<TensorBase>::owned(op.tensor().to(common_dtype_))); |
526 | op.current_dtype = common_dtype_; |
527 | op.target_dtype = common_dtype_; |
528 | } |
529 | } |
530 | } |
531 | } |
532 | |
533 | StrideVector TensorIteratorBase::compatible_stride(int element_size) const { |
534 | auto stride = StrideVector(); |
535 | int64_t next_stride = element_size; |
536 | for (const auto dim : c10::irange(ndim())) { |
537 | stride.push_back(next_stride); |
538 | next_stride *= shape_[dim]; |
539 | } |
540 | return stride; |
541 | } |
542 | |
543 | DimVector TensorIteratorBase::invert_perm(IntArrayRef input) const { |
544 | // Invert the permutation caused by reorder_dimensions. This is not valid |
545 | // after coalesce_dimensions is called. |
546 | TORCH_INTERNAL_ASSERT(!has_coalesced_dimensions_); |
547 | TORCH_INTERNAL_ASSERT(input.size()==perm_.size()); |
548 | auto res = DimVector(input.size()); //no initialization needed, every value in res should be written to. |
549 | for (const auto dim : c10::irange(ndim())) { |
550 | res[perm_[dim]] = input[dim]; |
551 | } |
552 | return res; |
553 | } |
554 | |
555 | void TensorIteratorBase::allocate_or_resize_outputs() { |
556 | for (const auto i : c10::irange(num_outputs_)) { |
557 | auto& op = operands_[i]; |
558 | if (!op.tensor_base().defined() || op.will_resize) { |
559 | TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand" , i); |
560 | int element_size = elementSize(op.target_dtype); |
561 | op.stride_bytes = compatible_stride(element_size); |
562 | // check if permutation is just an inverted order |
563 | bool inverted = true; |
564 | for (const auto j : c10::irange(ndim())) { |
565 | if (perm_[j] != ndim() - j - 1) { |
566 | inverted = false; |
567 | break; |
568 | } |
569 | } |
570 | auto tensor_shape = invert_perm(shape_); |
571 | if (inverted) { |
572 | // can just return contiguous output |
573 | // it is faster because it avoids allocating 0 size tensor and |
574 | // resizing and restriding it |
575 | set_output_raw_strided(i, tensor_shape, {}, original_options(op), names_); |
576 | } else { |
577 | auto tensor_stride = invert_perm(op.stride_bytes); |
578 | for (const auto dim : c10::irange(ndim())) { |
579 | tensor_stride[dim] /= element_size; |
580 | } |
581 | set_output_raw_strided(i, tensor_shape, tensor_stride, original_options(op), names_); |
582 | } |
583 | op.current_dtype = op.target_dtype; |
584 | } else if (op.tensor_base().defined()) { |
585 | // Even if we don't resize, we still need to tell set_output about |
586 | // the output, so that we properly set guard and propagate names |
587 | set_output_raw_strided(i, op.tensor_base().sizes(), {}, original_options(op), names_); |
588 | } |
589 | } |
590 | } |
591 | |
592 | void TensorIteratorBase::compute_names(const TensorIteratorConfig& config) { |
593 | bool should_infer_names = std::any_of( |
594 | operands_.begin(), |
595 | operands_.end(), |
596 | [](const OperandInfo& op) { |
597 | return op.tensor_base().defined() && op.tensor_base().has_names(); |
598 | }); |
599 | if (!should_infer_names) { |
600 | return; |
601 | } |
602 | |
603 | for (auto& op : operands_) { |
604 | if (!op.tensor_base().defined()) continue; |
605 | // Don't include output tensors if we are resizing, since we will |
606 | // clobber their names in any case. (If the output tensor was |
607 | // also an input tensor, we'll pick it up when it shows up again |
608 | // in operands). |
609 | if (config.resize_outputs_ && op.is_output) continue; |
610 | // perform name inference |
611 | if (names_.empty()) { |
612 | names_ = op.tensor_base().names(); |
613 | } else { |
614 | names_ = NameVector(unify_from_right(names_, op.tensor_base().names())); |
615 | } |
616 | } |
617 | } |
618 | |
619 | void TensorIteratorBase::coalesce_dimensions() { |
620 | if (ndim() <= 1) { |
621 | return; |
622 | } |
623 | |
624 | // We can coalesce two adjacent dimensions if either dim has size 1 or if: |
625 | // shape[n] * stride[n] == shape[n + 1]. |
626 | auto can_coalesce = [&](int dim0, int dim1) { |
627 | auto shape0 = shape_[dim0]; |
628 | auto shape1 = shape_[dim1]; |
629 | if (shape0 == 1 || shape1 == 1) { |
630 | return true; |
631 | } |
632 | for (const auto i : c10::irange(ntensors())) { |
633 | auto& stride = operands_[i].stride_bytes; |
634 | if (shape0 * stride[dim0] != stride[dim1]) { |
635 | return false; |
636 | } |
637 | } |
638 | return true; |
639 | }; |
640 | |
641 | // replace each operands stride at dim0 with its stride at dim1 |
642 | auto replace_stride = [&](int dim0, int dim1) { |
643 | for (const auto i : c10::irange(ntensors())) { |
644 | auto& stride = operands_[i].stride_bytes; |
645 | stride[dim0] = stride[dim1]; |
646 | } |
647 | }; |
648 | |
649 | int prev_dim = 0; |
650 | for (const auto dim : c10::irange(1, ndim())) { |
651 | if (can_coalesce(prev_dim, dim)) { |
652 | if (shape_[prev_dim] == 1) { |
653 | replace_stride(prev_dim, dim); |
654 | } |
655 | shape_[prev_dim] *= shape_[dim]; |
656 | } else { |
657 | prev_dim++; |
658 | if (prev_dim != dim) { |
659 | replace_stride(prev_dim, dim); |
660 | shape_[prev_dim] = shape_[dim]; |
661 | } |
662 | } |
663 | } |
664 | |
665 | shape_.resize(prev_dim + 1); |
666 | for (const auto i : c10::irange(ntensors())) { |
667 | operands_[i].stride_bytes.resize(ndim()); |
668 | } |
669 | has_coalesced_dimensions_ = true; |
670 | } |
671 | |
672 | int64_t TensorIteratorBase::numel() const { |
673 | int64_t numel = 1; |
674 | for (int64_t size : shape_) { |
675 | numel *= size; |
676 | } |
677 | return numel; |
678 | } |
679 | |
680 | StrideVector TensorIteratorBase::get_dim_strides(int dim) const { |
681 | auto dims = ndim(); |
682 | auto inner_strides = StrideVector(); |
683 | for (auto& op : operands_) { |
684 | inner_strides.push_back(dims == 0 ? 0 : op.stride_bytes[dim]); |
685 | } |
686 | return inner_strides; |
687 | } |
688 | |
689 | SmallVector<char*, 4> TensorIteratorBase::get_base_ptrs() const { |
690 | auto ptrs = SmallVector<char*, 4>(ntensors()); |
691 | at::get_base_ptrs(ptrs.data(), operands_); |
692 | return ptrs; |
693 | } |
694 | |
695 | bool TensorIteratorBase::is_dim_reduced(int dim) const { |
696 | for (auto& op : operands_) { |
697 | if (op.is_output && op.stride_bytes[dim] == 0 && shape_[dim] > 1) { |
698 | return true; |
699 | } |
700 | } |
701 | return false; |
702 | } |
703 | |
704 | void TensorIteratorBase::permute_dimensions(IntArrayRef perm) { |
705 | TORCH_INTERNAL_ASSERT(perm.size() == static_cast<unsigned>(ndim())); |
706 | |
707 | auto reorder = [perm](IntArrayRef data) { |
708 | auto res = DimVector(data.size(), 0); |
709 | for (const auto i : c10::irange(perm.size())) { |
710 | res[i] = data[perm[i]]; |
711 | } |
712 | return res; |
713 | }; |
714 | |
715 | // Update shape and strides |
716 | shape_ = reorder(shape_); |
717 | for (auto& op : operands_) { |
718 | if (!op.stride_bytes.empty()) { |
719 | op.stride_bytes = reorder(op.stride_bytes); |
720 | } |
721 | } |
722 | } |
723 | |
724 | int64_t TensorIteratorBase::num_output_elements() const { |
725 | int64_t elem = 1; |
726 | for (const auto dim : c10::irange(ndim())) { |
727 | if (operands_[0].stride_bytes[dim] != 0 || shape_[dim] == 0) { |
728 | elem *= shape_[dim]; |
729 | } |
730 | } |
731 | return elem; |
732 | } |
733 | |
734 | int TensorIteratorBase::num_reduce_dims() const { |
735 | int count = 0; |
736 | for (const auto dim : c10::irange(ndim())) { |
737 | if (operands_[0].stride_bytes[dim] == 0) { |
738 | count++; |
739 | } |
740 | } |
741 | return count; |
742 | } |
743 | |
744 | void TensorIteratorBase::for_each(loop2d_t loop, int64_t grain_size) { |
745 | int64_t numel = this->numel(); |
746 | if (numel == 0) { |
747 | return; |
748 | } else if (numel < grain_size || at::get_num_threads() == 1) { |
749 | return serial_for_each(loop, {0, numel}); |
750 | } else { |
751 | at::parallel_for(0, numel, grain_size, [&](int64_t begin, int64_t end) { |
752 | serial_for_each(loop, {begin, end}); |
753 | }); |
754 | } |
755 | } |
756 | |
757 | StrideVector TensorIteratorBase::get_strides() const { |
758 | const auto dim = ndim(); |
759 | StrideVector strides(std::max(dim, 2) * ntensors()); |
760 | at::get_strides(strides.data(), operands_, dim); |
761 | return strides; |
762 | } |
763 | |
764 | void TensorIteratorBase::serial_for_each(loop2d_t loop, Range range) const { |
765 | if (range.size() == 0) { |
766 | return; |
767 | } |
768 | |
769 | const auto ntensors = this->ntensors(); |
770 | const auto ndim = this->ndim(); |
771 | |
772 | c10::SmallBuffer<char*, 4> ptrs(ntensors); |
773 | c10::SmallBuffer<int64_t, 8> strides(ntensors * std::max(ndim, 2)); |
774 | |
775 | at::get_base_ptrs(ptrs.data(), operands_); |
776 | at::get_strides(strides.data(), operands_, ndim); |
777 | at::internal::serial_for_each( |
778 | shape_, strides, ptrs.data(), ptrs.size(), loop, range); |
779 | } |
780 | |
781 | bool TensorIteratorBase::is_trivial_1d() const { |
782 | // TODO: check for casting once it's supported |
783 | return ndim() == 1; |
784 | } |
785 | |
786 | bool TensorIteratorBase::is_contiguous() const { |
787 | if (numel() == 1) { |
788 | return true; |
789 | } |
790 | if (ndim() != 1) { |
791 | return false; |
792 | } |
793 | return has_contiguous_first_dim(); |
794 | } |
795 | |
796 | |
797 | bool TensorIteratorBase::is_scalar(int arg) const { |
798 | const auto& stride = operands_[arg].stride_bytes; |
799 | for (const auto i : c10::irange(ndim())) { |
800 | if (stride[i] != 0 && shape_[i] != 1) { |
801 | return false; |
802 | } |
803 | } |
804 | return true; |
805 | } |
806 | |
807 | bool TensorIteratorBase::is_cpu_scalar(int arg) const { |
808 | return is_scalar(arg) && device(arg).is_cpu(); |
809 | } |
810 | |
811 | void TensorIteratorBase::cast_outputs() { |
812 | for (auto& op : operands_) { |
813 | if (op.is_output && op.original_tensor_base().defined() && |
814 | op.original_tensor_base().scalar_type() != op.current_dtype) { |
815 | // TODO: Now that set_output resizes both the original_tensor |
816 | // and tensor, this condition should no longer ever be true |
817 | const auto &original_tensor = op.original_tensor(); |
818 | const auto &tensor = op.tensor(); |
819 | if (original_tensor.sizes() != tensor.sizes()){ |
820 | original_tensor.resize_as_(tensor).as_strided_(tensor.sizes(), tensor.strides()); |
821 | } |
822 | original_tensor.copy_(tensor); |
823 | op.restore_original_tensor(); |
824 | } |
825 | } |
826 | } |
827 | |
828 | void* TensorIteratorBase::data_ptr(int arg) const { |
829 | return operands_[arg].data; |
830 | } |
831 | |
832 | void TensorIteratorBase::remove_operand(int arg) { |
833 | operands_.erase(operands_.begin() + arg); |
834 | } |
835 | |
836 | void TensorIteratorBase::unsafe_replace_operand(int arg, void* data) { |
837 | operands_[arg].data = data; |
838 | } |
839 | |
840 | void TensorIteratorBase::narrow(int dim, int64_t start, int64_t size) { |
841 | TORCH_INTERNAL_ASSERT(dim < ndim() && size >= 1); |
842 | shape_[dim] = size; |
843 | view_offsets_[dim] += start; |
844 | for (auto& op : operands_) { |
845 | op.data = ((char*)op.data) + op.stride_bytes[dim] * start; |
846 | } |
847 | if (size == 1 && !is_reduction_) { |
848 | coalesce_dimensions(); |
849 | } |
850 | } |
851 | |
852 | void TensorIteratorBase::select_all_keeping_dim(int start_dim, IntArrayRef indices) { |
853 | TORCH_INTERNAL_ASSERT(start_dim <= ndim()); |
854 | for (const auto i : c10::irange(start_dim, ndim())) { |
855 | for (auto& op : operands_) { |
856 | op.data = ((char*)op.data) + op.stride_bytes[i] * indices[i - start_dim]; |
857 | } |
858 | shape_[i] = 1; |
859 | } |
860 | } |
861 | |
862 | #define BINARY_FLOAT_OP_CONFIG() \ |
863 | TensorIteratorConfig() \ |
864 | .set_check_mem_overlap(true) \ |
865 | .allow_cpu_scalars(true) \ |
866 | .promote_inputs_to_common_dtype(true) \ |
867 | .cast_common_dtype_to_outputs(true) \ |
868 | .enforce_safe_casting_to_output(true) \ |
869 | .promote_integer_inputs_to_float(true) |
870 | |
871 | // Helper to construct a binary op that promotes integer inputs to float. |
872 | void TensorIteratorBase::build_binary_float_op( |
873 | const TensorBase& out, const TensorBase& a, const TensorBase& b) { |
874 | build(BINARY_FLOAT_OP_CONFIG() |
875 | .add_owned_output(out) |
876 | .add_owned_input(a) |
877 | .add_owned_input(b)); |
878 | } |
879 | |
880 | void TensorIteratorBase::build_borrowing_binary_float_op( |
881 | const TensorBase& out, const TensorBase& a, const TensorBase& b) { |
882 | build(BINARY_FLOAT_OP_CONFIG() |
883 | .add_output(out) |
884 | .add_input(a) |
885 | .add_input(b)); |
886 | } |
887 | |
888 | static void set_up_comparison_op_config(TensorIteratorConfig& config, const TensorBase& out) { |
889 | config.set_check_mem_overlap(true); |
890 | config.allow_cpu_scalars(true); |
891 | config.promote_inputs_to_common_dtype(true); |
892 | |
893 | // When 'out' isn't defined (e.g. for the functional operator 'a == b'), we |
894 | // want the output to be bool. Otherwise (e.g. 'torch.eq(a, b, out=c)') we |
895 | // don't coerce the output. |
896 | if (!out.defined()) { |
897 | config.declare_static_dtype(kBool); |
898 | } |
899 | |
900 | // Note [special-case bool outputs] |
901 | // We explicitly don't call `cast_common_dtype_to_outputs` when the output tensor |
902 | // has `bool` dtype. This is a performance optimization: the functional |
903 | // version of all comparison/logical ops uses a bool output tensor, and we'd like to |
904 | // avoid creating a temporary copy of the output. |
905 | // However, note that all kernels using this TensorIterator will need to special-case when |
906 | // the output tensor has bool dtype, and provide a lambda of type (scalar_t, scalar_t -> bool). |
907 | if (out.defined() && out.scalar_type() != kBool) { |
908 | config.cast_common_dtype_to_outputs(true); |
909 | } |
910 | } |
911 | |
912 | void TensorIteratorBase::build_comparison_op( |
913 | const TensorBase& out, const TensorBase& a, const TensorBase& b) { |
914 | TensorIteratorConfig config; |
915 | set_up_comparison_op_config(config, out); |
916 | |
917 | config.add_owned_output(out); |
918 | config.add_owned_input(a); |
919 | config.add_owned_input(b); |
920 | build(config); |
921 | } |
922 | |
923 | void TensorIteratorBase::build_borrowing_comparison_op( |
924 | const TensorBase& out, const TensorBase& a, const TensorBase& b) { |
925 | TensorIteratorConfig config; |
926 | set_up_comparison_op_config(config, out); |
927 | |
928 | config.add_borrowed_output(out); |
929 | config.add_borrowed_input(a); |
930 | config.add_borrowed_input(b); |
931 | build(config); |
932 | } |
933 | |
934 | void TensorIteratorBase::build_borrowing_except_last_argument_comparison_op( |
935 | const TensorBase& out, const TensorBase& a, const TensorBase& b) { |
936 | TensorIteratorConfig config; |
937 | set_up_comparison_op_config(config, out); |
938 | |
939 | config.add_borrowed_output(out); |
940 | config.add_borrowed_input(a); |
941 | config.add_owned_input(b); |
942 | build(config); |
943 | } |
944 | |
945 | void TensorIteratorBase::build_ternary_op( |
946 | const TensorBase& out, const TensorBase& a, |
947 | const TensorBase& b, const TensorBase& c) { |
948 | build(TensorIteratorConfig() |
949 | .promote_inputs_to_common_dtype(true) |
950 | .enforce_safe_casting_to_output(true) |
951 | .add_owned_output(out) |
952 | .add_owned_input(a) |
953 | .add_owned_input(b) |
954 | .add_owned_input(c)); |
955 | } |
956 | |
957 | // This cannot be a function because TensorIteratorConfig is not |
958 | // copyable or movable, so it can't be returned from the function. |
959 | #define BINARY_OP_CONFIG() \ |
960 | TensorIteratorConfig() \ |
961 | .set_check_mem_overlap(true) \ |
962 | .allow_cpu_scalars(true) \ |
963 | .promote_inputs_to_common_dtype(true) \ |
964 | .cast_common_dtype_to_outputs(true) \ |
965 | .enforce_safe_casting_to_output(true) \ |
966 | |
967 | void TensorIteratorBase::build_binary_op(const TensorBase& out, const TensorBase& a, const TensorBase& b) { |
968 | build(BINARY_OP_CONFIG() |
969 | .add_owned_output(out) |
970 | .add_owned_input(a) |
971 | .add_owned_input(b)); |
972 | } |
973 | |
974 | void TensorIteratorBase::build_borrowing_binary_op( |
975 | const TensorBase& out, const TensorBase& a, const TensorBase& b) { |
976 | build(BINARY_OP_CONFIG() |
977 | .add_output(out) |
978 | .add_input(a) |
979 | .add_input(b)); |
980 | } |
981 | |
982 | // This cannot be a function because TensorIteratorConfig is not |
983 | // copyable or movable, so it can't be returned from the function. |
984 | #define UNARY_FLOAT_OP_CONFIG() \ |
985 | TensorIteratorConfig() \ |
986 | .set_check_mem_overlap(true) \ |
987 | .promote_inputs_to_common_dtype(true) \ |
988 | .cast_common_dtype_to_outputs(true) \ |
989 | .enforce_safe_casting_to_output(true) \ |
990 | .promote_integer_inputs_to_float(true) |
991 | |
992 | void TensorIteratorBase::build_unary_float_op(const TensorBase& out, const TensorBase& a) { |
993 | build(UNARY_FLOAT_OP_CONFIG() |
994 | .add_owned_output(out) |
995 | .add_owned_input(a)); |
996 | } |
997 | |
998 | void TensorIteratorBase::build_borrowing_unary_float_op(const TensorBase& out, const TensorBase& a) { |
999 | build(UNARY_FLOAT_OP_CONFIG() |
1000 | .add_output(out) |
1001 | .add_input(a)); |
1002 | } |
1003 | |
1004 | // This cannot be a function because TensorIteratorConfig is not |
1005 | // copyable or movable, so it can't be returned from the function. |
1006 | #define UNARY_OP_CONFIG() \ |
1007 | TensorIteratorConfig() \ |
1008 | .set_check_mem_overlap(true) \ |
1009 | .cast_common_dtype_to_outputs(false) \ |
1010 | .enforce_safe_casting_to_output(false) \ |
1011 | .check_all_same_dtype(true) |
1012 | |
1013 | void TensorIteratorBase::build_unary_op(const TensorBase& out, const TensorBase& a) { |
1014 | build(UNARY_OP_CONFIG() |
1015 | .add_owned_output(out) |
1016 | .add_owned_input(a)); |
1017 | } |
1018 | |
1019 | void TensorIteratorBase::build_borrowing_unary_op(const TensorBase& out, const TensorBase& a) { |
1020 | build(UNARY_OP_CONFIG() |
1021 | .add_output(out) |
1022 | .add_input(a)); |
1023 | } |
1024 | |
1025 | void TensorIteratorBase::build_output_borrowing_argument_owning_unary_op(const TensorBase& out, const TensorBase& a) { |
1026 | build(UNARY_OP_CONFIG() |
1027 | .add_output(out) |
1028 | .add_owned_input(a)); |
1029 | } |
1030 | |
1031 | // Helper to construct a unary op that forcibly promotes output to boolean. |
1032 | // Only be used when the output tensor must have boolean type. |
1033 | void TensorIteratorBase::build_borrowing_unary_force_boolean_op(const TensorBase& out, const TensorBase& a) { |
1034 | build(TensorIteratorConfig() |
1035 | .set_check_mem_overlap(true) |
1036 | .check_all_same_dtype(false) |
1037 | .declare_static_dtype(at::kBool) |
1038 | .declare_static_device(a.device()) |
1039 | .add_output(out) |
1040 | .add_input(a)); |
1041 | } |
1042 | |
1043 | TensorIterator TensorIterator::binary_op(TensorBase& out, const TensorBase& a, const TensorBase& b) { |
1044 | TensorIterator iter; |
1045 | iter.build_binary_op(out, a, b); |
1046 | return iter; |
1047 | } |
1048 | |
1049 | TensorIterator TensorIterator::borrowing_binary_op( |
1050 | const TensorBase& out, const TensorBase& a, const TensorBase& b) { |
1051 | TensorIterator iter; |
1052 | iter.build_borrowing_binary_op(out, a, b); |
1053 | return iter; |
1054 | } |
1055 | |
1056 | TensorIterator TensorIterator::binary_float_op(TensorBase& out, const TensorBase& a, const TensorBase& b) { |
1057 | TensorIterator iter; |
1058 | iter.build_binary_float_op(out, a, b); |
1059 | return iter; |
1060 | } |
1061 | |
1062 | TensorIterator TensorIterator::comparison_op(TensorBase& out, const TensorBase& a, |
1063 | const TensorBase& b) { |
1064 | TensorIterator iter; |
1065 | iter.build_comparison_op(out, a, b); |
1066 | return iter; |
1067 | } |
1068 | |
1069 | TensorIterator TensorIterator::unary_op(TensorBase& out, const TensorBase& a) { |
1070 | TensorIterator iter; |
1071 | iter.build_unary_op(out, a); |
1072 | return iter; |
1073 | } |
1074 | |
1075 | TensorIterator TensorIterator::unary_float_op(TensorBase& out, const TensorBase& a) { |
1076 | TensorIterator iter; |
1077 | iter.build_unary_float_op(out, a); |
1078 | return iter; |
1079 | } |
1080 | |
1081 | #define NULLARY_OP_CONFIG() \ |
1082 | TensorIteratorConfig() \ |
1083 | .set_check_mem_overlap(true) \ |
1084 | .check_all_same_dtype(false) \ |
1085 | /* FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342 */ \ |
1086 | .resize_outputs(false) |
1087 | |
1088 | TensorIterator TensorIterator::nullary_op(TensorBase& out) { |
1089 | return NULLARY_OP_CONFIG() |
1090 | .add_owned_output(out) |
1091 | .build(); |
1092 | } |
1093 | |
1094 | TensorIterator TensorIterator::borrowing_nullary_op(const TensorBase& out) { |
1095 | return NULLARY_OP_CONFIG() |
1096 | .add_output(out) |
1097 | .build(); |
1098 | } |
1099 | |
1100 | TensorIterator TensorIterator::reduce_op(TensorBase& out, const TensorBase& a) { |
1101 | TORCH_INTERNAL_ASSERT(out.defined()); |
1102 | return TensorIteratorConfig() |
1103 | .set_check_mem_overlap(false) |
1104 | .add_owned_output(out) |
1105 | .add_owned_input(a) |
1106 | .resize_outputs(false) |
1107 | .is_reduction(true) |
1108 | // TODO: not supporting casting to outputs is only really necessary for arg{min,max} |
1109 | .promote_inputs_to_common_dtype(true) |
1110 | .build(); |
1111 | } |
1112 | |
1113 | TensorIterator TensorIterator::reduce_op(TensorBase& out1, TensorBase& out2, const TensorBase& a) { |
1114 | TORCH_INTERNAL_ASSERT(out1.defined()); |
1115 | TORCH_INTERNAL_ASSERT(out2.defined()); |
1116 | TORCH_CHECK(a.device() == out1.device() && out1.device() == out2.device(), |
1117 | "reduce_op(): expected input and both outputs to be on same device, but input is on " , a.device(), |
1118 | ", output1 is on " , out1.device(), " and output2 is on" , out2.device()); |
1119 | TORCH_CHECK(out1.dim() == out2.dim(), "reduce_op(): expected both outputs to have same number of dims, but output1 has " , out1.dim(), |
1120 | " and output2 has " , out2.dim()); |
1121 | TORCH_CHECK(out1.sizes() == out2.sizes(), "reduce_op(): expected both outputs to have same sizes, but output1 has " , out1.sizes(), |
1122 | " and output2 has " , out2.sizes()); |
1123 | TORCH_CHECK(out1.strides() == out2.strides(), "reduce_op(): expected both outputs to have same strides, but output1 has " , out1.strides(), |
1124 | " and output2 has " , out2.strides()); |
1125 | return TensorIteratorConfig() |
1126 | .set_check_mem_overlap(false) |
1127 | .add_owned_output(out1) |
1128 | .add_owned_output(out2) |
1129 | .add_owned_input(a) |
1130 | .resize_outputs(false) |
1131 | .is_reduction(true) |
1132 | .check_all_same_dtype(false) |
1133 | .build(); |
1134 | } |
1135 | |
1136 | void TensorIteratorBase::populate_operands(TensorIteratorConfig& config) { |
1137 | for (auto& tensor: config.tensors_) { |
1138 | // If *any* of the arguments is a meta tensor, the overall |
1139 | // computation is a meta computation (don't do any work, |
1140 | // just compute output information). This aligns with |
1141 | // our multiple dispatch semantics. |
1142 | if (tensor->is_meta()) { |
1143 | is_meta_ = true; |
1144 | } |
1145 | operands_.emplace_back(std::move(tensor)); |
1146 | } |
1147 | num_outputs_ = config.num_outputs_; |
1148 | } |
1149 | |
1150 | void TensorIteratorBase::mark_outputs() { |
1151 | // TODO: merge this into populate_operands |
1152 | for (const auto i : c10::irange(num_outputs_)) { |
1153 | operands_[i].is_output = true; |
1154 | const auto& output = tensor(i); |
1155 | if (!output.defined()) continue; |
1156 | |
1157 | // check if output is also an input |
1158 | for (const auto arg : c10::irange(num_outputs_, ntensors())) { |
1159 | const auto& input = tensor(arg); |
1160 | if (output.is_same(input)) { |
1161 | operands_[i].is_read_write = true; |
1162 | } |
1163 | } |
1164 | } |
1165 | } |
1166 | |
1167 | void TensorIteratorBase::mark_resize_outputs(const TensorIteratorConfig& config) { |
1168 | // Outputs cannot be broadcasted. Check that the shape of the outputs matches |
1169 | // the inferred shape. There's an exception for write-only tensors to support |
1170 | // our legacy behavior that functions with `out=` arguments resize their |
1171 | // outputs. |
1172 | if (config.static_shape_.has_value()) { |
1173 | return; |
1174 | } |
1175 | for (const auto i : c10::irange(num_outputs_)) { |
1176 | const auto& output = tensor(i); |
1177 | if (output.defined() && !output.sizes().equals(shape_)) { |
1178 | if (config.resize_outputs_ && !operands_[i].is_read_write) { |
1179 | operands_[i].will_resize = true; |
1180 | continue; |
1181 | } |
1182 | // for reduction, output size does not match shape_, as output is reduced size, and shape_ is size of the input |
1183 | TORCH_CHECK(is_reduction_, "output with shape " , output.sizes(), " doesn't match the broadcast shape " , |
1184 | shape_); |
1185 | } |
1186 | } |
1187 | } |
1188 | |
1189 | void TensorIteratorBase::compute_mem_overlaps(const TensorIteratorConfig& config) { |
1190 | if (!config.check_mem_overlap_) { |
1191 | return; |
1192 | } |
1193 | for (const auto i : c10::irange(num_outputs_)) { |
1194 | const auto& output = tensor_base(i); |
1195 | if (!output.defined()) continue; |
1196 | assert_no_internal_overlap(output); |
1197 | for (const auto j : c10::irange(num_outputs_, ntensors())) { |
1198 | const auto& input = tensor_base(j); |
1199 | if (!input.is_same(output)) { |
1200 | assert_no_partial_overlap(output, input); |
1201 | } |
1202 | } |
1203 | } |
1204 | } |
1205 | |
1206 | void TensorIteratorBase::compute_shape(const TensorIteratorConfig& config) { |
1207 | if (config.static_shape_.has_value()) { |
1208 | shape_ = *config.static_shape_; |
1209 | return; |
1210 | } |
1211 | |
1212 | all_ops_same_shape_ = true; |
1213 | bool has_scalars = false; |
1214 | bool has_tensors = false; |
1215 | for (auto& op : operands_) { |
1216 | if (!op.tensor_base().defined()) continue; |
1217 | |
1218 | // For now, don't include output tensors when we're resizing outputs. |
1219 | // These shapes don't participate in shape computation. |
1220 | // This preserves the legacy behavior where torch.add(..., out=dst) resizes |
1221 | // the destination tensor. If the output tensor is also an input, we'll |
1222 | // pick it up later in the operands. |
1223 | if (config.resize_outputs_ && op.is_output) continue; |
1224 | TORCH_CHECK(!op.tensor_base().unsafeGetTensorImpl()->has_symbolic_sizes_strides(), |
1225 | "TensorIterator does not support symbolic shapes; please implement this operator in torch/_refs " |
1226 | "using the elementwise or reduction helpers (look at backtrace to find out what operator this is)" ); |
1227 | auto shape = op.tensor_base().sizes(); |
1228 | if (shape.empty()) { |
1229 | has_scalars = true; |
1230 | } else { |
1231 | has_tensors = true; |
1232 | } |
1233 | if (has_scalars && has_tensors) { |
1234 | all_ops_same_shape_ = false; |
1235 | } |
1236 | if (shape_.empty()) { |
1237 | shape_ = shape; |
1238 | } else if (!shape.equals(shape_)) { |
1239 | all_ops_same_shape_ = false; |
1240 | shape_ = infer_size_dimvector(shape_, shape); |
1241 | } |
1242 | } |
1243 | all_ops_are_scalars_ = !has_tensors; |
1244 | } |
1245 | |
1246 | void TensorIteratorBase::compute_strides(const TensorIteratorConfig& config) { |
1247 | for (auto& op : operands_) { |
1248 | if (op.tensor_base().defined() && !op.will_resize) { |
1249 | IntArrayRef original_shape = config.static_shape_ ? shape_ : op.tensor_base().sizes(); |
1250 | auto original_stride = op.tensor_base().strides(); |
1251 | auto element_size_in_bytes = op.tensor_base().element_size(); |
1252 | auto offset = ndim() - original_shape.size(); |
1253 | if (offset > 0) |
1254 | op.stride_bytes.resize(ndim(), 0); |
1255 | else |
1256 | op.stride_bytes.resize(ndim()); |
1257 | for (const auto i : c10::irange(original_shape.size())) { |
1258 | // see NOTE: [Computing output strides] |
1259 | if (original_shape[i] == 1 && shape_[offset + i] !=1) { |
1260 | op.stride_bytes[offset + i] = 0; |
1261 | } else { |
1262 | op.stride_bytes[offset + i] = original_stride[i] * element_size_in_bytes; |
1263 | } |
1264 | } |
1265 | } |
1266 | } |
1267 | } |
1268 | |
1269 | bool TensorIteratorBase::can_use_32bit_indexing() const { |
1270 | int64_t max_value = std::numeric_limits<int32_t>::max(); |
1271 | if (numel() > max_value) { |
1272 | return false; |
1273 | } |
1274 | for (auto& op : operands_) { |
1275 | int64_t max_offset = 1; |
1276 | for (const auto dim : c10::irange(ndim())) { |
1277 | max_offset += (shape_[dim] - 1) * op.stride_bytes[dim]; |
1278 | } |
1279 | if (max_offset > max_value) { |
1280 | return false; |
1281 | } |
1282 | } |
1283 | return true; |
1284 | } |
1285 | |
1286 | std::unique_ptr<TensorIterator> TensorIteratorBase::split(int dim) { |
1287 | TORCH_INTERNAL_ASSERT(dim >= 0 && dim < ndim() && shape()[dim] >= 2); |
1288 | std::unique_ptr<TensorIterator> copy(new TensorIterator(*this)); |
1289 | |
1290 | bool overlaps = is_dim_reduced(dim); |
1291 | auto copy_size = shape_[dim] / 2; |
1292 | auto this_size = shape_[dim] - copy_size; |
1293 | copy->narrow(dim, 0, copy_size); |
1294 | copy->final_output_ &= !overlaps; |
1295 | this->narrow(dim, copy_size, this_size); |
1296 | this->accumulate_ |= overlaps; |
1297 | |
1298 | return copy; |
1299 | } |
1300 | |
1301 | |
1302 | int TensorIteratorBase::get_dim_to_split() const { |
1303 | TORCH_INTERNAL_ASSERT(ndim() >= 1); |
1304 | int64_t max_extent = -1; |
1305 | int dim_to_split = -1; |
1306 | for (int dim = ndim() - 1; dim >= 0; dim--) { |
1307 | const int64_t size = shape_[dim]; |
1308 | if (size == 0) { |
1309 | continue; |
1310 | } |
1311 | for (auto& op : operands_) { |
1312 | // std::abs is necessary to handle some special cases where we support negative strides |
1313 | // see the CUDA backend of at::flip |
1314 | const int64_t extent = (size - 1) * std::abs(op.stride_bytes[dim]); |
1315 | if (extent > max_extent) { |
1316 | max_extent = extent; |
1317 | dim_to_split = dim; |
1318 | } |
1319 | } |
1320 | } |
1321 | TORCH_INTERNAL_ASSERT(max_extent >= 0); |
1322 | return dim_to_split; |
1323 | } |
1324 | |
1325 | bool TensorIteratorBase::fast_set_up(const TensorIteratorConfig& config) { |
1326 | // This function tries to do a fast setup to avoid needless reordering of dimensions and tracking output strides |
1327 | // Return true if it can do fast setup or false otherwise |
1328 | // TODO enable fast handling for reductions |
1329 | FastSetupType setup_type = compute_fast_setup_type(config); |
1330 | if (setup_type == FastSetupType::NONE) { |
1331 | return false; |
1332 | } |
1333 | |
1334 | // allocate memory for output, memory format depends on setup_type |
1335 | switch (setup_type) { |
1336 | case FastSetupType::CONTIGUOUS: |
1337 | { |
1338 | for (const auto i : c10::irange(num_outputs_)) { |
1339 | auto& op = operands_[i]; |
1340 | if (!op.tensor_base().defined()) { |
1341 | TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand" , i); |
1342 | } |
1343 | set_output_raw_strided(i, shape_, {}, original_options(op).memory_format(MemoryFormat::Contiguous), names_); |
1344 | } |
1345 | break; |
1346 | } |
1347 | case FastSetupType::CHANNELS_LAST: |
1348 | { |
1349 | for (const auto i : c10::irange(num_outputs_)) { |
1350 | auto& op = operands_[i]; |
1351 | if (!op.tensor_base().defined()) { |
1352 | TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand" , i); |
1353 | } |
1354 | set_output_raw_strided(i, shape_, {}, original_options(op).memory_format(MemoryFormat::ChannelsLast), names_); |
1355 | } |
1356 | break; |
1357 | } |
1358 | case FastSetupType::NON_OVERLAPPING_DENSE: |
1359 | { |
1360 | // find the index of a defined tensor in operands_ start from input tensor |
1361 | int i_defined; // NOLINT(cppcoreguidelines-init-variables) |
1362 | for (i_defined = ntensors() - 1; i_defined >= 0; --i_defined) { |
1363 | if (tensor(i_defined).defined()) break; |
1364 | } |
1365 | TORCH_CHECK(i_defined >= 0, "Can not find a defined tensor when fast allocating memory to outputs" ); |
1366 | for (const auto i : c10::irange(num_outputs_)) { |
1367 | auto& op = operands_[i]; |
1368 | if (!op.tensor_base().defined()) { |
1369 | TORCH_INTERNAL_ASSERT(op.is_type_defined(), "no type for operand" , i); |
1370 | } |
1371 | set_output_raw_strided(i, shape_, tensor_base(i_defined).strides(), original_options(op), names_); |
1372 | } |
1373 | break; |
1374 | } |
1375 | default: |
1376 | TORCH_INTERNAL_ASSERT(false, "Unsupported fast setup type" , c10::to_string((int)setup_type)); |
1377 | } |
1378 | //coalescing dimensions consists of collapsing dimensions to 1 (we are limited to contiguous no-broadcast cases here) |
1379 | if (ndim() > 1){ |
1380 | has_coalesced_dimensions_ = true; |
1381 | } |
1382 | if (ndim() >= 1) { |
1383 | shape_[0] = numel(); |
1384 | shape_.resize(1); |
1385 | } |
1386 | for (auto& op : operands_ ) { |
1387 | auto element_size_in_bytes = op.tensor_base().element_size(); |
1388 | op.stride_bytes.resize(ndim()); |
1389 | if (ndim()>0) { |
1390 | op.stride_bytes[0] = element_size_in_bytes; |
1391 | } |
1392 | } |
1393 | return true; |
1394 | } |
1395 | |
1396 | FastSetupType TensorIteratorBase::compute_fast_setup_type(const TensorIteratorConfig& config) { |
1397 | if (is_reduction_ || !all_ops_same_shape_) { |
1398 | return FastSetupType::NONE; |
1399 | } |
1400 | |
1401 | // For linear iteration, only contiguous tensors can be coalesced |
1402 | // Fast setup of any other format requires changing iteration order |
1403 | if (enforce_linear_iteration_) { |
1404 | for (const auto& op : operands_) { |
1405 | if (op.tensor_base().defined() && !op.will_resize) { |
1406 | auto is_contiguous = op.tensor_base().is_contiguous(at::MemoryFormat::Contiguous); |
1407 | if (!is_contiguous) { |
1408 | return FastSetupType::NONE; |
1409 | } |
1410 | } |
1411 | } |
1412 | return FastSetupType::CONTIGUOUS; |
1413 | } |
1414 | |
1415 | bool is_contiguous = true; |
1416 | bool is_channels_last = true; |
1417 | bool is_non_overlapping_and_dense = true; |
1418 | for (const auto& op : operands_) { |
1419 | if (op.tensor_base().defined() && !op.will_resize) { |
1420 | is_contiguous &= op.tensor_base().is_contiguous(at::MemoryFormat::Contiguous); |
1421 | is_channels_last &= op.tensor_base().is_contiguous(at::MemoryFormat::ChannelsLast); |
1422 | is_non_overlapping_and_dense &= op.tensor_base().is_non_overlapping_and_dense(); |
1423 | } |
1424 | } |
1425 | // TODO this leads to ambiguous cases (NC11) to be always treated as contiguous |
1426 | if (is_contiguous) { |
1427 | return FastSetupType::CONTIGUOUS; |
1428 | } |
1429 | if (is_channels_last) { |
1430 | return FastSetupType::CHANNELS_LAST; |
1431 | } |
1432 | if (is_non_overlapping_and_dense) { |
1433 | int64_t prev = -1; |
1434 | // Fast setup is allowed only when all the defined tensors have the same shape and strides, |
1435 | // Iterate from back to check input tensors' strides first, then output tensors'. |
1436 | for (int64_t i = ntensors() - 1; i >= 0; --i) { |
1437 | const auto& op = operands_[i]; |
1438 | if (op.tensor_base().defined() && !op.will_resize) { |
1439 | if (prev < 0) { |
1440 | prev = i; |
1441 | continue; |
1442 | } |
1443 | if (!tensor_base(prev).strides().equals(op.tensor_base().strides())) { |
1444 | // [Note: stride check for non contiguous tensors in fast setup] |
1445 | // We prevent 3 cases doing fast setup here: |
1446 | // 1. input tensors have different strides. |
1447 | // 2. output tensors won't be resized and have different strides. |
1448 | // 3. input tensors have the same strides, but output tensors have different strides with input tensors. |
1449 | // We don't allow re-stride output tensors in this case since it is not compatible with |
1450 | // numpy. The behavior in numpy is that if the output tensor has same shape as the input |
1451 | // tensor but different strides, the strides of output tensor will be preserved, so we do |
1452 | // the same in tensor iterator. |
1453 | return FastSetupType::NONE; |
1454 | } |
1455 | } |
1456 | } |
1457 | return FastSetupType::NON_OVERLAPPING_DENSE; |
1458 | } |
1459 | return FastSetupType::NONE; |
1460 | } |
1461 | |
1462 | TensorIteratorBase::TensorIteratorBase() = default; |
1463 | |
1464 | void TensorIteratorBase::build(TensorIteratorConfig& config) { |
1465 | // populate some persistent configuration fields |
1466 | is_reduction_ = config.is_reduction_; |
1467 | enforce_linear_iteration_ = config.enforce_linear_iteration_; |
1468 | |
1469 | // fill in operands_ based on configuration |
1470 | populate_operands(config); |
1471 | // set is_output and is_read_write flags on appropriate tensors |
1472 | mark_outputs(); |
1473 | // Check that the outputs have no internal overlap |
1474 | // and do not share memory with inputs. |
1475 | compute_mem_overlaps(config); |
1476 | // Check that input dimensions are aligned correctly & compute outnames. |
1477 | compute_names(config); |
1478 | // compute the broadcasted shape |
1479 | compute_shape(config); |
1480 | // mark outputs for resizing if necessary |
1481 | mark_resize_outputs(config); |
1482 | // compute the result dtype and device |
1483 | compute_types(config); |
1484 | // try fast setup output tensor, if failed, fallback to normal setup |
1485 | if (!fast_set_up(config)) { |
1486 | // compute each tensor's stride after broadcasting |
1487 | compute_strides(config); |
1488 | // re-order dimensions to improve coalescing |
1489 | reorder_dimensions(); |
1490 | // allocate the output tensor if it's not provided |
1491 | allocate_or_resize_outputs(); |
1492 | // coalesce adjacent dimensions when possible |
1493 | if (!is_meta_) coalesce_dimensions(); |
1494 | } |
1495 | |
1496 | if (is_meta_) return; |
1497 | |
1498 | auto has_storage = true; |
1499 | for (auto& op : operands_) { |
1500 | has_storage &= op.tensor_base().has_storage(); |
1501 | } |
1502 | auto privateuse1_without_storage = |
1503 | common_device_.type() == DeviceType::PrivateUse1 && |
1504 | !has_storage; |
1505 | |
1506 | // XLA and lazy tensors don't have storage, so they don't have an underlying data pointer. |
1507 | // Nothing beyond this point is important for meta functions, so it's fine to exit early here. |
1508 | // Extend the condition to ORT tesnors as ORT tensors also don't have storage. |
1509 | if (privateuse1_without_storage || |
1510 | common_device_.type() == DeviceType::XLA || |
1511 | common_device_.type() == DeviceType::IPU || |
1512 | common_device_.type() == DeviceType::Lazy || |
1513 | common_device_.type() == DeviceType::ORT || |
1514 | common_device_.type() == DeviceType::HPU) return; |
1515 | |
1516 | for (auto& op : operands_) { |
1517 | TORCH_INTERNAL_ASSERT(op.tensor_base().defined()); |
1518 | op.data = op.tensor_base().data_ptr(); |
1519 | } |
1520 | |
1521 | // zero out offsets |
1522 | // If the tensor is a scalar, we leave room for it |
1523 | // So index translations in reduction can access |
1524 | // a valid value for the offset |
1525 | int64_t ndim_offsets = (ndim() ? ndim() : 1); |
1526 | view_offsets_ = DimVector(ndim_offsets, 0); |
1527 | } |
1528 | |
1529 | // This is the structured kernels' implementation of set_output. It is |
1530 | // NEVER actually called directly; instead, a subclass of TensorIteratorBase |
1531 | // will override set_output to actually do the operation, and then call |
1532 | // set_output on the TensorIteratorBase to setup TI's metadata. |
1533 | // The precondition for this function is that maybe_get_output() now |
1534 | // unconditionally returns a real Tensor (prior to output setting, |
1535 | // this function may return an undefined tensor.) |
1536 | void TensorIteratorBase::set_output_raw_strided(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) { |
1537 | auto& op = operands_[output_idx]; |
1538 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_); |
1539 | const auto& t = maybe_get_output(output_idx); |
1540 | TORCH_INTERNAL_ASSERT(t.defined()); |
1541 | if (!op.tensor_base().defined()) { |
1542 | op.tensor(c10::MaybeOwned<TensorBase>::borrowed(t)); |
1543 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(op.target_dtype == t.scalar_type()); |
1544 | } else if (op.will_resize) { |
1545 | if (op.original_tensor_base().defined()) { |
1546 | // OK, so this is pretty weird. To understand how we can end up in |
1547 | // this situation, first look at Marker [Output original_tensor is set]. |
1548 | // That is the sole site where original_tensor may be set on an |
1549 | // output operand. Essentially, when we are given an explicit output |
1550 | // tensor whose dtype doesn't match the computed common dtype from |
1551 | // the input operands, we do a switcheroo: we replace the (incorrectly |
1552 | // typed) output tensor with a correctly typed, *temporary* tensor, |
1553 | // and remember the original tensor in original_tensor (which will |
1554 | // then get written back to when we cast_outputs). |
1555 | // |
1556 | // Now, what if the given output tensor also happened to be zero |
1557 | // size (meaning that we will_resize it)? Well, at the call site |
1558 | // above, we don't necessarily(*) know what the correct shape should |
1559 | // be, so we give the temporary tensor the same shape as the original. |
1560 | // At the time of set_output is when we DO know what the correct size |
1561 | // is, and the subclass's implementation of set_output in structured class |
1562 | // responsible for resizing original_tensor. But we still have this |
1563 | // incorrectly sized temporary output which the structured subclass |
1564 | // knows nothing about, so we are obligated to also resize it here. |
1565 | // |
1566 | // This is a slight memory pessimization, because previously |
1567 | // original_tensor only got resized at the end of the computation, rather |
1568 | // than at the beginning (as happens here). However, the peak memory |
1569 | // usage is the same, since you need to materialize both original tensor |
1570 | // and temporary tensor to do the copy. |
1571 | // |
1572 | // (*) Actually, technically, we probably do know what the shape |
1573 | // should be, since we do shape computation before dtype computation. |
1574 | // So hypothetically we could figure out what the correct shape is |
1575 | // at that point in time and directly allocate the temporary at |
1576 | // the right size. |
1577 | // |
1578 | // But a better solution is to delay allocation of temporaries until |
1579 | // after TensorIterator builder, waiting until we actually want |
1580 | // to do the computation. That would also remove the necessity |
1581 | // for the is_meta_ test. |
1582 | TORCH_INTERNAL_ASSERT(op.original_tensor_base().is_same(t)); |
1583 | TORCH_INTERNAL_ASSERT(!op.tensor_base().is_same(t)); |
1584 | OptionalTensorRef tensor(op.tensor()); |
1585 | at::native::resize_output(*tensor, sizes); |
1586 | if (!strides.empty()) { |
1587 | TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); |
1588 | tensor->as_strided_(sizes, strides); |
1589 | } else if (options.memory_format_opt().has_value()) { |
1590 | tensor->unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); |
1591 | } |
1592 | } |
1593 | } |
1594 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
1595 | op.tensor_base().is_same(t) || op.current_dtype == op.tensor_base().scalar_type()); |
1596 | // For simplicity, just always update the cached current_type. |
1597 | op.current_dtype = op.tensor_base().scalar_type(); |
1598 | } |
1599 | |
1600 | // This is the "traditional" implementation of set_output. On TensorIterator |
1601 | // instances, it is invoked directly from various call sites in this file. No |
1602 | // funny business. |
1603 | void TensorIterator::set_output_raw_strided(int64_t output_idx, IntArrayRef sizes, IntArrayRef strides, TensorOptions options, DimnameList names) { |
1604 | // NB: intentionally no superclass call |
1605 | auto& op = operands_[output_idx]; |
1606 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(output_idx < num_outputs_); |
1607 | if (!op.tensor_base().defined()) { |
1608 | if (strides.empty()) { |
1609 | op.tensor(c10::MaybeOwned<TensorBase>::owned(at::empty(sizes, options))); |
1610 | } else { |
1611 | op.tensor(c10::MaybeOwned<TensorBase>::owned(at::empty_strided(sizes, strides, options))); |
1612 | } |
1613 | op.current_dtype = op.target_dtype; |
1614 | } else if (op.will_resize) { |
1615 | at::native::resize_output(op.tensor(), sizes); |
1616 | if (!strides.empty()) { |
1617 | TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value()); |
1618 | op.tensor().as_strided_(sizes, strides); |
1619 | } else if (options.memory_format_opt().has_value()) { |
1620 | op.tensor_base().unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt()); |
1621 | } |
1622 | } |
1623 | if (!names.empty()) { |
1624 | TORCH_INTERNAL_ASSERT(op.tensor_base().defined()); |
1625 | namedinference::propagate_names(op.tensor_base(), names); |
1626 | } |
1627 | } |
1628 | |
1629 | // Not actually used by anything (TensorIterator subclass calls |
1630 | // its own implementation of set_output which knows exactly where |
1631 | // all the outputs are), but we have to provide all pure virtual methods |
1632 | // for MetaBase |
1633 | const Tensor& TensorIterator::maybe_get_output(int64_t output_idx) { |
1634 | return output(output_idx); |
1635 | } |
1636 | |
1637 | SplitUntil32Bit TensorIteratorBase::with_32bit_indexing() const { |
1638 | return SplitUntil32Bit(*this); |
1639 | } |
1640 | |
1641 | /// SplitUntil32Bit. Recursively splits an iterator into sub-iterators that |
1642 | /// can use 32-bit indexing. |
1643 | |
1644 | SplitUntil32Bit::iterator::iterator(const TensorIteratorBase& iter) { |
1645 | vec.emplace_back(new TensorIterator(iter)); |
1646 | vec.emplace_back(nullptr); // ++ first pops the last element |
1647 | ++(*this); |
1648 | } |
1649 | |
1650 | SplitUntil32Bit::iterator& SplitUntil32Bit::iterator::operator++() { |
1651 | vec.pop_back(); |
1652 | while (!vec.empty() && !vec.back()->can_use_32bit_indexing()) { |
1653 | auto& iter = *vec.back(); |
1654 | int64_t split_dim = iter.get_dim_to_split(); |
1655 | vec.emplace_back(iter.split(split_dim)); |
1656 | } |
1657 | return *this; |
1658 | } |
1659 | |
1660 | TensorIterator& SplitUntil32Bit::iterator::operator*() const { |
1661 | return *vec.back(); |
1662 | } |
1663 | |
1664 | SplitUntil32Bit::iterator SplitUntil32Bit::begin() const { |
1665 | return SplitUntil32Bit::iterator(iter); |
1666 | } |
1667 | |
1668 | SplitUntil32Bit::iterator SplitUntil32Bit::end() const { |
1669 | return SplitUntil32Bit::iterator(); |
1670 | } |
1671 | |
1672 | DimCounter::DimCounter(IntArrayRef shape, Range range) |
1673 | : shape(shape) |
1674 | , range(range) |
1675 | , values(shape.size()) |
1676 | , offset(range.begin) { |
1677 | std::fill(values.begin(), values.end(), 0); |
1678 | if (range.begin == 0) { |
1679 | return; |
1680 | } |
1681 | |
1682 | int64_t linear_offset = range.begin; |
1683 | int64_t ndim = values.size(); |
1684 | for (const auto dim : c10::irange(ndim)) { |
1685 | int64_t size = shape[dim]; |
1686 | if (size > 0) { |
1687 | values[dim] = linear_offset % size; |
1688 | linear_offset /= size; |
1689 | } |
1690 | } |
1691 | TORCH_INTERNAL_ASSERT(linear_offset == 0); |
1692 | } |
1693 | |
1694 | bool DimCounter::is_done() const { |
1695 | return offset >= range.end; |
1696 | } |
1697 | |
1698 | void DimCounter::increment(const std::array<int64_t, 2>& step) { |
1699 | offset += step[0] * step[1]; |
1700 | int64_t ndim = values.size(); |
1701 | int64_t overflow = step[0]; |
1702 | int i = 0; |
1703 | if (step[1] != 1) { |
1704 | TORCH_INTERNAL_ASSERT(step[0] == shape[0] && values[0] == 0); |
1705 | i = 1; |
1706 | overflow = step[1]; |
1707 | } |
1708 | for (; i < ndim && overflow > 0; i++) { |
1709 | auto size = shape[i]; |
1710 | auto prev = values[i]; |
1711 | auto value = prev + overflow; |
1712 | if (value >= size) { |
1713 | overflow = 1; |
1714 | value -= size; |
1715 | TORCH_INTERNAL_ASSERT(value < size); |
1716 | } else { |
1717 | overflow = 0; |
1718 | } |
1719 | values[i] = value; |
1720 | } |
1721 | TORCH_INTERNAL_ASSERT(overflow == 0 || overflow == 1); |
1722 | } |
1723 | |
1724 | std::array<int64_t, 2> DimCounter::max_2d_step() const { |
1725 | int64_t step0 = std::min(shape[0] - values[0], range.end - offset); |
1726 | int64_t step1 = 1; |
1727 | if (step0 == shape[0] && !shape.empty()) { |
1728 | step1 = std::min(shape[1] - values[1], (range.end - offset) / shape[0]); |
1729 | } |
1730 | return {step0, step1}; |
1731 | } |
1732 | |
1733 | } // namespace at |
1734 | |