1// Generated from "/code/pytorch/third_party/nvfuser/runtime/array.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* array_cu = R"(
7// aligned register array for vectorized load/store
8template <typename scalar_t, int size, int align_size>
9struct alignas(sizeof(scalar_t) * align_size) Array {
10 scalar_t array[size];
11
12 __device__ void set(scalar_t v) {
13#pragma unroll
14 for (int i = 0; i < size; ++i) {
15 array[i] = v;
16 }
17 }
18
19 __device__ scalar_t& operator[](const unsigned int i) {
20 return array[i];
21 }
22};
23
24// Used for vectorized allocations that are not in registers
25template <typename scalar_t, int vec_size>
26__device__ void arraySet(scalar_t* buff, scalar_t val) {
27#pragma unroll
28 for (int i = 0; i < vec_size; ++i) {
29 buff[i] = val;
30 }
31}
32
33template <typename scalar_t, int vec_size>
34__device__ void loadGeneric(scalar_t* to, scalar_t* from) {
35 // It would be really nice to use memcpy here, but one example was failing
36 // with:
37 //
38 // memcpy(to, from, vec_size * sizeof(scalar_t));
39 //
40 // Yet passing with:
41 //
42 // for(int i = 0; i < vec_size; i++){
43 // to[i] = from[i];
44 // }
45
46 switch (sizeof(scalar_t) * vec_size) {
47 case 1:
48 *reinterpret_cast<uchar1*>(to) = *reinterpret_cast<uchar1*>(from);
49 break;
50 case 2:
51 *reinterpret_cast<uchar2*>(to) = *reinterpret_cast<uchar2*>(from);
52 break;
53 case 4:
54 *reinterpret_cast<uint1*>(to) = *reinterpret_cast<uint1*>(from);
55 break;
56 case 8:
57 *reinterpret_cast<uint2*>(to) = *reinterpret_cast<uint2*>(from);
58 break;
59 case 12:
60 *reinterpret_cast<uint3*>(to) = *reinterpret_cast<uint3*>(from);
61 break;
62 case 16:
63 *reinterpret_cast<uint4*>(to) = *reinterpret_cast<uint4*>(from);
64 break;
65 }
66}
67
68// Volatile version only works with c++ fundamnetal types
69template <
70 typename scalar_t,
71 int vec_size,
72 bool is_volatile_to,
73 bool is_volatile_from>
74__device__ void loadGenericVolatile(
75 typename MaybeVolatile<scalar_t, is_volatile_to>::type* to,
76 typename MaybeVolatile<scalar_t, is_volatile_from>::type* from) {
77 switch (sizeof(scalar_t) * vec_size) {
78 // Reinterpret cast like this with volatile types only works for C++
79 // fundamental types otherwise the = operator is not defined
80 case 1:
81 *reinterpret_cast<
82 typename MaybeVolatile<unsigned char, is_volatile_to>::type*>(to) =
83 *reinterpret_cast<
84 typename MaybeVolatile<unsigned char, is_volatile_from>::type*>(
85 from);
86 break;
87 case 2:
88 *reinterpret_cast<typename MaybeVolatile<short, is_volatile_to>::type*>(
89 to) =
90 *reinterpret_cast<
91 typename MaybeVolatile<short, is_volatile_from>::type*>(from);
92 break;
93 case 4:
94 *reinterpret_cast<
95 typename MaybeVolatile<unsigned int, is_volatile_to>::type*>(to) =
96 *reinterpret_cast<
97 typename MaybeVolatile<unsigned int, is_volatile_from>::type*>(
98 from);
99 break;
100 case 8:
101 *reinterpret_cast<typename MaybeVolatile<double, is_volatile_to>::type*>(
102 to) =
103 *reinterpret_cast<
104 typename MaybeVolatile<double, is_volatile_from>::type*>(from);
105 break;
106 }
107}
108
109template <typename scalar_t, int vec_size, bool is_volatile>
110__device__ void loadLocalToGlobal(
111 typename MaybeVolatile<scalar_t, is_volatile>::type* to,
112 scalar_t* from) {
113 switch (sizeof(scalar_t) * vec_size) {
114 case 1:
115 case 2:
116 case 4:
117 loadGenericVolatile<scalar_t, vec_size, is_volatile, false>(to, from);
118 break;
119 case 8: {
120 uint2 const& data = *reinterpret_cast<uint2*>(from);
121 if (is_volatile) {
122 asm volatile(
123 "st.volatile.global.v2.s32 [%0], {%1,%2};" ::"l"(
124 (typename MaybeVolatile<uint2, is_volatile>::type*)to),
125 "r"(data.x),
126 "r"(data.y));
127 } else {
128 asm volatile(
129 "st.global.cs.v2.s32 [%0], {%1,%2};" ::"l"(
130 (typename MaybeVolatile<uint2, is_volatile>::type*)to),
131 "r"(data.x),
132 "r"(data.y));
133 }
134 break;
135 }
136 case 16: {
137 uint4 const& data = *reinterpret_cast<uint4*>(from);
138 if (is_volatile) {
139 asm volatile(
140 "st.volatile.global.v4.s32 [%0], {%1,%2,%3,%4};" ::"l"(
141 (typename MaybeVolatile<uint4, is_volatile>::type*)to),
142 "r"(data.x),
143 "r"(data.y),
144 "r"(data.z),
145 "r"(data.w));
146 } else {
147 asm volatile(
148 "st.global.cs.v4.s32 [%0], {%1,%2,%3,%4};" ::"l"(
149 (typename MaybeVolatile<uint4, is_volatile>::type*)to),
150 "r"(data.x),
151 "r"(data.y),
152 "r"(data.z),
153 "r"(data.w));
154 }
155 break;
156 }
157 }
158}
159
160template <typename scalar_t, int vec_size, bool is_volatile>
161__device__ void loadGlobalToLocal(
162 scalar_t* to,
163 typename MaybeVolatile<scalar_t, is_volatile>::type* from) {
164 switch (sizeof(scalar_t) * vec_size) {
165 case 1:
166 case 2:
167 case 4:
168 loadGenericVolatile<scalar_t, vec_size, false, is_volatile>(to, from);
169 break;
170 case 8: {
171 if (is_volatile) {
172 uint2& data = *reinterpret_cast<uint2*>(to);
173 asm volatile("ld.volatile.global.v2.s32 {%0,%1}, [%2];"
174 : "=r"(data.x), "=r"(data.y)
175 : "l"((uint2*)from));
176 break;
177 } else {
178 uint2& data = *reinterpret_cast<uint2*>(to);
179 asm volatile("ld.global.cs.v2.s32 {%0,%1}, [%2];"
180 : "=r"(data.x), "=r"(data.y)
181 : "l"((uint2*)from));
182 }
183 break;
184 }
185 case 16: {
186 if (is_volatile) {
187 uint4& data = *reinterpret_cast<uint4*>(to);
188 asm volatile("ld.volatile.global.v4.s32 {%0,%1,%2,%3}, [%4];"
189 : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
190 : "l"((uint4*)from));
191 } else {
192 uint4& data = *reinterpret_cast<uint4*>(to);
193 asm volatile("ld.global.cs.v4.s32 {%0,%1,%2,%3}, [%4];"
194 : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w)
195 : "l"((uint4*)from));
196 }
197 break;
198 }
199 }
200}
201
202template <
203 typename scalar_t,
204 int vec_size,
205 bool is_volatile_to,
206 bool is_volatile_from>
207__device__ void loadGlobalToGlobal(
208 typename MaybeVolatile<scalar_t, is_volatile_to>::type* to,
209 typename MaybeVolatile<scalar_t, is_volatile_from>::type* from) {
210 switch (sizeof(scalar_t) * vec_size) {
211 // Reinterpret cast like this with volatile types only works for C++
212 // fundamental types otherwise the = operator is not defined
213 case 1:
214 case 2:
215 case 4:
216 case 8:
217 loadGenericVolatile<scalar_t, vec_size, is_volatile_to, is_volatile_from>(
218 to, from);
219 break;
220 case 12: {
221 uint3 local_intermediate;
222 loadGlobalToLocal<scalar_t, vec_size, is_volatile_from>(
223 reinterpret_cast<scalar_t*>(&local_intermediate), from);
224 loadLocalToGlobal<scalar_t, vec_size, is_volatile_to>(
225 to, reinterpret_cast<scalar_t*>(&local_intermediate));
226 break;
227 }
228 case 16: {
229 uint4 local_intermediate;
230 loadGlobalToLocal<scalar_t, vec_size, is_volatile_from>(
231 reinterpret_cast<scalar_t*>(&local_intermediate), from);
232 loadLocalToGlobal<scalar_t, vec_size, is_volatile_to>(
233 to, reinterpret_cast<scalar_t*>(&local_intermediate));
234 break;
235 }
236 }
237}
238)";
239
240} // namespace nvfuser_resources
241