1// Generated from "/code/pytorch/third_party/nvfuser/runtime/broadcast.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* broadcast_cu = R"(
7
8namespace broadcast {
9// Broadcasts within partitioned groups of threads.
10//
11// X_THREAD: Broadcast from threadIdx.x == 0 if true
12// Y_THREAD: Broadcast from threadIdx.y == 0 if true
13// Z_THREAD: Broadcast from threadIdx.z == 0 if true
14// inp_val: Per-thread source value. Only valid when the thread is a source.
15// out: Per-thread output location
16//
17template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, typename T>
18__device__ void blockBroadcast(
19 T& out,
20 const T& inp_val,
21 T* shared_mem,
22 bool read_write_pred) {
23 const bool has_valid_data = (!X_THREAD || threadIdx.x == 0) &&
24 (!Y_THREAD || threadIdx.y == 0) && (!Z_THREAD || threadIdx.z == 0);
25
26 const auto shared_offset =
27 index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
28 threadIdx, blockDim);
29
30 if (has_valid_data && read_write_pred) {
31 shared_mem[shared_offset] = inp_val;
32 }
33
34 block_sync::sync();
35
36 if (read_write_pred) {
37 out = shared_mem[shared_offset];
38 }
39
40 block_sync::sync();
41}
42
43} // namespace broadcast
44)";
45
46} // namespace nvfuser_resources
47