1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Infer TensorCore metadata from tensor intrinsic.
22 * \file tensorcore_fragment.cc
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/expr.h>
26#include <tvm/tir/stmt_functor.h>
27#include <tvm/tir/transform.h>
28
29#include <unordered_map>
30#include <unordered_set>
31
32#include "../../runtime/thread_storage_scope.h"
33#include "ir_utils.h"
34#include "storage_access.h"
35
36namespace tvm {
37namespace tir {
38
39// Get fragment information from tensor intrinsics
40class FragmentGetter : public StmtExprVisitor {
41 public:
42 void VisitExpr_(const CallNode* op) final {
43 StmtExprVisitor::VisitExpr_(op);
44
45 if (op->op.same_as(builtin::tvm_load_matrix_sync()) ||
46 op->op.same_as(builtin::tvm_store_matrix_sync())) {
47 // Get shape and layout information from load and store intrinsic
48 ICHECK_EQ(op->args.size(), 8U);
49 const VarNode* buffer_var = op->args[0].as<VarNode>();
50 ICHECK(buffer_var);
51 // Get shape
52 const IntImmNode* m = op->args[1].as<IntImmNode>();
53 const IntImmNode* n = op->args[2].as<IntImmNode>();
54 const IntImmNode* k = op->args[3].as<IntImmNode>();
55 const StringImmNode* layout = op->args[7].as<StringImmNode>();
56 ICHECK(m);
57 ICHECK(n);
58 ICHECK(k);
59 ICHECK(layout);
60
61 std::string scope = GetPtrStorageScope(GetRef<Var>(buffer_var));
62 if (fragments.count(buffer_var)) {
63 // check if the fragment has met before
64 FragmentInfo info = fragments[buffer_var];
65 ICHECK_EQ(m->value, info.m);
66 ICHECK_EQ(n->value, info.n);
67 ICHECK_EQ(k->value, info.k);
68 if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
69 ICHECK_EQ(layout->value, info.layout);
70 }
71 } else {
72 // store metadata
73 FragmentInfo info;
74 if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
75 info = FragmentInfo(m->value, n->value, k->value, layout->value, scope);
76 } else if (scope == "wmma.accumulator") {
77 info = FragmentInfo(m->value, n->value, k->value, "", scope);
78 }
79 fragments[buffer_var] = info;
80 }
81 } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
82 // Get shape information from fill intrinsic
83 ICHECK_EQ(op->args.size(), 6U);
84 const VarNode* buffer_var = op->args[0].as<VarNode>();
85 ICHECK(buffer_var);
86 // Get shape
87 const IntImmNode* m = op->args[1].as<IntImmNode>();
88 const IntImmNode* n = op->args[2].as<IntImmNode>();
89 const IntImmNode* k = op->args[3].as<IntImmNode>();
90 ICHECK(m);
91 ICHECK(n);
92 ICHECK(k);
93
94 std::string scope = GetPtrStorageScope(GetRef<Var>(buffer_var));
95 if (fragments.count(buffer_var)) {
96 FragmentInfo info = fragments[buffer_var];
97 ICHECK_EQ(m->value, info.m);
98 ICHECK_EQ(n->value, info.n);
99 ICHECK_EQ(k->value, info.k);
100 } else {
101 // default to row major ordering
102 FragmentInfo info(m->value, n->value, k->value, "row_major", scope);
103 fragments[buffer_var] = info;
104 }
105 }
106 }
107
108 // Get memory scope
109 void VisitStmt_(const AttrStmtNode* op) final { StmtExprVisitor::VisitStmt_(op); }
110
111 // Fragment metadata for all fragments
112 std::unordered_map<const VarNode*, FragmentInfo> fragments;
113};
114
115std::unordered_map<const VarNode*, FragmentInfo> GetTensorCoreFragmentInfo(const Stmt& stmt) {
116 FragmentGetter getter;
117 getter(stmt);
118 return std::move(getter.fragments);
119}
120
121// Check shape of fragment making sure it is a valid shape for tvm_mma_sync
122class FragmentChecker : public StmtExprVisitor {
123 public:
124 explicit FragmentChecker(const FragmentGetter& getter) : fragment_getter(getter) {}
125
126 void VisitExpr_(const CallNode* op) final {
127 StmtExprVisitor::VisitExpr_(op);
128 // Check shape when calling tvm_mma_sync
129 if (op->op.same_as(builtin::tvm_mma_sync()) || op->op.same_as(builtin::tvm_bmma_sync())) {
130 ICHECK_EQ(op->args.size(), 8U);
131 const VarNode* buffer_var_d = op->args[0].as<VarNode>();
132 const VarNode* buffer_var_a = op->args[2].as<VarNode>();
133 const VarNode* buffer_var_b = op->args[4].as<VarNode>();
134 const VarNode* buffer_var_c = op->args[6].as<VarNode>();
135 ICHECK(buffer_var_d);
136 ICHECK(buffer_var_a);
137 ICHECK(buffer_var_b);
138 ICHECK(buffer_var_c);
139
140 // Check all fragment A, B, C and D have the same shape
141 ICHECK(CheckShape(buffer_var_d, buffer_var_a));
142 ICHECK(CheckShape(buffer_var_d, buffer_var_b));
143 ICHECK(CheckShape(buffer_var_d, buffer_var_c));
144 }
145 }
146
147 private:
148 // A tool for checking shapes of two fragments
149 bool CheckShape(const VarNode* buffer1, const VarNode* buffer2) {
150 CHECK(fragment_getter.fragments.count(buffer1))
151 << "Tensorecore fragment " << buffer1->name_hint
152 << " must be filled (with tvm_fill_fragment) or loaded (with tvm_load_matrix_sync) before "
153 "use.";
154 CHECK(fragment_getter.fragments.count(buffer2))
155 << "Tensorecore fragment " << buffer2->name_hint
156 << " must be filled (with tvm_fill_fragment) or loaded (with tvm_load_matrix_sync) before "
157 "use.";
158 FragmentInfo info1 = fragment_getter.fragments.at(buffer1);
159 FragmentInfo info2 = fragment_getter.fragments.at(buffer2);
160 return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k;
161 }
162 // Fragment infomation
163 const FragmentGetter& fragment_getter;
164};
165
166// Store the metadata into attributes
167class InferFragmenter : public StmtMutator {
168 public:
169 explicit InferFragmenter(const FragmentGetter& getter) : fragment_getter(getter) {}
170
171 Stmt VisitStmt_(const AllocateNode* op) final {
172 Stmt stmt = StmtMutator::VisitStmt_(op);
173 const VarNode* buffer = op->buffer_var.get();
174 if (fragment_getter.fragments.count(buffer)) {
175 // Add attribute to fragments allocation
176 FragmentInfo info = fragment_getter.fragments.at(buffer);
177
178 // Add shape attribute to all fragments
179 std::string shape =
180 std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k);
181 PrimExpr shape_expr = StringImm(shape);
182 Stmt shape_attr = AttrStmt(op->buffer_var, attr::fragment_shape, shape_expr, stmt);
183 if (info.layout != "") {
184 // Add shape attribute to matrix_a and matrix_b
185 Stmt layout_attr =
186 AttrStmt(op->buffer_var, attr::fragment_layout, StringImm(info.layout), shape_attr);
187 return layout_attr;
188 } else {
189 return shape_attr;
190 }
191 }
192 return stmt;
193 }
194
195 private:
196 // Fragment infomation
197 const FragmentGetter& fragment_getter;
198};
199
200Stmt InferFragment(Stmt stmt) {
201 FragmentGetter getter;
202 getter(stmt);
203 FragmentChecker checker(getter);
204 checker(stmt);
205 stmt = InferFragmenter(getter)(std::move(stmt));
206 return stmt;
207}
208
209namespace transform {
210
211Pass InferFragment() {
212 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
213 auto* n = f.CopyOnWrite();
214 n->body = InferFragment(std::move(n->body));
215 return f;
216 };
217 return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment", {});
218}
219
220TVM_REGISTER_GLOBAL("tir.transform.InferFragment").set_body_typed(InferFragment);
221
222} // namespace transform
223} // namespace tir
224} // namespace tvm
225