1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/fused_reduction.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* fused_reduction_cu = R"( |
7 | namespace fused_reduction { |
8 | |
9 | namespace impl { |
10 | |
11 | //! Suppose f_i be the i-th function of the binary function |
12 | //! parameters. Call the function as: f_i(x, y) |
13 | template <int i, typename DataType, typename Func, typename... Funcs> |
14 | struct FuncSelector { |
15 | static __device__ void call( |
16 | DataType& x, |
17 | const DataType y, |
18 | Func f, |
19 | Funcs... funcs) { |
20 | // Here, i is guaranteed to be larger than 0 as there's a |
21 | // specialization for i == 0 below. Recursively call FuncSelector |
22 | // by dropping f and decrementing i. |
23 | FuncSelector<i - 1, DataType, Funcs...>::call(x, y, funcs...); |
24 | } |
25 | }; |
26 | |
27 | //! Specialization of FuncSelector when i == 0, so f_i is f. |
28 | template <typename DataType, typename Func, typename... Funcs> |
29 | struct FuncSelector<0, DataType, Func, Funcs...> { |
30 | static __device__ void call( |
31 | DataType& x, |
32 | const DataType y, |
33 | Func f, |
34 | Funcs... funcs) { |
35 | f(x, y); |
36 | } |
37 | }; |
38 | |
39 | //! Call each of the first i+1 functions with the first i+1 values of |
40 | //! tuples. Here, i is guaranteed to be larger than -1 as there's a |
41 | //! specialization for i == -1. |
42 | template <int i, typename TupleType0, typename TupleType1, typename... Funcs> |
43 | struct FuncForEach { |
44 | static __device__ void call( |
45 | TupleType0& val0, |
46 | nvfuser_index_t offset0, |
47 | const TupleType1& val1, |
48 | nvfuser_index_t offset1, |
49 | Funcs... funcs) { |
50 | static_assert( |
51 | IsSameType< |
52 | typename TupleType0::template ValType<i>, |
53 | typename TupleType1::template ValType<i>>::value, |
54 | "Invalid tuple types"); |
55 | // Process the first i functions first. |
56 | FuncForEach<i - 1, TupleType0, TupleType1, Funcs...>::call( |
57 | val0, offset0, val1, offset1, funcs...); |
58 | // Call the i+1-th function |
59 | FuncSelector<i, typename TupleType0::template ValType<i>, Funcs...>::call( |
60 | val0.val<i>(offset0), val1.val<i>(offset1), funcs...); |
61 | } |
62 | }; |
63 | |
64 | //! Specialization of FuncForEach when i == -1, which means no |
65 | //! function to call. Just for stopping the recursive pattern here. |
66 | template <typename TupleType0, typename TupleType1, typename... Funcs> |
67 | struct FuncForEach<-1, TupleType0, TupleType1, Funcs...> { |
68 | static __device__ void call( |
69 | TupleType0& val0, |
70 | nvfuser_index_t offset0, |
71 | const TupleType1& val1, |
72 | nvfuser_index_t offset1, |
73 | Funcs... funcs) {} |
74 | }; |
75 | |
76 | //! Reduce one value of a tuple using one of the reduction ops. The |
77 | //! value at val_idx is reduced by the function at func_idx. |
78 | template < |
79 | int func_idx, |
80 | int val_idx, |
81 | typename TupleType0, |
82 | typename TupleType1, |
83 | typename... Funcs> |
84 | __inline__ __device__ static void reduceVal( |
85 | TupleType0& val0, |
86 | nvfuser_index_t offset0, |
87 | const TupleType1& val1, |
88 | nvfuser_index_t offset1, |
89 | Funcs... reduction_ops) { |
90 | static_assert( |
91 | IsSameType< |
92 | typename TupleType0::template ValType<val_idx>, |
93 | typename TupleType1::template ValType<val_idx>>::value, |
94 | "Invalid tuple types"); |
95 | FuncSelector< |
96 | func_idx, |
97 | typename TupleType0::template ValType<val_idx>, |
98 | Funcs...>:: |
99 | call( |
100 | val0.val<val_idx>(offset0), |
101 | val1.val<val_idx>(offset1), |
102 | reduction_ops...); |
103 | } |
104 | |
105 | //! Accumulate each value of a given pair of tuples using its corresponding |
106 | //! function. Suppose f_i be the i-th reduciton function. Call f_i as: |
107 | //! f_i(val0.val<i>(offset0), val1.val<i>(offset1)). |
108 | template <typename TupleType0, typename TupleType1, typename... Funcs> |
109 | __inline__ __device__ static void reduceEach( |
110 | TupleType0& val0, |
111 | nvfuser_index_t offset0, |
112 | const TupleType1& val1, |
113 | nvfuser_index_t offset1, |
114 | Funcs... reduction_ops) { |
115 | constexpr int num_funcs = sizeof...(reduction_ops); |
116 | FuncForEach<num_funcs - 1, TupleType0, TupleType1, Funcs...>::call( |
117 | val0, offset0, val1, offset1, reduction_ops...); |
118 | } |
119 | |
120 | template <typename TupleType0, typename TupleType1, typename Func, int num_vals> |
121 | struct TupleReduce {}; |
122 | |
123 | template <typename TupleType0, typename TupleType1, typename Func> |
124 | struct TupleReduce<TupleType0, TupleType1, Func, 1> { |
125 | __inline__ __device__ static void reduce( |
126 | TupleType0& val0, |
127 | nvfuser_index_t offset0, |
128 | const TupleType1& val1, |
129 | nvfuser_index_t offset1, |
130 | Func reduction_op) { |
131 | static_assert( |
132 | IsSameType< |
133 | typename TupleType0::ValTypes, |
134 | typename TupleType1::ValTypes>::value, |
135 | "Invalid value types"); |
136 | reduction_op(val0.val<0>(offset0), val1.val<0>(offset1)); |
137 | } |
138 | }; |
139 | |
140 | template <typename TupleType0, typename TupleType1, typename Func> |
141 | struct TupleReduce<TupleType0, TupleType1, Func, 2> { |
142 | __inline__ __device__ static void reduce( |
143 | TupleType0& val0, |
144 | nvfuser_index_t offset0, |
145 | const TupleType1& val1, |
146 | nvfuser_index_t offset1, |
147 | Func reduction_op) { |
148 | static_assert( |
149 | IsSameType< |
150 | typename TupleType0::ValTypes, |
151 | typename TupleType1::ValTypes>::value, |
152 | "Invalid value types"); |
153 | reduction_op( |
154 | val0.val<0>(offset0), |
155 | val0.val<1>(offset0), |
156 | val1.val<0>(offset1), |
157 | val1.val<1>(offset1)); |
158 | } |
159 | }; |
160 | |
161 | template <typename TupleType0, typename TupleType1, typename Func> |
162 | struct TupleReduce<TupleType0, TupleType1, Func, 3> { |
163 | __inline__ __device__ static void reduce( |
164 | TupleType0& val0, |
165 | nvfuser_index_t offset0, |
166 | const TupleType1& val1, |
167 | nvfuser_index_t offset1, |
168 | Func reduction_op) { |
169 | static_assert( |
170 | IsSameType< |
171 | typename TupleType0::ValTypes, |
172 | typename TupleType1::ValTypes>::value, |
173 | "Invalid value types"); |
174 | reduction_op( |
175 | val0.val<0>(offset0), |
176 | val0.val<1>(offset0), |
177 | val0.val<2>(offset0), |
178 | val1.val<0>(offset1), |
179 | val1.val<1>(offset1), |
180 | val1.val<2>(offset1)); |
181 | } |
182 | }; |
183 | |
184 | //! Reduce all values of a tuple together. The reduction function must |
185 | //! have the same number of inputs as the number of values of each tuple. |
186 | template <typename TupleType0, typename TupleType1, typename Func> |
187 | __inline__ __device__ void reduceTuple( |
188 | TupleType0& val0, |
189 | nvfuser_index_t offset0, |
190 | const TupleType1& val1, |
191 | nvfuser_index_t offset1, |
192 | Func reduction_op) { |
193 | static_assert( |
194 | TupleType0::num_vals == TupleType1::num_vals, "Invalid number of values"); |
195 | TupleReduce<TupleType0, TupleType1, Func, TupleType0::num_vals>::reduce( |
196 | val0, offset0, val1, offset1, reduction_op); |
197 | } |
198 | |
199 | // Reduces all of the first (idx+1) values by a thread block |
200 | template < |
201 | int idx, |
202 | bool BROADCAST, |
203 | bool FORWARD_PROTECT_SMEM, |
204 | typename LocalTupleT, |
205 | typename... Funcs> |
206 | struct BlockReduceEach { |
207 | __inline__ __device__ static void reduce( |
208 | LocalTupleT& block_result, |
209 | const LocalTupleT& partial_result, |
210 | void* shared_mem, |
211 | bool has_block_result, |
212 | int tid_in_reduction, |
213 | int num_threads_per_reduction, |
214 | int num_elements_per_reduction, |
215 | int reduction_idx, |
216 | Funcs... funcs) { |
217 | // Finish the reduction of each tuple value with a smaller offset |
218 | BlockReduceEach<idx - 1, BROADCAST, true, LocalTupleT, Funcs...>::reduce( |
219 | block_result, |
220 | partial_result, |
221 | shared_mem, |
222 | has_block_result, |
223 | tid_in_reduction, |
224 | num_threads_per_reduction, |
225 | num_elements_per_reduction, |
226 | reduction_idx, |
227 | funcs...); |
228 | |
229 | if (num_elements_per_reduction == 1) { |
230 | if (has_block_result) { |
231 | block_result.val<idx>(0) = partial_result.val<idx>(0); |
232 | } |
233 | return; |
234 | } |
235 | |
236 | using DataType = typename LocalTupleT::template ValType<idx>; |
237 | |
238 | PtrTuple<DataType> shared_buf(static_cast<DataType*>(shared_mem)); |
239 | |
240 | LocalTuple<DataType> block_result_i(partial_result.val<idx>(0)); |
241 | |
242 | const auto smem_offset = |
243 | reduction_idx * num_threads_per_reduction + tid_in_reduction; |
244 | |
245 | const int np2 = 1 << (31 - __clz(num_elements_per_reduction)); |
246 | |
247 | // Threads values are initialized, so all can participate here |
248 | if (tid_in_reduction >= np2) { |
249 | copyTuple(shared_buf, smem_offset, block_result_i); |
250 | } |
251 | |
252 | block_sync::sync(); |
253 | |
254 | if (tid_in_reduction < np2 && |
255 | tid_in_reduction + np2 < num_elements_per_reduction) { |
256 | impl::reduceVal<idx, 0>( |
257 | block_result_i, 0, shared_buf, smem_offset + np2, funcs...); |
258 | } |
259 | |
260 | if (tid_in_reduction < np2) { |
261 | copyTuple(shared_buf, smem_offset, block_result_i); |
262 | } |
263 | |
264 | // Always sync when communicating across smem |
265 | block_sync::sync(); |
266 | |
267 | // Reduce down to 2 values, last thread will do the final reduction and |
268 | // can save a syncthreads this way |
269 | for (int factor = np2 / 2; factor > 1; factor >>= 1) { |
270 | if (tid_in_reduction < factor) { |
271 | impl::reduceVal<idx, 0>( |
272 | shared_buf, |
273 | smem_offset, |
274 | shared_buf, |
275 | smem_offset + factor, |
276 | funcs...); |
277 | } |
278 | block_sync::sync(); |
279 | } |
280 | |
281 | copyTuple(block_result_i, shared_buf, smem_offset); |
282 | |
283 | // Do the last reduction |
284 | if (has_block_result) { |
285 | impl::reduceVal<idx, 0>( |
286 | block_result_i, 0, shared_buf, smem_offset + 1, funcs...); |
287 | } |
288 | |
289 | if (BROADCAST) { |
290 | if (has_block_result) { |
291 | // Put result back in shared memory, put in the first entry of the |
292 | // reduction segment's buffer |
293 | copyTuple( |
294 | shared_buf, |
295 | reduction_idx * num_threads_per_reduction, |
296 | block_result_i); |
297 | } |
298 | |
299 | // Sync threads to make sure result is in smem |
300 | block_sync::sync(); |
301 | |
302 | copyTuple( |
303 | block_result_i, |
304 | shared_buf, |
305 | reduction_idx * num_threads_per_reduction); |
306 | } |
307 | |
308 | block_result.val<idx>(0) = block_result_i.val<0>(0); |
309 | |
310 | if (FORWARD_PROTECT_SMEM) { |
311 | block_sync::sync(); |
312 | } |
313 | } |
314 | }; |
315 | |
316 | // Specialization for idx == -1, i.e., no value to reduce. |
317 | template < |
318 | bool BROADCAST, |
319 | bool FORWARD_PROTECT_SMEM, |
320 | typename LocalTupleT, |
321 | typename... Funcs> |
322 | struct BlockReduceEach< |
323 | -1, |
324 | BROADCAST, |
325 | FORWARD_PROTECT_SMEM, |
326 | LocalTupleT, |
327 | Funcs...> { |
328 | __inline__ __device__ static void reduce( |
329 | LocalTupleT& block_result, |
330 | const LocalTupleT& partial_result, |
331 | void* shared_mem, |
332 | bool has_block_result, |
333 | int tid_in_reduction, |
334 | int num_threads_per_reduction, |
335 | int num_elements_per_reduction, |
336 | int reduction_idx, |
337 | Funcs... funcs) {} |
338 | }; |
339 | |
340 | //! Reduce each value of a tuple by a thread block. |
341 | //! |
342 | //! The final result is broadcast when BROADCAST is true. |
343 | //! |
344 | //! \param block_result result of the block reduction |
345 | //! \param partial_result Per-thread input tuple |
346 | //! \param shared_mem |
347 | //! \param has_block_result |
348 | //! \param tid_in_reduction |
349 | //! \param num_threads_per_reduction |
350 | //! \param num_elements_per_reduction |
351 | //! \param reduction_idx |
352 | //! \param reduction_ops |
353 | template < |
354 | bool BROADCAST, |
355 | bool FORWARD_PROTECT_SMEM, |
356 | typename LocalTupleT, |
357 | typename... Funcs> |
358 | __inline__ __device__ void blockReduceEach( |
359 | LocalTupleT& block_result, |
360 | const LocalTupleT& partial_result, |
361 | void* shared_mem, |
362 | bool has_block_result, |
363 | int tid_in_reduction, |
364 | int num_threads_per_reduction, |
365 | int num_elements_per_reduction, |
366 | int reduction_idx, |
367 | Funcs... reduction_ops) { |
368 | BlockReduceEach< |
369 | LocalTupleT::num_vals - 1, |
370 | BROADCAST, |
371 | FORWARD_PROTECT_SMEM, |
372 | LocalTupleT, |
373 | Funcs...>:: |
374 | reduce( |
375 | block_result, |
376 | partial_result, |
377 | shared_mem, |
378 | has_block_result, |
379 | tid_in_reduction, |
380 | num_threads_per_reduction, |
381 | num_elements_per_reduction, |
382 | reduction_idx, |
383 | reduction_ops...); |
384 | } |
385 | |
386 | } // namespace impl |
387 | |
388 | // We have 6 dimensions, 3 in the grid, 3 in the block |
389 | // They can be 1 of 3 states, |
390 | // Reduction Domain - TEMPLATE STATE 0 |
391 | // - Participating in the reduction, has values coming in, one value coming |
392 | // out across the dimension |
393 | // Iteration Domain - TEMPLATE STATE 1 |
394 | // - Not participating in the reduction, has values across the dimension after |
395 | // the reduction |
396 | // Collapsed Domain - TEMPLATE STATE 2 |
397 | // - Previously reduced, doesn't need to be reduced on that dimension, doesn't |
398 | // have values across that dimension |
399 | constexpr __device__ bool isReduce(int STATE) { |
400 | return STATE == 0; |
401 | } |
402 | |
403 | constexpr __device__ bool isIter(int STATE) { |
404 | return STATE == 1; |
405 | } |
406 | |
407 | constexpr __device__ bool isPred(int STATE) { |
408 | return STATE == 2; |
409 | } |
410 | |
411 | constexpr __device__ bool inactive(int STATE) { |
412 | return STATE == 3; |
413 | } |
414 | |
415 | constexpr __device__ bool activeNotIter(int STATE) { |
416 | return STATE != 3 && STATE != 1; |
417 | } |
418 | |
419 | constexpr __device__ bool isReduceOrIter(int STATE) { |
420 | return isReduce(STATE) || isIter(STATE); |
421 | } |
422 | |
423 | // When generating an index into the reduction, we have to stride by iteration |
424 | // domains and reduction domains. Collapsed domains we can ignore, but we need |
425 | // to make sure they never read or write (need to be predicated to correct |
426 | // participation). |
427 | |
428 | // All inclusive reduction with option to re-broadcast. This reduction class |
429 | // does not use predication of parallelization in the read or write predicates. |
430 | // Instead there are 3 states each dimension of parallelization can have, |
431 | // described above. Predication, indexing, and reduction will be done based on |
432 | // this information. |
433 | template < |
434 | int X_BLOCK, |
435 | int Y_BLOCK, |
436 | int Z_BLOCK, |
437 | int X_THREAD, |
438 | int Y_THREAD, |
439 | int Z_THREAD, |
440 | bool PERSISTENT_REDUCTION, |
441 | bool BROADCAST> |
442 | class ParallelReduce { |
443 | static_assert( |
444 | !BROADCAST || PERSISTENT_REDUCTION, |
445 | "Broadcast requires persistent reduction"); |
446 | |
447 | static constexpr bool BLOCK_REDUCE = |
448 | isReduce(X_THREAD) || isReduce(Y_THREAD) || isReduce(Z_THREAD); |
449 | |
450 | static constexpr bool GRID_REDUCE = |
451 | isReduce(X_BLOCK) || isReduce(Y_BLOCK) || isReduce(Z_BLOCK); |
452 | |
453 | // ping-pong between global buffers to avoid a second sync |
454 | bool flip = false; |
455 | |
456 | public: |
457 | __device__ ParallelReduce() {} |
458 | |
459 | // reduceGroup does not support Welford-style reductions that reduce |
460 | // all values of a tuple together, so this is the only entry point |
461 | // for Welford for now. |
462 | template <typename Func, typename... Types> |
463 | __device__ __inline__ void reduce( |
464 | RefTuple<Types...> out, |
465 | const ConstRefTuple<Types...>& inp, |
466 | VolatilePtrTuple<Types...> global_work_buffer, |
467 | int64_t* global_sync_buffer, // Allocated as product of all |
468 | // non-participating Grid dimension |
469 | PtrTuple<Types...> shared_buf, |
470 | bool read_pred, // Prevent reading from out of bounds memory |
471 | bool write_pred, // Prevent from writing out of bounds |
472 | const LocalTuple<Types...>& init_val, |
473 | Func reduction_op); |
474 | |
475 | //! Profiled version |
476 | template <typename Func, typename... Types> |
477 | __device__ __inline__ void reduce( |
478 | RefTuple<Types...> out, |
479 | const ConstRefTuple<Types...>& inp, |
480 | VolatilePtrTuple<Types...> global_work_buffer, |
481 | int64_t* global_sync_buffer, // Allocated as product of all |
482 | // non-participating Grid dimension |
483 | PtrTuple<Types...> shared_buf, |
484 | bool read_pred, // Prevent reading from out of bounds memory |
485 | bool write_pred, // Prevent from writing out of bounds |
486 | const LocalTuple<Types...>& init_val, |
487 | Func reduction_op, |
488 | int64_t& cycles, |
489 | int64_t& count); |
490 | |
491 | //! Each value of a tuple is independently reduced by the |
492 | //! corresponding reduction op. Thus, Welford-like reductions are |
493 | //! not supported by this interface. |
494 | //! |
495 | //! Note that out is purely used as the output parameter, and its |
496 | //! initial value is not used but just overwritten. Since grid |
497 | )" |
498 | R"( |
499 | //! reductions do not allow serial reduction IterDomains, there is |
500 | //! no need to accumulate into the out parameter. |
501 | template <typename... DataTypes, typename... Funcs, typename... BoolTypes> |
502 | __device__ __inline__ void reduceGroup( |
503 | RefTuple<DataTypes...> out, |
504 | const ConstRefTuple<DataTypes...>& inp, |
505 | VolatilePtrTuple<DataTypes...> global_work_buffer, |
506 | const LocalTuple<DataTypes...>& init_val, |
507 | int64_t* global_sync_buffer, |
508 | void* shared_mem, |
509 | const LocalTuple<BoolTypes...>& read_preds, |
510 | const LocalTuple<BoolTypes...>& write_preds, |
511 | Funcs... funcs); |
512 | |
513 | //! Profiled version |
514 | template <typename... DataTypes, typename... Funcs, typename... BoolTypes> |
515 | __device__ __inline__ void reduceGroup( |
516 | RefTuple<DataTypes...> out, |
517 | const ConstRefTuple<DataTypes...>& inp, |
518 | VolatilePtrTuple<DataTypes...> global_work_buffer, |
519 | const LocalTuple<DataTypes...>& init_val, |
520 | int64_t* global_sync_buffer, |
521 | void* shared_mem, |
522 | const LocalTuple<BoolTypes...>& read_preds, |
523 | const LocalTuple<BoolTypes...>& write_preds, |
524 | int64_t& cycles, |
525 | int64_t& count, |
526 | Funcs... funcs); |
527 | |
528 | template <int NumArgs, typename DataType, typename IndexType> |
529 | __device__ __inline__ void welfordGroup( |
530 | typename MakeRefTuple<NumArgs, DataType>::type out_avg, |
531 | typename MakeRefTuple<NumArgs, DataType>::type out_var, |
532 | typename MakeRefTuple<NumArgs, IndexType>::type out_N, |
533 | const typename MakeConstRefTuple<NumArgs, DataType>::type& inp_avg, |
534 | const typename MakeConstRefTuple<NumArgs, DataType>::type& inp_var, |
535 | const typename MakeConstRefTuple<NumArgs, IndexType>::type& inp_N, |
536 | const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg, |
537 | const typename MakeLocalTuple<NumArgs, DataType>::type& init_var, |
538 | const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N, |
539 | typename MakeVolatilePtrTuple<NumArgs, DataType>::type |
540 | global_work_buffer_avg, |
541 | typename MakeVolatilePtrTuple<NumArgs, DataType>::type |
542 | global_work_buffer_var, |
543 | typename MakeVolatilePtrTuple<NumArgs, IndexType>::type |
544 | global_work_buffer_N, |
545 | int64_t* global_sync_buffer, |
546 | PtrTuple<DataType, DataType, IndexType> shared_buf, |
547 | const typename MakeLocalTuple<NumArgs, bool>::type& read_preds, |
548 | const typename MakeLocalTuple<NumArgs, bool>::type& write_preds); |
549 | |
550 | private: |
551 | __device__ static bool isLastBlockInGrid() { |
552 | return index_utils::maskedIsLast< |
553 | isReduceOrIter(X_BLOCK), |
554 | isReduceOrIter(Y_BLOCK), |
555 | isReduceOrIter(Z_BLOCK)>(blockIdx, gridDim) && |
556 | index_utils::maskedIsZero< |
557 | !isReduceOrIter(X_BLOCK), |
558 | !isReduceOrIter(Y_BLOCK), |
559 | !isReduceOrIter(Z_BLOCK)>(blockIdx); |
560 | } |
561 | |
562 | //! Initial per-CTA reduction of each value of a tuple. Each value |
563 | //! is reduced individually, so the shared memory buffer just needs |
564 | //! to be large enough for each value. NOTE that the smem buffer is |
565 | //! not forward protected. |
566 | template < |
567 | bool BLOCK_BROADCAST, |
568 | typename... DataTypes, |
569 | typename... Funcs, |
570 | typename... BoolTypes> |
571 | __device__ __inline__ static LocalTuple<DataTypes...> reduceGroupBlock( |
572 | const ConstRefTuple<DataTypes...>& inp, |
573 | const LocalTuple<DataTypes...>& init_val, |
574 | void* shared_mem, |
575 | const LocalTuple<BoolTypes...>& read_preds, |
576 | bool block_reduce_participate, |
577 | Funcs... funcs); |
578 | |
579 | //! Final reduction of partial results. Done by all blocks |
580 | //! redundantly when BROADCAST is true, or just one block otherwise. |
581 | //! The smem buffer is assumed synchronized when it is passed in, |
582 | //! but it isn't synchronized when returning from this function. |
583 | template <typename... DataTypes, typename... Funcs, typename... BoolTypes> |
584 | __device__ __inline__ static void reduceGroupLastBlock( |
585 | RefTuple<DataTypes...>& out, |
586 | const VolatilePtrTuple<DataTypes...>& global_work_buffer, |
587 | const LocalTuple<DataTypes...>& init_val, |
588 | void* shared_mem, |
589 | nvfuser_index_t block_red_idx_offset, |
590 | nvfuser_index_t num_thread_iters, |
591 | nvfuser_index_t num_block_iters, |
592 | nvfuser_index_t thread_red_idx_offset, |
593 | nvfuser_index_t grid_red_size, |
594 | const LocalTuple<BoolTypes...>& write_preds, |
595 | bool block_reduce_participate, |
596 | bool grid_reduce_participate, |
597 | Funcs... reduction_ops); |
598 | |
599 | //! Welford version of reduceGroupBlock |
600 | template < |
601 | bool BLOCK_BROADCAST, |
602 | int NumVals, |
603 | typename DataType, |
604 | typename IndexType> |
605 | __device__ __inline__ static void welfordGroupBlock( |
606 | LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result, |
607 | const ConstRefWelfordTripletTuple<NumVals, DataType, IndexType>& inp, |
608 | PtrTuple<DataType, DataType, IndexType> shared_buf, |
609 | const typename MakeLocalTuple<NumVals, bool>::type& read_preds, |
610 | bool block_reduce_participate); |
611 | |
612 | //! Welford version of reduceGrouplLastBlock |
613 | template <int NumVals, typename DataType, typename IndexType> |
614 | __device__ __inline__ static void welfordGroupLastBlock( |
615 | RefWelfordTripletTuple<NumVals, DataType, IndexType>& out, |
616 | const VolatilePtrWelfordTripletTuple<NumVals, DataType, IndexType>& |
617 | global_work_buffer, |
618 | const LocalWelfordTripletTuple<NumVals, DataType, IndexType>& init_val, |
619 | PtrTuple<DataType, DataType, IndexType> shared_buf, |
620 | nvfuser_index_t block_red_idx_offset, |
621 | nvfuser_index_t num_thread_iters, |
622 | nvfuser_index_t num_block_iters, |
623 | nvfuser_index_t thread_red_idx_offset, |
624 | nvfuser_index_t grid_red_size, |
625 | const typename MakeLocalTuple<NumVals, bool>::type& write_preds, |
626 | bool block_reduce_participate, |
627 | bool grid_reduce_participate); |
628 | |
629 | // End Parallel reduce class |
630 | }; |
631 | |
632 | template < |
633 | int X_BLOCK, |
634 | int Y_BLOCK, |
635 | int Z_BLOCK, |
636 | int X_THREAD, |
637 | int Y_THREAD, |
638 | int Z_THREAD, |
639 | bool PERSISTENT_REDUCTION, |
640 | bool BROADCAST> |
641 | template <typename Func, typename... Types> |
642 | __device__ __inline__ void ParallelReduce< |
643 | X_BLOCK, |
644 | Y_BLOCK, |
645 | Z_BLOCK, |
646 | X_THREAD, |
647 | Y_THREAD, |
648 | Z_THREAD, |
649 | PERSISTENT_REDUCTION, |
650 | BROADCAST>:: |
651 | reduce( |
652 | RefTuple<Types...> out, |
653 | const ConstRefTuple<Types...>& inp, |
654 | VolatilePtrTuple<Types...> global_work_buffer, |
655 | int64_t* global_sync_buffer, // Allocated as product of all |
656 | // non-participating Grid dimension |
657 | PtrTuple<Types...> shared_buf, |
658 | bool read_pred, // Prevent reading from out of bounds memory |
659 | bool write_pred, // Prevent from writing out of bounds |
660 | const LocalTuple<Types...>& init_val, |
661 | Func reduction_op) { |
662 | // If no reduction needed, just return input |
663 | if (!BLOCK_REDUCE && !GRID_REDUCE) { |
664 | if (read_pred && write_pred) { |
665 | out = inp; |
666 | } |
667 | return; |
668 | } |
669 | |
670 | // Don't read/write in temporary buffers if in a predicated dimension |
671 | bool block_reduce_participate = index_utils:: |
672 | maskedIsZero<isPred(X_THREAD), isPred(Y_THREAD), isPred(Z_THREAD)>( |
673 | threadIdx); |
674 | |
675 | // Initialize block result |
676 | LocalTuple<Types...> block_result = init_val; |
677 | |
678 | // Grab input data if participating in the reduction, set to block_result in |
679 | // the case there is no block reduction |
680 | if (block_reduce_participate && read_pred) { |
681 | block_result = inp; |
682 | } |
683 | |
684 | // Only threads that with id == 0 in the dimensions being reduced will |
685 | // have a valid result |
686 | bool has_block_result = index_utils:: |
687 | maskedIsZero<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( |
688 | threadIdx); |
689 | |
690 | if (BLOCK_REDUCE) { |
691 | // -- START BLOCK REDUCTION -- // |
692 | |
693 | // Size of the block reduction segment, can be an int since it's limited |
694 | // to number of threads |
695 | int block_reduction_size = index_utils:: |
696 | maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( |
697 | blockDim); |
698 | |
699 | // Index in the reduction segment, can be an int since it's limited to |
700 | // number of threads |
701 | int tid_in_block_reduction = index_utils::maskedOffset< |
702 | isReduce(X_THREAD), |
703 | isReduce(Y_THREAD), |
704 | isReduce(Z_THREAD)>(threadIdx, blockDim); |
705 | |
706 | // ID of the block reduction this thread is participating in |
707 | // |
708 | // If any of the parallel dimensions are predicated out, that means |
709 | // they've already been reduced, so we only care about the first thread in |
710 | // that dimension. Therefore don't expand the reduction_idx by that |
711 | // dimension |
712 | int block_reduction_idx = index_utils:: |
713 | maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
714 | threadIdx, blockDim); |
715 | |
716 | // Shared memory buffer is 2D |
717 | // [iter dimension, reduction dimension] |
718 | |
719 | // Offset into smem for the current thread |
720 | int block_reduce_smem_offset = |
721 | block_reduction_idx * block_reduction_size + tid_in_block_reduction; |
722 | |
723 | // Initialize shared memory |
724 | if (block_reduce_participate) { |
725 | copyTuple(shared_buf, block_reduce_smem_offset, block_result); |
726 | } |
727 | |
728 | // Sync to make sure smem is completely initialized |
729 | block_sync::sync(); |
730 | |
731 | // Round reduction size down to nearest power of 2 |
732 | int np2 = 1 << (31 - __clz(block_reduction_size)); |
733 | |
734 | // Perform an initial reduction leaving np2 elements |
735 | if (block_reduce_participate && tid_in_block_reduction < np2 && |
736 | tid_in_block_reduction + np2 < block_reduction_size) { |
737 | impl::reduceTuple( |
738 | shared_buf, |
739 | block_reduce_smem_offset, |
740 | shared_buf, |
741 | block_reduce_smem_offset + np2, |
742 | reduction_op); |
743 | } |
744 | |
745 | // Always need to sync while operating on shared memory |
746 | block_sync::sync(); |
747 | |
748 | // Reduce down until 2 values, leaving 2 values allows us to manually |
749 | // perform the last reduction and avoid a syncthreads |
750 | for (int factor = np2 / 2; factor > 1; factor >>= 1) { |
751 | if (tid_in_block_reduction < factor && block_reduce_participate) { |
752 | impl::reduceTuple( |
753 | shared_buf, |
754 | block_reduce_smem_offset, |
755 | shared_buf, |
756 | block_reduce_smem_offset + factor, |
757 | reduction_op); |
758 | } |
759 | block_sync::sync(); |
760 | } |
761 | |
762 | // Accumulate that last valid result |
763 | if (has_block_result) { |
764 | copyTuple(block_result, shared_buf, block_reduce_smem_offset); |
765 | if (block_reduction_size > 1) { |
766 | impl::reduceTuple( |
767 | block_result, |
768 | 0, |
769 | shared_buf, |
770 | block_reduce_smem_offset + 1, |
771 | reduction_op); |
772 | } |
773 | } |
774 | |
775 | // ===== BLOCK REDUCTION CLEANUP ======= |
776 | if (!GRID_REDUCE) { |
777 | // If no grid reduction, we don't have to continue. Either broadcast |
778 | // back across the block or return the correct reduction |
779 | if (has_block_result && write_pred) { |
780 | impl::reduceTuple(block_result, 0, out, 0, reduction_op); |
781 | out = block_result; |
782 | } |
783 | if (BROADCAST) { |
784 | // No grid reduce, but need to broadcast, perform block broadcast |
785 | if (has_block_result && write_pred) { |
786 | // Put result back in shared memory, put in the first entry of the |
787 | // reduction segment's buffer |
788 | copyTuple( |
789 | shared_buf, |
790 | block_reduction_idx * block_reduction_size, |
791 | block_result); |
792 | } |
793 | |
794 | // Sync threads to make sure result is in smem |
795 | block_sync::sync(); |
796 | // If the thread is participating, and is not attempting to write out |
797 | // of bounds, return the broadcasted value. |
798 | if (block_reduce_participate && write_pred) { |
799 | copyTuple( |
800 | out, shared_buf, block_reduction_idx * block_reduction_size); |
801 | } |
802 | } |
803 | |
804 | // Forward protect shared memory, don't want threads to continue to |
805 | // another reduction/broadcast and pollute shared memory before the |
806 | // reduction is completely finished. |
807 | // |
808 | // This could be avoided in some cases if we added thread syncs from |
809 | // block reductions in the syncthread insertion pass. |
810 | block_sync::sync(); |
811 | return; |
812 | } |
813 | } |
814 | |
815 | // -- START GRID REDUCTION -- // |
816 | // Grid reductions are more challenging for two reasons, (1) the reduction |
817 | // itself is 3D instead of 2D because we now have an iter domain space in |
818 | // the grid dimension. (2) a tree reduction isn't performed, instead all |
819 | // blocks will populate GMEM and one block will finish the grid reduction. |
820 | |
821 | // What is the grid reduction size, block reduction already performed so |
822 | // that doesn't have to be taken into consideration |
823 | const auto grid_red_size = index_utils:: |
824 | maskedSize<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
825 | gridDim); |
826 | |
827 | // Which ID in the reduction is this block. Threads can participate in |
828 | // multiple grid reductions, but the block will have the same relative index |
829 | // in those reductions |
830 | const auto idx_in_grid_red = index_utils:: |
831 | maskedOffset<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
832 | blockIdx, gridDim); |
833 | |
834 | if (PERSISTENT_REDUCTION && flip) { |
835 | auto global_buffer_size = |
836 | index_utils:: |
837 | maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>( |
838 | gridDim) * |
839 | index_utils:: |
840 | maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
841 | blockDim) * |
842 | grid_red_size; |
843 | global_work_buffer += global_buffer_size; |
844 | } |
845 | flip = !flip; |
846 | |
847 | // How many grid reductions have to be performed, in the grid dimension |
848 | const auto num_block_iters = index_utils:: |
849 | maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(gridDim); |
850 | |
851 | // Which grid reduction does this block participate in, in the grid |
852 | // dimension |
853 | const auto block_red_idx_offset = index_utils:: |
854 | maskedOffset<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>( |
855 | blockIdx, gridDim); |
856 | |
857 | // How many grid reductions have to be performed, in the block dimension |
858 | const auto num_thread_iters = index_utils:: |
859 | maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
860 | blockDim); |
861 | |
862 | // Which grid reduction does this thread participate in, in the block |
863 | // dimension |
864 | const auto thread_red_idx_offset = index_utils:: |
865 | maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
866 | threadIdx, blockDim); |
867 | |
868 | // 3D buffer of reductions: |
869 | // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] |
870 | // Offset into the work buffer |
871 | const auto work_buf_offset = |
872 | (idx_in_grid_red * num_block_iters + block_red_idx_offset) * |
873 | num_thread_iters + |
874 | thread_red_idx_offset; |
875 | |
876 | // Don't read/write in temporary buffers if in a predicated dimension |
877 | bool grid_reduce_participate = index_utils:: |
878 | maskedIsZero<isPred(X_BLOCK), isPred(Y_BLOCK), isPred(Z_BLOCK)>(blockIdx); |
879 | |
880 | if (grid_reduce_participate && block_reduce_participate) { |
881 | if (has_block_result) { |
882 | copyTuple(global_work_buffer, work_buf_offset, block_result); |
883 | } |
884 | } |
885 | |
886 | // -- GLOBAL BUFFER FILLED -- // |
887 | |
888 | bool last_block = index_utils:: |
889 | maskedIsLast<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
890 | blockIdx, gridDim); |
891 | |
892 | if (grid_reduce_participate) { |
893 | // Don't need to sync up blocks that are not participating in this |
894 | // reduction |
895 | grid_sync::sync< |
896 | isReduce(X_BLOCK), |
897 | isReduce(Y_BLOCK), |
898 | isReduce(Z_BLOCK), |
899 | PERSISTENT_REDUCTION>( |
900 | global_sync_buffer[block_red_idx_offset], grid_red_size, last_block); |
901 | } |
902 | |
903 | // -- START BLOCK CLEANUP -- // |
904 | // All blocks perform the last cleanup, so every block, and every thread |
905 | )" |
906 | R"( |
907 | // will have the final result |
908 | |
909 | // Initialize block result |
910 | LocalTuple<Types...> last_block_result(init_val); |
911 | |
912 | if ((PERSISTENT_REDUCTION || last_block) && grid_reduce_participate) { |
913 | // Can use the last block to reduce all the values the blocks filled in. |
914 | // Can use any thread that has been predicated, or has been reduced to do |
915 | // this reduction, cannot use any block that's associated with an |
916 | // iteration domain |
917 | |
918 | // Start with non-block reduction |
919 | |
920 | // Index in the reduction segment |
921 | int tid_in_block_reduction_2 = index_utils::maskedOffset< |
922 | activeNotIter(X_THREAD), |
923 | activeNotIter(Y_THREAD), |
924 | activeNotIter(Z_THREAD)>(threadIdx, blockDim); |
925 | |
926 | int block_reduction_size_2 = index_utils::maskedSize< |
927 | activeNotIter(X_THREAD), |
928 | activeNotIter(Y_THREAD), |
929 | activeNotIter(Z_THREAD)>(blockDim); |
930 | |
931 | // 3D buffer of reductions: |
932 | // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] |
933 | // Change the offset, we want to keep the last two dimensions, but the |
934 | // first dimension is what we will reduce over |
935 | const auto work_buf_offset_2 = |
936 | block_red_idx_offset * num_thread_iters + thread_red_idx_offset; |
937 | for (auto reduction_i = tid_in_block_reduction_2; |
938 | reduction_i < grid_red_size; |
939 | reduction_i += block_reduction_size_2) { |
940 | impl::reduceTuple( |
941 | last_block_result, |
942 | 0, |
943 | global_work_buffer, |
944 | work_buf_offset_2 + |
945 | reduction_i * num_block_iters * |
946 | num_thread_iters, // Iterating over the outer most |
947 | // dimension, so need to stride by the |
948 | // total number of grid reductions. Could |
949 | // come back and change it so this is the |
950 | // contiguous dimension |
951 | reduction_op); |
952 | } |
953 | |
954 | // -- START LAST BLOCK - BLOCK REDUCTION -- // |
955 | |
956 | // Reduced so we have one value per thread, we need to further reduce any |
957 | // dimension that is not an iter dimension |
958 | |
959 | // Which block reduction this thread is participating in |
960 | int block_reduction_idx = index_utils:: |
961 | maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
962 | threadIdx, blockDim); |
963 | |
964 | // Offset in smem for this thread's result |
965 | auto smem_offset = |
966 | block_reduction_idx * block_reduction_size_2 + tid_in_block_reduction_2; |
967 | |
968 | // Similar as before, reduce down to nearest power of 2 so we can do a |
969 | // tree reduction |
970 | int np2 = 1 << (31 - __clz(min(block_reduction_size_2, grid_red_size))); |
971 | |
972 | // Threads values are initialized, so all can participate here |
973 | if (tid_in_block_reduction_2 >= np2) { |
974 | copyTuple(shared_buf, smem_offset, last_block_result); |
975 | } |
976 | |
977 | block_sync::sync(); |
978 | |
979 | if (tid_in_block_reduction_2 < np2 && |
980 | tid_in_block_reduction_2 + np2 < |
981 | min(block_reduction_size_2, grid_red_size)) { |
982 | impl::reduceTuple( |
983 | last_block_result, 0, shared_buf, smem_offset + np2, reduction_op); |
984 | } |
985 | |
986 | if (tid_in_block_reduction_2 < np2) { |
987 | copyTuple(shared_buf, smem_offset, last_block_result); |
988 | } |
989 | |
990 | // Always sync when communicating across smem |
991 | block_sync::sync(); |
992 | |
993 | // Reduce down to 2 values, last thread will do the final reduction and |
994 | // can save a syncthreads this way |
995 | for (int factor = np2 / 2; factor > 1; factor >>= 1) { |
996 | if (tid_in_block_reduction_2 < factor) { |
997 | impl::reduceTuple( |
998 | shared_buf, |
999 | smem_offset, |
1000 | shared_buf, |
1001 | smem_offset + factor, |
1002 | reduction_op); |
1003 | } |
1004 | block_sync::sync(); |
1005 | } |
1006 | |
1007 | // If this thread in each block has the final result before broadcasting |
1008 | // to all other threads in block |
1009 | bool has_block_result_2 = index_utils::maskedIsZero< |
1010 | activeNotIter(X_THREAD), |
1011 | activeNotIter(Y_THREAD), |
1012 | activeNotIter(Z_THREAD)>(threadIdx); |
1013 | // Do the last reduction, protected by the write predicate |
1014 | copyTuple(last_block_result, shared_buf, smem_offset); |
1015 | if (has_block_result && grid_reduce_participate) { |
1016 | impl::reduceTuple(last_block_result, 0, out, 0, reduction_op); |
1017 | if (min(block_reduction_size_2, grid_red_size) > 1) { |
1018 | impl::reduceTuple( |
1019 | last_block_result, 0, shared_buf, smem_offset + 1, reduction_op); |
1020 | } |
1021 | } |
1022 | if (grid_reduce_participate && PERSISTENT_REDUCTION) { |
1023 | // If persistent reduction, always broadcast reduced values |
1024 | copyTuple(shared_buf, smem_offset, last_block_result); |
1025 | block_sync::sync(); |
1026 | if (write_pred && block_reduce_participate) { |
1027 | copyTuple( |
1028 | out, shared_buf, block_reduction_idx * block_reduction_size_2); |
1029 | } |
1030 | // For persistent kernels we double the global buffer allocation so we |
1031 | // don't need to protect those buffers every iteration preventing the |
1032 | // need of an additional grid_sync. Since we flip back and forth between |
1033 | // sections of the buffer, the one grid sync protects the other part of |
1034 | // the buffer. |
1035 | } else { |
1036 | if (grid_reduce_participate) { |
1037 | if (last_block && has_block_result && block_reduce_participate && |
1038 | write_pred) { |
1039 | copyTuple( |
1040 | out, shared_buf, block_reduction_idx * block_reduction_size_2); |
1041 | } |
1042 | } |
1043 | } |
1044 | // Forward protect the smem used in this reduction |
1045 | block_sync::sync(); |
1046 | } |
1047 | } |
1048 | |
1049 | //! Profiled version |
1050 | template < |
1051 | int X_BLOCK, |
1052 | int Y_BLOCK, |
1053 | int Z_BLOCK, |
1054 | int X_THREAD, |
1055 | int Y_THREAD, |
1056 | int Z_THREAD, |
1057 | bool PERSISTENT_REDUCTION, |
1058 | bool BROADCAST> |
1059 | template <typename Func, typename... Types> |
1060 | __device__ __inline__ void ParallelReduce< |
1061 | X_BLOCK, |
1062 | Y_BLOCK, |
1063 | Z_BLOCK, |
1064 | X_THREAD, |
1065 | Y_THREAD, |
1066 | Z_THREAD, |
1067 | PERSISTENT_REDUCTION, |
1068 | BROADCAST>:: |
1069 | reduce( |
1070 | RefTuple<Types...> out, |
1071 | const ConstRefTuple<Types...>& inp, |
1072 | VolatilePtrTuple<Types...> global_work_buffer, |
1073 | int64_t* global_sync_buffer, // Allocated as product of all |
1074 | // non-participating Grid dimension |
1075 | PtrTuple<Types...> shared_buf, |
1076 | bool read_pred, // Prevent reading from out of bounds memory |
1077 | bool write_pred, // Prevent from writing out of bounds |
1078 | const LocalTuple<Types...>& init_val, |
1079 | Func reduction_op, |
1080 | int64_t& cycles, |
1081 | int64_t& count) { |
1082 | int64_t start_counter = 0; |
1083 | |
1084 | if (isLastBlockInGrid() && |
1085 | index_utils::maskedIsZero<true, true, true>(threadIdx)) { |
1086 | start_counter = readCycleCounter(); |
1087 | } |
1088 | |
1089 | reduce( |
1090 | out, |
1091 | inp, |
1092 | global_work_buffer, |
1093 | global_sync_buffer, |
1094 | shared_buf, |
1095 | read_pred, |
1096 | write_pred, |
1097 | init_val, |
1098 | reduction_op); |
1099 | |
1100 | if (isLastBlockInGrid() && |
1101 | index_utils::maskedIsZero<true, true, true>(threadIdx)) { |
1102 | cycles += readCycleCounter() - start_counter; |
1103 | ++count; |
1104 | } |
1105 | } |
1106 | |
1107 | template < |
1108 | int X_BLOCK, |
1109 | int Y_BLOCK, |
1110 | int Z_BLOCK, |
1111 | int X_THREAD, |
1112 | int Y_THREAD, |
1113 | int Z_THREAD, |
1114 | bool PERSISTENT_REDUCTION, |
1115 | bool BROADCAST> |
1116 | template <typename... DataTypes, typename... Funcs, typename... BoolTypes> |
1117 | __device__ __inline__ void ParallelReduce< |
1118 | X_BLOCK, |
1119 | Y_BLOCK, |
1120 | Z_BLOCK, |
1121 | X_THREAD, |
1122 | Y_THREAD, |
1123 | Z_THREAD, |
1124 | PERSISTENT_REDUCTION, |
1125 | BROADCAST>:: |
1126 | reduceGroup( |
1127 | RefTuple<DataTypes...> out, |
1128 | const ConstRefTuple<DataTypes...>& inp, |
1129 | VolatilePtrTuple<DataTypes...> global_work_buffer, |
1130 | const LocalTuple<DataTypes...>& init_val, |
1131 | int64_t* global_sync_buffer, |
1132 | void* shared_mem, |
1133 | const LocalTuple<BoolTypes...>& read_preds, |
1134 | const LocalTuple<BoolTypes...>& write_preds, |
1135 | Funcs... funcs) { |
1136 | static_assert( |
1137 | sizeof...(DataTypes) == sizeof...(Funcs), |
1138 | "Mismatched number of Tuple values and functions"); |
1139 | static_assert( |
1140 | sizeof...(DataTypes) == sizeof...(BoolTypes), |
1141 | "Mismatched number of Tuple values and predicate values"); |
1142 | |
1143 | // If no reduction needed, just return input |
1144 | if (!BLOCK_REDUCE && !GRID_REDUCE) { |
1145 | copyTupleIf(out, inp, read_preds && write_preds); |
1146 | return; |
1147 | } |
1148 | |
1149 | // Don't read/write in temporary buffers if in a predicated dimension |
1150 | const bool block_reduce_participate = index_utils:: |
1151 | maskedIsZero<isPred(X_THREAD), isPred(Y_THREAD), isPred(Z_THREAD)>( |
1152 | threadIdx); |
1153 | |
1154 | // Only threads that with id == 0 in the dimensions being reduced will |
1155 | // have a valid result |
1156 | const bool has_block_result = index_utils:: |
1157 | maskedIsZero<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( |
1158 | threadIdx); |
1159 | |
1160 | // Initial per-block reduction. Result is broadcast if specified |
1161 | // and this call is block reduction only. |
1162 | const auto block_result = reduceGroupBlock < !GRID_REDUCE && |
1163 | BROADCAST > (inp, |
1164 | init_val, |
1165 | shared_mem, |
1166 | read_preds, |
1167 | block_reduce_participate, |
1168 | funcs...); |
1169 | // If block reduction only, save to out and exit |
1170 | if (!GRID_REDUCE) { |
1171 | copyTupleIf( |
1172 | out, |
1173 | block_result, |
1174 | write_preds && |
1175 | (block_reduce_participate && (BROADCAST || has_block_result))); |
1176 | |
1177 | // Need a block sync here as reduceGroupBlock does not |
1178 | // forward-protect the smem buffer. This block sync is not |
1179 | // necessary when a grid reduction follows since a block sync is |
1180 | // done just before the grid sync. |
1181 | block_sync::sync(); |
1182 | return; |
1183 | } |
1184 | |
1185 | // -- START GRID REDUCTION -- // |
1186 | // Grid reductions are more challenging for two reasons, (1) the reduction |
1187 | // itself is 3D instead of 2D because we now have an iter domain space in |
1188 | // the grid dimension. (2) a tree reduction isn't performed, instead all |
1189 | // blocks will populate GMEM and one block will finish the grid reduction. |
1190 | |
1191 | // What is the grid reduction size, block reduction already performed so |
1192 | // that doesn't have to be taken into consideration |
1193 | const auto grid_red_size = index_utils:: |
1194 | maskedSize<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
1195 | gridDim); |
1196 | |
1197 | // Which ID in the reduction is this block. Threads can participate in |
1198 | // multiple grid reductions, but the block will have the same relative index |
1199 | // in those reductions |
1200 | const auto idx_in_grid_red = index_utils:: |
1201 | maskedOffset<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
1202 | blockIdx, gridDim); |
1203 | |
1204 | // How many grid reductions have to be performed, in the grid dimension |
1205 | const auto num_block_iters = index_utils:: |
1206 | maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>(gridDim); |
1207 | |
1208 | // Which grid reduction does this block participate in, in the grid |
1209 | // dimension |
1210 | const auto block_red_idx_offset = index_utils:: |
1211 | maskedOffset<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>( |
1212 | blockIdx, gridDim); |
1213 | |
1214 | // How many grid reductions have to be performed, in the block dimension |
1215 | const auto num_thread_iters = index_utils:: |
1216 | maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
1217 | blockDim); |
1218 | |
1219 | // Which grid reduction does this thread participate in, in the block |
1220 | // dimension |
1221 | const auto thread_red_idx_offset = index_utils:: |
1222 | maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
1223 | threadIdx, blockDim); |
1224 | |
1225 | // 3D buffer of reductions: |
1226 | // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] |
1227 | // Offset into the work buffer |
1228 | const auto work_buf_offset = |
1229 | (idx_in_grid_red * num_block_iters + block_red_idx_offset) * |
1230 | num_thread_iters + |
1231 | thread_red_idx_offset; |
1232 | |
1233 | // Don't read/write in temporary buffers if in a predicated dimension |
1234 | bool grid_reduce_participate = index_utils:: |
1235 | maskedIsZero<isPred(X_BLOCK), isPred(Y_BLOCK), isPred(Z_BLOCK)>(blockIdx); |
1236 | |
1237 | if (PERSISTENT_REDUCTION && flip) { |
1238 | auto global_buffer_size = |
1239 | index_utils:: |
1240 | maskedSize<isIter(X_BLOCK), isIter(Y_BLOCK), isIter(Z_BLOCK)>( |
1241 | gridDim) * |
1242 | index_utils:: |
1243 | maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
1244 | blockDim) * |
1245 | grid_red_size; |
1246 | global_work_buffer += global_buffer_size; |
1247 | } |
1248 | flip = !flip; |
1249 | |
1250 | // Per-block partial reduction to global work buffer |
1251 | if (grid_reduce_participate && block_reduce_participate && has_block_result) { |
1252 | copyTuple(global_work_buffer, work_buf_offset, block_result); |
1253 | } |
1254 | |
1255 | // -- GLOBAL BUFFER FILLED -- // |
1256 | |
1257 | bool last_block = index_utils:: |
1258 | maskedIsLast<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
1259 | blockIdx, gridDim); |
1260 | |
1261 | if (grid_reduce_participate) { |
1262 | // Don't need to sync up blocks that are not participating in this |
1263 | // reduction |
1264 | grid_sync::sync< |
1265 | isReduce(X_BLOCK), |
1266 | isReduce(Y_BLOCK), |
1267 | isReduce(Z_BLOCK), |
1268 | PERSISTENT_REDUCTION>( |
1269 | global_sync_buffer[block_red_idx_offset], grid_red_size, last_block); |
1270 | } |
1271 | |
1272 | // -- START BLOCK CLEANUP -- // |
1273 | reduceGroupLastBlock( |
1274 | out, |
1275 | global_work_buffer, |
1276 | init_val, |
1277 | shared_mem, |
1278 | block_red_idx_offset, |
1279 | num_thread_iters, |
1280 | num_block_iters, |
1281 | thread_red_idx_offset, |
1282 | grid_red_size, |
1283 | write_preds, |
1284 | block_reduce_participate, |
1285 | grid_reduce_participate, |
1286 | funcs...); |
1287 | |
1288 | // Forward protect the smem buffer |
1289 | block_sync::sync(); |
1290 | } |
1291 | |
1292 | template < |
1293 | int X_BLOCK, |
1294 | int Y_BLOCK, |
1295 | int Z_BLOCK, |
1296 | int X_THREAD, |
1297 | int Y_THREAD, |
1298 | int Z_THREAD, |
1299 | bool PERSISTENT_REDUCTION, |
1300 | bool BROADCAST> |
1301 | template <typename... DataTypes, typename... Funcs, typename... BoolTypes> |
1302 | __device__ __inline__ void ParallelReduce< |
1303 | X_BLOCK, |
1304 | Y_BLOCK, |
1305 | Z_BLOCK, |
1306 | X_THREAD, |
1307 | Y_THREAD, |
1308 | Z_THREAD, |
1309 | PERSISTENT_REDUCTION, |
1310 | BROADCAST>:: |
1311 | reduceGroup( |
1312 | RefTuple<DataTypes...> out, |
1313 | const ConstRefTuple<DataTypes...>& inp, |
1314 | VolatilePtrTuple<DataTypes...> global_work_buffer, |
1315 | const LocalTuple<DataTypes...>& init_val, |
1316 | int64_t* global_sync_buffer, |
1317 | void* shared_mem, |
1318 | const LocalTuple<BoolTypes...>& read_preds, |
1319 | const LocalTuple<BoolTypes...>& write_preds, |
1320 | int64_t& cycles, |
1321 | int64_t& count, |
1322 | Funcs... funcs) { |
1323 | int64_t start_counter = 0; |
1324 | |
1325 | if (isLastBlockInGrid() && |
1326 | index_utils::maskedIsZero<true, true, true>(threadIdx)) { |
1327 | start_counter = readCycleCounter(); |
1328 | } |
1329 | |
1330 | reduceGroup( |
1331 | out, |
1332 | inp, |
1333 | global_work_buffer, |
1334 | init_val, |
1335 | global_sync_buffer, |
1336 | shared_mem, |
1337 | read_preds, |
1338 | write_preds, |
1339 | funcs...); |
1340 | |
1341 | if (isLastBlockInGrid() && |
1342 | index_utils::maskedIsZero<true, true, true>(threadIdx)) { |
1343 | cycles += readCycleCounter() - start_counter; |
1344 | ++count; |
1345 | } |
1346 | } |
1347 | |
1348 | template < |
1349 | int X_BLOCK, |
1350 | int Y_BLOCK, |
1351 | int Z_BLOCK, |
1352 | int X_THREAD, |
1353 | int Y_THREAD, |
1354 | int Z_THREAD, |
1355 | bool PERSISTENT_REDUCTION, |
1356 | bool BROADCAST> |
1357 | template < |
1358 | bool BLOCK_BROADCAST, |
1359 | typename... DataTypes, |
1360 | typename... Funcs, |
1361 | typename... BoolTypes> |
1362 | __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce< |
1363 | X_BLOCK, |
1364 | Y_BLOCK, |
1365 | Z_BLOCK, |
1366 | X_THREAD, |
1367 | Y_THREAD, |
1368 | Z_THREAD, |
1369 | PERSISTENT_REDUCTION, |
1370 | BROADCAST>:: |
1371 | reduceGroupBlock( |
1372 | const ConstRefTuple<DataTypes...>& inp, |
1373 | const LocalTuple<DataTypes...>& init_val, |
1374 | void* shared_mem, |
1375 | const LocalTuple<BoolTypes...>& read_preds, |
1376 | bool block_reduce_participate, |
1377 | Funcs... funcs) { |
1378 | const bool has_block_result = index_utils:: |
1379 | maskedIsZero<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( |
1380 | threadIdx); |
1381 | |
1382 | )" |
1383 | R"( |
1384 | // Initialize block result |
1385 | LocalTuple<DataTypes...> block_result = init_val; |
1386 | |
1387 | copyTupleIf(block_result, inp, block_reduce_participate && read_preds); |
1388 | |
1389 | // Size of the block reduction segment, can be an int since it's limited |
1390 | // to number of threads |
1391 | const int block_reduction_size = index_utils:: |
1392 | maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( |
1393 | blockDim); |
1394 | |
1395 | // Index in the reduction segment, can be an int since it's limited to |
1396 | // number of threads |
1397 | const int tid_in_block_reduction = index_utils:: |
1398 | maskedOffset<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( |
1399 | threadIdx, blockDim); |
1400 | |
1401 | // ID of the block reduction this thread is participating in |
1402 | // |
1403 | // If any of the parallel dimensions are predicated out, that means |
1404 | // they've already been reduced, so we only care about the first thread in |
1405 | // that dimension. Therefore don't expand the reduction_idx by that |
1406 | // dimension |
1407 | const int block_reduction_idx = index_utils:: |
1408 | maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
1409 | threadIdx, blockDim); |
1410 | |
1411 | // Do not protect the smem buffer as it's not always necessary. |
1412 | impl::blockReduceEach< |
1413 | BLOCK_BROADCAST, |
1414 | false, |
1415 | LocalTuple<DataTypes...>, |
1416 | Funcs...>( |
1417 | block_result, |
1418 | block_result, |
1419 | shared_mem, |
1420 | has_block_result, |
1421 | tid_in_block_reduction, |
1422 | block_reduction_size, |
1423 | block_reduction_size, |
1424 | block_reduction_idx, |
1425 | funcs...); |
1426 | |
1427 | return block_result; |
1428 | } |
1429 | |
1430 | template < |
1431 | int X_BLOCK, |
1432 | int Y_BLOCK, |
1433 | int Z_BLOCK, |
1434 | int X_THREAD, |
1435 | int Y_THREAD, |
1436 | int Z_THREAD, |
1437 | bool PERSISTENT_REDUCTION, |
1438 | bool BROADCAST> |
1439 | template <typename... DataTypes, typename... Funcs, typename... BoolTypes> |
1440 | __device__ __inline__ void ParallelReduce< |
1441 | X_BLOCK, |
1442 | Y_BLOCK, |
1443 | Z_BLOCK, |
1444 | X_THREAD, |
1445 | Y_THREAD, |
1446 | Z_THREAD, |
1447 | PERSISTENT_REDUCTION, |
1448 | BROADCAST>:: |
1449 | reduceGroupLastBlock( |
1450 | RefTuple<DataTypes...>& out, |
1451 | const VolatilePtrTuple<DataTypes...>& global_work_buffer, |
1452 | const LocalTuple<DataTypes...>& init_val, |
1453 | void* shared_mem, |
1454 | nvfuser_index_t block_red_idx_offset, |
1455 | nvfuser_index_t num_thread_iters, |
1456 | nvfuser_index_t num_block_iters, |
1457 | nvfuser_index_t thread_red_idx_offset, |
1458 | nvfuser_index_t grid_red_size, |
1459 | const LocalTuple<BoolTypes...>& write_preds, |
1460 | bool block_reduce_participate, |
1461 | bool grid_reduce_participate, |
1462 | Funcs... reduction_ops) { |
1463 | // Initialize block result |
1464 | LocalTuple<DataTypes...> last_block_result(init_val); |
1465 | |
1466 | const bool last_block = index_utils:: |
1467 | maskedIsLast<isReduce(X_BLOCK), isReduce(Y_BLOCK), isReduce(Z_BLOCK)>( |
1468 | blockIdx, gridDim); |
1469 | |
1470 | if ((PERSISTENT_REDUCTION || last_block) && grid_reduce_participate) { |
1471 | // Can use the last block to reduce all the values the blocks filled in. |
1472 | // Can use any thread that has been predicated, or has been reduced to do |
1473 | // this reduction, cannot use any block that's associated with an |
1474 | // iteration domain |
1475 | |
1476 | // Start with non-block reduction |
1477 | |
1478 | // Index in the reduction segment |
1479 | int tid_in_block_reduction = index_utils::maskedOffset< |
1480 | activeNotIter(X_THREAD), |
1481 | activeNotIter(Y_THREAD), |
1482 | activeNotIter(Z_THREAD)>(threadIdx, blockDim); |
1483 | |
1484 | int block_reduction_size = index_utils::maskedSize< |
1485 | activeNotIter(X_THREAD), |
1486 | activeNotIter(Y_THREAD), |
1487 | activeNotIter(Z_THREAD)>(blockDim); |
1488 | |
1489 | bool has_block_result = index_utils::maskedIsZero< |
1490 | activeNotIter(X_THREAD), |
1491 | activeNotIter(Y_THREAD), |
1492 | activeNotIter(Z_THREAD)>(threadIdx); |
1493 | |
1494 | // 3D buffer of reductions: |
1495 | // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] |
1496 | // Change the offset, we want to keep the last two dimensions, but the |
1497 | // first dimension is what we will reduce over |
1498 | const auto work_buf_offset = |
1499 | block_red_idx_offset * num_thread_iters + thread_red_idx_offset; |
1500 | for (auto reduction_i = tid_in_block_reduction; reduction_i < grid_red_size; |
1501 | reduction_i += block_reduction_size) { |
1502 | impl::reduceEach( |
1503 | last_block_result, |
1504 | 0, |
1505 | global_work_buffer, |
1506 | work_buf_offset + |
1507 | reduction_i * num_block_iters * |
1508 | num_thread_iters, // Iterating over the outer most |
1509 | // dimension, so need to stride by the |
1510 | // total number of grid reductions. Could |
1511 | // come back and change it so this is the |
1512 | // contiguous dimension |
1513 | reduction_ops...); |
1514 | } |
1515 | |
1516 | // Which block reduction this thread is participating in |
1517 | int block_reduction_idx = index_utils:: |
1518 | maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( |
1519 | threadIdx, blockDim); |
1520 | |
1521 | impl::blockReduceEach<BROADCAST, false, LocalTuple<DataTypes...>, Funcs...>( |
1522 | last_block_result, |
1523 | last_block_result, |
1524 | shared_mem, |
1525 | has_block_result, |
1526 | tid_in_block_reduction, |
1527 | block_reduction_size, |
1528 | min(grid_red_size, block_reduction_size), |
1529 | block_reduction_idx, |
1530 | reduction_ops...); |
1531 | |
1532 | copyTupleIf( |
1533 | out, |
1534 | last_block_result, |
1535 | write_preds && |
1536 | (block_reduce_participate && (BROADCAST || has_block_result))); |
1537 | } |
1538 | } |
1539 | |
1540 | } // namespace fused_reduction |
1541 | )" ; |
1542 | |
1543 | } // namespace nvfuser_resources |
1544 | |