1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/broadcast.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* broadcast_cu = R"( |
7 | |
8 | namespace broadcast { |
9 | // Broadcasts within partitioned groups of threads. |
10 | // |
11 | // X_THREAD: Broadcast from threadIdx.x == 0 if true |
12 | // Y_THREAD: Broadcast from threadIdx.y == 0 if true |
13 | // Z_THREAD: Broadcast from threadIdx.z == 0 if true |
14 | // inp_val: Per-thread source value. Only valid when the thread is a source. |
15 | // out: Per-thread output location |
16 | // |
17 | template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T> |
18 | __device__ void blockBroadcast( |
19 | T& out, |
20 | const T& inp_val, |
21 | T* shared_mem, |
22 | bool read_write_pred) { |
23 | const bool has_valid_data = (!X_THREAD || threadIdx.x == 0) && |
24 | (!Y_THREAD || threadIdx.y == 0) && (!Z_THREAD || threadIdx.z == 0); |
25 | |
26 | const auto shared_offset = |
27 | index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( |
28 | threadIdx, blockDim); |
29 | |
30 | if (has_valid_data && read_write_pred) { |
31 | shared_mem[shared_offset] = inp_val; |
32 | } |
33 | |
34 | block_sync::sync(); |
35 | |
36 | if (read_write_pred) { |
37 | out = shared_mem[shared_offset]; |
38 | } |
39 | |
40 | block_sync::sync(); |
41 | } |
42 | |
43 | } // namespace broadcast |
44 | )" ; |
45 | |
46 | } // namespace nvfuser_resources |
47 | |