1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file ptx.h
22 * \brief Code generation with inlined PTX code.
23 */
24#ifndef TVM_TARGET_SOURCE_PTX_H_
25#define TVM_TARGET_SOURCE_PTX_H_
26
27#include <tvm/runtime/logging.h>
28
29#include <string>
30#include <tuple>
31
32namespace tvm {
33namespace codegen {
34
35/*!
36 * \brief Print MMA assembly string given parameters.
37 * \param shape The shape string mMnNkK
38 * \param A_layout The layout of multiplicand A, can be either "row" or "col".
39 * \param B_layout The layout of multiplicand B, can be either "row" or "col".
40 * \param A_dtype The data type of multiplicand A.
41 * \param B_dtype The data type of multiplicand B.
42 * \param C_dtype The data type of multiplicand C.
43 * \param a_ptr Pointer to buffer A.
44 * \param a_offset The offset of element in A.
45 * \param b_ptr Pointer to buffer B.
46 * \param b_offset The offset of element in B.
47 * \param c_ptr Pointer to buffer C.
48 * \param c_offset The offset of element in C.
49 * \param metadata Pointer to metadata buffer (only used for sparse mma).
50 * \param metadata_offset The offset of element in metadata.
51 * \param sparsity_selector The sparsity selector in sparse mma.
52 * \param bit_op The bit operator used in 1-bit mma, can be either "xor" or "and".
53 * \param sparse Whether it's sparse mma or not.
54 * \param saturate Whether saturate output or not.
55 */
56std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layout,
57 const std::string& B_layout, const std::string& A_dtype,
58 const std::string& B_dtype, const std::string& C_dtype,
59 const std::string& a_ptr, const std::string& a_offset,
60 const std::string& b_ptr, const std::string& b_offset,
61 const std::string& c_ptr, const std::string& c_offset,
62 const std::string& metadata, const std::string& metadata_offset,
63 const std::string& sparsity_selector, const std::string& bit_op,
64 bool sparse, bool saturate);
65
66/*!
67 * \brief Print ldmatrix assembly string given parameters.
68 * \param trans: whether the matrix is loaded in column major format or not.
69 * \param num: number of matrices to load.
70 * \param type: The data type in the matrix, .b16 is the only accepted data type.
71 * \param local_ptr: pointer to local buffer.
72 * \param local_elem_offset: The offset of the element to store in the local buffer.
73 * \param smem_ptr: pointer to the shared memory buffer to load.
74 * \param smem_elem_offset: The offset of the start element of the row to load in shared memory.
75 */
76std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type,
77 const std::string& local_ptr,
78 const std::string& local_elem_offset,
79 const std::string& smem_ptr,
80 const std::string& smem_elem_offset);
81
82/*!
83 * \brief Print ptx cp.async assembly string given parameters.
84 * \param shared_ptr: The pointer to the destination shared memory.
85 * \param shared_elem_offset: The offset into the shared memory.
86 * \param global_ptr: The pointer to the global memory.
87 * \param global_elem_offset: The offset into the global memory.
88 * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
89 */
90std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
91 const std::string& shared_elem_offset,
92 const std::string& global_ptr,
93 const std::string& global_elem_offset, const std::string& bytes);
94
95} // namespace codegen
96} // namespace tvm
97
98#endif // TVM_TARGET_SOURCE_PTX_H_
99