1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file text_printer.h
22 * \brief Printer to print out the unified IR text format
23 * that can be parsed by a parser.
24 */
25
26#ifndef TVM_RELAY_PRINTER_TEXT_PRINTER_H_
27#define TVM_RELAY_PRINTER_TEXT_PRINTER_H_
28
29#include <tvm/ir/module.h>
30#include <tvm/ir/type_functor.h>
31#include <tvm/relay/expr_functor.h>
32#include <tvm/relay/pattern_functor.h>
33#include <tvm/tir/expr_functor.h>
34#include <tvm/tir/function.h>
35#include <tvm/tir/op.h>
36#include <tvm/tir/stmt_functor.h>
37#include <tvm/tir/var.h>
38
39#include <string>
40#include <unordered_map>
41#include <unordered_set>
42#include <vector>
43
44#include "../../ir/attr_functor.h"
45#include "../analysis/dependency_graph.h"
46#include "doc.h"
47#include "meta_data.h"
48
49namespace tvm {
50namespace relay {
51
52class TextPrinter;
53
54class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
55 public PatternFunctor<Doc(const Pattern&)>,
56 public TypeFunctor<Doc(const Type&)>,
57 public AttrFunctor<Doc(const ObjectRef&)> {
58 public:
59 explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta,
60 runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
61 : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {}
62 Doc VisitExpr(const Expr& expr) override;
63 virtual Doc VisitLeaf(const Expr& expr);
64 virtual bool CheckVisited(const Expr& expr);
65
66 /*!
67 * \brief Print additional info about expr in comment.
68 * \param expr The expression.
69 */
70 Doc PrintOptionalInfo(const Expr& expr);
71 // indent a new body
72 Doc PrintBody(const ObjectRef& node, int indent = 2);
73 // create a new scope by creating a new printer object. This allows temp var
74 // numbers to be reused and prevents hoisted vars from escaping too far
75 Doc PrintScope(const ObjectRef& node);
76 Doc PrintFinal(const ObjectRef& node);
77
78 /*!
79 * \brief Returns \p attrs printed using the generic attribute visitor, as a sequence
80 * of key=value entries, if any.
81 */
82 void AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs, bool include_type_key);
83
84 /*!
85 * \brief Returns \p attrs printed as a sequence of key=value entries, if any.
86 * This is used for call attributes.
87 */
88 std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
89
90 /*!
91 * \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any.
92 * This is used for function definition attributes.
93 */
94 std::vector<Doc> PrintDictAttrs(const DictAttrs& dict_attrs);
95 std::vector<Doc> PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs);
96
97 /*!
98 * \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta
99 * is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag.
100 */
101 Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false);
102
103 /*!
104 * \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces.
105 */
106 Doc PrintAttrsAsAttributeValue(const Attrs& attrs);
107
108 /*!
109 * \brief Returns \p map printed as a self-contained value, ie wrapped in braces.
110 */
111 Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);
112
113 Doc PrintSpan(const Span& span);
114
115 Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
116
117 Doc TempVar(int n);
118 Doc AllocTemp();
119 /*!
120 * \brief get a unique name with the corresponding prefix
121 * \param prefix The prefix of the name
122 * \return The returned name.
123 */
124 Doc GetUniqueName(const std::string& prefix);
125 Doc Print(Kind k);
126 /*!
127 * \brief Allocate name to a type variable.
128 * \param var The input type variable.
129 * \return The corresponding name.
130 */
131 Doc AllocTypeVar(const TypeVar& var);
132 /*!
133 * \brief Allocate name to a variable.
134 * \param var The input variable.
135 * \return The corresponding name.
136 */
137 Doc AllocVar(const Var& var);
138 bool IsUnique(const Expr& expr);
139 bool AlwaysInline(const Expr& expr);
140
141 Doc PrintFunc(const Doc& prefix, const relay::Function& fn);
142 Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func);
143 Doc PrintMod(const IRModule& mod);
144
145 //------------------------------------
146 // Overload of Expr printing functions
147 //------------------------------------
148 Doc PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info = true);
149 // Should only be triggered when op is a free variable being visited for the
150 // first time.
151 Doc VisitExpr_(const VarNode* op) final;
152 Doc VisitExpr_(const ConstantNode* op) final;
153 Doc VisitExpr_(const TupleNode* op) final;
154 Doc VisitExpr_(const TupleGetItemNode* op) final;
155 Doc VisitExpr_(const IfNode* op) final;
156 Doc VisitExpr_(const LetNode* op) final;
157 Doc VisitExpr_(const FunctionNode* op) final;
158 Doc VisitExpr_(const GlobalVarNode* op) final;
159 Doc VisitExpr_(const OpNode* op) final;
160 Doc VisitExpr_(const CallNode* op) final;
161 Doc VisitExpr_(const RefCreateNode* op) final;
162 Doc VisitExpr_(const RefReadNode* op) final;
163 Doc VisitExpr_(const RefWriteNode* op) final;
164 Doc VisitExpr_(const MatchNode* op) final;
165 Doc PrintPattern(const Pattern& pattern, bool meta);
166 Doc VisitPattern_(const PatternConstructorNode* p) final;
167 Doc VisitPattern_(const PatternTupleNode* pt) final;
168 Doc VisitPattern_(const PatternWildcardNode* pw) final;
169 Doc VisitPattern_(const PatternVarNode* pv) final;
170 Doc VisitExpr_(const ConstructorNode* n) final;
171 //------------------------------------
172 // Overload of Type printing functions
173 //------------------------------------
174 Doc PrintType(const Type& type, bool meta);
175 Doc VisitTypeDefault_(const Object* node) final;
176 Doc VisitType_(const TypeVarNode* node) final;
177 Doc VisitType_(const GlobalTypeVarNode* node) final;
178 Doc VisitType_(const TypeCallNode* node) final;
179 Doc PrintDType(DataType dtype);
180 Doc VisitType_(const TensorTypeNode* node) final;
181 Doc VisitType_(const TupleTypeNode* node) final;
182 Doc VisitType_(const FuncTypeNode* node) final;
183 Doc VisitType_(const RelayRefTypeNode* node) final;
184 Doc VisitType_(const TypeDataNode* node) final;
185 //------------------------------------
186 // Overload of Attr printing functions
187 //------------------------------------
188 Doc VisitAttrDefault_(const Object* op) final;
189 Doc VisitAttr_(const ArrayNode* op) final;
190 Doc VisitAttr_(const tir::IntImmNode* op) final;
191 Doc VisitAttr_(const tir::FloatImmNode* op) final;
192 Doc VisitAttr_(const tir::StringImmNode* op) final;
193
194 private:
195 /*! \brief Whether to print meta data. */
196 bool show_meta_data_;
197 /*! \brief additional comment function */
198 runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
199 /*! \brief Stack of docs to implement scoped GNFing. */
200 std::vector<Doc> doc_stack_{};
201 /*! \brief Set for introduced vars */
202 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
203 /*! \brief Set for exprs have been printed optional information */
204 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> opt_info_memo_;
205 /*! \brief Map for result and memo_ diffs for visited expression */
206 std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_;
207 /*! \brief Map from Expr to Doc */
208 std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> memo_;
209 /*! \brief Map from Type to Doc */
210 std::unordered_map<Type, Doc, ObjectPtrHash, ObjectPtrEqual> memo_type_;
211 /*! \brief Map from Type to Doc */
212 std::unordered_map<Pattern, Doc, ObjectPtrHash, ObjectPtrEqual> memo_pattern_;
213 /*! \brief name allocation map */
214 std::unordered_map<std::string, int> name_alloc_map_;
215 /*! \brief meta data context */
216 TextMetaDataContext* meta_;
217 /*! \brief counter of temporary variable */
218 size_t temp_var_counter_{0};
219 /*! \brief whether the printer is currently in an ADT definition */
220 bool in_adt_def_;
221 /*! \brief arena for dependency graph */
222 support::Arena arena_;
223 /*! \brief dependency graph of the expr */
224 DependencyGraph dg_;
225 class AttrPrinter;
226 friend class AttrPrinter;
227 friend class tvm::relay::TextPrinter;
228};
229
230using namespace ::tvm::tir;
231
232/*!
233 * \brief Meta node collector
234 * If we decide to put some node into meta, then all the sub-nodes inside
235 * it need to be put in meta as well, since when parsing we need to know
236 * whether two refs are the same
237 */
238class MetaCollector : public StmtExprVisitor {
239 public:
240 explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
241
242 void Collect(const ObjectRef& n) {
243 // these nodes can be print directly(StringLiteral or use identifier to identify)
244 if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>() ||
245 n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
246 return;
247 }
248 if (n->IsInstance<StmtNode>()) {
249 VisitStmt(Downcast<Stmt>(n));
250 } else if (n->IsInstance<PrimExprNode>()) {
251 VisitExpr(Downcast<PrimExpr>(n));
252 }
253 }
254
255 void VisitStmt(const Stmt& n) override {
256 meta_->GetMetaNode(n);
257 StmtVisitor::VisitStmt(n);
258 }
259
260 void VisitExpr(const PrimExpr& n) override {
261 meta_->GetMetaNode(n);
262 ExprVisitor::VisitExpr(n);
263 }
264
265 private:
266 TextMetaDataContext* meta_;
267};
268
269class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
270 public tir::ExprFunctor<Doc(const PrimExpr&)>,
271 public TypeFunctor<Doc(const Type&)> {
272 public:
273 explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta)
274 : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {}
275
276 /*! \brief Output a newline */
277 virtual Doc NewLine();
278
279 /*! \brief Print the node */
280 Doc Print(const ObjectRef& node);
281
282 /*! \brief Place into `s` the name used in the preceding Print call for `v`.
283 * \param v Var instance to check. Must point to a VarNode visited by Print.
284 * \param s String to receive the name.
285 * \return true when a name re-mapping was found.
286 */
287 bool GetVarName(::tvm::tir::Var v, std::string* s);
288
289 protected:
290 Doc VisitExpr_(const IntImmNode* op) override;
291 Doc VisitExpr_(const FloatImmNode* op) override;
292 Doc VisitExpr_(const StringImmNode* op) override;
293 Doc VisitExpr_(const CastNode* op) override;
294 Doc VisitExpr_(const tir::VarNode* op) override;
295 Doc VisitExpr_(const AddNode* op) override;
296 Doc VisitExpr_(const SubNode* op) override;
297 Doc VisitExpr_(const MulNode* op) override;
298 Doc VisitExpr_(const DivNode* op) override;
299 Doc VisitExpr_(const ModNode* op) override;
300 Doc VisitExpr_(const FloorDivNode* op) override;
301 Doc VisitExpr_(const FloorModNode* op) override;
302 Doc VisitExpr_(const MinNode* op) override;
303 Doc VisitExpr_(const MaxNode* op) override;
304 Doc VisitExpr_(const EQNode* op) override;
305 Doc VisitExpr_(const NENode* op) override;
306 Doc VisitExpr_(const LTNode* op) override;
307 Doc VisitExpr_(const LENode* op) override;
308 Doc VisitExpr_(const GTNode* op) override;
309 Doc VisitExpr_(const GENode* op) override;
310 Doc VisitExpr_(const AndNode* op) override;
311 Doc VisitExpr_(const OrNode* op) override;
312 Doc VisitExpr_(const NotNode* op) override;
313 Doc VisitExpr_(const SelectNode* op) override;
314 Doc VisitExpr_(const BufferLoadNode* op) override;
315 Doc VisitExpr_(const ProducerLoadNode* op) override;
316 Doc VisitExpr_(const LoadNode* op) override;
317 Doc VisitExpr_(const RampNode* op) override;
318 Doc VisitExpr_(const BroadcastNode* op) override;
319 Doc VisitExpr_(const tir::LetNode* op) override;
320 Doc VisitExpr_(const tir::CallNode* op) override;
321 Doc VisitExpr_(const ShuffleNode* op) override;
322 Doc VisitExpr_(const ReduceNode* op) override;
323 Doc VisitExprDefault_(const Object* op) override;
324
325 Doc VisitStmt_(const LetStmtNode* op) override;
326 Doc VisitStmt_(const AttrStmtNode* op) override;
327 Doc VisitStmt_(const AssertStmtNode* op) override;
328 Doc VisitStmt_(const StoreNode* op) override;
329 Doc VisitStmt_(const BufferStoreNode* op) override;
330 Doc VisitStmt_(const ProducerStoreNode* op) override;
331 Doc VisitStmt_(const BufferRealizeNode* op) override;
332 Doc VisitStmt_(const ProducerRealizeNode* op) override;
333 Doc VisitStmt_(const AllocateNode* op) override;
334 Doc VisitStmt_(const AllocateConstNode* op) override;
335 Doc VisitStmt_(const DeclBufferNode* op) override;
336 Doc VisitStmt_(const IfThenElseNode* op) override;
337 Doc VisitStmt_(const SeqStmtNode* op) override;
338 Doc VisitStmt_(const EvaluateNode* op) override;
339 Doc VisitStmt_(const ForNode* op) override;
340 Doc VisitStmt_(const WhileNode* op) override;
341 Doc VisitStmt_(const PrefetchNode* op) override;
342 Doc VisitStmt_(const BlockRealizeNode* op) override;
343 Doc VisitStmtDefault_(const Object* op) override;
344
345 private:
346 /*! \brief whether show meta data */
347 bool show_meta_;
348 /*! \brief meta data context */
349 TextMetaDataContext* meta_;
350 /*! \brief meta collector */
351 MetaCollector meta_collector_;
352 /*! \brief Map from Var to Doc */
353 std::unordered_map<tir::Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
354 /*! \brief Map from Buffer to Doc */
355 std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
356 /*! \brief Map from Buffer to Doc */
357 std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_;
358 /*! \brief name allocation map */
359 std::unordered_map<std::string, int> name_alloc_map_;
360
361 friend class TextPrinter;
362
363 Doc VisitType_(const PrimTypeNode* node) override;
364 Doc VisitType_(const PointerTypeNode* node) override;
365 Doc VisitType_(const TupleTypeNode* node) override;
366
367 Doc PrintIRModule(const IRModule& module);
368 Doc PrintPrimFunc(const PrimFunc& primFunc);
369 Doc PrintArray(const ArrayNode* op);
370 Doc PrintIterVar(const IterVarNode* op);
371 Doc PrintRange(const RangeNode* op);
372 Doc PrintBuffer(const BufferNode* op);
373 Doc PrintProducer(const DataProducerNode* op);
374 Doc BufferNode2Doc(const BufferNode* op, Doc doc);
375 Doc DataProducerNode2Doc(const DataProducerNode* op, Doc doc);
376 Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); }
377 Doc PrintBufferRegion(const BufferRegionNode* op);
378
379 /*!
380 * \brief special method to print out data type
381 * \param dtype The data type
382 */
383 static Doc PrintDType(DataType dtype);
384 /*!
385 * \brief special method to print out const scalar
386 * \param dtype The data type
387 * \param data The pointer to hold the data.
388 */
389 template <typename T>
390 static Doc PrintConstScalar(DataType dtype, const T& data);
391 Doc GetUniqueName(std::string prefix);
392 Doc AllocVar(const tir::Var& var);
393 Doc AllocConst(const AllocateConst& var);
394 Doc AllocBuf(const Buffer& buffer);
395 Doc AllocProducer(const DataProducer& buffer);
396 /*!
397 * \brief special method to render vectors of docs with a separator
398 * \param vec vector of docs
399 * \param sep separator
400 */
401 static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep);
402 Doc PrintBody(const Stmt& body, bool indent = true);
403};
404
405String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
406 runtime::TypedPackedFunc<std::string(Stmt)> annotate);
407
408class TextPrinter {
409 public:
410 explicit TextPrinter(bool show_meta_data,
411 const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate,
412 bool show_warning = true)
413 : show_meta_data_(show_meta_data),
414 show_warning_(show_warning),
415 annotate_(annotate),
416 relay_text_printer_(show_meta_data, &meta_, annotate),
417 tir_text_printer_(show_meta_data, &meta_) {}
418
419 /*! \brief whether show meta data */
420 bool show_meta_data_;
421
422 /*! \brief whether show the meta data warning message */
423 bool show_warning_;
424
425 /*! \brief meta data context */
426 TextMetaDataContext meta_;
427 /*! \brief additional comment function */
428 runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
429 /*! \brief Relay Text Printer */
430 relay::RelayTextPrinter relay_text_printer_;
431 /*! \brief TIR Text Printer */
432 TIRTextPrinter tir_text_printer_;
433
434 bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); }
435
436 Doc PrintFinal(const ObjectRef& node) {
437 Doc doc;
438 if (node.defined() && node->IsInstance<IRModuleNode>()) {
439 doc << PrintMod(Downcast<IRModule>(node));
440 } else if (node.defined() &&
441 (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() ||
442 node->IsInstance<tir::StmtNode>())) {
443 doc << tir_text_printer_.Print(node);
444 } else {
445 doc << relay_text_printer_.PrintFinal(node);
446 }
447 if (!meta_.empty()) {
448 doc << Doc::NewLine();
449 if (show_meta_data_) {
450 doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection();
451 } else if (show_warning_) {
452 doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine()
453 << " * If you would like to see the full metadata section you can set the "
454 << Doc::NewLine() << " * option to `True` when invoking `astext`. " << Doc::NewLine()
455 << " */";
456 }
457 }
458 return doc;
459 }
460
461 Doc PrintMod(const IRModule& mod);
462};
463} // namespace relay
464} // namespace tvm
465
466#endif // TVM_RELAY_PRINTER_TEXT_PRINTER_H_
467