1#pragma once
2
3#ifndef _TRITON_DRIVER_DISPATCH_H_
4#define _TRITON_DRIVER_DISPATCH_H_
5
6#include <type_traits>
7#include <dlfcn.h>
8
9//CUDA Backend
10#include "triton/external/CUDA/cuda.h"
11#include "triton/external/CUDA/nvml.h"
12
13//// HIP backend
14//#define __HIP_PLATFORM_AMD__
15#include "triton/external/hip.h"
16
17//Exceptions
18#include <iostream>
19#include <stdexcept>
20
21namespace llvm {
22class PassRegistry;
23class Module;
24}
25
26namespace triton
27{
28namespace driver
29{
30
31class cu_context;
32
33template<class T> void check(T){}
34void check(CUresult err);
35void check(hipError_t err);
36
37class dispatch
38{
39protected:
40 template <class F>
41 struct return_type;
42
43 template <class R, class... A>
44 struct return_type<R (*)(A...)>
45 { typedef R type; };
46
47 typedef bool (*f_init_t)();
48
49 template<f_init_t initializer, typename FunPtrT, typename... Args>
50 static typename return_type<FunPtrT>::type f_impl(void*& lib_h, FunPtrT, void*& cache, const char * name, Args... args)
51 {
52 initializer();
53 if(cache == nullptr){
54 cache = dlsym(lib_h, name);
55 if(cache == 0)
56 throw std::runtime_error("dlsym unable to load function");
57 }
58 FunPtrT fptr;
59 *reinterpret_cast<void **>(&fptr) = cache;
60 typename return_type<FunPtrT>::type res = (*fptr)(args...);
61 check(res);
62 return res;
63 }
64
65public:
66 static void release();
67 // Nvidia
68 static bool nvmlinit();
69 static bool cuinit();
70 // AMD
71 static bool hipinit();
72
73 /* ------------------- *
74 * CUDA
75 * ------------------- */
76 // context management
77 static CUresult cuInit(unsigned int Flags);
78 static CUresult cuCtxDestroy_v2(CUcontext ctx);
79 static CUresult cuCtxCreate_v2(CUcontext *pctx, unsigned int flags, CUdevice dev);
80 static CUresult cuCtxPushCurrent_v2(CUcontext ctx);
81 static CUresult cuCtxPopCurrent_v2(CUcontext *pctx);
82 static CUresult cuCtxGetDevice(CUdevice* result);
83 static CUresult cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int flags);
84 static CUresult cuDriverGetVersion(int *driverVersion);
85 // device management
86 static CUresult cuDeviceGet(CUdevice *device, int ordinal);
87 static CUresult cuDeviceGetName(char *name, int len, CUdevice dev);
88 static CUresult cuDeviceGetPCIBusId(char *id, int len, CUdevice dev);
89 static CUresult cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev);
90 static CUresult cuDeviceGetCount(int *count);
91 // link management
92 static CUresult cuLinkAddFile_v2(CUlinkState state, CUjitInputType type, const char *path, unsigned int numOptions, CUjit_option *options, void **optionValues);
93 static CUresult cuLinkAddData_v2(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues);
94 static CUresult cuLinkCreate_v2(unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut);
95 static CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut);
96 static CUresult cuLinkDestroy(CUlinkState state);
97 // module management
98 static CUresult cuModuleGetGlobal_v2(CUdeviceptr *dptr, size_t* bytes, CUmodule hmod, const char *name);
99 static CUresult cuModuleLoad(CUmodule *module, const char *fname);
100 static CUresult cuModuleLoadData(CUmodule* module, const void* image);
101 static CUresult cuModuleUnload(CUmodule hmod);
102 static CUresult cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues);
103 static CUresult cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name);
104 // stream management
105 static CUresult cuStreamCreate(CUstream *phStream, unsigned int Flags);
106 static CUresult cuStreamSynchronize(CUstream hStream);
107 static CUresult cuStreamGetCtx(CUstream hStream, CUcontext* pctx);
108 static CUresult cuStreamDestroy_v2(CUstream hStream);
109 static CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra);
110 // function management
111 static CUresult cuFuncGetAttribute(int* pi, CUfunction_attribute attrib, CUfunction hfunc);
112 static CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value);
113 static CUresult cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config);
114 // memory management
115 static CUresult cuMemAlloc_v2(CUdeviceptr *dptr, size_t bytesize);
116 static CUresult cuPointerGetAttribute(void * data, CUpointer_attribute attribute, CUdeviceptr ptr);
117 static CUresult cuMemsetD8Async(CUdeviceptr dst, unsigned char x, size_t N, CUstream stream);
118 static CUresult cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount);
119 static CUresult cuMemFree_v2(CUdeviceptr dptr);
120 static CUresult cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream);
121 static CUresult cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream);
122 static CUresult cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount);
123 // event management
124 static CUresult cuEventCreate(CUevent *phEvent, unsigned int Flags);
125 static CUresult cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd);
126 static CUresult cuEventRecord(CUevent hEvent, CUstream hStream);
127 static CUresult cuEventDestroy_v2(CUevent hEvent);
128
129
130 /* ------------------- *
131 * NVML
132 * ------------------- */
133 static nvmlReturn_t nvmlDeviceGetHandleByPciBusId_v2( const char* pciBusId, nvmlDevice_t* device);
134 static nvmlReturn_t nvmlDeviceGetClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
135 static nvmlReturn_t nvmlDeviceGetMaxClockInfo(nvmlDevice_t device, nvmlClockType_t type, unsigned int *clock);
136 static nvmlReturn_t nvmlDeviceSetApplicationsClocks(nvmlDevice_t device, unsigned int mem_clock, unsigned int sm_clock);
137
138 /* ------------------- *
139 * HIP
140 * ------------------- */
141 // context management
142 static hipError_t hipInit(unsigned int Flags);
143 static hipError_t hipCtxDestroy(hipCtx_t ctx);
144 static hipError_t hipCtxCreate(hipCtx_t *pctx, unsigned int flags, hipDevice_t dev);
145 static hipError_t hipCtxPushCurrent(hipCtx_t ctx);
146 static hipError_t hipCtxPopCurrent(hipCtx_t *pctx);
147 static hipError_t hipCtxGetDevice(hipDevice_t* result);
148 static hipError_t hipCtxEnablePeerAccess(hipCtx_t peerContext, unsigned int flags);
149 static hipError_t hipDriverGetVersion(int *driverVersion);
150 // device management
151 static hipError_t hipGetDevice(hipDevice_t *device, int ordinal);
152 static hipError_t hipDeviceGetName(char *name, int len, hipDevice_t dev);
153 static hipError_t hipDeviceGetPCIBusId(char *id, int len, hipDevice_t dev);
154 static hipError_t hipDeviceGetAttribute(int *pi, hipDeviceAttribute_t attrib, hipDevice_t dev);
155 static hipError_t hipGetDeviceCount(int *count);
156 // module management
157 static hipError_t hipModuleGetGlobal(hipDeviceptr_t *dptr, size_t* bytes, hipModule_t hmod, const char *name);
158 static hipError_t hipModuleLoad(hipModule_t *module, const char *fname);
159 static hipError_t hipModuleLoadData(hipModule_t* module, const void* image);
160 static hipError_t hipModuleUnload(hipModule_t hmod);
161 static hipError_t hipModuleLoadDataEx(hipModule_t *module, const void *image, unsigned int numOptions, hipJitOption *options, void **optionValues);
162 static hipError_t hipModuleGetFunction(hipFunction_t *hfunc, hipModule_t hmod, const char *name);
163 // stream management
164 static hipError_t hipStreamCreate(hipStream_t *phStream, unsigned int Flags);
165 static hipError_t hipStreamSynchronize(hipStream_t hStream);
166 static hipError_t hipStreamDestroy(hipStream_t hStream);
167 static hipError_t hipModuleLaunchKernel(hipFunction_t f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, hipStream_t hStream, void **kernelParams, void **extra);
168 // function management
169 static hipError_t hipFuncGetAttributes(hipFuncAttributes* attrib, void* hfunc);
170 static hipError_t hipFuncSetAttribute(hipFunction_t hfunc, hipFuncAttribute attrib, int value);
171 static hipError_t hipFuncSetCacheConfig(hipFunction_t hfunc, hipFuncCache_t config);
172 // memory management
173 static hipError_t hipMalloc(hipDeviceptr_t *dptr, size_t bytesize);
174 static hipError_t hipPointerGetAttribute(void * data, CUpointer_attribute attribute, hipDeviceptr_t ptr);
175 static hipError_t hipMemsetD8Async(hipDeviceptr_t dst, unsigned char x, size_t N, hipStream_t stream);
176 static hipError_t hipMemcpyDtoH(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount);
177 static hipError_t hipFree(hipDeviceptr_t dptr);
178 static hipError_t hipMemcpyDtoHAsync(void *dstHost, hipDeviceptr_t srcDevice, size_t ByteCount, hipStream_t hStream);
179 static hipError_t hipMemcpyHtoDAsync(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount, hipStream_t hStream);
180 static hipError_t hipMemcpyHtoD(hipDeviceptr_t dstDevice, const void *srcHost, size_t ByteCount);
181 // event management
182 static hipError_t hipEventCreate(hipEvent_t *phEvent, unsigned int Flags);
183 static hipError_t hipEventElapsedTime(float *pMilliseconds, hipEvent_t hStart, hipEvent_t hEnd);
184 static hipError_t hipEventRecord(hipEvent_t hEvent, hipStream_t hStream);
185 static hipError_t hipEventDestroy(hipEvent_t hEvent);
186
187
188
189private:
190
191 // Libraries
192 static void* cuda_;
193 static void* nvml_;
194 static void* hip_;
195
196
197 /* ------------------- *
198 * CUDA
199 * ------------------- */
200 // context management
201 static void* cuCtxGetCurrent_;
202 static void* cuCtxSetCurrent_;
203 static void* cuCtxDestroy_v2_;
204 static void* cuCtxCreate_v2_;
205 static void* cuCtxGetDevice_;
206 static void* cuCtxPushCurrent_v2_;
207 static void* cuCtxPopCurrent_v2_;
208 static void* cuCtxEnablePeerAccess_;
209 static void* cuDriverGetVersion_;
210 static void* cuInit_;
211 // device management
212 static void* cuDeviceGet_;
213 static void* cuDeviceGetName_;
214 static void* cuDeviceGetPCIBusId_;
215 static void* cuDeviceGetAttribute_;
216 static void* cuDeviceGetCount_;
217 // link management
218 static void* cuLinkAddFile_v2_;
219 static void* cuLinkAddData_v2_;
220 static void* cuLinkCreate_v2_;
221 static void* cuLinkDestroy_;
222 static void* cuLinkComplete_;
223 // module management
224 static void* cuModuleGetGlobal_v2_;
225 static void* cuModuleLoad_;
226 static void* cuModuleUnload_;
227 static void* cuModuleLoadDataEx_;
228 static void* cuModuleLoadData_;
229 static void* cuModuleGetFunction_;
230 // stream management
231 static void* cuStreamCreate_;
232 static void* cuStreamSynchronize_;
233 static void* cuStreamDestroy_v2_;
234 static void* cuStreamGetCtx_;
235 static void* cuLaunchKernel_;
236 // function management
237 static void* cuFuncGetAttribute_;
238 static void* cuFuncSetAttribute_;
239 static void* cuFuncSetCacheConfig_;
240 // memory management
241 static void* cuMemcpyDtoH_v2_;
242 static void* cuMemFree_v2_;
243 static void* cuMemcpyDtoHAsync_v2_;
244 static void* cuMemcpyHtoDAsync_v2_;
245 static void* cuMemcpyHtoD_v2_;
246 static void* cuMemAlloc_v2_;
247 static void* cuMemsetD8Async_;
248 static void* cuPointerGetAttribute_;
249 // event management
250 static void* cuEventCreate_;
251 static void* cuEventElapsedTime_;
252 static void* cuEventRecord_;
253 static void* cuEventDestroy_v2_;
254
255 /* ------------------- *
256 * NVML
257 * ------------------- */
258 static void* nvmlInit_v2_;
259 static void* nvmlDeviceGetHandleByPciBusId_v2_;
260 static void* nvmlDeviceGetClockInfo_;
261 static void* nvmlDeviceGetMaxClockInfo_;
262 static void* nvmlDeviceSetApplicationsClocks_;
263
264 /* ------------------- *
265 * HIP
266 * ------------------- */
267 // context management
268 static void* hipInit_;
269 static void* hipCtxDestroy_;
270 static void* hipCtxCreate_;
271 static void* hipCtxPushCurrent_;
272 static void* hipCtxPopCurrent_;
273 static void* hipCtxGetDevice_;
274 static void* hipCtxEnablePeerAccess_;
275 static void* hipDriverGetVersion_;
276 // device management
277 static void* hipGetDevice_;
278 static void* hipDeviceGetName_;
279 static void* hipDeviceGetPCIBusId_;
280 static void* hipDeviceGetAttribute_;
281 static void* hipGetDeviceCount_;
282 // module management
283 static void* hipModuleGetGlobal_;
284 static void* hipModuleLoad_;
285 static void* hipModuleLoadData_;
286 static void* hipModuleUnload_;
287 static void* hipModuleLoadDataEx_;
288 static void* hipModuleGetFunction_;
289 // stream management
290 static void* hipStreamCreate_;
291 static void* hipStreamSynchronize_;
292 static void* hipStreamDestroy_;
293 static void* hipModuleLaunchKernel_;;
294 // function management
295 static void* hipFuncGetAttributes_;
296 static void* hipFuncSetAttribute_;
297 static void* hipFuncSetCacheConfig_;
298 // memory management
299 static void* hipMalloc_;
300 static void* hipPointerGetAttribute_;
301 static void* hipMemsetD8Async_;
302 static void* hipMemcpyDtoH_;
303 static void* hipFree_;
304 static void* hipMemcpyDtoHAsync_;
305 static void* hipMemcpyHtoDAsync_;
306 static void* hipMemcpyHtoD_;
307 // event management
308 static void* hipEventCreate_;
309 static void* hipEventElapsedTime_;
310 static void* hipEventRecord_;
311 static void* hipEventDestroy_;
312};
313
314}
315}
316
317
318#endif
319