1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay_text_printer.cc |
22 | * \brief Printer to print out the IR text format |
23 | * that can be parsed by a parser. |
24 | * |
25 | * Supports ANF, GNF in relay and metadata. |
26 | * |
27 | * Inlining heuristics: |
28 | * - Always inline: |
29 | * - GlobalVar |
30 | * - Constant |
31 | * - Op |
32 | * - Var |
33 | * - Otherwise, inline if the node is at the end of a scope and is used at most once. |
34 | */ |
35 | #include <tvm/ir/module.h> |
36 | #include <tvm/ir/type_functor.h> |
37 | #include <tvm/relay/attrs/annotation.h> |
38 | #include <tvm/relay/expr_functor.h> |
39 | #include <tvm/relay/pattern_functor.h> |
40 | #include <tvm/target/virtual_device.h> |
41 | #include <tvm/tir/function.h> |
42 | |
43 | #include "../../ir/attr_functor.h" |
44 | #include "../../support/scalars.h" |
45 | #include "../analysis/dependency_graph.h" |
46 | #include "../parser/meta_ref.h" |
47 | #include "doc.h" |
48 | #include "meta_data.h" |
49 | #include "text_printer.h" |
50 | #include "tvm/runtime/builtin_fp16.h" |
51 | |
52 | namespace tvm { |
53 | namespace relay { |
54 | |
55 | /*! |
56 | * \brief Print additional info about expr in comment. |
57 | * \param expr The expression. |
58 | */ |
59 | Doc RelayTextPrinter::PrintOptionalInfo(const Expr& expr) { |
60 | Doc doc; |
61 | if (!opt_info_memo_.insert(expr).second) { |
62 | return doc; |
63 | } |
64 | // default annotations |
65 | if (annotate_ == nullptr) { |
66 | if ((expr.as<ConstantNode>() || expr.as<CallNode>() || expr.as<VarNode>() || |
67 | expr.as<FunctionNode>() || expr.as<TupleNode>() || expr.as<TupleGetItemNode>()) && |
68 | (expr->checked_type_.defined() || expr->span.defined())) { |
69 | doc << " /*" ; |
70 | if (expr->checked_type_.defined()) { |
71 | doc << " ty=" << Print(expr->checked_type()); |
72 | } |
73 | if (expr->span.defined()) { |
74 | doc << " span=" << PrintSpan(expr->span); |
75 | } |
76 | doc << " */" ; |
77 | } |
78 | } else { |
79 | std::string annotated_expr = annotate_(expr); |
80 | if (annotated_expr != "" ) { |
81 | doc << annotated_expr; |
82 | } |
83 | } |
84 | return doc; |
85 | } |
86 | |
87 | // indent a new body |
88 | Doc RelayTextPrinter::PrintBody(const ObjectRef& node, int indent) { |
89 | Doc doc; |
90 | Doc body; |
91 | doc << "{" ; |
92 | doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine(); |
93 | doc << "}" ; |
94 | return doc; |
95 | } |
96 | |
97 | // create a new scope by creating a new printer object. This allows temp var |
98 | // numbers to be reused and prevents hoisted vars from escaping too far |
99 | Doc RelayTextPrinter::PrintScope(const ObjectRef& node) { |
100 | // print in a new scope |
101 | doc_stack_.push_back(Doc()); |
102 | // must print first so doc_stack_.back() reference doesn't become stale |
103 | Doc doc = Print(node, false, true); |
104 | doc = doc_stack_.back() << doc; |
105 | doc_stack_.pop_back(); |
106 | return doc; |
107 | } |
108 | |
109 | Doc RelayTextPrinter::PrintFinal(const ObjectRef& node) { |
110 | if (node.defined() && node->IsInstance<BaseFuncNode>() && |
111 | !node->IsInstance<relay::FunctionNode>()) { |
112 | // Temporarily skip non-relay functions. |
113 | // TODO(tvm-team) enhance the code to work for all functions |
114 | } else if (node.as<ExprNode>()) { |
115 | Expr expr = Downcast<Expr>(node); |
116 | dg_ = DependencyGraph::Create(&arena_, expr); |
117 | } |
118 | |
119 | Doc doc; |
120 | doc << PrintScope(node); |
121 | return doc; |
122 | } |
123 | |
124 | Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { |
125 | bool is_non_relay_func = node.defined() && node->IsInstance<BaseFuncNode>() && |
126 | !node->IsInstance<relay::FunctionNode>(); |
127 | if (node.as<ExprNode>() && !is_non_relay_func) { |
128 | return PrintExpr(Downcast<Expr>(node), meta, try_inline); |
129 | } else if (node.as<TypeNode>()) { |
130 | return PrintType(Downcast<Type>(node), meta); |
131 | } else if (node.as<PatternNode>()) { |
132 | return PrintPattern(Downcast<Pattern>(node), meta); |
133 | } else if (node.as<IRModuleNode>()) { |
134 | return PrintMod(Downcast<IRModule>(node)); |
135 | } else { |
136 | // default module. |
137 | std::ostringstream os; |
138 | os << node; |
139 | return Doc::RawText(os.str()); |
140 | } |
141 | } |
142 | |
143 | Doc RelayTextPrinter::TempVar(int n) { |
144 | Doc doc; |
145 | return doc << "%" << n; |
146 | } |
147 | |
148 | Doc RelayTextPrinter::AllocTemp() { return TempVar(temp_var_counter_++); } |
149 | |
150 | /*! |
151 | * \brief get a unique name with the corresponding prefix |
152 | * \param prefix The prefix of the name |
153 | * \return The returned name. |
154 | */ |
155 | Doc RelayTextPrinter::GetUniqueName(const std::string& prefix) { |
156 | std::string unique_prefix = prefix; |
157 | auto it = name_alloc_map_.find(prefix); |
158 | if (it != name_alloc_map_.end()) { |
159 | while (true) { |
160 | std::ostringstream os; |
161 | os << prefix << (++it->second); |
162 | std::string name = os.str(); |
163 | if (name_alloc_map_.count(name) == 0) { |
164 | unique_prefix = name; |
165 | break; |
166 | } |
167 | } |
168 | } |
169 | name_alloc_map_[unique_prefix] = 0; |
170 | return Doc::Text(unique_prefix); |
171 | } |
172 | |
173 | Doc RelayTextPrinter::Print(Kind k) { |
174 | switch (k) { |
175 | case kType: |
176 | return Doc::Text("Type" ); |
177 | case kShapeVar: |
178 | return Doc::Text("Shape" ); |
179 | case kBaseType: |
180 | return Doc::Text("BaseType" ); |
181 | case kConstraint: |
182 | return Doc::Text("Constraint" ); |
183 | case kAdtHandle: |
184 | return Doc::Text("AdtHandle" ); |
185 | case kTypeData: |
186 | return Doc::Text("TypeData" ); |
187 | default: |
188 | LOG(ERROR) << "Unknown Kind" ; |
189 | throw; |
190 | } |
191 | } |
192 | /*! |
193 | * \brief Allocate name to a type variable. |
194 | * \param var The input type variable. |
195 | * \return The corresponding name. |
196 | */ |
197 | Doc RelayTextPrinter::AllocTypeVar(const TypeVar& var) { |
198 | if (memo_type_.count(var)) { |
199 | Doc val = memo_type_[var]; |
200 | val << "-malformed-ir" ; |
201 | return val; |
202 | } |
203 | std::string name = var->name_hint; |
204 | if (name.length() == 0 || !std::isalpha(name[0])) { |
205 | name = "t" + name; |
206 | } |
207 | Doc val = GetUniqueName(name); |
208 | memo_type_[var] = val; |
209 | if (var->kind != kType) { |
210 | val << ": " << Print(var->kind); |
211 | } |
212 | return val; |
213 | } |
214 | |
215 | /*! |
216 | * \brief Allocate name to a variable. |
217 | * \param var The input variable. |
218 | * \return The corresponding name. |
219 | */ |
220 | Doc RelayTextPrinter::AllocVar(const Var& var) { |
221 | // still print if ir is malformed, but show the error. |
222 | if (memo_.count(var)) { |
223 | Doc val = memo_[var]; |
224 | val << "-malformed-ir" ; |
225 | return val; |
226 | } |
227 | std::string name = var->name_hint(); |
228 | // always make sure first name is alpha |
229 | if (name.length() == 0 || !std::isalpha(name[0])) { |
230 | name = "v" + name; |
231 | } |
232 | Doc val = GetUniqueName("%" + name); |
233 | memo_[var] = val; // Referential occurrences will not include the following. |
234 | if (!var->virtual_device()->IsFullyUnconstrained()) { |
235 | val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}" ; |
236 | } |
237 | if (var->type_annotation.defined()) { |
238 | val << ": " << Print(var->type_annotation); |
239 | } |
240 | |
241 | val << PrintOptionalInfo(var); |
242 | return val; |
243 | } |
244 | |
245 | bool RelayTextPrinter::IsUnique(const Expr& expr) { |
246 | auto it = dg_.expr_node.find(expr); |
247 | if (it == dg_.expr_node.end()) { |
248 | return true; |
249 | } else { |
250 | return !(it->second->parents.head && it->second->parents.head->next); |
251 | } |
252 | } |
253 | |
254 | bool RelayTextPrinter::AlwaysInline(const Expr& expr) { |
255 | return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() || expr.as<OpNode>() || |
256 | expr.as<VarNode>() || expr.as<ConstructorNode>(); |
257 | } |
258 | |
259 | Doc RelayTextPrinter::VisitLeaf(const Expr& expr) { |
260 | if (!CheckVisited(expr)) { |
261 | Doc result = ExprFunctor<Doc(const Expr&)>::VisitExpr(expr); |
262 | // Add if not added after visiting |
263 | if (!CheckVisited(expr)) { |
264 | memo_[expr] = result; |
265 | } else { |
266 | result_memo_[expr] = result; |
267 | } |
268 | return result; |
269 | } |
270 | return memo_[expr]; |
271 | } |
272 | |
273 | bool RelayTextPrinter::CheckVisited(const Expr& expr) { return (memo_.count(expr)); } |
274 | |
275 | Doc RelayTextPrinter::VisitExpr(const Expr& expr) { |
276 | auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); }; |
277 | auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); }; |
278 | |
279 | if (fcheck_visited(expr)) { |
280 | return memo_[expr]; |
281 | } else { |
282 | ExpandDataflow(expr, fcheck_visited, fvisit_leaf); |
283 | return memo_[expr]; |
284 | } |
285 | } |
286 | |
287 | //------------------------------------ |
288 | // Overload of Expr printing functions |
289 | //------------------------------------ |
290 | Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info) { |
291 | // Exploit memoization to print GNF. |
292 | // The first time we visit an expression, we need to allocate a temp var |
293 | // for it. Every subsequent time we can just use its assigned variable. |
294 | // This works since hashing uses pointer equality. |
295 | |
296 | // determine whether to inline |
297 | bool inline_expr = AlwaysInline(expr); |
298 | |
299 | if (try_inline) { |
300 | inline_expr |= IsUnique(expr); |
301 | } |
302 | |
303 | Doc printed_expr; |
304 | |
305 | if (meta) { |
306 | printed_expr = meta_->GetMetaNode(GetRef<ObjectRef>(expr.get())); |
307 | } else if (!inline_expr && expr.as<LetNode>()) { |
308 | // wrap GNFed let in brackets |
309 | Doc body; |
310 | printed_expr << "(" ; |
311 | printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine(); |
312 | printed_expr << ")" ; |
313 | } else { |
314 | printed_expr = VisitExpr(expr); |
315 | } |
316 | |
317 | if (optional_info) { |
318 | printed_expr << PrintOptionalInfo(expr); |
319 | } |
320 | |
321 | // add expr to doc |
322 | if (expr.as<VarNode>()) { |
323 | // This is our first time visiting the var and we hit the VarNode case |
324 | // in the visitor. Thus the variable is free. |
325 | if (var_memo_.insert(expr).second && result_memo_.count(expr)) { |
326 | doc_stack_.back() << "free_var " << result_memo_[expr] << ";" << Doc::NewLine(); |
327 | } |
328 | // Memoization is done in AllocVar. |
329 | return memo_[expr]; |
330 | } else if (inline_expr) { |
331 | memo_[expr] = printed_expr; |
332 | return printed_expr; |
333 | } else { |
334 | // Already exists. Reuse |
335 | if (!var_memo_.insert(expr).second) { |
336 | return memo_[expr]; |
337 | } |
338 | Doc temp_var = AllocTemp(); |
339 | memo_[expr] = temp_var; |
340 | doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); |
341 | return temp_var; |
342 | } |
343 | } |
344 | |
345 | // Should only be triggered when op is a free variable being visited for the |
346 | // first time. |
347 | Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef<Var>(op)); } |
348 | |
349 | Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) { |
350 | // Print out simple scalars directly. |
351 | if (support::IsSimpleScalar(op)) { |
352 | return Doc::Text(support::NDArrayScalarToString(op->data)); |
353 | } |
354 | // Fallbock: record it as a meta node. |
355 | Doc doc; |
356 | // Don't append optional_info. Because the entry function is Print, |
357 | // and it will append the optional_info afterwards. |
358 | return doc << PrintExpr(GetRef<Expr>(op), /*meta=*/true, /*try_inline=*/false, |
359 | /*optional_info=*/false); |
360 | } |
361 | |
362 | Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) { |
363 | std::vector<Doc> fields; |
364 | for (Expr field : op->fields) { |
365 | fields.push_back(Print(field)); |
366 | } |
367 | Doc doc; |
368 | doc << "(" << Doc::Concat(fields); |
369 | // conform to python tuple format (1,) |
370 | if (op->fields.size() == 1) { |
371 | doc << "," ; |
372 | } |
373 | return doc << ")" ; |
374 | } |
375 | |
376 | Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) { |
377 | Doc doc; |
378 | return doc << Print(op->tuple) << "." << op->index; |
379 | } |
380 | |
381 | Doc RelayTextPrinter::VisitExpr_(const IfNode* op) { |
382 | Doc doc; |
383 | doc << "if (" << Print(op->cond) << ") " ; |
384 | doc << PrintBody(op->true_branch); |
385 | doc << " else " ; |
386 | doc << PrintBody(op->false_branch); |
387 | return doc; |
388 | } |
389 | |
390 | Doc RelayTextPrinter::VisitExpr_(const LetNode* op) { |
391 | int n = 0; |
392 | size_t l = doc_stack_.size(); |
393 | Expr let = GetRef<Let>(op); |
394 | while (auto let_node = let.as<LetNode>()) { |
395 | Doc doc; |
396 | doc << "let " << AllocVar(let_node->var) << " = " << Print(let_node->value, false, true) << ";" |
397 | << Doc::NewLine(); |
398 | doc_stack_.push_back(doc); |
399 | let = let_node->body; |
400 | ++n; |
401 | } |
402 | Doc doc = PrintScope(let); |
403 | Doc doc_last; |
404 | for (int i = 0; i < n; ++i) { |
405 | doc_last << doc_stack_[l + i]; |
406 | } |
407 | doc_last << doc; |
408 | for (int i = 0; i < n; ++i) { |
409 | doc_stack_.pop_back(); |
410 | } |
411 | return doc_last; |
412 | } |
413 | |
414 | Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { |
415 | Doc doc; |
416 | doc << prefix; |
417 | if (fn->type_params.size() > 0) { |
418 | doc << "[" ; |
419 | std::vector<Doc> type_params; |
420 | for (const TypeVar& tv : fn->type_params) { |
421 | type_params.push_back(Doc::Text(tv->name_hint)); |
422 | } |
423 | doc << Doc::Concat(type_params); |
424 | doc << "]" ; |
425 | } |
426 | doc << "(" ; |
427 | std::vector<Doc> params; |
428 | for (Var param : fn->params) { |
429 | params.push_back(AllocVar(param)); |
430 | } |
431 | for (const Doc& d : PrintDictAttrs(fn->attrs)) { |
432 | params.push_back(d); |
433 | } |
434 | if (!fn->virtual_device()->IsFullyUnconstrained()) { |
435 | Doc vid_doc; |
436 | vid_doc << kVirtualDevice << "=" << PrintAttributeValue(fn->virtual_device()); |
437 | params.push_back(vid_doc); |
438 | } |
439 | doc << Doc::Concat(params) << ") " ; |
440 | if (fn->ret_type.defined()) { |
441 | doc << "-> " << Print(fn->ret_type) << " " ; |
442 | } |
443 | doc << PrintBody(fn->body); |
444 | return doc; |
445 | } |
446 | |
447 | Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const BaseFunc& base_func) { |
448 | if (auto* n = base_func.as<relay::FunctionNode>()) { |
449 | return PrintFunc(prefix, GetRef<relay::Function>(n)); |
450 | } else if (auto* n = base_func.as<tir::PrimFuncNode>()) { |
451 | std::ostringstream os; |
452 | os << GetRef<tir::PrimFunc>(n); |
453 | return Doc::RawText(os.str()); |
454 | } else { |
455 | // def @xyz = meta['ExternalFunc'][id] |
456 | Doc doc; |
457 | doc << prefix << " = " << meta_->GetMetaNode(base_func); |
458 | return doc; |
459 | } |
460 | } |
461 | |
462 | Doc RelayTextPrinter::PrintMod(const IRModule& mod) { |
463 | Doc doc; |
464 | int counter = 0; |
465 | // type definitions |
466 | for (const auto& kv : mod->type_definitions) { |
467 | if (counter++ != 0) { |
468 | doc << Doc::NewLine(); |
469 | } |
470 | doc << Print(kv.second); |
471 | doc << Doc::NewLine(); |
472 | } |
473 | // functions |
474 | for (const auto& kv : mod->functions) { |
475 | if (kv.second.as<relay::FunctionNode>()) { |
476 | dg_ = DependencyGraph::Create(&arena_, kv.second); |
477 | } |
478 | if (counter++ != 0) { |
479 | doc << Doc::NewLine(); |
480 | } |
481 | std::ostringstream os; |
482 | os << "def @" << kv.first->name_hint; |
483 | doc << PrintFunc(Doc::Text(os.str()), kv.second); |
484 | doc << Doc::NewLine(); |
485 | } |
486 | return doc; |
487 | } |
488 | |
489 | Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { |
490 | return PrintFunc(Doc::Text("fn " ), GetRef<Function>(op)); |
491 | } |
492 | |
493 | Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { |
494 | Doc doc; |
495 | doc << "@" << op->name_hint; |
496 | return doc; |
497 | } |
498 | |
499 | Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); } |
500 | |
501 | Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { |
502 | Doc doc; |
503 | // visit args first so they are lifted before the op |
504 | // this places op closer to its call site |
505 | std::vector<Doc> args; |
506 | for (const Expr& arg : op->args) { |
507 | args.push_back(Print(arg)); |
508 | } |
509 | |
510 | for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { |
511 | args.push_back(d); |
512 | } |
513 | const auto* cons_node = op->op.as<ConstructorNode>(); |
514 | if (cons_node) { |
515 | doc << cons_node->name_hint; |
516 | } else { |
517 | doc << Print(op->op); |
518 | } |
519 | |
520 | if (cons_node && cons_node->inputs.size() == 0) { |
521 | // don't print as a call if it's a 0-arity cons |
522 | return doc; |
523 | } else { |
524 | doc << "(" << Doc::Concat(args) << ")" ; |
525 | return doc; |
526 | } |
527 | } |
528 | |
529 | Doc RelayTextPrinter::VisitExpr_(const RefCreateNode* op) { |
530 | Doc doc; |
531 | return doc << "ref(" << Print(op->value) << ")" ; |
532 | } |
533 | |
534 | Doc RelayTextPrinter::VisitExpr_(const RefReadNode* op) { |
535 | Doc doc; |
536 | return doc << "ref_read(" << Print(op->ref) << ")" ; |
537 | } |
538 | |
539 | Doc RelayTextPrinter::VisitExpr_(const RefWriteNode* op) { |
540 | Doc doc; |
541 | return doc << "ref_write(" << Print(op->ref) << ", " << Print(op->value) << ")" ; |
542 | } |
543 | |
544 | Doc RelayTextPrinter::VisitExpr_(const MatchNode* op) { |
545 | // TODO(jmp): Lots of code duplication here because PrintBody and PrintScope don't accept Docs. |
546 | Doc doc; |
547 | Doc body; |
548 | doc << "match" ; |
549 | if (!op->complete) { |
550 | doc << "?" ; |
551 | } |
552 | doc << " (" << Print(op->data) << ") {" ; |
553 | std::vector<Doc> clause_docs; |
554 | for (const auto& clause : op->clauses) { |
555 | Doc clause_doc; |
556 | clause_doc << PrintPattern(clause->lhs, false) << " => " ; |
557 | Doc rhs_doc = PrintScope(clause->rhs); |
558 | // TODO(@jroesch): This is unsound right now, and we need to revist it. |
559 | // if (clause->rhs.as<LetNode>()) { |
560 | // only add braces if there are multiple lines on the rhs |
561 | rhs_doc = Doc::Brace("{" , rhs_doc, "}" ); |
562 | // } |
563 | clause_doc << rhs_doc << "," ; |
564 | clause_docs.push_back(clause_doc); |
565 | } |
566 | doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine())) |
567 | << Doc::NewLine() << "}" ; |
568 | return doc; |
569 | } |
570 | |
571 | Doc RelayTextPrinter::PrintPattern(const Pattern& pattern, bool meta) { |
572 | auto it = memo_pattern_.find(pattern); |
573 | if (it != memo_pattern_.end()) return it->second; |
574 | Doc printed_pattern; |
575 | if (meta) { |
576 | printed_pattern = meta_->GetMetaNode(GetRef<ObjectRef>(pattern.get())); |
577 | } else { |
578 | printed_pattern = VisitPattern(pattern); |
579 | } |
580 | memo_pattern_[pattern] = printed_pattern; |
581 | return printed_pattern; |
582 | } |
583 | |
584 | Doc RelayTextPrinter::VisitPattern_(const PatternConstructorNode* p) { |
585 | Doc doc; |
586 | doc << p->constructor->name_hint; |
587 | if (!p->patterns.empty()) { |
588 | doc << "(" ; |
589 | std::vector<Doc> pats; |
590 | for (const auto& pat : p->patterns) { |
591 | pats.push_back(Print(pat)); |
592 | } |
593 | doc << Doc::Concat(pats) << ")" ; |
594 | } |
595 | return doc; |
596 | } |
597 | |
598 | Doc RelayTextPrinter::VisitPattern_(const PatternTupleNode* pt) { |
599 | Doc doc; |
600 | doc << "(" ; |
601 | std::vector<Doc> pats; |
602 | for (const auto& pat : pt->patterns) { |
603 | pats.push_back(Print(pat)); |
604 | } |
605 | doc << Doc::Concat(pats) << ")" ; |
606 | return doc; |
607 | } |
608 | |
609 | Doc RelayTextPrinter::VisitPattern_(const PatternWildcardNode* pw) { return Doc::Text("_" ); } |
610 | |
611 | Doc RelayTextPrinter::VisitPattern_(const PatternVarNode* pv) { return AllocVar(pv->var); } |
612 | |
613 | Doc RelayTextPrinter::VisitExpr_(const ConstructorNode* n) { |
614 | Doc doc; |
615 | doc << n->name_hint; |
616 | if (in_adt_def_ && n->inputs.size() != 0) { |
617 | doc << "(" ; |
618 | std::vector<Doc> inputs; |
619 | for (Type input : n->inputs) { |
620 | inputs.push_back(Print(input)); |
621 | } |
622 | doc << Doc::Concat(inputs) << ")" ; |
623 | } |
624 | return doc; |
625 | } |
626 | |
627 | //------------------------------------ |
628 | // Overload of Type printing functions |
629 | //------------------------------------ |
630 | Doc RelayTextPrinter::PrintType(const Type& type, bool meta) { |
631 | auto it = memo_type_.find(type); |
632 | if (it != memo_type_.end()) return it->second; |
633 | Doc printed_type; |
634 | if (meta) { |
635 | printed_type = meta_->GetMetaNode(GetRef<ObjectRef>(type.get())); |
636 | } else { |
637 | printed_type = VisitType(type); |
638 | } |
639 | memo_type_[type] = printed_type; |
640 | return printed_type; |
641 | } |
642 | |
643 | Doc RelayTextPrinter::VisitTypeDefault_(const Object* node) { |
644 | // by default always print as meta data |
645 | return Print(GetRef<ObjectRef>(node), true); |
646 | } |
647 | |
648 | Doc RelayTextPrinter::VisitType_(const TypeVarNode* node) { return Doc::Text(node->name_hint); } |
649 | |
650 | Doc RelayTextPrinter::VisitType_(const GlobalTypeVarNode* node) { |
651 | return Doc::Text(node->name_hint); |
652 | } |
653 | |
654 | Doc RelayTextPrinter::VisitType_(const TypeCallNode* node) { |
655 | Doc doc = PrintType(node->func, false); |
656 | std::vector<Doc> args; |
657 | for (const Type& t : node->args) { |
658 | args.push_back(PrintType(t, false)); |
659 | } |
660 | doc << "[" ; |
661 | doc << Doc::Concat(args); |
662 | doc << "]" ; |
663 | return doc; |
664 | } |
665 | |
666 | Doc RelayTextPrinter::PrintDType(DataType dtype) { |
667 | return Doc::Text(runtime::DLDataType2String(dtype)); |
668 | } |
669 | |
670 | Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) { |
671 | // scalar type |
672 | if (node->shape.size() == 0) { |
673 | return PrintDType(node->dtype); |
674 | } |
675 | Doc doc; |
676 | doc << "Tensor[(" ; |
677 | std::vector<Doc> shapes; |
678 | for (const PrimExpr& prim_expr : node->shape) { |
679 | // Though not bound within an attribute the attribute visitor will handle the PrimExprs we |
680 | // care about. |
681 | shapes.push_back(PrintAttributeValue(prim_expr)); |
682 | } |
683 | doc << Doc::Concat(shapes); |
684 | return doc << "), " << PrintDType(node->dtype) << "]" ; |
685 | } |
686 | |
687 | Doc RelayTextPrinter::VisitType_(const TupleTypeNode* node) { |
688 | std::vector<Doc> fields; |
689 | for (Type field : node->fields) { |
690 | fields.push_back(Print(field)); |
691 | } |
692 | Doc doc; |
693 | doc << "(" << Doc::Concat(fields); |
694 | // conform to python tuple format (1,) |
695 | if (node->fields.size() == 1) { |
696 | doc << "," ; |
697 | } |
698 | return doc << ")" ; |
699 | } |
700 | |
701 | Doc RelayTextPrinter::VisitType_(const FuncTypeNode* node) { |
702 | Doc doc; |
703 | doc << "fn " ; |
704 | if (node->type_params.size() != 0) { |
705 | doc << "[" ; |
706 | std::vector<Doc> type_params; |
707 | for (Type type_param : node->type_params) { |
708 | type_params.push_back(Print(type_param)); |
709 | } |
710 | doc << Doc::Concat(type_params); |
711 | doc << "]" ; |
712 | } |
713 | std::vector<Doc> arg_types; |
714 | for (Type arg_type : node->arg_types) { |
715 | arg_types.push_back(Print(arg_type)); |
716 | } |
717 | return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type); |
718 | } |
719 | |
720 | Doc RelayTextPrinter::VisitType_(const RelayRefTypeNode* node) { |
721 | Doc doc; |
722 | return doc << "ref(" << Print(node->value) << ")" ; |
723 | } |
724 | |
725 | Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) { |
726 | in_adt_def_ = true; |
727 | Doc doc; |
728 | doc << "type " << Print(node->header); |
729 | |
730 | // type vars |
731 | if (node->type_vars.size() != 0) { |
732 | doc << "[" ; |
733 | std::vector<Doc> type_vars; |
734 | for (Type type_var : node->type_vars) { |
735 | type_vars.push_back(Print(type_var)); |
736 | } |
737 | doc << Doc::Concat(type_vars) << "]" ; |
738 | } |
739 | doc << " " ; |
740 | |
741 | std::vector<Doc> constructor_docs; |
742 | for (Constructor constructor : node->constructors) { |
743 | constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true)); |
744 | } |
745 | Doc separator; |
746 | separator << "," << Doc::NewLine(); |
747 | Doc adt_body; |
748 | adt_body << Doc::Concat(constructor_docs, separator); |
749 | // add trailing comma if there are any constructors |
750 | if (!constructor_docs.empty()) { |
751 | adt_body << "," ; |
752 | } |
753 | doc << Doc::Brace("{" , adt_body, "}" ); |
754 | in_adt_def_ = false; |
755 | return doc; |
756 | } |
757 | |
758 | //------------------------------------ |
759 | // Overload of Attr printing functions |
760 | //------------------------------------ |
761 | |
762 | Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { |
763 | // Since we don't have any overload for a specific attribute type we'll need to force |
764 | // the meta[...] representation to avoid infinite regress. |
765 | return PrintAttributeValue(GetRef<ObjectRef>(op), /*force_meta=*/true); |
766 | } |
767 | |
768 | Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { |
769 | Doc doc; |
770 | doc << "[" ; |
771 | std::vector<Doc> arr_vals; |
772 | for (const auto& val : *op) { |
773 | arr_vals.push_back(PrintAttributeValue(val)); |
774 | } |
775 | doc << Doc::Concat(arr_vals); |
776 | doc << "]" ; |
777 | return doc; |
778 | } |
779 | |
780 | Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) { |
781 | if (support::IsSimpleScalarDtype(op->dtype)) { |
782 | return Doc::Text(support::IntImmToString(GetRef<IntImm>(op))); |
783 | } else { |
784 | // Fallback: Print int64_t without width suffix. |
785 | return Doc::Text(std::to_string(op->value)); |
786 | } |
787 | } |
788 | |
789 | Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) { |
790 | if (support::IsSimpleScalarDtype(op->dtype)) { |
791 | return Doc::Text(support::FloatImmToString(GetRef<FloatImm>(op))); |
792 | } else { |
793 | // Fallbock: Print double without width suffix. |
794 | return Doc::Text(std::to_string(op->value)); |
795 | } |
796 | } |
797 | |
798 | Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) { |
799 | return Doc::StrLiteral(op->value); |
800 | } |
801 | |
802 | /*! |
803 | * \brief Attribute printer which prints the attributes in the call. |
804 | */ |
805 | class RelayTextPrinter::AttrPrinter : public AttrVisitor { |
806 | public: |
807 | AttrPrinter(std::vector<Doc>* doc, RelayTextPrinter* parent) : docs(doc), parent_(parent) {} |
808 | |
809 | template <typename T> |
810 | void PrintKV(const char* key, const T& value) { |
811 | Doc doc; |
812 | doc << key << "=" << value; |
813 | docs->push_back(doc); |
814 | } |
815 | |
816 | void Visit(const char* key, double* value) final { |
817 | Doc doc; |
818 | doc << key << "=" << *value << "f" ; |
819 | docs->push_back(doc); |
820 | } |
821 | |
822 | void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } |
823 | void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } |
824 | void Visit(const char* key, int* value) final { PrintKV(key, *value); } |
825 | void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); } |
826 | void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); } |
827 | void Visit(const char* key, void** value) final { LOG(FATAL) << "do not allow void as argument" ; } |
828 | void Visit(const char* key, DataType* value) final { |
829 | PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value))); |
830 | } |
831 | void Visit(const char* key, runtime::NDArray* value) final { |
832 | LOG(FATAL) << "do not allow NDarray as argument" ; |
833 | } |
834 | void Visit(const char* key, runtime::ObjectRef* obj) final { |
835 | PrintKV(key, parent_->PrintAttributeValue(*obj)); |
836 | } |
837 | |
838 | private: |
839 | std::vector<Doc>* docs; |
840 | RelayTextPrinter* parent_; |
841 | }; |
842 | |
843 | void RelayTextPrinter::AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs, |
844 | bool include_type_key) { |
845 | if (!attrs.defined()) { |
846 | return; |
847 | } |
848 | AttrPrinter printer(docs, this); |
849 | // Need to drop cost cast since in general VisitNonDefaultAttrs can mutate, but in this |
850 | // case we are read-only. |
851 | const_cast<BaseAttrsNode*>(attrs.get())->VisitNonDefaultAttrs(&printer); |
852 | if (include_type_key) { |
853 | std::string s = attrs->GetTypeKey(); |
854 | printer.Visit("attrs_type_key" , &s); |
855 | } |
856 | } |
857 | |
858 | std::vector<Doc> RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) { |
859 | std::vector<Doc> docs; |
860 | if (!attrs.defined()) { |
861 | return docs; |
862 | } |
863 | const auto* op_node = op.as<OpNode>(); |
864 | if (show_meta_data_ && op_node && (attrs->type_index() != op_node->attrs_type_index)) { |
865 | // The parser can only understand calls with attributes if they match the operator's |
866 | // declared attribute type. If that's not the case fall back to the meta[...] representation. |
867 | docs.push_back(meta_->GetMetaNode(attrs)); |
868 | } else { |
869 | AppendGenericAttrs(&docs, attrs, /*include_type_key=*/!op_node); |
870 | } |
871 | return docs; |
872 | } |
873 | |
874 | std::vector<Doc> RelayTextPrinter::PrintDictAttrs(const DictAttrs& dict_attrs) { |
875 | if (!dict_attrs.defined()) { |
876 | return {}; |
877 | } |
878 | return PrintDictAttrs(dict_attrs->dict); |
879 | } |
880 | |
881 | std::vector<Doc> RelayTextPrinter::PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs) { |
882 | std::vector<Doc> docs; |
883 | if (!dict_attrs.defined()) { |
884 | return docs; |
885 | } |
886 | for (const auto& k : dict_attrs) { |
887 | Doc doc; |
888 | doc << k.first << "=" << PrintAttributeValue(k.second); |
889 | docs.push_back(doc); |
890 | } |
891 | return docs; |
892 | } |
893 | |
894 | Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_meta) { |
895 | if (value.defined()) { |
896 | Doc printed_attr; |
897 | if (value.as<tvm::tir::AnyNode>()) { |
898 | printed_attr << "?" ; |
899 | } else if (auto str_obj = value.as<tvm::StringObj>()) { |
900 | printed_attr << Doc::StrLiteral(GetRef<String>(str_obj)); |
901 | } else if (force_meta) { |
902 | printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value)); |
903 | } else if (const auto* virtual_device_node = value.as<VirtualDeviceNode>()) { |
904 | if (show_meta_data_) { |
905 | printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(virtual_device_node)); |
906 | } else { |
907 | // Special case: The ReprPrinter for VirtualDeviceNodes is much easier to work with while |
908 | // debugging. |
909 | std::ostringstream os; |
910 | os << GetRef<VirtualDevice>(virtual_device_node); |
911 | return Doc::Text(os.str()); |
912 | } |
913 | } else if (const auto* base_attr_node = value.as<BaseAttrsNode>()) { |
914 | if (show_meta_data_) { |
915 | printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(base_attr_node)); |
916 | } else { |
917 | // Special case: The non-meta form for attributes are much easier to work with while |
918 | // debugging. |
919 | printed_attr = PrintAttrsAsAttributeValue(GetRef<Attrs>(base_attr_node)); |
920 | } |
921 | } else if (const auto* base_map_node = value.as<MapNode>()) { |
922 | if (show_meta_data_) { |
923 | printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(base_map_node)); |
924 | } else { |
925 | // Special case: Show maps fields as key=value pairs to help debugging. |
926 | printed_attr << PrintMapAsAttributeValue(GetRef<Map<ObjectRef, ObjectRef>>(base_map_node)); |
927 | } |
928 | } else if (const auto* global_var_node = value.as<GlobalVarNode>()) { |
929 | if (show_meta_data_) { |
930 | printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(global_var_node)); |
931 | } else { |
932 | printed_attr << "'" << global_var_node->name_hint << "'" ; |
933 | } |
934 | } else { |
935 | printed_attr = VisitAttr(value); |
936 | } |
937 | return printed_attr; |
938 | } else { |
939 | return Doc::Text("None" ); |
940 | } |
941 | } |
942 | |
943 | Doc RelayTextPrinter::PrintAttrsAsAttributeValue(const Attrs& attrs) { |
944 | std::vector<Doc> docs; |
945 | AppendGenericAttrs(&docs, attrs, /*include_type_key=*/false); |
946 | Doc doc; |
947 | doc << "{" << Doc::Concat(docs) << "}" ; |
948 | return doc; |
949 | } |
950 | |
951 | Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map) { |
952 | std::vector<Doc> docs; |
953 | for (const auto& k : map) { |
954 | Doc doc; |
955 | doc << PrintAttributeValue(k.first); |
956 | doc << "=" ; |
957 | doc << PrintAttributeValue(k.second); |
958 | docs.push_back(doc); |
959 | } |
960 | Doc doc; |
961 | doc << "{" << Doc::Concat(docs) << "}" ; |
962 | return doc; |
963 | } |
964 | |
965 | Doc RelayTextPrinter::PrintSpan(const Span& span) { |
966 | Doc doc; |
967 | const auto* span_node = span.as<SpanNode>(); |
968 | ICHECK(span_node); |
969 | doc << span_node->source_name->name << ":" << span_node->line << ":" << span_node->column; |
970 | return doc; |
971 | } |
972 | |
973 | } // namespace relay |
974 | } // namespace tvm |
975 | |