1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/grid_broadcast.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* grid_broadcast_cu = R"( |
7 | namespace grid_broadcast { |
8 | |
9 | // Broadcasts per-thread values across threads and blocks. |
10 | // |
11 | // Function parameters: |
12 | // - out: Per-thread output location |
13 | // - inp_val: Per-thread input value |
14 | // - work_buf: Temporary buffer for communication across threads/blocks |
15 | // - sync_flags: A vector of integers for synchronizations |
16 | // |
17 | // Template parameters: |
18 | // - X/Y/Z_BLOCK: When true, broadcasts across thread blocks along the X/Y/Z |
19 | // dimensions |
20 | // - X/Y/Z_THREAD: When true, broadcasts across threads along the X/Y/Z |
21 | // dimensions |
22 | template < |
23 | bool X_BLOCK, |
24 | bool Y_BLOCK, |
25 | bool Z_BLOCK, |
26 | bool X_THREAD, |
27 | bool Y_THREAD, |
28 | bool Z_THREAD, |
29 | typename T> |
30 | __device__ void broadcast( |
31 | T& out, |
32 | const T& inp_val, |
33 | volatile T* work_buf, |
34 | Tensor<int64_t, 1> sync_flags, |
35 | bool read_write_pred) { |
36 | // Number of values broadcasted in the grid dimensions |
37 | const auto grid_seg_size = |
38 | index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim); |
39 | |
40 | // Index of the broadcast we're performing out of the grid_seg_size |
41 | const auto grid_seg_idx = |
42 | index_utils::maskedOffset<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>( |
43 | blockIdx, gridDim); |
44 | |
45 | // Number of threads not participating in a broadcast dimension, this is the |
46 | // number of thread entries to expect in the work buffer, therefore a striding |
47 | const auto block_stride = |
48 | index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim); |
49 | |
50 | // Which broadcast in the block this is to line up the entry with the work |
51 | // buffer |
52 | const auto thread_offset = |
53 | index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( |
54 | threadIdx, blockDim); |
55 | |
56 | const bool has_valid_data = (!X_BLOCK || blockIdx.x == gridDim.x - 1) && |
57 | (!Y_BLOCK || blockIdx.y == gridDim.y - 1) && |
58 | (!Z_BLOCK || blockIdx.z == gridDim.z - 1) && |
59 | (!X_THREAD || threadIdx.x == 0) && (!Y_THREAD || threadIdx.y == 0) && |
60 | (!Z_THREAD || threadIdx.z == 0); |
61 | |
62 | if (has_valid_data && read_write_pred) { |
63 | work_buf[grid_seg_idx * block_stride + thread_offset] = inp_val; |
64 | __threadfence(); |
65 | } |
66 | |
67 | grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, true>( |
68 | sync_flags[grid_seg_idx], grid_seg_size); |
69 | |
70 | if (read_write_pred) { |
71 | out = work_buf[grid_seg_idx * block_stride + thread_offset]; |
72 | } |
73 | |
74 | // Make sure everyone has read from the buffer before continuing the kernel |
75 | // and potentially overwriting |
76 | grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, true>( |
77 | sync_flags[grid_seg_idx], grid_seg_size); |
78 | } |
79 | } // namespace grid_broadcast |
80 | )" ; |
81 | |
82 | } // namespace nvfuser_resources |
83 | |