1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include <tvm/node/node.h> |
21 | |
22 | #include <optional> |
23 | #include <variant> |
24 | |
25 | #include "../../../arith/ir_mutator_with_analyzer.h" |
26 | #include "../utils.h" |
27 | |
28 | namespace tvm { |
29 | namespace tir { |
30 | |
31 | /*! \brief Planning stage prior to rewriting in TransformLayoutRewriter |
32 | * |
33 | * There are four ways that transformation may be handled. Each |
34 | * updates the buffer shape and the indices used to acces the buffer |
35 | * in BufferStore/BufferLoad nodes, but differ in how they handle the |
36 | * `pad_value`. In order of preference, the different strategies are |
37 | * as follows: |
38 | * |
39 | * 1. NoPaddingRequired. The transformation does not introduce |
40 | * padding, so only local changes to update the indices of |
41 | * BufferLoad/BufferStore nodes are required. No blocks are added, |
42 | * removed, or replaced. |
43 | * |
44 | * 2. ProloguePlan. The transformation introduces padding, but the |
45 | * analyzed block has no write stages for the transformed buffer. |
46 | * This buffer is an input and the caller is responsible for ensuring |
47 | * that the padding contains the specified `pad_value`. The generated |
48 | * prologue contains `builtin::assume()` calls that will expose this |
49 | * known value during scheduling/simplification, but will be removed |
50 | * during lowering. |
51 | * |
52 | * 3. ReplacementPlan. The transformation introduces padding, has at |
53 | * least one write stage for the transformed buffer, and at least one |
54 | * of those write stages writes to all pre-transformation indices |
55 | * following a row-major traversal. These write stage is rewritten to |
56 | * be row-major traversals of the post-transformation indices, with a |
57 | * `tir::if_then_else` call to write either the specified `pad_value` |
58 | * into padding or the computed value into non-padding. |
59 | * |
60 | * 4. EpiloguePlan. The transformation introduces padding, has at |
61 | * least one write stage for the transformed buffer, but no write |
62 | * stage can be rewritten to use `tir::if_then_else`. The |
63 | * transformation still requires the `pad_value` to be written into |
64 | * the padding, so a new block is inserted after the last write stage |
65 | * to explicitly fill the padding. |
66 | * |
67 | */ |
68 | class TransformLayoutPlanner : private StmtExprVisitor { |
69 | public: |
70 | // Statement to be inserted prior to the analyzed block |
71 | struct ProloguePlan { |
72 | Stmt prologue; |
73 | }; |
74 | |
75 | // Loops within the analyzed block that should be replaced |
76 | struct ReplacementPlan { |
77 | Map<For, Stmt> replacements; |
78 | Map<Block, Block> new_block_to_old; |
79 | }; |
80 | |
81 | // The block to be inserted, along with the location at which it |
82 | // should be inserted. The location will be either a For or a |
83 | // Block, and will be after all writes the transformed buffer. |
84 | struct EpiloguePlan { |
85 | Stmt insert_after; |
86 | Stmt new_block; |
87 | }; |
88 | |
89 | struct NoPaddingRequired {}; |
90 | |
91 | using TransformPlan = |
92 | std::variant<ProloguePlan, ReplacementPlan, EpiloguePlan, NoPaddingRequired>; |
93 | |
94 | static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, |
95 | IndexMap inverse, PrimExpr padding_predicate, |
96 | Optional<IndexMap> pad_value) { |
97 | ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1) |
98 | << "Internal error: Should be caught by ScheduleError checks prior to this point" ; |
99 | TransformLayoutPlanner visitor(old_buffer); |
100 | visitor(block); |
101 | return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value); |
102 | } |
103 | |
104 | private: |
105 | struct WriteInfo { |
106 | // The BufferStore object |
107 | BufferStore store; |
108 | |
109 | // The block realize that contains the store, if any. |
110 | Optional<BlockRealize> innermost_block_realize; |
111 | |
112 | // The nested loops whose values contribute to the indices used in |
113 | // the store. Not all loop variables in the loopnest need to |
114 | // contribute, but the first and last must. |
115 | std::vector<For> dependent_loopnest; |
116 | |
117 | // Whether the padding could be represented as a tir::if_then_else |
118 | // node. This requires that the surrounding loop iterators |
119 | // iterate over all pre-transformation buffer axes, that there are |
120 | // no data dependencies between loop iterations, and that |
121 | bool contains_row_major_traversal{false}; |
122 | }; |
123 | |
124 | explicit TransformLayoutPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} |
125 | |
126 | void VisitStmt_(const ForNode* op) override { |
127 | BindLoopVar context(this, GetRef<For>(op)); |
128 | StmtExprVisitor::VisitStmt_(op); |
129 | } |
130 | |
131 | void VisitStmt_(const LetStmtNode* op) override { |
132 | BindVariableDefinition context(this, op->var, op->value); |
133 | StmtExprVisitor::VisitStmt_(op); |
134 | } |
135 | |
136 | void VisitStmt_(const BlockRealizeNode* op) override { |
137 | BindBlockRealize context(this, GetRef<BlockRealize>(op)); |
138 | StmtExprVisitor::VisitStmt_(op); |
139 | } |
140 | |
141 | void VisitStmt_(const BufferStoreNode* op) override { |
142 | if (!op->buffer.same_as(old_buffer_)) { |
143 | return; |
144 | } |
145 | |
146 | std::optional<std::pair<size_t, size_t>> loop_dependency_range = std::nullopt; |
147 | for (const auto& index : op->indices) { |
148 | if (auto index_depth = LoopDependencyRange(index); index_depth.has_value()) { |
149 | if (loop_dependency_range) { |
150 | loop_dependency_range = { |
151 | std::min(loop_dependency_range.value().first, index_depth.value().first), |
152 | std::max(loop_dependency_range.value().second, index_depth.value().second)}; |
153 | } else { |
154 | loop_dependency_range = index_depth; |
155 | } |
156 | } |
157 | } |
158 | |
159 | WriteInfo write_info; |
160 | write_info.store = GetRef<BufferStore>(op); |
161 | if (loop_dependency_range) { |
162 | size_t i = loop_dependency_range.value().first; |
163 | size_t j = loop_dependency_range.value().second; |
164 | ICHECK_LT(i, active_loops_.size()); |
165 | ICHECK_LT(j, active_loops_.size()); |
166 | |
167 | write_info.dependent_loopnest = {active_loops_.begin() + i, active_loops_.begin() + j + 1}; |
168 | } |
169 | write_info.innermost_block_realize = innermost_block_realize_; |
170 | |
171 | write_info.contains_row_major_traversal = [&]() -> bool { |
172 | const auto& loopnest = write_info.dependent_loopnest; |
173 | if (loopnest.empty()) { |
174 | return false; |
175 | } |
176 | |
177 | if (loopnest.size() != old_buffer_->shape.size() || loopnest.size() != op->indices.size()) { |
178 | return false; |
179 | } |
180 | |
181 | for (size_t i = 0; i < loopnest.size(); i++) { |
182 | const For& loop = loopnest[i]; |
183 | const PrimExpr& buffer_dim = old_buffer_->shape[i]; |
184 | PrimExpr index = Substitute(op->indices[i], active_var_bindings_); |
185 | bool is_loop_over_axis = index.same_as(loop->loop_var) && is_const_int(loop->min, 0) && |
186 | ExprDeepEqual()(loop->extent, buffer_dim) && |
187 | loop->kind == ForKind::kSerial; |
188 | if (!is_loop_over_axis) { |
189 | return false; |
190 | } |
191 | } |
192 | |
193 | return true; |
194 | }(); |
195 | |
196 | write_info_.push_back(write_info); |
197 | |
198 | // Don't need to continue recursing, as the entire goal was to |
199 | // find the BufferStore. |
200 | } |
201 | |
202 | std::optional<std::pair<size_t, size_t>> LoopDependencyRange(const PrimExpr& expr) const { |
203 | std::optional<std::pair<size_t, size_t>> prev = std::nullopt; |
204 | for (const auto& var : UndefinedVars(expr)) { |
205 | auto it = loop_depth_lookup_.find(var.get()); |
206 | if (it != loop_depth_lookup_.end()) { |
207 | if (prev.has_value()) { |
208 | prev = {std::min(prev.value().first, it->second.first), |
209 | std::max(prev.value().second, it->second.second)}; |
210 | } else { |
211 | prev = it->second; |
212 | } |
213 | } |
214 | } |
215 | |
216 | return prev; |
217 | } |
218 | |
219 | class BufferStoreReplacer : public StmtExprMutator { |
220 | public: |
221 | BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer, PrimExpr padding_predicate, |
222 | const IndexMap& inverse, const Optional<IndexMap>& pad_value, |
223 | Map<Block, Block>* new_block_to_old) |
224 | : info(info), |
225 | new_buffer(new_buffer), |
226 | new_indices(inverse->initial_indices.Map([](const Var& var) -> PrimExpr { return var; })), |
227 | padding_predicate(padding_predicate), |
228 | inverse(inverse), |
229 | pad_value(pad_value), |
230 | new_block_to_old(*new_block_to_old) { |
231 | ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size()); |
232 | for (size_t i = 0; i < info.dependent_loopnest.size(); i++) { |
233 | Var var = info.dependent_loopnest[i]->loop_var; |
234 | PrimExpr expr = inverse->final_indices[i]; |
235 | var_remap.Set(var, expr); |
236 | } |
237 | |
238 | DefineBlockUpdates(); |
239 | } |
240 | |
241 | bool is_all_stores_replaced() const { return all_stores_replaced; } |
242 | |
243 | private: |
244 | void DefineBlockUpdates() { |
245 | if (!info.innermost_block_realize) { |
246 | return; |
247 | } |
248 | |
249 | BlockRealize block_realize = info.innermost_block_realize.value(); |
250 | const auto& block = block_realize->block; |
251 | const Array<PrimExpr>& old_indices = info.store->indices; |
252 | const auto& old_iter_vars = block->iter_vars; |
253 | |
254 | this->new_iter_vars = old_iter_vars; |
255 | this->new_iter_values = block_realize->iter_values; |
256 | |
257 | if (old_indices.empty()) { |
258 | return; |
259 | } |
260 | |
261 | // Find the block iterators that are used to access the buffer. Must be in the same |
262 | // order as they appear in the indices. |
263 | if (block->iter_vars.size() < old_indices.size()) { |
264 | return; |
265 | } |
266 | |
267 | size_t block_index_start = 0; |
268 | for (; block_index_start < old_iter_vars.size() - old_indices.size(); block_index_start++) { |
269 | if (old_indices[0].same_as(old_iter_vars[block_index_start]->var)) { |
270 | break; |
271 | } |
272 | } |
273 | if (block_index_start > old_iter_vars.size() - old_indices.size()) { |
274 | return; |
275 | } |
276 | |
277 | for (size_t i = 0; i < old_indices.size(); i++) { |
278 | if (!old_indices[i].same_as(old_iter_vars[block_index_start + i]->var) || |
279 | old_iter_vars[block_index_start + i]->iter_type != kDataPar) { |
280 | return; |
281 | } |
282 | } |
283 | |
284 | // If we got to this point, all indices used to access the |
285 | // buffer are virtual indices defined in the innermost block. |
286 | // Therefore, generate new virtual indices for iterating over |
287 | // the post-transform buffer. |
288 | |
289 | new_indices = inverse->initial_indices.Map([](Var var) -> PrimExpr { |
290 | std::stringstream ss; |
291 | ss << "v_" << var->name_hint; |
292 | return Var(ss.str(), var.dtype()); |
293 | }); |
294 | |
295 | Map<Var, PrimExpr> |
296 | loop_var_to_virtual_var; // For updating padding_predicate in terms of the new indices |
297 | Array<PrimExpr> new_iter_values; // For BlockRealize |
298 | Array<IterVar> new_iter_vars; // For Block |
299 | |
300 | for (size_t i = 0; i < block_index_start; i++) { |
301 | new_iter_vars.push_back(old_iter_vars[i]); |
302 | new_iter_values.push_back(block_realize->iter_values[i]); |
303 | } |
304 | |
305 | ICHECK_EQ(new_indices.size(), new_buffer->shape.size()); |
306 | for (size_t i = 0; i < new_indices.size(); i++) { |
307 | Var var = inverse->initial_indices[i]; |
308 | Var virtual_var = Downcast<Var>(new_indices[i]); |
309 | PrimExpr dim = new_buffer->shape[i]; |
310 | new_iter_values.push_back(var); |
311 | new_iter_vars.push_back( |
312 | IterVar(Range::FromMinExtent(make_zero(dim.dtype()), dim), virtual_var, kDataPar)); |
313 | loop_var_to_virtual_var.Set(var, virtual_var); |
314 | } |
315 | |
316 | for (size_t i = block_index_start + old_indices.size(); i < old_iter_vars.size(); i++) { |
317 | new_iter_vars.push_back(old_iter_vars[i]); |
318 | new_iter_values.push_back(block_realize->iter_values[i]); |
319 | } |
320 | |
321 | ICHECK_EQ(inverse->final_indices.size(), old_indices.size()); |
322 | for (size_t i = 0; i < old_indices.size(); i++) { |
323 | Var var = Downcast<Var>(old_indices[i]); |
324 | PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var); |
325 | var_remap.Set(var, expr); |
326 | } |
327 | |
328 | padding_predicate = Substitute(padding_predicate, loop_var_to_virtual_var); |
329 | |
330 | this->new_iter_vars = new_iter_vars; |
331 | this->new_iter_values = new_iter_values; |
332 | } |
333 | |
334 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
335 | bool can_replace = [&]() -> bool { |
336 | if (!op->buffer.same_as(info.store->buffer)) { |
337 | return false; |
338 | } |
339 | |
340 | const Array<PrimExpr>& old_indices = info.store->indices; |
341 | |
342 | ICHECK_EQ(old_indices.size(), op->indices.size()); |
343 | ExprDeepEqual expr_equal; |
344 | for (size_t i = 0; i < old_indices.size(); i++) { |
345 | if (!expr_equal(old_indices[i], op->indices[i])) { |
346 | return false; |
347 | } |
348 | } |
349 | return true; |
350 | }(); |
351 | |
352 | BufferStore store = GetRef<BufferStore>(op); |
353 | if (can_replace) { |
354 | PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_indices)[0]; |
355 | store = |
356 | BufferStore(new_buffer, if_then_else(padding_predicate, pad_value_at_index, op->value), |
357 | new_indices); |
358 | } else { |
359 | all_stores_replaced = false; |
360 | } |
361 | return StmtExprMutator::VisitStmt_(store.get()); |
362 | } |
363 | |
364 | Stmt VisitStmt_(const BlockRealizeNode* op) final { |
365 | BlockRealize realize = Downcast<BlockRealize>(StmtExprMutator::VisitStmt_(op)); |
366 | |
367 | if (op == info.innermost_block_realize.get()) { |
368 | Block block = realize->block; |
369 | if (!block->iter_vars.same_as(this->new_iter_vars)) { |
370 | block.CopyOnWrite()->iter_vars = this->new_iter_vars; |
371 | RecordReplacement(op->block, block); |
372 | } |
373 | |
374 | if (!block.same_as(realize->block) || |
375 | !realize->iter_values.same_as(this->new_iter_values)) { |
376 | auto write_ptr = realize.CopyOnWrite(); |
377 | write_ptr->block = block; |
378 | write_ptr->iter_values = this->new_iter_values; |
379 | } |
380 | } |
381 | |
382 | return std::move(realize); |
383 | } |
384 | |
385 | Stmt VisitStmt_(const BlockNode* op) final { |
386 | Block orig = GetRef<Block>(op); |
387 | Block mutated = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); |
388 | |
389 | RecordReplacement(orig, mutated); |
390 | return std::move(mutated); |
391 | } |
392 | |
393 | PrimExpr VisitExpr_(const VarNode* op) final { |
394 | Var var = GetRef<Var>(op); |
395 | if (auto opt = var_remap.Get(var)) { |
396 | return opt.value(); |
397 | } else { |
398 | return std::move(var); |
399 | } |
400 | } |
401 | |
402 | void RecordReplacement(Block before, Block after) { |
403 | if (before.same_as(after)) { |
404 | return; |
405 | } |
406 | |
407 | ICHECK(!new_block_to_old.count(after)); |
408 | |
409 | while (true) { |
410 | if (auto opt = new_block_to_old.Get(before)) { |
411 | before = opt.value(); |
412 | } else { |
413 | break; |
414 | } |
415 | } |
416 | |
417 | new_block_to_old.Set(after, before); |
418 | } |
419 | |
420 | const WriteInfo& info; |
421 | const Buffer& new_buffer; |
422 | Array<PrimExpr> new_indices; |
423 | Array<IterVar> new_iter_vars; |
424 | Array<PrimExpr> new_iter_values; |
425 | PrimExpr padding_predicate; |
426 | const IndexMap& inverse; |
427 | const Optional<IndexMap>& pad_value; |
428 | Map<Block, Block>& new_block_to_old; |
429 | bool all_stores_replaced{true}; |
430 | |
431 | Map<Var, PrimExpr> var_remap; |
432 | }; |
433 | |
434 | TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, |
435 | PrimExpr padding_predicate, Optional<IndexMap> pad_value) const { |
436 | if (auto prologue_plan = |
437 | FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value); |
438 | prologue_plan.has_value()) { |
439 | return prologue_plan.value(); |
440 | } else if (auto replacement_plan = FinalizeReplacementPlan(new_buffer, index_map, inverse, |
441 | padding_predicate, pad_value); |
442 | replacement_plan.has_value()) { |
443 | return replacement_plan.value(); |
444 | } else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer, index_map, inverse, |
445 | padding_predicate, pad_value); |
446 | epilogue_plan.has_value()) { |
447 | return epilogue_plan.value(); |
448 | } else { |
449 | return NoPaddingRequired(); |
450 | } |
451 | } |
452 | |
453 | std::optional<ProloguePlan> FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map, |
454 | IndexMap inverse, PrimExpr padding_predicate, |
455 | Optional<IndexMap> pad_value) const { |
456 | if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) { |
457 | return std::nullopt; |
458 | } |
459 | |
460 | Array<IterVar> iter_vars; |
461 | Array<PrimExpr> iter_values; |
462 | Array<PrimExpr> indices; |
463 | Map<Var, PrimExpr> loop_indices_to_block_indices; |
464 | ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); |
465 | for (size_t i = 0; i < inverse->initial_indices.size(); i++) { |
466 | const auto& loop_var = inverse->initial_indices[i]; |
467 | const auto& dim = new_buffer->shape[i]; |
468 | Var block_var("v_" + loop_var->name_hint, loop_var->dtype); |
469 | IterVar iter_var(Range(0, dim), block_var, kDataPar); |
470 | loop_indices_to_block_indices.Set(loop_var, block_var); |
471 | indices.push_back(iter_var->var); |
472 | iter_vars.push_back(iter_var); |
473 | iter_values.push_back(loop_var); |
474 | } |
475 | padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices); |
476 | |
477 | PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0]; |
478 | PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value_at_index); |
479 | Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr})); |
480 | |
481 | std::stringstream block_name; |
482 | block_name << "buffer_" << new_buffer->name << "_assumptions" ; |
483 | auto read_region = BufferRegion::FromPoint(new_buffer, indices); |
484 | stmt = BlockRealize(iter_values, Bool(true), |
485 | Block(iter_vars, {read_region}, {}, block_name.str(), stmt)); |
486 | |
487 | for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { |
488 | size_t i = (inverse->initial_indices.size() - 1) - rev_i; |
489 | Var loop_var = inverse->initial_indices[i]; |
490 | PrimExpr extent = new_buffer->shape[i]; |
491 | stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt); |
492 | } |
493 | return ProloguePlan{stmt}; |
494 | } |
495 | |
496 | std::optional<ReplacementPlan> FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map, |
497 | IndexMap inverse, |
498 | PrimExpr padding_predicate, |
499 | Optional<IndexMap> pad_value) const { |
500 | if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { |
501 | return std::nullopt; |
502 | } |
503 | |
504 | Map<Block, Block> new_block_to_old; |
505 | auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional<Stmt> { |
506 | if (!info.contains_row_major_traversal || !pad_value.defined() || |
507 | is_zero(padding_predicate)) { |
508 | return NullOpt; |
509 | } |
510 | |
511 | BufferStoreReplacer replacer(info, new_buffer, padding_predicate, inverse, pad_value, |
512 | &new_block_to_old); |
513 | Stmt stmt = replacer(info.dependent_loopnest.back()->body); |
514 | if (!replacer.is_all_stores_replaced()) { |
515 | return NullOpt; |
516 | } |
517 | |
518 | ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); |
519 | for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { |
520 | size_t i = (inverse->initial_indices.size() - 1) - rev_i; |
521 | Var loop_var = inverse->initial_indices[i]; |
522 | PrimExpr extent = new_buffer->shape[i]; |
523 | stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt); |
524 | } |
525 | |
526 | return stmt; |
527 | }; |
528 | |
529 | Map<For, Stmt> loop_replacements; |
530 | |
531 | for (const auto& info : write_info_) { |
532 | if (info.dependent_loopnest.size()) { |
533 | if (auto opt_stmt = generate_if_then_else_block(info)) { |
534 | loop_replacements.Set(info.dependent_loopnest[0], opt_stmt.value()); |
535 | } |
536 | } |
537 | } |
538 | |
539 | if (loop_replacements.size()) { |
540 | return ReplacementPlan{std::move(loop_replacements), std::move(new_block_to_old)}; |
541 | } else { |
542 | return std::nullopt; |
543 | } |
544 | } |
545 | |
546 | std::optional<EpiloguePlan> FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map, |
547 | IndexMap inverse, PrimExpr padding_predicate, |
548 | Optional<IndexMap> pad_value) const { |
549 | if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { |
550 | return std::nullopt; |
551 | } |
552 | |
553 | Array<IterVar> iter_vars; |
554 | Array<PrimExpr> iter_values; |
555 | Array<PrimExpr> indices; |
556 | ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); |
557 | for (size_t i = 0; i < inverse->initial_indices.size(); i++) { |
558 | const auto& loop_var = inverse->initial_indices[i]; |
559 | const auto& dim = new_buffer->shape[i]; |
560 | Var block_var("v_" + loop_var->name_hint, loop_var->dtype); |
561 | IterVar iter_var(Range(0, dim), block_var, kDataPar); |
562 | indices.push_back(iter_var->var); |
563 | iter_vars.push_back(iter_var); |
564 | iter_values.push_back(loop_var); |
565 | } |
566 | |
567 | PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0]; |
568 | Stmt stmt = BufferStore(new_buffer, pad_value_at_index, indices); |
569 | |
570 | std::stringstream block_name; |
571 | block_name << "buffer_" << new_buffer->name << "_padding" ; |
572 | auto write_region = BufferRegion::FromPoint(new_buffer, indices); |
573 | stmt = BlockRealize(iter_values, padding_predicate, |
574 | Block(iter_vars, {}, {write_region}, block_name.str(), stmt)); |
575 | |
576 | ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); |
577 | for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { |
578 | size_t i = (inverse->initial_indices.size() - 1) - rev_i; |
579 | Var loop_var = inverse->initial_indices[i]; |
580 | PrimExpr extent = new_buffer->shape[i]; |
581 | stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt); |
582 | } |
583 | |
584 | const auto& info = write_info_.back(); |
585 | Stmt insert_after = [&]() -> Stmt { |
586 | if (info.dependent_loopnest.size()) { |
587 | return info.dependent_loopnest.front(); |
588 | } else if (info.innermost_block_realize) { |
589 | return info.innermost_block_realize.value(); |
590 | } else { |
591 | LOG(FATAL) << "Write occured outside of any block/loop" ; |
592 | } |
593 | }(); |
594 | return EpiloguePlan{insert_after, stmt}; |
595 | } |
596 | |
597 | struct BindLoopVar { |
598 | BindLoopVar(TransformLayoutPlanner* self, For for_node) |
599 | : self_(self), var_(for_node->loop_var) { |
600 | size_t loop_depth = self_->active_loops_.size(); |
601 | self_->loop_depth_lookup_[var_.get()] = {loop_depth, loop_depth}; |
602 | self_->active_loops_.push_back(std::move(for_node)); |
603 | } |
604 | ~BindLoopVar() { |
605 | self_->active_loops_.pop_back(); |
606 | self_->loop_depth_lookup_.erase(var_.get()); |
607 | } |
608 | BindLoopVar(const BindLoopVar&) = delete; |
609 | BindLoopVar& operator=(const BindLoopVar&) = delete; |
610 | BindLoopVar(BindLoopVar&&) = delete; |
611 | BindLoopVar& operator=(BindLoopVar&&) = delete; |
612 | |
613 | TransformLayoutPlanner* self_{nullptr}; |
614 | Var var_; |
615 | }; |
616 | |
617 | struct BindVariableDefinition { |
618 | BindVariableDefinition() {} |
619 | BindVariableDefinition(TransformLayoutPlanner* self, Var var, PrimExpr value) |
620 | : self_(self), var_(var) { |
621 | if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) { |
622 | self_->loop_depth_lookup_[var_.get()] = loop_depth.value(); |
623 | self_->active_var_bindings_[var_.get()] = Substitute(value, self_->active_var_bindings_); |
624 | } |
625 | } |
626 | ~BindVariableDefinition() { |
627 | if (self_) { |
628 | self_->loop_depth_lookup_.erase(var_.get()); |
629 | self_->active_var_bindings_.erase(var_.get()); |
630 | } |
631 | } |
632 | BindVariableDefinition(const BindVariableDefinition&) = delete; |
633 | BindVariableDefinition& operator=(const BindVariableDefinition&) = delete; |
634 | BindVariableDefinition(BindVariableDefinition&& other) : BindVariableDefinition() { |
635 | swap(other); |
636 | } |
637 | BindVariableDefinition& operator=(BindVariableDefinition&& other) { |
638 | swap(other); |
639 | return *this; |
640 | } |
641 | void swap(BindVariableDefinition& other) { |
642 | std::swap(self_, other.self_); |
643 | std::swap(var_, other.var_); |
644 | } |
645 | |
646 | TransformLayoutPlanner* self_{nullptr}; |
647 | Var var_; |
648 | }; |
649 | |
650 | struct BindBlockRealize { |
651 | BindBlockRealize(TransformLayoutPlanner* self, BlockRealize block_realize) : self_(self) { |
652 | ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size()); |
653 | for (size_t i = 0; i < block_realize->iter_values.size(); i++) { |
654 | bound_vars_.emplace_back(self, block_realize->block->iter_vars[i]->var, |
655 | block_realize->iter_values[i]); |
656 | } |
657 | cache_ = std::move(block_realize); |
658 | std::swap(self_->innermost_block_realize_, cache_); |
659 | } |
660 | ~BindBlockRealize() { std::swap(self_->innermost_block_realize_, cache_); } |
661 | BindBlockRealize(const BindBlockRealize&) = delete; |
662 | BindBlockRealize& operator=(const BindBlockRealize&) = delete; |
663 | BindBlockRealize(BindBlockRealize&&) = delete; |
664 | BindBlockRealize& operator=(BindBlockRealize&&) = delete; |
665 | |
666 | TransformLayoutPlanner* self_{nullptr}; |
667 | Optional<BlockRealize> cache_; |
668 | std::vector<BindVariableDefinition> bound_vars_; |
669 | }; |
670 | |
671 | /*! \brief Collected information about each BufferStore */ |
672 | std::vector<WriteInfo> write_info_; |
673 | |
674 | /*! \brief The loop iterators surrounding the current node |
675 | * |
676 | * The outermost loop iterator is `active_loops_.front()`, and the |
677 | * innermost loop iterator is `active_loops_.back()`. |
678 | * |
679 | * Used to fill the `WriteInfo::dependent_loopnest` field. |
680 | */ |
681 | std::vector<For> active_loops_; |
682 | |
683 | /*! \brief Lookup for the outer/inner loops |
684 | * |
685 | * Used to fill the `WriteInfo::dependent_loopnest` field. |
686 | */ |
687 | std::unordered_map<const VarNode*, std::pair<size_t, size_t>> loop_depth_lookup_; |
688 | |
689 | /*! \brief The variable mappings that are currently in-scope |
690 | * |
691 | * Used to determine whether the indices of a BufferStore are a |
692 | * row-major traversal, even if they are rebound in let/block |
693 | * mappings. |
694 | */ |
695 | std::unordered_map<const VarNode*, PrimExpr> active_var_bindings_; |
696 | |
697 | /*! \brief The innermost BlockRealize surrounding the current node |
698 | * |
699 | * Used to fill the `WriteInfo::innermost_block_realize` field.. |
700 | */ |
701 | Optional<BlockRealize> innermost_block_realize_{NullOpt}; |
702 | |
703 | /*! \brief The buffer to be replaced */ |
704 | Buffer old_buffer_; |
705 | }; |
706 | |
707 | class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { |
708 | public: |
709 | /*! |
710 | * \brief Rewrite the access to the buffer after the transformation |
711 | * \param scope_stmt The parent statement that contains all accesses to the target buffer |
712 | * \param old_buffer The target buffer before transformation |
713 | * \param new_buffer The new buffer after transformation |
714 | * \param index_map The transformation applied to the buffer |
715 | * \return The new AST rooting at the original parent scope and the map from the old block to the |
716 | * new block |
717 | */ |
718 | static std::pair<Stmt, Map<Block, Block>> Rewrite( |
719 | const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer, |
720 | const IndexMap& index_map, const IndexMap& inverse, const PrimExpr& padding_predicate, |
721 | const Optional<IndexMap>& pad_value) { |
722 | auto plan = TransformLayoutPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, inverse, |
723 | padding_predicate, pad_value); |
724 | |
725 | arith::Analyzer analyzer; |
726 | TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer); |
727 | Block result = Downcast<Block>(rewriter(scope_stmt)); |
728 | if (auto plan_ptr = std::get_if<TransformLayoutPlanner::ProloguePlan>(&plan)) { |
729 | auto write_ptr = result.CopyOnWrite(); |
730 | write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body}); |
731 | } |
732 | |
733 | Map<Block, Block> block_sref_reuse; |
734 | for (auto [after, before] : rewriter.new_block_to_old_) { |
735 | while (auto opt = rewriter.new_block_to_old_.Get(before)) { |
736 | before = opt.value(); |
737 | } |
738 | while (auto opt = block_sref_reuse.Get(after)) { |
739 | after = opt.value(); |
740 | } |
741 | |
742 | block_sref_reuse.Set(before, after); |
743 | } |
744 | |
745 | return {result, block_sref_reuse}; |
746 | } |
747 | |
748 | private: |
749 | TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer, |
750 | const IndexMap& index_map, |
751 | const TransformLayoutPlanner::TransformPlan& plan, |
752 | arith::Analyzer* analyzer) |
753 | : IRMutatorWithAnalyzer(analyzer), |
754 | old_buffer_(old_buffer), |
755 | new_buffer_(new_buffer), |
756 | index_map_(index_map), |
757 | plan_(plan), |
758 | buffer_data_to_buffer_{{new_buffer->data, new_buffer}} { |
759 | if (auto plan_ptr = std::get_if<TransformLayoutPlanner::ReplacementPlan>(&plan_)) { |
760 | new_block_to_old_ = plan_ptr->new_block_to_old; |
761 | } |
762 | } |
763 | |
764 | void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) { |
765 | *buffer = new_buffer_; |
766 | *indices = index_map_->MapIndices(*indices); |
767 | (*indices).MutateByApply( |
768 | [&](const PrimExpr& e) { return SimplifyNonTrivialExpr(e, analyzer_); }); |
769 | } |
770 | |
771 | using Parent = arith::IRMutatorWithAnalyzer; |
772 | using Parent::VisitExpr_; |
773 | using Parent::VisitStmt_; |
774 | |
775 | Stmt VisitStmt(const Stmt& stmt) final { |
776 | Stmt output = Parent::VisitStmt(stmt); |
777 | if (auto plan_ptr = std::get_if<TransformLayoutPlanner::EpiloguePlan>(&plan_)) { |
778 | if (plan_ptr->insert_after.same_as(stmt)) { |
779 | return SeqStmt({output, plan_ptr->new_block}); |
780 | } |
781 | } |
782 | return output; |
783 | } |
784 | |
785 | Stmt VisitStmt_(const ForNode* op) final { |
786 | // Some replacements may include the original string, such as |
787 | // replacing `loop` with `{loop, post_proc}`. In this case, avoid |
788 | // infinite recursion. |
789 | |
790 | For node = GetRef<For>(op); |
791 | if (auto plan_ptr = std::get_if<TransformLayoutPlanner::ReplacementPlan>(&plan_)) { |
792 | auto it = plan_ptr->replacements.find(node); |
793 | if (it != plan_ptr->replacements.end()) { |
794 | return VisitStmt((*it).second); |
795 | } |
796 | } |
797 | return Parent::VisitStmt_(op); |
798 | } |
799 | |
800 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
801 | BufferLoad buffer_load = Downcast<BufferLoad>(Parent::VisitExpr_(op)); |
802 | if (buffer_load->buffer.same_as(old_buffer_)) { |
803 | auto* n = buffer_load.CopyOnWrite(); |
804 | RewriteBufferAccess(&n->buffer, &n->indices); |
805 | } |
806 | return std::move(buffer_load); |
807 | } |
808 | |
809 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
810 | BufferStore buffer_store = Downcast<BufferStore>(Parent::VisitStmt_(op)); |
811 | if (buffer_store->buffer.same_as(old_buffer_)) { |
812 | auto* n = buffer_store.CopyOnWrite(); |
813 | RewriteBufferAccess(&n->buffer, &n->indices); |
814 | } |
815 | return std::move(buffer_store); |
816 | } |
817 | |
818 | void RewriteAccessRegion(Array<BufferRegion>* old_access_regions, |
819 | const Array<BufferRegion>& infered_access_regions) { |
820 | auto fmutate = [this, &infered_access_regions](const BufferRegion& buffer_region) { |
821 | if (buffer_region->buffer.same_as(old_buffer_)) { |
822 | ICHECK(infered_access_regions.size() == 1); |
823 | return infered_access_regions[0]; |
824 | } |
825 | return buffer_region; |
826 | }; |
827 | (*old_access_regions).MutateByApply(fmutate); |
828 | } |
829 | |
830 | Stmt VisitStmt_(const BlockNode* op) final { |
831 | Block orig = [&]() { |
832 | Block block = GetRef<Block>(op); |
833 | while (true) { |
834 | if (auto it = new_block_to_old_.find(block); it != new_block_to_old_.end()) { |
835 | block = (*it).second; |
836 | } else { |
837 | break; |
838 | } |
839 | } |
840 | return block; |
841 | }(); |
842 | |
843 | Block block = Downcast<Block>(Parent::VisitStmt_(op)); |
844 | |
845 | auto infered_access_regions = GetBlockReadWriteRegion(block, buffer_data_to_buffer_); |
846 | auto* n = block.CopyOnWrite(); |
847 | RewriteAccessRegion(&n->reads, infered_access_regions[0]); |
848 | RewriteAccessRegion(&n->writes, infered_access_regions[1]); |
849 | n->alloc_buffers.MutateByApply([this](const Buffer& buffer) { |
850 | if (buffer.same_as(old_buffer_)) { |
851 | return new_buffer_; |
852 | } else { |
853 | return buffer; |
854 | } |
855 | }); |
856 | |
857 | RecordReplacement(orig, block); |
858 | return std::move(block); |
859 | } |
860 | |
861 | void RecordReplacement(Block before, Block after) { |
862 | if (before.same_as(after)) { |
863 | return; |
864 | } |
865 | |
866 | ICHECK(!new_block_to_old_.count(after)); |
867 | |
868 | while (true) { |
869 | if (auto opt = new_block_to_old_.Get(before)) { |
870 | before = opt.value(); |
871 | } else { |
872 | break; |
873 | } |
874 | } |
875 | |
876 | new_block_to_old_.Set(after, before); |
877 | } |
878 | |
879 | const Buffer& old_buffer_; |
880 | const Buffer& new_buffer_; |
881 | const IndexMap& index_map_; |
882 | const TransformLayoutPlanner::TransformPlan& plan_; |
883 | Map<Var, Buffer> buffer_data_to_buffer_; |
884 | Map<Block, Block> new_block_to_old_; |
885 | }; |
886 | |
887 | class BufferIsSubregionError : public ScheduleError { |
888 | public: |
889 | explicit BufferIsSubregionError(IRModule mod, Buffer buffer) : mod_(mod), buffer_(buffer) {} |
890 | |
891 | String FastErrorString() const final { |
892 | return "ScheduleError: The input buffer is defined in `match_buffer` of a block, it is expected" |
893 | " to be a function parameter or allocated by a block" ; |
894 | } |
895 | |
896 | String DetailRenderTemplate() const final { |
897 | std::ostringstream os; |
898 | os << "ScheduleError: The input buffer " << buffer_->name << " is defined in `match_buffer` of " |
899 | << "a block, it is expected to be a function parameter or allocated by a block." ; |
900 | return os.str(); |
901 | } |
902 | |
903 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
904 | IRModule mod() const final { return mod_; } |
905 | |
906 | private: |
907 | IRModule mod_; |
908 | Buffer buffer_; |
909 | }; |
910 | |
911 | class TransformationPaddingIndexMapError : public ScheduleError { |
912 | public: |
913 | TransformationPaddingIndexMapError(IRModule mod, IndexMap pad_value) |
914 | : mod_(mod), pad_value_(pad_value) {} |
915 | |
916 | String FastErrorString() const final { |
917 | std::ostringstream ss; |
918 | ss << "ScheduleError: The IndexMap specifying pad_value has " |
919 | << pad_value_->final_indices.size() << " outputs, should only have one output" ; |
920 | return ss.str(); |
921 | } |
922 | |
923 | String DetailRenderTemplate() const final { |
924 | std::ostringstream ss; |
925 | ss << "ScheduleError: Pad value is specified as " << pad_value_ << " which has " |
926 | << pad_value_->final_indices.size() << " outputs, but should only have one output" ; |
927 | return ss.str(); |
928 | } |
929 | |
930 | IRModule mod() const final { return mod_; } |
931 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
932 | |
933 | private: |
934 | IRModule mod_; |
935 | IndexMap pad_value_; |
936 | }; |
937 | |
938 | class TransformationPaddingTypeError : public ScheduleError { |
939 | public: |
940 | TransformationPaddingTypeError(IRModule mod, Buffer buffer, IndexMap pad_value) |
941 | : mod_(mod), buffer_(buffer), pad_value_(pad_value) { |
942 | ICHECK_EQ(pad_value_->final_indices.size(), 1); |
943 | pad_value_dtype_ = pad_value_->final_indices[0].dtype(); |
944 | } |
945 | |
946 | String FastErrorString() const final { |
947 | std::ostringstream ss; |
948 | ss << "ScheduleError: Type mismatch " << buffer_->dtype << " vs " << pad_value_dtype_; |
949 | return ss.str(); |
950 | } |
951 | |
952 | String DetailRenderTemplate() const final { |
953 | std::ostringstream ss; |
954 | ss << "ScheduleError: Buffer " << buffer_->name << " has elements of type " << buffer_->dtype |
955 | << ", but the transformation fills padding with " << pad_value_ << ", which is of type " |
956 | << pad_value_dtype_; |
957 | return ss.str(); |
958 | } |
959 | |
960 | IRModule mod() const final { return mod_; } |
961 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
962 | |
963 | private: |
964 | IRModule mod_; |
965 | Buffer buffer_; |
966 | IndexMap pad_value_; |
967 | DataType pad_value_dtype_; |
968 | }; |
969 | |
970 | class TransformationPaddingExpressionError : public ScheduleError { |
971 | public: |
972 | static void Check(IRModule mod, Buffer buffer, IndexMap pad_value) { |
973 | Visitor visitor(buffer); |
974 | ICHECK_EQ(pad_value->final_indices.size(), 1) |
975 | << "Internal error: Should be caught by ScheduleError checks prior to this point" ; |
976 | visitor(pad_value->final_indices[0]); |
977 | if (visitor.illegal_load) { |
978 | throw TransformationPaddingExpressionError(mod, buffer, pad_value, |
979 | visitor.illegal_load.value()); |
980 | } |
981 | } |
982 | |
983 | private: |
984 | struct Visitor : ExprVisitor { |
985 | explicit Visitor(const Buffer& buffer) : buffer_(buffer) {} |
986 | |
987 | void VisitExpr_(const BufferLoadNode* op) final { |
988 | if (!op->buffer.same_as(buffer_)) { |
989 | illegal_load = GetRef<BufferLoad>(op); |
990 | } |
991 | ExprVisitor::VisitExpr_(op); |
992 | } |
993 | |
994 | const Buffer& buffer_; |
995 | Optional<BufferLoad> illegal_load; |
996 | }; |
997 | |
998 | TransformationPaddingExpressionError(IRModule mod, Buffer buffer, IndexMap pad_value, |
999 | BufferLoad illegal_load) |
1000 | : mod_(mod), buffer_(buffer), pad_value_(pad_value), illegal_load_(illegal_load) {} |
1001 | |
1002 | String FastErrorString() const final { |
1003 | std::ostringstream ss; |
1004 | ss << "ScheduleError: Pad value may not contain load load from " << illegal_load_->buffer->name; |
1005 | return ss.str(); |
1006 | } |
1007 | |
1008 | String DetailRenderTemplate() const final { |
1009 | std::ostringstream ss; |
1010 | ss << "ScheduleError: Pad value may only contain BufferLoad from the transformed buffer " |
1011 | << buffer_->name << ", but pad_value " << pad_value_ << " contains expression " |
1012 | << illegal_load_; |
1013 | return ss.str(); |
1014 | } |
1015 | |
1016 | IRModule mod() const final { return mod_; } |
1017 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
1018 | |
1019 | IRModule mod_; |
1020 | Buffer buffer_; |
1021 | IndexMap pad_value_; |
1022 | BufferLoad illegal_load_; |
1023 | }; |
1024 | |
1025 | class TransformationIntroducesPaddingError : public ScheduleError { |
1026 | public: |
1027 | TransformationIntroducesPaddingError(IRModule mod, Buffer buffer, IndexMap index_map, |
1028 | PrimExpr padding_predicate) |
1029 | : mod_(std::move(mod)), |
1030 | buffer_(std::move(buffer)), |
1031 | index_map_(std::move(index_map)), |
1032 | padding_predicate_(std::move(padding_predicate)) {} |
1033 | |
1034 | String FastErrorString() const final { |
1035 | std::ostringstream ss; |
1036 | ss << "ScheduleError: Transformation would introduce padding at " << padding_predicate_ << "." ; |
1037 | return ss.str(); |
1038 | } |
1039 | |
1040 | String DetailRenderTemplate() const final { |
1041 | auto new_shape = index_map_->MapShape(buffer_->shape); |
1042 | std::ostringstream os; |
1043 | os << "The transformation " << index_map_ << " applied on buffer " << buffer_->name |
1044 | << " of shape " << buffer_->shape << " would result in shape " << new_shape |
1045 | << ". However, this would introduce padding wherever " << padding_predicate_ << " is true." ; |
1046 | return os.str(); |
1047 | } |
1048 | |
1049 | IRModule mod() const final { return mod_; } |
1050 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
1051 | |
1052 | private: |
1053 | IRModule mod_; |
1054 | Buffer buffer_; |
1055 | IndexMap index_map_; |
1056 | PrimExpr padding_predicate_; |
1057 | }; |
1058 | |
1059 | // Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid |
1060 | // dtype-mismatch issues later. |
1061 | IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>& args) { |
1062 | const auto& initial_indices_orig = index_map->initial_indices; |
1063 | ICHECK(args.size() == initial_indices_orig.size()); |
1064 | |
1065 | Array<Var> initial_indices; |
1066 | Map<Var, PrimExpr> var_map; |
1067 | |
1068 | for (size_t i = 0; i < args.size(); ++i) { |
1069 | if (args[i]->dtype != initial_indices_orig[i].dtype()) { |
1070 | auto new_idx = Var(initial_indices_orig[i]->name_hint, args[i]->dtype); |
1071 | initial_indices.push_back(new_idx); |
1072 | var_map.Set(initial_indices_orig[i], new_idx); |
1073 | } else { |
1074 | initial_indices.push_back(initial_indices_orig[i]); |
1075 | } |
1076 | } |
1077 | |
1078 | if (!var_map.empty()) { |
1079 | auto final_indices = index_map->final_indices.Map([&](PrimExpr index) { |
1080 | return SubstituteWithDataTypeLegalization(index, |
1081 | [&](const Var& var) { return var_map.Get(var); }); |
1082 | }); |
1083 | Optional<IndexMap> opt_inverse_index_map = |
1084 | Downcast<Optional<IndexMap>>(index_map->inverse_index_map); |
1085 | if (opt_inverse_index_map.defined()) { |
1086 | opt_inverse_index_map = LegalizeIndexMapDType(opt_inverse_index_map.value(), final_indices); |
1087 | } |
1088 | return IndexMap(initial_indices, final_indices, opt_inverse_index_map); |
1089 | } |
1090 | return index_map; |
1091 | } |
1092 | |
1093 | void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
1094 | BufferIndexType buffer_index_type, const IndexMap& index_map_orig, |
1095 | const Optional<IndexMap>& pad_value) { |
1096 | // Step 1: Input handling and error checking |
1097 | const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); |
1098 | Buffer old_buffer = |
1099 | GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, buffer_index_type); |
1100 | |
1101 | auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer->shape); |
1102 | |
1103 | auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer); |
1104 | if (defining_site_sref.defined() && !is_alloc) { |
1105 | throw BufferIsSubregionError(self->mod, old_buffer); |
1106 | } |
1107 | if (pad_value) { |
1108 | if (pad_value.value()->final_indices.size() != 1) { |
1109 | throw TransformationPaddingIndexMapError(self->mod, pad_value.value()); |
1110 | } |
1111 | if (pad_value.value()->final_indices[0]->dtype != old_buffer->dtype) { |
1112 | throw TransformationPaddingTypeError(self->mod, old_buffer, pad_value.value()); |
1113 | } |
1114 | |
1115 | TransformationPaddingExpressionError::Check(self->mod, old_buffer, pad_value.value()); |
1116 | } |
1117 | |
1118 | StmtSRef scope_sref = defining_site_sref.defined() |
1119 | ? defining_site_sref.value() |
1120 | : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); |
1121 | const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); |
1122 | |
1123 | auto [inverse, padding_predicate] = [&]() { |
1124 | Array<Range> region; |
1125 | for (const auto& dim : old_buffer->shape) { |
1126 | region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim)); |
1127 | } |
1128 | return index_map.NonSurjectiveInverse(region); |
1129 | }(); |
1130 | |
1131 | bool has_padding = !is_zero(padding_predicate); |
1132 | if (has_padding && !pad_value.defined()) { |
1133 | throw TransformationIntroducesPaddingError(self->mod, old_buffer, index_map, padding_predicate); |
1134 | } |
1135 | |
1136 | // Step 2: Infer the shape of the new buffer |
1137 | Buffer new_buffer = old_buffer; |
1138 | new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape); |
1139 | |
1140 | // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block |
1141 | // alloc_buffers. |
1142 | auto [new_stmt, block_sref_reuse] = |
1143 | TransformLayoutRewriter::Rewrite(GetRef<Block>(scope_block), old_buffer, new_buffer, |
1144 | index_map, inverse, padding_predicate, pad_value); |
1145 | Block new_scope_block = Downcast<Block>(new_stmt); |
1146 | |
1147 | // Step 4: Rewrite buffer_map of the PrimFunc if necessary. |
1148 | if (!defining_site_sref.defined()) { |
1149 | GlobalVar g_var; |
1150 | GetRootPrimFunc(self->mod, scope_block, &g_var); |
1151 | IRModuleNode* new_mod = self->mod.CopyOnWrite(); |
1152 | MapNode* new_map = new_mod->functions.CopyOnWrite(); |
1153 | PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var))); |
1154 | PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); |
1155 | MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite(); |
1156 | for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) { |
1157 | if ((*it).second.same_as(old_buffer)) { |
1158 | (*it).second = new_buffer; |
1159 | } |
1160 | } |
1161 | new_map->at(g_var) = std::move(ref_new_func); |
1162 | } |
1163 | |
1164 | // Step 4: Replace the scope block with the new block |
1165 | self->Replace(scope_sref, new_scope_block, block_sref_reuse); |
1166 | } |
1167 | |
1168 | /*! |
1169 | * \brief Detect the block iter type assoicated with the expression |
1170 | * |
1171 | * This function collects block iters in the expression and check if the block iters have the same |
1172 | * iter type. The detected iter type is the iter type of the block iters in the expression |
1173 | * if they have the same iter type, otherwise the detected iter type will be kOpaque. |
1174 | * |
1175 | * \param expr The expression |
1176 | * \param block_iter_type_map The mapping from block iter to iter type |
1177 | * \return The detected block iter type |
1178 | */ |
1179 | IterVarType DetectNewBlockIterType( |
1180 | const PrimExpr& expr, |
1181 | const std::unordered_map<const VarNode*, IterVarType>& block_iter_type_map) { |
1182 | IterVarType result{kOpaque}; |
1183 | bool found = false; |
1184 | PostOrderVisit(expr, [&](const ObjectRef& obj) { |
1185 | if (const VarNode* var = obj.as<VarNode>()) { |
1186 | auto it = block_iter_type_map.find(var); |
1187 | if (it != block_iter_type_map.end()) { |
1188 | if (!found) { |
1189 | found = true; |
1190 | result = it->second; |
1191 | } else if (result != it->second) { |
1192 | result = kOpaque; |
1193 | return false; |
1194 | } |
1195 | } |
1196 | } |
1197 | return true; |
1198 | }); |
1199 | return result; |
1200 | } |
1201 | |
1202 | class NotBijectiveAffineIndexMapError : public ScheduleError { |
1203 | public: |
1204 | NotBijectiveAffineIndexMapError(IRModule mod, IndexMap index_map) |
1205 | : mod_(std::move(mod)), index_map_(std::move(index_map)) {} |
1206 | String FastErrorString() const final { |
1207 | return "ScheduleError: The index map is not bijective affine." ; |
1208 | } |
1209 | |
1210 | String DetailRenderTemplate() const final { |
1211 | std::ostringstream os; |
1212 | os << "The index map " << index_map_->ToPythonString() << " is not bijective affine." ; |
1213 | return os.str(); |
1214 | } |
1215 | |
1216 | IRModule mod() const final { return mod_; } |
1217 | |
1218 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
1219 | |
1220 | private: |
1221 | IRModule mod_; |
1222 | IndexMap index_map_; |
1223 | }; |
1224 | |
1225 | class IndexMapNotApplicableToBlockIterError : public ScheduleError { |
1226 | public: |
1227 | static void Check(const IRModule mod, const Block& block, const IndexMap& index_map) { |
1228 | if (index_map->initial_indices.size() != block->iter_vars.size()) { |
1229 | throw IndexMapNotApplicableToBlockIterError(mod, block, index_map); |
1230 | } |
1231 | } |
1232 | explicit IndexMapNotApplicableToBlockIterError(IRModule mod, Block block, IndexMap index_map) |
1233 | : mod_(std::move(mod)), block_(std::move(block)), index_map_(std::move(index_map)) {} |
1234 | |
1235 | String FastErrorString() const final { |
1236 | return "ScheduleError: The index map can't be applied to block iters because the number of " |
1237 | "parameters mismatch." ; |
1238 | } |
1239 | |
1240 | String DetailRenderTemplate() const final { |
1241 | std::ostringstream os; |
1242 | os << "The index map " << index_map_->ToPythonString() |
1243 | << " can't be applied to block iters of {0} because the number of parameters mismatch. " |
1244 | "Expected: " |
1245 | << index_map_->initial_indices.size() << ", actual: " << block_->iter_vars.size(); |
1246 | return os.str(); |
1247 | } |
1248 | |
1249 | IRModule mod() const final { return mod_; } |
1250 | |
1251 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
1252 | |
1253 | private: |
1254 | IRModule mod_; |
1255 | Block block_; |
1256 | IndexMap index_map_; |
1257 | }; |
1258 | |
1259 | class OpaqueNewIterTypeError : public ScheduleError { |
1260 | public: |
1261 | explicit OpaqueNewIterTypeError(IRModule mod, Block block, PrimExpr iter_value) |
1262 | : mod_(std::move(mod)), block_(std::move(block)), iter_value_(std::move(iter_value)) {} |
1263 | |
1264 | String FastErrorString() const final { |
1265 | return "ScheduleError: Cannot detect the new block iter type because it contains more than one " |
1266 | "type of original iter vars." ; |
1267 | } |
1268 | |
1269 | String DetailRenderTemplate() const final { |
1270 | std::ostringstream os; |
1271 | os << "Cannot detect the block iter type for new iter value " << iter_value_ |
1272 | << " in {0} because it contains more than one type of original iter vars." ; |
1273 | return os.str(); |
1274 | } |
1275 | |
1276 | IRModule mod() const final { return mod_; } |
1277 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
1278 | |
1279 | private: |
1280 | IRModule mod_; |
1281 | Block block_; |
1282 | PrimExpr iter_value_; |
1283 | }; |
1284 | |
1285 | void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref, |
1286 | const IndexMap& index_map) { |
1287 | const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); |
1288 | const Block& block = GetRef<Block>(block_ptr); |
1289 | arith::Analyzer analyzer; |
1290 | |
1291 | // Step 1: Collect outer loops and loop vars |
1292 | Array<StmtSRef> loops = GetLoops(block_sref); // outer loops of the block |
1293 | std::unordered_set<const VarNode*> loop_vars; // loop vars of the outer loops |
1294 | for (const StmtSRef& loop_sref : loops) { |
1295 | CheckLoopStartsWithZero(self, loop_sref, &analyzer); |
1296 | loop_vars.emplace(loop_sref->StmtAs<ForNode>()->loop_var.get()); |
1297 | } |
1298 | |
1299 | // Step 2: Check the all outer loops have a single child and the block bindings are trivial (all |
1300 | // binding values are loop vars) |
1301 | StmtSRef scope_sref{nullptr}; // the scope statement for replacement |
1302 | if (!loops.empty()) { |
1303 | scope_sref = loops.front(); |
1304 | CheckGetSingleChildBlockRealizeOnSRefTree(self, loops.front()); |
1305 | } else { |
1306 | scope_sref = block_sref; |
1307 | } |
1308 | |
1309 | BlockRealize block_realize = GetBlockRealize(self, block_sref); |
1310 | CheckBlockHasTrivialBinding(self, block_sref); |
1311 | |
1312 | // Step 3: Collect information of block iter vars |
1313 | Array<PrimExpr> block_vars; // iter_var->var of each block iter |
1314 | Map<Var, Range> block_iter_dom; // domain of block iter |
1315 | std::unordered_map<const VarNode*, IterVarType> block_iter_type; // iter type of block iter |
1316 | |
1317 | Array<PrimExpr> |
1318 | block_iter_range_array; // array of block iter extents in the same order as block iters |
1319 | for (const auto& iter_var : block->iter_vars) { |
1320 | block_vars.push_back(iter_var->var); |
1321 | block_iter_dom.Set(iter_var->var, iter_var->dom); |
1322 | block_iter_type[iter_var->var.get()] = iter_var->iter_type; |
1323 | ICHECK(is_zero(iter_var->dom->min)); |
1324 | block_iter_range_array.push_back(iter_var->dom->extent); |
1325 | } |
1326 | |
1327 | // Step 4: Apply the IndexMap to block iters. |
1328 | IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map); |
1329 | Array<PrimExpr> transformed_block_iters = index_map->MapIndices(block_vars); |
1330 | Array<PrimExpr> new_block_iter_range = index_map->MapShape(block_iter_range_array); |
1331 | |
1332 | // Step 5: Create the new block after transformation. |
1333 | |
1334 | // Step 5.1: Create new block iters. After applying the IndexMap f to block iters ax_0, ..., ax_n, |
1335 | // create block iter each expression in f(ax_0, ..., ax_n). |
1336 | Array<IterVar> new_block_iters; // new block iters |
1337 | Array<PrimExpr> new_block_vars; // iter_var->var of new block iters |
1338 | for (size_t i = 0; i < transformed_block_iters.size(); ++i) { |
1339 | Var new_block_var{"v" + std::to_string(i), transformed_block_iters[i]->dtype}; |
1340 | new_block_vars.push_back(new_block_var); |
1341 | IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type); |
1342 | if (iter_type == kOpaque) { |
1343 | throw OpaqueNewIterTypeError(self->mod, GetRef<Block>(block_ptr), transformed_block_iters[i]); |
1344 | } |
1345 | auto dtype = new_block_var.dtype(); |
1346 | new_block_iters.push_back(IterVar( |
1347 | /*dom=*/Range::FromMinExtent(make_zero(dtype), new_block_iter_range[i]), |
1348 | /*var=*/std::move(new_block_var), /*iter_type=*/iter_type)); |
1349 | } |
1350 | |
1351 | // Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters |
1352 | // in the body. |
1353 | Map<Var, PrimExpr> inverse_subst_map; |
1354 | // Construct the inverse map |
1355 | { |
1356 | Array<Range> initial_ranges; |
1357 | for (const PrimExpr& extent : block_iter_range_array) { |
1358 | initial_ranges.push_back(Range::FromMinExtent(make_const(extent.dtype(), 0), extent)); |
1359 | } |
1360 | IndexMap inverse_index_map{nullptr}; |
1361 | try { |
1362 | inverse_index_map = index_map.Inverse(initial_ranges); |
1363 | } catch (...) { |
1364 | throw NotBijectiveAffineIndexMapError(self->mod, index_map); |
1365 | } |
1366 | |
1367 | Array<PrimExpr> inversed_new_block_vars = inverse_index_map->MapIndices( |
1368 | new_block_vars); // old block vars written in terms of new block vars |
1369 | |
1370 | for (int i = 0, n = block_vars.size(); i < n; ++i) { |
1371 | inverse_subst_map.Set(Downcast<Var>(block_vars[i]), inversed_new_block_vars[i]); |
1372 | } |
1373 | } |
1374 | Block new_block = Downcast<Block>(Substitute(GetRef<Block>(block_ptr), inverse_subst_map)); |
1375 | new_block.CopyOnWrite()->iter_vars = new_block_iters; |
1376 | new_block = Downcast<Block>(BlockBufferAccessSimplifier::Simplify(new_block, &analyzer)); |
1377 | |
1378 | // Step 5.3: Create outer loops for each new block iter. |
1379 | |
1380 | // Make new loop vars |
1381 | Array<PrimExpr> new_loop_vars; |
1382 | for (int i = 0; i < static_cast<int>(new_block_iters.size()); ++i) { |
1383 | new_loop_vars.push_back(Var("ax" + std::to_string(i), new_block_iters[i]->var.dtype())); |
1384 | } |
1385 | |
1386 | // Make new block realize |
1387 | BlockRealizeNode* new_block_realize = block_realize.CopyOnWrite(); |
1388 | new_block_realize->iter_values = new_loop_vars; |
1389 | new_block_realize->block = new_block; |
1390 | |
1391 | // Generate outer loops |
1392 | Stmt body = GetRef<Stmt>(new_block_realize); |
1393 | for (int i = static_cast<int>(new_loop_vars.size()) - 1; i >= 0; --i) { |
1394 | body = For(Downcast<Var>(new_loop_vars[i]), 0, new_block_iter_range[i], ForKind::kSerial, |
1395 | std::move(body)); |
1396 | } |
1397 | |
1398 | // Step 6: Do the actual replacement |
1399 | self->Replace(scope_sref, body, {{block, new_block}}); |
1400 | } |
1401 | |
1402 | class BufferAxisSeparatorMutator : private ReplaceBufferMutator { |
1403 | public: |
1404 | static Block Mutate(const Block& scope_block, const Buffer& old_buffer, Buffer new_buffer, |
1405 | Map<Block, Block>* block_sref_reuse) { |
1406 | BufferAxisSeparatorMutator mutator(old_buffer, std::move(new_buffer), block_sref_reuse); |
1407 | return Downcast<Block>(mutator.VisitStmt(scope_block)); |
1408 | } |
1409 | |
1410 | private: |
1411 | BufferAxisSeparatorMutator(const Buffer& old_buffer, Buffer new_buffer, |
1412 | Map<Block, Block>* block_sref_reuse) |
1413 | : ReplaceBufferMutator(old_buffer, new_buffer, block_sref_reuse) {} |
1414 | |
1415 | MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final { |
1416 | auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get()); |
1417 | if (it != buffer_var_map_.end()) { |
1418 | const Buffer& new_source_buffer = it->second; |
1419 | Buffer new_target_buffer = match_buffer->buffer; |
1420 | new_target_buffer.CopyOnWrite()->axis_separators = new_source_buffer->axis_separators; |
1421 | if (new_target_buffer->shape.size() != new_source_buffer->shape.size()) { |
1422 | LOG(WARNING) |
1423 | << "Target buffer in match_buffer doesn't have the same dimensionality as its source " |
1424 | "buffer. `axis_separators` for the target buffer might be incorrect." ; |
1425 | } |
1426 | buffer_var_map_[new_target_buffer->data.get()] = new_target_buffer; |
1427 | return MatchBufferRegion(new_target_buffer, |
1428 | BufferRegion(new_source_buffer, match_buffer->source->region)); |
1429 | } |
1430 | return match_buffer; |
1431 | } |
1432 | }; |
1433 | |
1434 | void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer_index, |
1435 | BufferIndexType buffer_index_type, const Array<IntImm>& axis_separators) { |
1436 | const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); |
1437 | Buffer old_buffer = |
1438 | GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, buffer_index_type); |
1439 | auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer); |
1440 | if (defining_site_sref.defined() && !is_alloc) { |
1441 | throw BufferIsSubregionError(self->mod, old_buffer); |
1442 | } |
1443 | |
1444 | StmtSRef scope_sref = defining_site_sref.defined() |
1445 | ? defining_site_sref.value() |
1446 | : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); |
1447 | const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); |
1448 | |
1449 | // Step 1: Check and update axis_separators of the buffer. |
1450 | Buffer new_buffer = old_buffer; |
1451 | new_buffer.CopyOnWrite()->axis_separators = axis_separators; |
1452 | |
1453 | Map<Block, Block> block_sref_reuse; |
1454 | |
1455 | // Step 2: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc. |
1456 | Block new_scope_block = BufferAxisSeparatorMutator::Mutate(GetRef<Block>(scope_block), old_buffer, |
1457 | new_buffer, &block_sref_reuse); |
1458 | if (!defining_site_sref.defined()) { |
1459 | // mutate buffer_map of the PrimFunc |
1460 | GlobalVar g_var; |
1461 | GetRootPrimFunc(self->mod, scope_block, &g_var); |
1462 | IRModuleNode* new_mod = self->mod.CopyOnWrite(); |
1463 | MapNode* new_map = new_mod->functions.CopyOnWrite(); |
1464 | PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var))); |
1465 | PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); |
1466 | MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite(); |
1467 | for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) { |
1468 | if ((*it).second.same_as(old_buffer)) { |
1469 | (*it).second = new_buffer; |
1470 | } |
1471 | } |
1472 | new_map->at(g_var) = std::move(ref_new_func); |
1473 | } |
1474 | |
1475 | // Step 4: Replace the scope block with the new block |
1476 | self->Replace(scope_sref, new_scope_block, block_sref_reuse); |
1477 | } |
1478 | |
1479 | /******** InstructionKind Registration ********/ |
1480 | |
1481 | struct TransformLayoutTraits : public UnpackedInstTraits<TransformLayoutTraits> { |
1482 | static constexpr const char* kName = "TransformLayout" ; |
1483 | static constexpr bool kIsPure = false; |
1484 | |
1485 | private: |
1486 | static constexpr size_t kNumInputs = 2; |
1487 | static constexpr size_t kNumAttrs = 3; |
1488 | static constexpr size_t kNumDecisions = 0; |
1489 | |
1490 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map, |
1491 | Integer buffer_index, Integer buffer_index_type, |
1492 | Optional<IndexMap> pad_value) { |
1493 | return sch->TransformLayout(block_rv, buffer_index.IntValue(), |
1494 | static_cast<BufferIndexType>(buffer_index_type->value), index_map, |
1495 | pad_value); |
1496 | } |
1497 | |
1498 | static String UnpackedAsPython(Array<String> outputs, String block_rv, IndexMap index_map, |
1499 | Integer buffer_index, Integer buffer_index_type, |
1500 | Optional<IndexMap> pad_value) { |
1501 | PythonAPICall py("transform_layout" ); |
1502 | py.Input("block" , block_rv); |
1503 | |
1504 | std::ostringstream os; |
1505 | os << "(\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value)) |
1506 | << "\", " << buffer_index << ")" ; |
1507 | py.Input("buffer" , os.str()); |
1508 | py.Input("index_map" , index_map->ToPythonString()); |
1509 | py.Input("pad_value" , pad_value ? pad_value.value()->ToPythonString() : "None" ); |
1510 | |
1511 | return py.Str(); |
1512 | } |
1513 | |
1514 | public: |
1515 | static ObjectRef AttrsAsJSON(const Array<ObjectRef>& attrs) { |
1516 | Array<ObjectRef> attrs_record; |
1517 | attrs_record.reserve(kNumAttrs); |
1518 | attrs_record.push_back(attrs[0]); |
1519 | attrs_record.push_back(attrs[1]); |
1520 | if (attrs[2].defined()) { |
1521 | attrs_record.push_back(String(::tvm::SaveJSON(attrs[2]))); |
1522 | } else { |
1523 | attrs_record.push_back(attrs[2]); |
1524 | } |
1525 | return std::move(attrs_record); |
1526 | } |
1527 | |
1528 | static Array<ObjectRef> AttrsFromJSON(const ObjectRef& attrs_record_) { |
1529 | Array<ObjectRef> attrs_record = Downcast<Array<ObjectRef>>(attrs_record_); |
1530 | Array<ObjectRef> attrs; |
1531 | attrs.push_back(attrs_record[0]); |
1532 | attrs.push_back(attrs_record[1]); |
1533 | if (attrs_record[2].defined()) { |
1534 | attrs.push_back(::tvm::LoadJSON(Downcast<String>(attrs_record[2]))); |
1535 | } else { |
1536 | attrs.push_back(attrs_record[2]); |
1537 | } |
1538 | return attrs; |
1539 | } |
1540 | |
1541 | template <typename> |
1542 | friend struct ::tvm::tir::UnpackedInstTraits; |
1543 | }; |
1544 | |
1545 | struct TransformBlockLayoutTraits : public UnpackedInstTraits<TransformBlockLayoutTraits> { |
1546 | static constexpr const char* kName = "TransformBlockLayout" ; |
1547 | static constexpr bool kIsPure = false; |
1548 | |
1549 | private: |
1550 | static constexpr size_t kNumInputs = 1; |
1551 | static constexpr size_t kNumAttrs = 1; |
1552 | static constexpr size_t kNumDecisions = 0; |
1553 | |
1554 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, IndexMap index_map) { |
1555 | return sch->TransformBlockLayout(block_rv, index_map); |
1556 | } |
1557 | |
1558 | static String UnpackedAsPython(Array<String> outputs, String block_rv, IndexMap index_map) { |
1559 | PythonAPICall py("transform_block_layout" ); |
1560 | py.Input("block" , block_rv); |
1561 | py.Input("index_map" , index_map->ToPythonString()); |
1562 | return py.Str(); |
1563 | } |
1564 | |
1565 | public: |
1566 | static ObjectRef AttrsAsJSON(const Array<ObjectRef>& attrs) { |
1567 | Array<ObjectRef> attrs_record; |
1568 | attrs_record.reserve(kNumAttrs); |
1569 | attrs_record.push_back(String(::tvm::SaveJSON(attrs[0]))); |
1570 | return std::move(attrs_record); |
1571 | } |
1572 | |
1573 | static Array<ObjectRef> AttrsFromJSON(const ObjectRef& attrs_record_) { |
1574 | Array<ObjectRef> attrs_record = Downcast<Array<ObjectRef>>(attrs_record_); |
1575 | Array<ObjectRef> attrs; |
1576 | attrs.push_back(::tvm::LoadJSON(Downcast<String>(attrs_record[0]))); |
1577 | return attrs; |
1578 | } |
1579 | |
1580 | template <typename> |
1581 | friend struct ::tvm::tir::UnpackedInstTraits; |
1582 | }; |
1583 | |
1584 | struct SetAxisSeparatorTraits : public UnpackedInstTraits<SetAxisSeparatorTraits> { |
1585 | static constexpr const char* kName = "SetAxisSeparator" ; |
1586 | static constexpr bool kIsPure = false; |
1587 | |
1588 | private: |
1589 | static constexpr size_t kNumInputs = 1; |
1590 | static constexpr size_t kNumAttrs = 3; |
1591 | static constexpr size_t kNumDecisions = 0; |
1592 | |
1593 | static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, |
1594 | Integer buffer_index_type, Array<IntImm> axis_separators) { |
1595 | return sch->SetAxisSeparator(block_rv, buffer_index.IntValue(), |
1596 | static_cast<BufferIndexType>(buffer_index_type->value), |
1597 | axis_separators); |
1598 | } |
1599 | |
1600 | static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index, |
1601 | Integer buffer_index_type, Array<IntImm> axis_separators) { |
1602 | PythonAPICall py("set_axis_separator" ); |
1603 | py.Input("block" , block_rv); |
1604 | |
1605 | std::ostringstream os; |
1606 | os << "(\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value)) |
1607 | << "\", " << buffer_index << ")" ; |
1608 | py.Input("buffer" , os.str()); |
1609 | |
1610 | py.Input("axis_separators" , axis_separators); |
1611 | return py.Str(); |
1612 | } |
1613 | |
1614 | template <typename> |
1615 | friend struct ::tvm::tir::UnpackedInstTraits; |
1616 | }; |
1617 | |
1618 | TVM_REGISTER_INST_KIND_TRAITS(TransformLayoutTraits); |
1619 | TVM_REGISTER_INST_KIND_TRAITS(TransformBlockLayoutTraits); |
1620 | TVM_REGISTER_INST_KIND_TRAITS(SetAxisSeparatorTraits); |
1621 | |
1622 | } // namespace tir |
1623 | } // namespace tvm |
1624 | |