1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | class WrongBlockIterTypeError : public ScheduleError { |
25 | public: |
26 | explicit WrongBlockIterTypeError(IRModule mod, ForKind for_kind, Var loop_var, Block block) |
27 | : mod_(std::move(mod)), loop_var_(std::move(loop_var)), block_(std::move(block)) { |
28 | op_str_ = for_kind == ForKind::kParallel |
29 | ? "parallel" |
30 | : (for_kind == ForKind::kVectorized ? "vectorize" : "bind" ); |
31 | } |
32 | String FastErrorString() const final { |
33 | std::ostringstream os; |
34 | os << "ScheduleError: The \"" << op_str_ |
35 | << "\" cannot be fulfilled with regard to some of its underlying block" ; |
36 | return os.str(); |
37 | } |
38 | String DetailRenderTemplate() const final { |
39 | std::ostringstream os; |
40 | if (op_str_ != "bind" ) { |
41 | os << "The \"" << op_str_ |
42 | << "\" cannot be fulfilled with regard to block {0} because some block iter whose block " |
43 | "binding contains the loop var is not a data parallel block iter" ; |
44 | } else { |
45 | os << "The \"bind\" cannot be fulfilled with regard to block {0}. This is because some of its" |
46 | " block iter whose block binding contains " |
47 | << loop_var_ |
48 | << " does not meet any of the conditions:\n1) the block iter is data parallel;\n2) the " |
49 | "block iter is a reduction block iter, and the thread axis to be bound is " |
50 | "\"threadIdx.x/y/z\"" ; |
51 | } |
52 | return os.str(); |
53 | } |
54 | IRModule mod() const final { return mod_; } |
55 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
56 | IRModule mod_; |
57 | std::string op_str_; |
58 | Var loop_var_; |
59 | Block block_; |
60 | }; |
61 | |
62 | /*! |
63 | * \brief Check if a loop can be parallelized/vectorized/bound with regard to a specific block |
64 | * \details There are two conditions: |
65 | * 1) The block is required to have affine bindings, and |
66 | * 2) For each block iter whose binding contains the input loop variable, either |
67 | * - the block iter is data parallel, or |
68 | * - the block iter is a reduction block iter, and the input `thread_tag` starts with "threadIdx" |
69 | * in case of cross-thread reduction. |
70 | * \param self The schedule state |
71 | * \param for_kind The desired ForKind (only `kParallel`, `kVectorized` and `kThreadBinding` are |
72 | * allowed) |
73 | * \param loop_var The loop variable of the loop to be checked |
74 | * \param block_realize The block-realize of the block to be checked |
75 | * \param thread_scope The thread scope of the thread axis to be bound, which is an invalid value if |
76 | * the operation is not "bind" |
77 | * \throws ScheduleError If the input loop cannot be parallelized/vectorized/bound with regard to |
78 | * the input block |
79 | */ |
80 | void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind, |
81 | const Var& loop_var, const BlockRealize& block_realize, |
82 | runtime::ThreadScope thread_scope) { |
83 | const Block& block = block_realize->block; |
84 | |
85 | // Cond 1. The block is required to have affine bindings. |
86 | // TODO(@automation): fix the check |
87 | // CheckAffineBinding(self, block); |
88 | |
89 | // Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed. |
90 | ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); |
91 | int n_iters = static_cast<int>(block->iter_vars.size()); |
92 | for (int i = 0; i < n_iters; ++i) { |
93 | const IterVar& iter_var = block->iter_vars[i]; |
94 | const PrimExpr& binding = block_realize->iter_values[i]; |
95 | |
96 | if (!UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { |
97 | continue; |
98 | } |
99 | // Only two cases are allowed: |
100 | // - The block iter is data parallel, or |
101 | // - The block iter is a reduction block iter, and the `thread_scope` is "threadIdx.x/y/z" |
102 | // in case of cross-thread reduction. |
103 | IterVarType iter_type = iter_var->iter_type; |
104 | if (!(iter_type == kDataPar || |
105 | (iter_type == kCommReduce && thread_scope.rank == 1 && thread_scope.dim_index != -1))) { |
106 | throw WrongBlockIterTypeError(self->mod, for_kind, loop_var, block); |
107 | } |
108 | } |
109 | } |
110 | |
111 | /*! |
112 | * \brief For each block (recursive) under the given loop, check whether the input loop can be |
113 | * parallelized/vectorized/bound with regard to the block |
114 | * \param self The schedule state |
115 | * \param loop The loop to be parallelized/vectorized/bound |
116 | * \param for_kind The desired ForKind (only `kParallel`, `kVectorized` and `kThreadBinding` are |
117 | * allowed) |
118 | * \param thread_scope The thread scope of the thread axis to be bound, which is an invalid value if |
119 | * the operation is not "bind" |
120 | */ |
121 | void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind for_kind, |
122 | runtime::ThreadScope thread_scope) { |
123 | PreOrderVisit(loop, [&](const ObjectRef& node) { |
124 | if (const auto* realize = node.as<BlockRealizeNode>()) { |
125 | // If this block doesn't have corresponding StmtSRef in the schedule state, it must be a block |
126 | // inside `tir.init()`. We don't check the condition for such blocks. |
127 | if (!self->stmt2ref.count(realize->block.get())) { |
128 | return false; |
129 | } |
130 | CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef<BlockRealize>(realize), |
131 | thread_scope); |
132 | } |
133 | return true; |
134 | }); |
135 | } |
136 | |
137 | /*! |
138 | * \brief The implementation of parallelizing/vectorizing/binding a given loop |
139 | * \param self The schedule state |
140 | * \param loop_sref The sref of the loop to be parallelized/vectorized/bound |
141 | * \param for_kind The type of the operation (only `kParallel`, `kVectorized` and `kThreadBinding` |
142 | * are allowed) |
143 | * \param thread_axis The thread axis that the input loop is bound to, which is defined only when |
144 | * `for_kind` is `kThreadBinding` |
145 | */ |
146 | void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind, |
147 | Optional<IterVar> thread_axis) { |
148 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
149 | |
150 | /* |
151 | * Check: |
152 | * - 1. the subtree rooted from the input loop in sref tree has compact data flow |
153 | * - 2. all the blocks under the given loop have affine block bindings |
154 | * - 3. the input loop can be only bound to data parallel block iters, or the loop can be bound to |
155 | * reduction block iter if `thread` is `threadIdx.x/y/z` in case of cross-thread reduction |
156 | * When the above conditions are all satisfied, this input loop can be |
157 | * parallelized/vectorized/bound. |
158 | */ |
159 | // Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow. |
160 | CheckSubtreeCompactDataflow(self, loop_sref); |
161 | |
162 | // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each |
163 | // underlying block. |
164 | CheckParallelizability(self, GetRef<For>(loop), for_kind, |
165 | thread_axis.defined() |
166 | ? runtime::ThreadScope::Create(thread_axis.value()->thread_tag) |
167 | : runtime::ThreadScope{-1, -1}); |
168 | |
169 | // Step 3. Loop update and IR replacement |
170 | ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop); |
171 | new_loop->kind = for_kind; |
172 | new_loop->thread_binding = std::move(thread_axis); |
173 | self->Replace(loop_sref, For(new_loop), {}); |
174 | } |
175 | |
176 | void Parallel(ScheduleState self, const StmtSRef& loop_sref) { |
177 | ParallelizeComputation(self, loop_sref, ForKind::kParallel, NullOpt); |
178 | } |
179 | |
180 | void Vectorize(ScheduleState self, const StmtSRef& loop_sref) { |
181 | ParallelizeComputation(self, loop_sref, ForKind::kVectorized, NullOpt); |
182 | } |
183 | |
184 | void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis) { |
185 | ParallelizeComputation(self, loop_sref, ForKind::kThreadBinding, thread_axis); |
186 | } |
187 | |
188 | void Unroll(ScheduleState self, const StmtSRef& loop_sref) { |
189 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
190 | ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop); |
191 | new_loop->kind = ForKind::kUnrolled; |
192 | new_loop->thread_binding = NullOpt; |
193 | self->Replace(loop_sref, For(new_loop), {}); |
194 | } |
195 | |
196 | /******** InstructionKind Registration ********/ |
197 | |
198 | struct ParallelTraits : public UnpackedInstTraits<ParallelTraits> { |
199 | static constexpr const char* kName = "Parallel" ; |
200 | static constexpr bool kIsPure = false; |
201 | |
202 | private: |
203 | static constexpr size_t kNumInputs = 1; |
204 | static constexpr size_t kNumAttrs = 0; |
205 | static constexpr size_t kNumDecisions = 0; |
206 | |
207 | static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { |
208 | return sch->Parallel(loop_rv); |
209 | } |
210 | |
211 | static String UnpackedAsPython(Array<String> outputs, String loop_rv) { |
212 | PythonAPICall py("parallel" ); |
213 | py.Input("loop" , loop_rv); |
214 | return py.Str(); |
215 | } |
216 | |
217 | template <typename> |
218 | friend struct ::tvm::tir::UnpackedInstTraits; |
219 | }; |
220 | |
221 | struct VectorizeTraits : public UnpackedInstTraits<VectorizeTraits> { |
222 | static constexpr const char* kName = "Vectorize" ; |
223 | static constexpr bool kIsPure = false; |
224 | |
225 | private: |
226 | static constexpr size_t kNumInputs = 1; |
227 | static constexpr size_t kNumAttrs = 0; |
228 | static constexpr size_t kNumDecisions = 0; |
229 | |
230 | static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { |
231 | return sch->Vectorize(loop_rv); |
232 | } |
233 | |
234 | static String UnpackedAsPython(Array<String> outputs, String loop_rv) { |
235 | PythonAPICall py("vectorize" ); |
236 | py.Input("loop" , loop_rv); |
237 | return py.Str(); |
238 | } |
239 | |
240 | template <typename> |
241 | friend struct ::tvm::tir::UnpackedInstTraits; |
242 | }; |
243 | |
244 | struct BindTraits : public UnpackedInstTraits<BindTraits> { |
245 | static constexpr const char* kName = "Bind" ; |
246 | static constexpr bool kIsPure = false; |
247 | |
248 | private: |
249 | static constexpr size_t kNumInputs = 1; |
250 | static constexpr size_t kNumAttrs = 1; |
251 | static constexpr size_t kNumDecisions = 0; |
252 | |
253 | static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String thread) { |
254 | return sch->Bind(loop_rv, thread); |
255 | } |
256 | |
257 | static String UnpackedAsPython(Array<String> outputs, String loop_rv, String thread) { |
258 | PythonAPICall py("bind" ); |
259 | py.Input("loop" , loop_rv); |
260 | py.Input("thread_axis" , thread); |
261 | return py.Str(); |
262 | } |
263 | |
264 | template <typename> |
265 | friend struct ::tvm::tir::UnpackedInstTraits; |
266 | }; |
267 | |
268 | struct UnrollTraits : public UnpackedInstTraits<UnrollTraits> { |
269 | static constexpr const char* kName = "Unroll" ; |
270 | static constexpr bool kIsPure = false; |
271 | |
272 | private: |
273 | static constexpr size_t kNumInputs = 1; |
274 | static constexpr size_t kNumAttrs = 0; |
275 | static constexpr size_t kNumDecisions = 0; |
276 | |
277 | static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { return sch->Unroll(loop_rv); } |
278 | |
279 | static String UnpackedAsPython(Array<String> outputs, String loop_rv) { |
280 | PythonAPICall py("unroll" ); |
281 | py.Input("loop" , loop_rv); |
282 | return py.Str(); |
283 | } |
284 | |
285 | template <typename> |
286 | friend struct ::tvm::tir::UnpackedInstTraits; |
287 | }; |
288 | |
289 | TVM_REGISTER_INST_KIND_TRAITS(ParallelTraits); |
290 | TVM_REGISTER_INST_KIND_TRAITS(VectorizeTraits); |
291 | TVM_REGISTER_INST_KIND_TRAITS(BindTraits); |
292 | TVM_REGISTER_INST_KIND_TRAITS(UnrollTraits); |
293 | |
294 | } // namespace tir |
295 | } // namespace tvm |
296 | |