1// Generated from "/code/pytorch/third_party/nvfuser/runtime/warp.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* warp_cu = R"(
7namespace warp {
8
9template <
10 bool SINGLE_WARP,
11 typename T,
12 typename Func,
13 typename _dim3ti,
14 typename _dim3bd>
15__device__ void warpReduceTIDX(
16 T& out,
17 const T& inp_val,
18 Func reduction_op,
19 const _dim3ti& thread_idx,
20 const _dim3bd& block_dim,
21 T* shared_mem,
22 bool read_write_pred,
23 T init_val) {
24 constexpr int WARP_SIZE = 32;
25
26 // Assume input padded to multiples of a warp
27 T reduce_val = init_val;
28
29 // Do warp reduction
30 if (read_write_pred) {
31 reduce_val = inp_val;
32 }
33
34 // Reduce within each warp
35 for (int i = 16; i >= 1; i /= 2) {
36 reduction_op(
37 reduce_val, __shfl_xor_sync(0xffffffff, reduce_val, i, WARP_SIZE));
38 }
39
40 // Reduce across warp if needed
41 // Load value to shared mem
42 if (!SINGLE_WARP) {
43 unsigned int warp_idx = thread_idx.x / WARP_SIZE;
44 unsigned int lane_idx = thread_idx.x % WARP_SIZE;
45 unsigned int reduce_group_id = thread_idx.z * block_dim.y + thread_idx.y;
46 bool is_warp_head = lane_idx == 0;
47 unsigned int reduction_size = block_dim.x;
48 unsigned int num_of_warps = reduction_size / WARP_SIZE;
49 unsigned int smem_offset = reduce_group_id * num_of_warps;
50
51 block_sync::sync();
52
53 if (is_warp_head) {
54 shared_mem[smem_offset + warp_idx] = reduce_val;
55 }
56
57 block_sync::sync();
58
59 if (warp_idx == 0) {
60 // This assumes num_of_warps will be < 32, meaning < 1024 threads.
61 // Should be true for long enough.
62 assert(num_of_warps <= 32);
63
64 reduce_val = lane_idx < num_of_warps ? shared_mem[smem_offset + lane_idx]
65 : init_val;
66
67 // Reduce within warp 0
68 for (int i = 16; i >= 1; i /= 2) {
69 reduction_op(
70 reduce_val, __shfl_xor_sync(0xffffffff, reduce_val, i, 32));
71 }
72 }
73
74 if (is_warp_head) {
75 reduction_op(out, reduce_val);
76 }
77 } else {
78 reduction_op(out, reduce_val);
79 }
80}
81
82} // namespace warp
83)";
84
85} // namespace nvfuser_resources
86