1// Generated from "/code/pytorch/third_party/nvfuser/runtime/fused_reduction.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* fused_reduction_cu = R"(
7namespace fused_reduction {
8
9namespace impl {
10
11//! Suppose f_i be the i-th function of the binary function
12//! parameters. Call the function as: f_i(x, y)
13template <int i, typename DataType, typename Func, typename... Funcs>
14struct FuncSelector {
15 static __device__ void call(
16 DataType& x,
17 const DataType y,
18 Func f,
19 Funcs... funcs) {
20 // Here, i is guaranteed to be larger than 0 as there's a
21 // specialization for i == 0 below. Recursively call FuncSelector
22 // by dropping f and decrementing i.
23 FuncSelector<i - 1, DataType, Funcs...>::call(x, y, funcs...);
24 }
25};
26
27//! Specialization of FuncSelector when i == 0, so f_i is f.
28template <typename DataType, typename Func, typename... Funcs>
29struct FuncSelector<0, DataType, Func, Funcs...> {
30 static __device__ void call(
31 DataType& x,
32 const DataType y,
33 Func f,
34 Funcs... funcs) {
35 f(x, y);
36 }
37};
38
39//! Call each of the first i+1 functions with the first i+1 values of
40//! tuples. Here, i is guaranteed to be larger than -1 as there's a
41//! specialization for i == -1.
42template <int i, typename TupleType0, typename TupleType1, typename... Funcs>
43struct FuncForEach {
44 static __device__ void call(
45 TupleType0& val0,
46 nvfuser_index_t offset0,
47 const TupleType1& val1,
48 nvfuser_index_t offset1,
49 Funcs... funcs) {
50 static_assert(
51 IsSameType<
52 typename TupleType0::template ValType<i>,
53 typename TupleType1::template ValType<i>>::value,
54 "Invalid tuple types");
55 // Process the first i functions first.
56 FuncForEach<i - 1, TupleType0, TupleType1, Funcs...>::call(
57 val0, offset0, val1, offset1, funcs...);
58 // Call the i+1-th function
59 FuncSelector<i, typename TupleType0::template ValType<i>, Funcs...>::call(
60 val0.val<i>(offset0), val1.val<i>(offset1), funcs...);
61 }
62};
63
64//! Specialization of FuncForEach when i == -1, which means no
65//! function to call. Just for stopping the recursive pattern here.
66template <typename TupleType0, typename TupleType1, typename... Funcs>
67struct FuncForEach<-1, TupleType0, TupleType1, Funcs...> {
68 static __device__ void call(
69 TupleType0& val0,
70 nvfuser_index_t offset0,
71 const TupleType1& val1,
72 nvfuser_index_t offset1,
73 Funcs... funcs) {}
74};
75
76//! Reduce one value of a tuple using one of the reduction ops. The
77//! value at val_idx is reduced by the function at func_idx.
78template <
79 int func_idx,
80 int val_idx,
81 typename TupleType0,
82 typename TupleType1,
83 typename... Funcs>
84__inline__ __device__ static void reduceVal(
85 TupleType0& val0,
86 nvfuser_index_t offset0,
87 const TupleType1& val1,
88 nvfuser_index_t offset1,
89 Funcs... reduction_ops) {
90 static_assert(
91 IsSameType<
92 typename TupleType0::template ValType<val_idx>,
93 typename TupleType1::template ValType<val_idx>>::value,
94 "Invalid tuple types");
95 FuncSelector<
96 func_idx,
97 typename TupleType0::template ValType<val_idx>,
98 Funcs...>::
99 call(
100 val0.val<val_idx>(offset0),
101 val1.val<val_idx>(offset1),
102 reduction_ops...);
103}
104
105//! Accumulate each value of a given pair of tuples using its corresponding
106//! function. Suppose f_i be the i-th reduciton function. Call f_i as:
107//! f_i(val0.val<i>(offset0), val1.val<i>(offset1)).
108template <typename TupleType0, typename TupleType1, typename... Funcs>
109__inline__ __device__ static void reduceEach(
110 TupleType0& val0,
111 nvfuser_index_t offset0,
112 const TupleType1& val1,
113 nvfuser_index_t offset1,
114 Funcs... reduction_ops) {
115 constexpr int num_funcs = sizeof...(reduction_ops);
116 FuncForEach<num_funcs - 1, TupleType0, TupleType1, Funcs...>::call(
117 val0, offset0, val1, offset1, reduction_ops...);
118}
119
120template <typename TupleType0, typename TupleType1, typename Func, int num_vals>
121struct TupleReduce {};
122
123template <typename TupleType0, typename TupleType1, typename Func>
124struct TupleReduce<TupleType0, TupleType1, Func, 1> {
125 __inline__ __device__ static void reduce(
126 TupleType0& val0,
127 nvfuser_index_t offset0,
128 const TupleType1& val1,
129 nvfuser_index_t offset1,
130 Func reduction_op) {
131 static_assert(
132 IsSameType<
133 typename TupleType0::ValTypes,
134 typename TupleType1::ValTypes>::value,
135 "Invalid value types");
136 reduction_op(val0.val<0>(offset0), val1.val<0>(offset1));
137 }
138};
139
140template <typename TupleType0, typename TupleType1, typename Func>
141struct TupleReduce<TupleType0, TupleType1, Func, 2> {
142 __inline__ __device__ static void reduce(
143 TupleType0& val0,
144 nvfuser_index_t offset0,
145 const TupleType1& val1,
146 nvfuser_index_t offset1,
147 Func reduction_op) {
148 static_assert(
149 IsSameType<
150 typename TupleType0::ValTypes,
151 typename TupleType1::ValTypes>::value,
152 "Invalid value types");
153 reduction_op(
154 val0.val<0>(offset0),
155 val0.val<1>(offset0),
156 val1.val<0>(offset1),
157 val1.val<1>(offset1));
158 }
159};
160
161template <typename TupleType0, typename TupleType1, typename Func>
162struct TupleReduce<TupleType0, TupleType1, Func, 3> {
163 __inline__ __device__ static void reduce(
164 TupleType0& val0,
165 nvfuser_index_t offset0,
166 const TupleType1& val1,
167 nvfuser_index_t offset1,
168 Func reduction_op) {
169 static_assert(
170 IsSameType<
171 typename TupleType0::ValTypes,
172 typename TupleType1::ValTypes>::value,
173 "Invalid value types");
174 reduction_op(
175 val0.val<0>(offset0),
176 val0.val<1>(offset0),
177 val0.val<2>(offset0),
178 val1.val<0>(offset1),
179 val1.val<1>(offset1),
180 val1.val<2>(offset1));
181 }
182};
183
184//! Reduce all values of a tuple together. The reduction function must
185//! have the same number of inputs as the number of values of each tuple.
186template <typename TupleType0, typename TupleType1, typename Func>
187__inline__ __device__ void reduceTuple(
188 TupleType0& val0,
189 nvfuser_index_t offset0,
190 const TupleType1& val1,
191 nvfuser_index_t offset1,
192 Func reduction_op) {
193 static_assert(
194 TupleType0::num_vals == TupleType1::num_vals, "Invalid number of values");
195 TupleReduce<TupleType0, TupleType1, Func, TupleType0::num_vals>::reduce(
196 val0, offset0, val1, offset1, reduction_op);
197}
198
199// Reduces all of the first (idx+1) values by a thread block
200template <
201 int idx,
202 bool BROADCAST,
203 bool FORWARD_PROTECT_SMEM,
204 typename LocalTupleT,
205 typename... Funcs>
206struct BlockReduceEach {
207 __inline__ __device__ static void reduce(
208 LocalTupleT& block_result,
209 const LocalTupleT& partial_result,
210 void* shared_mem,
211 bool has_block_result,
212 int tid_in_reduction,
213 int num_threads_per_reduction,
214 int num_elements_per_reduction,
215 int reduction_idx,
216 Funcs... funcs) {
217 // Finish the reduction of each tuple value with a smaller offset
218 BlockReduceEach<idx - 1, BROADCAST, true, LocalTupleT, Funcs...>::reduce(
219 block_result,
220 partial_result,
221 shared_mem,
222 has_block_result,
223 tid_in_reduction,
224 num_threads_per_reduction,
225 num_elements_per_reduction,
226 reduction_idx,
227 funcs...);
228
229 if (num_elements_per_reduction == 1) {
230 if (has_block_result) {
231 block_result.val<idx>(0) = partial_result.val<idx>(0);
232 }
233 return;
234 }
235
236 using DataType = typename LocalTupleT::template ValType<idx>;
237
238 PtrTuple<DataType> shared_buf(static_cast<DataType*>(shared_mem));
239
240 LocalTuple<DataType> block_result_i(partial_result.val<idx>(0));
241
242 const auto smem_offset =
243 reduction_idx * num_threads_per_reduction + tid_in_reduction;
244
245 const int np2 = 1 << (31 - __clz(num_elements_per_reduction));
246
247 // Threads values are initialized, so all can participate here
248 if (tid_in_reduction >= np2) {
249 copyTuple(shared_buf, smem_offset, block_result_i);
250 }
251
252 block_sync::sync();
253
254 if (tid_in_reduction < np2 &&
255 tid_in_reduction + np2 < num_elements_per_reduction) {
256 impl::reduceVal<idx, 0>(
257 block_result_i, 0, shared_buf, smem_offset + np2, funcs...);
258 }
259
260 if (tid_in_reduction < np2) {
261 copyTuple(shared_buf, smem_offset, block_result_i);
262 }
263
264 // Always sync when communicating across smem
265 block_sync::sync();
266
267 // Reduce down to 2 values, last thread will do the final reduction and
268 // can save a syncthreads this way
269 for (int factor = np2 / 2; factor > 1; factor >>= 1) {
270 if (tid_in_reduction < factor) {
271 impl::reduceVal<idx, 0>(
272 shared_buf,
273 smem_offset,
274 shared_buf,
275 smem_offset + factor,
276 funcs...);
277 }
278 block_sync::sync();
279 }
280
281 copyTuple(block_result_i, shared_buf, smem_offset);
282
283 // Do the last reduction
284 if (has_block_result) {
285 impl::reduceVal<idx, 0>(
286 block_result_i, 0, shared_buf, smem_offset + 1, funcs...);
287 }
288
289 if (BROADCAST) {
290 if (has_block_result) {
291 // Put result back in shared memory, put in the first entry of the
292 // reduction segment's buffer
293 copyTuple(
294 shared_buf,
295 reduction_idx * num_threads_per_reduction,
296 block_result_i);
297 }
298
299 // Sync threads to make sure result is in smem
300 block_sync::sync();
301
302 copyTuple(
303 block_result_i,
304 shared_buf,
305 reduction_idx * num_threads_per_reduction);
306 }
307
308 block_result.val<idx>(0) = block_result_i.val<0>(0);
309
310 if (FORWARD_PROTECT_SMEM) {
311 block_sync::sync();
312 }
313 }
314};
315
316// Specialization for idx == -1, i.e., no value to reduce.
317template <
318 bool BROADCAST,
319 bool FORWARD_PROTECT_SMEM,
320 typename LocalTupleT,
321 typename... Funcs>
322struct BlockReduceEach<
323 -1,
324 BROADCAST,
325 FORWARD_PROTECT_SMEM,
326 LocalTupleT,
327 Funcs...> {
328 __inline__ __device__ static void reduce(
329 LocalTupleT& block_result,
330 const LocalTupleT& partial_result,
331 void* shared_mem,
332 bool has_block_result,
333 int tid_in_reduction,
334 int num_threads_per_reduction,
335 int num_elements_per_reduction,
336 int reduction_idx,
337 Funcs... funcs) {}
338};
339
340//! Reduce each value of a tuple by a thread block.
341//!
342//! The final result is broadcast when BROADCAST is true.
343//!
344//! \param block_result result of the block reduction
345//! \param partial_result Per-thread input tuple
346//! \param shared_mem
347//! \param has_block_result
348//! \param tid_in_reduction
349//! \param num_threads_per_reduction
350//! \param num_elements_per_reduction
351//! \param reduction_idx
352//! \param reduction_ops
353template <
354 bool BROADCAST,
355 bool FORWARD_PROTECT_SMEM,
356 typename LocalTupleT,
357 typename... Funcs>
358__inline__ __device__ void blockReduceEach(
359 LocalTupleT& block_result,
360 const LocalTupleT& partial_result,
361 void* shared_mem,
362 bool has_block_result,
363 int tid_in_reduction,
364 int num_threads_per_reduction,
365 int num_elements_per_reduction,
366 int reduction_idx,
367 Funcs... reduction_ops) {
368 BlockReduceEach<
369 LocalTupleT::num_vals - 1,
370 BROADCAST,
371 FORWARD_PROTECT_SMEM,
372 LocalTupleT,
373 Funcs...>::
374 reduce(
375 block_result,
376 partial_result,
377 shared_mem,
378 has_block_result,
379 tid_in_reduction,
380 num_threads_per_reduction,
381 num_elements_per_reduction,
382 reduction_idx,
383 reduction_ops...);
384}
385
386} // namespace impl
387
388// We have 6 dimensions, 3 in the grid, 3 in the block
389// They can be 1 of 3 states,
390// Reduction Domain - TEMPLATE STATE 0
391// - Participating in the reduction, has values coming in, one value coming
392// out across the dimension
393// Iteration Domain - TEMPLATE STATE 1
394// - Not participating in the reduction, has values across the dimension after
395// the reduction
396// Collapsed Domain - TEMPLATE STATE 2
397// - Previously reduced, doesn't need to be reduced on that dimension, doesn't
398// have values across that dimension
399constexpr __device__ bool isReduce(int STATE) {
400 return STATE == 0;
401}
402
403constexpr __device__ bool isIter(int STATE) {
404 return STATE == 1;
405}
406
407constexpr __device__ bool isPred(int STATE) {
408 return STATE == 2;
409}
410
411constexpr __device__ bool inactive(int STATE) {
412 return STATE == 3;
413}
414
415constexpr __device__ bool activeNotIter(int STATE) {
416 return STATE != 3 && STATE != 1;
417}
418
419constexpr __device__ bool isReduceOrIter(int STATE) {
420 return isReduce(STATE) || isIter(STATE);
421}
422
423// When generating an index into the reduction, we have to stride by iteration
424// domains and reduction domains. Collapsed domains we can ignore, but we need
425// to make sure they never read or write (need to be predicated to correct
426// participation).
427
428// All inclusive reduction with option to re-broadcast. This reduction class
429// does not use predication of parallelization in the read or write predicates.
430// Instead there are 3 states each dimension of parallelization can have,
431// described above. Predication, indexing, and reduction will be done based on
432// this information.
433template <
434 int X_BLOCK,
435 int Y_BLOCK,
436 int Z_BLOCK,
437 int X_THREAD,
438 int Y_THREAD,
439 int Z_THREAD,
440 bool PERSISTENT_REDUCTION,
441 bool BROADCAST>
442class ParallelReduce {
443 static_assert(
444 !BROADCAST || PERSISTENT_REDUCTION,
445 "Broadcast requires persistent reduction");
446
447 static constexpr bool BLOCK_REDUCE =
448 isReduce(X_THREAD) || isReduce(Y_THREAD) || isReduce(Z_THREAD);
449
450 static constexpr bool GRID_REDUCE =
451 isReduce(X_BLOCK) || isReduce(Y_BLOCK) || isReduce(Z_BLOCK);
452
453 // ping-pong between global buffers to avoid a second sync
454 bool flip = false;
455
456 public:
457 __device__ ParallelReduce() {}
458
459 // reduceGroup does not support Welford-style reductions that reduce
460 // all values of a tuple together, so this is the only entry point
461 // for Welford for now.
462 template <typename Func, typename... Types>
463 __device__ __inline__ void reduce(
464 RefTuple<Types...> out,
465 const ConstRefTuple<Types...>& inp,
466 VolatilePtrTuple<Types...> global_work_buffer,
467 int64_t* global_sync_buffer, // Allocated as product of all
468 // non-participating Grid dimension
469 PtrTuple<Types...> shared_buf,
470 bool read_pred, // Prevent reading from out of bounds memory
471 bool write_pred, // Prevent from writing out of bounds
472 const LocalTuple<Types...>& init_val,
473 Func reduction_op);
474
475 //! Profiled version
476 template <typename Func, typename... Types>
477 __device__ __inline__ void reduce(
478 RefTuple<Types...> out,
479 const ConstRefTuple<Types...>& inp,
480 VolatilePtrTuple<Types...> global_work_buffer,
481 int64_t* global_sync_buffer, // Allocated as product of all
482 // non-participating Grid dimension
483 PtrTuple<Types...> shared_buf,
484 bool read_pred, // Prevent reading from out of bounds memory
485 bool write_pred, // Prevent from writing out of bounds
486 const LocalTuple<Types...>& init_val,
487 Func reduction_op,
488 int64_t& cycles,
489 int64_t& count);
490
491 //! Each value of a tuple is independently reduced by the
492 //! corresponding reduction op. Thus, Welford-like reductions are
493 //! not supported by this interface.
494 //!
495 //! Note that out is purely used as the output parameter, and its
496 //! initial value is not used but just overwritten. Since grid
497)"
498R"(
499 //! reductions do not allow serial reduction IterDomains, there is
500 //! no need to accumulate into the out parameter.
501 template <typename... DataTypes, typename... Funcs, typename... BoolTypes>
502 __device__ __inline__ void reduceGroup(
503 RefTuple<DataTypes...> out,
504 const ConstRefTuple<DataTypes...>& inp,
505 VolatilePtrTuple<DataTypes...> global_work_buffer,
506 const LocalTuple<DataTypes...>& init_val,
507 int64_t* global_sync_buffer,
508 void* shared_mem,
509 const LocalTuple<BoolTypes...>& read_preds,
510 const LocalTuple<BoolTypes...>& write_preds,
511 Funcs... funcs);
512
513 //! Profiled version
514 template <typename... DataTypes, typename... Funcs, typename... BoolTypes>
515 __device__ __inline__ void reduceGroup(
516 RefTuple<DataTypes...> out,
517 const ConstRefTuple<DataTypes...>& inp,
518 VolatilePtrTuple<DataTypes...> global_work_buffer,
519 const LocalTuple<DataTypes...>& init_val,
520 int64_t* global_sync_buffer,
521 void* shared_mem,
522 const LocalTuple<BoolTypes...>& read_preds,
523 const LocalTuple<BoolTypes...>& write_preds,
524 int64_t& cycles,
525 int64_t& count,
526 Funcs... funcs);
527
528 template <int NumArgs, typename DataType, typename IndexType>
529 __device__ __inline__ void welfordGroup(
530 typename MakeRefTuple<NumArgs, DataType>::type out_avg,
531 typename MakeRefTuple<NumArgs, DataType>::type out_var,
532 typename MakeRefTuple<NumArgs, IndexType>::type out_N,
533 const typename MakeConstRefTuple<NumArgs, DataType>::type& inp_avg,
534 const typename MakeConstRefTuple<NumArgs, DataType>::type& inp_var,
535 const typename MakeConstRefTuple<NumArgs, IndexType>::type& inp_N,
536 const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg,
537 const typename MakeLocalTuple<NumArgs, DataType>::type& init_var,
538 const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N,
539 typename MakeVolatilePtrTuple<NumArgs, DataType>::type
540 global_work_buffer_avg,
541 typename MakeVolatilePtrTuple<NumArgs, DataType>::type
542 global_work_buffer_var,
543 typename MakeVolatilePtrTuple<NumArgs, IndexType>::type
544 global_work_buffer_N,
545 int64_t* global_sync_buffer,
546 PtrTuple<DataType, DataType, IndexType> shared_buf,
547 const typename MakeLocalTuple<NumArgs, bool>::type& read_preds,
548 const typename MakeLocalTuple<NumArgs, bool>::type& write_preds);
549
550 private:
551 __device__ static bool isLastBlockInGrid() {
552 return index_utils::maskedIsLast<
553 isReduceOrIter(X_BLOCK),
554 isReduceOrIter(Y_BLOCK),
555 isReduceOrIter(Z_BLOCK)>(blockIdx, gridDim) &&
556 index_utils::maskedIsZero<
557 !isReduceOrIter(X_BLOCK),
558 !isReduceOrIter(Y_BLOCK),
559 !isReduceOrIter(Z_BLOCK)>(blockIdx);
560 }
561
562 //! Initial per-CTA reduction of each value of a tuple. Each value
563 //! is reduced individually, so the shared memory buffer just needs
564 //! to be large enough for each value. NOTE that the smem buffer is
565 //! not forward protected.
566 template <
567 bool BLOCK_BROADCAST,
568 typename... DataTypes,
569 typename... Funcs,
570 typename... BoolTypes>
571 __device__ __inline__ static LocalTuple<DataTypes...> reduceGroupBlock(
572 const ConstRefTuple<DataTypes...>& inp,
573 const LocalTuple<DataTypes...>& init_val,
574 void* shared_mem,
575 const LocalTuple<BoolTypes...>& read_preds,
576 bool block_reduce_participate,
577 Funcs... funcs);
578
579 //! Final reduction of partial results. Done by all blocks
580 //! redundantly when BROADCAST is true, or just one block otherwise.
581 //! The smem buffer is assumed synchronized when it is passed in,
582 //! but it isn't synchronized when returning from this function.
583 template <typename... DataTypes, typename... Funcs, typename... BoolTypes>
584 __device__ __inline__ static void reduceGroupLastBlock(
585 RefTuple<DataTypes...>& out,
586 const VolatilePtrTuple<DataTypes...>& global_work_buffer,
587 const LocalTuple<DataTypes...>& init_val,
588 void* shared_mem,
589 nvfuser_index_t block_red_idx_offset,
590 nvfuser_index_t num_thread_iters,
591 nvfuser_index_t num_block_iters,
592 nvfuser_index_t thread_red_idx_offset,
593 nvfuser_index_t grid_red_size,
594 const LocalTuple<BoolTypes...>& write_preds,
595 bool block_reduce_participate,
596 bool grid_reduce_participate,
597 Funcs... reduction_ops);
598
599 //! Welford version of reduceGroupBlock
600 template <
601 bool BLOCK_BROADCAST,
602 int NumVals,
603 typename DataType,
604 typename IndexType>
605 __device__ __inline__ static void welfordGroupBlock(
606 LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result,
607 const ConstRefWelfordTripletTuple<NumVals, DataType, IndexType>& inp,
608 PtrTuple<DataType, DataType, IndexType> shared_buf,
609 const typename MakeLocalTuple<NumVals, bool>::type& read_preds,
610 bool block_reduce_participate);
611
612 //! Welford version of reduceGrouplLastBlock
613 template <int NumVals, typename DataType, typename IndexType>
614 __device__ __inline__ static void welfordGroupLastBlock(
615 RefWelfordTripletTuple<NumVals, DataType, IndexType>& out,
616 const VolatilePtrWelfordTripletTuple<NumVals, DataType, IndexType>&
617 global_work_buffer,
618 const LocalWelfordTripletTuple<NumVals, DataType, IndexType>& init_val,
619 PtrTuple<DataType, DataType, IndexType> shared_buf,
620 nvfuser_index_t block_red_idx_offset,
621 nvfuser_index_t num_thread_iters,
622 nvfuser_index_t num_block_iters,
623 nvfuser_index_t thread_red_idx_offset,
624 nvfuser_index_t grid_red_size,
625 const typename MakeLocalTuple<NumVals, bool>::type& write_preds,
626 bool block_reduce_participate,
627 bool grid_reduce_participate);
628
629 // End Parallel reduce class
630};
631
632template <
633 int X_BLOCK,
634 int Y_BLOCK,
635 int Z_BLOCK,
636 int X_THREAD,
637 int Y_THREAD,
638 int Z_THREAD,
639 bool PERSISTENT_REDUCTION,
640 bool BROADCAST>
641template <typename Func, typename... Types>
642__device__ __inline__ void ParallelReduce<
643 X_BLOCK,
644 Y_BLOCK,
645 Z_BLOCK,
646 X_THREAD,
647 Y_THREAD,
648 Z_THREAD,
649 PERSISTENT_REDUCTION,
650 BROADCAST>::
651 reduce(
652 RefTuple<Types...> out,
653 const ConstRefTuple<Types...>& inp,
654 VolatilePtrTuple<Types...> global_work_buffer,
655 int64_t* global_sync_buffer, // Allocated as product of all
656 // non-participating Grid dimension
657 PtrTuple<Types...> shared_buf,
658 bool read_pred, // Prevent reading from out of bounds memory
659 bool write_pred, // Prevent from writing out of bounds
660 const LocalTuple<Types...>& init_val,
661 Func reduction_op) {
662 // If no reduction needed, just return input
663 if (!BLOCK_REDUCE && !GRID_REDUCE) {
664 if (read_pred && write_pred) {
665 out = inp;
666 }
667 return;
668 }
669
670 // Don't read/write in temporary buffers if in a predicated dimension
671 bool block_reduce_participate = index_utils::
672 maskedIsZero<isPred(X_THREAD), isPred(Y_THREAD), isPred(Z_THREAD)>(
673 threadIdx);
674
675 // Initialize block result
676 LocalTuple<Types...> block_result = init_val;
677
678 // Grab input data if participating in the reduction, set to block_result in
679 // the case there is no block reduction
680 if (block_reduce_participate && read_pred) {
681 block_result = inp;
682 }
683
684 // Only threads that with id == 0 in the dimensions being reduced will
685 // have a valid result
686 bool has_block_result = index_utils::
687 maskedIsZero<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
688 threadIdx);
689
690 if (BLOCK_REDUCE) {
691 // -- START BLOCK REDUCTION -- //
692
693 // Size of the block reduction segment, can be an int since it's limited
694 // to number of threads
695 int block_reduction_size = index_utils::
696 maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
697 blockDim);
698
699 // Index in the reduction segment, can be an int since it's limited to
700 // number of threads
701 int tid_in_block_reduction = index_utils::maskedOffset<
702 isReduce(X_THREAD),
703 isReduce(Y_THREAD),
704 isReduce(Z_THREAD)>(threadIdx, blockDim);
705
706 // ID of the block reduction this thread is participating in
707 //
708 // If any of the parallel dimensions are predicated out, that means
709 // they've already been reduced, so we only care about the first thread in
710 // that dimension. Therefore don't expand the reduction_idx by that
711 // dimension
712 int block_reduction_idx = index_utils::
713 maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
714 threadIdx, blockDim);
715
716 // Shared memory buffer is 2D
717 // [iter dimension, reduction dimension]
718
719 // Offset into smem for the current thread
720 int block_reduce_smem_offset =
721 block_reduction_idx * block_reduction_size + tid_in_block_reduction;
722
723 // Initialize shared memory
724 if (block_reduce_participate) {
725 copyTuple(shared_buf, block_reduce_smem_offset, block_result);
726 }
727
728 // Sync to make sure smem is completely initialized
729 block_sync::sync();
730
731 // Round reduction size down to nearest power of 2
732 int np2 = 1 << (31 - __clz(block_reduction_size));
733
734 // Perform an initial reduction leaving np2 elements
735 if (block_reduce_participate && tid_in_block_reduction < np2 &&
736 tid_in_block_reduction + np2 < block_reduction_size) {
737 impl::reduceTuple(
738 shared_buf,
739 block_reduce_smem_offset,
740 shared_buf,
741 block_reduce_smem_offset + np2,
742 reduction_op);
743 }
744
745 // Always need to sync while operating on shared memory
746 block_sync::sync();
747
748 // Reduce down until 2 values, leaving 2 values allows us to manually
749 // perform the last reduction and avoid a syncthreads
750 for (int factor = np2 / 2; factor > 1; factor >>= 1) {
751 if (tid_in_block_reduction < factor && block_reduce_participate) {
752 impl::reduceTuple(
753 shared_buf,
754 block_reduce_smem_offset,
755 shared_buf,
756 block_reduce_smem_offset + factor,
757 reduction_op);
758 }
759 block_sync::sync();
760 }
761
762 // Accumulate that last valid result
763 if (has_block_result) {
764 copyTuple(block_result, shared_buf, block_reduce_smem_offset);
765 if (block_reduction_size > 1) {
766 impl::reduceTuple(
767 block_result,
768 0,
769 shared_buf,
770 block_reduce_smem_offset + 1,
771 reduction_op);
772 }
773 }
774
775 // ===== BLOCK REDUCTION CLEANUP =======
776 if (!GRID_REDUCE) {
777 // If no grid reduction, we don't have to continue. Either broadcast
778 // back across the block or return the correct reduction
779 if (has_block_result && write_pred) {
780 impl::reduceTuple(block_result, 0, out, 0, reduction_op);
781 out = block_result;
782 }
783 if (BROADCAST) {
784 // No grid reduce, but need to broadcast, perform block broadcast
785 if (has_block_result && write_pred) {
786 // Put result back in shared memory, put in the first entry of the
787 // reduction segment's buffer
788 copyTuple(
789 shared_buf,
790 block_reduction_idx * block_reduction_size,
791 block_result);
792 }
793
794 // Sync threads to make sure result is in smem
795 block_sync::sync();
796 // If the thread is participating, and is not attempting to write out
797 // of bounds, return the broadcasted value.
798 if (block_reduce_participate && write_pred) {
799 copyTuple(
800 out, shared_buf, block_reduction_idx * block_reduction_size);
801 }
802 }
803
804 // Forward protect shared memory, don't want threads to continue to
805 // another reduction/broadcast and pollute shared memory before the
806 // reduction is completely finished.
807 //
808 // This could be avoided in some cases if we added thread syncs from
809 // block reductions in the syncthread insertion pass.
810 block_sync::sync();
811 return;
812 }
813 }
814
815 // -- START GRID REDUCTION -- //
816 // Grid reductions are more challenging for two reasons, (1) the reduction
817 // itself is 3D instead of 2D because we now have an iter domain space in
818 // the grid dimension. (2) a tree reduction isn't performed, instead all
819 // blocks will populate GMEM and one block will finish the grid reduction.
820
821 // What is the grid reduction size, block reduction already performed so
822 // that doesn't have to be taken into consideration
823 const auto grid_red_size = index_utils::
824 maskedSize<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
825 gridDim);
826
827 // Which ID in the reduction is this block. Threads can participate in
828 // multiple grid reductions, but the block will have the same relative index
829 // in those reductions
830 const auto idx_in_grid_red = index_utils::
831 maskedOffset<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
832 blockIdx, gridDim);
833
834 if (PERSISTENT_REDUCTION && flip) {
835 auto global_buffer_size =
836 index_utils::
837 maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(
838 gridDim) *
839 index_utils::
840 maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
841 blockDim) *
842 grid_red_size;
843 global_work_buffer += global_buffer_size;
844 }
845 flip = !flip;
846
847 // How many grid reductions have to be performed, in the grid dimension
848 const auto num_block_iters = index_utils::
849 maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(gridDim);
850
851 // Which grid reduction does this block participate in, in the grid
852 // dimension
853 const auto block_red_idx_offset = index_utils::
854 maskedOffset<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(
855 blockIdx, gridDim);
856
857 // How many grid reductions have to be performed, in the block dimension
858 const auto num_thread_iters = index_utils::
859 maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
860 blockDim);
861
862 // Which grid reduction does this thread participate in, in the block
863 // dimension
864 const auto thread_red_idx_offset = index_utils::
865 maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
866 threadIdx, blockDim);
867
868 // 3D buffer of reductions:
869 // [reduction_offset(grid), iter_offset(grid), iter_offset(block)]
870 // Offset into the work buffer
871 const auto work_buf_offset =
872 (idx_in_grid_red * num_block_iters + block_red_idx_offset) *
873 num_thread_iters +
874 thread_red_idx_offset;
875
876 // Don't read/write in temporary buffers if in a predicated dimension
877 bool grid_reduce_participate = index_utils::
878 maskedIsZero<isPred(X_BLOCK), isPred(Y_BLOCK), isPred(Z_BLOCK)>(blockIdx);
879
880 if (grid_reduce_participate && block_reduce_participate) {
881 if (has_block_result) {
882 copyTuple(global_work_buffer, work_buf_offset, block_result);
883 }
884 }
885
886 // -- GLOBAL BUFFER FILLED -- //
887
888 bool last_block = index_utils::
889 maskedIsLast<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
890 blockIdx, gridDim);
891
892 if (grid_reduce_participate) {
893 // Don't need to sync up blocks that are not participating in this
894 // reduction
895 grid_sync::sync<
896 isReduce(X_BLOCK),
897 isReduce(Y_BLOCK),
898 isReduce(Z_BLOCK),
899 PERSISTENT_REDUCTION>(
900 global_sync_buffer[block_red_idx_offset], grid_red_size, last_block);
901 }
902
903 // -- START BLOCK CLEANUP -- //
904 // All blocks perform the last cleanup, so every block, and every thread
905)"
906R"(
907 // will have the final result
908
909 // Initialize block result
910 LocalTuple<Types...> last_block_result(init_val);
911
912 if ((PERSISTENT_REDUCTION || last_block) && grid_reduce_participate) {
913 // Can use the last block to reduce all the values the blocks filled in.
914 // Can use any thread that has been predicated, or has been reduced to do
915 // this reduction, cannot use any block that's associated with an
916 // iteration domain
917
918 // Start with non-block reduction
919
920 // Index in the reduction segment
921 int tid_in_block_reduction_2 = index_utils::maskedOffset<
922 activeNotIter(X_THREAD),
923 activeNotIter(Y_THREAD),
924 activeNotIter(Z_THREAD)>(threadIdx, blockDim);
925
926 int block_reduction_size_2 = index_utils::maskedSize<
927 activeNotIter(X_THREAD),
928 activeNotIter(Y_THREAD),
929 activeNotIter(Z_THREAD)>(blockDim);
930
931 // 3D buffer of reductions:
932 // [reduction_offset(grid), iter_offset(grid), iter_offset(block)]
933 // Change the offset, we want to keep the last two dimensions, but the
934 // first dimension is what we will reduce over
935 const auto work_buf_offset_2 =
936 block_red_idx_offset * num_thread_iters + thread_red_idx_offset;
937 for (auto reduction_i = tid_in_block_reduction_2;
938 reduction_i < grid_red_size;
939 reduction_i += block_reduction_size_2) {
940 impl::reduceTuple(
941 last_block_result,
942 0,
943 global_work_buffer,
944 work_buf_offset_2 +
945 reduction_i * num_block_iters *
946 num_thread_iters, // Iterating over the outer most
947 // dimension, so need to stride by the
948 // total number of grid reductions. Could
949 // come back and change it so this is the
950 // contiguous dimension
951 reduction_op);
952 }
953
954 // -- START LAST BLOCK - BLOCK REDUCTION -- //
955
956 // Reduced so we have one value per thread, we need to further reduce any
957 // dimension that is not an iter dimension
958
959 // Which block reduction this thread is participating in
960 int block_reduction_idx = index_utils::
961 maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
962 threadIdx, blockDim);
963
964 // Offset in smem for this thread's result
965 auto smem_offset =
966 block_reduction_idx * block_reduction_size_2 + tid_in_block_reduction_2;
967
968 // Similar as before, reduce down to nearest power of 2 so we can do a
969 // tree reduction
970 int np2 = 1 << (31 - __clz(min(block_reduction_size_2, grid_red_size)));
971
972 // Threads values are initialized, so all can participate here
973 if (tid_in_block_reduction_2 >= np2) {
974 copyTuple(shared_buf, smem_offset, last_block_result);
975 }
976
977 block_sync::sync();
978
979 if (tid_in_block_reduction_2 < np2 &&
980 tid_in_block_reduction_2 + np2 <
981 min(block_reduction_size_2, grid_red_size)) {
982 impl::reduceTuple(
983 last_block_result, 0, shared_buf, smem_offset + np2, reduction_op);
984 }
985
986 if (tid_in_block_reduction_2 < np2) {
987 copyTuple(shared_buf, smem_offset, last_block_result);
988 }
989
990 // Always sync when communicating across smem
991 block_sync::sync();
992
993 // Reduce down to 2 values, last thread will do the final reduction and
994 // can save a syncthreads this way
995 for (int factor = np2 / 2; factor > 1; factor >>= 1) {
996 if (tid_in_block_reduction_2 < factor) {
997 impl::reduceTuple(
998 shared_buf,
999 smem_offset,
1000 shared_buf,
1001 smem_offset + factor,
1002 reduction_op);
1003 }
1004 block_sync::sync();
1005 }
1006
1007 // If this thread in each block has the final result before broadcasting
1008 // to all other threads in block
1009 bool has_block_result_2 = index_utils::maskedIsZero<
1010 activeNotIter(X_THREAD),
1011 activeNotIter(Y_THREAD),
1012 activeNotIter(Z_THREAD)>(threadIdx);
1013 // Do the last reduction, protected by the write predicate
1014 copyTuple(last_block_result, shared_buf, smem_offset);
1015 if (has_block_result && grid_reduce_participate) {
1016 impl::reduceTuple(last_block_result, 0, out, 0, reduction_op);
1017 if (min(block_reduction_size_2, grid_red_size) > 1) {
1018 impl::reduceTuple(
1019 last_block_result, 0, shared_buf, smem_offset + 1, reduction_op);
1020 }
1021 }
1022 if (grid_reduce_participate && PERSISTENT_REDUCTION) {
1023 // If persistent reduction, always broadcast reduced values
1024 copyTuple(shared_buf, smem_offset, last_block_result);
1025 block_sync::sync();
1026 if (write_pred && block_reduce_participate) {
1027 copyTuple(
1028 out, shared_buf, block_reduction_idx * block_reduction_size_2);
1029 }
1030 // For persistent kernels we double the global buffer allocation so we
1031 // don't need to protect those buffers every iteration preventing the
1032 // need of an additional grid_sync. Since we flip back and forth between
1033 // sections of the buffer, the one grid sync protects the other part of
1034 // the buffer.
1035 } else {
1036 if (grid_reduce_participate) {
1037 if (last_block && has_block_result && block_reduce_participate &&
1038 write_pred) {
1039 copyTuple(
1040 out, shared_buf, block_reduction_idx * block_reduction_size_2);
1041 }
1042 }
1043 }
1044 // Forward protect the smem used in this reduction
1045 block_sync::sync();
1046 }
1047}
1048
1049//! Profiled version
1050template <
1051 int X_BLOCK,
1052 int Y_BLOCK,
1053 int Z_BLOCK,
1054 int X_THREAD,
1055 int Y_THREAD,
1056 int Z_THREAD,
1057 bool PERSISTENT_REDUCTION,
1058 bool BROADCAST>
1059template <typename Func, typename... Types>
1060__device__ __inline__ void ParallelReduce<
1061 X_BLOCK,
1062 Y_BLOCK,
1063 Z_BLOCK,
1064 X_THREAD,
1065 Y_THREAD,
1066 Z_THREAD,
1067 PERSISTENT_REDUCTION,
1068 BROADCAST>::
1069 reduce(
1070 RefTuple<Types...> out,
1071 const ConstRefTuple<Types...>& inp,
1072 VolatilePtrTuple<Types...> global_work_buffer,
1073 int64_t* global_sync_buffer, // Allocated as product of all
1074 // non-participating Grid dimension
1075 PtrTuple<Types...> shared_buf,
1076 bool read_pred, // Prevent reading from out of bounds memory
1077 bool write_pred, // Prevent from writing out of bounds
1078 const LocalTuple<Types...>& init_val,
1079 Func reduction_op,
1080 int64_t& cycles,
1081 int64_t& count) {
1082 int64_t start_counter = 0;
1083
1084 if (isLastBlockInGrid() &&
1085 index_utils::maskedIsZero<true, true, true>(threadIdx)) {
1086 start_counter = readCycleCounter();
1087 }
1088
1089 reduce(
1090 out,
1091 inp,
1092 global_work_buffer,
1093 global_sync_buffer,
1094 shared_buf,
1095 read_pred,
1096 write_pred,
1097 init_val,
1098 reduction_op);
1099
1100 if (isLastBlockInGrid() &&
1101 index_utils::maskedIsZero<true, true, true>(threadIdx)) {
1102 cycles += readCycleCounter() - start_counter;
1103 ++count;
1104 }
1105}
1106
1107template <
1108 int X_BLOCK,
1109 int Y_BLOCK,
1110 int Z_BLOCK,
1111 int X_THREAD,
1112 int Y_THREAD,
1113 int Z_THREAD,
1114 bool PERSISTENT_REDUCTION,
1115 bool BROADCAST>
1116template <typename... DataTypes, typename... Funcs, typename... BoolTypes>
1117__device__ __inline__ void ParallelReduce<
1118 X_BLOCK,
1119 Y_BLOCK,
1120 Z_BLOCK,
1121 X_THREAD,
1122 Y_THREAD,
1123 Z_THREAD,
1124 PERSISTENT_REDUCTION,
1125 BROADCAST>::
1126 reduceGroup(
1127 RefTuple<DataTypes...> out,
1128 const ConstRefTuple<DataTypes...>& inp,
1129 VolatilePtrTuple<DataTypes...> global_work_buffer,
1130 const LocalTuple<DataTypes...>& init_val,
1131 int64_t* global_sync_buffer,
1132 void* shared_mem,
1133 const LocalTuple<BoolTypes...>& read_preds,
1134 const LocalTuple<BoolTypes...>& write_preds,
1135 Funcs... funcs) {
1136 static_assert(
1137 sizeof...(DataTypes) == sizeof...(Funcs),
1138 "Mismatched number of Tuple values and functions");
1139 static_assert(
1140 sizeof...(DataTypes) == sizeof...(BoolTypes),
1141 "Mismatched number of Tuple values and predicate values");
1142
1143 // If no reduction needed, just return input
1144 if (!BLOCK_REDUCE && !GRID_REDUCE) {
1145 copyTupleIf(out, inp, read_preds && write_preds);
1146 return;
1147 }
1148
1149 // Don't read/write in temporary buffers if in a predicated dimension
1150 const bool block_reduce_participate = index_utils::
1151 maskedIsZero<isPred(X_THREAD), isPred(Y_THREAD), isPred(Z_THREAD)>(
1152 threadIdx);
1153
1154 // Only threads that with id == 0 in the dimensions being reduced will
1155 // have a valid result
1156 const bool has_block_result = index_utils::
1157 maskedIsZero<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
1158 threadIdx);
1159
1160 // Initial per-block reduction. Result is broadcast if specified
1161 // and this call is block reduction only.
1162 const auto block_result = reduceGroupBlock < !GRID_REDUCE &&
1163 BROADCAST > (inp,
1164 init_val,
1165 shared_mem,
1166 read_preds,
1167 block_reduce_participate,
1168 funcs...);
1169 // If block reduction only, save to out and exit
1170 if (!GRID_REDUCE) {
1171 copyTupleIf(
1172 out,
1173 block_result,
1174 write_preds &&
1175 (block_reduce_participate && (BROADCAST || has_block_result)));
1176
1177 // Need a block sync here as reduceGroupBlock does not
1178 // forward-protect the smem buffer. This block sync is not
1179 // necessary when a grid reduction follows since a block sync is
1180 // done just before the grid sync.
1181 block_sync::sync();
1182 return;
1183 }
1184
1185 // -- START GRID REDUCTION -- //
1186 // Grid reductions are more challenging for two reasons, (1) the reduction
1187 // itself is 3D instead of 2D because we now have an iter domain space in
1188 // the grid dimension. (2) a tree reduction isn't performed, instead all
1189 // blocks will populate GMEM and one block will finish the grid reduction.
1190
1191 // What is the grid reduction size, block reduction already performed so
1192 // that doesn't have to be taken into consideration
1193 const auto grid_red_size = index_utils::
1194 maskedSize<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
1195 gridDim);
1196
1197 // Which ID in the reduction is this block. Threads can participate in
1198 // multiple grid reductions, but the block will have the same relative index
1199 // in those reductions
1200 const auto idx_in_grid_red = index_utils::
1201 maskedOffset<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
1202 blockIdx, gridDim);
1203
1204 // How many grid reductions have to be performed, in the grid dimension
1205 const auto num_block_iters = index_utils::
1206 maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(gridDim);
1207
1208 // Which grid reduction does this block participate in, in the grid
1209 // dimension
1210 const auto block_red_idx_offset = index_utils::
1211 maskedOffset<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(
1212 blockIdx, gridDim);
1213
1214 // How many grid reductions have to be performed, in the block dimension
1215 const auto num_thread_iters = index_utils::
1216 maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
1217 blockDim);
1218
1219 // Which grid reduction does this thread participate in, in the block
1220 // dimension
1221 const auto thread_red_idx_offset = index_utils::
1222 maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
1223 threadIdx, blockDim);
1224
1225 // 3D buffer of reductions:
1226 // [reduction_offset(grid), iter_offset(grid), iter_offset(block)]
1227 // Offset into the work buffer
1228 const auto work_buf_offset =
1229 (idx_in_grid_red * num_block_iters + block_red_idx_offset) *
1230 num_thread_iters +
1231 thread_red_idx_offset;
1232
1233 // Don't read/write in temporary buffers if in a predicated dimension
1234 bool grid_reduce_participate = index_utils::
1235 maskedIsZero<isPred(X_BLOCK), isPred(Y_BLOCK), isPred(Z_BLOCK)>(blockIdx);
1236
1237 if (PERSISTENT_REDUCTION && flip) {
1238 auto global_buffer_size =
1239 index_utils::
1240 maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(
1241 gridDim) *
1242 index_utils::
1243 maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
1244 blockDim) *
1245 grid_red_size;
1246 global_work_buffer += global_buffer_size;
1247 }
1248 flip = !flip;
1249
1250 // Per-block partial reduction to global work buffer
1251 if (grid_reduce_participate && block_reduce_participate && has_block_result) {
1252 copyTuple(global_work_buffer, work_buf_offset, block_result);
1253 }
1254
1255 // -- GLOBAL BUFFER FILLED -- //
1256
1257 bool last_block = index_utils::
1258 maskedIsLast<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
1259 blockIdx, gridDim);
1260
1261 if (grid_reduce_participate) {
1262 // Don't need to sync up blocks that are not participating in this
1263 // reduction
1264 grid_sync::sync<
1265 isReduce(X_BLOCK),
1266 isReduce(Y_BLOCK),
1267 isReduce(Z_BLOCK),
1268 PERSISTENT_REDUCTION>(
1269 global_sync_buffer[block_red_idx_offset], grid_red_size, last_block);
1270 }
1271
1272 // -- START BLOCK CLEANUP -- //
1273 reduceGroupLastBlock(
1274 out,
1275 global_work_buffer,
1276 init_val,
1277 shared_mem,
1278 block_red_idx_offset,
1279 num_thread_iters,
1280 num_block_iters,
1281 thread_red_idx_offset,
1282 grid_red_size,
1283 write_preds,
1284 block_reduce_participate,
1285 grid_reduce_participate,
1286 funcs...);
1287
1288 // Forward protect the smem buffer
1289 block_sync::sync();
1290}
1291
1292template <
1293 int X_BLOCK,
1294 int Y_BLOCK,
1295 int Z_BLOCK,
1296 int X_THREAD,
1297 int Y_THREAD,
1298 int Z_THREAD,
1299 bool PERSISTENT_REDUCTION,
1300 bool BROADCAST>
1301template <typename... DataTypes, typename... Funcs, typename... BoolTypes>
1302__device__ __inline__ void ParallelReduce<
1303 X_BLOCK,
1304 Y_BLOCK,
1305 Z_BLOCK,
1306 X_THREAD,
1307 Y_THREAD,
1308 Z_THREAD,
1309 PERSISTENT_REDUCTION,
1310 BROADCAST>::
1311 reduceGroup(
1312 RefTuple<DataTypes...> out,
1313 const ConstRefTuple<DataTypes...>& inp,
1314 VolatilePtrTuple<DataTypes...> global_work_buffer,
1315 const LocalTuple<DataTypes...>& init_val,
1316 int64_t* global_sync_buffer,
1317 void* shared_mem,
1318 const LocalTuple<BoolTypes...>& read_preds,
1319 const LocalTuple<BoolTypes...>& write_preds,
1320 int64_t& cycles,
1321 int64_t& count,
1322 Funcs... funcs) {
1323 int64_t start_counter = 0;
1324
1325 if (isLastBlockInGrid() &&
1326 index_utils::maskedIsZero<true, true, true>(threadIdx)) {
1327 start_counter = readCycleCounter();
1328 }
1329
1330 reduceGroup(
1331 out,
1332 inp,
1333 global_work_buffer,
1334 init_val,
1335 global_sync_buffer,
1336 shared_mem,
1337 read_preds,
1338 write_preds,
1339 funcs...);
1340
1341 if (isLastBlockInGrid() &&
1342 index_utils::maskedIsZero<true, true, true>(threadIdx)) {
1343 cycles += readCycleCounter() - start_counter;
1344 ++count;
1345 }
1346}
1347
1348template <
1349 int X_BLOCK,
1350 int Y_BLOCK,
1351 int Z_BLOCK,
1352 int X_THREAD,
1353 int Y_THREAD,
1354 int Z_THREAD,
1355 bool PERSISTENT_REDUCTION,
1356 bool BROADCAST>
1357template <
1358 bool BLOCK_BROADCAST,
1359 typename... DataTypes,
1360 typename... Funcs,
1361 typename... BoolTypes>
1362__device__ __inline__ LocalTuple<DataTypes...> ParallelReduce<
1363 X_BLOCK,
1364 Y_BLOCK,
1365 Z_BLOCK,
1366 X_THREAD,
1367 Y_THREAD,
1368 Z_THREAD,
1369 PERSISTENT_REDUCTION,
1370 BROADCAST>::
1371 reduceGroupBlock(
1372 const ConstRefTuple<DataTypes...>& inp,
1373 const LocalTuple<DataTypes...>& init_val,
1374 void* shared_mem,
1375 const LocalTuple<BoolTypes...>& read_preds,
1376 bool block_reduce_participate,
1377 Funcs... funcs) {
1378 const bool has_block_result = index_utils::
1379 maskedIsZero<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
1380 threadIdx);
1381
1382)"
1383R"(
1384 // Initialize block result
1385 LocalTuple<DataTypes...> block_result = init_val;
1386
1387 copyTupleIf(block_result, inp, block_reduce_participate && read_preds);
1388
1389 // Size of the block reduction segment, can be an int since it's limited
1390 // to number of threads
1391 const int block_reduction_size = index_utils::
1392 maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
1393 blockDim);
1394
1395 // Index in the reduction segment, can be an int since it's limited to
1396 // number of threads
1397 const int tid_in_block_reduction = index_utils::
1398 maskedOffset<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
1399 threadIdx, blockDim);
1400
1401 // ID of the block reduction this thread is participating in
1402 //
1403 // If any of the parallel dimensions are predicated out, that means
1404 // they've already been reduced, so we only care about the first thread in
1405 // that dimension. Therefore don't expand the reduction_idx by that
1406 // dimension
1407 const int block_reduction_idx = index_utils::
1408 maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
1409 threadIdx, blockDim);
1410
1411 // Do not protect the smem buffer as it's not always necessary.
1412 impl::blockReduceEach<
1413 BLOCK_BROADCAST,
1414 false,
1415 LocalTuple<DataTypes...>,
1416 Funcs...>(
1417 block_result,
1418 block_result,
1419 shared_mem,
1420 has_block_result,
1421 tid_in_block_reduction,
1422 block_reduction_size,
1423 block_reduction_size,
1424 block_reduction_idx,
1425 funcs...);
1426
1427 return block_result;
1428}
1429
1430template <
1431 int X_BLOCK,
1432 int Y_BLOCK,
1433 int Z_BLOCK,
1434 int X_THREAD,
1435 int Y_THREAD,
1436 int Z_THREAD,
1437 bool PERSISTENT_REDUCTION,
1438 bool BROADCAST>
1439template <typename... DataTypes, typename... Funcs, typename... BoolTypes>
1440__device__ __inline__ void ParallelReduce<
1441 X_BLOCK,
1442 Y_BLOCK,
1443 Z_BLOCK,
1444 X_THREAD,
1445 Y_THREAD,
1446 Z_THREAD,
1447 PERSISTENT_REDUCTION,
1448 BROADCAST>::
1449 reduceGroupLastBlock(
1450 RefTuple<DataTypes...>& out,
1451 const VolatilePtrTuple<DataTypes...>& global_work_buffer,
1452 const LocalTuple<DataTypes...>& init_val,
1453 void* shared_mem,
1454 nvfuser_index_t block_red_idx_offset,
1455 nvfuser_index_t num_thread_iters,
1456 nvfuser_index_t num_block_iters,
1457 nvfuser_index_t thread_red_idx_offset,
1458 nvfuser_index_t grid_red_size,
1459 const LocalTuple<BoolTypes...>& write_preds,
1460 bool block_reduce_participate,
1461 bool grid_reduce_participate,
1462 Funcs... reduction_ops) {
1463 // Initialize block result
1464 LocalTuple<DataTypes...> last_block_result(init_val);
1465
1466 const bool last_block = index_utils::
1467 maskedIsLast<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>(
1468 blockIdx, gridDim);
1469
1470 if ((PERSISTENT_REDUCTION || last_block) && grid_reduce_participate) {
1471 // Can use the last block to reduce all the values the blocks filled in.
1472 // Can use any thread that has been predicated, or has been reduced to do
1473 // this reduction, cannot use any block that's associated with an
1474 // iteration domain
1475
1476 // Start with non-block reduction
1477
1478 // Index in the reduction segment
1479 int tid_in_block_reduction = index_utils::maskedOffset<
1480 activeNotIter(X_THREAD),
1481 activeNotIter(Y_THREAD),
1482 activeNotIter(Z_THREAD)>(threadIdx, blockDim);
1483
1484 int block_reduction_size = index_utils::maskedSize<
1485 activeNotIter(X_THREAD),
1486 activeNotIter(Y_THREAD),
1487 activeNotIter(Z_THREAD)>(blockDim);
1488
1489 bool has_block_result = index_utils::maskedIsZero<
1490 activeNotIter(X_THREAD),
1491 activeNotIter(Y_THREAD),
1492 activeNotIter(Z_THREAD)>(threadIdx);
1493
1494 // 3D buffer of reductions:
1495 // [reduction_offset(grid), iter_offset(grid), iter_offset(block)]
1496 // Change the offset, we want to keep the last two dimensions, but the
1497 // first dimension is what we will reduce over
1498 const auto work_buf_offset =
1499 block_red_idx_offset * num_thread_iters + thread_red_idx_offset;
1500 for (auto reduction_i = tid_in_block_reduction; reduction_i < grid_red_size;
1501 reduction_i += block_reduction_size) {
1502 impl::reduceEach(
1503 last_block_result,
1504 0,
1505 global_work_buffer,
1506 work_buf_offset +
1507 reduction_i * num_block_iters *
1508 num_thread_iters, // Iterating over the outer most
1509 // dimension, so need to stride by the
1510 // total number of grid reductions. Could
1511 // come back and change it so this is the
1512 // contiguous dimension
1513 reduction_ops...);
1514 }
1515
1516 // Which block reduction this thread is participating in
1517 int block_reduction_idx = index_utils::
1518 maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
1519 threadIdx, blockDim);
1520
1521 impl::blockReduceEach<BROADCAST, false, LocalTuple<DataTypes...>, Funcs...>(
1522 last_block_result,
1523 last_block_result,
1524 shared_mem,
1525 has_block_result,
1526 tid_in_block_reduction,
1527 block_reduction_size,
1528 min(grid_red_size, block_reduction_size),
1529 block_reduction_idx,
1530 reduction_ops...);
1531
1532 copyTupleIf(
1533 out,
1534 last_block_result,
1535 write_preds &&
1536 (block_reduce_participate && (BROADCAST || has_block_result)));
1537 }
1538}
1539
1540} // namespace fused_reduction
1541)";
1542
1543} // namespace nvfuser_resources
1544