1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/tir/ir/specialize.cc
22 * \brief Specialize parameters of PrimFunc.
23 */
24#include <tvm/runtime/registry.h>
25#include <tvm/tir/analysis.h>
26#include <tvm/tir/function.h>
27#include <tvm/tir/op.h>
28#include <tvm/tir/stmt_functor.h>
29
30#include <functional>
31
32#include "functor_common.h"
33
34namespace tvm {
35namespace tir {
36
37using VarMap = std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
38
39/**************** Helper functions ****************/
40
41/*! \brief Helper function to check whether the given var is in function parameter list. */
42inline bool IsParam(const PrimFunc& func, const Var& param) {
43 return std::any_of(func->params.begin(), func->params.end(),
44 [&](const Var& var) { return var.same_as(param); });
45}
46
47/**************** Specializer ****************/
48
49// Try fold constants if op's child get specialized to constant.
50#define DEFINE_SPECIALIZER_BINARY_OP_MUTATE(BinaryNode, BinaryFunc) \
51 PrimExpr VisitExpr_(const BinaryNode* op) final { \
52 PrimExpr a = VisitExpr(op->a); \
53 PrimExpr b = VisitExpr(op->b); \
54 if (a.same_as(op->a) && b.same_as(op->b)) { \
55 return GetRef<PrimExpr>(op); \
56 } else { \
57 return BinaryFunc(a, b); \
58 } \
59 }
60#define DEFINE_SPECIALIZER_UNARY_OP_MUTATE(UnaryNode, UnaryFunc) \
61 PrimExpr VisitExpr_(const UnaryNode* op) final { \
62 PrimExpr a = VisitExpr(op->a); \
63 if (a.same_as(op->a)) { \
64 return GetRef<PrimExpr>(op); \
65 } else { \
66 return UnaryFunc(a); \
67 } \
68 }
69
70/*! \brief Mutator to specialize function and remove const parameters */
71class PrimFuncSpecializer : public StmtExprMutator {
72 public:
73 explicit PrimFuncSpecializer(const VarMap& var_map) : var_map_(var_map) {}
74
75 static PrimFunc Specialize(PrimFunc f, const VarMap& var_map) {
76 PrimFuncSpecializer specializer(var_map);
77 // Updating Buffer map
78 Map<Var, Buffer> buffer_map;
79 bool buffer_map_updated = false;
80 for (const auto& it : f->buffer_map) {
81 const Var& var = it.first;
82 const Buffer& buffer = it.second;
83 Buffer new_buffer = specializer.MutateBuffer(buffer);
84 buffer_map.Set(var, new_buffer);
85 if (!new_buffer.same_as(buffer)) {
86 buffer_map_updated = true;
87 specializer.buffer_map_[buffer] = new_buffer;
88 }
89 }
90
91 // Updating parmeters
92 Array<Var> params;
93 bool param_updated = false;
94 for (const auto& var : f->params) {
95 // Remove parmeters which has been specialized.
96 if (var_map.find(var) == var_map.end()) {
97 params.push_back(var);
98 } else {
99 param_updated = true;
100 }
101 }
102
103 // Updating function body
104 Stmt body = specializer(f->body);
105
106 if (param_updated || buffer_map_updated || !f->body.same_as(body)) {
107 PrimFuncNode* f_ptr = f.CopyOnWrite();
108 f_ptr->params = std::move(params);
109 f_ptr->buffer_map = std::move(buffer_map);
110 f_ptr->body = std::move(body);
111 }
112 return f;
113 }
114
115 private:
116 Stmt VisitStmt_(const BlockNode* op) final {
117 // Step.0. Define buffer mappings which is allocated inside the block
118 Array<Buffer> alloc_buffers = op->alloc_buffers.Map(
119 std::bind(&PrimFuncSpecializer::MutateAllocBuffer, this, std::placeholders::_1));
120
121 // Step.1. Recursively visit block body
122 Stmt stmt = StmtExprMutator::VisitStmt_(op);
123 op = stmt.as<BlockNode>();
124 ICHECK(op != nullptr);
125
126 Array<BufferRegion> reads = op->reads.Map(
127 std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
128 Array<BufferRegion> writes = op->writes.Map(
129 std::bind(&PrimFuncSpecializer::MutateBufferRegion, this, std::placeholders::_1));
130
131 if (alloc_buffers.same_as(op->alloc_buffers) && reads.same_as(op->reads) &&
132 writes.same_as(op->writes)) {
133 return GetRef<Block>(op);
134 } else {
135 ObjectPtr<BlockNode> n = CopyOnWrite(op);
136 n->alloc_buffers = std::move(alloc_buffers);
137 n->reads = std::move(reads);
138 n->writes = std::move(writes);
139 return Stmt(n);
140 }
141 }
142
143 Stmt VisitStmt_(const BufferStoreNode* op) final {
144 Stmt stmt = StmtExprMutator::VisitStmt_(op);
145 op = stmt.as<BufferStoreNode>();
146 ICHECK(op != nullptr);
147 auto it = buffer_map_.find(op->buffer);
148 if (it == buffer_map_.end()) {
149 return GetRef<BufferStore>(op);
150 } else {
151 auto n = CopyOnWrite(op);
152 n->buffer = it->second;
153 return Stmt(n);
154 }
155 }
156
157 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
158 PrimExpr expr = StmtExprMutator::VisitExpr_(op);
159 op = expr.as<BufferLoadNode>();
160 ICHECK(op != nullptr);
161 auto it = buffer_map_.find(op->buffer);
162 if (it == buffer_map_.end()) {
163 return GetRef<BufferLoad>(op);
164 } else {
165 auto n = make_object<BufferLoadNode>(*op);
166 n->buffer = it->second;
167 return PrimExpr(n);
168 }
169 }
170
171 PrimExpr VisitExpr_(const VarNode* op) final {
172 auto it = var_map_.find(GetRef<Var>(op));
173 if (it == var_map_.end()) {
174 return GetRef<PrimExpr>(op);
175 } else {
176 return it->second;
177 }
178 }
179
180 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(AddNode, add);
181 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(SubNode, sub);
182 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MulNode, mul);
183 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(DivNode, div);
184 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(ModNode, truncmod);
185 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(FloorDivNode, floordiv);
186 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(FloorModNode, floormod);
187 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MaxNode, max);
188 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(MinNode, min);
189 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(EQNode, equal);
190 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(NENode, not_equal);
191 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(LTNode, less);
192 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(LENode, less_equal);
193 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(GTNode, greater);
194 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(GENode, greater_equal);
195 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(AndNode, logical_and);
196 DEFINE_SPECIALIZER_BINARY_OP_MUTATE(OrNode, logical_or);
197 DEFINE_SPECIALIZER_UNARY_OP_MUTATE(NotNode, logical_not);
198
199 private:
200 Buffer MutateBuffer(const Buffer& buffer) {
201 Array<PrimExpr> shape = buffer->shape.Map([this](const PrimExpr& e) { return VisitExpr(e); });
202 Array<PrimExpr> strides =
203 buffer->strides.Map([this](const PrimExpr& e) { return VisitExpr(e); });
204
205 PrimExpr elem_offset = VisitExpr(buffer->elem_offset);
206
207 if (buffer->elem_offset.same_as(elem_offset) && buffer->shape.same_as(shape) &&
208 buffer->strides.same_as(strides)) {
209 return buffer;
210 } else {
211 auto n = make_object<BufferNode>(*buffer.get());
212 n->elem_offset = std::move(elem_offset);
213 n->shape = std::move(shape);
214 n->strides = std::move(strides);
215 return Buffer(n);
216 }
217 }
218
219 Range MutateRange(const Range& range) {
220 PrimExpr min = this->VisitExpr(range->min);
221 PrimExpr extent = this->VisitExpr(range->extent);
222 if (min.same_as(range->min) && extent.same_as(range->extent)) {
223 return range;
224 } else {
225 return Range::FromMinExtent(std::move(min), std::move(extent));
226 }
227 }
228
229 Buffer MutateAllocBuffer(const Buffer& alloc_buf) {
230 Buffer buf = MutateBuffer(alloc_buf);
231 if (buf.same_as(alloc_buf)) {
232 return alloc_buf;
233 } else {
234 ICHECK(buffer_map_.find(alloc_buf) == buffer_map_.end());
235 buffer_map_[alloc_buf] = buf;
236 return buf;
237 }
238 }
239
240 BufferRegion MutateBufferRegion(const BufferRegion& buffer_region) {
241 auto it = buffer_map_.find(buffer_region->buffer);
242 const Buffer& buffer = it != buffer_map_.end() ? it->second : buffer_region->buffer;
243 Array<Range> region = buffer_region->region.Map(
244 std::bind(&PrimFuncSpecializer::MutateRange, this, std::placeholders::_1));
245 if (it == buffer_map_.end() && region.same_as(buffer_region->region)) {
246 return buffer_region;
247 } else {
248 return BufferRegion(buffer, std::move(region));
249 }
250 }
251
252 private:
253 /*! \brief The vars to be substitute and their values */
254 const VarMap& var_map_;
255 /*! \brief map from old buffer to mutated buffer */
256 std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
257};
258
259/*!
260 * \brief Update Specialize var map with buffer matching.
261 * \param func The function to be specialized.
262 * \param param The given function parameter
263 * \param specific_buf The matching buffer.
264 * \param var_map The var mapping to be updated.
265 * \note This function will match target buffer's shape, strides and element_offset
266 * For example, we define a buffer in PrimFunc:
267 * A = T.match_buffer(a, [m, n])
268 *
269 * Then we match it with a buffer B = tir.decl_buffer((8, 16))
270 *
271 * It means we have two var mappings here: m = 8 and n = 16
272 *
273 * If the buffer signature is not a Var, the mapping will fail.
274 * e.g. A = T.match_buffer(a, [m * 2, n + 1])
275 */
276void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const Buffer& specific_buf,
277 VarMap* var_map) {
278 // preliminaries
279 tir::ExprDeepEqual equal;
280
281 auto it = func->buffer_map.find(param);
282 CHECK(it != func->buffer_map.end())
283 << "ValueError: specialize expects param to be in PrimFunc's buffer_map";
284 const Buffer& buf_to_specialize = (*it).second;
285
286 // build var mapping using specific_buf's parameters
287 auto build_var_mapping = [&](const PrimExpr& new_expr, const PrimExpr& old_expr) {
288 if (!equal(new_expr, old_expr)) {
289 CHECK(old_expr->IsInstance<VarNode>())
290 << "TypeError: The signature of target buffer exprected an independent Var, but got "
291 << old_expr << ".";
292 const Var& var = Downcast<Var>(old_expr);
293 auto it = var_map->find(var);
294 if (it != var_map->end()) {
295 CHECK(equal(it->second, new_expr))
296 << "ValueError: The assigned value of var " << var << " mismatched. " << it->second
297 << " vs. " << new_expr << ".";
298 } else {
299 (*var_map)[var] = new_expr;
300 }
301 }
302 };
303
304 // Check buffer dimensions
305 CHECK(specific_buf->shape.size() == buf_to_specialize->shape.size())
306 << "ValueError: The buffer dimensions mismatched" << buf_to_specialize->shape.size()
307 << " vs. " << specific_buf->shape.size() << ".";
308
309 CHECK(specific_buf->strides.size() == buf_to_specialize->strides.size())
310 << "ValueError: The buffer strides dimensions mismatched" << buf_to_specialize->strides.size()
311 << " vs. " << specific_buf->strides.size() << ".";
312
313 // Updating var mapping using specific_expr
314 for (size_t i = 0; i < specific_buf->shape.size(); ++i) {
315 build_var_mapping(specific_buf->shape[i], buf_to_specialize->shape[i]);
316 }
317 for (size_t i = 0; i < specific_buf->strides.size(); ++i) {
318 build_var_mapping(specific_buf->strides[i], buf_to_specialize->strides[i]);
319 }
320 build_var_mapping(specific_buf->elem_offset, buf_to_specialize->elem_offset);
321
322 // Check data_alignment and offset_factor.
323 // These two signatures are int, so we do not need map them.
324 CHECK_EQ(specific_buf->data_alignment, buf_to_specialize->data_alignment)
325 << "ValueError: The buffer data_alignment mismatched" << buf_to_specialize->data_alignment
326 << " vs. " << specific_buf->data_alignment << ".";
327
328 CHECK_EQ(specific_buf->offset_factor, buf_to_specialize->offset_factor)
329 << "ValueError: The buffer offset_factor mismatched" << buf_to_specialize->offset_factor
330 << " vs. " << specific_buf->offset_factor << ".";
331}
332
333/*!
334 * \brief Update Specialize var map with parameter value.
335 * \param func The function to be specialized.
336 * \param param The given function parameter
337 * \param specific_expr The parameter value.
338 * \param var_map The var mapping to be updated.
339 */
340void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimExpr& specific_expr,
341 VarMap* var_map) {
342 // check param is in PrimFunc's parameters
343 CHECK(IsParam(func, param)) << "ValueError: Specialize expects param to be in PrimFunc's params";
344 // specialize a param not in buffer_map
345 CHECK_EQ(func->buffer_map.count(param), 0)
346 << "ValueError: Specialize expects param to not be in PrimFunc's buffer_map";
347 // build var mapping using specific_expr
348 (*var_map)[param] = specific_expr;
349}
350
351/**************** Implementation ****************/
352
353PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map) {
354 VarMap var_map;
355 for (const auto& kv : param_map) {
356 const Var& param = kv.first;
357 const ObjectRef& instance = kv.second;
358 if (instance->IsInstance<BufferNode>()) {
359 UpdateSpecializeVarMap(func, param, Downcast<Buffer>(instance), &var_map);
360 } else if (instance->IsInstance<PrimExprNode>()) {
361 UpdateSpecializeVarMap(func, param, Downcast<PrimExpr>(instance), &var_map);
362 } else {
363 CHECK(instance.defined()) << "Specialize instance is not defined for param " << param;
364 LOG(FATAL) << "TypeError: specialize expected instance to be Buffer or PrimExpr, but got "
365 << instance->GetTypeKey();
366 }
367 }
368 return PrimFuncSpecializer::Specialize(func, std::move(var_map));
369}
370
371/**************** FFI ****************/
372
373TVM_REGISTER_GLOBAL("tir.Specialize").set_body_typed(Specialize);
374
375} // namespace tir
376} // namespace tvm
377