1#include <ATen/TensorNames.h>
2#include <ATen/WrapDimUtils.h>
3#include <c10/util/irange.h>
4
5namespace at { namespace namedinference {
6
7
8Dimname TensorName::toDimname() const {
9 return name_;
10}
11
12const TensorName& TensorName::unify(const TensorName& other, const char* op_name) const {
13 // unify(None, None)
14 if (name_.isWildcard() && other.name_.isWildcard()) {
15 return *this;
16 }
17
18 // unify(A, A)
19 if (name_ == other.name_) {
20 return *this;
21 }
22
23 // unify(A, None)
24 if (other.name_.isWildcard()) {
25 const auto it = std::find(other.origin_.begin(), other.origin_.end(), name_);
26 TORCH_CHECK(it == other.origin_.end(),
27 op_name, ":",
28 " Cannot match ", *this, " with ", other,
29 " because the latter names already have ", name_, ".",
30 " Are your tensors misaligned?");
31 return *this;
32 }
33
34 // unify(None, A)
35 if (name_.isWildcard()) {
36 return other.unify(*this, op_name);
37 }
38
39 // unify(A, B)
40 TORCH_CHECK(name_ == other.name_,
41 op_name, ":",
42 " Expected ", *this,
43 " to match ", other,
44 " but they do not match.");
45 return *this;
46}
47
48TensorNames::TensorNames(ArrayRef<Dimname> names) {
49 names_.reserve(names.size());
50 for (const auto idx : c10::irange(names.size())) {
51 names_.emplace_back(names, idx);
52 }
53}
54
55TensorNames::TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end) {
56 start = maybe_wrap_dim(start, names.size());
57 end = maybe_wrap_dim(end, names.size());
58 names_.reserve(end - start);
59 for (const auto idx : c10::irange(start, end)) {
60 names_.emplace_back(names, idx);
61 }
62}
63
64TensorNames& TensorNames::unifyFromRightInplace(const TensorNames& other, const char* op_name) {
65
66 if (names_.size() > other.names_.size()) {
67 const auto size_diff = names_.size() - other.names_.size();
68 for (const auto idx : c10::irange(size_diff, names_.size())) {
69 names_[idx] = names_[idx].unify(other.names_[idx - size_diff], op_name);
70 }
71 } else {
72 const auto size_diff = other.names_.size() - names_.size();
73 // pad names_ to the same length as other.names_ before unification
74 names_.insert(
75 names_.begin(),
76 other.names_.begin(),
77 other.names_.begin() + size_diff);
78 for (const auto idx : c10::irange(size_diff, names_.size())) {
79 names_[idx] = names_[idx].unify(other.names_[idx], op_name);
80 }
81 }
82
83 return *this;
84}
85
86void TensorNames::append(TensorName&& name) {
87 names_.emplace_back(name);
88}
89
90void TensorNames::checkUnique(const char* op_name) const {
91 // O(N^2), but named tensors can have at most N = 64 dimensions, so this
92 // doesn't matter unless benchmarking tells us it does. The alternative is
93 // to create some sort of set data structure but the overhead of that
94 // might dominate for small sizes.
95 for (auto it = names_.begin(); it != names_.end(); ++it) {
96 const auto name = it->toDimname();
97 if (name.isWildcard()) continue;
98
99 auto dup = std::find_if(it + 1, names_.end(),
100 [&](const TensorName& other) { return other.toDimname() == name; });
101 TORCH_CHECK(dup == names_.end(),
102 op_name, ": ",
103 "Attempted to propagate dims ", *it, " and ", *dup, " to the output, ",
104 "but that would create a tensor with duplicate names [", toDimnameVec(),
105 "]. Please rename your inputs with Tensor.rename to prevent this.");
106 }
107}
108
109// Let's say the TensorName represents 'C' in ['N', 'C', 'H, 'W'].
110// It should print like:
111// 'C' (index 1 of ['N', 'C', 'H', 'W'])
112std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) {
113 out << tensorname.name_ << " (index ";
114 out << tensorname.origin_idx_ << " of ";
115 out << tensorname.origin_ << ")";
116 return out;
117}
118
119std::vector<Dimname> TensorNames::toDimnameVec() const {
120 std::vector<Dimname> result;
121 result.reserve(names_.size());
122 for (const auto& tensor_name : names_) {
123 result.emplace_back(tensor_name.toDimname());
124 }
125 return result;
126}
127
128
129}} // namespace at::namedinference
130