1 | #pragma once |
2 | |
3 | #ifndef _TRITON_DRIVER_DISPATCH_H_ |
4 | #define _TRITON_DRIVER_DISPATCH_H_ |
5 | |
6 | #include <type_traits> |
7 | #include <dlfcn.h> |
8 | |
9 | //CUDA Backend |
10 | #include "triton/external/CUDA/cuda.h" |
11 | #include "triton/external/CUDA/nvml.h" |
12 | |
13 | //// HIP backend |
14 | //#define __HIP_PLATFORM_AMD__ |
15 | #include "triton/external/hip.h" |
16 | |
17 | //Exceptions |
18 | #include <iostream> |
19 | #include <stdexcept> |
20 | |
21 | namespace llvm { |
22 | class PassRegistry; |
23 | class Module; |
24 | } |
25 | |
26 | namespace triton |
27 | { |
28 | namespace driver |
29 | { |
30 | |
31 | class cu_context; |
32 | |
33 | template<class T> void check(T){} |
34 | void check(CUresult err); |
35 | void check(hipError_t err); |
36 | |
37 | class dispatch |
38 | { |
39 | protected: |
40 | template <class F> |
41 | struct return_type; |
42 | |
43 | template <class R, class... A> |
44 | struct return_type<R (*)(A...)> |
45 | { typedef R type; }; |
46 | |
47 | typedef bool (*f_init_t)(); |
48 | |
49 | template<f_init_t initializer, typename FunPtrT, typename... Args> |
50 | static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args) |
51 | { |
52 | initializer(); |
53 | if(cache == nullptr){ |
54 | cache = dlsym(lib_h, name); |
55 | if(cache == 0) |
56 | throw std::runtime_error("dlsym unable to load function" ); |
57 | } |
58 | FunPtrT fptr; |
59 | *reinterpret_cast<void **>(&fptr) = cache; |
60 | typename return_type<FunPtrT>::type res = (*fptr)(args...); |
61 | check(res); |
62 | return res; |
63 | } |
64 | |
65 | public: |
66 | static void release(); |
67 | // Nvidia |
68 | static bool nvmlinit(); |
69 | static bool cuinit(); |
70 | // AMD |
71 | static bool hipinit(); |
72 | |
73 | /* ------------------- * |
74 | * CUDA |
75 | * ------------------- */ |
76 | // context management |
77 | static CUresult cuInit(unsigned int Flags); |
78 | static CUresult cuCtxDestroy_v2(CUcontext ctx); |
79 | static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev); |
80 | static CUresult cuCtxPushCurrent_v2(CUcontext ctx); |
81 | static CUresult cuCtxPopCurrent_v2(CUcontext *pctx); |
82 | static CUresult cuCtxGetDevice(CUdevice* result); |
83 | static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags); |
84 | static CUresult cuDriverGetVersion(int *driverVersion); |
85 | // device management |
86 | static CUresult cuDeviceGet(CUdevice *device, int ordinal); |
87 | static CUresult cuDeviceGetName(char *name, int len, CUdevice dev); |
88 | static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev); |
89 | static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); |
90 | static CUresult cuDeviceGetCount(int *count); |
91 | // link management |
92 | static CUresult cuLinkAddFile_v2(CUlinkState state, CUjitInputType type, const char *path, unsigned int numOptions, CUjit_option *options, void **optionValues); |
93 | static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues); |
94 | static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut); |
95 | static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut); |
96 | static CUresult cuLinkDestroy(CUlinkState state); |
97 | // module management |
98 | static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name); |
99 | static CUresult cuModuleLoad(CUmodule *module, const char *fname); |
100 | static CUresult cuModuleLoadData(CUmodule* module, const void* image); |
101 | static CUresult cuModuleUnload(CUmodule hmod); |
102 | static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues); |
103 | static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name); |
104 | // stream management |
105 | static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags); |
106 | static CUresult cuStreamSynchronize(CUstream hStream); |
107 | static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx); |
108 | static CUresult cuStreamDestroy_v2(CUstream hStream); |
109 | static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **); |
110 | // function management |
111 | static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc); |
112 | static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); |
113 | static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); |
114 | // memory management |
115 | static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize); |
116 | static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr); |
117 | static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream); |
118 | static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount); |
119 | static CUresult cuMemFree_v2(CUdeviceptr dptr); |
120 | static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream); |
121 | static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream); |
122 | static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); |
123 | // event management |
124 | static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags); |
125 | static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd); |
126 | static CUresult cuEventRecord(CUevent hEvent, CUstream hStream); |
127 | static CUresult cuEventDestroy_v2(CUevent hEvent); |
128 | |
129 | |
130 | /* ------------------- * |
131 | * NVML |
132 | * ------------------- */ |
133 | static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device); |
134 | static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); |
135 | static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock); |
136 | static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock); |
137 | |
138 | /* ------------------- * |
139 | * HIP |
140 | * ------------------- */ |
141 | // context management |
142 | static hipError_t hipInit(unsigned int Flags); |
143 | static hipError_t hipCtxDestroy(hipCtx_t ctx); |
144 | static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, hipDevice_t dev); |
145 | static hipError_t hipCtxPushCurrent(hipCtx_t ctx); |
146 | static hipError_t hipCtxPopCurrent(hipCtx_t *pctx); |
147 | static hipError_t hipCtxGetDevice(hipDevice_t* result); |
148 | static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, unsigned int flags); |
149 | static hipError_t hipDriverGetVersion(int *driverVersion); |
150 | // device management |
151 | static hipError_t hipGetDevice(hipDevice_t *device, int ordinal); |
152 | static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev); |
153 | static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev); |
154 | static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, hipDevice_t dev); |
155 | static hipError_t hipGetDeviceCount(int *count); |
156 | // module management |
157 | static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t* bytes, hipModule_t hmod, const char *name); |
158 | static hipError_t hipModuleLoad(hipModule_t *module, const char *fname); |
159 | static hipError_t hipModuleLoadData(hipModule_t* module, const void* image); |
160 | static hipError_t hipModuleUnload(hipModule_t hmod); |
161 | static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, unsigned int numOptions, hipJitOption *options, void **optionValues); |
162 | static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, const char *name); |
163 | // stream management |
164 | static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags); |
165 | static hipError_t hipStreamSynchronize(hipStream_t hStream); |
166 | static hipError_t hipStreamDestroy(hipStream_t hStream); |
167 | static hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t hStream, void **kernelParams, void **); |
168 | // function management |
169 | static hipError_t hipFuncGetAttributes(hipFuncAttributes* attrib, void* hfunc); |
170 | static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, hipFuncAttribute attrib, int value); |
171 | static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, hipFuncCache_t config); |
172 | // memory management |
173 | static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize); |
174 | static hipError_t hipPointerGetAttribute(void * data, CUpointer_attribute attribute, hipDeviceptr_t ptr); |
175 | static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, size_t N, hipStream_t stream); |
176 | static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount); |
177 | static hipError_t hipFree(hipDeviceptr_t dptr); |
178 | static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount, hipStream_t hStream); |
179 | static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount, hipStream_t hStream); |
180 | static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount); |
181 | // event management |
182 | static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags); |
183 | static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd); |
184 | static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream); |
185 | static hipError_t hipEventDestroy(hipEvent_t hEvent); |
186 | |
187 | |
188 | |
189 | private: |
190 | |
191 | // Libraries |
192 | static void* cuda_; |
193 | static void* nvml_; |
194 | static void* hip_; |
195 | |
196 | |
197 | /* ------------------- * |
198 | * CUDA |
199 | * ------------------- */ |
200 | // context management |
201 | static void* cuCtxGetCurrent_; |
202 | static void* cuCtxSetCurrent_; |
203 | static void* cuCtxDestroy_v2_; |
204 | static void* cuCtxCreate_v2_; |
205 | static void* cuCtxGetDevice_; |
206 | static void* cuCtxPushCurrent_v2_; |
207 | static void* cuCtxPopCurrent_v2_; |
208 | static void* cuCtxEnablePeerAccess_; |
209 | static void* cuDriverGetVersion_; |
210 | static void* cuInit_; |
211 | // device management |
212 | static void* cuDeviceGet_; |
213 | static void* cuDeviceGetName_; |
214 | static void* cuDeviceGetPCIBusId_; |
215 | static void* cuDeviceGetAttribute_; |
216 | static void* cuDeviceGetCount_; |
217 | // link management |
218 | static void* cuLinkAddFile_v2_; |
219 | static void* cuLinkAddData_v2_; |
220 | static void* cuLinkCreate_v2_; |
221 | static void* cuLinkDestroy_; |
222 | static void* cuLinkComplete_; |
223 | // module management |
224 | static void* cuModuleGetGlobal_v2_; |
225 | static void* cuModuleLoad_; |
226 | static void* cuModuleUnload_; |
227 | static void* cuModuleLoadDataEx_; |
228 | static void* cuModuleLoadData_; |
229 | static void* cuModuleGetFunction_; |
230 | // stream management |
231 | static void* cuStreamCreate_; |
232 | static void* cuStreamSynchronize_; |
233 | static void* cuStreamDestroy_v2_; |
234 | static void* cuStreamGetCtx_; |
235 | static void* cuLaunchKernel_; |
236 | // function management |
237 | static void* cuFuncGetAttribute_; |
238 | static void* cuFuncSetAttribute_; |
239 | static void* cuFuncSetCacheConfig_; |
240 | // memory management |
241 | static void* cuMemcpyDtoH_v2_; |
242 | static void* cuMemFree_v2_; |
243 | static void* cuMemcpyDtoHAsync_v2_; |
244 | static void* cuMemcpyHtoDAsync_v2_; |
245 | static void* cuMemcpyHtoD_v2_; |
246 | static void* cuMemAlloc_v2_; |
247 | static void* cuMemsetD8Async_; |
248 | static void* cuPointerGetAttribute_; |
249 | // event management |
250 | static void* cuEventCreate_; |
251 | static void* cuEventElapsedTime_; |
252 | static void* cuEventRecord_; |
253 | static void* cuEventDestroy_v2_; |
254 | |
255 | /* ------------------- * |
256 | * NVML |
257 | * ------------------- */ |
258 | static void* nvmlInit_v2_; |
259 | static void* nvmlDeviceGetHandleByPciBusId_v2_; |
260 | static void* nvmlDeviceGetClockInfo_; |
261 | static void* nvmlDeviceGetMaxClockInfo_; |
262 | static void* nvmlDeviceSetApplicationsClocks_; |
263 | |
264 | /* ------------------- * |
265 | * HIP |
266 | * ------------------- */ |
267 | // context management |
268 | static void* hipInit_; |
269 | static void* hipCtxDestroy_; |
270 | static void* hipCtxCreate_; |
271 | static void* hipCtxPushCurrent_; |
272 | static void* hipCtxPopCurrent_; |
273 | static void* hipCtxGetDevice_; |
274 | static void* hipCtxEnablePeerAccess_; |
275 | static void* hipDriverGetVersion_; |
276 | // device management |
277 | static void* hipGetDevice_; |
278 | static void* hipDeviceGetName_; |
279 | static void* hipDeviceGetPCIBusId_; |
280 | static void* hipDeviceGetAttribute_; |
281 | static void* hipGetDeviceCount_; |
282 | // module management |
283 | static void* hipModuleGetGlobal_; |
284 | static void* hipModuleLoad_; |
285 | static void* hipModuleLoadData_; |
286 | static void* hipModuleUnload_; |
287 | static void* hipModuleLoadDataEx_; |
288 | static void* hipModuleGetFunction_; |
289 | // stream management |
290 | static void* hipStreamCreate_; |
291 | static void* hipStreamSynchronize_; |
292 | static void* hipStreamDestroy_; |
293 | static void* hipModuleLaunchKernel_;; |
294 | // function management |
295 | static void* hipFuncGetAttributes_; |
296 | static void* hipFuncSetAttribute_; |
297 | static void* hipFuncSetCacheConfig_; |
298 | // memory management |
299 | static void* hipMalloc_; |
300 | static void* hipPointerGetAttribute_; |
301 | static void* hipMemsetD8Async_; |
302 | static void* hipMemcpyDtoH_; |
303 | static void* hipFree_; |
304 | static void* hipMemcpyDtoHAsync_; |
305 | static void* hipMemcpyHtoDAsync_; |
306 | static void* hipMemcpyHtoD_; |
307 | // event management |
308 | static void* hipEventCreate_; |
309 | static void* hipEventElapsedTime_; |
310 | static void* hipEventRecord_; |
311 | static void* hipEventDestroy_; |
312 | }; |
313 | |
314 | } |
315 | } |
316 | |
317 | |
318 | #endif |
319 | |