1/**
2 * This is a handwritten file that accompanies codegenerated header
3 * LazyShapeDtype.h
4 *
5 * The purpose of these shape/dtype inference methods are to fill gaps
6 * where we do not yet have structured kernels in pytorch core. Ops
7 * for which there _are_ structured kernels can use meta::op() to infer
8 * shape/dtype, and codegen makes use of this. Ops for which there are not
9 * yet structured kernels can still be used with lazy_tensor codegen, but
10 * require manual intervention to implement compute_shape_{op} and
11 * compute_dtype_{op}.
12 *
13 * READ THIS!
14 *
15 * 1. Beware: Tech Debt!
16 * ---------------------
17 * These functions are tech debt. We want to delete them all and use structured
18 * kernels instead, but it's a lot faster to write these so we're decoupling the
19 * two efforts to move fast for adding support for codegenned Lazy Tensor ops.
20 *
21 * Codegenned Lazy Tensor ops with handwritten shape formulae are still better
22 * than fully handwritten Lazy Tensor ops (which also have handwritten shape
23 * formulae).
24 *
25 * 2. Structured Kernels For The Win
26 * ---------------------------------
27 * Long term, more and more ops should be supported as 'structured kernels'.
28 * Consider doing your part and porting an op. As ops get ported over, the
29 * codegen will automatically notice and stop generating declarations for these
30 * shape formulae, so we'll need to manually clean up the unused functions in
31 * this file, or somehow automate that.
32 *
33 * https://dev-discuss.pytorch.org/t/slides-from-structured-kernel-presentation/179
34 *
35 * 3. How to figure out the shape/dtype
36 * ------------------------------------
37 * Unfortunatley there isn't a one-stop-shop for learning the output shape
38 * formulae for all operators. This is partly because some operators are not
39 * part of our 'public' API, including backward operators which users don't
40 * directly invoke.
41 *
42 * Check our opinfo registry:
43 * https://github.com/pytorch/pytorch/blob/13b859983183ea9938deb5030ac9a0747841f0a8/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
44 *
45 * Read the manual (for ops that are 1:1 with python frontend):
46 * https://pytorch.org/docs/stable/generated/torch.trace.html
47 *
48 */
49
50#include <torch/csrc/lazy/core/shape_inference.h>
51
52#include <ATen/AccumulateType.h>
53#include <ATen/CompositeExplicitAutogradFunctions.h>
54#include <ATen/Dispatch.h>
55#include <ATen/ExpandUtils.h>
56#include <ATen/Functions.h>
57#include <ATen/InferSize.h>
58#include <ATen/NativeFunctions.h>
59#include <ATen/WrapDimUtils.h>
60#include <ATen/native/ConvUtils.h>
61#include <ATen/native/ReduceOpsUtils.h>
62#include <ATen/native/TensorConversions.h>
63#include <c10/core/ScalarType.h>
64#include <torch/csrc/api/include/torch/enum.h>
65#include <torch/csrc/lazy/core/ops/utils.h>
66#include <torch/csrc/lazy/core/shape.h>
67#include <torch/csrc/lazy/core/util.h>
68#include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
69#include <ostream>
70#include <vector>
71
72namespace torch {
73namespace lazy {
74
75// Copied from ATen/native/utils/ParamUtils.h, which aparently I can't include
76// from here?
77std::vector<int64_t> expand_param_if_needed(
78 at::IntArrayRef list_param,
79 const char* param_name,
80 int64_t expected_dim) {
81 if (list_param.size() == 1) {
82 return std::vector<int64_t>(expected_dim, list_param[0]);
83 } else if ((int64_t)list_param.size() != expected_dim) {
84 std::ostringstream ss;
85 ss << "expected " << param_name << " to be a single integer value or a "
86 << "list of " << expected_dim << " values to match the convolution "
87 << "dimensions, but got " << param_name << "=" << list_param;
88 AT_ERROR(ss.str());
89 } else {
90 return list_param.vec();
91 }
92}
93
94// It seems more common to not use parameters than to use them, so disable
95// unused-parameter warning
96#pragma GCC diagnostic push
97#pragma GCC diagnostic ignored "-Wunused-parameter"
98
99TORCH_API std::vector<Shape> compute_shape_arange_out(
100 const at::Scalar& start,
101 const at::Scalar& end,
102 const at::Scalar& step,
103 at::Tensor& out) {
104 double size_d = 0;
105 // shape inference code copied from RangeFactories.cpp arange_out function
106 // Note: AT_DISPATCH_ALL_TYPES_AND is just a macro that defines the correct
107 // c++ scalar_t type depending on out tensor
108
109 AT_DISPATCH_ALL_TYPES_AND(
110 c10::kBFloat16, out.scalar_type(), "compute_shape_arange_out", [&]() {
111 // Note: acc_type further defines an accumulataion type depending on the
112 // scalar_t and whether its on cuda vs cpu.
113 using accscalar_t = at::acc_type<scalar_t, false>;
114 auto xstart = start.to<accscalar_t>();
115 auto xend = end.to<accscalar_t>();
116 auto xstep = step.to<accscalar_t>();
117
118 // we use double precision for (start - end) / step
119 // to compute size_d for consistency across devices.
120 // The problem with using accscalar_t is that accscalar_t might be
121 // float32 on gpu for a float32 scalar_t, but double on cpu for the
122 // same, and the effective output size starts differing on CPU vs GPU
123 // because of precision issues, which we dont want. the corner-case we
124 // do want to take into account is int64_t, which has higher precision
125 // than double NOLINTNEXTLINE(bugprone-branch-clone)
126 if (std::is_same<scalar_t, int64_t>::value) {
127 size_d = std::ceil(
128 static_cast<double>(
129 end.to<accscalar_t>() - start.to<accscalar_t>()) /
130 step.to<accscalar_t>());
131 } else {
132 size_d = std::ceil(
133 static_cast<double>(end.to<double>() - start.to<double>()) /
134 step.to<double>());
135 }
136
137 TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero");
138 TORCH_CHECK(
139 std::isfinite(static_cast<double>(xstart)) &&
140 std::isfinite(static_cast<double>(xend)),
141 "unsupported range: ",
142 xstart,
143 " -> ",
144 xend);
145 TORCH_CHECK(
146 ((xstep > 0) && (xend >= xstart)) ||
147 ((xstep < 0) && (xend <= xstart)),
148 "upper bound and larger bound inconsistent with step sign");
149
150 TORCH_CHECK(
151 size_d >= 0 &&
152 size_d <=
153 static_cast<double>(std::numeric_limits<int64_t>::max()),
154 "invalid size, possible overflow?");
155 });
156
157 int64_t size = static_cast<int64_t>(size_d);
158
159 // From torch.arange docs:
160 // dtype (torch.dtype, optional) – the desired data type of returned tensor.
161 // Default: if None, uses a global default (see
162 // torch.set_default_tensor_type()). If dtype is not given, infer the data
163 // type from the other input arguments. If any of start, end, or stop are
164 // floating-point, the dtype is inferred to be the default dtype, see
165 // get_default_dtype(). Otherwise, the dtype is inferred to be torch.int64.
166
167 return {Shape(out.scalar_type(), {size})};
168}
169
170std::vector<Shape> compute_shape_abs(const at::Tensor& self) {
171 if (self.is_complex()) {
172 const auto float_type = c10::toRealValueType(self.scalar_type());
173 return {Shape(float_type, self.sizes().vec())};
174 }
175 return {Shape(self.scalar_type(), self.sizes().vec())};
176}
177
178std::vector<Shape> compute_shape_bernoulli(
179 const at::Tensor& self,
180 c10::optional<at::Generator> generator) {
181 return {Shape(self.scalar_type(), self.sizes().vec())};
182}
183
184std::vector<Shape> compute_shape_bernoulli(
185 const at::Tensor& self,
186 double p,
187 c10::optional<at::Generator> generator) {
188 return compute_shape_bernoulli(self, generator);
189}
190
191std::vector<Shape> compute_shape_binary_cross_entropy(
192 const at::Tensor& self,
193 const at::Tensor& target,
194 const c10::optional<at::Tensor>& weight,
195 int64_t reduction) {
196 if (reduction == at::Reduction::None) {
197 return {Shape(self.scalar_type(), self.sizes().vec())};
198 }
199 return {Shape(self.scalar_type(), {})};
200}
201
202std::vector<Shape> compute_shape_binary_cross_entropy_backward(
203 const at::Tensor& grad_output,
204 const at::Tensor& self,
205 const at::Tensor& target,
206 const c10::optional<at::Tensor>& weight,
207 int64_t reduction) {
208 return {Shape(self.scalar_type(), self.sizes().vec())};
209}
210
211std::vector<Shape> compute_shape_constant_pad_nd(
212 const at::Tensor& self,
213 at::IntArrayRef pad,
214 const at::Scalar& value) {
215 // Based on aten/src/ATen/native/ConstantPadNd.cpp::constant_pad_nd
216 TORCH_CHECK(
217 pad.size() % 2 == 0,
218 "Length of pad must be even but instead it equals ",
219 pad.size());
220
221 auto input_sizes = self.sizes();
222 auto l_inp = self.dim();
223
224 auto l_pad = pad.size() / 2;
225 auto l_diff = l_inp - l_pad;
226 TORCH_CHECK(
227 l_inp >= (int64_t)l_pad,
228 "Length of pad should be no more than twice the number of "
229 "dimensions of the input. Pad length is ",
230 pad.size(),
231 "while the input has ",
232 l_inp,
233 "dimensions.");
234
235 std::vector<int64_t> new_shape;
236 for (size_t i = 0; i < (size_t)l_diff; i++) {
237 new_shape.emplace_back(input_sizes[i]);
238 }
239
240 for (const auto i : c10::irange((size_t)l_pad)) {
241 auto pad_idx = pad.size() - ((i + 1) * 2);
242 auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1];
243 TORCH_CHECK(
244 new_dim > 0,
245 "The input size ",
246 input_sizes[l_diff + i],
247 ", plus negative padding ",
248 pad[pad_idx],
249 " and ",
250 pad[pad_idx + 1],
251 " resulted in a negative output size, "
252 "which is invalid. Check dimension ",
253 l_diff + i,
254 " of your input.");
255 new_shape.emplace_back(new_dim);
256 }
257 return {Shape(self.scalar_type(), new_shape)};
258}
259
260std::vector<Shape> compute_shape_convolution_backward(
261 const at::Tensor& grad_output,
262 const at::Tensor& input,
263 const at::Tensor& weight,
264 at::OptionalIntArrayRef bias_sizes,
265 at::IntArrayRef stride,
266 at::IntArrayRef padding,
267 at::IntArrayRef dilation,
268 bool transposed,
269 at::IntArrayRef output_padding,
270 int64_t groups,
271 ::std::array<bool, 3> output_mask) {
272 if (bias_sizes.has_value()) {
273 return {
274 Shape(input.scalar_type(), input.sizes().vec()),
275 Shape(weight.scalar_type(), weight.sizes().vec()),
276 Shape(grad_output.scalar_type(), bias_sizes.value().vec())};
277 } else {
278 // TODO(whc) not sure whether to return 2 shapes here, or a 3rd one that is
279 // empty
280 return {
281 Shape(input.scalar_type(), input.sizes().vec()),
282 Shape(weight.scalar_type(), weight.sizes().vec())};
283 }
284}
285
286std::vector<Shape> compute_shape_convolution(
287 const at::Tensor& input,
288 const at::Tensor& weight,
289 const c10::optional<at::Tensor>& bias,
290 at::IntArrayRef stride,
291 at::IntArrayRef padding,
292 at::IntArrayRef dilation,
293 bool transposed,
294 at::IntArrayRef output_padding,
295 int64_t groups) {
296 int64_t dim = weight.ndimension() - 2;
297 TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
298
299 // at::convolution performs parameter expansion before running kernels on
300 // expanded parameters we must do the same. Shape formulae access differnent
301 // dimensions of e.g. output_padding, but output_padding may be passed in as a
302 // scalar. Sadly, accessing output_padding[1] in this case gives incorrect
303 // results rather than indexing error
304 auto expanded_stride = expand_param_if_needed(stride, "stride", dim);
305 auto expanded_padding = expand_param_if_needed(padding, "padding", dim);
306 auto expanded_dilation = expand_param_if_needed(dilation, "dilation", dim);
307 if (!transposed) {
308 return {Shape(
309 input.scalar_type(),
310 at::native::conv_output_size(
311 input.sizes(),
312 weight.sizes(),
313 expanded_padding,
314 expanded_stride,
315 expanded_dilation))};
316 } else {
317 auto expanded_output_padding =
318 expand_param_if_needed(output_padding, "output_padding", dim);
319 auto out_shape = at::native::conv_input_size(
320 input.sizes(),
321 weight.sizes(),
322 expanded_padding,
323 expanded_output_padding,
324 expanded_stride,
325 expanded_dilation,
326 groups);
327 return {Shape(input.scalar_type(), out_shape)};
328 }
329}
330
331std::vector<Shape> compute_shape_masked_fill(
332 const at::Tensor& self,
333 const at::Tensor& mask,
334 const at::Scalar& value) {
335 return {Shape(self.scalar_type(), self.sizes().vec())};
336}
337
338std::vector<Shape> compute_shape_masked_fill(
339 const at::Tensor& self,
340 const at::Tensor& mask,
341 const at::Tensor& value) {
342 return {Shape(self.scalar_type(), self.sizes().vec())};
343}
344
345std::vector<Shape> compute_shape_max(const at::Tensor& self) {
346 TORCH_CHECK(
347 self.numel() > 0,
348 "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.");
349 return {Shape(self.scalar_type(), {})};
350}
351
352std::vector<Shape> compute_shape_min(const at::Tensor& self) {
353 TORCH_CHECK(
354 self.numel() > 0,
355 "min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument.");
356 return {Shape(self.scalar_type(), {})};
357}
358
359std::vector<Shape> compute_shape_nonzero(const at::Tensor& t, bool as_tuple) {
360 if (as_tuple) {
361 auto res = std::vector<Shape>();
362 for (auto dim_size : t.sizes()) {
363 res.emplace_back(Shape(at::kLong, {dim_size}));
364 }
365 return res;
366 }
367 int64_t max_elements = 1;
368 for (auto dim_size : t.sizes()) {
369 max_elements *= dim_size;
370 }
371 return {Shape(at::kLong, {max_elements, (int64_t)t.sizes().size()})};
372}
373
374std::vector<Shape> compute_shape_nonzero(const at::Tensor& self) {
375 return compute_shape_nonzero(self, false);
376}
377
378std::vector<Shape> compute_shape_embedding(
379 const at::Tensor& weight,
380 const at::Tensor& indices,
381 int64_t padding_idx,
382 bool scale_grad_by_freq,
383 bool sparse) {
384 // Based on aten/src/ATen/native/Embedding.cpp::embedding.
385 std::vector<int64_t> out_sizes = indices.sizes().vec();
386 out_sizes.emplace_back(weight.size(1));
387 return {Shape(weight.scalar_type(), out_sizes)};
388}
389
390std::vector<Shape> compute_shape_std(const at::Tensor& self, bool unbiased) {
391 return compute_shape_std(self, c10::nullopt, c10::nullopt, false);
392}
393std::vector<Shape> compute_shape_std(
394 const at::Tensor& self,
395 at::OptionalIntArrayRef dim,
396 bool unbiased,
397 bool keepdim) {
398 return compute_shape_std(self, dim, c10::nullopt, keepdim);
399}
400std::vector<Shape> compute_shape_std(
401 const at::Tensor& self,
402 at::OptionalIntArrayRef dim,
403 c10::optional<int64_t> correction,
404 bool keepdim) {
405 if (dim.has_value()) {
406 auto shape = at::native::shape_from_dim_mask(
407 self, at::native::make_dim_mask(dim.value(), self.dim()), keepdim);
408 return {Shape(
409 self.scalar_type(), std::vector<int64_t>(shape.begin(), shape.end()))};
410 }
411 return {Shape(self.scalar_type(), {})};
412}
413
414std::vector<Shape> compute_shape_embedding_dense_backward(
415 const at::Tensor& grad_output,
416 const at::Tensor& indices,
417 int64_t num_weights,
418 int64_t padding_idx,
419 bool scale_grad_by_freq) {
420 // Based on aten/src/ATen/native/Embedding.cpp::embedding_dense_backward_cpu.
421 return {
422 Shape(grad_output.scalar_type(), {num_weights, grad_output.size(-1)})};
423}
424
425std::vector<Shape> compute_shape_expand(
426 const at::Tensor& self,
427 at::IntArrayRef size,
428 bool implicit) {
429 TORCH_CHECK_GE(size.size(), self.dim());
430 int64_t num_new_dimensions = size.size() - self.dim();
431 std::vector<int64_t> padded_self(num_new_dimensions, 0);
432 padded_self.insert(
433 padded_self.end(), self.sizes().begin(), self.sizes().end());
434 std::vector<int64_t> target_size(size.size());
435 for (const auto idx : c10::irange(size.size())) {
436 target_size[idx] = size[idx] == -1 ? padded_self[idx] : size[idx];
437 }
438 return {Shape(self.scalar_type(), target_size)};
439}
440
441std::vector<Shape> compute_shape_expand(
442 const at::Tensor& self,
443 c10::SymIntArrayRef size,
444 bool implicit) {
445 TORCH_CHECK_GE(size.size(), self.dim());
446 std::vector<c10::SymInt> _sizes = ToVector<c10::SymInt>(size);
447 int64_t num_new_dimensions = _sizes.size() - self.dim();
448 std::vector<int64_t> padded_self(num_new_dimensions, 0);
449 padded_self.insert(
450 padded_self.end(), self.sizes().begin(), self.sizes().end());
451 std::vector<int64_t> target_size(_sizes.size());
452 for (const auto idx : c10::irange(_sizes.size())) {
453 if (_sizes[idx].is_symbolic()) {
454 c10::SymNode symbolicIntNode = _sizes[idx].toSymNodeImpl();
455 auto* lazySymNode =
456 dynamic_cast<torch::lazy::SymNodeImpl*>(symbolicIntNode.get());
457 TORCH_INTERNAL_ASSERT(lazySymNode);
458 auto size_node = lazySymNode->node_;
459 auto static_value =
460 std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node)
461 ->getStaticValue();
462 target_size[idx] = static_value;
463 } else {
464 target_size[idx] = _sizes[idx].as_int_unchecked();
465 if (_sizes[idx].as_int_unchecked() == -1) {
466 // -1 can't be specified for non-existing dimensions
467 TORCH_CHECK(idx >= num_new_dimensions);
468 target_size[idx] = padded_self[idx];
469 } else {
470 target_size[idx] = _sizes[idx].as_int_unchecked();
471 }
472 }
473 }
474 return {Shape(self.scalar_type(), target_size)};
475}
476
477std::vector<Shape> compute_shape_index_select(
478 const at::Tensor& self,
479 int64_t dim,
480 const at::Tensor& index) {
481 // Based on definition of
482 // https://pytorch.org/docs/stable/generated/torch.index_select.html. Promote
483 // Rank 0 index tensor to a 1 * 1 tensor.
484 dim = at::maybe_wrap_dim(dim, self);
485 auto index_dim = index.dim() > 0 ? index.dim() : 1;
486 auto index_size = index.dim() > 0 ? index.size(0) : 1;
487 TORCH_CHECK(index_dim == 1);
488
489 auto self_sizes = self.sizes();
490 std::vector<int64_t> output_sizes(self_sizes.begin(), self_sizes.end());
491 TORCH_CHECK(!output_sizes.empty(), "Empty output_sizes is not supported.");
492 output_sizes[dim] = index_size;
493
494 return {Shape(self.scalar_type(), output_sizes)};
495}
496
497std::vector<Shape> compute_shape_inverse(const at::Tensor& self) {
498 return {Shape(self.scalar_type(), self.sizes().vec())};
499}
500
501std::vector<Shape> compute_shape_isnan(const at::Tensor& self) {
502 return {Shape(c10::ScalarType::Bool, self.sizes().vec())};
503}
504
505std::vector<Shape> compute_shape_cat(at::TensorList tensors, int64_t dim) {
506 // TODO(whc) support cat in codegen and move this to compute_*_cat functions
507 std::vector<int64_t> out_shape(
508 tensors[0].sizes().begin(), tensors[0].sizes().end());
509
510 dim = at::maybe_wrap_dim(dim, tensors);
511 size_t extended_dim_shape = 0;
512 for (auto& tensor : tensors) {
513 extended_dim_shape += tensor.sizes()[dim];
514 }
515 TORCH_CHECK(!out_shape.empty(), "Scalar tensors are not supported in cat.");
516 TORCH_CHECK(
517 extended_dim_shape <= std::numeric_limits<int64_t>::max(),
518 "Size overflow");
519 out_shape[dim] = extended_dim_shape;
520 return {Shape(tensors[0].scalar_type(), out_shape)};
521}
522
523TORCH_API std::vector<torch::lazy::Shape> compute_shape_cholesky(
524 const at::Tensor& self,
525 bool upper) {
526 return {Shape(self.scalar_type(), self.sizes().vec())};
527}
528
529std::vector<torch::lazy::Shape> compute_shape_native_batch_norm(
530 const at::Tensor& input,
531 const c10::optional<at::Tensor>& weight,
532 const c10::optional<at::Tensor>& bias,
533 const c10::optional<at::Tensor>& running_mean,
534 const c10::optional<at::Tensor>& running_var,
535 bool training,
536 double momentum,
537 double eps) {
538 std::vector<torch::lazy::Shape> shapes;
539 shapes.reserve(3);
540 shapes.emplace_back(input.scalar_type(), input.sizes().vec());
541
542 // A separate mean and var needs to be kept for each channel.
543 TORCH_CHECK(
544 input.sizes().size() >= 2,
545 "Input tensor must have at least batch and channel dimensions!");
546 int64_t num_features = input.size(1);
547
548 if (running_mean.has_value()) {
549 shapes.emplace_back(
550 running_mean.value().scalar_type(), running_mean.value().sizes().vec());
551 } else {
552 shapes.emplace_back(
553 at::get_default_dtype_as_scalartype(),
554 std::vector<int64_t>{num_features});
555 }
556
557 if (running_var.has_value()) {
558 shapes.emplace_back(
559 running_var.value().scalar_type(), running_var.value().sizes().vec());
560 } else {
561 shapes.emplace_back(
562 at::get_default_dtype_as_scalartype(),
563 std::vector<int64_t>{num_features});
564 }
565 return shapes;
566}
567
568std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward(
569 const at::Tensor& grad_out,
570 const at::Tensor& input,
571 const c10::optional<at::Tensor>& weight,
572 const c10::optional<at::Tensor>& running_mean,
573 const c10::optional<at::Tensor>& running_var,
574 const c10::optional<at::Tensor>& save_mean,
575 const c10::optional<at::Tensor>& save_invstd,
576 bool train,
577 double eps,
578 ::std::array<bool, 3> output_mask) {
579 std::vector<torch::lazy::Shape> shapes;
580 shapes.reserve(3);
581 shapes.emplace_back(input.scalar_type(), input.sizes().vec());
582
583 // A separate mean and var needs to be kept for each channel.
584 TORCH_CHECK(
585 input.sizes().size() >= 2,
586 "Input tensor must have at least batch and channel dimensions!");
587 int64_t num_features = input.size(1);
588
589 // `weight` and `bias` are vectors of length C (number of channels)`
590 shapes.emplace_back(
591 at::get_default_dtype_as_scalartype(),
592 std::vector<int64_t>{num_features});
593 shapes.emplace_back(
594 at::get_default_dtype_as_scalartype(),
595 std::vector<int64_t>{num_features});
596
597 return shapes;
598}
599
600std::vector<Shape> compute_shape_native_layer_norm(
601 const at::Tensor& input,
602 at::IntArrayRef normalized_shape,
603 const c10::optional<at::Tensor>& weight,
604 const c10::optional<at::Tensor>& bias,
605 double eps) {
606 // Copied from aten/src/ATen/native/layer_norm.cpp::layer_norm_cpu_out.
607 auto input_shape = input.sizes().vec();
608 const size_t axis = input.dim() - normalized_shape.size();
609
610 std::vector<int64_t> stat_shape;
611 for (const auto idx : c10::irange(axis)) {
612 TORCH_CHECK(idx < input_shape.size(), "Shape mismatch");
613 stat_shape.emplace_back(input_shape[idx]);
614 }
615 for (const auto idx : c10::irange(axis, input.dim())) {
616 (void)idx; // Suppress unused variable warning
617 stat_shape.emplace_back(1);
618 }
619
620 return {
621 Shape(input.scalar_type(), input_shape),
622 Shape(input.scalar_type(), stat_shape),
623 Shape(input.scalar_type(), stat_shape)};
624}
625
626std::vector<Shape> compute_shape_native_layer_norm_backward(
627 const at::Tensor& grad_out,
628 const at::Tensor& input,
629 at::IntArrayRef normalized_shape,
630 const at::Tensor& mean,
631 const at::Tensor& rstd,
632 const c10::optional<at::Tensor>& weight,
633 const c10::optional<at::Tensor>& bias,
634 ::std::array<bool, 3> output_mask) {
635 std::vector<Shape> shapes;
636 shapes.emplace_back(
637 input.scalar_type(),
638 output_mask[0] ? input.sizes().vec() : std::vector<int64_t>{});
639 shapes.emplace_back(
640 weight && weight->defined() ? weight->scalar_type() : input.scalar_type(),
641 output_mask[1] && weight ? weight->sizes().vec()
642 : std::vector<int64_t>{});
643 shapes.emplace_back(
644 bias && weight->defined() ? bias->scalar_type() : input.scalar_type(),
645 output_mask[2] && bias ? bias->sizes().vec() : std::vector<int64_t>{});
646 return shapes;
647}
648
649std::vector<Shape> compute_shape_mean(
650 const at::Tensor& self,
651 c10::optional<at::ScalarType> dtype) {
652 if (dtype.has_value()) {
653 return {Shape(dtype.value(), {})};
654 }
655 return {Shape(self.scalar_type(), {})};
656}
657
658std::vector<Shape> compute_shape_new_empty_strided(
659 const at::Tensor& self,
660 at::IntArrayRef size,
661 at::IntArrayRef stride,
662 c10::optional<at::ScalarType> dtype,
663 c10::optional<at::Layout> layout,
664 c10::optional<at::Device> device,
665 c10::optional<bool> pin_memory) {
666 return {Shape(dtype.has_value() ? *dtype : self.scalar_type(), size.vec())};
667}
668
669std::vector<Shape> compute_shape_mv(
670 const at::Tensor& self,
671 const at::Tensor& vec) {
672 return {Shape(self.scalar_type(), {self.size(0)})};
673}
674
675std::vector<Shape> compute_shape_native_dropout(
676 const at::Tensor& input,
677 double p,
678 c10::optional<bool> train) {
679 return {
680 Shape(input.scalar_type(), input.sizes().vec()),
681 Shape(c10::ScalarType::Bool, input.sizes().vec())};
682}
683
684std::vector<Shape> compute_shape_native_dropout_backward(
685 const at::Tensor& grad_output,
686 const at::Tensor& mask,
687 double scale) {
688 return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())};
689}
690
691std::vector<Shape> compute_shape_random(
692 const at::Tensor& self,
693 c10::optional<at::Generator> generator) {
694 return {Shape(self.scalar_type(), self.sizes().vec())};
695}
696
697std::vector<Shape> compute_shape_random(
698 const at::Tensor& self,
699 int64_t to,
700 c10::optional<at::Generator> generator) {
701 return compute_shape_random(self, generator);
702}
703
704std::vector<Shape> compute_shape_random(
705 const at::Tensor& self,
706 int64_t from,
707 c10::optional<int64_t> to,
708 c10::optional<at::Generator> generator) {
709 return compute_shape_random(self, generator);
710}
711
712std::vector<Shape> compute_shape_relu(const at::Tensor& self) {
713 return {Shape(self.scalar_type(), self.sizes().vec())};
714}
715
716std::vector<Shape> compute_shape_bitwise_and(
717 const at::Tensor& self,
718 const at::Scalar& other) {
719 return {Shape(self.scalar_type(), self.sizes().vec())};
720}
721
722std::vector<Shape> compute_shape_sum(
723 const at::Tensor& self,
724 c10::optional<at::ScalarType> dtype) {
725 if (dtype.has_value()) {
726 return {Shape(dtype.value(), {})};
727 }
728 // It's undocumented, but torch::sum promotes all integral types to int64_t by
729 // default
730 if (isIntegralType(self.scalar_type(), /*includeBool*/ true)) {
731 return {Shape(c10::ScalarType::Long, {})};
732 }
733 return {Shape(self.scalar_type(), {})};
734 ;
735}
736
737std::vector<Shape> compute_shape_zero(const at::Tensor& self) {
738 return {Shape(self.scalar_type(), self.sizes().vec())};
739}
740
741TORCH_API std::vector<torch::lazy::Shape> compute_shape_take(
742 const at::Tensor& self,
743 const at::Tensor& index) {
744 return {Shape(self.scalar_type(), index.sizes().vec())};
745}
746
747std::vector<Shape> compute_shape_trace(const at::Tensor& self) {
748 return {Shape(self.scalar_type(), {})};
749}
750
751std::vector<Shape> compute_shape_sort(
752 const at::Tensor& self,
753 int64_t dim,
754 bool descending) {
755 return {
756 Shape(self.scalar_type(), self.sizes().vec()),
757 Shape(c10::ScalarType::Long, self.sizes().vec())};
758}
759
760std::vector<Shape> compute_shape_smooth_l1_loss(
761 const at::Tensor& self,
762 const at::Tensor& target,
763 int64_t reduction,
764 double beta) {
765 // Taken from definition of 'Output' shape here:
766 // https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html
767 switch (reduction) {
768 case at::Reduction::None:
769 return {Shape(self.scalar_type(), self.sizes().vec())};
770 default:
771 return {Shape(self.scalar_type(), {})};
772 }
773}
774
775std::vector<Shape> compute_shape_slogdet(const at::Tensor& self) {
776 // assumes self.shape is {*, n, n} and returns shape *
777 TORCH_INTERNAL_ASSERT(self.dim() >= 2);
778 std::vector<int64_t> out_sizes(self.sizes().begin(), self.sizes().end() - 2);
779 // Doesn't check input dtype, but output dtype either matches it,
780 // or the actual slogdet operation will throw if it's an unsupported type.
781 // Sign and det outputs hold the same shape, dtype.
782 return {
783 Shape(self.scalar_type(), out_sizes),
784 Shape(self.scalar_type(), out_sizes)};
785}
786
787std::vector<torch::lazy::Shape> compute_shape_logical_and(
788 const at::Tensor& self,
789 const at::Tensor& other) {
790 TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes()));
791 return {Shape(
792 c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))};
793}
794
795std::vector<torch::lazy::Shape> compute_shape_logical_not(
796 const at::Tensor& self) {
797 return {Shape(c10::ScalarType::Bool, self.sizes().vec())};
798}
799
800std::vector<torch::lazy::Shape> compute_shape_logical_or(
801 const at::Tensor& self,
802 const at::Tensor& other) {
803 TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes()));
804 return {Shape(
805 c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))};
806}
807
808std::vector<torch::lazy::Shape> compute_shape_logical_xor(
809 const at::Tensor& self,
810 const at::Tensor& other) {
811 TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes()));
812 return {Shape(
813 c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))};
814}
815
816std::vector<Shape> compute_shape_smooth_l1_loss_backward(
817 const at::Tensor& grad_output,
818 const at::Tensor& self,
819 const at::Tensor& target,
820 int64_t reduction,
821 double beta) {
822 // The `grad_output` tensor is really the input to this kernel, and while its
823 // shape may vary following the logic of the forward output, the output of
824 // this kernel should have fixed shapes matching the inputs to the forward
825 // kernel.
826 return {Shape(self.scalar_type(), self.sizes().vec())};
827}
828
829std::vector<Shape> compute_shape_logdet(const at::Tensor& self) {
830 // assumes self.shape is {*, n, n} and returns shape *
831 TORCH_INTERNAL_ASSERT(self.dim() >= 2);
832 std::vector<int64_t> out_sizes(self.sizes().begin(), self.sizes().end() - 2);
833 // Doesn't check input dtype, but output dtype either matches it,
834 // or the actual logdet operation will throw if it's an unsupported type
835 return {Shape(self.scalar_type(), out_sizes)};
836}
837
838std::vector<Shape> compute_shape_log_sigmoid_forward(const at::Tensor& self) {
839 // Based on definition of
840 // aten/src/ATen/native/Activation.cpp::log_sigmoid_forward_out_cpu.
841 return {
842 Shape(self.scalar_type(), self.sizes().vec()),
843 Shape(self.scalar_type(), self.sizes().vec())};
844}
845
846std::vector<Shape> compute_shape_log_sigmoid_backward(
847 const at::Tensor& grad_output,
848 const at::Tensor& self,
849 const at::Tensor& buffer) {
850 // Based on definition of
851 // aten/src/ATen/native/Activation.cpp::log_sigmoid_backward_cpu*.
852 return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())};
853}
854
855std::vector<Shape> compute_shape_nll_loss2d_forward(
856 const at::Tensor& self,
857 const at::Tensor& target,
858 const c10::optional<at::Tensor>& weight,
859 int64_t reduction,
860 int64_t ignore_index) {
861 // Based on definition of
862 // aten/src/ATen/native/LossNLL2d.cpp:nll_loss2d_forward_cpu
863 auto sizes =
864 (reduction == at::Reduction::Reduction::None ? target.sizes().vec()
865 : std::vector<int64_t>{});
866 return {Shape(self.scalar_type(), sizes), Shape(self.scalar_type(), {})};
867}
868
869std::vector<Shape> compute_shape_nll_loss2d_backward(
870 const at::Tensor& grad_output,
871 const at::Tensor& self,
872 const at::Tensor& target,
873 const c10::optional<at::Tensor>& weight,
874 int64_t reduction,
875 int64_t ignore_index,
876 const at::Tensor& total_weight) {
877 return {Shape(self.scalar_type(), self.sizes().vec())};
878}
879
880std::vector<Shape> compute_shape_grid_sampler_2d(
881 const at::Tensor& input,
882 const at::Tensor& grid,
883 int64_t interpolation_mode,
884 int64_t padding_mode,
885 bool align_corners) {
886 // from `aten/src/ATen/native/cpu/GridSamplerKernel.cpp
887 int64_t N = input.size(0);
888 int64_t C = input.size(1);
889 int64_t H = grid.size(1);
890 int64_t W = grid.size(2);
891 return {Shape(input.scalar_type(), {N, C, H, W})};
892}
893
894std::vector<Shape> compute_shape_grid_sampler_2d_backward(
895 const at::Tensor& grad_output,
896 const at::Tensor& input,
897 const at::Tensor& grid,
898 int64_t interpolation_mode,
899 int64_t padding_mode,
900 bool align_corners,
901 ::std::array<bool, 2> output_mask) {
902 // from `aten/src/ATen/native/cpu/GridSamplerKernel.cpp
903 auto grad_input_shape = Shape(input.scalar_type(), input.sizes().vec());
904 auto grad_grid_shape = Shape(grid.scalar_type(), grid.sizes().vec());
905 return {grad_input_shape, grad_grid_shape};
906}
907
908std::vector<Shape> compute_shape_flip(
909 const at::Tensor& self,
910 at::IntArrayRef dims) {
911 return {Shape(self.scalar_type(), self.sizes().vec())};
912}
913
914std::vector<Shape> compute_shape__adaptive_avg_pool2d(
915 const at::Tensor& self,
916 at::IntArrayRef output_size) {
917 // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
918 // and on `aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp`
919 TORCH_CHECK(
920 output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2");
921 TORCH_CHECK(
922 (output_size[0] >= 0 && output_size[1] >= 0),
923 "adaptive_avg_pool2d: elements of output_size must be greater than or equal to 0 ",
924 "but received {",
925 output_size[0],
926 ", ",
927 output_size[1],
928 "}");
929 int64_t ndim = self.ndimension();
930 for (const auto i : c10::irange(1, ndim)) {
931 TORCH_CHECK(
932 self.size(i) > 0,
933 "adaptive_avg_pool2d(): Expected self to have non-zero size for non-batch dimensions, "
934 "but Tensor has sizes ",
935 self.sizes(),
936 " with dimension ",
937 i,
938 " being "
939 "empty");
940 }
941 TORCH_CHECK(
942 (ndim == 3 || ndim == 4),
943 "adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got ",
944 self.sizes());
945
946 int64_t channels = self.size(-3);
947 int64_t output_height = output_size[0];
948 int64_t output_width = output_size[1];
949
950 if (ndim == 3) {
951 return {Shape(self.scalar_type(), {channels, output_height, output_width})};
952 } else {
953 int64_t nbatch = self.size(0);
954 return {Shape(
955 self.scalar_type(), {nbatch, channels, output_height, output_width})};
956 }
957}
958
959std::vector<Shape> compute_shape__adaptive_avg_pool2d_backward(
960 const at::Tensor& grad_output,
961 const at::Tensor& self) {
962 // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
963 int64_t ndim = grad_output.ndimension();
964
965 for (const auto i : c10::irange(1, ndim)) {
966 TORCH_CHECK(
967 grad_output.size(i) > 0,
968 "adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, "
969 "but grad_output has sizes ",
970 grad_output.sizes(),
971 " with dimension ",
972 i,
973 " being "
974 "empty");
975 }
976
977 TORCH_CHECK(
978 (ndim == 3 || ndim == 4),
979 "adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got ",
980 self.sizes());
981 TORCH_CHECK(
982 self.dtype() == grad_output.dtype(),
983 "expected dtype ",
984 self.dtype(),
985 " for `grad_output` but got dtype ",
986 grad_output.dtype());
987
988 return {Shape(self.scalar_type(), self.sizes().vec())};
989}
990
991std::vector<Shape> compute_shape__adaptive_avg_pool3d(
992 const at::Tensor& self,
993 at::IntArrayRef output_size) {
994 // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
995 // and on `aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp`
996 TORCH_CHECK(
997 output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3");
998 TORCH_CHECK(
999 (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0),
1000 "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 ",
1001 "but received {",
1002 output_size[0],
1003 ", ",
1004 output_size[1],
1005 ", ",
1006 output_size[2],
1007 "}");
1008 int64_t ndim = self.ndimension();
1009 for (const auto i : c10::irange(1, ndim)) {
1010 TORCH_CHECK(
1011 self.size(i) > 0,
1012 "adaptive_avg_pool3d(): Expected self to have non-zero size for non-batch dimensions, "
1013 "but Tensor has sizes ",
1014 self.sizes(),
1015 " with dimension ",
1016 i,
1017 " being "
1018 "empty");
1019 }
1020 TORCH_CHECK(
1021 (ndim == 4 || ndim == 5),
1022 "adaptive_avg_pool3d(): Expected 4D or 5D tensor, but got ",
1023 self.sizes());
1024
1025 int64_t channels = self.size(-3);
1026 int64_t output_depth = output_size[0];
1027 int64_t output_height = output_size[1];
1028 int64_t output_width = output_size[2];
1029
1030 if (ndim == 4) {
1031 return {Shape(
1032 self.scalar_type(),
1033 {channels, output_depth, output_height, output_width})};
1034 } else {
1035 int64_t nbatch = self.size(0);
1036 return {Shape(
1037 self.scalar_type(),
1038 {nbatch, channels, output_depth, output_height, output_width})};
1039 }
1040}
1041
1042std::vector<Shape> compute_shape__adaptive_avg_pool3d_backward(
1043 const at::Tensor& grad_output,
1044 const at::Tensor& self) {
1045 // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp`
1046 int64_t ndim = grad_output.ndimension();
1047
1048 for (const auto i : c10::irange(1, ndim)) {
1049 TORCH_CHECK(
1050 grad_output.size(i) > 0,
1051 "adaptive_avg_pool3d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, "
1052 "but grad_output has sizes ",
1053 grad_output.sizes(),
1054 " with dimension ",
1055 i,
1056 " being "
1057 "empty");
1058 }
1059
1060 TORCH_CHECK(
1061 (ndim == 4 || ndim == 5),
1062 "adaptive_avg_pool3d_backward(): Expected 4D or 5D tensor, but got ",
1063 self.sizes());
1064 TORCH_CHECK(
1065 self.dtype() == grad_output.dtype(),
1066 "expected dtype ",
1067 self.dtype(),
1068 " for `grad_output` but got dtype ",
1069 grad_output.dtype());
1070
1071 return {Shape(self.scalar_type(), self.sizes().vec())};
1072}
1073
1074std::vector<Shape> compute_shape_glu_backward(
1075 const at::Tensor& grad_output,
1076 const at::Tensor& self,
1077 int64_t dim) {
1078 return {Shape(self.scalar_type(), self.sizes().vec())};
1079}
1080
1081std::vector<Shape> compute_shape_glu_jvp(
1082 const at::Tensor& glu,
1083 const at::Tensor& x,
1084 const at::Tensor& dx,
1085 int64_t dim) {
1086 return {Shape(glu.scalar_type(), glu.sizes().vec())};
1087}
1088
1089std::vector<Shape> compute_shape_clamp_min(
1090 const at::Tensor& self,
1091 const at::Scalar& min) {
1092 return {Shape(self.scalar_type(), self.sizes().vec())};
1093}
1094
1095std::vector<Shape> compute_shape__to_copy(
1096 const at::Tensor& self,
1097 c10::optional<at::ScalarType> dtype,
1098 c10::optional<at::Layout> layout,
1099 c10::optional<at::Device> device,
1100 c10::optional<bool> pin_memory,
1101 bool non_blocking,
1102 c10::optional<at::MemoryFormat> memory_format) {
1103 if (dtype) {
1104 return {Shape(*dtype, self.sizes().vec())};
1105 }
1106 return {Shape(self.scalar_type(), self.sizes().vec())};
1107}
1108
1109TORCH_API std::vector<Shape> compute_shape_clone(
1110 const at::Tensor& self,
1111 c10::optional<at::MemoryFormat> memory_format) {
1112 return {Shape(self.scalar_type(), self.sizes().vec())};
1113}
1114
1115std::vector<Shape> compute_shape_stack(at::TensorList tensors, int64_t dim) {
1116 TORCH_CHECK(!tensors.empty(), "stack expects a non-empty TensorList");
1117 auto wrapped_dim = at::maybe_wrap_dim(dim, tensors[0].ndimension() + 1);
1118
1119 // Copied from 'check_stack_inputs' in TensorShape.cpp
1120 at::IntArrayRef entry_shape = tensors[0].sizes();
1121 for (const auto i : c10::irange(1, tensors.size())) {
1122 TORCH_CHECK(
1123 tensors[i].sizes() == entry_shape,
1124 "stack expects each tensor to be equal size, but got ",
1125 entry_shape,
1126 " at entry 0 and ",
1127 tensors[i].sizes(),
1128 " at entry ",
1129 i);
1130 }
1131
1132 auto result_sizes = tensors[0].sizes().vec();
1133 result_sizes.insert(result_sizes.begin() + wrapped_dim, tensors.size());
1134 return {Shape(tensors[0].scalar_type(), result_sizes)};
1135}
1136
1137std::vector<Shape> compute_shape_repeat(
1138 const at::Tensor& self,
1139 at::IntArrayRef repeats) {
1140 TORCH_CHECK_GE(repeats.size(), self.dim());
1141 int64_t num_new_dimensions = repeats.size() - self.dim();
1142 std::vector<int64_t> padded_size(num_new_dimensions, 1);
1143 padded_size.insert(
1144 padded_size.end(), self.sizes().begin(), self.sizes().end());
1145 std::vector<int64_t> target_size(repeats.size());
1146 for (const auto idx : c10::irange(repeats.size())) {
1147 target_size[idx] = padded_size[idx] * repeats[idx];
1148 }
1149 return {Shape(self.scalar_type(), target_size)};
1150}
1151
1152std::vector<Shape> compute_shape_narrow_copy_symint(
1153 const at::Tensor& self,
1154 int64_t dim,
1155 int64_t start,
1156 c10::SymInt length) {
1157 return {Shape(self.scalar_type(), self.sizes().vec())};
1158}
1159
1160std::vector<Shape> compute_shape_hardswish(const at::Tensor& self) {
1161 return {Shape(self.scalar_type(), self.sizes().vec())};
1162}
1163
1164std::vector<Shape> compute_shape_hardswish_backward(
1165 const at::Tensor& grad_output,
1166 const at::Tensor& self) {
1167 return {Shape(self.scalar_type(), self.sizes().vec())};
1168}
1169
1170std::vector<Shape> compute_shape_selu(const at::Tensor& self) {
1171 return {Shape(self.scalar_type(), self.sizes().vec())};
1172}
1173
1174// Non-Native Ops
1175std::vector<Shape> compute_shape_scalar(
1176 const at::Scalar& value,
1177 const at::ScalarType& type) {
1178 return {Shape(type, {})};
1179}
1180std::vector<Shape> compute_shape_expand(
1181 const Output& input,
1182 const std::vector<int64_t>& size,
1183 const bool& is_scalar_expand) {
1184 return {Shape(input.shape().scalar_type(), size)};
1185}
1186std::vector<Shape> compute_shape_view(
1187 const Output& input,
1188 const std::vector<int64_t>& output_sizes) {
1189 const Shape& input_shape = input.shape();
1190 const auto complete_output_sizes =
1191 at::infer_size(output_sizes, input_shape.numel());
1192 return {Shape(input_shape.scalar_type(), complete_output_sizes)};
1193}
1194std::vector<Shape> compute_shape_cast(
1195 const Output& input,
1196 const at::ScalarType& dtype,
1197 const c10::optional<at::ScalarType>& stype) {
1198 Shape shape = input.shape();
1199 shape.set_scalar_type(dtype);
1200 return {shape};
1201}
1202
1203// View Ops
1204std::vector<Shape> compute_shape_as_strided_view_update(
1205 const Output& target,
1206 const Output& input,
1207 const std::vector<int64_t>& size,
1208 const std::vector<int64_t>& stride,
1209 const int64_t& storage_offset) {
1210 return {Shape(target.shape().scalar_type(), size)};
1211}
1212std::vector<Shape> compute_shape_as_strided(
1213 const Output& input,
1214 const std::vector<int64_t>& size,
1215 const std::vector<int64_t>& stride,
1216 const int64_t& storage_offset) {
1217 return {Shape(input.shape().scalar_type(), size)};
1218}
1219std::vector<Shape> compute_shape_diagonal_view_update(
1220 const Output& target,
1221 const Output& input,
1222 const int64_t& offset,
1223 const int64_t& dim1,
1224 const int64_t& dim2) {
1225 return {target.shape()};
1226}
1227std::vector<Shape> compute_shape_diagonal(
1228 const Output& input,
1229 const int64_t& offset,
1230 const int64_t& dim1,
1231 const int64_t& dim2) {
1232 return {MakeDiagonalShape(input.shape(), offset, dim1, dim2)};
1233}
1234std::vector<Shape> compute_shape_narrow_view_update(
1235 const Output& input,
1236 const Output& source,
1237 const std::vector<int64_t>& base_indices) {
1238 return {input.shape()};
1239}
1240std::vector<Shape> compute_shape_narrow(
1241 const Output& input,
1242 const std::vector<int64_t>& base_indices,
1243 const std::vector<int64_t>& sizes) {
1244 return {Shape(input.shape().scalar_type(), sizes)};
1245}
1246std::vector<Shape> compute_shape_permute(
1247 const Output& input,
1248 const std::vector<int64_t>& dims) {
1249 return {MakePermuteShape(input.shape(), dims)};
1250}
1251std::vector<Shape> compute_shape_resize(
1252 const Output& input,
1253 const std::vector<int64_t>& size) {
1254 return {Shape(input.shape().scalar_type(), size)};
1255}
1256std::vector<Shape> compute_shape_select_view_update(
1257 const Output& target,
1258 const Output& source,
1259 const int64_t& dim,
1260 const int64_t& start,
1261 const int64_t& end,
1262 const int64_t& stride) {
1263 return {target.shape()};
1264}
1265std::vector<Shape> compute_shape_select(
1266 const Output& input,
1267 const int64_t& dim,
1268 const int64_t& start,
1269 const int64_t& end,
1270 const int64_t& stride) {
1271 return {MakeSelectShape(input.shape(), dim, start, end, stride)};
1272}
1273std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim) {
1274 const auto& input_shape = input.shape();
1275 return {torch::lazy::Shape(
1276 input_shape.scalar_type(),
1277 BuildSqueezedDimensions(input_shape.sizes(), dim))};
1278}
1279std::vector<Shape> compute_shape_unsqueeze(
1280 const Output& input,
1281 const int& dim) {
1282 const auto& input_shape = input.shape();
1283 return {torch::lazy::Shape(
1284 input_shape.scalar_type(),
1285 BuildUnsqueezedDimensions(input_shape.sizes(), dim))};
1286}
1287
1288std::vector<Shape> compute_shape_select_scatter(
1289 const at::Tensor& self,
1290 const at::Tensor& src,
1291 int64_t dim,
1292 int64_t index) {
1293 auto self_meta = at::native::empty_strided_meta_symint(
1294 self.sym_sizes(),
1295 self.sym_strides(),
1296 /*dtype=*/c10::make_optional(self.scalar_type()),
1297 /*layout=*/c10::make_optional(self.layout()),
1298 /*device=*/c10::make_optional(c10::Device(c10::kMeta)),
1299 /*pin_memory=*/c10::nullopt);
1300 auto src_meta = at::native::empty_strided_meta_symint(
1301 src.sym_sizes(),
1302 src.sym_strides(),
1303 /*dtype=*/c10::make_optional(src.scalar_type()),
1304 /*layout=*/c10::make_optional(src.layout()),
1305 /*device=*/c10::make_optional(c10::Device(c10::kMeta)),
1306 /*pin_memory=*/c10::nullopt);
1307 auto out_meta = at::compositeexplicitautograd::select_scatter(
1308 self_meta, src_meta, dim, index);
1309 return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1310}
1311
1312std::vector<Shape> compute_shape_diagonal_scatter(
1313 const at::Tensor& self,
1314 const at::Tensor& src,
1315 int64_t offset,
1316 int64_t dim1,
1317 int64_t dim2) {
1318 auto self_meta = at::native::empty_strided_meta_symint(
1319 self.sym_sizes(),
1320 self.sym_strides(),
1321 /*dtype=*/c10::make_optional(self.scalar_type()),
1322 /*layout=*/c10::make_optional(self.layout()),
1323 /*device=*/c10::make_optional(c10::Device(c10::kMeta)),
1324 /*pin_memory=*/c10::nullopt);
1325 auto src_meta = at::native::empty_strided_meta_symint(
1326 src.sym_sizes(),
1327 src.sym_strides(),
1328 /*dtype=*/c10::make_optional(src.scalar_type()),
1329 /*layout=*/c10::make_optional(src.layout()),
1330 /*device=*/c10::make_optional(c10::Device(c10::kMeta)),
1331 /*pin_memory=*/c10::nullopt);
1332 auto out_meta = at::compositeexplicitautograd::diagonal_scatter(
1333 self_meta, src_meta, offset, dim1, dim2);
1334 return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1335}
1336
1337std::vector<Shape> compute_shape_slice_scatter_symint(
1338 const at::Tensor& self,
1339 const at::Tensor& src,
1340 int64_t dim,
1341 c10::optional<c10::SymInt> start,
1342 c10::optional<c10::SymInt> end,
1343 c10::SymInt step) {
1344 auto self_meta = at::native::empty_strided_meta_symint(
1345 self.sym_sizes(),
1346 self.sym_strides(),
1347 /*dtype=*/c10::make_optional(self.scalar_type()),
1348 /*layout=*/c10::make_optional(self.layout()),
1349 /*device=*/c10::make_optional(c10::Device(c10::kMeta)),
1350 /*pin_memory=*/c10::nullopt);
1351 auto src_meta = at::native::empty_strided_meta_symint(
1352 src.sym_sizes(),
1353 src.sym_strides(),
1354 /*dtype=*/c10::make_optional(src.scalar_type()),
1355 /*layout=*/c10::make_optional(src.layout()),
1356 /*device=*/c10::make_optional(c10::Device(c10::kMeta)),
1357 /*pin_memory=*/c10::nullopt);
1358 auto out_meta = at::compositeexplicitautograd::slice_scatter_symint(
1359 self_meta, src_meta, dim, start, end, step);
1360 return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1361}
1362
1363std::vector<Shape> compute_shape_as_strided_scatter_symint(
1364 const at::Tensor& self,
1365 const at::Tensor& src,
1366 at::SymIntArrayRef size,
1367 at::SymIntArrayRef stride,
1368 c10::optional<c10::SymInt> storage_offset) {
1369 auto self_meta = at::native::empty_strided_meta_symint(
1370 self.sym_sizes(),
1371 self.sym_strides(),
1372 /*dtype=*/c10::make_optional(self.scalar_type()),
1373 /*layout=*/c10::make_optional(self.layout()),
1374 /*device=*/c10::make_optional(c10::Device(c10::kMeta)),
1375 /*pin_memory=*/c10::nullopt);
1376 auto src_meta = at::native::empty_strided_meta_symint(
1377 src.sym_sizes(),
1378 src.sym_strides(),
1379 /*dtype=*/c10::make_optional(src.scalar_type()),
1380 /*layout=*/c10::make_optional(src.layout()),
1381 /*device=*/c10::make_optional(c10::Device(c10::kMeta)),
1382 /*pin_memory=*/c10::nullopt);
1383 auto out_meta = at::compositeexplicitautograd::as_strided_scatter_symint(
1384 self_meta, src_meta, size, stride, storage_offset);
1385 return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
1386}
1387
1388// Restore unused-parameters warnings
1389#pragma GCC diagnostic pop
1390
1391} // namespace lazy
1392} // namespace torch
1393