1 | #include <ATen/NamedTensorUtils.h> |
2 | #include <ATen/TensorNames.h> |
3 | #include <ATen/WrapDimUtilsMulti.h> |
4 | #include <c10/util/irange.h> |
5 | |
6 | #include <bitset> |
7 | #include <sstream> |
8 | |
9 | namespace at { |
10 | |
11 | // Returns "Tensor['N', 'C', 'H', 'W']" for a tensor with names ('N', 'C', 'H', 'W'). |
12 | static std::string toDimnameRepr(const Tensor& tensor) { |
13 | std::ostringstream os; |
14 | os << "Tensor" << tensor.names(); |
15 | return os.str(); |
16 | } |
17 | |
18 | int64_t dimname_to_position(const Tensor& tensor, Dimname dim) { |
19 | TORCH_CHECK(dim.type() != NameType::WILDCARD, |
20 | "Please look up dimensions by name, got: name = None." ); |
21 | TORCH_CHECK(tensor.has_names(), |
22 | "Name " , dim, " not found in " , toDimnameRepr(tensor), "." ); |
23 | const auto names = tensor.names(); |
24 | |
25 | const auto it = std::find(names.begin(), names.end(), dim); |
26 | TORCH_CHECK(it != names.end(), |
27 | "Name " , dim, " not found in " , toDimnameRepr(tensor), "." ); |
28 | |
29 | return std::distance(names.begin(), it); |
30 | } |
31 | |
32 | std::vector<int64_t> dimnames_to_positions(const Tensor& tensor, DimnameList dims) { |
33 | std::vector<int64_t> result; |
34 | result.reserve(dims.size()); |
35 | for (const auto& name : dims) { |
36 | result.push_back(dimname_to_position(tensor, name)); |
37 | } |
38 | return result; |
39 | } |
40 | |
41 | static void report_positional_error( |
42 | const Dimname& name, |
43 | const Dimname& other_name, |
44 | DimnameList names, |
45 | DimnameList other_names, |
46 | const char* action) { |
47 | // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds |
48 | TORCH_CHECK(false, |
49 | "Error when attempting to " , action, " dims " , names, " and dims " , |
50 | other_names, ": dim " , name, " and dim " , other_name, " are at the same position " |
51 | "from the right but do not match." ) |
52 | } |
53 | |
54 | static void check_for_misalignment( |
55 | const Dimname& name, |
56 | DimnameList names, |
57 | DimnameList other_names, |
58 | const char* action) { |
59 | if (name.isWildcard()) { |
60 | return; |
61 | } |
62 | auto it = std::find(other_names.begin(), other_names.end(), name); |
63 | // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds |
64 | TORCH_CHECK(it == other_names.end(), |
65 | "Misaligned dims when attempting to " , action, " dims " , names, " and dims " , |
66 | other_names, ": dim " , name, " appears in a different position from the right " |
67 | "across both lists." ); |
68 | } |
69 | |
70 | // Assumption: A DimnameList can have no duplicate full names with |
71 | // the exception of wildcards |
72 | std::vector<Dimname> unify_from_right( |
73 | DimnameList names, |
74 | DimnameList other_names, |
75 | const char* action) { |
76 | const auto wildcard = Dimname::wildcard(); |
77 | const auto size = std::max(names.size(), other_names.size()); |
78 | auto result = std::vector<Dimname>(size, wildcard); |
79 | |
80 | auto names_it = names.rbegin(); |
81 | auto other_it = other_names.rbegin(); |
82 | auto result_it = result.rbegin(); |
83 | while (names_it != names.rend() || other_it != other_names.rend()) { |
84 | const auto& name = names_it == names.rend() ? wildcard : *names_it; |
85 | const auto& other_name = other_it == other_names.rend() ? wildcard : *other_it; |
86 | |
87 | // Step 1: Check that the names match |
88 | const auto maybeName = name.unify(other_name); |
89 | if (!maybeName) { |
90 | report_positional_error(name, other_name, names, other_names, action); |
91 | } |
92 | *result_it = *maybeName; |
93 | |
94 | // Step 2: Check that the names are not misaligned |
95 | if (!name.isBasic() || !other_name.isBasic()) { |
96 | // Let: N = max(len(names), len(other_names)) |
97 | // K = # of special names among names and other_names. |
98 | // This search (including the outer loop) is O(N*K) but typically # of dims is small. |
99 | check_for_misalignment(name, names, other_names, action); |
100 | check_for_misalignment(other_name, other_names, names, action); |
101 | } |
102 | |
103 | if (names_it != names.rend()) { |
104 | ++names_it; |
105 | } |
106 | if (other_it != other_names.rend()) { |
107 | ++other_it; |
108 | } |
109 | ++result_it; |
110 | } |
111 | return result; |
112 | } |
113 | |
114 | namespace namedinference { |
115 | |
116 | static std::bitset<dim_bitset_size> |
117 | compute_included_idxs(IntArrayRef excluded_idxs, int64_t ndims) { |
118 | auto result = dim_list_to_bitset(excluded_idxs, ndims); |
119 | result.flip(); |
120 | return result; |
121 | } |
122 | |
123 | static void assert_names_equal(DimnameList a, DimnameList b) { |
124 | TORCH_CHECK(a == b, |
125 | "Name mismatch: specified out tensor with names " , a, |
126 | " are not the same as the computed output names " , b, |
127 | ". Please rename the out tensor's dims with `Tensor.rename`." ); |
128 | } |
129 | |
130 | const Tensor& propagate_names_if_present_and_nonempty(const Tensor& result, |
131 | c10::optional<DimnameList> maybe_names, |
132 | bool validate_names) { |
133 | auto maybe_name_list = maybe_names.value_or(at::ArrayRef<Dimname>{}); |
134 | propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_name_list, validate_names); |
135 | return result; |
136 | } |
137 | |
138 | const Tensor& propagate_names_if_nonempty(const Tensor& result, |
139 | DimnameList maybe_names, |
140 | bool validate_names) { |
141 | propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_names, validate_names); |
142 | return result; |
143 | } |
144 | |
145 | TensorImpl* propagate_names_if_nonempty(TensorImpl* result, |
146 | DimnameList maybe_names, |
147 | bool validate_names) { |
148 | if (maybe_names.empty()) { |
149 | return result; |
150 | } |
151 | return propagate_names(result, maybe_names, validate_names); |
152 | } |
153 | |
154 | const Tensor& propagate_names(const Tensor& result, DimnameList names, bool validate_names) { |
155 | propagate_names(result.unsafeGetTensorImpl(), names, validate_names); |
156 | return result; |
157 | } |
158 | |
159 | TensorImpl* propagate_names(TensorImpl* result, DimnameList names, bool validate_names) { |
160 | if (result->dim() > 0) { |
161 | TORCH_INTERNAL_ASSERT( |
162 | !names.empty(), |
163 | "propagate_names: passed in empty names to propagate to result with" , |
164 | " shape " , result->sizes(), ". Empty names means that name inference did" , |
165 | "not occur; use `propagate_names_if_nonempty` instead of `propagate_names`." ); |
166 | } |
167 | if (!impl::has_names(result)) { |
168 | impl::internal_set_names_inplace(result, names, validate_names); |
169 | } else { |
170 | assert_names_equal(impl::get_names(result), names); |
171 | } |
172 | return result; |
173 | } |
174 | |
175 | void propagate_names_except(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) { |
176 | if (!result.has_names() && !src.has_names()) { |
177 | return; |
178 | } |
179 | const auto src_names = src.names(); |
180 | const auto result_dim = static_cast<int64_t>(result.dim()); |
181 | const auto src_dim = static_cast<int64_t>(src_names.size()); |
182 | const auto excluded_dim = static_cast<int64_t>(excluded_idxs.size()); |
183 | TORCH_INTERNAL_ASSERT(src_dim - excluded_dim == result_dim); |
184 | |
185 | // fast path |
186 | if (excluded_idxs.size() == 1) { |
187 | std::vector<Dimname> outnames = src_names.vec(); |
188 | outnames.erase(outnames.begin() + maybe_wrap_dim(excluded_idxs[0], src_dim)); |
189 | propagate_names(result, outnames); |
190 | return; |
191 | } |
192 | |
193 | std::vector<Dimname> outnames; |
194 | outnames.reserve(result_dim); |
195 | auto included_idxs = compute_included_idxs(excluded_idxs, src_dim); |
196 | for (const auto dim : c10::irange(src_dim)) { |
197 | if (included_idxs[dim]) { |
198 | outnames.push_back(src_names[dim]); |
199 | } |
200 | } |
201 | propagate_names(result, outnames); |
202 | } |
203 | |
204 | void propagate_names_for_reduction(const Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) { |
205 | if (keepdim) { |
206 | propagate_names(result, src); |
207 | return; |
208 | } |
209 | // This actually means "full reduction" |
210 | if (reduced_dims.empty()) { |
211 | return; |
212 | } |
213 | propagate_names_except(result, src, reduced_dims); |
214 | } |
215 | |
216 | void propagate_names(const Tensor& result, const Tensor& src) { |
217 | propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl()); |
218 | } |
219 | |
220 | void propagate_names(TensorImpl* result, TensorImpl* src) { |
221 | if (result == src) { |
222 | return; |
223 | } |
224 | if (!impl::has_names(result) && !impl::has_names(src)) { |
225 | return; |
226 | } |
227 | propagate_names(result, impl::get_names(src)); |
228 | } |
229 | |
230 | std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor) { |
231 | if (!tensor.has_names()) { |
232 | return {}; |
233 | } |
234 | std::vector<Dimname> outnames; |
235 | auto tensor_names = tensor.names(); |
236 | for (const auto d : c10::irange(tensor.dim())) { |
237 | if (tensor.sym_sizes()[d] != 1) { |
238 | outnames.push_back(tensor_names[d]); |
239 | } |
240 | } |
241 | return outnames; |
242 | } |
243 | |
244 | std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor, std::bitset<dim_bitset_size> dims) { |
245 | if (!tensor.has_names()) { |
246 | return {}; |
247 | } |
248 | std::vector<Dimname> outnames; |
249 | auto tensor_names = tensor.names(); |
250 | for (const auto d : c10::irange(tensor.dim())) { |
251 | if (!dims.test(d) || tensor.sym_sizes()[d] != 1) { |
252 | outnames.push_back(tensor_names[d]); |
253 | } |
254 | } |
255 | return outnames; |
256 | } |
257 | |
258 | std::vector<Dimname> compute_diagonal_outnames( |
259 | const Tensor& tensor, |
260 | int64_t dim1, |
261 | int64_t dim2) { |
262 | if (!tensor.has_names()) { |
263 | return {}; |
264 | } |
265 | std::vector<Dimname> outnames; |
266 | auto tensor_names = tensor.names(); |
267 | for (const auto d : c10::irange(tensor.dim())) { |
268 | if (d == dim1 || d == dim2) { |
269 | continue; |
270 | } |
271 | outnames.push_back(tensor_names[d]); |
272 | } |
273 | outnames.push_back(Dimname::wildcard()); |
274 | return outnames; |
275 | } |
276 | |
277 | static void check_feature_names_are_distinct( |
278 | DimnameList self_names, |
279 | DimnameList other_names, |
280 | const DimnameList& outnames) { |
281 | if (self_names.size() < 2 || other_names.size() < 2) { |
282 | // There are less than 2 feature dims in outnames so there is nothing to check |
283 | return; |
284 | } |
285 | auto feature0 = outnames[outnames.size() - 2]; |
286 | auto feature1 = outnames[outnames.size() - 1]; |
287 | TORCH_CHECK( |
288 | feature0 == Dimname::wildcard() || feature0 != feature1, |
289 | "Matrix multiplying Tensor" , self_names, |
290 | " with Tensor" , other_names, |
291 | " would produce output tensor with duplicate names " , |
292 | outnames, |
293 | ". Please rename the input tensors with `Tensor.rename` to prevent this." ); |
294 | } |
295 | |
296 | static int64_t num_batch_dims(DimnameList names) { |
297 | if (names.size() <= 2) { |
298 | return 0; |
299 | } |
300 | return names.size() - 2; |
301 | } |
302 | |
303 | static std::vector<Dimname> compute_matmul_outnames( |
304 | DimnameList self_names, |
305 | DimnameList other_names) { |
306 | TORCH_CHECK(!self_names.empty() && !other_names.empty(), |
307 | "both arguments to matmul need to be at least 1D, but they are " , |
308 | self_names.size(), "D and " , other_names.size(), "D" ); |
309 | |
310 | // matmul performs a batch matrix multiply between self and other, each of which |
311 | // can either be: |
312 | // - a batches of matrices (if dim > 2) |
313 | // - a matrix (if dim == 2) |
314 | // - a vector (if dim == 1) |
315 | // |
316 | // To compute output names, we unify the batch dimensions because those are |
317 | // broadcastable to get the output batch dimensions. |
318 | // |
319 | // After that, we append some names that are equal to the result of the matmul |
320 | // without batch dimensions. Those names are computed by removing the names |
321 | // of the dimensions that were contracted away. We always contract the |
322 | // last dim of the first tensor with the first feature dimension of the second. |
323 | |
324 | // Get the output's batch dimension names |
325 | auto wrapped_self_names = TensorNames(self_names, 0, num_batch_dims(self_names)); |
326 | const auto wrapped_other_names = TensorNames(other_names, 0, num_batch_dims(other_names)); |
327 | auto& working_names = wrapped_self_names.unifyFromRightInplace(wrapped_other_names, "matmul" ); |
328 | |
329 | // Append the result of each individual (non-batched) matmul. |
330 | // If either of self or other have dim 1, that means they are a vector. Vectors get |
331 | // completely contracted away during matmul so we don't take any names from them. |
332 | if (self_names.size() >= 2) { |
333 | working_names.append(TensorName(self_names, -2)); |
334 | } |
335 | if (other_names.size() >= 2) { |
336 | working_names.append(TensorName(other_names, -1)); |
337 | } |
338 | auto result = working_names.toDimnameVec(); |
339 | |
340 | check_feature_names_are_distinct(self_names, other_names, result); |
341 | return result; |
342 | } |
343 | |
344 | std::vector<Dimname> propagate_names_for_addmv( |
345 | const Tensor& mat, |
346 | const Tensor& vec, |
347 | const Tensor& bias) { |
348 | if (!mat.has_names() && |
349 | !vec.has_names() && !bias.has_names()) { |
350 | return std::vector<Dimname>{}; |
351 | } |
352 | auto mv_outnames = compute_matmul_outnames(mat.names(), vec.names()); |
353 | return unify_from_right(mv_outnames, bias.names()); |
354 | } |
355 | |
356 | std::vector<Dimname> propagate_names_for_addmm( |
357 | const Tensor& m1, |
358 | const Tensor& m2, |
359 | const Tensor& bias) { |
360 | if (!m1.has_names() && !m2.has_names() && |
361 | !bias.has_names()) { |
362 | return std::vector<Dimname>{}; |
363 | } |
364 | |
365 | auto mm_outnames = compute_matmul_outnames(m1.names(), m2.names()); |
366 | return unify_from_right(mm_outnames, bias.names()); |
367 | } |
368 | |
369 | void check_names_for_dot( |
370 | TensorImpl* vec1, |
371 | TensorImpl* vec2) { |
372 | if (!impl::has_names(vec1) && !impl::has_names(vec2)) { |
373 | return; |
374 | } |
375 | compute_matmul_outnames(impl::get_names(vec1), impl::get_names(vec2)); |
376 | } |
377 | |
378 | // expand adds new None dimensions. This is consistent with name inference |
379 | // rules for binary ops that expect the named dims to line up positionally |
380 | // from the right. i.e., |
381 | // Tensor[H, W].expand(3, 3, 3, 3) -> Tensor[None, None, H, W] |
382 | void propagate_names_for_expand(const Tensor& result, const Tensor& self) { |
383 | if (!self.has_names()) { |
384 | return; |
385 | } |
386 | auto result_dim = result.dim(); |
387 | if (self.dim() == result_dim) { |
388 | propagate_names(result, self); |
389 | return; |
390 | } |
391 | std::vector<Dimname> outnames(result_dim, Dimname::wildcard()); |
392 | std::copy( |
393 | self.opt_names()->begin(), |
394 | self.opt_names()->end(), |
395 | outnames.begin() + result_dim - self.dim()); |
396 | propagate_names(result, outnames); |
397 | } |
398 | |
399 | std::vector<Dimname> compute_broadcast_outnames( |
400 | const Tensor& self, |
401 | const Tensor& other) { |
402 | if (!self.has_names() && !other.has_names()) { |
403 | return {}; |
404 | } |
405 | return unify_from_right(self.names(), other.names()); |
406 | } |
407 | |
408 | std::vector<Dimname> broadcast_to_outnames( |
409 | const Tensor& tensor, |
410 | const Tensor& reference_tensor, |
411 | const char* op_name) { |
412 | if (!tensor.has_names() && !reference_tensor.has_names()) { |
413 | return {}; |
414 | } |
415 | auto reference_names = reference_tensor.names(); |
416 | auto tensor_names = tensor.names(); |
417 | TORCH_CHECK( |
418 | reference_names.size() >= tensor_names.size(), |
419 | op_name, ": attempted to broadcast Tensor" , tensor_names, " to Tensor" , |
420 | reference_names, " but the number of dims (" , tensor_names.size(), |
421 | ") must be less than or equal to the number of dims in the tensor (" , |
422 | reference_names.size(), ")" ); |
423 | return unify_from_right(reference_names, tensor_names); |
424 | } |
425 | |
426 | std::vector<Dimname> compute_cat_outnames(const MaterializedITensorListRef& tensors) { |
427 | if (!at::has_names(tensors)) { |
428 | return {}; |
429 | } |
430 | std::vector<Dimname> result; |
431 | for (const Tensor& tensor : tensors) { |
432 | const auto tensor_names = tensor.names(); |
433 | TORCH_CHECK(!tensor_names.empty(), "zero-dimensional tensor cannot be concatenated" ); |
434 | TORCH_CHECK(result.empty() || tensor_names.size() == result.size(), |
435 | "Tensors must have same number of dimensions: got " , result.size(), |
436 | " and " , tensor_names.size()); |
437 | result = unify_from_right(result, tensor_names, "cat" ); |
438 | } |
439 | return result; |
440 | } |
441 | |
442 | std::vector<Dimname> compute_matmul_outnames( |
443 | const Tensor& self, |
444 | const Tensor& other) { |
445 | if (!self.has_names() && !other.has_names()) { |
446 | return {}; |
447 | } |
448 | return compute_matmul_outnames(self.names(), other.names()); |
449 | } |
450 | |
451 | std::vector<Dimname> compute_cdist_outnames( |
452 | const Tensor& self, |
453 | const Tensor& other) { |
454 | if (!self.has_names() && !other.has_names()) { |
455 | return {}; |
456 | } |
457 | const auto self_names = self.names(); |
458 | const auto other_names = other.names(); |
459 | |
460 | auto self_batch = TensorNames(self_names, 0, num_batch_dims(self_names)); |
461 | const auto other_batch = TensorNames(other_names, 0, num_batch_dims(other_names)); |
462 | |
463 | auto& result = self_batch.unifyFromRightInplace(other_batch, "cdist" ); |
464 | |
465 | // cdist treats self and other like batches of M x D and N X D tensors, respectively. |
466 | // It computes the pairwise distance between each of the M vectors (of size D) |
467 | // in `self` and each of the N vectors in `other`, returning a batch of M x N |
468 | // distance values. We propagate the names of the dimension of size M (in self) |
469 | // and the dimension of size N (in other), both of which are second-from-last. |
470 | result.append(TensorName(self_names, -2)); |
471 | result.append(TensorName(other_names, -2)); |
472 | result.checkUnique("cdist" ); |
473 | |
474 | return result.toDimnameVec(); |
475 | } |
476 | |
477 | std::vector<Dimname> compute_bmm_outnames( |
478 | const Tensor& result, |
479 | const Tensor& self, |
480 | const Tensor& other) { |
481 | if (!result.has_names() && !self.has_names() && !other.has_names()) { |
482 | return {}; |
483 | } |
484 | return compute_matmul_outnames(self.names(), other.names()); |
485 | } |
486 | |
487 | std::vector<Dimname> compute_baddbmm_outnames( |
488 | const Tensor& result, |
489 | const Tensor& self, |
490 | const Tensor& other, |
491 | const Tensor& bias) { |
492 | if (!result.has_names() && !self.has_names() |
493 | && !other.has_names() && !bias.has_names()) { |
494 | return {}; |
495 | } |
496 | auto bmm_names = compute_matmul_outnames(self.names(), other.names()); |
497 | auto baddbmm_names = unify_from_right(bias.names(), bmm_names); |
498 | return baddbmm_names; |
499 | } |
500 | |
501 | bool are_names_equal(TensorImpl* self, TensorImpl* other) { |
502 | if (!impl::has_names(self) && !impl::has_names(other)) { |
503 | return true; |
504 | } |
505 | return impl::get_names(self) == impl::get_names(other); |
506 | } |
507 | |
508 | } // namespace namedinference |
509 | } // namespace at |
510 | |