1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/tir/data_type_rewriter.h>
20
21#include <functional>
22
23#include "../ir_comparator.h"
24#include "../utils.h"
25
26namespace tvm {
27namespace tir {
28
29template <class T>
30bool UsesVar(const T& x, const Var& var) {
31 return UsesVar(x, [tgt = var.get()](const VarNode* v) { return v == tgt; });
32}
33
34Range RangeFromExtent(const PrimExpr& extent) {
35 return Range::FromMinExtent(make_zero(extent->dtype), extent);
36}
37
38template <class T>
39T DeepCopy(const T& stmt) {
40 return Downcast<T>(LoadJSON(SaveJSON(stmt)));
41}
42
43/*!
44 * \brief ScheduleError that the bindings of the inner block are not divisible by the subspace
45 * represented by the outer loops.
46 */
47class SubspaceNotDivisibleError : public ScheduleError {
48 public:
49 explicit SubspaceNotDivisibleError(IRModule mod, For scope_loop, Block inner_block)
50 : mod_(std::move(mod)),
51 scope_loop_(std::move(scope_loop)),
52 inner_block_(std::move(inner_block)) {}
53
54 String FastErrorString() const final {
55 return "ScheduleError: The bindings of the inner block can not be blockized.";
56 }
57
58 String DetailRenderTemplate() const final {
59 return "ScheduleError: The bindings of the inner block {0} can not be blockized by the loops "
60 "starting at {1}.";
61 }
62
63 IRModule mod() const final { return mod_; }
64
65 Array<ObjectRef> LocationsOfInterest() const final { return {inner_block_, scope_loop_}; }
66
67 private:
68 IRModule mod_;
69 For scope_loop_;
70 Block inner_block_;
71};
72
73/*!
74 * \brief Detect if bindings are a trivial case of the subspace division where we can divide the
75 * block iter bindings into two categories:
76 * 1. The binding covers no inner loop vars.
77 * 2. The binding covers only inner loop vars.
78 *
79 * The bindings are not required to be quasi-affine. Trivial block iters are always preserved.
80 *
81 * \param iter_vars The input iterators
82 * \param bindings The values of iter_vars
83 * \param predicate The predicate constraint on the input iterators.
84 * \param outer_iters The iters of the outer space
85 * \param inner_iters The iters of the inner space
86 * \return The result of the subspace division.
87 */
88Array<Array<arith::IterMark>> TrivialSubspaceDivision(const Array<IterVar>& iter_vars,
89 const Array<PrimExpr>& bindings,
90 const PrimExpr& predicate,
91 const Array<Var>& outer_iters,
92 const Array<Var>& inner_iters) {
93 if (!is_one(predicate)) return {};
94 Array<Array<arith::IterMark>> res;
95 std::unordered_set<const VarNode*> outer_loop_vars;
96 std::unordered_set<const VarNode*> inner_loop_vars;
97
98 auto make_uses_var = [](const Array<Var>& vars) -> std::function<bool(const PrimExpr& expr)> {
99 std::unordered_set<const VarNode*> var_set;
100 var_set.reserve(vars.size());
101 for (const Var& var : vars) {
102 var_set.insert(var.get());
103 }
104 return [var_set = std::move(var_set)](const PrimExpr& expr) -> bool {
105 return UsesVar(expr, [&var_set](const VarNode* var) {
106 return var_set.count(var); //
107 });
108 };
109 };
110 auto use_outer_loop_vars = make_uses_var(outer_iters);
111 auto use_inner_loop_vars = make_uses_var(inner_iters);
112 arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1);
113
114 for (int i = 0, n = bindings.size(); i < n; ++i) {
115 bool outer = use_outer_loop_vars(bindings[i]);
116 bool inner = use_inner_loop_vars(bindings[i]);
117 arith::IterMark iter_mark;
118 if (bindings[i]->IsInstance<VarNode>()) {
119 iter_mark = arith::IterMark(
120 arith::IterSplitExpr(arith::IterMark(bindings[i], iter_vars[i]->dom->extent)),
121 iter_vars[i]->dom->extent);
122 } else {
123 iter_mark = arith::IterMark(arith::IterSumExpr({}, bindings[i]), iter_vars[i]->dom->extent);
124 }
125 if (outer && !inner) {
126 res.push_back({/*outer_iter=*/iter_mark, /*inner_iter=*/unit_iter_mark});
127 } else if (inner && !outer) {
128 res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/iter_mark});
129 } else if (!outer && !inner) {
130 res.push_back({/*outer_iter=*/unit_iter_mark, /*inner_iter=*/unit_iter_mark});
131 } else {
132 return {};
133 }
134 }
135 res.push_back({arith::IterMark(arith::IterSumExpr({}, 0), Bool(true)),
136 arith::IterMark(arith::IterSumExpr({}, 0), Bool(true))});
137 return res;
138}
139
140/*!
141 * \brief Subspace division. The space is divided into two subspaces:
142 * 1. The subspace represented by the outer loops above `loop_sref` (exclusive).
143 * 2. The subspace represented by the inner loops below `loop_sref` (inclusive).
144 * \param realize The inner block
145 * \param block_sref The sref to the inner block
146 * \param loop_sref The loop that is the root of the second subspace.
147 * \param loops The loops that represents the second part of the subspace.
148 * \param analyzer The arithmetic analyzer to use.
149 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
150 */
151Array<Array<arith::IterMark>> SubspaceDivide(const BlockRealize& realize,
152 const StmtSRef& block_sref, //
153 const StmtSRef& loop_sref, //
154 std::vector<const ForNode*>* loops,
155 arith::Analyzer* analyzer, bool preserve_unit_iters) {
156 Array<Var> inner_vars;
157 Array<Var> outer_vars;
158 Map<Var, Range> loop_var_domain;
159 bool inner = true;
160 for (StmtSRefNode* sref = block_sref->parent; //
161 sref && sref->stmt->IsInstance<ForNode>(); //
162 sref = sref->parent) {
163 const ForNode* loop = static_cast<const ForNode*>(sref->stmt);
164 if (inner) {
165 loops->push_back(loop);
166 inner_vars.push_back(loop->loop_var);
167 } else {
168 outer_vars.push_back(loop->loop_var);
169 }
170 loop_var_domain.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
171 if (sref == loop_sref.get()) {
172 inner = false;
173 }
174 }
175 Array<Array<arith::IterMark>> result =
176 arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate,
177 arith::IterMapLevel::Surjective, analyzer,
178 /*simplify_trivial_iterators=*/!preserve_unit_iters);
179 if (!result.empty()) {
180 return result;
181 }
182 return TrivialSubspaceDivision(realize->block->iter_vars,
183 realize->iter_values, //
184 realize->predicate, //
185 outer_vars, inner_vars);
186}
187
188/*!
189 * \brief Derive the block bindings for both inner and outer block
190 * \param iter_vars The original block iterators to the inner block
191 * \param division The subspace division.
192 * \param outer_iter_vars The outer block iterators.
193 * \param outer_bindings The outer block bindings.
194 * \param inner_iter_vars The inner block iterators.
195 * \param inner_bindings The inner block bindings.
196 * \param preserve_unit_iters Whether or not to preserve unit iterators in block bindings
197 * \return A substitution plan to the iterators in the original inner block.
198 */
199Map<Var, PrimExpr> DeriveBlockBinding(const Array<IterVar>& iter_vars, //
200 const Array<Array<arith::IterMark>>& division, //
201 Array<IterVar>* outer_iter_vars, //
202 Array<PrimExpr>* outer_bindings, //
203 Array<IterVar>* inner_iter_vars, //
204 Array<PrimExpr>* inner_bindings, bool preserve_unit_iters) {
205 using arith::IterMapExpr;
206 using arith::IterMapExprNode;
207 using arith::NormalizeIterMapToExpr;
208 Map<Var, PrimExpr> block_var_subst;
209 ICHECK_EQ(iter_vars.size() + 1, division.size());
210 for (int i = 0, n = iter_vars.size(); i < n; ++i) {
211 const IterVar& iter_var = iter_vars[i];
212 arith::IterMark outer_mark = division[i][0];
213 arith::IterMark inner_mark = division[i][1];
214 IterMapExpr outer_binding = Downcast<IterMapExpr>(outer_mark->source);
215 IterMapExpr inner_binding = Downcast<IterMapExpr>(inner_mark->source);
216 // After computing the subspace division, bindings[i] can be written as
217 // outer_binding * inner_binding->extent + inner_binding
218 // The outer block will have binding: iter_outer -> outer_binding
219 // The inner block will have binding: iter_inner -> inner_binding
220 // The iter in the original block will be substituted with base + iter_inner where
221 // base == iter_outer * iter_inner_extent
222 if (is_one(inner_mark->extent)) { // IsOuter
223 // extract this iter var to outer block directly
224 outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
225 outer_iter_vars->push_back(iter_var);
226 continue;
227 }
228 // create iter var for the outer block
229 IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent),
230 /*var=*/iter_var->var.copy_with_suffix("_o"),
231 /*iter_type=*/iter_var->iter_type);
232 outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding));
233 outer_iter_vars->push_back(outer_iter);
234 // create iter var for the inner block
235 IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent),
236 /*var=*/iter_var->var.copy_with_suffix("_i"),
237 /*iter_type=*/iter_var->iter_type);
238 inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding));
239 inner_iter_vars->push_back(inner_iter);
240 // substitution
241 PrimExpr sub{nullptr};
242 if (is_one(outer_mark->extent)) {
243 sub = inner_iter->var;
244 } else {
245 sub = outer_iter * inner_mark->extent + inner_iter->var;
246 }
247 block_var_subst.Set(iter_var->var, sub);
248 }
249 return block_var_subst;
250}
251
252/*!
253 * \brief Generate the inner block for blockization
254 * \param is_write_reduction Whether the write regions of the inner block are actually reduction.
255 * \param iter_vars IterVars used in the inner block.
256 * \param iter_values IterVar bindings used in the inner block.
257 * \param predicate The predicate of the inner block.
258 * \param block The inner block as a template to be created from. This method will modify its
259 * `iter_vars`, `init` and `reads` fields.
260 * \return The inner block created.
261 */
262BlockRealize GenerateInner(bool is_write_reduction,
263 const Array<IterVar>& iter_vars, //
264 const Array<PrimExpr>& iter_values, //
265 const PrimExpr& predicate, //
266 Block block) {
267 BlockNode* n = block.CopyOnWrite();
268 n->iter_vars = iter_vars;
269 n->init = NullOpt;
270 if (is_write_reduction) {
271 Array<BufferRegion> reads;
272 reads.reserve(block->writes.size() + block->reads.size());
273 reads.insert(reads.end(), block->writes.begin(), block->writes.end());
274 reads.insert(reads.end(), block->reads.begin(), block->reads.end());
275 n->reads = std::move(reads);
276 }
277 return BlockRealize(/*iter_values=*/iter_values, /*predicate=*/predicate,
278 /*block=*/block);
279}
280
281/*!
282 * \brief Generate the init stmt for the outer block
283 * \param block The original block with init.
284 * \param inner_realize The block realize of the inner block after blockize.
285 * \param loops The inner loops after blockize.
286 * \return The subtree of the init block and its outer loops.
287 */
288Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize,
289 const std::vector<const ForNode*>& loops, String block_name) {
290 const Block& inner_block = inner_realize->block;
291 Map<Var, PrimExpr> subst_map;
292 // Step 1: Create new block vars for the block inside the init stmt of outer block
293 // A iter is used in the block if
294 // 1) It is data parallel
295 // 2) It is used in the original init block
296 Array<IterVar> iter_vars;
297 Array<PrimExpr> iter_values;
298 ICHECK_EQ(inner_block->iter_vars.size(), inner_realize->iter_values.size());
299 int n = inner_block->iter_vars.size();
300 iter_vars.reserve(n);
301 iter_values.reserve(n);
302 for (int i = 0; i < n; ++i) {
303 const IterVar& old_iter_var = inner_block->iter_vars[i];
304 const PrimExpr& iter_value = inner_realize->iter_values[i];
305 if (old_iter_var->iter_type == IterVarType::kDataPar &&
306 UsesVar(block_init, old_iter_var->var)) {
307 ObjectPtr<IterVarNode> new_iter_var = make_object<IterVarNode>(*old_iter_var.get());
308 new_iter_var->var = new_iter_var->var.copy_with_suffix("_init");
309 subst_map.Set(old_iter_var->var, new_iter_var->var);
310 iter_vars.push_back(IterVar(new_iter_var));
311 iter_values.push_back(iter_value);
312 }
313 }
314 // Step 2: Generate the block inside init stmt of outer block
315 Stmt stmt = BlockRealize(
316 /*iter_values=*/iter_values,
317 /*predicate=*/inner_realize->predicate,
318 /*block=*/
319 Block(/*iter_vars=*/iter_vars,
320 /*reads=*/{},
321 /*writes=*/inner_block->writes,
322 /*name_hint=*/block_name,
323 /*body=*/block_init,
324 /*init=*/NullOpt));
325 // Step 3. Create the loop nest on top of the block
326 for (const ForNode* loop : loops) {
327 bool is_init_loop = false;
328 for (const PrimExpr& init_binding : iter_values) {
329 if (UsesVar(init_binding, loop->loop_var)) {
330 is_init_loop = true;
331 break;
332 }
333 }
334 if (is_init_loop) {
335 ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
336 new_loop->loop_var = loop->loop_var.copy_with_suffix("");
337 new_loop->body = std::move(stmt);
338 subst_map.Set(loop->loop_var, new_loop->loop_var);
339 stmt = For(new_loop);
340 }
341 }
342 // Step 4: Substitute the iter vars and loop vars
343 return Substitute(stmt, subst_map);
344}
345
346/*!
347 * \brief Substitute variables in the stmt, do simplification and track block substitution
348 * \param stmt The stmt to be substituted.
349 * \param sub The substitution map.
350 * \param block_sref_reuse The block substitution happens during the substitution.
351 * \param analyzer The analyzer for arithmetic simplification.
352 * \return The substituted stmt.
353 */
354Stmt Substitute(const Stmt& stmt, const Map<Var, PrimExpr>& sub,
355 Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer) {
356 struct Replacer : public StmtExprMutator {
357 explicit Replacer(const Map<Var, PrimExpr>& sub, Map<Block, Block>* block_sref_reuse,
358 arith::Analyzer* analyzer)
359 : sub_(sub), block_sref_reuse_(block_sref_reuse), analyzer_(analyzer) {}
360
361 PrimExpr VisitExpr(const PrimExpr& op) final {
362 PrimExpr result = StmtExprMutator::VisitExpr(op);
363 if (!result.same_as(op)) {
364 return analyzer_->Simplify(result);
365 }
366 return result;
367 }
368
369 PrimExpr VisitExpr_(const VarNode* op) final {
370 if (Optional<PrimExpr> e = sub_.Get(GetRef<Var>(op))) {
371 return e.value();
372 }
373 return StmtExprMutator::VisitExpr_(op);
374 }
375
376 Stmt VisitStmt_(const BlockNode* op) final {
377 Block src = GetRef<Block>(op);
378 Block tgt = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
379 if (!src.same_as(tgt)) {
380 block_sref_reuse_->Set(src, tgt);
381 }
382 return std::move(tgt);
383 }
384
385 const Map<Var, PrimExpr>& sub_;
386 Map<Block, Block>* block_sref_reuse_;
387 arith::Analyzer* analyzer_;
388 };
389 return Replacer(sub, block_sref_reuse, analyzer)(stmt);
390}
391
392/*!
393 * \brief Relax the variables for the given regions
394 * \param regions The regions to be relaxed.
395 * \param dom_map The variables to be relaxed
396 * \return The relaxed regions
397 */
398Array<BufferRegion> EvalSetRegions(const Array<BufferRegion>& regions,
399 const Map<Var, arith::IntSet>& dom_map) {
400 Array<BufferRegion> results;
401 results.reserve(regions.size());
402 for (const BufferRegion& buffer_region : regions) {
403 const Buffer& buffer = buffer_region->buffer;
404 Array<arith::IntSet> relaxed = arith::EvalSet(buffer_region->region, dom_map);
405 ICHECK_EQ(relaxed.size(), buffer->shape.size());
406 int ndim = buffer->shape.size();
407 Array<Range> new_region;
408 new_region.reserve(ndim);
409 for (int i = 0; i < ndim; ++i) {
410 new_region.push_back(relaxed[i].CoverRange(RangeFromExtent(buffer->shape[i])));
411 }
412 results.push_back(BufferRegion(buffer, new_region));
413 }
414 return results;
415}
416
417/*!
418 * \brief Create the loop nest on top of the given stmt.
419 * \param stmt The stmt to be wrapped.
420 * \param loops The loop nests
421 * \return The wrapped stmt.
422 */
423Stmt MakeLoopNest(Stmt stmt, const std::vector<const ForNode*>& loops) {
424 for (const ForNode* loop : loops) {
425 ObjectPtr<ForNode> new_loop = make_object<ForNode>(*loop);
426 new_loop->body = std::move(stmt);
427 stmt = For(new_loop);
428 }
429 return stmt;
430}
431
432BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref,
433 Map<Block, Block>* block_sref_reuse, arith::Analyzer* analyzer,
434 bool preserve_unit_iters) {
435 TVM_SREF_TO_FOR(loop_sref);
436 // Step 1: Check and get the only block under `loop`.
437 BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref);
438 Block block = block_realize->block;
439 StmtSRef block_sref = self->stmt2ref.at(block.get());
440 // Step 2: Derive subspace division
441 std::vector<const ForNode*> loops;
442 Array<Array<arith::IterMark>> division =
443 SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer, preserve_unit_iters);
444 if (division.empty()) {
445 throw SubspaceNotDivisibleError(self->mod, GetRef<For>(loops.back()), block);
446 }
447 PrimExpr outer_predicate = division.back()[0]->extent;
448 PrimExpr inner_predicate = division.back()[1]->extent;
449 // Step 3. Derive block bindings for both outer and inner block.
450 Array<IterVar> outer_iter_vars;
451 Array<IterVar> inner_iter_vars;
452 Array<PrimExpr> outer_bindings;
453 Array<PrimExpr> inner_bindings;
454 Map<Var, PrimExpr> block_var_subst = //
455 DeriveBlockBinding(block->iter_vars, division, //
456 &outer_iter_vars, &outer_bindings, //
457 &inner_iter_vars, &inner_bindings, //
458 preserve_unit_iters);
459 // Step 4: Do var substitution to adjust to the new block bindings
460 Map<Var, arith::IntSet> inner_iter_dom;
461 for (const IterVar& iter : inner_iter_vars) {
462 inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(iter->dom));
463 analyzer->Bind(iter->var, iter->dom);
464 }
465 Block block_subst =
466 Downcast<Block>(Substitute(block, block_var_subst, block_sref_reuse, analyzer));
467 // Step 5: Generate the inner block. The write regions of the inner blocks will be reduction if
468 // 1. The original block has init stmt.
469 // 2. There are outer reduction iter vars.
470 bool has_outer_reduction = false;
471 if (block_subst->init.defined()) {
472 for (const IterVar& iter_var : outer_iter_vars) {
473 if (iter_var->iter_type == kCommReduce) {
474 has_outer_reduction = true;
475 break;
476 }
477 }
478 }
479 BlockRealize inner_realize = GenerateInner(/*is_write_reduction=*/has_outer_reduction,
480 /*iter_vars=*/inner_iter_vars,
481 /*iter_values*/ inner_bindings,
482 /*predicate=*/inner_predicate,
483 /*block=*/block_subst);
484 block_sref_reuse->Set(block, inner_realize->block);
485 // Step 6: Generate the outer block.
486 return BlockRealize(
487 /*iter_values=*/std::move(outer_bindings),
488 /*predicate=*/std::move(outer_predicate),
489 /*block=*/
490 Block(/*iter_vars=*/std::move(outer_iter_vars),
491 /*reads=*/EvalSetRegions(block_subst->reads, inner_iter_dom),
492 /*writes=*/EvalSetRegions(block_subst->writes, inner_iter_dom),
493 /*name_hint=*/block_subst->name_hint + "_o",
494 /*body=*/MakeLoopNest(inner_realize, loops),
495 /*init=*/
496 block_subst->init.defined() //
497 ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops,
498 block_subst->name_hint + "_init")
499 : Optional<Stmt>(NullOpt)));
500}
501
502StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref, bool preserve_unit_iters) {
503 arith::Analyzer analyzer;
504 Map<Block, Block> block_sref_reuse;
505 BlockRealize blockized =
506 BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer, preserve_unit_iters);
507 self->Replace(loop_sref, blockized, block_sref_reuse);
508 StmtSRef result = self->stmt2ref.at(blockized->block.get());
509 StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false);
510 bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root);
511 self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root));
512 self->block_info[scope_root].affine_binding = scope_block_affine_binding;
513 return result;
514}
515
516void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin,
517 bool preserve_unit_iters) {
518 // Step 1: Blockize the subtree rooted at the given loop if needed
519 BlockRealize block_realize{nullptr};
520 Optional<Block> old_block = NullOpt;
521 if (sref->stmt->IsInstance<BlockNode>()) {
522 block_realize = GetBlockRealize(self, sref);
523 old_block = block_realize->block;
524 } else if (sref->stmt->IsInstance<ForNode>()) {
525 arith::Analyzer analyzer;
526 Map<Block, Block> block_sref_reuse;
527 block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer, preserve_unit_iters);
528 } else {
529 LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: "
530 << GetRef<Stmt>(sref->stmt);
531 throw;
532 }
533 PrimFunc intrin_desc = intrin->desc;
534 PrimFunc intrin_impl = DeepCopy(intrin->impl);
535
536 int index_dtype_bits = -1;
537 auto f_update_max_dtype_bits_from_region = [&](const Array<BufferRegion>& buffer_regions) {
538 for (const BufferRegion& buffer_region : buffer_regions) {
539 for (const auto& range : buffer_region->region) {
540 index_dtype_bits = std::max(index_dtype_bits, range->min.dtype().bits());
541 }
542 }
543 };
544 f_update_max_dtype_bits_from_region(block_realize->block->reads);
545 f_update_max_dtype_bits_from_region(block_realize->block->writes);
546 ICHECK(index_dtype_bits > 0);
547 intrin_impl = IndexDataTypeNormalizer(DataType::Int(index_dtype_bits)).Rewrite(intrin_impl);
548 // Step 2: Structural pattern matching
549 TensorizeComparator comparator(self->mod, /*assert_mode=*/true);
550 comparator.VisitStmt(block_realize, intrin_desc->body);
551 // Step 3: Prepare necessary mapping
552 // 1) Buffer mapping from intrin impl buffers to intrin desc buffers.
553 // 2) Buffer mapping from intrin impl buffers to buffers in the current AST.
554 // 3) Mapping impl buffers to their accessed regions.
555 std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> impl2desc;
556 ICHECK_EQ(intrin_desc->params.size(), intrin_impl->params.size());
557 for (int i = 0, n = intrin_desc->params.size(); i < n; ++i) {
558 const Buffer& desc = intrin_desc->buffer_map[intrin_desc->params[i]];
559 const Buffer& impl = intrin_impl->buffer_map[intrin_impl->params[i]];
560 impl2desc[impl] = desc;
561 }
562 std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> impl2cur;
563 for (const auto& pair : impl2desc) {
564 const Buffer& impl = pair.first;
565 const Buffer& desc = pair.second;
566 ICHECK(comparator.rhs_buffer_map_.count(desc));
567 impl2cur[impl] = comparator.rhs_buffer_map_[desc];
568 }
569 std::unordered_map<Buffer, Array<Range>, ObjectPtrHash, ObjectPtrEqual> impl2region;
570 Block impl_block = Downcast<BlockRealize>(intrin_impl->body)->block;
571 for (const BufferRegion& read : impl_block->reads) {
572 impl2region.emplace(read->buffer, read->region);
573 }
574 for (const BufferRegion& write : impl_block->writes) {
575 impl2region.emplace(write->buffer, write->region);
576 }
577 // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor
578 // intrin to make them subregions of the buffer in the original IR.
579 Array<MatchBufferRegion> match_buffer_regions;
580 match_buffer_regions.reserve(intrin_impl->params.size());
581 for (int i = 0, n = intrin_impl->params.size(); i < n; ++i) {
582 const Buffer& impl = intrin_impl->buffer_map.at(intrin_impl->params[i]);
583 const Buffer& cur = impl2cur.at(impl);
584 const Array<Range>& old_region = impl2region.at(impl);
585 const std::vector<PrimExpr>& indices_base = comparator.buffer_indices_.at(cur);
586 int offset = static_cast<int>(indices_base.size()) - static_cast<int>(old_region.size());
587 ICHECK(offset >= 0);
588 Array<Range> new_region;
589 new_region.reserve(cur->shape.size());
590 for (int i = 0; i < offset; i++) {
591 PrimExpr min = indices_base[i];
592 PrimExpr extent = make_const(min.dtype(), 1);
593 new_region.push_back(Range::FromMinExtent(min, extent));
594 }
595 for (int i = 0; i < static_cast<int>(old_region.size()); i++) {
596 PrimExpr min = indices_base[i + offset];
597 PrimExpr extent = cast(min.dtype(), old_region[i]->extent);
598 new_region.push_back(Range::FromMinExtent(min, extent));
599 }
600 match_buffer_regions.push_back(MatchBufferRegion(impl, BufferRegion(cur, new_region)));
601 }
602 // Step 5: Replace the subtree in the original IR with the tensor intrin impl.
603 {
604 BlockNode* block = block_realize.CopyOnWrite()->block.CopyOnWrite();
605 block->body = impl_block->body;
606 block->match_buffers = std::move(match_buffer_regions);
607 }
608 if (old_block.defined()) {
609 self->Replace(sref, block_realize->block, {{old_block.value(), block_realize->block}});
610 } else {
611 self->Replace(sref, block_realize, {});
612 }
613 // Step 6: Update the cached flags.
614 StmtSRef result = self->stmt2ref.at(block_realize->block.get());
615 StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false);
616 self->UpdateScopeBlockInfo(scope_root->StmtAs<BlockNode>()->body);
617}
618
619/******** InstructionKind Registration ********/
620
621struct BlockizeTraits : public UnpackedInstTraits<BlockizeTraits> {
622 static constexpr const char* kName = "Blockize";
623 static constexpr bool kIsPure = false;
624
625 private:
626 static constexpr size_t kNumInputs = 1;
627 static constexpr size_t kNumAttrs = 1;
628 static constexpr size_t kNumDecisions = 0;
629
630 static BlockRV UnpackedApplyToSchedule(Schedule sch, LoopRV loop_rv, Bool preserve_unit_iters) {
631 return sch->Blockize(loop_rv, preserve_unit_iters.operator bool());
632 }
633
634 static String UnpackedAsPython(Array<String> outputs, String loop_rv, Bool preserve_unit_iters) {
635 PythonAPICall py("blockize");
636 py.Input("loop", loop_rv);
637 py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
638 py.SingleOutput(outputs);
639 return py.Str();
640 }
641
642 template <typename>
643 friend struct ::tvm::tir::UnpackedInstTraits;
644};
645
646struct TensorizeTraits : public UnpackedInstTraits<TensorizeTraits> {
647 static constexpr const char* kName = "Tensorize";
648 static constexpr bool kIsPure = false;
649
650 private:
651 static constexpr size_t kNumInputs = 1;
652 static constexpr size_t kNumAttrs = 2;
653 static constexpr size_t kNumDecisions = 0;
654
655 static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, String intrin,
656 Bool preserve_unit_iters) {
657 if (const auto* block = block_or_loop_rv.as<BlockRVNode>()) {
658 sch->Tensorize(GetRef<BlockRV>(block), intrin, preserve_unit_iters.operator bool());
659 } else if (const auto* loop = block_or_loop_rv.as<LoopRVNode>()) {
660 sch->Tensorize(GetRef<LoopRV>(loop), intrin, preserve_unit_iters.operator bool());
661 } else {
662 LOG(FATAL) << "TypeError: Expected Block or Loop, but gets: "
663 << block_or_loop_rv->GetTypeKey();
664 }
665 }
666
667 static String UnpackedAsPython(Array<String> outputs, String block_or_loop_rv, String intrin,
668 Bool preserve_unit_iters) {
669 PythonAPICall py("tensorize");
670 py.Input("block_or_loop", block_or_loop_rv);
671 py.Input("tensor_intrin", intrin);
672 py.Input("preserve_unit_iters", preserve_unit_iters.operator bool());
673 return py.Str();
674 }
675
676 template <typename>
677 friend struct ::tvm::tir::UnpackedInstTraits;
678};
679
680TVM_REGISTER_INST_KIND_TRAITS(BlockizeTraits);
681TVM_REGISTER_INST_KIND_TRAITS(TensorizeTraits);
682
683} // namespace tir
684} // namespace tvm
685