1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/array.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* array_cu = R"( |
7 | // aligned register array for vectorized load/store |
8 | template <typename scalar_t, int size, int align_size> |
9 | struct alignas(sizeof(scalar_t) * align_size) Array { |
10 | scalar_t array[size]; |
11 | |
12 | __device__ void set(scalar_t v) { |
13 | #pragma unroll |
14 | for (int i = 0; i < size; ++i) { |
15 | array[i] = v; |
16 | } |
17 | } |
18 | |
19 | __device__ scalar_t& operator[](const unsigned int i) { |
20 | return array[i]; |
21 | } |
22 | }; |
23 | |
24 | // Used for vectorized allocations that are not in registers |
25 | template <typename scalar_t, int vec_size> |
26 | __device__ void arraySet(scalar_t* buff, scalar_t val) { |
27 | #pragma unroll |
28 | for (int i = 0; i < vec_size; ++i) { |
29 | buff[i] = val; |
30 | } |
31 | } |
32 | |
33 | template <typename scalar_t, int vec_size> |
34 | __device__ void loadGeneric(scalar_t* to, scalar_t* from) { |
35 | // It would be really nice to use memcpy here, but one example was failing |
36 | // with: |
37 | // |
38 | // memcpy(to, from, vec_size * sizeof(scalar_t)); |
39 | // |
40 | // Yet passing with: |
41 | // |
42 | // for(int i = 0; i < vec_size; i++){ |
43 | // to[i] = from[i]; |
44 | // } |
45 | |
46 | switch (sizeof(scalar_t) * vec_size) { |
47 | case 1: |
48 | *reinterpret_cast<uchar1*>(to) = *reinterpret_cast<uchar1*>(from); |
49 | break; |
50 | case 2: |
51 | *reinterpret_cast<uchar2*>(to) = *reinterpret_cast<uchar2*>(from); |
52 | break; |
53 | case 4: |
54 | *reinterpret_cast<uint1*>(to) = *reinterpret_cast<uint1*>(from); |
55 | break; |
56 | case 8: |
57 | *reinterpret_cast<uint2*>(to) = *reinterpret_cast<uint2*>(from); |
58 | break; |
59 | case 12: |
60 | *reinterpret_cast<uint3*>(to) = *reinterpret_cast<uint3*>(from); |
61 | break; |
62 | case 16: |
63 | *reinterpret_cast<uint4*>(to) = *reinterpret_cast<uint4*>(from); |
64 | break; |
65 | } |
66 | } |
67 | |
68 | // Volatile version only works with c++ fundamnetal types |
69 | template < |
70 | typename scalar_t, |
71 | int vec_size, |
72 | bool is_volatile_to, |
73 | bool is_volatile_from> |
74 | __device__ void loadGenericVolatile( |
75 | typename MaybeVolatile<scalar_t, is_volatile_to>::type* to, |
76 | typename MaybeVolatile<scalar_t, is_volatile_from>::type* from) { |
77 | switch (sizeof(scalar_t) * vec_size) { |
78 | // Reinterpret cast like this with volatile types only works for C++ |
79 | // fundamental types otherwise the = operator is not defined |
80 | case 1: |
81 | *reinterpret_cast< |
82 | typename MaybeVolatile<unsigned char, is_volatile_to>::type*>(to) = |
83 | *reinterpret_cast< |
84 | typename MaybeVolatile<unsigned char, is_volatile_from>::type*>( |
85 | from); |
86 | break; |
87 | case 2: |
88 | *reinterpret_cast<typename MaybeVolatile<short, is_volatile_to>::type*>( |
89 | to) = |
90 | *reinterpret_cast< |
91 | typename MaybeVolatile<short, is_volatile_from>::type*>(from); |
92 | break; |
93 | case 4: |
94 | *reinterpret_cast< |
95 | typename MaybeVolatile<unsigned int, is_volatile_to>::type*>(to) = |
96 | *reinterpret_cast< |
97 | typename MaybeVolatile<unsigned int, is_volatile_from>::type*>( |
98 | from); |
99 | break; |
100 | case 8: |
101 | *reinterpret_cast<typename MaybeVolatile<double, is_volatile_to>::type*>( |
102 | to) = |
103 | *reinterpret_cast< |
104 | typename MaybeVolatile<double, is_volatile_from>::type*>(from); |
105 | break; |
106 | } |
107 | } |
108 | |
109 | template <typename scalar_t, int vec_size, bool is_volatile> |
110 | __device__ void loadLocalToGlobal( |
111 | typename MaybeVolatile<scalar_t, is_volatile>::type* to, |
112 | scalar_t* from) { |
113 | switch (sizeof(scalar_t) * vec_size) { |
114 | case 1: |
115 | case 2: |
116 | case 4: |
117 | loadGenericVolatile<scalar_t, vec_size, is_volatile, false>(to, from); |
118 | break; |
119 | case 8: { |
120 | uint2 const& data = *reinterpret_cast<uint2*>(from); |
121 | if (is_volatile) { |
122 | asm volatile( |
123 | "st.volatile.global.v2.s32 [%0], {%1,%2};" ::"l"( |
124 | (typename MaybeVolatile<uint2, is_volatile>::type*)to), |
125 | "r"(data.x), |
126 | "r"(data.y)); |
127 | } else { |
128 | asm volatile( |
129 | "st.global.cs.v2.s32 [%0], {%1,%2};" ::"l"( |
130 | (typename MaybeVolatile<uint2, is_volatile>::type*)to), |
131 | "r"(data.x), |
132 | "r"(data.y)); |
133 | } |
134 | break; |
135 | } |
136 | case 16: { |
137 | uint4 const& data = *reinterpret_cast<uint4*>(from); |
138 | if (is_volatile) { |
139 | asm volatile( |
140 | "st.volatile.global.v4.s32 [%0], {%1,%2,%3,%4};" ::"l"( |
141 | (typename MaybeVolatile<uint4, is_volatile>::type*)to), |
142 | "r"(data.x), |
143 | "r"(data.y), |
144 | "r"(data.z), |
145 | "r"(data.w)); |
146 | } else { |
147 | asm volatile( |
148 | "st.global.cs.v4.s32 [%0], {%1,%2,%3,%4};" ::"l"( |
149 | (typename MaybeVolatile<uint4, is_volatile>::type*)to), |
150 | "r"(data.x), |
151 | "r"(data.y), |
152 | "r"(data.z), |
153 | "r"(data.w)); |
154 | } |
155 | break; |
156 | } |
157 | } |
158 | } |
159 | |
160 | template <typename scalar_t, int vec_size, bool is_volatile> |
161 | __device__ void loadGlobalToLocal( |
162 | scalar_t* to, |
163 | typename MaybeVolatile<scalar_t, is_volatile>::type* from) { |
164 | switch (sizeof(scalar_t) * vec_size) { |
165 | case 1: |
166 | case 2: |
167 | case 4: |
168 | loadGenericVolatile<scalar_t, vec_size, false, is_volatile>(to, from); |
169 | break; |
170 | case 8: { |
171 | if (is_volatile) { |
172 | uint2& data = *reinterpret_cast<uint2*>(to); |
173 | asm volatile("ld.volatile.global.v2.s32 {%0,%1}, [%2];" |
174 | : "=r"(data.x), "=r"(data.y) |
175 | : "l"((uint2*)from)); |
176 | break; |
177 | } else { |
178 | uint2& data = *reinterpret_cast<uint2*>(to); |
179 | asm volatile("ld.global.cs.v2.s32 {%0,%1}, [%2];" |
180 | : "=r"(data.x), "=r"(data.y) |
181 | : "l"((uint2*)from)); |
182 | } |
183 | break; |
184 | } |
185 | case 16: { |
186 | if (is_volatile) { |
187 | uint4& data = *reinterpret_cast<uint4*>(to); |
188 | asm volatile("ld.volatile.global.v4.s32 {%0,%1,%2,%3}, [%4];" |
189 | : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) |
190 | : "l"((uint4*)from)); |
191 | } else { |
192 | uint4& data = *reinterpret_cast<uint4*>(to); |
193 | asm volatile("ld.global.cs.v4.s32 {%0,%1,%2,%3}, [%4];" |
194 | : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) |
195 | : "l"((uint4*)from)); |
196 | } |
197 | break; |
198 | } |
199 | } |
200 | } |
201 | |
202 | template < |
203 | typename scalar_t, |
204 | int vec_size, |
205 | bool is_volatile_to, |
206 | bool is_volatile_from> |
207 | __device__ void loadGlobalToGlobal( |
208 | typename MaybeVolatile<scalar_t, is_volatile_to>::type* to, |
209 | typename MaybeVolatile<scalar_t, is_volatile_from>::type* from) { |
210 | switch (sizeof(scalar_t) * vec_size) { |
211 | // Reinterpret cast like this with volatile types only works for C++ |
212 | // fundamental types otherwise the = operator is not defined |
213 | case 1: |
214 | case 2: |
215 | case 4: |
216 | case 8: |
217 | loadGenericVolatile<scalar_t, vec_size, is_volatile_to, is_volatile_from>( |
218 | to, from); |
219 | break; |
220 | case 12: { |
221 | uint3 local_intermediate; |
222 | loadGlobalToLocal<scalar_t, vec_size, is_volatile_from>( |
223 | reinterpret_cast<scalar_t*>(&local_intermediate), from); |
224 | loadLocalToGlobal<scalar_t, vec_size, is_volatile_to>( |
225 | to, reinterpret_cast<scalar_t*>(&local_intermediate)); |
226 | break; |
227 | } |
228 | case 16: { |
229 | uint4 local_intermediate; |
230 | loadGlobalToLocal<scalar_t, vec_size, is_volatile_from>( |
231 | reinterpret_cast<scalar_t*>(&local_intermediate), from); |
232 | loadLocalToGlobal<scalar_t, vec_size, is_volatile_to>( |
233 | to, reinterpret_cast<scalar_t*>(&local_intermediate)); |
234 | break; |
235 | } |
236 | } |
237 | } |
238 | )" ; |
239 | |
240 | } // namespace nvfuser_resources |
241 | |