1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/memory.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* memory_cu = R"( |
7 | // Utility macro for this file |
8 | #define DEVICE_INLINE __device__ inline |
9 | |
10 | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) |
11 | |
12 | namespace Turing { |
13 | |
14 | namespace util { |
15 | |
16 | // Utility for converting generic pointer to SMEM pointer in PTX. |
17 | // We should review vectorized load/stores with shared memory. |
18 | // SMEM memory movement PTX is only Global -> SMEM, SMEM -> Local, Local -> |
19 | // SMEM, and this is needed for these PTX instructions to provide the SMEM |
20 | // pointer. |
21 | DEVICE_INLINE unsigned toSmem(const void* raw_ptr) { |
22 | unsigned smem_ptr_uint; |
23 | asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" |
24 | : "=r"(smem_ptr_uint) |
25 | : "l"(raw_ptr)); |
26 | |
27 | return smem_ptr_uint; |
28 | } |
29 | |
30 | // LdMatrix has .x1, .x2 and .x4 options, currently we actively use .x2 and |
31 | // .x4. In .x2 option. the the address register of upper half warp (lane 16-31) |
32 | // are un-used but on Turing [sm75,sm80) architecture these un-used addresses |
33 | // need to be valid, in the sense that: |
34 | // 1. The data it points to has to be within allocated shared mem buffer. |
35 | // 2. The address needs to be aligned to 16 byte. |
36 | // See also: |
37 | // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix |
38 | // This function addresses 2. above by masking out the sub-16B component |
39 | // of the address in upper warp and 1. is guaranteed by ldmatrix swizzle |
40 | // util. |
41 | // This will **not** affect any functionality. This is just modification |
42 | // of unused pointers to satisfy the alignment requirement on Turing |
43 | // hardware. |
44 | // The alignment requirement is lifted on sm80+, |
45 | // so this function is a no-op on Ampere or above. |
46 | DEVICE_INLINE void adjustPartialLdMatrixAddrInTuring(unsigned& addr_in_byte) { |
47 | #if (__CUDA_ARCH__ < 800) |
48 | const unsigned thread_id = threadIdx.x; |
49 | // Upper half warp has 8 bytes offset from aligned in .x2 option |
50 | // of ldmatrix. Currently no support for .x1 so assume always |
51 | // adjust by half warp. |
52 | constexpr unsigned half_warp = 16; |
53 | // Need to adjust to 16 byte alignment, mask out un-aligned component. |
54 | constexpr unsigned mask_out = 16 - 1; |
55 | // Adjust only in upper half warp. |
56 | // use bit math to reduce strength |
57 | if (thread_id & half_warp) { |
58 | // mask out the bits where adjust_mask has 1. |
59 | addr_in_byte &= (~mask_out); |
60 | } |
61 | #endif //(__CUDA_ARCH__ < 800) |
62 | } |
63 | |
64 | } // namespace util |
65 | |
66 | // Load Matrix (per warp instruction) is to take data from SMEM to Local Memory. |
67 | // Automatically handles vectorized loads/stores in the MMA operation. |
68 | // Loads 8x8 matrix into a warp. Thread 0-7 provide the ptr that is the start |
69 | // of each row. All other threads can simply point to something valid |
70 | // (including 0). |
71 | // The x2 modifier on the instruction will actually load 2x8 rows to make a |
72 | // 16x8, |
73 | // then thread 0-15 will specify the start of each row. |
74 | // Finally is an x4 modifier producing a 32x8 using addrs from 0-31 in each |
75 | // warp. |
76 | DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) { |
77 | uint2& val = reinterpret_cast<uint2&>(out); |
78 | unsigned addr = util::toSmem(ptr); |
79 | util::adjustPartialLdMatrixAddrInTuring(addr); |
80 | asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];" |
81 | : "=r"(val.x), "=r"(val.y) |
82 | : "r"(addr)); |
83 | } |
84 | |
85 | // Same as previous, 8x8 matrix is vectorized loaded, then scattered (to perform |
86 | // transpose) so threads will hold 2 values down a column (instead of the |
87 | // previous instruction that's across a row). |
88 | DEVICE_INLINE void ldMatrixT(Array<__half, 4, 4>& out, void const* ptr) { |
89 | uint2& val = reinterpret_cast<uint2&>(out); |
90 | unsigned addr = util::toSmem(ptr); |
91 | util::adjustPartialLdMatrixAddrInTuring(addr); |
92 | asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];" |
93 | : "=r"(val.x), "=r"(val.y) |
94 | : "r"(addr)); |
95 | } |
96 | |
97 | DEVICE_INLINE void ldMatrix(Array<__half, 8, 8>& out, void const* ptr) { |
98 | uint4& val = reinterpret_cast<uint4&>(out); |
99 | unsigned addr = util::toSmem(ptr); |
100 | asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" |
101 | : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) |
102 | : "r"(addr)); |
103 | } |
104 | |
105 | DEVICE_INLINE void ldMatrixT(Array<__half, 8, 8>& out, void const* ptr) { |
106 | uint4& val = reinterpret_cast<uint4&>(out); |
107 | unsigned addr = util::toSmem(ptr); |
108 | asm volatile( |
109 | "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];" |
110 | : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) |
111 | : "r"(addr)); |
112 | } |
113 | |
114 | } // namespace Turing |
115 | |
116 | #endif // Arch 75 |
117 | |
118 | #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) |
119 | |
120 | namespace Ampere { |
121 | |
122 | // MMA instruction wrappers (sm_80+): |
123 | |
124 | namespace util { |
125 | |
126 | // Special utility for cp_async |
127 | DEVICE_INLINE unsigned toSmem(void* ptr) { |
128 | unsigned smem_ptr_uint; |
129 | |
130 | // Declare 64 bit register smem_ptr |
131 | // Convert the input to a shared memory pointer |
132 | // Convert to unsigned 32 bit pointer |
133 | asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" |
134 | : "=r"(smem_ptr_uint) |
135 | : "l"(ptr)); |
136 | |
137 | return smem_ptr_uint; |
138 | } |
139 | |
140 | } // namespace util |
141 | |
142 | // Global to SMEM load that is asynchronous, |
143 | // not guaranteed to be completed until cpAsyncBarrier() is called. |
144 | template <typename dtype, int len> |
145 | DEVICE_INLINE void cpAsync( |
146 | Array<dtype, len, len>* smem_ptr, |
147 | void const* gmem_ptr) { |
148 | unsigned smem_addr = util::toSmem(&(smem_ptr->array[0])); |
149 | constexpr int byte_size = sizeof(dtype) * len; |
150 | |
151 | static_assert( |
152 | byte_size == 4 || byte_size == 8 || byte_size == 16, |
153 | "cp_async : unsupported byte size"); |
154 | |
155 | asm volatile( |
156 | "cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(smem_addr), |
157 | "l"(gmem_ptr), |
158 | "n"(byte_size)); |
159 | } |
160 | |
161 | // Global to SMEM load that is asynchronous, |
162 | // not guaranteed to be completed until cpAsyncBarrier() is called. |
163 | template <typename dtype, int len> |
164 | DEVICE_INLINE void cpAsync( |
165 | Array<dtype, len, len>* smem_ptr, |
166 | void const* gmem_ptr, |
167 | bool predicate) { |
168 | unsigned smem_addr = util::toSmem(&(smem_ptr->array[0])); |
169 | constexpr int byte_size = sizeof(dtype) * len; |
170 | |
171 | static_assert( |
172 | byte_size == 4 || byte_size == 8 || byte_size == 16, |
173 | "cp_async : unsupported byte size"); |
174 | |
175 | asm volatile( |
176 | "{\n" |
177 | " .reg .pred p;\n" |
178 | " setp.ne.b32 p, %3, 0;\n" |
179 | "@p cp.async.ca.shared.global [%0], [%1], %2;\n" |
180 | "}\n" ::"r"(smem_addr), |
181 | "l"(gmem_ptr), |
182 | "n"(byte_size), |
183 | "r"((int)predicate)); |
184 | } |
185 | |
186 | // TODO: Might have a different category of sync if we want to build out this: |
187 | DEVICE_INLINE void cpAsyncBarrier() { |
188 | asm volatile("cp.async.wait_all;"); |
189 | } |
190 | |
191 | DEVICE_INLINE void cpAsyncCommit() { |
192 | asm volatile("cp.async.commit_group;"); |
193 | } |
194 | |
195 | template <int keep_stages> |
196 | DEVICE_INLINE void cpAsyncPartialBarrier() { |
197 | asm volatile("cp.async.wait_group %0;\n" ::"n"(keep_stages)); |
198 | } |
199 | |
200 | } // namespace Ampere |
201 | |
202 | #endif // Arch 80 |
203 | |
204 | #undef DEVICE_INLINE |
205 | )" ; |
206 | |
207 | } // namespace nvfuser_resources |
208 | |