1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include <tvm/node/node.h>
21
22#include <optional>
23#include <variant>
24
25#include "../../../arith/ir_mutator_with_analyzer.h"
26#include "../utils.h"
27
28namespace tvm {
29namespace tir {
30
31/*! \brief Planning stage prior to rewriting in TransformLayoutRewriter
32 *
33 * There are four ways that transformation may be handled. Each
34 * updates the buffer shape and the indices used to acces the buffer
35 * in BufferStore/BufferLoad nodes, but differ in how they handle the
36 * `pad_value`. In order of preference, the different strategies are
37 * as follows:
38 *
39 * 1. NoPaddingRequired. The transformation does not introduce
40 * padding, so only local changes to update the indices of
41 * BufferLoad/BufferStore nodes are required. No blocks are added,
42 * removed, or replaced.
43 *
44 * 2. ProloguePlan. The transformation introduces padding, but the
45 * analyzed block has no write stages for the transformed buffer.
46 * This buffer is an input and the caller is responsible for ensuring
47 * that the padding contains the specified `pad_value`. The generated
48 * prologue contains `builtin::assume()` calls that will expose this
49 * known value during scheduling/simplification, but will be removed
50 * during lowering.
51 *
52 * 3. ReplacementPlan. The transformation introduces padding, has at
53 * least one write stage for the transformed buffer, and at least one
54 * of those write stages writes to all pre-transformation indices
55 * following a row-major traversal. These write stage is rewritten to
56 * be row-major traversals of the post-transformation indices, with a
57 * `tir::if_then_else` call to write either the specified `pad_value`
58 * into padding or the computed value into non-padding.
59 *
60 * 4. EpiloguePlan. The transformation introduces padding, has at
61 * least one write stage for the transformed buffer, but no write
62 * stage can be rewritten to use `tir::if_then_else`. The
63 * transformation still requires the `pad_value` to be written into
64 * the padding, so a new block is inserted after the last write stage
65 * to explicitly fill the padding.
66 *
67 */
68class TransformLayoutPlanner : private StmtExprVisitor {
69 public:
70 // Statement to be inserted prior to the analyzed block
71 struct ProloguePlan {
72 Stmt prologue;
73 };
74
75 // Loops within the analyzed block that should be replaced
76 struct ReplacementPlan {
77 Map<For, Stmt> replacements;
78 Map<Block, Block> new_block_to_old;
79 };
80
81 // The block to be inserted, along with the location at which it
82 // should be inserted. The location will be either a For or a
83 // Block, and will be after all writes the transformed buffer.
84 struct EpiloguePlan {
85 Stmt insert_after;
86 Stmt new_block;
87 };
88
89 struct NoPaddingRequired {};
90
91 using TransformPlan =
92 std::variant<ProloguePlan, ReplacementPlan, EpiloguePlan, NoPaddingRequired>;
93
94 static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map,
95 IndexMap inverse, PrimExpr padding_predicate,
96 Optional<IndexMap> pad_value) {
97 ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1)
98 << "Internal error: Should be caught by ScheduleError checks prior to this point";
99 TransformLayoutPlanner visitor(old_buffer);
100 visitor(block);
101 return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value);
102 }
103
104 private:
105 struct WriteInfo {
106 // The BufferStore object
107 BufferStore store;
108
109 // The block realize that contains the store, if any.
110 Optional<BlockRealize> innermost_block_realize;
111
112 // The nested loops whose values contribute to the indices used in
113 // the store. Not all loop variables in the loopnest need to
114 // contribute, but the first and last must.
115 std::vector<For> dependent_loopnest;
116
117 // Whether the padding could be represented as a tir::if_then_else
118 // node. This requires that the surrounding loop iterators
119 // iterate over all pre-transformation buffer axes, that there are
120 // no data dependencies between loop iterations, and that
121 bool contains_row_major_traversal{false};
122 };
123
124 explicit TransformLayoutPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {}
125
126 void VisitStmt_(const ForNode* op) override {
127 BindLoopVar context(this, GetRef<For>(op));
128 StmtExprVisitor::VisitStmt_(op);
129 }
130
131 void VisitStmt_(const LetStmtNode* op) override {
132 BindVariableDefinition context(this, op->var, op->value);
133 StmtExprVisitor::VisitStmt_(op);
134 }
135
136 void VisitStmt_(const BlockRealizeNode* op) override {
137 BindBlockRealize context(this, GetRef<BlockRealize>(op));
138 StmtExprVisitor::VisitStmt_(op);
139 }
140
141 void VisitStmt_(const BufferStoreNode* op) override {
142 if (!op->buffer.same_as(old_buffer_)) {
143 return;
144 }
145
146 std::optional<std::pair<size_t, size_t>> loop_dependency_range = std::nullopt;
147 for (const auto& index : op->indices) {
148 if (auto index_depth = LoopDependencyRange(index); index_depth.has_value()) {
149 if (loop_dependency_range) {
150 loop_dependency_range = {
151 std::min(loop_dependency_range.value().first, index_depth.value().first),
152 std::max(loop_dependency_range.value().second, index_depth.value().second)};
153 } else {
154 loop_dependency_range = index_depth;
155 }
156 }
157 }
158
159 WriteInfo write_info;
160 write_info.store = GetRef<BufferStore>(op);
161 if (loop_dependency_range) {
162 size_t i = loop_dependency_range.value().first;
163 size_t j = loop_dependency_range.value().second;
164 ICHECK_LT(i, active_loops_.size());
165 ICHECK_LT(j, active_loops_.size());
166
167 write_info.dependent_loopnest = {active_loops_.begin() + i, active_loops_.begin() + j + 1};
168 }
169 write_info.innermost_block_realize = innermost_block_realize_;
170
171 write_info.contains_row_major_traversal = [&]() -> bool {
172 const auto& loopnest = write_info.dependent_loopnest;
173 if (loopnest.empty()) {
174 return false;
175 }
176
177 if (loopnest.size() != old_buffer_->shape.size() || loopnest.size() != op->indices.size()) {
178 return false;
179 }
180
181 for (size_t i = 0; i < loopnest.size(); i++) {
182 const For& loop = loopnest[i];
183 const PrimExpr& buffer_dim = old_buffer_->shape[i];
184 PrimExpr index = Substitute(op->indices[i], active_var_bindings_);
185 bool is_loop_over_axis = index.same_as(loop->loop_var) && is_const_int(loop->min, 0) &&
186 ExprDeepEqual()(loop->extent, buffer_dim) &&
187 loop->kind == ForKind::kSerial;
188 if (!is_loop_over_axis) {
189 return false;
190 }
191 }
192
193 return true;
194 }();
195
196 write_info_.push_back(write_info);
197
198 // Don't need to continue recursing, as the entire goal was to
199 // find the BufferStore.
200 }
201
202 std::optional<std::pair<size_t, size_t>> LoopDependencyRange(const PrimExpr& expr) const {
203 std::optional<std::pair<size_t, size_t>> prev = std::nullopt;
204 for (const auto& var : UndefinedVars(expr)) {
205 auto it = loop_depth_lookup_.find(var.get());
206 if (it != loop_depth_lookup_.end()) {
207 if (prev.has_value()) {
208 prev = {std::min(prev.value().first, it->second.first),
209 std::max(prev.value().second, it->second.second)};
210 } else {
211 prev = it->second;
212 }
213 }
214 }
215
216 return prev;
217 }
218
219 class BufferStoreReplacer : public StmtExprMutator {
220 public:
221 BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer, PrimExpr padding_predicate,
222 const IndexMap& inverse, const Optional<IndexMap>& pad_value,
223 Map<Block, Block>* new_block_to_old)
224 : info(info),
225 new_buffer(new_buffer),
226 new_indices(inverse->initial_indices.Map([](const Var& var) -> PrimExpr { return var; })),
227 padding_predicate(padding_predicate),
228 inverse(inverse),
229 pad_value(pad_value),
230 new_block_to_old(*new_block_to_old) {
231 ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size());
232 for (size_t i = 0; i < info.dependent_loopnest.size(); i++) {
233 Var var = info.dependent_loopnest[i]->loop_var;
234 PrimExpr expr = inverse->final_indices[i];
235 var_remap.Set(var, expr);
236 }
237
238 DefineBlockUpdates();
239 }
240
241 bool is_all_stores_replaced() const { return all_stores_replaced; }
242
243 private:
244 void DefineBlockUpdates() {
245 if (!info.innermost_block_realize) {
246 return;
247 }
248
249 BlockRealize block_realize = info.innermost_block_realize.value();
250 const auto& block = block_realize->block;
251 const Array<PrimExpr>& old_indices = info.store->indices;
252 const auto& old_iter_vars = block->iter_vars;
253
254 this->new_iter_vars = old_iter_vars;
255 this->new_iter_values = block_realize->iter_values;
256
257 if (old_indices.empty()) {
258 return;
259 }
260
261 // Find the block iterators that are used to access the buffer. Must be in the same
262 // order as they appear in the indices.
263 if (block->iter_vars.size() < old_indices.size()) {
264 return;
265 }
266
267 size_t block_index_start = 0;
268 for (; block_index_start < old_iter_vars.size() - old_indices.size(); block_index_start++) {
269 if (old_indices[0].same_as(old_iter_vars[block_index_start]->var)) {
270 break;
271 }
272 }
273 if (block_index_start > old_iter_vars.size() - old_indices.size()) {
274 return;
275 }
276
277 for (size_t i = 0; i < old_indices.size(); i++) {
278 if (!old_indices[i].same_as(old_iter_vars[block_index_start + i]->var) ||
279 old_iter_vars[block_index_start + i]->iter_type != kDataPar) {
280 return;
281 }
282 }
283
284 // If we got to this point, all indices used to access the
285 // buffer are virtual indices defined in the innermost block.
286 // Therefore, generate new virtual indices for iterating over
287 // the post-transform buffer.
288
289 new_indices = inverse->initial_indices.Map([](Var var) -> PrimExpr {
290 std::stringstream ss;
291 ss << "v_" << var->name_hint;
292 return Var(ss.str(), var.dtype());
293 });
294
295 Map<Var, PrimExpr>
296 loop_var_to_virtual_var; // For updating padding_predicate in terms of the new indices
297 Array<PrimExpr> new_iter_values; // For BlockRealize
298 Array<IterVar> new_iter_vars; // For Block
299
300 for (size_t i = 0; i < block_index_start; i++) {
301 new_iter_vars.push_back(old_iter_vars[i]);
302 new_iter_values.push_back(block_realize->iter_values[i]);
303 }
304
305 ICHECK_EQ(new_indices.size(), new_buffer->shape.size());
306 for (size_t i = 0; i < new_indices.size(); i++) {
307 Var var = inverse->initial_indices[i];
308 Var virtual_var = Downcast<Var>(new_indices[i]);
309 PrimExpr dim = new_buffer->shape[i];
310 new_iter_values.push_back(var);
311 new_iter_vars.push_back(
312 IterVar(Range::FromMinExtent(make_zero(dim.dtype()), dim), virtual_var, kDataPar));
313 loop_var_to_virtual_var.Set(var, virtual_var);
314 }
315
316 for (size_t i = block_index_start + old_indices.size(); i < old_iter_vars.size(); i++) {
317 new_iter_vars.push_back(old_iter_vars[i]);
318 new_iter_values.push_back(block_realize->iter_values[i]);
319 }
320
321 ICHECK_EQ(inverse->final_indices.size(), old_indices.size());
322 for (size_t i = 0; i < old_indices.size(); i++) {
323 Var var = Downcast<Var>(old_indices[i]);
324 PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var);
325 var_remap.Set(var, expr);
326 }
327
328 padding_predicate = Substitute(padding_predicate, loop_var_to_virtual_var);
329
330 this->new_iter_vars = new_iter_vars;
331 this->new_iter_values = new_iter_values;
332 }
333
334 Stmt VisitStmt_(const BufferStoreNode* op) final {
335 bool can_replace = [&]() -> bool {
336 if (!op->buffer.same_as(info.store->buffer)) {
337 return false;
338 }
339
340 const Array<PrimExpr>& old_indices = info.store->indices;
341
342 ICHECK_EQ(old_indices.size(), op->indices.size());
343 ExprDeepEqual expr_equal;
344 for (size_t i = 0; i < old_indices.size(); i++) {
345 if (!expr_equal(old_indices[i], op->indices[i])) {
346 return false;
347 }
348 }
349 return true;
350 }();
351
352 BufferStore store = GetRef<BufferStore>(op);
353 if (can_replace) {
354 PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_indices)[0];
355 store =
356 BufferStore(new_buffer, if_then_else(padding_predicate, pad_value_at_index, op->value),
357 new_indices);
358 } else {
359 all_stores_replaced = false;
360 }
361 return StmtExprMutator::VisitStmt_(store.get());
362 }
363
364 Stmt VisitStmt_(const BlockRealizeNode* op) final {
365 BlockRealize realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op));
366
367 if (op == info.innermost_block_realize.get()) {
368 Block block = realize->block;
369 if (!block->iter_vars.same_as(this->new_iter_vars)) {
370 block.CopyOnWrite()->iter_vars = this->new_iter_vars;
371 RecordReplacement(op->block, block);
372 }
373
374 if (!block.same_as(realize->block) ||
375 !realize->iter_values.same_as(this->new_iter_values)) {
376 auto write_ptr = realize.CopyOnWrite();
377 write_ptr->block = block;
378 write_ptr->iter_values = this->new_iter_values;
379 }
380 }
381
382 return std::move(realize);
383 }
384
385 Stmt VisitStmt_(const BlockNode* op) final {
386 Block orig = GetRef<Block>(op);
387 Block mutated = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
388
389 RecordReplacement(orig, mutated);
390 return std::move(mutated);
391 }
392
393 PrimExpr VisitExpr_(const VarNode* op) final {
394 Var var = GetRef<Var>(op);
395 if (auto opt = var_remap.Get(var)) {
396 return opt.value();
397 } else {
398 return std::move(var);
399 }
400 }
401
402 void RecordReplacement(Block before, Block after) {
403 if (before.same_as(after)) {
404 return;
405 }
406
407 ICHECK(!new_block_to_old.count(after));
408
409 while (true) {
410 if (auto opt = new_block_to_old.Get(before)) {
411 before = opt.value();
412 } else {
413 break;
414 }
415 }
416
417 new_block_to_old.Set(after, before);
418 }
419
420 const WriteInfo& info;
421 const Buffer& new_buffer;
422 Array<PrimExpr> new_indices;
423 Array<IterVar> new_iter_vars;
424 Array<PrimExpr> new_iter_values;
425 PrimExpr padding_predicate;
426 const IndexMap& inverse;
427 const Optional<IndexMap>& pad_value;
428 Map<Block, Block>& new_block_to_old;
429 bool all_stores_replaced{true};
430
431 Map<Var, PrimExpr> var_remap;
432 };
433
434 TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse,
435 PrimExpr padding_predicate, Optional<IndexMap> pad_value) const {
436 if (auto prologue_plan =
437 FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value);
438 prologue_plan.has_value()) {
439 return prologue_plan.value();
440 } else if (auto replacement_plan = FinalizeReplacementPlan(new_buffer, index_map, inverse,
441 padding_predicate, pad_value);
442 replacement_plan.has_value()) {
443 return replacement_plan.value();
444 } else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer, index_map, inverse,
445 padding_predicate, pad_value);
446 epilogue_plan.has_value()) {
447 return epilogue_plan.value();
448 } else {
449 return NoPaddingRequired();
450 }
451 }
452
453 std::optional<ProloguePlan> FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map,
454 IndexMap inverse, PrimExpr padding_predicate,
455 Optional<IndexMap> pad_value) const {
456 if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) {
457 return std::nullopt;
458 }
459
460 Array<IterVar> iter_vars;
461 Array<PrimExpr> iter_values;
462 Array<PrimExpr> indices;
463 Map<Var, PrimExpr> loop_indices_to_block_indices;
464 ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
465 for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
466 const auto& loop_var = inverse->initial_indices[i];
467 const auto& dim = new_buffer->shape[i];
468 Var block_var("v_" + loop_var->name_hint, loop_var->dtype);
469 IterVar iter_var(Range(0, dim), block_var, kDataPar);
470 loop_indices_to_block_indices.Set(loop_var, block_var);
471 indices.push_back(iter_var->var);
472 iter_vars.push_back(iter_var);
473 iter_values.push_back(loop_var);
474 }
475 padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices);
476
477 PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0];
478 PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value_at_index);
479 Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr}));
480
481 std::stringstream block_name;
482 block_name << "buffer_" << new_buffer->name << "_assumptions";
483 auto read_region = BufferRegion::FromPoint(new_buffer, indices);
484 stmt = BlockRealize(iter_values, Bool(true),
485 Block(iter_vars, {read_region}, {}, block_name.str(), stmt));
486
487 for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
488 size_t i = (inverse->initial_indices.size() - 1) - rev_i;
489 Var loop_var = inverse->initial_indices[i];
490 PrimExpr extent = new_buffer->shape[i];
491 stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
492 }
493 return ProloguePlan{stmt};
494 }
495
496 std::optional<ReplacementPlan> FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map,
497 IndexMap inverse,
498 PrimExpr padding_predicate,
499 Optional<IndexMap> pad_value) const {
500 if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) {
501 return std::nullopt;
502 }
503
504 Map<Block, Block> new_block_to_old;
505 auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional<Stmt> {
506 if (!info.contains_row_major_traversal || !pad_value.defined() ||
507 is_zero(padding_predicate)) {
508 return NullOpt;
509 }
510
511 BufferStoreReplacer replacer(info, new_buffer, padding_predicate, inverse, pad_value,
512 &new_block_to_old);
513 Stmt stmt = replacer(info.dependent_loopnest.back()->body);
514 if (!replacer.is_all_stores_replaced()) {
515 return NullOpt;
516 }
517
518 ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
519 for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
520 size_t i = (inverse->initial_indices.size() - 1) - rev_i;
521 Var loop_var = inverse->initial_indices[i];
522 PrimExpr extent = new_buffer->shape[i];
523 stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
524 }
525
526 return stmt;
527 };
528
529 Map<For, Stmt> loop_replacements;
530
531 for (const auto& info : write_info_) {
532 if (info.dependent_loopnest.size()) {
533 if (auto opt_stmt = generate_if_then_else_block(info)) {
534 loop_replacements.Set(info.dependent_loopnest[0], opt_stmt.value());
535 }
536 }
537 }
538
539 if (loop_replacements.size()) {
540 return ReplacementPlan{std::move(loop_replacements), std::move(new_block_to_old)};
541 } else {
542 return std::nullopt;
543 }
544 }
545
546 std::optional<EpiloguePlan> FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map,
547 IndexMap inverse, PrimExpr padding_predicate,
548 Optional<IndexMap> pad_value) const {
549 if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) {
550 return std::nullopt;
551 }
552
553 Array<IterVar> iter_vars;
554 Array<PrimExpr> iter_values;
555 Array<PrimExpr> indices;
556 ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
557 for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
558 const auto& loop_var = inverse->initial_indices[i];
559 const auto& dim = new_buffer->shape[i];
560 Var block_var("v_" + loop_var->name_hint, loop_var->dtype);
561 IterVar iter_var(Range(0, dim), block_var, kDataPar);
562 indices.push_back(iter_var->var);
563 iter_vars.push_back(iter_var);
564 iter_values.push_back(loop_var);
565 }
566
567 PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0];
568 Stmt stmt = BufferStore(new_buffer, pad_value_at_index, indices);
569
570 std::stringstream block_name;
571 block_name << "buffer_" << new_buffer->name << "_padding";
572 auto write_region = BufferRegion::FromPoint(new_buffer, indices);
573 stmt = BlockRealize(iter_values, padding_predicate,
574 Block(iter_vars, {}, {write_region}, block_name.str(), stmt));
575
576 ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size());
577 for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) {
578 size_t i = (inverse->initial_indices.size() - 1) - rev_i;
579 Var loop_var = inverse->initial_indices[i];
580 PrimExpr extent = new_buffer->shape[i];
581 stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt);
582 }
583
584 const auto& info = write_info_.back();
585 Stmt insert_after = [&]() -> Stmt {
586 if (info.dependent_loopnest.size()) {
587 return info.dependent_loopnest.front();
588 } else if (info.innermost_block_realize) {
589 return info.innermost_block_realize.value();
590 } else {
591 LOG(FATAL) << "Write occured outside of any block/loop";
592 }
593 }();
594 return EpiloguePlan{insert_after, stmt};
595 }
596
597 struct BindLoopVar {
598 BindLoopVar(TransformLayoutPlanner* self, For for_node)
599 : self_(self), var_(for_node->loop_var) {
600 size_t loop_depth = self_->active_loops_.size();
601 self_->loop_depth_lookup_[var_.get()] = {loop_depth, loop_depth};
602 self_->active_loops_.push_back(std::move(for_node));
603 }
604 ~BindLoopVar() {
605 self_->active_loops_.pop_back();
606 self_->loop_depth_lookup_.erase(var_.get());
607 }
608 BindLoopVar(const BindLoopVar&) = delete;
609 BindLoopVar& operator=(const BindLoopVar&) = delete;
610 BindLoopVar(BindLoopVar&&) = delete;
611 BindLoopVar& operator=(BindLoopVar&&) = delete;
612
613 TransformLayoutPlanner* self_{nullptr};
614 Var var_;
615 };
616
617 struct BindVariableDefinition {
618 BindVariableDefinition() {}
619 BindVariableDefinition(TransformLayoutPlanner* self, Var var, PrimExpr value)
620 : self_(self), var_(var) {
621 if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) {
622 self_->loop_depth_lookup_[var_.get()] = loop_depth.value();
623 self_->active_var_bindings_[var_.get()] = Substitute(value, self_->active_var_bindings_);
624 }
625 }
626 ~BindVariableDefinition() {
627 if (self_) {
628 self_->loop_depth_lookup_.erase(var_.get());
629 self_->active_var_bindings_.erase(var_.get());
630 }
631 }
632 BindVariableDefinition(const BindVariableDefinition&) = delete;
633 BindVariableDefinition& operator=(const BindVariableDefinition&) = delete;
634 BindVariableDefinition(BindVariableDefinition&& other) : BindVariableDefinition() {
635 swap(other);
636 }
637 BindVariableDefinition& operator=(BindVariableDefinition&& other) {
638 swap(other);
639 return *this;
640 }
641 void swap(BindVariableDefinition& other) {
642 std::swap(self_, other.self_);
643 std::swap(var_, other.var_);
644 }
645
646 TransformLayoutPlanner* self_{nullptr};
647 Var var_;
648 };
649
650 struct BindBlockRealize {
651 BindBlockRealize(TransformLayoutPlanner* self, BlockRealize block_realize) : self_(self) {
652 ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size());
653 for (size_t i = 0; i < block_realize->iter_values.size(); i++) {
654 bound_vars_.emplace_back(self, block_realize->block->iter_vars[i]->var,
655 block_realize->iter_values[i]);
656 }
657 cache_ = std::move(block_realize);
658 std::swap(self_->innermost_block_realize_, cache_);
659 }
660 ~BindBlockRealize() { std::swap(self_->innermost_block_realize_, cache_); }
661 BindBlockRealize(const BindBlockRealize&) = delete;
662 BindBlockRealize& operator=(const BindBlockRealize&) = delete;
663 BindBlockRealize(BindBlockRealize&&) = delete;
664 BindBlockRealize& operator=(BindBlockRealize&&) = delete;
665
666 TransformLayoutPlanner* self_{nullptr};
667 Optional<BlockRealize> cache_;
668 std::vector<BindVariableDefinition> bound_vars_;
669 };
670
671 /*! \brief Collected information about each BufferStore */
672 std::vector<WriteInfo> write_info_;
673
674 /*! \brief The loop iterators surrounding the current node
675 *
676 * The outermost loop iterator is `active_loops_.front()`, and the
677 * innermost loop iterator is `active_loops_.back()`.
678 *
679 * Used to fill the `WriteInfo::dependent_loopnest` field.
680 */
681 std::vector<For> active_loops_;
682
683 /*! \brief Lookup for the outer/inner loops
684 *
685 * Used to fill the `WriteInfo::dependent_loopnest` field.
686 */
687 std::unordered_map<const VarNode*, std::pair<size_t, size_t>> loop_depth_lookup_;
688
689 /*! \brief The variable mappings that are currently in-scope
690 *
691 * Used to determine whether the indices of a BufferStore are a
692 * row-major traversal, even if they are rebound in let/block
693 * mappings.
694 */
695 std::unordered_map<const VarNode*, PrimExpr> active_var_bindings_;
696
697 /*! \brief The innermost BlockRealize surrounding the current node
698 *
699 * Used to fill the `WriteInfo::innermost_block_realize` field..
700 */
701 Optional<BlockRealize> innermost_block_realize_{NullOpt};
702
703 /*! \brief The buffer to be replaced */
704 Buffer old_buffer_;
705};
706
707class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
708 public:
709 /*!
710 * \brief Rewrite the access to the buffer after the transformation
711 * \param scope_stmt The parent statement that contains all accesses to the target buffer
712 * \param old_buffer The target buffer before transformation
713 * \param new_buffer The new buffer after transformation
714 * \param index_map The transformation applied to the buffer
715 * \return The new AST rooting at the original parent scope and the map from the old block to the
716 * new block
717 */
718 static std::pair<Stmt, Map<Block, Block>> Rewrite(
719 const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer,
720 const IndexMap& index_map, const IndexMap& inverse, const PrimExpr& padding_predicate,
721 const Optional<IndexMap>& pad_value) {
722 auto plan = TransformLayoutPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, inverse,
723 padding_predicate, pad_value);
724
725 arith::Analyzer analyzer;
726 TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer);
727 Block result = Downcast<Block>(rewriter(scope_stmt));
728 if (auto plan_ptr = std::get_if<TransformLayoutPlanner::ProloguePlan>(&plan)) {
729 auto write_ptr = result.CopyOnWrite();
730 write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body});
731 }
732
733 Map<Block, Block> block_sref_reuse;
734 for (auto [after, before] : rewriter.new_block_to_old_) {
735 while (auto opt = rewriter.new_block_to_old_.Get(before)) {
736 before = opt.value();
737 }
738 while (auto opt = block_sref_reuse.Get(after)) {
739 after = opt.value();
740 }
741
742 block_sref_reuse.Set(before, after);
743 }
744
745 return {result, block_sref_reuse};
746 }
747
748 private:
749 TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer,
750 const IndexMap& index_map,
751 const TransformLayoutPlanner::TransformPlan& plan,
752 arith::Analyzer* analyzer)
753 : IRMutatorWithAnalyzer(analyzer),
754 old_buffer_(old_buffer),
755 new_buffer_(new_buffer),
756 index_map_(index_map),
757 plan_(plan),
758 buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {
759 if (auto plan_ptr = std::get_if<TransformLayoutPlanner::ReplacementPlan>(&plan_)) {
760 new_block_to_old_ = plan_ptr->new_block_to_old;
761 }
762 }
763
764 void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) {
765 *buffer = new_buffer_;
766 *indices = index_map_->MapIndices(*indices);
767 (*indices).MutateByApply(
768 [&](const PrimExpr& e) { return SimplifyNonTrivialExpr(e, analyzer_); });
769 }
770
771 using Parent = arith::IRMutatorWithAnalyzer;
772 using Parent::VisitExpr_;
773 using Parent::VisitStmt_;
774
775 Stmt VisitStmt(const Stmt& stmt) final {
776 Stmt output = Parent::VisitStmt(stmt);
777 if (auto plan_ptr = std::get_if<TransformLayoutPlanner::EpiloguePlan>(&plan_)) {
778 if (plan_ptr->insert_after.same_as(stmt)) {
779 return SeqStmt({output, plan_ptr->new_block});
780 }
781 }
782 return output;
783 }
784
785 Stmt VisitStmt_(const ForNode* op) final {
786 // Some replacements may include the original string, such as
787 // replacing `loop` with `{loop, post_proc}`. In this case, avoid
788 // infinite recursion.
789
790 For node = GetRef<For>(op);
791 if (auto plan_ptr = std::get_if<TransformLayoutPlanner::ReplacementPlan>(&plan_)) {
792 auto it = plan_ptr->replacements.find(node);
793 if (it != plan_ptr->replacements.end()) {
794 return VisitStmt((*it).second);
795 }
796 }
797 return Parent::VisitStmt_(op);
798 }
799
800 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
801 BufferLoad buffer_load = Downcast<BufferLoad>(Parent::VisitExpr_(op));
802 if (buffer_load->buffer.same_as(old_buffer_)) {
803 auto* n = buffer_load.CopyOnWrite();
804 RewriteBufferAccess(&n->buffer, &n->indices);
805 }
806 return std::move(buffer_load);
807 }
808
809 Stmt VisitStmt_(const BufferStoreNode* op) final {
810 BufferStore buffer_store = Downcast<BufferStore>(Parent::VisitStmt_(op));
811 if (buffer_store->buffer.same_as(old_buffer_)) {
812 auto* n = buffer_store.CopyOnWrite();
813 RewriteBufferAccess(&n->buffer, &n->indices);
814 }
815 return std::move(buffer_store);
816 }
817
818 void RewriteAccessRegion(Array<BufferRegion>* old_access_regions,
819 const Array<BufferRegion>& infered_access_regions) {
820 auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) {
821 if (buffer_region->buffer.same_as(old_buffer_)) {
822 ICHECK(infered_access_regions.size() == 1);
823 return infered_access_regions[0];
824 }
825 return buffer_region;
826 };
827 (*old_access_regions).MutateByApply(fmutate);
828 }
829
830 Stmt VisitStmt_(const BlockNode* op) final {
831 Block orig = [&]() {
832 Block block = GetRef<Block>(op);
833 while (true) {
834 if (auto it = new_block_to_old_.find(block); it != new_block_to_old_.end()) {
835 block = (*it).second;
836 } else {
837 break;
838 }
839 }
840 return block;
841 }();
842
843 Block block = Downcast<Block>(Parent::VisitStmt_(op));
844
845 auto infered_access_regions = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
846 auto* n = block.CopyOnWrite();
847 RewriteAccessRegion(&n->reads, infered_access_regions[0]);
848 RewriteAccessRegion(&n->writes, infered_access_regions[1]);
849 n->alloc_buffers.MutateByApply([this](const Buffer& buffer) {
850 if (buffer.same_as(old_buffer_)) {
851 return new_buffer_;
852 } else {
853 return buffer;
854 }
855 });
856
857 RecordReplacement(orig, block);
858 return std::move(block);
859 }
860
861 void RecordReplacement(Block before, Block after) {
862 if (before.same_as(after)) {
863 return;
864 }
865
866 ICHECK(!new_block_to_old_.count(after));
867
868 while (true) {
869 if (auto opt = new_block_to_old_.Get(before)) {
870 before = opt.value();
871 } else {
872 break;
873 }
874 }
875
876 new_block_to_old_.Set(after, before);
877 }
878
879 const Buffer& old_buffer_;
880 const Buffer& new_buffer_;
881 const IndexMap& index_map_;
882 const TransformLayoutPlanner::TransformPlan& plan_;
883 Map<Var, Buffer> buffer_data_to_buffer_;
884 Map<Block, Block> new_block_to_old_;
885};
886
887class BufferIsSubregionError : public ScheduleError {
888 public:
889 explicit BufferIsSubregionError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {}
890
891 String FastErrorString() const final {
892 return "ScheduleError: The input buffer is defined in `match_buffer` of a block, it is expected"
893 " to be a function parameter or allocated by a block";
894 }
895
896 String DetailRenderTemplate() const final {
897 std::ostringstream os;
898 os << "ScheduleError: The input buffer " << buffer_->name << " is defined in `match_buffer` of "
899 << "a block, it is expected to be a function parameter or allocated by a block.";
900 return os.str();
901 }
902
903 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
904 IRModule mod() const final { return mod_; }
905
906 private:
907 IRModule mod_;
908 Buffer buffer_;
909};
910
911class TransformationPaddingIndexMapError : public ScheduleError {
912 public:
913 TransformationPaddingIndexMapError(IRModule mod, IndexMap pad_value)
914 : mod_(mod), pad_value_(pad_value) {}
915
916 String FastErrorString() const final {
917 std::ostringstream ss;
918 ss << "ScheduleError: The IndexMap specifying pad_value has "
919 << pad_value_->final_indices.size() << " outputs, should only have one output";
920 return ss.str();
921 }
922
923 String DetailRenderTemplate() const final {
924 std::ostringstream ss;
925 ss << "ScheduleError: Pad value is specified as " << pad_value_ << " which has "
926 << pad_value_->final_indices.size() << " outputs, but should only have one output";
927 return ss.str();
928 }
929
930 IRModule mod() const final { return mod_; }
931 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
932
933 private:
934 IRModule mod_;
935 IndexMap pad_value_;
936};
937
938class TransformationPaddingTypeError : public ScheduleError {
939 public:
940 TransformationPaddingTypeError(IRModule mod, Buffer buffer, IndexMap pad_value)
941 : mod_(mod), buffer_(buffer), pad_value_(pad_value) {
942 ICHECK_EQ(pad_value_->final_indices.size(), 1);
943 pad_value_dtype_ = pad_value_->final_indices[0].dtype();
944 }
945
946 String FastErrorString() const final {
947 std::ostringstream ss;
948 ss << "ScheduleError: Type mismatch " << buffer_->dtype << " vs " << pad_value_dtype_;
949 return ss.str();
950 }
951
952 String DetailRenderTemplate() const final {
953 std::ostringstream ss;
954 ss << "ScheduleError: Buffer " << buffer_->name << " has elements of type " << buffer_->dtype
955 << ", but the transformation fills padding with " << pad_value_ << ", which is of type "
956 << pad_value_dtype_;
957 return ss.str();
958 }
959
960 IRModule mod() const final { return mod_; }
961 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
962
963 private:
964 IRModule mod_;
965 Buffer buffer_;
966 IndexMap pad_value_;
967 DataType pad_value_dtype_;
968};
969
970class TransformationPaddingExpressionError : public ScheduleError {
971 public:
972 static void Check(IRModule mod, Buffer buffer, IndexMap pad_value) {
973 Visitor visitor(buffer);
974 ICHECK_EQ(pad_value->final_indices.size(), 1)
975 << "Internal error: Should be caught by ScheduleError checks prior to this point";
976 visitor(pad_value->final_indices[0]);
977 if (visitor.illegal_load) {
978 throw TransformationPaddingExpressionError(mod, buffer, pad_value,
979 visitor.illegal_load.value());
980 }
981 }
982
983 private:
984 struct Visitor : ExprVisitor {
985 explicit Visitor(const Buffer& buffer) : buffer_(buffer) {}
986
987 void VisitExpr_(const BufferLoadNode* op) final {
988 if (!op->buffer.same_as(buffer_)) {
989 illegal_load = GetRef<BufferLoad>(op);
990 }
991 ExprVisitor::VisitExpr_(op);
992 }
993
994 const Buffer& buffer_;
995 Optional<BufferLoad> illegal_load;
996 };
997
998 TransformationPaddingExpressionError(IRModule mod, Buffer buffer, IndexMap pad_value,
999 BufferLoad illegal_load)
1000 : mod_(mod), buffer_(buffer), pad_value_(pad_value), illegal_load_(illegal_load) {}
1001
1002 String FastErrorString() const final {
1003 std::ostringstream ss;
1004 ss << "ScheduleError: Pad value may not contain load load from " << illegal_load_->buffer->name;
1005 return ss.str();
1006 }
1007
1008 String DetailRenderTemplate() const final {
1009 std::ostringstream ss;
1010 ss << "ScheduleError: Pad value may only contain BufferLoad from the transformed buffer "
1011 << buffer_->name << ", but pad_value " << pad_value_ << " contains expression "
1012 << illegal_load_;
1013 return ss.str();
1014 }
1015
1016 IRModule mod() const final { return mod_; }
1017 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
1018
1019 IRModule mod_;
1020 Buffer buffer_;
1021 IndexMap pad_value_;
1022 BufferLoad illegal_load_;
1023};
1024
1025class TransformationIntroducesPaddingError : public ScheduleError {
1026 public:
1027 TransformationIntroducesPaddingError(IRModule mod, Buffer buffer, IndexMap index_map,
1028 PrimExpr padding_predicate)
1029 : mod_(std::move(mod)),
1030 buffer_(std::move(buffer)),
1031 index_map_(std::move(index_map)),
1032 padding_predicate_(std::move(padding_predicate)) {}
1033
1034 String FastErrorString() const final {
1035 std::ostringstream ss;
1036 ss << "ScheduleError: Transformation would introduce padding at " << padding_predicate_ << ".";
1037 return ss.str();
1038 }
1039
1040 String DetailRenderTemplate() const final {
1041 auto new_shape = index_map_->MapShape(buffer_->shape);
1042 std::ostringstream os;
1043 os << "The transformation " << index_map_ << " applied on buffer " << buffer_->name
1044 << " of shape " << buffer_->shape << " would result in shape " << new_shape
1045 << ". However, this would introduce padding wherever " << padding_predicate_ << " is true.";
1046 return os.str();
1047 }
1048
1049 IRModule mod() const final { return mod_; }
1050 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
1051
1052 private:
1053 IRModule mod_;
1054 Buffer buffer_;
1055 IndexMap index_map_;
1056 PrimExpr padding_predicate_;
1057};
1058
1059// Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid
1060// dtype-mismatch issues later.
1061IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>& args) {
1062 const auto& initial_indices_orig = index_map->initial_indices;
1063 ICHECK(args.size() == initial_indices_orig.size());
1064
1065 Array<Var> initial_indices;
1066 Map<Var, PrimExpr> var_map;
1067
1068 for (size_t i = 0; i < args.size(); ++i) {
1069 if (args[i]->dtype != initial_indices_orig[i].dtype()) {
1070 auto new_idx = Var(initial_indices_orig[i]->name_hint, args[i]->dtype);
1071 initial_indices.push_back(new_idx);
1072 var_map.Set(initial_indices_orig[i], new_idx);
1073 } else {
1074 initial_indices.push_back(initial_indices_orig[i]);
1075 }
1076 }
1077
1078 if (!var_map.empty()) {
1079 auto final_indices = index_map->final_indices.Map([&](PrimExpr index) {
1080 return SubstituteWithDataTypeLegalization(index,
1081 [&](const Var& var) { return var_map.Get(var); });
1082 });
1083 Optional<IndexMap> opt_inverse_index_map =
1084 Downcast<Optional<IndexMap>>(index_map->inverse_index_map);
1085 if (opt_inverse_index_map.defined()) {
1086 opt_inverse_index_map = LegalizeIndexMapDType(opt_inverse_index_map.value(), final_indices);
1087 }
1088 return IndexMap(initial_indices, final_indices, opt_inverse_index_map);
1089 }
1090 return index_map;
1091}
1092
1093void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
1094 BufferIndexType buffer_index_type, const IndexMap& index_map_orig,
1095 const Optional<IndexMap>& pad_value) {
1096 // Step 1: Input handling and error checking
1097 const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
1098 Buffer old_buffer =
1099 GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, buffer_index_type);
1100
1101 auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer->shape);
1102
1103 auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer);
1104 if (defining_site_sref.defined() && !is_alloc) {
1105 throw BufferIsSubregionError(self->mod, old_buffer);
1106 }
1107 if (pad_value) {
1108 if (pad_value.value()->final_indices.size() != 1) {
1109 throw TransformationPaddingIndexMapError(self->mod, pad_value.value());
1110 }
1111 if (pad_value.value()->final_indices[0]->dtype != old_buffer->dtype) {
1112 throw TransformationPaddingTypeError(self->mod, old_buffer, pad_value.value());
1113 }
1114
1115 TransformationPaddingExpressionError::Check(self->mod, old_buffer, pad_value.value());
1116 }
1117
1118 StmtSRef scope_sref = defining_site_sref.defined()
1119 ? defining_site_sref.value()
1120 : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
1121 const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
1122
1123 auto [inverse, padding_predicate] = [&]() {
1124 Array<Range> region;
1125 for (const auto& dim : old_buffer->shape) {
1126 region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
1127 }
1128 return index_map.NonSurjectiveInverse(region);
1129 }();
1130
1131 bool has_padding = !is_zero(padding_predicate);
1132 if (has_padding && !pad_value.defined()) {
1133 throw TransformationIntroducesPaddingError(self->mod, old_buffer, index_map, padding_predicate);
1134 }
1135
1136 // Step 2: Infer the shape of the new buffer
1137 Buffer new_buffer = old_buffer;
1138 new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape);
1139
1140 // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block
1141 // alloc_buffers.
1142 auto [new_stmt, block_sref_reuse] =
1143 TransformLayoutRewriter::Rewrite(GetRef<Block>(scope_block), old_buffer, new_buffer,
1144 index_map, inverse, padding_predicate, pad_value);
1145 Block new_scope_block = Downcast<Block>(new_stmt);
1146
1147 // Step 4: Rewrite buffer_map of the PrimFunc if necessary.
1148 if (!defining_site_sref.defined()) {
1149 GlobalVar g_var;
1150 GetRootPrimFunc(self->mod, scope_block, &g_var);
1151 IRModuleNode* new_mod = self->mod.CopyOnWrite();
1152 MapNode* new_map = new_mod->functions.CopyOnWrite();
1153 PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
1154 PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
1155 MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite();
1156 for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) {
1157 if ((*it).second.same_as(old_buffer)) {
1158 (*it).second = new_buffer;
1159 }
1160 }
1161 new_map->at(g_var) = std::move(ref_new_func);
1162 }
1163
1164 // Step 4: Replace the scope block with the new block
1165 self->Replace(scope_sref, new_scope_block, block_sref_reuse);
1166}
1167
1168/*!
1169 * \brief Detect the block iter type assoicated with the expression
1170 *
1171 * This function collects block iters in the expression and check if the block iters have the same
1172 * iter type. The detected iter type is the iter type of the block iters in the expression
1173 * if they have the same iter type, otherwise the detected iter type will be kOpaque.
1174 *
1175 * \param expr The expression
1176 * \param block_iter_type_map The mapping from block iter to iter type
1177 * \return The detected block iter type
1178 */
1179IterVarType DetectNewBlockIterType(
1180 const PrimExpr& expr,
1181 const std::unordered_map<const VarNode*, IterVarType>& block_iter_type_map) {
1182 IterVarType result{kOpaque};
1183 bool found = false;
1184 PostOrderVisit(expr, [&](const ObjectRef& obj) {
1185 if (const VarNode* var = obj.as<VarNode>()) {
1186 auto it = block_iter_type_map.find(var);
1187 if (it != block_iter_type_map.end()) {
1188 if (!found) {
1189 found = true;
1190 result = it->second;
1191 } else if (result != it->second) {
1192 result = kOpaque;
1193 return false;
1194 }
1195 }
1196 }
1197 return true;
1198 });
1199 return result;
1200}
1201
1202class NotBijectiveAffineIndexMapError : public ScheduleError {
1203 public:
1204 NotBijectiveAffineIndexMapError(IRModule mod, IndexMap index_map)
1205 : mod_(std::move(mod)), index_map_(std::move(index_map)) {}
1206 String FastErrorString() const final {
1207 return "ScheduleError: The index map is not bijective affine.";
1208 }
1209
1210 String DetailRenderTemplate() const final {
1211 std::ostringstream os;
1212 os << "The index map " << index_map_->ToPythonString() << " is not bijective affine.";
1213 return os.str();
1214 }
1215
1216 IRModule mod() const final { return mod_; }
1217
1218 Array<ObjectRef> LocationsOfInterest() const final { return {}; }
1219
1220 private:
1221 IRModule mod_;
1222 IndexMap index_map_;
1223};
1224
1225class IndexMapNotApplicableToBlockIterError : public ScheduleError {
1226 public:
1227 static void Check(const IRModule mod, const Block& block, const IndexMap& index_map) {
1228 if (index_map->initial_indices.size() != block->iter_vars.size()) {
1229 throw IndexMapNotApplicableToBlockIterError(mod, block, index_map);
1230 }
1231 }
1232 explicit IndexMapNotApplicableToBlockIterError(IRModule mod, Block block, IndexMap index_map)
1233 : mod_(std::move(mod)), block_(std::move(block)), index_map_(std::move(index_map)) {}
1234
1235 String FastErrorString() const final {
1236 return "ScheduleError: The index map can't be applied to block iters because the number of "
1237 "parameters mismatch.";
1238 }
1239
1240 String DetailRenderTemplate() const final {
1241 std::ostringstream os;
1242 os << "The index map " << index_map_->ToPythonString()
1243 << " can't be applied to block iters of {0} because the number of parameters mismatch. "
1244 "Expected: "
1245 << index_map_->initial_indices.size() << ", actual: " << block_->iter_vars.size();
1246 return os.str();
1247 }
1248
1249 IRModule mod() const final { return mod_; }
1250
1251 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
1252
1253 private:
1254 IRModule mod_;
1255 Block block_;
1256 IndexMap index_map_;
1257};
1258
1259class OpaqueNewIterTypeError : public ScheduleError {
1260 public:
1261 explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr iter_value)
1262 : mod_(std::move(mod)), block_(std::move(block)), iter_value_(std::move(iter_value)) {}
1263
1264 String FastErrorString() const final {
1265 return "ScheduleError: Cannot detect the new block iter type because it contains more than one "
1266 "type of original iter vars.";
1267 }
1268
1269 String DetailRenderTemplate() const final {
1270 std::ostringstream os;
1271 os << "Cannot detect the block iter type for new iter value " << iter_value_
1272 << " in {0} because it contains more than one type of original iter vars.";
1273 return os.str();
1274 }
1275
1276 IRModule mod() const final { return mod_; }
1277 Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }
1278
1279 private:
1280 IRModule mod_;
1281 Block block_;
1282 PrimExpr iter_value_;
1283};
1284
1285void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
1286 const IndexMap& index_map) {
1287 const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
1288 const Block& block = GetRef<Block>(block_ptr);
1289 arith::Analyzer analyzer;
1290
1291 // Step 1: Collect outer loops and loop vars
1292 Array<StmtSRef> loops = GetLoops(block_sref); // outer loops of the block
1293 std::unordered_set<const VarNode*> loop_vars; // loop vars of the outer loops
1294 for (const StmtSRef& loop_sref : loops) {
1295 CheckLoopStartsWithZero(self, loop_sref, &analyzer);
1296 loop_vars.emplace(loop_sref->StmtAs<ForNode>()->loop_var.get());
1297 }
1298
1299 // Step 2: Check the all outer loops have a single child and the block bindings are trivial (all
1300 // binding values are loop vars)
1301 StmtSRef scope_sref{nullptr}; // the scope statement for replacement
1302 if (!loops.empty()) {
1303 scope_sref = loops.front();
1304 CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front());
1305 } else {
1306 scope_sref = block_sref;
1307 }
1308
1309 BlockRealize block_realize = GetBlockRealize(self, block_sref);
1310 CheckBlockHasTrivialBinding(self, block_sref);
1311
1312 // Step 3: Collect information of block iter vars
1313 Array<PrimExpr> block_vars; // iter_var->var of each block iter
1314 Map<Var, Range> block_iter_dom; // domain of block iter
1315 std::unordered_map<const VarNode*, IterVarType> block_iter_type; // iter type of block iter
1316
1317 Array<PrimExpr>
1318 block_iter_range_array; // array of block iter extents in the same order as block iters
1319 for (const auto& iter_var : block->iter_vars) {
1320 block_vars.push_back(iter_var->var);
1321 block_iter_dom.Set(iter_var->var, iter_var->dom);
1322 block_iter_type[iter_var->var.get()] = iter_var->iter_type;
1323 ICHECK(is_zero(iter_var->dom->min));
1324 block_iter_range_array.push_back(iter_var->dom->extent);
1325 }
1326
1327 // Step 4: Apply the IndexMap to block iters.
1328 IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map);
1329 Array<PrimExpr> transformed_block_iters = index_map->MapIndices(block_vars);
1330 Array<PrimExpr> new_block_iter_range = index_map->MapShape(block_iter_range_array);
1331
1332 // Step 5: Create the new block after transformation.
1333
1334 // Step 5.1: Create new block iters. After applying the IndexMap f to block iters ax_0, ..., ax_n,
1335 // create block iter each expression in f(ax_0, ..., ax_n).
1336 Array<IterVar> new_block_iters; // new block iters
1337 Array<PrimExpr> new_block_vars; // iter_var->var of new block iters
1338 for (size_t i = 0; i < transformed_block_iters.size(); ++i) {
1339 Var new_block_var{"v" + std::to_string(i), transformed_block_iters[i]->dtype};
1340 new_block_vars.push_back(new_block_var);
1341 IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type);
1342 if (iter_type == kOpaque) {
1343 throw OpaqueNewIterTypeError(self->mod, GetRef<Block>(block_ptr), transformed_block_iters[i]);
1344 }
1345 auto dtype = new_block_var.dtype();
1346 new_block_iters.push_back(IterVar(
1347 /*dom=*/Range::FromMinExtent(make_zero(dtype), new_block_iter_range[i]),
1348 /*var=*/std::move(new_block_var), /*iter_type=*/iter_type));
1349 }
1350
1351 // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters
1352 // in the body.
1353 Map<Var, PrimExpr> inverse_subst_map;
1354 // Construct the inverse map
1355 {
1356 Array<Range> initial_ranges;
1357 for (const PrimExpr& extent : block_iter_range_array) {
1358 initial_ranges.push_back(Range::FromMinExtent(make_const(extent.dtype(), 0), extent));
1359 }
1360 IndexMap inverse_index_map{nullptr};
1361 try {
1362 inverse_index_map = index_map.Inverse(initial_ranges);
1363 } catch (...) {
1364 throw NotBijectiveAffineIndexMapError(self->mod, index_map);
1365 }
1366
1367 Array<PrimExpr> inversed_new_block_vars = inverse_index_map->MapIndices(
1368 new_block_vars); // old block vars written in terms of new block vars
1369
1370 for (int i = 0, n = block_vars.size(); i < n; ++i) {
1371 inverse_subst_map.Set(Downcast<Var>(block_vars[i]), inversed_new_block_vars[i]);
1372 }
1373 }
1374 Block new_block = Downcast<Block>(Substitute(GetRef<Block>(block_ptr), inverse_subst_map));
1375 new_block.CopyOnWrite()->iter_vars = new_block_iters;
1376 new_block = Downcast<Block>(BlockBufferAccessSimplifier::Simplify(new_block, &analyzer));
1377
1378 // Step 5.3: Create outer loops for each new block iter.
1379
1380 // Make new loop vars
1381 Array<PrimExpr> new_loop_vars;
1382 for (int i = 0; i < static_cast<int>(new_block_iters.size()); ++i) {
1383 new_loop_vars.push_back(Var("ax" + std::to_string(i), new_block_iters[i]->var.dtype()));
1384 }
1385
1386 // Make new block realize
1387 BlockRealizeNode* new_block_realize = block_realize.CopyOnWrite();
1388 new_block_realize->iter_values = new_loop_vars;
1389 new_block_realize->block = new_block;
1390
1391 // Generate outer loops
1392 Stmt body = GetRef<Stmt>(new_block_realize);
1393 for (int i = static_cast<int>(new_loop_vars.size()) - 1; i >= 0; --i) {
1394 body = For(Downcast<Var>(new_loop_vars[i]), 0, new_block_iter_range[i], ForKind::kSerial,
1395 std::move(body));
1396 }
1397
1398 // Step 6: Do the actual replacement
1399 self->Replace(scope_sref, body, {{block, new_block}});
1400}
1401
1402class BufferAxisSeparatorMutator : private ReplaceBufferMutator {
1403 public:
1404 static Block Mutate(const Block& scope_block, const Buffer& old_buffer, Buffer new_buffer,
1405 Map<Block, Block>* block_sref_reuse) {
1406 BufferAxisSeparatorMutator mutator(old_buffer, std::move(new_buffer), block_sref_reuse);
1407 return Downcast<Block>(mutator.VisitStmt(scope_block));
1408 }
1409
1410 private:
1411 BufferAxisSeparatorMutator(const Buffer& old_buffer, Buffer new_buffer,
1412 Map<Block, Block>* block_sref_reuse)
1413 : ReplaceBufferMutator(old_buffer, new_buffer, block_sref_reuse) {}
1414
1415 MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final {
1416 auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
1417 if (it != buffer_var_map_.end()) {
1418 const Buffer& new_source_buffer = it->second;
1419 Buffer new_target_buffer = match_buffer->buffer;
1420 new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators;
1421 if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) {
1422 LOG(WARNING)
1423 << "Target buffer in match_buffer doesn't have the same dimensionality as its source "
1424 "buffer. `axis_separators` for the target buffer might be incorrect.";
1425 }
1426 buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer;
1427 return MatchBufferRegion(new_target_buffer,
1428 BufferRegion(new_source_buffer, match_buffer->source->region));
1429 }
1430 return match_buffer;
1431 }
1432};
1433
1434void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
1435 BufferIndexType buffer_index_type, const Array<IntImm>& axis_separators) {
1436 const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
1437 Buffer old_buffer =
1438 GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, buffer_index_type);
1439 auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer);
1440 if (defining_site_sref.defined() && !is_alloc) {
1441 throw BufferIsSubregionError(self->mod, old_buffer);
1442 }
1443
1444 StmtSRef scope_sref = defining_site_sref.defined()
1445 ? defining_site_sref.value()
1446 : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
1447 const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
1448
1449 // Step 1: Check and update axis_separators of the buffer.
1450 Buffer new_buffer = old_buffer;
1451 new_buffer.CopyOnWrite()->axis_separators = axis_separators;
1452
1453 Map<Block, Block> block_sref_reuse;
1454
1455 // Step 2: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc.
1456 Block new_scope_block = BufferAxisSeparatorMutator::Mutate(GetRef<Block>(scope_block), old_buffer,
1457 new_buffer, &block_sref_reuse);
1458 if (!defining_site_sref.defined()) {
1459 // mutate buffer_map of the PrimFunc
1460 GlobalVar g_var;
1461 GetRootPrimFunc(self->mod, scope_block, &g_var);
1462 IRModuleNode* new_mod = self->mod.CopyOnWrite();
1463 MapNode* new_map = new_mod->functions.CopyOnWrite();
1464 PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
1465 PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
1466 MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite();
1467 for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) {
1468 if ((*it).second.same_as(old_buffer)) {
1469 (*it).second = new_buffer;
1470 }
1471 }
1472 new_map->at(g_var) = std::move(ref_new_func);
1473 }
1474
1475 // Step 4: Replace the scope block with the new block
1476 self->Replace(scope_sref, new_scope_block, block_sref_reuse);
1477}
1478
1479/******** InstructionKind Registration ********/
1480
1481struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits> {
1482 static constexpr const char* kName = "TransformLayout";
1483 static constexpr bool kIsPure = false;
1484
1485 private:
1486 static constexpr size_t kNumInputs = 2;
1487 static constexpr size_t kNumAttrs = 3;
1488 static constexpr size_t kNumDecisions = 0;
1489
1490 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map,
1491 Integer buffer_index, Integer buffer_index_type,
1492 Optional<IndexMap> pad_value) {
1493 return sch->TransformLayout(block_rv, buffer_index.IntValue(),
1494 static_cast<BufferIndexType>(buffer_index_type->value), index_map,
1495 pad_value);
1496 }
1497
1498 static String UnpackedAsPython(Array<String> outputs, String block_rv, IndexMap index_map,
1499 Integer buffer_index, Integer buffer_index_type,
1500 Optional<IndexMap> pad_value) {
1501 PythonAPICall py("transform_layout");
1502 py.Input("block", block_rv);
1503
1504 std::ostringstream os;
1505 os << "(\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
1506 << "\", " << buffer_index << ")";
1507 py.Input("buffer", os.str());
1508 py.Input("index_map", index_map->ToPythonString());
1509 py.Input("pad_value", pad_value ? pad_value.value()->ToPythonString() : "None");
1510
1511 return py.Str();
1512 }
1513
1514 public:
1515 static ObjectRef AttrsAsJSON(const Array<ObjectRef>& attrs) {
1516 Array<ObjectRef> attrs_record;
1517 attrs_record.reserve(kNumAttrs);
1518 attrs_record.push_back(attrs[0]);
1519 attrs_record.push_back(attrs[1]);
1520 if (attrs[2].defined()) {
1521 attrs_record.push_back(String(::tvm::SaveJSON(attrs[2])));
1522 } else {
1523 attrs_record.push_back(attrs[2]);
1524 }
1525 return std::move(attrs_record);
1526 }
1527
1528 static Array<ObjectRef> AttrsFromJSON(const ObjectRef& attrs_record_) {
1529 Array<ObjectRef> attrs_record = Downcast<Array<ObjectRef>>(attrs_record_);
1530 Array<ObjectRef> attrs;
1531 attrs.push_back(attrs_record[0]);
1532 attrs.push_back(attrs_record[1]);
1533 if (attrs_record[2].defined()) {
1534 attrs.push_back(::tvm::LoadJSON(Downcast<String>(attrs_record[2])));
1535 } else {
1536 attrs.push_back(attrs_record[2]);
1537 }
1538 return attrs;
1539 }
1540
1541 template <typename>
1542 friend struct ::tvm::tir::UnpackedInstTraits;
1543};
1544
1545struct TransformBlockLayoutTraits : public UnpackedInstTraits<TransformBlockLayoutTraits> {
1546 static constexpr const char* kName = "TransformBlockLayout";
1547 static constexpr bool kIsPure = false;
1548
1549 private:
1550 static constexpr size_t kNumInputs = 1;
1551 static constexpr size_t kNumAttrs = 1;
1552 static constexpr size_t kNumDecisions = 0;
1553
1554 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map) {
1555 return sch->TransformBlockLayout(block_rv, index_map);
1556 }
1557
1558 static String UnpackedAsPython(Array<String> outputs, String block_rv, IndexMap index_map) {
1559 PythonAPICall py("transform_block_layout");
1560 py.Input("block", block_rv);
1561 py.Input("index_map", index_map->ToPythonString());
1562 return py.Str();
1563 }
1564
1565 public:
1566 static ObjectRef AttrsAsJSON(const Array<ObjectRef>& attrs) {
1567 Array<ObjectRef> attrs_record;
1568 attrs_record.reserve(kNumAttrs);
1569 attrs_record.push_back(String(::tvm::SaveJSON(attrs[0])));
1570 return std::move(attrs_record);
1571 }
1572
1573 static Array<ObjectRef> AttrsFromJSON(const ObjectRef& attrs_record_) {
1574 Array<ObjectRef> attrs_record = Downcast<Array<ObjectRef>>(attrs_record_);
1575 Array<ObjectRef> attrs;
1576 attrs.push_back(::tvm::LoadJSON(Downcast<String>(attrs_record[0])));
1577 return attrs;
1578 }
1579
1580 template <typename>
1581 friend struct ::tvm::tir::UnpackedInstTraits;
1582};
1583
1584struct SetAxisSeparatorTraits : public UnpackedInstTraits<SetAxisSeparatorTraits> {
1585 static constexpr const char* kName = "SetAxisSeparator";
1586 static constexpr bool kIsPure = false;
1587
1588 private:
1589 static constexpr size_t kNumInputs = 1;
1590 static constexpr size_t kNumAttrs = 3;
1591 static constexpr size_t kNumDecisions = 0;
1592
1593 static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index,
1594 Integer buffer_index_type, Array<IntImm> axis_separators) {
1595 return sch->SetAxisSeparator(block_rv, buffer_index.IntValue(),
1596 static_cast<BufferIndexType>(buffer_index_type->value),
1597 axis_separators);
1598 }
1599
1600 static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index,
1601 Integer buffer_index_type, Array<IntImm> axis_separators) {
1602 PythonAPICall py("set_axis_separator");
1603 py.Input("block", block_rv);
1604
1605 std::ostringstream os;
1606 os << "(\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
1607 << "\", " << buffer_index << ")";
1608 py.Input("buffer", os.str());
1609
1610 py.Input("axis_separators", axis_separators);
1611 return py.Str();
1612 }
1613
1614 template <typename>
1615 friend struct ::tvm::tir::UnpackedInstTraits;
1616};
1617
1618TVM_REGISTER_INST_KIND_TRAITS(TransformLayoutTraits);
1619TVM_REGISTER_INST_KIND_TRAITS(TransformBlockLayoutTraits);
1620TVM_REGISTER_INST_KIND_TRAITS(SetAxisSeparatorTraits);
1621
1622} // namespace tir
1623} // namespace tvm
1624