1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /*! \brief Append a new predicate to the each child of type BlockRealize (not recursively) */ |
25 | class BlockPredicateAppender : public StmtMutator { |
26 | public: |
27 | /*! |
28 | * \brief Constructor |
29 | * \param to_append The predicate to be appended to BlockRealizeNode |
30 | */ |
31 | explicit BlockPredicateAppender(const PrimExpr& to_append) : to_append_(to_append) {} |
32 | |
33 | private: |
34 | // For each direct child of type BlockRealizeNode, append the predicate |
35 | Stmt VisitStmt_(const BlockRealizeNode* realize) final { |
36 | // We do not recursively do this |
37 | ObjectPtr<BlockRealizeNode> n = CopyOnWrite(realize); |
38 | n->predicate = n->predicate && to_append_; |
39 | return BlockRealize(n); |
40 | } |
41 | |
42 | /*! \brief The predicate to be appended */ |
43 | const PrimExpr& to_append_; |
44 | }; |
45 | |
46 | /*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ |
47 | class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { |
48 | public: |
49 | explicit SubstituteVarAndCollectOpaqueBlock(std::function<Optional<PrimExpr>(const Var&)> vmap, |
50 | Map<Block, Block>* opaque_blocks) |
51 | : vmap_(vmap), opaque_blocks_(opaque_blocks) {} |
52 | |
53 | private: |
54 | PrimExpr VisitExpr_(const VarNode* op) final { |
55 | Var var = GetRef<Var>(op); |
56 | if (Optional<PrimExpr> ret = vmap_(var)) { |
57 | return tvm::cast(var.dtype(), ret.value()); |
58 | } else { |
59 | return std::move(var); |
60 | } |
61 | } |
62 | |
63 | Stmt VisitStmt_(const BlockRealizeNode* op) final { |
64 | BlockRealize realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op)); |
65 | if (realize->block->iter_vars.empty()) { |
66 | opaque_blocks_->Set(op->block, realize->block); |
67 | } |
68 | return std::move(realize); |
69 | } |
70 | |
71 | /*! \brief The substitute function */ |
72 | std::function<Optional<PrimExpr>(const Var&)> vmap_; |
73 | /*! \brief The reuse mapping of opaque blocks */ |
74 | Map<Block, Block>* opaque_blocks_; |
75 | }; |
76 | |
77 | /*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ |
78 | class IterMapSimplifyBlockBinding : public StmtExprMutator { |
79 | public: |
80 | explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent, |
81 | bool preserve_unit_iters) |
82 | : opaque_blocks_(opaque_blocks), |
83 | loop_var2extent_(loop_var2extent), |
84 | preserve_unit_iters_(preserve_unit_iters) {} |
85 | |
86 | static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs, MapNode* opaque_blocks, |
87 | bool preserve_unit_iters) { |
88 | Map<Var, Range> loop_var2extent; |
89 | for (const StmtSRef& sref : loop_srefs) { |
90 | const ForNode* loop = TVM_SREF_TO_FOR(sref); |
91 | loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); |
92 | } |
93 | return Downcast<For>(IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent), |
94 | preserve_unit_iters)(std::move(stmt))); |
95 | } |
96 | |
97 | private: |
98 | Stmt VisitStmt_(const ForNode* op) final { |
99 | loop_var2extent_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent)); |
100 | Stmt res = StmtMutator::VisitStmt_(op); |
101 | loop_var2extent_.erase(op->loop_var); |
102 | return res; |
103 | } |
104 | |
105 | Stmt VisitStmt_(const BlockRealizeNode* op) final { |
106 | // skip opaque block and update mapping |
107 | if (op->iter_values.empty()) { |
108 | Block block = op->block; |
109 | BlockRealize realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op)); |
110 | for (const std::pair<ObjectRef, ObjectRef>& entry : *opaque_blocks_) { |
111 | if (entry.second.same_as(block)) { |
112 | opaque_blocks_->at(entry.first) = realize->block; |
113 | break; |
114 | } |
115 | } |
116 | return std::move(realize); |
117 | } |
118 | Array<PrimExpr> v = |
119 | arith::IterMapSimplify(/*indices=*/op->iter_values, |
120 | /*input_iters=*/loop_var2extent_, |
121 | /*input_pred=*/op->predicate, |
122 | /*check_level=*/arith::IterMapLevel::Surjective, |
123 | /*simplify_trivial_iterators=*/!preserve_unit_iters_); |
124 | if (v.same_as(op->iter_values)) { |
125 | return GetRef<Stmt>(op); |
126 | } else { |
127 | ObjectPtr<BlockRealizeNode> n = CopyOnWrite(op); |
128 | n->iter_values = std::move(v); |
129 | return Stmt(n); |
130 | } |
131 | } |
132 | |
133 | /*! \brief The reuse mapping */ |
134 | MapNode* opaque_blocks_; |
135 | /*! \brief The range of loops */ |
136 | Map<Var, Range> loop_var2extent_; |
137 | /*! \brief Whether or not to simplify unit iterators */ |
138 | bool preserve_unit_iters_; |
139 | }; |
140 | |
141 | class BlockPropertyError : public ScheduleError { |
142 | public: |
143 | /*! |
144 | * \brief Check that all the blocks under the specific stmt have affine bindings |
145 | * wrt top loop sref and only have data-parallel or reduction block iters |
146 | * \param self The state of the schedule |
147 | * \param sref The sref to the specific stmt |
148 | */ |
149 | static void CheckBlockIterTypeAndAffineBinding(const ScheduleState& self, const StmtSRefNode* top, |
150 | const StmtSRefNode* sref) { |
151 | class BlockIterTypeAndAffineBindingChecker : public StmtVisitor { |
152 | public: |
153 | explicit BlockIterTypeAndAffineBindingChecker(const ScheduleState& state, |
154 | const StmtSRefNode* top) |
155 | : state_(state), top_(top) {} |
156 | |
157 | private: |
158 | void VisitStmt_(const BlockNode* op) final { |
159 | for (const IterVar& iter_var : op->iter_vars) { |
160 | if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { |
161 | throw BlockPropertyError(state_->mod, GetRef<Block>(op)); |
162 | } |
163 | Optional<StmtSRef> high_exclusive = |
164 | top_->parent ? GetRef<StmtSRef>(top_->parent) : Optional<StmtSRef>(NullOpt); |
165 | CheckPartialAffineBinding(state_, GetRef<Block>(op), high_exclusive); |
166 | } |
167 | } |
168 | const ScheduleState& state_; |
169 | const StmtSRefNode* top_; |
170 | }; |
171 | |
172 | BlockIterTypeAndAffineBindingChecker checker(self, top); |
173 | checker(GetRef<Stmt>(sref->stmt)); |
174 | } |
175 | |
176 | explicit BlockPropertyError(IRModule mod, Block block) : mod_(mod), block_(std::move(block)) {} |
177 | |
178 | String FastErrorString() const final { |
179 | return "ScheduleError: The block under the loops to be reordered have block iter type other " |
180 | "than data-parallel or reduction" ; |
181 | } |
182 | |
183 | String DetailRenderTemplate() const final { |
184 | return "The block {0} under the loops to be reordered have block iter type other than " |
185 | "data-parallel or reduction" ; |
186 | } |
187 | |
188 | IRModule mod() const final { return mod_; } |
189 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
190 | |
191 | IRModule mod_; |
192 | Block block_; |
193 | }; |
194 | |
195 | class HasAnnotationOrThreadBindingError : public ScheduleError { |
196 | public: |
197 | explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop) |
198 | : mod_(mod), loop_(std::move(loop)) {} |
199 | |
200 | String FastErrorString() const final { |
201 | return "ScheduleError: The primitive can't be applied because the loop has annotation or " |
202 | "thread binding" ; |
203 | } |
204 | |
205 | String DetailRenderTemplate() const final { |
206 | return "The primitive can't be applied because the loop {0} has annotation or thread binding" ; |
207 | } |
208 | |
209 | IRModule mod() const final { return mod_; } |
210 | Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; } |
211 | |
212 | IRModule mod_; |
213 | For loop_; |
214 | }; |
215 | |
216 | class OuterNotInnerParent : public ScheduleError { |
217 | public: |
218 | explicit OuterNotInnerParent(IRModule mod, For outer, For inner) |
219 | : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} |
220 | |
221 | String FastErrorString() const final { |
222 | return "ScheduleError: The outer loop is not the parent of the inner loop" ; |
223 | } |
224 | |
225 | String DetailRenderTemplate() const final { |
226 | return "The loops can't be fused because the outer loop {0} is not the parent of the inner " |
227 | "loop {1}" ; |
228 | } |
229 | |
230 | IRModule mod() const final { return mod_; } |
231 | Array<ObjectRef> LocationsOfInterest() const final { return {outer_, inner_}; } |
232 | |
233 | IRModule mod_; |
234 | For outer_; |
235 | For inner_; |
236 | }; |
237 | |
238 | class NotOnlyChildError : public ScheduleError { |
239 | public: |
240 | explicit NotOnlyChildError(IRModule mod, For outer, For inner) |
241 | : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} |
242 | |
243 | String FastErrorString() const final { |
244 | return "ScheduleError: The inner loop is not the only child of outer loop" ; |
245 | } |
246 | |
247 | String DetailRenderTemplate() const final { |
248 | return "The loops can't be fused because the inner loop {1} is not the only child of outer " |
249 | "loop {0}." ; |
250 | } |
251 | |
252 | IRModule mod() const final { return mod_; } |
253 | Array<ObjectRef> LocationsOfInterest() const final { return {outer_, inner_}; } |
254 | |
255 | IRModule mod_; |
256 | For outer_; |
257 | For inner_; |
258 | }; |
259 | |
260 | class NotSingleInferFactorError : public ScheduleError { |
261 | public: |
262 | explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} |
263 | |
264 | String FastErrorString() const final { |
265 | return "ScheduleError: only one factor can be specified as -1 or none" ; |
266 | } |
267 | |
268 | String DetailRenderTemplate() const final { |
269 | return "Only one factor can be specified as -1 or none" ; |
270 | } |
271 | |
272 | IRModule mod() const final { return mod_; } |
273 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
274 | |
275 | IRModule mod_; |
276 | }; |
277 | |
278 | class WrongFactorProductError : public ScheduleError { |
279 | public: |
280 | explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} |
281 | |
282 | String FastErrorString() const final { |
283 | return "ScheduleError: The product of factors is not larger than or equal to the extent of " |
284 | "loop" ; |
285 | } |
286 | |
287 | String DetailRenderTemplate() const final { |
288 | return "The product of factors is not larger than or equal to the extent of loop {0}" ; |
289 | } |
290 | |
291 | IRModule mod() const final { return mod_; } |
292 | Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; } |
293 | |
294 | IRModule mod_; |
295 | For loop_; |
296 | }; |
297 | |
298 | class LoopMultiAppearanceError : public ScheduleError { |
299 | public: |
300 | explicit LoopMultiAppearanceError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} |
301 | |
302 | String FastErrorString() const final { |
303 | return "ScheduleError: Some loop appears in the input array for multiple times." ; |
304 | } |
305 | |
306 | String DetailRenderTemplate() const final { |
307 | return "Loop {0} appears in the input array for multiple times." ; |
308 | } |
309 | |
310 | IRModule mod() const final { return mod_; } |
311 | Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; } |
312 | |
313 | IRModule mod_; |
314 | For loop_; |
315 | }; |
316 | |
317 | class LoopsNotAChainError : public ScheduleError { |
318 | public: |
319 | enum class ProblemKind { kNotUnderAScope, kHaveNonSingleBranchStmt }; |
320 | |
321 | explicit LoopsNotAChainError(IRModule mod, Optional<Stmt> problematic_loop, ProblemKind kind) |
322 | : mod_(mod), problematic_loop_(std::move(problematic_loop)), kind_(kind) {} |
323 | |
324 | String FastErrorString() const final { return "ScheduleError: the loops are not in a chain" ; } |
325 | |
326 | String DetailRenderTemplate() const final { |
327 | std::stringstream ss; |
328 | ss << "The loops are not in a chain because" ; |
329 | if (kind_ == ProblemKind::kNotUnderAScope) { |
330 | ss << " they are not under the same scope." ; |
331 | } else { |
332 | ss << " there is a non-single-branch stmt in between. Problematic stmt: {0}" ; |
333 | } |
334 | return ss.str(); |
335 | } |
336 | |
337 | IRModule mod() const final { return mod_; } |
338 | Array<ObjectRef> LocationsOfInterest() const final { |
339 | if (kind_ == ProblemKind::kNotUnderAScope) { |
340 | return {}; |
341 | } else { |
342 | ICHECK(problematic_loop_.defined()); |
343 | return {problematic_loop_.value()}; |
344 | } |
345 | } |
346 | |
347 | IRModule mod_; |
348 | Optional<Stmt> problematic_loop_; |
349 | ProblemKind kind_; |
350 | }; |
351 | |
352 | class DependentLoopError : public ScheduleError { |
353 | public: |
354 | enum class PrimitiveKind { kFuse, kReorder }; |
355 | explicit DependentLoopError(IRModule mod, For loop, String inner_var, PrimitiveKind kind) |
356 | : mod_(mod), loop_(std::move(loop)), inner_var_(std::move(inner_var)), kind_(kind) {} |
357 | |
358 | String FastErrorString() const final { |
359 | if (kind_ == PrimitiveKind::kReorder) { |
360 | return "ScheduleError: An outer loop's `min` or `extent` is dependent on an inner loop " |
361 | "in the new order" ; |
362 | } else { |
363 | return "ScheduleError: A loop's `extent` is dependent on another loop" ; |
364 | } |
365 | } |
366 | |
367 | String DetailRenderTemplate() const final { |
368 | if (kind_ == PrimitiveKind::kReorder) { |
369 | return "Outer Loop {0}'s `min` or `extent` is dependent on an inner loop " + inner_var_ + |
370 | " in the new order" ; |
371 | } else { |
372 | return "A loop {0}'s `extent` is dependent on another loop " + inner_var_; |
373 | } |
374 | } |
375 | |
376 | IRModule mod() const final { return mod_; } |
377 | Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; } |
378 | |
379 | IRModule mod_; |
380 | For loop_; |
381 | String inner_var_; |
382 | PrimitiveKind kind_; |
383 | }; |
384 | |
385 | Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref, const Array<PrimExpr>& factors, |
386 | bool preserve_unit_iters) { |
387 | // Invariance |
388 | // - The total repeat number has not changed for each direct child block with updating predicate. |
389 | // - The execution order has not changed. (The block executes with the same args and the same |
390 | // order with before. |
391 | // Step 1. Check correctness |
392 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
393 | if (!loop->annotations.empty() || loop->thread_binding.defined()) { |
394 | throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop)); |
395 | } |
396 | // Currently, loops not starting with 0 are not supported |
397 | arith::Analyzer analyzer; |
398 | CheckLoopStartsWithZero(self, loop_sref, &analyzer); |
399 | |
400 | // Find the most common dtype |
401 | DataType dtype; |
402 | { |
403 | int bits = loop->loop_var.dtype().bits(); |
404 | for (const PrimExpr& factor : factors) { |
405 | bits = std::max(bits, factor.dtype().bits()); |
406 | } |
407 | dtype = DataType::Int(bits); |
408 | } |
409 | int n = factors.size(); |
410 | PrimExpr substitute_value = make_const(dtype, 0); |
411 | std::vector<Var> new_loop_vars; |
412 | new_loop_vars.reserve(n); |
413 | for (int i = 0; i < n; i++) { |
414 | const PrimExpr& factor = factors[i]; |
415 | Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)).copy_with_dtype(dtype); |
416 | substitute_value = substitute_value * factor + var; |
417 | analyzer.Bind(var, Range::FromMinExtent(make_const(dtype, 0), tvm::cast(dtype, factor))); |
418 | new_loop_vars.emplace_back(std::move(var)); |
419 | } |
420 | Map<Block, Block> opaque_block_reuse; |
421 | Stmt new_stmt = loop->body; |
422 | new_stmt = SubstituteVarAndCollectOpaqueBlock( |
423 | [&](const Var& v) -> Optional<PrimExpr> { |
424 | if (v.same_as(loop->loop_var)) { |
425 | return substitute_value; |
426 | } else { |
427 | return NullOpt; |
428 | } |
429 | }, |
430 | &opaque_block_reuse)(std::move(new_stmt)); |
431 | // Step 3. Update predicate to guard the loop |
432 | PrimExpr predicate = substitute_value < loop->extent; |
433 | if (!analyzer.CanProve(predicate)) { |
434 | new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); |
435 | } |
436 | // Step 4. Generate nested loops to replace the original loop and simplify the binding |
437 | for (int i = n - 1; i >= 0; i--) { |
438 | new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt); |
439 | } |
440 | new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref), |
441 | opaque_block_reuse.CopyOnWrite(), |
442 | preserve_unit_iters); |
443 | self->Replace(loop_sref, new_stmt, opaque_block_reuse); |
444 | Array<StmtSRef> result_srefs; |
445 | result_srefs.reserve(n); |
446 | for (int i = 0; i < n; i++) { |
447 | result_srefs.push_back(self->stmt2ref.at(new_stmt.get())); |
448 | const ForNode* outer_loop = TVM_TYPE_AS(new_stmt, ForNode); |
449 | new_stmt = outer_loop->body; |
450 | } |
451 | return result_srefs; |
452 | } |
453 | |
454 | StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs, bool preserve_unit_iters) { |
455 | // Invariance |
456 | // - The total repeat number has not changed for each direct child block. |
457 | // - The execution order has not changed. (The block executes with the same |
458 | // args and the same order with before.) |
459 | std::vector<const ForNode*> loops; |
460 | loops.reserve(loop_srefs.size()); |
461 | StmtSRef outer_loop_sref{nullptr}; |
462 | const ForNode* outer_loop = nullptr; |
463 | arith::Analyzer analyzer; |
464 | std::unordered_set<const VarNode*> outer_loop_vars; |
465 | // Step 1. check correctness |
466 | for (const StmtSRef& sref : loop_srefs) { |
467 | const ForNode* loop = TVM_SREF_TO_FOR(sref); |
468 | if (!loop->annotations.empty() || loop->thread_binding.defined()) { |
469 | throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop)); |
470 | } |
471 | if (outer_loop_sref.defined()) { |
472 | if (sref->parent != outer_loop_sref.get()) { |
473 | throw OuterNotInnerParent(self->mod, GetRef<For>(outer_loop), GetRef<For>(loop)); |
474 | } |
475 | if (!outer_loop->body.same_as(GetRef<For>(loop))) { |
476 | throw NotOnlyChildError(self->mod, GetRef<For>(outer_loop), GetRef<For>(loop)); |
477 | } |
478 | } |
479 | outer_loop_sref = sref; |
480 | outer_loop = loop; |
481 | CheckLoopStartsWithZero(self, sref, &analyzer); |
482 | const VarNode* used_var = nullptr; |
483 | auto f_contain = [&outer_loop_vars, &used_var](const VarNode* var) { |
484 | if (outer_loop_vars.count(var)) { |
485 | used_var = var; |
486 | return true; |
487 | } |
488 | return false; |
489 | }; |
490 | if (UsesVar(loop->extent, f_contain)) { |
491 | throw DependentLoopError(self->mod, GetRef<For>(loop), used_var->name_hint, |
492 | DependentLoopError::PrimitiveKind::kFuse); |
493 | } |
494 | outer_loop_vars.insert(loop->loop_var.get()); |
495 | loops.push_back(loop); |
496 | } |
497 | // Step 2. Create fused loop var and replace the original loop vars |
498 | std::string suffix; |
499 | int n = loops.size(); |
500 | int bits = loops[0]->loop_var.dtype().bits(); |
501 | for (int i = 1; i < n; i++) { |
502 | suffix += "_" + loops[i]->loop_var->name_hint; |
503 | bits = std::max(bits, loops[i]->loop_var.dtype().bits()); |
504 | } |
505 | suffix += "_fused" ; |
506 | Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix).copy_with_dtype(DataType::Int(bits)); |
507 | Array<PrimExpr> substitute_value; |
508 | substitute_value.resize(loops.size()); |
509 | PrimExpr lower = 1; |
510 | for (int i = static_cast<int>(loops.size()) - 1; i > 0; i--) { |
511 | substitute_value.Set(i, is_one(loops[i]->extent) |
512 | ? 0 |
513 | : floordiv(floormod(fused_var, lower * loops[i]->extent), lower)); |
514 | lower = lower * loops[i]->extent; |
515 | } |
516 | substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, lower)); |
517 | Stmt new_stmt = loops.back()->body; |
518 | Map<Block, Block> opaque_block_reuse; |
519 | auto f_substitute = [&](const Var& v) -> Optional<PrimExpr> { |
520 | for (int i = 0; i < n; i++) { |
521 | if (v.same_as(loops[i]->loop_var)) { |
522 | return substitute_value[i]; |
523 | } |
524 | } |
525 | return NullOpt; |
526 | }; |
527 | new_stmt = |
528 | SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt)); |
529 | // Step 3. Generate a loop to replace the original loops |
530 | PrimExpr fused_extent = 1; |
531 | for (int i = 0; i < n; i++) { |
532 | fused_extent *= loops[i]->extent; |
533 | } |
534 | fused_extent = analyzer.Simplify(fused_extent); |
535 | new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt); |
536 | new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings( |
537 | std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite(), |
538 | preserve_unit_iters); |
539 | self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); |
540 | return self->stmt2ref.at(new_stmt.get()); |
541 | } |
542 | |
543 | /*! |
544 | * \brief Collect an array of loop srefs into a set |
545 | * \param self The schedule state |
546 | * \param ordered_loop_srefs The array of loop srefs |
547 | * \return A set containing all loops in the array |
548 | * \throws ScheduleError If there are duplicate loops in the array |
549 | */ |
550 | std::unordered_set<const StmtSRefNode*> CollectLoopsIntoSet( |
551 | const ScheduleState& self, const Array<StmtSRef>& ordered_loop_srefs) { |
552 | std::unordered_set<const StmtSRefNode*> loop_srefs; |
553 | loop_srefs.reserve(ordered_loop_srefs.size()); |
554 | for (const StmtSRef& loop_sref : ordered_loop_srefs) { |
555 | auto inserted = loop_srefs.insert(loop_sref.get()); |
556 | if (!inserted.second) { |
557 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
558 | throw LoopMultiAppearanceError(self->mod, GetRef<For>(loop)); |
559 | } |
560 | } |
561 | return loop_srefs; |
562 | } |
563 | |
564 | /*! |
565 | * \brief Get the top and bottom boundary of reorder range (which should be a chain) |
566 | * \param self The schedule state |
567 | * \param loop_srefs The set containing the srefs to the loops to be reordered |
568 | * \return A pair containing the top and bottom boundary of the reorder range |
569 | * \throws ScheduleError If the loops to be reordered is not in a chain |
570 | */ |
571 | std::pair<const StmtSRefNode*, const StmtSRefNode*> GetBoundaryOfReorderRange( |
572 | const ScheduleState& self, const std::unordered_set<const StmtSRefNode*>& loop_srefs) { |
573 | const StmtSRefNode* top = nullptr; |
574 | const StmtSRefNode* bottom = *loop_srefs.begin(); |
575 | std::unordered_set<const StmtSRefNode*> visited; |
576 | bool scope_block_visited = false; |
577 | bool first_traversal = true; |
578 | for (const StmtSRefNode* loop_sref : loop_srefs) { |
579 | if (visited.count(loop_sref)) { |
580 | continue; |
581 | } |
582 | for (const StmtSRefNode* v = loop_sref;; v = v->parent) { |
583 | // Case 1. If `v` corresponds to a block, stop traversal. |
584 | if (v->stmt->IsInstance<BlockNode>()) { |
585 | if (scope_block_visited) { |
586 | throw LoopsNotAChainError(self->mod, NullOpt, |
587 | LoopsNotAChainError::ProblemKind::kNotUnderAScope); |
588 | } |
589 | scope_block_visited = true; |
590 | break; |
591 | } |
592 | // Case 2. If `v` corresponds to a previously-visited loop, stop traversal and update |
593 | // `bottom`. |
594 | if (visited.count(v)) { |
595 | if (v != bottom) { |
596 | throw LoopsNotAChainError(self->mod, GetRef<Stmt>(v->stmt), |
597 | LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); |
598 | } |
599 | bottom = loop_sref; |
600 | break; |
601 | } |
602 | // Case 3. Add `v` into `visited` |
603 | visited.insert(v); |
604 | // If it's the first traversal and the loop corresponding to `v` is in the input array, |
605 | // update `top`. |
606 | if (first_traversal && loop_srefs.count(v)) { |
607 | top = v; |
608 | } |
609 | } |
610 | first_traversal = false; |
611 | } |
612 | return std::make_pair(top, bottom); |
613 | } |
614 | |
615 | /*! |
616 | * \brief Get all the loops in the reorder range |
617 | * \param self The schedule state |
618 | * \param top The top boundary of the reorder range |
619 | * \param bottom The bottom boundary of the reorder range |
620 | * \return An array containing all the loops in the reorder range |
621 | * \throws ScheduleError If some loop in the reorder range is not single-branch |
622 | */ |
623 | std::vector<const StmtSRefNode*> GetLoopsInReorderRange(const ScheduleState& self, |
624 | const StmtSRefNode* top, |
625 | const StmtSRefNode* bottom) { |
626 | std::vector<const StmtSRefNode*> chain; |
627 | for (const StmtSRefNode* loop_sref = bottom; loop_sref != top;) { |
628 | const StmtSRefNode* parent_loop_sref = loop_sref->parent; |
629 | const ForNode* outer = parent_loop_sref->StmtAs<ForNode>(); |
630 | const ForNode* inner = loop_sref->StmtAs<ForNode>(); |
631 | ICHECK(outer != nullptr && inner != nullptr); |
632 | if (outer->body.get() != inner) { |
633 | throw LoopsNotAChainError(self->mod, GetRef<For>(outer), |
634 | LoopsNotAChainError::ProblemKind::kHaveNonSingleBranchStmt); |
635 | } |
636 | chain.push_back(loop_sref); |
637 | loop_sref = parent_loop_sref; |
638 | } |
639 | chain.push_back(top); |
640 | return chain; |
641 | } |
642 | |
643 | /*! |
644 | * \brief Construct a loop chain in the new order |
645 | * \param self The schedule state |
646 | * \param chain The loops in the reorder range |
647 | * \param ordered_loop_srefs The loop srefs to be reordered |
648 | * \param loop_srefs The set containing loop srefs to be reordered |
649 | * \return The new loop chain |
650 | * \throws ScheduleError If the domain of an outer loop depends on any of the inner loops after |
651 | * reordering |
652 | */ |
653 | For ConstructNewLoopChain(const ScheduleState& self, std::vector<const StmtSRefNode*> chain, |
654 | const Array<StmtSRef>& ordered_loop_srefs, |
655 | const std::unordered_set<const StmtSRefNode*>& loop_srefs) { |
656 | std::unordered_set<const VarNode*> inner_vars; |
657 | inner_vars.reserve(chain.size()); |
658 | For new_loop{nullptr}; |
659 | int index = static_cast<int>(ordered_loop_srefs.size()) - 1; |
660 | for (const StmtSRefNode* loop_sref : chain) { |
661 | const ForNode* copy = nullptr; |
662 | if (loop_srefs.count(loop_sref)) { |
663 | copy = ordered_loop_srefs[index]->StmtAs<ForNode>(); |
664 | --index; |
665 | } else { |
666 | copy = loop_sref->StmtAs<ForNode>(); |
667 | } |
668 | ICHECK(copy != nullptr); |
669 | ObjectPtr<ForNode> n = make_object<ForNode>(*copy); |
670 | if (new_loop.defined()) { |
671 | n->body = new_loop; |
672 | } else { |
673 | n->body = loop_sref->StmtAs<ForNode>()->body; |
674 | } |
675 | const VarNode* used_var = nullptr; |
676 | auto f_contain = [&inner_vars, &used_var](const VarNode* var) { |
677 | if (inner_vars.count(var)) { |
678 | used_var = var; |
679 | return true; |
680 | } |
681 | return false; |
682 | }; |
683 | if (UsesVar(copy->min, f_contain) || UsesVar(copy->extent, f_contain)) { |
684 | throw DependentLoopError(self->mod, GetRef<For>(copy), used_var->name_hint, |
685 | DependentLoopError::PrimitiveKind::kReorder); |
686 | } |
687 | inner_vars.insert(copy->loop_var.get()); |
688 | new_loop = For(std::move(n)); |
689 | } |
690 | return new_loop; |
691 | } |
692 | |
693 | void Reorder(ScheduleState self, const Array<StmtSRef>& ordered_loop_srefs) { |
694 | if (ordered_loop_srefs.size() <= 1) { |
695 | return; |
696 | } |
697 | // Step 1. Check uniqueness and collect the input loop srefs into a set |
698 | std::unordered_set<const StmtSRefNode*> loop_srefs = |
699 | CollectLoopsIntoSet(self, ordered_loop_srefs); |
700 | // Step 2. Gather loops to be reordered |
701 | // For each loop sref in the input sref array, traverse upwards along its parent pointer in the |
702 | // sref tree, and stop on either a block, or a previously-visited loop |
703 | // - the top of the reorder range is the last loop visited in the first traversal which exists in |
704 | // the input array |
705 | // - the bottom of the reorder range is the last loop in the input array which is not visited in |
706 | // the previous traversals |
707 | auto [top, bottom] = GetBoundaryOfReorderRange(self, loop_srefs); |
708 | // Step 3. Collect all loops in the chain and check the loops are single-branch |
709 | std::vector<const StmtSRefNode*> chain = GetLoopsInReorderRange(self, top, bottom); |
710 | // Step 4. Check the block below has all its block_var to be data-parallel or reduction, |
711 | // and the block has an affine binding wrt top of the loop range. |
712 | BlockPropertyError::CheckBlockIterTypeAndAffineBinding(self, top, bottom); |
713 | // Step 5. Replace the original loops with the reordered loops and check that outer loop is |
714 | // not dependent on inner loop |
715 | For new_loop = ConstructNewLoopChain(self, std::move(chain), ordered_loop_srefs, loop_srefs); |
716 | self->Replace(GetRef<StmtSRef>(top), new_loop, {}); |
717 | } |
718 | |
719 | StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { |
720 | if (sref->stmt->IsInstance<ForNode>()) { |
721 | For new_loop(Var("u" , DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef<Stmt>(sref->stmt)); |
722 | self->Replace(sref, new_loop, {}); |
723 | return self->stmt2ref.at(new_loop.get()); |
724 | } |
725 | class NewLoopCreator : public StmtMutator { |
726 | public: |
727 | explicit NewLoopCreator(const StmtNode* src_block) : src_block_(src_block) {} |
728 | |
729 | Stmt VisitStmt_(const BlockRealizeNode* realize) final { |
730 | if (realize->block.get() == src_block_) { |
731 | new_loop_ = |
732 | For(Var("u" , DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef<BlockRealize>(realize)); |
733 | return new_loop_; |
734 | } |
735 | return StmtMutator::VisitStmt_(realize); |
736 | } |
737 | |
738 | const StmtNode* src_block_; |
739 | For new_loop_{nullptr}; |
740 | }; |
741 | |
742 | CHECK(sref->parent != nullptr) << "ValueError: Cannot add loops on top of the root block" ; |
743 | StmtSRef parent_sref = GetRef<StmtSRef>(sref->parent); |
744 | NewLoopCreator creator(sref->stmt); |
745 | Stmt new_stmt = creator(GetRef<Stmt>(parent_sref->stmt)); |
746 | if (new_stmt->IsInstance<ForNode>()) { |
747 | self->Replace(parent_sref, std::move(new_stmt), {}); |
748 | } else { |
749 | Block old_parent_block = GetRef<Block>(parent_sref->StmtAs<BlockNode>()); |
750 | Block new_parent_block = Downcast<Block>(new_stmt); |
751 | self->Replace(parent_sref, new_stmt, {{old_parent_block, new_parent_block}}); |
752 | } |
753 | return self->stmt2ref.at(creator.new_loop_.get()); |
754 | } |
755 | |
756 | /******** InstructionKind Registration ********/ |
757 | |
758 | struct SplitTraits : public UnpackedInstTraits<SplitTraits> { |
759 | static constexpr const char* kName = "Split" ; |
760 | static constexpr bool kIsPure = false; |
761 | |
762 | private: |
763 | static constexpr size_t kNumInputs = 2; |
764 | static constexpr size_t kNumAttrs = 1; |
765 | static constexpr size_t kNumDecisions = 0; |
766 | |
767 | template <size_t delta> |
768 | static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, |
769 | const Array<ObjectRef>& inputs) { |
770 | thread_local ObjectRef loop_rv{nullptr}; |
771 | thread_local Array<ObjectRef> factors{nullptr}; |
772 | loop_rv = inputs[0]; |
773 | factors = Array<ObjectRef>{inputs.begin() + 1, inputs.end()}; |
774 | setter(delta, loop_rv); |
775 | setter(delta + 1, factors); |
776 | } |
777 | |
778 | static Array<LoopRV> UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, |
779 | Array<Optional<ExprRV>> factors, |
780 | Bool preserve_unit_iters) { |
781 | return sch->Split(loop_rv, factors, preserve_unit_iters.operator bool()); |
782 | } |
783 | |
784 | static String UnpackedAsPython(Array<String> outputs, String loop_rv, Array<ObjectRef> factors, |
785 | Bool preserve_unit_iters) { |
786 | PythonAPICall py("split" ); |
787 | py.Input("loop" , loop_rv); |
788 | py.Input("factors" , factors); |
789 | py.Input("preserve_unit_iters" , preserve_unit_iters.operator bool()); |
790 | py.OutputList(outputs); |
791 | return py.Str(); |
792 | } |
793 | |
794 | template <typename> |
795 | friend struct ::tvm::tir::UnpackedInstTraits; |
796 | }; |
797 | |
798 | struct FuseTraits : public UnpackedInstTraits<FuseTraits> { |
799 | static constexpr const char* kName = "Fuse" ; |
800 | static constexpr bool kIsPure = false; |
801 | |
802 | private: |
803 | static constexpr size_t kNumInputs = 1; |
804 | static constexpr size_t kNumAttrs = 1; |
805 | static constexpr size_t kNumDecisions = 0; |
806 | |
807 | template <size_t delta> |
808 | static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, |
809 | const Array<ObjectRef>& inputs) { |
810 | setter(delta, inputs); |
811 | } |
812 | |
813 | static LoopRV UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs, |
814 | Bool preserve_unit_iters) { |
815 | return sch->Fuse(loop_rvs, preserve_unit_iters.operator bool()); |
816 | } |
817 | |
818 | static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs, |
819 | Bool preserve_unit_iters) { |
820 | PythonAPICall py("fuse" ); |
821 | for (const String& loop_rv : loop_rvs) { |
822 | py.Input("" , loop_rv); |
823 | } |
824 | py.Input("preserve_unit_iters" , preserve_unit_iters.operator bool()); |
825 | py.SingleOutput(outputs); |
826 | return py.Str(); |
827 | } |
828 | |
829 | template <typename> |
830 | friend struct ::tvm::tir::UnpackedInstTraits; |
831 | }; |
832 | |
833 | struct ReorderTraits : public UnpackedInstTraits<ReorderTraits> { |
834 | static constexpr const char* kName = "Reorder" ; |
835 | static constexpr bool kIsPure = false; |
836 | |
837 | private: |
838 | static constexpr size_t kNumInputs = 1; |
839 | static constexpr size_t kNumAttrs = 0; |
840 | static constexpr size_t kNumDecisions = 0; |
841 | |
842 | template <size_t delta> |
843 | static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter, |
844 | const Array<ObjectRef>& inputs) { |
845 | setter(delta, inputs); |
846 | } |
847 | |
848 | static void UnpackedApplyToSchedule(Schedule sch, Array<LoopRV> loop_rvs) { |
849 | return sch->Reorder(loop_rvs); |
850 | } |
851 | |
852 | static String UnpackedAsPython(Array<String> outputs, Array<String> loop_rvs) { |
853 | PythonAPICall py("reorder" ); |
854 | for (const String& loop_rv : loop_rvs) { |
855 | py.Input("" , loop_rv); |
856 | } |
857 | return py.Str(); |
858 | } |
859 | |
860 | template <typename> |
861 | friend struct ::tvm::tir::UnpackedInstTraits; |
862 | }; |
863 | |
864 | struct AddUnitLoopTraits : public UnpackedInstTraits<AddUnitLoopTraits> { |
865 | static constexpr const char* kName = "AddUnitLoop" ; |
866 | static constexpr bool kIsPure = false; |
867 | |
868 | private: |
869 | static constexpr size_t kNumInputs = 1; |
870 | static constexpr size_t kNumAttrs = 0; |
871 | static constexpr size_t kNumDecisions = 0; |
872 | |
873 | static LoopRV UnpackedApplyToSchedule(Schedule sch, ObjectRef rv) { |
874 | if (const auto* block = rv.as<BlockRVNode>()) { |
875 | return sch->AddUnitLoop(GetRef<BlockRV>(block)); |
876 | } else if (const auto* loop = rv.as<LoopRVNode>()) { |
877 | return sch->AddUnitLoop(GetRef<LoopRV>(loop)); |
878 | } else { |
879 | LOG(FATAL) << "TypeError: AddUnitLoop expects a loop or block" ; |
880 | throw; |
881 | } |
882 | } |
883 | |
884 | static String UnpackedAsPython(Array<String> outputs, String rv) { |
885 | PythonAPICall py("add_unit_loop" ); |
886 | py.Input("block_or_loop" , rv); |
887 | py.SingleOutput(outputs); |
888 | return py.Str(); |
889 | } |
890 | |
891 | template <typename> |
892 | friend struct ::tvm::tir::UnpackedInstTraits; |
893 | }; |
894 | |
895 | TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); |
896 | TVM_REGISTER_INST_KIND_TRAITS(FuseTraits); |
897 | TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits); |
898 | TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits); |
899 | |
900 | } // namespace tir |
901 | } // namespace tvm |
902 | |