1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file inject_virtual_thread.cc |
22 | */ |
23 | #include <tvm/runtime/registry.h> |
24 | #include <tvm/tir/builtin.h> |
25 | #include <tvm/tir/expr.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | #include <tvm/tir/transform.h> |
28 | |
29 | #include <unordered_set> |
30 | |
31 | #include "../../arith/ir_mutator_with_analyzer.h" |
32 | #include "ir_utils.h" |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | // If expression is touched by var. |
38 | class ExprTouched final : public StmtExprVisitor { |
39 | public: |
40 | explicit ExprTouched(const std::unordered_set<const VarNode*>& touched, bool check_write) |
41 | : touched_var_(touched), check_write_(check_write) {} |
42 | |
43 | void VisitExpr(const PrimExpr& n) final { |
44 | // early stopping |
45 | if (expr_touched_ && !check_write_) return; |
46 | StmtExprVisitor::VisitExpr(n); |
47 | } |
48 | void VisitStmt(const Stmt& n) final { |
49 | // early stopping |
50 | if (expr_touched_ && !check_write_) return; |
51 | StmtExprVisitor::VisitStmt(n); |
52 | } |
53 | void VisitExpr_(const LoadNode* op) final { |
54 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
55 | } |
56 | void VisitExpr_(const BufferLoadNode* op) final { |
57 | HandleUseVar(op->buffer->data.get()); |
58 | StmtExprVisitor::VisitExpr_(op); |
59 | } |
60 | void VisitExpr_(const VarNode* op) final { HandleUseVar(op); } |
61 | void VisitExpr_(const CallNode* op) final { |
62 | if (op->op.same_as(builtin::tvm_access_ptr())) { |
63 | const auto* rw_mask = op->args[4].as<IntImmNode>(); |
64 | const VarNode* buffer_var = op->args[1].as<VarNode>(); |
65 | ICHECK(buffer_var); |
66 | ICHECK(rw_mask); |
67 | // read |
68 | if (rw_mask->value & 1) { |
69 | HandleUseVar(buffer_var); |
70 | } |
71 | if (rw_mask->value & 2) { |
72 | HandleWriteVar(buffer_var); |
73 | } |
74 | this->VisitExpr(op->args[2]); |
75 | } else { |
76 | StmtExprVisitor::VisitExpr_(op); |
77 | } |
78 | } |
79 | void HandleUseVar(const VarNode* var) { |
80 | auto it = touched_var_.find(var); |
81 | if (it != touched_var_.end()) { |
82 | expr_touched_ = true; |
83 | } |
84 | // rember the used vars |
85 | // in case the var get touched later in a loop. |
86 | if (!expr_touched_) { |
87 | used_vars_.push_back(var); |
88 | } |
89 | } |
90 | void HandleWriteVar(const VarNode* var) { write_vars_.push_back(var); } |
91 | // the fields. |
92 | bool expr_touched_{false}; |
93 | std::vector<const VarNode*> used_vars_; |
94 | std::vector<const VarNode*> write_vars_; |
95 | const std::unordered_set<const VarNode*>& touched_var_; |
96 | bool check_write_; |
97 | }; |
98 | |
99 | // Analyze if the buffers are invariant to value of var |
100 | class VarTouchedAnalysis : public StmtVisitor { |
101 | public: |
102 | void VisitStmt_(const LetStmtNode* op) final { |
103 | ExprTouched tc(touched_var_, false); |
104 | tc(op->value); |
105 | Record(op->var.get(), tc); |
106 | this->VisitStmt(op->body); |
107 | } |
108 | |
109 | void VisitStmt_(const StoreNode* op) final { |
110 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
111 | } |
112 | |
113 | void VisitStmt_(const BufferStoreNode* op) final { |
114 | ExprTouched tc(touched_var_, false); |
115 | tc(op->value); |
116 | for (const auto& index : op->indices) { |
117 | tc(index); |
118 | } |
119 | Record(op->buffer->data.get(), tc); |
120 | } |
121 | void VisitStmt_(const ForNode* op) final { |
122 | ExprTouched tc(touched_var_, false); |
123 | tc(op->min); |
124 | tc(op->extent); |
125 | Record(op->loop_var.get(), tc); |
126 | this->VisitStmt(op->body); |
127 | } |
128 | // external function call |
129 | void VisitStmt_(const EvaluateNode* op) final { |
130 | ExprTouched tc(touched_var_, true); |
131 | tc(op->value); |
132 | for (const VarNode* var : tc.write_vars_) { |
133 | Record(var, tc); |
134 | } |
135 | } |
136 | void VisitStmt_(const AllocateNode* op) final { |
137 | ExprTouched tc(touched_var_, false); |
138 | for (size_t i = 0; i < op->extents.size(); ++i) { |
139 | tc(op->extents[i]); |
140 | } |
141 | tc.VisitExpr(op->condition); |
142 | Record(op->buffer_var.get(), tc); |
143 | this->VisitStmt(op->body); |
144 | } |
145 | void Record(const VarNode* var, const ExprTouched& tc) { |
146 | if (touched_var_.count(var)) return; |
147 | if (tc.expr_touched_) { |
148 | touched_var_.insert(var); |
149 | } else { |
150 | for (const VarNode* r : tc.used_vars_) { |
151 | if (r != var) { |
152 | affect_[r].push_back(var); |
153 | } |
154 | } |
155 | } |
156 | } |
157 | |
158 | std::unordered_set<const VarNode*> TouchedVar(const Stmt& stmt, const VarNode* var) { |
159 | touched_var_.insert(var); |
160 | this->VisitStmt(stmt); |
161 | // do a DFS to push affect around dependency. |
162 | std::vector<const VarNode*> pending(touched_var_.begin(), touched_var_.end()); |
163 | while (!pending.empty()) { |
164 | const VarNode* v = pending.back(); |
165 | pending.pop_back(); |
166 | for (const VarNode* r : affect_[v]) { |
167 | if (!touched_var_.count(r)) { |
168 | touched_var_.insert(r); |
169 | pending.push_back(r); |
170 | } |
171 | } |
172 | } |
173 | return std::move(touched_var_); |
174 | } |
175 | |
176 | private: |
177 | // Whether variable is touched by the thread variable. |
178 | std::unordered_set<const VarNode*> touched_var_; |
179 | // x -> all the buffers x read from |
180 | std::unordered_map<const VarNode*, std::vector<const VarNode*>> affect_; |
181 | }; |
182 | |
183 | // Inject virtual thread loop |
184 | // rewrite the buffer access pattern when necessary. |
185 | class VTInjector : public arith::IRMutatorWithAnalyzer { |
186 | public: |
187 | using IRMutatorWithAnalyzer::VisitExpr_; |
188 | using IRMutatorWithAnalyzer::VisitStmt_; |
189 | |
190 | // constructor |
191 | VTInjector(arith::Analyzer* analyzer, Var var, int num_threads, |
192 | const std::unordered_set<const VarNode*>& touched_var, bool allow_share) |
193 | : IRMutatorWithAnalyzer(analyzer), |
194 | var_(var), |
195 | num_threads_(num_threads), |
196 | touched_var_(touched_var), |
197 | allow_share_(allow_share) {} |
198 | // Inject VTLoop when needed. |
199 | Stmt VisitStmt(const Stmt& s) final { |
200 | ICHECK(!visit_touched_var_); |
201 | auto stmt = StmtExprMutator::VisitStmt(s); |
202 | if (visit_touched_var_ || trigger_base_inject_) { |
203 | if (!vt_loop_injected_) { |
204 | return InjectVTLoop(stmt, false); |
205 | } |
206 | visit_touched_var_ = false; |
207 | trigger_base_inject_ = false; |
208 | } |
209 | return stmt; |
210 | } |
211 | // Variable |
212 | PrimExpr VisitExpr_(const VarNode* op) final { |
213 | ICHECK(!alloc_remap_.count(op)) << "Buffer address may get rewritten in virtual thread" ; |
214 | if (touched_var_.count(op)) { |
215 | visit_touched_var_ = true; |
216 | } |
217 | return GetRef<PrimExpr>(op); |
218 | } |
219 | PrimExpr RewriteIndex(PrimExpr index, PrimExpr alloc_extent) const { |
220 | return analyzer_->Simplify(index + var_ * alloc_extent); |
221 | } |
222 | // Expression. |
223 | PrimExpr VisitExpr_(const CallNode* op) final { |
224 | if (op->op.same_as(builtin::tvm_access_ptr())) { |
225 | ICHECK_EQ(op->args.size(), 5U); |
226 | DataType dtype = op->args[0].dtype(); |
227 | const VarNode* buffer = op->args[1].as<VarNode>(); |
228 | auto it = alloc_remap_.find(buffer); |
229 | if (it == alloc_remap_.end()) return StmtExprMutator::VisitExpr_(op); |
230 | visit_touched_var_ = true; |
231 | PrimExpr offset = this->VisitExpr(op->args[2]); |
232 | PrimExpr extent = this->VisitExpr(op->args[3]); |
233 | PrimExpr stride = it->second / make_const(offset.dtype(), dtype.lanes()); |
234 | offset = RewriteIndex(offset, stride); |
235 | |
236 | return Call(op->dtype, op->op, {op->args[0], op->args[1], offset, extent, op->args[4]}); |
237 | } else if (op->op.same_as(builtin::tvm_context_id())) { |
238 | return allow_share_ ? GetRef<PrimExpr>(op) : var_; |
239 | } else { |
240 | return StmtExprMutator::VisitExpr_(op); |
241 | } |
242 | } |
243 | Stmt VisitStmt_(const EvaluateNode* op) final { |
244 | trigger_base_inject_ = !allow_share_; |
245 | return StmtExprMutator::VisitStmt_(op); |
246 | } |
247 | // Load |
248 | PrimExpr VisitExpr_(const LoadNode* op) final { |
249 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
250 | } |
251 | // Store |
252 | Stmt VisitStmt_(const StoreNode* op) final { |
253 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
254 | } |
255 | // BufferLoad |
256 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
257 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
258 | return VisitBufferAccess(std::move(node)); |
259 | } |
260 | // BufferStore |
261 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
262 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
263 | trigger_base_inject_ = !allow_share_; |
264 | return VisitBufferAccess(std::move(node)); |
265 | } |
266 | |
267 | template <typename Node> |
268 | Node VisitBufferAccess(Node node) { |
269 | if (touched_var_.count(node->buffer->data.get())) { |
270 | visit_touched_var_ = true; |
271 | } |
272 | |
273 | auto it = alloc_remap_.find(node->buffer->data.get()); |
274 | if (it != alloc_remap_.end()) { |
275 | ICHECK_EQ(node->indices.size(), 1) |
276 | << "InjectVirtualThread expects rewritten allocations to be flat memory." ; |
277 | auto writer = node.CopyOnWrite(); |
278 | writer->buffer = GetRemappedBuffer(node->buffer, it->second); |
279 | writer->indices = {RewriteIndex(node->indices[0], it->second)}; |
280 | } |
281 | |
282 | return node; |
283 | } |
284 | |
285 | Buffer GetRemappedBuffer(Buffer buf, PrimExpr alloc_extent) { |
286 | auto key = buf.get(); |
287 | auto it = buf_remap_.find(key); |
288 | if (it != buf_remap_.end()) { |
289 | return it->second; |
290 | } |
291 | |
292 | ICHECK_EQ(buf->shape.size(), 1) << "Expected buffers being rewritten to already be flattened." ; |
293 | auto writer = buf.CopyOnWrite(); |
294 | writer->shape = {buf->shape[0] * alloc_extent}; |
295 | |
296 | buf_remap_[key] = buf; |
297 | return buf; |
298 | } |
299 | |
300 | // Attribute |
301 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
302 | PrimExpr value = this->VisitExpr(op->value); |
303 | if (visit_touched_var_ && !vt_loop_injected_) { |
304 | return InjectVTLoop(GetRef<Stmt>(op), true); |
305 | } else if (!allow_share_ && !vt_loop_injected_ && |
306 | (op->attr_key == attr::coproc_uop_scope || op->attr_key == attr::coproc_scope)) { |
307 | return InjectVTLoop(GetRef<Stmt>(op), true); |
308 | } else { |
309 | Stmt body = this->VisitStmt(op->body); |
310 | if (value.same_as(op->value) && body.same_as(op->body)) { |
311 | return GetRef<Stmt>(op); |
312 | } else { |
313 | return AttrStmt(op->node, op->attr_key, value, body); |
314 | } |
315 | } |
316 | } |
317 | // LetStmt |
318 | Stmt VisitStmt_(const LetStmtNode* op) final { |
319 | PrimExpr value = this->VisitExpr(op->value); |
320 | if (visit_touched_var_ && !vt_loop_injected_) { |
321 | return InjectVTLoop(GetRef<Stmt>(op), true); |
322 | } |
323 | visit_touched_var_ = false; |
324 | Stmt body = this->VisitStmt(op->body); |
325 | if (value.same_as(op->value) && body.same_as(op->body)) { |
326 | return GetRef<Stmt>(op); |
327 | } else { |
328 | return LetStmt(op->var, value, body); |
329 | } |
330 | } |
331 | // For |
332 | Stmt VisitStmt_(const ForNode* op) final { |
333 | ICHECK(is_zero(op->min)); |
334 | PrimExpr extent = this->VisitExpr(op->extent); |
335 | if (visit_touched_var_ && !vt_loop_injected_) { |
336 | Stmt stmt = InjectVTLoop(GetRef<Stmt>(op), true); |
337 | ++max_loop_depth_; |
338 | return stmt; |
339 | } |
340 | visit_touched_var_ = false; |
341 | Stmt body = this->VisitStmt(op->body); |
342 | ++max_loop_depth_; |
343 | if (extent.same_as(op->extent) && body.same_as(op->body)) { |
344 | return GetRef<Stmt>(op); |
345 | } else { |
346 | auto n = CopyOnWrite(op); |
347 | n->extent = std::move(extent); |
348 | n->body = std::move(body); |
349 | return Stmt(n); |
350 | } |
351 | } |
352 | // IfThenElse |
353 | Stmt VisitStmt_(const IfThenElseNode* op) final { |
354 | PrimExpr condition = this->VisitExpr(op->condition); |
355 | if (visit_touched_var_ && !vt_loop_injected_) { |
356 | return InjectVTLoop(GetRef<Stmt>(op), true); |
357 | } |
358 | visit_touched_var_ = false; |
359 | ICHECK_EQ(max_loop_depth_, 0); |
360 | Stmt then_case = this->VisitStmt(op->then_case); |
361 | Optional<Stmt> else_case = NullOpt; |
362 | if (op->else_case) { |
363 | int temp = max_loop_depth_; |
364 | max_loop_depth_ = 0; |
365 | else_case = this->VisitStmt(op->else_case.value()); |
366 | max_loop_depth_ = std::max(temp, max_loop_depth_); |
367 | } |
368 | if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && |
369 | else_case.same_as(op->else_case)) { |
370 | return GetRef<Stmt>(op); |
371 | } else { |
372 | return IfThenElse(condition, then_case, else_case); |
373 | } |
374 | } |
375 | |
376 | // While |
377 | Stmt VisitStmt_(const WhileNode* op) final { |
378 | // TODO(masahi): What should we do for While nodes? |
379 | LOG(FATAL) << "WhileNode in InjectVirtualThread not supported yet" ; |
380 | } |
381 | |
382 | // Seq |
383 | Stmt VisitStmt_(const SeqStmtNode* op) final { |
384 | ICHECK_EQ(max_loop_depth_, 0); |
385 | auto fmutate = [this](const Stmt& s) { |
386 | int temp = max_loop_depth_; |
387 | max_loop_depth_ = 0; |
388 | Stmt ret = this->VisitStmt(s); |
389 | max_loop_depth_ = std::max(max_loop_depth_, temp); |
390 | return ret; |
391 | }; |
392 | return StmtMutator::VisitSeqStmt_(op, false, fmutate); |
393 | } |
394 | // Allocate |
395 | Stmt VisitStmt_(const AllocateNode* op) final { |
396 | Allocate node = GetRef<Allocate>(op); |
397 | |
398 | PrimExpr condition = this->VisitExpr(op->condition); |
399 | |
400 | Array<PrimExpr> extents = |
401 | op->extents.Map([this](const PrimExpr& extent) { return this->VisitExpr(extent); }); |
402 | |
403 | if (visit_touched_var_ && !vt_loop_injected_) { |
404 | return InjectVTLoop(GetRef<Stmt>(op), true); |
405 | } |
406 | |
407 | visit_touched_var_ = false; |
408 | |
409 | // Rewrite the buffer if its shape or any value stored in it |
410 | // depends on the virtual thread var. If `allow_share_` is false, |
411 | // then the buffer is always rewritten, even if separate virtual |
412 | // threads only read from the buffer. |
413 | if (touched_var_.count(op->buffer_var.get()) || !allow_share_) { |
414 | // place v on highest dimension. |
415 | |
416 | // TODO(Lunderberg): Move pass to apply before |
417 | // StorageFlatten/FlattenBuffer. Would rewrite the Buffer to |
418 | // add the injected virtual thread as the first index. |
419 | ICHECK_EQ(extents.size(), 1) |
420 | << "InjectVirtualThread expects rewritten allocations to be flat memory." ; |
421 | PrimExpr stride = extents[0]; |
422 | extents = {stride * num_threads_}; |
423 | |
424 | // Mark the buffer var as touched. BufferLoad/BufferStore should |
425 | // access locations at `current_index + stride*vthread_var`. |
426 | alloc_remap_[op->buffer_var.get()] = stride; |
427 | } |
428 | |
429 | // Mutate the body. Depends on alloc_remap_. |
430 | auto body = this->VisitStmt(op->body); |
431 | |
432 | if (extents.same_as(op->extents) && body.same_as(op->body) && |
433 | condition.same_as(op->condition)) { |
434 | return GetRef<Stmt>(op); |
435 | } else { |
436 | return Allocate(op->buffer_var, op->dtype, extents, condition, body); |
437 | } |
438 | } |
439 | |
440 | // inject vthread loop |
441 | Stmt InjectVTLoop(Stmt stmt, bool before_mutation) { |
442 | ICHECK(!vt_loop_injected_); |
443 | // reset the flags |
444 | visit_touched_var_ = false; |
445 | trigger_base_inject_ = false; |
446 | vt_loop_injected_ = true; |
447 | if (before_mutation) { |
448 | stmt = this->VisitStmt(stmt); |
449 | } |
450 | // reset the flags after processing. |
451 | vt_loop_injected_ = false; |
452 | visit_touched_var_ = false; |
453 | // only unroll if number of vthreads are small |
454 | if (max_loop_depth_ == 0 && num_threads_ < 16) { |
455 | // do unrolling if it is inside innermost content. |
456 | Array<Stmt> seq; |
457 | for (int i = 0; i < num_threads_; ++i) { |
458 | seq.push_back(Substitute(stmt, {{var_, make_const(var_.dtype(), i)}})); |
459 | } |
460 | return SeqStmt::Flatten(seq); |
461 | } else { |
462 | // insert a for loop |
463 | Var idx(var_->name_hint + ".s" , var_->dtype); |
464 | Map<Var, PrimExpr> values{{var_, idx}}; |
465 | stmt = Substitute(stmt, values); |
466 | return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_), |
467 | ForKind::kSerial, stmt); |
468 | } |
469 | } |
470 | |
471 | private: |
472 | // vthread variable |
473 | Var var_; |
474 | // the threads/lanes |
475 | int num_threads_; |
476 | // whethe the loop is already injected. |
477 | bool vt_loop_injected_{false}; |
478 | // whether current expression get touched. |
479 | bool visit_touched_var_{false}; |
480 | // Trigger base stmt |
481 | bool trigger_base_inject_{false}; |
482 | // the counter of loops in after mutation. |
483 | int max_loop_depth_{0}; |
484 | // The variables that get touched. |
485 | const std::unordered_set<const VarNode*>& touched_var_; |
486 | // Whether allow shareding. |
487 | bool allow_share_; |
488 | /* \brief The allocations that get touched -> extent |
489 | * |
490 | * Maps from the buffer_var of an allocate node to the original |
491 | * extent of the allocation. Used when rewriting the indices of |
492 | * BufferLoad/BufferStore. |
493 | */ |
494 | std::unordered_map<const VarNode*, PrimExpr> alloc_remap_; |
495 | /*! \brief Map of buffers that are modified. |
496 | * |
497 | * Buffers allocated or written to within the virtual thread loop |
498 | * must have one copy per virtual thread. This is done by enlarging |
499 | * the allocated buffer size, then modifying the indices at which |
500 | * each virtual thread accesses the buffer. |
501 | */ |
502 | std::unordered_map<const BufferNode*, Buffer> buf_remap_; |
503 | }; |
504 | |
505 | class VirtualThreadInjector : public arith::IRMutatorWithAnalyzer { |
506 | public: |
507 | using IRMutatorWithAnalyzer::IRMutatorWithAnalyzer; |
508 | using IRMutatorWithAnalyzer::VisitStmt_; |
509 | |
510 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
511 | Stmt stmt = StmtMutator::VisitStmt_(op); |
512 | op = stmt.as<AttrStmtNode>(); |
513 | if (op->attr_key == attr::virtual_thread) { |
514 | IterVar iv = Downcast<IterVar>(op->node); |
515 | bool allow_share = std::string(iv->thread_tag).substr(0, 7) == "vthread" ; |
516 | int nthread = static_cast<int>(op->value.as<IntImmNode>()->value); |
517 | VarTouchedAnalysis vs; |
518 | auto touched = vs.TouchedVar(op->body, iv->var.get()); |
519 | VTInjector injector(analyzer_, iv->var, nthread, touched, allow_share); |
520 | return injector(op->body); |
521 | } else { |
522 | return stmt; |
523 | } |
524 | } |
525 | |
526 | Stmt VisitStmt_(const ProducerStoreNode* op) final { |
527 | LOG(FATAL) << "Need to call StorageFlatten first" ; |
528 | } |
529 | }; |
530 | |
531 | namespace transform { |
532 | |
533 | Pass InjectVirtualThread() { |
534 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
535 | auto* n = f.CopyOnWrite(); |
536 | |
537 | arith::Analyzer analyzer; |
538 | |
539 | n->body = VirtualThreadInjector(&analyzer)(std::move(n->body)); |
540 | n->body = ConvertSSA(std::move(n->body)); |
541 | return f; |
542 | }; |
543 | return CreatePrimFuncPass(pass_func, 0, "tir.InjectVirtualThread" , {}); |
544 | } |
545 | |
546 | TVM_REGISTER_GLOBAL("tir.transform.InjectVirtualThread" ).set_body_typed(InjectVirtualThread); |
547 | |
548 | } // namespace transform |
549 | |
550 | } // namespace tir |
551 | } // namespace tvm |
552 | |