1#include <c10/core/DispatchKeySet.h>
2#include <c10/util/irange.h>
3#include <iostream>
4
5namespace c10 {
6
7// backend_dispatch_keyset includes all dispatch keys that map to backends.
8// Alias key DispatchKey::CompositeExplicitAutograd maps to
9// backend_dispatch_keyset
10constexpr DispatchKeySet backend_dispatch_keyset =
11 autogradother_backends | DispatchKeySet(DispatchKey::Dense);
12
13// See Note [CompositeExplicitAutogradNonFunctional Key]
14// We have several types of decompositions in aten, that each have their own
15// alias key. You should register your decomposition to the
16// `CompositeExplicitAutogradNonFunctional key` if: (1) It's an out-of-place op
17// (2) It decomposes into one more mutation ops
18// (3) It has a derivative formula
19// (In theory we could also have a separate key for
20// "CompositeImplicitAutogradNonFunctional", but there isn't much of a use
21// case for it currently).
22// This key is important for "functional" backends like LazyTensor / XLA.
23// If you're a backend that only expects to deal with "functional ops",
24// then you don't want to decompose a functional op into an op that causes
25// aliasing. You should just directly write a kernel for that functional op
26// instead!
27constexpr DispatchKeySet non_functional_backend_dispatch_keyset =
28 backend_dispatch_keyset
29 // XLA and LazyTensor are currently the only 2 backends in core
30 // that use functionalization pass in eager mode.
31 .remove(DispatchKey::Sparse)
32 .remove_backend(BackendComponent::XLABit)
33 .remove_backend(BackendComponent::LazyBit);
34
35bool isBackendDispatchKey(DispatchKey t) {
36 return t != DispatchKey::Undefined
37 // See Note [No Alias Keys in DispatchKeySet]
38 && !isAliasDispatchKey(t)
39 // Note [NestedTensor Not Included in Backend Keys]
40 // NestedTensor has been explicitly removed from the "backend keyset" due
41 // to incompatibility with some kernels, so we don't want it to be
42 // included in CompositeExplicitAutograd kernels.
43 && t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t);
44}
45
46// math_dispatch_keyset contains all keys in backend_dispatch_keyset and
47// autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd
48// maps to [math_dispatch_keyset x full_backend_mask]
49constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset |
50 autograd_dispatch_keyset |
51 // See Note [NestedTensor Not Included in Backend Keys]
52 // The caveat to that note is that nested_tensor is a special case
53 // where we would like to support composite implicit kernels but not
54 // explicit kernels therefore we manually add the key to the
55 // math_dispatch_keyset
56 DispatchKeySet{DispatchKey::NestedTensor};
57
58constexpr DispatchKeySet nested_dispatch_keyset =
59 DispatchKeySet(
60 {DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) |
61 DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
62
63DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
64 TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
65 switch (t) {
66 case DispatchKey::Autograd:
67 // See Note [autograd_dispatch_keyset Does Not Include Backend Bits]
68 // That's why we OR it with a mask of the backend bits here.
69 // getRuntimeDispatchKeySet() expects to return a keyset of runtime
70 // dispatch keys, like AutogradCPU, but that requires having backend bits.
71 return autograd_dispatch_keyset |
72 DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
73 case DispatchKey::CompositeImplicitAutograd:
74 return math_dispatch_keyset;
75 case DispatchKey::CompositeImplicitAutogradNestedTensor:
76 return nested_dispatch_keyset;
77 case DispatchKey::CompositeExplicitAutograd:
78 return backend_dispatch_keyset;
79 case DispatchKey::CompositeExplicitAutogradNonFunctional:
80 return non_functional_backend_dispatch_keyset;
81 default:
82 return DispatchKeySet(t);
83 }
84}
85
86bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) {
87 TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined);
88 switch (t) {
89 case DispatchKey::Autograd:
90 return autograd_dispatch_keyset.has(toFunctionalityKey(k));
91 case DispatchKey::CompositeImplicitAutograd:
92 // See Note [NestedTensor Not Included in Backend Keys]
93 return math_dispatch_keyset.has(k);
94 case DispatchKey::CompositeImplicitAutogradNestedTensor:
95 // See Note [NestedTensor Not Included in Backend Keys]
96 return nested_dispatch_keyset.has(k);
97 case DispatchKey::CompositeExplicitAutograd:
98 // See Note [NestedTensor Not Included in Backend Keys]
99 return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k);
100 case DispatchKey::CompositeExplicitAutogradNonFunctional:
101 // See Note [NestedTensor Not Included in Backend Keys]
102 return k != DispatchKey::NestedTensor &&
103 non_functional_backend_dispatch_keyset.has(k);
104 case DispatchKey::FuncTorchBatchedDecomposition:
105 return functorch_batched_ks.has(k);
106 default:
107 return t == k;
108 }
109}
110
111// for a given autograd key, return the (guaranteed nonempty) set of associated
112// backend keys. for a non-autograd key, return the empty keyset.
113DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
114 switch (t) {
115 case DispatchKey::AutogradCPU:
116 return DispatchKeySet(DispatchKey::CPU);
117 case DispatchKey::AutogradCUDA:
118 return DispatchKeySet(DispatchKey::CUDA);
119 case DispatchKey::AutogradXLA:
120 return DispatchKeySet(DispatchKey::XLA);
121 case DispatchKey::AutogradLazy:
122 return DispatchKeySet(DispatchKey::Lazy);
123 case DispatchKey::AutogradMeta:
124 return DispatchKeySet(DispatchKey::Meta);
125 case DispatchKey::AutogradMPS:
126 return DispatchKeySet(DispatchKey::MPS);
127 case DispatchKey::AutogradHPU:
128 return DispatchKeySet(DispatchKey::HPU);
129 case DispatchKey::AutogradIPU:
130 return DispatchKeySet(DispatchKey::IPU);
131 case DispatchKey::AutogradXPU:
132 return DispatchKeySet(DispatchKey::XPU);
133 case DispatchKey::AutogradPrivateUse1:
134 return DispatchKeySet(DispatchKey::PrivateUse1);
135 case DispatchKey::AutogradPrivateUse2:
136 return DispatchKeySet(DispatchKey::PrivateUse2);
137 case DispatchKey::AutogradPrivateUse3:
138 return DispatchKeySet(DispatchKey::PrivateUse3);
139 case DispatchKey::AutogradNestedTensor:
140 return DispatchKeySet(DispatchKey::NestedTensor) |
141 DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
142 case DispatchKey::AutogradOther:
143 return autogradother_backends;
144 default:
145 return DispatchKeySet();
146 }
147}
148
149bool isIncludedInAlias(DispatchKey k, DispatchKey alias) {
150 return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k);
151}
152
153std::string toString(DispatchKeySet ts) {
154 std::stringstream ss;
155 ss << ts;
156 return ss.str();
157}
158
159std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) {
160 if (ts.empty()) {
161 os << "DispatchKeySet()";
162 return os;
163 }
164 os << "DispatchKeySet(";
165 bool first = true;
166 for (auto k : ts) {
167 if (!first) {
168 os << ", ";
169 }
170 os << k;
171 first = false;
172 }
173 os << ")";
174 return os;
175}
176
177DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() {
178 TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val);
179 TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends, next_backend_);
180
181 // Create a masked version of the set representation to ignore previous
182 // keys that we've iterated through.
183 uint64_t masked_functionality_bits =
184 llvm::maskTrailingZeros<uint64_t>(next_functionality_) & *data_ptr_;
185 uint64_t masked_backend_bits =
186 llvm::maskTrailingZeros<uint64_t>(next_backend_) & full_backend_mask &
187 *data_ptr_;
188
189 uint64_t first_functionality_idx =
190 llvm::findFirstSet(masked_functionality_bits);
191 uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits);
192
193 // If there are no keys, set to end iterator value
194 if (first_functionality_idx == std::numeric_limits<uint64_t>::max() ||
195 next_functionality_ == iterator::end_iter_mask_val) {
196 // Set up state to be the same as end()
197 next_functionality_ = iterator::end_iter_mask_val;
198 current_dispatchkey_idx_ = iterator::end_iter_key_val;
199 next_backend_ = 0;
200 current_backendcomponent_idx_ = iterator::end_iter_key_val;
201 return *this;
202 }
203
204 // The +1 is because of DispatchKey::Undefined and
205 // BackendComponent::InvalidBit
206 auto new_next_functionality = first_functionality_idx + 1;
207 auto new_backendcomponent_idx = first_backendcomponent_idx + 1;
208 // and the -num_backends is because the first <num_backends> bits in the
209 // keyset are not Dispatch Keys.
210 auto next_dispatchkey_idx = new_next_functionality - num_backends;
211
212 // If the current functionality bit is a per-backend bit, we need special
213 // handling
214 if (isPerBackendFunctionalityKey(
215 static_cast<DispatchKey>(next_dispatchkey_idx))) {
216 // case 1: if the current backend is undefined, then there is no valid
217 // backend instance of this functionality key so we can skip it.
218 if (first_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
219 // increment the functionality mask so we skip the current functionality
220 // bit on the next increment.
221 next_functionality_ = new_next_functionality;
222 ++(*this);
223 return *this;
224 }
225
226 // Otherwise, at this point we know what the current backend and
227 // functionality bits are.
228 current_dispatchkey_idx_ = next_dispatchkey_idx;
229 current_backendcomponent_idx_ = new_backendcomponent_idx;
230
231 // Next, we need to set up the masks for the next increment.
232 uint64_t next_backendcomponent_bits =
233 llvm::maskTrailingZeros<uint64_t>(first_backendcomponent_idx + 1) &
234 full_backend_mask & *data_ptr_;
235 uint64_t next_backendcomponent_idx =
236 llvm::findFirstSet(next_backendcomponent_bits);
237 if (next_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) {
238 // case 2: the current backend is valid, but there is not another backend
239 // in the keyset. In this case, we need to bump the functionality mask and
240 // reset the backend mask for the next increment
241 next_functionality_ = new_next_functionality;
242 next_backend_ = 0;
243 } else {
244 // case 3: we have another backend to iterate over. We want to iterate
245 // over the same functionality bit next time, but a different backend bit.
246 next_backend_ = first_backendcomponent_idx + 1;
247 }
248 } else {
249 // Functionality bits that aren't per backend are simpler to handle. We can
250 // ignore the backend bits.
251 TORCH_INTERNAL_ASSERT(next_backend_ == 0);
252 current_dispatchkey_idx_ = next_dispatchkey_idx;
253 next_functionality_ = new_next_functionality;
254 }
255 return *this;
256}
257
258std::array<FunctionalityOffsetAndMask, num_functionality_keys>
259initializeFunctionalityOffsetsAndMasks() {
260 std::array<FunctionalityOffsetAndMask, num_functionality_keys>
261 offsets_and_masks;
262 // manualy set the first entry, which corresponds to Undefined.
263 offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0);
264 // loop through every functionality key (aside from Undefined).
265 for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) {
266 // functionality_idx should be Dense -> 1, ...
267 auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1];
268 auto k = static_cast<DispatchKey>(functionality_idx);
269
270 // If the previous functionality was not per-backend, then we can just
271 // increment the previous offset. Otherwise, the next offset =
272 // previous_offset + num_backends.
273 auto next_offset = prev_offset_and_mask.offset +
274 (prev_offset_and_mask.mask == 0 ? 1 : num_backends);
275 // the mask is used in the runtime index calculation to find the offset of
276 // the backend. For non-per-backend functionalities, this offset should
277 // always be 0. Otherwise, we need to get the index of the backend (which we
278 // can do using a backend mask).
279 auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0;
280 offsets_and_masks[functionality_idx] =
281 FunctionalityOffsetAndMask(next_offset, next_mask);
282 }
283 // Sanity check that the computed offset index of the last functionality key
284 // is correct. This assumes that the highest priority functionality key is not
285 // per backend.
286 TORCH_INTERNAL_ASSERT(
287 offsets_and_masks[num_functionality_keys - 1].offset ==
288 (num_runtime_entries - 1),
289 "num_runtime_entries: ",
290 num_runtime_entries,
291 "last_offset: ",
292 offsets_and_masks[num_functionality_keys - 1].offset);
293 return offsets_and_masks;
294}
295
296} // namespace c10
297