1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/warp.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* warp_cu = R"( |
7 | namespace warp { |
8 | |
9 | template < |
10 | bool SINGLE_WARP, |
11 | typename T, |
12 | typename Func, |
13 | typename _dim3ti, |
14 | typename _dim3bd> |
15 | __device__ void warpReduceTIDX( |
16 | T& out, |
17 | const T& inp_val, |
18 | Func reduction_op, |
19 | const _dim3ti& thread_idx, |
20 | const _dim3bd& block_dim, |
21 | T* shared_mem, |
22 | bool read_write_pred, |
23 | T init_val) { |
24 | constexpr int WARP_SIZE = 32; |
25 | |
26 | // Assume input padded to multiples of a warp |
27 | T reduce_val = init_val; |
28 | |
29 | // Do warp reduction |
30 | if (read_write_pred) { |
31 | reduce_val = inp_val; |
32 | } |
33 | |
34 | // Reduce within each warp |
35 | for (int i = 16; i >= 1; i /= 2) { |
36 | reduction_op( |
37 | reduce_val, __shfl_xor_sync(0xffffffff, reduce_val, i, WARP_SIZE)); |
38 | } |
39 | |
40 | // Reduce across warp if needed |
41 | // Load value to shared mem |
42 | if (!SINGLE_WARP) { |
43 | unsigned int warp_idx = thread_idx.x / WARP_SIZE; |
44 | unsigned int lane_idx = thread_idx.x % WARP_SIZE; |
45 | unsigned int reduce_group_id = thread_idx.z * block_dim.y + thread_idx.y; |
46 | bool is_warp_head = lane_idx == 0; |
47 | unsigned int reduction_size = block_dim.x; |
48 | unsigned int num_of_warps = reduction_size / WARP_SIZE; |
49 | unsigned int smem_offset = reduce_group_id * num_of_warps; |
50 | |
51 | block_sync::sync(); |
52 | |
53 | if (is_warp_head) { |
54 | shared_mem[smem_offset + warp_idx] = reduce_val; |
55 | } |
56 | |
57 | block_sync::sync(); |
58 | |
59 | if (warp_idx == 0) { |
60 | // This assumes num_of_warps will be < 32, meaning < 1024 threads. |
61 | // Should be true for long enough. |
62 | assert(num_of_warps <= 32); |
63 | |
64 | reduce_val = lane_idx < num_of_warps ? shared_mem[smem_offset + lane_idx] |
65 | : init_val; |
66 | |
67 | // Reduce within warp 0 |
68 | for (int i = 16; i >= 1; i /= 2) { |
69 | reduction_op( |
70 | reduce_val, __shfl_xor_sync(0xffffffff, reduce_val, i, 32)); |
71 | } |
72 | } |
73 | |
74 | if (is_warp_head) { |
75 | reduction_op(out, reduce_val); |
76 | } |
77 | } else { |
78 | reduction_op(out, reduce_val); |
79 | } |
80 | } |
81 | |
82 | } // namespace warp |
83 | )" ; |
84 | |
85 | } // namespace nvfuser_resources |
86 | |