1 | #pragma once |
2 | |
3 | #include <c10/core/ScalarType.h> |
4 | #include <c10/macros/Macros.h> |
5 | #include <c10/util/Load.h> |
6 | #include <c10/util/TypeCast.h> |
7 | |
8 | namespace c10 { |
9 | |
10 | // Dynamic type casting utils: |
11 | // - fetch_and_cast |
12 | // - cast_and_store |
13 | // |
14 | // fetch_and_cast fetch a value with dynamic type specified by a ScalarType |
15 | // from a void pointer and cast it to a static type. |
16 | // |
17 | // cast_and_store casts a static typed value into dynamic type specified |
18 | // by a ScalarType, and store it into a void pointer. |
19 | // |
20 | // NOTE: |
21 | // |
22 | // Dynamic casting allows us to support type promotion without blowing up |
23 | // the combination space: For example, without dynamic cast, in order to |
24 | // implement `add_` with type promotion, we would need something like |
25 | // |
26 | // AT_DISPATCH_ALL_TYPES(output.dtype(), |
27 | // AT_DISPATCH_ALL_TYPES(input1.dtype(), |
28 | // AT_DISPATCH_ALL_TYPES(input2.dtype(), |
29 | // [](arg0_t a, arg1_t b) -> out_t { return a + b; } |
30 | // ) |
31 | // ) |
32 | // ) |
33 | // |
34 | // If we support N dtypes, the above code would generate the a+b kernel for |
35 | // all the N * N * N different supported types, the compilation time and |
36 | // binary size would become horrible. |
37 | // |
38 | // Dynamic casting might sounds like a bad idea in terms of performance. |
39 | // Especially if you ever do it in a loop, you are going to do a billion tests. |
40 | // But in practice it is not as bad as it might look: |
41 | // |
42 | // - on CPU, this is a branch that always has the same outcome, therefore |
43 | // hopefully the branch predictor could do the job pretty well |
44 | // - on GPU, these branches will not diverge, so we could still have the same |
45 | // warp executing the same line of code |
46 | // - Most kernels, like `add`, are bandwidth bound, adding a few clock cycles to |
47 | // check an integer does not hurt the performance much because the ALUs would |
48 | // wait for load instructions anyway. |
49 | // |
50 | // For the discussion and benchmark, refer to: |
51 | // - https://github.com/pytorch/pytorch/pull/28343 |
52 | // - https://github.com/pytorch/pytorch/pull/28344 |
53 | // - https://github.com/pytorch/pytorch/pull/28345 |
54 | // |
55 | |
56 | #ifdef C10_HOST_DEVICE |
57 | #define ERROR_UNSUPPORTED_CAST CUDA_KERNEL_ASSERT(false); |
58 | #else |
59 | #define ERROR_UNSUPPORTED_CAST TORCH_CHECK(false, "Unexpected scalar type"); |
60 | #endif |
61 | |
62 | // Fetch a value with dynamic type src_type from ptr, and cast it to static type |
63 | // dest_t. |
64 | #define FETCH_AND_CAST_CASE(type, scalartype) \ |
65 | case ScalarType::scalartype: \ |
66 | return c10::convert<dest_t>(c10::load<type>(ptr)); |
67 | |
68 | template <typename dest_t> |
69 | C10_HOST_DEVICE inline dest_t fetch_and_cast( |
70 | const ScalarType src_type, |
71 | const void* ptr) { |
72 | switch (src_type) { |
73 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(FETCH_AND_CAST_CASE) |
74 | default: |
75 | ERROR_UNSUPPORTED_CAST |
76 | } |
77 | return dest_t(0); // just to avoid compiler warning |
78 | } |
79 | |
80 | // Cast a value with static type src_t into dynamic dest_type, and store it to |
81 | // ptr. |
82 | #define CAST_AND_STORE_CASE(type, scalartype) \ |
83 | case ScalarType::scalartype: \ |
84 | *(type*)ptr = c10::convert<type>(value); \ |
85 | return; |
86 | template <typename src_t> |
87 | C10_HOST_DEVICE inline void cast_and_store( |
88 | const ScalarType dest_type, |
89 | void* ptr, |
90 | src_t value) { |
91 | switch (dest_type) { |
92 | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(CAST_AND_STORE_CASE) |
93 | default:; |
94 | } |
95 | ERROR_UNSUPPORTED_CAST |
96 | } |
97 | |
98 | #define DEFINE_UNCASTABLE(T, scalartype_) \ |
99 | template <> \ |
100 | C10_HOST_DEVICE inline T fetch_and_cast<T>( \ |
101 | const ScalarType src_type, const void* ptr) { \ |
102 | CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == src_type); \ |
103 | return c10::load<T>(ptr); \ |
104 | } \ |
105 | template <> \ |
106 | C10_HOST_DEVICE inline void cast_and_store<T>( \ |
107 | const ScalarType dest_type, void* ptr, T value) { \ |
108 | CUDA_KERNEL_ASSERT(ScalarType::scalartype_ == dest_type); \ |
109 | *(T*)ptr = value; \ |
110 | } |
111 | |
112 | AT_FORALL_QINT_TYPES(DEFINE_UNCASTABLE) |
113 | |
114 | #undef FETCH_AND_CAST_CASE |
115 | #undef CAST_AND_STORE_CASE |
116 | #undef DEFINE_UNCASTABLE |
117 | #undef ERROR_UNSUPPORTED_CAST |
118 | |
119 | } // namespace c10 |
120 | |