1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/tir/ir/specialize.cc |
22 | * \brief Specialize parameters of PrimFunc. |
23 | */ |
24 | #include <tvm/runtime/registry.h> |
25 | #include <tvm/tir/analysis.h> |
26 | #include <tvm/tir/function.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/tir/stmt_functor.h> |
29 | |
30 | #include <functional> |
31 | |
32 | #include "functor_common.h" |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>; |
38 | |
39 | /**************** Helper functions ****************/ |
40 | |
41 | /*! \brief Helper function to check whether the given var is in function parameter list. */ |
42 | inline bool IsParam(const PrimFunc& func, const Var& param) { |
43 | return std::any_of(func->params.begin(), func->params.end(), |
44 | [&](const Var& var) { return var.same_as(param); }); |
45 | } |
46 | |
47 | /**************** Specializer ****************/ |
48 | |
49 | // Try fold constants if op's child get specialized to constant. |
50 | #define DEFINE_SPECIALIZER_BINARY_OP_MUTATE(BinaryNode, BinaryFunc) \ |
51 | PrimExpr VisitExpr_(const BinaryNode* op) final { \ |
52 | PrimExpr a = VisitExpr(op->a); \ |
53 | PrimExpr b = VisitExpr(op->b); \ |
54 | if (a.same_as(op->a) && b.same_as(op->b)) { \ |
55 | return GetRef<PrimExpr>(op); \ |
56 | } else { \ |
57 | return BinaryFunc(a, b); \ |
58 | } \ |
59 | } |
60 | #define DEFINE_SPECIALIZER_UNARY_OP_MUTATE(UnaryNode, UnaryFunc) \ |
61 | PrimExpr VisitExpr_(const UnaryNode* op) final { \ |
62 | PrimExpr a = VisitExpr(op->a); \ |
63 | if (a.same_as(op->a)) { \ |
64 | return GetRef<PrimExpr>(op); \ |
65 | } else { \ |
66 | return UnaryFunc(a); \ |
67 | } \ |
68 | } |
69 | |
70 | /*! \brief Mutator to specialize function and remove const parameters */ |
71 | class PrimFuncSpecializer : public StmtExprMutator { |
72 | public: |
73 | explicit PrimFuncSpecializer(const VarMap& var_map) : var_map_(var_map) {} |
74 | |
75 | static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) { |
76 | PrimFuncSpecializer specializer(var_map); |
77 | // Updating Buffer map |
78 | Map<Var, Buffer> buffer_map; |
79 | bool buffer_map_updated = false; |
80 | for (const auto& it : f->buffer_map) { |
81 | const Var& var = it.first; |
82 | const Buffer& buffer = it.second; |
83 | Buffer new_buffer = specializer.MutateBuffer(buffer); |
84 | buffer_map.Set(var, new_buffer); |
85 | if (!new_buffer.same_as(buffer)) { |
86 | buffer_map_updated = true; |
87 | specializer.buffer_map_[buffer] = new_buffer; |
88 | } |
89 | } |
90 | |
91 | // Updating parmeters |
92 | Array<Var> params; |
93 | bool param_updated = false; |
94 | for (const auto& var : f->params) { |
95 | // Remove parmeters which has been specialized. |
96 | if (var_map.find(var) == var_map.end()) { |
97 | params.push_back(var); |
98 | } else { |
99 | param_updated = true; |
100 | } |
101 | } |
102 | |
103 | // Updating function body |
104 | Stmt body = specializer(f->body); |
105 | |
106 | if (param_updated || buffer_map_updated || !f->body.same_as(body)) { |
107 | PrimFuncNode* f_ptr = f.CopyOnWrite(); |
108 | f_ptr->params = std::move(params); |
109 | f_ptr->buffer_map = std::move(buffer_map); |
110 | f_ptr->body = std::move(body); |
111 | } |
112 | return f; |
113 | } |
114 | |
115 | private: |
116 | Stmt VisitStmt_(const BlockNode* op) final { |
117 | // Step.0. Define buffer mappings which is allocated inside the block |
118 | Array<Buffer> alloc_buffers = op->alloc_buffers.Map( |
119 | std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1)); |
120 | |
121 | // Step.1. Recursively visit block body |
122 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
123 | op = stmt.as<BlockNode>(); |
124 | ICHECK(op != nullptr); |
125 | |
126 | Array<BufferRegion> reads = op->reads.Map( |
127 | std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); |
128 | Array<BufferRegion> writes = op->writes.Map( |
129 | std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1)); |
130 | |
131 | if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) && |
132 | writes.same_as(op->writes)) { |
133 | return GetRef<Block>(op); |
134 | } else { |
135 | ObjectPtr<BlockNode> n = CopyOnWrite(op); |
136 | n->alloc_buffers = std::move(alloc_buffers); |
137 | n->reads = std::move(reads); |
138 | n->writes = std::move(writes); |
139 | return Stmt(n); |
140 | } |
141 | } |
142 | |
143 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
144 | Stmt stmt = StmtExprMutator::VisitStmt_(op); |
145 | op = stmt.as<BufferStoreNode>(); |
146 | ICHECK(op != nullptr); |
147 | auto it = buffer_map_.find(op->buffer); |
148 | if (it == buffer_map_.end()) { |
149 | return GetRef<BufferStore>(op); |
150 | } else { |
151 | auto n = CopyOnWrite(op); |
152 | n->buffer = it->second; |
153 | return Stmt(n); |
154 | } |
155 | } |
156 | |
157 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
158 | PrimExpr expr = StmtExprMutator::VisitExpr_(op); |
159 | op = expr.as<BufferLoadNode>(); |
160 | ICHECK(op != nullptr); |
161 | auto it = buffer_map_.find(op->buffer); |
162 | if (it == buffer_map_.end()) { |
163 | return GetRef<BufferLoad>(op); |
164 | } else { |
165 | auto n = make_object<BufferLoadNode>(*op); |
166 | n->buffer = it->second; |
167 | return PrimExpr(n); |
168 | } |
169 | } |
170 | |
171 | PrimExpr VisitExpr_(const VarNode* op) final { |
172 | auto it = var_map_.find(GetRef<Var>(op)); |
173 | if (it == var_map_.end()) { |
174 | return GetRef<PrimExpr>(op); |
175 | } else { |
176 | return it->second; |
177 | } |
178 | } |
179 | |
180 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(AddNode, add); |
181 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(SubNode, sub); |
182 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MulNode, mul); |
183 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(DivNode, div); |
184 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(ModNode, truncmod); |
185 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(FloorDivNode, floordiv); |
186 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(FloorModNode, floormod); |
187 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MaxNode, max); |
188 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MinNode, min); |
189 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(EQNode, equal); |
190 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(NENode, not_equal); |
191 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(LTNode, less); |
192 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(LENode, less_equal); |
193 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(GTNode, greater); |
194 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(GENode, greater_equal); |
195 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(AndNode, logical_and); |
196 | DEFINE_SPECIALIZER_BINARY_OP_MUTATE(OrNode, logical_or); |
197 | DEFINE_SPECIALIZER_UNARY_OP_MUTATE(NotNode, logical_not); |
198 | |
199 | private: |
200 | Buffer MutateBuffer(const Buffer& buffer) { |
201 | Array<PrimExpr> shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); }); |
202 | Array<PrimExpr> strides = |
203 | buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); }); |
204 | |
205 | PrimExpr elem_offset = VisitExpr(buffer->elem_offset); |
206 | |
207 | if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) && |
208 | buffer->strides.same_as(strides)) { |
209 | return buffer; |
210 | } else { |
211 | auto n = make_object<BufferNode>(*buffer.get()); |
212 | n->elem_offset = std::move(elem_offset); |
213 | n->shape = std::move(shape); |
214 | n->strides = std::move(strides); |
215 | return Buffer(n); |
216 | } |
217 | } |
218 | |
219 | Range MutateRange(const Range& range) { |
220 | PrimExpr min = this->VisitExpr(range->min); |
221 | PrimExpr extent = this->VisitExpr(range->extent); |
222 | if (min.same_as(range->min) && extent.same_as(range->extent)) { |
223 | return range; |
224 | } else { |
225 | return Range::FromMinExtent(std::move(min), std::move(extent)); |
226 | } |
227 | } |
228 | |
229 | Buffer MutateAllocBuffer(const Buffer& alloc_buf) { |
230 | Buffer buf = MutateBuffer(alloc_buf); |
231 | if (buf.same_as(alloc_buf)) { |
232 | return alloc_buf; |
233 | } else { |
234 | ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end()); |
235 | buffer_map_[alloc_buf] = buf; |
236 | return buf; |
237 | } |
238 | } |
239 | |
240 | BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) { |
241 | auto it = buffer_map_.find(buffer_region->buffer); |
242 | const Buffer& buffer = it != buffer_map_.end() ? it->second : buffer_region->buffer; |
243 | Array<Range> region = buffer_region->region.Map( |
244 | std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1)); |
245 | if (it == buffer_map_.end() && region.same_as(buffer_region->region)) { |
246 | return buffer_region; |
247 | } else { |
248 | return BufferRegion(buffer, std::move(region)); |
249 | } |
250 | } |
251 | |
252 | private: |
253 | /*! \brief The vars to be substitute and their values */ |
254 | const VarMap& var_map_; |
255 | /*! \brief map from old buffer to mutated buffer */ |
256 | std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_; |
257 | }; |
258 | |
259 | /*! |
260 | * \brief Update Specialize var map with buffer matching. |
261 | * \param func The function to be specialized. |
262 | * \param param The given function parameter |
263 | * \param specific_buf The matching buffer. |
264 | * \param var_map The var mapping to be updated. |
265 | * \note This function will match target buffer's shape, strides and element_offset |
266 | * For example, we define a buffer in PrimFunc: |
267 | * A = T.match_buffer(a, [m, n]) |
268 | * |
269 | * Then we match it with a buffer B = tir.decl_buffer((8, 16)) |
270 | * |
271 | * It means we have two var mappings here: m = 8 and n = 16 |
272 | * |
273 | * If the buffer signature is not a Var, the mapping will fail. |
274 | * e.g. A = T.match_buffer(a, [m * 2, n + 1]) |
275 | */ |
276 | void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf, |
277 | VarMap* var_map) { |
278 | // preliminaries |
279 | tir::ExprDeepEqual equal; |
280 | |
281 | auto it = func->buffer_map.find(param); |
282 | CHECK(it != func->buffer_map.end()) |
283 | << "ValueError: specialize expects param to be in PrimFunc's buffer_map" ; |
284 | const Buffer& buf_to_specialize = (*it).second; |
285 | |
286 | // build var mapping using specific_buf's parameters |
287 | auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr& old_expr) { |
288 | if (!equal(new_expr, old_expr)) { |
289 | CHECK(old_expr->IsInstance<VarNode>()) |
290 | << "TypeError: The signature of target buffer exprected an independent Var, but got " |
291 | << old_expr << "." ; |
292 | const Var& var = Downcast<Var>(old_expr); |
293 | auto it = var_map->find(var); |
294 | if (it != var_map->end()) { |
295 | CHECK(equal(it->second, new_expr)) |
296 | << "ValueError: The assigned value of var " << var << " mismatched. " << it->second |
297 | << " vs. " << new_expr << "." ; |
298 | } else { |
299 | (*var_map)[var] = new_expr; |
300 | } |
301 | } |
302 | }; |
303 | |
304 | // Check buffer dimensions |
305 | CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size()) |
306 | << "ValueError: The buffer dimensions mismatched" << buf_to_specialize->shape.size() |
307 | << " vs. " << specific_buf->shape.size() << "." ; |
308 | |
309 | CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size()) |
310 | << "ValueError: The buffer strides dimensions mismatched" << buf_to_specialize->strides.size() |
311 | << " vs. " << specific_buf->strides.size() << "." ; |
312 | |
313 | // Updating var mapping using specific_expr |
314 | for (size_t i = 0; i < specific_buf->shape.size(); ++i) { |
315 | build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]); |
316 | } |
317 | for (size_t i = 0; i < specific_buf->strides.size(); ++i) { |
318 | build_var_mapping(specific_buf->strides[i], buf_to_specialize->strides[i]); |
319 | } |
320 | build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset); |
321 | |
322 | // Check data_alignment and offset_factor. |
323 | // These two signatures are int, so we do not need map them. |
324 | CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment) |
325 | << "ValueError: The buffer data_alignment mismatched" << buf_to_specialize->data_alignment |
326 | << " vs. " << specific_buf->data_alignment << "." ; |
327 | |
328 | CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor) |
329 | << "ValueError: The buffer offset_factor mismatched" << buf_to_specialize->offset_factor |
330 | << " vs. " << specific_buf->offset_factor << "." ; |
331 | } |
332 | |
333 | /*! |
334 | * \brief Update Specialize var map with parameter value. |
335 | * \param func The function to be specialized. |
336 | * \param param The given function parameter |
337 | * \param specific_expr The parameter value. |
338 | * \param var_map The var mapping to be updated. |
339 | */ |
340 | void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr, |
341 | VarMap* var_map) { |
342 | // check param is in PrimFunc's parameters |
343 | CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be in PrimFunc's params" ; |
344 | // specialize a param not in buffer_map |
345 | CHECK_EQ(func->buffer_map.count(param), 0) |
346 | << "ValueError: Specialize expects param to not be in PrimFunc's buffer_map" ; |
347 | // build var mapping using specific_expr |
348 | (*var_map)[param] = specific_expr; |
349 | } |
350 | |
351 | /**************** Implementation ****************/ |
352 | |
353 | PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map) { |
354 | VarMap var_map; |
355 | for (const auto& kv : param_map) { |
356 | const Var& param = kv.first; |
357 | const ObjectRef& instance = kv.second; |
358 | if (instance->IsInstance<BufferNode>()) { |
359 | UpdateSpecializeVarMap(func, param, Downcast<Buffer>(instance), &var_map); |
360 | } else if (instance->IsInstance<PrimExprNode>()) { |
361 | UpdateSpecializeVarMap(func, param, Downcast<PrimExpr>(instance), &var_map); |
362 | } else { |
363 | CHECK(instance.defined()) << "Specialize instance is not defined for param " << param; |
364 | LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got " |
365 | << instance->GetTypeKey(); |
366 | } |
367 | } |
368 | return PrimFuncSpecializer::Specialize(func, std::move(var_map)); |
369 | } |
370 | |
371 | /**************** FFI ****************/ |
372 | |
373 | TVM_REGISTER_GLOBAL("tir.Specialize" ).set_body_typed(Specialize); |
374 | |
375 | } // namespace tir |
376 | } // namespace tvm |
377 | |