1 | #pragma once |
2 | |
3 | #include <ATen/DimVector.h> |
4 | #include <ATen/core/Dimname.h> |
5 | #include <c10/core/TensorOptions.h> |
6 | #include <c10/util/strides.h> |
7 | |
8 | C10_CLANG_DIAGNOSTIC_PUSH() |
9 | #if C10_CLANG_HAS_WARNING("-Wdeprecated-copy-dtor") |
10 | C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy-dtor" ) |
11 | #endif |
12 | |
13 | namespace at { |
14 | |
15 | class Tensor; |
16 | |
17 | namespace impl { |
18 | |
19 | // Use this to define the prototype for a meta function. There are two |
20 | // versions; one that takes one argument (just the operator name), or FUNC2 |
21 | // variant that takes two arguments (operator name and overload name). |
22 | // |
23 | // Example usage: |
24 | // |
25 | // TORCH_META_FUNC2(add, Tensor) ( |
26 | // const Tensor& self, const Tensor& other |
27 | // ) { |
28 | // ... compute sizes and options ... |
29 | // set_output(sizes, options); |
30 | // } |
31 | // |
32 | #define TORCH_META_FUNC(name) void structured_##name::meta |
33 | #define TORCH_META_FUNC2(name, overload) \ |
34 | void structured_##name##_##overload::meta |
35 | |
36 | // These are versions of TORCH_META_FUNC(2) that include a precompute_out struct |
37 | // as a return value. They should be used when the kernel in question has |
38 | // precomputed values declared in native_functions.yaml and the corresponding |
39 | // implementation should return an instance of the aforementioned struct. |
40 | #define TORCH_PRECOMPUTE_META_FUNC(name) \ |
41 | structured_##name::meta_return_ty structured_##name::meta |
42 | #define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \ |
43 | structured_##name##_##overload::meta_return_ty \ |
44 | structured_##name##_##overload::meta |
45 | |
46 | // Use this to create a precompute struct in a meta function. |
47 | #define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<> |
48 | #define TORCH_PRECOMPUTE_STRUCT2(name, overload) \ |
49 | structured_##name##_##overload::precompute_out<> |
50 | |
51 | // Use this to define the prototype for an implementation. This takes only |
52 | // one argument, which is the name of the dispatch key entry you're |
53 | // implementing. |
54 | // |
55 | // Example usage: |
56 | // |
57 | // TORCH_IMPL_FUNC(add_cpu) ( |
58 | // Tensor& result, const Tensor& self, const Tensor& other |
59 | // ) { |
60 | // ... do the actual implementation ... |
61 | // } |
62 | // |
63 | #define TORCH_IMPL_FUNC(name) void structured_##name::impl |
64 | |
65 | // Base class for all structured kernel classes. The set_output virtual |
66 | // method is varied depending whether or not the operator is |
67 | // functional/out/inplace, and could also be specialized for CPU/CUDA/etc |
68 | // (although presently it isn't). |
69 | // |
70 | // A notable subclass of this interface is TensorIteratorBase. |
71 | struct TORCH_API MetaBase { |
72 | virtual const Tensor& maybe_get_output(int64_t output_idx) = 0; |
73 | |
74 | // Note: [set_output_*] |
75 | // See: https://github.com/pytorch/pytorch/issues/69813 |
76 | // Whenever defining the output properties in the META function of a |
77 | // structured kernel (what was usually done with `set_output`), use one of |
78 | // these 3 variants, instead. In order to decide which variant to use, check |
79 | // the following decision tree: |
80 | // |
81 | // - Can the kernel you are going to implement support output tensors |
82 | // with arbitrary strides? |
83 | // | |
84 | // -- YES: `set_output_raw_strided` |
85 | // | |
86 | // -- NO: Should the output tensor strides be contiguous? |
87 | // | |
88 | // -- YES: `set_output_contiguous` |
89 | // | |
90 | // -- NO: `set_output_strided` |
91 | // |
92 | // Use this function whenever the kernel requires specific strides for the |
93 | // output. If `strides` does not match the given output strides, proxy outputs |
94 | // will be created and passed to the IMPL function. |
95 | virtual void set_output_strided( |
96 | int64_t output_idx, |
97 | IntArrayRef sizes, |
98 | IntArrayRef strides, |
99 | TensorOptions options, |
100 | DimnameList names = {}) { |
101 | TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented." ); |
102 | } |
103 | |
104 | // Use this function whenever the kernel knows how to handle arbitrary strided |
105 | // outputs. This function has the same behavior as the old `set_output`: it |
106 | // will only re-stride if the given output was resized. |
107 | virtual void set_output_raw_strided( |
108 | int64_t output_idx, |
109 | IntArrayRef sizes, |
110 | IntArrayRef strides_hint, |
111 | TensorOptions options, |
112 | DimnameList names = {}) { |
113 | TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented." ); |
114 | } |
115 | |
116 | // Use this function if the kernel requires contiguous strides. |
117 | // Alias for `set_output_strided`, but with contiguous strides. |
118 | void set_output_contiguous( |
119 | int64_t output_idx, |
120 | IntArrayRef sizes, |
121 | TensorOptions options, |
122 | DimnameList names = {}) { |
123 | auto strides = c10::contiguous_strides(sizes); |
124 | set_output_strided(output_idx, sizes, strides, options, names); |
125 | } |
126 | |
127 | // Returns a reference to an undefined tensor if there is no presupplied |
128 | // output |
129 | const Tensor& maybe_get_output() { |
130 | return maybe_get_output(0); |
131 | } |
132 | virtual ~MetaBase() = default; |
133 | }; |
134 | |
135 | } // namespace impl |
136 | |
137 | } // namespace at |
138 | |
139 | C10_CLANG_DIAGNOSTIC_POP() |
140 | |