1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file stmt_functor.cc
21 */
22#include <tvm/ir/module.h>
23#include <tvm/runtime/registry.h>
24#include <tvm/tir/data_type_rewriter.h>
25#include <tvm/tir/function.h>
26#include <tvm/tir/stmt_functor.h>
27
28#include <functional>
29
30#include "functor_common.h"
31
32namespace tvm {
33namespace tir {
34
35void StmtVisitor::VisitStmt_(const LetStmtNode* op) {
36 this->VisitExpr(op->value);
37 this->VisitStmt(op->body);
38}
39
40void StmtVisitor::VisitStmt_(const AttrStmtNode* op) {
41 this->VisitExpr(op->value);
42 this->VisitStmt(op->body);
43}
44
45void StmtVisitor::VisitStmt_(const ForNode* op) {
46 this->VisitExpr(op->min);
47 this->VisitExpr(op->extent);
48 this->VisitStmt(op->body);
49}
50
51void StmtVisitor::VisitStmt_(const WhileNode* op) {
52 this->VisitExpr(op->condition);
53 this->VisitStmt(op->body);
54}
55
56void StmtVisitor::VisitStmt_(const AllocateNode* op) {
57 VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); });
58 this->VisitStmt(op->body);
59 this->VisitExpr(op->condition);
60}
61
62void StmtVisitor::VisitStmt_(const AllocateConstNode* op) {
63 VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); });
64 this->VisitStmt(op->body);
65}
66
67void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { this->VisitStmt(op->body); }
68
69void StmtVisitor::VisitStmt_(const StoreNode* op) {
70 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
71}
72
73void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
74 this->VisitExpr(op->value);
75 VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
76}
77
78void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
79 VisitArray(op->bounds, [this](const Range& r) {
80 this->VisitExpr(r->min);
81 this->VisitExpr(r->extent);
82 });
83 this->VisitExpr(op->condition);
84 this->VisitStmt(op->body);
85}
86
87void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
88 this->VisitExpr(op->condition);
89 this->VisitStmt(op->then_case);
90 if (op->else_case) {
91 this->VisitStmt(op->else_case.value());
92 }
93}
94
95void StmtVisitor::VisitStmt_(const AssertStmtNode* op) {
96 this->VisitExpr(op->condition);
97 this->VisitExpr(op->message);
98 this->VisitStmt(op->body);
99}
100
101void StmtVisitor::VisitStmt_(const ProducerStoreNode* op) {
102 VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
103 this->VisitExpr(op->value);
104}
105
106void StmtVisitor::VisitStmt_(const ProducerRealizeNode* op) {
107 VisitArray(op->bounds, [this](const Range& r) {
108 this->VisitExpr(r->min);
109 this->VisitExpr(r->extent);
110 });
111 this->VisitStmt(op->body);
112 this->VisitExpr(op->condition);
113}
114
115void StmtVisitor::VisitStmt_(const PrefetchNode* op) {
116 VisitArray(op->bounds, [this](const Range& r) {
117 this->VisitExpr(r->min);
118 this->VisitExpr(r->extent);
119 });
120}
121
122void StmtVisitor::VisitStmt_(const SeqStmtNode* op) {
123 VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); });
124}
125
126void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); }
127
128void StmtVisitor::VisitStmt_(const BlockNode* op) {
129 auto fvisit_buffer_region = [this](const BufferRegion& s) {
130 for (const auto& range : s->region) {
131 this->VisitExpr(range->min);
132 this->VisitExpr(range->extent);
133 }
134 };
135 VisitArray(op->iter_vars, [this](const IterVar& iter_var) {
136 this->VisitExpr(iter_var->dom->min);
137 this->VisitExpr(iter_var->dom->extent);
138 });
139 VisitArray(op->reads, fvisit_buffer_region);
140 VisitArray(op->writes, fvisit_buffer_region);
141 VisitArray(op->match_buffers,
142 [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) {
143 fvisit_buffer_region(match_buffer_region->source);
144 });
145 if (op->init.defined()) {
146 this->VisitStmt(op->init.value());
147 }
148 this->VisitStmt(op->body);
149}
150
151void StmtVisitor::VisitStmt_(const BlockRealizeNode* op) {
152 VisitArray(op->iter_values, [this](const PrimExpr& e) { this->VisitExpr(e); });
153 this->VisitExpr(op->predicate);
154 this->VisitStmt(op->block);
155}
156
157class StmtMutator::Internal {
158 public:
159 /*!
160 * \brief Mutate array's element by fmutate function.
161 *
162 * \note Use extra care for copy on write setting.
163 *
164 * In particular, consider the following case of two reference chains:
165 * - strongref0 -> loop0 -> loop1 -> loop2
166 * - strongref1 -> loop3 -> loop1 -> loop2
167 *
168 * Think of the case of calling MutateArray on loop1->loop2(as const reference).
169 * When both strongref0 and strongref1 exists, the context does not allow copy
170 * on write, even though loop1 uniquely refers to loop2.
171 *
172 * \param self The pointer to the mutator.
173 * \param arr Array to be mutated, const reference is used to allow copy on write
174 * mutation in a recursive visitor.
175 * \param fmutate The mutator function.
176 * \return The mutated array, a new copy can be created.
177 */
178 template <typename T, typename F>
179 static Array<T> MutateArray(StmtMutator* self, const Array<T>& arr, F fmutate) {
180 if (self->allow_copy_on_write_ && arr.unique()) {
181 // if we allow copy on write, we can directly
182 // call the inplace mutate function.
183 const_cast<Array<T>&>(arr).MutateByApply(fmutate);
184 return arr;
185 } else {
186 bool allow_cow = false;
187 std::swap(allow_cow, self->allow_copy_on_write_);
188 Array<T> copy = arr.Map(fmutate);
189 std::swap(allow_cow, self->allow_copy_on_write_);
190 return copy;
191 }
192 }
193
194 static Array<IterVar> Mutate(StmtMutator* self, const Array<IterVar>& arr) {
195 auto fmutate = [self](const IterVar& iter_var) {
196 PrimExpr min = self->VisitExpr(iter_var->dom->min);
197 PrimExpr extent = self->VisitExpr(iter_var->dom->extent);
198 if (min.same_as(iter_var->dom->min) && extent.same_as(iter_var->dom->extent)) {
199 return iter_var;
200 } else {
201 return IterVar(Range(min, extent), iter_var->var, iter_var->iter_type,
202 iter_var->thread_tag);
203 }
204 };
205 return MutateArray(self, arr, fmutate);
206 }
207
208 static Array<PrimExpr> Mutate(StmtMutator* self, const Array<PrimExpr>& arr) {
209 auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); };
210 return MutateArray(self, arr, fmutate);
211 }
212
213 static Array<Stmt> Mutate(StmtMutator* self, const Array<Stmt>& arr) {
214 auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); };
215 return MutateArray(self, arr, fmutate);
216 }
217
218 static Array<Range> Mutate(StmtMutator* self, const Array<Range>& arr) {
219 auto fmutate = [self](const Range& r) {
220 PrimExpr min = self->VisitExpr(r->min);
221 PrimExpr extent = self->VisitExpr(r->extent);
222 if (min.same_as(r->min) && extent.same_as(r->extent)) {
223 return r;
224 } else {
225 return Range::FromMinExtent(min, extent);
226 }
227 };
228 return MutateArray(self, arr, fmutate);
229 }
230
231 static Array<BufferRegion> Mutate(StmtMutator* self, const Array<BufferRegion>& arr) {
232 auto fmutate = [self](const BufferRegion& buffer_region) {
233 Array<Range> region = Mutate(self, buffer_region->region);
234 if (region.same_as(buffer_region->region)) {
235 return buffer_region;
236 } else {
237 return BufferRegion(buffer_region->buffer, region);
238 }
239 };
240 return MutateArray(self, arr, fmutate);
241 }
242
243 static Array<MatchBufferRegion> Mutate(StmtMutator* self, const Array<MatchBufferRegion>& arr) {
244 auto fmutate = [self](const MatchBufferRegion& match_buffer_region) {
245 Array<Range> region = Mutate(self, match_buffer_region->source->region);
246 if (region.same_as(match_buffer_region->source->region)) {
247 return match_buffer_region;
248 } else {
249 return MatchBufferRegion(match_buffer_region->buffer,
250 BufferRegion(match_buffer_region->source->buffer, region));
251 }
252 };
253 return MutateArray(self, arr, fmutate);
254 }
255};
256
257Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) {
258 PrimExpr value = this->VisitExpr(op->value);
259 Stmt body = this->VisitStmt(op->body);
260 if (value.same_as(op->value) && body.same_as(op->body)) {
261 return GetRef<Stmt>(op);
262 } else {
263 auto n = CopyOnWrite(op);
264 n->value = std::move(value);
265 n->body = std::move(body);
266 return Stmt(n);
267 }
268}
269
270Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) {
271 PrimExpr value = this->VisitExpr(op->value);
272 Stmt body = this->VisitStmt(op->body);
273 if (value.same_as(op->value) && body.same_as(op->body)) {
274 return GetRef<Stmt>(op);
275 } else {
276 auto n = CopyOnWrite(op);
277 n->value = std::move(value);
278 n->body = std::move(body);
279 return Stmt(n);
280 }
281}
282
283Stmt StmtMutator::VisitStmt_(const ForNode* op) {
284 PrimExpr min = this->VisitExpr(op->min);
285 PrimExpr extent = this->VisitExpr(op->extent);
286 Stmt body = this->VisitStmt(op->body);
287 if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) {
288 return GetRef<Stmt>(op);
289 } else {
290 auto n = CopyOnWrite(op);
291 n->min = std::move(min);
292 n->extent = std::move(extent);
293 n->body = std::move(body);
294 return Stmt(n);
295 }
296}
297
298Stmt StmtMutator::VisitStmt_(const WhileNode* op) {
299 PrimExpr condition = this->VisitExpr(op->condition);
300 Stmt body = this->VisitStmt(op->body);
301 if (condition.same_as(op->condition) && body.same_as(op->body)) {
302 return GetRef<Stmt>(op);
303 } else {
304 auto n = CopyOnWrite(op);
305 n->condition = std::move(condition);
306 n->body = std::move(body);
307 return Stmt(n);
308 }
309}
310
311Stmt StmtMutator::VisitStmt_(const AllocateNode* op) {
312 Array<PrimExpr> extents = Internal::Mutate(this, op->extents);
313 Stmt body = this->VisitStmt(op->body);
314 PrimExpr condition = this->VisitExpr(op->condition);
315
316 if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) {
317 return GetRef<Stmt>(op);
318 } else {
319 auto n = CopyOnWrite(op);
320 n->extents = std::move(extents);
321 n->body = std::move(body);
322 n->condition = std::move(condition);
323 return Stmt(n);
324 }
325}
326
327Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) {
328 Array<PrimExpr> extents = Internal::Mutate(this, op->extents);
329 Stmt body = this->VisitStmt(op->body);
330
331 if (extents.same_as(op->extents) && body.same_as(op->body)) {
332 return GetRef<Stmt>(op);
333 } else {
334 auto n = CopyOnWrite(op);
335 n->extents = std::move(extents);
336 n->body = std::move(body);
337 return Stmt(n);
338 }
339}
340
341Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) {
342 Stmt body = this->VisitStmt(op->body);
343
344 if (body.same_as(op->body)) {
345 return GetRef<Stmt>(op);
346 } else {
347 auto n = CopyOnWrite(op);
348 n->body = std::move(body);
349 return Stmt(n);
350 }
351}
352
353Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
354 PrimExpr condition = this->VisitExpr(op->condition);
355 Stmt then_case = this->VisitStmt(op->then_case);
356 Optional<Stmt> else_case = NullOpt;
357 if (op->else_case) {
358 else_case = this->VisitStmt(op->else_case.value());
359 }
360 if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
361 else_case.same_as(op->else_case)) {
362 return GetRef<Stmt>(op);
363 } else {
364 auto n = CopyOnWrite(op);
365 n->condition = std::move(condition);
366 n->then_case = std::move(then_case);
367 n->else_case = std::move(else_case);
368 return Stmt(n);
369 }
370}
371
372Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
373 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
374}
375
376Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
377 PrimExpr value = this->VisitExpr(op->value);
378 Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
379
380 if (value.same_as(op->value) && indices.same_as(op->indices)) {
381 return GetRef<Stmt>(op);
382 } else {
383 auto n = CopyOnWrite(op);
384 n->value = std::move(value);
385 n->indices = std::move(indices);
386 return Stmt(n);
387 }
388}
389
390Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) {
391 Region bounds = Internal::Mutate(this, op->bounds);
392 PrimExpr condition = this->VisitExpr(op->condition);
393 Stmt body = this->VisitStmt(op->body);
394
395 if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) {
396 return GetRef<Stmt>(op);
397 } else {
398 auto n = CopyOnWrite(op);
399 n->bounds = std::move(bounds);
400 n->condition = std::move(condition);
401 n->body = std::move(body);
402 return Stmt(n);
403 }
404}
405
406Stmt StmtMutator::VisitStmt_(const ProducerStoreNode* op) {
407 Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
408 PrimExpr value = this->VisitExpr(op->value);
409 if (indices.same_as(op->indices) && value.same_as(op->value)) {
410 return GetRef<Stmt>(op);
411 } else {
412 auto n = CopyOnWrite(op);
413 n->indices = std::move(indices);
414 n->value = std::move(value);
415 return Stmt(n);
416 }
417}
418
419Stmt StmtMutator::VisitStmt_(const ProducerRealizeNode* op) {
420 Region bounds = Internal::Mutate(this, op->bounds);
421 Stmt body = this->VisitStmt(op->body);
422 PrimExpr condition = this->VisitExpr(op->condition);
423 if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) {
424 return GetRef<Stmt>(op);
425 } else {
426 auto n = CopyOnWrite(op);
427 n->bounds = std::move(bounds);
428 n->body = std::move(body);
429 n->condition = std::move(condition);
430 return Stmt(n);
431 }
432}
433
434Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) {
435 Region bounds = Internal::Mutate(this, op->bounds);
436 if (bounds.same_as(op->bounds)) {
437 return GetRef<Stmt>(op);
438 } else {
439 auto n = CopyOnWrite(op);
440 n->bounds = std::move(bounds);
441 return Stmt(n);
442 }
443}
444
445Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) {
446 Array<Stmt> seq = Internal::Mutate(this, op->seq);
447 if (seq.same_as(op->seq)) {
448 return GetRef<Stmt>(op);
449 } else {
450 auto n = CopyOnWrite(op);
451 n->seq = std::move(seq);
452 return Stmt(n);
453 }
454}
455
456// advanced visit function for seqstmt.
457Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit,
458 std::function<Stmt(const Stmt&)> fmutate) {
459 if (flatten_before_visit) {
460 // Pass 1, check if we need to flatten.
461 bool need_flatten = false;
462 for (size_t i = 0; i < op->seq.size(); ++i) {
463 Stmt tmp = (*op)[i];
464 if (tmp.as<SeqStmtNode>()) need_flatten = true;
465 }
466 flatten_before_visit = need_flatten;
467 }
468 // function to run the visit.
469 auto frunvisit = [&](const SeqStmtNode* op) {
470 Array<Stmt> seq = fmutate != nullptr ? Internal::MutateArray(this, op->seq, fmutate)
471 : Internal::Mutate(this, op->seq);
472 if (seq.same_as(op->seq)) {
473 return GetRef<Stmt>(op);
474 } else {
475 auto n = CopyOnWrite(op);
476 n->seq = std::move(seq);
477 return Stmt(n);
478 }
479 };
480 if (flatten_before_visit) {
481 Array<Stmt> seq;
482 SeqStmt::Flattener flattener(&seq);
483 flattener(0, op->seq);
484 // NOTE: If copy on write is allowed
485 // the assignment to seq below will
486 // destruct the original seq.
487 //
488 // Such destruction removes duplicated reference
489 // count to children and still enables COW for
490 // child Stmt.
491 ObjectPtr<SeqStmtNode> n = CopyOnWrite(op);
492 n->seq = std::move(seq);
493 return frunvisit(n.operator->());
494 } else {
495 return frunvisit(op);
496 }
497}
498
499Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) {
500 PrimExpr condition = this->VisitExpr(op->condition);
501 PrimExpr message = this->VisitExpr(op->message);
502 Stmt body = this->VisitStmt(op->body);
503
504 if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) {
505 return GetRef<Stmt>(op);
506 } else {
507 auto n = CopyOnWrite(op);
508 n->condition = std::move(condition);
509 n->message = std::move(message);
510 n->body = std::move(body);
511 return Stmt(n);
512 }
513}
514
515Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) {
516 PrimExpr value = this->VisitExpr(op->value);
517 if (value.same_as(op->value)) {
518 return GetRef<Stmt>(op);
519 } else {
520 auto n = CopyOnWrite(op);
521 n->value = std::move(value);
522 return Stmt(n);
523 }
524}
525
526Stmt StmtMutator::VisitStmt_(const BlockNode* op) {
527 Array<IterVar> iter_vars = Internal::Mutate(this, op->iter_vars);
528 Array<BufferRegion> reads = Internal::Mutate(this, op->reads);
529 Array<BufferRegion> writes = Internal::Mutate(this, op->writes);
530 Array<MatchBufferRegion> match_buffers = Internal::Mutate(this, op->match_buffers);
531 Optional<Stmt> init = NullOpt;
532 if (op->init.defined()) {
533 init = VisitStmt(op->init.value());
534 }
535 Stmt body = VisitStmt(op->body);
536 if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && writes.same_as(op->writes) &&
537 body.same_as(op->body) && init.same_as(op->init) &&
538 match_buffers.same_as(op->match_buffers)) {
539 return GetRef<Block>(op);
540 } else {
541 auto n = CopyOnWrite(op);
542 n->iter_vars = std::move(iter_vars);
543 n->reads = std::move(reads);
544 n->writes = std::move(writes);
545 n->body = std::move(body);
546 n->init = std::move(init);
547 n->match_buffers = std::move(match_buffers);
548 return Stmt(n);
549 }
550}
551
552Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) {
553 Array<PrimExpr> v = Internal::Mutate(this, op->iter_values);
554 PrimExpr pred = this->VisitExpr(op->predicate);
555 Stmt block = this->VisitStmt(op->block);
556 if (v.same_as(op->iter_values) && pred.same_as(op->predicate) && block.same_as(op->block)) {
557 return GetRef<Stmt>(op);
558 } else {
559 auto n = CopyOnWrite(op);
560 n->iter_values = std::move(v);
561 n->predicate = std::move(pred);
562 n->block = Downcast<Block>(block);
563 return Stmt(n);
564 }
565}
566
567// Implementations of IRTransform, PostOrderVisit and Substitute
568class IRApplyVisit : public StmtExprVisitor {
569 public:
570 explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {}
571
572 void VisitExpr(const PrimExpr& node) final {
573 if (visited_.count(node.get()) != 0) return;
574 visited_.insert(node.get());
575 ExprVisitor::VisitExpr(node);
576 f_(node);
577 }
578
579 void VisitStmt(const Stmt& node) final {
580 if (visited_.count(node.get()) != 0) return;
581 visited_.insert(node.get());
582 StmtVisitor::VisitStmt(node);
583 f_(node);
584 }
585
586 private:
587 std::function<void(const ObjectRef&)> f_;
588 std::unordered_set<const Object*> visited_;
589};
590
591void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit) {
592 if (node.as<StmtNode>()) {
593 IRApplyVisit visitor(fvisit);
594 visitor(Downcast<Stmt>(node));
595 } else {
596 IRApplyVisit visitor(fvisit);
597 visitor(Downcast<PrimExpr>(node));
598 }
599}
600
601class IRTransformer final : public StmtExprMutator {
602 public:
603 IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder,
604 const std::unordered_set<uint32_t>& only_enable)
605 : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {}
606
607 Stmt VisitStmt(const Stmt& stmt) final {
608 return MutateInternal<Stmt>(stmt, [this](const Stmt& s) { return this->BaseVisitStmt(s); });
609 }
610 PrimExpr VisitExpr(const PrimExpr& expr) final {
611 return MutateInternal<PrimExpr>(expr,
612 [this](const PrimExpr& e) { return this->BaseVisitExpr(e); });
613 }
614
615 private:
616 // NOTE: redirect to parent's call
617 // This is used to get around limitation of gcc-4.8
618 Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); }
619 PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); }
620
621 template <typename T, typename F>
622 T MutateInternal(const T& node, F fmutate) {
623 if (only_enable_.size() && !only_enable_.count(node->type_index())) {
624 return fmutate(node);
625 }
626 if (f_preorder_ != nullptr) {
627 T pre = f_preorder_(node);
628 if (pre.defined()) return pre;
629 }
630 T new_node = fmutate(node);
631 if (f_postorder_ != nullptr) {
632 T post = f_postorder_(new_node);
633 if (post.defined()) return post;
634 }
635 return new_node;
636 }
637 // The functions
638 const runtime::PackedFunc& f_preorder_;
639 const runtime::PackedFunc& f_postorder_;
640 // type indices enabled.
641 const std::unordered_set<uint32_t>& only_enable_;
642};
643
644Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder,
645 const runtime::PackedFunc& f_postorder, Optional<Array<String>> only_enable) {
646 std::unordered_set<uint32_t> only_type_index;
647 if (only_enable.defined()) {
648 for (auto s : only_enable.value()) {
649 only_type_index.insert(Object::TypeKey2Index(s.c_str()));
650 }
651 }
652 IRTransformer transform(f_preorder, f_postorder, only_type_index);
653 return transform(std::move(ir_node));
654}
655
656class IRSubstitute : public StmtExprMutator {
657 public:
658 explicit IRSubstitute(std::function<Optional<PrimExpr>(const Var&)> vmap) : vmap_(vmap) {}
659
660 PrimExpr VisitExpr_(const VarNode* op) final {
661 Var var = GetRef<Var>(op);
662 auto ret = vmap_(var);
663 if (ret.defined()) {
664 // Allow substitution of void variables with any expression. The TVM script parser
665 // uses void variables for lambda parameters (since exact types are not known yet).
666 if (!var.dtype().is_void()) {
667 PrimExpr ret_ex = Downcast<PrimExpr>(ret.value());
668 ICHECK(ret_ex.dtype() == var.dtype()) << "substituting " << var << ":" << var.dtype()
669 << " -> " << ret_ex << ":" << ret_ex.dtype();
670 }
671 return ret.value();
672 }
673 return std::move(var);
674 }
675
676 PrimExpr VisitExpr_(const LoadNode* op) final {
677 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
678 }
679
680 Stmt VisitStmt_(const StoreNode* op) final {
681 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
682 }
683
684 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
685 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
686 return VisitBufferAccess(std::move(node));
687 }
688
689 Stmt VisitStmt_(const BufferStoreNode* op) final {
690 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
691 return VisitBufferAccess(std::move(node));
692 }
693
694 template <typename Node>
695 Node VisitBufferAccess(Node node) {
696 Buffer new_buf = GetRemappedBuffer(node->buffer);
697
698 if (!new_buf.same_as(node->buffer)) {
699 auto writer = node.CopyOnWrite();
700 writer->buffer = new_buf;
701 }
702
703 return node;
704 }
705
706 Buffer GetRemappedBuffer(Buffer buf) {
707 auto key = buf.get();
708 auto it = buf_remap_.find(key);
709 if (it != buf_remap_.end()) {
710 return it->second;
711 }
712
713 auto new_buffer_var = vmap_(buf->data);
714 if (new_buffer_var.defined() && !new_buffer_var.value().same_as(buf->data)) {
715 auto writer = buf.CopyOnWrite();
716 writer->data = Downcast<Var>(new_buffer_var);
717 }
718
719 buf_remap_[key] = buf;
720 return buf;
721 }
722
723 Stmt VisitStmt_(const AttrStmtNode* op) final {
724 Stmt ret = StmtExprMutator::VisitStmt_(op);
725 op = ret.as<AttrStmtNode>();
726 // remap var node in attr
727 if (const auto* var_node = op->node.as<VarNode>()) {
728 if (auto mapped_var = vmap_(GetRef<Var>(var_node))) {
729 return AttrStmt(mapped_var, op->attr_key, op->value, op->body);
730 }
731 }
732 return ret;
733 }
734
735 private:
736 // Caller provided function that defines the variables to be remapped.
737 std::function<Optional<PrimExpr>(const Var&)> vmap_;
738
739 /* \brief Generated map to track buffers being remapped.
740 *
741 * If a `Var BufferNode::data` is remapped, then all buffers
742 * containing that data pointer should also be remapped. This map
743 * is used to track buffer modifications, and ensure all instances
744 * of a buffer are replaced by the same modified buffer object.
745 */
746 std::unordered_map<const BufferNode*, Buffer> buf_remap_;
747};
748
749Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var&)> vmap) {
750 return IRSubstitute(vmap)(std::move(stmt));
751}
752
753PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap) {
754 return IRSubstitute(vmap)(std::move(expr));
755}
756
757Array<Range> Substitute(const Array<Range>& region, const Map<Var, PrimExpr>& vmap) {
758 Array<Range> result;
759 result.reserve(region.size());
760 for (const Range& range : region) {
761 PrimExpr min = Substitute(range->min, vmap);
762 PrimExpr extent = Substitute(range->extent, vmap);
763 result.push_back(Range::FromMinExtent(std::move(min), std::move(extent)));
764 }
765 return result;
766}
767
768void PreOrderVisit(const ObjectRef& stmt_or_expr,
769 const std::function<bool(const ObjectRef&)>& fvisit) {
770 class PreOrderVisitor : public StmtExprVisitor {
771 public:
772 explicit PreOrderVisitor(const std::function<bool(const ObjectRef&)>& f) : f_(f) {}
773
774 private:
775 void VisitExpr(const PrimExpr& expr) final {
776 const PrimExprNode* p_expr = expr.get();
777 if (visited_.count(p_expr) == 0) {
778 visited_.insert(p_expr);
779 if (f_(expr)) {
780 ExprVisitor::VisitExpr(expr);
781 }
782 }
783 }
784
785 void VisitStmt(const Stmt& stmt) final {
786 const StmtNode* p_stmt = stmt.get();
787 if (visited_.count(p_stmt) == 0) {
788 visited_.insert(p_stmt);
789 if (f_(stmt)) {
790 StmtVisitor::VisitStmt(stmt);
791 }
792 }
793 }
794
795 const std::function<bool(const ObjectRef&)>& f_;
796 std::unordered_set<const Object*> visited_;
797 };
798
799 PreOrderVisitor visitor(fvisit);
800 if (const auto* stmt = stmt_or_expr.as<StmtNode>()) {
801 visitor(GetRef<Stmt>(stmt));
802 } else if (const auto* expr = stmt_or_expr.as<PrimExprNode>()) {
803 visitor(GetRef<PrimExpr>(expr));
804 } else {
805 LOG(FATAL) << "InternalError: PreOrderVisit does not accept object with type: "
806 << stmt_or_expr->GetTypeKey();
807 }
808}
809
810class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer {
811 public:
812 explicit IRSubstituteWithDataTypeLegalization(std::function<Optional<PrimExpr>(const Var&)> vmap)
813 : vmap_(vmap) {}
814
815 using DataTypeLegalizer::VisitExpr_;
816 using DataTypeLegalizer::VisitStmt_;
817
818 PrimExpr VisitExpr_(const VarNode* op) final {
819 Var var = GetRef<Var>(op);
820 auto ret = vmap_(var);
821 if (ret.defined()) {
822 return ret.value();
823 }
824 return std::move(var);
825 }
826
827 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
828 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
829 return VisitBufferAccess(std::move(node));
830 }
831
832 Stmt VisitStmt_(const BufferStoreNode* op) final {
833 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
834 return VisitBufferAccess(std::move(node));
835 }
836
837 template <typename Node>
838 Node VisitBufferAccess(Node node) {
839 Buffer new_buf = GetRemappedBuffer(node->buffer);
840
841 if (!new_buf.same_as(node->buffer)) {
842 auto writer = node.CopyOnWrite();
843 writer->buffer = new_buf;
844 }
845
846 return node;
847 }
848
849 Buffer GetRemappedBuffer(Buffer buf) {
850 auto key = buf.get();
851 auto it = buf_remap_.find(key);
852 if (it != buf_remap_.end()) {
853 return it->second;
854 }
855
856 auto new_buffer_var = vmap_(buf->data);
857 if (new_buffer_var.defined() && !new_buffer_var.value().same_as(buf->data)) {
858 auto writer = buf.CopyOnWrite();
859 writer->data = Downcast<Var>(new_buffer_var);
860 }
861
862 buf_remap_[key] = buf;
863 return buf;
864 }
865
866 Stmt VisitStmt_(const AttrStmtNode* op) final {
867 Stmt ret = StmtExprMutator::VisitStmt_(op);
868 op = ret.as<AttrStmtNode>();
869 // remap var node in attr
870 if (const auto* var_node = op->node.as<VarNode>()) {
871 if (auto mapped_var = vmap_(GetRef<Var>(var_node))) {
872 return AttrStmt(mapped_var, op->attr_key, op->value, op->body);
873 }
874 }
875 return ret;
876 }
877
878 private:
879 // Caller provided function that defines the variables to be remapped.
880 std::function<Optional<PrimExpr>(const Var&)> vmap_;
881
882 /* \brief Generated map to track buffers being remapped.
883 *
884 * If a `Var BufferNode::data` is remapped, then all buffers
885 * containing that data pointer should also be remapped. This map
886 * is used to track buffer modifications, and ensure all instances
887 * of a buffer are replaced by the same modified buffer object.
888 */
889 std::unordered_map<const BufferNode*, Buffer> buf_remap_;
890};
891
892Stmt SubstituteWithDataTypeLegalization(Stmt stmt,
893 std::function<Optional<PrimExpr>(const Var&)> vmap) {
894 return IRSubstituteWithDataTypeLegalization(vmap)(std::move(stmt));
895}
896
897PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr,
898 std::function<Optional<PrimExpr>(const Var&)> vmap) {
899 return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr));
900}
901
902TVM_REGISTER_GLOBAL("tir.IRTransform").set_body_typed(IRTransform);
903
904TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
905 tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); });
906});
907
908TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) {
909 tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n); });
910});
911
912TVM_REGISTER_GLOBAL("tir.Substitute")
913 .set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef {
914 if (node->IsInstance<StmtNode>()) {
915 return Substitute(Downcast<Stmt>(node), vmap);
916 } else {
917 return Substitute(Downcast<PrimExpr>(node), vmap);
918 }
919 });
920
921} // namespace tir
922} // namespace tvm
923