1#include <ATen/NamedTensorUtils.h>
2#include <ATen/TensorNames.h>
3#include <ATen/WrapDimUtilsMulti.h>
4#include <c10/util/irange.h>
5
6#include <bitset>
7#include <sstream>
8
9namespace at {
10
11// Returns "Tensor['N', 'C', 'H', 'W']" for a tensor with names ('N', 'C', 'H', 'W').
12static std::string toDimnameRepr(const Tensor& tensor) {
13 std::ostringstream os;
14 os << "Tensor" << tensor.names();
15 return os.str();
16}
17
18int64_t dimname_to_position(const Tensor& tensor, Dimname dim) {
19 TORCH_CHECK(dim.type() != NameType::WILDCARD,
20 "Please look up dimensions by name, got: name = None.");
21 TORCH_CHECK(tensor.has_names(),
22 "Name ", dim, " not found in ", toDimnameRepr(tensor), ".");
23 const auto names = tensor.names();
24
25 const auto it = std::find(names.begin(), names.end(), dim);
26 TORCH_CHECK(it != names.end(),
27 "Name ", dim, " not found in ", toDimnameRepr(tensor), ".");
28
29 return std::distance(names.begin(), it);
30}
31
32std::vector<int64_t> dimnames_to_positions(const Tensor& tensor, DimnameList dims) {
33 std::vector<int64_t> result;
34 result.reserve(dims.size());
35 for (const auto& name : dims) {
36 result.push_back(dimname_to_position(tensor, name));
37 }
38 return result;
39}
40
41static void report_positional_error(
42 const Dimname& name,
43 const Dimname& other_name,
44 DimnameList names,
45 DimnameList other_names,
46 const char* action) {
47 // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
48 TORCH_CHECK(false,
49 "Error when attempting to ", action, " dims ", names, " and dims ",
50 other_names, ": dim ", name, " and dim ", other_name, " are at the same position "
51 "from the right but do not match.")
52}
53
54static void check_for_misalignment(
55 const Dimname& name,
56 DimnameList names,
57 DimnameList other_names,
58 const char* action) {
59 if (name.isWildcard()) {
60 return;
61 }
62 auto it = std::find(other_names.begin(), other_names.end(), name);
63 // TODO(zou3519): Can improve message by checking if names are alignable and suggesting workarounds
64 TORCH_CHECK(it == other_names.end(),
65 "Misaligned dims when attempting to ", action, " dims ", names, " and dims ",
66 other_names, ": dim ", name, " appears in a different position from the right "
67 "across both lists.");
68}
69
70// Assumption: A DimnameList can have no duplicate full names with
71// the exception of wildcards
72std::vector<Dimname> unify_from_right(
73 DimnameList names,
74 DimnameList other_names,
75 const char* action) {
76 const auto wildcard = Dimname::wildcard();
77 const auto size = std::max(names.size(), other_names.size());
78 auto result = std::vector<Dimname>(size, wildcard);
79
80 auto names_it = names.rbegin();
81 auto other_it = other_names.rbegin();
82 auto result_it = result.rbegin();
83 while (names_it != names.rend() || other_it != other_names.rend()) {
84 const auto& name = names_it == names.rend() ? wildcard : *names_it;
85 const auto& other_name = other_it == other_names.rend() ? wildcard : *other_it;
86
87 // Step 1: Check that the names match
88 const auto maybeName = name.unify(other_name);
89 if (!maybeName) {
90 report_positional_error(name, other_name, names, other_names, action);
91 }
92 *result_it = *maybeName;
93
94 // Step 2: Check that the names are not misaligned
95 if (!name.isBasic() || !other_name.isBasic()) {
96 // Let: N = max(len(names), len(other_names))
97 // K = # of special names among names and other_names.
98 // This search (including the outer loop) is O(N*K) but typically # of dims is small.
99 check_for_misalignment(name, names, other_names, action);
100 check_for_misalignment(other_name, other_names, names, action);
101 }
102
103 if (names_it != names.rend()) {
104 ++names_it;
105 }
106 if (other_it != other_names.rend()) {
107 ++other_it;
108 }
109 ++result_it;
110 }
111 return result;
112}
113
114namespace namedinference {
115
116static std::bitset<dim_bitset_size>
117compute_included_idxs(IntArrayRef excluded_idxs, int64_t ndims) {
118 auto result = dim_list_to_bitset(excluded_idxs, ndims);
119 result.flip();
120 return result;
121}
122
123static void assert_names_equal(DimnameList a, DimnameList b) {
124 TORCH_CHECK(a == b,
125 "Name mismatch: specified out tensor with names ", a,
126 " are not the same as the computed output names ", b,
127 ". Please rename the out tensor's dims with `Tensor.rename`.");
128}
129
130const Tensor& propagate_names_if_present_and_nonempty(const Tensor& result,
131 c10::optional<DimnameList> maybe_names,
132 bool validate_names) {
133 auto maybe_name_list = maybe_names.value_or(at::ArrayRef<Dimname>{});
134 propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_name_list, validate_names);
135 return result;
136}
137
138const Tensor& propagate_names_if_nonempty(const Tensor& result,
139 DimnameList maybe_names,
140 bool validate_names) {
141 propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_names, validate_names);
142 return result;
143}
144
145TensorImpl* propagate_names_if_nonempty(TensorImpl* result,
146 DimnameList maybe_names,
147 bool validate_names) {
148 if (maybe_names.empty()) {
149 return result;
150 }
151 return propagate_names(result, maybe_names, validate_names);
152}
153
154const Tensor& propagate_names(const Tensor& result, DimnameList names, bool validate_names) {
155 propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
156 return result;
157}
158
159TensorImpl* propagate_names(TensorImpl* result, DimnameList names, bool validate_names) {
160 if (result->dim() > 0) {
161 TORCH_INTERNAL_ASSERT(
162 !names.empty(),
163 "propagate_names: passed in empty names to propagate to result with",
164 " shape ", result->sizes(), ". Empty names means that name inference did",
165 "not occur; use `propagate_names_if_nonempty` instead of `propagate_names`.");
166 }
167 if (!impl::has_names(result)) {
168 impl::internal_set_names_inplace(result, names, validate_names);
169 } else {
170 assert_names_equal(impl::get_names(result), names);
171 }
172 return result;
173}
174
175void propagate_names_except(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) {
176 if (!result.has_names() && !src.has_names()) {
177 return;
178 }
179 const auto src_names = src.names();
180 const auto result_dim = static_cast<int64_t>(result.dim());
181 const auto src_dim = static_cast<int64_t>(src_names.size());
182 const auto excluded_dim = static_cast<int64_t>(excluded_idxs.size());
183 TORCH_INTERNAL_ASSERT(src_dim - excluded_dim == result_dim);
184
185 // fast path
186 if (excluded_idxs.size() == 1) {
187 std::vector<Dimname> outnames = src_names.vec();
188 outnames.erase(outnames.begin() + maybe_wrap_dim(excluded_idxs[0], src_dim));
189 propagate_names(result, outnames);
190 return;
191 }
192
193 std::vector<Dimname> outnames;
194 outnames.reserve(result_dim);
195 auto included_idxs = compute_included_idxs(excluded_idxs, src_dim);
196 for (const auto dim : c10::irange(src_dim)) {
197 if (included_idxs[dim]) {
198 outnames.push_back(src_names[dim]);
199 }
200 }
201 propagate_names(result, outnames);
202}
203
204void propagate_names_for_reduction(const Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) {
205 if (keepdim) {
206 propagate_names(result, src);
207 return;
208 }
209 // This actually means "full reduction"
210 if (reduced_dims.empty()) {
211 return;
212 }
213 propagate_names_except(result, src, reduced_dims);
214}
215
216void propagate_names(const Tensor& result, const Tensor& src) {
217 propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
218}
219
220void propagate_names(TensorImpl* result, TensorImpl* src) {
221 if (result == src) {
222 return;
223 }
224 if (!impl::has_names(result) && !impl::has_names(src)) {
225 return;
226 }
227 propagate_names(result, impl::get_names(src));
228}
229
230std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor) {
231 if (!tensor.has_names()) {
232 return {};
233 }
234 std::vector<Dimname> outnames;
235 auto tensor_names = tensor.names();
236 for (const auto d : c10::irange(tensor.dim())) {
237 if (tensor.sym_sizes()[d] != 1) {
238 outnames.push_back(tensor_names[d]);
239 }
240 }
241 return outnames;
242}
243
244std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor, std::bitset<dim_bitset_size> dims) {
245 if (!tensor.has_names()) {
246 return {};
247 }
248 std::vector<Dimname> outnames;
249 auto tensor_names = tensor.names();
250 for (const auto d : c10::irange(tensor.dim())) {
251 if (!dims.test(d) || tensor.sym_sizes()[d] != 1) {
252 outnames.push_back(tensor_names[d]);
253 }
254 }
255 return outnames;
256}
257
258std::vector<Dimname> compute_diagonal_outnames(
259 const Tensor& tensor,
260 int64_t dim1,
261 int64_t dim2) {
262 if (!tensor.has_names()) {
263 return {};
264 }
265 std::vector<Dimname> outnames;
266 auto tensor_names = tensor.names();
267 for (const auto d : c10::irange(tensor.dim())) {
268 if (d == dim1 || d == dim2) {
269 continue;
270 }
271 outnames.push_back(tensor_names[d]);
272 }
273 outnames.push_back(Dimname::wildcard());
274 return outnames;
275}
276
277static void check_feature_names_are_distinct(
278 DimnameList self_names,
279 DimnameList other_names,
280 const DimnameList& outnames) {
281 if (self_names.size() < 2 || other_names.size() < 2) {
282 // There are less than 2 feature dims in outnames so there is nothing to check
283 return;
284 }
285 auto feature0 = outnames[outnames.size() - 2];
286 auto feature1 = outnames[outnames.size() - 1];
287 TORCH_CHECK(
288 feature0 == Dimname::wildcard() || feature0 != feature1,
289 "Matrix multiplying Tensor", self_names,
290 " with Tensor", other_names,
291 " would produce output tensor with duplicate names ",
292 outnames,
293 ". Please rename the input tensors with `Tensor.rename` to prevent this.");
294}
295
296static int64_t num_batch_dims(DimnameList names) {
297 if (names.size() <= 2) {
298 return 0;
299 }
300 return names.size() - 2;
301}
302
303static std::vector<Dimname> compute_matmul_outnames(
304 DimnameList self_names,
305 DimnameList other_names) {
306 TORCH_CHECK(!self_names.empty() && !other_names.empty(),
307 "both arguments to matmul need to be at least 1D, but they are ",
308 self_names.size(), "D and ", other_names.size(), "D");
309
310 // matmul performs a batch matrix multiply between self and other, each of which
311 // can either be:
312 // - a batches of matrices (if dim > 2)
313 // - a matrix (if dim == 2)
314 // - a vector (if dim == 1)
315 //
316 // To compute output names, we unify the batch dimensions because those are
317 // broadcastable to get the output batch dimensions.
318 //
319 // After that, we append some names that are equal to the result of the matmul
320 // without batch dimensions. Those names are computed by removing the names
321 // of the dimensions that were contracted away. We always contract the
322 // last dim of the first tensor with the first feature dimension of the second.
323
324 // Get the output's batch dimension names
325 auto wrapped_self_names = TensorNames(self_names, 0, num_batch_dims(self_names));
326 const auto wrapped_other_names = TensorNames(other_names, 0, num_batch_dims(other_names));
327 auto& working_names = wrapped_self_names.unifyFromRightInplace(wrapped_other_names, "matmul");
328
329 // Append the result of each individual (non-batched) matmul.
330 // If either of self or other have dim 1, that means they are a vector. Vectors get
331 // completely contracted away during matmul so we don't take any names from them.
332 if (self_names.size() >= 2) {
333 working_names.append(TensorName(self_names, -2));
334 }
335 if (other_names.size() >= 2) {
336 working_names.append(TensorName(other_names, -1));
337 }
338 auto result = working_names.toDimnameVec();
339
340 check_feature_names_are_distinct(self_names, other_names, result);
341 return result;
342}
343
344std::vector<Dimname> propagate_names_for_addmv(
345 const Tensor& mat,
346 const Tensor& vec,
347 const Tensor& bias) {
348 if (!mat.has_names() &&
349 !vec.has_names() && !bias.has_names()) {
350 return std::vector<Dimname>{};
351 }
352 auto mv_outnames = compute_matmul_outnames(mat.names(), vec.names());
353 return unify_from_right(mv_outnames, bias.names());
354}
355
356std::vector<Dimname> propagate_names_for_addmm(
357 const Tensor& m1,
358 const Tensor& m2,
359 const Tensor& bias) {
360 if (!m1.has_names() && !m2.has_names() &&
361 !bias.has_names()) {
362 return std::vector<Dimname>{};
363 }
364
365 auto mm_outnames = compute_matmul_outnames(m1.names(), m2.names());
366 return unify_from_right(mm_outnames, bias.names());
367}
368
369void check_names_for_dot(
370 TensorImpl* vec1,
371 TensorImpl* vec2) {
372 if (!impl::has_names(vec1) && !impl::has_names(vec2)) {
373 return;
374 }
375 compute_matmul_outnames(impl::get_names(vec1), impl::get_names(vec2));
376}
377
378// expand adds new None dimensions. This is consistent with name inference
379// rules for binary ops that expect the named dims to line up positionally
380// from the right. i.e.,
381// Tensor[H, W].expand(3, 3, 3, 3) -> Tensor[None, None, H, W]
382void propagate_names_for_expand(const Tensor& result, const Tensor& self) {
383 if (!self.has_names()) {
384 return;
385 }
386 auto result_dim = result.dim();
387 if (self.dim() == result_dim) {
388 propagate_names(result, self);
389 return;
390 }
391 std::vector<Dimname> outnames(result_dim, Dimname::wildcard());
392 std::copy(
393 self.opt_names()->begin(),
394 self.opt_names()->end(),
395 outnames.begin() + result_dim - self.dim());
396 propagate_names(result, outnames);
397}
398
399std::vector<Dimname> compute_broadcast_outnames(
400 const Tensor& self,
401 const Tensor& other) {
402 if (!self.has_names() && !other.has_names()) {
403 return {};
404 }
405 return unify_from_right(self.names(), other.names());
406}
407
408std::vector<Dimname> broadcast_to_outnames(
409 const Tensor& tensor,
410 const Tensor& reference_tensor,
411 const char* op_name) {
412 if (!tensor.has_names() && !reference_tensor.has_names()) {
413 return {};
414 }
415 auto reference_names = reference_tensor.names();
416 auto tensor_names = tensor.names();
417 TORCH_CHECK(
418 reference_names.size() >= tensor_names.size(),
419 op_name, ": attempted to broadcast Tensor", tensor_names, " to Tensor",
420 reference_names, " but the number of dims (", tensor_names.size(),
421 ") must be less than or equal to the number of dims in the tensor (",
422 reference_names.size(), ")");
423 return unify_from_right(reference_names, tensor_names);
424}
425
426std::vector<Dimname> compute_cat_outnames(const MaterializedITensorListRef& tensors) {
427 if (!at::has_names(tensors)) {
428 return {};
429 }
430 std::vector<Dimname> result;
431 for (const Tensor& tensor : tensors) {
432 const auto tensor_names = tensor.names();
433 TORCH_CHECK(!tensor_names.empty(), "zero-dimensional tensor cannot be concatenated");
434 TORCH_CHECK(result.empty() || tensor_names.size() == result.size(),
435 "Tensors must have same number of dimensions: got ", result.size(),
436 " and ", tensor_names.size());
437 result = unify_from_right(result, tensor_names, "cat");
438 }
439 return result;
440}
441
442std::vector<Dimname> compute_matmul_outnames(
443 const Tensor& self,
444 const Tensor& other) {
445 if (!self.has_names() && !other.has_names()) {
446 return {};
447 }
448 return compute_matmul_outnames(self.names(), other.names());
449}
450
451std::vector<Dimname> compute_cdist_outnames(
452 const Tensor& self,
453 const Tensor& other) {
454 if (!self.has_names() && !other.has_names()) {
455 return {};
456 }
457 const auto self_names = self.names();
458 const auto other_names = other.names();
459
460 auto self_batch = TensorNames(self_names, 0, num_batch_dims(self_names));
461 const auto other_batch = TensorNames(other_names, 0, num_batch_dims(other_names));
462
463 auto& result = self_batch.unifyFromRightInplace(other_batch, "cdist");
464
465 // cdist treats self and other like batches of M x D and N X D tensors, respectively.
466 // It computes the pairwise distance between each of the M vectors (of size D)
467 // in `self` and each of the N vectors in `other`, returning a batch of M x N
468 // distance values. We propagate the names of the dimension of size M (in self)
469 // and the dimension of size N (in other), both of which are second-from-last.
470 result.append(TensorName(self_names, -2));
471 result.append(TensorName(other_names, -2));
472 result.checkUnique("cdist");
473
474 return result.toDimnameVec();
475}
476
477std::vector<Dimname> compute_bmm_outnames(
478 const Tensor& result,
479 const Tensor& self,
480 const Tensor& other) {
481 if (!result.has_names() && !self.has_names() && !other.has_names()) {
482 return {};
483 }
484 return compute_matmul_outnames(self.names(), other.names());
485}
486
487std::vector<Dimname> compute_baddbmm_outnames(
488 const Tensor& result,
489 const Tensor& self,
490 const Tensor& other,
491 const Tensor& bias) {
492 if (!result.has_names() && !self.has_names()
493 && !other.has_names() && !bias.has_names()) {
494 return {};
495 }
496 auto bmm_names = compute_matmul_outnames(self.names(), other.names());
497 auto baddbmm_names = unify_from_right(bias.names(), bmm_names);
498 return baddbmm_names;
499}
500
501bool are_names_equal(TensorImpl* self, TensorImpl* other) {
502 if (!impl::has_names(self) && !impl::has_names(other)) {
503 return true;
504 }
505 return impl::get_names(self) == impl::get_names(other);
506}
507
508} // namespace namedinference
509} // namespace at
510