1#pragma once
2#include <ATen/NamedTensor.h>
3#include <ATen/TensorNames.h>
4#include <ATen/WrapDimUtilsMulti.h>
5
6#include <ATen/core/DimVector.h>
7#include <ATen/core/Tensor.h>
8#include <functional>
9
10namespace at {
11
12using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
13
14inline bool has_names(ITensorListRef tensors) {
15 return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) {
16 return t.has_names();
17 });
18}
19
20// Converts dim to an positional index. Errors if `dim` cannot be used to
21// refer to any dimension of tensor.
22TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim);
23TORCH_API std::vector<int64_t> dimnames_to_positions(
24 const Tensor& tensor,
25 DimnameList dims);
26
27// Unifies two DimnameList to produce a third. This is useful for implementing
28// the named inference rule for binary broadcasting operations like add.
29//
30// There are three main constraints:
31// 1) Check matching: Names must match positionally from the right.
32// 2) Check misaligned: If a name `n` is in `names`, then it must appear at
33// the same index from the right in other.
34// 3) The output names are obtained by unifying the names individually from the
35// right.
36TORCH_API std::vector<Dimname> unify_from_right(
37 DimnameList names,
38 DimnameList other,
39 const char* action = "broadcast");
40
41[[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) {
42 TORCH_CHECK(
43 false,
44 op_name,
45 ": You passed a dimname (string) to this op in place of a dimension "
46 "index but it does not yet support this behavior. Please pass a dimension "
47 "index to work around this.");
48}
49
50// [NOTE] Writing name inference rules
51//
52// Operators that support named tensors are either composed of operations that
53// support named tensors or implement some name inference rule. An op that
54// implements its own name inference rule generally looks like the following:
55//
56// Tensor op(...) {
57// perform_shape_checks(...);
58// # (1)
59// auto maybe_outnames = compute_outnames(...);
60// auto result = [&]() {
61// NoNamesGuard guard;
62// return op_impl(...);
63// }();
64// # (2)
65// propagate_names_if_nonempty(result, maybe_outnames);
66//
67// Each op has (1) a compute outnames step and (2) a propagate names step.
68//
69// compute_outnames is responsible for checking that input names match and
70// determining what the output names should be. It returns either:
71// - {} (if the inputs tensors are all unnamed)
72// - non-empty outnames.
73//
74// propagate_names_if_nonempty propagates the outnames if they exist to the
75// result tensors.
76//
77// The {} case is an optimization; if the user does not use named tensors they
78// pay no perf cost for it.
79
80namespace namedinference {
81
82const Tensor& propagate_names_if_present_and_nonempty(
83 const Tensor& result,
84 c10::optional<DimnameList> maybe_names,
85 bool validate_names = false);
86// Propagates `names` to `result` if `names` is not empty.
87// `names` can be empty; see [NOTE] Writing name inference rules
88// If `names` is not empty, `names.size()` should equal `result.dim()`.
89// When in doubt, use this overload instead of the others.
90TORCH_API const Tensor& propagate_names_if_nonempty(
91 const Tensor& result,
92 DimnameList maybe_names,
93 bool validate_names = false);
94
95// Propagates `names` to `result`. Only use this if we are certain that there
96// are names to propagate (that names is not empty).
97TORCH_API const Tensor& propagate_names(
98 const Tensor& result,
99 DimnameList names,
100 bool validate_names = false);
101
102// Propagates all names from src to result.
103TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
104
105// Propagates all names except for those at the excluded_idxs.
106TORCH_API void propagate_names_except(
107 const Tensor& result,
108 const Tensor& src,
109 IntArrayRef excluded_idxs);
110
111// Used for reduction ops that have a `keepdim` arg.
112TORCH_API void propagate_names_for_reduction(
113 const Tensor& result,
114 const Tensor& src,
115 IntArrayRef excluded_idxs,
116 bool keepdim);
117
118TORCH_API void propagate_names_for_expand(
119 const Tensor& result,
120 const Tensor& self);
121
122TORCH_API std::vector<Dimname> compute_cat_outnames(
123 const MaterializedITensorListRef& tensors);
124
125TORCH_API std::vector<Dimname> compute_broadcast_outnames(
126 const Tensor& self,
127 const Tensor& other);
128
129TORCH_API std::vector<Dimname> broadcast_to_outnames(
130 const Tensor& tensor,
131 const Tensor& reference_tensor,
132 const char* op_name);
133
134TORCH_API std::vector<Dimname> compute_matmul_outnames(
135 const Tensor& self,
136 const Tensor& other);
137
138TORCH_API std::vector<Dimname> compute_cdist_outnames(
139 const Tensor& self,
140 const Tensor& other);
141
142TORCH_API std::vector<Dimname> compute_bmm_outnames(
143 const Tensor& result,
144 const Tensor& self,
145 const Tensor& other);
146
147TORCH_API std::vector<Dimname> compute_squeeze_outnames(const Tensor& tensor);
148TORCH_API std::vector<Dimname> compute_squeeze_outnames(
149 const Tensor& tensor,
150 std::bitset<dim_bitset_size> dims);
151
152std::vector<Dimname> compute_diagonal_outnames(
153 const Tensor& tensor,
154 int64_t dim1,
155 int64_t dim2);
156
157// TensorImpl* overloads for Legacy TH/THC code. Use these sparingly.
158
159TORCH_API TensorImpl* propagate_names_if_nonempty(
160 TensorImpl* result,
161 DimnameList maybe_names,
162 bool validate_names = false);
163
164TORCH_API TensorImpl* propagate_names(
165 TensorImpl* result,
166 DimnameList names,
167 bool validate_names = false);
168
169TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src);
170
171TORCH_API inline void propagate_names(
172 const TensorBase& result,
173 DimnameList names,
174 bool validate_names = false) {
175 propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
176}
177
178TORCH_API inline void propagate_names_if_nonempty(
179 const TensorBase& result,
180 DimnameList names,
181 bool validate_names = false) {
182 propagate_names_if_nonempty(
183 result.unsafeGetTensorImpl(), names, validate_names);
184}
185
186TORCH_API inline void propagate_names(
187 const TensorBase& result,
188 const TensorBase& src) {
189 propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
190}
191
192// result = m1 @ m2 + bias
193TORCH_API std::vector<Dimname> propagate_names_for_addmm(
194 const Tensor& m1,
195 const Tensor& m2,
196 const Tensor& bias);
197
198TORCH_API std::vector<Dimname> propagate_names_for_addmv(
199 const Tensor& mat,
200 const Tensor& vec,
201 const Tensor& bias);
202
203TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2);
204
205TORCH_API std::vector<Dimname> compute_baddbmm_outnames(
206 const Tensor& result,
207 const Tensor& self,
208 const Tensor& other,
209 const Tensor& bias);
210
211TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other);
212
213} // namespace namedinference
214
215} // namespace at
216