1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file schedule_postproc_to_primfunc.cc |
22 | * |
23 | * \brief Translate the function body generated by ScheduleOps |
24 | * with te related dialects that incorporates Tensor |
25 | * into the Stmts to a PrimFunc. |
26 | * |
27 | * Perform this translation before running any TIR optimizations. |
28 | * |
29 | * Rationale: The body generated by ScheduleOps is not |
30 | * a formal PrimFunc and cannot be used for further optimization. |
31 | * This function canonicalize that body and creates a formal PrimFunc. |
32 | * |
33 | * List of actions taken by the function: |
34 | * - Remove occurrences of te::Tensor, te::Operation in the IR |
35 | * and replace them by corresponding IR nodes via tir::Buffer. |
36 | * - Add annotation of extern buffers using the buffer_map field |
37 | * in the PrimFunc type. |
38 | */ |
39 | #include <tvm/runtime/registry.h> |
40 | #include <tvm/te/operation.h> |
41 | #include <tvm/tir/expr.h> |
42 | #include <tvm/tir/function.h> |
43 | #include <tvm/tir/stmt_functor.h> |
44 | |
45 | #include <functional> |
46 | #include <unordered_map> |
47 | #include <utility> |
48 | |
49 | namespace tvm { |
50 | namespace te { |
51 | |
52 | // create a buffer for tensor. |
53 | Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "" ) { |
54 | std::string name = tensor->op->name; |
55 | if (tensor->op->num_outputs() != 1) { |
56 | name += ".v" + std::to_string(tensor->value_index); |
57 | } |
58 | Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope); |
59 | |
60 | return buffer; |
61 | } |
62 | |
63 | // A remapper that maps tensor to buffer |
64 | class TensorToBufferMapper : public StmtExprMutator { |
65 | public: |
66 | explicit TensorToBufferMapper(std::unordered_map<Tensor, Buffer> buffer_map) |
67 | : buffer_map_(buffer_map) {} |
68 | |
69 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
70 | auto ret = StmtExprMutator::VisitStmt_(op); |
71 | op = ret.as<AttrStmtNode>(); |
72 | if (op->attr_key == tir::attr::double_buffer_scope || |
73 | op->attr_key == tir::attr::rolling_buffer_scope) { |
74 | Stmt body = op->body; |
75 | Operation operation = Downcast<Operation>(op->node); |
76 | for (int i = operation->num_outputs(); i != 0; --i) { |
77 | Buffer buffer = GetOrAllocBuffer(operation.output(i - 1)); |
78 | body = AttrStmt(buffer, op->attr_key, op->value, body); |
79 | } |
80 | return body; |
81 | } else if (op->attr_key == tir::attr::buffer_bind_scope) { |
82 | Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node); |
83 | Tensor tensor = Downcast<Tensor>(tuple[1]); |
84 | return AttrStmt(Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, op->value, |
85 | op->body); |
86 | } else if (op->attr_key == tir::attr::buffer_dim_align || |
87 | op->attr_key == tir::attr::prefetch_scope) { |
88 | Tensor tensor = Downcast<Tensor>(op->node); |
89 | Buffer buffer = GetOrAllocBuffer(tensor); |
90 | return AttrStmt(buffer, op->attr_key, op->value, op->body); |
91 | } else if (op->attr_key == tir::attr::layout_transforms || |
92 | op->attr_key == tir::attr::axis_separators) { |
93 | auto arr = Downcast<Array<ObjectRef>>(op->node); |
94 | ICHECK_EQ(arr.size(), 2); |
95 | |
96 | Stmt body = op->body; |
97 | |
98 | Tensor tensor = Downcast<Tensor>(arr[0]); |
99 | Buffer buffer = GetBuffer(tensor); |
100 | |
101 | return AttrStmt(Array<ObjectRef>{buffer, arr[1]}, op->attr_key, 1, body); |
102 | } else { |
103 | return ret; |
104 | } |
105 | } |
106 | |
107 | Stmt VisitStmt_(const ProducerRealizeNode* op) final { |
108 | Tensor tensor = Downcast<Tensor>(op->producer); |
109 | Buffer buffer = GetOrAllocBuffer(tensor, op->storage_scope); |
110 | |
111 | auto ret = StmtExprMutator::VisitStmt_(op); |
112 | op = ret.as<ProducerRealizeNode>(); |
113 | |
114 | return BufferRealize(buffer, op->bounds, op->condition, op->body); |
115 | } |
116 | |
117 | Stmt VisitStmt_(const ProducerStoreNode* op) final { |
118 | Tensor tensor = Downcast<Tensor>(op->producer); |
119 | Buffer buffer = GetBuffer(tensor); |
120 | |
121 | auto ret = StmtExprMutator::VisitStmt_(op); |
122 | op = ret.as<ProducerStoreNode>(); |
123 | |
124 | return BufferStore(buffer, op->value, GetIndices(op->indices, buffer->shape)); |
125 | } |
126 | |
127 | PrimExpr VisitExpr_(const ProducerLoadNode* op) final { |
128 | auto ret = StmtExprMutator::VisitExpr_(op); |
129 | op = ret.as<ProducerLoadNode>(); |
130 | Tensor tensor = Downcast<Tensor>(op->producer); |
131 | Buffer buffer = GetBuffer(tensor); |
132 | return tir::BufferLoad(buffer, GetIndices(op->indices, buffer->shape)); |
133 | } |
134 | |
135 | private: |
136 | Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "" ) { |
137 | return GetBuffer(tensor, storage_scope, true); |
138 | } |
139 | |
140 | Buffer GetBuffer(const Tensor& tensor, String storage_scope = "" , bool allow_alloc = false) { |
141 | auto it = buffer_map_.find(tensor); |
142 | if (it != buffer_map_.end()) return it->second; |
143 | ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; |
144 | |
145 | auto buffer = CreateBufferFor(tensor, storage_scope); |
146 | buffer_map_[tensor] = buffer; |
147 | return buffer; |
148 | } |
149 | |
150 | Array<PrimExpr> GetIndices(const Array<PrimExpr>& tensor_indices, |
151 | const Array<PrimExpr>& buffer_shape) { |
152 | if (tensor_indices.size() == buffer_shape.size()) { |
153 | return tensor_indices; |
154 | } else if (tensor_indices.size() == 1) { |
155 | // Workaround to support previous behavior of tensor indexing by |
156 | // a single index, treating the tensor as if were already |
157 | // flattened by a row-major traversal. |
158 | PrimExpr unravel = tensor_indices[0]; |
159 | Array<PrimExpr> rev_indices; |
160 | for (size_t i = buffer_shape.size(); i > 0; i--) { |
161 | PrimExpr dim = buffer_shape[i - 1]; |
162 | rev_indices.push_back(indexmod(unravel, dim)); |
163 | unravel = indexdiv(unravel, dim); |
164 | } |
165 | return Array<PrimExpr>(rev_indices.rbegin(), rev_indices.rend()); |
166 | } else { |
167 | LOG(FATAL) << "Cannot produce indices for " << buffer_shape.size() |
168 | << "-dimensional TIR buffer using " << tensor_indices.size() |
169 | << "-dimensional tensor indices." ; |
170 | return {}; |
171 | } |
172 | } |
173 | |
174 | // Maps tensor to buffer. |
175 | std::unordered_map<Tensor, Buffer> buffer_map_; |
176 | }; |
177 | |
178 | /*! Collect the physical layout map of all tensors in the statement. */ |
179 | class LayoutTransformAttrUnwrapper : StmtExprMutator { |
180 | public: |
181 | static tir::PrimFunc Apply(tir::PrimFunc func) { |
182 | // Collect the physical layout annotations in the body, which may |
183 | // refer to input arguments. |
184 | auto layout_map = Collector::Collect(func->body); |
185 | |
186 | if (layout_map.size()) { |
187 | func = WithAttr(std::move(func), "layout_transform_map" , layout_map); |
188 | |
189 | auto write_ptr = func.CopyOnWrite(); |
190 | write_ptr->body = LayoutTransformAttrUnwrapper()(func->body); |
191 | } |
192 | |
193 | return func; |
194 | } |
195 | |
196 | LayoutTransformAttrUnwrapper() {} |
197 | |
198 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
199 | auto ret = StmtExprMutator::VisitStmt_(op); |
200 | op = ret.as<AttrStmtNode>(); |
201 | |
202 | if (op->attr_key == tir::attr::layout_transforms) { |
203 | return op->body; |
204 | } else { |
205 | return ret; |
206 | } |
207 | } |
208 | |
209 | private: |
210 | /*! Collect the physical layout information of all tensors in the statement. |
211 | * |
212 | * Must be done before constructing the buffers, since the |
213 | * attributes could either apply to the external buffers or to |
214 | * internal allocations. |
215 | */ |
216 | class Collector : StmtExprVisitor { |
217 | public: |
218 | static Map<Buffer, Array<IndexMap>> Collect(Stmt stmt) { |
219 | Collector collector; |
220 | collector(std::move(stmt)); |
221 | return std::move(collector.layout_map_); |
222 | } |
223 | |
224 | Collector() {} |
225 | |
226 | void VisitStmt_(const AttrStmtNode* op) final { |
227 | if (op->attr_key == tir::attr::layout_transforms) { |
228 | auto arr = Downcast<Array<ObjectRef>>(op->node); |
229 | ICHECK_EQ(arr.size(), 2); |
230 | |
231 | auto buffer = Downcast<Buffer>(arr[0]); |
232 | auto layout_transforms = Downcast<Array<IndexMap>>(arr[1]); |
233 | layout_map_.Set(buffer, layout_transforms); |
234 | } |
235 | StmtExprVisitor::VisitStmt_(op); |
236 | } |
237 | |
238 | Map<Buffer, Array<IndexMap>> layout_map_; |
239 | }; |
240 | |
241 | std::unordered_map<const BufferNode*, Buffer> buffer_remap_; |
242 | |
243 | Map<Buffer, Array<IndexMap>> layout_map_; |
244 | }; |
245 | |
246 | /*! Move axis_separators from an attribute to a buffer property. */ |
247 | class AxisSeparatorsAttrUnwrapper : StmtExprMutator { |
248 | public: |
249 | static tir::PrimFunc Apply(tir::PrimFunc func) { |
250 | // Collect the physical layout annotations in the body, which may |
251 | // refer to input arguments. |
252 | auto axis_separators_map = Collector::Collect(func->body); |
253 | |
254 | if (axis_separators_map.size()) { |
255 | auto write_ptr = func.CopyOnWrite(); |
256 | auto pass = AxisSeparatorsAttrUnwrapper(axis_separators_map); |
257 | write_ptr->buffer_map = pass.UpdateExternBufferMap(func->buffer_map); |
258 | write_ptr->body = pass(func->body); |
259 | if (auto map = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map" )) { |
260 | func = WithAttr(std::move(func), "layout_transform_map" , pass.UpdateIndexMap(map.value())); |
261 | } |
262 | } |
263 | |
264 | return func; |
265 | } |
266 | |
267 | explicit AxisSeparatorsAttrUnwrapper(Map<Buffer, Array<IntImm>> axis_separators_map) |
268 | : axis_separators_map_(axis_separators_map) {} |
269 | |
270 | Map<Var, Buffer> UpdateExternBufferMap(const Map<Var, Buffer>& orig) { |
271 | Map<Var, Buffer> output; |
272 | for (const auto& kv : orig) { |
273 | output.Set(kv.first, GetRemappedBuffer(kv.second)); |
274 | } |
275 | return output; |
276 | } |
277 | |
278 | Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) { |
279 | Map<Buffer, Array<IndexMap>> output; |
280 | for (const auto& kv : orig) { |
281 | output.Set(GetRemappedBuffer(kv.first), kv.second); |
282 | } |
283 | return output; |
284 | } |
285 | |
286 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
287 | auto ret = StmtExprMutator::VisitStmt_(op); |
288 | op = ret.as<AttrStmtNode>(); |
289 | |
290 | if (op->attr_key == tir::attr::axis_separators) { |
291 | return op->body; |
292 | } else if (op->attr_key == tir::attr::buffer_bind_scope) { |
293 | Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node); |
294 | Buffer view_buffer = Downcast<Buffer>(tuple[0]); |
295 | Buffer source_buffer = Downcast<Buffer>(tuple[1]); |
296 | return AttrStmt( |
297 | Array<ObjectRef>{GetRemappedBuffer(view_buffer), GetRemappedBuffer(source_buffer)}, |
298 | op->attr_key, op->value, op->body); |
299 | } else { |
300 | return ret; |
301 | } |
302 | } |
303 | |
304 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
305 | auto node = Downcast<BufferRealize>(StmtExprMutator::VisitStmt_(op)); |
306 | return VisitBufferAccess(std::move(node)); |
307 | } |
308 | |
309 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
310 | auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); |
311 | return VisitBufferAccess(std::move(node)); |
312 | } |
313 | |
314 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
315 | auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); |
316 | return VisitBufferAccess(std::move(node)); |
317 | } |
318 | |
319 | private: |
320 | template <typename Node> |
321 | Node VisitBufferAccess(Node node) { |
322 | Buffer new_buf = GetRemappedBuffer(node->buffer); |
323 | if (!node->buffer.same_as(new_buf)) { |
324 | auto writer = node.CopyOnWrite(); |
325 | writer->buffer = new_buf; |
326 | } |
327 | return node; |
328 | } |
329 | |
330 | Buffer GetRemappedBuffer(Buffer buf) { |
331 | // If this buffer has already been remapped, then return the |
332 | // previous value. |
333 | auto key = buf.get(); |
334 | { |
335 | auto it = buffer_remap_.find(key); |
336 | if (it != buffer_remap_.end()) { |
337 | return it->second; |
338 | } |
339 | } |
340 | |
341 | // Otherwise, check if we need to add axis_separators to this |
342 | // buffer. |
343 | auto lookup = axis_separators_map_.Get(buf); |
344 | if (lookup) { |
345 | Array<IntImm> axis_separators = lookup.value(); |
346 | if (axis_separators.size()) { |
347 | auto write_ptr = buf.CopyOnWrite(); |
348 | write_ptr->axis_separators = axis_separators; |
349 | } |
350 | } |
351 | |
352 | // And cache the result for next time. |
353 | buffer_remap_[key] = buf; |
354 | |
355 | return buf; |
356 | } |
357 | |
358 | /*! Collect the axis separator information of all tensors in the statement. |
359 | * |
360 | * Must be done before constructing the buffers, since the |
361 | * attributes could either apply to the external buffers or to |
362 | * internal allocations. |
363 | */ |
364 | class Collector : StmtExprVisitor { |
365 | public: |
366 | static Map<Buffer, Array<IntImm>> Collect(Stmt stmt) { |
367 | Collector collector; |
368 | collector(std::move(stmt)); |
369 | return std::move(collector.axis_separators_map_); |
370 | } |
371 | |
372 | Collector() {} |
373 | |
374 | void VisitStmt_(const AttrStmtNode* op) final { |
375 | if (op->attr_key == tir::attr::axis_separators) { |
376 | auto arr = Downcast<Array<ObjectRef>>(op->node); |
377 | ICHECK_EQ(arr.size(), 2); |
378 | |
379 | auto buffer = Downcast<Buffer>(arr[0]); |
380 | auto axis_separators = Downcast<Array<IntImm>>(arr[1]); |
381 | axis_separators_map_.Set(buffer, axis_separators); |
382 | } |
383 | StmtExprVisitor::VisitStmt_(op); |
384 | } |
385 | |
386 | Map<Buffer, Array<IntImm>> axis_separators_map_; |
387 | }; |
388 | |
389 | std::unordered_map<const BufferNode*, Buffer> buffer_remap_; |
390 | |
391 | Map<Buffer, Array<IntImm>> axis_separators_map_; |
392 | }; |
393 | |
394 | PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body, |
395 | Optional<Map<Tensor, Buffer>> extern_buffer_opt) { |
396 | std::unordered_map<Tensor, Buffer> extern_tensor_map; |
397 | |
398 | if (extern_buffer_opt.defined()) { |
399 | auto v = extern_buffer_opt.value(); |
400 | extern_tensor_map = std::unordered_map<Tensor, Buffer>(v.begin(), v.end()); |
401 | } |
402 | |
403 | Array<tir::Var> params; |
404 | Map<tir::Var, tir::Buffer> buffer_map; |
405 | |
406 | for (auto arg : arg_list) { |
407 | if (auto* n = arg.as<tir::VarNode>()) { |
408 | tir::Var var = GetRef<tir::Var>(n); |
409 | params.push_back(GetRef<tir::Var>(n)); |
410 | } else if (auto* n = arg.as<te::TensorNode>()) { |
411 | te::Tensor tensor = GetRef<te::Tensor>(n); |
412 | ICHECK(!extern_tensor_map.count(tensor)); |
413 | |
414 | tir::Buffer buffer = CreateBufferFor(tensor); |
415 | tir::Var bptr(buffer->name, PrimType(DataType::Handle())); |
416 | params.push_back(bptr); |
417 | buffer_map.Set(bptr, buffer); |
418 | extern_tensor_map[tensor] = buffer; |
419 | } else if (auto* n = arg.as<tir::BufferNode>()) { |
420 | tir::Buffer buffer = GetRef<tir::Buffer>(n); |
421 | tir::Var bptr(buffer->name, PrimType(DataType::Handle())); |
422 | params.push_back(bptr); |
423 | buffer_map.Set(bptr, buffer); |
424 | } else { |
425 | LOG(FATAL) << "Expected argument to be Var, Tensor, or Buffer, but received " |
426 | << arg->GetTypeKey(); |
427 | } |
428 | } |
429 | |
430 | body = TensorToBufferMapper(std::move(extern_tensor_map))(std::move(body)); |
431 | |
432 | PrimFunc func = tir::PrimFunc(params, body, VoidType(), buffer_map); |
433 | |
434 | func = LayoutTransformAttrUnwrapper::Apply(std::move(func)); |
435 | func = AxisSeparatorsAttrUnwrapper::Apply(std::move(func)); |
436 | |
437 | // We mark this PrimFunc as coming from a TE schedule |
438 | func = WithAttr(func, "from_legacy_te_schedule" , Bool(true)); |
439 | |
440 | return func; |
441 | } |
442 | |
443 | TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc" ) |
444 | .set_body_typed(SchedulePostProcToPrimFunc); |
445 | |
446 | } // namespace te |
447 | } // namespace tvm |
448 | |