1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Infer TensorCore metadata from tensor intrinsic. |
22 | * \file tensorcore_fragment.cc |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/expr.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | #include <tvm/tir/transform.h> |
28 | |
29 | #include <unordered_map> |
30 | #include <unordered_set> |
31 | |
32 | #include "../../runtime/thread_storage_scope.h" |
33 | #include "ir_utils.h" |
34 | #include "storage_access.h" |
35 | |
36 | namespace tvm { |
37 | namespace tir { |
38 | |
39 | // Get fragment information from tensor intrinsics |
40 | class FragmentGetter : public StmtExprVisitor { |
41 | public: |
42 | void VisitExpr_(const CallNode* op) final { |
43 | StmtExprVisitor::VisitExpr_(op); |
44 | |
45 | if (op->op.same_as(builtin::tvm_load_matrix_sync()) || |
46 | op->op.same_as(builtin::tvm_store_matrix_sync())) { |
47 | // Get shape and layout information from load and store intrinsic |
48 | ICHECK_EQ(op->args.size(), 8U); |
49 | const VarNode* buffer_var = op->args[0].as<VarNode>(); |
50 | ICHECK(buffer_var); |
51 | // Get shape |
52 | const IntImmNode* m = op->args[1].as<IntImmNode>(); |
53 | const IntImmNode* n = op->args[2].as<IntImmNode>(); |
54 | const IntImmNode* k = op->args[3].as<IntImmNode>(); |
55 | const StringImmNode* layout = op->args[7].as<StringImmNode>(); |
56 | ICHECK(m); |
57 | ICHECK(n); |
58 | ICHECK(k); |
59 | ICHECK(layout); |
60 | |
61 | std::string scope = GetPtrStorageScope(GetRef<Var>(buffer_var)); |
62 | if (fragments.count(buffer_var)) { |
63 | // check if the fragment has met before |
64 | FragmentInfo info = fragments[buffer_var]; |
65 | ICHECK_EQ(m->value, info.m); |
66 | ICHECK_EQ(n->value, info.n); |
67 | ICHECK_EQ(k->value, info.k); |
68 | if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b" ) { |
69 | ICHECK_EQ(layout->value, info.layout); |
70 | } |
71 | } else { |
72 | // store metadata |
73 | FragmentInfo info; |
74 | if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b" ) { |
75 | info = FragmentInfo(m->value, n->value, k->value, layout->value, scope); |
76 | } else if (scope == "wmma.accumulator" ) { |
77 | info = FragmentInfo(m->value, n->value, k->value, "" , scope); |
78 | } |
79 | fragments[buffer_var] = info; |
80 | } |
81 | } else if (op->op.same_as(builtin::tvm_fill_fragment())) { |
82 | // Get shape information from fill intrinsic |
83 | ICHECK_EQ(op->args.size(), 6U); |
84 | const VarNode* buffer_var = op->args[0].as<VarNode>(); |
85 | ICHECK(buffer_var); |
86 | // Get shape |
87 | const IntImmNode* m = op->args[1].as<IntImmNode>(); |
88 | const IntImmNode* n = op->args[2].as<IntImmNode>(); |
89 | const IntImmNode* k = op->args[3].as<IntImmNode>(); |
90 | ICHECK(m); |
91 | ICHECK(n); |
92 | ICHECK(k); |
93 | |
94 | std::string scope = GetPtrStorageScope(GetRef<Var>(buffer_var)); |
95 | if (fragments.count(buffer_var)) { |
96 | FragmentInfo info = fragments[buffer_var]; |
97 | ICHECK_EQ(m->value, info.m); |
98 | ICHECK_EQ(n->value, info.n); |
99 | ICHECK_EQ(k->value, info.k); |
100 | } else { |
101 | // default to row major ordering |
102 | FragmentInfo info(m->value, n->value, k->value, "row_major" , scope); |
103 | fragments[buffer_var] = info; |
104 | } |
105 | } |
106 | } |
107 | |
108 | // Get memory scope |
109 | void VisitStmt_(const AttrStmtNode* op) final { StmtExprVisitor::VisitStmt_(op); } |
110 | |
111 | // Fragment metadata for all fragments |
112 | std::unordered_map<const VarNode*, FragmentInfo> fragments; |
113 | }; |
114 | |
115 | std::unordered_map<const VarNode*, FragmentInfo> GetTensorCoreFragmentInfo(const Stmt& stmt) { |
116 | FragmentGetter getter; |
117 | getter(stmt); |
118 | return std::move(getter.fragments); |
119 | } |
120 | |
121 | // Check shape of fragment making sure it is a valid shape for tvm_mma_sync |
122 | class FragmentChecker : public StmtExprVisitor { |
123 | public: |
124 | explicit FragmentChecker(const FragmentGetter& getter) : fragment_getter(getter) {} |
125 | |
126 | void VisitExpr_(const CallNode* op) final { |
127 | StmtExprVisitor::VisitExpr_(op); |
128 | // Check shape when calling tvm_mma_sync |
129 | if (op->op.same_as(builtin::tvm_mma_sync()) || op->op.same_as(builtin::tvm_bmma_sync())) { |
130 | ICHECK_EQ(op->args.size(), 8U); |
131 | const VarNode* buffer_var_d = op->args[0].as<VarNode>(); |
132 | const VarNode* buffer_var_a = op->args[2].as<VarNode>(); |
133 | const VarNode* buffer_var_b = op->args[4].as<VarNode>(); |
134 | const VarNode* buffer_var_c = op->args[6].as<VarNode>(); |
135 | ICHECK(buffer_var_d); |
136 | ICHECK(buffer_var_a); |
137 | ICHECK(buffer_var_b); |
138 | ICHECK(buffer_var_c); |
139 | |
140 | // Check all fragment A, B, C and D have the same shape |
141 | ICHECK(CheckShape(buffer_var_d, buffer_var_a)); |
142 | ICHECK(CheckShape(buffer_var_d, buffer_var_b)); |
143 | ICHECK(CheckShape(buffer_var_d, buffer_var_c)); |
144 | } |
145 | } |
146 | |
147 | private: |
148 | // A tool for checking shapes of two fragments |
149 | bool CheckShape(const VarNode* buffer1, const VarNode* buffer2) { |
150 | CHECK(fragment_getter.fragments.count(buffer1)) |
151 | << "Tensorecore fragment " << buffer1->name_hint |
152 | << " must be filled (with tvm_fill_fragment) or loaded (with tvm_load_matrix_sync) before " |
153 | "use." ; |
154 | CHECK(fragment_getter.fragments.count(buffer2)) |
155 | << "Tensorecore fragment " << buffer2->name_hint |
156 | << " must be filled (with tvm_fill_fragment) or loaded (with tvm_load_matrix_sync) before " |
157 | "use." ; |
158 | FragmentInfo info1 = fragment_getter.fragments.at(buffer1); |
159 | FragmentInfo info2 = fragment_getter.fragments.at(buffer2); |
160 | return info1.m == info2.m && info1.n == info2.n && info1.k == info2.k; |
161 | } |
162 | // Fragment infomation |
163 | const FragmentGetter& fragment_getter; |
164 | }; |
165 | |
166 | // Store the metadata into attributes |
167 | class InferFragmenter : public StmtMutator { |
168 | public: |
169 | explicit InferFragmenter(const FragmentGetter& getter) : fragment_getter(getter) {} |
170 | |
171 | Stmt VisitStmt_(const AllocateNode* op) final { |
172 | Stmt stmt = StmtMutator::VisitStmt_(op); |
173 | const VarNode* buffer = op->buffer_var.get(); |
174 | if (fragment_getter.fragments.count(buffer)) { |
175 | // Add attribute to fragments allocation |
176 | FragmentInfo info = fragment_getter.fragments.at(buffer); |
177 | |
178 | // Add shape attribute to all fragments |
179 | std::string shape = |
180 | std::to_string(info.m) + ", " + std::to_string(info.n) + ", " + std::to_string(info.k); |
181 | PrimExpr shape_expr = StringImm(shape); |
182 | Stmt shape_attr = AttrStmt(op->buffer_var, attr::fragment_shape, shape_expr, stmt); |
183 | if (info.layout != "" ) { |
184 | // Add shape attribute to matrix_a and matrix_b |
185 | Stmt layout_attr = |
186 | AttrStmt(op->buffer_var, attr::fragment_layout, StringImm(info.layout), shape_attr); |
187 | return layout_attr; |
188 | } else { |
189 | return shape_attr; |
190 | } |
191 | } |
192 | return stmt; |
193 | } |
194 | |
195 | private: |
196 | // Fragment infomation |
197 | const FragmentGetter& fragment_getter; |
198 | }; |
199 | |
200 | Stmt InferFragment(Stmt stmt) { |
201 | FragmentGetter getter; |
202 | getter(stmt); |
203 | FragmentChecker checker(getter); |
204 | checker(stmt); |
205 | stmt = InferFragmenter(getter)(std::move(stmt)); |
206 | return stmt; |
207 | } |
208 | |
209 | namespace transform { |
210 | |
211 | Pass InferFragment() { |
212 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
213 | auto* n = f.CopyOnWrite(); |
214 | n->body = InferFragment(std::move(n->body)); |
215 | return f; |
216 | }; |
217 | return CreatePrimFuncPass(pass_func, 0, "tir.InferFragment" , {}); |
218 | } |
219 | |
220 | TVM_REGISTER_GLOBAL("tir.transform.InferFragment" ).set_body_typed(InferFragment); |
221 | |
222 | } // namespace transform |
223 | } // namespace tir |
224 | } // namespace tvm |
225 | |