1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * Lower warp memory to use local memory
22 * and shuffle intrinsics.
23 *
24 * \file lower_warp_memory.cc
25 */
26// Thanks to Andrew Adams and Vinod Grover for
27// explaining the concept of warp shuffle.
28#include <tvm/arith/analyzer.h>
29#include <tvm/arith/pattern.h>
30#include <tvm/runtime/registry.h>
31#include <tvm/target/target.h>
32#include <tvm/tir/analysis.h>
33#include <tvm/tir/builtin.h>
34#include <tvm/tir/expr.h>
35#include <tvm/tir/op.h>
36#include <tvm/tir/stmt_functor.h>
37#include <tvm/tir/transform.h>
38
39#include <unordered_set>
40
41#include "../../arith/pattern_match.h"
42#include "../../runtime/thread_storage_scope.h"
43#include "ir_utils.h"
44#include "update_pointer_storage_scope.h"
45
46namespace tvm {
47namespace tir {
48
49// Rewrite Rule
50//
51// There is no special warp memory in most GPUs.
52// Instead, we can stripe the data into threads
53// and store the data into local memory.
54//
55// This requires us to do the following rewriting:
56// - Rewrite allocation to use local memory.
57// - Rewrite store of warp memory to local store.
58// - Rewrite load of warp memory to local plus a shuffle.
59//
60// Define a generic shuffle intrinsic warp_shuffle(data, warp_index).
61// We can use the following rewriting rule
62//
63// Before rewrite,
64//
65// alloc warp warp_mem[n * width * m]
66// store warp_mem[m * warp_index + (width * m) * y + x]
67// load warp_mem[m * z + (width * m) * y + x]
68// subject to x \in [0, m), y \in [0, n)
69//
70// where width equals to the extent of threadIdx.x, which should
71// be no larger than the warp size
72//
73// After rewrite:
74//
75// alloc local local_mem[n * m]
76// store warp_mem[m * y + x]
77// warp_shuffle(load warp_mem[m * y + x], z)
78// subject to (m * y + x) is invariant to warp_index
79//
80// If width == warp size, we are shuffling on full warps.
81// Otherwise, we are virtually shuffling on sub-warps,
82// whose size equals to width. In this case, you can imagine
83// a warp only consists of `width` threads. Width is passed
84// as an argument to the shuffle primitive, and will be
85// lowered to the device code if the target supports.
86//
87// A limitation of this sub-warp approach is that users
88// cannot shuffle across the sub-warp boundary (i.e. shuffle
89// with threadIdx.y or threadIdx.z indices). It can be solved
90// via fusing threadIdx.x to the warp size, or improving the
91// analyzer to detect both 3 thread axes, which is left for
92// future improvements.
93
94// Algorithm
95//
96// To implement this rewrite rule, we can do the follow step:
97// For each warp memory alloc
98// - Use linear pattern detector on load index to find m
99// - Deduce n given width and alloc size
100// - Now that we have m, n, width, we can proceed with the rewrite
101
102// Visitor to find m in pattern
103// store warp_mem[m * warp_index + (width * m) * y + x]
104class WarpStoreCoeffFinder : private StmtExprVisitor {
105 public:
106 WarpStoreCoeffFinder(const VarNode* buffer, Var warp_index, arith::Analyzer* analyzer)
107 : buffer_(buffer), warp_index_(warp_index), analyzer_(analyzer) {}
108 // find the warp co-efficient in the statement given the warp size
109 int Find(const Stmt& stmt) {
110 this->VisitStmt(stmt);
111 return warp_coeff_;
112 }
113
114 private:
115 /// Visitor implementation
116 void VisitExpr_(const CallNode* op) final {
117 if (op->op.same_as(builtin::ptx_ldmatrix()) && op->args[3].as<VarNode>() == buffer_) {
118 UpdatePattern(op->args[4]);
119 } else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as<VarNode>() == buffer_) {
120 auto* local_size = op->args[0].as<IntImmNode>();
121 ICHECK(local_size) << "Integer expected for the first argument of mma_fill";
122 warp_coeff_ = local_size->value;
123 }
124
125 StmtExprVisitor::VisitExpr_(op);
126 }
127
128 void VisitStmt_(const StoreNode* op) final {
129 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
130 }
131
132 void VisitStmt_(const BufferStoreNode* op) final {
133 if (op->buffer->data.get() != buffer_) {
134 StmtVisitor::VisitStmt_(op);
135 return;
136 }
137
138 ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. "
139 << "Has StorageFlatten (TE-based schedule) or "
140 << "FlattenBuffer (TIR-based schedules) been run?";
141
142 PrimExpr index = op->indices[0];
143 if (op->value.dtype().lanes() != 1) {
144 arith::PVar<PrimExpr> base;
145 ICHECK(arith::ramp(base, 1, op->value.dtype().lanes()).Match(index))
146 << "LowerWarpMemory failed due to store index=" << index
147 << ", can only handle continuous store";
148 UpdatePattern(base.Eval());
149
150 index = base.Eval();
151 }
152
153 UpdatePattern(index);
154 }
155
156 void UpdatePattern(const PrimExpr& index) {
157 Array<PrimExpr> m = arith::DetectLinearEquation(index, {warp_index_});
158 ICHECK_EQ(m.size(), 2U)
159 << "LowerWarpMemory failed. Could not simplify the store index `" << index
160 << "` into the form ax + by + cz + ... Warp memory is approximated by storing values in "
161 "thread local registers and shuffling values between these registers. Currently only "
162 "linear equation indices are supported.";
163 PrimExpr mcoeff = analyzer_->canonical_simplify(m[0]);
164 const auto* mcoeff_as_int = mcoeff.as<IntImmNode>();
165 ICHECK(mcoeff_as_int && mcoeff_as_int->value > 0)
166 << "LowerWarpMemory failed due to store index=" << index
167 << ", require positive constant coefficient on warp index " << warp_index_ << " but get "
168 << mcoeff;
169
170 if (warp_coeff_ != 0) {
171 ICHECK_EQ(warp_coeff_, mcoeff_as_int->value)
172 << "LowerWarpMemory failed due to two different store coefficient to warp index";
173 } else {
174 warp_coeff_ = mcoeff_as_int->value;
175 }
176 }
177
178 // The buffer variable
179 const VarNode* buffer_;
180 // the warp index
181 Var warp_index_;
182 // the coefficient
183 int64_t warp_coeff_{0};
184 // analyzer.
185 arith::Analyzer* analyzer_;
186};
187
188// Visitor to find the warp index
189class WarpIndexFinder : private StmtVisitor {
190 public:
191 explicit WarpIndexFinder(int warp_size) : warp_size_(warp_size) {}
192 // find the warp co-efficient and the shuffle width in the statement
193 std::pair<Var, int> Find(const Stmt& stmt) {
194 this->VisitStmt(stmt);
195 ICHECK(warp_index_.defined())
196 << "Cannot find warp index(threadIdx.x) within the scope of warp memory";
197 return std::make_pair(warp_index_->var, width_);
198 }
199
200 private:
201 /// Visitor implementation
202 void VisitStmt_(const AttrStmtNode* op) final {
203 if (op->attr_key == attr::thread_extent) {
204 IterVar iv = Downcast<IterVar>(op->node);
205 if (iv->thread_tag == "threadIdx.x") {
206 auto* value_as_int = op->value.as<IntImmNode>();
207 ICHECK(value_as_int && value_as_int->value <= warp_size_ &&
208 warp_size_ % value_as_int->value == 0)
209 << "Expect threadIdx.x 's size to be no larger than, and a factor of"
210 << " warp size(" << warp_size_ << ")"
211 << " to enable warp memory"
212 << " but get " << op->value << " instead";
213 if (warp_index_.defined()) {
214 ICHECK(warp_index_.same_as(iv))
215 << "Find two instance of " << warp_index_->thread_tag << " in the same kernel. "
216 << "Please create it using thread_axis once and reuse the axis "
217 << "across multiple binds in the same kernel";
218 } else {
219 width_ = value_as_int->value;
220 warp_index_ = iv;
221 }
222 }
223 }
224 StmtVisitor::VisitStmt_(op);
225 }
226 // warp size
227 int warp_size_{0};
228 // number of threads involved in one shuffle
229 int width_{0};
230 // the warp index
231 IterVar warp_index_{nullptr};
232};
233// Mutator to change the read pattern
234class WarpAccessRewriter : protected StmtExprMutator {
235 public:
236 explicit WarpAccessRewriter(int warp_size, arith::Analyzer* analyzer)
237 : warp_size_(warp_size), analyzer_(analyzer) {}
238 // Rewrite the allocate statement which transforms
239 // warp memory to local memory.
240 Stmt Rewrite(const AllocateNode* op) {
241 buffer_ = op->buffer_var.get();
242 int alloc_size = op->ConstantAllocationSize();
243 ICHECK_GT(alloc_size, 0) << "warp memory only support constant alloc size";
244 alloc_size *= op->dtype.lanes();
245 std::tie(warp_index_, width_) = WarpIndexFinder(warp_size_).Find(op->body);
246 warp_coeff_ = WarpStoreCoeffFinder(buffer_, warp_index_, analyzer_).Find(op->body);
247
248 // Align the local memory size. The number of elements may not
249 // be a multiple of width_ * warp_coeff_; round it up.
250 int factor = width_ * warp_coeff_;
251 ICHECK_NE(factor, 0) << "Divide by zero";
252 warp_group_ = (alloc_size + (factor - 1)) / factor;
253 alloc_size = warp_group_ * factor;
254
255 return Allocate(op->buffer_var, op->dtype, {make_const(DataType::Int(32), alloc_size / width_)},
256 op->condition, this->VisitStmt(op->body));
257 }
258
259 protected:
260 PrimExpr RewriteIndicesAt(const CallNode* op, const std::vector<int>& indices) {
261 Array<PrimExpr> new_args = op->args;
262 for (int i : indices) {
263 if (op->args[i].get() == buffer_) {
264 PrimExpr local_index = SplitIndexByGroup(op->args[i + 1]).first;
265 new_args.Set(i + 1, local_index);
266 }
267 }
268 return Call(op->dtype, op->op, new_args);
269 }
270
271 PrimExpr VisitExpr_(const CallNode* op) override {
272 if (op->op.same_as(builtin::ptx_mma())) {
273 return RewriteIndicesAt(op, {6, 8, 10});
274 }
275
276 if (op->op.same_as(builtin::ptx_ldmatrix())) {
277 return RewriteIndicesAt(op, {3});
278 }
279
280 if (op->op.same_as(builtin::mma_store())) {
281 return RewriteIndicesAt(op, {3});
282 }
283
284 if (op->op.same_as(builtin::mma_fill())) {
285 return RewriteIndicesAt(op, {1});
286 }
287
288 return StmtExprMutator::VisitExpr_(op);
289 }
290
291 PrimExpr VisitExpr_(const VarNode* op) override {
292 ICHECK(op != buffer_) << "Cannot access address of warp memory directly";
293 return StmtExprMutator::VisitExpr_(op);
294 }
295
296 Stmt VisitStmt_(const StoreNode* op) override {
297 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
298 }
299
300 PrimExpr VisitExpr_(const LoadNode* op) override {
301 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
302 }
303
304 Stmt VisitStmt_(const BufferStoreNode* op) override {
305 auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
306
307 if (store->buffer->data.get() == buffer_) {
308 ICHECK_EQ(store->indices.size(), 1) << "Expected flat memory to use as warp memory. "
309 << "Has StorageFlatten (TE-based schedule) or "
310 << "FlattenBuffer (TIR-based schedules) been run?";
311
312 auto [local_index, group] = SplitIndexByGroup(store->indices[0]);
313 (void)group; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
314
315 auto writer = store.CopyOnWrite();
316 writer->indices = {local_index};
317 }
318
319 return std::move(store);
320 }
321
322 PrimExpr VisitExpr_(const BufferLoadNode* op) override {
323 auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
324
325 if (load->buffer->data.get() != buffer_) {
326 return std::move(load);
327 }
328
329 ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. "
330 << "Has StorageFlatten (TE-based schedule) or "
331 << "FlattenBuffer (TIR-based schedules) been run?";
332
333 auto [local_index, group] = SplitIndexByGroup(op->indices[0]);
334 // invariance: local index must do not contain warp id
335 ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); }))
336 << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0]
337 << " local_index=" << local_index;
338
339 auto writer = load.CopyOnWrite();
340 writer->indices = {local_index};
341
342 if (analyzer_->CanProveEqual(group, warp_index_)) {
343 return std::move(load);
344 }
345
346 PrimExpr mask = Call(DataType::UInt(32), builtin::tvm_warp_activemask(), {});
347 return Call(load.dtype(), builtin::tvm_warp_shuffle(), {mask, load, group, width_, warp_size_});
348 }
349
350 // Split the index to the two component
351 // <local_index, source_index>
352 // local index is the index in the local
353 // source index is the corresponding source index
354 // in this access pattern.
355 std::pair<PrimExpr, PrimExpr> SplitIndexByGroup(const PrimExpr& index) {
356 if (index.dtype().lanes() != 1) {
357 arith::PVar<PrimExpr> base;
358 ICHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index));
359
360 auto [local_index, group] = SplitIndexByGroup(base.Eval());
361 local_index = Ramp(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes());
362 return std::make_pair(local_index, group);
363 }
364 PrimExpr m = make_const(index.dtype(), warp_coeff_);
365
366 // simple case, warp index is on the highest.
367 if (warp_group_ == 1) {
368 PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
369 PrimExpr z = analyzer_->canonical_simplify(indexdiv(index, m));
370 return std::make_pair(x, z);
371 } else {
372 PrimExpr x = analyzer_->canonical_simplify(indexmod(index, m));
373 PrimExpr y = index / make_const(index.dtype(), warp_coeff_ * width_);
374 y = y * m + x;
375 PrimExpr z = indexdiv(indexmod(index, make_const(index.dtype(), warp_coeff_ * width_)), m);
376 return std::make_pair(analyzer_->canonical_simplify(y), analyzer_->canonical_simplify(z));
377 }
378 }
379
380 private:
381 // the warp size
382 int warp_size_{0};
383 // The buffer variable
384 const VarNode* buffer_;
385 // number of threads involved in one shuffle
386 int width_{0};
387 // Warp index
388 Var warp_index_;
389 // the coefficient m
390 int warp_coeff_{0};
391 // the coefficient n
392 int warp_group_{0};
393 // Internal analyzer
394 arith::Analyzer* analyzer_;
395};
396
397// Bind bound information of variables to make analyzer more effective
398// TODO(tqchen): consider a pass to inline the bound info into the expr
399// so analysis can be context independent.
400class BindVarBoundInfo : public StmtVisitor {
401 public:
402 explicit BindVarBoundInfo(arith::Analyzer* analyzer) : analyzer_(analyzer) {}
403
404 void VisitStmt_(const ForNode* op) final {
405 const Var& loop_var = op->loop_var;
406 analyzer_->Bind(loop_var, Range::FromMinExtent(op->min, op->extent));
407 StmtVisitor::VisitStmt_(op);
408 }
409
410 void VisitStmt_(const AttrStmtNode* op) {
411 if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
412 IterVar iv = Downcast<IterVar>(op->node);
413 ICHECK_NE(iv->thread_tag.length(), 0U);
414 if (!var_dom_.count(iv->var.get())) {
415 Range dom = Range::FromMinExtent(0, op->value);
416 var_dom_[iv->var.get()] = dom;
417 analyzer_->Bind(iv->var, dom);
418 }
419 }
420 StmtVisitor::VisitStmt_(op);
421 }
422
423 protected:
424 // internal analyzer.
425 arith::Analyzer* analyzer_;
426 // variable domain
427 std::unordered_map<const VarNode*, Range> var_dom_;
428};
429
430// Mutator to change the read pattern
431class WarpMemoryRewriter : private StmtMutator {
432 public:
433 explicit WarpMemoryRewriter(int warp_size) : warp_size_(warp_size) {}
434
435 Stmt Rewrite(Stmt stmt) {
436 if (warp_size_ == 1) return stmt;
437 BindVarBoundInfo binder(&analyzer_);
438 binder(stmt);
439 stmt = operator()(std::move(stmt));
440 return stmt;
441 }
442
443 std::unordered_map<const VarNode*, String> new_storage_scopes_;
444
445 private:
446 Stmt VisitStmt_(const AllocateNode* op) {
447 auto ret = StmtMutator::VisitStmt_(op);
448 op = ret.as<AllocateNode>();
449 if (GetPtrStorageScope(op->buffer_var) == "warp") {
450 new_storage_scopes_[op->buffer_var.get()] = "local";
451 WarpAccessRewriter rewriter(warp_size_, &analyzer_);
452 ret = rewriter.Rewrite(op);
453 }
454 return ret;
455 }
456
457 int warp_size_{0};
458 arith::Analyzer analyzer_;
459 // variable domain
460 std::unordered_map<const VarNode*, Range> var_dom_;
461};
462
463namespace transform {
464
465Pass LowerWarpMemory() {
466 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
467 auto* n = f.CopyOnWrite();
468 auto target = f->GetAttr<Target>(tvm::attr::kTarget);
469 ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";
470 int warp_size = target.value()->GetAttr<Integer>("thread_warp_size", 1).value().IntValue();
471 WarpMemoryRewriter warp_memory_rewriter(warp_size);
472 auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
473 n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);
474 return f;
475 };
476 return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {});
477}
478
479TVM_REGISTER_GLOBAL("tir.transform.LowerWarpMemory").set_body_typed(LowerWarpMemory);
480
481} // namespace transform
482
483} // namespace tir
484} // namespace tvm
485