1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file text_printer.h
22 * \brief Printer to print out the unified IR text format
23 * that can be parsed by a parser.
24 */
25
26#ifndef TVM_PRINTER_TEXT_PRINTER_H_
27#define TVM_PRINTER_TEXT_PRINTER_H_
28
29#include <tvm/ir/module.h>
30#include <tvm/ir/type_functor.h>
31#include <tvm/relay/expr_functor.h>
32#include <tvm/relay/pattern_functor.h>
33#include <tvm/tir/expr_functor.h>
34#include <tvm/tir/function.h>
35#include <tvm/tir/op.h>
36#include <tvm/tir/stmt_functor.h>
37#include <tvm/tir/var.h>
38
39#include <string>
40#include <unordered_map>
41#include <unordered_set>
42#include <vector>
43
44#include "../ir/attr_functor.h"
45#include "../relay/analysis/dependency_graph.h"
46#include "doc.h"
47#include "meta_data.h"
48#include "text_printer.h"
49
50namespace tvm {
51class TextPrinter;
52} // namespace tvm
53
54namespace tvm {
55namespace relay {
56
57class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
58 public PatternFunctor<Doc(const Pattern&)>,
59 public TypeFunctor<Doc(const Type&)>,
60 public AttrFunctor<Doc(const ObjectRef&)> {
61 public:
62 explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta,
63 runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
64 : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {}
65 Doc VisitExpr(const Expr& expr) override;
66 virtual Doc VisitLeaf(const Expr& expr);
67 virtual bool CheckVisited(const Expr& expr);
68
69 /*!
70 * \brief Print additional info about expr in comment.
71 * \param expr The expression.
72 */
73 Doc PrintOptionalInfo(const Expr& expr);
74 // indent a new body
75 Doc PrintBody(const ObjectRef& node, int indent = 2);
76 // create a new scope by creating a new printer object. This allows temp var
77 // numbers to be reused and prevents hoisted vars from escaping too far
78 Doc PrintScope(const ObjectRef& node);
79 Doc PrintFinal(const ObjectRef& node);
80
81 /*!
82 * \brief Returns \p attrs printed using the generic attribute visitor, as a sequence
83 * of key=value entries, if any.
84 */
85 void AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs, bool include_type_key);
86
87 /*!
88 * \brief Returns \p attrs printed as a sequence of key=value entries, if any.
89 * This is used for call attributes.
90 */
91 std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
92
93 /*!
94 * \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any.
95 * This is used for function definition attributes.
96 */
97 std::vector<Doc> PrintDictAttrs(const DictAttrs& dict_attrs);
98 std::vector<Doc> PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs);
99
100 /*!
101 * \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta
102 * is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag.
103 */
104 Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false);
105
106 /*!
107 * \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces.
108 */
109 Doc PrintAttrsAsAttributeValue(const Attrs& attrs);
110
111 /*!
112 * \brief Returns \p map printed as a self-contained value, ie wrapped in braces.
113 */
114 Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);
115
116 Doc PrintSpan(const Span& span);
117
118 Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
119
120 Doc TempVar(int n);
121 Doc AllocTemp();
122 /*!
123 * \brief get a unique name with the corresponding prefix
124 * \param prefix The prefix of the name
125 * \return The returned name.
126 */
127 Doc GetUniqueName(const std::string& prefix);
128 Doc Print(Kind k);
129 /*!
130 * \brief Allocate name to a type variable.
131 * \param var The input type variable.
132 * \return The corresponding name.
133 */
134 Doc AllocTypeVar(const TypeVar& var);
135 /*!
136 * \brief Allocate name to a variable.
137 * \param var The input variable.
138 * \return The corresponding name.
139 */
140 Doc AllocVar(const Var& var);
141 bool IsUnique(const Expr& expr);
142 bool AlwaysInline(const Expr& expr);
143
144 Doc PrintFunc(const Doc& prefix, const relay::Function& fn);
145 Doc PrintFunc(const Doc& prefix, const BaseFunc& base_func);
146 Doc PrintMod(const IRModule& mod);
147
148 //------------------------------------
149 // Overload of Expr printing functions
150 //------------------------------------
151 Doc PrintExpr(const Expr& expr, bool meta, bool try_inline, bool optional_info = true);
152 // Should only be triggered when op is a free variable being visited for the
153 // first time.
154 Doc VisitExpr_(const VarNode* op) final;
155 Doc VisitExpr_(const ConstantNode* op) final;
156 Doc VisitExpr_(const TupleNode* op) final;
157 Doc VisitExpr_(const TupleGetItemNode* op) final;
158 Doc VisitExpr_(const IfNode* op) final;
159 Doc VisitExpr_(const LetNode* op) final;
160 Doc VisitExpr_(const FunctionNode* op) final;
161 Doc VisitExpr_(const GlobalVarNode* op) final;
162 Doc VisitExpr_(const OpNode* op) final;
163 Doc VisitExpr_(const CallNode* op) final;
164 Doc VisitExpr_(const RefCreateNode* op) final;
165 Doc VisitExpr_(const RefReadNode* op) final;
166 Doc VisitExpr_(const RefWriteNode* op) final;
167 Doc VisitExpr_(const MatchNode* op) final;
168 Doc PrintPattern(const Pattern& pattern, bool meta);
169 Doc VisitPattern_(const PatternConstructorNode* p) final;
170 Doc VisitPattern_(const PatternTupleNode* pt) final;
171 Doc VisitPattern_(const PatternWildcardNode* pw) final;
172 Doc VisitPattern_(const PatternVarNode* pv) final;
173 Doc VisitExpr_(const ConstructorNode* n) final;
174 //------------------------------------
175 // Overload of Type printing functions
176 //------------------------------------
177 Doc PrintType(const Type& type, bool meta);
178 Doc VisitTypeDefault_(const Object* node) final;
179 Doc VisitType_(const TypeVarNode* node) final;
180 Doc VisitType_(const GlobalTypeVarNode* node) final;
181 Doc VisitType_(const TypeCallNode* node) final;
182 Doc PrintDType(DataType dtype);
183 Doc VisitType_(const TensorTypeNode* node) final;
184 Doc VisitType_(const TupleTypeNode* node) final;
185 Doc VisitType_(const FuncTypeNode* node) final;
186 Doc VisitType_(const RelayRefTypeNode* node) final;
187 Doc VisitType_(const TypeDataNode* node) final;
188 //------------------------------------
189 // Overload of Attr printing functions
190 //------------------------------------
191 Doc VisitAttrDefault_(const Object* op) final;
192 Doc VisitAttr_(const ArrayNode* op) final;
193 Doc VisitAttr_(const tir::IntImmNode* op) final;
194 Doc VisitAttr_(const tir::FloatImmNode* op) final;
195 Doc VisitAttr_(const tir::StringImmNode* op) final;
196
197 private:
198 /*! \brief Whether to print meta data. */
199 bool show_meta_data_;
200 /*! \brief additional comment function */
201 runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
202 /*! \brief Stack of docs to implement scoped GNFing. */
203 std::vector<Doc> doc_stack_{};
204 /*! \brief Set for introduced vars */
205 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
206 /*! \brief Set for exprs have been printed optional information */
207 std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> opt_info_memo_;
208 /*! \brief Map for result and memo_ diffs for visited expression */
209 std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_;
210 /*! \brief Map from Expr to Doc */
211 std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> memo_;
212 /*! \brief Map from Type to Doc */
213 std::unordered_map<Type, Doc, ObjectPtrHash, ObjectPtrEqual> memo_type_;
214 /*! \brief Map from Type to Doc */
215 std::unordered_map<Pattern, Doc, ObjectPtrHash, ObjectPtrEqual> memo_pattern_;
216 /*! \brief name allocation map */
217 std::unordered_map<std::string, int> name_alloc_map_;
218 /*! \brief meta data context */
219 TextMetaDataContext* meta_;
220 /*! \brief counter of temporary variable */
221 size_t temp_var_counter_{0};
222 /*! \brief whether the printer is currently in an ADT definition */
223 bool in_adt_def_;
224 /*! \brief arena for dependency graph */
225 support::Arena arena_;
226 /*! \brief dependency graph of the expr */
227 DependencyGraph dg_;
228 class AttrPrinter;
229 friend class AttrPrinter;
230 friend class tvm::TextPrinter;
231};
232
233} // namespace relay
234} // namespace tvm
235
236namespace tvm {
237namespace tir {
238
239/*!
240 * \brief Meta node collector
241 * If we decide to put some node into meta, then all the sub-nodes inside
242 * it need to be put in meta as well, since when parsing we need to know
243 * whether two refs are the same
244 */
245class MetaCollector : public StmtExprVisitor {
246 public:
247 explicit MetaCollector(TextMetaDataContext* meta) : meta_(meta) {}
248
249 void Collect(const ObjectRef& n) {
250 // these nodes can be print directly(StringLiteral or use identifier to identify)
251 if (!n.defined() || n.as<StringImmNode>() || n.as<StringObj>() || n.as<SizeVarNode>() ||
252 n.as<VarNode>() || n.as<BufferNode>() || n.as<IterVarNode>()) {
253 return;
254 }
255 if (n->IsInstance<StmtNode>()) {
256 VisitStmt(Downcast<Stmt>(n));
257 } else if (n->IsInstance<PrimExprNode>()) {
258 VisitExpr(Downcast<PrimExpr>(n));
259 }
260 }
261
262 void VisitStmt(const Stmt& n) override {
263 meta_->GetMetaNode(n);
264 StmtVisitor::VisitStmt(n);
265 }
266
267 void VisitExpr(const PrimExpr& n) override {
268 meta_->GetMetaNode(n);
269 ExprVisitor::VisitExpr(n);
270 }
271
272 private:
273 TextMetaDataContext* meta_;
274};
275
276class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
277 public ExprFunctor<Doc(const PrimExpr&)>,
278 public TypeFunctor<Doc(const Type&)> {
279 public:
280 explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta)
281 : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {}
282
283 /*! \brief Output a newline */
284 virtual Doc NewLine();
285
286 /*! \brief Print the node */
287 Doc Print(const ObjectRef& node);
288
289 /*! \brief Place into `s` the name used in the preceding Print call for `v`.
290 * \param v Var instance to check. Must point to a VarNode visited by Print.
291 * \param s String to receive the name.
292 * \return true when a name re-mapping was found.
293 */
294 bool GetVarName(::tvm::tir::Var v, std::string* s);
295
296 protected:
297 Doc VisitExpr_(const IntImmNode* op) override;
298 Doc VisitExpr_(const FloatImmNode* op) override;
299 Doc VisitExpr_(const StringImmNode* op) override;
300 Doc VisitExpr_(const CastNode* op) override;
301 Doc VisitExpr_(const VarNode* op) override;
302 Doc VisitExpr_(const AddNode* op) override;
303 Doc VisitExpr_(const SubNode* op) override;
304 Doc VisitExpr_(const MulNode* op) override;
305 Doc VisitExpr_(const DivNode* op) override;
306 Doc VisitExpr_(const ModNode* op) override;
307 Doc VisitExpr_(const FloorDivNode* op) override;
308 Doc VisitExpr_(const FloorModNode* op) override;
309 Doc VisitExpr_(const MinNode* op) override;
310 Doc VisitExpr_(const MaxNode* op) override;
311 Doc VisitExpr_(const EQNode* op) override;
312 Doc VisitExpr_(const NENode* op) override;
313 Doc VisitExpr_(const LTNode* op) override;
314 Doc VisitExpr_(const LENode* op) override;
315 Doc VisitExpr_(const GTNode* op) override;
316 Doc VisitExpr_(const GENode* op) override;
317 Doc VisitExpr_(const AndNode* op) override;
318 Doc VisitExpr_(const OrNode* op) override;
319 Doc VisitExpr_(const NotNode* op) override;
320 Doc VisitExpr_(const SelectNode* op) override;
321 Doc VisitExpr_(const BufferLoadNode* op) override;
322 Doc VisitExpr_(const ProducerLoadNode* op) override;
323 Doc VisitExpr_(const LoadNode* op) override;
324 Doc VisitExpr_(const RampNode* op) override;
325 Doc VisitExpr_(const BroadcastNode* op) override;
326 Doc VisitExpr_(const LetNode* op) override;
327 Doc VisitExpr_(const CallNode* op) override;
328 Doc VisitExpr_(const ShuffleNode* op) override;
329 Doc VisitExpr_(const ReduceNode* op) override;
330 Doc VisitExprDefault_(const Object* op) override;
331
332 Doc VisitStmt_(const LetStmtNode* op) override;
333 Doc VisitStmt_(const AttrStmtNode* op) override;
334 Doc VisitStmt_(const AssertStmtNode* op) override;
335 Doc VisitStmt_(const StoreNode* op) override;
336 Doc VisitStmt_(const BufferStoreNode* op) override;
337 Doc VisitStmt_(const ProducerStoreNode* op) override;
338 Doc VisitStmt_(const BufferRealizeNode* op) override;
339 Doc VisitStmt_(const ProducerRealizeNode* op) override;
340 Doc VisitStmt_(const AllocateNode* op) override;
341 Doc VisitStmt_(const AllocateConstNode* op) override;
342 Doc VisitStmt_(const DeclBufferNode* op) override;
343 Doc VisitStmt_(const IfThenElseNode* op) override;
344 Doc VisitStmt_(const SeqStmtNode* op) override;
345 Doc VisitStmt_(const EvaluateNode* op) override;
346 Doc VisitStmt_(const ForNode* op) override;
347 Doc VisitStmt_(const WhileNode* op) override;
348 Doc VisitStmt_(const PrefetchNode* op) override;
349 Doc VisitStmt_(const BlockRealizeNode* op) override;
350 Doc VisitStmtDefault_(const Object* op) override;
351
352 private:
353 /*! \brief whether show meta data */
354 bool show_meta_;
355 /*! \brief meta data context */
356 TextMetaDataContext* meta_;
357 /*! \brief meta collector */
358 MetaCollector meta_collector_;
359 /*! \brief Map from Var to Doc */
360 std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
361 /*! \brief Map from Buffer to Doc */
362 std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
363 /*! \brief Map from Buffer to Doc */
364 std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_;
365 /*! \brief name allocation map */
366 std::unordered_map<std::string, int> name_alloc_map_;
367
368 friend class tvm::TextPrinter;
369
370 Doc VisitType_(const PrimTypeNode* node) override;
371 Doc VisitType_(const PointerTypeNode* node) override;
372 Doc VisitType_(const TupleTypeNode* node) override;
373
374 Doc PrintIRModule(const IRModule& module);
375 Doc PrintPrimFunc(const PrimFunc& primFunc);
376 Doc PrintArray(const ArrayNode* op);
377 Doc PrintIterVar(const IterVarNode* op);
378 Doc PrintRange(const RangeNode* op);
379 Doc PrintBuffer(const BufferNode* op);
380 Doc PrintProducer(const DataProducerNode* op);
381 Doc BufferNode2Doc(const BufferNode* op, Doc doc);
382 Doc DataProducerNode2Doc(const DataProducerNode* op, Doc doc);
383 Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); }
384 Doc PrintBufferRegion(const BufferRegionNode* op);
385
386 /*!
387 * \brief special method to print out data type
388 * \param dtype The data type
389 */
390 static Doc PrintDType(DataType dtype);
391 /*!
392 * \brief special method to print out const scalar
393 * \param dtype The data type
394 * \param data The pointer to hold the data.
395 */
396 template <typename T>
397 static Doc PrintConstScalar(DataType dtype, const T& data);
398 Doc GetUniqueName(std::string prefix);
399 Doc AllocVar(const Var& var);
400 Doc AllocConst(const AllocateConst& var);
401 Doc AllocBuf(const Buffer& buffer);
402 Doc AllocProducer(const DataProducer& buffer);
403 /*!
404 * \brief special method to render vectors of docs with a separator
405 * \param vec vector of docs
406 * \param sep separator
407 */
408 static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep);
409 Doc PrintBody(const Stmt& body, bool indent = true);
410};
411
412String AsTVMScript(const ObjectRef& mod, const String& tir_prefix = "T", bool show_meta = false);
413
414String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
415 runtime::TypedPackedFunc<std::string(Stmt)> annotate);
416
417} // namespace tir
418} // namespace tvm
419
420namespace tvm {
421
422class TextPrinter {
423 public:
424 explicit TextPrinter(bool show_meta_data,
425 const runtime::TypedPackedFunc<std::string(ObjectRef)>& annotate,
426 bool show_warning = true)
427 : show_meta_data_(show_meta_data),
428 show_warning_(show_warning),
429 annotate_(annotate),
430 relay_text_printer_(show_meta_data, &meta_, annotate),
431 tir_text_printer_(show_meta_data, &meta_) {}
432
433 /*! \brief whether show meta data */
434 bool show_meta_data_;
435
436 /*! \brief whether show the meta data warning message */
437 bool show_warning_;
438
439 /*! \brief meta data context */
440 TextMetaDataContext meta_;
441 /*! \brief additional comment function */
442 runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
443 /*! \brief Relay Text Printer */
444 relay::RelayTextPrinter relay_text_printer_;
445 /*! \brief TIR Text Printer */
446 tir::TIRTextPrinter tir_text_printer_;
447
448 bool GetVarName(::tvm::tir::Var v, std::string* s) { return tir_text_printer_.GetVarName(v, s); }
449
450 Doc PrintFinal(const ObjectRef& node) {
451 Doc doc;
452 if (node.defined() && node->IsInstance<IRModuleNode>()) {
453 doc << PrintMod(Downcast<IRModule>(node));
454 } else if (node.defined() &&
455 (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() ||
456 node->IsInstance<tir::StmtNode>())) {
457 doc << tir_text_printer_.Print(node);
458 } else {
459 doc << relay_text_printer_.PrintFinal(node);
460 }
461 if (!meta_.empty()) {
462 doc << Doc::NewLine();
463 if (show_meta_data_) {
464 doc << "#[metadata]" << Doc::NewLine() << meta_.GetMetaSection();
465 } else if (show_warning_) {
466 doc << "/* For debugging purposes the metadata section has been omitted." << Doc::NewLine()
467 << " * If you would like to see the full metadata section you can set the "
468 << Doc::NewLine() << " * option to `True` when invoking `astext`. " << Doc::NewLine()
469 << " */";
470 }
471 }
472 return doc;
473 }
474
475 Doc PrintMod(const IRModule& mod);
476};
477} // namespace tvm
478
479#endif // TVM_PRINTER_TEXT_PRINTER_H_
480