1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24class WrongBlockIterTypeError : public ScheduleError {
25 public:
26 explicit WrongBlockIterTypeError(IRModule mod, ForKind for_kind, Var loop_var, Block block)
27 : mod_(std::move(mod)), loop_var_(std::move(loop_var)), block_(std::move(block)) {
28 op_str_ = for_kind == ForKind::kParallel
29 ? "parallel"
30 : (for_kind == ForKind::kVectorized ? "vectorize" : "bind");
31 }
32 String FastErrorString() const final {
33 std::ostringstream os;
34 os << "ScheduleError: The \"" << op_str_
35 << "\" cannot be fulfilled with regard to some of its underlying block";
36 return os.str();
37 }
38 String DetailRenderTemplate() const final {
39 std::ostringstream os;
40 if (op_str_ != "bind") {
41 os << "The \"" << op_str_
42 << "\" cannot be fulfilled with regard to block {0} because some block iter whose block "
43 "binding contains the loop var is not a data parallel block iter";
44 } else {
45 os << "The \"bind\" cannot be fulfilled with regard to block {0}. This is because some of its"
46 " block iter whose block binding contains "
47 << loop_var_
48 << " does not meet any of the conditions:\n1) the block iter is data parallel;\n2) the "
49 "block iter is a reduction block iter, and the thread axis to be bound is "
50 "\"threadIdx.x/y/z\"";
51 }
52 return os.str();
53 }
54 IRModule mod() const final { return mod_; }
55 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
56 IRModule mod_;
57 std::string op_str_;
58 Var loop_var_;
59 Block block_;
60};
61
62/*!
63 * \brief Check if a loop can be parallelized/vectorized/bound with regard to a specific block
64 * \details There are two conditions:
65 * 1) The block is required to have affine bindings, and
66 * 2) For each block iter whose binding contains the input loop variable, either
67 * - the block iter is data parallel, or
68 * - the block iter is a reduction block iter, and the input `thread_tag` starts with "threadIdx"
69 * in case of cross-thread reduction.
70 * \param self The schedule state
71 * \param for_kind The desired ForKind (only `kParallel`, `kVectorized` and `kThreadBinding` are
72 * allowed)
73 * \param loop_var The loop variable of the loop to be checked
74 * \param block_realize The block-realize of the block to be checked
75 * \param thread_scope The thread scope of the thread axis to be bound, which is an invalid value if
76 * the operation is not "bind"
77 * \throws ScheduleError If the input loop cannot be parallelized/vectorized/bound with regard to
78 * the input block
79 */
80void CheckLoopParallelizableInBlock(const ScheduleState& self, ForKind for_kind,
81 const Var& loop_var, const BlockRealize& block_realize,
82 runtime::ThreadScope thread_scope) {
83 const Block& block = block_realize->block;
84
85 // Cond 1. The block is required to have affine bindings.
86 // TODO(@automation): fix the check
87 // CheckAffineBinding(self, block);
88
89 // Cond 2. For each block iter whose binding contains `loop_var`, only two cases are allowed.
90 ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size());
91 int n_iters = static_cast<int>(block->iter_vars.size());
92 for (int i = 0; i < n_iters; ++i) {
93 const IterVar& iter_var = block->iter_vars[i];
94 const PrimExpr& binding = block_realize->iter_values[i];
95
96 if (!UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) {
97 continue;
98 }
99 // Only two cases are allowed:
100 // - The block iter is data parallel, or
101 // - The block iter is a reduction block iter, and the `thread_scope` is "threadIdx.x/y/z"
102 // in case of cross-thread reduction.
103 IterVarType iter_type = iter_var->iter_type;
104 if (!(iter_type == kDataPar ||
105 (iter_type == kCommReduce && thread_scope.rank == 1 && thread_scope.dim_index != -1))) {
106 throw WrongBlockIterTypeError(self->mod, for_kind, loop_var, block);
107 }
108 }
109}
110
111/*!
112 * \brief For each block (recursive) under the given loop, check whether the input loop can be
113 * parallelized/vectorized/bound with regard to the block
114 * \param self The schedule state
115 * \param loop The loop to be parallelized/vectorized/bound
116 * \param for_kind The desired ForKind (only `kParallel`, `kVectorized` and `kThreadBinding` are
117 * allowed)
118 * \param thread_scope The thread scope of the thread axis to be bound, which is an invalid value if
119 * the operation is not "bind"
120 */
121void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind for_kind,
122 runtime::ThreadScope thread_scope) {
123 PreOrderVisit(loop, [&](const ObjectRef& node) {
124 if (const auto* realize = node.as<BlockRealizeNode>()) {
125 // If this block doesn't have corresponding StmtSRef in the schedule state, it must be a block
126 // inside `tir.init()`. We don't check the condition for such blocks.
127 if (!self->stmt2ref.count(realize->block.get())) {
128 return false;
129 }
130 CheckLoopParallelizableInBlock(self, for_kind, loop->loop_var, GetRef<BlockRealize>(realize),
131 thread_scope);
132 }
133 return true;
134 });
135}
136
137/*!
138 * \brief The implementation of parallelizing/vectorizing/binding a given loop
139 * \param self The schedule state
140 * \param loop_sref The sref of the loop to be parallelized/vectorized/bound
141 * \param for_kind The type of the operation (only `kParallel`, `kVectorized` and `kThreadBinding`
142 * are allowed)
143 * \param thread_axis The thread axis that the input loop is bound to, which is defined only when
144 * `for_kind` is `kThreadBinding`
145 */
146void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind,
147 Optional<IterVar> thread_axis) {
148 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
149
150 /*
151 * Check:
152 * - 1. the subtree rooted from the input loop in sref tree has compact data flow
153 * - 2. all the blocks under the given loop have affine block bindings
154 * - 3. the input loop can be only bound to data parallel block iters, or the loop can be bound to
155 * reduction block iter if `thread` is `threadIdx.x/y/z` in case of cross-thread reduction
156 * When the above conditions are all satisfied, this input loop can be
157 * parallelized/vectorized/bound.
158 */
159 // Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow.
160 CheckSubtreeCompactDataflow(self, loop_sref);
161
162 // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each
163 // underlying block.
164 CheckParallelizability(self, GetRef<For>(loop), for_kind,
165 thread_axis.defined()
166 ? runtime::ThreadScope::Create(thread_axis.value()->thread_tag)
167 : runtime::ThreadScope{-1, -1});
168
169 // Step 3. Loop update and IR replacement
170 ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
171 new_loop->kind = for_kind;
172 new_loop->thread_binding = std::move(thread_axis);
173 self->Replace(loop_sref, For(new_loop), {});
174}
175
176void Parallel(ScheduleState self, const StmtSRef& loop_sref) {
177 ParallelizeComputation(self, loop_sref, ForKind::kParallel, NullOpt);
178}
179
180void Vectorize(ScheduleState self, const StmtSRef& loop_sref) {
181 ParallelizeComputation(self, loop_sref, ForKind::kVectorized, NullOpt);
182}
183
184void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis) {
185 ParallelizeComputation(self, loop_sref, ForKind::kThreadBinding, thread_axis);
186}
187
188void Unroll(ScheduleState self, const StmtSRef& loop_sref) {
189 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
190 ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
191 new_loop->kind = ForKind::kUnrolled;
192 new_loop->thread_binding = NullOpt;
193 self->Replace(loop_sref, For(new_loop), {});
194}
195
196/******** InstructionKind Registration ********/
197
198struct ParallelTraits : public UnpackedInstTraits<ParallelTraits> {
199 static constexpr const char* kName = "Parallel";
200 static constexpr bool kIsPure = false;
201
202 private:
203 static constexpr size_t kNumInputs = 1;
204 static constexpr size_t kNumAttrs = 0;
205 static constexpr size_t kNumDecisions = 0;
206
207 static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) {
208 return sch->Parallel(loop_rv);
209 }
210
211 static String UnpackedAsPython(Array<String> outputs, String loop_rv) {
212 PythonAPICall py("parallel");
213 py.Input("loop", loop_rv);
214 return py.Str();
215 }
216
217 template <typename>
218 friend struct ::tvm::tir::UnpackedInstTraits;
219};
220
221struct VectorizeTraits : public UnpackedInstTraits<VectorizeTraits> {
222 static constexpr const char* kName = "Vectorize";
223 static constexpr bool kIsPure = false;
224
225 private:
226 static constexpr size_t kNumInputs = 1;
227 static constexpr size_t kNumAttrs = 0;
228 static constexpr size_t kNumDecisions = 0;
229
230 static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) {
231 return sch->Vectorize(loop_rv);
232 }
233
234 static String UnpackedAsPython(Array<String> outputs, String loop_rv) {
235 PythonAPICall py("vectorize");
236 py.Input("loop", loop_rv);
237 return py.Str();
238 }
239
240 template <typename>
241 friend struct ::tvm::tir::UnpackedInstTraits;
242};
243
244struct BindTraits : public UnpackedInstTraits<BindTraits> {
245 static constexpr const char* kName = "Bind";
246 static constexpr bool kIsPure = false;
247
248 private:
249 static constexpr size_t kNumInputs = 1;
250 static constexpr size_t kNumAttrs = 1;
251 static constexpr size_t kNumDecisions = 0;
252
253 static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, String thread) {
254 return sch->Bind(loop_rv, thread);
255 }
256
257 static String UnpackedAsPython(Array<String> outputs, String loop_rv, String thread) {
258 PythonAPICall py("bind");
259 py.Input("loop", loop_rv);
260 py.Input("thread_axis", thread);
261 return py.Str();
262 }
263
264 template <typename>
265 friend struct ::tvm::tir::UnpackedInstTraits;
266};
267
268struct UnrollTraits : public UnpackedInstTraits<UnrollTraits> {
269 static constexpr const char* kName = "Unroll";
270 static constexpr bool kIsPure = false;
271
272 private:
273 static constexpr size_t kNumInputs = 1;
274 static constexpr size_t kNumAttrs = 0;
275 static constexpr size_t kNumDecisions = 0;
276
277 static void UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv) { return sch->Unroll(loop_rv); }
278
279 static String UnpackedAsPython(Array<String> outputs, String loop_rv) {
280 PythonAPICall py("unroll");
281 py.Input("loop", loop_rv);
282 return py.Str();
283 }
284
285 template <typename>
286 friend struct ::tvm::tir::UnpackedInstTraits;
287};
288
289TVM_REGISTER_INST_KIND_TRAITS(ParallelTraits);
290TVM_REGISTER_INST_KIND_TRAITS(VectorizeTraits);
291TVM_REGISTER_INST_KIND_TRAITS(BindTraits);
292TVM_REGISTER_INST_KIND_TRAITS(UnrollTraits);
293
294} // namespace tir
295} // namespace tvm
296