1#include <gtest/gtest.h>
2
3#include <cstddef>
4#include <iterator>
5#include <unordered_set>
6
7#include <c10/core/DispatchKeySet.h>
8#include <c10/util/irange.h>
9
10using namespace c10;
11
12// This test exists not to be comprehensive, but to more clearly show
13// what the semantics of DispatchKeySet are.
14TEST(DispatchKeySet, ShowSemantics) {
15 // the "CPU" dispatch key is an instance of a per-backend-functionality key.
16 // It corresponds to "dense" functionality, "CPU" backend.
17 // This means that it gets a dense functionality bit, and a cpu backend bit
18 // set.
19 auto dense_cpu_set = DispatchKeySet(DispatchKey::CPU);
20 ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
21 ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
22 ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
23
24 auto dense_lazy_set = DispatchKeySet(DispatchKey::Lazy);
25 ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Dense));
26 ASSERT_TRUE(dense_lazy_set.has_backend(BackendComponent::LazyBit));
27 ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Lazy));
28
29 // You can think of "Dense/Sparse", and "CPUBit/CUDABit", as "building block"
30 // dispatch keys. You are allowed to directly create keysets out of them!
31 auto dense_cpu_set_from_building_blocks = DispatchKeySet(DispatchKey::Dense) |
32 DispatchKeySet(BackendComponent::CPUBit);
33 ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense));
34 ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit));
35 ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU));
36 ASSERT_EQ(dense_cpu_set, dense_cpu_set_from_building_blocks);
37
38 // Similarly, the AutogradCUDA key gets 2 bits in the keyset:
39 // The "Autograd" functionality bit, and the "CUDA" backend bit
40 auto autograd_cuda = DispatchKeySet(DispatchKey::AutogradCUDA);
41 ASSERT_TRUE(autograd_cuda.has(DispatchKey::AutogradFunctionality));
42 ASSERT_TRUE(autograd_cuda.has_backend(BackendComponent::CUDABit));
43
44 // Because DispatchKeySet uses a condensed internal representation, you cannot
45 // use it to represent the FULL cross product of backends and functionalities
46 // for example:
47 auto autograd_dense_cpu_cuda = DispatchKeySet(
48 {DispatchKey::AutogradFunctionality,
49 DispatchKey::Dense,
50 DispatchKey::CUDA,
51 DispatchKey::CPU});
52 // this keyset has all of the building block keys:
53 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradFunctionality));
54 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::Dense));
55 ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CUDABit));
56 ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CPUBit));
57
58 // and it also has the "runtime" keys that correspond to the full
59 // cross-product of functionality
60 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
61 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU));
62 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CPU));
63 ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CUDA));
64
65 // This means that there's no way to represent a keyset with, say, only
66 // Autograd CUDA + Dense CPU. Instead, you should think of a keyset as
67 // inheriting the full set of functionalities + backends of its keys. This
68 // means that the below keysets are all indistinguishable from each other.
69 ASSERT_EQ(
70 autograd_dense_cpu_cuda,
71 DispatchKeySet(
72 {DispatchKey::AutogradCUDA,
73 DispatchKey::AutogradCPU,
74 DispatchKey::CUDA,
75 DispatchKey::CPU}));
76 ASSERT_EQ(
77 autograd_dense_cpu_cuda,
78 DispatchKeySet({DispatchKey::AutogradCUDA, DispatchKey::CPU}));
79 ASSERT_EQ(
80 autograd_dense_cpu_cuda,
81 DispatchKeySet({DispatchKey::CUDA, DispatchKey::AutogradCPU}));
82
83 // ~~~~~~~~~~ DispatchKeySet iterators ~~~~~~~~~~~
84
85 // Iterators allow you to iterate individually through the DispatchKey's in a
86 // DispatchKeySet
87 auto empty_set = DispatchKeySet();
88 ASSERT_EQ(*empty_set.begin(), *empty_set.end());
89
90 // However, only keys that correspond to actual runtime indices of kernels in
91 // the operator table show up when you iterate through a keyset. i.e.
92 // DispatchKey::Dense, and BackendComponent::CPUBit won't show up in an
93 // iterator.
94 auto dense_cpu_iter = dense_cpu_set.begin();
95 ASSERT_EQ(*dense_cpu_iter++, DispatchKey::CPU);
96 ASSERT_EQ(*dense_cpu_iter, *dense_cpu_set.end());
97
98 auto autograd_dense_cpu_cuda_iter = autograd_dense_cpu_cuda.begin();
99 ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CPU);
100 ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CUDA);
101 ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCPU);
102 ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCUDA);
103 ASSERT_EQ(*autograd_dense_cpu_cuda_iter, *autograd_dense_cpu_cuda.end());
104
105 // But other "functionality bits" that are not defined per-backend DO get
106 // their own slots in the operator table.
107 auto mixed_keyset = DispatchKeySet(BackendComponent::CPUBit) |
108 DispatchKeySet(
109 {DispatchKey::FPGA, // runtime key
110 DispatchKey::Functionalize, // runtime key
111 DispatchKey::Dense}); // NOT a runtime key
112 auto mixed_iter = mixed_keyset.begin();
113 ASSERT_EQ(*mixed_iter++, DispatchKey::CPU);
114 ASSERT_EQ(*mixed_iter++, DispatchKey::FPGA);
115 ASSERT_EQ(*mixed_iter++, DispatchKey::Functionalize);
116 ASSERT_EQ(*mixed_iter, *mixed_keyset.end());
117}
118
119TEST(DispatchKeySet, Empty) {
120 DispatchKeySet empty_set;
121 for (uint8_t i = 0;
122 i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
123 i++) {
124 auto tid = static_cast<DispatchKey>(i);
125 if (tid == DispatchKey::Undefined)
126 continue;
127 ASSERT_FALSE(empty_set.has(tid));
128 }
129 ASSERT_TRUE(empty_set.empty());
130 DispatchKeySet empty_set2;
131 ASSERT_TRUE(empty_set == empty_set2);
132}
133
134// This covers all keys that correspond to a single backend bit, e.g.
135// BackendComponent::CPUBit. Even though these are NOT runtime keys, we still
136// allow adding them directly to a keyset
137TEST(DispatchKeySet, SingletonBackendComponent) {
138 for (const auto i : c10::irange(1, num_backends)) {
139 auto tid = static_cast<DispatchKey>(i);
140 DispatchKeySet sing(tid);
141 ASSERT_EQ(sing, sing);
142 ASSERT_EQ(sing, DispatchKeySet().add(tid));
143 ASSERT_EQ(sing, sing.add(tid));
144 ASSERT_EQ(sing, sing | sing);
145 ASSERT_FALSE(sing.empty());
146 ASSERT_TRUE(sing.has(tid));
147 }
148}
149
150// This covers all keys that correspond to a single functionality bit:
151// - runtime, not-per-backend functionality keys, e.g.
152// DispatchKey::FuncTorchBatched
153// - runtime, "fake backend" keys, e.g. DispatchKey::FPGA
154// - NOT-runtime, per-backend functionality keys, e.g. DispatchKey::Dense
155// Even though it's not a runtime key, we still allow adding it directly to a
156// keyset.
157// DispatchKey::
158TEST(DispatchKeySet, SingletonFunctionalityKeys) {
159 for (const auto i : c10::irange(1, num_functionality_keys)) {
160 auto tid = static_cast<DispatchKey>(i);
161 DispatchKeySet sing(tid);
162 ASSERT_EQ(sing, sing);
163 ASSERT_EQ(sing, DispatchKeySet().add(tid));
164 ASSERT_EQ(sing, sing.add(tid));
165 ASSERT_EQ(sing, sing | sing);
166 ASSERT_FALSE(sing.empty());
167 ASSERT_TRUE(sing.has(tid));
168 ASSERT_EQ(sing.remove(tid), DispatchKeySet());
169 }
170}
171
172// This covers runtime keys that are per-backend,
173// and take up more than one bit in a DispatchKeySet. They take up one
174// functionality bit + one backend bit. e.g. CPU, CUDA, SparseCPU, SparseCUDA,
175// AutogradCPU, AutogradCUDA
176TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) {
177 for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
178 i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
179 i++) {
180 auto tid = static_cast<DispatchKey>(i);
181 // Skip these because they aren't real keys.
182 if (tid == DispatchKey::StartOfDenseBackends ||
183 tid == DispatchKey::StartOfSparseBackends ||
184 tid == DispatchKey::StartOfQuantizedBackends ||
185 tid == DispatchKey::StartOfAutogradFunctionalityBackends) {
186 continue;
187 }
188 DispatchKeySet sing(tid);
189 ASSERT_EQ(sing, sing);
190 ASSERT_EQ(sing, DispatchKeySet().add(tid));
191 ASSERT_EQ(sing, sing.add(tid));
192 ASSERT_EQ(sing, sing | sing);
193 ASSERT_FALSE(sing.empty());
194 ASSERT_TRUE(sing.has(tid));
195
196 auto functionality_key = toFunctionalityKey(tid);
197 auto backend_key = toBackendComponent(tid);
198 // These two sets should be equivalent:
199 // DispatchKeySet(DispatchKey::CPU)
200 // DispatchKeySet({DispatchKey::Dense, BackendComponent::CPUBit})
201 auto expected_ks =
202 DispatchKeySet(functionality_key) | DispatchKeySet(backend_key);
203 ASSERT_EQ(sing, expected_ks);
204 // These two sets should be equivalent:
205 // DispatchKeySet(DispatchKey::CPU).remove(DispatchKey::Dense)
206 // DispatchKeySet(BackendComponent::CPUBit)
207 expected_ks = DispatchKeySet(toBackendComponent(tid));
208 ASSERT_EQ(sing.remove(tid), expected_ks);
209 }
210}
211
212TEST(DispatchKeySet, DoubletonPerBackend) {
213 for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends);
214 i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
215 i++) {
216 for (uint8_t j = i + 1;
217 j <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys);
218 j++) {
219 ASSERT_LT(i, j);
220 auto tid1 = static_cast<DispatchKey>(i);
221 auto tid2 = static_cast<DispatchKey>(j);
222
223 // Skip these because they aren't real keys.
224 if (tid1 == DispatchKey::StartOfDenseBackends ||
225 tid1 == DispatchKey::StartOfSparseBackends ||
226 tid1 == DispatchKey::StartOfQuantizedBackends ||
227 tid1 == DispatchKey::StartOfNestedTensorBackends ||
228 tid1 == DispatchKey::StartOfAutogradFunctionalityBackends)
229 continue;
230 if (tid2 == DispatchKey::StartOfDenseBackends ||
231 tid2 == DispatchKey::StartOfSparseBackends ||
232 tid2 == DispatchKey::StartOfQuantizedBackends ||
233 tid2 == DispatchKey::StartOfNestedTensorBackends ||
234 tid2 == DispatchKey::StartOfAutogradFunctionalityBackends)
235 continue;
236
237 auto backend1 = toBackendComponent(tid1);
238 auto backend2 = toBackendComponent(tid2);
239 auto functionality1 = toFunctionalityKey(tid1);
240 auto functionality2 = toFunctionalityKey(tid2);
241
242 auto combined = DispatchKeySet({tid1, tid2});
243 // The combined set has the backend bits
244 ASSERT_TRUE(combined.has_backend(backend1));
245 ASSERT_TRUE(combined.has_backend(backend2));
246 // and it has the backend bits
247 ASSERT_TRUE(combined.has(functionality1));
248 ASSERT_TRUE(combined.has(functionality2));
249 // and it has the original two runtime keys
250 ASSERT_TRUE(combined.has(tid1));
251 ASSERT_TRUE(combined.has(tid2));
252
253 // Add all of the keys in the keyset to a real set
254 std::unordered_set<DispatchKey> visited_keys;
255 auto iter = combined.begin();
256 while (*iter != *combined.end()) {
257 visited_keys.insert(*iter);
258 ++iter;
259 }
260 std::unordered_set<DispatchKey> expected_keys;
261 expected_keys.insert(
262 toRuntimePerBackendFunctionalityKey(functionality1, backend1));
263 expected_keys.insert(
264 toRuntimePerBackendFunctionalityKey(functionality1, backend2));
265 expected_keys.insert(
266 toRuntimePerBackendFunctionalityKey(functionality2, backend1));
267 expected_keys.insert(
268 toRuntimePerBackendFunctionalityKey(functionality2, backend2));
269 ASSERT_EQ(expected_keys, visited_keys);
270
271 if (backend1 == backend2 || functionality1 == functionality2) {
272 // We have two runtime keys, with either the same backend or the same
273 // per-backend functionalities. E.g. {AutogradCUDA, CUDA} or
274 // {AutogradCPU, AutogradCUDA} There should be 2 total runtime keys in
275 // this set.
276 ASSERT_EQ(2, visited_keys.size());
277 } else {
278 // since i and j are different keys, they should not have the same
279 // functionality and backend
280 ASSERT_TRUE(backend1 != backend2 && functionality1 != functionality2);
281 // We have two runtime keys, that have different backends + per-backend
282 // functionalities. So we should expect the full cross product of
283 // runtime keys to be in the set. e.g. if i = AutogradCUDA, and j = CPU,
284 // then combined = {AutogradCUDA, AutogradCPU, CUDA, CPU}
285 ASSERT_EQ(4, visited_keys.size());
286 }
287 }
288 }
289}
290
291TEST(DispatchKeySet, Full) {
292 DispatchKeySet full(DispatchKeySet::FULL);
293 for (const auto i : c10::irange(1, num_functionality_keys)) {
294 auto tid = static_cast<DispatchKey>(i);
295 ASSERT_TRUE(full.has(tid));
296 }
297 ASSERT_FALSE(full.has(DispatchKey::EndOfFunctionalityKeys));
298}
299
300TEST(DispatchKeySet, IteratorBasicOps) {
301 DispatchKeySet empty_set;
302 DispatchKeySet full_set(DispatchKeySet::FULL);
303 DispatchKeySet mutated_set = empty_set.add(DispatchKey::CPU);
304
305 // Constructor + Comparison
306 ASSERT_EQ(*empty_set.begin(), DispatchKey::EndOfFunctionalityKeys);
307 ASSERT_EQ(*empty_set.end(), DispatchKey::EndOfFunctionalityKeys);
308 ASSERT_EQ(*mutated_set.begin(), DispatchKey::CPU);
309
310 ASSERT_TRUE(empty_set.begin() == empty_set.end());
311 ASSERT_TRUE(full_set.begin() != full_set.end());
312
313 // Increment Ops
314 ASSERT_TRUE(full_set.begin() == full_set.begin()++);
315 ASSERT_TRUE(full_set.begin() != ++full_set.begin());
316}
317
318TEST(DispatchKeySet, getHighestPriorityBackendTypeId) {
319 // AutogradCPU isn't a backend key so it is ignored
320 DispatchKeySet dense_cpu({DispatchKey::AutogradCPU, DispatchKey::CPU});
321 ASSERT_EQ(DispatchKey::CPU, c10::highestPriorityBackendTypeId(dense_cpu));
322
323 // Functionalize isn't a backend key so it is ignored
324 DispatchKeySet sparse_cuda(
325 {DispatchKey::Functionalize, DispatchKey::SparseCUDA});
326 ASSERT_EQ(
327 DispatchKey::SparseCUDA, c10::highestPriorityBackendTypeId(sparse_cuda));
328
329 // quantizedCUDA has higher priority than CUDA
330 DispatchKeySet quantized_cuda(
331 {DispatchKey::CUDA, DispatchKey::QuantizedCUDA});
332 ASSERT_EQ(
333 DispatchKey::QuantizedCUDA,
334 c10::highestPriorityBackendTypeId(quantized_cuda));
335}
336
337TEST(DispatchKeySet, IteratorEmpty) {
338 DispatchKeySet empty_set;
339 uint8_t i = 0;
340
341 for (auto it = empty_set.begin(); it != empty_set.end(); ++it) {
342 i++;
343 }
344 ASSERT_EQ(i, 0);
345}
346
347TEST(DispatchKeySet, IteratorCrossProduct) {
348 // The iterator should return all runtime keys in the set,
349 // including the cross product of {backends} x {functionalities}
350 auto ks =
351 DispatchKeySet({BackendComponent::CPUBit, BackendComponent::CUDABit}) |
352 DispatchKeySet(
353 {DispatchKey::Dense,
354 DispatchKey::FPGA,
355 DispatchKey::AutogradFunctionality});
356
357 auto iter = ks.begin();
358 // iterate through dense backends first.
359 ASSERT_EQ(DispatchKey::CPU, *(iter++));
360 ASSERT_EQ(DispatchKey::CUDA, *(iter++));
361 // FPGA doesn't have a backend bit, so it isn't included in the cross product.
362 ASSERT_EQ(DispatchKey::FPGA, *(iter++));
363 // iterate through the autograd keys laster.
364 ASSERT_EQ(DispatchKey::AutogradCPU, *(iter++));
365 ASSERT_EQ(DispatchKey::AutogradCUDA, *(iter++));
366}
367
368TEST(DispatchKeySet, IteratorFull) {
369 DispatchKeySet full_set(DispatchKeySet::FULL);
370 std::ptrdiff_t count = std::distance(full_set.begin(), full_set.end());
371
372 // Total # of runtime entries includes an entry for DispatchKey::Undefined,
373 // which is not included when iterating through the DispatchKeySet.
374 ASSERT_EQ(count, std::ptrdiff_t{num_runtime_entries} - 1);
375}
376TEST(DispatchKeySet, FailAtEndIterator) {
377 DispatchKeySet full_set(DispatchKeySet::FULL);
378 uint64_t raw_repr = full_set.raw_repr();
379
380 // doesn't throw
381 DispatchKeySet::iterator(&raw_repr, num_backends + num_functionality_keys);
382 // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
383 EXPECT_THROW(
384 DispatchKeySet::iterator(
385 &raw_repr, num_backends + num_functionality_keys + 1),
386 c10::Error);
387}
388
389TEST(DispatchKeySet, TestBackendComponentToString) {
390 std::unordered_set<std::string> seen_strings;
391 for (int64_t i = 0;
392 i <= static_cast<int64_t>(BackendComponent::EndOfBackendKeys);
393 i++) {
394 auto k = static_cast<BackendComponent>(i);
395 auto res = std::string(toString(k));
396 ASSERT_FALSE(res == "UNKNOWN_BACKEND_BIT");
397 ASSERT_FALSE(seen_strings.count(res) > 0);
398 seen_strings.insert(res);
399 }
400}
401
402TEST(DispatchKeySet, TestEndOfRuntimeBackendKeysAccurate) {
403 DispatchKey k;
404#define SETTER(fullname, prefix) k = DispatchKey::EndOf##fullname##Backends;
405 C10_FORALL_FUNCTIONALITY_KEYS(SETTER)
406#undef SETTER
407 ASSERT_TRUE(k == DispatchKey::EndOfRuntimeBackendKeys);
408}
409
410TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) {
411 std::unordered_set<std::string> seen_strings;
412 for (int i = 0; i <= static_cast<int>(DispatchKey::EndOfAliasKeys); i++) {
413 auto k = static_cast<DispatchKey>(i);
414 // These synthetic keys never actually get used and don't need
415 // to be printed
416 if (k == DispatchKey::EndOfFunctionalityKeys ||
417 k == DispatchKey::StartOfDenseBackends ||
418 k == DispatchKey::StartOfQuantizedBackends ||
419 k == DispatchKey::StartOfSparseBackends ||
420 k == DispatchKey::StartOfNestedTensorBackends ||
421 k == DispatchKey::StartOfAutogradFunctionalityBackends)
422 continue;
423 auto res = std::string(toString(k));
424 ASSERT_TRUE(res.find("Unknown") == std::string::npos)
425 << i << " (before is " << toString(static_cast<DispatchKey>(i - 1))
426 << ")";
427 ASSERT_TRUE(seen_strings.count(res) == 0);
428 seen_strings.insert(res);
429 }
430}
431