1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../../../tir/transforms/ir_utils.h" // For `GetPtrStorageScope` |
20 | #include "./utils.h" |
21 | |
22 | namespace tvm { |
23 | namespace script { |
24 | namespace printer { |
25 | |
26 | Doc DoConciseScoping(const Optional<ExprDoc>& lhs, const ExprDoc& rhs, Array<StmtDoc>* stmts, |
27 | bool concise_scoping) { |
28 | if (concise_scoping) { |
29 | if (lhs.defined()) { |
30 | stmts->insert(stmts->begin(), AssignDoc(lhs.value(), rhs, NullOpt)); |
31 | } else { |
32 | stmts->insert(stmts->begin(), ExprStmtDoc(rhs)); |
33 | } |
34 | return StmtBlockDoc(*stmts); |
35 | } else { |
36 | return ScopeDoc(lhs, rhs, *stmts); |
37 | } |
38 | } |
39 | |
40 | bool AllowConciseScoping(const IRDocsifier& d) { |
41 | ICHECK(!d->frames.empty()); |
42 | if (const auto* f = d->frames.back().as<TIRFrameNode>()) { |
43 | return f->allow_concise_scoping; |
44 | } |
45 | LOG(FATAL) << "NotImplementedError: fragment printing" ; |
46 | } |
47 | |
48 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
49 | .set_dispatch<tir::Evaluate>("" , [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc { |
50 | ExprDoc value = d->AsDoc<ExprDoc>(eval->value, p->Attr("value" )); |
51 | if (eval->value->IsInstance<tir::CallNode>()) { |
52 | return ExprStmtDoc(value); |
53 | } |
54 | return ExprStmtDoc(TIR(d, "evaluate" )->Call({value})); |
55 | }); |
56 | |
57 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
58 | .set_dispatch<tir::LetStmt>("" , [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { |
59 | bool concise = AllowConciseScoping(d); |
60 | if (concise && !d->IsVarDefined(stmt->var)) { |
61 | ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value" )); |
62 | With<TIRFrame> f(d, stmt); |
63 | ExprDoc lhs = DefineVar(stmt->var, *f, d); |
64 | AsDocBody(stmt->body, p->Attr("body" ), f->get(), d); |
65 | Array<StmtDoc>* stmts = &(*f)->stmts; |
66 | Type type = stmt->var->type_annotation; |
67 | Optional<ExprDoc> type_doc = |
68 | d->AsDoc<ExprDoc>(type, p->Attr("var" )->Attr("type_annotation" )); |
69 | if (const auto* tuple_type = type.as<TupleTypeNode>()) { |
70 | if (tuple_type->fields.empty()) { |
71 | type_doc = NullOpt; |
72 | } |
73 | } |
74 | stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc)); |
75 | return StmtBlockDoc(*stmts); |
76 | } else { |
77 | ExprDoc lhs = d->AsDoc<ExprDoc>(stmt->var, p->Attr("var" )); |
78 | ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value" )); |
79 | With<TIRFrame> f(d, stmt); |
80 | AsDocBody(stmt->body, p->Attr("body" ), f->get(), d); |
81 | Array<StmtDoc>* stmts = &(*f)->stmts; |
82 | rhs = TIR(d, "let" )->Call({lhs, rhs}); |
83 | return ScopeDoc(NullOpt, rhs, *stmts); |
84 | } |
85 | }); |
86 | |
87 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
88 | .set_dispatch<tir::AssertStmt>( |
89 | "" , [](tir::AssertStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { |
90 | bool concise = AllowConciseScoping(d); |
91 | ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition" )); |
92 | ExprDoc msg = d->AsDoc<ExprDoc>(stmt->message, p->Attr("message" )); |
93 | With<TIRFrame> f(d, stmt); |
94 | AsDocBody(stmt->body, p->Attr("body" ), f->get(), d); |
95 | if (concise) { |
96 | Array<StmtDoc>* stmts = &(*f)->stmts; |
97 | stmts->insert(stmts->begin(), AssertDoc(cond, msg)); |
98 | return StmtBlockDoc(*stmts); |
99 | } |
100 | return ScopeDoc(NullOpt, TIR(d, "Assert" )->Call({cond, msg}), (*f)->stmts); |
101 | }); |
102 | |
103 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
104 | .set_dispatch<tir::While>("" , [](tir::While stmt, ObjectPath p, IRDocsifier d) -> Doc { |
105 | ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition" )); |
106 | With<TIRFrame> f(d, stmt); |
107 | AsDocBody(stmt->body, p->Attr("body" ), f->get(), d); |
108 | return WhileDoc(cond, (*f)->stmts); |
109 | }); |
110 | |
111 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
112 | .set_dispatch<tir::DeclBuffer>( // |
113 | "" , [](tir::DeclBuffer stmt, ObjectPath p, IRDocsifier d) -> Doc { |
114 | bool concise = AllowConciseScoping(d); |
115 | ExprDoc rhs = |
116 | BufferDecl(stmt->buffer, "decl_buffer" , {}, p->Attr("buffer" ), d->frames.back(), d); |
117 | With<TIRFrame> f(d, stmt); |
118 | ExprDoc lhs = DefineBuffer(stmt->buffer, *f, d); |
119 | AsDocBody(stmt->body, p->Attr("body" ), f->get(), d); |
120 | return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); |
121 | }); |
122 | |
123 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
124 | .set_dispatch<tir::IfThenElse>( // |
125 | "" , [](tir::IfThenElse stmt, ObjectPath p, IRDocsifier d) -> Doc { |
126 | ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition" )); |
127 | Array<StmtDoc> then_branch; |
128 | Array<StmtDoc> else_branch; |
129 | if (stmt->then_case.defined()) { |
130 | With<TIRFrame> f(d, stmt->then_case); |
131 | AsDocBody(stmt->then_case, p->Attr("then_case" ), f->get(), d); |
132 | then_branch = (*f)->stmts; |
133 | } |
134 | if (stmt->else_case.defined()) { |
135 | With<TIRFrame> f(d, stmt->else_case); |
136 | AsDocBody(stmt->else_case.value(), p->Attr("else_case" ), f->get(), d); |
137 | else_branch = (*f)->stmts; |
138 | } |
139 | return IfDoc(cond, then_branch, else_branch); |
140 | }); |
141 | |
142 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
143 | .set_dispatch<tir::SeqStmt>("" , [](tir::SeqStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { |
144 | With<TIRFrame> f(d, stmt); |
145 | AsDocBody(stmt, p, f->get(), d); |
146 | return StmtBlockDoc((*f)->stmts); |
147 | }); |
148 | |
149 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
150 | .set_dispatch<tir::Prefetch>( // |
151 | "" , [](tir::Prefetch stmt, ObjectPath p, IRDocsifier d) -> Doc { |
152 | return ExprStmtDoc(TIR(d, "prefetch" ) |
153 | ->Call({ |
154 | d->AsDoc<ExprDoc>(stmt->buffer, p->Attr("buffer" )), |
155 | d->AsDoc<ExprDoc>(stmt->bounds, p->Attr("bounds" )), |
156 | })); |
157 | }); |
158 | |
159 | bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) { |
160 | const tir::Var& buffer_var = allocate->buffer_var; |
161 | if (const tir::DeclBufferNode* decl_buffer = allocate->body.as<tir::DeclBufferNode>()) { |
162 | const tir::Buffer& buffer = decl_buffer->buffer; |
163 | if (buffer_var.same_as(buffer->data) && allocate->dtype == buffer->dtype && |
164 | tir::is_one(allocate->condition) && !allocate->annotations.size() && |
165 | allocate->extents.size() == buffer->shape.size()) { |
166 | tir::ExprDeepEqual expr_equal; |
167 | for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) { |
168 | if (!expr_equal(allocate->extents[i], buffer->shape[i])) { |
169 | return false; |
170 | } |
171 | } |
172 | return true; |
173 | } |
174 | } |
175 | return false; |
176 | } |
177 | |
178 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
179 | .set_dispatch<tir::Allocate>( // |
180 | "" , [](tir::Allocate stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { |
181 | bool concise = AllowConciseScoping(d); |
182 | if (d->cfg->syntax_sugar && IsAllocateDeclBufferPattern(stmt.get())) { |
183 | return d->AsDoc(stmt->body, stmt_p->Attr("body" )); |
184 | } |
185 | Array<ExprDoc> args; |
186 | Array<String> kwargs_keys; |
187 | Array<ExprDoc> kwargs_values; |
188 | args.push_back(d->AsDoc<ExprDoc>(stmt->extents, stmt_p->Attr("extents" ))); |
189 | args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype" ))); |
190 | args.push_back(LiteralDoc::Str(tir::GetPtrStorageScope(stmt->buffer_var), |
191 | stmt_p |
192 | ->Attr("buffer_var" ) // |
193 | ->Attr("type_annotation" ) |
194 | ->Attr("storage_scope" ))); |
195 | if (!tir::is_one(stmt->condition)) { |
196 | args.push_back(d->AsDoc<ExprDoc>(stmt->condition, stmt_p->Attr("condition" ))); |
197 | } |
198 | if (!stmt->annotations.empty()) { |
199 | kwargs_keys.push_back("annotations" ); |
200 | kwargs_values.push_back( |
201 | d->AsDoc<ExprDoc>(stmt->annotations, stmt_p->Attr("annotations" ))); |
202 | } |
203 | ExprDoc lhs = DefineVar(stmt->buffer_var, d->frames.back(), d); |
204 | With<TIRFrame> f(d, stmt); |
205 | ExprDoc rhs = TIR(d, "allocate" )->Call(args, kwargs_keys, kwargs_values); |
206 | AsDocBody(stmt->body, stmt_p->Attr("body" ), f->get(), d); |
207 | return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); |
208 | }); |
209 | |
210 | template <typename T> |
211 | ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) { |
212 | // FIXME(@junrushao): this is a hack and can be wrong in most of the cases |
213 | constexpr int NUM_PRINT = 200; |
214 | int ndim = arr->ndim; |
215 | int tot_dim = 1; |
216 | for (int i = 0; i < ndim; i++) { |
217 | tot_dim *= arr->shape[i]; |
218 | } |
219 | Array<ExprDoc> result; |
220 | T* data_ptr = reinterpret_cast<T*>(arr->data); |
221 | runtime::DataType dtype = arr.DataType(); |
222 | for (int i = 0; i < tot_dim; i++) { |
223 | if (dtype.is_float()) { |
224 | result.push_back(LiteralDoc::Float(data_ptr[i], NullOpt)); |
225 | } else { |
226 | result.push_back(LiteralDoc::Int(data_ptr[i], NullOpt)); |
227 | } |
228 | if (i == NUM_PRINT) { |
229 | break; |
230 | } |
231 | } |
232 | return ListDoc(result); |
233 | } |
234 | |
235 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
236 | .set_dispatch<tir::AllocateConst>( |
237 | "" , [](tir::AllocateConst stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { |
238 | bool concise = AllowConciseScoping(d); |
239 | String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var); |
240 | Array<ExprDoc> args; |
241 | Array<String> kwargs_keys; |
242 | Array<ExprDoc> kwargs_values; |
243 | ExprDoc data_doc{nullptr}; |
244 | if (stmt->dtype.is_int()) { |
245 | if (stmt->dtype.bits() == 8) { |
246 | data_doc = PrintNDArray<int8_t>(stmt->data.value()); |
247 | } else if (stmt->dtype.bits() == 16) { |
248 | data_doc = PrintNDArray<int16_t>(stmt->data.value()); |
249 | } else if (stmt->dtype.bits() == 32) { |
250 | data_doc = PrintNDArray<int32_t>(stmt->data.value()); |
251 | } else if (stmt->dtype.bits() == 64) { |
252 | data_doc = PrintNDArray<int64_t>(stmt->data.value()); |
253 | } else { |
254 | LOG(FATAL) << "DataType not supported" ; |
255 | } |
256 | } else if (stmt->dtype.is_uint()) { |
257 | if (stmt->dtype.bits() == 8) { |
258 | data_doc = PrintNDArray<uint8_t>(stmt->data.value()); |
259 | } else if (stmt->dtype.bits() == 16) { |
260 | data_doc = PrintNDArray<uint16_t>(stmt->data.value()); |
261 | } else if (stmt->dtype.bits() == 32) { |
262 | data_doc = PrintNDArray<uint32_t>(stmt->data.value()); |
263 | } else if (stmt->dtype.bits() == 64) { |
264 | data_doc = PrintNDArray<uint64_t>(stmt->data.value()); |
265 | } else { |
266 | LOG(FATAL) << "DataType not supported" ; |
267 | } |
268 | } else if (stmt->dtype.is_float()) { |
269 | if (stmt->dtype.bits() == 16) { |
270 | data_doc = PrintNDArray<int16_t>(stmt->data.value()); |
271 | } else if (stmt->dtype.bits() == 32) { |
272 | data_doc = PrintNDArray<float>(stmt->data.value()); |
273 | } else if (stmt->dtype.bits() == 64) { |
274 | data_doc = PrintNDArray<double>(stmt->data.value()); |
275 | } else { |
276 | LOG(FATAL) << "DataType not supported" ; |
277 | } |
278 | } else { |
279 | LOG(FATAL) << "DataType not supported" ; |
280 | } |
281 | args.push_back(data_doc); |
282 | args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype" ))); |
283 | args.push_back(d->AsDoc<ExprDoc>(stmt->extents, stmt_p->Attr("extents" ))); |
284 | ExprDoc rhs = TIR(d, "allocate_const" )->Call(args, kwargs_keys, kwargs_values); |
285 | With<TIRFrame> f(d, stmt); |
286 | ExprDoc lhs = DefineVar(stmt->buffer_var, *f, d); |
287 | AsDocBody(stmt->body, stmt_p->Attr("body" ), f->get(), d); |
288 | return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); |
289 | }); |
290 | |
291 | ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional<ExprDoc> value, // |
292 | ObjectPath p, IRDocsifier d) { |
293 | ExprDoc buffer = d->AsDoc<ExprDoc>(stmt->buffer, p->Attr("buffer" )); |
294 | { |
295 | Array<Doc> bounds; |
296 | bounds.reserve(stmt->bounds.size()); |
297 | for (int i = 0, n = stmt->bounds.size(); i < n; ++i) { |
298 | Range range = stmt->bounds[i]; |
299 | ObjectPath range_p = p->Attr("bounds" )->ArrayIndex(i); |
300 | bounds.push_back( |
301 | SliceDoc(d->AsDoc<ExprDoc>(range->min, range_p->Attr("min" )), |
302 | d->AsDoc<ExprDoc>(range->min + range->extent, range_p->Attr("extent" )), // |
303 | NullOpt)); |
304 | } |
305 | buffer = buffer[bounds]; |
306 | } |
307 | Array<ExprDoc> args{buffer}; |
308 | Array<String> kwargs_keys; |
309 | Array<ExprDoc> kwargs_values; |
310 | if (value.defined()) { |
311 | args.push_back(value.value()); |
312 | } |
313 | if (!tir::is_one(stmt->condition)) { |
314 | kwargs_keys.push_back("condition" ); |
315 | kwargs_values.push_back(d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition" ))); |
316 | } |
317 | return TIR(d, "realize" )->Call(args, kwargs_keys, kwargs_values); |
318 | } |
319 | |
320 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
321 | .set_dispatch<tir::BufferRealize>( // |
322 | "" , [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { |
323 | bool concise = AllowConciseScoping(d); |
324 | ExprDoc rhs = DocsifyBufferRealize(stmt.get(), NullOpt, p, d); |
325 | With<TIRFrame> f(d, stmt); |
326 | AsDocBody(stmt->body, p->Attr("body" ), f->get(), d); |
327 | return DoConciseScoping(NullOpt, rhs, &(*f)->stmts, concise); |
328 | }); |
329 | |
330 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
331 | .set_dispatch<tir::AttrStmt>( // |
332 | "" , [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { |
333 | bool concise = AllowConciseScoping(d); |
334 | Optional<ExprDoc> rhs = NullOpt; |
335 | tir::Stmt body = stmt->body; |
336 | ObjectPath body_p = stmt_p->Attr("body" ); |
337 | if (stmt->attr_key == "realize_scope" ) { |
338 | if (const auto* realize = stmt->body.as<tir::BufferRealizeNode>()) { |
339 | if (realize->buffer.same_as(stmt->node)) { |
340 | rhs = DocsifyBufferRealize( |
341 | realize, |
342 | /*value=*/d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value" )), |
343 | /*p=*/stmt_p->Attr("body" ), d); |
344 | body = realize->body; |
345 | body_p = body_p->Attr("body" ); |
346 | } |
347 | } |
348 | } |
349 | if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread" ) { |
350 | if (const auto* iter_var = stmt->node.as<tir::IterVarNode>()) { |
351 | if (!d->IsVarDefined(iter_var->var)) { |
352 | // `DefineVar` is not used here because a more specific name is desirable |
353 | ObjectPath iter_var_p = stmt_p->Attr("node" ); |
354 | Frame f = FindLowestVarDef(iter_var->var, d).value(); |
355 | DefineVar(iter_var->var, f, d); |
356 | f->stmts.push_back( |
357 | AssignDoc(d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var" )), |
358 | TIR(d, "env_thread" ) |
359 | ->Call({LiteralDoc::Str(iter_var->thread_tag, |
360 | iter_var_p->Attr("thread_tag" ))}), // |
361 | NullOpt)); |
362 | } |
363 | rhs = TIR(d, "launch_thread" ) |
364 | ->Call({ |
365 | d->AsDoc<ExprDoc>(iter_var->var, stmt_p->Attr("node" )), |
366 | d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value" )), |
367 | }); |
368 | } |
369 | } |
370 | if (!rhs.defined()) { |
371 | rhs = TIR(d, "attr" )->Call({ |
372 | d->AsDoc<ExprDoc>(stmt->node, stmt_p->Attr("node" )), |
373 | LiteralDoc::Str(stmt->attr_key, stmt_p->Attr("attr_key" )), |
374 | d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value" )), |
375 | }); |
376 | } |
377 | With<TIRFrame> f(d, stmt); |
378 | AsDocBody(body, body_p, f->get(), d); |
379 | return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise); |
380 | }); |
381 | |
382 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
383 | .set_dispatch<tir::Store>( // |
384 | "" , [](tir::Store stmt, ObjectPath p, IRDocsifier d) -> Doc { |
385 | LOG(FATAL) << "ValueError: Store has been deprecated for BufferStore: " << stmt; |
386 | }); |
387 | |
388 | TVM_SCRIPT_REPR(tir::LetStmtNode, ReprPrintTIR); |
389 | TVM_SCRIPT_REPR(tir::AttrStmtNode, ReprPrintTIR); |
390 | TVM_SCRIPT_REPR(tir::AssertStmtNode, ReprPrintTIR); |
391 | TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR); |
392 | TVM_SCRIPT_REPR(tir::AllocateNode, ReprPrintTIR); |
393 | TVM_SCRIPT_REPR(tir::AllocateConstNode, ReprPrintTIR); |
394 | TVM_SCRIPT_REPR(tir::DeclBufferNode, ReprPrintTIR); |
395 | TVM_SCRIPT_REPR(tir::PrefetchNode, ReprPrintTIR); |
396 | TVM_SCRIPT_REPR(tir::SeqStmtNode, ReprPrintTIR); |
397 | TVM_SCRIPT_REPR(tir::IfThenElseNode, ReprPrintTIR); |
398 | TVM_SCRIPT_REPR(tir::EvaluateNode, ReprPrintTIR); |
399 | TVM_SCRIPT_REPR(tir::BufferRealizeNode, ReprPrintTIR); |
400 | TVM_SCRIPT_REPR(tir::StoreNode, ReprPrintTIR); |
401 | |
402 | } // namespace printer |
403 | } // namespace script |
404 | } // namespace tvm |
405 | |