1 | #include <ATen/TensorNames.h> |
2 | #include <ATen/WrapDimUtils.h> |
3 | #include <c10/util/irange.h> |
4 | |
5 | namespace at { namespace namedinference { |
6 | |
7 | |
8 | Dimname TensorName::toDimname() const { |
9 | return name_; |
10 | } |
11 | |
12 | const TensorName& TensorName::unify(const TensorName& other, const char* op_name) const { |
13 | // unify(None, None) |
14 | if (name_.isWildcard() && other.name_.isWildcard()) { |
15 | return *this; |
16 | } |
17 | |
18 | // unify(A, A) |
19 | if (name_ == other.name_) { |
20 | return *this; |
21 | } |
22 | |
23 | // unify(A, None) |
24 | if (other.name_.isWildcard()) { |
25 | const auto it = std::find(other.origin_.begin(), other.origin_.end(), name_); |
26 | TORCH_CHECK(it == other.origin_.end(), |
27 | op_name, ":" , |
28 | " Cannot match " , *this, " with " , other, |
29 | " because the latter names already have " , name_, "." , |
30 | " Are your tensors misaligned?" ); |
31 | return *this; |
32 | } |
33 | |
34 | // unify(None, A) |
35 | if (name_.isWildcard()) { |
36 | return other.unify(*this, op_name); |
37 | } |
38 | |
39 | // unify(A, B) |
40 | TORCH_CHECK(name_ == other.name_, |
41 | op_name, ":" , |
42 | " Expected " , *this, |
43 | " to match " , other, |
44 | " but they do not match." ); |
45 | return *this; |
46 | } |
47 | |
48 | TensorNames::TensorNames(ArrayRef<Dimname> names) { |
49 | names_.reserve(names.size()); |
50 | for (const auto idx : c10::irange(names.size())) { |
51 | names_.emplace_back(names, idx); |
52 | } |
53 | } |
54 | |
55 | TensorNames::TensorNames(ArrayRef<Dimname> names, int64_t start, int64_t end) { |
56 | start = maybe_wrap_dim(start, names.size()); |
57 | end = maybe_wrap_dim(end, names.size()); |
58 | names_.reserve(end - start); |
59 | for (const auto idx : c10::irange(start, end)) { |
60 | names_.emplace_back(names, idx); |
61 | } |
62 | } |
63 | |
64 | TensorNames& TensorNames::unifyFromRightInplace(const TensorNames& other, const char* op_name) { |
65 | |
66 | if (names_.size() > other.names_.size()) { |
67 | const auto size_diff = names_.size() - other.names_.size(); |
68 | for (const auto idx : c10::irange(size_diff, names_.size())) { |
69 | names_[idx] = names_[idx].unify(other.names_[idx - size_diff], op_name); |
70 | } |
71 | } else { |
72 | const auto size_diff = other.names_.size() - names_.size(); |
73 | // pad names_ to the same length as other.names_ before unification |
74 | names_.insert( |
75 | names_.begin(), |
76 | other.names_.begin(), |
77 | other.names_.begin() + size_diff); |
78 | for (const auto idx : c10::irange(size_diff, names_.size())) { |
79 | names_[idx] = names_[idx].unify(other.names_[idx], op_name); |
80 | } |
81 | } |
82 | |
83 | return *this; |
84 | } |
85 | |
86 | void TensorNames::append(TensorName&& name) { |
87 | names_.emplace_back(name); |
88 | } |
89 | |
90 | void TensorNames::checkUnique(const char* op_name) const { |
91 | // O(N^2), but named tensors can have at most N = 64 dimensions, so this |
92 | // doesn't matter unless benchmarking tells us it does. The alternative is |
93 | // to create some sort of set data structure but the overhead of that |
94 | // might dominate for small sizes. |
95 | for (auto it = names_.begin(); it != names_.end(); ++it) { |
96 | const auto name = it->toDimname(); |
97 | if (name.isWildcard()) continue; |
98 | |
99 | auto dup = std::find_if(it + 1, names_.end(), |
100 | [&](const TensorName& other) { return other.toDimname() == name; }); |
101 | TORCH_CHECK(dup == names_.end(), |
102 | op_name, ": " , |
103 | "Attempted to propagate dims " , *it, " and " , *dup, " to the output, " , |
104 | "but that would create a tensor with duplicate names [" , toDimnameVec(), |
105 | "]. Please rename your inputs with Tensor.rename to prevent this." ); |
106 | } |
107 | } |
108 | |
109 | // Let's say the TensorName represents 'C' in ['N', 'C', 'H, 'W']. |
110 | // It should print like: |
111 | // 'C' (index 1 of ['N', 'C', 'H', 'W']) |
112 | std::ostream& operator<<(std::ostream& out, const TensorName& tensorname) { |
113 | out << tensorname.name_ << " (index " ; |
114 | out << tensorname.origin_idx_ << " of " ; |
115 | out << tensorname.origin_ << ")" ; |
116 | return out; |
117 | } |
118 | |
119 | std::vector<Dimname> TensorNames::toDimnameVec() const { |
120 | std::vector<Dimname> result; |
121 | result.reserve(names_.size()); |
122 | for (const auto& tensor_name : names_) { |
123 | result.emplace_back(tensor_name.toDimname()); |
124 | } |
125 | return result; |
126 | } |
127 | |
128 | |
129 | }} // namespace at::namedinference |
130 | |