1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/runtime/logging.h> |
20 | #include <tvm/runtime/registry.h> |
21 | #include <tvm/script/printer/doc.h> |
22 | |
23 | #include <algorithm> |
24 | #include <cmath> |
25 | #include <string> |
26 | |
27 | #include "../../../support/str_escape.h" |
28 | #include "../../../support/utils.h" |
29 | #include "./base_doc_printer.h" |
30 | |
31 | namespace tvm { |
32 | namespace script { |
33 | namespace printer { |
34 | |
35 | /*! |
36 | * \brief Operator precedence |
37 | * |
38 | * This is based on |
39 | * https://docs.python.org/3/reference/expressions.html#operator-precedence |
40 | */ |
41 | enum class ExprPrecedence : int32_t { |
42 | /*! \brief Unknown precedence */ |
43 | kUnkown = 0, |
44 | /*! \brief Lambda Expression */ |
45 | kLambda = 1, |
46 | /*! \brief Conditional Expression */ |
47 | kIfThenElse = 2, |
48 | /*! \brief Boolean OR */ |
49 | kBooleanOr = 3, |
50 | /*! \brief Boolean AND */ |
51 | kBooleanAnd = 4, |
52 | /*! \brief Boolean NOT */ |
53 | kBooleanNot = 5, |
54 | /*! \brief Comparisons */ |
55 | kComparison = 6, |
56 | /*! \brief Bitwise OR */ |
57 | kBitwiseOr = 7, |
58 | /*! \brief Bitwise XOR */ |
59 | kBitwiseXor = 8, |
60 | /*! \brief Bitwise AND */ |
61 | kBitwiseAnd = 9, |
62 | /*! \brief Shift Operators */ |
63 | kShift = 10, |
64 | /*! \brief Addition and subtraction */ |
65 | kAdd = 11, |
66 | /*! \brief Multiplication, division, floor division, remainder */ |
67 | kMult = 12, |
68 | /*! \brief Positive negative and bitwise NOT */ |
69 | kUnary = 13, |
70 | /*! \brief Exponentiation */ |
71 | kExp = 14, |
72 | /*! \brief Index access, attribute access, call and atom expression */ |
73 | kIdentity = 15, |
74 | }; |
75 | |
76 | ExprPrecedence GetExprPrecedence(const ExprDoc& doc) { |
77 | // Key is the value of OperationDocNode::Kind |
78 | static const std::vector<ExprPrecedence> op_kind_precedence = []() { |
79 | using OpKind = OperationDocNode::Kind; |
80 | std::map<OpKind, ExprPrecedence> raw_table = { |
81 | {OpKind::kUSub, ExprPrecedence::kUnary}, |
82 | {OpKind::kInvert, ExprPrecedence::kUnary}, |
83 | {OpKind::kNot, ExprPrecedence::kBooleanNot}, |
84 | {OpKind::kAdd, ExprPrecedence::kAdd}, |
85 | {OpKind::kSub, ExprPrecedence::kAdd}, |
86 | {OpKind::kMult, ExprPrecedence::kMult}, |
87 | {OpKind::kDiv, ExprPrecedence::kMult}, |
88 | {OpKind::kFloorDiv, ExprPrecedence::kMult}, |
89 | {OpKind::kMod, ExprPrecedence::kMult}, |
90 | {OpKind::kPow, ExprPrecedence::kExp}, |
91 | {OpKind::kLShift, ExprPrecedence::kShift}, |
92 | {OpKind::kRShift, ExprPrecedence::kShift}, |
93 | {OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd}, |
94 | {OpKind::kBitOr, ExprPrecedence::kBitwiseOr}, |
95 | {OpKind::kBitXor, ExprPrecedence::kBitwiseXor}, |
96 | {OpKind::kLt, ExprPrecedence::kComparison}, |
97 | {OpKind::kLtE, ExprPrecedence::kComparison}, |
98 | {OpKind::kEq, ExprPrecedence::kComparison}, |
99 | {OpKind::kNotEq, ExprPrecedence::kComparison}, |
100 | {OpKind::kGt, ExprPrecedence::kComparison}, |
101 | {OpKind::kGtE, ExprPrecedence::kComparison}, |
102 | {OpKind::kAnd, ExprPrecedence::kBooleanAnd}, |
103 | {OpKind::kOr, ExprPrecedence::kBooleanOr}, |
104 | {OpKind::kIfThenElse, ExprPrecedence::kIfThenElse}, |
105 | }; |
106 | int n = static_cast<int>(OpKind::kSpecialEnd); |
107 | std::vector<ExprPrecedence> table(n + 1, ExprPrecedence::kUnkown); |
108 | for (const auto& kv : raw_table) { |
109 | table[static_cast<int>(kv.first)] = kv.second; |
110 | } |
111 | return table; |
112 | }(); |
113 | |
114 | // Key is the type index of Doc |
115 | static const std::unordered_map<uint32_t, ExprPrecedence> doc_type_precedence = { |
116 | {LiteralDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
117 | {IdDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
118 | {AttrAccessDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
119 | {IndexDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
120 | {CallDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
121 | {LambdaDocNode::RuntimeTypeIndex(), ExprPrecedence::kLambda}, |
122 | {TupleDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
123 | {ListDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
124 | {DictDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity}, |
125 | }; |
126 | |
127 | if (const auto* op_doc = doc.as<OperationDocNode>()) { |
128 | size_t kind = static_cast<int>(op_doc->kind); |
129 | ICHECK_LT(kind, op_kind_precedence.size()) << "ValueError: Invalid operation: " << kind; |
130 | ExprPrecedence precedence = op_kind_precedence[kind]; |
131 | ICHECK(precedence != ExprPrecedence::kUnkown) |
132 | << "Precedence for operator " << static_cast<int>(op_doc->kind) << " is unknown" ; |
133 | return precedence; |
134 | } |
135 | auto it = doc_type_precedence.find(doc->type_index()); |
136 | if (it != doc_type_precedence.end()) { |
137 | return it->second; |
138 | } |
139 | ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is unknown" ; |
140 | throw; |
141 | } |
142 | |
143 | class PythonDocPrinter : public DocPrinter { |
144 | public: |
145 | explicit PythonDocPrinter(const PrinterConfig& options) : DocPrinter(options) {} |
146 | |
147 | protected: |
148 | using DocPrinter::PrintDoc; |
149 | |
150 | void PrintTypedDoc(const LiteralDoc& doc) final; |
151 | void PrintTypedDoc(const IdDoc& doc) final; |
152 | void PrintTypedDoc(const AttrAccessDoc& doc) final; |
153 | void PrintTypedDoc(const IndexDoc& doc) final; |
154 | void PrintTypedDoc(const OperationDoc& doc) final; |
155 | void PrintTypedDoc(const CallDoc& doc) final; |
156 | void PrintTypedDoc(const LambdaDoc& doc) final; |
157 | void PrintTypedDoc(const ListDoc& doc) final; |
158 | void PrintTypedDoc(const DictDoc& doc) final; |
159 | void PrintTypedDoc(const TupleDoc& doc) final; |
160 | void PrintTypedDoc(const SliceDoc& doc) final; |
161 | void PrintTypedDoc(const StmtBlockDoc& doc) final; |
162 | void PrintTypedDoc(const AssignDoc& doc) final; |
163 | void PrintTypedDoc(const IfDoc& doc) final; |
164 | void PrintTypedDoc(const WhileDoc& doc) final; |
165 | void PrintTypedDoc(const ForDoc& doc) final; |
166 | void PrintTypedDoc(const ExprStmtDoc& doc) final; |
167 | void PrintTypedDoc(const AssertDoc& doc) final; |
168 | void PrintTypedDoc(const ReturnDoc& doc) final; |
169 | void PrintTypedDoc(const ScopeDoc& doc) final; |
170 | void PrintTypedDoc(const FunctionDoc& doc) final; |
171 | void PrintTypedDoc(const ClassDoc& doc) final; |
172 | void PrintTypedDoc(const CommentDoc& doc) final; |
173 | void PrintTypedDoc(const DocStringDoc& doc) final; |
174 | |
175 | private: |
176 | void NewLineWithoutIndent() { |
177 | size_t start_pos = output_.tellp(); |
178 | output_ << "\n" ; |
179 | size_t end_pos = output_.tellp(); |
180 | underlines_exempted_.push_back({start_pos, end_pos}); |
181 | } |
182 | |
183 | template <typename DocType> |
184 | void PrintJoinedDocs(const Array<DocType>& docs, const std::string& separator) { |
185 | bool is_first = true; |
186 | for (auto& doc : docs) { |
187 | if (is_first) { |
188 | is_first = false; |
189 | } else { |
190 | output_ << separator; |
191 | } |
192 | PrintDoc(doc); |
193 | } |
194 | } |
195 | |
196 | void PrintIndentedBlock(const Array<StmtDoc>& docs) { |
197 | IncreaseIndent(); |
198 | for (const StmtDoc& d : docs) { |
199 | NewLine(); |
200 | PrintDoc(d); |
201 | } |
202 | if (docs.empty()) { |
203 | NewLine(); |
204 | output_ << "pass" ; |
205 | } |
206 | DecreaseIndent(); |
207 | } |
208 | |
209 | void PrintDecorators(const Array<ExprDoc>& decorators) { |
210 | for (const ExprDoc& decorator : decorators) { |
211 | output_ << "@" ; |
212 | PrintDoc(decorator); |
213 | NewLine(); |
214 | } |
215 | } |
216 | |
217 | /*! |
218 | * \brief Print expression and add parenthesis if needed. |
219 | */ |
220 | void PrintChildExpr(const ExprDoc& doc, ExprPrecedence parent_precedence, |
221 | bool parenthesis_for_same_precedence = false) { |
222 | ExprPrecedence doc_precedence = GetExprPrecedence(doc); |
223 | if (doc_precedence < parent_precedence || |
224 | (parenthesis_for_same_precedence && doc_precedence == parent_precedence)) { |
225 | output_ << "(" ; |
226 | PrintDoc(doc); |
227 | output_ << ")" ; |
228 | } else { |
229 | PrintDoc(doc); |
230 | } |
231 | } |
232 | |
233 | /*! |
234 | * \brief Print expression and add parenthesis if doc has lower precedence than parent. |
235 | */ |
236 | void PrintChildExpr(const ExprDoc& doc, const ExprDoc& parent, |
237 | bool parenthesis_for_same_precedence = false) { |
238 | ExprPrecedence parent_precedence = GetExprPrecedence(parent); |
239 | return PrintChildExpr(doc, parent_precedence, parenthesis_for_same_precedence); |
240 | } |
241 | |
242 | /*! |
243 | * \brief Print expression and add parenthesis if doc doesn't have higher precedence than parent. |
244 | * |
245 | * This function should be used to print an child expression that needs to be wrapped |
246 | * by parenthesis even if it has the same precedence as its parent, e.g., the `b` in `a + b` |
247 | * and the `b` and `c` in `a if b else c`. |
248 | */ |
249 | void PrintChildExprConservatively(const ExprDoc& doc, const ExprDoc& parent) { |
250 | PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence=*/true); |
251 | } |
252 | |
253 | void (const StmtDoc& stmt) { |
254 | if (stmt->comment.defined()) { |
255 | const std::string& = stmt->comment.value(); |
256 | bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end(); |
257 | CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey() |
258 | << " cannot have newline." ; |
259 | size_t start_pos = output_.tellp(); |
260 | output_ << " # " << comment; |
261 | size_t end_pos = output_.tellp(); |
262 | underlines_exempted_.push_back({start_pos, end_pos}); |
263 | } |
264 | } |
265 | |
266 | void MaybePrintCommenMultiLines(const StmtDoc& stmt, bool new_line = false) { |
267 | if (stmt->comment.defined()) { |
268 | std::vector<std::string> = support::Split(stmt->comment.value(), '\n'); |
269 | bool first_line = true; |
270 | size_t start_pos = output_.tellp(); |
271 | for (const std::string& line : comment_lines) { |
272 | if (first_line) { |
273 | output_ << "# " << line; |
274 | first_line = false; |
275 | } else { |
276 | NewLine() << "# " << line; |
277 | } |
278 | } |
279 | size_t end_pos = output_.tellp(); |
280 | underlines_exempted_.push_back({start_pos, end_pos}); |
281 | if (new_line) { |
282 | NewLine(); |
283 | } |
284 | } |
285 | } |
286 | |
287 | void PrintDocString(const String& ) { |
288 | size_t start_pos = output_.tellp(); |
289 | output_ << "\"\"\"" ; |
290 | |
291 | std::vector<std::string> = support::Split(comment, '\n'); |
292 | for (const std::string& line : comment_lines) { |
293 | if (line.empty()) { |
294 | // No indentation on empty line |
295 | output_ << "\n" ; |
296 | } else { |
297 | NewLine() << line; |
298 | } |
299 | } |
300 | |
301 | NewLine() << "\"\"\"" ; |
302 | size_t end_pos = output_.tellp(); |
303 | underlines_exempted_.push_back({start_pos, end_pos}); |
304 | } |
305 | |
306 | void (const String& ) { |
307 | IncreaseIndent(); |
308 | NewLine(); |
309 | PrintDocString(comment); |
310 | DecreaseIndent(); |
311 | } |
312 | }; |
313 | |
314 | void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { |
315 | const ObjectRef& value = doc->value; |
316 | if (!value.defined()) { |
317 | output_ << "None" ; |
318 | } else if (const auto* int_imm = value.as<IntImmNode>()) { |
319 | if (int_imm->dtype.is_bool()) { |
320 | output_ << (int_imm->value ? "True" : "False" ); |
321 | } else { |
322 | output_ << int_imm->value; |
323 | } |
324 | } else if (const auto* float_imm = value.as<FloatImmNode>()) { |
325 | // TODO(yelite): Make float number printing roundtrippable |
326 | output_.precision(17); |
327 | if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { |
328 | output_ << '"' << float_imm->value << '"'; |
329 | } else { |
330 | output_ << float_imm->value; |
331 | } |
332 | } else if (const auto* string_obj = value.as<StringObj>()) { |
333 | output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\"" ; |
334 | } else { |
335 | LOG(FATAL) << "TypeError: Unsupported literal value type: " << value->GetTypeKey(); |
336 | } |
337 | } |
338 | |
339 | void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; } |
340 | |
341 | void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { |
342 | PrintChildExpr(doc->value, doc); |
343 | output_ << "." << doc->name; |
344 | } |
345 | |
346 | void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { |
347 | PrintChildExpr(doc->value, doc); |
348 | if (doc->indices.size() == 0) { |
349 | output_ << "[()]" ; |
350 | } else { |
351 | output_ << "[" ; |
352 | PrintJoinedDocs(doc->indices, ", " ); |
353 | output_ << "]" ; |
354 | } |
355 | } |
356 | |
357 | const std::string OperatorToString(OperationDocNode::Kind operation_kind) { |
358 | static const std::vector<std::string> op_kind2str = []() { |
359 | using OpKind = OperationDocNode::Kind; |
360 | std::map<OpKind, std::string> raw_table = { |
361 | {OpKind::kUSub, "-" }, // |
362 | {OpKind::kInvert, "~" }, // |
363 | {OpKind::kNot, "not " }, // |
364 | {OpKind::kAdd, "+" }, // |
365 | {OpKind::kSub, "-" }, // |
366 | {OpKind::kMult, "*" }, // |
367 | {OpKind::kDiv, "/" }, // |
368 | {OpKind::kFloorDiv, "//" }, // |
369 | {OpKind::kMod, "%" }, // |
370 | {OpKind::kPow, "**" }, // |
371 | {OpKind::kLShift, "<<" }, // |
372 | {OpKind::kRShift, ">>" }, // |
373 | {OpKind::kBitAnd, "&" }, // |
374 | {OpKind::kBitOr, "|" }, // |
375 | {OpKind::kBitXor, "^" }, // |
376 | {OpKind::kLt, "<" }, // |
377 | {OpKind::kLtE, "<=" }, // |
378 | {OpKind::kEq, "==" }, // |
379 | {OpKind::kNotEq, "!=" }, // |
380 | {OpKind::kGt, ">" }, // |
381 | {OpKind::kGtE, ">=" }, // |
382 | {OpKind::kAnd, "and" }, // |
383 | {OpKind::kOr, "or" }, // |
384 | }; |
385 | |
386 | std::vector<std::string> table; |
387 | table.resize(static_cast<int>(OperationDocNode::Kind::kSpecialEnd) + 1); |
388 | |
389 | for (const auto& kv : raw_table) { |
390 | table[static_cast<int>(kv.first)] = kv.second; |
391 | } |
392 | |
393 | return table; |
394 | }(); |
395 | |
396 | auto op_index = static_cast<int>(operation_kind); |
397 | ICHECK_LT(op_index, op_kind2str.size()); |
398 | const std::string str = op_kind2str[op_index]; |
399 | ICHECK(!str.empty()) << "OperationDocNode::Kind " << static_cast<int>(operation_kind) |
400 | << " cannot be converted to operator token in Python directly." ; |
401 | return str; |
402 | } |
403 | |
404 | void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) { |
405 | using OpKind = OperationDocNode::Kind; |
406 | if (doc->kind < OpKind::kUnaryEnd) { |
407 | // Unary Operators |
408 | ICHECK_EQ(doc->operands.size(), 1); |
409 | output_ << OperatorToString(doc->kind); |
410 | PrintChildExpr(doc->operands[0], doc); |
411 | } else if (doc->kind == OpKind::kPow) { |
412 | // Power operator is different than other binary operators |
413 | // It's right-associative and binds less tightly than unary operator on its right. |
414 | // https://docs.python.org/3/reference/expressions.html#the-power-operator |
415 | // https://docs.python.org/3/reference/expressions.html#operator-precedence |
416 | ICHECK_EQ(doc->operands.size(), 2); |
417 | PrintChildExprConservatively(doc->operands[0], doc); |
418 | output_ << " ** " ; |
419 | PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary); |
420 | } else if (doc->kind < OpKind::kBinaryEnd) { |
421 | // Binary Operator |
422 | ICHECK_EQ(doc->operands.size(), 2); |
423 | PrintChildExpr(doc->operands[0], doc); |
424 | output_ << " " << OperatorToString(doc->kind) << " " ; |
425 | PrintChildExprConservatively(doc->operands[1], doc); |
426 | } else if (doc->kind == OpKind::kIfThenElse) { |
427 | ICHECK_EQ(doc->operands.size(), 3) |
428 | << "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size(); |
429 | PrintChildExpr(doc->operands[1], doc); |
430 | output_ << " if " ; |
431 | PrintChildExprConservatively(doc->operands[0], doc); |
432 | output_ << " else " ; |
433 | PrintChildExprConservatively(doc->operands[2], doc); |
434 | } else { |
435 | LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast<int>(doc->kind); |
436 | throw; |
437 | } |
438 | } |
439 | |
440 | void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { |
441 | PrintChildExpr(doc->callee, doc); |
442 | |
443 | output_ << "(" ; |
444 | |
445 | // Print positional args |
446 | bool is_first = true; |
447 | for (const ExprDoc& arg : doc->args) { |
448 | if (is_first) { |
449 | is_first = false; |
450 | } else { |
451 | output_ << ", " ; |
452 | } |
453 | PrintDoc(arg); |
454 | } |
455 | |
456 | // Print keyword args |
457 | ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) |
458 | << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values." ; |
459 | for (size_t i = 0; i < doc->kwargs_keys.size(); i++) { |
460 | if (is_first) { |
461 | is_first = false; |
462 | } else { |
463 | output_ << ", " ; |
464 | } |
465 | const String& keyword = doc->kwargs_keys[i]; |
466 | output_ << keyword; |
467 | output_ << "=" ; |
468 | PrintDoc(doc->kwargs_values[i]); |
469 | } |
470 | |
471 | output_ << ")" ; |
472 | } |
473 | |
474 | void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) { |
475 | output_ << "lambda " ; |
476 | PrintJoinedDocs(doc->args, ", " ); |
477 | output_ << ": " ; |
478 | PrintChildExpr(doc->body, doc); |
479 | } |
480 | |
481 | void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) { |
482 | output_ << "[" ; |
483 | PrintJoinedDocs(doc->elements, ", " ); |
484 | output_ << "]" ; |
485 | } |
486 | |
487 | void PythonDocPrinter::PrintTypedDoc(const TupleDoc& doc) { |
488 | output_ << "(" ; |
489 | if (doc->elements.size() == 1) { |
490 | PrintDoc(doc->elements[0]); |
491 | output_ << "," ; |
492 | } else { |
493 | PrintJoinedDocs(doc->elements, ", " ); |
494 | } |
495 | output_ << ")" ; |
496 | } |
497 | |
498 | void PythonDocPrinter::PrintTypedDoc(const DictDoc& doc) { |
499 | ICHECK_EQ(doc->keys.size(), doc->values.size()) |
500 | << "DictDoc should have equal number of elements in keys and values." ; |
501 | output_ << "{" ; |
502 | size_t idx = 0; |
503 | for (const ExprDoc& key : doc->keys) { |
504 | if (idx > 0) { |
505 | output_ << ", " ; |
506 | } |
507 | PrintDoc(key); |
508 | output_ << ": " ; |
509 | PrintDoc(doc->values[idx]); |
510 | idx++; |
511 | } |
512 | output_ << "}" ; |
513 | } |
514 | |
515 | void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) { |
516 | if (doc->start != nullptr) { |
517 | PrintDoc(doc->start.value()); |
518 | } |
519 | output_ << ":" ; |
520 | if (doc->stop != nullptr) { |
521 | PrintDoc(doc->stop.value()); |
522 | } |
523 | if (doc->step != nullptr) { |
524 | output_ << ":" ; |
525 | PrintDoc(doc->step.value()); |
526 | } |
527 | } |
528 | |
529 | void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) { |
530 | for (const StmtDoc& stmt : doc->stmts) { |
531 | PrintDoc(stmt); |
532 | if (stmt != doc->stmts.back()) { |
533 | NewLine(); |
534 | } |
535 | } |
536 | } |
537 | |
538 | void PythonDocPrinter::PrintTypedDoc(const AssignDoc& doc) { |
539 | if (const auto* tuple_doc = doc->lhs.as<TupleDocNode>()) { |
540 | PrintJoinedDocs(tuple_doc->elements, ", " ); |
541 | } else { |
542 | PrintDoc(doc->lhs); |
543 | } |
544 | |
545 | if (doc->annotation) { |
546 | output_ << ": " ; |
547 | PrintDoc(doc->annotation.value()); |
548 | } |
549 | if (doc->rhs) { |
550 | output_ << " = " ; |
551 | PrintDoc(doc->rhs.value()); |
552 | } |
553 | MaybePrintCommentInline(doc); |
554 | } |
555 | |
556 | void PythonDocPrinter::PrintTypedDoc(const IfDoc& doc) { |
557 | MaybePrintCommenMultiLines(doc, true); |
558 | output_ << "if " ; |
559 | PrintDoc(doc->predicate); |
560 | output_ << ":" ; |
561 | |
562 | PrintIndentedBlock(doc->then_branch); |
563 | |
564 | if (!doc->else_branch.empty()) { |
565 | NewLine(); |
566 | output_ << "else:" ; |
567 | PrintIndentedBlock(doc->else_branch); |
568 | } |
569 | } |
570 | |
571 | void PythonDocPrinter::PrintTypedDoc(const WhileDoc& doc) { |
572 | MaybePrintCommenMultiLines(doc, true); |
573 | output_ << "while " ; |
574 | PrintDoc(doc->predicate); |
575 | output_ << ":" ; |
576 | |
577 | PrintIndentedBlock(doc->body); |
578 | } |
579 | |
580 | void PythonDocPrinter::PrintTypedDoc(const ForDoc& doc) { |
581 | MaybePrintCommenMultiLines(doc, true); |
582 | output_ << "for " ; |
583 | if (const auto* tuple = doc->lhs.as<TupleDocNode>()) { |
584 | if (tuple->elements.size() == 1) { |
585 | PrintDoc(tuple->elements[0]); |
586 | output_ << "," ; |
587 | } else { |
588 | PrintJoinedDocs(tuple->elements, ", " ); |
589 | } |
590 | } else { |
591 | PrintDoc(doc->lhs); |
592 | } |
593 | output_ << " in " ; |
594 | PrintDoc(doc->rhs); |
595 | output_ << ":" ; |
596 | |
597 | PrintIndentedBlock(doc->body); |
598 | } |
599 | |
600 | void PythonDocPrinter::PrintTypedDoc(const ScopeDoc& doc) { |
601 | MaybePrintCommenMultiLines(doc, true); |
602 | output_ << "with " ; |
603 | PrintDoc(doc->rhs); |
604 | if (doc->lhs != nullptr) { |
605 | output_ << " as " ; |
606 | PrintDoc(doc->lhs.value()); |
607 | } |
608 | output_ << ":" ; |
609 | |
610 | PrintIndentedBlock(doc->body); |
611 | } |
612 | |
613 | void PythonDocPrinter::PrintTypedDoc(const ExprStmtDoc& doc) { |
614 | PrintDoc(doc->expr); |
615 | MaybePrintCommentInline(doc); |
616 | } |
617 | |
618 | void PythonDocPrinter::PrintTypedDoc(const AssertDoc& doc) { |
619 | output_ << "assert " ; |
620 | PrintDoc(doc->test); |
621 | if (doc->msg.defined()) { |
622 | output_ << ", " ; |
623 | PrintDoc(doc->msg.value()); |
624 | } |
625 | MaybePrintCommentInline(doc); |
626 | } |
627 | |
628 | void PythonDocPrinter::PrintTypedDoc(const ReturnDoc& doc) { |
629 | output_ << "return " ; |
630 | PrintDoc(doc->value); |
631 | MaybePrintCommentInline(doc); |
632 | } |
633 | |
634 | void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) { |
635 | for (const AssignDoc& arg_doc : doc->args) { |
636 | ICHECK(arg_doc->comment == nullptr) << "Function arg cannot have comment attached to them." ; |
637 | } |
638 | |
639 | PrintDecorators(doc->decorators); |
640 | |
641 | output_ << "def " ; |
642 | PrintDoc(doc->name); |
643 | |
644 | output_ << "(" ; |
645 | PrintJoinedDocs(doc->args, ", " ); |
646 | output_ << ")" ; |
647 | |
648 | if (doc->return_type.defined()) { |
649 | output_ << " -> " ; |
650 | PrintDoc(doc->return_type.value()); |
651 | } |
652 | |
653 | output_ << ":" ; |
654 | |
655 | if (doc->comment.defined()) { |
656 | PrintBlockComment(doc->comment.value()); |
657 | } |
658 | PrintIndentedBlock(doc->body); |
659 | NewLineWithoutIndent(); |
660 | } |
661 | |
662 | void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) { |
663 | PrintDecorators(doc->decorators); |
664 | |
665 | output_ << "class " ; |
666 | PrintDoc(doc->name); |
667 | output_ << ":" ; |
668 | |
669 | if (doc->comment.defined()) { |
670 | PrintBlockComment(doc->comment.value()); |
671 | } |
672 | PrintIndentedBlock(doc->body); |
673 | NewLineWithoutIndent(); |
674 | } |
675 | |
676 | void PythonDocPrinter::(const CommentDoc& doc) { |
677 | if (doc->comment.defined()) { |
678 | MaybePrintCommenMultiLines(doc, false); |
679 | } |
680 | } |
681 | |
682 | void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) { |
683 | if (doc->comment.defined() && !doc->comment.value().empty()) { |
684 | PrintDocString(doc->comment.value()); |
685 | } |
686 | } |
687 | |
688 | String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { |
689 | if (cfg->num_context_lines < 0) { |
690 | cfg->num_context_lines = std::numeric_limits<int32_t>::max(); |
691 | } |
692 | PythonDocPrinter printer(cfg); |
693 | printer.Append(doc, cfg); |
694 | std::string result = printer.GetString(); |
695 | int last_space = result.size(); |
696 | while (last_space > 0 && std::isspace(result[last_space - 1])) { |
697 | last_space--; |
698 | } |
699 | return result.substr(0, last_space); |
700 | } |
701 | |
702 | TVM_REGISTER_GLOBAL("script.printer.DocToPythonScript" ).set_body_typed(DocToPythonScript); |
703 | |
704 | } // namespace printer |
705 | } // namespace script |
706 | } // namespace tvm |
707 | |