1 | #pragma once |
2 | #include <c10/macros/Macros.h> |
3 | #include <string> |
4 | |
5 | #define JITERATOR_HOST_DEVICE C10_HOST_DEVICE |
6 | #if defined(_MSC_VER) && defined(__CUDACC__) |
7 | // NVRTC on Windows errors if __host__ __device__ attribute is |
8 | // present on kernel. |
9 | // error: attribute "__host__" does not apply here |
10 | // error: attribute "__device__" does not apply here |
11 | #define JITERATOR_HOST_DEVICE |
12 | #endif |
13 | |
14 | // jiterator_also_stringify_as macro is used to define code (for CPU/ROCm) |
15 | // and generate code string for `jiterator` (only when compiling for CUDA). |
16 | // Usage : |
17 | // jiterator_also_stringify_as( |
18 | // jiterator_code(template <typename T> T identity(T x) { return x; }), |
19 | // identity_string); |
20 | // This will define the template `identity` as present in code and |
21 | // also define `std::string identity_string` with the code as the string |
22 | // if this is being compiled for CUDA. |
23 | |
24 | // `jiterator_code` macro is to deal with `,` in the kernel code. |
25 | // These `,`s confuse the preprocessor into thinking we are passing |
26 | // multiple arguments to the macro. |
27 | #define jiterator_code(...) __VA_ARGS__ |
28 | #if defined(__CUDACC__) || defined(__HIPCC__) |
29 | // CPU and CUDA and ROCm case |
30 | #define stringify_code(...) #__VA_ARGS__ |
31 | #define jiterator_also_stringify_as(code, str_name) \ |
32 | code /* define the function */ \ |
33 | const std::string str_name = std::string(stringify_code(code)); |
34 | #else |
35 | // CPU only or CPU and ROCm case |
36 | // Only needs the function |
37 | #define jiterator_also_stringify_as(code, str_name) code |
38 | #endif |
39 | |