1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file ptx.cc |
22 | */ |
23 | |
24 | #include "ptx.h" |
25 | |
26 | #include <algorithm> |
27 | #include <string> |
28 | #include <tuple> |
29 | #include <utility> |
30 | #include <vector> |
31 | |
32 | namespace tvm { |
33 | namespace codegen { |
34 | |
35 | // PTX related data structures and functions. |
36 | namespace ptx { |
37 | |
38 | /*! |
39 | * \brief PTX data type. |
40 | * \note |
41 | * PTX fundamental data types: |
42 | * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types |
43 | * PTX matrix data types: |
44 | * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types |
45 | */ |
46 | enum class DataType : int { |
47 | kInt4 = 0, |
48 | kUInt4 = 1, |
49 | kInt8 = 2, |
50 | kUInt8 = 3, |
51 | kInt16 = 4, |
52 | kUInt16 = 5, |
53 | kInt32 = 6, |
54 | kUInt32 = 7, |
55 | kInt64 = 8, |
56 | kUInt64 = 9, |
57 | kFloat16 = 10, |
58 | kBFloat16 = 11, |
59 | kFloat16x2 = 12, |
60 | kFloat32 = 13, |
61 | kTensorFloat32 = 14, |
62 | kFloat64 = 15, |
63 | kBit1 = 16, |
64 | kBit8 = 17, |
65 | kBit16 = 18, |
66 | kBit32 = 19, |
67 | kBit64 = 20, |
68 | }; |
69 | |
70 | static const char* dtype_str[] = {".s4" , ".u4" , ".s8" , ".u8" , ".s16" , ".u16" , ".s32" , |
71 | ".u32" , ".s64" , ".u64" , ".f16" , ".bf16" , ".f16x2" , ".f32" , |
72 | ".tf32" , ".f64" , ".b1" , ".b8" , ".b16" , ".b32" , ".b64" }; |
73 | static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, |
74 | 16, 32, 32, 32, 64, 1, 8, 16, 32, 64}; |
75 | |
76 | /*! |
77 | * \brief Create PTX data type from string. |
78 | */ |
79 | inline DataType DTypeFromString(const std::string str) { |
80 | if (str == "int4" || str == ".s4" ) { |
81 | return DataType::kInt4; |
82 | } else if (str == "uint4" || str == ".u4" ) { |
83 | return DataType::kUInt4; |
84 | } else if (str == "int8" || str == ".s8" ) { |
85 | return DataType::kInt8; |
86 | } else if (str == "uint8" || str == ".u8" ) { |
87 | return DataType::kUInt8; |
88 | } else if (str == "int16" || str == ".s16" ) { |
89 | return DataType::kInt16; |
90 | } else if (str == "uint16" || str == ".u16" ) { |
91 | return DataType::kUInt16; |
92 | } else if (str == "int32" || str == ".s32" ) { |
93 | return DataType::kInt32; |
94 | } else if (str == "uint32" || str == ".u32" ) { |
95 | return DataType::kUInt32; |
96 | } else if (str == "int64" || str == ".s64" ) { |
97 | return DataType::kInt64; |
98 | } else if (str == "uint64" || str == ".u64" ) { |
99 | return DataType::kUInt64; |
100 | } else if (str == "float16" || str == "fp16" || str == ".f16" ) { |
101 | return DataType::kFloat16; |
102 | } else if (str == "bfloat16" || str == "bf16" ) { |
103 | return DataType::kBFloat16; |
104 | } else if (str == ".f16x2" ) { |
105 | return DataType::kFloat16x2; |
106 | } else if (str == "float32" || str == "fp32" || str == ".f32" ) { |
107 | return DataType::kFloat32; |
108 | } else if (str == "tf32" ) { |
109 | return DataType::kTensorFloat32; |
110 | } else if (str == "float64" || str == "fp64" || str == ".f64" ) { |
111 | return DataType::kFloat64; |
112 | } else if (str == "int1" || str == ".b1" ) { |
113 | return DataType::kBit1; |
114 | } else if (str == ".b8" ) { |
115 | return DataType::kBit8; |
116 | } else if (str == ".b16" ) { |
117 | return DataType::kBit16; |
118 | } else if (str == ".b32" ) { |
119 | return DataType::kBit32; |
120 | } else if (str == ".b64" ) { |
121 | return DataType::kBit64; |
122 | } else { |
123 | LOG(FATAL) << "Unrecognized PTX data type " << str; |
124 | } |
125 | } |
126 | |
127 | /*! |
128 | * \brief Get the string representation of given PTX data type. |
129 | */ |
130 | inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast<int>(dtype)]; } |
131 | |
132 | /*! |
133 | * \brief Get the number of bits of given PTX data type. |
134 | */ |
135 | inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast<int>(dtype)]; } |
136 | |
137 | /*! |
138 | * \brief Extract the value m, n, k from string m*n*k* |
139 | */ |
140 | inline std::tuple<int, int, int> ParseMMAShape(const std::string& str) { |
141 | size_t pos_m = str.find("m" ), pos_n = str.find("n" ), pos_k = str.find("k" ); |
142 | CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos) |
143 | << "Cannot parse MMA shape " << str; |
144 | int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)), |
145 | n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1)); |
146 | return std::make_tuple(m, n, k); |
147 | } |
148 | |
149 | /*! |
150 | * \brief Layout Type |
151 | */ |
152 | enum class LayoutType : int { kRowMajor = 0, kColumnMajor = 1 }; |
153 | |
154 | /*! |
155 | * \brief Parse layout type |
156 | */ |
157 | LayoutType LayoutTypeFromString(const std::string& str) { |
158 | if (str == "row" ) { |
159 | return LayoutType::kRowMajor; |
160 | } else if (str == "col" ) { |
161 | return LayoutType::kColumnMajor; |
162 | } else { |
163 | LOG(FATAL) << "Unrecognized layout type " << str; |
164 | } |
165 | } |
166 | |
167 | static const char* layout_type_str[] = {"row" , "col" }; |
168 | |
169 | /*! |
170 | * \brief Convert layout type to string. |
171 | */ |
172 | inline std::string LayoutTypeToString(LayoutType layout) { |
173 | return layout_type_str[static_cast<int>(layout)]; |
174 | } |
175 | |
176 | /*! |
177 | * \brief MMA Configurations, used to determine validity. |
178 | */ |
179 | struct MMAConfig { |
180 | explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op, bool sparse) |
181 | : m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op), sparse(sparse) {} |
182 | int m, n, k; |
183 | DataType dtype_mul; |
184 | bool use_bit_op; |
185 | bool sparse; |
186 | inline bool operator==(const MMAConfig& other) { |
187 | return m == other.m && n == other.n && k == other.k && dtype_mul == other.dtype_mul && |
188 | use_bit_op == other.use_bit_op && sparse == other.sparse; |
189 | } |
190 | }; |
191 | |
192 | /*! |
193 | * \brief Valid MMA configurations |
194 | * \note Reference: |
195 | * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape |
196 | */ |
197 | const MMAConfig valid_mma_configs[] = { |
198 | MMAConfig(8, 8, 4, DataType::kFloat64, false, false), |
199 | MMAConfig(8, 8, 4, DataType::kFloat16, false, false), |
200 | MMAConfig(16, 8, 8, DataType::kFloat16, false, false), |
201 | MMAConfig(16, 8, 16, DataType::kFloat16, false, false), |
202 | MMAConfig(16, 8, 8, DataType::kBFloat16, false, false), |
203 | MMAConfig(16, 8, 16, DataType::kBFloat16, false, false), |
204 | MMAConfig(16, 8, 4, DataType::kTensorFloat32, false, false), |
205 | MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, false), |
206 | MMAConfig(8, 8, 16, DataType::kInt8, false, false), |
207 | MMAConfig(16, 8, 16, DataType::kInt8, false, false), |
208 | MMAConfig(16, 8, 32, DataType::kInt8, false, false), |
209 | MMAConfig(8, 8, 16, DataType::kUInt8, false, false), |
210 | MMAConfig(16, 8, 16, DataType::kUInt8, false, false), |
211 | MMAConfig(16, 8, 32, DataType::kUInt8, false, false), |
212 | MMAConfig(8, 8, 32, DataType::kInt4, false, false), |
213 | MMAConfig(16, 8, 32, DataType::kInt4, false, false), |
214 | MMAConfig(16, 8, 64, DataType::kInt4, false, false), |
215 | MMAConfig(8, 8, 32, DataType::kUInt4, false, false), |
216 | MMAConfig(16, 8, 32, DataType::kUInt4, false, false), |
217 | MMAConfig(16, 8, 64, DataType::kUInt4, false, false), |
218 | MMAConfig(8, 8, 128, DataType::kBit1, true, false), |
219 | MMAConfig(16, 8, 128, DataType::kBit1, true, false), |
220 | MMAConfig(16, 8, 256, DataType::kBit1, true, false), |
221 | MMAConfig(16, 8, 16, DataType::kFloat16, false, true), |
222 | MMAConfig(16, 8, 32, DataType::kFloat16, false, true), |
223 | MMAConfig(16, 8, 16, DataType::kBFloat16, false, true), |
224 | MMAConfig(16, 8, 32, DataType::kBFloat16, false, true), |
225 | MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, true), |
226 | MMAConfig(16, 8, 16, DataType::kTensorFloat32, false, true), |
227 | MMAConfig(16, 8, 32, DataType::kInt8, false, true), |
228 | MMAConfig(16, 8, 64, DataType::kInt8, false, true), |
229 | MMAConfig(16, 8, 32, DataType::kUInt8, false, true), |
230 | MMAConfig(16, 8, 64, DataType::kUInt8, false, true), |
231 | MMAConfig(16, 8, 64, DataType::kInt4, false, true), |
232 | MMAConfig(16, 8, 128, DataType::kInt4, false, true), |
233 | MMAConfig(16, 8, 64, DataType::kUInt4, false, true), |
234 | MMAConfig(16, 8, 128, DataType::kUInt4, false, true), |
235 | }; |
236 | |
237 | /*! |
238 | * \brief Check whether the multiplicand data type and accumulator data type is valid for MMA |
239 | * computation. |
240 | * \param dtype_a The data type of multiplicand a. |
241 | * \param dtype_b The data type of multiplicand b. |
242 | * \param dtype_c The data type of accumulator c. |
243 | * \note Reference: |
244 | * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types |
245 | */ |
246 | void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) { |
247 | std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) + |
248 | DTypeToString(dtype_b) + " do not match." ; |
249 | // check a and b |
250 | switch (dtype_a) { |
251 | case DataType::kBit1: |
252 | case DataType::kFloat16: |
253 | case DataType::kBFloat16: |
254 | case DataType::kTensorFloat32: |
255 | case DataType::kFloat64: |
256 | CHECK(dtype_a == dtype_b) << ab_not_match_err_str; |
257 | break; |
258 | case DataType::kInt4: |
259 | case DataType::kUInt4: |
260 | CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) << ab_not_match_err_str; |
261 | break; |
262 | case DataType::kInt8: |
263 | case DataType::kUInt8: |
264 | CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << ab_not_match_err_str; |
265 | break; |
266 | default: |
267 | CHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a) |
268 | << DTypeToString(dtype_b); |
269 | } |
270 | // check a,b and c |
271 | switch (dtype_a) { |
272 | case DataType::kBit1: |
273 | case DataType::kInt4: |
274 | case DataType::kUInt4: |
275 | case DataType::kInt8: |
276 | case DataType::kUInt8: |
277 | CHECK(dtype_c == DataType::kInt32) |
278 | << "For multiplicand data type " << DTypeToString(dtype_a) << DTypeToString(dtype_b) |
279 | << ", accumulator data type should be s32." ; |
280 | break; |
281 | case DataType::kFloat16: |
282 | CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32) |
283 | << "For multiplicand data type f16, accumulator data type should be f16/f32." ; |
284 | break; |
285 | case DataType::kBFloat16: |
286 | case DataType::kTensorFloat32: |
287 | CHECK(dtype_c == DataType::kFloat32) |
288 | << "For multiplicand data type bf16/tf32, accumulator data type can only be f32." ; |
289 | break; |
290 | case DataType::kFloat64: |
291 | CHECK(dtype_c == DataType::kFloat64) |
292 | << "For multiplicand data type f64, accumulator data type can only be f64." ; |
293 | break; |
294 | default: |
295 | CHECK(false) << "Invalid multiplicand/accumulator data types: " << DTypeToString(dtype_a) |
296 | << DTypeToString(dtype_b) << DTypeToString(dtype_c) << "." ; |
297 | } |
298 | } |
299 | |
300 | /*! |
301 | * \brief Check whether the given configuration is valid for MMA computation. |
302 | * \param m The M in mMnNkK of MMA instructions. |
303 | * \param n The N in mMnNkK of MMA instructions. |
304 | * \param k The K in mMnNkK of MMA instructions. |
305 | * \param layout_a The layout of multiplicand A (row/col). |
306 | * \param layout_b The layout of multiplicand B (row/col). |
307 | * \param dtype_a The data type of multiplicand A. |
308 | * \param dtype_b The data type of multiplicand B. |
309 | * \param dtype_c The data type of accumulator C. |
310 | * \param bit_op The bit operator for 1-bit MMA computation, can be "xor"/"and" or ""(if it's not |
311 | * 1-bit MMA). |
312 | * \param sparse Whether it's Sparse MMA or not. |
313 | * \param saturate Whether saturate output or not. |
314 | */ |
315 | void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType layout_b, |
316 | DataType dtype_a, DataType dtype_b, DataType dtype_c, |
317 | const std::string& bit_op, bool sparse, bool saturate) { |
318 | CHECK(bit_op == "xor" || bit_op == "and" || bit_op == "" ) |
319 | << "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and." ; |
320 | bool use_bit_op = !bit_op.empty(); |
321 | if (use_bit_op) { |
322 | CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1-bit multiplicand." ; |
323 | } |
324 | CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c); |
325 | if (saturate) { |
326 | CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a == DataType::kInt8 || |
327 | dtype_a == DataType::kUInt8) |
328 | << "Output saturation only applicable to multiplicand type s4/u4/s8/u8." ; |
329 | } |
330 | |
331 | if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) { |
332 | // Only MMA on m8n8k4 for fp16 supports customized layouts. |
333 | CHECK(layout_a == LayoutType::kRowMajor && layout_b == LayoutType::kColumnMajor) |
334 | << "Invalid layout combination " << LayoutTypeToString(layout_a) << "," |
335 | << LayoutTypeToString(layout_b) << "." ; |
336 | } |
337 | |
338 | MMAConfig config(m, n, k, dtype_a, use_bit_op, sparse); |
339 | bool match = false; |
340 | for (const MMAConfig& valid_config : valid_mma_configs) { |
341 | if (config == valid_config) { |
342 | match = true; |
343 | break; |
344 | } |
345 | } |
346 | CHECK(match) << "Cannot find matched MMA configurations." ; |
347 | } |
348 | |
349 | /*! |
350 | * \brief Fragment attributes |
351 | */ |
352 | class FragAttrs { |
353 | public: |
354 | explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type) |
355 | : reg_type(reg_type), size(size), ptr_type(ptr_type) {} |
356 | /*! \brief PTX register type */ |
357 | char reg_type; |
358 | /*! \brief Fragment size */ |
359 | uint32_t size; |
360 | /*! \brief Fragment pointer type */ |
361 | std::string ptr_type; |
362 | }; |
363 | |
364 | /*! |
365 | * \brief Fragment attributes of given data type. |
366 | */ |
367 | inline FragAttrs GetFragAttrs(DataType dtype) { |
368 | switch (dtype) { |
369 | case DataType::kBit1: |
370 | case DataType::kInt4: |
371 | case DataType::kUInt4: |
372 | case DataType::kInt8: |
373 | case DataType::kUInt8: |
374 | case DataType::kBit16: |
375 | case DataType::kFloat16: // .f16x2 register |
376 | case DataType::kBFloat16: |
377 | case DataType::kTensorFloat32: |
378 | return FragAttrs('r', 32, "(unsigned *)" ); |
379 | case DataType::kInt32: |
380 | return FragAttrs('r', 32, "(int *)" ); |
381 | case DataType::kFloat32: |
382 | return FragAttrs('f', 32, "(float *)" ); |
383 | case DataType::kFloat64: |
384 | return FragAttrs('d', 64, "(double *)" ); |
385 | default: |
386 | ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA." ; |
387 | return FragAttrs('\0', 0, "" ); |
388 | } |
389 | } |
390 | |
391 | }; // namespace ptx |
392 | |
393 | /*! |
394 | * \brief Replace patterns with replacement strings. |
395 | * \note should use std::format instead when codebase is ported to C++20. |
396 | */ |
397 | class Replacer { |
398 | public: |
399 | void register_rule(const std::string& pattern, const std::string& replacement) { |
400 | _rules.emplace_back(pattern, replacement); |
401 | } |
402 | std::string rewrite(std::string str) { |
403 | for (auto&& rule : _rules) { |
404 | auto [pattern, replacement] = rule; |
405 | size_t len = pattern.size(); |
406 | size_t new_len = replacement.size(); |
407 | size_t pos = str.find(pattern); |
408 | while (pos != std::string::npos) { |
409 | str = str.replace(pos, len, replacement); |
410 | pos = str.find(pattern, pos + new_len); |
411 | } |
412 | } |
413 | return str; |
414 | } |
415 | void empty_rules() { _rules.clear(); } |
416 | |
417 | private: |
418 | std::vector<std::pair<std::string, std::string>> _rules; |
419 | }; |
420 | |
421 | /*! |
422 | * \brief Get the number of MMA computations for given shape and datatype. |
423 | */ |
424 | inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype) { |
425 | if (m == 8 && n == 8 && k == 4 && dtype == ptx::DataType::kFloat16) { |
426 | // MMA for m8n8k4 on fp16 would launch 4 MMA computations instead of one. |
427 | return 4; |
428 | } else { |
429 | return 1; |
430 | } |
431 | } |
432 | |
433 | /*! |
434 | * \brief Return template string, input operands string and output operands string. |
435 | * \param m The M in mMnNkK of MMA instructions. |
436 | * \param n The N in mMnNkK of MMA instructions. |
437 | * \param k The K in mMnNkK of MMA instructions. |
438 | * \param dtype_a The data type of multiplicand a. |
439 | * \param dtype_b The data type of multiplicand b. |
440 | * \param dtype_c The data type of accumulator c. |
441 | * \param sparse Whether it's Sparse MMA or not. |
442 | */ |
443 | inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, int n, int k, |
444 | ptx::DataType dtype_a, |
445 | ptx::DataType dtype_b, |
446 | ptx::DataType dtype_c, |
447 | bool sparse) { |
448 | std::stringstream templates, inputs, outputs; |
449 | const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a), |
450 | frag_attr_b = ptx::GetFragAttrs(dtype_b), |
451 | frag_attr_c = ptx::GetFragAttrs(dtype_c); |
452 | constexpr uint32_t warp_size = 32; |
453 | const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a); |
454 | const int num_operands_a = |
455 | (m * k) * ptx::DTypeBits(dtype_a) / frag_attr_a.size / threads / (sparse ? 2 : 1), |
456 | num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads, |
457 | num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads; |
458 | |
459 | // generate templates; |
460 | int arg_counter = 0; |
461 | templates << "{" |
462 | << "%" << arg_counter++; |
463 | for (int i = 1; i < num_operands_c; ++i) { |
464 | templates << ", %" << arg_counter++; |
465 | } |
466 | templates << "}, {" |
467 | << "%" << arg_counter++; |
468 | for (int i = 1; i < num_operands_a; ++i) { |
469 | templates << ", %" << arg_counter++; |
470 | } |
471 | templates << "}, {" |
472 | << "%" << arg_counter++; |
473 | for (int i = 1; i < num_operands_b; ++i) { |
474 | templates << ", %" << arg_counter++; |
475 | } |
476 | templates << "}, {" |
477 | << "%" << arg_counter++; |
478 | for (int i = 1; i < num_operands_c; ++i) { |
479 | templates << ", %" << arg_counter++; |
480 | } |
481 | templates << "}" ; |
482 | // templates of metadata and sparse selector for sparse mma. |
483 | if (sparse) { |
484 | templates << ", %" << (arg_counter++) << ", F" ; |
485 | } |
486 | |
487 | // generate inputs |
488 | for (int i = 0; i < num_operands_a; ++i) { |
489 | if (i != 0) { |
490 | inputs << ", " ; |
491 | } |
492 | inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type << "(A))[" << i |
493 | << "])" ; |
494 | } |
495 | for (int i = 0; i < num_operands_b; ++i) { |
496 | inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type << "(B))[" << i |
497 | << "])" ; |
498 | } |
499 | for (int i = 0; i < num_operands_c; ++i) { |
500 | inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(C))[" << i |
501 | << "])" ; |
502 | } |
503 | // input of metadata for sparse mma. |
504 | if (sparse) { |
505 | inputs << ", \"r\"(((unsigned *)(E))[0])" ; |
506 | } |
507 | |
508 | // generate outputs |
509 | for (int i = 0; i < num_operands_c; ++i) { |
510 | if (i != 0) { |
511 | outputs << "," ; |
512 | } |
513 | outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(D))[" << i |
514 | << "])" ; |
515 | } |
516 | return std::make_tuple(templates.str(), inputs.str(), outputs.str()); |
517 | } |
518 | |
519 | std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, |
520 | const std::string& B_layout, const std::string& A_dtype, |
521 | const std::string& B_dtype, const std::string& C_dtype, |
522 | const std::string& a_ptr, const std::string& a_elem_offset, |
523 | const std::string& b_ptr, const std::string& b_elem_offset, |
524 | const std::string& c_ptr, const std::string& c_elem_offset, |
525 | const std::string& metadata, const std::string& metadata_offset, |
526 | const std::string& sparsity_selector, const std::string& bit_op, |
527 | bool sparse, bool saturate) { |
528 | ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = ptx::DTypeFromString(B_dtype), |
529 | dtype_c = ptx::DTypeFromString(C_dtype); |
530 | ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout), |
531 | layout_b = ptx::LayoutTypeFromString(B_layout); |
532 | auto [m, n, k] = ptx::ParseMMAShape(shape); |
533 | CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, bit_op, sparse, |
534 | saturate); |
535 | std::string asm_code = R"( |
536 | { |
537 | __asm__ __volatile__( |
538 | "mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}" |
539 | "{templates};\n" |
540 | : {outputs} |
541 | : {inputs}); |
542 | } |
543 | )" ; |
544 | auto [templates_str, inputs_str, outputs_str] = |
545 | GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse); |
546 | |
547 | // replace patterns |
548 | Replacer replacer; |
549 | replacer.register_rule("{.sparse}" , sparse ? ".sp" : "" ); |
550 | replacer.register_rule("{.shape}" , "." + shape); |
551 | replacer.register_rule("{.saturate}" , saturate ? ".satfinite" : "" ); |
552 | replacer.register_rule("{.alayout}" , "." + A_layout); |
553 | replacer.register_rule("{.blayout}" , "." + B_layout); |
554 | replacer.register_rule("{.atype}" , ptx::DTypeToString(dtype_a)); |
555 | replacer.register_rule("{.btype}" , ptx::DTypeToString(dtype_b)); |
556 | replacer.register_rule("{.ctype}" , ptx::DTypeToString(dtype_c)); |
557 | replacer.register_rule("{.dtype}" , ptx::DTypeToString(dtype_c)); |
558 | replacer.register_rule("{.bitop}" , bit_op.empty() ? "" : "." + bit_op + ".popc" ); |
559 | replacer.register_rule("{templates}" , templates_str); |
560 | replacer.register_rule("{outputs}" , outputs_str); |
561 | replacer.register_rule("{inputs}" , inputs_str); |
562 | asm_code = replacer.rewrite(asm_code); |
563 | replacer.empty_rules(); |
564 | replacer.register_rule("A" , a_ptr + " + " + a_elem_offset); |
565 | replacer.register_rule("B" , b_ptr + " + " + b_elem_offset); |
566 | replacer.register_rule("C" , c_ptr + " + " + c_elem_offset); |
567 | replacer.register_rule("D" , c_ptr + " + " + c_elem_offset); |
568 | replacer.register_rule("E" , metadata + " + " + metadata_offset); |
569 | replacer.register_rule("F" , sparsity_selector); |
570 | asm_code = replacer.rewrite(asm_code); |
571 | return asm_code; |
572 | } |
573 | |
574 | inline std::tuple<std::string, std::string> GetLoadMatrixOperands( |
575 | int num, const std::string& local_ptr, const std::string& local_elem_offset) { |
576 | std::stringstream templates, outputs; |
577 | int arg_counter = 0; |
578 | // generate templates |
579 | templates << "{%" << arg_counter++; |
580 | for (int i = 1; i < num; ++i) { |
581 | templates << ", %" << arg_counter++; |
582 | } |
583 | templates << "}, [%" << arg_counter++ << "]" ; |
584 | // generate outputs |
585 | std::string ptr_type = "(unsigned *)" ; |
586 | for (int i = 0; i < num; ++i) { |
587 | if (i != 0) { |
588 | outputs << ", " ; |
589 | } |
590 | outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))[" |
591 | << i << "])" ; |
592 | } |
593 | return std::make_tuple(templates.str(), outputs.str()); |
594 | } |
595 | |
596 | std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type, |
597 | const std::string& local_ptr, |
598 | const std::string& local_elem_offset, |
599 | const std::string& smem_ptr, |
600 | const std::string& smem_elem_offset) { |
601 | CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices." ; |
602 | ptx::DataType data_type = ptx::DTypeFromString(type); |
603 | CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16." ; |
604 | std::string asm_code = R"( |
605 | { |
606 | unsigned int addr; |
607 | __asm__ __volatile__( |
608 | "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" |
609 | : "=r"(addr) |
610 | : "l"((void *)({smem_addr})) |
611 | ); |
612 | __asm__ __volatile__( |
613 | "ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}" |
614 | "{templates};\n" |
615 | : {outputs} |
616 | : "r"(addr) |
617 | ); |
618 | } |
619 | )" ; |
620 | auto [templates_str, outputs_str] = GetLoadMatrixOperands(num, local_ptr, local_elem_offset); |
621 | |
622 | Replacer replacer; |
623 | replacer.register_rule("{.shape}" , ".m8n8" ); |
624 | replacer.register_rule("{.num}" , ".x" + std::to_string(num)); |
625 | replacer.register_rule("{.trans}" , trans ? ".trans" : "" ); |
626 | replacer.register_rule("{.ss}" , ".shared" ); |
627 | replacer.register_rule("{.type}" , ptx::DTypeToString(data_type)); |
628 | replacer.register_rule("{smem_addr}" , smem_ptr + " + " + smem_elem_offset); |
629 | replacer.register_rule("{templates}" , templates_str); |
630 | replacer.register_rule("{outputs}" , outputs_str); |
631 | asm_code = replacer.rewrite(asm_code); |
632 | return asm_code; |
633 | } |
634 | |
635 | std::string PrintCpAsyncAssembly(const std::string& shared_ptr, |
636 | const std::string& shared_elem_offset, |
637 | const std::string& global_ptr, |
638 | const std::string& global_elem_offset, const std::string& bytes) { |
639 | std::string asm_code = R"( |
640 | { |
641 | unsigned int addr; |
642 | __asm__ __volatile__( |
643 | "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" |
644 | : "=r"(addr) |
645 | : "l"((void *)({smem_addr})) |
646 | ); |
647 | __asm__ __volatile__( |
648 | "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;" |
649 | :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) |
650 | ); |
651 | } |
652 | )" ; |
653 | Replacer replacer; |
654 | replacer.register_rule("{smem_addr}" , shared_ptr + " + " + shared_elem_offset); |
655 | replacer.register_rule("{global_ptr}" , global_ptr + " + " + global_elem_offset); |
656 | replacer.register_rule("{bytes}" , bytes); |
657 | replacer.register_rule("{cg_or_ca}" , bytes == "16" ? "cg" : "ca" ); |
658 | asm_code = replacer.rewrite(asm_code); |
659 | return asm_code; |
660 | } |
661 | |
662 | } // namespace codegen |
663 | } // namespace tvm |
664 | |