1// Generated from "/code/pytorch/third_party/nvfuser/runtime/grid_sync.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* grid_sync_cu = R"(
7namespace grid_sync {
8
9// Get the first bit in a 64 bit integer
10#define FIRST_UINT64_BIT ((uint64_t)1 << (sizeof(uint64_t) * 8 - 1))
11
12template <typename T>
13__device__ T globalAsVolatile(volatile T& global_val) {
14 return global_val;
15}
16
17// A grid synchronization that can be called multiple times in a kernel assuming
18// all the blocks fit on device at once. The semaphore is an integer semaphore
19// assumed to be initialized to 0 before launching the kernel. The persistent
20// option should be envoked if this sync will be called multiple times in one
21// kernel (i.e. having a grid reduce within a loop). Having multiple grid syncs
22// called once in the same kernel does not require persistent mode. Segment size
23// is the number of blocks participating in the sync in the dimensions marked by
24// [X,Y,Z]_BLOCK. The granularity of this sync are those dimensions. I.E.
25// Marking X and Y but not Z means there should be Z semaphores of size X*Y.
26template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, bool PERSISTENT>
27__device__ void sync(
28 int64_t& semaphore,
29 const uint64_t& segment_size,
30 const bool last_block) {
31 // Finish all global memory transactions before synchronizing
32 __threadfence();
33
34 // Synchronize all threads in a block before synchronizing blocks
35 block_sync::sync();
36
37 // Only allow linear_tid == 0 to participate in the synchronization
38 if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
39 // Get increment value, only want a single block to have the large
40 // increment, doesn't really matter which one, the goal is to flip/flop the
41 // first bit of a uint64_t value, since our semaphores are actualy int64_t
42 // we will just reinterpret_cast it to act as a uint64_t
43 uint64_t semaphore_increment = 1;
44
45 // Makes the assumption that blocks are in increasing order, this is not
46 // guaranteed by CUDA but this is the current behavior, and unlikely to
47 // change.
48 if (last_block) {
49 semaphore_increment = FIRST_UINT64_BIT - (segment_size - 1);
50 }
51
52 uint64_t oldArrive =
53 atomicAdd(reinterpret_cast<uint64_t*>(&semaphore), semaphore_increment);
54
55 // If for persistent kernels, lock all blocks until the semaphore has been
56 // reached. Make sure we access semaphore as a volatile address so we get
57 // the global memory updates.
58 unsigned int ns = 8;
59 while ((PERSISTENT || last_block) &&
60 ((oldArrive ^ globalAsVolatile(semaphore)) & FIRST_UINT64_BIT) ==
61 0) {
62 // Put a sleep here so we have some breaks in probing the global
63 // semaphore, giving a better chance for other warps/blocks to catch up.
64#if __CUDA_ARCH__ >= 700
65 // __nanosleep only available on compute capability 7.0 or higher
66 __nanosleep(ns); // avoids busy waiting
67 if (ns < 256) {
68 ns *= 2;
69 }
70#endif
71 }
72 }
73
74 // Sync block to make sure all other threads are waiting on the sync
75 block_sync::sync();
76}
77
78template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, bool PERSISTENT>
79__device__ void sync(int64_t& semaphore, const uint64_t& segment_size) {
80 sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT>(
81 semaphore,
82 segment_size,
83 index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim));
84}
85
86// Grid sync that can be called multiple times in the same kernel without all
87// blocks being resident on device. This allows grid sync to be called multiple
88// times as long as it's not broadcasted on the parallel axis it was reduced on.
89//
90// n_entrances is how many times every block is expected to enter into this
91// function. All blocks must enter n_entrances times. The last block is only
92// allowed to proceed once all other blocks have entered n_entrance
93// times.
94//
95// Note that this is not currently used by grid and welford reduction
96// as they use a separate sync flag for each each grid sync call.
97template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK>
98__device__ void sync(
99 int64_t& semaphore,
100 const uint64_t& segment_size,
101 const nvfuser_index_t n_entrances) {
102 // Finish all global memory transactions before synchronizing
103 __threadfence();
104
105 // Synchronize all threads in a block before synchronizing blocks
106 block_sync::sync();
107
108 // Only allow linear_tid == 0 to participate in the synchronization
109 if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
110 // Makes the assumption that blocks are in increasing order, this is not
111 // guaranteed by CUDA but this is the current behavior, and unlikely to
112 // change.
113 bool last_block =
114 index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
115 if (last_block) {
116 int64_t finished_val =
117 ((int64_t)(
118 index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim) -
119 1)) *
120 ((int64_t)n_entrances);
121
122 unsigned int ns = 8;
123 // Last block needs to wait for all other blocks to finish
124 while (globalAsVolatile(semaphore) < finished_val) {
125#if __CUDA_ARCH__ >= 700
126 // __nanosleep only available on compute capability 7.0 or higher
127 __nanosleep(ns); // avoids busy waiting
128 if (ns < 256) {
129 ns *= 2;
130 }
131#endif
132 }
133 } else {
134 auto old = atomicAdd(reinterpret_cast<uint64_t*>(&semaphore), 1);
135 }
136 }
137
138 // Sync block to make sure all other threads are waiting on the sync
139 block_sync::sync();
140}
141
142} // namespace grid_sync
143)";
144
145} // namespace nvfuser_resources
146