1// Generated from "/code/pytorch/third_party/nvfuser/runtime/memory.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* memory_cu = R"(
7// Utility macro for this file
8#define DEVICE_INLINE __device__ inline
9
10#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
11
12namespace Turing {
13
14namespace util {
15
16// Utility for converting generic pointer to SMEM pointer in PTX.
17// We should review vectorized load/stores with shared memory.
18// SMEM memory movement PTX is only Global -> SMEM, SMEM -> Local, Local ->
19// SMEM, and this is needed for these PTX instructions to provide the SMEM
20// pointer.
21DEVICE_INLINE unsigned toSmem(const void* raw_ptr) {
22 unsigned smem_ptr_uint;
23 asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
24 : "=r"(smem_ptr_uint)
25 : "l"(raw_ptr));
26
27 return smem_ptr_uint;
28}
29
30// LdMatrix has .x1, .x2 and .x4 options, currently we actively use .x2 and
31// .x4. In .x2 option. the the address register of upper half warp (lane 16-31)
32// are un-used but on Turing [sm75,sm80) architecture these un-used addresses
33// need to be valid, in the sense that:
34// 1. The data it points to has to be within allocated shared mem buffer.
35// 2. The address needs to be aligned to 16 byte.
36// See also:
37// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
38// This function addresses 2. above by masking out the sub-16B component
39// of the address in upper warp and 1. is guaranteed by ldmatrix swizzle
40// util.
41// This will **not** affect any functionality. This is just modification
42// of unused pointers to satisfy the alignment requirement on Turing
43// hardware.
44// The alignment requirement is lifted on sm80+,
45// so this function is a no-op on Ampere or above.
46DEVICE_INLINE void adjustPartialLdMatrixAddrInTuring(unsigned& addr_in_byte) {
47#if (__CUDA_ARCH__ < 800)
48 const unsigned thread_id = threadIdx.x;
49 // Upper half warp has 8 bytes offset from aligned in .x2 option
50 // of ldmatrix. Currently no support for .x1 so assume always
51 // adjust by half warp.
52 constexpr unsigned half_warp = 16;
53 // Need to adjust to 16 byte alignment, mask out un-aligned component.
54 constexpr unsigned mask_out = 16 - 1;
55 // Adjust only in upper half warp.
56 // use bit math to reduce strength
57 if (thread_id & half_warp) {
58 // mask out the bits where adjust_mask has 1.
59 addr_in_byte &= (~mask_out);
60 }
61#endif //(__CUDA_ARCH__ < 800)
62}
63
64} // namespace util
65
66// Load Matrix (per warp instruction) is to take data from SMEM to Local Memory.
67// Automatically handles vectorized loads/stores in the MMA operation.
68// Loads 8x8 matrix into a warp. Thread 0-7 provide the ptr that is the start
69// of each row. All other threads can simply point to something valid
70// (including 0).
71// The x2 modifier on the instruction will actually load 2x8 rows to make a
72// 16x8,
73// then thread 0-15 will specify the start of each row.
74// Finally is an x4 modifier producing a 32x8 using addrs from 0-31 in each
75// warp.
76DEVICE_INLINE void ldMatrix(Array<__half, 4, 4>& out, void const* ptr) {
77 uint2& val = reinterpret_cast<uint2&>(out);
78 unsigned addr = util::toSmem(ptr);
79 util::adjustPartialLdMatrixAddrInTuring(addr);
80 asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0,%1}, [%2];"
81 : "=r"(val.x), "=r"(val.y)
82 : "r"(addr));
83}
84
85// Same as previous, 8x8 matrix is vectorized loaded, then scattered (to perform
86// transpose) so threads will hold 2 values down a column (instead of the
87// previous instruction that's across a row).
88DEVICE_INLINE void ldMatrixT(Array<__half, 4, 4>& out, void const* ptr) {
89 uint2& val = reinterpret_cast<uint2&>(out);
90 unsigned addr = util::toSmem(ptr);
91 util::adjustPartialLdMatrixAddrInTuring(addr);
92 asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0,%1}, [%2];"
93 : "=r"(val.x), "=r"(val.y)
94 : "r"(addr));
95}
96
97DEVICE_INLINE void ldMatrix(Array<__half, 8, 8>& out, void const* ptr) {
98 uint4& val = reinterpret_cast<uint4&>(out);
99 unsigned addr = util::toSmem(ptr);
100 asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];"
101 : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
102 : "r"(addr));
103}
104
105DEVICE_INLINE void ldMatrixT(Array<__half, 8, 8>& out, void const* ptr) {
106 uint4& val = reinterpret_cast<uint4&>(out);
107 unsigned addr = util::toSmem(ptr);
108 asm volatile(
109 "ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];"
110 : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
111 : "r"(addr));
112}
113
114} // namespace Turing
115
116#endif // Arch 75
117
118#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
119
120namespace Ampere {
121
122// MMA instruction wrappers (sm_80+):
123
124namespace util {
125
126// Special utility for cp_async
127DEVICE_INLINE unsigned toSmem(void* ptr) {
128 unsigned smem_ptr_uint;
129
130 // Declare 64 bit register smem_ptr
131 // Convert the input to a shared memory pointer
132 // Convert to unsigned 32 bit pointer
133 asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
134 : "=r"(smem_ptr_uint)
135 : "l"(ptr));
136
137 return smem_ptr_uint;
138}
139
140} // namespace util
141
142// Global to SMEM load that is asynchronous,
143// not guaranteed to be completed until cpAsyncBarrier() is called.
144template <typename dtype, int len>
145DEVICE_INLINE void cpAsync(
146 Array<dtype, len, len>* smem_ptr,
147 void const* gmem_ptr) {
148 unsigned smem_addr = util::toSmem(&(smem_ptr->array[0]));
149 constexpr int byte_size = sizeof(dtype) * len;
150
151 static_assert(
152 byte_size == 4 || byte_size == 8 || byte_size == 16,
153 "cp_async : unsupported byte size");
154
155 asm volatile(
156 "cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(smem_addr),
157 "l"(gmem_ptr),
158 "n"(byte_size));
159}
160
161// Global to SMEM load that is asynchronous,
162// not guaranteed to be completed until cpAsyncBarrier() is called.
163template <typename dtype, int len>
164DEVICE_INLINE void cpAsync(
165 Array<dtype, len, len>* smem_ptr,
166 void const* gmem_ptr,
167 bool predicate) {
168 unsigned smem_addr = util::toSmem(&(smem_ptr->array[0]));
169 constexpr int byte_size = sizeof(dtype) * len;
170
171 static_assert(
172 byte_size == 4 || byte_size == 8 || byte_size == 16,
173 "cp_async : unsupported byte size");
174
175 asm volatile(
176 "{\n"
177 " .reg .pred p;\n"
178 " setp.ne.b32 p, %3, 0;\n"
179 "@p cp.async.ca.shared.global [%0], [%1], %2;\n"
180 "}\n" ::"r"(smem_addr),
181 "l"(gmem_ptr),
182 "n"(byte_size),
183 "r"((int)predicate));
184}
185
186// TODO: Might have a different category of sync if we want to build out this:
187DEVICE_INLINE void cpAsyncBarrier() {
188 asm volatile("cp.async.wait_all;");
189}
190
191DEVICE_INLINE void cpAsyncCommit() {
192 asm volatile("cp.async.commit_group;");
193}
194
195template <int keep_stages>
196DEVICE_INLINE void cpAsyncPartialBarrier() {
197 asm volatile("cp.async.wait_group %0;\n" ::"n"(keep_stages));
198}
199
200} // namespace Ampere
201
202#endif // Arch 80
203
204#undef DEVICE_INLINE
205)";
206
207} // namespace nvfuser_resources
208