1 | /** |
2 | * This is a handwritten file that accompanies codegenerated header |
3 | * LazyShapeDtype.h |
4 | * |
5 | * The purpose of these shape/dtype inference methods are to fill gaps |
6 | * where we do not yet have structured kernels in pytorch core. Ops |
7 | * for which there _are_ structured kernels can use meta::op() to infer |
8 | * shape/dtype, and codegen makes use of this. Ops for which there are not |
9 | * yet structured kernels can still be used with lazy_tensor codegen, but |
10 | * require manual intervention to implement compute_shape_{op} and |
11 | * compute_dtype_{op}. |
12 | * |
13 | * READ THIS! |
14 | * |
15 | * 1. Beware: Tech Debt! |
16 | * --------------------- |
17 | * These functions are tech debt. We want to delete them all and use structured |
18 | * kernels instead, but it's a lot faster to write these so we're decoupling the |
19 | * two efforts to move fast for adding support for codegenned Lazy Tensor ops. |
20 | * |
21 | * Codegenned Lazy Tensor ops with handwritten shape formulae are still better |
22 | * than fully handwritten Lazy Tensor ops (which also have handwritten shape |
23 | * formulae). |
24 | * |
25 | * 2. Structured Kernels For The Win |
26 | * --------------------------------- |
27 | * Long term, more and more ops should be supported as 'structured kernels'. |
28 | * Consider doing your part and porting an op. As ops get ported over, the |
29 | * codegen will automatically notice and stop generating declarations for these |
30 | * shape formulae, so we'll need to manually clean up the unused functions in |
31 | * this file, or somehow automate that. |
32 | * |
33 | * https://dev-discuss.pytorch.org/t/slides-from-structured-kernel-presentation/179 |
34 | * |
35 | * 3. How to figure out the shape/dtype |
36 | * ------------------------------------ |
37 | * Unfortunatley there isn't a one-stop-shop for learning the output shape |
38 | * formulae for all operators. This is partly because some operators are not |
39 | * part of our 'public' API, including backward operators which users don't |
40 | * directly invoke. |
41 | * |
42 | * Check our opinfo registry: |
43 | * https://github.com/pytorch/pytorch/blob/13b859983183ea9938deb5030ac9a0747841f0a8/torch/csrc/jit/runtime/symbolic_shape_registry.cpp |
44 | * |
45 | * Read the manual (for ops that are 1:1 with python frontend): |
46 | * https://pytorch.org/docs/stable/generated/torch.trace.html |
47 | * |
48 | */ |
49 | |
50 | #include <torch/csrc/lazy/core/shape_inference.h> |
51 | |
52 | #include <ATen/AccumulateType.h> |
53 | #include <ATen/CompositeExplicitAutogradFunctions.h> |
54 | #include <ATen/Dispatch.h> |
55 | #include <ATen/ExpandUtils.h> |
56 | #include <ATen/Functions.h> |
57 | #include <ATen/InferSize.h> |
58 | #include <ATen/NativeFunctions.h> |
59 | #include <ATen/WrapDimUtils.h> |
60 | #include <ATen/native/ConvUtils.h> |
61 | #include <ATen/native/ReduceOpsUtils.h> |
62 | #include <ATen/native/TensorConversions.h> |
63 | #include <c10/core/ScalarType.h> |
64 | #include <torch/csrc/api/include/torch/enum.h> |
65 | #include <torch/csrc/lazy/core/ops/utils.h> |
66 | #include <torch/csrc/lazy/core/shape.h> |
67 | #include <torch/csrc/lazy/core/util.h> |
68 | #include <torch/csrc/lazy/ts_backend/dynamic_ir.h> |
69 | #include <ostream> |
70 | #include <vector> |
71 | |
72 | namespace torch { |
73 | namespace lazy { |
74 | |
75 | // Copied from ATen/native/utils/ParamUtils.h, which aparently I can't include |
76 | // from here? |
77 | std::vector<int64_t> expand_param_if_needed( |
78 | at::IntArrayRef list_param, |
79 | const char* param_name, |
80 | int64_t expected_dim) { |
81 | if (list_param.size() == 1) { |
82 | return std::vector<int64_t>(expected_dim, list_param[0]); |
83 | } else if ((int64_t)list_param.size() != expected_dim) { |
84 | std::ostringstream ss; |
85 | ss << "expected " << param_name << " to be a single integer value or a " |
86 | << "list of " << expected_dim << " values to match the convolution " |
87 | << "dimensions, but got " << param_name << "=" << list_param; |
88 | AT_ERROR(ss.str()); |
89 | } else { |
90 | return list_param.vec(); |
91 | } |
92 | } |
93 | |
94 | // It seems more common to not use parameters than to use them, so disable |
95 | // unused-parameter warning |
96 | #pragma GCC diagnostic push |
97 | #pragma GCC diagnostic ignored "-Wunused-parameter" |
98 | |
99 | TORCH_API std::vector<Shape> compute_shape_arange_out( |
100 | const at::Scalar& start, |
101 | const at::Scalar& end, |
102 | const at::Scalar& step, |
103 | at::Tensor& out) { |
104 | double size_d = 0; |
105 | // shape inference code copied from RangeFactories.cpp arange_out function |
106 | // Note: AT_DISPATCH_ALL_TYPES_AND is just a macro that defines the correct |
107 | // c++ scalar_t type depending on out tensor |
108 | |
109 | AT_DISPATCH_ALL_TYPES_AND( |
110 | c10::kBFloat16, out.scalar_type(), "compute_shape_arange_out" , [&]() { |
111 | // Note: acc_type further defines an accumulataion type depending on the |
112 | // scalar_t and whether its on cuda vs cpu. |
113 | using accscalar_t = at::acc_type<scalar_t, false>; |
114 | auto xstart = start.to<accscalar_t>(); |
115 | auto xend = end.to<accscalar_t>(); |
116 | auto xstep = step.to<accscalar_t>(); |
117 | |
118 | // we use double precision for (start - end) / step |
119 | // to compute size_d for consistency across devices. |
120 | // The problem with using accscalar_t is that accscalar_t might be |
121 | // float32 on gpu for a float32 scalar_t, but double on cpu for the |
122 | // same, and the effective output size starts differing on CPU vs GPU |
123 | // because of precision issues, which we dont want. the corner-case we |
124 | // do want to take into account is int64_t, which has higher precision |
125 | // than double NOLINTNEXTLINE(bugprone-branch-clone) |
126 | if (std::is_same<scalar_t, int64_t>::value) { |
127 | size_d = std::ceil( |
128 | static_cast<double>( |
129 | end.to<accscalar_t>() - start.to<accscalar_t>()) / |
130 | step.to<accscalar_t>()); |
131 | } else { |
132 | size_d = std::ceil( |
133 | static_cast<double>(end.to<double>() - start.to<double>()) / |
134 | step.to<double>()); |
135 | } |
136 | |
137 | TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero" ); |
138 | TORCH_CHECK( |
139 | std::isfinite(static_cast<double>(xstart)) && |
140 | std::isfinite(static_cast<double>(xend)), |
141 | "unsupported range: " , |
142 | xstart, |
143 | " -> " , |
144 | xend); |
145 | TORCH_CHECK( |
146 | ((xstep > 0) && (xend >= xstart)) || |
147 | ((xstep < 0) && (xend <= xstart)), |
148 | "upper bound and larger bound inconsistent with step sign" ); |
149 | |
150 | TORCH_CHECK( |
151 | size_d >= 0 && |
152 | size_d <= |
153 | static_cast<double>(std::numeric_limits<int64_t>::max()), |
154 | "invalid size, possible overflow?" ); |
155 | }); |
156 | |
157 | int64_t size = static_cast<int64_t>(size_d); |
158 | |
159 | // From torch.arange docs: |
160 | // dtype (torch.dtype, optional) – the desired data type of returned tensor. |
161 | // Default: if None, uses a global default (see |
162 | // torch.set_default_tensor_type()). If dtype is not given, infer the data |
163 | // type from the other input arguments. If any of start, end, or stop are |
164 | // floating-point, the dtype is inferred to be the default dtype, see |
165 | // get_default_dtype(). Otherwise, the dtype is inferred to be torch.int64. |
166 | |
167 | return {Shape(out.scalar_type(), {size})}; |
168 | } |
169 | |
170 | std::vector<Shape> compute_shape_abs(const at::Tensor& self) { |
171 | if (self.is_complex()) { |
172 | const auto float_type = c10::toRealValueType(self.scalar_type()); |
173 | return {Shape(float_type, self.sizes().vec())}; |
174 | } |
175 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
176 | } |
177 | |
178 | std::vector<Shape> compute_shape_bernoulli( |
179 | const at::Tensor& self, |
180 | c10::optional<at::Generator> generator) { |
181 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
182 | } |
183 | |
184 | std::vector<Shape> compute_shape_bernoulli( |
185 | const at::Tensor& self, |
186 | double p, |
187 | c10::optional<at::Generator> generator) { |
188 | return compute_shape_bernoulli(self, generator); |
189 | } |
190 | |
191 | std::vector<Shape> compute_shape_binary_cross_entropy( |
192 | const at::Tensor& self, |
193 | const at::Tensor& target, |
194 | const c10::optional<at::Tensor>& weight, |
195 | int64_t reduction) { |
196 | if (reduction == at::Reduction::None) { |
197 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
198 | } |
199 | return {Shape(self.scalar_type(), {})}; |
200 | } |
201 | |
202 | std::vector<Shape> compute_shape_binary_cross_entropy_backward( |
203 | const at::Tensor& grad_output, |
204 | const at::Tensor& self, |
205 | const at::Tensor& target, |
206 | const c10::optional<at::Tensor>& weight, |
207 | int64_t reduction) { |
208 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
209 | } |
210 | |
211 | std::vector<Shape> compute_shape_constant_pad_nd( |
212 | const at::Tensor& self, |
213 | at::IntArrayRef pad, |
214 | const at::Scalar& value) { |
215 | // Based on aten/src/ATen/native/ConstantPadNd.cpp::constant_pad_nd |
216 | TORCH_CHECK( |
217 | pad.size() % 2 == 0, |
218 | "Length of pad must be even but instead it equals " , |
219 | pad.size()); |
220 | |
221 | auto input_sizes = self.sizes(); |
222 | auto l_inp = self.dim(); |
223 | |
224 | auto l_pad = pad.size() / 2; |
225 | auto l_diff = l_inp - l_pad; |
226 | TORCH_CHECK( |
227 | l_inp >= (int64_t)l_pad, |
228 | "Length of pad should be no more than twice the number of " |
229 | "dimensions of the input. Pad length is " , |
230 | pad.size(), |
231 | "while the input has " , |
232 | l_inp, |
233 | "dimensions." ); |
234 | |
235 | std::vector<int64_t> new_shape; |
236 | for (size_t i = 0; i < (size_t)l_diff; i++) { |
237 | new_shape.emplace_back(input_sizes[i]); |
238 | } |
239 | |
240 | for (const auto i : c10::irange((size_t)l_pad)) { |
241 | auto pad_idx = pad.size() - ((i + 1) * 2); |
242 | auto new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]; |
243 | TORCH_CHECK( |
244 | new_dim > 0, |
245 | "The input size " , |
246 | input_sizes[l_diff + i], |
247 | ", plus negative padding " , |
248 | pad[pad_idx], |
249 | " and " , |
250 | pad[pad_idx + 1], |
251 | " resulted in a negative output size, " |
252 | "which is invalid. Check dimension " , |
253 | l_diff + i, |
254 | " of your input." ); |
255 | new_shape.emplace_back(new_dim); |
256 | } |
257 | return {Shape(self.scalar_type(), new_shape)}; |
258 | } |
259 | |
260 | std::vector<Shape> compute_shape_convolution_backward( |
261 | const at::Tensor& grad_output, |
262 | const at::Tensor& input, |
263 | const at::Tensor& weight, |
264 | at::OptionalIntArrayRef bias_sizes, |
265 | at::IntArrayRef stride, |
266 | at::IntArrayRef padding, |
267 | at::IntArrayRef dilation, |
268 | bool transposed, |
269 | at::IntArrayRef output_padding, |
270 | int64_t groups, |
271 | ::std::array<bool, 3> output_mask) { |
272 | if (bias_sizes.has_value()) { |
273 | return { |
274 | Shape(input.scalar_type(), input.sizes().vec()), |
275 | Shape(weight.scalar_type(), weight.sizes().vec()), |
276 | Shape(grad_output.scalar_type(), bias_sizes.value().vec())}; |
277 | } else { |
278 | // TODO(whc) not sure whether to return 2 shapes here, or a 3rd one that is |
279 | // empty |
280 | return { |
281 | Shape(input.scalar_type(), input.sizes().vec()), |
282 | Shape(weight.scalar_type(), weight.sizes().vec())}; |
283 | } |
284 | } |
285 | |
286 | std::vector<Shape> compute_shape_convolution( |
287 | const at::Tensor& input, |
288 | const at::Tensor& weight, |
289 | const c10::optional<at::Tensor>& bias, |
290 | at::IntArrayRef stride, |
291 | at::IntArrayRef padding, |
292 | at::IntArrayRef dilation, |
293 | bool transposed, |
294 | at::IntArrayRef output_padding, |
295 | int64_t groups) { |
296 | int64_t dim = weight.ndimension() - 2; |
297 | TORCH_CHECK(dim > 0, "weight should have at least three dimensions" ); |
298 | |
299 | // at::convolution performs parameter expansion before running kernels on |
300 | // expanded parameters we must do the same. Shape formulae access differnent |
301 | // dimensions of e.g. output_padding, but output_padding may be passed in as a |
302 | // scalar. Sadly, accessing output_padding[1] in this case gives incorrect |
303 | // results rather than indexing error |
304 | auto expanded_stride = expand_param_if_needed(stride, "stride" , dim); |
305 | auto expanded_padding = expand_param_if_needed(padding, "padding" , dim); |
306 | auto expanded_dilation = expand_param_if_needed(dilation, "dilation" , dim); |
307 | if (!transposed) { |
308 | return {Shape( |
309 | input.scalar_type(), |
310 | at::native::conv_output_size( |
311 | input.sizes(), |
312 | weight.sizes(), |
313 | expanded_padding, |
314 | expanded_stride, |
315 | expanded_dilation))}; |
316 | } else { |
317 | auto expanded_output_padding = |
318 | expand_param_if_needed(output_padding, "output_padding" , dim); |
319 | auto out_shape = at::native::conv_input_size( |
320 | input.sizes(), |
321 | weight.sizes(), |
322 | expanded_padding, |
323 | expanded_output_padding, |
324 | expanded_stride, |
325 | expanded_dilation, |
326 | groups); |
327 | return {Shape(input.scalar_type(), out_shape)}; |
328 | } |
329 | } |
330 | |
331 | std::vector<Shape> compute_shape_masked_fill( |
332 | const at::Tensor& self, |
333 | const at::Tensor& mask, |
334 | const at::Scalar& value) { |
335 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
336 | } |
337 | |
338 | std::vector<Shape> compute_shape_masked_fill( |
339 | const at::Tensor& self, |
340 | const at::Tensor& mask, |
341 | const at::Tensor& value) { |
342 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
343 | } |
344 | |
345 | std::vector<Shape> compute_shape_max(const at::Tensor& self) { |
346 | TORCH_CHECK( |
347 | self.numel() > 0, |
348 | "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument." ); |
349 | return {Shape(self.scalar_type(), {})}; |
350 | } |
351 | |
352 | std::vector<Shape> compute_shape_min(const at::Tensor& self) { |
353 | TORCH_CHECK( |
354 | self.numel() > 0, |
355 | "min(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument." ); |
356 | return {Shape(self.scalar_type(), {})}; |
357 | } |
358 | |
359 | std::vector<Shape> compute_shape_nonzero(const at::Tensor& t, bool as_tuple) { |
360 | if (as_tuple) { |
361 | auto res = std::vector<Shape>(); |
362 | for (auto dim_size : t.sizes()) { |
363 | res.emplace_back(Shape(at::kLong, {dim_size})); |
364 | } |
365 | return res; |
366 | } |
367 | int64_t max_elements = 1; |
368 | for (auto dim_size : t.sizes()) { |
369 | max_elements *= dim_size; |
370 | } |
371 | return {Shape(at::kLong, {max_elements, (int64_t)t.sizes().size()})}; |
372 | } |
373 | |
374 | std::vector<Shape> compute_shape_nonzero(const at::Tensor& self) { |
375 | return compute_shape_nonzero(self, false); |
376 | } |
377 | |
378 | std::vector<Shape> compute_shape_embedding( |
379 | const at::Tensor& weight, |
380 | const at::Tensor& indices, |
381 | int64_t padding_idx, |
382 | bool scale_grad_by_freq, |
383 | bool sparse) { |
384 | // Based on aten/src/ATen/native/Embedding.cpp::embedding. |
385 | std::vector<int64_t> out_sizes = indices.sizes().vec(); |
386 | out_sizes.emplace_back(weight.size(1)); |
387 | return {Shape(weight.scalar_type(), out_sizes)}; |
388 | } |
389 | |
390 | std::vector<Shape> compute_shape_std(const at::Tensor& self, bool unbiased) { |
391 | return compute_shape_std(self, c10::nullopt, c10::nullopt, false); |
392 | } |
393 | std::vector<Shape> compute_shape_std( |
394 | const at::Tensor& self, |
395 | at::OptionalIntArrayRef dim, |
396 | bool unbiased, |
397 | bool keepdim) { |
398 | return compute_shape_std(self, dim, c10::nullopt, keepdim); |
399 | } |
400 | std::vector<Shape> compute_shape_std( |
401 | const at::Tensor& self, |
402 | at::OptionalIntArrayRef dim, |
403 | c10::optional<int64_t> correction, |
404 | bool keepdim) { |
405 | if (dim.has_value()) { |
406 | auto shape = at::native::shape_from_dim_mask( |
407 | self, at::native::make_dim_mask(dim.value(), self.dim()), keepdim); |
408 | return {Shape( |
409 | self.scalar_type(), std::vector<int64_t>(shape.begin(), shape.end()))}; |
410 | } |
411 | return {Shape(self.scalar_type(), {})}; |
412 | } |
413 | |
414 | std::vector<Shape> compute_shape_embedding_dense_backward( |
415 | const at::Tensor& grad_output, |
416 | const at::Tensor& indices, |
417 | int64_t num_weights, |
418 | int64_t padding_idx, |
419 | bool scale_grad_by_freq) { |
420 | // Based on aten/src/ATen/native/Embedding.cpp::embedding_dense_backward_cpu. |
421 | return { |
422 | Shape(grad_output.scalar_type(), {num_weights, grad_output.size(-1)})}; |
423 | } |
424 | |
425 | std::vector<Shape> compute_shape_expand( |
426 | const at::Tensor& self, |
427 | at::IntArrayRef size, |
428 | bool implicit) { |
429 | TORCH_CHECK_GE(size.size(), self.dim()); |
430 | int64_t num_new_dimensions = size.size() - self.dim(); |
431 | std::vector<int64_t> padded_self(num_new_dimensions, 0); |
432 | padded_self.insert( |
433 | padded_self.end(), self.sizes().begin(), self.sizes().end()); |
434 | std::vector<int64_t> target_size(size.size()); |
435 | for (const auto idx : c10::irange(size.size())) { |
436 | target_size[idx] = size[idx] == -1 ? padded_self[idx] : size[idx]; |
437 | } |
438 | return {Shape(self.scalar_type(), target_size)}; |
439 | } |
440 | |
441 | std::vector<Shape> compute_shape_expand( |
442 | const at::Tensor& self, |
443 | c10::SymIntArrayRef size, |
444 | bool implicit) { |
445 | TORCH_CHECK_GE(size.size(), self.dim()); |
446 | std::vector<c10::SymInt> _sizes = ToVector<c10::SymInt>(size); |
447 | int64_t num_new_dimensions = _sizes.size() - self.dim(); |
448 | std::vector<int64_t> padded_self(num_new_dimensions, 0); |
449 | padded_self.insert( |
450 | padded_self.end(), self.sizes().begin(), self.sizes().end()); |
451 | std::vector<int64_t> target_size(_sizes.size()); |
452 | for (const auto idx : c10::irange(_sizes.size())) { |
453 | if (_sizes[idx].is_symbolic()) { |
454 | c10::SymNode symbolicIntNode = _sizes[idx].toSymNodeImpl(); |
455 | auto* lazySymNode = |
456 | dynamic_cast<torch::lazy::SymNodeImpl*>(symbolicIntNode.get()); |
457 | TORCH_INTERNAL_ASSERT(lazySymNode); |
458 | auto size_node = lazySymNode->node_; |
459 | auto static_value = |
460 | std::dynamic_pointer_cast<torch::lazy::DimensionNode>(size_node) |
461 | ->getStaticValue(); |
462 | target_size[idx] = static_value; |
463 | } else { |
464 | target_size[idx] = _sizes[idx].as_int_unchecked(); |
465 | if (_sizes[idx].as_int_unchecked() == -1) { |
466 | // -1 can't be specified for non-existing dimensions |
467 | TORCH_CHECK(idx >= num_new_dimensions); |
468 | target_size[idx] = padded_self[idx]; |
469 | } else { |
470 | target_size[idx] = _sizes[idx].as_int_unchecked(); |
471 | } |
472 | } |
473 | } |
474 | return {Shape(self.scalar_type(), target_size)}; |
475 | } |
476 | |
477 | std::vector<Shape> compute_shape_index_select( |
478 | const at::Tensor& self, |
479 | int64_t dim, |
480 | const at::Tensor& index) { |
481 | // Based on definition of |
482 | // https://pytorch.org/docs/stable/generated/torch.index_select.html. Promote |
483 | // Rank 0 index tensor to a 1 * 1 tensor. |
484 | dim = at::maybe_wrap_dim(dim, self); |
485 | auto index_dim = index.dim() > 0 ? index.dim() : 1; |
486 | auto index_size = index.dim() > 0 ? index.size(0) : 1; |
487 | TORCH_CHECK(index_dim == 1); |
488 | |
489 | auto self_sizes = self.sizes(); |
490 | std::vector<int64_t> output_sizes(self_sizes.begin(), self_sizes.end()); |
491 | TORCH_CHECK(!output_sizes.empty(), "Empty output_sizes is not supported." ); |
492 | output_sizes[dim] = index_size; |
493 | |
494 | return {Shape(self.scalar_type(), output_sizes)}; |
495 | } |
496 | |
497 | std::vector<Shape> compute_shape_inverse(const at::Tensor& self) { |
498 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
499 | } |
500 | |
501 | std::vector<Shape> compute_shape_isnan(const at::Tensor& self) { |
502 | return {Shape(c10::ScalarType::Bool, self.sizes().vec())}; |
503 | } |
504 | |
505 | std::vector<Shape> compute_shape_cat(at::TensorList tensors, int64_t dim) { |
506 | // TODO(whc) support cat in codegen and move this to compute_*_cat functions |
507 | std::vector<int64_t> out_shape( |
508 | tensors[0].sizes().begin(), tensors[0].sizes().end()); |
509 | |
510 | dim = at::maybe_wrap_dim(dim, tensors); |
511 | size_t extended_dim_shape = 0; |
512 | for (auto& tensor : tensors) { |
513 | extended_dim_shape += tensor.sizes()[dim]; |
514 | } |
515 | TORCH_CHECK(!out_shape.empty(), "Scalar tensors are not supported in cat." ); |
516 | TORCH_CHECK( |
517 | extended_dim_shape <= std::numeric_limits<int64_t>::max(), |
518 | "Size overflow" ); |
519 | out_shape[dim] = extended_dim_shape; |
520 | return {Shape(tensors[0].scalar_type(), out_shape)}; |
521 | } |
522 | |
523 | TORCH_API std::vector<torch::lazy::Shape> compute_shape_cholesky( |
524 | const at::Tensor& self, |
525 | bool upper) { |
526 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
527 | } |
528 | |
529 | std::vector<torch::lazy::Shape> compute_shape_native_batch_norm( |
530 | const at::Tensor& input, |
531 | const c10::optional<at::Tensor>& weight, |
532 | const c10::optional<at::Tensor>& bias, |
533 | const c10::optional<at::Tensor>& running_mean, |
534 | const c10::optional<at::Tensor>& running_var, |
535 | bool training, |
536 | double momentum, |
537 | double eps) { |
538 | std::vector<torch::lazy::Shape> shapes; |
539 | shapes.reserve(3); |
540 | shapes.emplace_back(input.scalar_type(), input.sizes().vec()); |
541 | |
542 | // A separate mean and var needs to be kept for each channel. |
543 | TORCH_CHECK( |
544 | input.sizes().size() >= 2, |
545 | "Input tensor must have at least batch and channel dimensions!" ); |
546 | int64_t num_features = input.size(1); |
547 | |
548 | if (running_mean.has_value()) { |
549 | shapes.emplace_back( |
550 | running_mean.value().scalar_type(), running_mean.value().sizes().vec()); |
551 | } else { |
552 | shapes.emplace_back( |
553 | at::get_default_dtype_as_scalartype(), |
554 | std::vector<int64_t>{num_features}); |
555 | } |
556 | |
557 | if (running_var.has_value()) { |
558 | shapes.emplace_back( |
559 | running_var.value().scalar_type(), running_var.value().sizes().vec()); |
560 | } else { |
561 | shapes.emplace_back( |
562 | at::get_default_dtype_as_scalartype(), |
563 | std::vector<int64_t>{num_features}); |
564 | } |
565 | return shapes; |
566 | } |
567 | |
568 | std::vector<torch::lazy::Shape> compute_shape_native_batch_norm_backward( |
569 | const at::Tensor& grad_out, |
570 | const at::Tensor& input, |
571 | const c10::optional<at::Tensor>& weight, |
572 | const c10::optional<at::Tensor>& running_mean, |
573 | const c10::optional<at::Tensor>& running_var, |
574 | const c10::optional<at::Tensor>& save_mean, |
575 | const c10::optional<at::Tensor>& save_invstd, |
576 | bool train, |
577 | double eps, |
578 | ::std::array<bool, 3> output_mask) { |
579 | std::vector<torch::lazy::Shape> shapes; |
580 | shapes.reserve(3); |
581 | shapes.emplace_back(input.scalar_type(), input.sizes().vec()); |
582 | |
583 | // A separate mean and var needs to be kept for each channel. |
584 | TORCH_CHECK( |
585 | input.sizes().size() >= 2, |
586 | "Input tensor must have at least batch and channel dimensions!" ); |
587 | int64_t num_features = input.size(1); |
588 | |
589 | // `weight` and `bias` are vectors of length C (number of channels)` |
590 | shapes.emplace_back( |
591 | at::get_default_dtype_as_scalartype(), |
592 | std::vector<int64_t>{num_features}); |
593 | shapes.emplace_back( |
594 | at::get_default_dtype_as_scalartype(), |
595 | std::vector<int64_t>{num_features}); |
596 | |
597 | return shapes; |
598 | } |
599 | |
600 | std::vector<Shape> compute_shape_native_layer_norm( |
601 | const at::Tensor& input, |
602 | at::IntArrayRef normalized_shape, |
603 | const c10::optional<at::Tensor>& weight, |
604 | const c10::optional<at::Tensor>& bias, |
605 | double eps) { |
606 | // Copied from aten/src/ATen/native/layer_norm.cpp::layer_norm_cpu_out. |
607 | auto input_shape = input.sizes().vec(); |
608 | const size_t axis = input.dim() - normalized_shape.size(); |
609 | |
610 | std::vector<int64_t> stat_shape; |
611 | for (const auto idx : c10::irange(axis)) { |
612 | TORCH_CHECK(idx < input_shape.size(), "Shape mismatch" ); |
613 | stat_shape.emplace_back(input_shape[idx]); |
614 | } |
615 | for (const auto idx : c10::irange(axis, input.dim())) { |
616 | (void)idx; // Suppress unused variable warning |
617 | stat_shape.emplace_back(1); |
618 | } |
619 | |
620 | return { |
621 | Shape(input.scalar_type(), input_shape), |
622 | Shape(input.scalar_type(), stat_shape), |
623 | Shape(input.scalar_type(), stat_shape)}; |
624 | } |
625 | |
626 | std::vector<Shape> compute_shape_native_layer_norm_backward( |
627 | const at::Tensor& grad_out, |
628 | const at::Tensor& input, |
629 | at::IntArrayRef normalized_shape, |
630 | const at::Tensor& mean, |
631 | const at::Tensor& rstd, |
632 | const c10::optional<at::Tensor>& weight, |
633 | const c10::optional<at::Tensor>& bias, |
634 | ::std::array<bool, 3> output_mask) { |
635 | std::vector<Shape> shapes; |
636 | shapes.emplace_back( |
637 | input.scalar_type(), |
638 | output_mask[0] ? input.sizes().vec() : std::vector<int64_t>{}); |
639 | shapes.emplace_back( |
640 | weight && weight->defined() ? weight->scalar_type() : input.scalar_type(), |
641 | output_mask[1] && weight ? weight->sizes().vec() |
642 | : std::vector<int64_t>{}); |
643 | shapes.emplace_back( |
644 | bias && weight->defined() ? bias->scalar_type() : input.scalar_type(), |
645 | output_mask[2] && bias ? bias->sizes().vec() : std::vector<int64_t>{}); |
646 | return shapes; |
647 | } |
648 | |
649 | std::vector<Shape> compute_shape_mean( |
650 | const at::Tensor& self, |
651 | c10::optional<at::ScalarType> dtype) { |
652 | if (dtype.has_value()) { |
653 | return {Shape(dtype.value(), {})}; |
654 | } |
655 | return {Shape(self.scalar_type(), {})}; |
656 | } |
657 | |
658 | std::vector<Shape> compute_shape_new_empty_strided( |
659 | const at::Tensor& self, |
660 | at::IntArrayRef size, |
661 | at::IntArrayRef stride, |
662 | c10::optional<at::ScalarType> dtype, |
663 | c10::optional<at::Layout> layout, |
664 | c10::optional<at::Device> device, |
665 | c10::optional<bool> pin_memory) { |
666 | return {Shape(dtype.has_value() ? *dtype : self.scalar_type(), size.vec())}; |
667 | } |
668 | |
669 | std::vector<Shape> compute_shape_mv( |
670 | const at::Tensor& self, |
671 | const at::Tensor& vec) { |
672 | return {Shape(self.scalar_type(), {self.size(0)})}; |
673 | } |
674 | |
675 | std::vector<Shape> compute_shape_native_dropout( |
676 | const at::Tensor& input, |
677 | double p, |
678 | c10::optional<bool> train) { |
679 | return { |
680 | Shape(input.scalar_type(), input.sizes().vec()), |
681 | Shape(c10::ScalarType::Bool, input.sizes().vec())}; |
682 | } |
683 | |
684 | std::vector<Shape> compute_shape_native_dropout_backward( |
685 | const at::Tensor& grad_output, |
686 | const at::Tensor& mask, |
687 | double scale) { |
688 | return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())}; |
689 | } |
690 | |
691 | std::vector<Shape> compute_shape_random( |
692 | const at::Tensor& self, |
693 | c10::optional<at::Generator> generator) { |
694 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
695 | } |
696 | |
697 | std::vector<Shape> compute_shape_random( |
698 | const at::Tensor& self, |
699 | int64_t to, |
700 | c10::optional<at::Generator> generator) { |
701 | return compute_shape_random(self, generator); |
702 | } |
703 | |
704 | std::vector<Shape> compute_shape_random( |
705 | const at::Tensor& self, |
706 | int64_t from, |
707 | c10::optional<int64_t> to, |
708 | c10::optional<at::Generator> generator) { |
709 | return compute_shape_random(self, generator); |
710 | } |
711 | |
712 | std::vector<Shape> compute_shape_relu(const at::Tensor& self) { |
713 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
714 | } |
715 | |
716 | std::vector<Shape> compute_shape_bitwise_and( |
717 | const at::Tensor& self, |
718 | const at::Scalar& other) { |
719 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
720 | } |
721 | |
722 | std::vector<Shape> compute_shape_sum( |
723 | const at::Tensor& self, |
724 | c10::optional<at::ScalarType> dtype) { |
725 | if (dtype.has_value()) { |
726 | return {Shape(dtype.value(), {})}; |
727 | } |
728 | // It's undocumented, but torch::sum promotes all integral types to int64_t by |
729 | // default |
730 | if (isIntegralType(self.scalar_type(), /*includeBool*/ true)) { |
731 | return {Shape(c10::ScalarType::Long, {})}; |
732 | } |
733 | return {Shape(self.scalar_type(), {})}; |
734 | ; |
735 | } |
736 | |
737 | std::vector<Shape> compute_shape_zero(const at::Tensor& self) { |
738 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
739 | } |
740 | |
741 | TORCH_API std::vector<torch::lazy::Shape> compute_shape_take( |
742 | const at::Tensor& self, |
743 | const at::Tensor& index) { |
744 | return {Shape(self.scalar_type(), index.sizes().vec())}; |
745 | } |
746 | |
747 | std::vector<Shape> compute_shape_trace(const at::Tensor& self) { |
748 | return {Shape(self.scalar_type(), {})}; |
749 | } |
750 | |
751 | std::vector<Shape> compute_shape_sort( |
752 | const at::Tensor& self, |
753 | int64_t dim, |
754 | bool descending) { |
755 | return { |
756 | Shape(self.scalar_type(), self.sizes().vec()), |
757 | Shape(c10::ScalarType::Long, self.sizes().vec())}; |
758 | } |
759 | |
760 | std::vector<Shape> compute_shape_smooth_l1_loss( |
761 | const at::Tensor& self, |
762 | const at::Tensor& target, |
763 | int64_t reduction, |
764 | double beta) { |
765 | // Taken from definition of 'Output' shape here: |
766 | // https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html |
767 | switch (reduction) { |
768 | case at::Reduction::None: |
769 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
770 | default: |
771 | return {Shape(self.scalar_type(), {})}; |
772 | } |
773 | } |
774 | |
775 | std::vector<Shape> compute_shape_slogdet(const at::Tensor& self) { |
776 | // assumes self.shape is {*, n, n} and returns shape * |
777 | TORCH_INTERNAL_ASSERT(self.dim() >= 2); |
778 | std::vector<int64_t> out_sizes(self.sizes().begin(), self.sizes().end() - 2); |
779 | // Doesn't check input dtype, but output dtype either matches it, |
780 | // or the actual slogdet operation will throw if it's an unsupported type. |
781 | // Sign and det outputs hold the same shape, dtype. |
782 | return { |
783 | Shape(self.scalar_type(), out_sizes), |
784 | Shape(self.scalar_type(), out_sizes)}; |
785 | } |
786 | |
787 | std::vector<torch::lazy::Shape> compute_shape_logical_and( |
788 | const at::Tensor& self, |
789 | const at::Tensor& other) { |
790 | TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes())); |
791 | return {Shape( |
792 | c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))}; |
793 | } |
794 | |
795 | std::vector<torch::lazy::Shape> compute_shape_logical_not( |
796 | const at::Tensor& self) { |
797 | return {Shape(c10::ScalarType::Bool, self.sizes().vec())}; |
798 | } |
799 | |
800 | std::vector<torch::lazy::Shape> compute_shape_logical_or( |
801 | const at::Tensor& self, |
802 | const at::Tensor& other) { |
803 | TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes())); |
804 | return {Shape( |
805 | c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))}; |
806 | } |
807 | |
808 | std::vector<torch::lazy::Shape> compute_shape_logical_xor( |
809 | const at::Tensor& self, |
810 | const at::Tensor& other) { |
811 | TORCH_INTERNAL_ASSERT(at::are_expandable(self.sizes(), other.sizes())); |
812 | return {Shape( |
813 | c10::ScalarType::Bool, at::infer_size(self.sizes(), other.sizes()))}; |
814 | } |
815 | |
816 | std::vector<Shape> compute_shape_smooth_l1_loss_backward( |
817 | const at::Tensor& grad_output, |
818 | const at::Tensor& self, |
819 | const at::Tensor& target, |
820 | int64_t reduction, |
821 | double beta) { |
822 | // The `grad_output` tensor is really the input to this kernel, and while its |
823 | // shape may vary following the logic of the forward output, the output of |
824 | // this kernel should have fixed shapes matching the inputs to the forward |
825 | // kernel. |
826 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
827 | } |
828 | |
829 | std::vector<Shape> compute_shape_logdet(const at::Tensor& self) { |
830 | // assumes self.shape is {*, n, n} and returns shape * |
831 | TORCH_INTERNAL_ASSERT(self.dim() >= 2); |
832 | std::vector<int64_t> out_sizes(self.sizes().begin(), self.sizes().end() - 2); |
833 | // Doesn't check input dtype, but output dtype either matches it, |
834 | // or the actual logdet operation will throw if it's an unsupported type |
835 | return {Shape(self.scalar_type(), out_sizes)}; |
836 | } |
837 | |
838 | std::vector<Shape> compute_shape_log_sigmoid_forward(const at::Tensor& self) { |
839 | // Based on definition of |
840 | // aten/src/ATen/native/Activation.cpp::log_sigmoid_forward_out_cpu. |
841 | return { |
842 | Shape(self.scalar_type(), self.sizes().vec()), |
843 | Shape(self.scalar_type(), self.sizes().vec())}; |
844 | } |
845 | |
846 | std::vector<Shape> compute_shape_log_sigmoid_backward( |
847 | const at::Tensor& grad_output, |
848 | const at::Tensor& self, |
849 | const at::Tensor& buffer) { |
850 | // Based on definition of |
851 | // aten/src/ATen/native/Activation.cpp::log_sigmoid_backward_cpu*. |
852 | return {Shape(grad_output.scalar_type(), grad_output.sizes().vec())}; |
853 | } |
854 | |
855 | std::vector<Shape> compute_shape_nll_loss2d_forward( |
856 | const at::Tensor& self, |
857 | const at::Tensor& target, |
858 | const c10::optional<at::Tensor>& weight, |
859 | int64_t reduction, |
860 | int64_t ignore_index) { |
861 | // Based on definition of |
862 | // aten/src/ATen/native/LossNLL2d.cpp:nll_loss2d_forward_cpu |
863 | auto sizes = |
864 | (reduction == at::Reduction::Reduction::None ? target.sizes().vec() |
865 | : std::vector<int64_t>{}); |
866 | return {Shape(self.scalar_type(), sizes), Shape(self.scalar_type(), {})}; |
867 | } |
868 | |
869 | std::vector<Shape> compute_shape_nll_loss2d_backward( |
870 | const at::Tensor& grad_output, |
871 | const at::Tensor& self, |
872 | const at::Tensor& target, |
873 | const c10::optional<at::Tensor>& weight, |
874 | int64_t reduction, |
875 | int64_t ignore_index, |
876 | const at::Tensor& total_weight) { |
877 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
878 | } |
879 | |
880 | std::vector<Shape> compute_shape_grid_sampler_2d( |
881 | const at::Tensor& input, |
882 | const at::Tensor& grid, |
883 | int64_t interpolation_mode, |
884 | int64_t padding_mode, |
885 | bool align_corners) { |
886 | // from `aten/src/ATen/native/cpu/GridSamplerKernel.cpp |
887 | int64_t N = input.size(0); |
888 | int64_t C = input.size(1); |
889 | int64_t H = grid.size(1); |
890 | int64_t W = grid.size(2); |
891 | return {Shape(input.scalar_type(), {N, C, H, W})}; |
892 | } |
893 | |
894 | std::vector<Shape> compute_shape_grid_sampler_2d_backward( |
895 | const at::Tensor& grad_output, |
896 | const at::Tensor& input, |
897 | const at::Tensor& grid, |
898 | int64_t interpolation_mode, |
899 | int64_t padding_mode, |
900 | bool align_corners, |
901 | ::std::array<bool, 2> output_mask) { |
902 | // from `aten/src/ATen/native/cpu/GridSamplerKernel.cpp |
903 | auto grad_input_shape = Shape(input.scalar_type(), input.sizes().vec()); |
904 | auto grad_grid_shape = Shape(grid.scalar_type(), grid.sizes().vec()); |
905 | return {grad_input_shape, grad_grid_shape}; |
906 | } |
907 | |
908 | std::vector<Shape> compute_shape_flip( |
909 | const at::Tensor& self, |
910 | at::IntArrayRef dims) { |
911 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
912 | } |
913 | |
914 | std::vector<Shape> compute_shape__adaptive_avg_pool2d( |
915 | const at::Tensor& self, |
916 | at::IntArrayRef output_size) { |
917 | // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp` |
918 | // and on `aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp` |
919 | TORCH_CHECK( |
920 | output_size.size() == 2, "adaptive_avg_pool2d: output_size must be 2" ); |
921 | TORCH_CHECK( |
922 | (output_size[0] >= 0 && output_size[1] >= 0), |
923 | "adaptive_avg_pool2d: elements of output_size must be greater than or equal to 0 " , |
924 | "but received {" , |
925 | output_size[0], |
926 | ", " , |
927 | output_size[1], |
928 | "}" ); |
929 | int64_t ndim = self.ndimension(); |
930 | for (const auto i : c10::irange(1, ndim)) { |
931 | TORCH_CHECK( |
932 | self.size(i) > 0, |
933 | "adaptive_avg_pool2d(): Expected self to have non-zero size for non-batch dimensions, " |
934 | "but Tensor has sizes " , |
935 | self.sizes(), |
936 | " with dimension " , |
937 | i, |
938 | " being " |
939 | "empty" ); |
940 | } |
941 | TORCH_CHECK( |
942 | (ndim == 3 || ndim == 4), |
943 | "adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got " , |
944 | self.sizes()); |
945 | |
946 | int64_t channels = self.size(-3); |
947 | int64_t output_height = output_size[0]; |
948 | int64_t output_width = output_size[1]; |
949 | |
950 | if (ndim == 3) { |
951 | return {Shape(self.scalar_type(), {channels, output_height, output_width})}; |
952 | } else { |
953 | int64_t nbatch = self.size(0); |
954 | return {Shape( |
955 | self.scalar_type(), {nbatch, channels, output_height, output_width})}; |
956 | } |
957 | } |
958 | |
959 | std::vector<Shape> compute_shape__adaptive_avg_pool2d_backward( |
960 | const at::Tensor& grad_output, |
961 | const at::Tensor& self) { |
962 | // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp` |
963 | int64_t ndim = grad_output.ndimension(); |
964 | |
965 | for (const auto i : c10::irange(1, ndim)) { |
966 | TORCH_CHECK( |
967 | grad_output.size(i) > 0, |
968 | "adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, " |
969 | "but grad_output has sizes " , |
970 | grad_output.sizes(), |
971 | " with dimension " , |
972 | i, |
973 | " being " |
974 | "empty" ); |
975 | } |
976 | |
977 | TORCH_CHECK( |
978 | (ndim == 3 || ndim == 4), |
979 | "adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got " , |
980 | self.sizes()); |
981 | TORCH_CHECK( |
982 | self.dtype() == grad_output.dtype(), |
983 | "expected dtype " , |
984 | self.dtype(), |
985 | " for `grad_output` but got dtype " , |
986 | grad_output.dtype()); |
987 | |
988 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
989 | } |
990 | |
991 | std::vector<Shape> compute_shape__adaptive_avg_pool3d( |
992 | const at::Tensor& self, |
993 | at::IntArrayRef output_size) { |
994 | // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp` |
995 | // and on `aten/src/ATen/native/cpu/AdaptiveAvgPoolKernel.cpp` |
996 | TORCH_CHECK( |
997 | output_size.size() == 3, "adaptive_avg_pool3d: output_size must be 3" ); |
998 | TORCH_CHECK( |
999 | (output_size[0] >= 0 && output_size[1] >= 0 && output_size[2] >= 0), |
1000 | "adaptive_avg_pool3d: elements of output_size must be greater than or equal to 0 " , |
1001 | "but received {" , |
1002 | output_size[0], |
1003 | ", " , |
1004 | output_size[1], |
1005 | ", " , |
1006 | output_size[2], |
1007 | "}" ); |
1008 | int64_t ndim = self.ndimension(); |
1009 | for (const auto i : c10::irange(1, ndim)) { |
1010 | TORCH_CHECK( |
1011 | self.size(i) > 0, |
1012 | "adaptive_avg_pool3d(): Expected self to have non-zero size for non-batch dimensions, " |
1013 | "but Tensor has sizes " , |
1014 | self.sizes(), |
1015 | " with dimension " , |
1016 | i, |
1017 | " being " |
1018 | "empty" ); |
1019 | } |
1020 | TORCH_CHECK( |
1021 | (ndim == 4 || ndim == 5), |
1022 | "adaptive_avg_pool3d(): Expected 4D or 5D tensor, but got " , |
1023 | self.sizes()); |
1024 | |
1025 | int64_t channels = self.size(-3); |
1026 | int64_t output_depth = output_size[0]; |
1027 | int64_t output_height = output_size[1]; |
1028 | int64_t output_width = output_size[2]; |
1029 | |
1030 | if (ndim == 4) { |
1031 | return {Shape( |
1032 | self.scalar_type(), |
1033 | {channels, output_depth, output_height, output_width})}; |
1034 | } else { |
1035 | int64_t nbatch = self.size(0); |
1036 | return {Shape( |
1037 | self.scalar_type(), |
1038 | {nbatch, channels, output_depth, output_height, output_width})}; |
1039 | } |
1040 | } |
1041 | |
1042 | std::vector<Shape> compute_shape__adaptive_avg_pool3d_backward( |
1043 | const at::Tensor& grad_output, |
1044 | const at::Tensor& self) { |
1045 | // Checks based on `aten/src/ATen/native/AdaptiveAveragePooling.cpp` |
1046 | int64_t ndim = grad_output.ndimension(); |
1047 | |
1048 | for (const auto i : c10::irange(1, ndim)) { |
1049 | TORCH_CHECK( |
1050 | grad_output.size(i) > 0, |
1051 | "adaptive_avg_pool3d_backward(): Expected grad_output to have non-zero size for non-batch dimensions, " |
1052 | "but grad_output has sizes " , |
1053 | grad_output.sizes(), |
1054 | " with dimension " , |
1055 | i, |
1056 | " being " |
1057 | "empty" ); |
1058 | } |
1059 | |
1060 | TORCH_CHECK( |
1061 | (ndim == 4 || ndim == 5), |
1062 | "adaptive_avg_pool3d_backward(): Expected 4D or 5D tensor, but got " , |
1063 | self.sizes()); |
1064 | TORCH_CHECK( |
1065 | self.dtype() == grad_output.dtype(), |
1066 | "expected dtype " , |
1067 | self.dtype(), |
1068 | " for `grad_output` but got dtype " , |
1069 | grad_output.dtype()); |
1070 | |
1071 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
1072 | } |
1073 | |
1074 | std::vector<Shape> compute_shape_glu_backward( |
1075 | const at::Tensor& grad_output, |
1076 | const at::Tensor& self, |
1077 | int64_t dim) { |
1078 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
1079 | } |
1080 | |
1081 | std::vector<Shape> compute_shape_glu_jvp( |
1082 | const at::Tensor& glu, |
1083 | const at::Tensor& x, |
1084 | const at::Tensor& dx, |
1085 | int64_t dim) { |
1086 | return {Shape(glu.scalar_type(), glu.sizes().vec())}; |
1087 | } |
1088 | |
1089 | std::vector<Shape> compute_shape_clamp_min( |
1090 | const at::Tensor& self, |
1091 | const at::Scalar& min) { |
1092 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
1093 | } |
1094 | |
1095 | std::vector<Shape> compute_shape__to_copy( |
1096 | const at::Tensor& self, |
1097 | c10::optional<at::ScalarType> dtype, |
1098 | c10::optional<at::Layout> layout, |
1099 | c10::optional<at::Device> device, |
1100 | c10::optional<bool> pin_memory, |
1101 | bool non_blocking, |
1102 | c10::optional<at::MemoryFormat> memory_format) { |
1103 | if (dtype) { |
1104 | return {Shape(*dtype, self.sizes().vec())}; |
1105 | } |
1106 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
1107 | } |
1108 | |
1109 | TORCH_API std::vector<Shape> compute_shape_clone( |
1110 | const at::Tensor& self, |
1111 | c10::optional<at::MemoryFormat> memory_format) { |
1112 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
1113 | } |
1114 | |
1115 | std::vector<Shape> compute_shape_stack(at::TensorList tensors, int64_t dim) { |
1116 | TORCH_CHECK(!tensors.empty(), "stack expects a non-empty TensorList" ); |
1117 | auto wrapped_dim = at::maybe_wrap_dim(dim, tensors[0].ndimension() + 1); |
1118 | |
1119 | // Copied from 'check_stack_inputs' in TensorShape.cpp |
1120 | at::IntArrayRef entry_shape = tensors[0].sizes(); |
1121 | for (const auto i : c10::irange(1, tensors.size())) { |
1122 | TORCH_CHECK( |
1123 | tensors[i].sizes() == entry_shape, |
1124 | "stack expects each tensor to be equal size, but got " , |
1125 | entry_shape, |
1126 | " at entry 0 and " , |
1127 | tensors[i].sizes(), |
1128 | " at entry " , |
1129 | i); |
1130 | } |
1131 | |
1132 | auto result_sizes = tensors[0].sizes().vec(); |
1133 | result_sizes.insert(result_sizes.begin() + wrapped_dim, tensors.size()); |
1134 | return {Shape(tensors[0].scalar_type(), result_sizes)}; |
1135 | } |
1136 | |
1137 | std::vector<Shape> compute_shape_repeat( |
1138 | const at::Tensor& self, |
1139 | at::IntArrayRef repeats) { |
1140 | TORCH_CHECK_GE(repeats.size(), self.dim()); |
1141 | int64_t num_new_dimensions = repeats.size() - self.dim(); |
1142 | std::vector<int64_t> padded_size(num_new_dimensions, 1); |
1143 | padded_size.insert( |
1144 | padded_size.end(), self.sizes().begin(), self.sizes().end()); |
1145 | std::vector<int64_t> target_size(repeats.size()); |
1146 | for (const auto idx : c10::irange(repeats.size())) { |
1147 | target_size[idx] = padded_size[idx] * repeats[idx]; |
1148 | } |
1149 | return {Shape(self.scalar_type(), target_size)}; |
1150 | } |
1151 | |
1152 | std::vector<Shape> compute_shape_narrow_copy_symint( |
1153 | const at::Tensor& self, |
1154 | int64_t dim, |
1155 | int64_t start, |
1156 | c10::SymInt length) { |
1157 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
1158 | } |
1159 | |
1160 | std::vector<Shape> compute_shape_hardswish(const at::Tensor& self) { |
1161 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
1162 | } |
1163 | |
1164 | std::vector<Shape> compute_shape_hardswish_backward( |
1165 | const at::Tensor& grad_output, |
1166 | const at::Tensor& self) { |
1167 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
1168 | } |
1169 | |
1170 | std::vector<Shape> compute_shape_selu(const at::Tensor& self) { |
1171 | return {Shape(self.scalar_type(), self.sizes().vec())}; |
1172 | } |
1173 | |
1174 | // Non-Native Ops |
1175 | std::vector<Shape> compute_shape_scalar( |
1176 | const at::Scalar& value, |
1177 | const at::ScalarType& type) { |
1178 | return {Shape(type, {})}; |
1179 | } |
1180 | std::vector<Shape> compute_shape_expand( |
1181 | const Output& input, |
1182 | const std::vector<int64_t>& size, |
1183 | const bool& is_scalar_expand) { |
1184 | return {Shape(input.shape().scalar_type(), size)}; |
1185 | } |
1186 | std::vector<Shape> compute_shape_view( |
1187 | const Output& input, |
1188 | const std::vector<int64_t>& output_sizes) { |
1189 | const Shape& input_shape = input.shape(); |
1190 | const auto complete_output_sizes = |
1191 | at::infer_size(output_sizes, input_shape.numel()); |
1192 | return {Shape(input_shape.scalar_type(), complete_output_sizes)}; |
1193 | } |
1194 | std::vector<Shape> compute_shape_cast( |
1195 | const Output& input, |
1196 | const at::ScalarType& dtype, |
1197 | const c10::optional<at::ScalarType>& stype) { |
1198 | Shape shape = input.shape(); |
1199 | shape.set_scalar_type(dtype); |
1200 | return {shape}; |
1201 | } |
1202 | |
1203 | // View Ops |
1204 | std::vector<Shape> compute_shape_as_strided_view_update( |
1205 | const Output& target, |
1206 | const Output& input, |
1207 | const std::vector<int64_t>& size, |
1208 | const std::vector<int64_t>& stride, |
1209 | const int64_t& storage_offset) { |
1210 | return {Shape(target.shape().scalar_type(), size)}; |
1211 | } |
1212 | std::vector<Shape> compute_shape_as_strided( |
1213 | const Output& input, |
1214 | const std::vector<int64_t>& size, |
1215 | const std::vector<int64_t>& stride, |
1216 | const int64_t& storage_offset) { |
1217 | return {Shape(input.shape().scalar_type(), size)}; |
1218 | } |
1219 | std::vector<Shape> compute_shape_diagonal_view_update( |
1220 | const Output& target, |
1221 | const Output& input, |
1222 | const int64_t& offset, |
1223 | const int64_t& dim1, |
1224 | const int64_t& dim2) { |
1225 | return {target.shape()}; |
1226 | } |
1227 | std::vector<Shape> compute_shape_diagonal( |
1228 | const Output& input, |
1229 | const int64_t& offset, |
1230 | const int64_t& dim1, |
1231 | const int64_t& dim2) { |
1232 | return {MakeDiagonalShape(input.shape(), offset, dim1, dim2)}; |
1233 | } |
1234 | std::vector<Shape> compute_shape_narrow_view_update( |
1235 | const Output& input, |
1236 | const Output& source, |
1237 | const std::vector<int64_t>& base_indices) { |
1238 | return {input.shape()}; |
1239 | } |
1240 | std::vector<Shape> compute_shape_narrow( |
1241 | const Output& input, |
1242 | const std::vector<int64_t>& base_indices, |
1243 | const std::vector<int64_t>& sizes) { |
1244 | return {Shape(input.shape().scalar_type(), sizes)}; |
1245 | } |
1246 | std::vector<Shape> compute_shape_permute( |
1247 | const Output& input, |
1248 | const std::vector<int64_t>& dims) { |
1249 | return {MakePermuteShape(input.shape(), dims)}; |
1250 | } |
1251 | std::vector<Shape> compute_shape_resize( |
1252 | const Output& input, |
1253 | const std::vector<int64_t>& size) { |
1254 | return {Shape(input.shape().scalar_type(), size)}; |
1255 | } |
1256 | std::vector<Shape> compute_shape_select_view_update( |
1257 | const Output& target, |
1258 | const Output& source, |
1259 | const int64_t& dim, |
1260 | const int64_t& start, |
1261 | const int64_t& end, |
1262 | const int64_t& stride) { |
1263 | return {target.shape()}; |
1264 | } |
1265 | std::vector<Shape> compute_shape_select( |
1266 | const Output& input, |
1267 | const int64_t& dim, |
1268 | const int64_t& start, |
1269 | const int64_t& end, |
1270 | const int64_t& stride) { |
1271 | return {MakeSelectShape(input.shape(), dim, start, end, stride)}; |
1272 | } |
1273 | std::vector<Shape> compute_shape_squeeze(const Output& input, const int& dim) { |
1274 | const auto& input_shape = input.shape(); |
1275 | return {torch::lazy::Shape( |
1276 | input_shape.scalar_type(), |
1277 | BuildSqueezedDimensions(input_shape.sizes(), dim))}; |
1278 | } |
1279 | std::vector<Shape> compute_shape_unsqueeze( |
1280 | const Output& input, |
1281 | const int& dim) { |
1282 | const auto& input_shape = input.shape(); |
1283 | return {torch::lazy::Shape( |
1284 | input_shape.scalar_type(), |
1285 | BuildUnsqueezedDimensions(input_shape.sizes(), dim))}; |
1286 | } |
1287 | |
1288 | std::vector<Shape> compute_shape_select_scatter( |
1289 | const at::Tensor& self, |
1290 | const at::Tensor& src, |
1291 | int64_t dim, |
1292 | int64_t index) { |
1293 | auto self_meta = at::native::empty_strided_meta_symint( |
1294 | self.sym_sizes(), |
1295 | self.sym_strides(), |
1296 | /*dtype=*/c10::make_optional(self.scalar_type()), |
1297 | /*layout=*/c10::make_optional(self.layout()), |
1298 | /*device=*/c10::make_optional(c10::Device(c10::kMeta)), |
1299 | /*pin_memory=*/c10::nullopt); |
1300 | auto src_meta = at::native::empty_strided_meta_symint( |
1301 | src.sym_sizes(), |
1302 | src.sym_strides(), |
1303 | /*dtype=*/c10::make_optional(src.scalar_type()), |
1304 | /*layout=*/c10::make_optional(src.layout()), |
1305 | /*device=*/c10::make_optional(c10::Device(c10::kMeta)), |
1306 | /*pin_memory=*/c10::nullopt); |
1307 | auto out_meta = at::compositeexplicitautograd::select_scatter( |
1308 | self_meta, src_meta, dim, index); |
1309 | return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; |
1310 | } |
1311 | |
1312 | std::vector<Shape> compute_shape_diagonal_scatter( |
1313 | const at::Tensor& self, |
1314 | const at::Tensor& src, |
1315 | int64_t offset, |
1316 | int64_t dim1, |
1317 | int64_t dim2) { |
1318 | auto self_meta = at::native::empty_strided_meta_symint( |
1319 | self.sym_sizes(), |
1320 | self.sym_strides(), |
1321 | /*dtype=*/c10::make_optional(self.scalar_type()), |
1322 | /*layout=*/c10::make_optional(self.layout()), |
1323 | /*device=*/c10::make_optional(c10::Device(c10::kMeta)), |
1324 | /*pin_memory=*/c10::nullopt); |
1325 | auto src_meta = at::native::empty_strided_meta_symint( |
1326 | src.sym_sizes(), |
1327 | src.sym_strides(), |
1328 | /*dtype=*/c10::make_optional(src.scalar_type()), |
1329 | /*layout=*/c10::make_optional(src.layout()), |
1330 | /*device=*/c10::make_optional(c10::Device(c10::kMeta)), |
1331 | /*pin_memory=*/c10::nullopt); |
1332 | auto out_meta = at::compositeexplicitautograd::diagonal_scatter( |
1333 | self_meta, src_meta, offset, dim1, dim2); |
1334 | return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; |
1335 | } |
1336 | |
1337 | std::vector<Shape> compute_shape_slice_scatter_symint( |
1338 | const at::Tensor& self, |
1339 | const at::Tensor& src, |
1340 | int64_t dim, |
1341 | c10::optional<c10::SymInt> start, |
1342 | c10::optional<c10::SymInt> end, |
1343 | c10::SymInt step) { |
1344 | auto self_meta = at::native::empty_strided_meta_symint( |
1345 | self.sym_sizes(), |
1346 | self.sym_strides(), |
1347 | /*dtype=*/c10::make_optional(self.scalar_type()), |
1348 | /*layout=*/c10::make_optional(self.layout()), |
1349 | /*device=*/c10::make_optional(c10::Device(c10::kMeta)), |
1350 | /*pin_memory=*/c10::nullopt); |
1351 | auto src_meta = at::native::empty_strided_meta_symint( |
1352 | src.sym_sizes(), |
1353 | src.sym_strides(), |
1354 | /*dtype=*/c10::make_optional(src.scalar_type()), |
1355 | /*layout=*/c10::make_optional(src.layout()), |
1356 | /*device=*/c10::make_optional(c10::Device(c10::kMeta)), |
1357 | /*pin_memory=*/c10::nullopt); |
1358 | auto out_meta = at::compositeexplicitautograd::slice_scatter_symint( |
1359 | self_meta, src_meta, dim, start, end, step); |
1360 | return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; |
1361 | } |
1362 | |
1363 | std::vector<Shape> compute_shape_as_strided_scatter_symint( |
1364 | const at::Tensor& self, |
1365 | const at::Tensor& src, |
1366 | at::SymIntArrayRef size, |
1367 | at::SymIntArrayRef stride, |
1368 | c10::optional<c10::SymInt> storage_offset) { |
1369 | auto self_meta = at::native::empty_strided_meta_symint( |
1370 | self.sym_sizes(), |
1371 | self.sym_strides(), |
1372 | /*dtype=*/c10::make_optional(self.scalar_type()), |
1373 | /*layout=*/c10::make_optional(self.layout()), |
1374 | /*device=*/c10::make_optional(c10::Device(c10::kMeta)), |
1375 | /*pin_memory=*/c10::nullopt); |
1376 | auto src_meta = at::native::empty_strided_meta_symint( |
1377 | src.sym_sizes(), |
1378 | src.sym_strides(), |
1379 | /*dtype=*/c10::make_optional(src.scalar_type()), |
1380 | /*layout=*/c10::make_optional(src.layout()), |
1381 | /*device=*/c10::make_optional(c10::Device(c10::kMeta)), |
1382 | /*pin_memory=*/c10::nullopt); |
1383 | auto out_meta = at::compositeexplicitautograd::as_strided_scatter_symint( |
1384 | self_meta, src_meta, size, stride, storage_offset); |
1385 | return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; |
1386 | } |
1387 | |
1388 | // Restore unused-parameters warnings |
1389 | #pragma GCC diagnostic pop |
1390 | |
1391 | } // namespace lazy |
1392 | } // namespace torch |
1393 | |