1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file unify_thread_binding.cc |
22 | */ |
23 | |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | #include <tvm/tir/transform.h> |
28 | |
29 | #include "../../support/utils.h" |
30 | #include "ir_utils.h" |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | using support::StartsWith; |
36 | |
37 | /*! |
38 | * \brief A mutator which searches AttrStmts of thread bindings and changes the `node` field IterVar |
39 | * of the AttrStmts, so that for one kind of thread binding, all such thread bindings use the same |
40 | * IterVar |
41 | */ |
42 | class ThreadBindingUnifier : public StmtExprMutator { |
43 | public: |
44 | static Stmt Unify(Stmt stmt) { return ThreadBindingUnifier()(std::move(stmt)); } |
45 | |
46 | private: |
47 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
48 | // If this AttrStmt is not thread binding attribute, return as usual. |
49 | if (op->attr_key != attr::thread_extent && op->attr_key != attr::virtual_thread) { |
50 | return StmtMutator::VisitStmt_(op); |
51 | } |
52 | IterVar old_iter_var = Downcast<IterVar>(op->node); |
53 | return UnifyThreadBindingImpl(op, old_iter_var->var, old_iter_var, old_iter_var->dom); |
54 | } |
55 | |
56 | Stmt VisitStmt_(const ForNode* op) final { |
57 | // If this For is not thread binding attribute, return as usual. |
58 | if (op->kind != ForKind::kThreadBinding) { |
59 | return StmtExprMutator::VisitStmt_(op); |
60 | } |
61 | Map<String, ObjectRef> annotations = op->annotations; |
62 | Stmt stmt = UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(), |
63 | Range::FromMinExtent(op->min, op->extent)); |
64 | if (annotations.empty()) { |
65 | return stmt; |
66 | } |
67 | For new_loop = Downcast<For>(stmt); |
68 | new_loop.CopyOnWrite()->annotations = std::move(annotations); |
69 | return std::move(new_loop); |
70 | } |
71 | |
72 | template <typename Node> |
73 | Stmt UnifyThreadBindingImpl(const Node* op, const Var& old_var, const IterVar& old_iter_var, |
74 | const Range& dom) { |
75 | // Step 1. Fetch the thread tag. |
76 | IterVar new_iter_var{nullptr}; |
77 | const String& thread_tag = old_iter_var->thread_tag; |
78 | |
79 | // Step 2: Increase `thread_block_depth_` if the thread tag starts with "blockIdx". If the |
80 | // thread block depth is 0 before the increment, it means we are entering a new kernel, and |
81 | // therefore we need to make `thread_tag2iter_var_map_` empty, as different kernels can have |
82 | // thread axes with different extents. |
83 | bool is_kernel_launch_scope = false; |
84 | int old_thread_block_depth = thread_block_depth_; |
85 | if (StartsWith(thread_tag, "blockIdx." ) || !thread_block_depth_) { |
86 | if (!thread_block_depth_) { |
87 | thread_tag2iter_var_map_.clear(); |
88 | is_kernel_launch_scope = true; |
89 | } |
90 | ++thread_block_depth_; |
91 | } |
92 | |
93 | // Step 3. See if an IterVar for this kind of thread binding was created before. If so, we use |
94 | // the created IterVar. Otherwise, we create a new IterVar for this thread binding and store the |
95 | // IterVar in mapping `thread_tag2iter_var_map_`. |
96 | Map<String, IterVar>::iterator it = thread_tag2iter_var_map_.find(thread_tag); |
97 | if (it != thread_tag2iter_var_map_.end()) { |
98 | new_iter_var = (*it).second; |
99 | ICHECK(ana.CanProveEqual(dom->min, new_iter_var->dom->min)); |
100 | CHECK(ana.CanProveEqual(dom->extent, new_iter_var->dom->extent)) |
101 | << "ValueError: All loops that are bound to `" << thread_tag |
102 | << "` should have the same extent. However, there are two loops with extent " |
103 | << new_iter_var->dom->extent << " and " << dom->extent << ", which are not equal" ; |
104 | } else { |
105 | new_iter_var = IterVar(dom, Var(thread_tag, dom->extent.dtype()), old_iter_var->iter_type, |
106 | old_iter_var->thread_tag); |
107 | thread_tag2iter_var_map_.Set(thread_tag, new_iter_var); |
108 | launch_threads_.push_back(new_iter_var); |
109 | } |
110 | |
111 | // Step 4. We will substitute the occurrences of the old variable in the old IterVar with the |
112 | // new variable in further mutation. Thus, we store the mapping entry. Cast to old dtype if |
113 | // needed (we assume both old and new dtype are valid for the range of the thread extent). |
114 | var_substitution_map_.Set(old_var, cast(old_var.dtype(), new_iter_var->var)); |
115 | |
116 | // Step 5. Mutate recursively, update the body with the new IterVar, and restore the depth |
117 | // counter. Emit for-loops to launch threads if current statement is the outermost thread |
118 | // binding of the kernel. |
119 | Stmt new_stmt = StmtMutator::VisitStmt_(op); |
120 | auto* new_node = new_stmt.as<Node>(); |
121 | ICHECK(new_node); |
122 | thread_block_depth_ = old_thread_block_depth; |
123 | if (is_kernel_launch_scope) { |
124 | return EmitLaunchThreads(new_node->body); |
125 | } else { |
126 | return new_node->body; |
127 | } |
128 | } |
129 | |
130 | /*! |
131 | * \brief Emit loop nests representing all thread bindings of the kernel |
132 | * \param body The body of the innermost loop of the thread bindings. |
133 | * \return The loop nests of the thread bindings. |
134 | */ |
135 | Stmt EmitLaunchThreads(const Stmt& body) { |
136 | Stmt result = body; |
137 | while (!launch_threads_.empty()) { |
138 | const IterVar& thread_binding = launch_threads_.back(); |
139 | // Recreate the IterVar as we don't duplicate `dom` in both For and IterVar. This is |
140 | // necessary for unit tests. |
141 | result = For(thread_binding->var, thread_binding->dom->min, thread_binding->dom->extent, |
142 | ForKind::kThreadBinding, result, |
143 | IterVar(NullValue<Range>(), Var("" ), IterVarType::kThreadIndex, |
144 | thread_binding->thread_tag)); |
145 | launch_threads_.pop_back(); |
146 | } |
147 | return result; |
148 | } |
149 | |
150 | PrimExpr VisitExpr_(const VarNode* var) final { |
151 | // If this variable appears as a key in `var_substitution_map_`, we substitute it with its |
152 | // corresponding value in the mapping. |
153 | Map<Var, PrimExpr>::iterator it = var_substitution_map_.find(GetRef<Var>(var)); |
154 | return it != var_substitution_map_.end() ? (*it).second : GetRef<Var>(var); |
155 | } |
156 | |
157 | /*! |
158 | * \brief A mapping from a thread tag to its corresponding IterVar that is shared by all |
159 | * occurrences of the thread tag |
160 | */ |
161 | Map<String, IterVar> thread_tag2iter_var_map_; |
162 | /*! |
163 | * \brief A list of IterVar corresponding to threads in current kernel. This will be used to |
164 | * generate for-loops to launch threads. |
165 | */ |
166 | Array<IterVar> launch_threads_; |
167 | /*! \brief A mapping from old variables to new variables, which is used for substitution */ |
168 | Map<Var, PrimExpr> var_substitution_map_; |
169 | /*! \brief A integer counter storing the depth of thread bindings of "blockIdx.x/y/z" */ |
170 | int thread_block_depth_ = 0; |
171 | /*! \brief An analyzer used for equality proof */ |
172 | arith::Analyzer ana; |
173 | }; |
174 | |
175 | PrimFunc UnifyThreadBinding(PrimFunc f) { |
176 | // Only apply this pass to TIR that is not from TE schedules |
177 | if (!IsFromLegacyTESchedule(f)) { |
178 | PrimFuncNode* fptr = f.CopyOnWrite(); |
179 | fptr->body = ThreadBindingUnifier::Unify(std::move(f->body)); |
180 | return f; |
181 | } else { |
182 | return f; |
183 | } |
184 | } |
185 | |
186 | namespace transform { |
187 | |
188 | Pass UnifyThreadBinding() { |
189 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
190 | return UnifyThreadBinding(std::move(f)); |
191 | }; |
192 | return CreatePrimFuncPass(pass_func, 0, "tir.UnifyThreadBinding" , {}); |
193 | } |
194 | |
195 | TVM_REGISTER_GLOBAL("tir.transform.UnifyThreadBinding" ).set_body_typed(UnifyThreadBinding); |
196 | |
197 | } // namespace transform |
198 | |
199 | } // namespace tir |
200 | } // namespace tvm |
201 | |