1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/block_reduction.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* block_reduction_cu = R"( |
7 | // [Z,Y,X]_THREADS is the number of participating threads in the z, y, x |
8 | // dimension of the block. If set to false the dimension doesn't |
9 | // participate in the reduction. We could start with warp reductions, then |
10 | // reduce the warps, this could save some shared memory, but could be slower in |
11 | // some instances. |
12 | // |
13 | // EXAMPLE USAGE: |
14 | // blockReduceSum<X_THREADS, Y_THREADS, Z_THREADS> |
15 | // (output[output_index], inputs[input_index], |
16 | // [] __device__ (T& a, const T b) { a += b; }); |
17 | // |
18 | // Note: We agressively template functions taking dim3 in the functions below |
19 | // because ROCM uses different types for the various dim3 and maps them |
20 | // directly to intrinsics, but they're dim3 when used after modification. |
21 | // |
22 | template < |
23 | bool X_REDUCE, |
24 | bool Y_REDUCE, |
25 | bool Z_REDUCE, |
26 | typename T, |
27 | typename Func, |
28 | typename _dim3, |
29 | typename _dim3_2> |
30 | __device__ void blockReduce( |
31 | T& out, |
32 | const T& inp_val, |
33 | Func reduction_op, |
34 | const _dim3& thread_idx, |
35 | const _dim3_2& block_dim, |
36 | T* shared_mem, |
37 | bool read_pred, |
38 | bool write_pred, |
39 | T init_val) { |
40 | // If this thread will output a final result |
41 | bool should_write = |
42 | index_utils::maskedIsZero<X_REDUCE, Y_REDUCE, Z_REDUCE>(thread_idx); |
43 | |
44 | // Size of the reduction segments |
45 | unsigned int reduction_size = |
46 | index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(block_dim); |
47 | |
48 | // Index into the reduction segment |
49 | unsigned int reduction_tid = |
50 | index_utils::maskedOffset<X_REDUCE, Y_REDUCE, Z_REDUCE>( |
51 | thread_idx, block_dim); |
52 | |
53 | // Index of the reduction segment |
54 | unsigned int reduction_idx = |
55 | index_utils::maskedOffset<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>( |
56 | thread_idx, block_dim); |
57 | |
58 | // Offset into smem for the current thread |
59 | unsigned int smem_offset = reduction_idx * reduction_size + reduction_tid; |
60 | |
61 | // Initialize shared memory |
62 | if (read_pred) { |
63 | shared_mem[smem_offset] = inp_val; |
64 | } else { |
65 | shared_mem[smem_offset] = init_val; |
66 | } |
67 | |
68 | block_sync::sync(); |
69 | // Reduce down to nearest power of 2 for the tree reduction: |
70 | int np2 = 1 << (31 - __clz(reduction_size)); |
71 | |
72 | if (reduction_tid < np2 && reduction_tid + np2 < reduction_size) { |
73 | reduction_op(shared_mem[smem_offset], shared_mem[smem_offset + np2]); |
74 | } |
75 | block_sync::sync(); |
76 | |
77 | // loop peel the final iteration to save one syncthread for the end |
78 | for (int factor = np2 / 2; factor > 1; factor >>= 1) { |
79 | if (reduction_tid < factor) { |
80 | reduction_op(shared_mem[smem_offset], shared_mem[smem_offset + factor]); |
81 | } |
82 | block_sync::sync(); |
83 | } |
84 | |
85 | if (should_write && write_pred) { |
86 | T result = out; |
87 | reduction_op(result, shared_mem[smem_offset]); |
88 | if (reduction_size > 1) { |
89 | reduction_op(result, shared_mem[smem_offset + 1]); |
90 | } |
91 | out = result; |
92 | } |
93 | block_sync::sync(); |
94 | } |
95 | |
96 | // Use the same pred for both reads and writes |
97 | template < |
98 | bool X_REDUCE, |
99 | bool Y_REDUCE, |
100 | bool Z_REDUCE, |
101 | typename T, |
102 | typename Func, |
103 | typename _dim3, |
104 | typename _dim3_2> |
105 | __device__ void blockReduce( |
106 | T& out, |
107 | const T& inp_val, |
108 | Func reduction_op, |
109 | const _dim3& thread_idx, |
110 | const _dim3_2& block_dim, |
111 | T* shared_mem, |
112 | bool read_write_pred, |
113 | T init_val) { |
114 | blockReduce<X_REDUCE, Y_REDUCE, Z_REDUCE, T, Func, _dim3, _dim3_2>( |
115 | out, |
116 | inp_val, |
117 | reduction_op, |
118 | thread_idx, |
119 | block_dim, |
120 | shared_mem, |
121 | read_write_pred, |
122 | read_write_pred, |
123 | init_val); |
124 | } |
125 | )" ; |
126 | |
127 | } // namespace nvfuser_resources |
128 | |