1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/fused_welford_impl.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* fused_welford_impl_cu = R"( |
7 | namespace fused_reduction { |
8 | |
9 | namespace impl { |
10 | |
11 | //! Implementation helper for welfordEach. |
12 | template <int ValIdx, typename Triplet0, typename Triplet1> |
13 | struct WelfordForEach { |
14 | static __inline__ __device__ void call( |
15 | Triplet0& triplet0, |
16 | nvfuser_index_t offset0, |
17 | const Triplet1& triplet1, |
18 | nvfuser_index_t offset1) { |
19 | static_assert( |
20 | Triplet0::num_vals == Triplet1::num_vals, "Invalid Triplet types"); |
21 | static_assert( |
22 | IsSameType<typename Triplet0::DataType, typename Triplet1::DataType>:: |
23 | value, |
24 | "Invalid Triplet types"); |
25 | static_assert( |
26 | IsSameType<typename Triplet0::IndexType, typename Triplet1::IndexType>:: |
27 | value, |
28 | "Invalid Triplet types"); |
29 | |
30 | using DataType = typename Triplet0::DataType; |
31 | using IndexType = typename Triplet0::IndexType; |
32 | |
33 | WelfordForEach<ValIdx - 1, Triplet0, Triplet1>::call( |
34 | triplet0, offset0, triplet1, offset1); |
35 | welfordCombine<DataType, IndexType>( |
36 | triplet0.avg.val<ValIdx>(offset0), |
37 | triplet0.var.val<ValIdx>(offset0), |
38 | triplet0.N.val<ValIdx>(offset0), |
39 | triplet1.avg.val<ValIdx>(offset1), |
40 | triplet1.var.val<ValIdx>(offset1), |
41 | triplet1.N.val<ValIdx>(offset1)); |
42 | } |
43 | }; |
44 | |
45 | template <typename Triplet0, typename Triplet1> |
46 | struct WelfordForEach<-1, Triplet0, Triplet1> { |
47 | __inline__ __device__ static void call( |
48 | Triplet0& triplet0, |
49 | nvfuser_index_t offset0, |
50 | const Triplet1& triplet1, |
51 | nvfuser_index_t offset1) {} |
52 | }; |
53 | |
54 | //! Call welfordCombine with each of the triplet tuples. This is a |
55 | //! welford version of reduceEach. |
56 | template <typename Triplet0, typename Triplet1> |
57 | __inline__ __device__ static void welfordEach( |
58 | Triplet0& triplet0, |
59 | nvfuser_index_t offset0, |
60 | const Triplet1& triplet1, |
61 | nvfuser_index_t offset1) { |
62 | WelfordForEach<Triplet0::num_vals - 1, Triplet0, Triplet1>::call( |
63 | triplet0, offset0, triplet1, offset1); |
64 | } |
65 | |
66 | // Welford version of BlockReduceEach |
67 | template < |
68 | int idx, |
69 | bool BROADCAST, |
70 | bool FORWARD_PROTECT_SMEM, |
71 | typename LocalWelfordTripletTupleT> |
72 | struct BlockWelfordEach { |
73 | __inline__ __device__ static void reduce( |
74 | LocalWelfordTripletTupleT& block_result, |
75 | const LocalWelfordTripletTupleT& partial_result, |
76 | PtrTuple< |
77 | typename LocalWelfordTripletTupleT::DataType, |
78 | typename LocalWelfordTripletTupleT::DataType, |
79 | typename LocalWelfordTripletTupleT::IndexType> shared_buf, |
80 | bool has_block_result, |
81 | int tid_in_reduction, |
82 | int num_threads_per_reduction, |
83 | int num_elements_per_reduction, |
84 | int reduction_idx) { |
85 | // Finish the reduction of each tuple value with a smaller offset |
86 | BlockWelfordEach<idx - 1, BROADCAST, true, LocalWelfordTripletTupleT>:: |
87 | reduce( |
88 | block_result, |
89 | partial_result, |
90 | shared_buf, |
91 | has_block_result, |
92 | tid_in_reduction, |
93 | num_threads_per_reduction, |
94 | num_elements_per_reduction, |
95 | reduction_idx); |
96 | |
97 | if (num_elements_per_reduction == 1) { |
98 | if (has_block_result) { |
99 | copyWelfordTripletTuple(block_result, partial_result); |
100 | } |
101 | return; |
102 | } |
103 | |
104 | using DataType = typename LocalWelfordTripletTupleT::DataType; |
105 | using IndexType = typename LocalWelfordTripletTupleT::IndexType; |
106 | |
107 | LocalTuple<DataType, DataType, IndexType> block_result_i( |
108 | partial_result.avg.val<idx>(0), |
109 | partial_result.var.val<idx>(0), |
110 | partial_result.N.val<idx>(0)); |
111 | |
112 | const auto smem_offset = |
113 | reduction_idx * num_threads_per_reduction + tid_in_reduction; |
114 | |
115 | const int np2 = 1 << (31 - __clz(num_elements_per_reduction)); |
116 | |
117 | // Threads values are initialized, so all can participate here |
118 | if (tid_in_reduction >= np2) { |
119 | copyTuple(shared_buf, smem_offset, block_result_i); |
120 | } |
121 | |
122 | block_sync::sync(); |
123 | if (tid_in_reduction < np2 && |
124 | tid_in_reduction + np2 < num_elements_per_reduction) { |
125 | impl::reduceTuple( |
126 | block_result_i, |
127 | 0, |
128 | shared_buf, |
129 | smem_offset + np2, |
130 | welfordCombine<DataType, IndexType>); |
131 | } |
132 | |
133 | if (tid_in_reduction < np2) { |
134 | copyTuple(shared_buf, smem_offset, block_result_i); |
135 | } |
136 | |
137 | // Always sync when communicating across smem |
138 | block_sync::sync(); |
139 | |
140 | // Reduce down to 2 values, last thread will do the final reduction and |
141 | // can save a syncthreads this way |
142 | for (int factor = np2 / 2; factor > 1; factor >>= 1) { |
143 | if (tid_in_reduction < factor) { |
144 | impl::reduceTuple( |
145 | shared_buf, |
146 | smem_offset, |
147 | shared_buf, |
148 | smem_offset + factor, |
149 | welfordCombine<DataType, IndexType>); |
150 | } |
151 | block_sync::sync(); |
152 | } |
153 | |
154 | copyTuple(block_result_i, shared_buf, smem_offset); |
155 | |
156 | // Do the last reduction |
157 | if (has_block_result) { |
158 | impl::reduceTuple( |
159 | block_result_i, |
160 | 0, |
161 | shared_buf, |
162 | smem_offset + 1, |
163 | welfordCombine<DataType, IndexType>); |
164 | } |
165 | |
166 | if (BROADCAST) { |
167 | if (has_block_result) { |
168 | // Put result back in shared memory, put in the first entry of the |
169 | // reduction segment's buffer |
170 | copyTuple( |
171 | shared_buf, |
172 | reduction_idx * num_threads_per_reduction, |
173 | block_result_i); |
174 | } |
175 | |
176 | // Sync threads to make sure result is in smem |
177 | block_sync::sync(); |
178 | |
179 | copyTuple( |
180 | block_result_i, |
181 | shared_buf, |
182 | reduction_idx * num_threads_per_reduction); |
183 | } |
184 | |
185 | block_result.avg.val<idx>(0) = block_result_i.val<0>(0); |
186 | block_result.var.val<idx>(0) = block_result_i.val<1>(0); |
187 | block_result.N.val<idx>(0) = block_result_i.val<2>(0); |
188 | |
189 | if (FORWARD_PROTECT_SMEM) { |
190 | block_sync::sync(); |
191 | } |
192 | } |
193 | }; |
194 | |
195 | // Specialization for idx == -1, i.e., no value to reduce. |
196 | template < |
197 | bool BROADCAST, |
198 | bool FORWARD_PROTECT_SMEM, |
199 | typename LocalWelfordTripletTupleT> |
200 | struct BlockWelfordEach< |
201 | -1, |
202 | BROADCAST, |
203 | FORWARD_PROTECT_SMEM, |
204 | LocalWelfordTripletTupleT> { |
205 | __inline__ __device__ static void reduce( |
206 | LocalWelfordTripletTupleT& block_result, |
207 | const LocalWelfordTripletTupleT& partial_result, |
208 | PtrTuple< |
209 | typename LocalWelfordTripletTupleT::DataType, |
210 | typename LocalWelfordTripletTupleT::DataType, |
211 | typename LocalWelfordTripletTupleT::IndexType> shared_buf, |
212 | bool has_block_result, |
213 | int tid_in_reduction, |
214 | int num_threads_per_reduction, |
215 | int num_elements_per_reduction, |
216 | int reduction_idx) {} |
217 | }; |
218 | |
219 | //! Welford version of blockReduceEach. Perform block-parallel Welford |
220 | //! reduction of each Welford triplet. |
221 | template < |
222 | bool BROADCAST, |
223 | bool FORWARD_PROTECT_SMEM, |
224 | typename LocalWelfordTripletTupleT> |
225 | __inline__ __device__ void blockWelfordEach( |
226 | LocalWelfordTripletTupleT& block_result, |
227 | const LocalWelfordTripletTupleT& partial_result, |
228 | PtrTuple< |
229 | typename LocalWelfordTripletTupleT::DataType, |
230 | typename LocalWelfordTripletTupleT::DataType, |
231 | typename LocalWelfordTripletTupleT::IndexType> shared_buf, |
232 | bool has_block_result, |
233 | int tid_in_reduction, |
234 | int num_threads_per_reduction, |
235 | int num_elements_per_reduction, |
236 | int reduction_idx) { |
237 | BlockWelfordEach< |
238 | LocalWelfordTripletTupleT::num_vals - 1, |
239 | BROADCAST, |
240 | FORWARD_PROTECT_SMEM, |
241 | LocalWelfordTripletTupleT>:: |
242 | reduce( |
243 | block_result, |
244 | partial_result, |
245 | shared_buf, |
246 | has_block_result, |
247 | tid_in_reduction, |
248 | num_threads_per_reduction, |
249 | num_elements_per_reduction, |
250 | reduction_idx); |
251 | } |
252 | |
253 | } // namespace impl |
254 | |
255 | template < |
256 | int X_BLOCK, |
257 | int Y_BLOCK, |
258 | int Z_BLOCK, |
259 | int X_THREAD, |
260 | int Y_THREAD, |
261 | int Z_THREAD, |
262 | bool PERSISTENT_REDUCTION, |
263 | bool BROADCAST> |
264 | template <int NumArgs, typename DataType, typename IndexType> |
265 | __device__ __inline__ void ParallelReduce< |
266 | X_BLOCK, |
267 | Y_BLOCK, |
268 | Z_BLOCK, |
269 | X_THREAD, |
270 | Y_THREAD, |
271 | Z_THREAD, |
272 | PERSISTENT_REDUCTION, |
273 | BROADCAST>:: |
274 | welfordGroup( |
275 | typename MakeRefTuple<NumArgs, DataType>::type out_avg, |
276 | typename MakeRefTuple<NumArgs, DataType>::type out_var, |
277 | typename MakeRefTuple<NumArgs, IndexType>::type out_N, |
278 | const typename MakeConstRefTuple<NumArgs, DataType>::type& inp_avg, |
279 | const typename MakeConstRefTuple<NumArgs, DataType>::type& inp_var, |
280 | const typename MakeConstRefTuple<NumArgs, IndexType>::type& inp_N, |
281 | const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg, |
282 | const typename MakeLocalTuple<NumArgs, DataType>::type& init_var, |
283 | const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N, |
284 | typename MakeVolatilePtrTuple<NumArgs, DataType>::type |
285 | global_work_buffer_avg, |
286 | typename MakeVolatilePtrTuple<NumArgs, DataType>::type |
287 | global_work_buffer_var, |
288 | typename MakeVolatilePtrTuple<NumArgs, IndexType>::type |
289 | global_work_buffer_N, |
290 | int64_t* global_sync_buffer, |
291 | PtrTuple<DataType, DataType, IndexType> shared_buf, |
292 | const typename MakeLocalTuple<NumArgs, bool>::type& read_preds, |
293 | const typename MakeLocalTuple<NumArgs, bool>::type& write_preds) { |
294 | const ConstRefWelfordTripletTuple<NumArgs, DataType, IndexType> inp( |
295 | inp_avg, inp_var, inp_N); |
296 | RefWelfordTripletTuple<NumArgs, DataType, IndexType> out( |
297 | out_avg, out_var, out_N); |
298 | |
299 | // If no reduction needed, just return input |
300 | if (!BLOCK_REDUCE && !GRID_REDUCE) { |
301 | copyWelfordTripletTupleIf(out, inp, read_preds && write_preds); |
302 | return; |
303 | } |
304 | |
305 | // Don't read/write in temporary buffers if in a predicated dimension |
306 | const bool block_reduce_participate = index_utils:: |
307 | maskedIsZero<isPred(X_THREAD), isPred(Y_THREAD), isPred(Z_THREAD)>( |
308 | threadIdx); |
309 | |
310 | // Only threads that with id == 0 in the dimensions being reduced will |
311 | // have a valid result |
312 | const bool has_block_result = index_utils:: |
313 | maskedIsZero<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( |
314 | threadIdx); |
315 | |
316 | LocalWelfordTripletTuple<NumArgs, DataType, IndexType> block_result( |
317 | init_avg, init_var, init_N); |
318 | |
319 | // Initial per-block reduction. Result is broadcast if specified |
320 | // and this call is block reduction only. |
321 | welfordGroupBlock<!GRID_REDUCE && BROADCAST>( |
322 | block_result, inp, shared_buf, read_preds, block_reduce_participate); |
323 | |
324 | // If block reduction only, save to out and exit |
325 | if (!GRID_REDUCE) { |
326 | copyWelfordTripletTupleIf( |
327 | out, |
328 | block_result, |
329 | write_preds && |
330 | (block_reduce_participate && (BROADCAST || has_block_result))); |
331 | |
332 | // Need a block sync here as reduceGroupBlock does not |
333 | // forward-protect the smem buffer. This block sync is not |
334 | // necessary when a grid reduction follows since a block sync is |
335 | // done just before the grid sync. |
336 | block_sync::sync(); |
337 | return; |
338 | } |
339 | |
340 | // -- START GRID REDUCTION -- // |
341 | // Grid reductions are more challenging for two reasons, (1) the reduction |
342 | // itself is 3D instead of 2D because we now have an iter domain space in |
343 | // the grid dimension. (2) a tree reduction isn't performed, instead all |
344 | // blocks will populate GMEM and one block will finish the grid reduction. |
345 | |
346 | // What is the grid reduction size, block reduction already performed so |
347 | // that doesn't have to be taken into consideration |
348 | const auto grid_red_size = index_utils:: |
349 | maskedSize<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
350 | gridDim); |
351 | |
352 | // Which ID in the reduction is this block. Threads can participate in |
353 | // multiple grid reductions, but the block will have the same relative index |
354 | // in those reductions |
355 | const auto idx_in_grid_red = index_utils:: |
356 | maskedOffset<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
357 | blockIdx, gridDim); |
358 | |
359 | // How many grid reductions have to be performed, in the grid dimension |
360 | const auto num_block_iters = index_utils:: |
361 | maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(gridDim); |
362 | |
363 | // Which grid reduction does this block participate in, in the grid |
364 | // dimension |
365 | const auto block_red_idx_offset = index_utils:: |
366 | maskedOffset<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>( |
367 | blockIdx, gridDim); |
368 | |
369 | // How many grid reductions have to be performed, in the block dimension |
370 | const auto num_thread_iters = index_utils:: |
371 | maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
372 | blockDim); |
373 | |
374 | // Which grid reduction does this thread participate in, in the block |
375 | // dimension |
376 | const auto thread_red_idx_offset = index_utils:: |
377 | maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
378 | threadIdx, blockDim); |
379 | |
380 | // 3D buffer of reductions: |
381 | // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] |
382 | // Offset into the work buffer |
383 | auto work_buf_offset = |
384 | (idx_in_grid_red * num_block_iters + block_red_idx_offset) * |
385 | num_thread_iters + |
386 | thread_red_idx_offset; |
387 | |
388 | // Don't read/write in temporary buffers if in a predicated dimension |
389 | bool grid_reduce_participate = index_utils:: |
390 | maskedIsZero<isPred(X_BLOCK), isPred(Y_BLOCK), isPred(Z_BLOCK)>(blockIdx); |
391 | |
392 | VolatilePtrWelfordTripletTuple<NumArgs, DataType, IndexType> |
393 | global_work_buffer( |
394 | global_work_buffer_avg, global_work_buffer_var, global_work_buffer_N); |
395 | |
396 | if (PERSISTENT_REDUCTION && flip) { |
397 | auto global_buffer_size = |
398 | index_utils:: |
399 | maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>( |
400 | gridDim) * |
401 | index_utils:: |
402 | maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
403 | blockDim) * |
404 | grid_red_size; |
405 | global_work_buffer += global_buffer_size; |
406 | } |
407 | flip = !flip; |
408 | |
409 | // Per-block partial reduction to global work buffer |
410 | if (grid_reduce_participate && block_reduce_participate && has_block_result) { |
411 | copyWelfordTripletTuple(global_work_buffer, work_buf_offset, block_result); |
412 | } |
413 | |
414 | // -- GLOBAL BUFFER FILLED -- // |
415 | |
416 | bool last_block = index_utils:: |
417 | maskedIsLast<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
418 | blockIdx, gridDim); |
419 | |
420 | if (grid_reduce_participate) { |
421 | // Don't need to sync up blocks that are not participating in this |
422 | // reduction |
423 | grid_sync::sync< |
424 | isReduce(X_BLOCK), |
425 | isReduce(Y_BLOCK), |
426 | isReduce(Z_BLOCK), |
427 | PERSISTENT_REDUCTION>( |
428 | global_sync_buffer[block_red_idx_offset], grid_red_size, last_block); |
429 | } |
430 | |
431 | // -- START BLOCK CLEANUP -- // |
432 | welfordGroupLastBlock( |
433 | out, |
434 | global_work_buffer, |
435 | LocalWelfordTripletTuple<NumArgs, DataType, IndexType>( |
436 | init_avg, init_var, init_N), |
437 | shared_buf, |
438 | block_red_idx_offset, |
439 | num_thread_iters, |
440 | num_block_iters, |
441 | thread_red_idx_offset, |
442 | grid_red_size, |
443 | write_preds, |
444 | block_reduce_participate, |
445 | grid_reduce_participate); |
446 | |
447 | // Forward protect the smem buffer |
448 | block_sync::sync(); |
449 | } |
450 | |
451 | template < |
452 | int X_BLOCK, |
453 | int Y_BLOCK, |
454 | int Z_BLOCK, |
455 | int X_THREAD, |
456 | int Y_THREAD, |
457 | int Z_THREAD, |
458 | bool PERSISTENT_REDUCTION, |
459 | bool BROADCAST> |
460 | template < |
461 | bool BLOCK_BROADCAST, |
462 | int NumVals, |
463 | typename DataType, |
464 | typename IndexType> |
465 | __device__ __inline__ void ParallelReduce< |
466 | X_BLOCK, |
467 | Y_BLOCK, |
468 | Z_BLOCK, |
469 | X_THREAD, |
470 | Y_THREAD, |
471 | Z_THREAD, |
472 | PERSISTENT_REDUCTION, |
473 | BROADCAST>:: |
474 | )" |
475 | R"( |
476 | welfordGroupBlock( |
477 | LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result, |
478 | const ConstRefWelfordTripletTuple<NumVals, DataType, IndexType>& inp, |
479 | PtrTuple<DataType, DataType, IndexType> shared_buf, |
480 | const typename MakeLocalTuple<NumVals, bool>::type& read_preds, |
481 | bool block_reduce_participate) { |
482 | const bool has_block_result = index_utils:: |
483 | maskedIsZero<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( |
484 | threadIdx); |
485 | |
486 | copyWelfordTripletTupleIf( |
487 | block_result, inp, block_reduce_participate && read_preds); |
488 | |
489 | // Size of the block reduction segment, can be an int since it's limited |
490 | // to number of threads |
491 | const int block_reduction_size = index_utils:: |
492 | maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( |
493 | blockDim); |
494 | |
495 | // Index in the reduction segment, can be an int since it's limited to |
496 | // number of threads |
497 | const int tid_in_block_reduction = index_utils:: |
498 | maskedOffset<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( |
499 | threadIdx, blockDim); |
500 | |
501 | // ID of the block reduction this thread is participating in |
502 | // |
503 | // If any of the parallel dimensions are predicated out, that means |
504 | // they've already been reduced, so we only care about the first thread in |
505 | // that dimension. Therefore don't expand the reduction_idx by that |
506 | // dimension |
507 | const int block_reduction_idx = index_utils:: |
508 | maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
509 | threadIdx, blockDim); |
510 | |
511 | // Do not protect the smem buffer as it's not always necessary. |
512 | impl::blockWelfordEach< |
513 | BLOCK_BROADCAST, |
514 | false, |
515 | LocalWelfordTripletTuple<NumVals, DataType, IndexType>>( |
516 | block_result, |
517 | block_result, |
518 | shared_buf, |
519 | has_block_result, |
520 | tid_in_block_reduction, |
521 | block_reduction_size, |
522 | block_reduction_size, |
523 | block_reduction_idx); |
524 | } |
525 | |
526 | template < |
527 | int X_BLOCK, |
528 | int Y_BLOCK, |
529 | int Z_BLOCK, |
530 | int X_THREAD, |
531 | int Y_THREAD, |
532 | int Z_THREAD, |
533 | bool PERSISTENT_REDUCTION, |
534 | bool BROADCAST> |
535 | template <int NumVals, typename DataType, typename IndexType> |
536 | __device__ __inline__ void ParallelReduce< |
537 | X_BLOCK, |
538 | Y_BLOCK, |
539 | Z_BLOCK, |
540 | X_THREAD, |
541 | Y_THREAD, |
542 | Z_THREAD, |
543 | PERSISTENT_REDUCTION, |
544 | BROADCAST>:: |
545 | welfordGroupLastBlock( |
546 | RefWelfordTripletTuple<NumVals, DataType, IndexType>& out, |
547 | const VolatilePtrWelfordTripletTuple<NumVals, DataType, IndexType>& |
548 | global_work_buffer, |
549 | const LocalWelfordTripletTuple<NumVals, DataType, IndexType>& init_val, |
550 | PtrTuple<DataType, DataType, IndexType> shared_buf, |
551 | nvfuser_index_t block_red_idx_offset, |
552 | nvfuser_index_t num_thread_iters, |
553 | nvfuser_index_t num_block_iters, |
554 | nvfuser_index_t thread_red_idx_offset, |
555 | nvfuser_index_t grid_red_size, |
556 | const typename MakeLocalTuple<NumVals, bool>::type& write_preds, |
557 | bool block_reduce_participate, |
558 | bool grid_reduce_participate) { |
559 | // Initialize block result |
560 | auto last_block_result = init_val; |
561 | |
562 | const bool last_block = index_utils:: |
563 | maskedIsLast<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
564 | blockIdx, gridDim); |
565 | |
566 | if ((PERSISTENT_REDUCTION || last_block) && grid_reduce_participate) { |
567 | // Can use the last block to reduce all the values the blocks filled in. |
568 | // Can use any thread that has been predicated, or has been reduced to do |
569 | // this reduction, cannot use any block that's associated with an |
570 | // iteration domain |
571 | |
572 | // Start with non-block reduction |
573 | |
574 | // Index in the reduction segment |
575 | int tid_in_block_reduction = index_utils::maskedOffset< |
576 | activeNotIter(X_THREAD), |
577 | activeNotIter(Y_THREAD), |
578 | activeNotIter(Z_THREAD)>(threadIdx, blockDim); |
579 | |
580 | int block_reduction_size = index_utils::maskedSize< |
581 | activeNotIter(X_THREAD), |
582 | activeNotIter(Y_THREAD), |
583 | activeNotIter(Z_THREAD)>(blockDim); |
584 | |
585 | bool has_block_result = index_utils::maskedIsZero< |
586 | activeNotIter(X_THREAD), |
587 | activeNotIter(Y_THREAD), |
588 | activeNotIter(Z_THREAD)>(threadIdx); |
589 | |
590 | // 3D buffer of reductions: |
591 | // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] |
592 | // Change the offset, we want to keep the last two dimensions, but the |
593 | // first dimension is what we will reduce over |
594 | const auto work_buf_offset = |
595 | block_red_idx_offset * num_thread_iters + thread_red_idx_offset; |
596 | for (auto reduction_i = tid_in_block_reduction; reduction_i < grid_red_size; |
597 | reduction_i += block_reduction_size) { |
598 | impl::welfordEach( |
599 | last_block_result, |
600 | 0, |
601 | global_work_buffer, |
602 | work_buf_offset + reduction_i * num_block_iters * num_thread_iters); |
603 | } |
604 | |
605 | // Which block reduction this thread is participating in |
606 | int block_reduction_idx = index_utils:: |
607 | maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
608 | threadIdx, blockDim); |
609 | |
610 | impl::blockWelfordEach< |
611 | BROADCAST, |
612 | false, |
613 | LocalWelfordTripletTuple<NumVals, DataType, IndexType>>( |
614 | last_block_result, |
615 | last_block_result, |
616 | shared_buf, |
617 | has_block_result, |
618 | tid_in_block_reduction, |
619 | block_reduction_size, |
620 | min(grid_red_size, block_reduction_size), |
621 | block_reduction_idx); |
622 | |
623 | copyWelfordTripletTupleIf( |
624 | out, |
625 | last_block_result, |
626 | write_preds && |
627 | (block_reduce_participate && (BROADCAST || has_block_result))); |
628 | } |
629 | } |
630 | |
631 | } // namespace fused_reduction |
632 | )" ; |
633 | |
634 | } // namespace nvfuser_resources |
635 | |