1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file ptx.h |
22 | * \brief Code generation with inlined PTX code. |
23 | */ |
24 | #ifndef TVM_TARGET_SOURCE_PTX_H_ |
25 | #define TVM_TARGET_SOURCE_PTX_H_ |
26 | |
27 | #include <tvm/runtime/logging.h> |
28 | |
29 | #include <string> |
30 | #include <tuple> |
31 | |
32 | namespace tvm { |
33 | namespace codegen { |
34 | |
35 | /*! |
36 | * \brief Print MMA assembly string given parameters. |
37 | * \param shape The shape string mMnNkK |
38 | * \param A_layout The layout of multiplicand A, can be either "row" or "col". |
39 | * \param B_layout The layout of multiplicand B, can be either "row" or "col". |
40 | * \param A_dtype The data type of multiplicand A. |
41 | * \param B_dtype The data type of multiplicand B. |
42 | * \param C_dtype The data type of multiplicand C. |
43 | * \param a_ptr Pointer to buffer A. |
44 | * \param a_offset The offset of element in A. |
45 | * \param b_ptr Pointer to buffer B. |
46 | * \param b_offset The offset of element in B. |
47 | * \param c_ptr Pointer to buffer C. |
48 | * \param c_offset The offset of element in C. |
49 | * \param metadata Pointer to metadata buffer (only used for sparse mma). |
50 | * \param metadata_offset The offset of element in metadata. |
51 | * \param sparsity_selector The sparsity selector in sparse mma. |
52 | * \param bit_op The bit operator used in 1-bit mma, can be either "xor" or "and". |
53 | * \param sparse Whether it's sparse mma or not. |
54 | * \param saturate Whether saturate output or not. |
55 | */ |
56 | std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout, |
57 | const std::string& B_layout, const std::string& A_dtype, |
58 | const std::string& B_dtype, const std::string& C_dtype, |
59 | const std::string& a_ptr, const std::string& a_offset, |
60 | const std::string& b_ptr, const std::string& b_offset, |
61 | const std::string& c_ptr, const std::string& c_offset, |
62 | const std::string& metadata, const std::string& metadata_offset, |
63 | const std::string& sparsity_selector, const std::string& bit_op, |
64 | bool sparse, bool saturate); |
65 | |
66 | /*! |
67 | * \brief Print ldmatrix assembly string given parameters. |
68 | * \param trans: whether the matrix is loaded in column major format or not. |
69 | * \param num: number of matrices to load. |
70 | * \param type: The data type in the matrix, .b16 is the only accepted data type. |
71 | * \param local_ptr: pointer to local buffer. |
72 | * \param local_elem_offset: The offset of the element to store in the local buffer. |
73 | * \param smem_ptr: pointer to the shared memory buffer to load. |
74 | * \param smem_elem_offset: The offset of the start element of the row to load in shared memory. |
75 | */ |
76 | std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type, |
77 | const std::string& local_ptr, |
78 | const std::string& local_elem_offset, |
79 | const std::string& smem_ptr, |
80 | const std::string& smem_elem_offset); |
81 | |
82 | /*! |
83 | * \brief Print ptx cp.async assembly string given parameters. |
84 | * \param shared_ptr: The pointer to the destination shared memory. |
85 | * \param shared_elem_offset: The offset into the shared memory. |
86 | * \param global_ptr: The pointer to the global memory. |
87 | * \param global_elem_offset: The offset into the global memory. |
88 | * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. |
89 | */ |
90 | std::string PrintCpAsyncAssembly(const std::string& shared_ptr, |
91 | const std::string& shared_elem_offset, |
92 | const std::string& global_ptr, |
93 | const std::string& global_elem_offset, const std::string& bytes); |
94 | |
95 | } // namespace codegen |
96 | } // namespace tvm |
97 | |
98 | #endif // TVM_TARGET_SOURCE_PTX_H_ |
99 | |