1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/runtime/device_api.h> |
20 | |
21 | #include "./utils.h" |
22 | |
23 | namespace tvm { |
24 | namespace script { |
25 | namespace printer { |
26 | |
27 | bool IsSimpleBuffer(const tir::Buffer& buf) { |
28 | if (!buf->strides.empty()) { |
29 | return false; |
30 | } |
31 | for (const PrimExpr& shp_i : buf->shape) { |
32 | if (!tir::UndefinedVars(shp_i).empty()) { |
33 | return false; |
34 | } |
35 | } |
36 | for (const PrimExpr& stride_i : buf->strides) { |
37 | if (!tir::UndefinedVars(stride_i).empty()) { |
38 | return false; |
39 | } |
40 | } |
41 | if (!tir::UndefinedVars(buf->elem_offset).empty()) { |
42 | return false; |
43 | } else if (buf->elem_offset->IsInstance<IntImmNode>()) { |
44 | IntImm elem_offset = Downcast<IntImm>(buf->elem_offset); |
45 | if (elem_offset->value != 0) { |
46 | return false; |
47 | } |
48 | } |
49 | return buf.scope() == "global" && buf->data_alignment == runtime::kAllocAlignment && |
50 | buf->offset_factor == 1 && buf->buffer_type == tir::BufferType::kDefault && |
51 | !buf->axis_separators.size(); |
52 | } |
53 | |
54 | int CountVarOccurrence(const tir::PrimFunc& f, const tir::Var& v) { |
55 | OccurrenceCounter counter(v.get()); |
56 | counter(f->body); |
57 | for (const tir::Var& v : f->params) { |
58 | counter(v); |
59 | } |
60 | for (const auto& pair : f->buffer_map) { |
61 | counter(pair.first); |
62 | counter.VisitBuffer(pair.second.get()); |
63 | } |
64 | return counter.count; |
65 | } |
66 | |
67 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
68 | .set_dispatch<tir::PrimFunc>("" , [](tir::PrimFunc func, ObjectPath p, IRDocsifier d) -> Doc { |
69 | With<TIRFrame> f(d, func); |
70 | (*f)->AddDispatchToken(d, "tir" ); |
71 | d->SetCommonPrefix(func, [](const ObjectRef& obj) { |
72 | return obj->IsInstance<tir::VarNode>() || obj->IsInstance<tir::BufferNode>(); |
73 | }); |
74 | int n_args = func->params.size(); |
75 | std::unordered_map<const tir::VarNode*, int> buffer_data_counter; |
76 | for (const auto& pair : func->buffer_map) { |
77 | const tir::VarNode* data_var = pair.second->data.get(); |
78 | if (!buffer_data_counter.count(data_var)) { |
79 | buffer_data_counter.insert({data_var, 0}); |
80 | } |
81 | ++buffer_data_counter.at(data_var); |
82 | } |
83 | // Step 1. Handle `func->params` |
84 | Array<AssignDoc> args; |
85 | args.reserve(n_args); |
86 | std::unordered_set<const tir::BufferNode*> buffer_inlined; |
87 | for (int i = 0; i < n_args; ++i) { |
88 | tir::Var var = func->params[i]; |
89 | ObjectPath var_p = p->Attr("params" )->ArrayIndex(i); |
90 | if (d->cfg->syntax_sugar && CountVarOccurrence(func, var) == 2 && |
91 | func->buffer_map.count(var)) { |
92 | tir::Buffer buffer = func->buffer_map[var]; |
93 | if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { |
94 | ObjectPath buffer_p = p->Attr("buffer_map" )->MapValue(var); |
95 | args.push_back(AssignDoc(DefineBuffer(buffer, *f, d), NullOpt, |
96 | BufferAttn(buffer, buffer_p, *f, d))); |
97 | buffer_inlined.insert(buffer.get()); |
98 | continue; |
99 | } |
100 | } |
101 | ExprDoc a = d->AsDoc<ExprDoc>(var->type_annotation, var_p->Attr("type_annotation" )); |
102 | args.push_back(AssignDoc(DefineVar(var, *f, d), NullOpt, a)); |
103 | } |
104 | // Step 2. Handle `func->attrs` |
105 | if (func->attrs.defined() && !func->attrs->dict.empty()) { |
106 | (*f)->stmts.push_back( |
107 | ExprStmtDoc(TIR(d, "func_attr" ) // |
108 | ->Call({d->AsDoc<ExprDoc>(func->attrs, p->Attr("attrs" ))}))); |
109 | } |
110 | // Step 3. Handle `func->buffer_map` |
111 | for (int i = 0; i < n_args; ++i) { |
112 | tir::Var param = func->params[i]; |
113 | if (func->buffer_map.count(param)) { |
114 | tir::Buffer buffer = func->buffer_map[param]; |
115 | if (buffer_inlined.count(buffer.get())) { |
116 | continue; |
117 | } |
118 | ExprDoc param_doc = args[i]->lhs; |
119 | ObjectPath buffer_p = p->Attr("buffer_map" )->MapValue(param); |
120 | ExprDoc lhs = DefineBuffer(buffer, *f, d); // TODO(@junrushao): switch `lhs` and `rhs` |
121 | ExprDoc rhs = BufferDecl(buffer, "match_buffer" , {param_doc}, buffer_p, *f, d); |
122 | (*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); |
123 | } |
124 | } |
125 | // Step 4. Handle `func->body` |
126 | Optional<tir::Block> implicit_root_block = [&]() -> Optional<tir::Block> { |
127 | const tir::BlockRealizeNode* root_block_realize = func->body.as<tir::BlockRealizeNode>(); |
128 | if (root_block_realize && !root_block_realize->iter_values.size() && |
129 | tir::is_one(root_block_realize->predicate)) { |
130 | tir::Block root_block = root_block_realize->block; |
131 | if (!root_block->annotations.size() && !root_block->match_buffers.size() && |
132 | !root_block->reads.size() && !root_block->writes.size() && |
133 | !root_block->init.defined()) { |
134 | const tir::BlockRealizeNode* block_realize = |
135 | root_block->body.as<tir::BlockRealizeNode>(); |
136 | if (root_block->alloc_buffers.size() || |
137 | (block_realize && block_realize->block->iter_vars.size()) || |
138 | (!block_realize && tir::ContainsNode<tir::BlockRealizeNode>(root_block->body))) { |
139 | return root_block; |
140 | } |
141 | } |
142 | } |
143 | return NullOpt; |
144 | }(); |
145 | if (d->cfg->syntax_sugar && implicit_root_block) { |
146 | tir::Block root_block = implicit_root_block.value(); |
147 | ObjectPath root_block_p = p->Attr("body" )->Attr("block" ); |
148 | (*f)->stmts.push_back(CommentDoc("with T.block(\"root\"):" )); |
149 | // Handle root block `alloc_buffer` |
150 | for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { |
151 | tir::Buffer buffer = root_block->alloc_buffers[i]; |
152 | ObjectPath buffer_p = root_block_p->Attr("alloc_buffers" )->ArrayIndex(i); |
153 | IdDoc lhs = DefineBuffer(buffer, *f, d); |
154 | ExprDoc rhs = BufferDecl(buffer, "alloc_buffer" , {}, buffer_p, *f, d); |
155 | (*f)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); |
156 | } |
157 | AsDocBody(root_block->body, root_block_p->Attr("body" ), f->get(), d); |
158 | } else { |
159 | AsDocBody(func->body, p->Attr("body" ), f->get(), d); |
160 | } |
161 | Optional<ExprDoc> ret_type = NullOpt; |
162 | if (func->ret_type.defined()) { |
163 | const auto* as_tuple = func->ret_type.as<TupleTypeNode>(); |
164 | if (!as_tuple || as_tuple->fields.size()) { |
165 | ret_type = d->AsDoc<ExprDoc>(func->ret_type, p->Attr("ret_type" )); |
166 | } |
167 | } |
168 | return HeaderWrapper(d, FunctionDoc( |
169 | /*name=*/IdDoc(FindFunctionName(d, func).value_or("main" )), |
170 | /*args=*/args, |
171 | /*decorators=*/{TIR(d, "prim_func" )}, |
172 | /*return_type=*/ret_type, |
173 | /*body=*/(*f)->stmts)); |
174 | }); |
175 | |
176 | TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintTIR); |
177 | |
178 | } // namespace printer |
179 | } // namespace script |
180 | } // namespace tvm |
181 | |