1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file text_printer.h |
22 | * \brief Printer to print out the unified IR text format |
23 | * that can be parsed by a parser. |
24 | */ |
25 | |
26 | #ifndef TVM_PRINTER_TEXT_PRINTER_H_ |
27 | #define TVM_PRINTER_TEXT_PRINTER_H_ |
28 | |
29 | #include <tvm/ir/module.h> |
30 | #include <tvm/ir/type_functor.h> |
31 | #include <tvm/relay/expr_functor.h> |
32 | #include <tvm/relay/pattern_functor.h> |
33 | #include <tvm/tir/expr_functor.h> |
34 | #include <tvm/tir/function.h> |
35 | #include <tvm/tir/op.h> |
36 | #include <tvm/tir/stmt_functor.h> |
37 | #include <tvm/tir/var.h> |
38 | |
39 | #include <string> |
40 | #include <unordered_map> |
41 | #include <unordered_set> |
42 | #include <vector> |
43 | |
44 | #include "../ir/attr_functor.h" |
45 | #include "../relay/analysis/dependency_graph.h" |
46 | #include "doc.h" |
47 | #include "meta_data.h" |
48 | #include "text_printer.h" |
49 | |
50 | namespace tvm { |
51 | class TextPrinter; |
52 | } // namespace tvm |
53 | |
54 | namespace tvm { |
55 | namespace relay { |
56 | |
57 | class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>, |
58 | public PatternFunctor<Doc(const Pattern&)>, |
59 | public TypeFunctor<Doc(const Type&)>, |
60 | public AttrFunctor<Doc(const ObjectRef&)> { |
61 | public: |
62 | explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta, |
63 | runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) |
64 | : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {} |
65 | Doc VisitExpr(const Expr& expr) override; |
66 | virtual Doc VisitLeaf(const Expr& expr); |
67 | virtual bool CheckVisited(const Expr& expr); |
68 | |
69 | /*! |
70 | * \brief Print additional info about expr in comment. |
71 | * \param expr The expression. |
72 | */ |
73 | Doc PrintOptionalInfo(const Expr& expr); |
74 | // indent a new body |
75 | Doc PrintBody(const ObjectRef& node, int indent = 2); |
76 | // create a new scope by creating a new printer object. This allows temp var |
77 | // numbers to be reused and prevents hoisted vars from escaping too far |
78 | Doc PrintScope(const ObjectRef& node); |
79 | Doc PrintFinal(const ObjectRef& node); |
80 | |
81 | /*! |
82 | * \brief Returns \p attrs printed using the generic attribute visitor, as a sequence |
83 | * of key=value entries, if any. |
84 | */ |
85 | void AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs, bool include_type_key); |
86 | |
87 | /*! |
88 | * \brief Returns \p attrs printed as a sequence of key=value entries, if any. |
89 | * This is used for call attributes. |
90 | */ |
91 | std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op); |
92 | |
93 | /*! |
94 | * \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any. |
95 | * This is used for function definition attributes. |
96 | */ |
97 | std::vector<Doc> PrintDictAttrs(const DictAttrs& dict_attrs); |
98 | std::vector<Doc> PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs); |
99 | |
100 | /*! |
101 | * \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta |
102 | * is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag. |
103 | */ |
104 | Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false); |
105 | |
106 | /*! |
107 | * \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces. |
108 | */ |
109 | Doc PrintAttrsAsAttributeValue(const Attrs& attrs); |
110 | |
111 | /*! |
112 | * \brief Returns \p map printed as a self-contained value, ie wrapped in braces. |
113 | */ |
114 | Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map); |
115 | |
116 | Doc PrintSpan(const Span& span); |
117 | |
118 | Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); |
119 | |
120 | Doc TempVar(int n); |
121 | Doc AllocTemp(); |
122 | /*! |
123 | * \brief get a unique name with the corresponding prefix |
124 | * \param prefix The prefix of the name |
125 | * \return The returned name. |
126 | */ |
127 | Doc GetUniqueName(const std::string& prefix); |
128 | Doc Print(Kind k); |
129 | /*! |
130 | * \brief Allocate name to a type variable. |
131 | * \param var The input type variable. |
132 | * \return The corresponding name. |
133 | */ |
134 | Doc AllocTypeVar(const TypeVar& var); |
135 | /*! |
136 | * \brief Allocate name to a variable. |
137 | * \param var The input variable. |
138 | * \return The corresponding name. |
139 | */ |
140 | Doc AllocVar(const Var& var); |
141 | bool IsUnique(const Expr& expr); |
142 | bool AlwaysInline(const Expr& expr); |
143 | |
144 | Doc PrintFunc(const Doc& prefix, const relay::Function& fn); |
145 | Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func); |
146 | Doc PrintMod(const IRModule& mod); |
147 | |
148 | //------------------------------------ |
149 | // Overload of Expr printing functions |
150 | //------------------------------------ |
151 | Doc PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info = true); |
152 | // Should only be triggered when op is a free variable being visited for the |
153 | // first time. |
154 | Doc VisitExpr_(const VarNode* op) final; |
155 | Doc VisitExpr_(const ConstantNode* op) final; |
156 | Doc VisitExpr_(const TupleNode* op) final; |
157 | Doc VisitExpr_(const TupleGetItemNode* op) final; |
158 | Doc VisitExpr_(const IfNode* op) final; |
159 | Doc VisitExpr_(const LetNode* op) final; |
160 | Doc VisitExpr_(const FunctionNode* op) final; |
161 | Doc VisitExpr_(const GlobalVarNode* op) final; |
162 | Doc VisitExpr_(const OpNode* op) final; |
163 | Doc VisitExpr_(const CallNode* op) final; |
164 | Doc VisitExpr_(const RefCreateNode* op) final; |
165 | Doc VisitExpr_(const RefReadNode* op) final; |
166 | Doc VisitExpr_(const RefWriteNode* op) final; |
167 | Doc VisitExpr_(const MatchNode* op) final; |
168 | Doc PrintPattern(const Pattern& pattern, bool meta); |
169 | Doc VisitPattern_(const PatternConstructorNode* p) final; |
170 | Doc VisitPattern_(const PatternTupleNode* pt) final; |
171 | Doc VisitPattern_(const PatternWildcardNode* pw) final; |
172 | Doc VisitPattern_(const PatternVarNode* pv) final; |
173 | Doc VisitExpr_(const ConstructorNode* n) final; |
174 | //------------------------------------ |
175 | // Overload of Type printing functions |
176 | //------------------------------------ |
177 | Doc PrintType(const Type& type, bool meta); |
178 | Doc VisitTypeDefault_(const Object* node) final; |
179 | Doc VisitType_(const TypeVarNode* node) final; |
180 | Doc VisitType_(const GlobalTypeVarNode* node) final; |
181 | Doc VisitType_(const TypeCallNode* node) final; |
182 | Doc PrintDType(DataType dtype); |
183 | Doc VisitType_(const TensorTypeNode* node) final; |
184 | Doc VisitType_(const TupleTypeNode* node) final; |
185 | Doc VisitType_(const FuncTypeNode* node) final; |
186 | Doc VisitType_(const RelayRefTypeNode* node) final; |
187 | Doc VisitType_(const TypeDataNode* node) final; |
188 | //------------------------------------ |
189 | // Overload of Attr printing functions |
190 | //------------------------------------ |
191 | Doc VisitAttrDefault_(const Object* op) final; |
192 | Doc VisitAttr_(const ArrayNode* op) final; |
193 | Doc VisitAttr_(const tir::IntImmNode* op) final; |
194 | Doc VisitAttr_(const tir::FloatImmNode* op) final; |
195 | Doc VisitAttr_(const tir::StringImmNode* op) final; |
196 | |
197 | private: |
198 | /*! \brief Whether to print meta data. */ |
199 | bool show_meta_data_; |
200 | /*! \brief additional comment function */ |
201 | runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_; |
202 | /*! \brief Stack of docs to implement scoped GNFing. */ |
203 | std::vector<Doc> doc_stack_{}; |
204 | /*! \brief Set for introduced vars */ |
205 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_; |
206 | /*! \brief Set for exprs have been printed optional information */ |
207 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> opt_info_memo_; |
208 | /*! \brief Map for result and memo_ diffs for visited expression */ |
209 | std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_; |
210 | /*! \brief Map from Expr to Doc */ |
211 | std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> memo_; |
212 | /*! \brief Map from Type to Doc */ |
213 | std::unordered_map<Type, Doc, ObjectPtrHash, ObjectPtrEqual> memo_type_; |
214 | /*! \brief Map from Type to Doc */ |
215 | std::unordered_map<Pattern, Doc, ObjectPtrHash, ObjectPtrEqual> memo_pattern_; |
216 | /*! \brief name allocation map */ |
217 | std::unordered_map<std::string, int> name_alloc_map_; |
218 | /*! \brief meta data context */ |
219 | TextMetaDataContext* meta_; |
220 | /*! \brief counter of temporary variable */ |
221 | size_t temp_var_counter_{0}; |
222 | /*! \brief whether the printer is currently in an ADT definition */ |
223 | bool in_adt_def_; |
224 | /*! \brief arena for dependency graph */ |
225 | support::Arena arena_; |
226 | /*! \brief dependency graph of the expr */ |
227 | DependencyGraph dg_; |
228 | class AttrPrinter; |
229 | friend class AttrPrinter; |
230 | friend class tvm::TextPrinter; |
231 | }; |
232 | |
233 | } // namespace relay |
234 | } // namespace tvm |
235 | |
236 | namespace tvm { |
237 | namespace tir { |
238 | |
239 | /*! |
240 | * \brief Meta node collector |
241 | * If we decide to put some node into meta, then all the sub-nodes inside |
242 | * it need to be put in meta as well, since when parsing we need to know |
243 | * whether two refs are the same |
244 | */ |
245 | class MetaCollector : public StmtExprVisitor { |
246 | public: |
247 | explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {} |
248 | |
249 | void Collect(const ObjectRef& n) { |
250 | // these nodes can be print directly(StringLiteral or use identifier to identify) |
251 | if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>() || |
252 | n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) { |
253 | return; |
254 | } |
255 | if (n->IsInstance<StmtNode>()) { |
256 | VisitStmt(Downcast<Stmt>(n)); |
257 | } else if (n->IsInstance<PrimExprNode>()) { |
258 | VisitExpr(Downcast<PrimExpr>(n)); |
259 | } |
260 | } |
261 | |
262 | void VisitStmt(const Stmt& n) override { |
263 | meta_->GetMetaNode(n); |
264 | StmtVisitor::VisitStmt(n); |
265 | } |
266 | |
267 | void VisitExpr(const PrimExpr& n) override { |
268 | meta_->GetMetaNode(n); |
269 | ExprVisitor::VisitExpr(n); |
270 | } |
271 | |
272 | private: |
273 | TextMetaDataContext* meta_; |
274 | }; |
275 | |
276 | class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>, |
277 | public ExprFunctor<Doc(const PrimExpr&)>, |
278 | public TypeFunctor<Doc(const Type&)> { |
279 | public: |
280 | explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta) |
281 | : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {} |
282 | |
283 | /*! \brief Output a newline */ |
284 | virtual Doc NewLine(); |
285 | |
286 | /*! \brief Print the node */ |
287 | Doc Print(const ObjectRef& node); |
288 | |
289 | /*! \brief Place into `s` the name used in the preceding Print call for `v`. |
290 | * \param v Var instance to check. Must point to a VarNode visited by Print. |
291 | * \param s String to receive the name. |
292 | * \return true when a name re-mapping was found. |
293 | */ |
294 | bool GetVarName(::tvm::tir::Var v, std::string* s); |
295 | |
296 | protected: |
297 | Doc VisitExpr_(const IntImmNode* op) override; |
298 | Doc VisitExpr_(const FloatImmNode* op) override; |
299 | Doc VisitExpr_(const StringImmNode* op) override; |
300 | Doc VisitExpr_(const CastNode* op) override; |
301 | Doc VisitExpr_(const VarNode* op) override; |
302 | Doc VisitExpr_(const AddNode* op) override; |
303 | Doc VisitExpr_(const SubNode* op) override; |
304 | Doc VisitExpr_(const MulNode* op) override; |
305 | Doc VisitExpr_(const DivNode* op) override; |
306 | Doc VisitExpr_(const ModNode* op) override; |
307 | Doc VisitExpr_(const FloorDivNode* op) override; |
308 | Doc VisitExpr_(const FloorModNode* op) override; |
309 | Doc VisitExpr_(const MinNode* op) override; |
310 | Doc VisitExpr_(const MaxNode* op) override; |
311 | Doc VisitExpr_(const EQNode* op) override; |
312 | Doc VisitExpr_(const NENode* op) override; |
313 | Doc VisitExpr_(const LTNode* op) override; |
314 | Doc VisitExpr_(const LENode* op) override; |
315 | Doc VisitExpr_(const GTNode* op) override; |
316 | Doc VisitExpr_(const GENode* op) override; |
317 | Doc VisitExpr_(const AndNode* op) override; |
318 | Doc VisitExpr_(const OrNode* op) override; |
319 | Doc VisitExpr_(const NotNode* op) override; |
320 | Doc VisitExpr_(const SelectNode* op) override; |
321 | Doc VisitExpr_(const BufferLoadNode* op) override; |
322 | Doc VisitExpr_(const ProducerLoadNode* op) override; |
323 | Doc VisitExpr_(const LoadNode* op) override; |
324 | Doc VisitExpr_(const RampNode* op) override; |
325 | Doc VisitExpr_(const BroadcastNode* op) override; |
326 | Doc VisitExpr_(const LetNode* op) override; |
327 | Doc VisitExpr_(const CallNode* op) override; |
328 | Doc VisitExpr_(const ShuffleNode* op) override; |
329 | Doc VisitExpr_(const ReduceNode* op) override; |
330 | Doc VisitExprDefault_(const Object* op) override; |
331 | |
332 | Doc VisitStmt_(const LetStmtNode* op) override; |
333 | Doc VisitStmt_(const AttrStmtNode* op) override; |
334 | Doc VisitStmt_(const AssertStmtNode* op) override; |
335 | Doc VisitStmt_(const StoreNode* op) override; |
336 | Doc VisitStmt_(const BufferStoreNode* op) override; |
337 | Doc VisitStmt_(const ProducerStoreNode* op) override; |
338 | Doc VisitStmt_(const BufferRealizeNode* op) override; |
339 | Doc VisitStmt_(const ProducerRealizeNode* op) override; |
340 | Doc VisitStmt_(const AllocateNode* op) override; |
341 | Doc VisitStmt_(const AllocateConstNode* op) override; |
342 | Doc VisitStmt_(const DeclBufferNode* op) override; |
343 | Doc VisitStmt_(const IfThenElseNode* op) override; |
344 | Doc VisitStmt_(const SeqStmtNode* op) override; |
345 | Doc VisitStmt_(const EvaluateNode* op) override; |
346 | Doc VisitStmt_(const ForNode* op) override; |
347 | Doc VisitStmt_(const WhileNode* op) override; |
348 | Doc VisitStmt_(const PrefetchNode* op) override; |
349 | Doc VisitStmt_(const BlockRealizeNode* op) override; |
350 | Doc VisitStmtDefault_(const Object* op) override; |
351 | |
352 | private: |
353 | /*! \brief whether show meta data */ |
354 | bool show_meta_; |
355 | /*! \brief meta data context */ |
356 | TextMetaDataContext* meta_; |
357 | /*! \brief meta collector */ |
358 | MetaCollector meta_collector_; |
359 | /*! \brief Map from Var to Doc */ |
360 | std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_; |
361 | /*! \brief Map from Buffer to Doc */ |
362 | std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_; |
363 | /*! \brief Map from Buffer to Doc */ |
364 | std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_; |
365 | /*! \brief name allocation map */ |
366 | std::unordered_map<std::string, int> name_alloc_map_; |
367 | |
368 | friend class tvm::TextPrinter; |
369 | |
370 | Doc VisitType_(const PrimTypeNode* node) override; |
371 | Doc VisitType_(const PointerTypeNode* node) override; |
372 | Doc VisitType_(const TupleTypeNode* node) override; |
373 | |
374 | Doc PrintIRModule(const IRModule& module); |
375 | Doc PrintPrimFunc(const PrimFunc& primFunc); |
376 | Doc PrintArray(const ArrayNode* op); |
377 | Doc PrintIterVar(const IterVarNode* op); |
378 | Doc PrintRange(const RangeNode* op); |
379 | Doc PrintBuffer(const BufferNode* op); |
380 | Doc PrintProducer(const DataProducerNode* op); |
381 | Doc BufferNode2Doc(const BufferNode* op, Doc doc); |
382 | Doc DataProducerNode2Doc(const DataProducerNode* op, Doc doc); |
383 | Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } |
384 | Doc PrintBufferRegion(const BufferRegionNode* op); |
385 | |
386 | /*! |
387 | * \brief special method to print out data type |
388 | * \param dtype The data type |
389 | */ |
390 | static Doc PrintDType(DataType dtype); |
391 | /*! |
392 | * \brief special method to print out const scalar |
393 | * \param dtype The data type |
394 | * \param data The pointer to hold the data. |
395 | */ |
396 | template <typename T> |
397 | static Doc PrintConstScalar(DataType dtype, const T& data); |
398 | Doc GetUniqueName(std::string prefix); |
399 | Doc AllocVar(const Var& var); |
400 | Doc AllocConst(const AllocateConst& var); |
401 | Doc AllocBuf(const Buffer& buffer); |
402 | Doc AllocProducer(const DataProducer& buffer); |
403 | /*! |
404 | * \brief special method to render vectors of docs with a separator |
405 | * \param vec vector of docs |
406 | * \param sep separator |
407 | */ |
408 | static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep); |
409 | Doc PrintBody(const Stmt& body, bool indent = true); |
410 | }; |
411 | |
412 | String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "T" , bool show_meta = false); |
413 | |
414 | String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, |
415 | runtime::TypedPackedFunc<std::string(Stmt)> annotate); |
416 | |
417 | } // namespace tir |
418 | } // namespace tvm |
419 | |
420 | namespace tvm { |
421 | |
422 | class TextPrinter { |
423 | public: |
424 | explicit TextPrinter(bool show_meta_data, |
425 | const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate, |
426 | bool show_warning = true) |
427 | : show_meta_data_(show_meta_data), |
428 | show_warning_(show_warning), |
429 | annotate_(annotate), |
430 | relay_text_printer_(show_meta_data, &meta_, annotate), |
431 | tir_text_printer_(show_meta_data, &meta_) {} |
432 | |
433 | /*! \brief whether show meta data */ |
434 | bool show_meta_data_; |
435 | |
436 | /*! \brief whether show the meta data warning message */ |
437 | bool show_warning_; |
438 | |
439 | /*! \brief meta data context */ |
440 | TextMetaDataContext meta_; |
441 | /*! \brief additional comment function */ |
442 | runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_; |
443 | /*! \brief Relay Text Printer */ |
444 | relay::RelayTextPrinter relay_text_printer_; |
445 | /*! \brief TIR Text Printer */ |
446 | tir::TIRTextPrinter tir_text_printer_; |
447 | |
448 | bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); } |
449 | |
450 | Doc PrintFinal(const ObjectRef& node) { |
451 | Doc doc; |
452 | if (node.defined() && node->IsInstance<IRModuleNode>()) { |
453 | doc << PrintMod(Downcast<IRModule>(node)); |
454 | } else if (node.defined() && |
455 | (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() || |
456 | node->IsInstance<tir::StmtNode>())) { |
457 | doc << tir_text_printer_.Print(node); |
458 | } else { |
459 | doc << relay_text_printer_.PrintFinal(node); |
460 | } |
461 | if (!meta_.empty()) { |
462 | doc << Doc::NewLine(); |
463 | if (show_meta_data_) { |
464 | doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection(); |
465 | } else if (show_warning_) { |
466 | doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine() |
467 | << " * If you would like to see the full metadata section you can set the " |
468 | << Doc::NewLine() << " * option to `True` when invoking `astext`. " << Doc::NewLine() |
469 | << " */" ; |
470 | } |
471 | } |
472 | return doc; |
473 | } |
474 | |
475 | Doc PrintMod(const IRModule& mod); |
476 | }; |
477 | } // namespace tvm |
478 | |
479 | #endif // TVM_PRINTER_TEXT_PRINTER_H_ |
480 | |