1 | #pragma once |
2 | #include <ATen/NamedTensor.h> |
3 | #include <ATen/TensorNames.h> |
4 | #include <ATen/WrapDimUtilsMulti.h> |
5 | |
6 | #include <ATen/core/DimVector.h> |
7 | #include <ATen/core/Tensor.h> |
8 | #include <functional> |
9 | |
10 | namespace at { |
11 | |
12 | using NameVector = SmallVector<Dimname, kDimVectorStaticSize>; |
13 | |
14 | inline bool has_names(ITensorListRef tensors) { |
15 | return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) { |
16 | return t.has_names(); |
17 | }); |
18 | } |
19 | |
20 | // Converts dim to an positional index. Errors if `dim` cannot be used to |
21 | // refer to any dimension of tensor. |
22 | TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim); |
23 | TORCH_API std::vector<int64_t> dimnames_to_positions( |
24 | const Tensor& tensor, |
25 | DimnameList dims); |
26 | |
27 | // Unifies two DimnameList to produce a third. This is useful for implementing |
28 | // the named inference rule for binary broadcasting operations like add. |
29 | // |
30 | // There are three main constraints: |
31 | // 1) Check matching: Names must match positionally from the right. |
32 | // 2) Check misaligned: If a name `n` is in `names`, then it must appear at |
33 | // the same index from the right in other. |
34 | // 3) The output names are obtained by unifying the names individually from the |
35 | // right. |
36 | TORCH_API std::vector<Dimname> unify_from_right( |
37 | DimnameList names, |
38 | DimnameList other, |
39 | const char* action = "broadcast" ); |
40 | |
41 | [[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) { |
42 | TORCH_CHECK( |
43 | false, |
44 | op_name, |
45 | ": You passed a dimname (string) to this op in place of a dimension " |
46 | "index but it does not yet support this behavior. Please pass a dimension " |
47 | "index to work around this." ); |
48 | } |
49 | |
50 | // [NOTE] Writing name inference rules |
51 | // |
52 | // Operators that support named tensors are either composed of operations that |
53 | // support named tensors or implement some name inference rule. An op that |
54 | // implements its own name inference rule generally looks like the following: |
55 | // |
56 | // Tensor op(...) { |
57 | // perform_shape_checks(...); |
58 | // # (1) |
59 | // auto maybe_outnames = compute_outnames(...); |
60 | // auto result = [&]() { |
61 | // NoNamesGuard guard; |
62 | // return op_impl(...); |
63 | // }(); |
64 | // # (2) |
65 | // propagate_names_if_nonempty(result, maybe_outnames); |
66 | // |
67 | // Each op has (1) a compute outnames step and (2) a propagate names step. |
68 | // |
69 | // compute_outnames is responsible for checking that input names match and |
70 | // determining what the output names should be. It returns either: |
71 | // - {} (if the inputs tensors are all unnamed) |
72 | // - non-empty outnames. |
73 | // |
74 | // propagate_names_if_nonempty propagates the outnames if they exist to the |
75 | // result tensors. |
76 | // |
77 | // The {} case is an optimization; if the user does not use named tensors they |
78 | // pay no perf cost for it. |
79 | |
80 | namespace namedinference { |
81 | |
82 | const Tensor& propagate_names_if_present_and_nonempty( |
83 | const Tensor& result, |
84 | c10::optional<DimnameList> maybe_names, |
85 | bool validate_names = false); |
86 | // Propagates `names` to `result` if `names` is not empty. |
87 | // `names` can be empty; see [NOTE] Writing name inference rules |
88 | // If `names` is not empty, `names.size()` should equal `result.dim()`. |
89 | // When in doubt, use this overload instead of the others. |
90 | TORCH_API const Tensor& propagate_names_if_nonempty( |
91 | const Tensor& result, |
92 | DimnameList maybe_names, |
93 | bool validate_names = false); |
94 | |
95 | // Propagates `names` to `result`. Only use this if we are certain that there |
96 | // are names to propagate (that names is not empty). |
97 | TORCH_API const Tensor& propagate_names( |
98 | const Tensor& result, |
99 | DimnameList names, |
100 | bool validate_names = false); |
101 | |
102 | // Propagates all names from src to result. |
103 | TORCH_API void propagate_names(const Tensor& result, const Tensor& src); |
104 | |
105 | // Propagates all names except for those at the excluded_idxs. |
106 | TORCH_API void propagate_names_except( |
107 | const Tensor& result, |
108 | const Tensor& src, |
109 | IntArrayRef excluded_idxs); |
110 | |
111 | // Used for reduction ops that have a `keepdim` arg. |
112 | TORCH_API void propagate_names_for_reduction( |
113 | const Tensor& result, |
114 | const Tensor& src, |
115 | IntArrayRef excluded_idxs, |
116 | bool keepdim); |
117 | |
118 | TORCH_API void propagate_names_for_expand( |
119 | const Tensor& result, |
120 | const Tensor& self); |
121 | |
122 | TORCH_API std::vector<Dimname> compute_cat_outnames( |
123 | const MaterializedITensorListRef& tensors); |
124 | |
125 | TORCH_API std::vector<Dimname> compute_broadcast_outnames( |
126 | const Tensor& self, |
127 | const Tensor& other); |
128 | |
129 | TORCH_API std::vector<Dimname> broadcast_to_outnames( |
130 | const Tensor& tensor, |
131 | const Tensor& reference_tensor, |
132 | const char* op_name); |
133 | |
134 | TORCH_API std::vector<Dimname> compute_matmul_outnames( |
135 | const Tensor& self, |
136 | const Tensor& other); |
137 | |
138 | TORCH_API std::vector<Dimname> compute_cdist_outnames( |
139 | const Tensor& self, |
140 | const Tensor& other); |
141 | |
142 | TORCH_API std::vector<Dimname> compute_bmm_outnames( |
143 | const Tensor& result, |
144 | const Tensor& self, |
145 | const Tensor& other); |
146 | |
147 | TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor); |
148 | TORCH_API std::vector<Dimname> compute_squeeze_outnames( |
149 | const Tensor& tensor, |
150 | std::bitset<dim_bitset_size> dims); |
151 | |
152 | std::vector<Dimname> compute_diagonal_outnames( |
153 | const Tensor& tensor, |
154 | int64_t dim1, |
155 | int64_t dim2); |
156 | |
157 | // TensorImpl* overloads for Legacy TH/THC code. Use these sparingly. |
158 | |
159 | TORCH_API TensorImpl* propagate_names_if_nonempty( |
160 | TensorImpl* result, |
161 | DimnameList maybe_names, |
162 | bool validate_names = false); |
163 | |
164 | TORCH_API TensorImpl* propagate_names( |
165 | TensorImpl* result, |
166 | DimnameList names, |
167 | bool validate_names = false); |
168 | |
169 | TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src); |
170 | |
171 | TORCH_API inline void propagate_names( |
172 | const TensorBase& result, |
173 | DimnameList names, |
174 | bool validate_names = false) { |
175 | propagate_names(result.unsafeGetTensorImpl(), names, validate_names); |
176 | } |
177 | |
178 | TORCH_API inline void propagate_names_if_nonempty( |
179 | const TensorBase& result, |
180 | DimnameList names, |
181 | bool validate_names = false) { |
182 | propagate_names_if_nonempty( |
183 | result.unsafeGetTensorImpl(), names, validate_names); |
184 | } |
185 | |
186 | TORCH_API inline void propagate_names( |
187 | const TensorBase& result, |
188 | const TensorBase& src) { |
189 | propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl()); |
190 | } |
191 | |
192 | // result = m1 @ m2 + bias |
193 | TORCH_API std::vector<Dimname> propagate_names_for_addmm( |
194 | const Tensor& m1, |
195 | const Tensor& m2, |
196 | const Tensor& bias); |
197 | |
198 | TORCH_API std::vector<Dimname> propagate_names_for_addmv( |
199 | const Tensor& mat, |
200 | const Tensor& vec, |
201 | const Tensor& bias); |
202 | |
203 | TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2); |
204 | |
205 | TORCH_API std::vector<Dimname> compute_baddbmm_outnames( |
206 | const Tensor& result, |
207 | const Tensor& self, |
208 | const Tensor& other, |
209 | const Tensor& bias); |
210 | |
211 | TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other); |
212 | |
213 | } // namespace namedinference |
214 | |
215 | } // namespace at |
216 | |