1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file inject_virtual_thread.cc
22 */
23#include <tvm/runtime/registry.h>
24#include <tvm/tir/builtin.h>
25#include <tvm/tir/expr.h>
26#include <tvm/tir/stmt_functor.h>
27#include <tvm/tir/transform.h>
28
29#include <unordered_set>
30
31#include "../../arith/ir_mutator_with_analyzer.h"
32#include "ir_utils.h"
33
34namespace tvm {
35namespace tir {
36
37// If expression is touched by var.
38class ExprTouched final : public StmtExprVisitor {
39 public:
40 explicit ExprTouched(const std::unordered_set<const VarNode*>& touched, bool check_write)
41 : touched_var_(touched), check_write_(check_write) {}
42
43 void VisitExpr(const PrimExpr& n) final {
44 // early stopping
45 if (expr_touched_ && !check_write_) return;
46 StmtExprVisitor::VisitExpr(n);
47 }
48 void VisitStmt(const Stmt& n) final {
49 // early stopping
50 if (expr_touched_ && !check_write_) return;
51 StmtExprVisitor::VisitStmt(n);
52 }
53 void VisitExpr_(const LoadNode* op) final {
54 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
55 }
56 void VisitExpr_(const BufferLoadNode* op) final {
57 HandleUseVar(op->buffer->data.get());
58 StmtExprVisitor::VisitExpr_(op);
59 }
60 void VisitExpr_(const VarNode* op) final { HandleUseVar(op); }
61 void VisitExpr_(const CallNode* op) final {
62 if (op->op.same_as(builtin::tvm_access_ptr())) {
63 const auto* rw_mask = op->args[4].as<IntImmNode>();
64 const VarNode* buffer_var = op->args[1].as<VarNode>();
65 ICHECK(buffer_var);
66 ICHECK(rw_mask);
67 // read
68 if (rw_mask->value & 1) {
69 HandleUseVar(buffer_var);
70 }
71 if (rw_mask->value & 2) {
72 HandleWriteVar(buffer_var);
73 }
74 this->VisitExpr(op->args[2]);
75 } else {
76 StmtExprVisitor::VisitExpr_(op);
77 }
78 }
79 void HandleUseVar(const VarNode* var) {
80 auto it = touched_var_.find(var);
81 if (it != touched_var_.end()) {
82 expr_touched_ = true;
83 }
84 // rember the used vars
85 // in case the var get touched later in a loop.
86 if (!expr_touched_) {
87 used_vars_.push_back(var);
88 }
89 }
90 void HandleWriteVar(const VarNode* var) { write_vars_.push_back(var); }
91 // the fields.
92 bool expr_touched_{false};
93 std::vector<const VarNode*> used_vars_;
94 std::vector<const VarNode*> write_vars_;
95 const std::unordered_set<const VarNode*>& touched_var_;
96 bool check_write_;
97};
98
99// Analyze if the buffers are invariant to value of var
100class VarTouchedAnalysis : public StmtVisitor {
101 public:
102 void VisitStmt_(const LetStmtNode* op) final {
103 ExprTouched tc(touched_var_, false);
104 tc(op->value);
105 Record(op->var.get(), tc);
106 this->VisitStmt(op->body);
107 }
108
109 void VisitStmt_(const StoreNode* op) final {
110 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
111 }
112
113 void VisitStmt_(const BufferStoreNode* op) final {
114 ExprTouched tc(touched_var_, false);
115 tc(op->value);
116 for (const auto& index : op->indices) {
117 tc(index);
118 }
119 Record(op->buffer->data.get(), tc);
120 }
121 void VisitStmt_(const ForNode* op) final {
122 ExprTouched tc(touched_var_, false);
123 tc(op->min);
124 tc(op->extent);
125 Record(op->loop_var.get(), tc);
126 this->VisitStmt(op->body);
127 }
128 // external function call
129 void VisitStmt_(const EvaluateNode* op) final {
130 ExprTouched tc(touched_var_, true);
131 tc(op->value);
132 for (const VarNode* var : tc.write_vars_) {
133 Record(var, tc);
134 }
135 }
136 void VisitStmt_(const AllocateNode* op) final {
137 ExprTouched tc(touched_var_, false);
138 for (size_t i = 0; i < op->extents.size(); ++i) {
139 tc(op->extents[i]);
140 }
141 tc.VisitExpr(op->condition);
142 Record(op->buffer_var.get(), tc);
143 this->VisitStmt(op->body);
144 }
145 void Record(const VarNode* var, const ExprTouched& tc) {
146 if (touched_var_.count(var)) return;
147 if (tc.expr_touched_) {
148 touched_var_.insert(var);
149 } else {
150 for (const VarNode* r : tc.used_vars_) {
151 if (r != var) {
152 affect_[r].push_back(var);
153 }
154 }
155 }
156 }
157
158 std::unordered_set<const VarNode*> TouchedVar(const Stmt& stmt, const VarNode* var) {
159 touched_var_.insert(var);
160 this->VisitStmt(stmt);
161 // do a DFS to push affect around dependency.
162 std::vector<const VarNode*> pending(touched_var_.begin(), touched_var_.end());
163 while (!pending.empty()) {
164 const VarNode* v = pending.back();
165 pending.pop_back();
166 for (const VarNode* r : affect_[v]) {
167 if (!touched_var_.count(r)) {
168 touched_var_.insert(r);
169 pending.push_back(r);
170 }
171 }
172 }
173 return std::move(touched_var_);
174 }
175
176 private:
177 // Whether variable is touched by the thread variable.
178 std::unordered_set<const VarNode*> touched_var_;
179 // x -> all the buffers x read from
180 std::unordered_map<const VarNode*, std::vector<const VarNode*>> affect_;
181};
182
183// Inject virtual thread loop
184// rewrite the buffer access pattern when necessary.
185class VTInjector : public arith::IRMutatorWithAnalyzer {
186 public:
187 using IRMutatorWithAnalyzer::VisitExpr_;
188 using IRMutatorWithAnalyzer::VisitStmt_;
189
190 // constructor
191 VTInjector(arith::Analyzer* analyzer, Var var, int num_threads,
192 const std::unordered_set<const VarNode*>& touched_var, bool allow_share)
193 : IRMutatorWithAnalyzer(analyzer),
194 var_(var),
195 num_threads_(num_threads),
196 touched_var_(touched_var),
197 allow_share_(allow_share) {}
198 // Inject VTLoop when needed.
199 Stmt VisitStmt(const Stmt& s) final {
200 ICHECK(!visit_touched_var_);
201 auto stmt = StmtExprMutator::VisitStmt(s);
202 if (visit_touched_var_ || trigger_base_inject_) {
203 if (!vt_loop_injected_) {
204 return InjectVTLoop(stmt, false);
205 }
206 visit_touched_var_ = false;
207 trigger_base_inject_ = false;
208 }
209 return stmt;
210 }
211 // Variable
212 PrimExpr VisitExpr_(const VarNode* op) final {
213 ICHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread";
214 if (touched_var_.count(op)) {
215 visit_touched_var_ = true;
216 }
217 return GetRef<PrimExpr>(op);
218 }
219 PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const {
220 return analyzer_->Simplify(index + var_ * alloc_extent);
221 }
222 // Expression.
223 PrimExpr VisitExpr_(const CallNode* op) final {
224 if (op->op.same_as(builtin::tvm_access_ptr())) {
225 ICHECK_EQ(op->args.size(), 5U);
226 DataType dtype = op->args[0].dtype();
227 const VarNode* buffer = op->args[1].as<VarNode>();
228 auto it = alloc_remap_.find(buffer);
229 if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op);
230 visit_touched_var_ = true;
231 PrimExpr offset = this->VisitExpr(op->args[2]);
232 PrimExpr extent = this->VisitExpr(op->args[3]);
233 PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes());
234 offset = RewriteIndex(offset, stride);
235
236 return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]});
237 } else if (op->op.same_as(builtin::tvm_context_id())) {
238 return allow_share_ ? GetRef<PrimExpr>(op) : var_;
239 } else {
240 return StmtExprMutator::VisitExpr_(op);
241 }
242 }
243 Stmt VisitStmt_(const EvaluateNode* op) final {
244 trigger_base_inject_ = !allow_share_;
245 return StmtExprMutator::VisitStmt_(op);
246 }
247 // Load
248 PrimExpr VisitExpr_(const LoadNode* op) final {
249 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
250 }
251 // Store
252 Stmt VisitStmt_(const StoreNode* op) final {
253 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
254 }
255 // BufferLoad
256 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
257 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
258 return VisitBufferAccess(std::move(node));
259 }
260 // BufferStore
261 Stmt VisitStmt_(const BufferStoreNode* op) final {
262 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
263 trigger_base_inject_ = !allow_share_;
264 return VisitBufferAccess(std::move(node));
265 }
266
267 template <typename Node>
268 Node VisitBufferAccess(Node node) {
269 if (touched_var_.count(node->buffer->data.get())) {
270 visit_touched_var_ = true;
271 }
272
273 auto it = alloc_remap_.find(node->buffer->data.get());
274 if (it != alloc_remap_.end()) {
275 ICHECK_EQ(node->indices.size(), 1)
276 << "InjectVirtualThread expects rewritten allocations to be flat memory.";
277 auto writer = node.CopyOnWrite();
278 writer->buffer = GetRemappedBuffer(node->buffer, it->second);
279 writer->indices = {RewriteIndex(node->indices[0], it->second)};
280 }
281
282 return node;
283 }
284
285 Buffer GetRemappedBuffer(Buffer buf, PrimExpr alloc_extent) {
286 auto key = buf.get();
287 auto it = buf_remap_.find(key);
288 if (it != buf_remap_.end()) {
289 return it->second;
290 }
291
292 ICHECK_EQ(buf->shape.size(), 1) << "Expected buffers being rewritten to already be flattened.";
293 auto writer = buf.CopyOnWrite();
294 writer->shape = {buf->shape[0] * alloc_extent};
295
296 buf_remap_[key] = buf;
297 return buf;
298 }
299
300 // Attribute
301 Stmt VisitStmt_(const AttrStmtNode* op) final {
302 PrimExpr value = this->VisitExpr(op->value);
303 if (visit_touched_var_ && !vt_loop_injected_) {
304 return InjectVTLoop(GetRef<Stmt>(op), true);
305 } else if (!allow_share_ && !vt_loop_injected_ &&
306 (op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) {
307 return InjectVTLoop(GetRef<Stmt>(op), true);
308 } else {
309 Stmt body = this->VisitStmt(op->body);
310 if (value.same_as(op->value) && body.same_as(op->body)) {
311 return GetRef<Stmt>(op);
312 } else {
313 return AttrStmt(op->node, op->attr_key, value, body);
314 }
315 }
316 }
317 // LetStmt
318 Stmt VisitStmt_(const LetStmtNode* op) final {
319 PrimExpr value = this->VisitExpr(op->value);
320 if (visit_touched_var_ && !vt_loop_injected_) {
321 return InjectVTLoop(GetRef<Stmt>(op), true);
322 }
323 visit_touched_var_ = false;
324 Stmt body = this->VisitStmt(op->body);
325 if (value.same_as(op->value) && body.same_as(op->body)) {
326 return GetRef<Stmt>(op);
327 } else {
328 return LetStmt(op->var, value, body);
329 }
330 }
331 // For
332 Stmt VisitStmt_(const ForNode* op) final {
333 ICHECK(is_zero(op->min));
334 PrimExpr extent = this->VisitExpr(op->extent);
335 if (visit_touched_var_ && !vt_loop_injected_) {
336 Stmt stmt = InjectVTLoop(GetRef<Stmt>(op), true);
337 ++max_loop_depth_;
338 return stmt;
339 }
340 visit_touched_var_ = false;
341 Stmt body = this->VisitStmt(op->body);
342 ++max_loop_depth_;
343 if (extent.same_as(op->extent) && body.same_as(op->body)) {
344 return GetRef<Stmt>(op);
345 } else {
346 auto n = CopyOnWrite(op);
347 n->extent = std::move(extent);
348 n->body = std::move(body);
349 return Stmt(n);
350 }
351 }
352 // IfThenElse
353 Stmt VisitStmt_(const IfThenElseNode* op) final {
354 PrimExpr condition = this->VisitExpr(op->condition);
355 if (visit_touched_var_ && !vt_loop_injected_) {
356 return InjectVTLoop(GetRef<Stmt>(op), true);
357 }
358 visit_touched_var_ = false;
359 ICHECK_EQ(max_loop_depth_, 0);
360 Stmt then_case = this->VisitStmt(op->then_case);
361 Optional<Stmt> else_case = NullOpt;
362 if (op->else_case) {
363 int temp = max_loop_depth_;
364 max_loop_depth_ = 0;
365 else_case = this->VisitStmt(op->else_case.value());
366 max_loop_depth_ = std::max(temp, max_loop_depth_);
367 }
368 if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
369 else_case.same_as(op->else_case)) {
370 return GetRef<Stmt>(op);
371 } else {
372 return IfThenElse(condition, then_case, else_case);
373 }
374 }
375
376 // While
377 Stmt VisitStmt_(const WhileNode* op) final {
378 // TODO(masahi): What should we do for While nodes?
379 LOG(FATAL) << "WhileNode in InjectVirtualThread not supported yet";
380 }
381
382 // Seq
383 Stmt VisitStmt_(const SeqStmtNode* op) final {
384 ICHECK_EQ(max_loop_depth_, 0);
385 auto fmutate = [this](const Stmt& s) {
386 int temp = max_loop_depth_;
387 max_loop_depth_ = 0;
388 Stmt ret = this->VisitStmt(s);
389 max_loop_depth_ = std::max(max_loop_depth_, temp);
390 return ret;
391 };
392 return StmtMutator::VisitSeqStmt_(op, false, fmutate);
393 }
394 // Allocate
395 Stmt VisitStmt_(const AllocateNode* op) final {
396 Allocate node = GetRef<Allocate>(op);
397
398 PrimExpr condition = this->VisitExpr(op->condition);
399
400 Array<PrimExpr> extents =
401 op->extents.Map([this](const PrimExpr& extent) { return this->VisitExpr(extent); });
402
403 if (visit_touched_var_ && !vt_loop_injected_) {
404 return InjectVTLoop(GetRef<Stmt>(op), true);
405 }
406
407 visit_touched_var_ = false;
408
409 // Rewrite the buffer if its shape or any value stored in it
410 // depends on the virtual thread var. If `allow_share_` is false,
411 // then the buffer is always rewritten, even if separate virtual
412 // threads only read from the buffer.
413 if (touched_var_.count(op->buffer_var.get()) || !allow_share_) {
414 // place v on highest dimension.
415
416 // TODO(Lunderberg): Move pass to apply before
417 // StorageFlatten/FlattenBuffer. Would rewrite the Buffer to
418 // add the injected virtual thread as the first index.
419 ICHECK_EQ(extents.size(), 1)
420 << "InjectVirtualThread expects rewritten allocations to be flat memory.";
421 PrimExpr stride = extents[0];
422 extents = {stride * num_threads_};
423
424 // Mark the buffer var as touched. BufferLoad/BufferStore should
425 // access locations at `current_index + stride*vthread_var`.
426 alloc_remap_[op->buffer_var.get()] = stride;
427 }
428
429 // Mutate the body. Depends on alloc_remap_.
430 auto body = this->VisitStmt(op->body);
431
432 if (extents.same_as(op->extents) && body.same_as(op->body) &&
433 condition.same_as(op->condition)) {
434 return GetRef<Stmt>(op);
435 } else {
436 return Allocate(op->buffer_var, op->dtype, extents, condition, body);
437 }
438 }
439
440 // inject vthread loop
441 Stmt InjectVTLoop(Stmt stmt, bool before_mutation) {
442 ICHECK(!vt_loop_injected_);
443 // reset the flags
444 visit_touched_var_ = false;
445 trigger_base_inject_ = false;
446 vt_loop_injected_ = true;
447 if (before_mutation) {
448 stmt = this->VisitStmt(stmt);
449 }
450 // reset the flags after processing.
451 vt_loop_injected_ = false;
452 visit_touched_var_ = false;
453 // only unroll if number of vthreads are small
454 if (max_loop_depth_ == 0 && num_threads_ < 16) {
455 // do unrolling if it is inside innermost content.
456 Array<Stmt> seq;
457 for (int i = 0; i < num_threads_; ++i) {
458 seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}}));
459 }
460 return SeqStmt::Flatten(seq);
461 } else {
462 // insert a for loop
463 Var idx(var_->name_hint + ".s", var_->dtype);
464 Map<Var, PrimExpr> values{{var_, idx}};
465 stmt = Substitute(stmt, values);
466 return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_),
467 ForKind::kSerial, stmt);
468 }
469 }
470
471 private:
472 // vthread variable
473 Var var_;
474 // the threads/lanes
475 int num_threads_;
476 // whethe the loop is already injected.
477 bool vt_loop_injected_{false};
478 // whether current expression get touched.
479 bool visit_touched_var_{false};
480 // Trigger base stmt
481 bool trigger_base_inject_{false};
482 // the counter of loops in after mutation.
483 int max_loop_depth_{0};
484 // The variables that get touched.
485 const std::unordered_set<const VarNode*>& touched_var_;
486 // Whether allow shareding.
487 bool allow_share_;
488 /* \brief The allocations that get touched -> extent
489 *
490 * Maps from the buffer_var of an allocate node to the original
491 * extent of the allocation. Used when rewriting the indices of
492 * BufferLoad/BufferStore.
493 */
494 std::unordered_map<const VarNode*, PrimExpr> alloc_remap_;
495 /*! \brief Map of buffers that are modified.
496 *
497 * Buffers allocated or written to within the virtual thread loop
498 * must have one copy per virtual thread. This is done by enlarging
499 * the allocated buffer size, then modifying the indices at which
500 * each virtual thread accesses the buffer.
501 */
502 std::unordered_map<const BufferNode*, Buffer> buf_remap_;
503};
504
505class VirtualThreadInjector : public arith::IRMutatorWithAnalyzer {
506 public:
507 using IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
508 using IRMutatorWithAnalyzer::VisitStmt_;
509
510 Stmt VisitStmt_(const AttrStmtNode* op) final {
511 Stmt stmt = StmtMutator::VisitStmt_(op);
512 op = stmt.as<AttrStmtNode>();
513 if (op->attr_key == attr::virtual_thread) {
514 IterVar iv = Downcast<IterVar>(op->node);
515 bool allow_share = std::string(iv->thread_tag).substr(0, 7) == "vthread";
516 int nthread = static_cast<int>(op->value.as<IntImmNode>()->value);
517 VarTouchedAnalysis vs;
518 auto touched = vs.TouchedVar(op->body, iv->var.get());
519 VTInjector injector(analyzer_, iv->var, nthread, touched, allow_share);
520 return injector(op->body);
521 } else {
522 return stmt;
523 }
524 }
525
526 Stmt VisitStmt_(const ProducerStoreNode* op) final {
527 LOG(FATAL) << "Need to call StorageFlatten first";
528 }
529};
530
531namespace transform {
532
533Pass InjectVirtualThread() {
534 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
535 auto* n = f.CopyOnWrite();
536
537 arith::Analyzer analyzer;
538
539 n->body = VirtualThreadInjector(&analyzer)(std::move(n->body));
540 n->body = ConvertSSA(std::move(n->body));
541 return f;
542 };
543 return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread", {});
544}
545
546TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread").set_body_typed(InjectVirtualThread);
547
548} // namespace transform
549
550} // namespace tir
551} // namespace tvm
552