1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/runtime/logging.h>
20#include <tvm/runtime/registry.h>
21#include <tvm/script/printer/doc.h>
22
23#include <algorithm>
24#include <cmath>
25#include <string>
26
27#include "../../../support/str_escape.h"
28#include "../../../support/utils.h"
29#include "./base_doc_printer.h"
30
31namespace tvm {
32namespace script {
33namespace printer {
34
35/*!
36 * \brief Operator precedence
37 *
38 * This is based on
39 * https://docs.python.org/3/reference/expressions.html#operator-precedence
40 */
41enum class ExprPrecedence : int32_t {
42 /*! \brief Unknown precedence */
43 kUnkown = 0,
44 /*! \brief Lambda Expression */
45 kLambda = 1,
46 /*! \brief Conditional Expression */
47 kIfThenElse = 2,
48 /*! \brief Boolean OR */
49 kBooleanOr = 3,
50 /*! \brief Boolean AND */
51 kBooleanAnd = 4,
52 /*! \brief Boolean NOT */
53 kBooleanNot = 5,
54 /*! \brief Comparisons */
55 kComparison = 6,
56 /*! \brief Bitwise OR */
57 kBitwiseOr = 7,
58 /*! \brief Bitwise XOR */
59 kBitwiseXor = 8,
60 /*! \brief Bitwise AND */
61 kBitwiseAnd = 9,
62 /*! \brief Shift Operators */
63 kShift = 10,
64 /*! \brief Addition and subtraction */
65 kAdd = 11,
66 /*! \brief Multiplication, division, floor division, remainder */
67 kMult = 12,
68 /*! \brief Positive negative and bitwise NOT */
69 kUnary = 13,
70 /*! \brief Exponentiation */
71 kExp = 14,
72 /*! \brief Index access, attribute access, call and atom expression */
73 kIdentity = 15,
74};
75
76ExprPrecedence GetExprPrecedence(const ExprDoc& doc) {
77 // Key is the value of OperationDocNode::Kind
78 static const std::vector<ExprPrecedence> op_kind_precedence = []() {
79 using OpKind = OperationDocNode::Kind;
80 std::map<OpKind, ExprPrecedence> raw_table = {
81 {OpKind::kUSub, ExprPrecedence::kUnary},
82 {OpKind::kInvert, ExprPrecedence::kUnary},
83 {OpKind::kNot, ExprPrecedence::kBooleanNot},
84 {OpKind::kAdd, ExprPrecedence::kAdd},
85 {OpKind::kSub, ExprPrecedence::kAdd},
86 {OpKind::kMult, ExprPrecedence::kMult},
87 {OpKind::kDiv, ExprPrecedence::kMult},
88 {OpKind::kFloorDiv, ExprPrecedence::kMult},
89 {OpKind::kMod, ExprPrecedence::kMult},
90 {OpKind::kPow, ExprPrecedence::kExp},
91 {OpKind::kLShift, ExprPrecedence::kShift},
92 {OpKind::kRShift, ExprPrecedence::kShift},
93 {OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd},
94 {OpKind::kBitOr, ExprPrecedence::kBitwiseOr},
95 {OpKind::kBitXor, ExprPrecedence::kBitwiseXor},
96 {OpKind::kLt, ExprPrecedence::kComparison},
97 {OpKind::kLtE, ExprPrecedence::kComparison},
98 {OpKind::kEq, ExprPrecedence::kComparison},
99 {OpKind::kNotEq, ExprPrecedence::kComparison},
100 {OpKind::kGt, ExprPrecedence::kComparison},
101 {OpKind::kGtE, ExprPrecedence::kComparison},
102 {OpKind::kAnd, ExprPrecedence::kBooleanAnd},
103 {OpKind::kOr, ExprPrecedence::kBooleanOr},
104 {OpKind::kIfThenElse, ExprPrecedence::kIfThenElse},
105 };
106 int n = static_cast<int>(OpKind::kSpecialEnd);
107 std::vector<ExprPrecedence> table(n + 1, ExprPrecedence::kUnkown);
108 for (const auto& kv : raw_table) {
109 table[static_cast<int>(kv.first)] = kv.second;
110 }
111 return table;
112 }();
113
114 // Key is the type index of Doc
115 static const std::unordered_map<uint32_t, ExprPrecedence> doc_type_precedence = {
116 {LiteralDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
117 {IdDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
118 {AttrAccessDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
119 {IndexDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
120 {CallDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
121 {LambdaDocNode::RuntimeTypeIndex(), ExprPrecedence::kLambda},
122 {TupleDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
123 {ListDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
124 {DictDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
125 };
126
127 if (const auto* op_doc = doc.as<OperationDocNode>()) {
128 size_t kind = static_cast<int>(op_doc->kind);
129 ICHECK_LT(kind, op_kind_precedence.size()) << "ValueError: Invalid operation: " << kind;
130 ExprPrecedence precedence = op_kind_precedence[kind];
131 ICHECK(precedence != ExprPrecedence::kUnkown)
132 << "Precedence for operator " << static_cast<int>(op_doc->kind) << " is unknown";
133 return precedence;
134 }
135 auto it = doc_type_precedence.find(doc->type_index());
136 if (it != doc_type_precedence.end()) {
137 return it->second;
138 }
139 ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is unknown";
140 throw;
141}
142
143class PythonDocPrinter : public DocPrinter {
144 public:
145 explicit PythonDocPrinter(const PrinterConfig& options) : DocPrinter(options) {}
146
147 protected:
148 using DocPrinter::PrintDoc;
149
150 void PrintTypedDoc(const LiteralDoc& doc) final;
151 void PrintTypedDoc(const IdDoc& doc) final;
152 void PrintTypedDoc(const AttrAccessDoc& doc) final;
153 void PrintTypedDoc(const IndexDoc& doc) final;
154 void PrintTypedDoc(const OperationDoc& doc) final;
155 void PrintTypedDoc(const CallDoc& doc) final;
156 void PrintTypedDoc(const LambdaDoc& doc) final;
157 void PrintTypedDoc(const ListDoc& doc) final;
158 void PrintTypedDoc(const DictDoc& doc) final;
159 void PrintTypedDoc(const TupleDoc& doc) final;
160 void PrintTypedDoc(const SliceDoc& doc) final;
161 void PrintTypedDoc(const StmtBlockDoc& doc) final;
162 void PrintTypedDoc(const AssignDoc& doc) final;
163 void PrintTypedDoc(const IfDoc& doc) final;
164 void PrintTypedDoc(const WhileDoc& doc) final;
165 void PrintTypedDoc(const ForDoc& doc) final;
166 void PrintTypedDoc(const ExprStmtDoc& doc) final;
167 void PrintTypedDoc(const AssertDoc& doc) final;
168 void PrintTypedDoc(const ReturnDoc& doc) final;
169 void PrintTypedDoc(const ScopeDoc& doc) final;
170 void PrintTypedDoc(const FunctionDoc& doc) final;
171 void PrintTypedDoc(const ClassDoc& doc) final;
172 void PrintTypedDoc(const CommentDoc& doc) final;
173 void PrintTypedDoc(const DocStringDoc& doc) final;
174
175 private:
176 void NewLineWithoutIndent() {
177 size_t start_pos = output_.tellp();
178 output_ << "\n";
179 size_t end_pos = output_.tellp();
180 underlines_exempted_.push_back({start_pos, end_pos});
181 }
182
183 template <typename DocType>
184 void PrintJoinedDocs(const Array<DocType>& docs, const std::string& separator) {
185 bool is_first = true;
186 for (auto& doc : docs) {
187 if (is_first) {
188 is_first = false;
189 } else {
190 output_ << separator;
191 }
192 PrintDoc(doc);
193 }
194 }
195
196 void PrintIndentedBlock(const Array<StmtDoc>& docs) {
197 IncreaseIndent();
198 for (const StmtDoc& d : docs) {
199 NewLine();
200 PrintDoc(d);
201 }
202 if (docs.empty()) {
203 NewLine();
204 output_ << "pass";
205 }
206 DecreaseIndent();
207 }
208
209 void PrintDecorators(const Array<ExprDoc>& decorators) {
210 for (const ExprDoc& decorator : decorators) {
211 output_ << "@";
212 PrintDoc(decorator);
213 NewLine();
214 }
215 }
216
217 /*!
218 * \brief Print expression and add parenthesis if needed.
219 */
220 void PrintChildExpr(const ExprDoc& doc, ExprPrecedence parent_precedence,
221 bool parenthesis_for_same_precedence = false) {
222 ExprPrecedence doc_precedence = GetExprPrecedence(doc);
223 if (doc_precedence < parent_precedence ||
224 (parenthesis_for_same_precedence && doc_precedence == parent_precedence)) {
225 output_ << "(";
226 PrintDoc(doc);
227 output_ << ")";
228 } else {
229 PrintDoc(doc);
230 }
231 }
232
233 /*!
234 * \brief Print expression and add parenthesis if doc has lower precedence than parent.
235 */
236 void PrintChildExpr(const ExprDoc& doc, const ExprDoc& parent,
237 bool parenthesis_for_same_precedence = false) {
238 ExprPrecedence parent_precedence = GetExprPrecedence(parent);
239 return PrintChildExpr(doc, parent_precedence, parenthesis_for_same_precedence);
240 }
241
242 /*!
243 * \brief Print expression and add parenthesis if doc doesn't have higher precedence than parent.
244 *
245 * This function should be used to print an child expression that needs to be wrapped
246 * by parenthesis even if it has the same precedence as its parent, e.g., the `b` in `a + b`
247 * and the `b` and `c` in `a if b else c`.
248 */
249 void PrintChildExprConservatively(const ExprDoc& doc, const ExprDoc& parent) {
250 PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence=*/true);
251 }
252
253 void MaybePrintCommentInline(const StmtDoc& stmt) {
254 if (stmt->comment.defined()) {
255 const std::string& comment = stmt->comment.value();
256 bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end();
257 CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey()
258 << " cannot have newline.";
259 size_t start_pos = output_.tellp();
260 output_ << " # " << comment;
261 size_t end_pos = output_.tellp();
262 underlines_exempted_.push_back({start_pos, end_pos});
263 }
264 }
265
266 void MaybePrintCommenMultiLines(const StmtDoc& stmt, bool new_line = false) {
267 if (stmt->comment.defined()) {
268 std::vector<std::string> comment_lines = support::Split(stmt->comment.value(), '\n');
269 bool first_line = true;
270 size_t start_pos = output_.tellp();
271 for (const std::string& line : comment_lines) {
272 if (first_line) {
273 output_ << "# " << line;
274 first_line = false;
275 } else {
276 NewLine() << "# " << line;
277 }
278 }
279 size_t end_pos = output_.tellp();
280 underlines_exempted_.push_back({start_pos, end_pos});
281 if (new_line) {
282 NewLine();
283 }
284 }
285 }
286
287 void PrintDocString(const String& comment) {
288 size_t start_pos = output_.tellp();
289 output_ << "\"\"\"";
290
291 std::vector<std::string> comment_lines = support::Split(comment, '\n');
292 for (const std::string& line : comment_lines) {
293 if (line.empty()) {
294 // No indentation on empty line
295 output_ << "\n";
296 } else {
297 NewLine() << line;
298 }
299 }
300
301 NewLine() << "\"\"\"";
302 size_t end_pos = output_.tellp();
303 underlines_exempted_.push_back({start_pos, end_pos});
304 }
305
306 void PrintBlockComment(const String& comment) {
307 IncreaseIndent();
308 NewLine();
309 PrintDocString(comment);
310 DecreaseIndent();
311 }
312};
313
314void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
315 const ObjectRef& value = doc->value;
316 if (!value.defined()) {
317 output_ << "None";
318 } else if (const auto* int_imm = value.as<IntImmNode>()) {
319 if (int_imm->dtype.is_bool()) {
320 output_ << (int_imm->value ? "True" : "False");
321 } else {
322 output_ << int_imm->value;
323 }
324 } else if (const auto* float_imm = value.as<FloatImmNode>()) {
325 // TODO(yelite): Make float number printing roundtrippable
326 output_.precision(17);
327 if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) {
328 output_ << '"' << float_imm->value << '"';
329 } else {
330 output_ << float_imm->value;
331 }
332 } else if (const auto* string_obj = value.as<StringObj>()) {
333 output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\"";
334 } else {
335 LOG(FATAL) << "TypeError: Unsupported literal value type: " << value->GetTypeKey();
336 }
337}
338
339void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; }
340
341void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) {
342 PrintChildExpr(doc->value, doc);
343 output_ << "." << doc->name;
344}
345
346void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) {
347 PrintChildExpr(doc->value, doc);
348 if (doc->indices.size() == 0) {
349 output_ << "[()]";
350 } else {
351 output_ << "[";
352 PrintJoinedDocs(doc->indices, ", ");
353 output_ << "]";
354 }
355}
356
357const std::string OperatorToString(OperationDocNode::Kind operation_kind) {
358 static const std::vector<std::string> op_kind2str = []() {
359 using OpKind = OperationDocNode::Kind;
360 std::map<OpKind, std::string> raw_table = {
361 {OpKind::kUSub, "-"}, //
362 {OpKind::kInvert, "~"}, //
363 {OpKind::kNot, "not "}, //
364 {OpKind::kAdd, "+"}, //
365 {OpKind::kSub, "-"}, //
366 {OpKind::kMult, "*"}, //
367 {OpKind::kDiv, "/"}, //
368 {OpKind::kFloorDiv, "//"}, //
369 {OpKind::kMod, "%"}, //
370 {OpKind::kPow, "**"}, //
371 {OpKind::kLShift, "<<"}, //
372 {OpKind::kRShift, ">>"}, //
373 {OpKind::kBitAnd, "&"}, //
374 {OpKind::kBitOr, "|"}, //
375 {OpKind::kBitXor, "^"}, //
376 {OpKind::kLt, "<"}, //
377 {OpKind::kLtE, "<="}, //
378 {OpKind::kEq, "=="}, //
379 {OpKind::kNotEq, "!="}, //
380 {OpKind::kGt, ">"}, //
381 {OpKind::kGtE, ">="}, //
382 {OpKind::kAnd, "and"}, //
383 {OpKind::kOr, "or"}, //
384 };
385
386 std::vector<std::string> table;
387 table.resize(static_cast<int>(OperationDocNode::Kind::kSpecialEnd) + 1);
388
389 for (const auto& kv : raw_table) {
390 table[static_cast<int>(kv.first)] = kv.second;
391 }
392
393 return table;
394 }();
395
396 auto op_index = static_cast<int>(operation_kind);
397 ICHECK_LT(op_index, op_kind2str.size());
398 const std::string str = op_kind2str[op_index];
399 ICHECK(!str.empty()) << "OperationDocNode::Kind " << static_cast<int>(operation_kind)
400 << " cannot be converted to operator token in Python directly.";
401 return str;
402}
403
404void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) {
405 using OpKind = OperationDocNode::Kind;
406 if (doc->kind < OpKind::kUnaryEnd) {
407 // Unary Operators
408 ICHECK_EQ(doc->operands.size(), 1);
409 output_ << OperatorToString(doc->kind);
410 PrintChildExpr(doc->operands[0], doc);
411 } else if (doc->kind == OpKind::kPow) {
412 // Power operator is different than other binary operators
413 // It's right-associative and binds less tightly than unary operator on its right.
414 // https://docs.python.org/3/reference/expressions.html#the-power-operator
415 // https://docs.python.org/3/reference/expressions.html#operator-precedence
416 ICHECK_EQ(doc->operands.size(), 2);
417 PrintChildExprConservatively(doc->operands[0], doc);
418 output_ << " ** ";
419 PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary);
420 } else if (doc->kind < OpKind::kBinaryEnd) {
421 // Binary Operator
422 ICHECK_EQ(doc->operands.size(), 2);
423 PrintChildExpr(doc->operands[0], doc);
424 output_ << " " << OperatorToString(doc->kind) << " ";
425 PrintChildExprConservatively(doc->operands[1], doc);
426 } else if (doc->kind == OpKind::kIfThenElse) {
427 ICHECK_EQ(doc->operands.size(), 3)
428 << "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size();
429 PrintChildExpr(doc->operands[1], doc);
430 output_ << " if ";
431 PrintChildExprConservatively(doc->operands[0], doc);
432 output_ << " else ";
433 PrintChildExprConservatively(doc->operands[2], doc);
434 } else {
435 LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast<int>(doc->kind);
436 throw;
437 }
438}
439
440void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) {
441 PrintChildExpr(doc->callee, doc);
442
443 output_ << "(";
444
445 // Print positional args
446 bool is_first = true;
447 for (const ExprDoc& arg : doc->args) {
448 if (is_first) {
449 is_first = false;
450 } else {
451 output_ << ", ";
452 }
453 PrintDoc(arg);
454 }
455
456 // Print keyword args
457 ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size())
458 << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values.";
459 for (size_t i = 0; i < doc->kwargs_keys.size(); i++) {
460 if (is_first) {
461 is_first = false;
462 } else {
463 output_ << ", ";
464 }
465 const String& keyword = doc->kwargs_keys[i];
466 output_ << keyword;
467 output_ << "=";
468 PrintDoc(doc->kwargs_values[i]);
469 }
470
471 output_ << ")";
472}
473
474void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) {
475 output_ << "lambda ";
476 PrintJoinedDocs(doc->args, ", ");
477 output_ << ": ";
478 PrintChildExpr(doc->body, doc);
479}
480
481void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) {
482 output_ << "[";
483 PrintJoinedDocs(doc->elements, ", ");
484 output_ << "]";
485}
486
487void PythonDocPrinter::PrintTypedDoc(const TupleDoc& doc) {
488 output_ << "(";
489 if (doc->elements.size() == 1) {
490 PrintDoc(doc->elements[0]);
491 output_ << ",";
492 } else {
493 PrintJoinedDocs(doc->elements, ", ");
494 }
495 output_ << ")";
496}
497
498void PythonDocPrinter::PrintTypedDoc(const DictDoc& doc) {
499 ICHECK_EQ(doc->keys.size(), doc->values.size())
500 << "DictDoc should have equal number of elements in keys and values.";
501 output_ << "{";
502 size_t idx = 0;
503 for (const ExprDoc& key : doc->keys) {
504 if (idx > 0) {
505 output_ << ", ";
506 }
507 PrintDoc(key);
508 output_ << ": ";
509 PrintDoc(doc->values[idx]);
510 idx++;
511 }
512 output_ << "}";
513}
514
515void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
516 if (doc->start != nullptr) {
517 PrintDoc(doc->start.value());
518 }
519 output_ << ":";
520 if (doc->stop != nullptr) {
521 PrintDoc(doc->stop.value());
522 }
523 if (doc->step != nullptr) {
524 output_ << ":";
525 PrintDoc(doc->step.value());
526 }
527}
528
529void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
530 for (const StmtDoc& stmt : doc->stmts) {
531 PrintDoc(stmt);
532 if (stmt != doc->stmts.back()) {
533 NewLine();
534 }
535 }
536}
537
538void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) {
539 if (const auto* tuple_doc = doc->lhs.as<TupleDocNode>()) {
540 PrintJoinedDocs(tuple_doc->elements, ", ");
541 } else {
542 PrintDoc(doc->lhs);
543 }
544
545 if (doc->annotation) {
546 output_ << ": ";
547 PrintDoc(doc->annotation.value());
548 }
549 if (doc->rhs) {
550 output_ << " = ";
551 PrintDoc(doc->rhs.value());
552 }
553 MaybePrintCommentInline(doc);
554}
555
556void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) {
557 MaybePrintCommenMultiLines(doc, true);
558 output_ << "if ";
559 PrintDoc(doc->predicate);
560 output_ << ":";
561
562 PrintIndentedBlock(doc->then_branch);
563
564 if (!doc->else_branch.empty()) {
565 NewLine();
566 output_ << "else:";
567 PrintIndentedBlock(doc->else_branch);
568 }
569}
570
571void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) {
572 MaybePrintCommenMultiLines(doc, true);
573 output_ << "while ";
574 PrintDoc(doc->predicate);
575 output_ << ":";
576
577 PrintIndentedBlock(doc->body);
578}
579
580void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) {
581 MaybePrintCommenMultiLines(doc, true);
582 output_ << "for ";
583 if (const auto* tuple = doc->lhs.as<TupleDocNode>()) {
584 if (tuple->elements.size() == 1) {
585 PrintDoc(tuple->elements[0]);
586 output_ << ",";
587 } else {
588 PrintJoinedDocs(tuple->elements, ", ");
589 }
590 } else {
591 PrintDoc(doc->lhs);
592 }
593 output_ << " in ";
594 PrintDoc(doc->rhs);
595 output_ << ":";
596
597 PrintIndentedBlock(doc->body);
598}
599
600void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) {
601 MaybePrintCommenMultiLines(doc, true);
602 output_ << "with ";
603 PrintDoc(doc->rhs);
604 if (doc->lhs != nullptr) {
605 output_ << " as ";
606 PrintDoc(doc->lhs.value());
607 }
608 output_ << ":";
609
610 PrintIndentedBlock(doc->body);
611}
612
613void PythonDocPrinter::PrintTypedDoc(const ExprStmtDoc& doc) {
614 PrintDoc(doc->expr);
615 MaybePrintCommentInline(doc);
616}
617
618void PythonDocPrinter::PrintTypedDoc(const AssertDoc& doc) {
619 output_ << "assert ";
620 PrintDoc(doc->test);
621 if (doc->msg.defined()) {
622 output_ << ", ";
623 PrintDoc(doc->msg.value());
624 }
625 MaybePrintCommentInline(doc);
626}
627
628void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) {
629 output_ << "return ";
630 PrintDoc(doc->value);
631 MaybePrintCommentInline(doc);
632}
633
634void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
635 for (const AssignDoc& arg_doc : doc->args) {
636 ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them.";
637 }
638
639 PrintDecorators(doc->decorators);
640
641 output_ << "def ";
642 PrintDoc(doc->name);
643
644 output_ << "(";
645 PrintJoinedDocs(doc->args, ", ");
646 output_ << ")";
647
648 if (doc->return_type.defined()) {
649 output_ << " -> ";
650 PrintDoc(doc->return_type.value());
651 }
652
653 output_ << ":";
654
655 if (doc->comment.defined()) {
656 PrintBlockComment(doc->comment.value());
657 }
658 PrintIndentedBlock(doc->body);
659 NewLineWithoutIndent();
660}
661
662void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
663 PrintDecorators(doc->decorators);
664
665 output_ << "class ";
666 PrintDoc(doc->name);
667 output_ << ":";
668
669 if (doc->comment.defined()) {
670 PrintBlockComment(doc->comment.value());
671 }
672 PrintIndentedBlock(doc->body);
673 NewLineWithoutIndent();
674}
675
676void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) {
677 if (doc->comment.defined()) {
678 MaybePrintCommenMultiLines(doc, false);
679 }
680}
681
682void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) {
683 if (doc->comment.defined() && !doc->comment.value().empty()) {
684 PrintDocString(doc->comment.value());
685 }
686}
687
688String DocToPythonScript(Doc doc, const PrinterConfig& cfg) {
689 if (cfg->num_context_lines < 0) {
690 cfg->num_context_lines = std::numeric_limits<int32_t>::max();
691 }
692 PythonDocPrinter printer(cfg);
693 printer.Append(doc, cfg);
694 std::string result = printer.GetString();
695 int last_space = result.size();
696 while (last_space > 0 && std::isspace(result[last_space - 1])) {
697 last_space--;
698 }
699 return result.substr(0, last_space);
700}
701
702TVM_REGISTER_GLOBAL("script.printer.DocToPythonScript").set_body_typed(DocToPythonScript);
703
704} // namespace printer
705} // namespace script
706} // namespace tvm
707