1#pragma once
2
3#include <ATen/WrapDimUtils.h>
4
5namespace at {
6namespace namedinference {
7
8// TensorName and TensorNames are wrappers around Dimname and DimnameList
9// that contain helper functions to make writing name inference rules easier.
10//
11// A TensorName represents a Dimname associated with some DimnameList (from a
12// Tensor). This encapsulates all the information that is needed to check if
13// names *match* and to *unify* names.
14//
15// Definition: Two names in two tensors *match* if they are equal, or if at
16// least one of them is a wildcard that can be *refined* to the other name.
17//
18// Definition: unify(name, other) fails if the names do not match. Otherwise,
19// it returns the most refined of name and other.
20//
21// Here is an example of checking if two names match.
22// tensor: Tensor[A, None]
23// other: Tensor[A]
24//
25// Let's say we wish to check if tensor.names[-1] matches other.names[-1].
26// None (in tensor) cannot match A (in other) because if the None were refined
27// to A, `tensor` would have duplicate names [A, A]. Therefore we need to check
28// tensor.names [A, None] for the existence of A.
29struct TORCH_API TensorName {
30 explicit TensorName(ArrayRef<Dimname> origin, int origin_idx)
31 : origin_(origin),
32 name_(origin[maybe_wrap_dim(origin_idx, origin.size())]),
33 origin_idx_(origin_idx) {}
34
35 // op_name is only used for error reporting.
36 const TensorName& unify(const TensorName& other, const char* op_name) const;
37 Dimname toDimname() const;
38
39 private:
40 ArrayRef<Dimname> origin_;
41 Dimname name_;
42 int origin_idx_; // A named tensor can have at most 64 dims.
43
44 TORCH_API friend std::ostream& operator<<(
45 std::ostream& out,
46 const TensorName& tensorname);
47};
48
49using TensorNameVec = SmallVector<TensorName, 10>;
50
51struct TORCH_API TensorNames {
52 explicit TensorNames(ArrayRef<Dimname> names);
53
54 // Create TensorNames from names[start:end]. Each individual TensorName stores
55 // `names`, NOT names[start:end], because the original tensor's names are
56 // `names`.
57 explicit TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end);
58
59 // op_name is only used for error reporting.
60 TensorNames& unifyFromRightInplace(
61 const TensorNames& other,
62 const char* op_name = "unify");
63 void checkUnique(const char* op_name) const;
64
65 void append(TensorName&& name);
66 std::vector<Dimname> toDimnameVec() const;
67
68 private:
69 explicit TensorNames(TensorNameVec&& names) : names_(names){};
70
71 TensorNameVec names_;
72};
73
74} // namespace namedinference
75} // namespace at
76