1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/runtime/device_api.h>
20
21#include "./utils.h"
22
23namespace tvm {
24namespace script {
25namespace printer {
26
27bool IsSimpleBuffer(const tir::Buffer& buf) {
28 if (!buf->strides.empty()) {
29 return false;
30 }
31 for (const PrimExpr& shp_i : buf->shape) {
32 if (!tir::UndefinedVars(shp_i).empty()) {
33 return false;
34 }
35 }
36 for (const PrimExpr& stride_i : buf->strides) {
37 if (!tir::UndefinedVars(stride_i).empty()) {
38 return false;
39 }
40 }
41 if (!tir::UndefinedVars(buf->elem_offset).empty()) {
42 return false;
43 } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
44 IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
45 if (elem_offset->value != 0) {
46 return false;
47 }
48 }
49 return buf.scope() == "global" && buf->data_alignment == runtime::kAllocAlignment &&
50 buf->offset_factor == 1 && buf->buffer_type == tir::BufferType::kDefault &&
51 !buf->axis_separators.size();
52}
53
54int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) {
55 OccurrenceCounter counter(v.get());
56 counter(f->body);
57 for (const tir::Var& v : f->params) {
58 counter(v);
59 }
60 for (const auto& pair : f->buffer_map) {
61 counter(pair.first);
62 counter.VisitBuffer(pair.second.get());
63 }
64 return counter.count;
65}
66
67TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
68 .set_dispatch<tir::PrimFunc>("", [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc {
69 With<TIRFrame> f(d, func);
70 (*f)->AddDispatchToken(d, "tir");
71 d->SetCommonPrefix(func, [](const ObjectRef& obj) {
72 return obj->IsInstance<tir::VarNode>() || obj->IsInstance<tir::BufferNode>();
73 });
74 int n_args = func->params.size();
75 std::unordered_map<const tir::VarNode*, int> buffer_data_counter;
76 for (const auto& pair : func->buffer_map) {
77 const tir::VarNode* data_var = pair.second->data.get();
78 if (!buffer_data_counter.count(data_var)) {
79 buffer_data_counter.insert({data_var, 0});
80 }
81 ++buffer_data_counter.at(data_var);
82 }
83 // Step 1. Handle `func->params`
84 Array<AssignDoc> args;
85 args.reserve(n_args);
86 std::unordered_set<const tir::BufferNode*> buffer_inlined;
87 for (int i = 0; i < n_args; ++i) {
88 tir::Var var = func->params[i];
89 ObjectPath var_p = p->Attr("params")->ArrayIndex(i);
90 if (d->cfg->syntax_sugar && CountVarOccurrence(func, var) == 2 &&
91 func->buffer_map.count(var)) {
92 tir::Buffer buffer = func->buffer_map[var];
93 if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) {
94 ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var);
95 args.push_back(AssignDoc(DefineBuffer(buffer, *f, d), NullOpt,
96 BufferAttn(buffer, buffer_p, *f, d)));
97 buffer_inlined.insert(buffer.get());
98 continue;
99 }
100 }
101 ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation, var_p->Attr("type_annotation"));
102 args.push_back(AssignDoc(DefineVar(var, *f, d), NullOpt, a));
103 }
104 // Step 2. Handle `func->attrs`
105 if (func->attrs.defined() && !func->attrs->dict.empty()) {
106 (*f)->stmts.push_back(
107 ExprStmtDoc(TIR(d, "func_attr") //
108 ->Call({d->AsDoc<ExprDoc>(func->attrs, p->Attr("attrs"))})));
109 }
110 // Step 3. Handle `func->buffer_map`
111 for (int i = 0; i < n_args; ++i) {
112 tir::Var param = func->params[i];
113 if (func->buffer_map.count(param)) {
114 tir::Buffer buffer = func->buffer_map[param];
115 if (buffer_inlined.count(buffer.get())) {
116 continue;
117 }
118 ExprDoc param_doc = args[i]->lhs;
119 ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param);
120 ExprDoc lhs = DefineBuffer(buffer, *f, d); // TODO(@junrushao): switch `lhs` and `rhs`
121 ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc}, buffer_p, *f, d);
122 (*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
123 }
124 }
125 // Step 4. Handle `func->body`
126 Optional<tir::Block> implicit_root_block = [&]() -> Optional<tir::Block> {
127 const tir::BlockRealizeNode* root_block_realize = func->body.as<tir::BlockRealizeNode>();
128 if (root_block_realize && !root_block_realize->iter_values.size() &&
129 tir::is_one(root_block_realize->predicate)) {
130 tir::Block root_block = root_block_realize->block;
131 if (!root_block->annotations.size() && !root_block->match_buffers.size() &&
132 !root_block->reads.size() && !root_block->writes.size() &&
133 !root_block->init.defined()) {
134 const tir::BlockRealizeNode* block_realize =
135 root_block->body.as<tir::BlockRealizeNode>();
136 if (root_block->alloc_buffers.size() ||
137 (block_realize && block_realize->block->iter_vars.size()) ||
138 (!block_realize && tir::ContainsNode<tir::BlockRealizeNode>(root_block->body))) {
139 return root_block;
140 }
141 }
142 }
143 return NullOpt;
144 }();
145 if (d->cfg->syntax_sugar && implicit_root_block) {
146 tir::Block root_block = implicit_root_block.value();
147 ObjectPath root_block_p = p->Attr("body")->Attr("block");
148 (*f)->stmts.push_back(CommentDoc("with T.block(\"root\"):"));
149 // Handle root block `alloc_buffer`
150 for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) {
151 tir::Buffer buffer = root_block->alloc_buffers[i];
152 ObjectPath buffer_p = root_block_p->Attr("alloc_buffers")->ArrayIndex(i);
153 IdDoc lhs = DefineBuffer(buffer, *f, d);
154 ExprDoc rhs = BufferDecl(buffer, "alloc_buffer", {}, buffer_p, *f, d);
155 (*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
156 }
157 AsDocBody(root_block->body, root_block_p->Attr("body"), f->get(), d);
158 } else {
159 AsDocBody(func->body, p->Attr("body"), f->get(), d);
160 }
161 Optional<ExprDoc> ret_type = NullOpt;
162 if (func->ret_type.defined()) {
163 const auto* as_tuple = func->ret_type.as<TupleTypeNode>();
164 if (!as_tuple || as_tuple->fields.size()) {
165 ret_type = d->AsDoc<ExprDoc>(func->ret_type, p->Attr("ret_type"));
166 }
167 }
168 return HeaderWrapper(d, FunctionDoc(
169 /*name=*/IdDoc(FindFunctionName(d, func).value_or("main")),
170 /*args=*/args,
171 /*decorators=*/{TIR(d, "prim_func")},
172 /*return_type=*/ret_type,
173 /*body=*/(*f)->stmts));
174 });
175
176TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR);
177
178} // namespace printer
179} // namespace script
180} // namespace tvm
181