1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file renew_defs.cc |
22 | * \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar. |
23 | */ |
24 | |
25 | #include <tvm/tir/stmt_functor.h> |
26 | #include <tvm/tir/transform.h> |
27 | |
28 | #include "../ir/functor_common.h" |
29 | |
30 | namespace tvm { |
31 | namespace tir { |
32 | |
33 | #define STMT_REGENERATE_VAR_DEF(NODE, FIELD) \ |
34 | Stmt VisitStmt_(const NODE* op) final { \ |
35 | Var new_var = this->ReDefineVar(op->FIELD); \ |
36 | Stmt stmt = StmtExprMutator::VisitStmt_(op); \ |
37 | op = stmt.as<NODE>(); \ |
38 | ICHECK(op != nullptr); \ |
39 | auto n = make_object<NODE>(*op); \ |
40 | n->FIELD = std::move(new_var); \ |
41 | return Stmt(n); \ |
42 | } |
43 | |
44 | class RenewDefMutator : public StmtExprMutator { |
45 | public: |
46 | static PrimFunc Transform(const PrimFunc& func) { |
47 | RenewDefMutator generator; |
48 | // Redefine params |
49 | Array<Var> params; |
50 | for (const auto& param : func->params) { |
51 | params.push_back(generator.ReDefineVar(param)); |
52 | } |
53 | // Redefine buffers in order |
54 | // TODO(Siyuan Feng): checking var is used after define |
55 | Map<tir::Var, Buffer> buffer_map; |
56 | for (const auto& param : func->params) { |
57 | if (param->dtype.is_handle()) { |
58 | const Buffer& buffer = func->buffer_map.at(param); |
59 | Var new_param = Downcast<Var>(generator.VisitExpr(param)); |
60 | Buffer new_buffer = generator.VisitBuffer(buffer, true); |
61 | buffer_map.Set(new_param, new_buffer); |
62 | } |
63 | } |
64 | // Visit body |
65 | Stmt body = generator(func->body); |
66 | // Recreate function |
67 | auto n = make_object<PrimFuncNode>(*func.get()); |
68 | n->params = std::move(params); |
69 | n->buffer_map = std::move(buffer_map); |
70 | n->body = std::move(body); |
71 | return PrimFunc(n); |
72 | } |
73 | |
74 | private: |
75 | Stmt operator()(Stmt stmt) { |
76 | // override StmtMutator::operator() to disable copy_on_write |
77 | // Since this pass tries to explict create a new function rather than update the existing one |
78 | allow_copy_on_write_ = false; |
79 | return VisitStmt(stmt); |
80 | } |
81 | |
82 | PrimExpr VisitExpr(const PrimExpr& expr) final { |
83 | auto it = remap_.find(expr); |
84 | if (it != remap_.end()) { |
85 | return Downcast<PrimExpr>((*it).second); |
86 | } else { |
87 | return ExprMutator::VisitExpr(expr); |
88 | } |
89 | } |
90 | |
91 | private: |
92 | STMT_REGENERATE_VAR_DEF(LetStmtNode, var); |
93 | STMT_REGENERATE_VAR_DEF(AllocateNode, buffer_var); |
94 | STMT_REGENERATE_VAR_DEF(AllocateConstNode, buffer_var); |
95 | STMT_REGENERATE_VAR_DEF(ForNode, loop_var); |
96 | |
97 | Stmt VisitStmt_(const BlockNode* op) final { |
98 | // Step 0. Re-define Itervars |
99 | Array<IterVar> iter_vars = |
100 | op->iter_vars.Map(std::bind(&RenewDefMutator::VisitIterVar, this, std::placeholders::_1)); |
101 | |
102 | // Step 1. Re-define buffers allocate under the block |
103 | Array<Buffer> alloc_buffers = op->alloc_buffers.Map( |
104 | std::bind(&RenewDefMutator::VisitBuffer, this, std::placeholders::_1, /*define=*/true)); |
105 | |
106 | // Step 2. Re-define match_buffers |
107 | Array<MatchBufferRegion> match_buffers = op->match_buffers.Map( |
108 | std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1)); |
109 | |
110 | // Step 3. Visit body |
111 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
112 | op = stmt.as<BlockNode>(); |
113 | ICHECK(op); |
114 | |
115 | // Step 4. Revisit access region |
116 | Array<BufferRegion> reads = |
117 | op->reads.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); |
118 | Array<BufferRegion> writes = |
119 | op->writes.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); |
120 | |
121 | // Step 5. Regenerate block. Since the defs are changed, we need to create a new block |
122 | auto n = make_object<BlockNode>(*op); |
123 | n->iter_vars = std::move(iter_vars); |
124 | n->alloc_buffers = std::move(alloc_buffers); |
125 | n->match_buffers = std::move(match_buffers); |
126 | n->reads = std::move(reads); |
127 | n->writes = std::move(writes); |
128 | |
129 | return Stmt(n); |
130 | } |
131 | |
132 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
133 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
134 | op = stmt.as<BufferStoreNode>(); |
135 | ICHECK(op != nullptr); |
136 | Buffer buffer = VisitDeclOrRemapBuffer(op->buffer); |
137 | if (buffer.same_as(op->buffer)) { |
138 | return stmt; |
139 | } else { |
140 | auto n = make_object<BufferStoreNode>(*op); |
141 | n->buffer = std::move(buffer); |
142 | return BufferStore(n); |
143 | } |
144 | } |
145 | |
146 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
147 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
148 | op = expr.as<BufferLoadNode>(); |
149 | ICHECK(op != nullptr); |
150 | Buffer buffer = VisitDeclOrRemapBuffer(op->buffer); |
151 | if (buffer.same_as(op->buffer)) { |
152 | return expr; |
153 | } else { |
154 | auto n = make_object<BufferLoadNode>(*op); |
155 | n->buffer = std::move(buffer); |
156 | return BufferLoad(n); |
157 | } |
158 | } |
159 | |
160 | PrimExpr VisitExpr_(const LoadNode* op) final { |
161 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
162 | } |
163 | |
164 | Stmt VisitStmt_(const StoreNode* op) final { |
165 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
166 | } |
167 | |
168 | private: |
169 | Var ReDefineVar(const Var& var) { |
170 | Var new_var = Var(make_object<VarNode>(*var.get())); |
171 | this->AddDefRemap(var, new_var); |
172 | return new_var; |
173 | } |
174 | |
175 | template <typename T> |
176 | void AddDefRemap(const T& source, const T& target) { |
177 | ICHECK(remap_.count(source) == 0); |
178 | remap_.Set(source, target); |
179 | } |
180 | |
181 | Buffer VisitBuffer(const Buffer& buffer, bool define = false) { |
182 | auto it = remap_.find(buffer); |
183 | if (it != remap_.end()) { |
184 | return Downcast<Buffer>((*it).second); |
185 | } |
186 | ICHECK(define); |
187 | |
188 | auto redefine_if_is_var = [this](const PrimExpr& expr) -> PrimExpr { |
189 | auto it = remap_.find(expr); |
190 | if (it != remap_.end()) { |
191 | return Downcast<PrimExpr>((*it).second); |
192 | } else if (const VarNode* var = expr.as<VarNode>()) { |
193 | return this->ReDefineVar(GetRef<Var>(var)); |
194 | } else { |
195 | return ExprMutator::VisitExpr(expr); |
196 | } |
197 | }; |
198 | |
199 | // update data |
200 | Var data = Downcast<Var>(redefine_if_is_var(buffer->data)); |
201 | // update shape |
202 | Array<PrimExpr> shape = buffer->shape.Map(redefine_if_is_var); |
203 | // update strides |
204 | Array<PrimExpr> strides = buffer->strides.Map(redefine_if_is_var); |
205 | // update elem_offset |
206 | PrimExpr elem_offset = redefine_if_is_var(buffer->elem_offset); |
207 | |
208 | auto n = make_object<BufferNode>(*buffer.get()); |
209 | n->data = std::move(data); |
210 | n->shape = std::move(shape); |
211 | n->strides = std::move(strides); |
212 | n->elem_offset = std::move(elem_offset); |
213 | Buffer new_buffer(n); |
214 | this->AddDefRemap(buffer, new_buffer); |
215 | return new_buffer; |
216 | } |
217 | |
218 | IterVar VisitIterVar(const IterVar& iter_var) { |
219 | auto it = remap_.find(iter_var); |
220 | if (it != remap_.end()) { |
221 | return Downcast<IterVar>((*it).second); |
222 | } |
223 | PrimExpr min = VisitExpr(iter_var->dom->min); |
224 | PrimExpr extent = VisitExpr(iter_var->dom->extent); |
225 | IterVar new_iter_var(Range(min, extent), ReDefineVar(iter_var->var), iter_var->iter_type, |
226 | iter_var->thread_tag); |
227 | this->AddDefRemap(iter_var, new_iter_var); |
228 | return new_iter_var; |
229 | } |
230 | |
231 | Buffer VisitDeclOrRemapBuffer(const Buffer& buffer) { |
232 | // If the buffer has been remapped, return the remapped buffer, otherwise, |
233 | // return the declared one. |
234 | // Due to a recent PR, we can allow undefined buffer appearing in BufferLoad/Store. We need |
235 | // to remap them but will not create new var |
236 | auto it = remap_.find(buffer); |
237 | if (it != remap_.end()) { |
238 | return Downcast<Buffer>((*it).second); |
239 | } |
240 | Var data = Downcast<Var>(VisitExpr(buffer->data)); |
241 | Array<PrimExpr> shape = |
242 | buffer->shape.Map(std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); |
243 | Array<PrimExpr> strides = |
244 | buffer->strides.Map(std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); |
245 | PrimExpr elem_offset = VisitExpr(buffer->elem_offset); |
246 | |
247 | auto n = make_object<BufferNode>(*buffer.get()); |
248 | n->data = std::move(data); |
249 | n->shape = std::move(shape); |
250 | n->strides = std::move(strides); |
251 | n->elem_offset = std::move(elem_offset); |
252 | Buffer new_buffer(n); |
253 | this->AddDefRemap(buffer, new_buffer); |
254 | return new_buffer; |
255 | } |
256 | |
257 | MatchBufferRegion VisitMatchBuffer(const MatchBufferRegion& match_buffer) { |
258 | Buffer buffer = VisitBuffer(match_buffer->buffer, /*define=*/true); |
259 | BufferRegion region = VisitBufferRegion(match_buffer->source); |
260 | return MatchBufferRegion(std::move(buffer), std::move(region)); |
261 | } |
262 | |
263 | Range VisitRange(const Range& range) { |
264 | PrimExpr min = VisitExpr(range->min); |
265 | PrimExpr extent = VisitExpr(range->extent); |
266 | if (min.same_as(range->min) && extent.same_as(range->extent)) { |
267 | return range; |
268 | } else { |
269 | return Range::FromMinExtent(std::move(min), std::move(extent)); |
270 | } |
271 | } |
272 | |
273 | BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) { |
274 | Buffer buffer = VisitBuffer(buffer_region->buffer); |
275 | Array<Range> region = buffer_region->region.Map( |
276 | std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1)); |
277 | if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { |
278 | return buffer_region; |
279 | } else { |
280 | return BufferRegion(std::move(buffer), std::move(region)); |
281 | } |
282 | } |
283 | |
284 | Map<ObjectRef, ObjectRef> remap_; |
285 | }; |
286 | |
287 | PrimFunc RenewDefs(const PrimFunc& func) { return RenewDefMutator::Transform(func); } |
288 | |
289 | TVM_REGISTER_GLOBAL("tir.RenewDefs" ).set_body_typed(RenewDefs); |
290 | |
291 | } // namespace tir |
292 | } // namespace tvm |
293 | |