1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file schedule_postproc_to_primfunc.cc
22 *
23 * \brief Translate the function body generated by ScheduleOps
24 * with te related dialects that incorporates Tensor
25 * into the Stmts to a PrimFunc.
26 *
27 * Perform this translation before running any TIR optimizations.
28 *
29 * Rationale: The body generated by ScheduleOps is not
30 * a formal PrimFunc and cannot be used for further optimization.
31 * This function canonicalize that body and creates a formal PrimFunc.
32 *
33 * List of actions taken by the function:
34 * - Remove occurrences of te::Tensor, te::Operation in the IR
35 * and replace them by corresponding IR nodes via tir::Buffer.
36 * - Add annotation of extern buffers using the buffer_map field
37 * in the PrimFunc type.
38 */
39#include <tvm/runtime/registry.h>
40#include <tvm/te/operation.h>
41#include <tvm/tir/expr.h>
42#include <tvm/tir/function.h>
43#include <tvm/tir/stmt_functor.h>
44
45#include <functional>
46#include <unordered_map>
47#include <utility>
48
49namespace tvm {
50namespace te {
51
52// create a buffer for tensor.
53Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") {
54 std::string name = tensor->op->name;
55 if (tensor->op->num_outputs() != 1) {
56 name += ".v" + std::to_string(tensor->value_index);
57 }
58 Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope);
59
60 return buffer;
61}
62
63// A remapper that maps tensor to buffer
64class TensorToBufferMapper : public StmtExprMutator {
65 public:
66 explicit TensorToBufferMapper(std::unordered_map<Tensor, Buffer> buffer_map)
67 : buffer_map_(buffer_map) {}
68
69 Stmt VisitStmt_(const AttrStmtNode* op) final {
70 auto ret = StmtExprMutator::VisitStmt_(op);
71 op = ret.as<AttrStmtNode>();
72 if (op->attr_key == tir::attr::double_buffer_scope ||
73 op->attr_key == tir::attr::rolling_buffer_scope) {
74 Stmt body = op->body;
75 Operation operation = Downcast<Operation>(op->node);
76 for (int i = operation->num_outputs(); i != 0; --i) {
77 Buffer buffer = GetOrAllocBuffer(operation.output(i - 1));
78 body = AttrStmt(buffer, op->attr_key, op->value, body);
79 }
80 return body;
81 } else if (op->attr_key == tir::attr::buffer_bind_scope) {
82 Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node);
83 Tensor tensor = Downcast<Tensor>(tuple[1]);
84 return AttrStmt(Array<ObjectRef>{tuple[0], GetOrAllocBuffer(tensor)}, op->attr_key, op->value,
85 op->body);
86 } else if (op->attr_key == tir::attr::buffer_dim_align ||
87 op->attr_key == tir::attr::prefetch_scope) {
88 Tensor tensor = Downcast<Tensor>(op->node);
89 Buffer buffer = GetOrAllocBuffer(tensor);
90 return AttrStmt(buffer, op->attr_key, op->value, op->body);
91 } else if (op->attr_key == tir::attr::layout_transforms ||
92 op->attr_key == tir::attr::axis_separators) {
93 auto arr = Downcast<Array<ObjectRef>>(op->node);
94 ICHECK_EQ(arr.size(), 2);
95
96 Stmt body = op->body;
97
98 Tensor tensor = Downcast<Tensor>(arr[0]);
99 Buffer buffer = GetBuffer(tensor);
100
101 return AttrStmt(Array<ObjectRef>{buffer, arr[1]}, op->attr_key, 1, body);
102 } else {
103 return ret;
104 }
105 }
106
107 Stmt VisitStmt_(const ProducerRealizeNode* op) final {
108 Tensor tensor = Downcast<Tensor>(op->producer);
109 Buffer buffer = GetOrAllocBuffer(tensor, op->storage_scope);
110
111 auto ret = StmtExprMutator::VisitStmt_(op);
112 op = ret.as<ProducerRealizeNode>();
113
114 return BufferRealize(buffer, op->bounds, op->condition, op->body);
115 }
116
117 Stmt VisitStmt_(const ProducerStoreNode* op) final {
118 Tensor tensor = Downcast<Tensor>(op->producer);
119 Buffer buffer = GetBuffer(tensor);
120
121 auto ret = StmtExprMutator::VisitStmt_(op);
122 op = ret.as<ProducerStoreNode>();
123
124 return BufferStore(buffer, op->value, GetIndices(op->indices, buffer->shape));
125 }
126
127 PrimExpr VisitExpr_(const ProducerLoadNode* op) final {
128 auto ret = StmtExprMutator::VisitExpr_(op);
129 op = ret.as<ProducerLoadNode>();
130 Tensor tensor = Downcast<Tensor>(op->producer);
131 Buffer buffer = GetBuffer(tensor);
132 return tir::BufferLoad(buffer, GetIndices(op->indices, buffer->shape));
133 }
134
135 private:
136 Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") {
137 return GetBuffer(tensor, storage_scope, true);
138 }
139
140 Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) {
141 auto it = buffer_map_.find(tensor);
142 if (it != buffer_map_.end()) return it->second;
143 ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor;
144
145 auto buffer = CreateBufferFor(tensor, storage_scope);
146 buffer_map_[tensor] = buffer;
147 return buffer;
148 }
149
150 Array<PrimExpr> GetIndices(const Array<PrimExpr>& tensor_indices,
151 const Array<PrimExpr>& buffer_shape) {
152 if (tensor_indices.size() == buffer_shape.size()) {
153 return tensor_indices;
154 } else if (tensor_indices.size() == 1) {
155 // Workaround to support previous behavior of tensor indexing by
156 // a single index, treating the tensor as if were already
157 // flattened by a row-major traversal.
158 PrimExpr unravel = tensor_indices[0];
159 Array<PrimExpr> rev_indices;
160 for (size_t i = buffer_shape.size(); i > 0; i--) {
161 PrimExpr dim = buffer_shape[i - 1];
162 rev_indices.push_back(indexmod(unravel, dim));
163 unravel = indexdiv(unravel, dim);
164 }
165 return Array<PrimExpr>(rev_indices.rbegin(), rev_indices.rend());
166 } else {
167 LOG(FATAL) << "Cannot produce indices for " << buffer_shape.size()
168 << "-dimensional TIR buffer using " << tensor_indices.size()
169 << "-dimensional tensor indices.";
170 return {};
171 }
172 }
173
174 // Maps tensor to buffer.
175 std::unordered_map<Tensor, Buffer> buffer_map_;
176};
177
178/*! Collect the physical layout map of all tensors in the statement. */
179class LayoutTransformAttrUnwrapper : StmtExprMutator {
180 public:
181 static tir::PrimFunc Apply(tir::PrimFunc func) {
182 // Collect the physical layout annotations in the body, which may
183 // refer to input arguments.
184 auto layout_map = Collector::Collect(func->body);
185
186 if (layout_map.size()) {
187 func = WithAttr(std::move(func), "layout_transform_map", layout_map);
188
189 auto write_ptr = func.CopyOnWrite();
190 write_ptr->body = LayoutTransformAttrUnwrapper()(func->body);
191 }
192
193 return func;
194 }
195
196 LayoutTransformAttrUnwrapper() {}
197
198 Stmt VisitStmt_(const AttrStmtNode* op) final {
199 auto ret = StmtExprMutator::VisitStmt_(op);
200 op = ret.as<AttrStmtNode>();
201
202 if (op->attr_key == tir::attr::layout_transforms) {
203 return op->body;
204 } else {
205 return ret;
206 }
207 }
208
209 private:
210 /*! Collect the physical layout information of all tensors in the statement.
211 *
212 * Must be done before constructing the buffers, since the
213 * attributes could either apply to the external buffers or to
214 * internal allocations.
215 */
216 class Collector : StmtExprVisitor {
217 public:
218 static Map<Buffer, Array<IndexMap>> Collect(Stmt stmt) {
219 Collector collector;
220 collector(std::move(stmt));
221 return std::move(collector.layout_map_);
222 }
223
224 Collector() {}
225
226 void VisitStmt_(const AttrStmtNode* op) final {
227 if (op->attr_key == tir::attr::layout_transforms) {
228 auto arr = Downcast<Array<ObjectRef>>(op->node);
229 ICHECK_EQ(arr.size(), 2);
230
231 auto buffer = Downcast<Buffer>(arr[0]);
232 auto layout_transforms = Downcast<Array<IndexMap>>(arr[1]);
233 layout_map_.Set(buffer, layout_transforms);
234 }
235 StmtExprVisitor::VisitStmt_(op);
236 }
237
238 Map<Buffer, Array<IndexMap>> layout_map_;
239 };
240
241 std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
242
243 Map<Buffer, Array<IndexMap>> layout_map_;
244};
245
246/*! Move axis_separators from an attribute to a buffer property. */
247class AxisSeparatorsAttrUnwrapper : StmtExprMutator {
248 public:
249 static tir::PrimFunc Apply(tir::PrimFunc func) {
250 // Collect the physical layout annotations in the body, which may
251 // refer to input arguments.
252 auto axis_separators_map = Collector::Collect(func->body);
253
254 if (axis_separators_map.size()) {
255 auto write_ptr = func.CopyOnWrite();
256 auto pass = AxisSeparatorsAttrUnwrapper(axis_separators_map);
257 write_ptr->buffer_map = pass.UpdateExternBufferMap(func->buffer_map);
258 write_ptr->body = pass(func->body);
259 if (auto map = func->attrs.GetAttr<Map<Buffer, Array<IndexMap>>>("layout_transform_map")) {
260 func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value()));
261 }
262 }
263
264 return func;
265 }
266
267 explicit AxisSeparatorsAttrUnwrapper(Map<Buffer, Array<IntImm>> axis_separators_map)
268 : axis_separators_map_(axis_separators_map) {}
269
270 Map<Var, Buffer> UpdateExternBufferMap(const Map<Var, Buffer>& orig) {
271 Map<Var, Buffer> output;
272 for (const auto& kv : orig) {
273 output.Set(kv.first, GetRemappedBuffer(kv.second));
274 }
275 return output;
276 }
277
278 Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, Array<IndexMap>>& orig) {
279 Map<Buffer, Array<IndexMap>> output;
280 for (const auto& kv : orig) {
281 output.Set(GetRemappedBuffer(kv.first), kv.second);
282 }
283 return output;
284 }
285
286 Stmt VisitStmt_(const AttrStmtNode* op) final {
287 auto ret = StmtExprMutator::VisitStmt_(op);
288 op = ret.as<AttrStmtNode>();
289
290 if (op->attr_key == tir::attr::axis_separators) {
291 return op->body;
292 } else if (op->attr_key == tir::attr::buffer_bind_scope) {
293 Array<ObjectRef> tuple = Downcast<Array<ObjectRef>>(op->node);
294 Buffer view_buffer = Downcast<Buffer>(tuple[0]);
295 Buffer source_buffer = Downcast<Buffer>(tuple[1]);
296 return AttrStmt(
297 Array<ObjectRef>{GetRemappedBuffer(view_buffer), GetRemappedBuffer(source_buffer)},
298 op->attr_key, op->value, op->body);
299 } else {
300 return ret;
301 }
302 }
303
304 Stmt VisitStmt_(const BufferRealizeNode* op) final {
305 auto node = Downcast<BufferRealize>(StmtExprMutator::VisitStmt_(op));
306 return VisitBufferAccess(std::move(node));
307 }
308
309 Stmt VisitStmt_(const BufferStoreNode* op) final {
310 auto node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
311 return VisitBufferAccess(std::move(node));
312 }
313
314 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
315 auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
316 return VisitBufferAccess(std::move(node));
317 }
318
319 private:
320 template <typename Node>
321 Node VisitBufferAccess(Node node) {
322 Buffer new_buf = GetRemappedBuffer(node->buffer);
323 if (!node->buffer.same_as(new_buf)) {
324 auto writer = node.CopyOnWrite();
325 writer->buffer = new_buf;
326 }
327 return node;
328 }
329
330 Buffer GetRemappedBuffer(Buffer buf) {
331 // If this buffer has already been remapped, then return the
332 // previous value.
333 auto key = buf.get();
334 {
335 auto it = buffer_remap_.find(key);
336 if (it != buffer_remap_.end()) {
337 return it->second;
338 }
339 }
340
341 // Otherwise, check if we need to add axis_separators to this
342 // buffer.
343 auto lookup = axis_separators_map_.Get(buf);
344 if (lookup) {
345 Array<IntImm> axis_separators = lookup.value();
346 if (axis_separators.size()) {
347 auto write_ptr = buf.CopyOnWrite();
348 write_ptr->axis_separators = axis_separators;
349 }
350 }
351
352 // And cache the result for next time.
353 buffer_remap_[key] = buf;
354
355 return buf;
356 }
357
358 /*! Collect the axis separator information of all tensors in the statement.
359 *
360 * Must be done before constructing the buffers, since the
361 * attributes could either apply to the external buffers or to
362 * internal allocations.
363 */
364 class Collector : StmtExprVisitor {
365 public:
366 static Map<Buffer, Array<IntImm>> Collect(Stmt stmt) {
367 Collector collector;
368 collector(std::move(stmt));
369 return std::move(collector.axis_separators_map_);
370 }
371
372 Collector() {}
373
374 void VisitStmt_(const AttrStmtNode* op) final {
375 if (op->attr_key == tir::attr::axis_separators) {
376 auto arr = Downcast<Array<ObjectRef>>(op->node);
377 ICHECK_EQ(arr.size(), 2);
378
379 auto buffer = Downcast<Buffer>(arr[0]);
380 auto axis_separators = Downcast<Array<IntImm>>(arr[1]);
381 axis_separators_map_.Set(buffer, axis_separators);
382 }
383 StmtExprVisitor::VisitStmt_(op);
384 }
385
386 Map<Buffer, Array<IntImm>> axis_separators_map_;
387 };
388
389 std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
390
391 Map<Buffer, Array<IntImm>> axis_separators_map_;
392};
393
394PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list, Stmt body,
395 Optional<Map<Tensor, Buffer>> extern_buffer_opt) {
396 std::unordered_map<Tensor, Buffer> extern_tensor_map;
397
398 if (extern_buffer_opt.defined()) {
399 auto v = extern_buffer_opt.value();
400 extern_tensor_map = std::unordered_map<Tensor, Buffer>(v.begin(), v.end());
401 }
402
403 Array<tir::Var> params;
404 Map<tir::Var, tir::Buffer> buffer_map;
405
406 for (auto arg : arg_list) {
407 if (auto* n = arg.as<tir::VarNode>()) {
408 tir::Var var = GetRef<tir::Var>(n);
409 params.push_back(GetRef<tir::Var>(n));
410 } else if (auto* n = arg.as<te::TensorNode>()) {
411 te::Tensor tensor = GetRef<te::Tensor>(n);
412 ICHECK(!extern_tensor_map.count(tensor));
413
414 tir::Buffer buffer = CreateBufferFor(tensor);
415 tir::Var bptr(buffer->name, PrimType(DataType::Handle()));
416 params.push_back(bptr);
417 buffer_map.Set(bptr, buffer);
418 extern_tensor_map[tensor] = buffer;
419 } else if (auto* n = arg.as<tir::BufferNode>()) {
420 tir::Buffer buffer = GetRef<tir::Buffer>(n);
421 tir::Var bptr(buffer->name, PrimType(DataType::Handle()));
422 params.push_back(bptr);
423 buffer_map.Set(bptr, buffer);
424 } else {
425 LOG(FATAL) << "Expected argument to be Var, Tensor, or Buffer, but received "
426 << arg->GetTypeKey();
427 }
428 }
429
430 body = TensorToBufferMapper(std::move(extern_tensor_map))(std::move(body));
431
432 PrimFunc func = tir::PrimFunc(params, body, VoidType(), buffer_map);
433
434 func = LayoutTransformAttrUnwrapper::Apply(std::move(func));
435 func = AxisSeparatorsAttrUnwrapper::Apply(std::move(func));
436
437 // We mark this PrimFunc as coming from a TE schedule
438 func = WithAttr(func, "from_legacy_te_schedule", Bool(true));
439
440 return func;
441}
442
443TVM_REGISTER_GLOBAL("schedule.SchedulePostProcToPrimFunc")
444 .set_body_typed(SchedulePostProcToPrimFunc);
445
446} // namespace te
447} // namespace tvm
448