1// Generated from "/code/pytorch/third_party/nvfuser/runtime/fused_welford_impl.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* fused_welford_impl_cu = R"(
7namespace fused_reduction {
8
9namespace impl {
10
11//! Implementation helper for welfordEach.
12template <int ValIdx, typename Triplet0, typename Triplet1>
13struct WelfordForEach {
14 static __inline__ __device__ void call(
15 Triplet0& triplet0,
16 nvfuser_index_t offset0,
17 const Triplet1& triplet1,
18 nvfuser_index_t offset1) {
19 static_assert(
20 Triplet0::num_vals == Triplet1::num_vals, "Invalid Triplet types");
21 static_assert(
22 IsSameType<typename Triplet0::DataType, typename Triplet1::DataType>::
23 value,
24 "Invalid Triplet types");
25 static_assert(
26 IsSameType<typename Triplet0::IndexType, typename Triplet1::IndexType>::
27 value,
28 "Invalid Triplet types");
29
30 using DataType = typename Triplet0::DataType;
31 using IndexType = typename Triplet0::IndexType;
32
33 WelfordForEach<ValIdx - 1, Triplet0, Triplet1>::call(
34 triplet0, offset0, triplet1, offset1);
35 welfordCombine<DataType, IndexType>(
36 triplet0.avg.val<ValIdx>(offset0),
37 triplet0.var.val<ValIdx>(offset0),
38 triplet0.N.val<ValIdx>(offset0),
39 triplet1.avg.val<ValIdx>(offset1),
40 triplet1.var.val<ValIdx>(offset1),
41 triplet1.N.val<ValIdx>(offset1));
42 }
43};
44
45template <typename Triplet0, typename Triplet1>
46struct WelfordForEach<-1, Triplet0, Triplet1> {
47 __inline__ __device__ static void call(
48 Triplet0& triplet0,
49 nvfuser_index_t offset0,
50 const Triplet1& triplet1,
51 nvfuser_index_t offset1) {}
52};
53
54//! Call welfordCombine with each of the triplet tuples. This is a
55//! welford version of reduceEach.
56template <typename Triplet0, typename Triplet1>
57__inline__ __device__ static void welfordEach(
58 Triplet0& triplet0,
59 nvfuser_index_t offset0,
60 const Triplet1& triplet1,
61 nvfuser_index_t offset1) {
62 WelfordForEach<Triplet0::num_vals - 1, Triplet0, Triplet1>::call(
63 triplet0, offset0, triplet1, offset1);
64}
65
66// Welford version of BlockReduceEach
67template <
68 int idx,
69 bool BROADCAST,
70 bool FORWARD_PROTECT_SMEM,
71 typename LocalWelfordTripletTupleT>
72struct BlockWelfordEach {
73 __inline__ __device__ static void reduce(
74 LocalWelfordTripletTupleT& block_result,
75 const LocalWelfordTripletTupleT& partial_result,
76 PtrTuple<
77 typename LocalWelfordTripletTupleT::DataType,
78 typename LocalWelfordTripletTupleT::DataType,
79 typename LocalWelfordTripletTupleT::IndexType> shared_buf,
80 bool has_block_result,
81 int tid_in_reduction,
82 int num_threads_per_reduction,
83 int num_elements_per_reduction,
84 int reduction_idx) {
85 // Finish the reduction of each tuple value with a smaller offset
86 BlockWelfordEach<idx - 1, BROADCAST, true, LocalWelfordTripletTupleT>::
87 reduce(
88 block_result,
89 partial_result,
90 shared_buf,
91 has_block_result,
92 tid_in_reduction,
93 num_threads_per_reduction,
94 num_elements_per_reduction,
95 reduction_idx);
96
97 if (num_elements_per_reduction == 1) {
98 if (has_block_result) {
99 copyWelfordTripletTuple(block_result, partial_result);
100 }
101 return;
102 }
103
104 using DataType = typename LocalWelfordTripletTupleT::DataType;
105 using IndexType = typename LocalWelfordTripletTupleT::IndexType;
106
107 LocalTuple<DataType, DataType, IndexType> block_result_i(
108 partial_result.avg.val<idx>(0),
109 partial_result.var.val<idx>(0),
110 partial_result.N.val<idx>(0));
111
112 const auto smem_offset =
113 reduction_idx * num_threads_per_reduction + tid_in_reduction;
114
115 const int np2 = 1 << (31 - __clz(num_elements_per_reduction));
116
117 // Threads values are initialized, so all can participate here
118 if (tid_in_reduction >= np2) {
119 copyTuple(shared_buf, smem_offset, block_result_i);
120 }
121
122 block_sync::sync();
123 if (tid_in_reduction < np2 &&
124 tid_in_reduction + np2 < num_elements_per_reduction) {
125 impl::reduceTuple(
126 block_result_i,
127 0,
128 shared_buf,
129 smem_offset + np2,
130 welfordCombine<DataType, IndexType>);
131 }
132
133 if (tid_in_reduction < np2) {
134 copyTuple(shared_buf, smem_offset, block_result_i);
135 }
136
137 // Always sync when communicating across smem
138 block_sync::sync();
139
140 // Reduce down to 2 values, last thread will do the final reduction and
141 // can save a syncthreads this way
142 for (int factor = np2 / 2; factor > 1; factor >>= 1) {
143 if (tid_in_reduction < factor) {
144 impl::reduceTuple(
145 shared_buf,
146 smem_offset,
147 shared_buf,
148 smem_offset + factor,
149 welfordCombine<DataType, IndexType>);
150 }
151 block_sync::sync();
152 }
153
154 copyTuple(block_result_i, shared_buf, smem_offset);
155
156 // Do the last reduction
157 if (has_block_result) {
158 impl::reduceTuple(
159 block_result_i,
160 0,
161 shared_buf,
162 smem_offset + 1,
163 welfordCombine<DataType, IndexType>);
164 }
165
166 if (BROADCAST) {
167 if (has_block_result) {
168 // Put result back in shared memory, put in the first entry of the
169 // reduction segment's buffer
170 copyTuple(
171 shared_buf,
172 reduction_idx * num_threads_per_reduction,
173 block_result_i);
174 }
175
176 // Sync threads to make sure result is in smem
177 block_sync::sync();
178
179 copyTuple(
180 block_result_i,
181 shared_buf,
182 reduction_idx * num_threads_per_reduction);
183 }
184
185 block_result.avg.val<idx>(0) = block_result_i.val<0>(0);
186 block_result.var.val<idx>(0) = block_result_i.val<1>(0);
187 block_result.N.val<idx>(0) = block_result_i.val<2>(0);
188
189 if (FORWARD_PROTECT_SMEM) {
190 block_sync::sync();
191 }
192 }
193};
194
195// Specialization for idx == -1, i.e., no value to reduce.
196template <
197 bool BROADCAST,
198 bool FORWARD_PROTECT_SMEM,
199 typename LocalWelfordTripletTupleT>
200struct BlockWelfordEach<
201 -1,
202 BROADCAST,
203 FORWARD_PROTECT_SMEM,
204 LocalWelfordTripletTupleT> {
205 __inline__ __device__ static void reduce(
206 LocalWelfordTripletTupleT& block_result,
207 const LocalWelfordTripletTupleT& partial_result,
208 PtrTuple<
209 typename LocalWelfordTripletTupleT::DataType,
210 typename LocalWelfordTripletTupleT::DataType,
211 typename LocalWelfordTripletTupleT::IndexType> shared_buf,
212 bool has_block_result,
213 int tid_in_reduction,
214 int num_threads_per_reduction,
215 int num_elements_per_reduction,
216 int reduction_idx) {}
217};
218
219//! Welford version of blockReduceEach. Perform block-parallel Welford
220//! reduction of each Welford triplet.
221template <
222 bool BROADCAST,
223 bool FORWARD_PROTECT_SMEM,
224 typename LocalWelfordTripletTupleT>
225__inline__ __device__ void blockWelfordEach(
226 LocalWelfordTripletTupleT& block_result,
227 const LocalWelfordTripletTupleT& partial_result,
228 PtrTuple<
229 typename LocalWelfordTripletTupleT::DataType,
230 typename LocalWelfordTripletTupleT::DataType,
231 typename LocalWelfordTripletTupleT::IndexType> shared_buf,
232 bool has_block_result,
233 int tid_in_reduction,
234 int num_threads_per_reduction,
235 int num_elements_per_reduction,
236 int reduction_idx) {
237 BlockWelfordEach<
238 LocalWelfordTripletTupleT::num_vals - 1,
239 BROADCAST,
240 FORWARD_PROTECT_SMEM,
241 LocalWelfordTripletTupleT>::
242 reduce(
243 block_result,
244 partial_result,
245 shared_buf,
246 has_block_result,
247 tid_in_reduction,
248 num_threads_per_reduction,
249 num_elements_per_reduction,
250 reduction_idx);
251}
252
253} // namespace impl
254
255template <
256 int X_BLOCK,
257 int Y_BLOCK,
258 int Z_BLOCK,
259 int X_THREAD,
260 int Y_THREAD,
261 int Z_THREAD,
262 bool PERSISTENT_REDUCTION,
263 bool BROADCAST>
264template <int NumArgs, typename DataType, typename IndexType>
265__device__ __inline__ void ParallelReduce<
266 X_BLOCK,
267 Y_BLOCK,
268 Z_BLOCK,
269 X_THREAD,
270 Y_THREAD,
271 Z_THREAD,
272 PERSISTENT_REDUCTION,
273 BROADCAST>::
274 welfordGroup(
275 typename MakeRefTuple<NumArgs, DataType>::type out_avg,
276 typename MakeRefTuple<NumArgs, DataType>::type out_var,
277 typename MakeRefTuple<NumArgs, IndexType>::type out_N,
278 const typename MakeConstRefTuple<NumArgs, DataType>::type& inp_avg,
279 const typename MakeConstRefTuple<NumArgs, DataType>::type& inp_var,
280 const typename MakeConstRefTuple<NumArgs, IndexType>::type& inp_N,
281 const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg,
282 const typename MakeLocalTuple<NumArgs, DataType>::type& init_var,
283 const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N,
284 typename MakeVolatilePtrTuple<NumArgs, DataType>::type
285 global_work_buffer_avg,
286 typename MakeVolatilePtrTuple<NumArgs, DataType>::type
287 global_work_buffer_var,
288 typename MakeVolatilePtrTuple<NumArgs, IndexType>::type
289 global_work_buffer_N,
290 int64_t* global_sync_buffer,
291 PtrTuple<DataType, DataType, IndexType> shared_buf,
292 const typename MakeLocalTuple<NumArgs, bool>::type& read_preds,
293 const typename MakeLocalTuple<NumArgs, bool>::type& write_preds) {
294 const ConstRefWelfordTripletTuple<NumArgs, DataType, IndexType> inp(
295 inp_avg, inp_var, inp_N);
296 RefWelfordTripletTuple<NumArgs, DataType, IndexType> out(
297 out_avg, out_var, out_N);
298
299 // If no reduction needed, just return input
300 if (!BLOCK_REDUCE && !GRID_REDUCE) {
301 copyWelfordTripletTupleIf(out, inp, read_preds && write_preds);
302 return;
303 }
304
305 // Don't read/write in temporary buffers if in a predicated dimension
306 const bool block_reduce_participate = index_utils::
307 maskedIsZero<isPred(X_THREAD), isPred(Y_THREAD), isPred(Z_THREAD)>(
308 threadIdx);
309
310 // Only threads that with id == 0 in the dimensions being reduced will
311 // have a valid result
312 const bool has_block_result = index_utils::
313 maskedIsZero<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
314 threadIdx);
315
316 LocalWelfordTripletTuple<NumArgs, DataType, IndexType> block_result(
317 init_avg, init_var, init_N);
318
319 // Initial per-block reduction. Result is broadcast if specified
320 // and this call is block reduction only.
321 welfordGroupBlock<!GRID_REDUCE && BROADCAST>(
322 block_result, inp, shared_buf, read_preds, block_reduce_participate);
323
324 // If block reduction only, save to out and exit
325 if (!GRID_REDUCE) {
326 copyWelfordTripletTupleIf(
327 out,
328 block_result,
329 write_preds &&
330 (block_reduce_participate && (BROADCAST || has_block_result)));
331
332 // Need a block sync here as reduceGroupBlock does not
333 // forward-protect the smem buffer. This block sync is not
334 // necessary when a grid reduction follows since a block sync is
335 // done just before the grid sync.
336 block_sync::sync();
337 return;
338 }
339
340 // -- START GRID REDUCTION -- //
341 // Grid reductions are more challenging for two reasons, (1) the reduction
342 // itself is 3D instead of 2D because we now have an iter domain space in
343 // the grid dimension. (2) a tree reduction isn't performed, instead all
344 // blocks will populate GMEM and one block will finish the grid reduction.
345
346 // What is the grid reduction size, block reduction already performed so
347 // that doesn't have to be taken into consideration
348 const auto grid_red_size = index_utils::
349 maskedSize<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
350 gridDim);
351
352 // Which ID in the reduction is this block. Threads can participate in
353 // multiple grid reductions, but the block will have the same relative index
354 // in those reductions
355 const auto idx_in_grid_red = index_utils::
356 maskedOffset<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
357 blockIdx, gridDim);
358
359 // How many grid reductions have to be performed, in the grid dimension
360 const auto num_block_iters = index_utils::
361 maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(gridDim);
362
363 // Which grid reduction does this block participate in, in the grid
364 // dimension
365 const auto block_red_idx_offset = index_utils::
366 maskedOffset<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(
367 blockIdx, gridDim);
368
369 // How many grid reductions have to be performed, in the block dimension
370 const auto num_thread_iters = index_utils::
371 maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
372 blockDim);
373
374 // Which grid reduction does this thread participate in, in the block
375 // dimension
376 const auto thread_red_idx_offset = index_utils::
377 maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
378 threadIdx, blockDim);
379
380 // 3D buffer of reductions:
381 // [reduction_offset(grid), iter_offset(grid), iter_offset(block)]
382 // Offset into the work buffer
383 auto work_buf_offset =
384 (idx_in_grid_red * num_block_iters + block_red_idx_offset) *
385 num_thread_iters +
386 thread_red_idx_offset;
387
388 // Don't read/write in temporary buffers if in a predicated dimension
389 bool grid_reduce_participate = index_utils::
390 maskedIsZero<isPred(X_BLOCK), isPred(Y_BLOCK), isPred(Z_BLOCK)>(blockIdx);
391
392 VolatilePtrWelfordTripletTuple<NumArgs, DataType, IndexType>
393 global_work_buffer(
394 global_work_buffer_avg, global_work_buffer_var, global_work_buffer_N);
395
396 if (PERSISTENT_REDUCTION && flip) {
397 auto global_buffer_size =
398 index_utils::
399 maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(
400 gridDim) *
401 index_utils::
402 maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
403 blockDim) *
404 grid_red_size;
405 global_work_buffer += global_buffer_size;
406 }
407 flip = !flip;
408
409 // Per-block partial reduction to global work buffer
410 if (grid_reduce_participate && block_reduce_participate && has_block_result) {
411 copyWelfordTripletTuple(global_work_buffer, work_buf_offset, block_result);
412 }
413
414 // -- GLOBAL BUFFER FILLED -- //
415
416 bool last_block = index_utils::
417 maskedIsLast<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
418 blockIdx, gridDim);
419
420 if (grid_reduce_participate) {
421 // Don't need to sync up blocks that are not participating in this
422 // reduction
423 grid_sync::sync<
424 isReduce(X_BLOCK),
425 isReduce(Y_BLOCK),
426 isReduce(Z_BLOCK),
427 PERSISTENT_REDUCTION>(
428 global_sync_buffer[block_red_idx_offset], grid_red_size, last_block);
429 }
430
431 // -- START BLOCK CLEANUP -- //
432 welfordGroupLastBlock(
433 out,
434 global_work_buffer,
435 LocalWelfordTripletTuple<NumArgs, DataType, IndexType>(
436 init_avg, init_var, init_N),
437 shared_buf,
438 block_red_idx_offset,
439 num_thread_iters,
440 num_block_iters,
441 thread_red_idx_offset,
442 grid_red_size,
443 write_preds,
444 block_reduce_participate,
445 grid_reduce_participate);
446
447 // Forward protect the smem buffer
448 block_sync::sync();
449}
450
451template <
452 int X_BLOCK,
453 int Y_BLOCK,
454 int Z_BLOCK,
455 int X_THREAD,
456 int Y_THREAD,
457 int Z_THREAD,
458 bool PERSISTENT_REDUCTION,
459 bool BROADCAST>
460template <
461 bool BLOCK_BROADCAST,
462 int NumVals,
463 typename DataType,
464 typename IndexType>
465__device__ __inline__ void ParallelReduce<
466 X_BLOCK,
467 Y_BLOCK,
468 Z_BLOCK,
469 X_THREAD,
470 Y_THREAD,
471 Z_THREAD,
472 PERSISTENT_REDUCTION,
473 BROADCAST>::
474)"
475R"(
476 welfordGroupBlock(
477 LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result,
478 const ConstRefWelfordTripletTuple<NumVals, DataType, IndexType>& inp,
479 PtrTuple<DataType, DataType, IndexType> shared_buf,
480 const typename MakeLocalTuple<NumVals, bool>::type& read_preds,
481 bool block_reduce_participate) {
482 const bool has_block_result = index_utils::
483 maskedIsZero<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
484 threadIdx);
485
486 copyWelfordTripletTupleIf(
487 block_result, inp, block_reduce_participate && read_preds);
488
489 // Size of the block reduction segment, can be an int since it's limited
490 // to number of threads
491 const int block_reduction_size = index_utils::
492 maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
493 blockDim);
494
495 // Index in the reduction segment, can be an int since it's limited to
496 // number of threads
497 const int tid_in_block_reduction = index_utils::
498 maskedOffset<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
499 threadIdx, blockDim);
500
501 // ID of the block reduction this thread is participating in
502 //
503 // If any of the parallel dimensions are predicated out, that means
504 // they've already been reduced, so we only care about the first thread in
505 // that dimension. Therefore don't expand the reduction_idx by that
506 // dimension
507 const int block_reduction_idx = index_utils::
508 maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
509 threadIdx, blockDim);
510
511 // Do not protect the smem buffer as it's not always necessary.
512 impl::blockWelfordEach<
513 BLOCK_BROADCAST,
514 false,
515 LocalWelfordTripletTuple<NumVals, DataType, IndexType>>(
516 block_result,
517 block_result,
518 shared_buf,
519 has_block_result,
520 tid_in_block_reduction,
521 block_reduction_size,
522 block_reduction_size,
523 block_reduction_idx);
524}
525
526template <
527 int X_BLOCK,
528 int Y_BLOCK,
529 int Z_BLOCK,
530 int X_THREAD,
531 int Y_THREAD,
532 int Z_THREAD,
533 bool PERSISTENT_REDUCTION,
534 bool BROADCAST>
535template <int NumVals, typename DataType, typename IndexType>
536__device__ __inline__ void ParallelReduce<
537 X_BLOCK,
538 Y_BLOCK,
539 Z_BLOCK,
540 X_THREAD,
541 Y_THREAD,
542 Z_THREAD,
543 PERSISTENT_REDUCTION,
544 BROADCAST>::
545 welfordGroupLastBlock(
546 RefWelfordTripletTuple<NumVals, DataType, IndexType>& out,
547 const VolatilePtrWelfordTripletTuple<NumVals, DataType, IndexType>&
548 global_work_buffer,
549 const LocalWelfordTripletTuple<NumVals, DataType, IndexType>& init_val,
550 PtrTuple<DataType, DataType, IndexType> shared_buf,
551 nvfuser_index_t block_red_idx_offset,
552 nvfuser_index_t num_thread_iters,
553 nvfuser_index_t num_block_iters,
554 nvfuser_index_t thread_red_idx_offset,
555 nvfuser_index_t grid_red_size,
556 const typename MakeLocalTuple<NumVals, bool>::type& write_preds,
557 bool block_reduce_participate,
558 bool grid_reduce_participate) {
559 // Initialize block result
560 auto last_block_result = init_val;
561
562 const bool last_block = index_utils::
563 maskedIsLast<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
564 blockIdx, gridDim);
565
566 if ((PERSISTENT_REDUCTION || last_block) && grid_reduce_participate) {
567 // Can use the last block to reduce all the values the blocks filled in.
568 // Can use any thread that has been predicated, or has been reduced to do
569 // this reduction, cannot use any block that's associated with an
570 // iteration domain
571
572 // Start with non-block reduction
573
574 // Index in the reduction segment
575 int tid_in_block_reduction = index_utils::maskedOffset<
576 activeNotIter(X_THREAD),
577 activeNotIter(Y_THREAD),
578 activeNotIter(Z_THREAD)>(threadIdx, blockDim);
579
580 int block_reduction_size = index_utils::maskedSize<
581 activeNotIter(X_THREAD),
582 activeNotIter(Y_THREAD),
583 activeNotIter(Z_THREAD)>(blockDim);
584
585 bool has_block_result = index_utils::maskedIsZero<
586 activeNotIter(X_THREAD),
587 activeNotIter(Y_THREAD),
588 activeNotIter(Z_THREAD)>(threadIdx);
589
590 // 3D buffer of reductions:
591 // [reduction_offset(grid), iter_offset(grid), iter_offset(block)]
592 // Change the offset, we want to keep the last two dimensions, but the
593 // first dimension is what we will reduce over
594 const auto work_buf_offset =
595 block_red_idx_offset * num_thread_iters + thread_red_idx_offset;
596 for (auto reduction_i = tid_in_block_reduction; reduction_i < grid_red_size;
597 reduction_i += block_reduction_size) {
598 impl::welfordEach(
599 last_block_result,
600 0,
601 global_work_buffer,
602 work_buf_offset + reduction_i * num_block_iters * num_thread_iters);
603 }
604
605 // Which block reduction this thread is participating in
606 int block_reduction_idx = index_utils::
607 maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
608 threadIdx, blockDim);
609
610 impl::blockWelfordEach<
611 BROADCAST,
612 false,
613 LocalWelfordTripletTuple<NumVals, DataType, IndexType>>(
614 last_block_result,
615 last_block_result,
616 shared_buf,
617 has_block_result,
618 tid_in_block_reduction,
619 block_reduction_size,
620 min(grid_red_size, block_reduction_size),
621 block_reduction_idx);
622
623 copyWelfordTripletTupleIf(
624 out,
625 last_block_result,
626 write_preds &&
627 (block_reduce_participate && (BROADCAST || has_block_result)));
628 }
629}
630
631} // namespace fused_reduction
632)";
633
634} // namespace nvfuser_resources
635