1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file stmt_functor.cc |
21 | */ |
22 | #include <tvm/ir/module.h> |
23 | #include <tvm/runtime/registry.h> |
24 | #include <tvm/tir/data_type_rewriter.h> |
25 | #include <tvm/tir/function.h> |
26 | #include <tvm/tir/stmt_functor.h> |
27 | |
28 | #include <functional> |
29 | |
30 | #include "functor_common.h" |
31 | |
32 | namespace tvm { |
33 | namespace tir { |
34 | |
35 | void StmtVisitor::VisitStmt_(const LetStmtNode* op) { |
36 | this->VisitExpr(op->value); |
37 | this->VisitStmt(op->body); |
38 | } |
39 | |
40 | void StmtVisitor::VisitStmt_(const AttrStmtNode* op) { |
41 | this->VisitExpr(op->value); |
42 | this->VisitStmt(op->body); |
43 | } |
44 | |
45 | void StmtVisitor::VisitStmt_(const ForNode* op) { |
46 | this->VisitExpr(op->min); |
47 | this->VisitExpr(op->extent); |
48 | this->VisitStmt(op->body); |
49 | } |
50 | |
51 | void StmtVisitor::VisitStmt_(const WhileNode* op) { |
52 | this->VisitExpr(op->condition); |
53 | this->VisitStmt(op->body); |
54 | } |
55 | |
56 | void StmtVisitor::VisitStmt_(const AllocateNode* op) { |
57 | VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
58 | this->VisitStmt(op->body); |
59 | this->VisitExpr(op->condition); |
60 | } |
61 | |
62 | void StmtVisitor::VisitStmt_(const AllocateConstNode* op) { |
63 | VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
64 | this->VisitStmt(op->body); |
65 | } |
66 | |
67 | void StmtVisitor::VisitStmt_(const DeclBufferNode* op) { this->VisitStmt(op->body); } |
68 | |
69 | void StmtVisitor::VisitStmt_(const StoreNode* op) { |
70 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
71 | } |
72 | |
73 | void StmtVisitor::VisitStmt_(const BufferStoreNode* op) { |
74 | this->VisitExpr(op->value); |
75 | VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
76 | } |
77 | |
78 | void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { |
79 | VisitArray(op->bounds, [this](const Range& r) { |
80 | this->VisitExpr(r->min); |
81 | this->VisitExpr(r->extent); |
82 | }); |
83 | this->VisitExpr(op->condition); |
84 | this->VisitStmt(op->body); |
85 | } |
86 | |
87 | void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { |
88 | this->VisitExpr(op->condition); |
89 | this->VisitStmt(op->then_case); |
90 | if (op->else_case) { |
91 | this->VisitStmt(op->else_case.value()); |
92 | } |
93 | } |
94 | |
95 | void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { |
96 | this->VisitExpr(op->condition); |
97 | this->VisitExpr(op->message); |
98 | this->VisitStmt(op->body); |
99 | } |
100 | |
101 | void StmtVisitor::VisitStmt_(const ProducerStoreNode* op) { |
102 | VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
103 | this->VisitExpr(op->value); |
104 | } |
105 | |
106 | void StmtVisitor::VisitStmt_(const ProducerRealizeNode* op) { |
107 | VisitArray(op->bounds, [this](const Range& r) { |
108 | this->VisitExpr(r->min); |
109 | this->VisitExpr(r->extent); |
110 | }); |
111 | this->VisitStmt(op->body); |
112 | this->VisitExpr(op->condition); |
113 | } |
114 | |
115 | void StmtVisitor::VisitStmt_(const PrefetchNode* op) { |
116 | VisitArray(op->bounds, [this](const Range& r) { |
117 | this->VisitExpr(r->min); |
118 | this->VisitExpr(r->extent); |
119 | }); |
120 | } |
121 | |
122 | void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { |
123 | VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); |
124 | } |
125 | |
126 | void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); } |
127 | |
128 | void StmtVisitor::VisitStmt_(const BlockNode* op) { |
129 | auto fvisit_buffer_region = [this](const BufferRegion& s) { |
130 | for (const auto& range : s->region) { |
131 | this->VisitExpr(range->min); |
132 | this->VisitExpr(range->extent); |
133 | } |
134 | }; |
135 | VisitArray(op->iter_vars, [this](const IterVar& iter_var) { |
136 | this->VisitExpr(iter_var->dom->min); |
137 | this->VisitExpr(iter_var->dom->extent); |
138 | }); |
139 | VisitArray(op->reads, fvisit_buffer_region); |
140 | VisitArray(op->writes, fvisit_buffer_region); |
141 | VisitArray(op->match_buffers, |
142 | [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { |
143 | fvisit_buffer_region(match_buffer_region->source); |
144 | }); |
145 | if (op->init.defined()) { |
146 | this->VisitStmt(op->init.value()); |
147 | } |
148 | this->VisitStmt(op->body); |
149 | } |
150 | |
151 | void StmtVisitor::VisitStmt_(const BlockRealizeNode* op) { |
152 | VisitArray(op->iter_values, [this](const PrimExpr& e) { this->VisitExpr(e); }); |
153 | this->VisitExpr(op->predicate); |
154 | this->VisitStmt(op->block); |
155 | } |
156 | |
157 | class StmtMutator::Internal { |
158 | public: |
159 | /*! |
160 | * \brief Mutate array's element by fmutate function. |
161 | * |
162 | * \note Use extra care for copy on write setting. |
163 | * |
164 | * In particular, consider the following case of two reference chains: |
165 | * - strongref0 -> loop0 -> loop1 -> loop2 |
166 | * - strongref1 -> loop3 -> loop1 -> loop2 |
167 | * |
168 | * Think of the case of calling MutateArray on loop1->loop2(as const reference). |
169 | * When both strongref0 and strongref1 exists, the context does not allow copy |
170 | * on write, even though loop1 uniquely refers to loop2. |
171 | * |
172 | * \param self The pointer to the mutator. |
173 | * \param arr Array to be mutated, const reference is used to allow copy on write |
174 | * mutation in a recursive visitor. |
175 | * \param fmutate The mutator function. |
176 | * \return The mutated array, a new copy can be created. |
177 | */ |
178 | template <typename T, typename F> |
179 | static Array<T> MutateArray(StmtMutator* self, const Array<T>& arr, F fmutate) { |
180 | if (self->allow_copy_on_write_ && arr.unique()) { |
181 | // if we allow copy on write, we can directly |
182 | // call the inplace mutate function. |
183 | const_cast<Array<T>&>(arr).MutateByApply(fmutate); |
184 | return arr; |
185 | } else { |
186 | bool allow_cow = false; |
187 | std::swap(allow_cow, self->allow_copy_on_write_); |
188 | Array<T> copy = arr.Map(fmutate); |
189 | std::swap(allow_cow, self->allow_copy_on_write_); |
190 | return copy; |
191 | } |
192 | } |
193 | |
194 | static Array<IterVar> Mutate(StmtMutator* self, const Array<IterVar>& arr) { |
195 | auto fmutate = [self](const IterVar& iter_var) { |
196 | PrimExpr min = self->VisitExpr(iter_var->dom->min); |
197 | PrimExpr extent = self->VisitExpr(iter_var->dom->extent); |
198 | if (min.same_as(iter_var->dom->min) && extent.same_as(iter_var->dom->extent)) { |
199 | return iter_var; |
200 | } else { |
201 | return IterVar(Range(min, extent), iter_var->var, iter_var->iter_type, |
202 | iter_var->thread_tag); |
203 | } |
204 | }; |
205 | return MutateArray(self, arr, fmutate); |
206 | } |
207 | |
208 | static Array<PrimExpr> Mutate(StmtMutator* self, const Array<PrimExpr>& arr) { |
209 | auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); }; |
210 | return MutateArray(self, arr, fmutate); |
211 | } |
212 | |
213 | static Array<Stmt> Mutate(StmtMutator* self, const Array<Stmt>& arr) { |
214 | auto fmutate = [self](const Stmt& s) { return self->VisitStmt(s); }; |
215 | return MutateArray(self, arr, fmutate); |
216 | } |
217 | |
218 | static Array<Range> Mutate(StmtMutator* self, const Array<Range>& arr) { |
219 | auto fmutate = [self](const Range& r) { |
220 | PrimExpr min = self->VisitExpr(r->min); |
221 | PrimExpr extent = self->VisitExpr(r->extent); |
222 | if (min.same_as(r->min) && extent.same_as(r->extent)) { |
223 | return r; |
224 | } else { |
225 | return Range::FromMinExtent(min, extent); |
226 | } |
227 | }; |
228 | return MutateArray(self, arr, fmutate); |
229 | } |
230 | |
231 | static Array<BufferRegion> Mutate(StmtMutator* self, const Array<BufferRegion>& arr) { |
232 | auto fmutate = [self](const BufferRegion& buffer_region) { |
233 | Array<Range> region = Mutate(self, buffer_region->region); |
234 | if (region.same_as(buffer_region->region)) { |
235 | return buffer_region; |
236 | } else { |
237 | return BufferRegion(buffer_region->buffer, region); |
238 | } |
239 | }; |
240 | return MutateArray(self, arr, fmutate); |
241 | } |
242 | |
243 | static Array<MatchBufferRegion> Mutate(StmtMutator* self, const Array<MatchBufferRegion>& arr) { |
244 | auto fmutate = [self](const MatchBufferRegion& match_buffer_region) { |
245 | Array<Range> region = Mutate(self, match_buffer_region->source->region); |
246 | if (region.same_as(match_buffer_region->source->region)) { |
247 | return match_buffer_region; |
248 | } else { |
249 | return MatchBufferRegion(match_buffer_region->buffer, |
250 | BufferRegion(match_buffer_region->source->buffer, region)); |
251 | } |
252 | }; |
253 | return MutateArray(self, arr, fmutate); |
254 | } |
255 | }; |
256 | |
257 | Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { |
258 | PrimExpr value = this->VisitExpr(op->value); |
259 | Stmt body = this->VisitStmt(op->body); |
260 | if (value.same_as(op->value) && body.same_as(op->body)) { |
261 | return GetRef<Stmt>(op); |
262 | } else { |
263 | auto n = CopyOnWrite(op); |
264 | n->value = std::move(value); |
265 | n->body = std::move(body); |
266 | return Stmt(n); |
267 | } |
268 | } |
269 | |
270 | Stmt StmtMutator::VisitStmt_(const LetStmtNode* op) { |
271 | PrimExpr value = this->VisitExpr(op->value); |
272 | Stmt body = this->VisitStmt(op->body); |
273 | if (value.same_as(op->value) && body.same_as(op->body)) { |
274 | return GetRef<Stmt>(op); |
275 | } else { |
276 | auto n = CopyOnWrite(op); |
277 | n->value = std::move(value); |
278 | n->body = std::move(body); |
279 | return Stmt(n); |
280 | } |
281 | } |
282 | |
283 | Stmt StmtMutator::VisitStmt_(const ForNode* op) { |
284 | PrimExpr min = this->VisitExpr(op->min); |
285 | PrimExpr extent = this->VisitExpr(op->extent); |
286 | Stmt body = this->VisitStmt(op->body); |
287 | if (min.same_as(op->min) && extent.same_as(op->extent) && body.same_as(op->body)) { |
288 | return GetRef<Stmt>(op); |
289 | } else { |
290 | auto n = CopyOnWrite(op); |
291 | n->min = std::move(min); |
292 | n->extent = std::move(extent); |
293 | n->body = std::move(body); |
294 | return Stmt(n); |
295 | } |
296 | } |
297 | |
298 | Stmt StmtMutator::VisitStmt_(const WhileNode* op) { |
299 | PrimExpr condition = this->VisitExpr(op->condition); |
300 | Stmt body = this->VisitStmt(op->body); |
301 | if (condition.same_as(op->condition) && body.same_as(op->body)) { |
302 | return GetRef<Stmt>(op); |
303 | } else { |
304 | auto n = CopyOnWrite(op); |
305 | n->condition = std::move(condition); |
306 | n->body = std::move(body); |
307 | return Stmt(n); |
308 | } |
309 | } |
310 | |
311 | Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { |
312 | Array<PrimExpr> extents = Internal::Mutate(this, op->extents); |
313 | Stmt body = this->VisitStmt(op->body); |
314 | PrimExpr condition = this->VisitExpr(op->condition); |
315 | |
316 | if (extents.same_as(op->extents) && body.same_as(op->body) && condition.same_as(op->condition)) { |
317 | return GetRef<Stmt>(op); |
318 | } else { |
319 | auto n = CopyOnWrite(op); |
320 | n->extents = std::move(extents); |
321 | n->body = std::move(body); |
322 | n->condition = std::move(condition); |
323 | return Stmt(n); |
324 | } |
325 | } |
326 | |
327 | Stmt StmtMutator::VisitStmt_(const AllocateConstNode* op) { |
328 | Array<PrimExpr> extents = Internal::Mutate(this, op->extents); |
329 | Stmt body = this->VisitStmt(op->body); |
330 | |
331 | if (extents.same_as(op->extents) && body.same_as(op->body)) { |
332 | return GetRef<Stmt>(op); |
333 | } else { |
334 | auto n = CopyOnWrite(op); |
335 | n->extents = std::move(extents); |
336 | n->body = std::move(body); |
337 | return Stmt(n); |
338 | } |
339 | } |
340 | |
341 | Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { |
342 | Stmt body = this->VisitStmt(op->body); |
343 | |
344 | if (body.same_as(op->body)) { |
345 | return GetRef<Stmt>(op); |
346 | } else { |
347 | auto n = CopyOnWrite(op); |
348 | n->body = std::move(body); |
349 | return Stmt(n); |
350 | } |
351 | } |
352 | |
353 | Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { |
354 | PrimExpr condition = this->VisitExpr(op->condition); |
355 | Stmt then_case = this->VisitStmt(op->then_case); |
356 | Optional<Stmt> else_case = NullOpt; |
357 | if (op->else_case) { |
358 | else_case = this->VisitStmt(op->else_case.value()); |
359 | } |
360 | if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && |
361 | else_case.same_as(op->else_case)) { |
362 | return GetRef<Stmt>(op); |
363 | } else { |
364 | auto n = CopyOnWrite(op); |
365 | n->condition = std::move(condition); |
366 | n->then_case = std::move(then_case); |
367 | n->else_case = std::move(else_case); |
368 | return Stmt(n); |
369 | } |
370 | } |
371 | |
372 | Stmt StmtMutator::VisitStmt_(const StoreNode* op) { |
373 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
374 | } |
375 | |
376 | Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) { |
377 | PrimExpr value = this->VisitExpr(op->value); |
378 | Array<PrimExpr> indices = Internal::Mutate(this, op->indices); |
379 | |
380 | if (value.same_as(op->value) && indices.same_as(op->indices)) { |
381 | return GetRef<Stmt>(op); |
382 | } else { |
383 | auto n = CopyOnWrite(op); |
384 | n->value = std::move(value); |
385 | n->indices = std::move(indices); |
386 | return Stmt(n); |
387 | } |
388 | } |
389 | |
390 | Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { |
391 | Region bounds = Internal::Mutate(this, op->bounds); |
392 | PrimExpr condition = this->VisitExpr(op->condition); |
393 | Stmt body = this->VisitStmt(op->body); |
394 | |
395 | if (bounds.same_as(op->bounds) && condition.same_as(op->condition) && body.same_as(op->body)) { |
396 | return GetRef<Stmt>(op); |
397 | } else { |
398 | auto n = CopyOnWrite(op); |
399 | n->bounds = std::move(bounds); |
400 | n->condition = std::move(condition); |
401 | n->body = std::move(body); |
402 | return Stmt(n); |
403 | } |
404 | } |
405 | |
406 | Stmt StmtMutator::VisitStmt_(const ProducerStoreNode* op) { |
407 | Array<PrimExpr> indices = Internal::Mutate(this, op->indices); |
408 | PrimExpr value = this->VisitExpr(op->value); |
409 | if (indices.same_as(op->indices) && value.same_as(op->value)) { |
410 | return GetRef<Stmt>(op); |
411 | } else { |
412 | auto n = CopyOnWrite(op); |
413 | n->indices = std::move(indices); |
414 | n->value = std::move(value); |
415 | return Stmt(n); |
416 | } |
417 | } |
418 | |
419 | Stmt StmtMutator::VisitStmt_(const ProducerRealizeNode* op) { |
420 | Region bounds = Internal::Mutate(this, op->bounds); |
421 | Stmt body = this->VisitStmt(op->body); |
422 | PrimExpr condition = this->VisitExpr(op->condition); |
423 | if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) { |
424 | return GetRef<Stmt>(op); |
425 | } else { |
426 | auto n = CopyOnWrite(op); |
427 | n->bounds = std::move(bounds); |
428 | n->body = std::move(body); |
429 | n->condition = std::move(condition); |
430 | return Stmt(n); |
431 | } |
432 | } |
433 | |
434 | Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) { |
435 | Region bounds = Internal::Mutate(this, op->bounds); |
436 | if (bounds.same_as(op->bounds)) { |
437 | return GetRef<Stmt>(op); |
438 | } else { |
439 | auto n = CopyOnWrite(op); |
440 | n->bounds = std::move(bounds); |
441 | return Stmt(n); |
442 | } |
443 | } |
444 | |
445 | Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { |
446 | Array<Stmt> seq = Internal::Mutate(this, op->seq); |
447 | if (seq.same_as(op->seq)) { |
448 | return GetRef<Stmt>(op); |
449 | } else { |
450 | auto n = CopyOnWrite(op); |
451 | n->seq = std::move(seq); |
452 | return Stmt(n); |
453 | } |
454 | } |
455 | |
456 | // advanced visit function for seqstmt. |
457 | Stmt StmtMutator::VisitSeqStmt_(const SeqStmtNode* op, bool flatten_before_visit, |
458 | std::function<Stmt(const Stmt&)> fmutate) { |
459 | if (flatten_before_visit) { |
460 | // Pass 1, check if we need to flatten. |
461 | bool need_flatten = false; |
462 | for (size_t i = 0; i < op->seq.size(); ++i) { |
463 | Stmt tmp = (*op)[i]; |
464 | if (tmp.as<SeqStmtNode>()) need_flatten = true; |
465 | } |
466 | flatten_before_visit = need_flatten; |
467 | } |
468 | // function to run the visit. |
469 | auto frunvisit = [&](const SeqStmtNode* op) { |
470 | Array<Stmt> seq = fmutate != nullptr ? Internal::MutateArray(this, op->seq, fmutate) |
471 | : Internal::Mutate(this, op->seq); |
472 | if (seq.same_as(op->seq)) { |
473 | return GetRef<Stmt>(op); |
474 | } else { |
475 | auto n = CopyOnWrite(op); |
476 | n->seq = std::move(seq); |
477 | return Stmt(n); |
478 | } |
479 | }; |
480 | if (flatten_before_visit) { |
481 | Array<Stmt> seq; |
482 | SeqStmt::Flattener flattener(&seq); |
483 | flattener(0, op->seq); |
484 | // NOTE: If copy on write is allowed |
485 | // the assignment to seq below will |
486 | // destruct the original seq. |
487 | // |
488 | // Such destruction removes duplicated reference |
489 | // count to children and still enables COW for |
490 | // child Stmt. |
491 | ObjectPtr<SeqStmtNode> n = CopyOnWrite(op); |
492 | n->seq = std::move(seq); |
493 | return frunvisit(n.operator->()); |
494 | } else { |
495 | return frunvisit(op); |
496 | } |
497 | } |
498 | |
499 | Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { |
500 | PrimExpr condition = this->VisitExpr(op->condition); |
501 | PrimExpr message = this->VisitExpr(op->message); |
502 | Stmt body = this->VisitStmt(op->body); |
503 | |
504 | if (condition.same_as(op->condition) && message.same_as(op->message) && body.same_as(op->body)) { |
505 | return GetRef<Stmt>(op); |
506 | } else { |
507 | auto n = CopyOnWrite(op); |
508 | n->condition = std::move(condition); |
509 | n->message = std::move(message); |
510 | n->body = std::move(body); |
511 | return Stmt(n); |
512 | } |
513 | } |
514 | |
515 | Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { |
516 | PrimExpr value = this->VisitExpr(op->value); |
517 | if (value.same_as(op->value)) { |
518 | return GetRef<Stmt>(op); |
519 | } else { |
520 | auto n = CopyOnWrite(op); |
521 | n->value = std::move(value); |
522 | return Stmt(n); |
523 | } |
524 | } |
525 | |
526 | Stmt StmtMutator::VisitStmt_(const BlockNode* op) { |
527 | Array<IterVar> iter_vars = Internal::Mutate(this, op->iter_vars); |
528 | Array<BufferRegion> reads = Internal::Mutate(this, op->reads); |
529 | Array<BufferRegion> writes = Internal::Mutate(this, op->writes); |
530 | Array<MatchBufferRegion> match_buffers = Internal::Mutate(this, op->match_buffers); |
531 | Optional<Stmt> init = NullOpt; |
532 | if (op->init.defined()) { |
533 | init = VisitStmt(op->init.value()); |
534 | } |
535 | Stmt body = VisitStmt(op->body); |
536 | if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && writes.same_as(op->writes) && |
537 | body.same_as(op->body) && init.same_as(op->init) && |
538 | match_buffers.same_as(op->match_buffers)) { |
539 | return GetRef<Block>(op); |
540 | } else { |
541 | auto n = CopyOnWrite(op); |
542 | n->iter_vars = std::move(iter_vars); |
543 | n->reads = std::move(reads); |
544 | n->writes = std::move(writes); |
545 | n->body = std::move(body); |
546 | n->init = std::move(init); |
547 | n->match_buffers = std::move(match_buffers); |
548 | return Stmt(n); |
549 | } |
550 | } |
551 | |
552 | Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) { |
553 | Array<PrimExpr> v = Internal::Mutate(this, op->iter_values); |
554 | PrimExpr pred = this->VisitExpr(op->predicate); |
555 | Stmt block = this->VisitStmt(op->block); |
556 | if (v.same_as(op->iter_values) && pred.same_as(op->predicate) && block.same_as(op->block)) { |
557 | return GetRef<Stmt>(op); |
558 | } else { |
559 | auto n = CopyOnWrite(op); |
560 | n->iter_values = std::move(v); |
561 | n->predicate = std::move(pred); |
562 | n->block = Downcast<Block>(block); |
563 | return Stmt(n); |
564 | } |
565 | } |
566 | |
567 | // Implementations of IRTransform, PostOrderVisit and Substitute |
568 | class IRApplyVisit : public StmtExprVisitor { |
569 | public: |
570 | explicit IRApplyVisit(std::function<void(const ObjectRef&)> f) : f_(f) {} |
571 | |
572 | void VisitExpr(const PrimExpr& node) final { |
573 | if (visited_.count(node.get()) != 0) return; |
574 | visited_.insert(node.get()); |
575 | ExprVisitor::VisitExpr(node); |
576 | f_(node); |
577 | } |
578 | |
579 | void VisitStmt(const Stmt& node) final { |
580 | if (visited_.count(node.get()) != 0) return; |
581 | visited_.insert(node.get()); |
582 | StmtVisitor::VisitStmt(node); |
583 | f_(node); |
584 | } |
585 | |
586 | private: |
587 | std::function<void(const ObjectRef&)> f_; |
588 | std::unordered_set<const Object*> visited_; |
589 | }; |
590 | |
591 | void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit) { |
592 | if (node.as<StmtNode>()) { |
593 | IRApplyVisit visitor(fvisit); |
594 | visitor(Downcast<Stmt>(node)); |
595 | } else { |
596 | IRApplyVisit visitor(fvisit); |
597 | visitor(Downcast<PrimExpr>(node)); |
598 | } |
599 | } |
600 | |
601 | class IRTransformer final : public StmtExprMutator { |
602 | public: |
603 | IRTransformer(const runtime::PackedFunc& f_preorder, const runtime::PackedFunc& f_postorder, |
604 | const std::unordered_set<uint32_t>& only_enable) |
605 | : f_preorder_(f_preorder), f_postorder_(f_postorder), only_enable_(only_enable) {} |
606 | |
607 | Stmt VisitStmt(const Stmt& stmt) final { |
608 | return MutateInternal<Stmt>(stmt, [this](const Stmt& s) { return this->BaseVisitStmt(s); }); |
609 | } |
610 | PrimExpr VisitExpr(const PrimExpr& expr) final { |
611 | return MutateInternal<PrimExpr>(expr, |
612 | [this](const PrimExpr& e) { return this->BaseVisitExpr(e); }); |
613 | } |
614 | |
615 | private: |
616 | // NOTE: redirect to parent's call |
617 | // This is used to get around limitation of gcc-4.8 |
618 | Stmt BaseVisitStmt(const Stmt& s) { return StmtMutator::VisitStmt(s); } |
619 | PrimExpr BaseVisitExpr(const PrimExpr& e) { return ExprMutator::VisitExpr(e); } |
620 | |
621 | template <typename T, typename F> |
622 | T MutateInternal(const T& node, F fmutate) { |
623 | if (only_enable_.size() && !only_enable_.count(node->type_index())) { |
624 | return fmutate(node); |
625 | } |
626 | if (f_preorder_ != nullptr) { |
627 | T pre = f_preorder_(node); |
628 | if (pre.defined()) return pre; |
629 | } |
630 | T new_node = fmutate(node); |
631 | if (f_postorder_ != nullptr) { |
632 | T post = f_postorder_(new_node); |
633 | if (post.defined()) return post; |
634 | } |
635 | return new_node; |
636 | } |
637 | // The functions |
638 | const runtime::PackedFunc& f_preorder_; |
639 | const runtime::PackedFunc& f_postorder_; |
640 | // type indices enabled. |
641 | const std::unordered_set<uint32_t>& only_enable_; |
642 | }; |
643 | |
644 | Stmt IRTransform(Stmt ir_node, const runtime::PackedFunc& f_preorder, |
645 | const runtime::PackedFunc& f_postorder, Optional<Array<String>> only_enable) { |
646 | std::unordered_set<uint32_t> only_type_index; |
647 | if (only_enable.defined()) { |
648 | for (auto s : only_enable.value()) { |
649 | only_type_index.insert(Object::TypeKey2Index(s.c_str())); |
650 | } |
651 | } |
652 | IRTransformer transform(f_preorder, f_postorder, only_type_index); |
653 | return transform(std::move(ir_node)); |
654 | } |
655 | |
656 | class IRSubstitute : public StmtExprMutator { |
657 | public: |
658 | explicit IRSubstitute(std::function<Optional<PrimExpr>(const Var&)> vmap) : vmap_(vmap) {} |
659 | |
660 | PrimExpr VisitExpr_(const VarNode* op) final { |
661 | Var var = GetRef<Var>(op); |
662 | auto ret = vmap_(var); |
663 | if (ret.defined()) { |
664 | // Allow substitution of void variables with any expression. The TVM script parser |
665 | // uses void variables for lambda parameters (since exact types are not known yet). |
666 | if (!var.dtype().is_void()) { |
667 | PrimExpr ret_ex = Downcast<PrimExpr>(ret.value()); |
668 | ICHECK(ret_ex.dtype() == var.dtype()) << "substituting " << var << ":" << var.dtype() |
669 | << " -> " << ret_ex << ":" << ret_ex.dtype(); |
670 | } |
671 | return ret.value(); |
672 | } |
673 | return std::move(var); |
674 | } |
675 | |
676 | PrimExpr VisitExpr_(const LoadNode* op) final { |
677 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
678 | } |
679 | |
680 | Stmt VisitStmt_(const StoreNode* op) final { |
681 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
682 | } |
683 | |
684 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
685 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
686 | return VisitBufferAccess(std::move(node)); |
687 | } |
688 | |
689 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
690 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
691 | return VisitBufferAccess(std::move(node)); |
692 | } |
693 | |
694 | template <typename Node> |
695 | Node VisitBufferAccess(Node node) { |
696 | Buffer new_buf = GetRemappedBuffer(node->buffer); |
697 | |
698 | if (!new_buf.same_as(node->buffer)) { |
699 | auto writer = node.CopyOnWrite(); |
700 | writer->buffer = new_buf; |
701 | } |
702 | |
703 | return node; |
704 | } |
705 | |
706 | Buffer GetRemappedBuffer(Buffer buf) { |
707 | auto key = buf.get(); |
708 | auto it = buf_remap_.find(key); |
709 | if (it != buf_remap_.end()) { |
710 | return it->second; |
711 | } |
712 | |
713 | auto new_buffer_var = vmap_(buf->data); |
714 | if (new_buffer_var.defined() && !new_buffer_var.value().same_as(buf->data)) { |
715 | auto writer = buf.CopyOnWrite(); |
716 | writer->data = Downcast<Var>(new_buffer_var); |
717 | } |
718 | |
719 | buf_remap_[key] = buf; |
720 | return buf; |
721 | } |
722 | |
723 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
724 | Stmt ret = StmtExprMutator::VisitStmt_(op); |
725 | op = ret.as<AttrStmtNode>(); |
726 | // remap var node in attr |
727 | if (const auto* var_node = op->node.as<VarNode>()) { |
728 | if (auto mapped_var = vmap_(GetRef<Var>(var_node))) { |
729 | return AttrStmt(mapped_var, op->attr_key, op->value, op->body); |
730 | } |
731 | } |
732 | return ret; |
733 | } |
734 | |
735 | private: |
736 | // Caller provided function that defines the variables to be remapped. |
737 | std::function<Optional<PrimExpr>(const Var&)> vmap_; |
738 | |
739 | /* \brief Generated map to track buffers being remapped. |
740 | * |
741 | * If a `Var BufferNode::data` is remapped, then all buffers |
742 | * containing that data pointer should also be remapped. This map |
743 | * is used to track buffer modifications, and ensure all instances |
744 | * of a buffer are replaced by the same modified buffer object. |
745 | */ |
746 | std::unordered_map<const BufferNode*, Buffer> buf_remap_; |
747 | }; |
748 | |
749 | Stmt Substitute(Stmt stmt, std::function<Optional<PrimExpr>(const Var&)> vmap) { |
750 | return IRSubstitute(vmap)(std::move(stmt)); |
751 | } |
752 | |
753 | PrimExpr Substitute(PrimExpr expr, std::function<Optional<PrimExpr>(const Var&)> vmap) { |
754 | return IRSubstitute(vmap)(std::move(expr)); |
755 | } |
756 | |
757 | Array<Range> Substitute(const Array<Range>& region, const Map<Var, PrimExpr>& vmap) { |
758 | Array<Range> result; |
759 | result.reserve(region.size()); |
760 | for (const Range& range : region) { |
761 | PrimExpr min = Substitute(range->min, vmap); |
762 | PrimExpr extent = Substitute(range->extent, vmap); |
763 | result.push_back(Range::FromMinExtent(std::move(min), std::move(extent))); |
764 | } |
765 | return result; |
766 | } |
767 | |
768 | void PreOrderVisit(const ObjectRef& stmt_or_expr, |
769 | const std::function<bool(const ObjectRef&)>& fvisit) { |
770 | class PreOrderVisitor : public StmtExprVisitor { |
771 | public: |
772 | explicit PreOrderVisitor(const std::function<bool(const ObjectRef&)>& f) : f_(f) {} |
773 | |
774 | private: |
775 | void VisitExpr(const PrimExpr& expr) final { |
776 | const PrimExprNode* p_expr = expr.get(); |
777 | if (visited_.count(p_expr) == 0) { |
778 | visited_.insert(p_expr); |
779 | if (f_(expr)) { |
780 | ExprVisitor::VisitExpr(expr); |
781 | } |
782 | } |
783 | } |
784 | |
785 | void VisitStmt(const Stmt& stmt) final { |
786 | const StmtNode* p_stmt = stmt.get(); |
787 | if (visited_.count(p_stmt) == 0) { |
788 | visited_.insert(p_stmt); |
789 | if (f_(stmt)) { |
790 | StmtVisitor::VisitStmt(stmt); |
791 | } |
792 | } |
793 | } |
794 | |
795 | const std::function<bool(const ObjectRef&)>& f_; |
796 | std::unordered_set<const Object*> visited_; |
797 | }; |
798 | |
799 | PreOrderVisitor visitor(fvisit); |
800 | if (const auto* stmt = stmt_or_expr.as<StmtNode>()) { |
801 | visitor(GetRef<Stmt>(stmt)); |
802 | } else if (const auto* expr = stmt_or_expr.as<PrimExprNode>()) { |
803 | visitor(GetRef<PrimExpr>(expr)); |
804 | } else { |
805 | LOG(FATAL) << "InternalError: PreOrderVisit does not accept object with type: " |
806 | << stmt_or_expr->GetTypeKey(); |
807 | } |
808 | } |
809 | |
810 | class IRSubstituteWithDataTypeLegalization : public DataTypeLegalizer { |
811 | public: |
812 | explicit IRSubstituteWithDataTypeLegalization(std::function<Optional<PrimExpr>(const Var&)> vmap) |
813 | : vmap_(vmap) {} |
814 | |
815 | using DataTypeLegalizer::VisitExpr_; |
816 | using DataTypeLegalizer::VisitStmt_; |
817 | |
818 | PrimExpr VisitExpr_(const VarNode* op) final { |
819 | Var var = GetRef<Var>(op); |
820 | auto ret = vmap_(var); |
821 | if (ret.defined()) { |
822 | return ret.value(); |
823 | } |
824 | return std::move(var); |
825 | } |
826 | |
827 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
828 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
829 | return VisitBufferAccess(std::move(node)); |
830 | } |
831 | |
832 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
833 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
834 | return VisitBufferAccess(std::move(node)); |
835 | } |
836 | |
837 | template <typename Node> |
838 | Node VisitBufferAccess(Node node) { |
839 | Buffer new_buf = GetRemappedBuffer(node->buffer); |
840 | |
841 | if (!new_buf.same_as(node->buffer)) { |
842 | auto writer = node.CopyOnWrite(); |
843 | writer->buffer = new_buf; |
844 | } |
845 | |
846 | return node; |
847 | } |
848 | |
849 | Buffer GetRemappedBuffer(Buffer buf) { |
850 | auto key = buf.get(); |
851 | auto it = buf_remap_.find(key); |
852 | if (it != buf_remap_.end()) { |
853 | return it->second; |
854 | } |
855 | |
856 | auto new_buffer_var = vmap_(buf->data); |
857 | if (new_buffer_var.defined() && !new_buffer_var.value().same_as(buf->data)) { |
858 | auto writer = buf.CopyOnWrite(); |
859 | writer->data = Downcast<Var>(new_buffer_var); |
860 | } |
861 | |
862 | buf_remap_[key] = buf; |
863 | return buf; |
864 | } |
865 | |
866 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
867 | Stmt ret = StmtExprMutator::VisitStmt_(op); |
868 | op = ret.as<AttrStmtNode>(); |
869 | // remap var node in attr |
870 | if (const auto* var_node = op->node.as<VarNode>()) { |
871 | if (auto mapped_var = vmap_(GetRef<Var>(var_node))) { |
872 | return AttrStmt(mapped_var, op->attr_key, op->value, op->body); |
873 | } |
874 | } |
875 | return ret; |
876 | } |
877 | |
878 | private: |
879 | // Caller provided function that defines the variables to be remapped. |
880 | std::function<Optional<PrimExpr>(const Var&)> vmap_; |
881 | |
882 | /* \brief Generated map to track buffers being remapped. |
883 | * |
884 | * If a `Var BufferNode::data` is remapped, then all buffers |
885 | * containing that data pointer should also be remapped. This map |
886 | * is used to track buffer modifications, and ensure all instances |
887 | * of a buffer are replaced by the same modified buffer object. |
888 | */ |
889 | std::unordered_map<const BufferNode*, Buffer> buf_remap_; |
890 | }; |
891 | |
892 | Stmt SubstituteWithDataTypeLegalization(Stmt stmt, |
893 | std::function<Optional<PrimExpr>(const Var&)> vmap) { |
894 | return IRSubstituteWithDataTypeLegalization(vmap)(std::move(stmt)); |
895 | } |
896 | |
897 | PrimExpr SubstituteWithDataTypeLegalization(PrimExpr expr, |
898 | std::function<Optional<PrimExpr>(const Var&)> vmap) { |
899 | return IRSubstituteWithDataTypeLegalization(vmap)(std::move(expr)); |
900 | } |
901 | |
902 | TVM_REGISTER_GLOBAL("tir.IRTransform" ).set_body_typed(IRTransform); |
903 | |
904 | TVM_REGISTER_GLOBAL("tir.PostOrderVisit" ).set_body_typed([](ObjectRef node, PackedFunc f) { |
905 | tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); |
906 | }); |
907 | |
908 | TVM_REGISTER_GLOBAL("tir.PreOrderVisit" ).set_body_typed([](ObjectRef node, PackedFunc f) { |
909 | tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n); }); |
910 | }); |
911 | |
912 | TVM_REGISTER_GLOBAL("tir.Substitute" ) |
913 | .set_body_typed([](ObjectRef node, Map<Var, PrimExpr> vmap) -> ObjectRef { |
914 | if (node->IsInstance<StmtNode>()) { |
915 | return Substitute(Downcast<Stmt>(node), vmap); |
916 | } else { |
917 | return Substitute(Downcast<PrimExpr>(node), vmap); |
918 | } |
919 | }); |
920 | |
921 | } // namespace tir |
922 | } // namespace tvm |
923 | |