1// Generated from "/code/pytorch/third_party/nvfuser/runtime/grid_broadcast.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* grid_broadcast_cu = R"(
7namespace grid_broadcast {
8
9// Broadcasts per-thread values across threads and blocks.
10//
11// Function parameters:
12// - out: Per-thread output location
13// - inp_val: Per-thread input value
14// - work_buf: Temporary buffer for communication across threads/blocks
15// - sync_flags: A vector of integers for synchronizations
16//
17// Template parameters:
18// - X/Y/Z_BLOCK: When true, broadcasts across thread blocks along the X/Y/Z
19// dimensions
20// - X/Y/Z_THREAD: When true, broadcasts across threads along the X/Y/Z
21// dimensions
22template <
23 bool X_BLOCK,
24 bool Y_BLOCK,
25 bool Z_BLOCK,
26 bool X_THREAD,
27 bool Y_THREAD,
28 bool Z_THREAD,
29 typename T>
30__device__ void broadcast(
31 T& out,
32 const T& inp_val,
33 volatile T* work_buf,
34 Tensor<int64_t, 1> sync_flags,
35 bool read_write_pred) {
36 // Number of values broadcasted in the grid dimensions
37 const auto grid_seg_size =
38 index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
39
40 // Index of the broadcast we're performing out of the grid_seg_size
41 const auto grid_seg_idx =
42 index_utils::maskedOffset<!X_BLOCK, !Y_BLOCK, !Z_BLOCK>(
43 blockIdx, gridDim);
44
45 // Number of threads not participating in a broadcast dimension, this is the
46 // number of thread entries to expect in the work buffer, therefore a striding
47 const auto block_stride =
48 index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
49
50 // Which broadcast in the block this is to line up the entry with the work
51 // buffer
52 const auto thread_offset =
53 index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
54 threadIdx, blockDim);
55
56 const bool has_valid_data = (!X_BLOCK || blockIdx.x == gridDim.x - 1) &&
57 (!Y_BLOCK || blockIdx.y == gridDim.y - 1) &&
58 (!Z_BLOCK || blockIdx.z == gridDim.z - 1) &&
59 (!X_THREAD || threadIdx.x == 0) && (!Y_THREAD || threadIdx.y == 0) &&
60 (!Z_THREAD || threadIdx.z == 0);
61
62 if (has_valid_data && read_write_pred) {
63 work_buf[grid_seg_idx * block_stride + thread_offset] = inp_val;
64 __threadfence();
65 }
66
67 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, true>(
68 sync_flags[grid_seg_idx], grid_seg_size);
69
70 if (read_write_pred) {
71 out = work_buf[grid_seg_idx * block_stride + thread_offset];
72 }
73
74 // Make sure everyone has read from the buffer before continuing the kernel
75 // and potentially overwriting
76 grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, true>(
77 sync_flags[grid_seg_idx], grid_seg_size);
78}
79} // namespace grid_broadcast
80)";
81
82} // namespace nvfuser_resources
83