1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "./utils.h"
21
22namespace tvm {
23namespace tir {
24
25/******** Annotation ********/
26
27Block WithAnnotation(const BlockNode* block, const String& attr_key, const ObjectRef& attr_value) {
28 Map<String, ObjectRef> annotations = block->annotations;
29 annotations.Set(attr_key, attr_value);
30 ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block);
31 new_block->annotations = std::move(annotations);
32 return Block(new_block);
33}
34
35/******** Buffer Related ********/
36Buffer WithScope(const Buffer& buffer, const String& scope) {
37 ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
38 ObjectPtr<VarNode> new_var = make_object<VarNode>(*buffer->data.get());
39 const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
40 new_var->type_annotation = PointerType(ptr_type->element_type, scope);
41 new_buffer->data = Var(new_var->name_hint + "_" + scope, new_var->type_annotation);
42 new_buffer->name = buffer->name + "_" + scope;
43 return Buffer(new_buffer);
44}
45
46Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& source,
47 const Buffer& target) {
48 regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion {
49 if (region->buffer.same_as(source)) {
50 ObjectPtr<BufferRegionNode> n = make_object<BufferRegionNode>(*region.get());
51 n->buffer = target;
52 return BufferRegion(n);
53 }
54 return region;
55 });
56 return regions;
57}
58
59Array<MatchBufferRegion> ReplaceBuffer(Array<MatchBufferRegion> match_buffers, const Buffer& source,
60 const Buffer& target) {
61 match_buffers.MutateByApply([&source,
62 &target](MatchBufferRegion match_buffer) -> MatchBufferRegion {
63 if (match_buffer->source->buffer.same_as(source)) {
64 ObjectPtr<MatchBufferRegionNode> n = make_object<MatchBufferRegionNode>(*match_buffer.get());
65 n->source = BufferRegion(target, n->source->region);
66 return MatchBufferRegion(n);
67 }
68 return match_buffer;
69 });
70 return match_buffers;
71}
72
73Array<BufferRegion> ReplaceBufferRegion(Array<BufferRegion> regions, const Buffer& source_buffer,
74 const BufferRegion& target) {
75 regions.MutateByApply([&source_buffer, &target](const BufferRegion& region) -> BufferRegion {
76 if (region->buffer.same_as(source_buffer)) {
77 return target;
78 }
79 return region;
80 });
81 return regions;
82}
83
84Array<MatchBufferRegion> ReplaceBufferRegion(Array<MatchBufferRegion> match_buffers,
85 const Buffer& source_buffer,
86 const BufferRegion& target) {
87 match_buffers.MutateByApply([&source_buffer, &target](
88 const MatchBufferRegion& match_buffer) -> MatchBufferRegion {
89 if (match_buffer->source->buffer.same_as(source_buffer)) {
90 ObjectPtr<MatchBufferRegionNode> n = make_object<MatchBufferRegionNode>(*match_buffer.get());
91 n->source = target;
92 return MatchBufferRegion(n);
93 }
94 return match_buffer;
95 });
96 return match_buffers;
97}
98
99/******** ReplaceBufferMutator ********/
100ReplaceBufferMutator::ReplaceBufferMutator(const Buffer& old_buffer, Buffer new_buffer,
101 Map<Block, Block>* block_sref_reuse)
102 : block_sref_reuse_(block_sref_reuse) {
103 buffer_var_map_[old_buffer->data.get()] = std::move(new_buffer);
104}
105
106ReplaceBufferMutator::ReplaceBufferMutator(const Map<Buffer, Buffer>& buffer_map,
107 Map<Block, Block>* block_sref_reuse)
108 : block_sref_reuse_(block_sref_reuse) {
109 for (const auto& [old_buffer, new_buffer] : buffer_map) {
110 buffer_var_map_[old_buffer->data.get()] = new_buffer;
111 }
112}
113
114PrimExpr ReplaceBufferMutator::VisitExpr_(const VarNode* var) {
115 auto it = buffer_var_map_.find(var);
116 return it != buffer_var_map_.end() ? it->second->data : GetRef<Var>(var);
117}
118
119Stmt ReplaceBufferMutator::VisitStmt_(const BufferStoreNode* op) {
120 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
121 return VisitBufferAccess(std::move(node));
122}
123
124PrimExpr ReplaceBufferMutator::VisitExpr_(const BufferLoadNode* op) {
125 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
126 return VisitBufferAccess(std::move(node));
127}
128
129MatchBufferRegion ReplaceBufferMutator::VisitMatchBufferRegion(
130 const MatchBufferRegion& match_buffer) {
131 auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
132 if (it != buffer_var_map_.end()) {
133 return MatchBufferRegion(match_buffer->buffer,
134 BufferRegion(it->second, match_buffer->source->region));
135 } else {
136 return match_buffer;
137 }
138}
139
140Stmt ReplaceBufferMutator::VisitStmt_(const BlockNode* block) {
141 // To reduce the number of blocks in block sref reuse map, we check whether the block is really
142 // mutated (i.e., the old buffer appears in the block). If so, we return the block after
143 // mutation. Otherwise we just return the original block.
144
145 auto f_mutate_match_buffer = [this](const MatchBufferRegion& match_buffer) {
146 return this->VisitMatchBufferRegion(match_buffer);
147 };
148 auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) {
149 auto region = MutateArray(buffer_region->region, [this](const Range& range) {
150 PrimExpr min = VisitExpr(range->min);
151 PrimExpr extent = VisitExpr(range->extent);
152 if (min.same_as(range->min) && extent.same_as(range->extent)) {
153 return range;
154 } else {
155 return Range::FromMinExtent(min, extent);
156 }
157 });
158
159 Buffer buf = [&]() {
160 auto it = buffer_var_map_.find(buffer_region->buffer->data.get());
161 if (it == buffer_var_map_.end()) {
162 return buffer_region->buffer;
163 } else {
164 return it->second;
165 }
166 }();
167
168 if (buf.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) {
169 return buffer_region;
170 } else {
171 return BufferRegion(buf, region);
172 }
173 };
174 auto f_mutate_alloc_buffers = [this](const Buffer& buffer) {
175 auto it = buffer_var_map_.find(buffer->data.get());
176 return it == buffer_var_map_.end() ? buffer : it->second;
177 };
178
179 // Step 1. Mutate `match_buffers`. If an old buffer appears as a source of MatchBufferRegion,
180 Array<MatchBufferRegion> match_buffers = block->match_buffers.Map(f_mutate_match_buffer);
181 // Step 2. Mutate the read/write region.
182 Array<BufferRegion> reads = block->reads.Map(f_mutate_read_write_region);
183 Array<BufferRegion> writes = block->writes.Map(f_mutate_read_write_region);
184 // Step 3. Mutate `alloc_buffers` for the old buffer allocated in this block.
185 Array<Buffer> alloc_buffers = block->alloc_buffers.Map(f_mutate_alloc_buffers);
186 // Step 4. Recursively mutate the block.
187 Block mutated_block = Downcast<Block>(StmtMutator::VisitStmt_(block));
188
189 if (mutated_block.get() == block && reads.same_as(mutated_block->reads) &&
190 writes.same_as(mutated_block->writes) &&
191 alloc_buffers.same_as(mutated_block->alloc_buffers) &&
192 match_buffers.same_as(mutated_block->match_buffers)) {
193 return GetRef<Block>(block);
194 } else {
195 ObjectPtr<BlockNode> n = CopyOnWrite(mutated_block.get());
196 n->reads = std::move(reads);
197 n->writes = std::move(writes);
198 n->alloc_buffers = std::move(alloc_buffers);
199 n->match_buffers = std::move(match_buffers);
200
201 Block new_block(n);
202 if (block_sref_reuse_ != nullptr) {
203 block_sref_reuse_->Set(GetRef<Block>(block), new_block);
204 }
205 return std::move(new_block);
206 }
207}
208
209/******** Block Removal ********/
210
211void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_sref,
212 Stmt* src_stmt, Stmt* tgt_stmt) {
213 class OnlyLeafError : public ScheduleError {
214 public:
215 explicit OnlyLeafError(IRModule mod, Block leaf_block, Block scope_root)
216 : mod_(mod), leaf_block_(leaf_block), scope_root_(scope_root) {}
217
218 String FastErrorString() const final {
219 return "ScheduleError: Cannot remove the only leaf in the scope";
220 }
221
222 String DetailRenderTemplate() const final {
223 return "Block {0} is the only leaf in the scope {1}, which cannot be removed; Otherwise the "
224 "scope will be empty.";
225 }
226
227 IRModule mod() const final { return mod_; }
228 Array<ObjectRef> LocationsOfInterest() const final { return {leaf_block_, scope_root_}; }
229
230 IRModule mod_;
231 Block leaf_block_;
232 Block scope_root_;
233 };
234
235 // Go upwards until find an ancestor with more than one child
236 const StmtNode* last_stmt = leaf_block_sref->stmt;
237 StmtSRefNode* sref = leaf_block_sref->parent;
238 for (;; last_stmt = sref->stmt, sref = sref->parent) {
239 if (const auto* loop = sref->StmtAs<ForNode>()) {
240 if (const auto* seq = loop->body.as<SeqStmtNode>()) {
241 if (seq->size() > 1) {
242 break;
243 }
244 }
245 } else {
246 // Removal is not done beyond scope-level.
247 // When encountering a block, i.e. the scope root, we simply stop
248 break;
249 }
250 }
251 if (const auto* block = sref->StmtAs<BlockNode>()) {
252 auto body = block->body;
253 // Peel off AllocateConst nodes at the beginning of the block body.
254 std::vector<const AllocateConstNode*> allocs;
255 while (const auto* alloc = body.as<AllocateConstNode>()) {
256 allocs.push_back(alloc);
257 body = alloc->body;
258 }
259 if (const auto* seq = body.as<SeqStmtNode>()) {
260 ObjectPtr<BlockNode> n = make_object<BlockNode>(*block);
261 auto new_seq = RemoveFromSeqStmt(GetRef<SeqStmt>(seq), GetRef<Stmt>(last_stmt));
262 // Re-attach AllocateConst nodes
263 auto new_body = new_seq;
264 for (int i = 0; i < static_cast<int>(allocs.size()); ++i) {
265 auto alloc = allocs[allocs.size() - 1 - i];
266 new_body = AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data,
267 new_body, alloc->annotations, alloc->span);
268 }
269 n->body = new_body;
270 *src_stmt = GetRef<Stmt>(block);
271 *tgt_stmt = Stmt(std::move(n));
272 return;
273 }
274 }
275 if (const auto* loop = sref->StmtAs<ForNode>()) {
276 if (const auto* seq = loop->body.as<SeqStmtNode>()) {
277 ObjectPtr<ForNode> n = make_object<ForNode>(*loop);
278 n->body = RemoveFromSeqStmt(GetRef<SeqStmt>(seq), GetRef<Stmt>(last_stmt));
279 *src_stmt = GetRef<Stmt>(loop);
280 *tgt_stmt = Stmt(std::move(n));
281 return;
282 }
283 }
284 ICHECK(sref != nullptr && sref->stmt != nullptr);
285 const auto* leaf_block = TVM_SREF_TO_BLOCK(leaf_block_sref);
286 const auto* scope_block = TVM_SREF_TO_BLOCK(sref);
287 throw OnlyLeafError(self->mod, GetRef<Block>(leaf_block), GetRef<Block>(scope_block));
288}
289
290Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
291 const String& intrin_name, bool allow_padding) {
292 Optional<tir::TensorizeInfo> opt_tensorize_info =
293 GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv),
294 tir::TensorIntrin::Get(intrin_name).value()->desc, allow_padding);
295 if (!opt_tensorize_info) return NullOpt;
296 const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
297 if (info->block_iter_paddings.defined()) {
298 sch->PadEinsum(block_rv, info->block_iter_paddings.value());
299 }
300 // Construct a mapping from tir loops back to LoopRVs
301 Map<tir::StmtSRef, LoopRV> loop2rv;
302 {
303 Array<LoopRV> loop_rvs = sch->GetLoops(block_rv);
304 for (const LoopRV& loop_rv : loop_rvs) {
305 loop2rv.Set(sch->GetSRef(loop_rv), loop_rv);
306 }
307 }
308 // Split the loops
309 arith::Analyzer analyzer;
310 std::unordered_set<const tir::StmtSRefNode*> inner_loops;
311 std::vector<LoopRV> reorder_suffix;
312 reorder_suffix.resize(info->loop_map.size());
313 for (const auto& kv : info->loop_map) {
314 // Extract mapping (block_loop => desc_loop)
315 const tir::StmtSRef& block_loop_sref = kv.first;
316 const tir::ForNode* block_loop = block_loop_sref->StmtAs<tir::ForNode>();
317 const tir::ForNode* desc_loop = kv.second.get();
318 ICHECK(block_loop != nullptr && desc_loop != nullptr);
319 // Extract the loop extent
320 PrimExpr block_extent = analyzer.Simplify(block_loop->extent);
321 PrimExpr desc_extent = analyzer.Simplify(desc_loop->extent);
322 const auto* int_block_extent = block_extent.as<IntImmNode>();
323 const auto* int_desc_extent = desc_extent.as<IntImmNode>();
324 ICHECK(int_block_extent != nullptr && int_desc_extent != nullptr);
325 // Check divisibility
326 int64_t total = int_block_extent->value;
327 int64_t inner = int_desc_extent->value;
328 ICHECK_EQ(total % inner, 0);
329 // Do the split. Leave the outer extent as NullOpt (unspecified) so that the split factors
330 // can be used for different extents (needed during tuning).
331 Array<LoopRV> split = sch->Split(loop2rv.at(block_loop_sref), {NullOpt, Integer(inner)});
332 ICHECK_EQ(split.size(), 2);
333 inner_loops.insert(sch->GetSRef(split[1]).operator->());
334 // The inner split will be reordered to the loop domain that is tensorized
335 int desc_loop_index = info->desc_loop_indexer.at(GetRef<tir::For>(desc_loop)).IntValue();
336 reorder_suffix[desc_loop_index] = split[1];
337 }
338 // Reorder the loops
339 std::vector<LoopRV> reorder_list;
340 bool meet = false;
341 Array<LoopRV> all_loops = sch->GetLoops(block_rv);
342 for (const LoopRV& loop : all_loops) {
343 if (inner_loops.count(sch->GetSRef(loop).operator->())) {
344 meet = true;
345 } else if (meet) {
346 reorder_list.push_back(loop);
347 }
348 }
349 reorder_list.insert(reorder_list.end(), reorder_suffix.begin(), reorder_suffix.end());
350 sch->Reorder(reorder_list);
351 ICHECK(!reorder_suffix.empty());
352 return reorder_suffix[0];
353}
354
355TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWithTensorIntrin);
356
357/******** BlockBufferAccessSimplifier ********/
358void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array<BufferRegion>* old_access_regions) {
359 auto fmutate = [this](const BufferRegion& buffer_region) {
360 std::vector<Range> new_buffer_region;
361 for (const auto& range : buffer_region->region) {
362 if (is_one(range->extent) && range->min->IsInstance<VarNode>()) {
363 new_buffer_region.push_back(Range::FromMinExtent(
364 SimplifyNonTrivialExpr(range->min, analyzer_), make_const(range->min.dtype(), 1)));
365 } else {
366 new_buffer_region.push_back(
367 Range::FromMinExtent(SimplifyNonTrivialExpr(range->min, analyzer_),
368 SimplifyNonTrivialExpr(range->extent, analyzer_)));
369 }
370 }
371 return BufferRegion(buffer_region->buffer, new_buffer_region);
372 };
373 (*old_access_regions).MutateByApply(fmutate);
374}
375
376void BlockBufferAccessSimplifier::SimplifyBufferIndices(Array<PrimExpr>* indices) {
377 (*indices).MutateByApply(
378 [this](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, analyzer_); });
379}
380
381Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) {
382 Block block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
383 auto* n = block.CopyOnWrite();
384 SimplifyAccessRegion(&n->reads);
385 SimplifyAccessRegion(&n->writes);
386 return std::move(block);
387}
388
389Stmt BlockBufferAccessSimplifier::VisitStmt_(const BufferStoreNode* op) {
390 BufferStore node = Downcast<BufferStore>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
391 SimplifyBufferIndices(&node.CopyOnWrite()->indices);
392 return std::move(node);
393}
394
395PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) {
396 BufferLoad node = Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));
397 SimplifyBufferIndices(&node.CopyOnWrite()->indices);
398 return std::move(node);
399}
400
401} // namespace tir
402} // namespace tvm
403