1// Generated from "/code/pytorch/third_party/nvfuser/runtime/block_reduction.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* block_reduction_cu = R"(
7// [Z,Y,X]_THREADS is the number of participating threads in the z, y, x
8// dimension of the block. If set to false the dimension doesn't
9// participate in the reduction. We could start with warp reductions, then
10// reduce the warps, this could save some shared memory, but could be slower in
11// some instances.
12//
13// EXAMPLE USAGE:
14// blockReduceSum<X_THREADS, Y_THREADS, Z_THREADS>
15// (output[output_index], inputs[input_index],
16// [] __device__ (T& a, const T b) { a += b; });
17//
18// Note: We agressively template functions taking dim3 in the functions below
19// because ROCM uses different types for the various dim3 and maps them
20// directly to intrinsics, but they're dim3 when used after modification.
21//
22template <
23 bool X_REDUCE,
24 bool Y_REDUCE,
25 bool Z_REDUCE,
26 typename T,
27 typename Func,
28 typename _dim3,
29 typename _dim3_2>
30__device__ void blockReduce(
31 T& out,
32 const T& inp_val,
33 Func reduction_op,
34 const _dim3& thread_idx,
35 const _dim3_2& block_dim,
36 T* shared_mem,
37 bool read_pred,
38 bool write_pred,
39 T init_val) {
40 // If this thread will output a final result
41 bool should_write =
42 index_utils::maskedIsZero<X_REDUCE, Y_REDUCE, Z_REDUCE>(thread_idx);
43
44 // Size of the reduction segments
45 unsigned int reduction_size =
46 index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(block_dim);
47
48 // Index into the reduction segment
49 unsigned int reduction_tid =
50 index_utils::maskedOffset<X_REDUCE, Y_REDUCE, Z_REDUCE>(
51 thread_idx, block_dim);
52
53 // Index of the reduction segment
54 unsigned int reduction_idx =
55 index_utils::maskedOffset<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>(
56 thread_idx, block_dim);
57
58 // Offset into smem for the current thread
59 unsigned int smem_offset = reduction_idx * reduction_size + reduction_tid;
60
61 // Initialize shared memory
62 if (read_pred) {
63 shared_mem[smem_offset] = inp_val;
64 } else {
65 shared_mem[smem_offset] = init_val;
66 }
67
68 block_sync::sync();
69 // Reduce down to nearest power of 2 for the tree reduction:
70 int np2 = 1 << (31 - __clz(reduction_size));
71
72 if (reduction_tid < np2 && reduction_tid + np2 < reduction_size) {
73 reduction_op(shared_mem[smem_offset], shared_mem[smem_offset + np2]);
74 }
75 block_sync::sync();
76
77 // loop peel the final iteration to save one syncthread for the end
78 for (int factor = np2 / 2; factor > 1; factor >>= 1) {
79 if (reduction_tid < factor) {
80 reduction_op(shared_mem[smem_offset], shared_mem[smem_offset + factor]);
81 }
82 block_sync::sync();
83 }
84
85 if (should_write && write_pred) {
86 T result = out;
87 reduction_op(result, shared_mem[smem_offset]);
88 if (reduction_size > 1) {
89 reduction_op(result, shared_mem[smem_offset + 1]);
90 }
91 out = result;
92 }
93 block_sync::sync();
94}
95
96// Use the same pred for both reads and writes
97template <
98 bool X_REDUCE,
99 bool Y_REDUCE,
100 bool Z_REDUCE,
101 typename T,
102 typename Func,
103 typename _dim3,
104 typename _dim3_2>
105__device__ void blockReduce(
106 T& out,
107 const T& inp_val,
108 Func reduction_op,
109 const _dim3& thread_idx,
110 const _dim3_2& block_dim,
111 T* shared_mem,
112 bool read_write_pred,
113 T init_val) {
114 blockReduce<X_REDUCE, Y_REDUCE, Z_REDUCE, T, Func, _dim3, _dim3_2>(
115 out,
116 inp_val,
117 reduction_op,
118 thread_idx,
119 block_dim,
120 shared_mem,
121 read_write_pred,
122 read_write_pred,
123 init_val);
124}
125)";
126
127} // namespace nvfuser_resources
128