1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/block_sync_atomic.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* block_sync_atomic_cu = R"( |
7 | |
8 | // Counter-based block synchronization. Only meant to be used for |
9 | // debugging and validating synchronization. This should be replaced |
10 | // with cuda::barrier::arrive_and_wait as that should be more robust. |
11 | |
12 | namespace block_sync { |
13 | |
14 | using CounterType = unsigned int; |
15 | static constexpr CounterType COUNTER_TYPE_MAX = ~(CounterType)0; |
16 | __shared__ CounterType sync_counter; |
17 | |
18 | __device__ void init() { |
19 | const unsigned int tid = threadIdx.x + threadIdx.y * blockDim.x + |
20 | threadIdx.z * blockDim.x * blockDim.y; |
21 | if (tid == 0) { |
22 | sync_counter = 0; |
23 | } |
24 | __syncthreads(); |
25 | } |
26 | |
27 | // Emulate __syncthreads() with a synchronization counter |
28 | __device__ void sync() { |
29 | unsigned int backoff = 8; |
30 | const unsigned int backoff_max = 256; |
31 | const unsigned int num_threads = blockDim.x * blockDim.y * blockDim.z; |
32 | |
33 | __threadfence_block(); |
34 | |
35 | // Use counter range only up to a limit so that the next val won't |
36 | // overflow. |
37 | |
38 | const auto counter_max = (COUNTER_TYPE_MAX / num_threads) * num_threads; |
39 | const auto old = atomicInc(&sync_counter, counter_max - 1); |
40 | |
41 | const auto next = (old / num_threads) * num_threads + num_threads; |
42 | |
43 | auto local_sync_counter = *(volatile CounterType*)(&sync_counter); |
44 | |
45 | // sync_counter may wrap around, which means local_sync_counter |
46 | // becomes smaller than old. In that case, it's guaranteed that all |
47 | // threads have incremented the counter. |
48 | while (local_sync_counter < next && old < local_sync_counter) { |
49 | #if __CUDA_ARCH__ >= 700 |
50 | // __nanosleep only available on compute capability 7.0 or higher |
51 | __nanosleep(backoff); // avoids busy waiting |
52 | #endif |
53 | if (backoff < backoff_max) { |
54 | backoff *= 2; |
55 | } |
56 | local_sync_counter = *(volatile CounterType*)(&sync_counter); |
57 | } |
58 | } |
59 | |
60 | } // namespace block_sync |
61 | )" ; |
62 | |
63 | } // namespace nvfuser_resources |
64 | |