1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay_text_printer.cc
22 * \brief Printer to print out the IR text format
23 * that can be parsed by a parser.
24 *
25 * Supports ANF, GNF in relay and metadata.
26 *
27 * Inlining heuristics:
28 * - Always inline:
29 * - GlobalVar
30 * - Constant
31 * - Op
32 * - Var
33 * - Otherwise, inline if the node is at the end of a scope and is used at most once.
34 */
35#include <tvm/ir/module.h>
36#include <tvm/ir/type_functor.h>
37#include <tvm/relay/attrs/annotation.h>
38#include <tvm/relay/expr_functor.h>
39#include <tvm/relay/pattern_functor.h>
40#include <tvm/target/virtual_device.h>
41#include <tvm/tir/function.h>
42
43#include "../../ir/attr_functor.h"
44#include "../../support/scalars.h"
45#include "../analysis/dependency_graph.h"
46#include "../parser/meta_ref.h"
47#include "doc.h"
48#include "meta_data.h"
49#include "text_printer.h"
50#include "tvm/runtime/builtin_fp16.h"
51
52namespace tvm {
53namespace relay {
54
55/*!
56 * \brief Print additional info about expr in comment.
57 * \param expr The expression.
58 */
59Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) {
60 Doc doc;
61 if (!opt_info_memo_.insert(expr).second) {
62 return doc;
63 }
64 // default annotations
65 if (annotate_ == nullptr) {
66 if ((expr.as<ConstantNode>() || expr.as<CallNode>() || expr.as<VarNode>() ||
67 expr.as<FunctionNode>() || expr.as<TupleNode>() || expr.as<TupleGetItemNode>()) &&
68 (expr->checked_type_.defined() || expr->span.defined())) {
69 doc << " /*";
70 if (expr->checked_type_.defined()) {
71 doc << " ty=" << Print(expr->checked_type());
72 }
73 if (expr->span.defined()) {
74 doc << " span=" << PrintSpan(expr->span);
75 }
76 doc << " */";
77 }
78 } else {
79 std::string annotated_expr = annotate_(expr);
80 if (annotated_expr != "") {
81 doc << annotated_expr;
82 }
83 }
84 return doc;
85}
86
87// indent a new body
88Doc RelayTextPrinter::PrintBody(const ObjectRef& node, int indent) {
89 Doc doc;
90 Doc body;
91 doc << "{";
92 doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine();
93 doc << "}";
94 return doc;
95}
96
97// create a new scope by creating a new printer object. This allows temp var
98// numbers to be reused and prevents hoisted vars from escaping too far
99Doc RelayTextPrinter::PrintScope(const ObjectRef& node) {
100 // print in a new scope
101 doc_stack_.push_back(Doc());
102 // must print first so doc_stack_.back() reference doesn't become stale
103 Doc doc = Print(node, false, true);
104 doc = doc_stack_.back() << doc;
105 doc_stack_.pop_back();
106 return doc;
107}
108
109Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) {
110 if (node.defined() && node->IsInstance<BaseFuncNode>() &&
111 !node->IsInstance<relay::FunctionNode>()) {
112 // Temporarily skip non-relay functions.
113 // TODO(tvm-team) enhance the code to work for all functions
114 } else if (node.as<ExprNode>()) {
115 Expr expr = Downcast<Expr>(node);
116 dg_ = DependencyGraph::Create(&arena_, expr);
117 }
118
119 Doc doc;
120 doc << PrintScope(node);
121 return doc;
122}
123
124Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
125 bool is_non_relay_func = node.defined() && node->IsInstance<BaseFuncNode>() &&
126 !node->IsInstance<relay::FunctionNode>();
127 if (node.as<ExprNode>() && !is_non_relay_func) {
128 return PrintExpr(Downcast<Expr>(node), meta, try_inline);
129 } else if (node.as<TypeNode>()) {
130 return PrintType(Downcast<Type>(node), meta);
131 } else if (node.as<PatternNode>()) {
132 return PrintPattern(Downcast<Pattern>(node), meta);
133 } else if (node.as<IRModuleNode>()) {
134 return PrintMod(Downcast<IRModule>(node));
135 } else {
136 // default module.
137 std::ostringstream os;
138 os << node;
139 return Doc::RawText(os.str());
140 }
141}
142
143Doc RelayTextPrinter::TempVar(int n) {
144 Doc doc;
145 return doc << "%" << n;
146}
147
148Doc RelayTextPrinter::AllocTemp() { return TempVar(temp_var_counter_++); }
149
150/*!
151 * \brief get a unique name with the corresponding prefix
152 * \param prefix The prefix of the name
153 * \return The returned name.
154 */
155Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) {
156 std::string unique_prefix = prefix;
157 auto it = name_alloc_map_.find(prefix);
158 if (it != name_alloc_map_.end()) {
159 while (true) {
160 std::ostringstream os;
161 os << prefix << (++it->second);
162 std::string name = os.str();
163 if (name_alloc_map_.count(name) == 0) {
164 unique_prefix = name;
165 break;
166 }
167 }
168 }
169 name_alloc_map_[unique_prefix] = 0;
170 return Doc::Text(unique_prefix);
171}
172
173Doc RelayTextPrinter::Print(Kind k) {
174 switch (k) {
175 case kType:
176 return Doc::Text("Type");
177 case kShapeVar:
178 return Doc::Text("Shape");
179 case kBaseType:
180 return Doc::Text("BaseType");
181 case kConstraint:
182 return Doc::Text("Constraint");
183 case kAdtHandle:
184 return Doc::Text("AdtHandle");
185 case kTypeData:
186 return Doc::Text("TypeData");
187 default:
188 LOG(ERROR) << "Unknown Kind";
189 throw;
190 }
191}
192/*!
193 * \brief Allocate name to a type variable.
194 * \param var The input type variable.
195 * \return The corresponding name.
196 */
197Doc RelayTextPrinter::AllocTypeVar(const TypeVar& var) {
198 if (memo_type_.count(var)) {
199 Doc val = memo_type_[var];
200 val << "-malformed-ir";
201 return val;
202 }
203 std::string name = var->name_hint;
204 if (name.length() == 0 || !std::isalpha(name[0])) {
205 name = "t" + name;
206 }
207 Doc val = GetUniqueName(name);
208 memo_type_[var] = val;
209 if (var->kind != kType) {
210 val << ": " << Print(var->kind);
211 }
212 return val;
213}
214
215/*!
216 * \brief Allocate name to a variable.
217 * \param var The input variable.
218 * \return The corresponding name.
219 */
220Doc RelayTextPrinter::AllocVar(const Var& var) {
221 // still print if ir is malformed, but show the error.
222 if (memo_.count(var)) {
223 Doc val = memo_[var];
224 val << "-malformed-ir";
225 return val;
226 }
227 std::string name = var->name_hint();
228 // always make sure first name is alpha
229 if (name.length() == 0 || !std::isalpha(name[0])) {
230 name = "v" + name;
231 }
232 Doc val = GetUniqueName("%" + name);
233 memo_[var] = val; // Referential occurrences will not include the following.
234 if (!var->virtual_device()->IsFullyUnconstrained()) {
235 val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}";
236 }
237 if (var->type_annotation.defined()) {
238 val << ": " << Print(var->type_annotation);
239 }
240
241 val << PrintOptionalInfo(var);
242 return val;
243}
244
245bool RelayTextPrinter::IsUnique(const Expr& expr) {
246 auto it = dg_.expr_node.find(expr);
247 if (it == dg_.expr_node.end()) {
248 return true;
249 } else {
250 return !(it->second->parents.head && it->second->parents.head->next);
251 }
252}
253
254bool RelayTextPrinter::AlwaysInline(const Expr& expr) {
255 return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() ||
256 expr.as<VarNode>() || expr.as<ConstructorNode>();
257}
258
259Doc RelayTextPrinter::VisitLeaf(const Expr& expr) {
260 if (!CheckVisited(expr)) {
261 Doc result = ExprFunctor<Doc(const Expr&)>::VisitExpr(expr);
262 // Add if not added after visiting
263 if (!CheckVisited(expr)) {
264 memo_[expr] = result;
265 } else {
266 result_memo_[expr] = result;
267 }
268 return result;
269 }
270 return memo_[expr];
271}
272
273bool RelayTextPrinter::CheckVisited(const Expr& expr) { return (memo_.count(expr)); }
274
275Doc RelayTextPrinter::VisitExpr(const Expr& expr) {
276 auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
277 auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };
278
279 if (fcheck_visited(expr)) {
280 return memo_[expr];
281 } else {
282 ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
283 return memo_[expr];
284 }
285}
286
287//------------------------------------
288// Overload of Expr printing functions
289//------------------------------------
290Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info) {
291 // Exploit memoization to print GNF.
292 // The first time we visit an expression, we need to allocate a temp var
293 // for it. Every subsequent time we can just use its assigned variable.
294 // This works since hashing uses pointer equality.
295
296 // determine whether to inline
297 bool inline_expr = AlwaysInline(expr);
298
299 if (try_inline) {
300 inline_expr |= IsUnique(expr);
301 }
302
303 Doc printed_expr;
304
305 if (meta) {
306 printed_expr = meta_->GetMetaNode(GetRef<ObjectRef>(expr.get()));
307 } else if (!inline_expr && expr.as<LetNode>()) {
308 // wrap GNFed let in brackets
309 Doc body;
310 printed_expr << "(";
311 printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine();
312 printed_expr << ")";
313 } else {
314 printed_expr = VisitExpr(expr);
315 }
316
317 if (optional_info) {
318 printed_expr << PrintOptionalInfo(expr);
319 }
320
321 // add expr to doc
322 if (expr.as<VarNode>()) {
323 // This is our first time visiting the var and we hit the VarNode case
324 // in the visitor. Thus the variable is free.
325 if (var_memo_.insert(expr).second && result_memo_.count(expr)) {
326 doc_stack_.back() << "free_var " << result_memo_[expr] << ";" << Doc::NewLine();
327 }
328 // Memoization is done in AllocVar.
329 return memo_[expr];
330 } else if (inline_expr) {
331 memo_[expr] = printed_expr;
332 return printed_expr;
333 } else {
334 // Already exists. Reuse
335 if (!var_memo_.insert(expr).second) {
336 return memo_[expr];
337 }
338 Doc temp_var = AllocTemp();
339 memo_[expr] = temp_var;
340 doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();
341 return temp_var;
342 }
343}
344
345// Should only be triggered when op is a free variable being visited for the
346// first time.
347Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef<Var>(op)); }
348
349Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) {
350 // Print out simple scalars directly.
351 if (support::IsSimpleScalar(op)) {
352 return Doc::Text(support::NDArrayScalarToString(op->data));
353 }
354 // Fallbock: record it as a meta node.
355 Doc doc;
356 // Don't append optional_info. Because the entry function is Print,
357 // and it will append the optional_info afterwards.
358 return doc << PrintExpr(GetRef<Expr>(op), /*meta=*/true, /*try_inline=*/false,
359 /*optional_info=*/false);
360}
361
362Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
363 std::vector<Doc> fields;
364 for (Expr field : op->fields) {
365 fields.push_back(Print(field));
366 }
367 Doc doc;
368 doc << "(" << Doc::Concat(fields);
369 // conform to python tuple format (1,)
370 if (op->fields.size() == 1) {
371 doc << ",";
372 }
373 return doc << ")";
374}
375
376Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) {
377 Doc doc;
378 return doc << Print(op->tuple) << "." << op->index;
379}
380
381Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
382 Doc doc;
383 doc << "if (" << Print(op->cond) << ") ";
384 doc << PrintBody(op->true_branch);
385 doc << " else ";
386 doc << PrintBody(op->false_branch);
387 return doc;
388}
389
390Doc RelayTextPrinter::VisitExpr_(const LetNode* op) {
391 int n = 0;
392 size_t l = doc_stack_.size();
393 Expr let = GetRef<Let>(op);
394 while (auto let_node = let.as<LetNode>()) {
395 Doc doc;
396 doc << "let " << AllocVar(let_node->var) << " = " << Print(let_node->value, false, true) << ";"
397 << Doc::NewLine();
398 doc_stack_.push_back(doc);
399 let = let_node->body;
400 ++n;
401 }
402 Doc doc = PrintScope(let);
403 Doc doc_last;
404 for (int i = 0; i < n; ++i) {
405 doc_last << doc_stack_[l + i];
406 }
407 doc_last << doc;
408 for (int i = 0; i < n; ++i) {
409 doc_stack_.pop_back();
410 }
411 return doc_last;
412}
413
414Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
415 Doc doc;
416 doc << prefix;
417 if (fn->type_params.size() > 0) {
418 doc << "[";
419 std::vector<Doc> type_params;
420 for (const TypeVar& tv : fn->type_params) {
421 type_params.push_back(Doc::Text(tv->name_hint));
422 }
423 doc << Doc::Concat(type_params);
424 doc << "]";
425 }
426 doc << "(";
427 std::vector<Doc> params;
428 for (Var param : fn->params) {
429 params.push_back(AllocVar(param));
430 }
431 for (const Doc& d : PrintDictAttrs(fn->attrs)) {
432 params.push_back(d);
433 }
434 if (!fn->virtual_device()->IsFullyUnconstrained()) {
435 Doc vid_doc;
436 vid_doc << kVirtualDevice << "=" << PrintAttributeValue(fn->virtual_device());
437 params.push_back(vid_doc);
438 }
439 doc << Doc::Concat(params) << ") ";
440 if (fn->ret_type.defined()) {
441 doc << "-> " << Print(fn->ret_type) << " ";
442 }
443 doc << PrintBody(fn->body);
444 return doc;
445}
446
447Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const BaseFunc& base_func) {
448 if (auto* n = base_func.as<relay::FunctionNode>()) {
449 return PrintFunc(prefix, GetRef<relay::Function>(n));
450 } else if (auto* n = base_func.as<tir::PrimFuncNode>()) {
451 std::ostringstream os;
452 os << GetRef<tir::PrimFunc>(n);
453 return Doc::RawText(os.str());
454 } else {
455 // def @xyz = meta['ExternalFunc'][id]
456 Doc doc;
457 doc << prefix << " = " << meta_->GetMetaNode(base_func);
458 return doc;
459 }
460}
461
462Doc RelayTextPrinter::PrintMod(const IRModule& mod) {
463 Doc doc;
464 int counter = 0;
465 // type definitions
466 for (const auto& kv : mod->type_definitions) {
467 if (counter++ != 0) {
468 doc << Doc::NewLine();
469 }
470 doc << Print(kv.second);
471 doc << Doc::NewLine();
472 }
473 // functions
474 for (const auto& kv : mod->functions) {
475 if (kv.second.as<relay::FunctionNode>()) {
476 dg_ = DependencyGraph::Create(&arena_, kv.second);
477 }
478 if (counter++ != 0) {
479 doc << Doc::NewLine();
480 }
481 std::ostringstream os;
482 os << "def @" << kv.first->name_hint;
483 doc << PrintFunc(Doc::Text(os.str()), kv.second);
484 doc << Doc::NewLine();
485 }
486 return doc;
487}
488
489Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
490 return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
491}
492
493Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
494 Doc doc;
495 doc << "@" << op->name_hint;
496 return doc;
497}
498
499Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }
500
501Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
502 Doc doc;
503 // visit args first so they are lifted before the op
504 // this places op closer to its call site
505 std::vector<Doc> args;
506 for (const Expr& arg : op->args) {
507 args.push_back(Print(arg));
508 }
509
510 for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
511 args.push_back(d);
512 }
513 const auto* cons_node = op->op.as<ConstructorNode>();
514 if (cons_node) {
515 doc << cons_node->name_hint;
516 } else {
517 doc << Print(op->op);
518 }
519
520 if (cons_node && cons_node->inputs.size() == 0) {
521 // don't print as a call if it's a 0-arity cons
522 return doc;
523 } else {
524 doc << "(" << Doc::Concat(args) << ")";
525 return doc;
526 }
527}
528
529Doc RelayTextPrinter::VisitExpr_(const RefCreateNode* op) {
530 Doc doc;
531 return doc << "ref(" << Print(op->value) << ")";
532}
533
534Doc RelayTextPrinter::VisitExpr_(const RefReadNode* op) {
535 Doc doc;
536 return doc << "ref_read(" << Print(op->ref) << ")";
537}
538
539Doc RelayTextPrinter::VisitExpr_(const RefWriteNode* op) {
540 Doc doc;
541 return doc << "ref_write(" << Print(op->ref) << ", " << Print(op->value) << ")";
542}
543
544Doc RelayTextPrinter::VisitExpr_(const MatchNode* op) {
545 // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs.
546 Doc doc;
547 Doc body;
548 doc << "match";
549 if (!op->complete) {
550 doc << "?";
551 }
552 doc << " (" << Print(op->data) << ") {";
553 std::vector<Doc> clause_docs;
554 for (const auto& clause : op->clauses) {
555 Doc clause_doc;
556 clause_doc << PrintPattern(clause->lhs, false) << " => ";
557 Doc rhs_doc = PrintScope(clause->rhs);
558 // TODO(@jroesch): This is unsound right now, and we need to revist it.
559 // if (clause->rhs.as<LetNode>()) {
560 // only add braces if there are multiple lines on the rhs
561 rhs_doc = Doc::Brace("{", rhs_doc, "}");
562 // }
563 clause_doc << rhs_doc << ",";
564 clause_docs.push_back(clause_doc);
565 }
566 doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine()))
567 << Doc::NewLine() << "}";
568 return doc;
569}
570
571Doc RelayTextPrinter::PrintPattern(const Pattern& pattern, bool meta) {
572 auto it = memo_pattern_.find(pattern);
573 if (it != memo_pattern_.end()) return it->second;
574 Doc printed_pattern;
575 if (meta) {
576 printed_pattern = meta_->GetMetaNode(GetRef<ObjectRef>(pattern.get()));
577 } else {
578 printed_pattern = VisitPattern(pattern);
579 }
580 memo_pattern_[pattern] = printed_pattern;
581 return printed_pattern;
582}
583
584Doc RelayTextPrinter::VisitPattern_(const PatternConstructorNode* p) {
585 Doc doc;
586 doc << p->constructor->name_hint;
587 if (!p->patterns.empty()) {
588 doc << "(";
589 std::vector<Doc> pats;
590 for (const auto& pat : p->patterns) {
591 pats.push_back(Print(pat));
592 }
593 doc << Doc::Concat(pats) << ")";
594 }
595 return doc;
596}
597
598Doc RelayTextPrinter::VisitPattern_(const PatternTupleNode* pt) {
599 Doc doc;
600 doc << "(";
601 std::vector<Doc> pats;
602 for (const auto& pat : pt->patterns) {
603 pats.push_back(Print(pat));
604 }
605 doc << Doc::Concat(pats) << ")";
606 return doc;
607}
608
609Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) { return Doc::Text("_"); }
610
611Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) { return AllocVar(pv->var); }
612
613Doc RelayTextPrinter::VisitExpr_(const ConstructorNode* n) {
614 Doc doc;
615 doc << n->name_hint;
616 if (in_adt_def_ && n->inputs.size() != 0) {
617 doc << "(";
618 std::vector<Doc> inputs;
619 for (Type input : n->inputs) {
620 inputs.push_back(Print(input));
621 }
622 doc << Doc::Concat(inputs) << ")";
623 }
624 return doc;
625}
626
627//------------------------------------
628// Overload of Type printing functions
629//------------------------------------
630Doc RelayTextPrinter::PrintType(const Type& type, bool meta) {
631 auto it = memo_type_.find(type);
632 if (it != memo_type_.end()) return it->second;
633 Doc printed_type;
634 if (meta) {
635 printed_type = meta_->GetMetaNode(GetRef<ObjectRef>(type.get()));
636 } else {
637 printed_type = VisitType(type);
638 }
639 memo_type_[type] = printed_type;
640 return printed_type;
641}
642
643Doc RelayTextPrinter::VisitTypeDefault_(const Object* node) {
644 // by default always print as meta data
645 return Print(GetRef<ObjectRef>(node), true);
646}
647
648Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) { return Doc::Text(node->name_hint); }
649
650Doc RelayTextPrinter::VisitType_(const GlobalTypeVarNode* node) {
651 return Doc::Text(node->name_hint);
652}
653
654Doc RelayTextPrinter::VisitType_(const TypeCallNode* node) {
655 Doc doc = PrintType(node->func, false);
656 std::vector<Doc> args;
657 for (const Type& t : node->args) {
658 args.push_back(PrintType(t, false));
659 }
660 doc << "[";
661 doc << Doc::Concat(args);
662 doc << "]";
663 return doc;
664}
665
666Doc RelayTextPrinter::PrintDType(DataType dtype) {
667 return Doc::Text(runtime::DLDataType2String(dtype));
668}
669
670Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) {
671 // scalar type
672 if (node->shape.size() == 0) {
673 return PrintDType(node->dtype);
674 }
675 Doc doc;
676 doc << "Tensor[(";
677 std::vector<Doc> shapes;
678 for (const PrimExpr& prim_expr : node->shape) {
679 // Though not bound within an attribute the attribute visitor will handle the PrimExprs we
680 // care about.
681 shapes.push_back(PrintAttributeValue(prim_expr));
682 }
683 doc << Doc::Concat(shapes);
684 return doc << "), " << PrintDType(node->dtype) << "]";
685}
686
687Doc RelayTextPrinter::VisitType_(const TupleTypeNode* node) {
688 std::vector<Doc> fields;
689 for (Type field : node->fields) {
690 fields.push_back(Print(field));
691 }
692 Doc doc;
693 doc << "(" << Doc::Concat(fields);
694 // conform to python tuple format (1,)
695 if (node->fields.size() == 1) {
696 doc << ",";
697 }
698 return doc << ")";
699}
700
701Doc RelayTextPrinter::VisitType_(const FuncTypeNode* node) {
702 Doc doc;
703 doc << "fn ";
704 if (node->type_params.size() != 0) {
705 doc << "[";
706 std::vector<Doc> type_params;
707 for (Type type_param : node->type_params) {
708 type_params.push_back(Print(type_param));
709 }
710 doc << Doc::Concat(type_params);
711 doc << "]";
712 }
713 std::vector<Doc> arg_types;
714 for (Type arg_type : node->arg_types) {
715 arg_types.push_back(Print(arg_type));
716 }
717 return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type);
718}
719
720Doc RelayTextPrinter::VisitType_(const RelayRefTypeNode* node) {
721 Doc doc;
722 return doc << "ref(" << Print(node->value) << ")";
723}
724
725Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) {
726 in_adt_def_ = true;
727 Doc doc;
728 doc << "type " << Print(node->header);
729
730 // type vars
731 if (node->type_vars.size() != 0) {
732 doc << "[";
733 std::vector<Doc> type_vars;
734 for (Type type_var : node->type_vars) {
735 type_vars.push_back(Print(type_var));
736 }
737 doc << Doc::Concat(type_vars) << "]";
738 }
739 doc << " ";
740
741 std::vector<Doc> constructor_docs;
742 for (Constructor constructor : node->constructors) {
743 constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true));
744 }
745 Doc separator;
746 separator << "," << Doc::NewLine();
747 Doc adt_body;
748 adt_body << Doc::Concat(constructor_docs, separator);
749 // add trailing comma if there are any constructors
750 if (!constructor_docs.empty()) {
751 adt_body << ",";
752 }
753 doc << Doc::Brace("{", adt_body, "}");
754 in_adt_def_ = false;
755 return doc;
756}
757
758//------------------------------------
759// Overload of Attr printing functions
760//------------------------------------
761
762Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) {
763 // Since we don't have any overload for a specific attribute type we'll need to force
764 // the meta[...] representation to avoid infinite regress.
765 return PrintAttributeValue(GetRef<ObjectRef>(op), /*force_meta=*/true);
766}
767
768Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
769 Doc doc;
770 doc << "[";
771 std::vector<Doc> arr_vals;
772 for (const auto& val : *op) {
773 arr_vals.push_back(PrintAttributeValue(val));
774 }
775 doc << Doc::Concat(arr_vals);
776 doc << "]";
777 return doc;
778}
779
780Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) {
781 if (support::IsSimpleScalarDtype(op->dtype)) {
782 return Doc::Text(support::IntImmToString(GetRef<IntImm>(op)));
783 } else {
784 // Fallback: Print int64_t without width suffix.
785 return Doc::Text(std::to_string(op->value));
786 }
787}
788
789Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) {
790 if (support::IsSimpleScalarDtype(op->dtype)) {
791 return Doc::Text(support::FloatImmToString(GetRef<FloatImm>(op)));
792 } else {
793 // Fallbock: Print double without width suffix.
794 return Doc::Text(std::to_string(op->value));
795 }
796}
797
798Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) {
799 return Doc::StrLiteral(op->value);
800}
801
802/*!
803 * \brief Attribute printer which prints the attributes in the call.
804 */
805class RelayTextPrinter::AttrPrinter : public AttrVisitor {
806 public:
807 AttrPrinter(std::vector<Doc>* doc, RelayTextPrinter* parent) : docs(doc), parent_(parent) {}
808
809 template <typename T>
810 void PrintKV(const char* key, const T& value) {
811 Doc doc;
812 doc << key << "=" << value;
813 docs->push_back(doc);
814 }
815
816 void Visit(const char* key, double* value) final {
817 Doc doc;
818 doc << key << "=" << *value << "f";
819 docs->push_back(doc);
820 }
821
822 void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); }
823 void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); }
824 void Visit(const char* key, int* value) final { PrintKV(key, *value); }
825 void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); }
826 void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); }
827 void Visit(const char* key, void** value) final { LOG(FATAL) << "do not allow void as argument"; }
828 void Visit(const char* key, DataType* value) final {
829 PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value)));
830 }
831 void Visit(const char* key, runtime::NDArray* value) final {
832 LOG(FATAL) << "do not allow NDarray as argument";
833 }
834 void Visit(const char* key, runtime::ObjectRef* obj) final {
835 PrintKV(key, parent_->PrintAttributeValue(*obj));
836 }
837
838 private:
839 std::vector<Doc>* docs;
840 RelayTextPrinter* parent_;
841};
842
843void RelayTextPrinter::AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs,
844 bool include_type_key) {
845 if (!attrs.defined()) {
846 return;
847 }
848 AttrPrinter printer(docs, this);
849 // Need to drop cost cast since in general VisitNonDefaultAttrs can mutate, but in this
850 // case we are read-only.
851 const_cast<BaseAttrsNode*>(attrs.get())->VisitNonDefaultAttrs(&printer);
852 if (include_type_key) {
853 std::string s = attrs->GetTypeKey();
854 printer.Visit("attrs_type_key", &s);
855 }
856}
857
858std::vector<Doc> RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
859 std::vector<Doc> docs;
860 if (!attrs.defined()) {
861 return docs;
862 }
863 const auto* op_node = op.as<OpNode>();
864 if (show_meta_data_ && op_node && (attrs->type_index() != op_node->attrs_type_index)) {
865 // The parser can only understand calls with attributes if they match the operator's
866 // declared attribute type. If that's not the case fall back to the meta[...] representation.
867 docs.push_back(meta_->GetMetaNode(attrs));
868 } else {
869 AppendGenericAttrs(&docs, attrs, /*include_type_key=*/!op_node);
870 }
871 return docs;
872}
873
874std::vector<Doc> RelayTextPrinter::PrintDictAttrs(const DictAttrs& dict_attrs) {
875 if (!dict_attrs.defined()) {
876 return {};
877 }
878 return PrintDictAttrs(dict_attrs->dict);
879}
880
881std::vector<Doc> RelayTextPrinter::PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs) {
882 std::vector<Doc> docs;
883 if (!dict_attrs.defined()) {
884 return docs;
885 }
886 for (const auto& k : dict_attrs) {
887 Doc doc;
888 doc << k.first << "=" << PrintAttributeValue(k.second);
889 docs.push_back(doc);
890 }
891 return docs;
892}
893
894Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_meta) {
895 if (value.defined()) {
896 Doc printed_attr;
897 if (value.as<tvm::tir::AnyNode>()) {
898 printed_attr << "?";
899 } else if (auto str_obj = value.as<tvm::StringObj>()) {
900 printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
901 } else if (force_meta) {
902 printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
903 } else if (const auto* virtual_device_node = value.as<VirtualDeviceNode>()) {
904 if (show_meta_data_) {
905 printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(virtual_device_node));
906 } else {
907 // Special case: The ReprPrinter for VirtualDeviceNodes is much easier to work with while
908 // debugging.
909 std::ostringstream os;
910 os << GetRef<VirtualDevice>(virtual_device_node);
911 return Doc::Text(os.str());
912 }
913 } else if (const auto* base_attr_node = value.as<BaseAttrsNode>()) {
914 if (show_meta_data_) {
915 printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(base_attr_node));
916 } else {
917 // Special case: The non-meta form for attributes are much easier to work with while
918 // debugging.
919 printed_attr = PrintAttrsAsAttributeValue(GetRef<Attrs>(base_attr_node));
920 }
921 } else if (const auto* base_map_node = value.as<MapNode>()) {
922 if (show_meta_data_) {
923 printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(base_map_node));
924 } else {
925 // Special case: Show maps fields as key=value pairs to help debugging.
926 printed_attr << PrintMapAsAttributeValue(GetRef<Map<ObjectRef, ObjectRef>>(base_map_node));
927 }
928 } else if (const auto* global_var_node = value.as<GlobalVarNode>()) {
929 if (show_meta_data_) {
930 printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(global_var_node));
931 } else {
932 printed_attr << "'" << global_var_node->name_hint << "'";
933 }
934 } else {
935 printed_attr = VisitAttr(value);
936 }
937 return printed_attr;
938 } else {
939 return Doc::Text("None");
940 }
941}
942
943Doc RelayTextPrinter::PrintAttrsAsAttributeValue(const Attrs& attrs) {
944 std::vector<Doc> docs;
945 AppendGenericAttrs(&docs, attrs, /*include_type_key=*/false);
946 Doc doc;
947 doc << "{" << Doc::Concat(docs) << "}";
948 return doc;
949}
950
951Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map) {
952 std::vector<Doc> docs;
953 for (const auto& k : map) {
954 Doc doc;
955 doc << PrintAttributeValue(k.first);
956 doc << "=";
957 doc << PrintAttributeValue(k.second);
958 docs.push_back(doc);
959 }
960 Doc doc;
961 doc << "{" << Doc::Concat(docs) << "}";
962 return doc;
963}
964
965Doc RelayTextPrinter::PrintSpan(const Span& span) {
966 Doc doc;
967 const auto* span_node = span.as<SpanNode>();
968 ICHECK(span_node);
969 doc << span_node->source_name->name << ":" << span_node->line << ":" << span_node->column;
970 return doc;
971}
972
973} // namespace relay
974} // namespace tvm
975