1// Generated from "/code/pytorch/third_party/nvfuser/runtime/grid_reduction.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* grid_reduction_cu = R"(
7// Inter-block reduction.
8//
9// The gridReduce function performs point-wise reductions of scalars across
10// thread blocks. Thread blocks are disjointly partitioned into groups,
11// "reduction segments", that are collectively defined by boolean template
12// parameters, X_BLOCK, Y_BLOCK and Z_BLOCK. Each of X/Y/Z_BLOCK determines
13// whether thread blocks along the dimension should be grouped into the same
14// reduction segment. Cross-block reducitons are independently done within each
15// segment and generates distinctive results per segment. For instance, if all
16// of X/Y/Z_BLOCK are true, reductions will be done across all thread blocks
17// since there will be just a single segment consisting of all thread blocks. If
18// none of them are true, each thread block will become a segment by itself, so
19// no reduction will be performed.
20//
21// The input scalars to reduce within each segment are a certain subset of
22// thread-private scalars provided as part of the gridReduce function
23// parameters. Boolean template parameters, X_THREAD, Y_THREAD and Z_THREAD,
24// determine which subset of the scalars should be used for inter-block
25// reductions. Specifically, all the input scalars of threads along each
26// dimension will be used when X/Y/Z_THREAD are true. Otherwise, only the value
27// held at offset 0 of each dimension will be used. Thus, for example, if all of
28// X/Y/Z_THREAD are true, the scalars of all threads in each block will
29// participate in inter-block reductions. If all of them are false, only one
30// scalar of the thread at threadIdx.x == threadIdx.y == threadIdx.z == 0 will
31// be used. In the code below, we call the subset of threads a "reduction
32// block". "Participating" thread dimensions here are similar to the
33// "non-participating" block dimensions. They come from a block dimension that
34// has not been reduced before hitting this grid reduction.
35//
36// Inter-block reductions perform point-wise reductions of scalars of reduction
37// blocks within each reduction segment. More specifically, let rb be a
38// reduction block and rs be a reduction segment. Let IN(thread_idx, block_idx)
39// denote the input scalar of thread at thread_idx and block_idx. The result of
40// each reduction segment, OUT(thread_idx, block_idx_out), is defined only for
41// each thread_idx in thread block block_idx_out in the segment as follows:
42//
43// OUT(thread_idx, block_idx_out) =
44// Reduction of IN(thread_idx, block_idx) for
45// all block_idx in a reduction segment
46//
47// OUT is not given for all threads that are not in block_idx_out and the
48// reduction block.
49//
50// See also the function comment of gridReduce.
51
52namespace reduction {
53
54// Reduces all the reduction blocks in each reduction segment. This is the
55// "cleanup" stage of a grid reduction.
56//
57// This is only called by one thread block per reduction segment. The input
58// reduction blocks of the segment are stored in an intermediate buffer pointed
59// by parameter in. Template parameters X/Y/Z_THREAD denote how the reduction
60// block is formed.
61//
62// The size of a reduction block is by definition smaller or equal to the size
63// of a thread block. We use the remaining threads to parallelize reductions
64// across reduction blocks. For example, when X/Y/Z_THREAD = {true, false,
65// false}, we use blockDim.y*blockDim.z threads for each output value. This is
66// done first by loading the input values in parallel and then by reducing
67// across threads of dimensions whose XYZ_THREAD are false.
68//
69// Note that what is done here after the loading from global memory is similar
70// to what the existing blockReduce function does.
71template <
72 bool X_THREAD,
73 bool Y_THREAD,
74 bool Z_THREAD,
75 typename T,
76 typename Func>
77__device__ void gridReduceLastBlock(
78 T& out,
79 const volatile T* in,
80 const nvfuser_index_t
81 grid_reduction_segment_size, // Number of reductions across
82 // grid reduce dimensions
83 const nvfuser_index_t
84 block_reduction_segment_size, // Number of reductions across the block
85 Func reduction_op,
86 T* shared_buf,
87 bool write_pred,
88 T init_val) {
89 // We have to do num_reductions across reduction_size. The reductions are
90 // contiguous, but offset by reduction_size. There is an entry in "in" for
91 // every block, and every thread marked as true. Threads in dimensions marked
92 // as false can be used to parallelize the reduction.
93
94 // Find the reduction id of the participating threads
95 const auto block_reduction_segment_idx =
96 index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>(
97 threadIdx, blockDim);
98
99 // Find an id associated within a reduction segment for all
100 // "non-participating" threads, which will parallelize the reductions for the
101 // "participating" threads
102 const auto id_in_block_segment =
103 index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
104 threadIdx, blockDim);
105
106 // Stride by the "non-participating" threads
107 const auto input_stride_for_thread_in_segment =
108 index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
109
110 T inp = init_val;
111
112 // Block stride across the reduction until we only have one value per thread
113 for (nvfuser_index_t reduction_i = id_in_block_segment;
114 reduction_i < grid_reduction_segment_size;
115 reduction_i += input_stride_for_thread_in_segment) {
116 auto work_buf_offset = reduction_i * block_reduction_segment_size +
117 block_reduction_segment_idx;
118 reduction_op(inp, in[work_buf_offset]);
119 }
120
121 // Block reduce the per thread values into per "participating" thread values
122 T inp_tmp = init_val;
123 blockReduce<!X_THREAD, !Y_THREAD, !Z_THREAD>(
124 inp_tmp,
125 inp,
126 reduction_op,
127 threadIdx,
128 blockDim,
129 shared_buf,
130 true,
131 init_val);
132 const bool should_write = (X_THREAD || threadIdx.x == 0) &&
133 (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0);
134 if (should_write && write_pred) {
135 reduction_op(out, inp_tmp);
136 }
137}
138
139// Reduces per-thread values across threads and thread blocks.
140//
141// Function parameters:
142// - out: Per-thread output location
143// - inp_val: Per-thread input value
144// - reduction_op: Scalar reduction function
145// - work_buf: Temporary buffer for cross-block reductions
146// - sync_flags: A vector of integers for synchronizations
147// - shared_buf: Shared memory buffer for intra-block reduction
148//
149// Thread has valid results based on if it's the last block in the grid
150// reduction dimension
151//
152// Template parameters:
153// - X/Y/Z_BLOCK/THREAD: When true, reduces across thread blocks along the X/Y/Z
154// dimensions
155// - PERSISTENT_REDUCTION: Indicates grid reduction will be called in a loop, or
156// the result of the grid reduction will be broadcasted and used across the
157// grid. These requires cross grid communication and the grid synchronizations
158// here to actually synchronize across the entire grid. When false the grid is
159// not synchronized, the last block just waits for everyone else to finish and
160// the other blocks can exit early.
161// - T: Scalar data type of input/output data
162// - Func: Type of scalara reduction function
163//
164// Template parameters X/Y/Z_BLOCK define a group of thread blocks that are
165// reduced together. We call it a reduction segment. Some examples are:
166//
167// Case 1: X/Y/Z_BLOCK == true/true/true -> There is only one segment, which
168// includes all thread blocks. It is effecively the same as the grid.
169//
170// Case 2: X/Y/Z_BLOCK == false/false/false -> Each thread block comprises an
171// individual segment by itself.
172//
173// Case 3: X/Y/Z_BLOCK == true/false/false -> Each segment contains thread
174// blocks that have the same blockDim.x. There will be blockDim.y*blockDim.z
175// such segments.
176//
177// X/Y/Z_THREAD also works similarly as X/Y/Z_BLOCK and defines a
178// group of threads that are reduced togather.
179//
180// After the function completes, only one thread block per reduction segment
181// gets valid reduction results. There is no guarantee which particular block
182// gets the final results.
183//
184// entrance_ind and n_entrances are allowed when PERSISTENT_REDUCTION = false.
185// If a grid reduction call is only called once per thread, entrance_ind == 0
186// and n_entrances == 1. However, grid reduction can be called in a loop in a
187// thread, in that case entrance_ind is the count of times the function has been
188// called, and n_entrances is the total number of times it will be called.
189template <
190 bool X_BLOCK,
191 bool Y_BLOCK,
192 bool Z_BLOCK,
193 bool X_THREAD,
194 bool Y_THREAD,
195 bool Z_THREAD,
196 bool PERSISTENT_REDUCTION,
197 typename T,
198 typename Func>
199__device__ void gridReduce(
200 T& out,
201 const T& inp_val,
202 Func reduction_op,
203 volatile T* work_buf,
204 int64_t* sync_flags,
205 T* shared_buf,
206 bool read_pred,
207 bool write_pred,
208 T init_val,
209 const nvfuser_index_t entrance_ind,
210 const nvfuser_index_t n_entrances) {
211 T block_reduction_val = init_val;
212
213 // Do block reduction when required
214 if (X_THREAD || Y_THREAD || Z_THREAD) {
215 blockReduce<X_THREAD, Y_THREAD, Z_THREAD>(
216 block_reduction_val,
217 inp_val,
218 reduction_op,
219 threadIdx,
220 blockDim,
221 shared_buf,
222 read_pred,
223 true,
224 init_val);
225 } else if (read_pred) {
226 block_reduction_val = inp_val;
227 }
228
229 // Number of values to reduce in the reduction segment
230 const auto grid_reduction_segment_size =
231 index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
232
233 // Index of the reduction we're performing out of the
234 // grid_reduction_segment_size
235 const auto idx_in_grid_segment =
236 index_utils::maskedOffset<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(
237 blockIdx, gridDim);
238
239 // Number of threads we can use in final reduction, Seems to assume all
240 // threads in the block participate
241 const auto block_reduction_segment_size =
242 index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
243
244 // Number of reductions in the grid
245 const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION
246 ? 1
247 : index_utils::maskedSize<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(gridDim);
248
249 // advance to the offset for this segment
250 // index of reduction * size of the reduction * size of threads
251 work_buf += (entrance_ind * grid_segment_size + idx_in_grid_segment) *
252 grid_reduction_segment_size * block_reduction_segment_size;
253
254 if ((!X_THREAD || threadIdx.x == 0) && (!Y_THREAD || threadIdx.y == 0) &&
255 (!Z_THREAD || threadIdx.z == 0)) {
256 auto block_offset =
257 index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
258 auto thread_offset =
259 index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
260 threadIdx, blockDim);
261 auto work_buf_offset =
262 block_offset * block_reduction_segment_size + thread_offset;
263 work_buf[work_buf_offset] = block_reduction_val;
264 }
265 if (PERSISTENT_REDUCTION) {
266 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
267 sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
268
269 } else {
270 // Use a different sync flag for each call
271 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
272 sync_flags[entrance_ind * grid_segment_size + idx_in_grid_segment],
273 grid_reduction_segment_size);
274 }
275
276 bool last_block =
277 index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
278
279 if (last_block) {
280 // Cleanup with block reduction
281 gridReduceLastBlock<!X_THREAD, !Y_THREAD, !Z_THREAD>(
282 out,
283 (T*)work_buf,
284 grid_reduction_segment_size,
285 block_reduction_segment_size,
286 reduction_op,
287 shared_buf,
288 write_pred,
289 init_val);
290 }
291
292 if (PERSISTENT_REDUCTION) {
293 // Make sure we're done with global memory before we allow the kernel to
294 // continue
295 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
296 sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
297 }
298}
299
300// This is just a wrapper of the above grid reduction routine to
301// measure the elapsed cycles. The measurement must be done just by
302// one thread, and in this case it should be done by one of the
303// threads in the last thread block.
304#ifdef PYTORCH_NVFUSER_PROFILE_KERNEL
305template <
306 bool X_BLOCK,
307 bool Y_BLOCK,
308 bool Z_BLOCK,
309 bool X_THREAD,
310 bool Y_THREAD,
311 bool Z_THREAD,
312 bool PERSISTENT_REDUCTION,
313 typename T,
314 typename Func>
315__device__ void gridReduce(
316 T& out,
317 const T& inp_val,
318 Func reduction_op,
319 volatile T* work_buf,
320 int64_t* sync_flags,
321 T* shared_buf,
322 bool read_pred,
323 bool write_pred,
324 T init_val,
325 const nvfuser_index_t entrance_ind,
326 const nvfuser_index_t n_entrances,
327 int64_t& cycles,
328 int64_t& count) {
329 int64_t start_counter = 0;
330
331 if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) &&
332 index_utils::maskedIsZero<true, true, true>(threadIdx)) {
333 start_counter = readCycleCounter();
334 }
335
336 gridReduce<
337 X_BLOCK,
338 Y_BLOCK,
339 Z_BLOCK,
340 X_THREAD,
341 Y_THREAD,
342 Z_THREAD,
343 PERSISTENT_REDUCTION,
344 T,
345 Func>(
346 out,
347 inp_val,
348 reduction_op,
349 work_buf,
350 sync_flags,
351 shared_buf,
352 read_pred,
353 write_pred,
354 init_val,
355 entrance_ind,
356 n_entrances);
357
358 if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) &&
359 index_utils::maskedIsZero<true, true, true>(threadIdx)) {
360 cycles += readCycleCounter() - start_counter;
361 ++count;
362 }
363}
364#endif // PYTORCH_NVFUSER_PROFILE_KERNEL
365
366template <
367 bool X_BLOCK,
368 bool Y_BLOCK,
369 bool Z_BLOCK,
370 bool X_THREAD,
371 bool Y_THREAD,
372 bool Z_THREAD,
373 typename T,
374 typename Func>
375__device__ void gridReduce2PartialReduction(
376 const T& inp_val,
377 T init_val,
378 Func reduction_op,
379 volatile T* work_buf,
380 T* shared_buf,
381 bool read_pred,
382 nvfuser_index_t grid_reduction_segment_size,
383 nvfuser_index_t idx_in_grid_segment,
384 nvfuser_index_t block_reduction_segment_size) {
385 T block_reduction_val = init_val;
386
387 // Do block reduction when required
388 if (X_THREAD || Y_THREAD || Z_THREAD) {
389 blockReduce<X_THREAD, Y_THREAD, Z_THREAD>(
390 block_reduction_val,
391 inp_val,
392 reduction_op,
393 threadIdx,
394 blockDim,
395 shared_buf,
396 read_pred,
397 true,
398 init_val);
399 } else if (read_pred) {
400 block_reduction_val = inp_val;
401 }
402
403 if ((!X_THREAD || threadIdx.x == 0) && (!Y_THREAD || threadIdx.y == 0) &&
404 (!Z_THREAD || threadIdx.z == 0)) {
405 auto block_offset =
406 index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
407 auto thread_offset =
408 index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
409 threadIdx, blockDim);
410 auto work_buf_offset =
411 block_offset * block_reduction_segment_size + thread_offset;
412 work_buf[work_buf_offset] = block_reduction_val;
413 }
414}
415
416// 2-way horizontally fused grid reduction
417template <
418 bool X_BLOCK,
419 bool Y_BLOCK,
420 bool Z_BLOCK,
421 bool X_THREAD,
422 bool Y_THREAD,
423 bool Z_THREAD,
424 bool PERSISTENT_REDUCTION,
425 typename T1,
426 typename Func1,
427 typename T2,
428 typename Func2>
429__device__ void gridReduceGroup(
430 T1& out1,
431 const T1& inp_val1,
432 T1 init_val1,
433 Func1 reduction_op1,
434 volatile T1* work_buf1,
435 T2& out2,
436 const T2& inp_val2,
437 T2 init_val2,
438 Func2 reduction_op2,
439 volatile T2* work_buf2,
440 int64_t* sync_flags,
441 void* shared_buf,
442 bool read_pred,
443)"
444R"(
445 bool write_pred,
446 const nvfuser_index_t entrance_ind,
447 const nvfuser_index_t n_entrances) {
448 // Number of values to reduce in the reduction segment
449 const auto grid_reduction_segment_size =
450 index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
451
452 // Index of the reduction we're performing out of the
453 // grid_reduction_segment_size
454 const auto idx_in_grid_segment =
455 index_utils::maskedOffset<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(
456 blockIdx, gridDim);
457
458 // Number of threads we can use in final reduction, Seems to assume all
459 // threads in the block participate
460 const auto block_reduction_segment_size =
461 index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
462
463 // Number of reductions in the grid
464 const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION
465 ? 1
466 : index_utils::maskedSize<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(gridDim);
467
468 // advance to the offset for this segment
469 // index of reduction * size of the reduction * size of threads
470 work_buf1 += (entrance_ind * grid_segment_size + idx_in_grid_segment) *
471 grid_reduction_segment_size * block_reduction_segment_size;
472
473 work_buf2 += (entrance_ind * grid_segment_size + idx_in_grid_segment) *
474 grid_reduction_segment_size * block_reduction_segment_size;
475
476 gridReduce2PartialReduction<
477 X_BLOCK,
478 Y_BLOCK,
479 Z_BLOCK,
480 X_THREAD,
481 Y_THREAD,
482 Z_THREAD>(
483 inp_val1,
484 init_val1,
485 reduction_op1,
486 work_buf1,
487 (T1*)shared_buf,
488 read_pred,
489 grid_reduction_segment_size,
490 idx_in_grid_segment,
491 block_reduction_segment_size);
492
493 gridReduce2PartialReduction<
494 X_BLOCK,
495 Y_BLOCK,
496 Z_BLOCK,
497 X_THREAD,
498 Y_THREAD,
499 Z_THREAD>(
500 inp_val2,
501 init_val2,
502 reduction_op2,
503 work_buf2,
504 (T2*)shared_buf,
505 read_pred,
506 grid_reduction_segment_size,
507 idx_in_grid_segment,
508 block_reduction_segment_size);
509
510 if (PERSISTENT_REDUCTION) {
511 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
512 sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
513 } else {
514 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
515 sync_flags[entrance_ind * grid_segment_size + idx_in_grid_segment],
516 grid_reduction_segment_size);
517 }
518
519 bool last_block =
520 index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
521
522 if (last_block) {
523 // Cleanup with block reduction
524 gridReduceLastBlock<!X_THREAD, !Y_THREAD, !Z_THREAD>(
525 out1,
526 work_buf1,
527 grid_reduction_segment_size,
528 block_reduction_segment_size,
529 reduction_op1,
530 (T1*)shared_buf,
531 write_pred,
532 init_val1);
533 gridReduceLastBlock<!X_THREAD, !Y_THREAD, !Z_THREAD>(
534 out2,
535 work_buf2,
536 grid_reduction_segment_size,
537 block_reduction_segment_size,
538 reduction_op2,
539 (T2*)shared_buf,
540 write_pred,
541 init_val2);
542 }
543
544 if (PERSISTENT_REDUCTION) {
545 // Make sure we're done with global memory before we allow the kernel to
546 // continue
547 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION>(
548 sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
549 }
550}
551
552#ifdef PYTORCH_NVFUSER_PROFILE_KERNEL
553template <
554 bool X_BLOCK,
555 bool Y_BLOCK,
556 bool Z_BLOCK,
557 bool X_THREAD,
558 bool Y_THREAD,
559 bool Z_THREAD,
560 bool PERSISTENT_REDUCTION,
561 typename T1,
562 typename Func1,
563 typename T2,
564 typename Func2>
565__device__ void gridReduceGroup(
566 T1& out1,
567 const T1& inp_val1,
568 T1 init_val1,
569 Func1 reduction_op1,
570 volatile T1* work_buf1,
571 T2& out2,
572 const T2& inp_val2,
573 T2 init_val2,
574 Func2 reduction_op2,
575 volatile T2* work_buf2,
576 int64_t* sync_flags,
577 void* shared_buf,
578 bool read_pred,
579 bool write_pred,
580 const nvfuser_index_t entrance_ind,
581 const nvfuser_index_t n_entrances,
582 int64_t& cycles,
583 int64_t& count) {
584 int64_t start_counter = 0;
585
586 if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) &&
587 index_utils::maskedIsZero<true, true, true>(threadIdx)) {
588 start_counter = readCycleCounter();
589 }
590
591 gridReduceGroup<
592 X_BLOCK,
593 Y_BLOCK,
594 Z_BLOCK,
595 X_THREAD,
596 Y_THREAD,
597 Z_THREAD,
598 PERSISTENT_REDUCTION,
599 T1,
600 Func1,
601 T2,
602 Func2>(
603 out1,
604 inp_val1,
605 init_val1,
606 reduction_op1,
607 work_buf1,
608 out2,
609 inp_val2,
610 init_val2,
611 reduction_op2,
612 work_buf2,
613 sync_flags,
614 shared_buf,
615 read_pred,
616 write_pred,
617 entrance_ind,
618 n_entrances);
619
620 if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) &&
621 index_utils::maskedIsZero<true, true, true>(threadIdx)) {
622 cycles += readCycleCounter() - start_counter;
623 ++count;
624 }
625}
626#endif // PYTORCH_NVFUSER_PROFILE_KERNEL
627
628} // namespace reduction
629)";
630
631} // namespace nvfuser_resources
632