1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * Lower warp memory to use local memory |
22 | * and shuffle intrinsics. |
23 | * |
24 | * \file lower_warp_memory.cc |
25 | */ |
26 | // Thanks to Andrew Adams and Vinod Grover for |
27 | // explaining the concept of warp shuffle. |
28 | #include <tvm/arith/analyzer.h> |
29 | #include <tvm/arith/pattern.h> |
30 | #include <tvm/runtime/registry.h> |
31 | #include <tvm/target/target.h> |
32 | #include <tvm/tir/analysis.h> |
33 | #include <tvm/tir/builtin.h> |
34 | #include <tvm/tir/expr.h> |
35 | #include <tvm/tir/op.h> |
36 | #include <tvm/tir/stmt_functor.h> |
37 | #include <tvm/tir/transform.h> |
38 | |
39 | #include <unordered_set> |
40 | |
41 | #include "../../arith/pattern_match.h" |
42 | #include "../../runtime/thread_storage_scope.h" |
43 | #include "ir_utils.h" |
44 | #include "update_pointer_storage_scope.h" |
45 | |
46 | namespace tvm { |
47 | namespace tir { |
48 | |
49 | // Rewrite Rule |
50 | // |
51 | // There is no special warp memory in most GPUs. |
52 | // Instead, we can stripe the data into threads |
53 | // and store the data into local memory. |
54 | // |
55 | // This requires us to do the following rewriting: |
56 | // - Rewrite allocation to use local memory. |
57 | // - Rewrite store of warp memory to local store. |
58 | // - Rewrite load of warp memory to local plus a shuffle. |
59 | // |
60 | // Define a generic shuffle intrinsic warp_shuffle(data, warp_index). |
61 | // We can use the following rewriting rule |
62 | // |
63 | // Before rewrite, |
64 | // |
65 | // alloc warp warp_mem[n * width * m] |
66 | // store warp_mem[m * warp_index + (width * m) * y + x] |
67 | // load warp_mem[m * z + (width * m) * y + x] |
68 | // subject to x \in [0, m), y \in [0, n) |
69 | // |
70 | // where width equals to the extent of threadIdx.x, which should |
71 | // be no larger than the warp size |
72 | // |
73 | // After rewrite: |
74 | // |
75 | // alloc local local_mem[n * m] |
76 | // store warp_mem[m * y + x] |
77 | // warp_shuffle(load warp_mem[m * y + x], z) |
78 | // subject to (m * y + x) is invariant to warp_index |
79 | // |
80 | // If width == warp size, we are shuffling on full warps. |
81 | // Otherwise, we are virtually shuffling on sub-warps, |
82 | // whose size equals to width. In this case, you can imagine |
83 | // a warp only consists of `width` threads. Width is passed |
84 | // as an argument to the shuffle primitive, and will be |
85 | // lowered to the device code if the target supports. |
86 | // |
87 | // A limitation of this sub-warp approach is that users |
88 | // cannot shuffle across the sub-warp boundary (i.e. shuffle |
89 | // with threadIdx.y or threadIdx.z indices). It can be solved |
90 | // via fusing threadIdx.x to the warp size, or improving the |
91 | // analyzer to detect both 3 thread axes, which is left for |
92 | // future improvements. |
93 | |
94 | // Algorithm |
95 | // |
96 | // To implement this rewrite rule, we can do the follow step: |
97 | // For each warp memory alloc |
98 | // - Use linear pattern detector on load index to find m |
99 | // - Deduce n given width and alloc size |
100 | // - Now that we have m, n, width, we can proceed with the rewrite |
101 | |
102 | // Visitor to find m in pattern |
103 | // store warp_mem[m * warp_index + (width * m) * y + x] |
104 | class WarpStoreCoeffFinder : private StmtExprVisitor { |
105 | public: |
106 | WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer) |
107 | : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {} |
108 | // find the warp co-efficient in the statement given the warp size |
109 | int Find(const Stmt& stmt) { |
110 | this->VisitStmt(stmt); |
111 | return warp_coeff_; |
112 | } |
113 | |
114 | private: |
115 | /// Visitor implementation |
116 | void VisitExpr_(const CallNode* op) final { |
117 | if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as<VarNode>() == buffer_) { |
118 | UpdatePattern(op->args[4]); |
119 | } else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as<VarNode>() == buffer_) { |
120 | auto* local_size = op->args[0].as<IntImmNode>(); |
121 | ICHECK(local_size) << "Integer expected for the first argument of mma_fill" ; |
122 | warp_coeff_ = local_size->value; |
123 | } |
124 | |
125 | StmtExprVisitor::VisitExpr_(op); |
126 | } |
127 | |
128 | void VisitStmt_(const StoreNode* op) final { |
129 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
130 | } |
131 | |
132 | void VisitStmt_(const BufferStoreNode* op) final { |
133 | if (op->buffer->data.get() != buffer_) { |
134 | StmtVisitor::VisitStmt_(op); |
135 | return; |
136 | } |
137 | |
138 | ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " |
139 | << "Has StorageFlatten (TE-based schedule) or " |
140 | << "FlattenBuffer (TIR-based schedules) been run?" ; |
141 | |
142 | PrimExpr index = op->indices[0]; |
143 | if (op->value.dtype().lanes() != 1) { |
144 | arith::PVar<PrimExpr> base; |
145 | ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(index)) |
146 | << "LowerWarpMemory failed due to store index=" << index |
147 | << ", can only handle continuous store" ; |
148 | UpdatePattern(base.Eval()); |
149 | |
150 | index = base.Eval(); |
151 | } |
152 | |
153 | UpdatePattern(index); |
154 | } |
155 | |
156 | void UpdatePattern(const PrimExpr& index) { |
157 | Array<PrimExpr> m = arith::DetectLinearEquation(index, {warp_index_}); |
158 | ICHECK_EQ(m.size(), 2U) |
159 | << "LowerWarpMemory failed. Could not simplify the store index `" << index |
160 | << "` into the form ax + by + cz + ... Warp memory is approximated by storing values in " |
161 | "thread local registers and shuffling values between these registers. Currently only " |
162 | "linear equation indices are supported." ; |
163 | PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]); |
164 | const auto* mcoeff_as_int = mcoeff.as<IntImmNode>(); |
165 | ICHECK(mcoeff_as_int && mcoeff_as_int->value > 0) |
166 | << "LowerWarpMemory failed due to store index=" << index |
167 | << ", require positive constant coefficient on warp index " << warp_index_ << " but get " |
168 | << mcoeff; |
169 | |
170 | if (warp_coeff_ != 0) { |
171 | ICHECK_EQ(warp_coeff_, mcoeff_as_int->value) |
172 | << "LowerWarpMemory failed due to two different store coefficient to warp index" ; |
173 | } else { |
174 | warp_coeff_ = mcoeff_as_int->value; |
175 | } |
176 | } |
177 | |
178 | // The buffer variable |
179 | const VarNode* buffer_; |
180 | // the warp index |
181 | Var warp_index_; |
182 | // the coefficient |
183 | int64_t warp_coeff_{0}; |
184 | // analyzer. |
185 | arith::Analyzer* analyzer_; |
186 | }; |
187 | |
188 | // Visitor to find the warp index |
189 | class WarpIndexFinder : private StmtVisitor { |
190 | public: |
191 | explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) {} |
192 | // find the warp co-efficient and the shuffle width in the statement |
193 | std::pair<Var, int> Find(const Stmt& stmt) { |
194 | this->VisitStmt(stmt); |
195 | ICHECK(warp_index_.defined()) |
196 | << "Cannot find warp index(threadIdx.x) within the scope of warp memory" ; |
197 | return std::make_pair(warp_index_->var, width_); |
198 | } |
199 | |
200 | private: |
201 | /// Visitor implementation |
202 | void VisitStmt_(const AttrStmtNode* op) final { |
203 | if (op->attr_key == attr::thread_extent) { |
204 | IterVar iv = Downcast<IterVar>(op->node); |
205 | if (iv->thread_tag == "threadIdx.x" ) { |
206 | auto* value_as_int = op->value.as<IntImmNode>(); |
207 | ICHECK(value_as_int && value_as_int->value <= warp_size_ && |
208 | warp_size_ % value_as_int->value == 0) |
209 | << "Expect threadIdx.x 's size to be no larger than, and a factor of" |
210 | << " warp size(" << warp_size_ << ")" |
211 | << " to enable warp memory" |
212 | << " but get " << op->value << " instead" ; |
213 | if (warp_index_.defined()) { |
214 | ICHECK(warp_index_.same_as(iv)) |
215 | << "Find two instance of " << warp_index_->thread_tag << " in the same kernel. " |
216 | << "Please create it using thread_axis once and reuse the axis " |
217 | << "across multiple binds in the same kernel" ; |
218 | } else { |
219 | width_ = value_as_int->value; |
220 | warp_index_ = iv; |
221 | } |
222 | } |
223 | } |
224 | StmtVisitor::VisitStmt_(op); |
225 | } |
226 | // warp size |
227 | int warp_size_{0}; |
228 | // number of threads involved in one shuffle |
229 | int width_{0}; |
230 | // the warp index |
231 | IterVar warp_index_{nullptr}; |
232 | }; |
233 | // Mutator to change the read pattern |
234 | class WarpAccessRewriter : protected StmtExprMutator { |
235 | public: |
236 | explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer) |
237 | : warp_size_(warp_size), analyzer_(analyzer) {} |
238 | // Rewrite the allocate statement which transforms |
239 | // warp memory to local memory. |
240 | Stmt Rewrite(const AllocateNode* op) { |
241 | buffer_ = op->buffer_var.get(); |
242 | int alloc_size = op->ConstantAllocationSize(); |
243 | ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size" ; |
244 | alloc_size *= op->dtype.lanes(); |
245 | std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body); |
246 | warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body); |
247 | |
248 | // Align the local memory size. The number of elements may not |
249 | // be a multiple of width_ * warp_coeff_; round it up. |
250 | int factor = width_ * warp_coeff_; |
251 | ICHECK_NE(factor, 0) << "Divide by zero" ; |
252 | warp_group_ = (alloc_size + (factor - 1)) / factor; |
253 | alloc_size = warp_group_ * factor; |
254 | |
255 | return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)}, |
256 | op->condition, this->VisitStmt(op->body)); |
257 | } |
258 | |
259 | protected: |
260 | PrimExpr RewriteIndicesAt(const CallNode* op, const std::vector<int>& indices) { |
261 | Array<PrimExpr> new_args = op->args; |
262 | for (int i : indices) { |
263 | if (op->args[i].get() == buffer_) { |
264 | PrimExpr local_index = SplitIndexByGroup(op->args[i + 1]).first; |
265 | new_args.Set(i + 1, local_index); |
266 | } |
267 | } |
268 | return Call(op->dtype, op->op, new_args); |
269 | } |
270 | |
271 | PrimExpr VisitExpr_(const CallNode* op) override { |
272 | if (op->op.same_as(builtin::ptx_mma())) { |
273 | return RewriteIndicesAt(op, {6, 8, 10}); |
274 | } |
275 | |
276 | if (op->op.same_as(builtin::ptx_ldmatrix())) { |
277 | return RewriteIndicesAt(op, {3}); |
278 | } |
279 | |
280 | if (op->op.same_as(builtin::mma_store())) { |
281 | return RewriteIndicesAt(op, {3}); |
282 | } |
283 | |
284 | if (op->op.same_as(builtin::mma_fill())) { |
285 | return RewriteIndicesAt(op, {1}); |
286 | } |
287 | |
288 | return StmtExprMutator::VisitExpr_(op); |
289 | } |
290 | |
291 | PrimExpr VisitExpr_(const VarNode* op) override { |
292 | ICHECK(op != buffer_) << "Cannot access address of warp memory directly" ; |
293 | return StmtExprMutator::VisitExpr_(op); |
294 | } |
295 | |
296 | Stmt VisitStmt_(const StoreNode* op) override { |
297 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
298 | } |
299 | |
300 | PrimExpr VisitExpr_(const LoadNode* op) override { |
301 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
302 | } |
303 | |
304 | Stmt VisitStmt_(const BufferStoreNode* op) override { |
305 | auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
306 | |
307 | if (store->buffer->data.get() == buffer_) { |
308 | ICHECK_EQ(store->indices.size(), 1) << "Expected flat memory to use as warp memory. " |
309 | << "Has StorageFlatten (TE-based schedule) or " |
310 | << "FlattenBuffer (TIR-based schedules) been run?" ; |
311 | |
312 | auto [local_index, group] = SplitIndexByGroup(store->indices[0]); |
313 | (void)group; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 |
314 | |
315 | auto writer = store.CopyOnWrite(); |
316 | writer->indices = {local_index}; |
317 | } |
318 | |
319 | return std::move(store); |
320 | } |
321 | |
322 | PrimExpr VisitExpr_(const BufferLoadNode* op) override { |
323 | auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
324 | |
325 | if (load->buffer->data.get() != buffer_) { |
326 | return std::move(load); |
327 | } |
328 | |
329 | ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " |
330 | << "Has StorageFlatten (TE-based schedule) or " |
331 | << "FlattenBuffer (TIR-based schedules) been run?" ; |
332 | |
333 | auto [local_index, group] = SplitIndexByGroup(op->indices[0]); |
334 | // invariance: local index must do not contain warp id |
335 | ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) |
336 | << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0] |
337 | << " local_index=" << local_index; |
338 | |
339 | auto writer = load.CopyOnWrite(); |
340 | writer->indices = {local_index}; |
341 | |
342 | if (analyzer_->CanProveEqual(group, warp_index_)) { |
343 | return std::move(load); |
344 | } |
345 | |
346 | PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {}); |
347 | return Call(load.dtype(), builtin::tvm_warp_shuffle(), {mask, load, group, width_, warp_size_}); |
348 | } |
349 | |
350 | // Split the index to the two component |
351 | // <local_index, source_index> |
352 | // local index is the index in the local |
353 | // source index is the corresponding source index |
354 | // in this access pattern. |
355 | std::pair<PrimExpr, PrimExpr> SplitIndexByGroup(const PrimExpr& index) { |
356 | if (index.dtype().lanes() != 1) { |
357 | arith::PVar<PrimExpr> base; |
358 | ICHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index)); |
359 | |
360 | auto [local_index, group] = SplitIndexByGroup(base.Eval()); |
361 | local_index = Ramp(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); |
362 | return std::make_pair(local_index, group); |
363 | } |
364 | PrimExpr m = make_const(index.dtype(), warp_coeff_); |
365 | |
366 | // simple case, warp index is on the highest. |
367 | if (warp_group_ == 1) { |
368 | PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); |
369 | PrimExpr z = analyzer_->canonical_simplify(indexdiv(index, m)); |
370 | return std::make_pair(x, z); |
371 | } else { |
372 | PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m)); |
373 | PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_); |
374 | y = y * m + x; |
375 | PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)), m); |
376 | return std::make_pair(analyzer_->canonical_simplify(y), analyzer_->canonical_simplify(z)); |
377 | } |
378 | } |
379 | |
380 | private: |
381 | // the warp size |
382 | int warp_size_{0}; |
383 | // The buffer variable |
384 | const VarNode* buffer_; |
385 | // number of threads involved in one shuffle |
386 | int width_{0}; |
387 | // Warp index |
388 | Var warp_index_; |
389 | // the coefficient m |
390 | int warp_coeff_{0}; |
391 | // the coefficient n |
392 | int warp_group_{0}; |
393 | // Internal analyzer |
394 | arith::Analyzer* analyzer_; |
395 | }; |
396 | |
397 | // Bind bound information of variables to make analyzer more effective |
398 | // TODO(tqchen): consider a pass to inline the bound info into the expr |
399 | // so analysis can be context independent. |
400 | class BindVarBoundInfo : public StmtVisitor { |
401 | public: |
402 | explicit BindVarBoundInfo(arith::Analyzer* analyzer) : analyzer_(analyzer) {} |
403 | |
404 | void VisitStmt_(const ForNode* op) final { |
405 | const Var& loop_var = op->loop_var; |
406 | analyzer_->Bind(loop_var, Range::FromMinExtent(op->min, op->extent)); |
407 | StmtVisitor::VisitStmt_(op); |
408 | } |
409 | |
410 | void VisitStmt_(const AttrStmtNode* op) { |
411 | if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { |
412 | IterVar iv = Downcast<IterVar>(op->node); |
413 | ICHECK_NE(iv->thread_tag.length(), 0U); |
414 | if (!var_dom_.count(iv->var.get())) { |
415 | Range dom = Range::FromMinExtent(0, op->value); |
416 | var_dom_[iv->var.get()] = dom; |
417 | analyzer_->Bind(iv->var, dom); |
418 | } |
419 | } |
420 | StmtVisitor::VisitStmt_(op); |
421 | } |
422 | |
423 | protected: |
424 | // internal analyzer. |
425 | arith::Analyzer* analyzer_; |
426 | // variable domain |
427 | std::unordered_map<const VarNode*, Range> var_dom_; |
428 | }; |
429 | |
430 | // Mutator to change the read pattern |
431 | class WarpMemoryRewriter : private StmtMutator { |
432 | public: |
433 | explicit WarpMemoryRewriter(int warp_size) : warp_size_(warp_size) {} |
434 | |
435 | Stmt Rewrite(Stmt stmt) { |
436 | if (warp_size_ == 1) return stmt; |
437 | BindVarBoundInfo binder(&analyzer_); |
438 | binder(stmt); |
439 | stmt = operator()(std::move(stmt)); |
440 | return stmt; |
441 | } |
442 | |
443 | std::unordered_map<const VarNode*, String> new_storage_scopes_; |
444 | |
445 | private: |
446 | Stmt VisitStmt_(const AllocateNode* op) { |
447 | auto ret = StmtMutator::VisitStmt_(op); |
448 | op = ret.as<AllocateNode>(); |
449 | if (GetPtrStorageScope(op->buffer_var) == "warp" ) { |
450 | new_storage_scopes_[op->buffer_var.get()] = "local" ; |
451 | WarpAccessRewriter rewriter(warp_size_, &analyzer_); |
452 | ret = rewriter.Rewrite(op); |
453 | } |
454 | return ret; |
455 | } |
456 | |
457 | int warp_size_{0}; |
458 | arith::Analyzer analyzer_; |
459 | // variable domain |
460 | std::unordered_map<const VarNode*, Range> var_dom_; |
461 | }; |
462 | |
463 | namespace transform { |
464 | |
465 | Pass LowerWarpMemory() { |
466 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
467 | auto* n = f.CopyOnWrite(); |
468 | auto target = f->GetAttr<Target>(tvm::attr::kTarget); |
469 | ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute" ; |
470 | int warp_size = target.value()->GetAttr<Integer>("thread_warp_size" , 1).value().IntValue(); |
471 | WarpMemoryRewriter warp_memory_rewriter(warp_size); |
472 | auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); |
473 | n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); |
474 | return f; |
475 | }; |
476 | return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory" , {}); |
477 | } |
478 | |
479 | TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory" ).set_body_typed(LowerWarpMemory); |
480 | |
481 | } // namespace transform |
482 | |
483 | } // namespace tir |
484 | } // namespace tvm |
485 | |