1 | #include <gtest/gtest.h> |
2 | |
3 | #include <cstddef> |
4 | #include <iterator> |
5 | #include <unordered_set> |
6 | |
7 | #include <c10/core/DispatchKeySet.h> |
8 | #include <c10/util/irange.h> |
9 | |
10 | using namespace c10; |
11 | |
12 | // This test exists not to be comprehensive, but to more clearly show |
13 | // what the semantics of DispatchKeySet are. |
14 | TEST(DispatchKeySet, ShowSemantics) { |
15 | // the "CPU" dispatch key is an instance of a per-backend-functionality key. |
16 | // It corresponds to "dense" functionality, "CPU" backend. |
17 | // This means that it gets a dense functionality bit, and a cpu backend bit |
18 | // set. |
19 | auto dense_cpu_set = DispatchKeySet(DispatchKey::CPU); |
20 | ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense)); |
21 | ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit)); |
22 | ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU)); |
23 | |
24 | auto dense_lazy_set = DispatchKeySet(DispatchKey::Lazy); |
25 | ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Dense)); |
26 | ASSERT_TRUE(dense_lazy_set.has_backend(BackendComponent::LazyBit)); |
27 | ASSERT_TRUE(dense_lazy_set.has(DispatchKey::Lazy)); |
28 | |
29 | // You can think of "Dense/Sparse", and "CPUBit/CUDABit", as "building block" |
30 | // dispatch keys. You are allowed to directly create keysets out of them! |
31 | auto dense_cpu_set_from_building_blocks = DispatchKeySet(DispatchKey::Dense) | |
32 | DispatchKeySet(BackendComponent::CPUBit); |
33 | ASSERT_TRUE(dense_cpu_set.has(DispatchKey::Dense)); |
34 | ASSERT_TRUE(dense_cpu_set.has_backend(BackendComponent::CPUBit)); |
35 | ASSERT_TRUE(dense_cpu_set.has(DispatchKey::CPU)); |
36 | ASSERT_EQ(dense_cpu_set, dense_cpu_set_from_building_blocks); |
37 | |
38 | // Similarly, the AutogradCUDA key gets 2 bits in the keyset: |
39 | // The "Autograd" functionality bit, and the "CUDA" backend bit |
40 | auto autograd_cuda = DispatchKeySet(DispatchKey::AutogradCUDA); |
41 | ASSERT_TRUE(autograd_cuda.has(DispatchKey::AutogradFunctionality)); |
42 | ASSERT_TRUE(autograd_cuda.has_backend(BackendComponent::CUDABit)); |
43 | |
44 | // Because DispatchKeySet uses a condensed internal representation, you cannot |
45 | // use it to represent the FULL cross product of backends and functionalities |
46 | // for example: |
47 | auto autograd_dense_cpu_cuda = DispatchKeySet( |
48 | {DispatchKey::AutogradFunctionality, |
49 | DispatchKey::Dense, |
50 | DispatchKey::CUDA, |
51 | DispatchKey::CPU}); |
52 | // this keyset has all of the building block keys: |
53 | ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradFunctionality)); |
54 | ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::Dense)); |
55 | ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CUDABit)); |
56 | ASSERT_TRUE(autograd_dense_cpu_cuda.has_backend(BackendComponent::CPUBit)); |
57 | |
58 | // and it also has the "runtime" keys that correspond to the full |
59 | // cross-product of functionality |
60 | ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU)); |
61 | ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::AutogradCPU)); |
62 | ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CPU)); |
63 | ASSERT_TRUE(autograd_dense_cpu_cuda.has(DispatchKey::CUDA)); |
64 | |
65 | // This means that there's no way to represent a keyset with, say, only |
66 | // Autograd CUDA + Dense CPU. Instead, you should think of a keyset as |
67 | // inheriting the full set of functionalities + backends of its keys. This |
68 | // means that the below keysets are all indistinguishable from each other. |
69 | ASSERT_EQ( |
70 | autograd_dense_cpu_cuda, |
71 | DispatchKeySet( |
72 | {DispatchKey::AutogradCUDA, |
73 | DispatchKey::AutogradCPU, |
74 | DispatchKey::CUDA, |
75 | DispatchKey::CPU})); |
76 | ASSERT_EQ( |
77 | autograd_dense_cpu_cuda, |
78 | DispatchKeySet({DispatchKey::AutogradCUDA, DispatchKey::CPU})); |
79 | ASSERT_EQ( |
80 | autograd_dense_cpu_cuda, |
81 | DispatchKeySet({DispatchKey::CUDA, DispatchKey::AutogradCPU})); |
82 | |
83 | // ~~~~~~~~~~ DispatchKeySet iterators ~~~~~~~~~~~ |
84 | |
85 | // Iterators allow you to iterate individually through the DispatchKey's in a |
86 | // DispatchKeySet |
87 | auto empty_set = DispatchKeySet(); |
88 | ASSERT_EQ(*empty_set.begin(), *empty_set.end()); |
89 | |
90 | // However, only keys that correspond to actual runtime indices of kernels in |
91 | // the operator table show up when you iterate through a keyset. i.e. |
92 | // DispatchKey::Dense, and BackendComponent::CPUBit won't show up in an |
93 | // iterator. |
94 | auto dense_cpu_iter = dense_cpu_set.begin(); |
95 | ASSERT_EQ(*dense_cpu_iter++, DispatchKey::CPU); |
96 | ASSERT_EQ(*dense_cpu_iter, *dense_cpu_set.end()); |
97 | |
98 | auto autograd_dense_cpu_cuda_iter = autograd_dense_cpu_cuda.begin(); |
99 | ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CPU); |
100 | ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::CUDA); |
101 | ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCPU); |
102 | ASSERT_EQ(*autograd_dense_cpu_cuda_iter++, DispatchKey::AutogradCUDA); |
103 | ASSERT_EQ(*autograd_dense_cpu_cuda_iter, *autograd_dense_cpu_cuda.end()); |
104 | |
105 | // But other "functionality bits" that are not defined per-backend DO get |
106 | // their own slots in the operator table. |
107 | auto mixed_keyset = DispatchKeySet(BackendComponent::CPUBit) | |
108 | DispatchKeySet( |
109 | {DispatchKey::FPGA, // runtime key |
110 | DispatchKey::Functionalize, // runtime key |
111 | DispatchKey::Dense}); // NOT a runtime key |
112 | auto mixed_iter = mixed_keyset.begin(); |
113 | ASSERT_EQ(*mixed_iter++, DispatchKey::CPU); |
114 | ASSERT_EQ(*mixed_iter++, DispatchKey::FPGA); |
115 | ASSERT_EQ(*mixed_iter++, DispatchKey::Functionalize); |
116 | ASSERT_EQ(*mixed_iter, *mixed_keyset.end()); |
117 | } |
118 | |
119 | TEST(DispatchKeySet, Empty) { |
120 | DispatchKeySet empty_set; |
121 | for (uint8_t i = 0; |
122 | i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys); |
123 | i++) { |
124 | auto tid = static_cast<DispatchKey>(i); |
125 | if (tid == DispatchKey::Undefined) |
126 | continue; |
127 | ASSERT_FALSE(empty_set.has(tid)); |
128 | } |
129 | ASSERT_TRUE(empty_set.empty()); |
130 | DispatchKeySet empty_set2; |
131 | ASSERT_TRUE(empty_set == empty_set2); |
132 | } |
133 | |
134 | // This covers all keys that correspond to a single backend bit, e.g. |
135 | // BackendComponent::CPUBit. Even though these are NOT runtime keys, we still |
136 | // allow adding them directly to a keyset |
137 | TEST(DispatchKeySet, SingletonBackendComponent) { |
138 | for (const auto i : c10::irange(1, num_backends)) { |
139 | auto tid = static_cast<DispatchKey>(i); |
140 | DispatchKeySet sing(tid); |
141 | ASSERT_EQ(sing, sing); |
142 | ASSERT_EQ(sing, DispatchKeySet().add(tid)); |
143 | ASSERT_EQ(sing, sing.add(tid)); |
144 | ASSERT_EQ(sing, sing | sing); |
145 | ASSERT_FALSE(sing.empty()); |
146 | ASSERT_TRUE(sing.has(tid)); |
147 | } |
148 | } |
149 | |
150 | // This covers all keys that correspond to a single functionality bit: |
151 | // - runtime, not-per-backend functionality keys, e.g. |
152 | // DispatchKey::FuncTorchBatched |
153 | // - runtime, "fake backend" keys, e.g. DispatchKey::FPGA |
154 | // - NOT-runtime, per-backend functionality keys, e.g. DispatchKey::Dense |
155 | // Even though it's not a runtime key, we still allow adding it directly to a |
156 | // keyset. |
157 | // DispatchKey:: |
158 | TEST(DispatchKeySet, SingletonFunctionalityKeys) { |
159 | for (const auto i : c10::irange(1, num_functionality_keys)) { |
160 | auto tid = static_cast<DispatchKey>(i); |
161 | DispatchKeySet sing(tid); |
162 | ASSERT_EQ(sing, sing); |
163 | ASSERT_EQ(sing, DispatchKeySet().add(tid)); |
164 | ASSERT_EQ(sing, sing.add(tid)); |
165 | ASSERT_EQ(sing, sing | sing); |
166 | ASSERT_FALSE(sing.empty()); |
167 | ASSERT_TRUE(sing.has(tid)); |
168 | ASSERT_EQ(sing.remove(tid), DispatchKeySet()); |
169 | } |
170 | } |
171 | |
172 | // This covers runtime keys that are per-backend, |
173 | // and take up more than one bit in a DispatchKeySet. They take up one |
174 | // functionality bit + one backend bit. e.g. CPU, CUDA, SparseCPU, SparseCUDA, |
175 | // AutogradCPU, AutogradCUDA |
176 | TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) { |
177 | for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends); |
178 | i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys); |
179 | i++) { |
180 | auto tid = static_cast<DispatchKey>(i); |
181 | // Skip these because they aren't real keys. |
182 | if (tid == DispatchKey::StartOfDenseBackends || |
183 | tid == DispatchKey::StartOfSparseBackends || |
184 | tid == DispatchKey::StartOfQuantizedBackends || |
185 | tid == DispatchKey::StartOfAutogradFunctionalityBackends) { |
186 | continue; |
187 | } |
188 | DispatchKeySet sing(tid); |
189 | ASSERT_EQ(sing, sing); |
190 | ASSERT_EQ(sing, DispatchKeySet().add(tid)); |
191 | ASSERT_EQ(sing, sing.add(tid)); |
192 | ASSERT_EQ(sing, sing | sing); |
193 | ASSERT_FALSE(sing.empty()); |
194 | ASSERT_TRUE(sing.has(tid)); |
195 | |
196 | auto functionality_key = toFunctionalityKey(tid); |
197 | auto backend_key = toBackendComponent(tid); |
198 | // These two sets should be equivalent: |
199 | // DispatchKeySet(DispatchKey::CPU) |
200 | // DispatchKeySet({DispatchKey::Dense, BackendComponent::CPUBit}) |
201 | auto expected_ks = |
202 | DispatchKeySet(functionality_key) | DispatchKeySet(backend_key); |
203 | ASSERT_EQ(sing, expected_ks); |
204 | // These two sets should be equivalent: |
205 | // DispatchKeySet(DispatchKey::CPU).remove(DispatchKey::Dense) |
206 | // DispatchKeySet(BackendComponent::CPUBit) |
207 | expected_ks = DispatchKeySet(toBackendComponent(tid)); |
208 | ASSERT_EQ(sing.remove(tid), expected_ks); |
209 | } |
210 | } |
211 | |
212 | TEST(DispatchKeySet, DoubletonPerBackend) { |
213 | for (uint8_t i = static_cast<uint8_t>(DispatchKey::StartOfDenseBackends); |
214 | i <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys); |
215 | i++) { |
216 | for (uint8_t j = i + 1; |
217 | j <= static_cast<uint8_t>(DispatchKey::EndOfRuntimeBackendKeys); |
218 | j++) { |
219 | ASSERT_LT(i, j); |
220 | auto tid1 = static_cast<DispatchKey>(i); |
221 | auto tid2 = static_cast<DispatchKey>(j); |
222 | |
223 | // Skip these because they aren't real keys. |
224 | if (tid1 == DispatchKey::StartOfDenseBackends || |
225 | tid1 == DispatchKey::StartOfSparseBackends || |
226 | tid1 == DispatchKey::StartOfQuantizedBackends || |
227 | tid1 == DispatchKey::StartOfNestedTensorBackends || |
228 | tid1 == DispatchKey::StartOfAutogradFunctionalityBackends) |
229 | continue; |
230 | if (tid2 == DispatchKey::StartOfDenseBackends || |
231 | tid2 == DispatchKey::StartOfSparseBackends || |
232 | tid2 == DispatchKey::StartOfQuantizedBackends || |
233 | tid2 == DispatchKey::StartOfNestedTensorBackends || |
234 | tid2 == DispatchKey::StartOfAutogradFunctionalityBackends) |
235 | continue; |
236 | |
237 | auto backend1 = toBackendComponent(tid1); |
238 | auto backend2 = toBackendComponent(tid2); |
239 | auto functionality1 = toFunctionalityKey(tid1); |
240 | auto functionality2 = toFunctionalityKey(tid2); |
241 | |
242 | auto combined = DispatchKeySet({tid1, tid2}); |
243 | // The combined set has the backend bits |
244 | ASSERT_TRUE(combined.has_backend(backend1)); |
245 | ASSERT_TRUE(combined.has_backend(backend2)); |
246 | // and it has the backend bits |
247 | ASSERT_TRUE(combined.has(functionality1)); |
248 | ASSERT_TRUE(combined.has(functionality2)); |
249 | // and it has the original two runtime keys |
250 | ASSERT_TRUE(combined.has(tid1)); |
251 | ASSERT_TRUE(combined.has(tid2)); |
252 | |
253 | // Add all of the keys in the keyset to a real set |
254 | std::unordered_set<DispatchKey> visited_keys; |
255 | auto iter = combined.begin(); |
256 | while (*iter != *combined.end()) { |
257 | visited_keys.insert(*iter); |
258 | ++iter; |
259 | } |
260 | std::unordered_set<DispatchKey> expected_keys; |
261 | expected_keys.insert( |
262 | toRuntimePerBackendFunctionalityKey(functionality1, backend1)); |
263 | expected_keys.insert( |
264 | toRuntimePerBackendFunctionalityKey(functionality1, backend2)); |
265 | expected_keys.insert( |
266 | toRuntimePerBackendFunctionalityKey(functionality2, backend1)); |
267 | expected_keys.insert( |
268 | toRuntimePerBackendFunctionalityKey(functionality2, backend2)); |
269 | ASSERT_EQ(expected_keys, visited_keys); |
270 | |
271 | if (backend1 == backend2 || functionality1 == functionality2) { |
272 | // We have two runtime keys, with either the same backend or the same |
273 | // per-backend functionalities. E.g. {AutogradCUDA, CUDA} or |
274 | // {AutogradCPU, AutogradCUDA} There should be 2 total runtime keys in |
275 | // this set. |
276 | ASSERT_EQ(2, visited_keys.size()); |
277 | } else { |
278 | // since i and j are different keys, they should not have the same |
279 | // functionality and backend |
280 | ASSERT_TRUE(backend1 != backend2 && functionality1 != functionality2); |
281 | // We have two runtime keys, that have different backends + per-backend |
282 | // functionalities. So we should expect the full cross product of |
283 | // runtime keys to be in the set. e.g. if i = AutogradCUDA, and j = CPU, |
284 | // then combined = {AutogradCUDA, AutogradCPU, CUDA, CPU} |
285 | ASSERT_EQ(4, visited_keys.size()); |
286 | } |
287 | } |
288 | } |
289 | } |
290 | |
291 | TEST(DispatchKeySet, Full) { |
292 | DispatchKeySet full(DispatchKeySet::FULL); |
293 | for (const auto i : c10::irange(1, num_functionality_keys)) { |
294 | auto tid = static_cast<DispatchKey>(i); |
295 | ASSERT_TRUE(full.has(tid)); |
296 | } |
297 | ASSERT_FALSE(full.has(DispatchKey::EndOfFunctionalityKeys)); |
298 | } |
299 | |
300 | TEST(DispatchKeySet, IteratorBasicOps) { |
301 | DispatchKeySet empty_set; |
302 | DispatchKeySet full_set(DispatchKeySet::FULL); |
303 | DispatchKeySet mutated_set = empty_set.add(DispatchKey::CPU); |
304 | |
305 | // Constructor + Comparison |
306 | ASSERT_EQ(*empty_set.begin(), DispatchKey::EndOfFunctionalityKeys); |
307 | ASSERT_EQ(*empty_set.end(), DispatchKey::EndOfFunctionalityKeys); |
308 | ASSERT_EQ(*mutated_set.begin(), DispatchKey::CPU); |
309 | |
310 | ASSERT_TRUE(empty_set.begin() == empty_set.end()); |
311 | ASSERT_TRUE(full_set.begin() != full_set.end()); |
312 | |
313 | // Increment Ops |
314 | ASSERT_TRUE(full_set.begin() == full_set.begin()++); |
315 | ASSERT_TRUE(full_set.begin() != ++full_set.begin()); |
316 | } |
317 | |
318 | TEST(DispatchKeySet, getHighestPriorityBackendTypeId) { |
319 | // AutogradCPU isn't a backend key so it is ignored |
320 | DispatchKeySet dense_cpu({DispatchKey::AutogradCPU, DispatchKey::CPU}); |
321 | ASSERT_EQ(DispatchKey::CPU, c10::highestPriorityBackendTypeId(dense_cpu)); |
322 | |
323 | // Functionalize isn't a backend key so it is ignored |
324 | DispatchKeySet sparse_cuda( |
325 | {DispatchKey::Functionalize, DispatchKey::SparseCUDA}); |
326 | ASSERT_EQ( |
327 | DispatchKey::SparseCUDA, c10::highestPriorityBackendTypeId(sparse_cuda)); |
328 | |
329 | // quantizedCUDA has higher priority than CUDA |
330 | DispatchKeySet quantized_cuda( |
331 | {DispatchKey::CUDA, DispatchKey::QuantizedCUDA}); |
332 | ASSERT_EQ( |
333 | DispatchKey::QuantizedCUDA, |
334 | c10::highestPriorityBackendTypeId(quantized_cuda)); |
335 | } |
336 | |
337 | TEST(DispatchKeySet, IteratorEmpty) { |
338 | DispatchKeySet empty_set; |
339 | uint8_t i = 0; |
340 | |
341 | for (auto it = empty_set.begin(); it != empty_set.end(); ++it) { |
342 | i++; |
343 | } |
344 | ASSERT_EQ(i, 0); |
345 | } |
346 | |
347 | TEST(DispatchKeySet, IteratorCrossProduct) { |
348 | // The iterator should return all runtime keys in the set, |
349 | // including the cross product of {backends} x {functionalities} |
350 | auto ks = |
351 | DispatchKeySet({BackendComponent::CPUBit, BackendComponent::CUDABit}) | |
352 | DispatchKeySet( |
353 | {DispatchKey::Dense, |
354 | DispatchKey::FPGA, |
355 | DispatchKey::AutogradFunctionality}); |
356 | |
357 | auto iter = ks.begin(); |
358 | // iterate through dense backends first. |
359 | ASSERT_EQ(DispatchKey::CPU, *(iter++)); |
360 | ASSERT_EQ(DispatchKey::CUDA, *(iter++)); |
361 | // FPGA doesn't have a backend bit, so it isn't included in the cross product. |
362 | ASSERT_EQ(DispatchKey::FPGA, *(iter++)); |
363 | // iterate through the autograd keys laster. |
364 | ASSERT_EQ(DispatchKey::AutogradCPU, *(iter++)); |
365 | ASSERT_EQ(DispatchKey::AutogradCUDA, *(iter++)); |
366 | } |
367 | |
368 | TEST(DispatchKeySet, IteratorFull) { |
369 | DispatchKeySet full_set(DispatchKeySet::FULL); |
370 | std::ptrdiff_t count = std::distance(full_set.begin(), full_set.end()); |
371 | |
372 | // Total # of runtime entries includes an entry for DispatchKey::Undefined, |
373 | // which is not included when iterating through the DispatchKeySet. |
374 | ASSERT_EQ(count, std::ptrdiff_t{num_runtime_entries} - 1); |
375 | } |
376 | TEST(DispatchKeySet, FailAtEndIterator) { |
377 | DispatchKeySet full_set(DispatchKeySet::FULL); |
378 | uint64_t raw_repr = full_set.raw_repr(); |
379 | |
380 | // doesn't throw |
381 | DispatchKeySet::iterator(&raw_repr, num_backends + num_functionality_keys); |
382 | // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto) |
383 | EXPECT_THROW( |
384 | DispatchKeySet::iterator( |
385 | &raw_repr, num_backends + num_functionality_keys + 1), |
386 | c10::Error); |
387 | } |
388 | |
389 | TEST(DispatchKeySet, TestBackendComponentToString) { |
390 | std::unordered_set<std::string> seen_strings; |
391 | for (int64_t i = 0; |
392 | i <= static_cast<int64_t>(BackendComponent::EndOfBackendKeys); |
393 | i++) { |
394 | auto k = static_cast<BackendComponent>(i); |
395 | auto res = std::string(toString(k)); |
396 | ASSERT_FALSE(res == "UNKNOWN_BACKEND_BIT" ); |
397 | ASSERT_FALSE(seen_strings.count(res) > 0); |
398 | seen_strings.insert(res); |
399 | } |
400 | } |
401 | |
402 | TEST(DispatchKeySet, TestEndOfRuntimeBackendKeysAccurate) { |
403 | DispatchKey k; |
404 | #define SETTER(fullname, prefix) k = DispatchKey::EndOf##fullname##Backends; |
405 | C10_FORALL_FUNCTIONALITY_KEYS(SETTER) |
406 | #undef SETTER |
407 | ASSERT_TRUE(k == DispatchKey::EndOfRuntimeBackendKeys); |
408 | } |
409 | |
410 | TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) { |
411 | std::unordered_set<std::string> seen_strings; |
412 | for (int i = 0; i <= static_cast<int>(DispatchKey::EndOfAliasKeys); i++) { |
413 | auto k = static_cast<DispatchKey>(i); |
414 | // These synthetic keys never actually get used and don't need |
415 | // to be printed |
416 | if (k == DispatchKey::EndOfFunctionalityKeys || |
417 | k == DispatchKey::StartOfDenseBackends || |
418 | k == DispatchKey::StartOfQuantizedBackends || |
419 | k == DispatchKey::StartOfSparseBackends || |
420 | k == DispatchKey::StartOfNestedTensorBackends || |
421 | k == DispatchKey::StartOfAutogradFunctionalityBackends) |
422 | continue; |
423 | auto res = std::string(toString(k)); |
424 | ASSERT_TRUE(res.find("Unknown" ) == std::string::npos) |
425 | << i << " (before is " << toString(static_cast<DispatchKey>(i - 1)) |
426 | << ")" ; |
427 | ASSERT_TRUE(seen_strings.count(res) == 0); |
428 | seen_strings.insert(res); |
429 | } |
430 | } |
431 | |