1 | #pragma once |
2 | |
3 | #include <c10/core/DeviceType.h> |
4 | #include <c10/macros/Export.h> |
5 | |
6 | #include <atomic> |
7 | #include <utility> |
8 | |
9 | // Implements instruction set specific function dispatch. |
10 | // |
11 | // Kernels that may make use of specialized instruction sets (e.g. AVX2) are |
12 | // compiled multiple times with different compiler flags (e.g. -mavx2). A |
13 | // DispatchStub contains a table of function pointers for a kernel. At runtime, |
14 | // the fastest available kernel is chosen based on the features reported by |
15 | // cpuinfo. |
16 | // |
17 | // Example: |
18 | // |
19 | // In native/MyKernel.h: |
20 | // using fn_type = void(*)(const Tensor& x); |
21 | // DECLARE_DISPATCH(fn_type, stub); |
22 | // |
23 | // In native/MyKernel.cpp |
24 | // DEFINE_DISPATCH(stub); |
25 | // |
26 | // In native/cpu/MyKernel.cpp: |
27 | // namespace { |
28 | // // use anonymous namespace so that different cpu versions won't conflict |
29 | // void kernel(const Tensor& x) { ... } |
30 | // } |
31 | // REGISTER_DISPATCH(stub, &kernel); |
32 | // |
33 | // To call: |
34 | // stub(kCPU, tensor); |
35 | // |
36 | // TODO: CPU instruction set selection should be folded into whatever |
37 | // the main dispatch mechanism is. |
38 | |
39 | // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere |
40 | #if defined(__clang__) |
41 | #pragma clang diagnostic push |
42 | #pragma clang diagnostic ignored "-Wundefined-var-template" |
43 | #endif |
44 | |
45 | namespace at { namespace native { |
46 | |
47 | enum class CPUCapability { |
48 | DEFAULT = 0, |
49 | #if defined(HAVE_VSX_CPU_DEFINITION) |
50 | VSX = 1, |
51 | #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) |
52 | ZVECTOR = 1, |
53 | #else |
54 | AVX2 = 1, |
55 | AVX512 = 2, |
56 | #endif |
57 | NUM_OPTIONS |
58 | }; |
59 | |
60 | CPUCapability get_cpu_capability(); |
61 | |
62 | template <typename FnPtr, typename T> |
63 | struct DispatchStub; |
64 | |
65 | /** |
66 | * The sole purpose of this class is to outline methods that don't need to be |
67 | * specialized or otherwise inlined and duplicated (by the compiler due to |
68 | * template expansion), since it causes size bloat if there are a significant |
69 | * number of specialization of the DispatchStub<> class. |
70 | */ |
71 | struct TORCH_API DispatchStubImpl { |
72 | void* get_call_ptr( |
73 | DeviceType device_type |
74 | , void *DEFAULT |
75 | #ifdef HAVE_AVX512_CPU_DEFINITION |
76 | , void *AVX512 |
77 | #endif |
78 | #ifdef HAVE_AVX2_CPU_DEFINITION |
79 | , void *AVX2 |
80 | #endif |
81 | #ifdef HAVE_VSX_CPU_DEFINITION |
82 | , void *VSX |
83 | #endif |
84 | #ifdef HAVE_ZVECTOR_CPU_DEFINITION |
85 | , void *ZVECTOR |
86 | #endif |
87 | ); |
88 | |
89 | /** |
90 | * The CPU Dispatch actual method is chosen in decreasing order of preference by |
91 | * DispatchStubImpl::choose_cpu_impl() in case none is found by |
92 | * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr. |
93 | */ |
94 | void* choose_cpu_impl( |
95 | void *DEFAULT |
96 | #ifdef HAVE_AVX512_CPU_DEFINITION |
97 | , void *AVX512 |
98 | #endif |
99 | #ifdef HAVE_AVX2_CPU_DEFINITION |
100 | , void *AVX2 |
101 | #endif |
102 | #ifdef HAVE_VSX_CPU_DEFINITION |
103 | , void *VSX |
104 | #endif |
105 | #ifdef HAVE_ZVECTOR_CPU_DEFINITION |
106 | , void *ZVECTOR |
107 | #endif |
108 | ); |
109 | |
110 | // Fixing dispatch error in Windows debug builds. |
111 | // See https://github.com/pytorch/pytorch/issues/22681 for more details. |
112 | #if defined(_MSC_VER) && defined(_DEBUG) |
113 | std::atomic<void*> cpu_dispatch_ptr; |
114 | void* cuda_dispatch_ptr; |
115 | void* hip_dispatch_ptr; |
116 | void* mps_dispatch_ptr; |
117 | #else |
118 | std::atomic<void*> cpu_dispatch_ptr{nullptr}; |
119 | void* cuda_dispatch_ptr = nullptr; |
120 | void* hip_dispatch_ptr = nullptr; |
121 | void* mps_dispatch_ptr = nullptr; |
122 | #endif |
123 | }; |
124 | |
125 | template <typename rT, typename T, typename... Args> |
126 | struct DispatchStub<rT (*)(Args...), T> { |
127 | using FnPtr = rT (*) (Args...); |
128 | |
129 | DispatchStub() = default; |
130 | DispatchStub(const DispatchStub&) = delete; |
131 | DispatchStub& operator=(const DispatchStub&) = delete; |
132 | |
133 | private: |
134 | FnPtr get_call_ptr(DeviceType device_type) { |
135 | return reinterpret_cast<FnPtr>( |
136 | impl.get_call_ptr(device_type |
137 | , reinterpret_cast<void*>(DEFAULT) |
138 | #ifdef HAVE_AVX512_CPU_DEFINITION |
139 | , reinterpret_cast<void*>(AVX512) |
140 | #endif |
141 | #ifdef HAVE_AVX2_CPU_DEFINITION |
142 | , reinterpret_cast<void*>(AVX2) |
143 | #endif |
144 | #ifdef HAVE_VSX_CPU_DEFINITION |
145 | , reinterpret_cast<void*>(VSX) |
146 | #endif |
147 | #ifdef HAVE_ZVECTOR_CPU_DEFINITION |
148 | , reinterpret_cast<void*>(ZVECTOR) |
149 | #endif |
150 | ) |
151 | ); |
152 | } |
153 | |
154 | public: |
155 | template <typename... ArgTypes> |
156 | rT operator()(DeviceType device_type, ArgTypes&&... args) { |
157 | FnPtr call_ptr = get_call_ptr(device_type); |
158 | return (*call_ptr)(std::forward<ArgTypes>(args)...); |
159 | } |
160 | |
161 | void set_cuda_dispatch_ptr(FnPtr fn_ptr) { |
162 | impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); |
163 | } |
164 | |
165 | void set_hip_dispatch_ptr(FnPtr fn_ptr) { |
166 | impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); |
167 | } |
168 | |
169 | void set_mps_dispatch_ptr(FnPtr fn_ptr) { |
170 | impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); |
171 | } |
172 | |
173 | static TORCH_API FnPtr DEFAULT; |
174 | #ifdef HAVE_AVX512_CPU_DEFINITION |
175 | static TORCH_API FnPtr AVX512; |
176 | #endif |
177 | #ifdef HAVE_AVX2_CPU_DEFINITION |
178 | static TORCH_API FnPtr AVX2; |
179 | #endif |
180 | #ifdef HAVE_VSX_CPU_DEFINITION |
181 | static TORCH_API FnPtr VSX; |
182 | #endif |
183 | #ifdef HAVE_ZVECTOR_CPU_DEFINITION |
184 | static TORCH_API FnPtr ZVECTOR; |
185 | #endif |
186 | private: |
187 | DispatchStubImpl impl; |
188 | }; |
189 | |
190 | namespace { |
191 | template <typename DispatchStub> |
192 | struct RegisterCUDADispatch { |
193 | RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { |
194 | stub.set_cuda_dispatch_ptr(value); |
195 | } |
196 | }; |
197 | |
198 | template <typename DispatchStub> |
199 | struct RegisterMPSDispatch { |
200 | RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { |
201 | stub.set_mps_dispatch_ptr(value); |
202 | } |
203 | }; |
204 | |
205 | template <typename DispatchStub> |
206 | struct RegisterHIPDispatch { |
207 | RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { |
208 | // TODO: make this point at hip_dispatch_ptr |
209 | stub.set_cuda_dispatch_ptr(value); |
210 | } |
211 | }; |
212 | |
213 | } // anonymous namespace |
214 | // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in |
215 | // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g., |
216 | // adding parentheses and using helper struct to get rid of the parentheses, do |
217 | // not work with MSVC. So do a `using`-declaration if you need to pass in such |
218 | // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h. |
219 | #define DECLARE_DISPATCH(fn, name) \ |
220 | struct name : DispatchStub<fn, name> { \ |
221 | name() = default; \ |
222 | name(const name&) = delete; \ |
223 | name& operator=(const name&) = delete; \ |
224 | }; \ |
225 | extern TORCH_API struct name name |
226 | |
227 | #define DEFINE_DISPATCH(name) struct name name |
228 | |
229 | #define REGISTER_ARCH_DISPATCH(name, arch, fn) \ |
230 | template <> name::FnPtr TORCH_API DispatchStub<name::FnPtr, struct name>::arch = fn; |
231 | |
232 | #ifdef HAVE_AVX512_CPU_DEFINITION |
233 | #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn) |
234 | #else |
235 | #define REGISTER_AVX512_DISPATCH(name, fn) |
236 | #endif |
237 | |
238 | #ifdef HAVE_AVX2_CPU_DEFINITION |
239 | #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn) |
240 | #else |
241 | #define REGISTER_AVX2_DISPATCH(name, fn) |
242 | #endif |
243 | |
244 | #ifdef HAVE_VSX_CPU_DEFINITION |
245 | #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn) |
246 | #else |
247 | #define REGISTER_VSX_DISPATCH(name, fn) |
248 | #endif |
249 | |
250 | #ifdef HAVE_ZVECTOR_CPU_DEFINITION |
251 | #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn) |
252 | #else |
253 | #define REGISTER_ZVECTOR_DISPATCH(name, fn) |
254 | #endif |
255 | |
256 | // Macro to register the same kernel for all CPU arch types. This is useful |
257 | // if a kernel does not benefit from being recompiled across different arch types. |
258 | #define REGISTER_ALL_CPU_DISPATCH(name, fn) \ |
259 | REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \ |
260 | REGISTER_AVX512_DISPATCH(name, fn) \ |
261 | REGISTER_AVX2_DISPATCH(name, fn) \ |
262 | REGISTER_VSX_DISPATCH(name, fn) \ |
263 | REGISTER_ZVECTOR_DISPATCH(name, fn) |
264 | |
265 | #define REGISTER_NO_CPU_DISPATCH(name) \ |
266 | REGISTER_ALL_CPU_DISPATCH(name, nullptr) |
267 | |
268 | #define REGISTER_CUDA_DISPATCH(name, fn) \ |
269 | static RegisterCUDADispatch<struct name> name ## __register(name, fn); |
270 | |
271 | #define REGISTER_HIP_DISPATCH(name, fn) \ |
272 | static RegisterHIPDispatch<struct name> name ## __register(name, fn); |
273 | |
274 | #define REGISTER_MPS_DISPATCH(name, fn) \ |
275 | static RegisterMPSDispatch<struct name> name ## __register(name, fn); |
276 | |
277 | // NB: This macro must be used in an actual 'cu' file; if you try using |
278 | // it from a 'cpp' file it will not work! |
279 | #if defined(__CUDACC__) |
280 | #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) |
281 | #elif defined(__HIPCC__) |
282 | // TODO: cut this over to HIP dispatch once we stop pretending that CUDA |
283 | // is HIP in the PyTorch HIPify build. |
284 | #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) |
285 | // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn) |
286 | #elif defined(__OBJC__) && defined(USE_MPS) |
287 | // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel |
288 | #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn) |
289 | #elif defined(CPU_CAPABILITY) |
290 | #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) |
291 | #define REGISTER_NO_AVX512_DISPATCH(name) \ |
292 | REGISTER_AVX512_DISPATCH(name, nullptr) |
293 | #endif |
294 | |
295 | |
296 | }} // namespace at::native |
297 | |
298 | |
299 | #if defined(__clang__) |
300 | #pragma clang diagnostic pop |
301 | #endif |
302 | |