1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/tir/data_type_rewriter.h> |
20 | |
21 | #include <functional> |
22 | |
23 | #include "../ir_comparator.h" |
24 | #include "../utils.h" |
25 | |
26 | namespace tvm { |
27 | namespace tir { |
28 | |
29 | template <class T> |
30 | bool UsesVar(const T& x, const Var& var) { |
31 | return UsesVar(x, [tgt = var.get()](const VarNode* v) { return v == tgt; }); |
32 | } |
33 | |
34 | Range RangeFromExtent(const PrimExpr& extent) { |
35 | return Range::FromMinExtent(make_zero(extent->dtype), extent); |
36 | } |
37 | |
38 | template <class T> |
39 | T DeepCopy(const T& stmt) { |
40 | return Downcast<T>(LoadJSON(SaveJSON(stmt))); |
41 | } |
42 | |
43 | /*! |
44 | * \brief ScheduleError that the bindings of the inner block are not divisible by the subspace |
45 | * represented by the outer loops. |
46 | */ |
47 | class SubspaceNotDivisibleError : public ScheduleError { |
48 | public: |
49 | explicit SubspaceNotDivisibleError(IRModule mod, For scope_loop, Block inner_block) |
50 | : mod_(std::move(mod)), |
51 | scope_loop_(std::move(scope_loop)), |
52 | inner_block_(std::move(inner_block)) {} |
53 | |
54 | String FastErrorString() const final { |
55 | return "ScheduleError: The bindings of the inner block can not be blockized." ; |
56 | } |
57 | |
58 | String DetailRenderTemplate() const final { |
59 | return "ScheduleError: The bindings of the inner block {0} can not be blockized by the loops " |
60 | "starting at {1}." ; |
61 | } |
62 | |
63 | IRModule mod() const final { return mod_; } |
64 | |
65 | Array<ObjectRef> LocationsOfInterest() const final { return {inner_block_, scope_loop_}; } |
66 | |
67 | private: |
68 | IRModule mod_; |
69 | For scope_loop_; |
70 | Block inner_block_; |
71 | }; |
72 | |
73 | /*! |
74 | * \brief Detect if bindings are a trivial case of the subspace division where we can divide the |
75 | * block iter bindings into two categories: |
76 | * 1. The binding covers no inner loop vars. |
77 | * 2. The binding covers only inner loop vars. |
78 | * |
79 | * The bindings are not required to be quasi-affine. Trivial block iters are always preserved. |
80 | * |
81 | * \param iter_vars The input iterators |
82 | * \param bindings The values of iter_vars |
83 | * \param predicate The predicate constraint on the input iterators. |
84 | * \param outer_iters The iters of the outer space |
85 | * \param inner_iters The iters of the inner space |
86 | * \return The result of the subspace division. |
87 | */ |
88 | Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter_vars, |
89 | const Array<PrimExpr>& bindings, |
90 | const PrimExpr& predicate, |
91 | const Array<Var>& outer_iters, |
92 | const Array<Var>& inner_iters) { |
93 | if (!is_one(predicate)) return {}; |
94 | Array<Array<arith::IterMark>> res; |
95 | std::unordered_set<const VarNode*> outer_loop_vars; |
96 | std::unordered_set<const VarNode*> inner_loop_vars; |
97 | |
98 | auto make_uses_var = [](const Array<Var>& vars) -> std::function<bool(const PrimExpr& expr)> { |
99 | std::unordered_set<const VarNode*> var_set; |
100 | var_set.reserve(vars.size()); |
101 | for (const Var& var : vars) { |
102 | var_set.insert(var.get()); |
103 | } |
104 | return [var_set = std::move(var_set)](const PrimExpr& expr) -> bool { |
105 | return UsesVar(expr, [&var_set](const VarNode* var) { |
106 | return var_set.count(var); // |
107 | }); |
108 | }; |
109 | }; |
110 | auto use_outer_loop_vars = make_uses_var(outer_iters); |
111 | auto use_inner_loop_vars = make_uses_var(inner_iters); |
112 | arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1); |
113 | |
114 | for (int i = 0, n = bindings.size(); i < n; ++i) { |
115 | bool outer = use_outer_loop_vars(bindings[i]); |
116 | bool inner = use_inner_loop_vars(bindings[i]); |
117 | arith::IterMark iter_mark; |
118 | if (bindings[i]->IsInstance<VarNode>()) { |
119 | iter_mark = arith::IterMark( |
120 | arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)), |
121 | iter_vars[i]->dom->extent); |
122 | } else { |
123 | iter_mark = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent); |
124 | } |
125 | if (outer && !inner) { |
126 | res.push_back({/*outer_iter=*/iter_mark, /*inner_iter=*/unit_iter_mark}); |
127 | } else if (inner && !outer) { |
128 | res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/iter_mark}); |
129 | } else if (!outer && !inner) { |
130 | res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/unit_iter_mark}); |
131 | } else { |
132 | return {}; |
133 | } |
134 | } |
135 | res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)), |
136 | arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))}); |
137 | return res; |
138 | } |
139 | |
140 | /*! |
141 | * \brief Subspace division. The space is divided into two subspaces: |
142 | * 1. The subspace represented by the outer loops above `loop_sref` (exclusive). |
143 | * 2. The subspace represented by the inner loops below `loop_sref` (inclusive). |
144 | * \param realize The inner block |
145 | * \param block_sref The sref to the inner block |
146 | * \param loop_sref The loop that is the root of the second subspace. |
147 | * \param loops The loops that represents the second part of the subspace. |
148 | * \param analyzer The arithmetic analyzer to use. |
149 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
150 | */ |
151 | Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize, |
152 | const StmtSRef& block_sref, // |
153 | const StmtSRef& loop_sref, // |
154 | std::vector<const ForNode*>* loops, |
155 | arith::Analyzer* analyzer, bool preserve_unit_iters) { |
156 | Array<Var> inner_vars; |
157 | Array<Var> outer_vars; |
158 | Map<Var, Range> loop_var_domain; |
159 | bool inner = true; |
160 | for (StmtSRefNode* sref = block_sref->parent; // |
161 | sref && sref->stmt->IsInstance<ForNode>(); // |
162 | sref = sref->parent) { |
163 | const ForNode* loop = static_cast<const ForNode*>(sref->stmt); |
164 | if (inner) { |
165 | loops->push_back(loop); |
166 | inner_vars.push_back(loop->loop_var); |
167 | } else { |
168 | outer_vars.push_back(loop->loop_var); |
169 | } |
170 | loop_var_domain.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); |
171 | if (sref == loop_sref.get()) { |
172 | inner = false; |
173 | } |
174 | } |
175 | Array<Array<arith::IterMark>> result = |
176 | arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate, |
177 | arith::IterMapLevel::Surjective, analyzer, |
178 | /*simplify_trivial_iterators=*/!preserve_unit_iters); |
179 | if (!result.empty()) { |
180 | return result; |
181 | } |
182 | return TrivialSubspaceDivision(realize->block->iter_vars, |
183 | realize->iter_values, // |
184 | realize->predicate, // |
185 | outer_vars, inner_vars); |
186 | } |
187 | |
188 | /*! |
189 | * \brief Derive the block bindings for both inner and outer block |
190 | * \param iter_vars The original block iterators to the inner block |
191 | * \param division The subspace division. |
192 | * \param outer_iter_vars The outer block iterators. |
193 | * \param outer_bindings The outer block bindings. |
194 | * \param inner_iter_vars The inner block iterators. |
195 | * \param inner_bindings The inner block bindings. |
196 | * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings |
197 | * \return A substitution plan to the iterators in the original inner block. |
198 | */ |
199 | Map<Var, PrimExpr> DeriveBlockBinding(const Array<IterVar>& iter_vars, // |
200 | const Array<Array<arith::IterMark>>& division, // |
201 | Array<IterVar>* outer_iter_vars, // |
202 | Array<PrimExpr>* outer_bindings, // |
203 | Array<IterVar>* inner_iter_vars, // |
204 | Array<PrimExpr>* inner_bindings, bool preserve_unit_iters) { |
205 | using arith::IterMapExpr; |
206 | using arith::IterMapExprNode; |
207 | using arith::NormalizeIterMapToExpr; |
208 | Map<Var, PrimExpr> block_var_subst; |
209 | ICHECK_EQ(iter_vars.size() + 1, division.size()); |
210 | for (int i = 0, n = iter_vars.size(); i < n; ++i) { |
211 | const IterVar& iter_var = iter_vars[i]; |
212 | arith::IterMark outer_mark = division[i][0]; |
213 | arith::IterMark inner_mark = division[i][1]; |
214 | IterMapExpr outer_binding = Downcast<IterMapExpr>(outer_mark->source); |
215 | IterMapExpr inner_binding = Downcast<IterMapExpr>(inner_mark->source); |
216 | // After computing the subspace division, bindings[i] can be written as |
217 | // outer_binding * inner_binding->extent + inner_binding |
218 | // The outer block will have binding: iter_outer -> outer_binding |
219 | // The inner block will have binding: iter_inner -> inner_binding |
220 | // The iter in the original block will be substituted with base + iter_inner where |
221 | // base == iter_outer * iter_inner_extent |
222 | if (is_one(inner_mark->extent)) { // IsOuter |
223 | // extract this iter var to outer block directly |
224 | outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding)); |
225 | outer_iter_vars->push_back(iter_var); |
226 | continue; |
227 | } |
228 | // create iter var for the outer block |
229 | IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent), |
230 | /*var=*/iter_var->var.copy_with_suffix("_o" ), |
231 | /*iter_type=*/iter_var->iter_type); |
232 | outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding)); |
233 | outer_iter_vars->push_back(outer_iter); |
234 | // create iter var for the inner block |
235 | IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent), |
236 | /*var=*/iter_var->var.copy_with_suffix("_i" ), |
237 | /*iter_type=*/iter_var->iter_type); |
238 | inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding)); |
239 | inner_iter_vars->push_back(inner_iter); |
240 | // substitution |
241 | PrimExpr sub{nullptr}; |
242 | if (is_one(outer_mark->extent)) { |
243 | sub = inner_iter->var; |
244 | } else { |
245 | sub = outer_iter * inner_mark->extent + inner_iter->var; |
246 | } |
247 | block_var_subst.Set(iter_var->var, sub); |
248 | } |
249 | return block_var_subst; |
250 | } |
251 | |
252 | /*! |
253 | * \brief Generate the inner block for blockization |
254 | * \param is_write_reduction Whether the write regions of the inner block are actually reduction. |
255 | * \param iter_vars IterVars used in the inner block. |
256 | * \param iter_values IterVar bindings used in the inner block. |
257 | * \param predicate The predicate of the inner block. |
258 | * \param block The inner block as a template to be created from. This method will modify its |
259 | * `iter_vars`, `init` and `reads` fields. |
260 | * \return The inner block created. |
261 | */ |
262 | BlockRealize GenerateInner(bool is_write_reduction, |
263 | const Array<IterVar>& iter_vars, // |
264 | const Array<PrimExpr>& iter_values, // |
265 | const PrimExpr& predicate, // |
266 | Block block) { |
267 | BlockNode* n = block.CopyOnWrite(); |
268 | n->iter_vars = iter_vars; |
269 | n->init = NullOpt; |
270 | if (is_write_reduction) { |
271 | Array<BufferRegion> reads; |
272 | reads.reserve(block->writes.size() + block->reads.size()); |
273 | reads.insert(reads.end(), block->writes.begin(), block->writes.end()); |
274 | reads.insert(reads.end(), block->reads.begin(), block->reads.end()); |
275 | n->reads = std::move(reads); |
276 | } |
277 | return BlockRealize(/*iter_values=*/iter_values, /*predicate=*/predicate, |
278 | /*block=*/block); |
279 | } |
280 | |
281 | /*! |
282 | * \brief Generate the init stmt for the outer block |
283 | * \param block The original block with init. |
284 | * \param inner_realize The block realize of the inner block after blockize. |
285 | * \param loops The inner loops after blockize. |
286 | * \return The subtree of the init block and its outer loops. |
287 | */ |
288 | Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize, |
289 | const std::vector<const ForNode*>& loops, String block_name) { |
290 | const Block& inner_block = inner_realize->block; |
291 | Map<Var, PrimExpr> subst_map; |
292 | // Step 1: Create new block vars for the block inside the init stmt of outer block |
293 | // A iter is used in the block if |
294 | // 1) It is data parallel |
295 | // 2) It is used in the original init block |
296 | Array<IterVar> iter_vars; |
297 | Array<PrimExpr> iter_values; |
298 | ICHECK_EQ(inner_block->iter_vars.size(), inner_realize->iter_values.size()); |
299 | int n = inner_block->iter_vars.size(); |
300 | iter_vars.reserve(n); |
301 | iter_values.reserve(n); |
302 | for (int i = 0; i < n; ++i) { |
303 | const IterVar& old_iter_var = inner_block->iter_vars[i]; |
304 | const PrimExpr& iter_value = inner_realize->iter_values[i]; |
305 | if (old_iter_var->iter_type == IterVarType::kDataPar && |
306 | UsesVar(block_init, old_iter_var->var)) { |
307 | ObjectPtr<IterVarNode> new_iter_var = make_object<IterVarNode>(*old_iter_var.get()); |
308 | new_iter_var->var = new_iter_var->var.copy_with_suffix("_init" ); |
309 | subst_map.Set(old_iter_var->var, new_iter_var->var); |
310 | iter_vars.push_back(IterVar(new_iter_var)); |
311 | iter_values.push_back(iter_value); |
312 | } |
313 | } |
314 | // Step 2: Generate the block inside init stmt of outer block |
315 | Stmt stmt = BlockRealize( |
316 | /*iter_values=*/iter_values, |
317 | /*predicate=*/inner_realize->predicate, |
318 | /*block=*/ |
319 | Block(/*iter_vars=*/iter_vars, |
320 | /*reads=*/{}, |
321 | /*writes=*/inner_block->writes, |
322 | /*name_hint=*/block_name, |
323 | /*body=*/block_init, |
324 | /*init=*/NullOpt)); |
325 | // Step 3. Create the loop nest on top of the block |
326 | for (const ForNode* loop : loops) { |
327 | bool is_init_loop = false; |
328 | for (const PrimExpr& init_binding : iter_values) { |
329 | if (UsesVar(init_binding, loop->loop_var)) { |
330 | is_init_loop = true; |
331 | break; |
332 | } |
333 | } |
334 | if (is_init_loop) { |
335 | ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop); |
336 | new_loop->loop_var = loop->loop_var.copy_with_suffix("" ); |
337 | new_loop->body = std::move(stmt); |
338 | subst_map.Set(loop->loop_var, new_loop->loop_var); |
339 | stmt = For(new_loop); |
340 | } |
341 | } |
342 | // Step 4: Substitute the iter vars and loop vars |
343 | return Substitute(stmt, subst_map); |
344 | } |
345 | |
346 | /*! |
347 | * \brief Substitute variables in the stmt, do simplification and track block substitution |
348 | * \param stmt The stmt to be substituted. |
349 | * \param sub The substitution map. |
350 | * \param block_sref_reuse The block substitution happens during the substitution. |
351 | * \param analyzer The analyzer for arithmetic simplification. |
352 | * \return The substituted stmt. |
353 | */ |
354 | Stmt Substitute(const Stmt& stmt, const Map<Var, PrimExpr>& sub, |
355 | Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer) { |
356 | struct Replacer : public StmtExprMutator { |
357 | explicit Replacer(const Map<Var, PrimExpr>& sub, Map<Block, Block>* block_sref_reuse, |
358 | arith::Analyzer* analyzer) |
359 | : sub_(sub), block_sref_reuse_(block_sref_reuse), analyzer_(analyzer) {} |
360 | |
361 | PrimExpr VisitExpr(const PrimExpr& op) final { |
362 | PrimExpr result = StmtExprMutator::VisitExpr(op); |
363 | if (!result.same_as(op)) { |
364 | return analyzer_->Simplify(result); |
365 | } |
366 | return result; |
367 | } |
368 | |
369 | PrimExpr VisitExpr_(const VarNode* op) final { |
370 | if (Optional<PrimExpr> e = sub_.Get(GetRef<Var>(op))) { |
371 | return e.value(); |
372 | } |
373 | return StmtExprMutator::VisitExpr_(op); |
374 | } |
375 | |
376 | Stmt VisitStmt_(const BlockNode* op) final { |
377 | Block src = GetRef<Block>(op); |
378 | Block tgt = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); |
379 | if (!src.same_as(tgt)) { |
380 | block_sref_reuse_->Set(src, tgt); |
381 | } |
382 | return std::move(tgt); |
383 | } |
384 | |
385 | const Map<Var, PrimExpr>& sub_; |
386 | Map<Block, Block>* block_sref_reuse_; |
387 | arith::Analyzer* analyzer_; |
388 | }; |
389 | return Replacer(sub, block_sref_reuse, analyzer)(stmt); |
390 | } |
391 | |
392 | /*! |
393 | * \brief Relax the variables for the given regions |
394 | * \param regions The regions to be relaxed. |
395 | * \param dom_map The variables to be relaxed |
396 | * \return The relaxed regions |
397 | */ |
398 | Array<BufferRegion> EvalSetRegions(const Array<BufferRegion>& regions, |
399 | const Map<Var, arith::IntSet>& dom_map) { |
400 | Array<BufferRegion> results; |
401 | results.reserve(regions.size()); |
402 | for (const BufferRegion& buffer_region : regions) { |
403 | const Buffer& buffer = buffer_region->buffer; |
404 | Array<arith::IntSet> relaxed = arith::EvalSet(buffer_region->region, dom_map); |
405 | ICHECK_EQ(relaxed.size(), buffer->shape.size()); |
406 | int ndim = buffer->shape.size(); |
407 | Array<Range> new_region; |
408 | new_region.reserve(ndim); |
409 | for (int i = 0; i < ndim; ++i) { |
410 | new_region.push_back(relaxed[i].CoverRange(RangeFromExtent(buffer->shape[i]))); |
411 | } |
412 | results.push_back(BufferRegion(buffer, new_region)); |
413 | } |
414 | return results; |
415 | } |
416 | |
417 | /*! |
418 | * \brief Create the loop nest on top of the given stmt. |
419 | * \param stmt The stmt to be wrapped. |
420 | * \param loops The loop nests |
421 | * \return The wrapped stmt. |
422 | */ |
423 | Stmt MakeLoopNest(Stmt stmt, const std::vector<const ForNode*>& loops) { |
424 | for (const ForNode* loop : loops) { |
425 | ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop); |
426 | new_loop->body = std::move(stmt); |
427 | stmt = For(new_loop); |
428 | } |
429 | return stmt; |
430 | } |
431 | |
432 | BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, |
433 | Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer, |
434 | bool preserve_unit_iters) { |
435 | TVM_SREF_TO_FOR(loop_sref); |
436 | // Step 1: Check and get the only block under `loop`. |
437 | BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); |
438 | Block block = block_realize->block; |
439 | StmtSRef block_sref = self->stmt2ref.at(block.get()); |
440 | // Step 2: Derive subspace division |
441 | std::vector<const ForNode*> loops; |
442 | Array<Array<arith::IterMark>> division = |
443 | SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer, preserve_unit_iters); |
444 | if (division.empty()) { |
445 | throw SubspaceNotDivisibleError(self->mod, GetRef<For>(loops.back()), block); |
446 | } |
447 | PrimExpr outer_predicate = division.back()[0]->extent; |
448 | PrimExpr inner_predicate = division.back()[1]->extent; |
449 | // Step 3. Derive block bindings for both outer and inner block. |
450 | Array<IterVar> outer_iter_vars; |
451 | Array<IterVar> inner_iter_vars; |
452 | Array<PrimExpr> outer_bindings; |
453 | Array<PrimExpr> inner_bindings; |
454 | Map<Var, PrimExpr> block_var_subst = // |
455 | DeriveBlockBinding(block->iter_vars, division, // |
456 | &outer_iter_vars, &outer_bindings, // |
457 | &inner_iter_vars, &inner_bindings, // |
458 | preserve_unit_iters); |
459 | // Step 4: Do var substitution to adjust to the new block bindings |
460 | Map<Var, arith::IntSet> inner_iter_dom; |
461 | for (const IterVar& iter : inner_iter_vars) { |
462 | inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(iter->dom)); |
463 | analyzer->Bind(iter->var, iter->dom); |
464 | } |
465 | Block block_subst = |
466 | Downcast<Block>(Substitute(block, block_var_subst, block_sref_reuse, analyzer)); |
467 | // Step 5: Generate the inner block. The write regions of the inner blocks will be reduction if |
468 | // 1. The original block has init stmt. |
469 | // 2. There are outer reduction iter vars. |
470 | bool has_outer_reduction = false; |
471 | if (block_subst->init.defined()) { |
472 | for (const IterVar& iter_var : outer_iter_vars) { |
473 | if (iter_var->iter_type == kCommReduce) { |
474 | has_outer_reduction = true; |
475 | break; |
476 | } |
477 | } |
478 | } |
479 | BlockRealize inner_realize = GenerateInner(/*is_write_reduction=*/has_outer_reduction, |
480 | /*iter_vars=*/inner_iter_vars, |
481 | /*iter_values*/ inner_bindings, |
482 | /*predicate=*/inner_predicate, |
483 | /*block=*/block_subst); |
484 | block_sref_reuse->Set(block, inner_realize->block); |
485 | // Step 6: Generate the outer block. |
486 | return BlockRealize( |
487 | /*iter_values=*/std::move(outer_bindings), |
488 | /*predicate=*/std::move(outer_predicate), |
489 | /*block=*/ |
490 | Block(/*iter_vars=*/std::move(outer_iter_vars), |
491 | /*reads=*/EvalSetRegions(block_subst->reads, inner_iter_dom), |
492 | /*writes=*/EvalSetRegions(block_subst->writes, inner_iter_dom), |
493 | /*name_hint=*/block_subst->name_hint + "_o" , |
494 | /*body=*/MakeLoopNest(inner_realize, loops), |
495 | /*init=*/ |
496 | block_subst->init.defined() // |
497 | ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops, |
498 | block_subst->name_hint + "_init" ) |
499 | : Optional<Stmt>(NullOpt))); |
500 | } |
501 | |
502 | StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters) { |
503 | arith::Analyzer analyzer; |
504 | Map<Block, Block> block_sref_reuse; |
505 | BlockRealize blockized = |
506 | BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer, preserve_unit_iters); |
507 | self->Replace(loop_sref, blockized, block_sref_reuse); |
508 | StmtSRef result = self->stmt2ref.at(blockized->block.get()); |
509 | StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); |
510 | bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); |
511 | self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); |
512 | self->block_info[scope_root].affine_binding = scope_block_affine_binding; |
513 | return result; |
514 | } |
515 | |
516 | void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin, |
517 | bool preserve_unit_iters) { |
518 | // Step 1: Blockize the subtree rooted at the given loop if needed |
519 | BlockRealize block_realize{nullptr}; |
520 | Optional<Block> old_block = NullOpt; |
521 | if (sref->stmt->IsInstance<BlockNode>()) { |
522 | block_realize = GetBlockRealize(self, sref); |
523 | old_block = block_realize->block; |
524 | } else if (sref->stmt->IsInstance<ForNode>()) { |
525 | arith::Analyzer analyzer; |
526 | Map<Block, Block> block_sref_reuse; |
527 | block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer, preserve_unit_iters); |
528 | } else { |
529 | LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: " |
530 | << GetRef<Stmt>(sref->stmt); |
531 | throw; |
532 | } |
533 | PrimFunc intrin_desc = intrin->desc; |
534 | PrimFunc intrin_impl = DeepCopy(intrin->impl); |
535 | |
536 | int index_dtype_bits = -1; |
537 | auto f_update_max_dtype_bits_from_region = [&](const Array<BufferRegion>& buffer_regions) { |
538 | for (const BufferRegion& buffer_region : buffer_regions) { |
539 | for (const auto& range : buffer_region->region) { |
540 | index_dtype_bits = std::max(index_dtype_bits, range->min.dtype().bits()); |
541 | } |
542 | } |
543 | }; |
544 | f_update_max_dtype_bits_from_region(block_realize->block->reads); |
545 | f_update_max_dtype_bits_from_region(block_realize->block->writes); |
546 | ICHECK(index_dtype_bits > 0); |
547 | intrin_impl = IndexDataTypeNormalizer(DataType::Int(index_dtype_bits)).Rewrite(intrin_impl); |
548 | // Step 2: Structural pattern matching |
549 | TensorizeComparator comparator(self->mod, /*assert_mode=*/true); |
550 | comparator.VisitStmt(block_realize, intrin_desc->body); |
551 | // Step 3: Prepare necessary mapping |
552 | // 1) Buffer mapping from intrin impl buffers to intrin desc buffers. |
553 | // 2) Buffer mapping from intrin impl buffers to buffers in the current AST. |
554 | // 3) Mapping impl buffers to their accessed regions. |
555 | std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> impl2desc; |
556 | ICHECK_EQ(intrin_desc->params.size(), intrin_impl->params.size()); |
557 | for (int i = 0, n = intrin_desc->params.size(); i < n; ++i) { |
558 | const Buffer& desc = intrin_desc->buffer_map[intrin_desc->params[i]]; |
559 | const Buffer& impl = intrin_impl->buffer_map[intrin_impl->params[i]]; |
560 | impl2desc[impl] = desc; |
561 | } |
562 | std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> impl2cur; |
563 | for (const auto& pair : impl2desc) { |
564 | const Buffer& impl = pair.first; |
565 | const Buffer& desc = pair.second; |
566 | ICHECK(comparator.rhs_buffer_map_.count(desc)); |
567 | impl2cur[impl] = comparator.rhs_buffer_map_[desc]; |
568 | } |
569 | std::unordered_map<Buffer, Array<Range>, ObjectPtrHash, ObjectPtrEqual> impl2region; |
570 | Block impl_block = Downcast<BlockRealize>(intrin_impl->body)->block; |
571 | for (const BufferRegion& read : impl_block->reads) { |
572 | impl2region.emplace(read->buffer, read->region); |
573 | } |
574 | for (const BufferRegion& write : impl_block->writes) { |
575 | impl2region.emplace(write->buffer, write->region); |
576 | } |
577 | // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor |
578 | // intrin to make them subregions of the buffer in the original IR. |
579 | Array<MatchBufferRegion> match_buffer_regions; |
580 | match_buffer_regions.reserve(intrin_impl->params.size()); |
581 | for (int i = 0, n = intrin_impl->params.size(); i < n; ++i) { |
582 | const Buffer& impl = intrin_impl->buffer_map.at(intrin_impl->params[i]); |
583 | const Buffer& cur = impl2cur.at(impl); |
584 | const Array<Range>& old_region = impl2region.at(impl); |
585 | const std::vector<PrimExpr>& indices_base = comparator.buffer_indices_.at(cur); |
586 | int offset = static_cast<int>(indices_base.size()) - static_cast<int>(old_region.size()); |
587 | ICHECK(offset >= 0); |
588 | Array<Range> new_region; |
589 | new_region.reserve(cur->shape.size()); |
590 | for (int i = 0; i < offset; i++) { |
591 | PrimExpr min = indices_base[i]; |
592 | PrimExpr extent = make_const(min.dtype(), 1); |
593 | new_region.push_back(Range::FromMinExtent(min, extent)); |
594 | } |
595 | for (int i = 0; i < static_cast<int>(old_region.size()); i++) { |
596 | PrimExpr min = indices_base[i + offset]; |
597 | PrimExpr extent = cast(min.dtype(), old_region[i]->extent); |
598 | new_region.push_back(Range::FromMinExtent(min, extent)); |
599 | } |
600 | match_buffer_regions.push_back(MatchBufferRegion(impl, BufferRegion(cur, new_region))); |
601 | } |
602 | // Step 5: Replace the subtree in the original IR with the tensor intrin impl. |
603 | { |
604 | BlockNode* block = block_realize.CopyOnWrite()->block.CopyOnWrite(); |
605 | block->body = impl_block->body; |
606 | block->match_buffers = std::move(match_buffer_regions); |
607 | } |
608 | if (old_block.defined()) { |
609 | self->Replace(sref, block_realize->block, {{old_block.value(), block_realize->block}}); |
610 | } else { |
611 | self->Replace(sref, block_realize, {}); |
612 | } |
613 | // Step 6: Update the cached flags. |
614 | StmtSRef result = self->stmt2ref.at(block_realize->block.get()); |
615 | StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); |
616 | self->UpdateScopeBlockInfo(scope_root->StmtAs<BlockNode>()->body); |
617 | } |
618 | |
619 | /******** InstructionKind Registration ********/ |
620 | |
621 | struct BlockizeTraits : public UnpackedInstTraits<BlockizeTraits> { |
622 | static constexpr const char* kName = "Blockize" ; |
623 | static constexpr bool kIsPure = false; |
624 | |
625 | private: |
626 | static constexpr size_t kNumInputs = 1; |
627 | static constexpr size_t kNumAttrs = 1; |
628 | static constexpr size_t kNumDecisions = 0; |
629 | |
630 | static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Bool preserve_unit_iters) { |
631 | return sch->Blockize(loop_rv, preserve_unit_iters.operator bool()); |
632 | } |
633 | |
634 | static String UnpackedAsPython(Array<String> outputs, String loop_rv, Bool preserve_unit_iters) { |
635 | PythonAPICall py("blockize" ); |
636 | py.Input("loop" , loop_rv); |
637 | py.Input("preserve_unit_iters" , preserve_unit_iters.operator bool()); |
638 | py.SingleOutput(outputs); |
639 | return py.Str(); |
640 | } |
641 | |
642 | template <typename> |
643 | friend struct ::tvm::tir::UnpackedInstTraits; |
644 | }; |
645 | |
646 | struct TensorizeTraits : public UnpackedInstTraits<TensorizeTraits> { |
647 | static constexpr const char* kName = "Tensorize" ; |
648 | static constexpr bool kIsPure = false; |
649 | |
650 | private: |
651 | static constexpr size_t kNumInputs = 1; |
652 | static constexpr size_t kNumAttrs = 2; |
653 | static constexpr size_t kNumDecisions = 0; |
654 | |
655 | static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin, |
656 | Bool preserve_unit_iters) { |
657 | if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) { |
658 | sch->Tensorize(GetRef<BlockRV>(block), intrin, preserve_unit_iters.operator bool()); |
659 | } else if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) { |
660 | sch->Tensorize(GetRef<LoopRV>(loop), intrin, preserve_unit_iters.operator bool()); |
661 | } else { |
662 | LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: " |
663 | << block_or_loop_rv->GetTypeKey(); |
664 | } |
665 | } |
666 | |
667 | static String UnpackedAsPython(Array<String> outputs, String block_or_loop_rv, String intrin, |
668 | Bool preserve_unit_iters) { |
669 | PythonAPICall py("tensorize" ); |
670 | py.Input("block_or_loop" , block_or_loop_rv); |
671 | py.Input("tensor_intrin" , intrin); |
672 | py.Input("preserve_unit_iters" , preserve_unit_iters.operator bool()); |
673 | return py.Str(); |
674 | } |
675 | |
676 | template <typename> |
677 | friend struct ::tvm::tir::UnpackedInstTraits; |
678 | }; |
679 | |
680 | TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits); |
681 | TVM_REGISTER_INST_KIND_TRAITS(TensorizeTraits); |
682 | |
683 | } // namespace tir |
684 | } // namespace tvm |
685 | |