1#pragma once
2#include <c10/macros/Macros.h>
3#include <string>
4
5#define JITERATOR_HOST_DEVICE C10_HOST_DEVICE
6#if defined(_MSC_VER) && defined(__CUDACC__)
7// NVRTC on Windows errors if __host__ __device__ attribute is
8// present on kernel.
9// error: attribute "__host__" does not apply here
10// error: attribute "__device__" does not apply here
11#define JITERATOR_HOST_DEVICE
12#endif
13
14// jiterator_also_stringify_as macro is used to define code (for CPU/ROCm)
15// and generate code string for `jiterator` (only when compiling for CUDA).
16// Usage :
17// jiterator_also_stringify_as(
18// jiterator_code(template <typename T> T identity(T x) { return x; }),
19// identity_string);
20// This will define the template `identity` as present in code and
21// also define `std::string identity_string` with the code as the string
22// if this is being compiled for CUDA.
23
24// `jiterator_code` macro is to deal with `,` in the kernel code.
25// These `,`s confuse the preprocessor into thinking we are passing
26// multiple arguments to the macro.
27#define jiterator_code(...) __VA_ARGS__
28#if defined(__CUDACC__) || defined(__HIPCC__)
29// CPU and CUDA and ROCm case
30#define stringify_code(...) #__VA_ARGS__
31#define jiterator_also_stringify_as(code, str_name) \
32 code /* define the function */ \
33 const std::string str_name = std::string(stringify_code(code));
34#else
35// CPU only or CPU and ROCm case
36// Only needs the function
37#define jiterator_also_stringify_as(code, str_name) code
38#endif
39