1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file ptx.cc
22 */
23
24#include "ptx.h"
25
26#include <algorithm>
27#include <string>
28#include <tuple>
29#include <utility>
30#include <vector>
31
32namespace tvm {
33namespace codegen {
34
35// PTX related data structures and functions.
36namespace ptx {
37
38/*!
39 * \brief PTX data type.
40 * \note
41 * PTX fundamental data types:
42 * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types
43 * PTX matrix data types:
44 * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
45 */
46enum class DataType : int {
47 kInt4 = 0,
48 kUInt4 = 1,
49 kInt8 = 2,
50 kUInt8 = 3,
51 kInt16 = 4,
52 kUInt16 = 5,
53 kInt32 = 6,
54 kUInt32 = 7,
55 kInt64 = 8,
56 kUInt64 = 9,
57 kFloat16 = 10,
58 kBFloat16 = 11,
59 kFloat16x2 = 12,
60 kFloat32 = 13,
61 kTensorFloat32 = 14,
62 kFloat64 = 15,
63 kBit1 = 16,
64 kBit8 = 17,
65 kBit16 = 18,
66 kBit32 = 19,
67 kBit64 = 20,
68};
69
70static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32",
71 ".u32", ".s64", ".u64", ".f16", ".bf16", ".f16x2", ".f32",
72 ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"};
73static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16,
74 16, 32, 32, 32, 64, 1, 8, 16, 32, 64};
75
76/*!
77 * \brief Create PTX data type from string.
78 */
79inline DataType DTypeFromString(const std::string str) {
80 if (str == "int4" || str == ".s4") {
81 return DataType::kInt4;
82 } else if (str == "uint4" || str == ".u4") {
83 return DataType::kUInt4;
84 } else if (str == "int8" || str == ".s8") {
85 return DataType::kInt8;
86 } else if (str == "uint8" || str == ".u8") {
87 return DataType::kUInt8;
88 } else if (str == "int16" || str == ".s16") {
89 return DataType::kInt16;
90 } else if (str == "uint16" || str == ".u16") {
91 return DataType::kUInt16;
92 } else if (str == "int32" || str == ".s32") {
93 return DataType::kInt32;
94 } else if (str == "uint32" || str == ".u32") {
95 return DataType::kUInt32;
96 } else if (str == "int64" || str == ".s64") {
97 return DataType::kInt64;
98 } else if (str == "uint64" || str == ".u64") {
99 return DataType::kUInt64;
100 } else if (str == "float16" || str == "fp16" || str == ".f16") {
101 return DataType::kFloat16;
102 } else if (str == "bfloat16" || str == "bf16") {
103 return DataType::kBFloat16;
104 } else if (str == ".f16x2") {
105 return DataType::kFloat16x2;
106 } else if (str == "float32" || str == "fp32" || str == ".f32") {
107 return DataType::kFloat32;
108 } else if (str == "tf32") {
109 return DataType::kTensorFloat32;
110 } else if (str == "float64" || str == "fp64" || str == ".f64") {
111 return DataType::kFloat64;
112 } else if (str == "int1" || str == ".b1") {
113 return DataType::kBit1;
114 } else if (str == ".b8") {
115 return DataType::kBit8;
116 } else if (str == ".b16") {
117 return DataType::kBit16;
118 } else if (str == ".b32") {
119 return DataType::kBit32;
120 } else if (str == ".b64") {
121 return DataType::kBit64;
122 } else {
123 LOG(FATAL) << "Unrecognized PTX data type " << str;
124 }
125}
126
127/*!
128 * \brief Get the string representation of given PTX data type.
129 */
130inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast<int>(dtype)]; }
131
132/*!
133 * \brief Get the number of bits of given PTX data type.
134 */
135inline uint32_t DTypeBits(DataType dtype) { return num_bits[static_cast<int>(dtype)]; }
136
137/*!
138 * \brief Extract the value m, n, k from string m*n*k*
139 */
140inline std::tuple<int, int, int> ParseMMAShape(const std::string& str) {
141 size_t pos_m = str.find("m"), pos_n = str.find("n"), pos_k = str.find("k");
142 CHECK(pos_m != str.npos && pos_n != str.npos && pos_k != str.npos)
143 << "Cannot parse MMA shape " << str;
144 int m = std::stoi(str.substr(pos_m + 1, pos_n - pos_m - 1)),
145 n = std::stoi(str.substr(pos_n + 1, pos_k - pos_n - 1)), k = std::stoi(str.substr(pos_k + 1));
146 return std::make_tuple(m, n, k);
147}
148
149/*!
150 * \brief Layout Type
151 */
152enum class LayoutType : int { kRowMajor = 0, kColumnMajor = 1 };
153
154/*!
155 * \brief Parse layout type
156 */
157LayoutType LayoutTypeFromString(const std::string& str) {
158 if (str == "row") {
159 return LayoutType::kRowMajor;
160 } else if (str == "col") {
161 return LayoutType::kColumnMajor;
162 } else {
163 LOG(FATAL) << "Unrecognized layout type " << str;
164 }
165}
166
167static const char* layout_type_str[] = {"row", "col"};
168
169/*!
170 * \brief Convert layout type to string.
171 */
172inline std::string LayoutTypeToString(LayoutType layout) {
173 return layout_type_str[static_cast<int>(layout)];
174}
175
176/*!
177 * \brief MMA Configurations, used to determine validity.
178 */
179struct MMAConfig {
180 explicit MMAConfig(int m, int n, int k, DataType dtype_mul, bool use_bit_op, bool sparse)
181 : m(m), n(n), k(k), dtype_mul(dtype_mul), use_bit_op(use_bit_op), sparse(sparse) {}
182 int m, n, k;
183 DataType dtype_mul;
184 bool use_bit_op;
185 bool sparse;
186 inline bool operator==(const MMAConfig& other) {
187 return m == other.m && n == other.n && k == other.k && dtype_mul == other.dtype_mul &&
188 use_bit_op == other.use_bit_op && sparse == other.sparse;
189 }
190};
191
192/*!
193 * \brief Valid MMA configurations
194 * \note Reference:
195 * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-shape
196 */
197const MMAConfig valid_mma_configs[] = {
198 MMAConfig(8, 8, 4, DataType::kFloat64, false, false),
199 MMAConfig(8, 8, 4, DataType::kFloat16, false, false),
200 MMAConfig(16, 8, 8, DataType::kFloat16, false, false),
201 MMAConfig(16, 8, 16, DataType::kFloat16, false, false),
202 MMAConfig(16, 8, 8, DataType::kBFloat16, false, false),
203 MMAConfig(16, 8, 16, DataType::kBFloat16, false, false),
204 MMAConfig(16, 8, 4, DataType::kTensorFloat32, false, false),
205 MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, false),
206 MMAConfig(8, 8, 16, DataType::kInt8, false, false),
207 MMAConfig(16, 8, 16, DataType::kInt8, false, false),
208 MMAConfig(16, 8, 32, DataType::kInt8, false, false),
209 MMAConfig(8, 8, 16, DataType::kUInt8, false, false),
210 MMAConfig(16, 8, 16, DataType::kUInt8, false, false),
211 MMAConfig(16, 8, 32, DataType::kUInt8, false, false),
212 MMAConfig(8, 8, 32, DataType::kInt4, false, false),
213 MMAConfig(16, 8, 32, DataType::kInt4, false, false),
214 MMAConfig(16, 8, 64, DataType::kInt4, false, false),
215 MMAConfig(8, 8, 32, DataType::kUInt4, false, false),
216 MMAConfig(16, 8, 32, DataType::kUInt4, false, false),
217 MMAConfig(16, 8, 64, DataType::kUInt4, false, false),
218 MMAConfig(8, 8, 128, DataType::kBit1, true, false),
219 MMAConfig(16, 8, 128, DataType::kBit1, true, false),
220 MMAConfig(16, 8, 256, DataType::kBit1, true, false),
221 MMAConfig(16, 8, 16, DataType::kFloat16, false, true),
222 MMAConfig(16, 8, 32, DataType::kFloat16, false, true),
223 MMAConfig(16, 8, 16, DataType::kBFloat16, false, true),
224 MMAConfig(16, 8, 32, DataType::kBFloat16, false, true),
225 MMAConfig(16, 8, 8, DataType::kTensorFloat32, false, true),
226 MMAConfig(16, 8, 16, DataType::kTensorFloat32, false, true),
227 MMAConfig(16, 8, 32, DataType::kInt8, false, true),
228 MMAConfig(16, 8, 64, DataType::kInt8, false, true),
229 MMAConfig(16, 8, 32, DataType::kUInt8, false, true),
230 MMAConfig(16, 8, 64, DataType::kUInt8, false, true),
231 MMAConfig(16, 8, 64, DataType::kInt4, false, true),
232 MMAConfig(16, 8, 128, DataType::kInt4, false, true),
233 MMAConfig(16, 8, 64, DataType::kUInt4, false, true),
234 MMAConfig(16, 8, 128, DataType::kUInt4, false, true),
235};
236
237/*!
238 * \brief Check whether the multiplicand data type and accumulator data type is valid for MMA
239 * computation.
240 * \param dtype_a The data type of multiplicand a.
241 * \param dtype_b The data type of multiplicand b.
242 * \param dtype_c The data type of accumulator c.
243 * \note Reference:
244 * https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types
245 */
246void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_c) {
247 std::string ab_not_match_err_str = "The multiplicands' data type " + DTypeToString(dtype_a) +
248 DTypeToString(dtype_b) + " do not match.";
249 // check a and b
250 switch (dtype_a) {
251 case DataType::kBit1:
252 case DataType::kFloat16:
253 case DataType::kBFloat16:
254 case DataType::kTensorFloat32:
255 case DataType::kFloat64:
256 CHECK(dtype_a == dtype_b) << ab_not_match_err_str;
257 break;
258 case DataType::kInt4:
259 case DataType::kUInt4:
260 CHECK(dtype_b == DataType::kInt4 || dtype_b == DataType::kUInt4) << ab_not_match_err_str;
261 break;
262 case DataType::kInt8:
263 case DataType::kUInt8:
264 CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << ab_not_match_err_str;
265 break;
266 default:
267 CHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a)
268 << DTypeToString(dtype_b);
269 }
270 // check a,b and c
271 switch (dtype_a) {
272 case DataType::kBit1:
273 case DataType::kInt4:
274 case DataType::kUInt4:
275 case DataType::kInt8:
276 case DataType::kUInt8:
277 CHECK(dtype_c == DataType::kInt32)
278 << "For multiplicand data type " << DTypeToString(dtype_a) << DTypeToString(dtype_b)
279 << ", accumulator data type should be s32.";
280 break;
281 case DataType::kFloat16:
282 CHECK(dtype_c == DataType::kFloat16 || dtype_c == DataType::kFloat32)
283 << "For multiplicand data type f16, accumulator data type should be f16/f32.";
284 break;
285 case DataType::kBFloat16:
286 case DataType::kTensorFloat32:
287 CHECK(dtype_c == DataType::kFloat32)
288 << "For multiplicand data type bf16/tf32, accumulator data type can only be f32.";
289 break;
290 case DataType::kFloat64:
291 CHECK(dtype_c == DataType::kFloat64)
292 << "For multiplicand data type f64, accumulator data type can only be f64.";
293 break;
294 default:
295 CHECK(false) << "Invalid multiplicand/accumulator data types: " << DTypeToString(dtype_a)
296 << DTypeToString(dtype_b) << DTypeToString(dtype_c) << ".";
297 }
298}
299
300/*!
301 * \brief Check whether the given configuration is valid for MMA computation.
302 * \param m The M in mMnNkK of MMA instructions.
303 * \param n The N in mMnNkK of MMA instructions.
304 * \param k The K in mMnNkK of MMA instructions.
305 * \param layout_a The layout of multiplicand A (row/col).
306 * \param layout_b The layout of multiplicand B (row/col).
307 * \param dtype_a The data type of multiplicand A.
308 * \param dtype_b The data type of multiplicand B.
309 * \param dtype_c The data type of accumulator C.
310 * \param bit_op The bit operator for 1-bit MMA computation, can be "xor"/"and" or ""(if it's not
311 * 1-bit MMA).
312 * \param sparse Whether it's Sparse MMA or not.
313 * \param saturate Whether saturate output or not.
314 */
315void CheckMMAConfigValidity(int m, int n, int k, LayoutType layout_a, LayoutType layout_b,
316 DataType dtype_a, DataType dtype_b, DataType dtype_c,
317 const std::string& bit_op, bool sparse, bool saturate) {
318 CHECK(bit_op == "xor" || bit_op == "and" || bit_op == "")
319 << "Unrecognized 1-bit operation " << bit_op << " , can only be xor/and.";
320 bool use_bit_op = !bit_op.empty();
321 if (use_bit_op) {
322 CHECK(dtype_a == DataType::kBit1) << "Bit operator is only compatible with 1-bit multiplicand.";
323 }
324 CheckMMADTypeCompatible(dtype_a, dtype_b, dtype_c);
325 if (saturate) {
326 CHECK(dtype_a == DataType::kInt4 || dtype_a == DataType::kUInt4 || dtype_a == DataType::kInt8 ||
327 dtype_a == DataType::kUInt8)
328 << "Output saturation only applicable to multiplicand type s4/u4/s8/u8.";
329 }
330
331 if (!(m == 8 && n == 8 && k == 4 && dtype_a == ptx::DataType::kFloat16)) {
332 // Only MMA on m8n8k4 for fp16 supports customized layouts.
333 CHECK(layout_a == LayoutType::kRowMajor && layout_b == LayoutType::kColumnMajor)
334 << "Invalid layout combination " << LayoutTypeToString(layout_a) << ","
335 << LayoutTypeToString(layout_b) << ".";
336 }
337
338 MMAConfig config(m, n, k, dtype_a, use_bit_op, sparse);
339 bool match = false;
340 for (const MMAConfig& valid_config : valid_mma_configs) {
341 if (config == valid_config) {
342 match = true;
343 break;
344 }
345 }
346 CHECK(match) << "Cannot find matched MMA configurations.";
347}
348
349/*!
350 * \brief Fragment attributes
351 */
352class FragAttrs {
353 public:
354 explicit FragAttrs(char reg_type, uint32_t size, std::string ptr_type)
355 : reg_type(reg_type), size(size), ptr_type(ptr_type) {}
356 /*! \brief PTX register type */
357 char reg_type;
358 /*! \brief Fragment size */
359 uint32_t size;
360 /*! \brief Fragment pointer type */
361 std::string ptr_type;
362};
363
364/*!
365 * \brief Fragment attributes of given data type.
366 */
367inline FragAttrs GetFragAttrs(DataType dtype) {
368 switch (dtype) {
369 case DataType::kBit1:
370 case DataType::kInt4:
371 case DataType::kUInt4:
372 case DataType::kInt8:
373 case DataType::kUInt8:
374 case DataType::kBit16:
375 case DataType::kFloat16: // .f16x2 register
376 case DataType::kBFloat16:
377 case DataType::kTensorFloat32:
378 return FragAttrs('r', 32, "(unsigned *)");
379 case DataType::kInt32:
380 return FragAttrs('r', 32, "(int *)");
381 case DataType::kFloat32:
382 return FragAttrs('f', 32, "(float *)");
383 case DataType::kFloat64:
384 return FragAttrs('d', 64, "(double *)");
385 default:
386 ICHECK(false) << DTypeToString(dtype) << " is not matrix data type in MMA.";
387 return FragAttrs('\0', 0, "");
388 }
389}
390
391}; // namespace ptx
392
393/*!
394 * \brief Replace patterns with replacement strings.
395 * \note should use std::format instead when codebase is ported to C++20.
396 */
397class Replacer {
398 public:
399 void register_rule(const std::string& pattern, const std::string& replacement) {
400 _rules.emplace_back(pattern, replacement);
401 }
402 std::string rewrite(std::string str) {
403 for (auto&& rule : _rules) {
404 auto [pattern, replacement] = rule;
405 size_t len = pattern.size();
406 size_t new_len = replacement.size();
407 size_t pos = str.find(pattern);
408 while (pos != std::string::npos) {
409 str = str.replace(pos, len, replacement);
410 pos = str.find(pattern, pos + new_len);
411 }
412 }
413 return str;
414 }
415 void empty_rules() { _rules.clear(); }
416
417 private:
418 std::vector<std::pair<std::string, std::string>> _rules;
419};
420
421/*!
422 * \brief Get the number of MMA computations for given shape and datatype.
423 */
424inline uint32_t GetNumMMAComputations(int m, int n, int k, ptx::DataType dtype) {
425 if (m == 8 && n == 8 && k == 4 && dtype == ptx::DataType::kFloat16) {
426 // MMA for m8n8k4 on fp16 would launch 4 MMA computations instead of one.
427 return 4;
428 } else {
429 return 1;
430 }
431}
432
433/*!
434 * \brief Return template string, input operands string and output operands string.
435 * \param m The M in mMnNkK of MMA instructions.
436 * \param n The N in mMnNkK of MMA instructions.
437 * \param k The K in mMnNkK of MMA instructions.
438 * \param dtype_a The data type of multiplicand a.
439 * \param dtype_b The data type of multiplicand b.
440 * \param dtype_c The data type of accumulator c.
441 * \param sparse Whether it's Sparse MMA or not.
442 */
443inline std::tuple<std::string, std::string, std::string> GetMMAOperands(int m, int n, int k,
444 ptx::DataType dtype_a,
445 ptx::DataType dtype_b,
446 ptx::DataType dtype_c,
447 bool sparse) {
448 std::stringstream templates, inputs, outputs;
449 const ptx::FragAttrs frag_attr_a = ptx::GetFragAttrs(dtype_a),
450 frag_attr_b = ptx::GetFragAttrs(dtype_b),
451 frag_attr_c = ptx::GetFragAttrs(dtype_c);
452 constexpr uint32_t warp_size = 32;
453 const uint32_t threads = warp_size / GetNumMMAComputations(m, n, k, dtype_a);
454 const int num_operands_a =
455 (m * k) * ptx::DTypeBits(dtype_a) / frag_attr_a.size / threads / (sparse ? 2 : 1),
456 num_operands_b = (k * n) * ptx::DTypeBits(dtype_b) / frag_attr_b.size / threads,
457 num_operands_c = (m * n) * ptx::DTypeBits(dtype_c) / frag_attr_c.size / threads;
458
459 // generate templates;
460 int arg_counter = 0;
461 templates << "{"
462 << "%" << arg_counter++;
463 for (int i = 1; i < num_operands_c; ++i) {
464 templates << ", %" << arg_counter++;
465 }
466 templates << "}, {"
467 << "%" << arg_counter++;
468 for (int i = 1; i < num_operands_a; ++i) {
469 templates << ", %" << arg_counter++;
470 }
471 templates << "}, {"
472 << "%" << arg_counter++;
473 for (int i = 1; i < num_operands_b; ++i) {
474 templates << ", %" << arg_counter++;
475 }
476 templates << "}, {"
477 << "%" << arg_counter++;
478 for (int i = 1; i < num_operands_c; ++i) {
479 templates << ", %" << arg_counter++;
480 }
481 templates << "}";
482 // templates of metadata and sparse selector for sparse mma.
483 if (sparse) {
484 templates << ", %" << (arg_counter++) << ", F";
485 }
486
487 // generate inputs
488 for (int i = 0; i < num_operands_a; ++i) {
489 if (i != 0) {
490 inputs << ", ";
491 }
492 inputs << "\"" << frag_attr_a.reg_type << "\"((" << frag_attr_a.ptr_type << "(A))[" << i
493 << "])";
494 }
495 for (int i = 0; i < num_operands_b; ++i) {
496 inputs << ", \"" << frag_attr_b.reg_type << "\"((" << frag_attr_b.ptr_type << "(B))[" << i
497 << "])";
498 }
499 for (int i = 0; i < num_operands_c; ++i) {
500 inputs << ", \"" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(C))[" << i
501 << "])";
502 }
503 // input of metadata for sparse mma.
504 if (sparse) {
505 inputs << ", \"r\"(((unsigned *)(E))[0])";
506 }
507
508 // generate outputs
509 for (int i = 0; i < num_operands_c; ++i) {
510 if (i != 0) {
511 outputs << ",";
512 }
513 outputs << " \"=" << frag_attr_c.reg_type << "\"((" << frag_attr_c.ptr_type << "(D))[" << i
514 << "])";
515 }
516 return std::make_tuple(templates.str(), inputs.str(), outputs.str());
517}
518
519std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
520 const std::string& B_layout, const std::string& A_dtype,
521 const std::string& B_dtype, const std::string& C_dtype,
522 const std::string& a_ptr, const std::string& a_elem_offset,
523 const std::string& b_ptr, const std::string& b_elem_offset,
524 const std::string& c_ptr, const std::string& c_elem_offset,
525 const std::string& metadata, const std::string& metadata_offset,
526 const std::string& sparsity_selector, const std::string& bit_op,
527 bool sparse, bool saturate) {
528 ptx::DataType dtype_a = ptx::DTypeFromString(A_dtype), dtype_b = ptx::DTypeFromString(B_dtype),
529 dtype_c = ptx::DTypeFromString(C_dtype);
530 ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout),
531 layout_b = ptx::LayoutTypeFromString(B_layout);
532 auto [m, n, k] = ptx::ParseMMAShape(shape);
533 CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, bit_op, sparse,
534 saturate);
535 std::string asm_code = R"(
536 {
537 __asm__ __volatile__(
538 "mma{.sparse}.sync.aligned{.shape}{.alayout}{.blayout}{.saturate}{.dtype}{.atype}{.btype}{.ctype}{.bitop}"
539 "{templates};\n"
540 : {outputs}
541 : {inputs});
542 }
543)";
544 auto [templates_str, inputs_str, outputs_str] =
545 GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse);
546
547 // replace patterns
548 Replacer replacer;
549 replacer.register_rule("{.sparse}", sparse ? ".sp" : "");
550 replacer.register_rule("{.shape}", "." + shape);
551 replacer.register_rule("{.saturate}", saturate ? ".satfinite" : "");
552 replacer.register_rule("{.alayout}", "." + A_layout);
553 replacer.register_rule("{.blayout}", "." + B_layout);
554 replacer.register_rule("{.atype}", ptx::DTypeToString(dtype_a));
555 replacer.register_rule("{.btype}", ptx::DTypeToString(dtype_b));
556 replacer.register_rule("{.ctype}", ptx::DTypeToString(dtype_c));
557 replacer.register_rule("{.dtype}", ptx::DTypeToString(dtype_c));
558 replacer.register_rule("{.bitop}", bit_op.empty() ? "" : "." + bit_op + ".popc");
559 replacer.register_rule("{templates}", templates_str);
560 replacer.register_rule("{outputs}", outputs_str);
561 replacer.register_rule("{inputs}", inputs_str);
562 asm_code = replacer.rewrite(asm_code);
563 replacer.empty_rules();
564 replacer.register_rule("A", a_ptr + " + " + a_elem_offset);
565 replacer.register_rule("B", b_ptr + " + " + b_elem_offset);
566 replacer.register_rule("C", c_ptr + " + " + c_elem_offset);
567 replacer.register_rule("D", c_ptr + " + " + c_elem_offset);
568 replacer.register_rule("E", metadata + " + " + metadata_offset);
569 replacer.register_rule("F", sparsity_selector);
570 asm_code = replacer.rewrite(asm_code);
571 return asm_code;
572}
573
574inline std::tuple<std::string, std::string> GetLoadMatrixOperands(
575 int num, const std::string& local_ptr, const std::string& local_elem_offset) {
576 std::stringstream templates, outputs;
577 int arg_counter = 0;
578 // generate templates
579 templates << "{%" << arg_counter++;
580 for (int i = 1; i < num; ++i) {
581 templates << ", %" << arg_counter++;
582 }
583 templates << "}, [%" << arg_counter++ << "]";
584 // generate outputs
585 std::string ptr_type = "(unsigned *)";
586 for (int i = 0; i < num; ++i) {
587 if (i != 0) {
588 outputs << ", ";
589 }
590 outputs << "\"=r\"((" << ptr_type << "(" << local_ptr << " + " << local_elem_offset << "))["
591 << i << "])";
592 }
593 return std::make_tuple(templates.str(), outputs.str());
594}
595
596std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
597 const std::string& local_ptr,
598 const std::string& local_elem_offset,
599 const std::string& smem_ptr,
600 const std::string& smem_elem_offset) {
601 CHECK(num == 1 || num == 2 || num == 4) << "ldmatrix only accept loading 1/2/4 matrices.";
602 ptx::DataType data_type = ptx::DTypeFromString(type);
603 CHECK(data_type == ptx::DataType::kBit16) << "ldmatrix only accept matrix with type .b16.";
604 std::string asm_code = R"(
605 {
606 unsigned int addr;
607 __asm__ __volatile__(
608 "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
609 : "=r"(addr)
610 : "l"((void *)({smem_addr}))
611 );
612 __asm__ __volatile__(
613 "ldmatrix.sync.aligned{.shape}{.num}{.trans}{.ss}{.type}"
614 "{templates};\n"
615 : {outputs}
616 : "r"(addr)
617 );
618 }
619)";
620 auto [templates_str, outputs_str] = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);
621
622 Replacer replacer;
623 replacer.register_rule("{.shape}", ".m8n8");
624 replacer.register_rule("{.num}", ".x" + std::to_string(num));
625 replacer.register_rule("{.trans}", trans ? ".trans" : "");
626 replacer.register_rule("{.ss}", ".shared");
627 replacer.register_rule("{.type}", ptx::DTypeToString(data_type));
628 replacer.register_rule("{smem_addr}", smem_ptr + " + " + smem_elem_offset);
629 replacer.register_rule("{templates}", templates_str);
630 replacer.register_rule("{outputs}", outputs_str);
631 asm_code = replacer.rewrite(asm_code);
632 return asm_code;
633}
634
635std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
636 const std::string& shared_elem_offset,
637 const std::string& global_ptr,
638 const std::string& global_elem_offset, const std::string& bytes) {
639 std::string asm_code = R"(
640 {
641 unsigned int addr;
642 __asm__ __volatile__(
643 "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
644 : "=r"(addr)
645 : "l"((void *)({smem_addr}))
646 );
647 __asm__ __volatile__(
648 "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
649 :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
650 );
651 }
652)";
653 Replacer replacer;
654 replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
655 replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
656 replacer.register_rule("{bytes}", bytes);
657 replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
658 asm_code = replacer.rewrite(asm_code);
659 return asm_code;
660}
661
662} // namespace codegen
663} // namespace tvm
664