1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file renew_defs.cc
22 * \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar.
23 */
24
25#include <tvm/tir/stmt_functor.h>
26#include <tvm/tir/transform.h>
27
28#include "../ir/functor_common.h"
29
30namespace tvm {
31namespace tir {
32
33#define STMT_REGENERATE_VAR_DEF(NODE, FIELD) \
34 Stmt VisitStmt_(const NODE* op) final { \
35 Var new_var = this->ReDefineVar(op->FIELD); \
36 Stmt stmt = StmtExprMutator::VisitStmt_(op); \
37 op = stmt.as<NODE>(); \
38 ICHECK(op != nullptr); \
39 auto n = make_object<NODE>(*op); \
40 n->FIELD = std::move(new_var); \
41 return Stmt(n); \
42 }
43
44class RenewDefMutator : public StmtExprMutator {
45 public:
46 static PrimFunc Transform(const PrimFunc& func) {
47 RenewDefMutator generator;
48 // Redefine params
49 Array<Var> params;
50 for (const auto& param : func->params) {
51 params.push_back(generator.ReDefineVar(param));
52 }
53 // Redefine buffers in order
54 // TODO(Siyuan Feng): checking var is used after define
55 Map<tir::Var, Buffer> buffer_map;
56 for (const auto& param : func->params) {
57 if (param->dtype.is_handle()) {
58 const Buffer& buffer = func->buffer_map.at(param);
59 Var new_param = Downcast<Var>(generator.VisitExpr(param));
60 Buffer new_buffer = generator.VisitBuffer(buffer, true);
61 buffer_map.Set(new_param, new_buffer);
62 }
63 }
64 // Visit body
65 Stmt body = generator(func->body);
66 // Recreate function
67 auto n = make_object<PrimFuncNode>(*func.get());
68 n->params = std::move(params);
69 n->buffer_map = std::move(buffer_map);
70 n->body = std::move(body);
71 return PrimFunc(n);
72 }
73
74 private:
75 Stmt operator()(Stmt stmt) {
76 // override StmtMutator::operator() to disable copy_on_write
77 // Since this pass tries to explict create a new function rather than update the existing one
78 allow_copy_on_write_ = false;
79 return VisitStmt(stmt);
80 }
81
82 PrimExpr VisitExpr(const PrimExpr& expr) final {
83 auto it = remap_.find(expr);
84 if (it != remap_.end()) {
85 return Downcast<PrimExpr>((*it).second);
86 } else {
87 return ExprMutator::VisitExpr(expr);
88 }
89 }
90
91 private:
92 STMT_REGENERATE_VAR_DEF(LetStmtNode, var);
93 STMT_REGENERATE_VAR_DEF(AllocateNode, buffer_var);
94 STMT_REGENERATE_VAR_DEF(AllocateConstNode, buffer_var);
95 STMT_REGENERATE_VAR_DEF(ForNode, loop_var);
96
97 Stmt VisitStmt_(const BlockNode* op) final {
98 // Step 0. Re-define Itervars
99 Array<IterVar> iter_vars =
100 op->iter_vars.Map(std::bind(&RenewDefMutator::VisitIterVar, this, std::placeholders::_1));
101
102 // Step 1. Re-define buffers allocate under the block
103 Array<Buffer> alloc_buffers = op->alloc_buffers.Map(
104 std::bind(&RenewDefMutator::VisitBuffer, this, std::placeholders::_1, /*define=*/true));
105
106 // Step 2. Re-define match_buffers
107 Array<MatchBufferRegion> match_buffers = op->match_buffers.Map(
108 std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1));
109
110 // Step 3. Visit body
111 Stmt stmt = StmtExprMutator::VisitStmt_(op);
112 op = stmt.as<BlockNode>();
113 ICHECK(op);
114
115 // Step 4. Revisit access region
116 Array<BufferRegion> reads =
117 op->reads.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1));
118 Array<BufferRegion> writes =
119 op->writes.Map(std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1));
120
121 // Step 5. Regenerate block. Since the defs are changed, we need to create a new block
122 auto n = make_object<BlockNode>(*op);
123 n->iter_vars = std::move(iter_vars);
124 n->alloc_buffers = std::move(alloc_buffers);
125 n->match_buffers = std::move(match_buffers);
126 n->reads = std::move(reads);
127 n->writes = std::move(writes);
128
129 return Stmt(n);
130 }
131
132 Stmt VisitStmt_(const BufferStoreNode* op) final {
133 Stmt stmt = StmtExprMutator::VisitStmt_(op);
134 op = stmt.as<BufferStoreNode>();
135 ICHECK(op != nullptr);
136 Buffer buffer = VisitDeclOrRemapBuffer(op->buffer);
137 if (buffer.same_as(op->buffer)) {
138 return stmt;
139 } else {
140 auto n = make_object<BufferStoreNode>(*op);
141 n->buffer = std::move(buffer);
142 return BufferStore(n);
143 }
144 }
145
146 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
147 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
148 op = expr.as<BufferLoadNode>();
149 ICHECK(op != nullptr);
150 Buffer buffer = VisitDeclOrRemapBuffer(op->buffer);
151 if (buffer.same_as(op->buffer)) {
152 return expr;
153 } else {
154 auto n = make_object<BufferLoadNode>(*op);
155 n->buffer = std::move(buffer);
156 return BufferLoad(n);
157 }
158 }
159
160 PrimExpr VisitExpr_(const LoadNode* op) final {
161 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
162 }
163
164 Stmt VisitStmt_(const StoreNode* op) final {
165 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
166 }
167
168 private:
169 Var ReDefineVar(const Var& var) {
170 Var new_var = Var(make_object<VarNode>(*var.get()));
171 this->AddDefRemap(var, new_var);
172 return new_var;
173 }
174
175 template <typename T>
176 void AddDefRemap(const T& source, const T& target) {
177 ICHECK(remap_.count(source) == 0);
178 remap_.Set(source, target);
179 }
180
181 Buffer VisitBuffer(const Buffer& buffer, bool define = false) {
182 auto it = remap_.find(buffer);
183 if (it != remap_.end()) {
184 return Downcast<Buffer>((*it).second);
185 }
186 ICHECK(define);
187
188 auto redefine_if_is_var = [this](const PrimExpr& expr) -> PrimExpr {
189 auto it = remap_.find(expr);
190 if (it != remap_.end()) {
191 return Downcast<PrimExpr>((*it).second);
192 } else if (const VarNode* var = expr.as<VarNode>()) {
193 return this->ReDefineVar(GetRef<Var>(var));
194 } else {
195 return ExprMutator::VisitExpr(expr);
196 }
197 };
198
199 // update data
200 Var data = Downcast<Var>(redefine_if_is_var(buffer->data));
201 // update shape
202 Array<PrimExpr> shape = buffer->shape.Map(redefine_if_is_var);
203 // update strides
204 Array<PrimExpr> strides = buffer->strides.Map(redefine_if_is_var);
205 // update elem_offset
206 PrimExpr elem_offset = redefine_if_is_var(buffer->elem_offset);
207
208 auto n = make_object<BufferNode>(*buffer.get());
209 n->data = std::move(data);
210 n->shape = std::move(shape);
211 n->strides = std::move(strides);
212 n->elem_offset = std::move(elem_offset);
213 Buffer new_buffer(n);
214 this->AddDefRemap(buffer, new_buffer);
215 return new_buffer;
216 }
217
218 IterVar VisitIterVar(const IterVar& iter_var) {
219 auto it = remap_.find(iter_var);
220 if (it != remap_.end()) {
221 return Downcast<IterVar>((*it).second);
222 }
223 PrimExpr min = VisitExpr(iter_var->dom->min);
224 PrimExpr extent = VisitExpr(iter_var->dom->extent);
225 IterVar new_iter_var(Range(min, extent), ReDefineVar(iter_var->var), iter_var->iter_type,
226 iter_var->thread_tag);
227 this->AddDefRemap(iter_var, new_iter_var);
228 return new_iter_var;
229 }
230
231 Buffer VisitDeclOrRemapBuffer(const Buffer& buffer) {
232 // If the buffer has been remapped, return the remapped buffer, otherwise,
233 // return the declared one.
234 // Due to a recent PR, we can allow undefined buffer appearing in BufferLoad/Store. We need
235 // to remap them but will not create new var
236 auto it = remap_.find(buffer);
237 if (it != remap_.end()) {
238 return Downcast<Buffer>((*it).second);
239 }
240 Var data = Downcast<Var>(VisitExpr(buffer->data));
241 Array<PrimExpr> shape =
242 buffer->shape.Map(std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1));
243 Array<PrimExpr> strides =
244 buffer->strides.Map(std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1));
245 PrimExpr elem_offset = VisitExpr(buffer->elem_offset);
246
247 auto n = make_object<BufferNode>(*buffer.get());
248 n->data = std::move(data);
249 n->shape = std::move(shape);
250 n->strides = std::move(strides);
251 n->elem_offset = std::move(elem_offset);
252 Buffer new_buffer(n);
253 this->AddDefRemap(buffer, new_buffer);
254 return new_buffer;
255 }
256
257 MatchBufferRegion VisitMatchBuffer(const MatchBufferRegion& match_buffer) {
258 Buffer buffer = VisitBuffer(match_buffer->buffer, /*define=*/true);
259 BufferRegion region = VisitBufferRegion(match_buffer->source);
260 return MatchBufferRegion(std::move(buffer), std::move(region));
261 }
262
263 Range VisitRange(const Range& range) {
264 PrimExpr min = VisitExpr(range->min);
265 PrimExpr extent = VisitExpr(range->extent);
266 if (min.same_as(range->min) && extent.same_as(range->extent)) {
267 return range;
268 } else {
269 return Range::FromMinExtent(std::move(min), std::move(extent));
270 }
271 }
272
273 BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) {
274 Buffer buffer = VisitBuffer(buffer_region->buffer);
275 Array<Range> region = buffer_region->region.Map(
276 std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1));
277 if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) {
278 return buffer_region;
279 } else {
280 return BufferRegion(std::move(buffer), std::move(region));
281 }
282 }
283
284 Map<ObjectRef, ObjectRef> remap_;
285};
286
287PrimFunc RenewDefs(const PrimFunc& func) { return RenewDefMutator::Transform(func); }
288
289TVM_REGISTER_GLOBAL("tir.RenewDefs").set_body_typed(RenewDefs);
290
291} // namespace tir
292} // namespace tvm
293