1// Generated from "/code/pytorch/third_party/nvfuser/runtime/index_utils.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* index_utils_cu = R"(
7namespace index_utils {
8
9// Utility functions
10
11// Total size of provided dimension
12template <typename _dim3>
13__device__ __forceinline__ nvfuser_index_t size(const _dim3& d) {
14 return (nvfuser_index_t)d.x * (nvfuser_index_t)d.y * (nvfuser_index_t)d.z;
15}
16
17// Linearized indexing of idx based on dim, if bool==false that dimension does
18// not participate
19template <bool X, bool Y, bool Z, typename _dim3, typename _dim3_2>
20__device__ nvfuser_index_t maskedOffset(const _dim3& idx, const _dim3_2& dim) {
21 nvfuser_index_t offset = 0;
22 if (Z)
23 offset += idx.z;
24 if (Y)
25 offset = offset * dim.y + idx.y;
26 if (X)
27 offset = offset * dim.x + idx.x;
28 return offset;
29}
30
31// Linearized indexing of idx based on dim. All dimensions participate.
32template <typename _dim3, typename _dim3_2>
33__device__ nvfuser_index_t offset(const _dim3& idx, const _dim3_2& dim) {
34 nvfuser_index_t offset = idx.z;
35 offset = offset * dim.y + idx.y;
36 offset = offset * dim.x + idx.x;
37 return offset;
38}
39
40// Masks the provided dim3, those == false get truncated to 1
41template <bool X, bool Y, bool Z, typename _dim3>
42__device__ dim3 maskedDims(const _dim3& dim) {
43 return dim3{
44 X ? (unsigned)dim.x : 1U,
45 Y ? (unsigned)dim.y : 1U,
46 Z ? (unsigned)dim.z : 1U};
47}
48
49// Provides total size of dim with masking, those dims == false do not
50// participate in the size calculation
51template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, typename _dim3>
52__device__ nvfuser_index_t maskedSize(const _dim3& dim) {
53 return size(maskedDims<X_BLOCK, Y_BLOCK, Z_BLOCK>(dim));
54}
55
56// Checks if provided idx is zero on those dims == true
57template <bool X, bool Y, bool Z, typename _dim3>
58__device__ bool maskedIsZero(const _dim3& idx) {
59 bool isZero = true;
60 if (X)
61 isZero = isZero && idx.x == 0;
62 if (Y)
63 isZero = isZero && idx.y == 0;
64 if (Z)
65 isZero = isZero && idx.z == 0;
66 return isZero;
67}
68
69// Checks if provided idx is zero on those dims == true
70template <bool X, bool Y, bool Z, typename _dim3, typename _dim3_2>
71__device__ bool maskedIsLast(const _dim3& idx, const _dim3_2& dim) {
72 bool isZero = true;
73 if (X)
74 isZero = isZero && idx.x == dim.x - 1;
75 if (Y)
76 isZero = isZero && idx.y == dim.y - 1;
77 if (Z)
78 isZero = isZero && idx.z == dim.z - 1;
79 return isZero;
80}
81
82} // namespace index_utils
83)";
84
85} // namespace nvfuser_resources
86