1#pragma once
2
3#include <c10/core/DeviceType.h>
4#include <c10/macros/Export.h>
5
6#include <atomic>
7#include <utility>
8
9// Implements instruction set specific function dispatch.
10//
11// Kernels that may make use of specialized instruction sets (e.g. AVX2) are
12// compiled multiple times with different compiler flags (e.g. -mavx2). A
13// DispatchStub contains a table of function pointers for a kernel. At runtime,
14// the fastest available kernel is chosen based on the features reported by
15// cpuinfo.
16//
17// Example:
18//
19// In native/MyKernel.h:
20// using fn_type = void(*)(const Tensor& x);
21// DECLARE_DISPATCH(fn_type, stub);
22//
23// In native/MyKernel.cpp
24// DEFINE_DISPATCH(stub);
25//
26// In native/cpu/MyKernel.cpp:
27// namespace {
28// // use anonymous namespace so that different cpu versions won't conflict
29// void kernel(const Tensor& x) { ... }
30// }
31// REGISTER_DISPATCH(stub, &kernel);
32//
33// To call:
34// stub(kCPU, tensor);
35//
36// TODO: CPU instruction set selection should be folded into whatever
37// the main dispatch mechanism is.
38
39// ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere
40#if defined(__clang__)
41#pragma clang diagnostic push
42#pragma clang diagnostic ignored "-Wundefined-var-template"
43#endif
44
45namespace at { namespace native {
46
47enum class CPUCapability {
48 DEFAULT = 0,
49#if defined(HAVE_VSX_CPU_DEFINITION)
50 VSX = 1,
51#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
52 ZVECTOR = 1,
53#else
54 AVX2 = 1,
55 AVX512 = 2,
56#endif
57 NUM_OPTIONS
58};
59
60CPUCapability get_cpu_capability();
61
62template <typename FnPtr, typename T>
63struct DispatchStub;
64
65/**
66 * The sole purpose of this class is to outline methods that don't need to be
67 * specialized or otherwise inlined and duplicated (by the compiler due to
68 * template expansion), since it causes size bloat if there are a significant
69 * number of specialization of the DispatchStub<> class.
70 */
71struct TORCH_API DispatchStubImpl {
72 void* get_call_ptr(
73 DeviceType device_type
74 , void *DEFAULT
75#ifdef HAVE_AVX512_CPU_DEFINITION
76 , void *AVX512
77#endif
78#ifdef HAVE_AVX2_CPU_DEFINITION
79 , void *AVX2
80#endif
81#ifdef HAVE_VSX_CPU_DEFINITION
82 , void *VSX
83#endif
84#ifdef HAVE_ZVECTOR_CPU_DEFINITION
85 , void *ZVECTOR
86#endif
87 );
88
89 /**
90 * The CPU Dispatch actual method is chosen in decreasing order of preference by
91 * DispatchStubImpl::choose_cpu_impl() in case none is found by
92 * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr.
93 */
94 void* choose_cpu_impl(
95 void *DEFAULT
96#ifdef HAVE_AVX512_CPU_DEFINITION
97 , void *AVX512
98#endif
99#ifdef HAVE_AVX2_CPU_DEFINITION
100 , void *AVX2
101#endif
102#ifdef HAVE_VSX_CPU_DEFINITION
103 , void *VSX
104#endif
105#ifdef HAVE_ZVECTOR_CPU_DEFINITION
106 , void *ZVECTOR
107#endif
108 );
109
110 // Fixing dispatch error in Windows debug builds.
111 // See https://github.com/pytorch/pytorch/issues/22681 for more details.
112 #if defined(_MSC_VER) && defined(_DEBUG)
113 std::atomic<void*> cpu_dispatch_ptr;
114 void* cuda_dispatch_ptr;
115 void* hip_dispatch_ptr;
116 void* mps_dispatch_ptr;
117 #else
118 std::atomic<void*> cpu_dispatch_ptr{nullptr};
119 void* cuda_dispatch_ptr = nullptr;
120 void* hip_dispatch_ptr = nullptr;
121 void* mps_dispatch_ptr = nullptr;
122 #endif
123};
124
125template <typename rT, typename T, typename... Args>
126struct DispatchStub<rT (*)(Args...), T> {
127 using FnPtr = rT (*) (Args...);
128
129 DispatchStub() = default;
130 DispatchStub(const DispatchStub&) = delete;
131 DispatchStub& operator=(const DispatchStub&) = delete;
132
133private:
134 FnPtr get_call_ptr(DeviceType device_type) {
135 return reinterpret_cast<FnPtr>(
136 impl.get_call_ptr(device_type
137 , reinterpret_cast<void*>(DEFAULT)
138#ifdef HAVE_AVX512_CPU_DEFINITION
139 , reinterpret_cast<void*>(AVX512)
140#endif
141#ifdef HAVE_AVX2_CPU_DEFINITION
142 , reinterpret_cast<void*>(AVX2)
143#endif
144#ifdef HAVE_VSX_CPU_DEFINITION
145 , reinterpret_cast<void*>(VSX)
146#endif
147#ifdef HAVE_ZVECTOR_CPU_DEFINITION
148 , reinterpret_cast<void*>(ZVECTOR)
149#endif
150 )
151 );
152 }
153
154public:
155 template <typename... ArgTypes>
156 rT operator()(DeviceType device_type, ArgTypes&&... args) {
157 FnPtr call_ptr = get_call_ptr(device_type);
158 return (*call_ptr)(std::forward<ArgTypes>(args)...);
159 }
160
161 void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
162 impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
163 }
164
165 void set_hip_dispatch_ptr(FnPtr fn_ptr) {
166 impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
167 }
168
169 void set_mps_dispatch_ptr(FnPtr fn_ptr) {
170 impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
171 }
172
173 static TORCH_API FnPtr DEFAULT;
174#ifdef HAVE_AVX512_CPU_DEFINITION
175 static TORCH_API FnPtr AVX512;
176#endif
177#ifdef HAVE_AVX2_CPU_DEFINITION
178 static TORCH_API FnPtr AVX2;
179#endif
180#ifdef HAVE_VSX_CPU_DEFINITION
181 static TORCH_API FnPtr VSX;
182#endif
183#ifdef HAVE_ZVECTOR_CPU_DEFINITION
184 static TORCH_API FnPtr ZVECTOR;
185#endif
186private:
187 DispatchStubImpl impl;
188};
189
190namespace {
191template <typename DispatchStub>
192struct RegisterCUDADispatch {
193 RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
194 stub.set_cuda_dispatch_ptr(value);
195 }
196};
197
198template <typename DispatchStub>
199struct RegisterMPSDispatch {
200 RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
201 stub.set_mps_dispatch_ptr(value);
202 }
203};
204
205template <typename DispatchStub>
206struct RegisterHIPDispatch {
207 RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
208 // TODO: make this point at hip_dispatch_ptr
209 stub.set_cuda_dispatch_ptr(value);
210 }
211};
212
213} // anonymous namespace
214// Compiler will complain if you put things like std::tuple<Tensor, Tensor> in
215// the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g.,
216// adding parentheses and using helper struct to get rid of the parentheses, do
217// not work with MSVC. So do a `using`-declaration if you need to pass in such
218// `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h.
219#define DECLARE_DISPATCH(fn, name) \
220 struct name : DispatchStub<fn, name> { \
221 name() = default; \
222 name(const name&) = delete; \
223 name& operator=(const name&) = delete; \
224 }; \
225 extern TORCH_API struct name name
226
227#define DEFINE_DISPATCH(name) struct name name
228
229#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
230 template <> name::FnPtr TORCH_API DispatchStub<name::FnPtr, struct name>::arch = fn;
231
232#ifdef HAVE_AVX512_CPU_DEFINITION
233#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
234#else
235#define REGISTER_AVX512_DISPATCH(name, fn)
236#endif
237
238#ifdef HAVE_AVX2_CPU_DEFINITION
239#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
240#else
241#define REGISTER_AVX2_DISPATCH(name, fn)
242#endif
243
244#ifdef HAVE_VSX_CPU_DEFINITION
245#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
246#else
247#define REGISTER_VSX_DISPATCH(name, fn)
248#endif
249
250#ifdef HAVE_ZVECTOR_CPU_DEFINITION
251#define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
252#else
253#define REGISTER_ZVECTOR_DISPATCH(name, fn)
254#endif
255
256// Macro to register the same kernel for all CPU arch types. This is useful
257// if a kernel does not benefit from being recompiled across different arch types.
258#define REGISTER_ALL_CPU_DISPATCH(name, fn) \
259 REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
260 REGISTER_AVX512_DISPATCH(name, fn) \
261 REGISTER_AVX2_DISPATCH(name, fn) \
262 REGISTER_VSX_DISPATCH(name, fn) \
263 REGISTER_ZVECTOR_DISPATCH(name, fn)
264
265#define REGISTER_NO_CPU_DISPATCH(name) \
266 REGISTER_ALL_CPU_DISPATCH(name, nullptr)
267
268#define REGISTER_CUDA_DISPATCH(name, fn) \
269 static RegisterCUDADispatch<struct name> name ## __register(name, fn);
270
271#define REGISTER_HIP_DISPATCH(name, fn) \
272 static RegisterHIPDispatch<struct name> name ## __register(name, fn);
273
274#define REGISTER_MPS_DISPATCH(name, fn) \
275 static RegisterMPSDispatch<struct name> name ## __register(name, fn);
276
277// NB: This macro must be used in an actual 'cu' file; if you try using
278// it from a 'cpp' file it will not work!
279#if defined(__CUDACC__)
280#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
281#elif defined(__HIPCC__)
282// TODO: cut this over to HIP dispatch once we stop pretending that CUDA
283// is HIP in the PyTorch HIPify build.
284#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
285// #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn)
286#elif defined(__OBJC__) && defined(USE_MPS)
287// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
288#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
289#elif defined(CPU_CAPABILITY)
290#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
291#define REGISTER_NO_AVX512_DISPATCH(name) \
292 REGISTER_AVX512_DISPATCH(name, nullptr)
293#endif
294
295
296}} // namespace at::native
297
298
299#if defined(__clang__)
300#pragma clang diagnostic pop
301#endif
302