1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file unify_thread_binding.cc
22 */
23
24#include <tvm/arith/analyzer.h>
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/stmt_functor.h>
27#include <tvm/tir/transform.h>
28
29#include "../../support/utils.h"
30#include "ir_utils.h"
31
32namespace tvm {
33namespace tir {
34
35using support::StartsWith;
36
37/*!
38 * \brief A mutator which searches AttrStmts of thread bindings and changes the `node` field IterVar
39 * of the AttrStmts, so that for one kind of thread binding, all such thread bindings use the same
40 * IterVar
41 */
42class ThreadBindingUnifier : public StmtExprMutator {
43 public:
44 static Stmt Unify(Stmt stmt) { return ThreadBindingUnifier()(std::move(stmt)); }
45
46 private:
47 Stmt VisitStmt_(const AttrStmtNode* op) final {
48 // If this AttrStmt is not thread binding attribute, return as usual.
49 if (op->attr_key != attr::thread_extent && op->attr_key != attr::virtual_thread) {
50 return StmtMutator::VisitStmt_(op);
51 }
52 IterVar old_iter_var = Downcast<IterVar>(op->node);
53 return UnifyThreadBindingImpl(op, old_iter_var->var, old_iter_var, old_iter_var->dom);
54 }
55
56 Stmt VisitStmt_(const ForNode* op) final {
57 // If this For is not thread binding attribute, return as usual.
58 if (op->kind != ForKind::kThreadBinding) {
59 return StmtExprMutator::VisitStmt_(op);
60 }
61 Map<String, ObjectRef> annotations = op->annotations;
62 Stmt stmt = UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(),
63 Range::FromMinExtent(op->min, op->extent));
64 if (annotations.empty()) {
65 return stmt;
66 }
67 For new_loop = Downcast<For>(stmt);
68 new_loop.CopyOnWrite()->annotations = std::move(annotations);
69 return std::move(new_loop);
70 }
71
72 template <typename Node>
73 Stmt UnifyThreadBindingImpl(const Node* op, const Var& old_var, const IterVar& old_iter_var,
74 const Range& dom) {
75 // Step 1. Fetch the thread tag.
76 IterVar new_iter_var{nullptr};
77 const String& thread_tag = old_iter_var->thread_tag;
78
79 // Step 2: Increase `thread_block_depth_` if the thread tag starts with "blockIdx". If the
80 // thread block depth is 0 before the increment, it means we are entering a new kernel, and
81 // therefore we need to make `thread_tag2iter_var_map_` empty, as different kernels can have
82 // thread axes with different extents.
83 bool is_kernel_launch_scope = false;
84 int old_thread_block_depth = thread_block_depth_;
85 if (StartsWith(thread_tag, "blockIdx.") || !thread_block_depth_) {
86 if (!thread_block_depth_) {
87 thread_tag2iter_var_map_.clear();
88 is_kernel_launch_scope = true;
89 }
90 ++thread_block_depth_;
91 }
92
93 // Step 3. See if an IterVar for this kind of thread binding was created before. If so, we use
94 // the created IterVar. Otherwise, we create a new IterVar for this thread binding and store the
95 // IterVar in mapping `thread_tag2iter_var_map_`.
96 Map<String, IterVar>::iterator it = thread_tag2iter_var_map_.find(thread_tag);
97 if (it != thread_tag2iter_var_map_.end()) {
98 new_iter_var = (*it).second;
99 ICHECK(ana.CanProveEqual(dom->min, new_iter_var->dom->min));
100 CHECK(ana.CanProveEqual(dom->extent, new_iter_var->dom->extent))
101 << "ValueError: All loops that are bound to `" << thread_tag
102 << "` should have the same extent. However, there are two loops with extent "
103 << new_iter_var->dom->extent << " and " << dom->extent << ", which are not equal";
104 } else {
105 new_iter_var = IterVar(dom, Var(thread_tag, dom->extent.dtype()), old_iter_var->iter_type,
106 old_iter_var->thread_tag);
107 thread_tag2iter_var_map_.Set(thread_tag, new_iter_var);
108 launch_threads_.push_back(new_iter_var);
109 }
110
111 // Step 4. We will substitute the occurrences of the old variable in the old IterVar with the
112 // new variable in further mutation. Thus, we store the mapping entry. Cast to old dtype if
113 // needed (we assume both old and new dtype are valid for the range of the thread extent).
114 var_substitution_map_.Set(old_var, cast(old_var.dtype(), new_iter_var->var));
115
116 // Step 5. Mutate recursively, update the body with the new IterVar, and restore the depth
117 // counter. Emit for-loops to launch threads if current statement is the outermost thread
118 // binding of the kernel.
119 Stmt new_stmt = StmtMutator::VisitStmt_(op);
120 auto* new_node = new_stmt.as<Node>();
121 ICHECK(new_node);
122 thread_block_depth_ = old_thread_block_depth;
123 if (is_kernel_launch_scope) {
124 return EmitLaunchThreads(new_node->body);
125 } else {
126 return new_node->body;
127 }
128 }
129
130 /*!
131 * \brief Emit loop nests representing all thread bindings of the kernel
132 * \param body The body of the innermost loop of the thread bindings.
133 * \return The loop nests of the thread bindings.
134 */
135 Stmt EmitLaunchThreads(const Stmt& body) {
136 Stmt result = body;
137 while (!launch_threads_.empty()) {
138 const IterVar& thread_binding = launch_threads_.back();
139 // Recreate the IterVar as we don't duplicate `dom` in both For and IterVar. This is
140 // necessary for unit tests.
141 result = For(thread_binding->var, thread_binding->dom->min, thread_binding->dom->extent,
142 ForKind::kThreadBinding, result,
143 IterVar(NullValue<Range>(), Var(""), IterVarType::kThreadIndex,
144 thread_binding->thread_tag));
145 launch_threads_.pop_back();
146 }
147 return result;
148 }
149
150 PrimExpr VisitExpr_(const VarNode* var) final {
151 // If this variable appears as a key in `var_substitution_map_`, we substitute it with its
152 // corresponding value in the mapping.
153 Map<Var, PrimExpr>::iterator it = var_substitution_map_.find(GetRef<Var>(var));
154 return it != var_substitution_map_.end() ? (*it).second : GetRef<Var>(var);
155 }
156
157 /*!
158 * \brief A mapping from a thread tag to its corresponding IterVar that is shared by all
159 * occurrences of the thread tag
160 */
161 Map<String, IterVar> thread_tag2iter_var_map_;
162 /*!
163 * \brief A list of IterVar corresponding to threads in current kernel. This will be used to
164 * generate for-loops to launch threads.
165 */
166 Array<IterVar> launch_threads_;
167 /*! \brief A mapping from old variables to new variables, which is used for substitution */
168 Map<Var, PrimExpr> var_substitution_map_;
169 /*! \brief A integer counter storing the depth of thread bindings of "blockIdx.x/y/z" */
170 int thread_block_depth_ = 0;
171 /*! \brief An analyzer used for equality proof */
172 arith::Analyzer ana;
173};
174
175PrimFunc UnifyThreadBinding(PrimFunc f) {
176 // Only apply this pass to TIR that is not from TE schedules
177 if (!IsFromLegacyTESchedule(f)) {
178 PrimFuncNode* fptr = f.CopyOnWrite();
179 fptr->body = ThreadBindingUnifier::Unify(std::move(f->body));
180 return f;
181 } else {
182 return f;
183 }
184}
185
186namespace transform {
187
188Pass UnifyThreadBinding() {
189 auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
190 return UnifyThreadBinding(std::move(f));
191 };
192 return CreatePrimFuncPass(pass_func, 0, "tir.UnifyThreadBinding", {});
193}
194
195TVM_REGISTER_GLOBAL("tir.transform.UnifyThreadBinding").set_body_typed(UnifyThreadBinding);
196
197} // namespace transform
198
199} // namespace tir
200} // namespace tvm
201