1/* Copyright 2015-2017 Philippe Tillet
2*
3* Permission is hereby granted, free of charge, to any person obtaining
4* a copy of this software and associated documentation files
5* (the "Software"), to deal in the Software without restriction,
6* including without limitation the rights to use, copy, modify, merge,
7* publish, distribute, sublicense, and/or sell copies of the Software,
8* and to permit persons to whom the Software is furnished to do so,
9* subject to the following conditions:
10*
11* The above copyright notice and this permission notice shall be
12* included in all copies or substantial portions of the Software.
13*
14* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
17* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
18* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
19* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
20* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
21*/
22
23#include "triton/driver/dispatch.h"
24
25namespace triton
26{
27namespace driver
28{
29
30//Helpers for function definition
31#define DEFINE0(init, hlib, ret, fname) ret dispatch::fname()\
32{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname); }\
33void* dispatch::fname ## _;
34
35#define DEFINE1(init, hlib, ret, fname, t1) ret dispatch::fname(t1 a)\
36{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a); }\
37void* dispatch::fname ## _;
38
39#define DEFINE2(init, hlib, ret, fname, t1, t2) ret dispatch::fname(t1 a, t2 b)\
40{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b); }\
41void* dispatch::fname ## _;
42
43#define DEFINE3(init, hlib, ret, fname, t1, t2, t3) ret dispatch::fname(t1 a, t2 b, t3 c)\
44{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c); }\
45void* dispatch::fname ## _;
46
47#define DEFINE4(init, hlib, ret, fname, t1, t2, t3, t4) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d)\
48{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d); }\
49void* dispatch::fname ## _;
50
51#define DEFINE5(init, hlib, ret, fname, t1, t2, t3, t4, t5) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e)\
52{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e); }\
53void* dispatch::fname ## _;
54
55#define DEFINE6(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f)\
56{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f); }\
57void* dispatch::fname ## _;
58
59#define DEFINE7(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g)\
60{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g); }\
61void* dispatch::fname ## _;
62
63#define DEFINE8(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h)\
64{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h); }\
65void* dispatch::fname ## _;
66
67#define DEFINE9(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i)\
68{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i); }\
69void* dispatch::fname ## _;
70
71#define DEFINE10(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j)\
72{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j); }\
73void* dispatch::fname ## _;
74
75#define DEFINE11(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k)\
76{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k); }\
77void* dispatch::fname ## _;
78
79#define DEFINE13(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m)\
80{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m); }\
81void* dispatch::fname ## _;
82
83#define DEFINE19(init, hlib, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15, t16, t17, t18, t19) ret dispatch::fname(t1 a, t2 b, t3 c, t4 d, t5 e, t6 f, t7 g, t8 h, t9 i, t10 j, t11 k, t12 l, t13 m, t14 n, t15 o, t16 p, t17 q, t18 r, t19 s)\
84{return f_impl<dispatch::init>(hlib, fname, fname ## _, #fname, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s); }\
85void* dispatch::fname ## _;
86
87
88/* ------------------- *
89 * CUDA
90 * ------------------- */
91
92bool dispatch::cuinit(){
93 if(cuda_==nullptr){
94 #ifdef _WIN32
95 cuda_ = dlopen("cudart64_110.dll", RTLD_LAZY);
96 #else
97 cuda_ = dlopen("libcuda.so", RTLD_LAZY);
98 if(!cuda_)
99 cuda_ = dlopen("libcuda.so.1", RTLD_LAZY);
100 #endif
101 if(!cuda_)
102 throw std::runtime_error("Could not find `libcuda.so`. Make sure it is in your LD_LIBRARY_PATH.");
103 }
104 if(cuda_ == nullptr)
105 return false;
106 CUresult (*fptr)(unsigned int);
107 cuInit_ = dlsym(cuda_, "cuInit");
108 *reinterpret_cast<void **>(&fptr) = cuInit_;
109 CUresult res = (*fptr)(0);
110 check(res);
111 return true;
112}
113
114#define CUDA_DEFINE1(ret, fname, t1) DEFINE1(cuinit, cuda_, ret, fname, t1)
115#define CUDA_DEFINE2(ret, fname, t1, t2) DEFINE2(cuinit, cuda_, ret, fname, t1, t2)
116#define CUDA_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(cuinit, cuda_, ret, fname, t1, t2, t3)
117#define CUDA_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(cuinit, cuda_, ret, fname, t1, t2, t3, t4)
118#define CUDA_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5)
119#define CUDA_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6)
120#define CUDA_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
121#define CUDA_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
122#define CUDA_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
123#define CUDA_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
124#define CUDA_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(cuinit, cuda_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11)
125
126// context management
127CUDA_DEFINE1(CUresult, cuCtxDestroy_v2, CUcontext)
128CUDA_DEFINE3(CUresult, cuCtxCreate_v2, CUcontext *, unsigned int, CUdevice)
129CUDA_DEFINE1(CUresult, cuCtxGetDevice, CUdevice*)
130CUDA_DEFINE2(CUresult, cuCtxEnablePeerAccess, CUcontext, unsigned int)
131CUDA_DEFINE1(CUresult, cuInit, unsigned int)
132CUDA_DEFINE1(CUresult, cuDriverGetVersion, int *)
133// device management
134CUDA_DEFINE2(CUresult, cuDeviceGet, CUdevice *, int)
135CUDA_DEFINE3(CUresult, cuDeviceGetName, char *, int, CUdevice)
136CUDA_DEFINE3(CUresult, cuDeviceGetPCIBusId, char *, int, CUdevice)
137CUDA_DEFINE3(CUresult, cuDeviceGetAttribute, int *, CUdevice_attribute, CUdevice)
138CUDA_DEFINE1(CUresult, cuDeviceGetCount, int*)
139
140// link management
141CUDA_DEFINE6(CUresult, cuLinkAddFile_v2, CUlinkState, CUjitInputType, const char *, unsigned int , CUjit_option *, void **);
142CUDA_DEFINE8(CUresult, cuLinkAddData_v2, CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
143CUDA_DEFINE4(CUresult, cuLinkCreate_v2, unsigned int, CUjit_option*, void**, CUlinkState*);
144CUDA_DEFINE1(CUresult, cuLinkDestroy, CUlinkState);
145CUDA_DEFINE3(CUresult, cuLinkComplete, CUlinkState, void**, size_t*);
146// module management
147CUDA_DEFINE4(CUresult, cuModuleGetGlobal_v2, CUdeviceptr*, size_t*, CUmodule, const char*)
148CUDA_DEFINE2(CUresult, cuModuleLoad, CUmodule *, const char *)
149CUDA_DEFINE1(CUresult, cuModuleUnload, CUmodule)
150CUDA_DEFINE2(CUresult, cuModuleLoadData, CUmodule *, const void *)
151CUDA_DEFINE5(CUresult, cuModuleLoadDataEx, CUmodule *, const void *, unsigned int, CUjit_option *, void **)
152CUDA_DEFINE3(CUresult, cuModuleGetFunction, CUfunction *, CUmodule, const char *)
153// stream management
154CUDA_DEFINE2(CUresult, cuStreamCreate, CUstream *, unsigned int)
155CUDA_DEFINE1(CUresult, cuStreamSynchronize, CUstream)
156CUDA_DEFINE1(CUresult, cuStreamDestroy_v2, CUstream)
157CUDA_DEFINE2(CUresult, cuStreamGetCtx, CUstream, CUcontext*)
158CUDA_DEFINE11(CUresult, cuLaunchKernel, CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, CUstream, void **, void **)
159// function management
160CUDA_DEFINE3(CUresult, cuFuncGetAttribute, int*, CUfunction_attribute, CUfunction)
161CUDA_DEFINE3(CUresult, cuFuncSetAttribute, CUfunction, CUfunction_attribute, int)
162CUDA_DEFINE2(CUresult, cuFuncSetCacheConfig, CUfunction, CUfunc_cache)
163// memory management
164CUDA_DEFINE3(CUresult, cuMemcpyDtoH_v2, void *, CUdeviceptr, size_t)
165CUDA_DEFINE1(CUresult, cuMemFree_v2, CUdeviceptr)
166CUDA_DEFINE4(CUresult, cuMemcpyDtoHAsync_v2, void *, CUdeviceptr, size_t, CUstream)
167CUDA_DEFINE4(CUresult, cuMemcpyHtoDAsync_v2, CUdeviceptr, const void *, size_t, CUstream)
168CUDA_DEFINE3(CUresult, cuMemcpyHtoD_v2, CUdeviceptr, const void *, size_t )
169CUDA_DEFINE2(CUresult, cuMemAlloc_v2, CUdeviceptr*, size_t)
170CUDA_DEFINE3(CUresult, cuPointerGetAttribute, void*, CUpointer_attribute, CUdeviceptr)
171CUDA_DEFINE4(CUresult, cuMemsetD8Async, CUdeviceptr, unsigned char, size_t, CUstream)
172// event management
173CUDA_DEFINE2(CUresult, cuEventCreate, CUevent *, unsigned int)
174CUDA_DEFINE3(CUresult, cuEventElapsedTime, float *, CUevent, CUevent)
175CUDA_DEFINE2(CUresult, cuEventRecord, CUevent, CUstream)
176CUDA_DEFINE1(CUresult, cuEventDestroy_v2, CUevent)
177
178
179
180/* ------------------- *
181 * NVML
182 * ------------------- */
183bool dispatch::nvmlinit(){
184 #ifdef _WIN32
185 if(nvml_==nullptr)
186 nvml_ = dlopen("nvml.dll", RTLD_LAZY);
187 #else
188 if(nvml_==nullptr)
189 nvml_ = dlopen("libnvidia-ml.so", RTLD_LAZY);
190 #endif
191 nvmlReturn_t (*fptr)();
192 nvmlInit_v2_ = dlsym(nvml_, "nvmlInit_v2");
193 *reinterpret_cast<void **>(&fptr) = nvmlInit_v2_;
194 nvmlReturn_t res = (*fptr)();
195 check(res);
196 return res;
197}
198
199#define NVML_DEFINE0(ret, fname) DEFINE0(nvmlinit, nvml_, ret, fname)
200#define NVML_DEFINE1(ret, fname, t1) DEFINE1(nvmlinit, nvml_, ret, fname, t1)
201#define NVML_DEFINE2(ret, fname, t1, t2) DEFINE2(nvmlinit, nvml_, ret, fname, t1, t2)
202#define NVML_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(nvmlinit, nvml_, ret, fname, t1, t2, t3)
203
204NVML_DEFINE2(nvmlReturn_t, nvmlDeviceGetHandleByPciBusId_v2, const char *, nvmlDevice_t*)
205NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
206NVML_DEFINE3(nvmlReturn_t, nvmlDeviceGetMaxClockInfo, nvmlDevice_t, nvmlClockType_t, unsigned int*)
207NVML_DEFINE3(nvmlReturn_t, nvmlDeviceSetApplicationsClocks, nvmlDevice_t, unsigned int, unsigned int)
208
209/* ------------------- *
210 * HIP
211 * ------------------- */
212bool dispatch::hipinit(){
213 if(hip_==nullptr)
214 hip_ = dlopen("libamdhip64.so", RTLD_LAZY);
215 if(hip_ == nullptr)
216 return false;
217 hipError_t (*fptr)();
218 hipInit_ = dlsym(hip_, "hipInit");
219 *reinterpret_cast<void **>(&fptr) = hipInit_;
220 hipError_t res = (*fptr)();
221 check(res);
222 return res;
223}
224
225#define HIP_DEFINE1(ret, fname, t1) DEFINE1(hipinit, hip_, ret, fname, t1)
226#define HIP_DEFINE2(ret, fname, t1, t2) DEFINE2(hipinit, hip_, ret, fname, t1, t2)
227#define HIP_DEFINE3(ret, fname, t1, t2, t3) DEFINE3(hipinit, hip_, ret, fname, t1, t2, t3)
228#define HIP_DEFINE4(ret, fname, t1, t2, t3, t4) DEFINE4(hipinit, hip_, ret, fname, t1, t2, t3, t4)
229#define HIP_DEFINE5(ret, fname, t1, t2, t3, t4, t5) DEFINE5(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5)
230#define HIP_DEFINE6(ret, fname, t1, t2, t3, t4, t5, t6) DEFINE6(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6)
231#define HIP_DEFINE7(ret, fname, t1, t2, t3, t4, t5, t6, t7) DEFINE7(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7)
232#define HIP_DEFINE8(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8) DEFINE8(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8)
233#define HIP_DEFINE9(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9) DEFINE9(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9)
234#define HIP_DEFINE10(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10) DEFINE10(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10)
235#define HIP_DEFINE11(ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11) DEFINE11(hipinit, hip_, ret, fname, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11)
236
237// context management
238HIP_DEFINE1(hipError_t, hipCtxDestroy, hipCtx_t)
239HIP_DEFINE3(hipError_t, hipCtxCreate, hipCtx_t *, unsigned int, hipDevice_t)
240HIP_DEFINE1(hipError_t, hipCtxGetDevice, hipDevice_t*)
241HIP_DEFINE1(hipError_t, hipCtxPushCurrent, hipCtx_t)
242HIP_DEFINE1(hipError_t, hipCtxPopCurrent, hipCtx_t*)
243HIP_DEFINE2(hipError_t, hipCtxEnablePeerAccess, hipCtx_t, unsigned int)
244HIP_DEFINE1(hipError_t, hipInit, unsigned int)
245HIP_DEFINE1(hipError_t, hipDriverGetVersion, int *)
246// device management
247HIP_DEFINE2(hipError_t, hipGetDevice, hipDevice_t *, int)
248HIP_DEFINE3(hipError_t, hipDeviceGetName, char *, int, hipDevice_t)
249HIP_DEFINE3(hipError_t, hipDeviceGetPCIBusId, char *, int, hipDevice_t)
250HIP_DEFINE3(hipError_t, hipDeviceGetAttribute, int *, hipDeviceAttribute_t, hipDevice_t)
251HIP_DEFINE1(hipError_t, hipGetDeviceCount, int *)
252// module management
253HIP_DEFINE4(hipError_t, hipModuleGetGlobal, hipDeviceptr_t*, size_t*, hipModule_t, const char*)
254HIP_DEFINE2(hipError_t, hipModuleLoad, hipModule_t *, const char *)
255HIP_DEFINE1(hipError_t, hipModuleUnload, hipModule_t)
256HIP_DEFINE2(hipError_t, hipModuleLoadData, hipModule_t *, const void *)
257HIP_DEFINE5(hipError_t, hipModuleLoadDataEx, hipModule_t *, const void *, unsigned int, hipJitOption *, void **)
258HIP_DEFINE3(hipError_t, hipModuleGetFunction, hipFunction_t *, hipModule_t, const char *)
259// stream management
260HIP_DEFINE2(hipError_t, hipStreamCreate, hipStream_t *, unsigned int)
261HIP_DEFINE1(hipError_t, hipStreamSynchronize, hipStream_t)
262HIP_DEFINE1(hipError_t, hipStreamDestroy, hipStream_t)
263HIP_DEFINE11(hipError_t, hipModuleLaunchKernel, hipFunction_t, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, hipStream_t, void **, void **)
264// function management
265HIP_DEFINE2(hipError_t, hipFuncGetAttributes, hipFuncAttributes*, void*)
266HIP_DEFINE2(hipError_t, hipFuncSetCacheConfig, hipFunction_t, hipFuncCache_t)
267// memory management
268HIP_DEFINE3(hipError_t, hipMemcpyDtoH, void *, hipDeviceptr_t, size_t)
269HIP_DEFINE1(hipError_t, hipFree, hipDeviceptr_t)
270HIP_DEFINE4(hipError_t, hipMemcpyDtoHAsync, void *, hipDeviceptr_t, size_t, hipStream_t)
271HIP_DEFINE4(hipError_t, hipMemcpyHtoDAsync, hipDeviceptr_t, const void *, size_t, hipStream_t)
272HIP_DEFINE3(hipError_t, hipMemcpyHtoD, hipDeviceptr_t, const void *, size_t )
273HIP_DEFINE2(hipError_t, hipMalloc, hipDeviceptr_t*, size_t)
274HIP_DEFINE3(hipError_t, hipPointerGetAttribute, void*, CUpointer_attribute, hipDeviceptr_t)
275HIP_DEFINE4(hipError_t, hipMemsetD8Async, hipDeviceptr_t, unsigned char, size_t, hipStream_t)
276// event management
277HIP_DEFINE2(hipError_t, hipEventCreate, hipEvent_t *, unsigned int)
278HIP_DEFINE3(hipError_t, hipEventElapsedTime, float *, hipEvent_t, hipEvent_t)
279HIP_DEFINE2(hipError_t, hipEventRecord, hipEvent_t, hipStream_t)
280HIP_DEFINE1(hipError_t, hipEventDestroy, hipEvent_t)
281
282
283/* ------------------- *
284 * COMMON
285 * ------------------- */
286
287// Release
288void dispatch::release(){
289 if(cuda_){
290 dlclose(cuda_);
291 cuda_ = nullptr;
292 }
293}
294
295void* dispatch::cuda_;
296void* dispatch::nvml_;
297void* dispatch::nvmlInit_v2_;
298void* dispatch::hip_;
299
300
301}
302}
303