1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include "./utils.h" |
21 | |
22 | namespace tvm { |
23 | namespace tir { |
24 | |
25 | /******** Annotation ********/ |
26 | |
27 | Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) { |
28 | Map<String, ObjectRef> annotations = block->annotations; |
29 | annotations.Set(attr_key, attr_value); |
30 | ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block); |
31 | new_block->annotations = std::move(annotations); |
32 | return Block(new_block); |
33 | } |
34 | |
35 | /******** Buffer Related ********/ |
36 | Buffer WithScope(const Buffer& buffer, const String& scope) { |
37 | ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get()); |
38 | ObjectPtr<VarNode> new_var = make_object<VarNode>(*buffer->data.get()); |
39 | const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode); |
40 | new_var->type_annotation = PointerType(ptr_type->element_type, scope); |
41 | new_buffer->data = Var(new_var->name_hint + "_" + scope, new_var->type_annotation); |
42 | new_buffer->name = buffer->name + "_" + scope; |
43 | return Buffer(new_buffer); |
44 | } |
45 | |
46 | Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& source, |
47 | const Buffer& target) { |
48 | regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion { |
49 | if (region->buffer.same_as(source)) { |
50 | ObjectPtr<BufferRegionNode> n = make_object<BufferRegionNode>(*region.get()); |
51 | n->buffer = target; |
52 | return BufferRegion(n); |
53 | } |
54 | return region; |
55 | }); |
56 | return regions; |
57 | } |
58 | |
59 | Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source, |
60 | const Buffer& target) { |
61 | match_buffers.MutateByApply([&source, |
62 | &target](MatchBufferRegion match_buffer) -> MatchBufferRegion { |
63 | if (match_buffer->source->buffer.same_as(source)) { |
64 | ObjectPtr<MatchBufferRegionNode> n = make_object<MatchBufferRegionNode>(*match_buffer.get()); |
65 | n->source = BufferRegion(target, n->source->region); |
66 | return MatchBufferRegion(n); |
67 | } |
68 | return match_buffer; |
69 | }); |
70 | return match_buffers; |
71 | } |
72 | |
73 | Array<BufferRegion> ReplaceBufferRegion(Array<BufferRegion> regions, const Buffer& source_buffer, |
74 | const BufferRegion& target) { |
75 | regions.MutateByApply([&source_buffer, &target](const BufferRegion& region) -> BufferRegion { |
76 | if (region->buffer.same_as(source_buffer)) { |
77 | return target; |
78 | } |
79 | return region; |
80 | }); |
81 | return regions; |
82 | } |
83 | |
84 | Array<MatchBufferRegion> ReplaceBufferRegion(Array<MatchBufferRegion> match_buffers, |
85 | const Buffer& source_buffer, |
86 | const BufferRegion& target) { |
87 | match_buffers.MutateByApply([&source_buffer, &target]( |
88 | const MatchBufferRegion& match_buffer) -> MatchBufferRegion { |
89 | if (match_buffer->source->buffer.same_as(source_buffer)) { |
90 | ObjectPtr<MatchBufferRegionNode> n = make_object<MatchBufferRegionNode>(*match_buffer.get()); |
91 | n->source = target; |
92 | return MatchBufferRegion(n); |
93 | } |
94 | return match_buffer; |
95 | }); |
96 | return match_buffers; |
97 | } |
98 | |
99 | /******** ReplaceBufferMutator ********/ |
100 | ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer, |
101 | Map<Block, Block>* block_sref_reuse) |
102 | : block_sref_reuse_(block_sref_reuse) { |
103 | buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer); |
104 | } |
105 | |
106 | ReplaceBufferMutator::ReplaceBufferMutator(const Map<Buffer, Buffer>& buffer_map, |
107 | Map<Block, Block>* block_sref_reuse) |
108 | : block_sref_reuse_(block_sref_reuse) { |
109 | for (const auto& [old_buffer, new_buffer] : buffer_map) { |
110 | buffer_var_map_[old_buffer->data.get()] = new_buffer; |
111 | } |
112 | } |
113 | |
114 | PrimExpr ReplaceBufferMutator::VisitExpr_(const VarNode* var) { |
115 | auto it = buffer_var_map_.find(var); |
116 | return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var); |
117 | } |
118 | |
119 | Stmt ReplaceBufferMutator::VisitStmt_(const BufferStoreNode* op) { |
120 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
121 | return VisitBufferAccess(std::move(node)); |
122 | } |
123 | |
124 | PrimExpr ReplaceBufferMutator::VisitExpr_(const BufferLoadNode* op) { |
125 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
126 | return VisitBufferAccess(std::move(node)); |
127 | } |
128 | |
129 | MatchBufferRegion ReplaceBufferMutator::VisitMatchBufferRegion( |
130 | const MatchBufferRegion& match_buffer) { |
131 | auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get()); |
132 | if (it != buffer_var_map_.end()) { |
133 | return MatchBufferRegion(match_buffer->buffer, |
134 | BufferRegion(it->second, match_buffer->source->region)); |
135 | } else { |
136 | return match_buffer; |
137 | } |
138 | } |
139 | |
140 | Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) { |
141 | // To reduce the number of blocks in block sref reuse map, we check whether the block is really |
142 | // mutated (i.e., the old buffer appears in the block). If so, we return the block after |
143 | // mutation. Otherwise we just return the original block. |
144 | |
145 | auto f_mutate_match_buffer = [this](const MatchBufferRegion& match_buffer) { |
146 | return this->VisitMatchBufferRegion(match_buffer); |
147 | }; |
148 | auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { |
149 | auto region = MutateArray(buffer_region->region, [this](const Range& range) { |
150 | PrimExpr min = VisitExpr(range->min); |
151 | PrimExpr extent = VisitExpr(range->extent); |
152 | if (min.same_as(range->min) && extent.same_as(range->extent)) { |
153 | return range; |
154 | } else { |
155 | return Range::FromMinExtent(min, extent); |
156 | } |
157 | }); |
158 | |
159 | Buffer buf = [&]() { |
160 | auto it = buffer_var_map_.find(buffer_region->buffer->data.get()); |
161 | if (it == buffer_var_map_.end()) { |
162 | return buffer_region->buffer; |
163 | } else { |
164 | return it->second; |
165 | } |
166 | }(); |
167 | |
168 | if (buf.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { |
169 | return buffer_region; |
170 | } else { |
171 | return BufferRegion(buf, region); |
172 | } |
173 | }; |
174 | auto f_mutate_alloc_buffers = [this](const Buffer& buffer) { |
175 | auto it = buffer_var_map_.find(buffer->data.get()); |
176 | return it == buffer_var_map_.end() ? buffer : it->second; |
177 | }; |
178 | |
179 | // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion, |
180 | Array<MatchBufferRegion> match_buffers = block->match_buffers.Map(f_mutate_match_buffer); |
181 | // Step 2. Mutate the read/write region. |
182 | Array<BufferRegion> reads = block->reads.Map(f_mutate_read_write_region); |
183 | Array<BufferRegion> writes = block->writes.Map(f_mutate_read_write_region); |
184 | // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block. |
185 | Array<Buffer> alloc_buffers = block->alloc_buffers.Map(f_mutate_alloc_buffers); |
186 | // Step 4. Recursively mutate the block. |
187 | Block mutated_block = Downcast<Block>(StmtMutator::VisitStmt_(block)); |
188 | |
189 | if (mutated_block.get() == block && reads.same_as(mutated_block->reads) && |
190 | writes.same_as(mutated_block->writes) && |
191 | alloc_buffers.same_as(mutated_block->alloc_buffers) && |
192 | match_buffers.same_as(mutated_block->match_buffers)) { |
193 | return GetRef<Block>(block); |
194 | } else { |
195 | ObjectPtr<BlockNode> n = CopyOnWrite(mutated_block.get()); |
196 | n->reads = std::move(reads); |
197 | n->writes = std::move(writes); |
198 | n->alloc_buffers = std::move(alloc_buffers); |
199 | n->match_buffers = std::move(match_buffers); |
200 | |
201 | Block new_block(n); |
202 | if (block_sref_reuse_ != nullptr) { |
203 | block_sref_reuse_->Set(GetRef<Block>(block), new_block); |
204 | } |
205 | return std::move(new_block); |
206 | } |
207 | } |
208 | |
209 | /******** Block Removal ********/ |
210 | |
211 | void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref, |
212 | Stmt* src_stmt, Stmt* tgt_stmt) { |
213 | class OnlyLeafError : public ScheduleError { |
214 | public: |
215 | explicit OnlyLeafError(IRModule mod, Block leaf_block, Block scope_root) |
216 | : mod_(mod), leaf_block_(leaf_block), scope_root_(scope_root) {} |
217 | |
218 | String FastErrorString() const final { |
219 | return "ScheduleError: Cannot remove the only leaf in the scope" ; |
220 | } |
221 | |
222 | String DetailRenderTemplate() const final { |
223 | return "Block {0} is the only leaf in the scope {1}, which cannot be removed; Otherwise the " |
224 | "scope will be empty." ; |
225 | } |
226 | |
227 | IRModule mod() const final { return mod_; } |
228 | Array<ObjectRef> LocationsOfInterest() const final { return {leaf_block_, scope_root_}; } |
229 | |
230 | IRModule mod_; |
231 | Block leaf_block_; |
232 | Block scope_root_; |
233 | }; |
234 | |
235 | // Go upwards until find an ancestor with more than one child |
236 | const StmtNode* last_stmt = leaf_block_sref->stmt; |
237 | StmtSRefNode* sref = leaf_block_sref->parent; |
238 | for (;; last_stmt = sref->stmt, sref = sref->parent) { |
239 | if (const auto* loop = sref->StmtAs<ForNode>()) { |
240 | if (const auto* seq = loop->body.as<SeqStmtNode>()) { |
241 | if (seq->size() > 1) { |
242 | break; |
243 | } |
244 | } |
245 | } else { |
246 | // Removal is not done beyond scope-level. |
247 | // When encountering a block, i.e. the scope root, we simply stop |
248 | break; |
249 | } |
250 | } |
251 | if (const auto* block = sref->StmtAs<BlockNode>()) { |
252 | auto body = block->body; |
253 | // Peel off AllocateConst nodes at the beginning of the block body. |
254 | std::vector<const AllocateConstNode*> allocs; |
255 | while (const auto* alloc = body.as<AllocateConstNode>()) { |
256 | allocs.push_back(alloc); |
257 | body = alloc->body; |
258 | } |
259 | if (const auto* seq = body.as<SeqStmtNode>()) { |
260 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*block); |
261 | auto new_seq = RemoveFromSeqStmt(GetRef<SeqStmt>(seq), GetRef<Stmt>(last_stmt)); |
262 | // Re-attach AllocateConst nodes |
263 | auto new_body = new_seq; |
264 | for (int i = 0; i < static_cast<int>(allocs.size()); ++i) { |
265 | auto alloc = allocs[allocs.size() - 1 - i]; |
266 | new_body = AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, |
267 | new_body, alloc->annotations, alloc->span); |
268 | } |
269 | n->body = new_body; |
270 | *src_stmt = GetRef<Stmt>(block); |
271 | *tgt_stmt = Stmt(std::move(n)); |
272 | return; |
273 | } |
274 | } |
275 | if (const auto* loop = sref->StmtAs<ForNode>()) { |
276 | if (const auto* seq = loop->body.as<SeqStmtNode>()) { |
277 | ObjectPtr<ForNode> n = make_object<ForNode>(*loop); |
278 | n->body = RemoveFromSeqStmt(GetRef<SeqStmt>(seq), GetRef<Stmt>(last_stmt)); |
279 | *src_stmt = GetRef<Stmt>(loop); |
280 | *tgt_stmt = Stmt(std::move(n)); |
281 | return; |
282 | } |
283 | } |
284 | ICHECK(sref != nullptr && sref->stmt != nullptr); |
285 | const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block_sref); |
286 | const auto* scope_block = TVM_SREF_TO_BLOCK(sref); |
287 | throw OnlyLeafError(self->mod, GetRef<Block>(leaf_block), GetRef<Block>(scope_block)); |
288 | } |
289 | |
290 | Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, |
291 | const String& intrin_name, bool allow_padding) { |
292 | Optional<tir::TensorizeInfo> opt_tensorize_info = |
293 | GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv), |
294 | tir::TensorIntrin::Get(intrin_name).value()->desc, allow_padding); |
295 | if (!opt_tensorize_info) return NullOpt; |
296 | const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); |
297 | if (info->block_iter_paddings.defined()) { |
298 | sch->PadEinsum(block_rv, info->block_iter_paddings.value()); |
299 | } |
300 | // Construct a mapping from tir loops back to LoopRVs |
301 | Map<tir::StmtSRef, LoopRV> loop2rv; |
302 | { |
303 | Array<LoopRV> loop_rvs = sch->GetLoops(block_rv); |
304 | for (const LoopRV& loop_rv : loop_rvs) { |
305 | loop2rv.Set(sch->GetSRef(loop_rv), loop_rv); |
306 | } |
307 | } |
308 | // Split the loops |
309 | arith::Analyzer analyzer; |
310 | std::unordered_set<const tir::StmtSRefNode*> inner_loops; |
311 | std::vector<LoopRV> reorder_suffix; |
312 | reorder_suffix.resize(info->loop_map.size()); |
313 | for (const auto& kv : info->loop_map) { |
314 | // Extract mapping (block_loop => desc_loop) |
315 | const tir::StmtSRef& block_loop_sref = kv.first; |
316 | const tir::ForNode* block_loop = block_loop_sref->StmtAs<tir::ForNode>(); |
317 | const tir::ForNode* desc_loop = kv.second.get(); |
318 | ICHECK(block_loop != nullptr && desc_loop != nullptr); |
319 | // Extract the loop extent |
320 | PrimExpr block_extent = analyzer.Simplify(block_loop->extent); |
321 | PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent); |
322 | const auto* int_block_extent = block_extent.as<IntImmNode>(); |
323 | const auto* int_desc_extent = desc_extent.as<IntImmNode>(); |
324 | ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr); |
325 | // Check divisibility |
326 | int64_t total = int_block_extent->value; |
327 | int64_t inner = int_desc_extent->value; |
328 | ICHECK_EQ(total % inner, 0); |
329 | // Do the split. Leave the outer extent as NullOpt (unspecified) so that the split factors |
330 | // can be used for different extents (needed during tuning). |
331 | Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {NullOpt, Integer(inner)}); |
332 | ICHECK_EQ(split.size(), 2); |
333 | inner_loops.insert(sch->GetSRef(split[1]).operator->()); |
334 | // The inner split will be reordered to the loop domain that is tensorized |
335 | int desc_loop_index = info->desc_loop_indexer.at(GetRef<tir::For>(desc_loop)).IntValue(); |
336 | reorder_suffix[desc_loop_index] = split[1]; |
337 | } |
338 | // Reorder the loops |
339 | std::vector<LoopRV> reorder_list; |
340 | bool meet = false; |
341 | Array<LoopRV> all_loops = sch->GetLoops(block_rv); |
342 | for (const LoopRV& loop : all_loops) { |
343 | if (inner_loops.count(sch->GetSRef(loop).operator->())) { |
344 | meet = true; |
345 | } else if (meet) { |
346 | reorder_list.push_back(loop); |
347 | } |
348 | } |
349 | reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end()); |
350 | sch->Reorder(reorder_list); |
351 | ICHECK(!reorder_suffix.empty()); |
352 | return reorder_suffix[0]; |
353 | } |
354 | |
355 | TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin" ).set_body_typed(TileWithTensorIntrin); |
356 | |
357 | /******** BlockBufferAccessSimplifier ********/ |
358 | void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array<BufferRegion>* old_access_regions) { |
359 | auto fmutate = [this](const BufferRegion& buffer_region) { |
360 | std::vector<Range> new_buffer_region; |
361 | for (const auto& range : buffer_region->region) { |
362 | if (is_one(range->extent) && range->min->IsInstance<VarNode>()) { |
363 | new_buffer_region.push_back(Range::FromMinExtent( |
364 | SimplifyNonTrivialExpr(range->min, analyzer_), make_const(range->min.dtype(), 1))); |
365 | } else { |
366 | new_buffer_region.push_back( |
367 | Range::FromMinExtent(SimplifyNonTrivialExpr(range->min, analyzer_), |
368 | SimplifyNonTrivialExpr(range->extent, analyzer_))); |
369 | } |
370 | } |
371 | return BufferRegion(buffer_region->buffer, new_buffer_region); |
372 | }; |
373 | (*old_access_regions).MutateByApply(fmutate); |
374 | } |
375 | |
376 | void BlockBufferAccessSimplifier::SimplifyBufferIndices(Array<PrimExpr>* indices) { |
377 | (*indices).MutateByApply( |
378 | [this](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, analyzer_); }); |
379 | } |
380 | |
381 | Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) { |
382 | Block block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); |
383 | auto* n = block.CopyOnWrite(); |
384 | SimplifyAccessRegion(&n->reads); |
385 | SimplifyAccessRegion(&n->writes); |
386 | return std::move(block); |
387 | } |
388 | |
389 | Stmt BlockBufferAccessSimplifier::VisitStmt_(const BufferStoreNode* op) { |
390 | BufferStore node = Downcast<BufferStore>(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); |
391 | SimplifyBufferIndices(&node.CopyOnWrite()->indices); |
392 | return std::move(node); |
393 | } |
394 | |
395 | PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) { |
396 | BufferLoad node = Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op)); |
397 | SimplifyBufferIndices(&node.CopyOnWrite()->indices); |
398 | return std::move(node); |
399 | } |
400 | |
401 | } // namespace tir |
402 | } // namespace tvm |
403 | |