1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../../../tir/transforms/ir_utils.h" // For `GetPtrStorageScope`
20#include "./utils.h"
21
22namespace tvm {
23namespace script {
24namespace printer {
25
26Doc DoConciseScoping(const Optional<ExprDoc>& lhs, const ExprDoc& rhs, Array<StmtDoc>* stmts,
27 bool concise_scoping) {
28 if (concise_scoping) {
29 if (lhs.defined()) {
30 stmts->insert(stmts->begin(), AssignDoc(lhs.value(), rhs, NullOpt));
31 } else {
32 stmts->insert(stmts->begin(), ExprStmtDoc(rhs));
33 }
34 return StmtBlockDoc(*stmts);
35 } else {
36 return ScopeDoc(lhs, rhs, *stmts);
37 }
38}
39
40bool AllowConciseScoping(const IRDocsifier& d) {
41 ICHECK(!d->frames.empty());
42 if (const auto* f = d->frames.back().as<TIRFrameNode>()) {
43 return f->allow_concise_scoping;
44 }
45 LOG(FATAL) << "NotImplementedError: fragment printing";
46}
47
48TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
49 .set_dispatch<tir::Evaluate>("", [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc {
50 ExprDoc value = d->AsDoc<ExprDoc>(eval->value, p->Attr("value"));
51 if (eval->value->IsInstance<tir::CallNode>()) {
52 return ExprStmtDoc(value);
53 }
54 return ExprStmtDoc(TIR(d, "evaluate")->Call({value}));
55 });
56
57TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
58 .set_dispatch<tir::LetStmt>("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
59 bool concise = AllowConciseScoping(d);
60 if (concise && !d->IsVarDefined(stmt->var)) {
61 ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
62 With<TIRFrame> f(d, stmt);
63 ExprDoc lhs = DefineVar(stmt->var, *f, d);
64 AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
65 Array<StmtDoc>* stmts = &(*f)->stmts;
66 Type type = stmt->var->type_annotation;
67 Optional<ExprDoc> type_doc =
68 d->AsDoc<ExprDoc>(type, p->Attr("var")->Attr("type_annotation"));
69 if (const auto* tuple_type = type.as<TupleTypeNode>()) {
70 if (tuple_type->fields.empty()) {
71 type_doc = NullOpt;
72 }
73 }
74 stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc));
75 return StmtBlockDoc(*stmts);
76 } else {
77 ExprDoc lhs = d->AsDoc<ExprDoc>(stmt->var, p->Attr("var"));
78 ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
79 With<TIRFrame> f(d, stmt);
80 AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
81 Array<StmtDoc>* stmts = &(*f)->stmts;
82 rhs = TIR(d, "let")->Call({lhs, rhs});
83 return ScopeDoc(NullOpt, rhs, *stmts);
84 }
85 });
86
87TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
88 .set_dispatch<tir::AssertStmt>(
89 "", [](tir::AssertStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
90 bool concise = AllowConciseScoping(d);
91 ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
92 ExprDoc msg = d->AsDoc<ExprDoc>(stmt->message, p->Attr("message"));
93 With<TIRFrame> f(d, stmt);
94 AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
95 if (concise) {
96 Array<StmtDoc>* stmts = &(*f)->stmts;
97 stmts->insert(stmts->begin(), AssertDoc(cond, msg));
98 return StmtBlockDoc(*stmts);
99 }
100 return ScopeDoc(NullOpt, TIR(d, "Assert")->Call({cond, msg}), (*f)->stmts);
101 });
102
103TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
104 .set_dispatch<tir::While>("", [](tir::While stmt, ObjectPath p, IRDocsifier d) -> Doc {
105 ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
106 With<TIRFrame> f(d, stmt);
107 AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
108 return WhileDoc(cond, (*f)->stmts);
109 });
110
111TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
112 .set_dispatch<tir::DeclBuffer>( //
113 "", [](tir::DeclBuffer stmt, ObjectPath p, IRDocsifier d) -> Doc {
114 bool concise = AllowConciseScoping(d);
115 ExprDoc rhs =
116 BufferDecl(stmt->buffer, "decl_buffer", {}, p->Attr("buffer"), d->frames.back(), d);
117 With<TIRFrame> f(d, stmt);
118 ExprDoc lhs = DefineBuffer(stmt->buffer, *f, d);
119 AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
120 return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
121 });
122
123TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
124 .set_dispatch<tir::IfThenElse>( //
125 "", [](tir::IfThenElse stmt, ObjectPath p, IRDocsifier d) -> Doc {
126 ExprDoc cond = d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition"));
127 Array<StmtDoc> then_branch;
128 Array<StmtDoc> else_branch;
129 if (stmt->then_case.defined()) {
130 With<TIRFrame> f(d, stmt->then_case);
131 AsDocBody(stmt->then_case, p->Attr("then_case"), f->get(), d);
132 then_branch = (*f)->stmts;
133 }
134 if (stmt->else_case.defined()) {
135 With<TIRFrame> f(d, stmt->else_case);
136 AsDocBody(stmt->else_case.value(), p->Attr("else_case"), f->get(), d);
137 else_branch = (*f)->stmts;
138 }
139 return IfDoc(cond, then_branch, else_branch);
140 });
141
142TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
143 .set_dispatch<tir::SeqStmt>("", [](tir::SeqStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
144 With<TIRFrame> f(d, stmt);
145 AsDocBody(stmt, p, f->get(), d);
146 return StmtBlockDoc((*f)->stmts);
147 });
148
149TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
150 .set_dispatch<tir::Prefetch>( //
151 "", [](tir::Prefetch stmt, ObjectPath p, IRDocsifier d) -> Doc {
152 return ExprStmtDoc(TIR(d, "prefetch")
153 ->Call({
154 d->AsDoc<ExprDoc>(stmt->buffer, p->Attr("buffer")),
155 d->AsDoc<ExprDoc>(stmt->bounds, p->Attr("bounds")),
156 }));
157 });
158
159bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) {
160 const tir::Var& buffer_var = allocate->buffer_var;
161 if (const tir::DeclBufferNode* decl_buffer = allocate->body.as<tir::DeclBufferNode>()) {
162 const tir::Buffer& buffer = decl_buffer->buffer;
163 if (buffer_var.same_as(buffer->data) && allocate->dtype == buffer->dtype &&
164 tir::is_one(allocate->condition) && !allocate->annotations.size() &&
165 allocate->extents.size() == buffer->shape.size()) {
166 tir::ExprDeepEqual expr_equal;
167 for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
168 if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
169 return false;
170 }
171 }
172 return true;
173 }
174 }
175 return false;
176}
177
178TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
179 .set_dispatch<tir::Allocate>( //
180 "", [](tir::Allocate stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
181 bool concise = AllowConciseScoping(d);
182 if (d->cfg->syntax_sugar && IsAllocateDeclBufferPattern(stmt.get())) {
183 return d->AsDoc(stmt->body, stmt_p->Attr("body"));
184 }
185 Array<ExprDoc> args;
186 Array<String> kwargs_keys;
187 Array<ExprDoc> kwargs_values;
188 args.push_back(d->AsDoc<ExprDoc>(stmt->extents, stmt_p->Attr("extents")));
189 args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype")));
190 args.push_back(LiteralDoc::Str(tir::GetPtrStorageScope(stmt->buffer_var),
191 stmt_p
192 ->Attr("buffer_var") //
193 ->Attr("type_annotation")
194 ->Attr("storage_scope")));
195 if (!tir::is_one(stmt->condition)) {
196 args.push_back(d->AsDoc<ExprDoc>(stmt->condition, stmt_p->Attr("condition")));
197 }
198 if (!stmt->annotations.empty()) {
199 kwargs_keys.push_back("annotations");
200 kwargs_values.push_back(
201 d->AsDoc<ExprDoc>(stmt->annotations, stmt_p->Attr("annotations")));
202 }
203 ExprDoc lhs = DefineVar(stmt->buffer_var, d->frames.back(), d);
204 With<TIRFrame> f(d, stmt);
205 ExprDoc rhs = TIR(d, "allocate")->Call(args, kwargs_keys, kwargs_values);
206 AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d);
207 return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
208 });
209
210template <typename T>
211ExprDoc PrintNDArray(::tvm::runtime::NDArray arr) {
212 // FIXME(@junrushao): this is a hack and can be wrong in most of the cases
213 constexpr int NUM_PRINT = 200;
214 int ndim = arr->ndim;
215 int tot_dim = 1;
216 for (int i = 0; i < ndim; i++) {
217 tot_dim *= arr->shape[i];
218 }
219 Array<ExprDoc> result;
220 T* data_ptr = reinterpret_cast<T*>(arr->data);
221 runtime::DataType dtype = arr.DataType();
222 for (int i = 0; i < tot_dim; i++) {
223 if (dtype.is_float()) {
224 result.push_back(LiteralDoc::Float(data_ptr[i], NullOpt));
225 } else {
226 result.push_back(LiteralDoc::Int(data_ptr[i], NullOpt));
227 }
228 if (i == NUM_PRINT) {
229 break;
230 }
231 }
232 return ListDoc(result);
233}
234
235TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
236 .set_dispatch<tir::AllocateConst>(
237 "", [](tir::AllocateConst stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
238 bool concise = AllowConciseScoping(d);
239 String storage_scope = tir::GetPtrStorageScope(stmt->buffer_var);
240 Array<ExprDoc> args;
241 Array<String> kwargs_keys;
242 Array<ExprDoc> kwargs_values;
243 ExprDoc data_doc{nullptr};
244 if (stmt->dtype.is_int()) {
245 if (stmt->dtype.bits() == 8) {
246 data_doc = PrintNDArray<int8_t>(stmt->data.value());
247 } else if (stmt->dtype.bits() == 16) {
248 data_doc = PrintNDArray<int16_t>(stmt->data.value());
249 } else if (stmt->dtype.bits() == 32) {
250 data_doc = PrintNDArray<int32_t>(stmt->data.value());
251 } else if (stmt->dtype.bits() == 64) {
252 data_doc = PrintNDArray<int64_t>(stmt->data.value());
253 } else {
254 LOG(FATAL) << "DataType not supported";
255 }
256 } else if (stmt->dtype.is_uint()) {
257 if (stmt->dtype.bits() == 8) {
258 data_doc = PrintNDArray<uint8_t>(stmt->data.value());
259 } else if (stmt->dtype.bits() == 16) {
260 data_doc = PrintNDArray<uint16_t>(stmt->data.value());
261 } else if (stmt->dtype.bits() == 32) {
262 data_doc = PrintNDArray<uint32_t>(stmt->data.value());
263 } else if (stmt->dtype.bits() == 64) {
264 data_doc = PrintNDArray<uint64_t>(stmt->data.value());
265 } else {
266 LOG(FATAL) << "DataType not supported";
267 }
268 } else if (stmt->dtype.is_float()) {
269 if (stmt->dtype.bits() == 16) {
270 data_doc = PrintNDArray<int16_t>(stmt->data.value());
271 } else if (stmt->dtype.bits() == 32) {
272 data_doc = PrintNDArray<float>(stmt->data.value());
273 } else if (stmt->dtype.bits() == 64) {
274 data_doc = PrintNDArray<double>(stmt->data.value());
275 } else {
276 LOG(FATAL) << "DataType not supported";
277 }
278 } else {
279 LOG(FATAL) << "DataType not supported";
280 }
281 args.push_back(data_doc);
282 args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype")));
283 args.push_back(d->AsDoc<ExprDoc>(stmt->extents, stmt_p->Attr("extents")));
284 ExprDoc rhs = TIR(d, "allocate_const")->Call(args, kwargs_keys, kwargs_values);
285 With<TIRFrame> f(d, stmt);
286 ExprDoc lhs = DefineVar(stmt->buffer_var, *f, d);
287 AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d);
288 return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise);
289 });
290
291ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional<ExprDoc> value, //
292 ObjectPath p, IRDocsifier d) {
293 ExprDoc buffer = d->AsDoc<ExprDoc>(stmt->buffer, p->Attr("buffer"));
294 {
295 Array<Doc> bounds;
296 bounds.reserve(stmt->bounds.size());
297 for (int i = 0, n = stmt->bounds.size(); i < n; ++i) {
298 Range range = stmt->bounds[i];
299 ObjectPath range_p = p->Attr("bounds")->ArrayIndex(i);
300 bounds.push_back(
301 SliceDoc(d->AsDoc<ExprDoc>(range->min, range_p->Attr("min")),
302 d->AsDoc<ExprDoc>(range->min + range->extent, range_p->Attr("extent")), //
303 NullOpt));
304 }
305 buffer = buffer[bounds];
306 }
307 Array<ExprDoc> args{buffer};
308 Array<String> kwargs_keys;
309 Array<ExprDoc> kwargs_values;
310 if (value.defined()) {
311 args.push_back(value.value());
312 }
313 if (!tir::is_one(stmt->condition)) {
314 kwargs_keys.push_back("condition");
315 kwargs_values.push_back(d->AsDoc<ExprDoc>(stmt->condition, p->Attr("condition")));
316 }
317 return TIR(d, "realize")->Call(args, kwargs_keys, kwargs_values);
318}
319
320TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
321 .set_dispatch<tir::BufferRealize>( //
322 "", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
323 bool concise = AllowConciseScoping(d);
324 ExprDoc rhs = DocsifyBufferRealize(stmt.get(), NullOpt, p, d);
325 With<TIRFrame> f(d, stmt);
326 AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
327 return DoConciseScoping(NullOpt, rhs, &(*f)->stmts, concise);
328 });
329
330TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
331 .set_dispatch<tir::AttrStmt>( //
332 "", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
333 bool concise = AllowConciseScoping(d);
334 Optional<ExprDoc> rhs = NullOpt;
335 tir::Stmt body = stmt->body;
336 ObjectPath body_p = stmt_p->Attr("body");
337 if (stmt->attr_key == "realize_scope") {
338 if (const auto* realize = stmt->body.as<tir::BufferRealizeNode>()) {
339 if (realize->buffer.same_as(stmt->node)) {
340 rhs = DocsifyBufferRealize(
341 realize,
342 /*value=*/d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
343 /*p=*/stmt_p->Attr("body"), d);
344 body = realize->body;
345 body_p = body_p->Attr("body");
346 }
347 }
348 }
349 if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") {
350 if (const auto* iter_var = stmt->node.as<tir::IterVarNode>()) {
351 if (!d->IsVarDefined(iter_var->var)) {
352 // `DefineVar` is not used here because a more specific name is desirable
353 ObjectPath iter_var_p = stmt_p->Attr("node");
354 Frame f = FindLowestVarDef(iter_var->var, d).value();
355 DefineVar(iter_var->var, f, d);
356 f->stmts.push_back(
357 AssignDoc(d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var")),
358 TIR(d, "env_thread")
359 ->Call({LiteralDoc::Str(iter_var->thread_tag,
360 iter_var_p->Attr("thread_tag"))}), //
361 NullOpt));
362 }
363 rhs = TIR(d, "launch_thread")
364 ->Call({
365 d->AsDoc<ExprDoc>(iter_var->var, stmt_p->Attr("node")),
366 d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
367 });
368 }
369 }
370 if (!rhs.defined()) {
371 rhs = TIR(d, "attr")->Call({
372 d->AsDoc<ExprDoc>(stmt->node, stmt_p->Attr("node")),
373 LiteralDoc::Str(stmt->attr_key, stmt_p->Attr("attr_key")),
374 d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
375 });
376 }
377 With<TIRFrame> f(d, stmt);
378 AsDocBody(body, body_p, f->get(), d);
379 return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise);
380 });
381
382TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
383 .set_dispatch<tir::Store>( //
384 "", [](tir::Store stmt, ObjectPath p, IRDocsifier d) -> Doc {
385 LOG(FATAL) << "ValueError: Store has been deprecated for BufferStore: " << stmt;
386 });
387
388TVM_SCRIPT_REPR(tir::LetStmtNode, ReprPrintTIR);
389TVM_SCRIPT_REPR(tir::AttrStmtNode, ReprPrintTIR);
390TVM_SCRIPT_REPR(tir::AssertStmtNode, ReprPrintTIR);
391TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR);
392TVM_SCRIPT_REPR(tir::AllocateNode, ReprPrintTIR);
393TVM_SCRIPT_REPR(tir::AllocateConstNode, ReprPrintTIR);
394TVM_SCRIPT_REPR(tir::DeclBufferNode, ReprPrintTIR);
395TVM_SCRIPT_REPR(tir::PrefetchNode, ReprPrintTIR);
396TVM_SCRIPT_REPR(tir::SeqStmtNode, ReprPrintTIR);
397TVM_SCRIPT_REPR(tir::IfThenElseNode, ReprPrintTIR);
398TVM_SCRIPT_REPR(tir::EvaluateNode, ReprPrintTIR);
399TVM_SCRIPT_REPR(tir::BufferRealizeNode, ReprPrintTIR);
400TVM_SCRIPT_REPR(tir::StoreNode, ReprPrintTIR);
401
402} // namespace printer
403} // namespace script
404} // namespace tvm
405