1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file printer/tvmscript_printer.cc
22 * \brief Printer class to print Tensor IR to python syntax script
23 */
24
25#include <tvm/arith/analyzer.h>
26#include <tvm/ir/module.h>
27#include <tvm/node/serialization.h>
28#include <tvm/runtime/registry.h>
29#include <tvm/target/target.h>
30#include <tvm/tir/analysis.h>
31#include <tvm/tir/buffer.h>
32#include <tvm/tir/expr.h>
33#include <tvm/tir/expr_functor.h>
34#include <tvm/tir/function.h>
35#include <tvm/tir/op.h>
36#include <tvm/tir/stmt.h>
37#include <tvm/tir/stmt_functor.h>
38
39#include <algorithm>
40#include <utility>
41
42#include "../../tir/transforms/ir_utils.h"
43#include "doc.h"
44#include "meta_data.h"
45#include "text_printer.h"
46
47namespace tvm {
48namespace relay {
49
50using namespace tvm::tir;
51
52enum class ExprPrecedence : int {
53 /*! \brief Identity(e.g., IntImm, Var) and function call(e.g., floordiv, min) */
54 kIdentity = 0,
55 /*!
56 * \brief Multiplication(*), division(/), and remainder(%)
57 * \note floorDiv, floorMod is marked as kIdentity since they are function calls.
58 */
59 kMultiplicationDivision = 1,
60 /*! \brief Addition(+) and subtraction(-) */
61 kAdditionSubtraction = 2,
62 /*! \brief For relational operators < and <= and > and >= respectively */
63 kRelational = 3,
64 /*! \brief For equality operators = and != respectively */
65 kEquality = 4,
66 /*! \brief And(&&) */
67 kAnd = 5,
68 /*! \brief Or(||) */
69 kOr = 6,
70 /*! \brief Unknown precedence */
71 kUnknown = 7,
72};
73
74/*! \brief Utility used for identifying usage of a buffer_var
75 *
76 * \details Find the Buffer object that corresponds to a variable or
77 * allocation, based on the BufferLoad/BufferStore instances that
78 * occur within the allocation's body.
79 */
80class BufferUsageFinder : public StmtExprVisitor {
81 public:
82 static Map<tir::Var, Array<Buffer>> FindUsage(Map<tir::Var, Array<Buffer>> usage, Stmt body) {
83 BufferUsageFinder visitor(std::move(usage));
84 visitor.VisitStmt(body);
85 return std::move(visitor.usage_);
86 }
87
88 void VisitExpr_(const tir::VarNode* op) final {
89 tir::Var var = GetRef<tir::Var>(op);
90 if (!usage_.count(var)) {
91 usage_.Set(var, {});
92 }
93 }
94
95 void VisitExpr_(const BufferLoadNode* op) final {
96 VisitBuffer(op->buffer);
97 StmtExprVisitor::VisitExpr_(op);
98 }
99
100 void VisitStmt_(const BufferStoreNode* op) final {
101 VisitBuffer(op->buffer);
102 StmtExprVisitor::VisitStmt_(op);
103 }
104
105 void VisitStmt_(const DeclBufferNode* op) final {
106 buffers_declared_.insert(op->buffer.get());
107 StmtExprVisitor::VisitStmt_(op);
108 buffers_declared_.erase(op->buffer.get());
109 }
110
111 private:
112 explicit BufferUsageFinder(Map<tir::Var, Array<Buffer>> usage) : usage_(usage) {}
113
114 void VisitBuffer(const Buffer& buffer) {
115 if (buffers_visited_.count(buffer.get())) {
116 return;
117 }
118 if (buffers_declared_.count(buffer.get())) {
119 return;
120 }
121 buffers_visited_.insert(buffer.get());
122
123 Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
124 arr.push_back(buffer);
125 usage_.Set(buffer->data, arr);
126 }
127
128 // The search result.
129 Map<tir::Var, Array<Buffer>> usage_;
130 // The buffers that have been visited so far, to avoid duplicate
131 // entries in the search result.
132 std::unordered_set<const BufferNode*> buffers_visited_;
133 // The buffers declared via `DeclBuffer`. These buffers are excluded from the result because
134 // T.buffer_decl shouldn't be printed for them.
135 std::unordered_set<const BufferNode*> buffers_declared_;
136};
137
138/*!
139 * \brief The printer for TVMScript
140 * \details The printer obtain the precedence of the top-level operation when printing each
141 * subexpression to decide whether or not parentheses is needed.
142 */
143class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
144 public tir::ExprFunctor<Doc(const PrimExpr&, ExprPrecedence*)>,
145 public TypeFunctor<Doc(const Type&)> {
146 public:
147 explicit TVMScriptPrinter(const String& tir_prefix, bool show_meta,
148 runtime::TypedPackedFunc<std::string(Stmt)> annotate = nullptr)
149 : tir_prefix_(tir_prefix),
150 show_meta_(show_meta),
151 annotate_(std::move(annotate)),
152 meta_collector_(&meta_) {}
153
154 /*!
155 * \brief Print the node.
156 * \param node The node to be printed.
157 * \param out_precedence The operator precedence of node if it's a PrimExpr,
158 * so we can simplify the bracket.
159 */
160 TVM_DLL Doc Print(const ObjectRef& node);
161
162 protected:
163 /*! \brief The tir prefix */
164 String tir_prefix_;
165 /*! \brief whether show meta data */
166 bool show_meta_;
167 /*! \brief additional comment function */
168 runtime::TypedPackedFunc<std::string(Stmt)> annotate_;
169 /*! \brief meta data context */
170 TextMetaDataContext meta_;
171 /*! \brief meta collector */
172 relay::MetaCollector meta_collector_;
173 /*! \brief map from Function to GlobalVar */
174 std::unordered_map<const BaseFuncNode*, GlobalVar> func2var_;
175 /*! \brief var collector (var defined by For/Loop/Block) */
176 std::unordered_set<const tir::VarNode*> var_not_in_headers_;
177 /*!
178 * \brief buffer collector
179 * (buffer defined in BufferMap, BufferAllocation and MatchBufferRegion)
180 */
181 std::unordered_set<const BufferNode*> buf_not_in_headers_;
182 /*! \brief Map from Var to thread env name */
183 std::unordered_map<tir::Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_;
184 /*! \brief Map from Var to Doc */
185 std::unordered_map<tir::Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
186 /*! \brief Map from Buffer to Doc */
187 std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
188 /*! \brief Map from Buffer to Declaration Doc */
189 std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_decl_;
190 /*! \brief name allocation map */
191 std::unordered_map<std::string, int> name_alloc_map_;
192 /*! \brief number of children of current node's parent */
193 int num_child_;
194 /*! \brief the number of current node */
195 int current_num_;
196 /*! \brief loop stack without annotations */
197 std::vector<For> simple_loop_stack_;
198 /*! \brief the maps from loop_vars to the loops */
199 std::unordered_map<const tir::VarNode*, For> loop_var_map_;
200 /*!
201 * \brief simple block vars remap from loop vars
202 * simple_remap requires:
203 * 1. block var iter type is kDataPar or kCommReduce
204 * 2. value is a single Var, which is a loop_var outside the block
205 * 3. The iter range is equal to loop range
206 */
207 std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_;
208 /*!
209 * \brief Map from variables to the buffers they are used in.
210 *
211 * Used for identifying buffers that should be declared after the
212 * LetStmt or Allocate that generates their data pointer, rather
213 * than in the header.
214 */
215 Map<tir::Var, Array<Buffer>> buffer_var_usage_;
216 /*! \brief Analyzer to simplify some expressions. */
217 arith::Analyzer ana_;
218
219 Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
220 Doc VisitExpr_(const tir::VarNode* op, ExprPrecedence* out_precedence) override;
221 Doc VisitExpr_(const AddNode* op, ExprPrecedence* out_precedence) override;
222 Doc VisitExpr_(const SubNode* op, ExprPrecedence* out_precedence) override;
223 Doc VisitExpr_(const MulNode* op, ExprPrecedence* out_precedence) override;
224 Doc VisitExpr_(const DivNode* op, ExprPrecedence* out_precedence) override;
225 Doc VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) override;
226 Doc VisitExpr_(const FloorDivNode* op, ExprPrecedence* out_precedence) override;
227 Doc VisitExpr_(const FloorModNode* op, ExprPrecedence* out_precedence) override;
228 Doc VisitExpr_(const MinNode* op, ExprPrecedence* out_precedence) override;
229 Doc VisitExpr_(const MaxNode* op, ExprPrecedence* out_precedence) override;
230 Doc VisitExpr_(const EQNode* op, ExprPrecedence* out_precedence) override;
231 Doc VisitExpr_(const NENode* op, ExprPrecedence* out_precedence) override;
232 Doc VisitExpr_(const LTNode* op, ExprPrecedence* out_precedence) override;
233 Doc VisitExpr_(const LENode* op, ExprPrecedence* out_precedence) override;
234 Doc VisitExpr_(const GTNode* op, ExprPrecedence* out_precedence) override;
235 Doc VisitExpr_(const GENode* op, ExprPrecedence* out_precedence) override;
236 Doc VisitExpr_(const AndNode* op, ExprPrecedence* out_precedence) override;
237 Doc VisitExpr_(const OrNode* op, ExprPrecedence* out_precedence) override;
238 Doc VisitExpr_(const NotNode* op, ExprPrecedence* out_precedence) override;
239 Doc VisitExpr_(const SelectNode* op, ExprPrecedence* out_precedence) override;
240 Doc VisitExpr_(const IntImmNode* op, ExprPrecedence* out_precedence) override;
241 Doc VisitExpr_(const FloatImmNode* op, ExprPrecedence* out_precedence) override;
242 Doc VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) override;
243 Doc VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) override;
244 Doc VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) override;
245 Doc VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) override;
246 Doc VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) override;
247 Doc VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) override;
248 Doc VisitExpr_(const tir::LetNode* op, ExprPrecedence* out_precedence) override;
249 Doc VisitExpr_(const tir::CallNode* op, ExprPrecedence* out_precedence) override;
250 Doc VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) override;
251 Doc VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) override;
252 Doc VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) override;
253
254 Doc VisitStmt_(const LetStmtNode* op) override;
255 Doc VisitStmt_(const AttrStmtNode* op) override;
256 Doc VisitStmt_(const AssertStmtNode* op) override;
257 Doc VisitStmt_(const StoreNode* op) override;
258 Doc VisitStmt_(const BufferStoreNode* op) override;
259 Doc VisitStmt_(const BufferRealizeNode* op) override;
260 Doc VisitStmt_(const AllocateNode* op) override;
261 Doc VisitStmt_(const AllocateConstNode* op) override;
262 Doc VisitStmt_(const DeclBufferNode* op) override;
263 Doc VisitStmt_(const IfThenElseNode* op) override;
264 Doc VisitStmt_(const SeqStmtNode* op) override;
265 Doc VisitStmt_(const ForNode* op) override;
266 Doc VisitStmt_(const WhileNode* op) override;
267 Doc VisitStmt_(const PrefetchNode* op) override;
268 Doc VisitStmt_(const EvaluateNode* op) override;
269 Doc VisitStmt_(const BlockRealizeNode* op) override;
270 Doc VisitStmtDefault_(const Object* op) override;
271
272 Doc VisitType_(const PrimTypeNode* node) override;
273 Doc VisitType_(const PointerTypeNode* node) override;
274 Doc VisitType_(const TupleTypeNode* node) override;
275
276 Doc PrintBody(const Stmt& body);
277 Doc PrintIRModule(const IRModule& module);
278 Doc PrintPrimFunc(const PrimFunc& primFunc);
279 Doc PrintIterVar(const IterVarNode* op);
280 Doc PrintRange(const RangeNode* op);
281 Doc PrintArray(const ArrayNode* op);
282 Doc PrintBuffer(const BufferNode* op);
283 Doc PrintBufferIndices(const Array<PrimExpr>& indices);
284 Doc PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers);
285 Doc AllocBufferDeclaration(const Buffer& buf);
286 Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
287 Doc PrintBlockVarRemaps();
288 Doc PrintBlockPredicate(const BlockRealizeNode* op);
289 Doc PrintBlockVars(const BlockRealizeNode* op);
290 Doc PrintBlockAttr(const BlockRealizeNode* op);
291 Doc PrintExpandedArray(const ArrayNode* op);
292 Doc PrintBlockBody(const BlockNode* op);
293 virtual Doc PrintBlockName(const BlockNode* block_op);
294 Doc PrintBufferRegion(const BufferRegionNode* op);
295 Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op);
296 Doc PrintCommReducer(const CommReducerNode* op);
297 Doc PrintAnnotations(const Map<String, ObjectRef>& annotations);
298 Doc PrintTarget(const TargetNode* target);
299 static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); }
300
301 Doc GetUniqueName(std::string prefix);
302 Doc AllocVar(const tir::Var& var);
303 Doc AllocBuf(const Buffer& buffer);
304 void TryDeallocVar(const tir::Var& var);
305 bool ContainsOptionalInfo(const Stmt& stmt);
306 /*!
307 * \brief Check if a buffer declaration satisfies:
308 * 1. has only 'shape' and 'dtype' arguments specified,
309 * 2. the shape and strides are not dynamic.
310 * \param buffer The match buffer to be checked
311 */
312 bool IsSimpleBuffer(const Buffer& buffer);
313 Doc PrintInlineBufferBind(const Buffer& buffer);
314 Doc PrintTuple(const ArrayNode* op);
315
316 /*! Helper functions for loop printing. */
317 /*!
318 * \brief Print a single for loop
319 * \param loop The for loop to be printed
320 */
321 virtual Doc PrintLoop(const For& loop);
322 /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */
323 Doc PrintLoopStack();
324 /*!
325 * \brief Check whether a loop satisfies:
326 * 1. the loop is serial;
327 * 2. the loop has no annotation;
328 * 3. the loop starts from 0;
329 * 4. there is no optional information.
330 * \param for_op the for node to be checked
331 * \return A boolean indicating whether the input loop satisfies the above conditions
332 */
333 bool IsSimpleLoop(const ForNode* for_op) {
334 return for_op->kind == ForKind::kSerial && for_op->annotations.empty() &&
335 is_zero(for_op->min) && !ContainsOptionalInfo(GetRef<Stmt>(for_op));
336 }
337 /*!
338 * \brief Check whether the `min` or `extent` of a loop depends on previous loops
339 * \param for_op The loop to be checked
340 * \return A boolean indicating whether the input loop depends on previous loops
341 */
342 bool DependOnPrevLoops(const ForNode* for_op) {
343 auto f_check = [&var_map = this->loop_var_map_](const tir::VarNode* v) {
344 return var_map.count(v);
345 };
346 return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check);
347 }
348
349 /*!
350 * \brief Print additional info about expr in comment.
351 * \param expr The expression.
352 */
353 Doc PrintOptionalInfo(const Stmt& stmt) {
354 Doc doc;
355 // default annotations
356 if (ContainsOptionalInfo(stmt)) {
357 std::string annotated_stmt = annotate_(stmt);
358 doc << "# " << annotated_stmt << Doc::NewLine();
359 }
360 return doc;
361 }
362
363 /*!
364 * \brief special method to render vectors of docs with a separator
365 * \param vec vector of docs
366 * \param sep separator
367 */
368 static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
369 Doc seq;
370 if (vec.size() != 0) {
371 seq = vec[0];
372 for (size_t i = 1; i < vec.size(); i++) {
373 seq << sep << vec[i];
374 }
375 }
376 return seq;
377 }
378
379 /*!
380 * \brief dump meta info
381 * \return Doc with meta info
382 */
383 Doc DumpMeta() {
384 if (show_meta_) {
385 return Doc::Text("__tvm_meta__ = ")
386 << (meta_.empty() ? Doc::Text("None") : meta_.GetMetaSection());
387 } else {
388 return Doc::Text("");
389 }
390 }
391
392 /*!
393 * \brief special method to print out data type
394 * \param dtype The data type
395 */
396 static Doc PrintDType(DataType dtype) {
397 return Doc::StrLiteral(runtime::DLDataType2String(dtype));
398 }
399
400 /*!
401 * \brief special method to print out const int64_t scalar
402 * \param dtype The data type
403 * \param data The pointer to hold the data.
404 */
405 Doc PrintConstScalar(DataType dtype, const int64_t* data) const {
406 Doc doc;
407 std::ostringstream os;
408
409 os << data[0];
410
411 if (dtype == DataType::Int(32)) {
412 doc << Doc::Text(os.str());
413 } else if (dtype == DataType::Bool()) {
414 doc << Doc::Text(data[0] ? "True" : "False");
415 } else {
416 doc << tir_prefix_ << "." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str())
417 << ")";
418 }
419 return doc;
420 }
421
422 /*!
423 * \brief special method to print out const double scalar
424 * \param dtype The data type
425 * \param data The pointer to hold the data.
426 * \note this overriden function is created as std::isnan of msvc will complain about int64_t
427 */
428 Doc PrintConstScalar(DataType dtype, const double* data) const {
429 Doc doc;
430 std::ostringstream os;
431
432 os.precision(17);
433 if (std::isinf(data[0]) || std::isnan(data[0])) {
434 os << "\"" << data[0] << "\"";
435 } else {
436 os << data[0];
437 }
438
439 doc << tir_prefix_ << "." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str())
440 << ")";
441
442 return doc;
443 }
444
445 public:
446 static Doc PrintHeader(const std::string& tir_prefix) {
447 Doc header;
448 if (tir_prefix != "tir") {
449 header << "# from tvm.script import tir as " << tir_prefix << Doc::NewLine();
450 } else {
451 header << "# from tvm.script import tir" << Doc::NewLine();
452 }
453 return header;
454 }
455};
456
457/*!
458 * \brief special method to print NDArray in TIR
459 * \param arr the NDArray to be printed
460 * \param os the output stream where the NDArray will be printed to
461 */
462template <typename T>
463void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) {
464 if ((arr.DataType().code() == runtime::DataType::kInt ||
465 arr.DataType().code() == runtime::DataType::kUInt) &&
466 arr.DataType().bits() == 8) {
467 // Printing int8 NDArrays causes "UnicodeDecodeError: 'utf-8' codec can't decode byte"
468 // error during MetaSchedule tuning on int8 models.
469 return;
470 }
471 int ndim = arr->ndim;
472 int tot_dim = 1;
473 for (int i = 0; i < ndim; i++) {
474 tot_dim *= arr->shape[i];
475 }
476 T* data_ptr = reinterpret_cast<T*>(arr->data);
477 constexpr int NUM_PRINT = 20;
478 os << "[";
479 for (int i = 0; i < tot_dim; i++) {
480 os << (i != 0 ? ", " : "") << data_ptr[i];
481 if (i == NUM_PRINT) {
482 os << "...";
483 break;
484 }
485 }
486 os << "]";
487}
488
489Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
490 std::replace(prefix.begin(), prefix.end(), '.', '_');
491 std::string unique_prefix = prefix;
492 auto it = name_alloc_map_.find(prefix);
493 if (it != name_alloc_map_.end() && it->second >= 0) {
494 while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {
495 }
496 }
497 name_alloc_map_[unique_prefix] = 0;
498 return Doc::Text(unique_prefix);
499}
500
501Doc TVMScriptPrinter::AllocVar(const tir::Var& var) {
502 const auto& it = memo_var_.find(var);
503 if (it != memo_var_.end()) {
504 return it->second;
505 }
506 std::string name = var->name_hint.operator std::string();
507 if (name.length() == 0 || !std::isalpha(name[0])) {
508 name = "v" + name;
509 }
510 Doc val = GetUniqueName(name);
511 memo_var_[var] = val;
512 return val;
513}
514
515Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
516 Doc doc = Print(buf->shape);
517 bool print_factor_explicitly = false;
518 doc << ", dtype=" << PrintDType(buf->dtype);
519 if (memo_var_.find(buf->data) != memo_var_.end()) {
520 doc << ", data=" << Print(buf->data);
521 } else {
522 // implicitly define data
523 memo_var_[buf->data] = Doc::Text(memo_buf_[buf].str() + ".data");
524 var_not_in_headers_.insert(buf->data.get());
525 }
526 if (!buf->strides.empty()) {
527 doc << ", strides=" << Print(buf->strides);
528 }
529 if (buf->elem_offset->IsInstance<tir::VarNode>()) {
530 tir::Var elem_offset = Downcast<tir::Var>(buf->elem_offset);
531 if (memo_var_.find(elem_offset) != memo_var_.end()) {
532 doc << ", elem_offset=" << Print(buf->elem_offset);
533 } else {
534 // implicitly define elem_offset
535 memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset");
536 var_not_in_headers_.insert(elem_offset.get());
537 print_factor_explicitly = true;
538 }
539 } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
540 IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
541 if (elem_offset->value != 0) {
542 doc << ", elem_offset=" << Print(buf->elem_offset);
543 }
544 }
545 if (buf.scope() != "global") {
546 doc << ", scope=" << Doc::StrLiteral(buf.scope());
547 }
548 if (buf->data_alignment != runtime::kAllocAlignment) {
549 doc << ", align=" << buf->data_alignment;
550 }
551 if (buf->offset_factor != 1 || print_factor_explicitly) {
552 doc << ", offset_factor=" << buf->offset_factor;
553 }
554 if (buf->buffer_type != BufferType::kDefault) {
555 doc << ", type=" << Doc::StrLiteral("auto");
556 }
557 if (buf->axis_separators.size()) {
558 doc << ", axis_separators=" << Print(buf->axis_separators);
559 }
560 return doc;
561}
562
563Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) {
564 const auto& it = memo_buf_.find(buffer);
565 if (it != memo_buf_.end()) {
566 return it->second;
567 }
568 std::string name = buffer->name;
569 if (name.length() == 0 || !std::isalpha(name[0])) {
570 name = "buf_" + name;
571 }
572 Doc val = GetUniqueName(name);
573 memo_buf_[buffer] = val;
574 memo_buf_decl_[buffer] = AllocBufferDeclaration(buffer);
575 return val;
576}
577
578/*!
579 * \brief Check if any optional information exists in annotate_ for
580 * a given Stmt.
581 * \param stmt The statement.
582 */
583bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) {
584 if (annotate_ == nullptr) return false;
585 return !annotate_(stmt).empty();
586}
587
588/*!
589 * \brief Try to dealloc vars out of space and leave the index to coming vars.
590 * \note It is not a necessary step.
591 */
592void TVMScriptPrinter::TryDeallocVar(const tir::Var& var) {
593 auto it = memo_var_.find(var);
594 ICHECK(it != memo_var_.end());
595 std::string print_name = it->second.str();
596
597 std::string name_hint = var->name_hint.operator std::string();
598 if (name_hint.length() == 0 || !std::isalpha(name_hint[0])) {
599 name_hint = "v" + name_hint;
600 }
601 std::replace(name_hint.begin(), name_hint.end(), '.', '_');
602
603 auto it2 = name_alloc_map_.find(name_hint);
604 // Skip it if we can not find the name_hint in name_alloc_map_.
605 if (it2 == name_alloc_map_.end()) return;
606 if (it2->second > 0) {
607 name_hint = name_hint + '_' + std::to_string(it2->second);
608 }
609 // Skip it if the name_hint is not equal to how it should be printed.
610 if (name_hint != print_name) return;
611 // Free the conresponding name_alloc_map_ index
612 --it2->second;
613}
614
615Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
616 const Buffer& buf = op->buffer;
617 buf_not_in_headers_.insert(buf.get());
618
619 Doc doc = Print(op->buffer) << " = " << tir_prefix_ << ".match_buffer(" << Print(op->source)
620 << ", " << memo_buf_decl_[op->buffer] << ")";
621 return doc;
622}
623
624// check if all arguments, except the first two, are specified for T.match_buffer
625// if not, then this match buffer is printed out as T.buffer in prim_func arguments
626// and check whether there are undefined variables in the shape/strides.
627bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
628 if (memo_var_.find(buf->data) != memo_var_.end()) {
629 return false;
630 }
631 if (!buf->strides.empty()) {
632 return false;
633 }
634 for (const PrimExpr& shp_i : buf->shape) {
635 if (!UndefinedVars(shp_i).empty()) {
636 return false;
637 }
638 }
639 for (const PrimExpr& stride_i : buf->strides) {
640 if (!UndefinedVars(stride_i).empty()) {
641 return false;
642 }
643 }
644 if (!UndefinedVars(buf->elem_offset).empty()) {
645 return false;
646 } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
647 IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
648 if (elem_offset->value != 0) {
649 return false;
650 }
651 }
652 if (buf.scope() != "global") {
653 return false;
654 }
655 if (buf->data_alignment != runtime::kAllocAlignment) {
656 return false;
657 }
658 if (buf->offset_factor != 1) {
659 return false;
660 }
661 if (buf->buffer_type != BufferType::kDefault) {
662 return false;
663 }
664 if (buf->axis_separators.size()) {
665 return false;
666 }
667 return true;
668}
669
670Doc TVMScriptPrinter::PrintInlineBufferBind(const Buffer& buffer) {
671 Doc doc;
672 doc << tir_prefix_ << ".Buffer[";
673 if (buffer->shape.size() == 1) {
674 doc << Print(buffer->shape[0]);
675 } else {
676 doc << PrintTuple(buffer->shape.as<ArrayNode>());
677 }
678 doc << ", " << PrintDType(buffer->dtype) << "]";
679 return doc;
680}
681
682// print array out as tuple with parentheses
683Doc TVMScriptPrinter::PrintTuple(const ArrayNode* op) {
684 Doc doc;
685 doc << '(';
686 for (size_t i = 0; i < op->size(); ++i) {
687 if (i != 0) {
688 doc << ", ";
689 }
690 doc << Print(op->at(i));
691 }
692 if (op->size() == 1) doc << ",";
693 doc << ')';
694 return doc;
695}
696
697Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) {
698 Doc doc;
699 int n_var = static_cast<int>(op->rhs.size());
700
701 doc << tir_prefix_ << ".comm_reducer(lambda ";
702 for (const tir::Var& v_lhs : op->lhs) {
703 doc << Print(v_lhs) << ", ";
704 }
705 for (int i = 0; i < n_var; ++i) {
706 doc << Print(op->rhs[i]) << (i == n_var - 1 ? ": " : ", ");
707 }
708 if (n_var == 1) {
709 doc << Print(op->result[0]) << ", ";
710 } else {
711 doc << "(";
712 for (int i = 0; i < n_var; ++i) {
713 doc << Print(op->result[i]);
714 if (i != n_var - 1) {
715 doc << ", ";
716 }
717 }
718 doc << "), ";
719 }
720 doc << Print(op->identity_element) << ")";
721
722 // Remove the vars in `lhs` and `rhs`, because they are the parameters of the printed lambda.
723 for (int i = 0; i < n_var; ++i) {
724 memo_var_.erase(op->lhs[i]);
725 memo_var_.erase(op->rhs[i]);
726 }
727 return doc;
728}
729
730Doc TVMScriptPrinter::Print(const ObjectRef& node) {
731 if (!node.defined()) return Doc::Text("None");
732 if (node->IsInstance<StmtNode>()) {
733 return PrintOptionalInfo(Downcast<Stmt>(node)) << VisitStmt(Downcast<Stmt>(node));
734 } else if (node->IsInstance<PrimExprNode>()) {
735 ExprPrecedence t = ExprPrecedence::kUnknown;
736 return VisitExpr(Downcast<PrimExpr>(node), &t);
737 } else if (node->IsInstance<TypeNode>()) {
738 return VisitType(Downcast<Type>(node));
739 } else if (node->IsInstance<PrimFuncNode>()) {
740 return PrintPrimFunc(Downcast<PrimFunc>(node));
741 } else if (node->IsInstance<IRModuleNode>()) {
742 return PrintIRModule(Downcast<IRModule>(node));
743 } else if (node->IsInstance<ArrayNode>()) {
744 return PrintArray(node.as<ArrayNode>());
745 } else if (node->IsInstance<BufferNode>()) {
746 return PrintBuffer(node.as<BufferNode>());
747 } else if (node->IsInstance<StringObj>()) {
748 return PrintString(node.as<StringObj>());
749 } else if (node->IsInstance<IterVarNode>()) {
750 return PrintIterVar(node.as<IterVarNode>());
751 } else if (node->IsInstance<RangeNode>()) {
752 return PrintRange(node.as<RangeNode>());
753 } else if (node->IsInstance<BufferRegionNode>()) {
754 return PrintBufferRegion(node.as<BufferRegionNode>());
755 } else if (node->IsInstance<MatchBufferRegionNode>()) {
756 return PrintMatchBufferRegion(node.as<MatchBufferRegionNode>());
757 } else if (node->IsInstance<CommReducerNode>()) {
758 return PrintCommReducer(node.as<CommReducerNode>());
759 } else if (node->IsInstance<TargetNode>()) {
760 return PrintTarget(node.as<TargetNode>());
761 } else {
762 LOG(FATAL) << "Do not know how to print " << node->GetTypeKey();
763 }
764}
765
766Doc TVMScriptPrinter::VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) {
767 LOG(FATAL) << "Do not know how to print " << op->GetTypeKey();
768}
769
770Doc TVMScriptPrinter::VisitStmtDefault_(const Object* op) {
771 LOG(FATAL) << "Do not know how to print " << op->GetTypeKey();
772}
773
774Doc TVMScriptPrinter::VisitExpr_(const IntImmNode* op, ExprPrecedence* out_precedence) {
775 *out_precedence = ExprPrecedence::kIdentity;
776 return PrintConstScalar(op->dtype, &(op->value));
777}
778
779Doc TVMScriptPrinter::VisitExpr_(const FloatImmNode* op, ExprPrecedence* out_precedence) {
780 *out_precedence = ExprPrecedence::kIdentity;
781 return PrintConstScalar(op->dtype, &(op->value));
782}
783
784Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) {
785 *out_precedence = ExprPrecedence::kIdentity;
786 return Doc::StrLiteral(op->value);
787}
788
789Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) {
790 *out_precedence = ExprPrecedence::kIdentity;
791 Doc doc;
792 doc << tir_prefix_ << ".Cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
793 return doc;
794}
795
796Doc TVMScriptPrinter::VisitExpr_(const tir::VarNode* op, ExprPrecedence* out_precedence) {
797 *out_precedence = ExprPrecedence::kIdentity;
798 const tir::Var& var = GetRef<tir::Var>(op);
799 return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<tir::Var>(op));
800}
801
802bool WillPrintConstScalar(const PrimExpr& expr) {
803 if (const auto* imm = expr.as<IntImmNode>()) {
804 DataType dtype = imm->dtype;
805 return dtype == DataType::Int(32) || dtype == DataType::Bool();
806 }
807 return false;
808}
809
810#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpClass, OpPrecedence) \
811 Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \
812 Doc doc; \
813 if (WillPrintConstScalar(op->a) && WillPrintConstScalar(op->b)) { \
814 *out_precedence = ExprPrecedence::kIdentity; \
815 doc << tir_prefix_ << "." << OpClass << "(" << Print(op->a) << ", " << Print(op->b) << ")"; \
816 return doc; \
817 } \
818 ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown; \
819 ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown; \
820 /* Get children expr out_precedence */ \
821 Doc lhs_doc = VisitExpr(op->a, &lhs_precedence); \
822 Doc rhs_doc = VisitExpr(op->b, &rhs_precedence); \
823 ICHECK(lhs_precedence != ExprPrecedence::kUnknown); \
824 ICHECK(rhs_precedence != ExprPrecedence::kUnknown); \
825 /* Update out_precedence of current node. */ \
826 *out_precedence = OpPrecedence; \
827 if (lhs_precedence > OpPrecedence || \
828 (lhs_precedence == ExprPrecedence::kAnd && OpPrecedence == ExprPrecedence::kOr)) { \
829 doc << "(" << lhs_doc << ")"; \
830 } else { \
831 doc << lhs_doc; \
832 } \
833 doc << OpString; \
834 if (rhs_precedence >= OpPrecedence || \
835 (rhs_precedence == ExprPrecedence::kAnd && OpPrecedence == ExprPrecedence::kOr)) { \
836 doc << "(" << rhs_doc << ")"; \
837 } else { \
838 doc << rhs_doc; \
839 } \
840 return doc; \
841 }
842
843TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", "Mul", ExprPrecedence::kMultiplicationDivision)
844TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", "Div", ExprPrecedence::kMultiplicationDivision)
845TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", "FloorDiv",
846 ExprPrecedence::kMultiplicationDivision)
847TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", "FloorMod",
848 ExprPrecedence::kMultiplicationDivision)
849TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", "Add", ExprPrecedence::kAdditionSubtraction)
850TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", "Sub", ExprPrecedence::kAdditionSubtraction)
851TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", "LT", ExprPrecedence::kRelational)
852TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", "LE", ExprPrecedence::kRelational)
853TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", "GT", ExprPrecedence::kRelational)
854TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", "GE", ExprPrecedence::kRelational)
855TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", "EQ", ExprPrecedence::kEquality)
856TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", "NE", ExprPrecedence::kEquality)
857TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", "And", ExprPrecedence::kAnd)
858TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", "Or", ExprPrecedence::kOr)
859
860Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) {
861 *out_precedence = ExprPrecedence::kIdentity;
862 Doc doc;
863 doc << tir_prefix_ << ".truncmod(" << Print(op->a) << ", " << Print(op->b) << ")";
864 return doc;
865}
866
867Doc TVMScriptPrinter::VisitExpr_(const MinNode* op, ExprPrecedence* out_precedence) {
868 *out_precedence = ExprPrecedence::kIdentity;
869 Doc doc;
870 doc << tir_prefix_ << ".min(" << Print(op->a) << ", " << Print(op->b) << ")";
871 return doc;
872}
873
874Doc TVMScriptPrinter::VisitExpr_(const MaxNode* op, ExprPrecedence* out_precedence) {
875 *out_precedence = ExprPrecedence::kIdentity;
876 Doc doc;
877 doc << tir_prefix_ << ".max(" << Print(op->a) << ", " << Print(op->b) << ")";
878 return doc;
879}
880
881Doc TVMScriptPrinter::VisitExpr_(const NotNode* op, ExprPrecedence* out_precedence) {
882 *out_precedence = ExprPrecedence::kIdentity;
883 Doc doc;
884 doc << "not(" << Print(op->a) << ")";
885 return doc;
886}
887
888Doc TVMScriptPrinter::VisitExpr_(const SelectNode* op, ExprPrecedence* out_precedence) {
889 *out_precedence = ExprPrecedence::kIdentity;
890 Doc doc;
891 doc << tir_prefix_ << ".Select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
892 << Print(op->false_value) << ")";
893 return doc;
894}
895
896Doc TVMScriptPrinter::VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) {
897 LOG(FATAL) << "Cannot print a tir.ProducerLoad as it is not valid in TIR Primfuncs. You need to "
898 "lower this function first.";
899 return Doc();
900}
901
902Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) {
903 *out_precedence = ExprPrecedence::kIdentity;
904 Doc doc;
905 if (op->indices.size() == 0) {
906 doc << Print(op->buffer) << "[()]";
907 } else {
908 doc << Print(op->buffer) << PrintBufferIndices(op->indices);
909 }
910 return doc;
911}
912
913Doc TVMScriptPrinter::VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) {
914 *out_precedence = ExprPrecedence::kIdentity;
915 Doc doc;
916 if (op->dtype == DataType::Float(32) && is_one(op->predicate) &&
917 op->buffer_var->dtype == DataType::Float(32)) {
918 doc << Print(op->buffer_var) << "[" << Print(op->index) << "]";
919 } else {
920 doc << tir_prefix_ << ".load(" << PrintDType(op->dtype) << ", " << Print(op->buffer_var) << ", "
921 << Print(op->index);
922 if (!is_one(op->predicate) || op->dtype.lanes() != 1) {
923 doc << ", " << Print(op->predicate);
924 }
925 doc << ")";
926 }
927 return doc;
928}
929
930Doc TVMScriptPrinter::VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) {
931 *out_precedence = ExprPrecedence::kIdentity;
932 Doc doc;
933 doc << tir_prefix_ << ".ramp(" << Print(op->base) << ", " << Print(op->stride) << ", "
934 << op->lanes << ")";
935 return doc;
936}
937
938Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) {
939 *out_precedence = ExprPrecedence::kIdentity;
940 Doc doc;
941 doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " << op->lanes << ")";
942 return doc;
943}
944
945Doc TVMScriptPrinter::VisitExpr_(const tir::LetNode* op, ExprPrecedence* out_precedence) {
946 *out_precedence = ExprPrecedence::kIdentity;
947 Doc doc;
948 doc << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << ", "
949 << Print(op->body) << ")";
950 return doc;
951}
952
953Doc TVMScriptPrinter::VisitExpr_(const tir::CallNode* op, ExprPrecedence* out_precedence) {
954 *out_precedence = ExprPrecedence::kIdentity;
955 Doc doc;
956 if (auto* ptr_op = op->op.as<OpNode>()) {
957 std::string name = ptr_op->name;
958 if (name.find("tir.") == 0) {
959 name = tir_prefix_ + "." + name.substr(4);
960 }
961 doc << name << "(";
962 } else {
963 auto* op_gvar = op->op.as<GlobalVarNode>();
964 ICHECK(op_gvar != nullptr);
965 doc << Doc::Text(op_gvar->name_hint) << "(";
966 }
967 std::vector<Doc> args;
968 for (const auto& arg : op->args) {
969 args.push_back(Print(arg));
970 }
971 args.push_back(Doc::Text("dtype=") << PrintDType(op->dtype));
972 doc << PrintSep(args, Doc::Text(", ")) << ")";
973 return doc;
974}
975
976Doc TVMScriptPrinter::VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) {
977 *out_precedence = ExprPrecedence::kIdentity;
978 Doc doc;
979 doc << tir_prefix_ << ".shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")";
980 return doc;
981}
982
983Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) {
984 *out_precedence = ExprPrecedence::kIdentity;
985 Doc doc;
986 doc << tir_prefix_ << ".reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", "
987 << Print(op->axis) << ", " << op->value_index << ")";
988 return doc;
989}
990
991Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
992 if (!buffer_var_usage_.count(op->var)) {
993 buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
994 }
995 Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->var).value_or({});
996
997 Doc doc;
998 if (current_num_ != num_child_ - 1) {
999 doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):";
1000 doc << Doc::Indent(
1001 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body));
1002 } else {
1003 if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get());
1004 doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value)
1005 << Doc::NewLine();
1006 doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body);
1007 }
1008 return doc;
1009}
1010
1011Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) {
1012 Doc doc;
1013 if (op->node.defined()) {
1014 // merge attr with realize when possible
1015 if (op->node->IsInstance<BufferNode>() && op->attr_key == "realize_scope" &&
1016 op->body->IsInstance<BufferRealizeNode>()) {
1017 const auto* realize = Downcast<BufferRealize>(op->body).get();
1018 if (realize->buffer.same_as(op->node)) {
1019 if (current_num_ != num_child_ - 1) {
1020 doc << "with " << tir_prefix_ << ".realize(" << Print(realize->buffer)
1021 << Print(realize->bounds) << ", " << Print(op->value);
1022 if (!is_one(realize->condition)) {
1023 doc << ", " << Print(realize->condition);
1024 }
1025 doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(realize->body));
1026 } else {
1027 doc << tir_prefix_ << ".realize(" << Print(realize->buffer) << Print(realize->bounds)
1028 << ", " << Print(op->value);
1029 if (!is_one(realize->condition)) {
1030 doc << ", " << Print(realize->condition);
1031 }
1032 doc << ")" << Doc::NewLine() << PrintBody(realize->body);
1033 }
1034 return doc;
1035 }
1036 }
1037 // concise thread env
1038 if (op->node->IsInstance<IterVarNode>() &&
1039 (op->attr_key == "thread_extent" || op->attr_key == "virtual_thread")) {
1040 const auto* iter_var = Downcast<IterVar>(op->node).get();
1041 var_not_in_headers_.insert(iter_var->var.get());
1042 var_env_map_[iter_var->var] = iter_var->thread_tag;
1043 if (current_num_ != num_child_ - 1) {
1044 doc << "with " << tir_prefix_ << ".launch_thread(" << Print(iter_var->var) << ", "
1045 << Print(op->value) << "):";
1046 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1047 } else {
1048 doc << tir_prefix_ << ".launch_thread(" << Print(iter_var->var) << ", " << Print(op->value)
1049 << ")";
1050 doc << Doc::NewLine() << PrintBody(op->body);
1051 }
1052 return doc;
1053 }
1054 }
1055 // default
1056 if (current_num_ != num_child_ - 1) {
1057 doc << "with " << tir_prefix_ << ".attr(" << Print(op->node) << ", "
1058 << Doc::StrLiteral(op->attr_key) << ", " << Print(op->value) << "):";
1059 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1060 } else {
1061 doc << tir_prefix_ << ".attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key)
1062 << ", " << Print(op->value) << ")";
1063 doc << Doc::NewLine() << PrintBody(op->body);
1064 }
1065 return doc;
1066}
1067
1068Doc TVMScriptPrinter::VisitStmt_(const AssertStmtNode* op) {
1069 Doc doc;
1070 if (current_num_ != num_child_ - 1) {
1071 doc << "with " << tir_prefix_ << ".Assert(" << Print(op->condition) << ", "
1072 << Print(op->message) << "):";
1073 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1074 } else {
1075 doc << "assert " << Print(op->condition) << ", " << Print(op->message);
1076 doc << Doc::NewLine() << PrintBody(op->body);
1077 }
1078 return doc;
1079}
1080
1081Doc TVMScriptPrinter::VisitStmt_(const StoreNode* op) {
1082 Doc doc;
1083 doc << tir_prefix_ << ".store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", "
1084 << Print(op->value) << ", " << Print(op->predicate) << ")";
1085 return doc;
1086}
1087
1088Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
1089 LOG(FATAL)
1090 << "TVM Script Printer Internal Error: All the BufferRealize should be folded with Attr";
1091 return Doc();
1092}
1093
1094namespace {
1095
1096bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) {
1097 const tir::Var& buffer_var = allocate->buffer_var;
1098 const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>();
1099 if (!decl_buffer) {
1100 return false;
1101 }
1102 const Buffer& buffer = decl_buffer->buffer;
1103 if (!buffer_var.same_as(buffer->data)) {
1104 return false;
1105 }
1106 if (allocate->dtype != buffer->dtype) {
1107 return false;
1108 }
1109 if (!is_one(allocate->condition)) {
1110 return false;
1111 }
1112 if (allocate->annotations.size()) {
1113 return false;
1114 }
1115 if (allocate->extents.size() != buffer->shape.size()) {
1116 return false;
1117 }
1118 tir::ExprDeepEqual expr_equal;
1119 for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
1120 if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
1121 return false;
1122 }
1123 }
1124 return true;
1125}
1126
1127} // namespace
1128
1129Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
1130 var_not_in_headers_.insert(op->buffer_var.get());
1131
1132 if (!buffer_var_usage_.count(op->buffer_var)) {
1133 buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
1134 }
1135 Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({});
1136
1137 if (buffer_usage.empty()) {
1138 if (IsAllocateDeclBufferPattern(op)) {
1139 // As a syntax sugar, we identify the pattern of Allocate and DeclBuffer and print a single
1140 // DeclBuffer statement. It is intentionally to call `Print` instead of `PrintBody` here to
1141 // delegate the printing of the current node to `DeclBufferNode` while maintaining the
1142 // same value of `current_num_` and `num_child_`.
1143 return Print(op->body);
1144 }
1145 }
1146
1147 auto storage_scope = GetPtrStorageScope(op->buffer_var);
1148 Doc func_call;
1149 func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype)
1150 << ", " << Print(storage_scope);
1151 if (!is_one(op->condition)) {
1152 func_call << ", " << Print(op->condition);
1153 }
1154 if (!op->annotations.empty()) {
1155 func_call << ", annotations={";
1156 func_call << PrintAnnotations(op->annotations);
1157 func_call << "}";
1158 }
1159 func_call << ")";
1160
1161 Doc doc;
1162 if (current_num_ != num_child_ - 1) {
1163 doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
1164 doc << Doc::Indent(
1165 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body));
1166 } else {
1167 doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine();
1168 doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body);
1169 }
1170 TryDeallocVar(op->buffer_var);
1171 return doc;
1172}
1173
1174Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
1175 std::stringstream ss;
1176 ICHECK(alloc->data) << "Should be presented";
1177 const auto& data = alloc->data.value();
1178
1179 if (alloc->dtype.is_int()) {
1180 if (alloc->dtype.bits() == 8) {
1181 NDArrayToTIR<int8_t>(data, ss);
1182 } else if (alloc->dtype.bits() == 16) {
1183 NDArrayToTIR<int16_t>(data, ss);
1184 } else if (alloc->dtype.bits() == 32) {
1185 NDArrayToTIR<int32_t>(data, ss);
1186 } else if (alloc->dtype.bits() == 64) {
1187 NDArrayToTIR<int64_t>(data, ss);
1188 } else {
1189 LOG(FATAL) << "DataType not supported";
1190 }
1191 } else if (alloc->dtype.is_uint()) {
1192 if (alloc->dtype.bits() == 8) {
1193 NDArrayToTIR<uint8_t>(data, ss);
1194 } else if (alloc->dtype.bits() == 16) {
1195 NDArrayToTIR<uint16_t>(data, ss);
1196 } else if (alloc->dtype.bits() == 32) {
1197 NDArrayToTIR<uint32_t>(data, ss);
1198 } else if (alloc->dtype.bits() == 64) {
1199 NDArrayToTIR<int64_t>(data, ss);
1200 } else {
1201 LOG(FATAL) << "DataType not supported";
1202 }
1203 } else if (alloc->dtype.is_float()) {
1204 if (alloc->dtype.bits() == 16) {
1205 NDArrayToTIR<int16_t>(data, ss);
1206 } else if (alloc->dtype.bits() == 32) {
1207 NDArrayToTIR<float>(data, ss);
1208 } else if (alloc->dtype.bits() == 64) {
1209 NDArrayToTIR<double>(data, ss);
1210 } else {
1211 LOG(FATAL) << "DataType not supported";
1212 }
1213 } else {
1214 LOG(FATAL) << "DataType not supported";
1215 }
1216 auto ndarray_str = ss.str();
1217
1218 var_not_in_headers_.insert(alloc->buffer_var.get());
1219
1220 if (!buffer_var_usage_.count(alloc->buffer_var)) {
1221 buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), alloc->body);
1222 }
1223 Array<Buffer> buffer_usage = buffer_var_usage_.Get(alloc->buffer_var).value_or({});
1224
1225 Doc func_call;
1226 func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype)
1227 << ", " << Print(alloc->extents) << ")";
1228
1229 Doc doc;
1230 var_not_in_headers_.insert(alloc->buffer_var.get());
1231 if (current_num_ != num_child_ - 1) {
1232 doc << "with " << func_call << " as " << Print(alloc->buffer_var) << ":";
1233 doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage)
1234 << PrintBody(alloc->body));
1235 } else {
1236 doc << Print(alloc->buffer_var) << " = " << func_call << Doc::NewLine();
1237 doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(alloc->body);
1238 }
1239 return doc;
1240}
1241
1242Doc TVMScriptPrinter::VisitStmt_(const DeclBufferNode* op) {
1243 const Buffer& buffer = op->buffer;
1244 buf_not_in_headers_.insert(buffer.get());
1245 Doc buffer_name = Print(op->buffer);
1246 Doc func_call;
1247 func_call << tir_prefix_ << ".decl_buffer(" << memo_buf_decl_.at(buffer) << ")";
1248
1249 Doc doc;
1250 if (current_num_ != num_child_ - 1) {
1251 doc << "with " << func_call << " as " << buffer_name << ":";
1252 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1253 } else {
1254 doc << buffer_name << " = " << func_call << Doc::NewLine();
1255 doc << PrintBody(op->body);
1256 }
1257 return doc;
1258}
1259
1260Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
1261 Doc doc;
1262 doc << "if " << Print(op->condition) << ":";
1263 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case));
1264
1265 Optional<Stmt> else_case = op->else_case;
1266 while (else_case) {
1267 if (auto* else_if = else_case.value().as<IfThenElseNode>()) {
1268 doc << Doc::NewLine();
1269 doc << "elif " << Print(else_if->condition) << ":";
1270 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(else_if->then_case));
1271
1272 else_case = else_if->else_case;
1273 } else {
1274 doc << Doc::NewLine();
1275 doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(else_case.value()));
1276 break;
1277 }
1278 }
1279
1280 return doc;
1281}
1282
1283Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) {
1284 std::vector<Doc> stmts;
1285 for (Stmt stmt : op->seq) {
1286 stmts.push_back(Print(stmt));
1287 }
1288 return PrintSep(stmts, Doc::NewLine());
1289}
1290
1291Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
1292 // When parsing TVMScript, a PrimExpr that occurs as a statement is
1293 // automatically wrapped in `tir::Evaluate`. Therefore, when
1294 // printing, it's only necessary to print the value. For
1295 // readability, though, we still print T.evaluate() when the
1296 // expression is something other than a call node.
1297 Doc doc;
1298 if (op->value.as<CallNode>()) {
1299 doc << Print(op->value);
1300 } else {
1301 doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
1302 }
1303 return doc;
1304}
1305
1306Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
1307 Doc doc;
1308 var_not_in_headers_.insert(op->loop_var.get());
1309 loop_var_map_[op->loop_var.get()] = GetRef<For>(op);
1310 const auto* body = op->body.as<ForNode>();
1311 bool simple_loop = IsSimpleLoop(op);
1312 if (simple_loop) simple_loop_stack_.push_back(GetRef<For>(op));
1313 // It is a loop that can be compressed, let the loops below print it out
1314 if (simple_loop && body != nullptr && IsSimpleLoop(body) && !DependOnPrevLoops(body)) {
1315 doc << Print(GetRef<For>(body));
1316 TryDeallocVar(op->loop_var);
1317 loop_var_map_.erase(op->loop_var.get());
1318 return doc;
1319 }
1320 // It is a loop that can not be compressed
1321 bool print_above = !simple_loop_stack_.empty();
1322 // print loops above if needed
1323 if (print_above) {
1324 doc << PrintLoopStack();
1325 simple_loop_stack_.clear();
1326 }
1327 if (!simple_loop) {
1328 // print current loop if needed
1329 Doc current_loop;
1330 current_loop << PrintLoop(GetRef<For>(op));
1331 current_loop << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1332 doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) : current_loop);
1333 } else {
1334 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1335 }
1336 TryDeallocVar(op->loop_var);
1337 loop_var_map_.erase(op->loop_var.get());
1338 return doc;
1339}
1340
1341Doc TVMScriptPrinter::VisitStmt_(const PrefetchNode* op) {
1342 Doc doc;
1343 doc << tir_prefix_ << ".prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")";
1344 return doc;
1345}
1346
1347Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) {
1348 Doc doc;
1349 doc << "while " << Print(op->condition) << ":";
1350 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1351 return doc;
1352}
1353
1354Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {
1355 Doc doc;
1356 doc << tir_prefix_ << ".";
1357 if (node->dtype.is_void()) {
1358 doc << "void";
1359 } else {
1360 doc << runtime::DLDataType2String(node->dtype);
1361 }
1362 return doc;
1363}
1364
1365Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) {
1366 Doc doc;
1367 doc << tir_prefix_ << ".Ptr[";
1368 doc << Print(node->element_type);
1369 if (!node->storage_scope.empty()) {
1370 doc << ", " << Doc::StrLiteral(node->storage_scope);
1371 }
1372 doc << "]";
1373 return doc;
1374}
1375
1376Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) {
1377 if (node->fields.empty()) {
1378 return Doc::Text("None");
1379 } else {
1380 std::vector<Doc> fields;
1381 for (Type field : node->fields) {
1382 fields.push_back(Print(field));
1383 }
1384 return Doc::Text(tir_prefix_ + ".Tuple[") << Doc::Concat(fields) << "]";
1385 }
1386}
1387
1388Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) {
1389 Doc doc;
1390 if (op->indices.size() == 0) {
1391 doc << Print(op->buffer) << "[()] = " << Print(op->value);
1392 } else {
1393 doc << Print(op->buffer) << PrintBufferIndices(op->indices) << " = " << Print(op->value);
1394 }
1395 return doc;
1396}
1397
1398/*! Helper functions for block printing. */
1399Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) {
1400 Doc doc;
1401 doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis.";
1402 switch (iter_var->iter_type) {
1403 case kDataPar:
1404 doc << "spatial";
1405 break;
1406 case kCommReduce:
1407 doc << "reduce";
1408 break;
1409 case kOrdered:
1410 doc << "scan";
1411 break;
1412 case kOpaque:
1413 doc << "opaque";
1414 break;
1415 default:
1416 LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type;
1417 break;
1418 }
1419 doc << "(";
1420 const Range& dom = iter_var->dom;
1421 if (is_zero(dom->min)) {
1422 doc << Print(dom->extent);
1423 } else {
1424 doc << "(" << Print(dom->min) << ", " << Print(dom->min + dom->extent) << ")";
1425 }
1426 doc << ", " << Print(value) << ")";
1427 return doc;
1428}
1429
1430Doc TVMScriptPrinter::PrintBlockVarRemaps() {
1431 ICHECK(!block_var_remaps_.empty());
1432 if (block_var_remaps_.size() == 1) {
1433 const IterVar& iter_var = block_var_remaps_[0].first;
1434 const PrimExpr& value = block_var_remaps_[0].second;
1435 return PrintBlockVar(iter_var, value);
1436 }
1437 Doc doc;
1438 std::vector<Doc> iter_vars, iter_values;
1439 std::string iter_type;
1440 for (const auto& pair : block_var_remaps_) {
1441 const IterVar& iter_var = pair.first;
1442 const PrimExpr& value = pair.second;
1443 iter_vars.push_back(Print(iter_var->var));
1444 iter_values.push_back(Print(value));
1445 if (iter_var->iter_type == kDataPar) {
1446 iter_type += "S";
1447 } else if (iter_var->iter_type == kCommReduce) {
1448 iter_type += "R";
1449 } else {
1450 ICHECK(false);
1451 }
1452 }
1453 doc << PrintSep(iter_vars, Doc::Text(", ")) << " = " << tir_prefix_ << ".axis.remap("
1454 << Doc::StrLiteral(iter_type) << ", [" << PrintSep(iter_values, Doc::Text(", ")) << "])";
1455 return doc;
1456}
1457
1458Doc TVMScriptPrinter::PrintBlockPredicate(const BlockRealizeNode* op) {
1459 Doc doc;
1460 if (!is_one(op->predicate)) {
1461 doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")";
1462 }
1463 return doc;
1464}
1465
1466Doc TVMScriptPrinter::PrintBlockVars(const BlockRealizeNode* op) {
1467 Doc doc;
1468 const auto* block_op = op->block.as<BlockNode>();
1469 ICHECK_EQ(block_op->iter_vars.size(), op->iter_values.size());
1470 tir::ExprDeepEqual expr_equal;
1471
1472 auto is_simple_remap = [this, &expr_equal](const IterVar& iter_var,
1473 const PrimExpr& value) -> bool {
1474 if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) return false;
1475 if (!value->IsInstance<tir::VarNode>()) return false;
1476 const tir::Var& var = Downcast<tir::Var>(value);
1477 auto it = loop_var_map_.find(var.get());
1478 return it != loop_var_map_.end() && expr_equal(it->second->min, iter_var->dom->min) &&
1479 expr_equal(it->second->extent, iter_var->dom->extent);
1480 };
1481
1482 for (size_t i = 0; i < block_op->iter_vars.size(); ++i) {
1483 const IterVar& iter_var = block_op->iter_vars[i];
1484 const PrimExpr& value = op->iter_values[i];
1485 var_not_in_headers_.insert(iter_var->var.get());
1486 if (is_simple_remap(iter_var, value)) {
1487 block_var_remaps_.push_back(std::make_pair(iter_var, value));
1488 } else {
1489 if (!block_var_remaps_.empty()) {
1490 doc << Doc::NewLine() << PrintBlockVarRemaps();
1491 block_var_remaps_.clear();
1492 }
1493 doc << Doc::NewLine() << PrintBlockVar(iter_var, value);
1494 }
1495 }
1496 if (!block_var_remaps_.empty()) {
1497 doc << Doc::NewLine() << PrintBlockVarRemaps();
1498 block_var_remaps_.clear();
1499 }
1500 return doc;
1501}
1502
1503Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
1504 const auto* block_op = op->block.as<BlockNode>();
1505 Doc block_attr_doc;
1506 // print binding, read/write tensor region, annotations
1507 block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads("
1508 << PrintExpandedArray(block_op->reads.as<ArrayNode>()) << ")";
1509 block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes("
1510 << PrintExpandedArray(block_op->writes.as<ArrayNode>()) << ")";
1511 if (!block_op->annotations.empty()) {
1512 block_attr_doc << Doc::NewLine() << tir_prefix_ << ".block_attr({";
1513 block_attr_doc << PrintAnnotations(block_op->annotations);
1514 block_attr_doc << "})";
1515 }
1516 return block_attr_doc;
1517}
1518
1519// This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a
1520// List. Therefore the brackets are removed before and after printing arguments out
1521Doc TVMScriptPrinter::PrintExpandedArray(const ArrayNode* op) {
1522 Doc doc;
1523 for (size_t i = 0; i < op->size(); ++i) {
1524 if (i != 0) {
1525 doc << ", ";
1526 }
1527 doc << Print(op->at(i));
1528 }
1529 return doc;
1530}
1531
1532Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) {
1533 Doc body;
1534 for (const auto& alloc_buf : op->alloc_buffers) {
1535 buf_not_in_headers_.insert(alloc_buf.get());
1536 body << Print(alloc_buf) << " = " << tir_prefix_ << ".alloc_buffer("
1537 << memo_buf_decl_[alloc_buf] << ")" << Doc::NewLine();
1538 }
1539 for (const auto& match_buf : op->match_buffers) {
1540 body << Print(match_buf) << Doc::NewLine();
1541 }
1542 if (op->init.defined()) {
1543 Doc init_block;
1544 init_block << "with " << tir_prefix_ << ".init():";
1545 init_block << Doc::Indent(4, Doc::NewLine() << PrintBody(op->init.value()));
1546 body << init_block << Doc::NewLine();
1547 }
1548 body << PrintBody(op->body);
1549 return body;
1550}
1551
1552/*!
1553 * \brief Print the name of a block
1554 * \param block_op The block node to be printed
1555 */
1556Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) {
1557 Doc doc;
1558 doc << "with " << tir_prefix_ << ".block(";
1559 if (!block_op->name_hint.empty()) {
1560 doc << Doc::StrLiteral(block_op->name_hint);
1561 }
1562 doc << "):";
1563 return doc;
1564}
1565
1566Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) {
1567 const auto* block_op = op->block.as<BlockNode>();
1568 Doc doc = PrintOptionalInfo(GetRef<Stmt>(block_op));
1569 // print block name
1570 doc << PrintBlockName(block_op);
1571 // Print block predicate.
1572 Doc block_predicate = PrintBlockPredicate(op);
1573 // Print the variable bindings, valid to use in block attributes and
1574 // body
1575 Doc block_var = PrintBlockVars(op);
1576 // print read/write tensor region, annotations
1577 Doc block_attr_doc = PrintBlockAttr(op);
1578 // print body
1579 Doc body = PrintBlockBody(block_op);
1580 doc << Doc::Indent(4, block_predicate << block_var << block_attr_doc << Doc::NewLine() << body);
1581 for (const auto& iter_var : block_op->iter_vars) {
1582 TryDeallocVar(iter_var->var);
1583 }
1584 return doc;
1585}
1586
1587Doc TVMScriptPrinter::PrintBody(const Stmt& body) {
1588 int memo_num_child, memo_current_num;
1589 std::swap(memo_num_child, num_child_);
1590 std::swap(memo_current_num, current_num_);
1591
1592 Doc doc;
1593 if (body->IsInstance<SeqStmtNode>()) {
1594 const auto& op = Downcast<SeqStmt>(body);
1595 num_child_ = op->seq.size();
1596 current_num_ = 0;
1597 std::vector<Doc> stmts;
1598 for (Stmt stmt : op->seq) {
1599 stmts.push_back(Print(stmt));
1600 current_num_++;
1601 }
1602 doc = PrintSep(stmts, Doc::NewLine());
1603 } else {
1604 num_child_ = 1;
1605 current_num_ = 0;
1606 doc = Print(body);
1607 }
1608
1609 std::swap(memo_num_child, num_child_);
1610 std::swap(memo_current_num, current_num_);
1611 return doc;
1612}
1613
1614Doc TVMScriptPrinter::PrintIRModule(const IRModule& module) {
1615 auto* op = module.operator->();
1616 Doc doc;
1617 doc << "@tvm.script.ir_module" << Doc::NewLine();
1618 doc << "class Module:";
1619 for (const auto& x : op->functions) {
1620 func2var_[x.second.operator->()] = x.first;
1621 }
1622 Doc body = Doc::NewLine();
1623 std::vector<Doc> functions;
1624 for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
1625 if ((*it).second.as<PrimFuncNode>()) {
1626 functions.push_back(Print((*it).second));
1627 }
1628 }
1629 body << TVMScriptPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
1630 body << Doc::NewLine() << DumpMeta();
1631 doc << Doc::Indent(4, body);
1632 return doc;
1633}
1634
1635Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
1636 auto* op = primFunc.operator->();
1637 // clear renaming map
1638 memo_var_.clear();
1639 memo_buf_.clear();
1640 memo_buf_decl_.clear();
1641 var_not_in_headers_.clear();
1642 buf_not_in_headers_.clear();
1643 // print signature
1644 Doc doc;
1645 doc << "@" << tir_prefix_ << ".prim_func" << Doc::NewLine();
1646 doc << "def " << (func2var_.find(op) == func2var_.end() ? "func" : func2var_[op]->name_hint)
1647 << "(";
1648 std::vector<Doc> params;
1649 std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> simple_buf;
1650 for (const auto& param : op->params) {
1651 var_not_in_headers_.insert(param.get());
1652 auto it = op->buffer_map.find(param);
1653 // check if this param is a T.handle
1654 if (it != op->buffer_map.end()) {
1655 // check if this match_buffer has only the first two arguments specified
1656 // and whether the match_buffer is a dynamic buffer.
1657 const Buffer& buf = (*it).second;
1658 if (IsSimpleBuffer(buf)) {
1659 simple_buf.insert(buf);
1660 buf_not_in_headers_.insert(buf.get());
1661 params.push_back(Print(buf) << ": " << PrintInlineBufferBind(buf));
1662 continue;
1663 }
1664 }
1665 params.push_back(Print(param) << ": " << Print(GetType(param)));
1666 }
1667 doc << PrintSep(params, Doc::Text(", ")) << ")";
1668 if (primFunc->ret_type.defined()) {
1669 auto as_tuple = primFunc->ret_type.as<TupleTypeNode>();
1670 if (!as_tuple || as_tuple->fields.size()) {
1671 doc << " -> " << Print(primFunc->ret_type);
1672 }
1673 }
1674 doc << ":";
1675
1676 Doc body = Doc::NewLine();
1677 // print buffer_bind
1678 for (const auto& param : op->params) {
1679 auto it = op->buffer_map.find(param);
1680 if (it == op->buffer_map.end()) continue;
1681 const Buffer& buf = (*it).second;
1682 if (simple_buf.count(buf)) continue;
1683 buf_not_in_headers_.insert(buf.get());
1684 body << Print(buf) << " = " << tir_prefix_ << ".match_buffer(";
1685 ICHECK(memo_buf_decl_.count(buf));
1686 body << Print((*it).first) << ", " << memo_buf_decl_[buf];
1687 body << ")" << Doc::NewLine();
1688 }
1689 // print body
1690 body << "# body" << Doc::NewLine();
1691
1692 Optional<Block> elided_root_block_body = [&]() -> Optional<Block> {
1693 auto block_realize = op->body.as<BlockRealizeNode>();
1694 if (!block_realize || block_realize->iter_values.size()) {
1695 return NullOpt;
1696 }
1697
1698 const auto& block = block_realize->block;
1699 if (block->annotations.size() || ContainsOptionalInfo(block)) {
1700 return NullOpt;
1701 }
1702
1703 // The autocomplete might recognize the body itself as being a
1704 // root block, and fail to insert it.
1705 bool autocomplete_would_insert_root_block = [&]() -> bool {
1706 if (block->alloc_buffers.size()) {
1707 return true;
1708 }
1709
1710 auto* block_realize = block->body.as<BlockRealizeNode>();
1711 if (block_realize && block_realize->block->iter_vars.size()) {
1712 return true;
1713 }
1714 if (!block_realize && ContainsNode<BlockRealizeNode>(block->body)) {
1715 return true;
1716 }
1717 return false;
1718 }();
1719
1720 if (autocomplete_would_insert_root_block) {
1721 return block;
1722 } else {
1723 return NullOpt;
1724 }
1725 }();
1726
1727 if (elided_root_block_body) {
1728 // Skip printing of root block in cases where tvm::tir::ScriptComplete
1729 // would re-insert it.
1730 body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
1731 body << PrintBlockBody(elided_root_block_body.value().get());
1732 } else {
1733 // If this is a non-root block, or is an unskippable root block,
1734 // just print it without skipping.
1735 body << PrintBody(op->body);
1736 }
1737
1738 // print func attrs
1739 Doc header_attr;
1740 if (primFunc->attrs.defined()) {
1741 header_attr << Doc::NewLine() << "# function attr dict" << Doc::NewLine() << tir_prefix_
1742 << ".func_attr({";
1743 std::vector<Doc> attrs;
1744 for (const auto& it : op->attrs->dict) {
1745 attrs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
1746 }
1747 header_attr << PrintSep(attrs, Doc::Text(", ")) << "})";
1748 }
1749 // print buffer declarations(buffers not defined by buffer_bind or buffer_allocate)
1750 Doc header_buf;
1751 std::vector<const BufferNode*> bufs;
1752 for (const auto& it : memo_buf_) {
1753 if (buf_not_in_headers_.find(it.first.get()) == buf_not_in_headers_.end()) {
1754 bufs.push_back(it.first.get());
1755 }
1756 }
1757 if (!bufs.empty()) {
1758 header_buf << Doc::NewLine() << "# buffer definition";
1759 std::sort(bufs.begin(), bufs.end(), [&](const BufferNode* a, const BufferNode* b) {
1760 return memo_buf_[GetRef<Buffer>(a)].str() < memo_buf_[GetRef<Buffer>(b)].str();
1761 });
1762 for (const auto& buf : bufs) {
1763 header_buf << Doc::NewLine() << Print(GetRef<Buffer>(buf)) << " = " << tir_prefix_
1764 << ".buffer_decl(";
1765 header_buf << memo_buf_decl_[GetRef<Buffer>(buf)] << ")";
1766 }
1767 }
1768 // print var declaration
1769 Doc header_var;
1770 std::vector<const tir::VarNode*> vars;
1771 for (const auto& it : memo_var_) {
1772 if (var_not_in_headers_.find(it.first.get()) == var_not_in_headers_.end()) {
1773 vars.push_back(it.first.get());
1774 }
1775 }
1776 if (!var_env_map_.empty()) {
1777 header_var << Doc::NewLine() << "# var definition";
1778 for (const auto& it : var_env_map_) {
1779 header_var << Doc::NewLine() << Print(it.first) << " = " << tir_prefix_ << ".env_thread("
1780 << Doc::StrLiteral(it.second) << ")";
1781 }
1782 }
1783 if (!vars.empty()) {
1784 std::sort(vars.begin(), vars.end(), [&](const tir::VarNode* a, const tir::VarNode* b) {
1785 return memo_var_[GetRef<tir::Var>(a)].str() < memo_var_[GetRef<tir::Var>(b)].str();
1786 });
1787 for (const auto& var : vars) {
1788 auto type = GetRef<tir::Var>(var)->type_annotation;
1789 if (auto* ptr_type = type.as<PointerTypeNode>()) {
1790 auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
1791 ICHECK(prim_type);
1792 header_var << Doc::NewLine() << Print(GetRef<tir::Var>(var)) << " = " << tir_prefix_
1793 << ".buffer_var(";
1794 header_var << PrintDType(prim_type->dtype) << ", "
1795 << Doc::StrLiteral(ptr_type->storage_scope) << ")";
1796 } else {
1797 header_var << Doc::NewLine() << Print(GetRef<tir::Var>(var)) << " = " << tir_prefix_
1798 << ".var(";
1799 header_var << PrintDType(var->dtype) << ")";
1800 }
1801 }
1802 }
1803 doc << Doc::Indent(4, header_attr << header_var << header_buf << body);
1804 return doc;
1805}
1806
1807Doc TVMScriptPrinter::PrintArray(const ArrayNode* op) {
1808 Doc doc;
1809 doc << '[';
1810 for (size_t i = 0; i < op->size(); ++i) {
1811 if (i != 0) {
1812 doc << ", ";
1813 }
1814 doc << Print(op->at(i));
1815 }
1816 doc << ']';
1817 return doc;
1818}
1819
1820Doc TVMScriptPrinter::PrintIterVar(const IterVarNode* op) {
1821 Doc doc;
1822 doc << tir_prefix_ << ".iter_var(" << Print(op->var);
1823 if (op->dom.defined()) {
1824 doc << ", [" << Print(op->dom) << "], ";
1825 } else {
1826 doc << ", None, ";
1827 }
1828 doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", ";
1829 doc << Doc::StrLiteral(op->thread_tag) << ")";
1830 return doc;
1831}
1832
1833Doc TVMScriptPrinter::PrintRange(const RangeNode* op) {
1834 return Print(op->min) << ":" << Print(op->min + op->extent);
1835}
1836
1837Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
1838 const Buffer& buffer = GetRef<Buffer>(op);
1839 return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
1840}
1841
1842Doc TVMScriptPrinter::PrintBufferIndices(const Array<PrimExpr>& indices) {
1843 Doc doc;
1844 doc << '[';
1845 for (size_t i = 0; i < indices.size(); ++i) {
1846 if (i != 0) {
1847 doc << ", ";
1848 }
1849 PrimExpr index = indices[i];
1850 if (const RampNode* ramp = index.as<RampNode>()) {
1851 // specify ramp printing as python index slice
1852 if (auto* stride_imm = ramp->stride.as<IntImmNode>()) {
1853 doc << Print(ramp->base) << ":" << Print(ramp->base + ramp->lanes * ramp->stride);
1854 if (stride_imm->value != 1) {
1855 doc << ":" << Print(ramp->stride);
1856 }
1857 continue;
1858 }
1859 }
1860 doc << Print(index);
1861 }
1862 doc << ']';
1863 return doc;
1864}
1865
1866Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers) {
1867 Doc decls;
1868 for (const auto& buf_usage : aliasing_buffers) {
1869 decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl("
1870 << memo_buf_decl_[buf_usage] << ")" << Doc::NewLine();
1871 buf_not_in_headers_.insert(buf_usage.get());
1872 }
1873 return decls;
1874}
1875
1876Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
1877 Doc doc;
1878 if (op->region.size() == 0) {
1879 doc << Print(op->buffer) << "[()]";
1880 } else {
1881 doc << Print(op->buffer) << "[";
1882 for (size_t i = 0; i < op->region.size(); ++i) {
1883 if (i != 0) doc << ", ";
1884 const auto& range = op->region[i];
1885 if (!is_one(range->extent)) {
1886 doc << Print(range->min) << " : " << Print(ana_.Simplify(range->min + range->extent));
1887 } else {
1888 doc << Print(range->min);
1889 }
1890 }
1891 doc << "]";
1892 }
1893 return doc;
1894}
1895
1896Doc TVMScriptPrinter::PrintAnnotations(const Map<String, ObjectRef>& annotations) {
1897 Doc res;
1898 std::vector<std::pair<String, ObjectRef>> anno_list;
1899 anno_list.reserve(annotations.size());
1900 for (const auto& pair : annotations) {
1901 anno_list.emplace_back(pair);
1902 }
1903 sort(anno_list.begin(), anno_list.end());
1904 for (size_t i = 0; i < anno_list.size(); ++i) {
1905 if (i != 0) {
1906 res << ", ";
1907 }
1908 res << "\"" << anno_list[i].first << "\":" << Print(anno_list[i].second);
1909 }
1910 return res;
1911}
1912
1913Doc TVMScriptPrinter::PrintLoop(const For& loop) {
1914 Doc res;
1915 res << "for " << Print(loop->loop_var) << " in " << tir_prefix_
1916 << "." + std::string(ForKind2String(loop->kind)) + "(";
1917 if (is_zero(loop->min)) {
1918 res << Print(loop->extent);
1919 } else {
1920 res << Print(loop->min) << ", " << Print(ana_.Simplify(loop->min + loop->extent));
1921 }
1922 if (loop->thread_binding.defined()) {
1923 res << ", thread=";
1924 res << Print(loop->thread_binding.value()->thread_tag);
1925 }
1926 if (!loop->annotations.empty()) {
1927 res << ", annotations={";
1928 res << PrintAnnotations(loop->annotations);
1929 res << "}";
1930 }
1931 res << "):";
1932 return res;
1933}
1934
1935Doc TVMScriptPrinter::PrintLoopStack() {
1936 Doc res;
1937 if (simple_loop_stack_.size() == 1) {
1938 res << PrintLoop(simple_loop_stack_[0]);
1939 } else if (simple_loop_stack_.size() > 1) {
1940 std::vector<Doc> vars, extents;
1941 for (const auto& loop : simple_loop_stack_) {
1942 vars.push_back(Print(loop->loop_var));
1943 extents.push_back(Print(loop->extent));
1944 }
1945 res << "for " << PrintSep(vars, Doc::Text(", ")) << " in " << tir_prefix_ << ".grid("
1946 << PrintSep(extents, Doc::Text(", ")) << "):";
1947 }
1948 return res;
1949}
1950
1951Doc TVMScriptPrinter::PrintTarget(const TargetNode* target) {
1952 Doc res;
1953 res << tir_prefix_ << ".target({";
1954 Map<String, ObjectRef> config = target->Export();
1955 for (auto it = config.begin(); it != config.end(); ++it) {
1956 if (it != config.begin()) {
1957 res << ", ";
1958 }
1959 res << "\"" << (*it).first << "\":";
1960 if ((*it).first == "host") {
1961 ICHECK(target->host.defined());
1962 res << PrintTarget(target->GetHost().value().get());
1963 } else {
1964 res << Print((*it).second);
1965 }
1966 }
1967 res << "})";
1968 return res;
1969}
1970
1971/*!
1972 * \brief The printer for TVMScript with diagnostic
1973 * \details The printer obtain the precedence of the top-level operation when printing each
1974 * subexpression to decide whether or not parentheses is needed.
1975 */
1976class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter {
1977 public:
1978 explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta,
1979 runtime::TypedPackedFunc<std::string(Stmt)> annotate)
1980 : TVMScriptPrinter(tir_prefix, show_meta, annotate) {}
1981
1982 protected:
1983 Doc PrintBlockName(const BlockNode* block_op) override;
1984 Doc PrintUnderline(const Stmt& stmt, int length);
1985 Doc PrintLoop(const For& loop) override;
1986};
1987
1988Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) {
1989 Doc doc = TVMScriptPrinter::PrintBlockName(block_op);
1990 doc << PrintUnderline(GetRef<Stmt>(block_op), doc.str().size());
1991 return doc;
1992}
1993
1994Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) {
1995 Doc doc;
1996 // annotation
1997 if (ContainsOptionalInfo(stmt)) {
1998 String underline = std::string(length, '^');
1999 doc << Doc::NewLine() << underline;
2000 }
2001 return doc;
2002}
2003
2004Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) {
2005 Doc res = TVMScriptPrinter::PrintLoop(loop);
2006 res << PrintUnderline(loop, res.str().size());
2007 return res;
2008}
2009
2010String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
2011 runtime::TypedPackedFunc<std::string(Stmt)> annotate) {
2012 ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
2013 Doc doc;
2014 doc << TVMScriptPrinter::PrintHeader(tir_prefix)
2015 << TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod);
2016 return doc.str() + "\n";
2017}
2018
2019TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic);
2020
2021} // namespace relay
2022} // namespace tvm
2023