1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/welford.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* welford_cu = R"( |
7 | // ----------------------------------------------------------------------------------------------- |
8 | // Block Welford Primitives |
9 | // ----------------------------------------------------------------------------------------------- |
10 | // Basic utility for welford update. Can be used to scan one value, or two merge |
11 | // two welford results |
12 | template <typename T, typename TN> |
13 | __inline__ __device__ void welfordCombine( |
14 | T& a_avg, |
15 | T& a_M2, |
16 | TN& a_N, |
17 | const T b_avg, |
18 | const T b_M2, |
19 | TN b_N) { |
20 | if (b_N == 0) { |
21 | return; |
22 | } |
23 | TN ab_N = a_N + b_N; |
24 | T b_N_div_ab_N = ((T)(nvfuser_index_t)(b_N)) / ((T)(nvfuser_index_t)(ab_N)); |
25 | T delta = b_avg - a_avg; |
26 | a_avg += delta * b_N_div_ab_N; |
27 | a_M2 += b_M2 + delta * delta * ((T)(nvfuser_index_t)(a_N)) * b_N_div_ab_N; |
28 | a_N = ab_N; |
29 | } |
30 | |
31 | // [Z,Y,X]_THREADS is the number of participating threads in the z, y, x |
32 | // dimension of the block. |
33 | template < |
34 | bool X_REDUCE, |
35 | bool Y_REDUCE, |
36 | bool Z_REDUCE, |
37 | typename T, |
38 | typename TN, |
39 | typename _dim3, |
40 | typename _dim3_2> |
41 | __inline__ __device__ void blockWelford( |
42 | T& out_avg, |
43 | T& out_M2, |
44 | TN& out_N, |
45 | const T& in_avg, |
46 | const T& in_M2, |
47 | const TN& in_N, |
48 | const _dim3& thread_idx, |
49 | const _dim3_2& block_dim, |
50 | T* shared_mem_avg, |
51 | T* shared_mem_M2, |
52 | TN* shared_mem_N, |
53 | bool read_pred, |
54 | bool write_pred, |
55 | T init_val) { |
56 | // If this thread will output a final result |
57 | bool should_write = |
58 | index_utils::maskedIsZero<X_REDUCE, Y_REDUCE, Z_REDUCE>(thread_idx); |
59 | |
60 | // Size of the reduction segments |
61 | unsigned int reduction_size = |
62 | index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(block_dim); |
63 | |
64 | // Index into the reduction segment |
65 | unsigned int reduction_tid = |
66 | index_utils::maskedOffset<X_REDUCE, Y_REDUCE, Z_REDUCE>( |
67 | thread_idx, block_dim); |
68 | |
69 | // Index of the reduction segment |
70 | unsigned int reduction_idx = |
71 | index_utils::maskedOffset<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>( |
72 | thread_idx, block_dim); |
73 | |
74 | // Offset into smem for the current thread |
75 | unsigned int smem_offset = reduction_idx * reduction_size + reduction_tid; |
76 | |
77 | if (read_pred) { |
78 | shared_mem_avg[smem_offset] = in_avg; |
79 | shared_mem_M2[smem_offset] = in_M2; |
80 | shared_mem_N[smem_offset] = in_N; |
81 | } else { |
82 | shared_mem_avg[smem_offset] = init_val; |
83 | shared_mem_M2[smem_offset] = init_val; |
84 | shared_mem_N[smem_offset] = 0; |
85 | } |
86 | |
87 | block_sync::sync(); |
88 | // Reduce down to nearest power of 2: |
89 | int np2 = 1 << (31 - __clz(reduction_size)); |
90 | |
91 | if (reduction_tid < np2 && reduction_tid + np2 < reduction_size) { |
92 | welfordCombine( |
93 | shared_mem_avg[smem_offset], |
94 | shared_mem_M2[smem_offset], |
95 | shared_mem_N[smem_offset], |
96 | shared_mem_avg[smem_offset + np2], |
97 | shared_mem_M2[smem_offset + np2], |
98 | shared_mem_N[smem_offset + np2]); |
99 | } |
100 | block_sync::sync(); |
101 | |
102 | // loop peel the final iteration to save one syncthread for the end |
103 | for (int factor = np2 / 2; factor > 1; factor >>= 1) { |
104 | if (reduction_tid < factor) { |
105 | welfordCombine( |
106 | shared_mem_avg[smem_offset], |
107 | shared_mem_M2[smem_offset], |
108 | shared_mem_N[smem_offset], |
109 | shared_mem_avg[smem_offset + factor], |
110 | shared_mem_M2[smem_offset + factor], |
111 | shared_mem_N[smem_offset + factor]); |
112 | } |
113 | block_sync::sync(); |
114 | } |
115 | |
116 | if (should_write && write_pred) { |
117 | T res_avg = out_avg; |
118 | T res_M2 = out_M2; |
119 | TN res_N = out_N; |
120 | welfordCombine( |
121 | res_avg, |
122 | res_M2, |
123 | res_N, |
124 | shared_mem_avg[smem_offset], |
125 | shared_mem_M2[smem_offset], |
126 | shared_mem_N[smem_offset]); |
127 | if (reduction_size > 1) { |
128 | welfordCombine( |
129 | res_avg, |
130 | res_M2, |
131 | res_N, |
132 | shared_mem_avg[smem_offset + 1], |
133 | shared_mem_M2[smem_offset + 1], |
134 | shared_mem_N[smem_offset + 1]); |
135 | } |
136 | out_avg = res_avg; |
137 | out_M2 = res_M2; |
138 | out_N = res_N; |
139 | } |
140 | block_sync::sync(); |
141 | } |
142 | |
143 | // Use the same pred for both reads and writes |
144 | template < |
145 | bool X_REDUCE, |
146 | bool Y_REDUCE, |
147 | bool Z_REDUCE, |
148 | typename T, |
149 | typename TN, |
150 | typename _dim3, |
151 | typename _dim3_2> |
152 | __inline__ __device__ void blockWelford( |
153 | T& out_avg, |
154 | T& out_M2, |
155 | TN& out_N, |
156 | const T& in_avg, |
157 | const T& in_M2, |
158 | const TN& in_N, |
159 | const _dim3& thread_idx, |
160 | const _dim3_2& block_dim, |
161 | T* shared_mem_avg, |
162 | T* shared_mem_M2, |
163 | TN* shared_mem_N, |
164 | bool read_write_pred, |
165 | T init_val) { |
166 | blockWelford<X_REDUCE, Y_REDUCE, Z_REDUCE, T, TN, _dim3, _dim3_2>( |
167 | out_avg, |
168 | out_M2, |
169 | out_N, |
170 | in_avg, |
171 | in_M2, |
172 | in_N, |
173 | thread_idx, |
174 | block_dim, |
175 | shared_mem_avg, |
176 | shared_mem_M2, |
177 | shared_mem_N, |
178 | read_write_pred, |
179 | read_write_pred, |
180 | init_val); |
181 | } |
182 | // ----------------------------------------------------------------------------------------------- |
183 | // Grid Welford Prototype |
184 | // ----------------------------------------------------------------------------------------------- |
185 | namespace welford { |
186 | |
187 | template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T, typename TN> |
188 | __device__ void gridWelfordLastBlock( |
189 | T& out_avg, |
190 | T& out_M2, |
191 | TN& out_N, |
192 | const volatile T* in_avg, |
193 | const volatile T* in_M2, |
194 | const volatile TN* in_N, |
195 | const nvfuser_index_t |
196 | grid_reduction_segment_size, // Number of reductions across |
197 | // grid reduce dimensions |
198 | const nvfuser_index_t |
199 | block_reduction_segment_size, // Number of reductions across the block |
200 | T* shared_buf_avg, |
201 | T* shared_buf_M2, |
202 | TN* shared_buf_N, |
203 | bool write_pred, |
204 | T init_val) { |
205 | // We have to do num_reductions across reduction_size. The reductions are |
206 | // contiguous, but offset by reduction_size. There is an entry in "in" for |
207 | // every block, and every thread marked as true. Threads in dimensions marked |
208 | // as false can be used to parallelize the reduction. |
209 | |
210 | // Find the reduction id of the participating threads |
211 | const auto block_reduction_segment_idx = |
212 | index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>( |
213 | threadIdx, blockDim); |
214 | |
215 | // Find an id associated within a reduction segment for all |
216 | // "non-participating" threads, which will parallelize the reductions for the |
217 | // "participating" threads |
218 | const auto id_in_block_segment = |
219 | index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( |
220 | threadIdx, blockDim); |
221 | |
222 | // Stride by the "non-participating" threads |
223 | const auto input_stride_for_thread_in_segment = |
224 | index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim); |
225 | |
226 | T inp_avg = init_val; |
227 | T inp_M2 = init_val; |
228 | TN inp_N = 0; |
229 | |
230 | // Block stride across the reduction until we only have one value per thread |
231 | for (nvfuser_index_t reduction_i = id_in_block_segment; |
232 | reduction_i < grid_reduction_segment_size; |
233 | reduction_i += input_stride_for_thread_in_segment) { |
234 | auto work_buf_offset = reduction_i * block_reduction_segment_size + |
235 | block_reduction_segment_idx; |
236 | welfordCombine( |
237 | inp_avg, |
238 | inp_M2, |
239 | inp_N, |
240 | in_avg[work_buf_offset], |
241 | in_M2[work_buf_offset], |
242 | in_N[work_buf_offset]); |
243 | } |
244 | |
245 | // Block reduce the per thread values into per "participating" thread values |
246 | T inp_avg_tmp = init_val; |
247 | T inp_M2_tmp = init_val; |
248 | TN inp_N_tmp = 0; |
249 | blockWelford<!X_THREAD, !Y_THREAD, !Z_THREAD>( |
250 | inp_avg_tmp, |
251 | inp_M2_tmp, |
252 | inp_N_tmp, |
253 | inp_avg, |
254 | inp_M2, |
255 | inp_N, |
256 | threadIdx, |
257 | blockDim, |
258 | shared_buf_avg, |
259 | shared_buf_M2, |
260 | shared_buf_N, |
261 | true, |
262 | init_val); |
263 | const bool should_write = (X_THREAD || threadIdx.x == 0) && |
264 | (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0); |
265 | if (should_write && write_pred) { |
266 | welfordCombine(out_avg, out_M2, out_N, inp_avg_tmp, inp_M2_tmp, inp_N_tmp); |
267 | } |
268 | } |
269 | |
270 | // Grid welford combine. See GridReduction for more information |
271 | template < |
272 | bool X_BLOCK, |
273 | bool Y_BLOCK, |
274 | bool Z_BLOCK, |
275 | bool X_THREAD, |
276 | bool Y_THREAD, |
277 | bool Z_THREAD, |
278 | bool PERSISTENT_REDUCTION, |
279 | typename T, |
280 | typename TN> |
281 | __device__ void gridWelford( |
282 | T& out_avg, |
283 | T& out_M2, |
284 | TN& out_N, |
285 | const T& inp_avg, |
286 | const T& inp_M2, |
287 | const TN& inp_N, |
288 | volatile T* work_buf_avg, |
289 | volatile T* work_buf_M2, |
290 | volatile TN* work_buf_N, |
291 | Tensor<int64_t, 1> sync_flags, |
292 | T* shared_buf_avg, |
293 | T* shared_buf_M2, |
294 | TN* shared_buf_N, |
295 | bool read_pred, |
296 | bool write_pred, |
297 | T init_val, |
298 | const nvfuser_index_t entrance_ind, |
299 | const nvfuser_index_t n_entrances) { |
300 | // entrance index only matters for non-persistent re-entrant grid reductions. |
301 | const nvfuser_index_t entrance_ind_ = PERSISTENT_REDUCTION ? 0 : entrance_ind; |
302 | const nvfuser_index_t n_entrances_ = PERSISTENT_REDUCTION ? 1 : n_entrances; |
303 | |
304 | // Number of values to reduce in the reduction segment |
305 | const auto grid_reduction_segment_size = |
306 | index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim); |
307 | |
308 | // Index of the reduction we're performing out of the |
309 | // grid_reduction_segment_size |
310 | const auto idx_in_grid_segment = |
311 | index_utils::maskedOffset<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>( |
312 | blockIdx, gridDim); |
313 | |
314 | // Number of threads we can use in final reduction, Seems to assume all |
315 | // threads in the block participate |
316 | const auto block_reduction_segment_size = |
317 | index_utils::maskedSize<X_THREAD, Y_THREAD, Z_THREAD>(blockDim); |
318 | |
319 | // Number of reductions in the grid |
320 | const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION |
321 | ? 1 |
322 | : index_utils::maskedSize<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(gridDim); |
323 | |
324 | // advance to the offset for this segment |
325 | // index of reduction * size of the reduction * size of threads |
326 | work_buf_avg += (entrance_ind_ * grid_segment_size + idx_in_grid_segment) * |
327 | grid_reduction_segment_size * block_reduction_segment_size; |
328 | work_buf_M2 += (entrance_ind_ * grid_segment_size + idx_in_grid_segment) * |
329 | grid_reduction_segment_size * block_reduction_segment_size; |
330 | work_buf_N += (entrance_ind_ * grid_segment_size + idx_in_grid_segment) * |
331 | grid_reduction_segment_size * block_reduction_segment_size; |
332 | |
333 | if ((X_THREAD || threadIdx.x == 0) && (Y_THREAD || threadIdx.y == 0) && |
334 | (Z_THREAD || threadIdx.z == 0)) { |
335 | auto block_offset = |
336 | index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim); |
337 | auto thread_offset = |
338 | index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>( |
339 | threadIdx, blockDim); |
340 | auto work_buf_offset = |
341 | block_offset * block_reduction_segment_size + thread_offset; |
342 | if (read_pred) { |
343 | work_buf_avg[work_buf_offset] = inp_avg; |
344 | work_buf_M2[work_buf_offset] = inp_M2; |
345 | work_buf_N[work_buf_offset] = inp_N; |
346 | } else { |
347 | work_buf_avg[work_buf_offset] = init_val; |
348 | work_buf_M2[work_buf_offset] = init_val; |
349 | work_buf_N[work_buf_offset] = 0; |
350 | } |
351 | } |
352 | |
353 | if (PERSISTENT_REDUCTION) { |
354 | grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>( |
355 | sync_flags[idx_in_grid_segment], grid_reduction_segment_size); |
356 | } else { |
357 | // Use a different sync flag for each call |
358 | grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>( |
359 | sync_flags[entrance_ind_ * grid_segment_size + idx_in_grid_segment], |
360 | grid_reduction_segment_size); |
361 | } |
362 | |
363 | bool last_block = |
364 | index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim); |
365 | |
366 | if (last_block) { |
367 | // final reduction |
368 | gridWelfordLastBlock<X_THREAD, Y_THREAD, Z_THREAD>( |
369 | out_avg, |
370 | out_M2, |
371 | out_N, |
372 | work_buf_avg, |
373 | work_buf_M2, |
374 | work_buf_N, |
375 | grid_reduction_segment_size, |
376 | block_reduction_segment_size, |
377 | shared_buf_avg, |
378 | shared_buf_M2, |
379 | shared_buf_N, |
380 | write_pred, |
381 | init_val); |
382 | } |
383 | |
384 | if (PERSISTENT_REDUCTION) { |
385 | // Make sure we're done with global memory before we allow the kernel to |
386 | // continue |
387 | grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>( |
388 | sync_flags[idx_in_grid_segment], grid_reduction_segment_size); |
389 | } |
390 | } |
391 | |
392 | } // namespace welford |
393 | )" ; |
394 | |
395 | } // namespace nvfuser_resources |
396 | |