1// Generated from "/code/pytorch/third_party/nvfuser/runtime/welford.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* welford_cu = R"(
7// -----------------------------------------------------------------------------------------------
8// Block Welford Primitives
9// -----------------------------------------------------------------------------------------------
10// Basic utility for welford update. Can be used to scan one value, or two merge
11// two welford results
12template <typename T, typename TN>
13__inline__ __device__ void welfordCombine(
14 T& a_avg,
15 T& a_M2,
16 TN& a_N,
17 const T b_avg,
18 const T b_M2,
19 TN b_N) {
20 if (b_N == 0) {
21 return;
22 }
23 TN ab_N = a_N + b_N;
24 T b_N_div_ab_N = ((T)(nvfuser_index_t)(b_N)) / ((T)(nvfuser_index_t)(ab_N));
25 T delta = b_avg - a_avg;
26 a_avg += delta * b_N_div_ab_N;
27 a_M2 += b_M2 + delta * delta * ((T)(nvfuser_index_t)(a_N)) * b_N_div_ab_N;
28 a_N = ab_N;
29}
30
31// [Z,Y,X]_THREADS is the number of participating threads in the z, y, x
32// dimension of the block.
33template <
34 bool X_REDUCE,
35 bool Y_REDUCE,
36 bool Z_REDUCE,
37 typename T,
38 typename TN,
39 typename _dim3,
40 typename _dim3_2>
41__inline__ __device__ void blockWelford(
42 T& out_avg,
43 T& out_M2,
44 TN& out_N,
45 const T& in_avg,
46 const T& in_M2,
47 const TN& in_N,
48 const _dim3& thread_idx,
49 const _dim3_2& block_dim,
50 T* shared_mem_avg,
51 T* shared_mem_M2,
52 TN* shared_mem_N,
53 bool read_pred,
54 bool write_pred,
55 T init_val) {
56 // If this thread will output a final result
57 bool should_write =
58 index_utils::maskedIsZero<X_REDUCE, Y_REDUCE, Z_REDUCE>(thread_idx);
59
60 // Size of the reduction segments
61 unsigned int reduction_size =
62 index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(block_dim);
63
64 // Index into the reduction segment
65 unsigned int reduction_tid =
66 index_utils::maskedOffset<X_REDUCE, Y_REDUCE, Z_REDUCE>(
67 thread_idx, block_dim);
68
69 // Index of the reduction segment
70 unsigned int reduction_idx =
71 index_utils::maskedOffset<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>(
72 thread_idx, block_dim);
73
74 // Offset into smem for the current thread
75 unsigned int smem_offset = reduction_idx * reduction_size + reduction_tid;
76
77 if (read_pred) {
78 shared_mem_avg[smem_offset] = in_avg;
79 shared_mem_M2[smem_offset] = in_M2;
80 shared_mem_N[smem_offset] = in_N;
81 } else {
82 shared_mem_avg[smem_offset] = init_val;
83 shared_mem_M2[smem_offset] = init_val;
84 shared_mem_N[smem_offset] = 0;
85 }
86
87 block_sync::sync();
88 // Reduce down to nearest power of 2:
89 int np2 = 1 << (31 - __clz(reduction_size));
90
91 if (reduction_tid < np2 && reduction_tid + np2 < reduction_size) {
92 welfordCombine(
93 shared_mem_avg[smem_offset],
94 shared_mem_M2[smem_offset],
95 shared_mem_N[smem_offset],
96 shared_mem_avg[smem_offset + np2],
97 shared_mem_M2[smem_offset + np2],
98 shared_mem_N[smem_offset + np2]);
99 }
100 block_sync::sync();
101
102 // loop peel the final iteration to save one syncthread for the end
103 for (int factor = np2 / 2; factor > 1; factor >>= 1) {
104 if (reduction_tid < factor) {
105 welfordCombine(
106 shared_mem_avg[smem_offset],
107 shared_mem_M2[smem_offset],
108 shared_mem_N[smem_offset],
109 shared_mem_avg[smem_offset + factor],
110 shared_mem_M2[smem_offset + factor],
111 shared_mem_N[smem_offset + factor]);
112 }
113 block_sync::sync();
114 }
115
116 if (should_write && write_pred) {
117 T res_avg = out_avg;
118 T res_M2 = out_M2;
119 TN res_N = out_N;
120 welfordCombine(
121 res_avg,
122 res_M2,
123 res_N,
124 shared_mem_avg[smem_offset],
125 shared_mem_M2[smem_offset],
126 shared_mem_N[smem_offset]);
127 if (reduction_size > 1) {
128 welfordCombine(
129 res_avg,
130 res_M2,
131 res_N,
132 shared_mem_avg[smem_offset + 1],
133 shared_mem_M2[smem_offset + 1],
134 shared_mem_N[smem_offset + 1]);
135 }
136 out_avg = res_avg;
137 out_M2 = res_M2;
138 out_N = res_N;
139 }
140 block_sync::sync();
141}
142
143// Use the same pred for both reads and writes
144template <
145 bool X_REDUCE,
146 bool Y_REDUCE,
147 bool Z_REDUCE,
148 typename T,
149 typename TN,
150 typename _dim3,
151 typename _dim3_2>
152__inline__ __device__ void blockWelford(
153 T& out_avg,
154 T& out_M2,
155 TN& out_N,
156 const T& in_avg,
157 const T& in_M2,
158 const TN& in_N,
159 const _dim3& thread_idx,
160 const _dim3_2& block_dim,
161 T* shared_mem_avg,
162 T* shared_mem_M2,
163 TN* shared_mem_N,
164 bool read_write_pred,
165 T init_val) {
166 blockWelford<X_REDUCE, Y_REDUCE, Z_REDUCE, T, TN, _dim3, _dim3_2>(
167 out_avg,
168 out_M2,
169 out_N,
170 in_avg,
171 in_M2,
172 in_N,
173 thread_idx,
174 block_dim,
175 shared_mem_avg,
176 shared_mem_M2,
177 shared_mem_N,
178 read_write_pred,
179 read_write_pred,
180 init_val);
181}
182// -----------------------------------------------------------------------------------------------
183// Grid Welford Prototype
184// -----------------------------------------------------------------------------------------------
185namespace welford {
186
187template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T, typename TN>
188__device__ void gridWelfordLastBlock(
189 T& out_avg,
190 T& out_M2,
191 TN& out_N,
192 const volatile T* in_avg,
193 const volatile T* in_M2,
194 const volatile TN* in_N,
195 const nvfuser_index_t
196 grid_reduction_segment_size, // Number of reductions across
197 // grid reduce dimensions
198 const nvfuser_index_t
199 block_reduction_segment_size, // Number of reductions across the block
200 T* shared_buf_avg,
201 T* shared_buf_M2,
202 TN* shared_buf_N,
203 bool write_pred,
204 T init_val) {
205 // We have to do num_reductions across reduction_size. The reductions are
206 // contiguous, but offset by reduction_size. There is an entry in "in" for
207 // every block, and every thread marked as true. Threads in dimensions marked
208 // as false can be used to parallelize the reduction.
209
210 // Find the reduction id of the participating threads
211 const auto block_reduction_segment_idx =
212 index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>(
213 threadIdx, blockDim);
214
215 // Find an id associated within a reduction segment for all
216 // "non-participating" threads, which will parallelize the reductions for the
217 // "participating" threads
218 const auto id_in_block_segment =
219 index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
220 threadIdx, blockDim);
221
222 // Stride by the "non-participating" threads
223 const auto input_stride_for_thread_in_segment =
224 index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
225
226 T inp_avg = init_val;
227 T inp_M2 = init_val;
228 TN inp_N = 0;
229
230 // Block stride across the reduction until we only have one value per thread
231 for (nvfuser_index_t reduction_i = id_in_block_segment;
232 reduction_i < grid_reduction_segment_size;
233 reduction_i += input_stride_for_thread_in_segment) {
234 auto work_buf_offset = reduction_i * block_reduction_segment_size +
235 block_reduction_segment_idx;
236 welfordCombine(
237 inp_avg,
238 inp_M2,
239 inp_N,
240 in_avg[work_buf_offset],
241 in_M2[work_buf_offset],
242 in_N[work_buf_offset]);
243 }
244
245 // Block reduce the per thread values into per "participating" thread values
246 T inp_avg_tmp = init_val;
247 T inp_M2_tmp = init_val;
248 TN inp_N_tmp = 0;
249 blockWelford<!X_THREAD, !Y_THREAD, !Z_THREAD>(
250 inp_avg_tmp,
251 inp_M2_tmp,
252 inp_N_tmp,
253 inp_avg,
254 inp_M2,
255 inp_N,
256 threadIdx,
257 blockDim,
258 shared_buf_avg,
259 shared_buf_M2,
260 shared_buf_N,
261 true,
262 init_val);
263 const bool should_write = (X_THREAD || threadIdx.x == 0) &&
264 (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0);
265 if (should_write && write_pred) {
266 welfordCombine(out_avg, out_M2, out_N, inp_avg_tmp, inp_M2_tmp, inp_N_tmp);
267 }
268}
269
270// Grid welford combine. See GridReduction for more information
271template <
272 bool X_BLOCK,
273 bool Y_BLOCK,
274 bool Z_BLOCK,
275 bool X_THREAD,
276 bool Y_THREAD,
277 bool Z_THREAD,
278 bool PERSISTENT_REDUCTION,
279 typename T,
280 typename TN>
281__device__ void gridWelford(
282 T& out_avg,
283 T& out_M2,
284 TN& out_N,
285 const T& inp_avg,
286 const T& inp_M2,
287 const TN& inp_N,
288 volatile T* work_buf_avg,
289 volatile T* work_buf_M2,
290 volatile TN* work_buf_N,
291 Tensor<int64_t, 1> sync_flags,
292 T* shared_buf_avg,
293 T* shared_buf_M2,
294 TN* shared_buf_N,
295 bool read_pred,
296 bool write_pred,
297 T init_val,
298 const nvfuser_index_t entrance_ind,
299 const nvfuser_index_t n_entrances) {
300 // entrance index only matters for non-persistent re-entrant grid reductions.
301 const nvfuser_index_t entrance_ind_ = PERSISTENT_REDUCTION ? 0 : entrance_ind;
302 const nvfuser_index_t n_entrances_ = PERSISTENT_REDUCTION ? 1 : n_entrances;
303
304 // Number of values to reduce in the reduction segment
305 const auto grid_reduction_segment_size =
306 index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
307
308 // Index of the reduction we're performing out of the
309 // grid_reduction_segment_size
310 const auto idx_in_grid_segment =
311 index_utils::maskedOffset<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(
312 blockIdx, gridDim);
313
314 // Number of threads we can use in final reduction, Seems to assume all
315 // threads in the block participate
316 const auto block_reduction_segment_size =
317 index_utils::maskedSize<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
318
319 // Number of reductions in the grid
320 const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION
321 ? 1
322 : index_utils::maskedSize<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(gridDim);
323
324 // advance to the offset for this segment
325 // index of reduction * size of the reduction * size of threads
326 work_buf_avg += (entrance_ind_ * grid_segment_size + idx_in_grid_segment) *
327 grid_reduction_segment_size * block_reduction_segment_size;
328 work_buf_M2 += (entrance_ind_ * grid_segment_size + idx_in_grid_segment) *
329 grid_reduction_segment_size * block_reduction_segment_size;
330 work_buf_N += (entrance_ind_ * grid_segment_size + idx_in_grid_segment) *
331 grid_reduction_segment_size * block_reduction_segment_size;
332
333 if ((X_THREAD || threadIdx.x == 0) && (Y_THREAD || threadIdx.y == 0) &&
334 (Z_THREAD || threadIdx.z == 0)) {
335 auto block_offset =
336 index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
337 auto thread_offset =
338 index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>(
339 threadIdx, blockDim);
340 auto work_buf_offset =
341 block_offset * block_reduction_segment_size + thread_offset;
342 if (read_pred) {
343 work_buf_avg[work_buf_offset] = inp_avg;
344 work_buf_M2[work_buf_offset] = inp_M2;
345 work_buf_N[work_buf_offset] = inp_N;
346 } else {
347 work_buf_avg[work_buf_offset] = init_val;
348 work_buf_M2[work_buf_offset] = init_val;
349 work_buf_N[work_buf_offset] = 0;
350 }
351 }
352
353 if (PERSISTENT_REDUCTION) {
354 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
355 sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
356 } else {
357 // Use a different sync flag for each call
358 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
359 sync_flags[entrance_ind_ * grid_segment_size + idx_in_grid_segment],
360 grid_reduction_segment_size);
361 }
362
363 bool last_block =
364 index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
365
366 if (last_block) {
367 // final reduction
368 gridWelfordLastBlock<X_THREAD, Y_THREAD, Z_THREAD>(
369 out_avg,
370 out_M2,
371 out_N,
372 work_buf_avg,
373 work_buf_M2,
374 work_buf_N,
375 grid_reduction_segment_size,
376 block_reduction_segment_size,
377 shared_buf_avg,
378 shared_buf_M2,
379 shared_buf_N,
380 write_pred,
381 init_val);
382 }
383
384 if (PERSISTENT_REDUCTION) {
385 // Make sure we're done with global memory before we allow the kernel to
386 // continue
387 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
388 sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
389 }
390}
391
392} // namespace welford
393)";
394
395} // namespace nvfuser_resources
396