1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file text_printer.h |
22 | * \brief Printer to print out the unified IR text format |
23 | * that can be parsed by a parser. |
24 | */ |
25 | |
26 | #ifndef TVM_RELAY_PRINTER_TEXT_PRINTER_H_ |
27 | #define TVM_RELAY_PRINTER_TEXT_PRINTER_H_ |
28 | |
29 | #include <tvm/ir/module.h> |
30 | #include <tvm/ir/type_functor.h> |
31 | #include <tvm/relay/expr_functor.h> |
32 | #include <tvm/relay/pattern_functor.h> |
33 | #include <tvm/tir/expr_functor.h> |
34 | #include <tvm/tir/function.h> |
35 | #include <tvm/tir/op.h> |
36 | #include <tvm/tir/stmt_functor.h> |
37 | #include <tvm/tir/var.h> |
38 | |
39 | #include <string> |
40 | #include <unordered_map> |
41 | #include <unordered_set> |
42 | #include <vector> |
43 | |
44 | #include "../../ir/attr_functor.h" |
45 | #include "../analysis/dependency_graph.h" |
46 | #include "doc.h" |
47 | #include "meta_data.h" |
48 | |
49 | namespace tvm { |
50 | namespace relay { |
51 | |
52 | class TextPrinter; |
53 | |
54 | class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>, |
55 | public PatternFunctor<Doc(const Pattern&)>, |
56 | public TypeFunctor<Doc(const Type&)>, |
57 | public AttrFunctor<Doc(const ObjectRef&)> { |
58 | public: |
59 | explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta, |
60 | runtime::TypedPackedFunc<std::string(ObjectRef)> annotate) |
61 | : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {} |
62 | Doc VisitExpr(const Expr& expr) override; |
63 | virtual Doc VisitLeaf(const Expr& expr); |
64 | virtual bool CheckVisited(const Expr& expr); |
65 | |
66 | /*! |
67 | * \brief Print additional info about expr in comment. |
68 | * \param expr The expression. |
69 | */ |
70 | Doc PrintOptionalInfo(const Expr& expr); |
71 | // indent a new body |
72 | Doc PrintBody(const ObjectRef& node, int indent = 2); |
73 | // create a new scope by creating a new printer object. This allows temp var |
74 | // numbers to be reused and prevents hoisted vars from escaping too far |
75 | Doc PrintScope(const ObjectRef& node); |
76 | Doc PrintFinal(const ObjectRef& node); |
77 | |
78 | /*! |
79 | * \brief Returns \p attrs printed using the generic attribute visitor, as a sequence |
80 | * of key=value entries, if any. |
81 | */ |
82 | void AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs, bool include_type_key); |
83 | |
84 | /*! |
85 | * \brief Returns \p attrs printed as a sequence of key=value entries, if any. |
86 | * This is used for call attributes. |
87 | */ |
88 | std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op); |
89 | |
90 | /*! |
91 | * \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any. |
92 | * This is used for function definition attributes. |
93 | */ |
94 | std::vector<Doc> PrintDictAttrs(const DictAttrs& dict_attrs); |
95 | std::vector<Doc> PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs); |
96 | |
97 | /*! |
98 | * \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta |
99 | * is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag. |
100 | */ |
101 | Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false); |
102 | |
103 | /*! |
104 | * \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces. |
105 | */ |
106 | Doc PrintAttrsAsAttributeValue(const Attrs& attrs); |
107 | |
108 | /*! |
109 | * \brief Returns \p map printed as a self-contained value, ie wrapped in braces. |
110 | */ |
111 | Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map); |
112 | |
113 | Doc PrintSpan(const Span& span); |
114 | |
115 | Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); |
116 | |
117 | Doc TempVar(int n); |
118 | Doc AllocTemp(); |
119 | /*! |
120 | * \brief get a unique name with the corresponding prefix |
121 | * \param prefix The prefix of the name |
122 | * \return The returned name. |
123 | */ |
124 | Doc GetUniqueName(const std::string& prefix); |
125 | Doc Print(Kind k); |
126 | /*! |
127 | * \brief Allocate name to a type variable. |
128 | * \param var The input type variable. |
129 | * \return The corresponding name. |
130 | */ |
131 | Doc AllocTypeVar(const TypeVar& var); |
132 | /*! |
133 | * \brief Allocate name to a variable. |
134 | * \param var The input variable. |
135 | * \return The corresponding name. |
136 | */ |
137 | Doc AllocVar(const Var& var); |
138 | bool IsUnique(const Expr& expr); |
139 | bool AlwaysInline(const Expr& expr); |
140 | |
141 | Doc PrintFunc(const Doc& prefix, const relay::Function& fn); |
142 | Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func); |
143 | Doc PrintMod(const IRModule& mod); |
144 | |
145 | //------------------------------------ |
146 | // Overload of Expr printing functions |
147 | //------------------------------------ |
148 | Doc PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info = true); |
149 | // Should only be triggered when op is a free variable being visited for the |
150 | // first time. |
151 | Doc VisitExpr_(const VarNode* op) final; |
152 | Doc VisitExpr_(const ConstantNode* op) final; |
153 | Doc VisitExpr_(const TupleNode* op) final; |
154 | Doc VisitExpr_(const TupleGetItemNode* op) final; |
155 | Doc VisitExpr_(const IfNode* op) final; |
156 | Doc VisitExpr_(const LetNode* op) final; |
157 | Doc VisitExpr_(const FunctionNode* op) final; |
158 | Doc VisitExpr_(const GlobalVarNode* op) final; |
159 | Doc VisitExpr_(const OpNode* op) final; |
160 | Doc VisitExpr_(const CallNode* op) final; |
161 | Doc VisitExpr_(const RefCreateNode* op) final; |
162 | Doc VisitExpr_(const RefReadNode* op) final; |
163 | Doc VisitExpr_(const RefWriteNode* op) final; |
164 | Doc VisitExpr_(const MatchNode* op) final; |
165 | Doc PrintPattern(const Pattern& pattern, bool meta); |
166 | Doc VisitPattern_(const PatternConstructorNode* p) final; |
167 | Doc VisitPattern_(const PatternTupleNode* pt) final; |
168 | Doc VisitPattern_(const PatternWildcardNode* pw) final; |
169 | Doc VisitPattern_(const PatternVarNode* pv) final; |
170 | Doc VisitExpr_(const ConstructorNode* n) final; |
171 | //------------------------------------ |
172 | // Overload of Type printing functions |
173 | //------------------------------------ |
174 | Doc PrintType(const Type& type, bool meta); |
175 | Doc VisitTypeDefault_(const Object* node) final; |
176 | Doc VisitType_(const TypeVarNode* node) final; |
177 | Doc VisitType_(const GlobalTypeVarNode* node) final; |
178 | Doc VisitType_(const TypeCallNode* node) final; |
179 | Doc PrintDType(DataType dtype); |
180 | Doc VisitType_(const TensorTypeNode* node) final; |
181 | Doc VisitType_(const TupleTypeNode* node) final; |
182 | Doc VisitType_(const FuncTypeNode* node) final; |
183 | Doc VisitType_(const RelayRefTypeNode* node) final; |
184 | Doc VisitType_(const TypeDataNode* node) final; |
185 | //------------------------------------ |
186 | // Overload of Attr printing functions |
187 | //------------------------------------ |
188 | Doc VisitAttrDefault_(const Object* op) final; |
189 | Doc VisitAttr_(const ArrayNode* op) final; |
190 | Doc VisitAttr_(const tir::IntImmNode* op) final; |
191 | Doc VisitAttr_(const tir::FloatImmNode* op) final; |
192 | Doc VisitAttr_(const tir::StringImmNode* op) final; |
193 | |
194 | private: |
195 | /*! \brief Whether to print meta data. */ |
196 | bool show_meta_data_; |
197 | /*! \brief additional comment function */ |
198 | runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_; |
199 | /*! \brief Stack of docs to implement scoped GNFing. */ |
200 | std::vector<Doc> doc_stack_{}; |
201 | /*! \brief Set for introduced vars */ |
202 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_; |
203 | /*! \brief Set for exprs have been printed optional information */ |
204 | std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> opt_info_memo_; |
205 | /*! \brief Map for result and memo_ diffs for visited expression */ |
206 | std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_; |
207 | /*! \brief Map from Expr to Doc */ |
208 | std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> memo_; |
209 | /*! \brief Map from Type to Doc */ |
210 | std::unordered_map<Type, Doc, ObjectPtrHash, ObjectPtrEqual> memo_type_; |
211 | /*! \brief Map from Type to Doc */ |
212 | std::unordered_map<Pattern, Doc, ObjectPtrHash, ObjectPtrEqual> memo_pattern_; |
213 | /*! \brief name allocation map */ |
214 | std::unordered_map<std::string, int> name_alloc_map_; |
215 | /*! \brief meta data context */ |
216 | TextMetaDataContext* meta_; |
217 | /*! \brief counter of temporary variable */ |
218 | size_t temp_var_counter_{0}; |
219 | /*! \brief whether the printer is currently in an ADT definition */ |
220 | bool in_adt_def_; |
221 | /*! \brief arena for dependency graph */ |
222 | support::Arena arena_; |
223 | /*! \brief dependency graph of the expr */ |
224 | DependencyGraph dg_; |
225 | class AttrPrinter; |
226 | friend class AttrPrinter; |
227 | friend class tvm::relay::TextPrinter; |
228 | }; |
229 | |
230 | using namespace ::tvm::tir; |
231 | |
232 | /*! |
233 | * \brief Meta node collector |
234 | * If we decide to put some node into meta, then all the sub-nodes inside |
235 | * it need to be put in meta as well, since when parsing we need to know |
236 | * whether two refs are the same |
237 | */ |
238 | class MetaCollector : public StmtExprVisitor { |
239 | public: |
240 | explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {} |
241 | |
242 | void Collect(const ObjectRef& n) { |
243 | // these nodes can be print directly(StringLiteral or use identifier to identify) |
244 | if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>() || |
245 | n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) { |
246 | return; |
247 | } |
248 | if (n->IsInstance<StmtNode>()) { |
249 | VisitStmt(Downcast<Stmt>(n)); |
250 | } else if (n->IsInstance<PrimExprNode>()) { |
251 | VisitExpr(Downcast<PrimExpr>(n)); |
252 | } |
253 | } |
254 | |
255 | void VisitStmt(const Stmt& n) override { |
256 | meta_->GetMetaNode(n); |
257 | StmtVisitor::VisitStmt(n); |
258 | } |
259 | |
260 | void VisitExpr(const PrimExpr& n) override { |
261 | meta_->GetMetaNode(n); |
262 | ExprVisitor::VisitExpr(n); |
263 | } |
264 | |
265 | private: |
266 | TextMetaDataContext* meta_; |
267 | }; |
268 | |
269 | class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>, |
270 | public tir::ExprFunctor<Doc(const PrimExpr&)>, |
271 | public TypeFunctor<Doc(const Type&)> { |
272 | public: |
273 | explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta) |
274 | : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {} |
275 | |
276 | /*! \brief Output a newline */ |
277 | virtual Doc NewLine(); |
278 | |
279 | /*! \brief Print the node */ |
280 | Doc Print(const ObjectRef& node); |
281 | |
282 | /*! \brief Place into `s` the name used in the preceding Print call for `v`. |
283 | * \param v Var instance to check. Must point to a VarNode visited by Print. |
284 | * \param s String to receive the name. |
285 | * \return true when a name re-mapping was found. |
286 | */ |
287 | bool GetVarName(::tvm::tir::Var v, std::string* s); |
288 | |
289 | protected: |
290 | Doc VisitExpr_(const IntImmNode* op) override; |
291 | Doc VisitExpr_(const FloatImmNode* op) override; |
292 | Doc VisitExpr_(const StringImmNode* op) override; |
293 | Doc VisitExpr_(const CastNode* op) override; |
294 | Doc VisitExpr_(const tir::VarNode* op) override; |
295 | Doc VisitExpr_(const AddNode* op) override; |
296 | Doc VisitExpr_(const SubNode* op) override; |
297 | Doc VisitExpr_(const MulNode* op) override; |
298 | Doc VisitExpr_(const DivNode* op) override; |
299 | Doc VisitExpr_(const ModNode* op) override; |
300 | Doc VisitExpr_(const FloorDivNode* op) override; |
301 | Doc VisitExpr_(const FloorModNode* op) override; |
302 | Doc VisitExpr_(const MinNode* op) override; |
303 | Doc VisitExpr_(const MaxNode* op) override; |
304 | Doc VisitExpr_(const EQNode* op) override; |
305 | Doc VisitExpr_(const NENode* op) override; |
306 | Doc VisitExpr_(const LTNode* op) override; |
307 | Doc VisitExpr_(const LENode* op) override; |
308 | Doc VisitExpr_(const GTNode* op) override; |
309 | Doc VisitExpr_(const GENode* op) override; |
310 | Doc VisitExpr_(const AndNode* op) override; |
311 | Doc VisitExpr_(const OrNode* op) override; |
312 | Doc VisitExpr_(const NotNode* op) override; |
313 | Doc VisitExpr_(const SelectNode* op) override; |
314 | Doc VisitExpr_(const BufferLoadNode* op) override; |
315 | Doc VisitExpr_(const ProducerLoadNode* op) override; |
316 | Doc VisitExpr_(const LoadNode* op) override; |
317 | Doc VisitExpr_(const RampNode* op) override; |
318 | Doc VisitExpr_(const BroadcastNode* op) override; |
319 | Doc VisitExpr_(const tir::LetNode* op) override; |
320 | Doc VisitExpr_(const tir::CallNode* op) override; |
321 | Doc VisitExpr_(const ShuffleNode* op) override; |
322 | Doc VisitExpr_(const ReduceNode* op) override; |
323 | Doc VisitExprDefault_(const Object* op) override; |
324 | |
325 | Doc VisitStmt_(const LetStmtNode* op) override; |
326 | Doc VisitStmt_(const AttrStmtNode* op) override; |
327 | Doc VisitStmt_(const AssertStmtNode* op) override; |
328 | Doc VisitStmt_(const StoreNode* op) override; |
329 | Doc VisitStmt_(const BufferStoreNode* op) override; |
330 | Doc VisitStmt_(const ProducerStoreNode* op) override; |
331 | Doc VisitStmt_(const BufferRealizeNode* op) override; |
332 | Doc VisitStmt_(const ProducerRealizeNode* op) override; |
333 | Doc VisitStmt_(const AllocateNode* op) override; |
334 | Doc VisitStmt_(const AllocateConstNode* op) override; |
335 | Doc VisitStmt_(const DeclBufferNode* op) override; |
336 | Doc VisitStmt_(const IfThenElseNode* op) override; |
337 | Doc VisitStmt_(const SeqStmtNode* op) override; |
338 | Doc VisitStmt_(const EvaluateNode* op) override; |
339 | Doc VisitStmt_(const ForNode* op) override; |
340 | Doc VisitStmt_(const WhileNode* op) override; |
341 | Doc VisitStmt_(const PrefetchNode* op) override; |
342 | Doc VisitStmt_(const BlockRealizeNode* op) override; |
343 | Doc VisitStmtDefault_(const Object* op) override; |
344 | |
345 | private: |
346 | /*! \brief whether show meta data */ |
347 | bool show_meta_; |
348 | /*! \brief meta data context */ |
349 | TextMetaDataContext* meta_; |
350 | /*! \brief meta collector */ |
351 | MetaCollector meta_collector_; |
352 | /*! \brief Map from Var to Doc */ |
353 | std::unordered_map<tir::Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_; |
354 | /*! \brief Map from Buffer to Doc */ |
355 | std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_; |
356 | /*! \brief Map from Buffer to Doc */ |
357 | std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_; |
358 | /*! \brief name allocation map */ |
359 | std::unordered_map<std::string, int> name_alloc_map_; |
360 | |
361 | friend class TextPrinter; |
362 | |
363 | Doc VisitType_(const PrimTypeNode* node) override; |
364 | Doc VisitType_(const PointerTypeNode* node) override; |
365 | Doc VisitType_(const TupleTypeNode* node) override; |
366 | |
367 | Doc PrintIRModule(const IRModule& module); |
368 | Doc PrintPrimFunc(const PrimFunc& primFunc); |
369 | Doc PrintArray(const ArrayNode* op); |
370 | Doc PrintIterVar(const IterVarNode* op); |
371 | Doc PrintRange(const RangeNode* op); |
372 | Doc PrintBuffer(const BufferNode* op); |
373 | Doc PrintProducer(const DataProducerNode* op); |
374 | Doc BufferNode2Doc(const BufferNode* op, Doc doc); |
375 | Doc DataProducerNode2Doc(const DataProducerNode* op, Doc doc); |
376 | Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } |
377 | Doc PrintBufferRegion(const BufferRegionNode* op); |
378 | |
379 | /*! |
380 | * \brief special method to print out data type |
381 | * \param dtype The data type |
382 | */ |
383 | static Doc PrintDType(DataType dtype); |
384 | /*! |
385 | * \brief special method to print out const scalar |
386 | * \param dtype The data type |
387 | * \param data The pointer to hold the data. |
388 | */ |
389 | template <typename T> |
390 | static Doc PrintConstScalar(DataType dtype, const T& data); |
391 | Doc GetUniqueName(std::string prefix); |
392 | Doc AllocVar(const tir::Var& var); |
393 | Doc AllocConst(const AllocateConst& var); |
394 | Doc AllocBuf(const Buffer& buffer); |
395 | Doc AllocProducer(const DataProducer& buffer); |
396 | /*! |
397 | * \brief special method to render vectors of docs with a separator |
398 | * \param vec vector of docs |
399 | * \param sep separator |
400 | */ |
401 | static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep); |
402 | Doc PrintBody(const Stmt& body, bool indent = true); |
403 | }; |
404 | |
405 | String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, |
406 | runtime::TypedPackedFunc<std::string(Stmt)> annotate); |
407 | |
408 | class TextPrinter { |
409 | public: |
410 | explicit TextPrinter(bool show_meta_data, |
411 | const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate, |
412 | bool show_warning = true) |
413 | : show_meta_data_(show_meta_data), |
414 | show_warning_(show_warning), |
415 | annotate_(annotate), |
416 | relay_text_printer_(show_meta_data, &meta_, annotate), |
417 | tir_text_printer_(show_meta_data, &meta_) {} |
418 | |
419 | /*! \brief whether show meta data */ |
420 | bool show_meta_data_; |
421 | |
422 | /*! \brief whether show the meta data warning message */ |
423 | bool show_warning_; |
424 | |
425 | /*! \brief meta data context */ |
426 | TextMetaDataContext meta_; |
427 | /*! \brief additional comment function */ |
428 | runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_; |
429 | /*! \brief Relay Text Printer */ |
430 | relay::RelayTextPrinter relay_text_printer_; |
431 | /*! \brief TIR Text Printer */ |
432 | TIRTextPrinter tir_text_printer_; |
433 | |
434 | bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); } |
435 | |
436 | Doc PrintFinal(const ObjectRef& node) { |
437 | Doc doc; |
438 | if (node.defined() && node->IsInstance<IRModuleNode>()) { |
439 | doc << PrintMod(Downcast<IRModule>(node)); |
440 | } else if (node.defined() && |
441 | (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() || |
442 | node->IsInstance<tir::StmtNode>())) { |
443 | doc << tir_text_printer_.Print(node); |
444 | } else { |
445 | doc << relay_text_printer_.PrintFinal(node); |
446 | } |
447 | if (!meta_.empty()) { |
448 | doc << Doc::NewLine(); |
449 | if (show_meta_data_) { |
450 | doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection(); |
451 | } else if (show_warning_) { |
452 | doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine() |
453 | << " * If you would like to see the full metadata section you can set the " |
454 | << Doc::NewLine() << " * option to `True` when invoking `astext`. " << Doc::NewLine() |
455 | << " */" ; |
456 | } |
457 | } |
458 | return doc; |
459 | } |
460 | |
461 | Doc PrintMod(const IRModule& mod); |
462 | }; |
463 | } // namespace relay |
464 | } // namespace tvm |
465 | |
466 | #endif // TVM_RELAY_PRINTER_TEXT_PRINTER_H_ |
467 | |