1 | #include <c10/core/DispatchKeySet.h> |
2 | #include <c10/util/irange.h> |
3 | #include <iostream> |
4 | |
5 | namespace c10 { |
6 | |
7 | // backend_dispatch_keyset includes all dispatch keys that map to backends. |
8 | // Alias key DispatchKey::CompositeExplicitAutograd maps to |
9 | // backend_dispatch_keyset |
10 | constexpr DispatchKeySet backend_dispatch_keyset = |
11 | autogradother_backends | DispatchKeySet(DispatchKey::Dense); |
12 | |
13 | // See Note [CompositeExplicitAutogradNonFunctional Key] |
14 | // We have several types of decompositions in aten, that each have their own |
15 | // alias key. You should register your decomposition to the |
16 | // `CompositeExplicitAutogradNonFunctional key` if: (1) It's an out-of-place op |
17 | // (2) It decomposes into one more mutation ops |
18 | // (3) It has a derivative formula |
19 | // (In theory we could also have a separate key for |
20 | // "CompositeImplicitAutogradNonFunctional", but there isn't much of a use |
21 | // case for it currently). |
22 | // This key is important for "functional" backends like LazyTensor / XLA. |
23 | // If you're a backend that only expects to deal with "functional ops", |
24 | // then you don't want to decompose a functional op into an op that causes |
25 | // aliasing. You should just directly write a kernel for that functional op |
26 | // instead! |
27 | constexpr DispatchKeySet non_functional_backend_dispatch_keyset = |
28 | backend_dispatch_keyset |
29 | // XLA and LazyTensor are currently the only 2 backends in core |
30 | // that use functionalization pass in eager mode. |
31 | .remove(DispatchKey::Sparse) |
32 | .remove_backend(BackendComponent::XLABit) |
33 | .remove_backend(BackendComponent::LazyBit); |
34 | |
35 | bool isBackendDispatchKey(DispatchKey t) { |
36 | return t != DispatchKey::Undefined |
37 | // See Note [No Alias Keys in DispatchKeySet] |
38 | && !isAliasDispatchKey(t) |
39 | // Note [NestedTensor Not Included in Backend Keys] |
40 | // NestedTensor has been explicitly removed from the "backend keyset" due |
41 | // to incompatibility with some kernels, so we don't want it to be |
42 | // included in CompositeExplicitAutograd kernels. |
43 | && t != DispatchKey::NestedTensor && backend_dispatch_keyset.has(t); |
44 | } |
45 | |
46 | // math_dispatch_keyset contains all keys in backend_dispatch_keyset and |
47 | // autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd |
48 | // maps to [math_dispatch_keyset x full_backend_mask] |
49 | constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | |
50 | autograd_dispatch_keyset | |
51 | // See Note [NestedTensor Not Included in Backend Keys] |
52 | // The caveat to that note is that nested_tensor is a special case |
53 | // where we would like to support composite implicit kernels but not |
54 | // explicit kernels therefore we manually add the key to the |
55 | // math_dispatch_keyset |
56 | DispatchKeySet{DispatchKey::NestedTensor}; |
57 | |
58 | constexpr DispatchKeySet nested_dispatch_keyset = |
59 | DispatchKeySet( |
60 | {DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) | |
61 | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); |
62 | |
63 | DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { |
64 | TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); |
65 | switch (t) { |
66 | case DispatchKey::Autograd: |
67 | // See Note [autograd_dispatch_keyset Does Not Include Backend Bits] |
68 | // That's why we OR it with a mask of the backend bits here. |
69 | // getRuntimeDispatchKeySet() expects to return a keyset of runtime |
70 | // dispatch keys, like AutogradCPU, but that requires having backend bits. |
71 | return autograd_dispatch_keyset | |
72 | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); |
73 | case DispatchKey::CompositeImplicitAutograd: |
74 | return math_dispatch_keyset; |
75 | case DispatchKey::CompositeImplicitAutogradNestedTensor: |
76 | return nested_dispatch_keyset; |
77 | case DispatchKey::CompositeExplicitAutograd: |
78 | return backend_dispatch_keyset; |
79 | case DispatchKey::CompositeExplicitAutogradNonFunctional: |
80 | return non_functional_backend_dispatch_keyset; |
81 | default: |
82 | return DispatchKeySet(t); |
83 | } |
84 | } |
85 | |
86 | bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k) { |
87 | TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); |
88 | switch (t) { |
89 | case DispatchKey::Autograd: |
90 | return autograd_dispatch_keyset.has(toFunctionalityKey(k)); |
91 | case DispatchKey::CompositeImplicitAutograd: |
92 | // See Note [NestedTensor Not Included in Backend Keys] |
93 | return math_dispatch_keyset.has(k); |
94 | case DispatchKey::CompositeImplicitAutogradNestedTensor: |
95 | // See Note [NestedTensor Not Included in Backend Keys] |
96 | return nested_dispatch_keyset.has(k); |
97 | case DispatchKey::CompositeExplicitAutograd: |
98 | // See Note [NestedTensor Not Included in Backend Keys] |
99 | return k != DispatchKey::NestedTensor && backend_dispatch_keyset.has(k); |
100 | case DispatchKey::CompositeExplicitAutogradNonFunctional: |
101 | // See Note [NestedTensor Not Included in Backend Keys] |
102 | return k != DispatchKey::NestedTensor && |
103 | non_functional_backend_dispatch_keyset.has(k); |
104 | case DispatchKey::FuncTorchBatchedDecomposition: |
105 | return functorch_batched_ks.has(k); |
106 | default: |
107 | return t == k; |
108 | } |
109 | } |
110 | |
111 | // for a given autograd key, return the (guaranteed nonempty) set of associated |
112 | // backend keys. for a non-autograd key, return the empty keyset. |
113 | DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) { |
114 | switch (t) { |
115 | case DispatchKey::AutogradCPU: |
116 | return DispatchKeySet(DispatchKey::CPU); |
117 | case DispatchKey::AutogradCUDA: |
118 | return DispatchKeySet(DispatchKey::CUDA); |
119 | case DispatchKey::AutogradXLA: |
120 | return DispatchKeySet(DispatchKey::XLA); |
121 | case DispatchKey::AutogradLazy: |
122 | return DispatchKeySet(DispatchKey::Lazy); |
123 | case DispatchKey::AutogradMeta: |
124 | return DispatchKeySet(DispatchKey::Meta); |
125 | case DispatchKey::AutogradMPS: |
126 | return DispatchKeySet(DispatchKey::MPS); |
127 | case DispatchKey::AutogradHPU: |
128 | return DispatchKeySet(DispatchKey::HPU); |
129 | case DispatchKey::AutogradIPU: |
130 | return DispatchKeySet(DispatchKey::IPU); |
131 | case DispatchKey::AutogradXPU: |
132 | return DispatchKeySet(DispatchKey::XPU); |
133 | case DispatchKey::AutogradPrivateUse1: |
134 | return DispatchKeySet(DispatchKey::PrivateUse1); |
135 | case DispatchKey::AutogradPrivateUse2: |
136 | return DispatchKeySet(DispatchKey::PrivateUse2); |
137 | case DispatchKey::AutogradPrivateUse3: |
138 | return DispatchKeySet(DispatchKey::PrivateUse3); |
139 | case DispatchKey::AutogradNestedTensor: |
140 | return DispatchKeySet(DispatchKey::NestedTensor) | |
141 | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); |
142 | case DispatchKey::AutogradOther: |
143 | return autogradother_backends; |
144 | default: |
145 | return DispatchKeySet(); |
146 | } |
147 | } |
148 | |
149 | bool isIncludedInAlias(DispatchKey k, DispatchKey alias) { |
150 | return k != DispatchKey::Undefined && runtimeDispatchKeySetHas(alias, k); |
151 | } |
152 | |
153 | std::string toString(DispatchKeySet ts) { |
154 | std::stringstream ss; |
155 | ss << ts; |
156 | return ss.str(); |
157 | } |
158 | |
159 | std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) { |
160 | if (ts.empty()) { |
161 | os << "DispatchKeySet()" ; |
162 | return os; |
163 | } |
164 | os << "DispatchKeySet(" ; |
165 | bool first = true; |
166 | for (auto k : ts) { |
167 | if (!first) { |
168 | os << ", " ; |
169 | } |
170 | os << k; |
171 | first = false; |
172 | } |
173 | os << ")" ; |
174 | return os; |
175 | } |
176 | |
177 | DispatchKeySet::iterator& DispatchKeySet::iterator::operator++() { |
178 | TORCH_INTERNAL_ASSERT(next_functionality_ <= iterator::end_iter_mask_val); |
179 | TORCH_INTERNAL_ASSERT(next_backend_ <= num_backends, next_backend_); |
180 | |
181 | // Create a masked version of the set representation to ignore previous |
182 | // keys that we've iterated through. |
183 | uint64_t masked_functionality_bits = |
184 | llvm::maskTrailingZeros<uint64_t>(next_functionality_) & *data_ptr_; |
185 | uint64_t masked_backend_bits = |
186 | llvm::maskTrailingZeros<uint64_t>(next_backend_) & full_backend_mask & |
187 | *data_ptr_; |
188 | |
189 | uint64_t first_functionality_idx = |
190 | llvm::findFirstSet(masked_functionality_bits); |
191 | uint64_t first_backendcomponent_idx = llvm::findFirstSet(masked_backend_bits); |
192 | |
193 | // If there are no keys, set to end iterator value |
194 | if (first_functionality_idx == std::numeric_limits<uint64_t>::max() || |
195 | next_functionality_ == iterator::end_iter_mask_val) { |
196 | // Set up state to be the same as end() |
197 | next_functionality_ = iterator::end_iter_mask_val; |
198 | current_dispatchkey_idx_ = iterator::end_iter_key_val; |
199 | next_backend_ = 0; |
200 | current_backendcomponent_idx_ = iterator::end_iter_key_val; |
201 | return *this; |
202 | } |
203 | |
204 | // The +1 is because of DispatchKey::Undefined and |
205 | // BackendComponent::InvalidBit |
206 | auto new_next_functionality = first_functionality_idx + 1; |
207 | auto new_backendcomponent_idx = first_backendcomponent_idx + 1; |
208 | // and the -num_backends is because the first <num_backends> bits in the |
209 | // keyset are not Dispatch Keys. |
210 | auto next_dispatchkey_idx = new_next_functionality - num_backends; |
211 | |
212 | // If the current functionality bit is a per-backend bit, we need special |
213 | // handling |
214 | if (isPerBackendFunctionalityKey( |
215 | static_cast<DispatchKey>(next_dispatchkey_idx))) { |
216 | // case 1: if the current backend is undefined, then there is no valid |
217 | // backend instance of this functionality key so we can skip it. |
218 | if (first_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) { |
219 | // increment the functionality mask so we skip the current functionality |
220 | // bit on the next increment. |
221 | next_functionality_ = new_next_functionality; |
222 | ++(*this); |
223 | return *this; |
224 | } |
225 | |
226 | // Otherwise, at this point we know what the current backend and |
227 | // functionality bits are. |
228 | current_dispatchkey_idx_ = next_dispatchkey_idx; |
229 | current_backendcomponent_idx_ = new_backendcomponent_idx; |
230 | |
231 | // Next, we need to set up the masks for the next increment. |
232 | uint64_t next_backendcomponent_bits = |
233 | llvm::maskTrailingZeros<uint64_t>(first_backendcomponent_idx + 1) & |
234 | full_backend_mask & *data_ptr_; |
235 | uint64_t next_backendcomponent_idx = |
236 | llvm::findFirstSet(next_backendcomponent_bits); |
237 | if (next_backendcomponent_idx == std::numeric_limits<uint64_t>::max()) { |
238 | // case 2: the current backend is valid, but there is not another backend |
239 | // in the keyset. In this case, we need to bump the functionality mask and |
240 | // reset the backend mask for the next increment |
241 | next_functionality_ = new_next_functionality; |
242 | next_backend_ = 0; |
243 | } else { |
244 | // case 3: we have another backend to iterate over. We want to iterate |
245 | // over the same functionality bit next time, but a different backend bit. |
246 | next_backend_ = first_backendcomponent_idx + 1; |
247 | } |
248 | } else { |
249 | // Functionality bits that aren't per backend are simpler to handle. We can |
250 | // ignore the backend bits. |
251 | TORCH_INTERNAL_ASSERT(next_backend_ == 0); |
252 | current_dispatchkey_idx_ = next_dispatchkey_idx; |
253 | next_functionality_ = new_next_functionality; |
254 | } |
255 | return *this; |
256 | } |
257 | |
258 | std::array<FunctionalityOffsetAndMask, num_functionality_keys> |
259 | initializeFunctionalityOffsetsAndMasks() { |
260 | std::array<FunctionalityOffsetAndMask, num_functionality_keys> |
261 | offsets_and_masks; |
262 | // manualy set the first entry, which corresponds to Undefined. |
263 | offsets_and_masks[0] = FunctionalityOffsetAndMask(0, 0); |
264 | // loop through every functionality key (aside from Undefined). |
265 | for (const auto functionality_idx : c10::irange(1, num_functionality_keys)) { |
266 | // functionality_idx should be Dense -> 1, ... |
267 | auto prev_offset_and_mask = offsets_and_masks[functionality_idx - 1]; |
268 | auto k = static_cast<DispatchKey>(functionality_idx); |
269 | |
270 | // If the previous functionality was not per-backend, then we can just |
271 | // increment the previous offset. Otherwise, the next offset = |
272 | // previous_offset + num_backends. |
273 | auto next_offset = prev_offset_and_mask.offset + |
274 | (prev_offset_and_mask.mask == 0 ? 1 : num_backends); |
275 | // the mask is used in the runtime index calculation to find the offset of |
276 | // the backend. For non-per-backend functionalities, this offset should |
277 | // always be 0. Otherwise, we need to get the index of the backend (which we |
278 | // can do using a backend mask). |
279 | auto next_mask = isPerBackendFunctionalityKey(k) ? full_backend_mask : 0; |
280 | offsets_and_masks[functionality_idx] = |
281 | FunctionalityOffsetAndMask(next_offset, next_mask); |
282 | } |
283 | // Sanity check that the computed offset index of the last functionality key |
284 | // is correct. This assumes that the highest priority functionality key is not |
285 | // per backend. |
286 | TORCH_INTERNAL_ASSERT( |
287 | offsets_and_masks[num_functionality_keys - 1].offset == |
288 | (num_runtime_entries - 1), |
289 | "num_runtime_entries: " , |
290 | num_runtime_entries, |
291 | "last_offset: " , |
292 | offsets_and_masks[num_functionality_keys - 1].offset); |
293 | return offsets_and_masks; |
294 | } |
295 | |
296 | } // namespace c10 |
297 | |