1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir_text_printer.cc
22 * \brief Printer to print out the IR text format
23 * that can be parsed by a parser.
24 */
25
26#include <tvm/ir/module.h>
27#include <tvm/ir/type.h>
28#include <tvm/ir/type_functor.h>
29#include <tvm/node/serialization.h>
30#include <tvm/target/target.h>
31#include <tvm/tir/expr.h>
32#include <tvm/tir/function.h>
33#include <tvm/tir/op.h>
34#include <tvm/tir/stmt.h>
35
36#include <algorithm>
37#include <string>
38
39#include "../../tir/transforms/ir_utils.h"
40#include "doc.h"
41#include "meta_data.h"
42#include "text_printer.h"
43
44namespace tvm {
45namespace relay {
46
47Doc TIRTextPrinter::Print(const ObjectRef& node) {
48 if (!node.defined()) return Doc::Text("(nullptr)");
49 if (node->IsInstance<StmtNode>()) {
50 return VisitStmt(Downcast<Stmt>(node));
51 } else if (node->IsInstance<AnyNode>()) {
52 return Doc::Text("?");
53 } else if (node->IsInstance<PrimExprNode>()) {
54 return VisitExpr(Downcast<PrimExpr>(node));
55 } else if (node->IsInstance<TypeNode>()) {
56 return VisitType(Downcast<Type>(node));
57 } else if (node->IsInstance<PrimFuncNode>()) {
58 return PrintPrimFunc(Downcast<PrimFunc>(node));
59 } else if (node->IsInstance<IRModuleNode>()) {
60 return PrintIRModule(Downcast<IRModule>(node));
61 } else if (node->IsInstance<ArrayNode>()) {
62 return PrintArray(node.as<ArrayNode>());
63 } else if (node->IsInstance<IterVarNode>()) {
64 return PrintIterVar(node.as<IterVarNode>());
65 } else if (node->IsInstance<RangeNode>()) {
66 return PrintRange(node.as<RangeNode>());
67 } else if (node->IsInstance<BufferNode>()) {
68 return PrintBuffer(node.as<BufferNode>());
69 } else if (node->IsInstance<DataProducerNode>()) {
70 return PrintProducer(node.as<DataProducerNode>());
71 } else if (node->IsInstance<StringObj>()) {
72 return PrintString(node.as<StringObj>());
73 } else if (node->IsInstance<BufferRegionNode>()) {
74 return PrintBufferRegion(node.as<BufferRegionNode>());
75 } else if (node->IsInstance<TargetNode>()) {
76 return Doc::Text(node.as<TargetNode>()->ToDebugString());
77 } else {
78 return this->meta_->GetMetaNode(node);
79 }
80}
81
82Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) {
83 const auto* op = prim_func.operator->();
84 const auto& signature = op->func_type_annotation();
85 // collect Meta in DictAttr
86 if (prim_func->attrs.defined()) {
87 for (const auto& it : prim_func->attrs->dict) {
88 meta_collector_.Collect(it.second);
89 }
90 }
91 // collect buffers in buffer_map
92 memo_var_.clear();
93 memo_buf_.clear();
94
95 // ordered vars associated with buffers, for consistent printing
96 std::vector<tir::Var> buffer_vars_ordered;
97
98 for (tir::Var v : op->params) {
99 auto buffer_map_find = op->buffer_map.find(v);
100 if (buffer_map_find != op->buffer_map.end()) {
101 auto map_data = *buffer_map_find;
102 buffer_vars_ordered.push_back(map_data.first);
103 memo_buf_[map_data.second] = AllocBuf(map_data.second);
104 }
105 }
106
107 // print PrimFunc
108 Doc doc;
109 doc << "primfn"
110 << "(";
111 // print params and its type annotation
112 std::vector<Doc> params;
113 for (const auto& param : op->params) {
114 params.push_back(Print(param));
115 }
116 Doc sep;
117 doc << PrintSep(params, Doc::Indent(9, Doc::Text(", "))) << ")";
118 // print return type
119 doc << " -> " << Print(signature->ret_type);
120 // print attr
121 Doc attr_doc;
122 std::vector<Doc> attr_docs;
123 if (prim_func->attrs.defined()) {
124 for (const auto& it : op->attrs->dict) {
125 attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
126 }
127 attr_doc << NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}";
128 doc << Doc::Indent(2, attr_doc);
129 }
130
131 // print all the buffers in the tree
132 if (memo_buf_.size() != 0) {
133 Doc buffer_doc;
134 std::vector<Doc> buffer_docs;
135 for (const tir::Var& v : buffer_vars_ordered) {
136 const Buffer buf = op->buffer_map[v];
137 buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf)));
138 }
139 buffer_doc << NewLine() << "buffers = {";
140 buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << NewLine()));
141 doc << Doc::Indent(2, buffer_doc) << "}";
142 }
143
144 if (op->buffer_map.size() != 0) {
145 // print buffer_map
146 std::vector<Doc> buffer_map_doc;
147 for (const tir::Var& v : buffer_vars_ordered) {
148 const Buffer buf = op->buffer_map[v];
149 buffer_map_doc.push_back(Print(v) << ": " << Print(buf));
150 }
151 doc << Doc::Indent(
152 2, NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}");
153 }
154
155 doc << PrintBody(op->body);
156 return doc;
157}
158
159Doc TIRTextPrinter::NewLine() { return Doc::NewLine(); }
160
161Doc TIRTextPrinter::PrintIRModule(const IRModule& module) {
162 const auto* op = module.operator->();
163 Doc doc;
164
165 Doc body;
166 body << NewLine();
167 std::vector<Doc> functions;
168 for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
169 if ((*it).second.as<PrimFuncNode>()) {
170 functions.push_back(Print((*it).second));
171 }
172 }
173 body << TIRTextPrinter::PrintSep(functions, NewLine() << NewLine());
174 doc << Doc::Indent(0, body);
175 return doc;
176}
177
178Doc TIRTextPrinter::PrintArray(const ArrayNode* op) {
179 Doc doc;
180 doc << '[';
181 for (size_t i = 0; i < op->size(); ++i) {
182 if (i != 0) {
183 doc << ", ";
184 }
185 doc << Print(op->at(i));
186 }
187 doc << ']';
188 return doc;
189}
190
191Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) {
192 Doc doc;
193 doc << "IterVar(" << Print(op->var);
194 if (op->dom.defined()) {
195 doc << ", [" << Print(op->dom) << "], ";
196 } else {
197 doc << ", " << Print(op->dom) << ", ";
198 }
199 doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", ";
200 doc << Doc::StrLiteral(op->thread_tag) << ")";
201 return doc;
202}
203
204Doc TIRTextPrinter::PrintRange(const RangeNode* op) {
205 return Print(op->min) << ":" << Print(op->min + op->extent);
206}
207
208Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
209 const Buffer& buffer = GetRef<Buffer>(op);
210
211 if (meta_->InMeta(buffer)) {
212 return meta_->GetMetaNode(buffer);
213 } else if (memo_buf_.count(buffer)) {
214 return memo_buf_[buffer];
215 } else {
216 memo_buf_[buffer] = AllocBuf(buffer);
217 return BufferNode2Doc(op, memo_buf_[buffer]);
218 }
219}
220
221Doc TIRTextPrinter::PrintProducer(const DataProducerNode* op) {
222 const DataProducer& prod = GetRef<DataProducer>(op);
223
224 if (meta_->InMeta(prod)) {
225 return meta_->GetMetaNode(prod);
226 } else if (memo_producer_.count(prod)) {
227 return memo_producer_[prod];
228 } else {
229 memo_producer_[prod] = AllocProducer(prod);
230 return DataProducerNode2Doc(op, memo_producer_[prod]);
231 }
232}
233
234Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
235 doc << Doc::Text(": Buffer(") << Print(buf->data) << ", " << PrintDType(buf->dtype) << ", "
236 << Print(buf->shape) << ", " << Print(buf->strides);
237 if (!is_zero(buf->elem_offset)) {
238 doc << ", elem_offset=" << Print(buf->elem_offset);
239 }
240 if (buf->axis_separators.size()) {
241 doc << ", axis_separators=" << Print(buf->axis_separators);
242 }
243 if (GetRef<Buffer>(buf).scope() != "global") {
244 doc << ", scope=" << Doc::StrLiteral(GetRef<Buffer>(buf).scope());
245 }
246 if (buf->data_alignment != runtime::kAllocAlignment) {
247 doc << ", align=" << buf->data_alignment;
248 }
249 if (buf->offset_factor != 1) {
250 doc << ", offset_factor=" << buf->offset_factor;
251 }
252 if (buf->buffer_type != 1) {
253 doc << ", type=" << Doc::StrLiteral("auto");
254 }
255 return doc << ")";
256}
257
258Doc TIRTextPrinter::DataProducerNode2Doc(const DataProducerNode* prod, Doc doc) {
259 return doc << Doc::Text(": DataProducer(") << Print(prod->GetNameHint()) << ", "
260 << PrintDType(prod->GetDataType()) << ", " << Print(prod->GetShape()) << ")";
261}
262
263Doc TIRTextPrinter::PrintBufferRegion(const BufferRegionNode* op) {
264 Doc doc;
265 doc << Print(op->buffer) << "[";
266 for (size_t i = 0; i < op->region.size(); ++i) {
267 if (i != 0) {
268 doc << ", ";
269 }
270 const auto& range = op->region[i];
271 if (!is_one(range->extent)) {
272 doc << Print(range->min) << ":" << Print(range->min + range->extent);
273 } else {
274 doc << Print(range->min);
275 }
276 }
277 doc << "]";
278 return doc;
279}
280
281Doc TIRTextPrinter::VisitExprDefault_(const Object* op) {
282 return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
283}
284
285Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) {
286 return this->meta_->GetMetaNode(GetRef<ObjectRef>(op));
287}
288
289Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) {
290 return PrintConstScalar<int64_t>(op->dtype, op->value);
291}
292
293Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) {
294 return PrintConstScalar<double>(op->dtype, op->value);
295}
296
297Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); }
298
299Doc TIRTextPrinter::VisitExpr_(const CastNode* op) {
300 Doc doc;
301 doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
302 return doc;
303}
304
305Doc TIRTextPrinter::VisitExpr_(const tir::VarNode* op) {
306 const tir::Var& var = GetRef<tir::Var>(op);
307 return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<tir::Var>(op));
308}
309
310#define TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OpName, OpString) \
311 Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \
312 Doc doc; \
313 doc << "(" << Print(op->a) << OpString; \
314 doc << Print(op->b) << ")"; \
315 return doc; \
316 }
317
318TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(AddNode, " + ")
319TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(SubNode, " - ")
320TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(MulNode, "*")
321TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(DivNode, " / ")
322TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(ModNode, " % ")
323TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(EQNode, " == ")
324TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(NENode, " != ")
325TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(LTNode, " < ")
326TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(LENode, " <= ")
327TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(GTNode, " > ")
328TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(GENode, " >= ")
329TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(AndNode, " && ")
330TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OrNode, " || ")
331
332Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) {
333 Doc doc;
334 doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")";
335 return doc;
336}
337
338Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) {
339 Doc doc;
340 doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")";
341 return doc;
342}
343
344Doc TIRTextPrinter::VisitExpr_(const MinNode* op) {
345 Doc doc;
346 doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")";
347 return doc;
348}
349
350Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) {
351 Doc doc;
352 doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")";
353 return doc;
354}
355
356Doc TIRTextPrinter::VisitExpr_(const NotNode* op) {
357 Doc doc;
358 doc << "!" << Print(op->a);
359 return doc;
360}
361
362Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) {
363 Doc doc;
364 doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
365 << Print(op->false_value) << ")";
366 return doc;
367}
368
369Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) {
370 Doc doc;
371 doc << Print(op->buffer) << Print(op->indices);
372 return doc;
373}
374
375Doc TIRTextPrinter::VisitExpr_(const ProducerLoadNode* op) {
376 // TODO(tvm-team): consider make a better text format for producer.
377 Doc doc;
378 doc << op->producer->GetNameHint() << Print(op->indices);
379 return doc;
380}
381
382Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) {
383 Doc doc;
384 doc << "(" << PrintDType(op->dtype) << "*)" << Print(op->buffer_var) << "[" << Print(op->index)
385 << "]";
386 if (!is_one(op->predicate)) {
387 doc << " if " << Print(op->predicate);
388 }
389 return doc;
390}
391
392Doc TIRTextPrinter::VisitExpr_(const RampNode* op) {
393 Doc doc;
394 doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")";
395 return doc;
396}
397
398Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) {
399 Doc doc;
400 doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")";
401 return doc;
402}
403
404Doc TIRTextPrinter::VisitExpr_(const tir::LetNode* op) {
405 Doc doc;
406 doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body);
407 return doc;
408}
409
410Doc TIRTextPrinter::VisitExpr_(const tir::CallNode* op) {
411 Doc doc;
412 std::vector<Doc> func_args;
413 if (auto* ptr_op = op->op.as<OpNode>()) {
414 doc << "@" << Doc::Text(ptr_op->name) << "(";
415 if (ptr_op->name == "tir.call_llvm_pure_intrin") {
416 auto f = tvm::runtime::Registry::Get("target.llvm_get_intrinsic_name");
417 ICHECK(f != nullptr)
418 << "Cannot find target.llvm_get_intrinsic_name. Compile with USE_LLVM=On";
419 func_args.push_back(Print((*f)(Downcast<IntImm>(op->args[0])->value)));
420 for (size_t i = 1; i < op->args.size(); i++) {
421 func_args.push_back(Print(op->args[i]));
422 }
423 } else {
424 for (const auto& arg : op->args) {
425 func_args.push_back(Print(arg));
426 }
427 }
428 } else {
429 // TODO(bohan): Print out the name by he global var in the module.
430 auto* op_gvar = op->op.as<GlobalVarNode>();
431 ICHECK(op_gvar != nullptr);
432 doc << "@" << Doc::Text(op_gvar->name_hint) << "(";
433 for (const auto& arg : op->args) {
434 func_args.push_back(Print(arg));
435 }
436 }
437 doc << PrintSep(func_args, Doc::Text(", ")) << ", dtype=" << PrintDType(op->dtype) << ")";
438 return doc;
439}
440
441Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) {
442 Doc doc;
443 doc << "shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")";
444 return doc;
445}
446
447Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) {
448 Doc doc;
449 doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis)
450 << ", " << op->value_index << ", " << Print(op->init) << ")";
451 return doc;
452}
453
454Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) {
455 Doc doc;
456 doc << "let " << Print(op->var) << " = " << Print(op->value) << NewLine() << Print(op->body);
457 return doc;
458}
459
460Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) {
461 Doc doc;
462 meta_collector_.Collect(op->node);
463 doc << "attr [" << Print(op->node) << "] " << Doc::StrLiteral(op->attr_key) << " = "
464 << Print(op->value);
465 if (op->body->IsInstance<SeqStmtNode>()) {
466 doc << PrintBody(op->body);
467 } else {
468 doc << ";" << NewLine() << Print(op->body);
469 }
470 return doc;
471}
472
473Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) {
474 Doc doc;
475 doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << NewLine()
476 << Print(op->body);
477 return doc;
478}
479
480Doc TIRTextPrinter::VisitStmt_(const StoreNode* op) {
481 Doc doc;
482 doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value);
483 if (!is_one(op->predicate)) {
484 doc << " if " << Print(op->predicate);
485 }
486 return doc;
487}
488
489Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) {
490 Doc doc;
491 doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value);
492 return doc;
493}
494
495Doc TIRTextPrinter::VisitStmt_(const ProducerStoreNode* op) {
496 Doc doc;
497 doc << Print(op->producer) << Print(op->indices) << " = " << Print(op->value);
498 return doc;
499}
500
501Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) {
502 Doc doc;
503 doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", "
504 << Print(op->condition) << PrintBody(op->body) << ")";
505 return doc;
506}
507
508Doc TIRTextPrinter::VisitStmt_(const ProducerRealizeNode* op) {
509 Doc doc;
510 doc << "producer_realize(" << Print(op->producer) << ", " << Print(op->bounds) << ", "
511 << Print(op->condition) << ", " << PrintBody(op->body) << ")";
512 return doc;
513}
514
515Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
516 Doc doc;
517 auto scope = GetPtrStorageScope(op->buffer_var);
518 doc << "allocate(" << Print(op->buffer_var) << ", ";
519 doc << PrintDType(op->dtype) << ", ";
520 doc << Print(op->extents) << "), storage_scope = " << scope;
521 if (!op->annotations.empty()) {
522 std::vector<Doc> attr_docs;
523 for (const auto& it : op->annotations) {
524 attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
525 }
526 doc << ", annotations = {" << PrintSep(attr_docs, Doc::Text(", ")) << "})";
527 }
528 if (!is_one(op->condition)) {
529 doc << " if " << Print(op->condition);
530 }
531 if (op->body->IsInstance<SeqStmtNode>()) {
532 doc << PrintBody(op->body);
533 } else {
534 doc << ";" << NewLine() << Print(op->body);
535 }
536 return doc;
537}
538
539Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) {
540 Doc doc;
541 doc << "constant(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
542 << Print(op->extents) << ")";
543
544 if (op->body->IsInstance<SeqStmtNode>()) {
545 doc << PrintBody(op->body);
546 } else {
547 doc << ";" << NewLine() << Print(op->body);
548 }
549 return doc;
550}
551
552Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) {
553 Doc doc;
554 doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) << ", "
555 << PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << NewLine();
556 if (op->body->IsInstance<SeqStmtNode>()) {
557 doc << PrintBody(op->body);
558 } else {
559 doc << ";" << NewLine() << Print(op->body);
560 }
561 return doc;
562}
563
564Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
565 Doc doc;
566 doc << "if " << Print(op->condition) << PrintBody(op->then_case);
567 if (!is_one(op->condition) && op->else_case) {
568 doc << " else" << PrintBody(op->else_case.value());
569 }
570 return doc;
571}
572
573Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) {
574 std::vector<Doc> stmts;
575 Doc seq_doc, doc;
576 for (Stmt stmt : op->seq) {
577 seq_doc << NewLine() << Print(stmt);
578 }
579 doc << " {" << Doc::Indent(2, seq_doc) << NewLine() << "}";
580 return doc;
581}
582
583Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) {
584 Doc doc;
585 doc << Print(op->value);
586 return doc;
587}
588
589Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
590 Doc doc;
591 doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", "
592 << Print(op->min + op->extent) << ")";
593 if (op->kind != ForKind::kSerial) {
594 doc << " " << Doc::StrLiteral(ForKind2String(op->kind));
595 }
596 doc << PrintBody(op->body);
597 return doc;
598}
599
600Doc TIRTextPrinter::VisitStmt_(const WhileNode* op) {
601 Doc doc;
602 doc << "while (" << Print(op->condition) << ")";
603 doc << PrintBody(op->body);
604 return doc;
605}
606
607Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) {
608 Doc doc;
609 doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")";
610 return doc;
611}
612
613Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) {
614 const auto* block_op = op->block.as<BlockNode>();
615 // print block name and block vars
616 Doc doc;
617 doc << "block([";
618 std::vector<Doc> block_var_docs;
619 for (const auto& iter_var : block_op->iter_vars) {
620 Doc block_var_doc;
621 if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) {
622 block_var_doc << Print(iter_var->dom->extent);
623 } else {
624 block_var_doc << "tir.";
625 switch (iter_var->iter_type) {
626 case kDataPar:
627 block_var_doc << "range";
628 break;
629 case kCommReduce:
630 block_var_doc << "reduce_axis";
631 break;
632 case kOrdered:
633 block_var_doc << "scan_axis";
634 break;
635 case kOpaque:
636 block_var_doc << "opaque_axis";
637 break;
638 default:
639 LOG(FATAL) << "Unknown block var iter type";
640 break;
641 }
642 block_var_doc << "(" << Print(iter_var->dom->min) << ", "
643 << Print(iter_var->dom->min + iter_var->dom->extent) << ")";
644 }
645 block_var_docs.push_back(block_var_doc);
646 }
647 doc << PrintSep(block_var_docs, Doc::Text(", ")) << "], ";
648 doc << Doc::StrLiteral(block_op->name_hint) << ")";
649 std::vector<Doc> block_var_names;
650 for (const auto& iter_var : block_op->iter_vars) {
651 Doc block_var_name;
652 AllocVar(iter_var->var);
653 block_var_names.push_back(Print(iter_var->var));
654 }
655 if (!block_var_names.empty()) {
656 doc << " as [" << PrintSep(block_var_names, Doc::Text(", ")) << "]";
657 }
658 doc << " {";
659 Doc block_attr_doc;
660 // print predicate, binding, read/write tensor region, annotations
661 if (!is_one(op->predicate)) {
662 block_attr_doc << NewLine() << "where(" << Print(op->predicate) << ")";
663 }
664 for (size_t i = 0; i < block_op->iter_vars.size(); ++i)
665 block_attr_doc << NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", "
666 << Print(op->iter_values[i]) << ")";
667 block_attr_doc << NewLine() << "tir.reads(" << Print(block_op->reads) << ")";
668 block_attr_doc << NewLine() << "tir.writes(" << Print(block_op->writes) << ")";
669 if (!block_op->annotations.empty()) {
670 std::vector<Doc> attr_docs;
671 for (const auto& it : block_op->annotations) {
672 attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
673 }
674 block_attr_doc << NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", ")) << "})";
675 }
676 // print body
677 Doc body;
678 body << NewLine();
679 for (const auto& alloc_buf : block_op->alloc_buffers) {
680 body << AllocBuf(alloc_buf) << " = alloc_buffer(" << PrintDType(alloc_buf->dtype)
681 << Print(alloc_buf->shape) << ")" << NewLine();
682 }
683 for (const auto& match_buf : block_op->match_buffers) {
684 body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")"
685 << NewLine();
686 }
687 if (block_op->init.defined()) {
688 Doc init_block;
689 init_block << "with init()";
690 init_block << PrintBody(block_op->init.value());
691 body << init_block << NewLine();
692 }
693 body << Print(block_op->body);
694 doc << Doc::Indent(2, block_attr_doc << body);
695 return doc;
696}
697
698Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) {
699 Doc doc;
700 doc << PrintDType(node->dtype);
701 return doc;
702}
703
704Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) {
705 Doc doc;
706 doc << "Pointer(";
707 if (!node->storage_scope.empty()) {
708 doc << node->storage_scope << " ";
709 }
710 doc << Print(node->element_type) << ")";
711 return doc;
712}
713
714Doc TIRTextPrinter::VisitType_(const TupleTypeNode* node) {
715 std::vector<Doc> fields;
716 for (Type field : node->fields) {
717 fields.push_back(Print(field));
718 }
719 Doc doc;
720 doc << "(" << Doc::Concat(fields);
721 // conform to python tuple format (1,)
722 if (node->fields.size() == 1) {
723 doc << ",";
724 }
725 return doc << ")";
726}
727
728Doc TIRTextPrinter::PrintDType(DataType dtype) {
729 return Doc::Text(runtime::DLDataType2String(dtype));
730}
731
732template <typename T>
733Doc TIRTextPrinter::PrintConstScalar(DataType dtype, const T& data) {
734 Doc doc;
735 std::ostringstream os;
736 os << data;
737 if (dtype == DataType::Int(32)) {
738 doc << Doc::Text(os.str());
739 } else {
740 if (dtype.bits() == 1 && dtype.lanes() == 1 && dtype.code() == kDLUInt) {
741 doc << ((data == 1) ? "True" : "False");
742 return doc;
743 }
744 doc << Doc::Text(os.str());
745 switch (dtype.code()) {
746 case kDLInt:
747 doc << "i";
748 break;
749 case kDLUInt:
750 doc << "u";
751 break;
752 case kDLFloat:
753 doc << "f";
754 break;
755 }
756 doc << Doc::Text(std::to_string(dtype.bits()));
757 if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes()));
758 }
759 return doc;
760}
761
762Doc TIRTextPrinter::GetUniqueName(std::string prefix) {
763 // std::replace(prefix.begin(), prefix.end(), '.', '_');
764 std::string unique_prefix = prefix;
765 auto it = name_alloc_map_.find(prefix);
766 if (it != name_alloc_map_.end()) {
767 while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {
768 }
769 }
770 name_alloc_map_[unique_prefix] = 0;
771 return Doc::Text(unique_prefix);
772}
773
774Doc TIRTextPrinter::AllocVar(const tir::Var& var) {
775 const auto& it = memo_var_.find(var);
776 if (it != memo_var_.end()) {
777 return it->second;
778 }
779 std::string name = var->name_hint.operator std::string();
780 if (name.length() == 0 || !std::isalpha(name[0])) {
781 name = "v" + name;
782 }
783 Doc val = GetUniqueName(name);
784 memo_var_[var] = val;
785 return val << ": " << Print(GetType(var));
786}
787
788Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) {
789 const auto& it = memo_buf_.find(buffer);
790 if (it != memo_buf_.end()) {
791 return it->second;
792 }
793 std::string name = buffer->name;
794 if (name.length() == 0 || !std::isalpha(name[0])) {
795 name = "buf_" + name;
796 }
797 Doc val = GetUniqueName(name);
798 memo_buf_[buffer] = val;
799 return val;
800}
801
802Doc TIRTextPrinter::AllocProducer(const DataProducer& producer) {
803 const auto& it = memo_producer_.find(producer);
804 if (it != memo_producer_.end()) {
805 return it->second;
806 }
807 std::string name = producer->GetNameHint();
808 if (name.length() == 0 || !std::isalpha(name[0])) {
809 name = "tensor_" + name;
810 }
811 Doc val = GetUniqueName(name);
812 memo_producer_[producer] = val;
813 return val;
814}
815
816Doc TIRTextPrinter::PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
817 Doc seq;
818 if (vec.size() != 0) {
819 seq = vec[0];
820 for (size_t i = 1; i < vec.size(); i++) {
821 seq << sep << vec[i];
822 }
823 }
824 return seq;
825}
826
827Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) {
828 Doc doc;
829 if (body->IsInstance<SeqStmtNode>()) return Print(body);
830 doc << " {" << Doc::Indent(2, NewLine() << Print(body)) << NewLine() << "}";
831 return doc;
832}
833
834bool TIRTextPrinter::GetVarName(tir::Var v, std::string* s) {
835 auto it = memo_var_.find(v);
836 if (it == memo_var_.end()) {
837 return false;
838 }
839
840 *s = it->second.str();
841 return true;
842}
843
844} // namespace relay
845} // namespace tvm
846