1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file printer/tvmscript_printer.cc
22 * \brief Printer class to print Tensor IR to python syntax script
23 */
24
25#include <tvm/arith/analyzer.h>
26#include <tvm/ir/module.h>
27#include <tvm/node/serialization.h>
28#include <tvm/runtime/registry.h>
29#include <tvm/target/target.h>
30#include <tvm/tir/analysis.h>
31#include <tvm/tir/buffer.h>
32#include <tvm/tir/expr.h>
33#include <tvm/tir/expr_functor.h>
34#include <tvm/tir/function.h>
35#include <tvm/tir/op.h>
36#include <tvm/tir/stmt.h>
37#include <tvm/tir/stmt_functor.h>
38
39#include <algorithm>
40#include <utility>
41
42#include "../tir/transforms/ir_utils.h"
43#include "doc.h"
44#include "meta_data.h"
45#include "text_printer.h"
46
47namespace tvm {
48namespace tir {
49
50enum class ExprPrecedence : int {
51 /*! \brief Identity(e.g., IntImm, Var) and function call(e.g., floordiv, min) */
52 kIdentity = 0,
53 /*!
54 * \brief Multiplication(*), division(/), and remainder(%)
55 * \note floorDiv, floorMod is marked as kIdentity since they are function calls.
56 */
57 kMultiplicationDivision = 1,
58 /*! \brief Addition(+) and subtraction(-) */
59 kAdditionSubtraction = 2,
60 /*! \brief For relational operators < and <= and > and >= respectively */
61 kRelational = 3,
62 /*! \brief For equality operators = and != respectively */
63 kEquality = 4,
64 /*! \brief And(&&) */
65 kAnd = 5,
66 /*! \brief Or(||) */
67 kOr = 6,
68 /*! \brief Unknown precedence */
69 kUnknown = 7,
70};
71
72/*! \brief Utility used for identifying usage of a buffer_var
73 *
74 * \details Find the Buffer object that corresponds to a variable or
75 * allocation, based on the BufferLoad/BufferStore instances that
76 * occur within the allocation's body.
77 */
78class BufferUsageFinder : public StmtExprVisitor {
79 public:
80 static Map<Var, Array<Buffer>> FindUsage(Map<Var, Array<Buffer>> usage, Stmt body) {
81 BufferUsageFinder visitor(std::move(usage));
82 visitor.VisitStmt(body);
83 return std::move(visitor.usage_);
84 }
85
86 void VisitExpr_(const VarNode* op) final {
87 Var var = GetRef<Var>(op);
88 if (!usage_.count(var)) {
89 usage_.Set(var, {});
90 }
91 }
92
93 void VisitExpr_(const BufferLoadNode* op) final {
94 VisitBuffer(op->buffer);
95 StmtExprVisitor::VisitExpr_(op);
96 }
97
98 void VisitStmt_(const BufferStoreNode* op) final {
99 VisitBuffer(op->buffer);
100 StmtExprVisitor::VisitStmt_(op);
101 }
102
103 void VisitStmt_(const DeclBufferNode* op) final {
104 buffers_declared_.insert(op->buffer.get());
105 StmtExprVisitor::VisitStmt_(op);
106 buffers_declared_.erase(op->buffer.get());
107 }
108
109 private:
110 explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {}
111
112 void VisitBuffer(const Buffer& buffer) {
113 if (buffers_visited_.count(buffer.get())) {
114 return;
115 }
116 if (buffers_declared_.count(buffer.get())) {
117 return;
118 }
119 buffers_visited_.insert(buffer.get());
120
121 Array<Buffer> arr = usage_.Get(buffer->data).value_or({});
122 arr.push_back(buffer);
123 usage_.Set(buffer->data, arr);
124 }
125
126 // The search result.
127 Map<Var, Array<Buffer>> usage_;
128 // The buffers that have been visited so far, to avoid duplicate
129 // entries in the search result.
130 std::unordered_set<const BufferNode*> buffers_visited_;
131 // The buffers declared via `DeclBuffer`. These buffers are excluded from the result because
132 // T.buffer_decl shouldn't be printed for them.
133 std::unordered_set<const BufferNode*> buffers_declared_;
134};
135
136/*!
137 * \brief The printer for TVMScript
138 * \details The printer obtain the precedence of the top-level operation when printing each
139 * subexpression to decide whether or not parentheses is needed.
140 */
141class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
142 public ExprFunctor<Doc(const PrimExpr&, ExprPrecedence*)>,
143 public TypeFunctor<Doc(const Type&)> {
144 public:
145 explicit TVMScriptPrinter(const String& tir_prefix, bool show_meta,
146 runtime::TypedPackedFunc<std::string(Stmt)> annotate = nullptr)
147 : tir_prefix_(tir_prefix),
148 show_meta_(show_meta),
149 annotate_(std::move(annotate)),
150 meta_collector_(&meta_) {}
151
152 /*!
153 * \brief Print the node.
154 * \param node The node to be printed.
155 * \param out_precedence The operator precedence of node if it's a PrimExpr,
156 * so we can simplify the bracket.
157 */
158 TVM_DLL Doc Print(const ObjectRef& node);
159
160 protected:
161 /*! \brief The tir prefix */
162 String tir_prefix_;
163 /*! \brief whether show meta data */
164 bool show_meta_;
165 /*! \brief additional comment function */
166 runtime::TypedPackedFunc<std::string(Stmt)> annotate_;
167 /*! \brief meta data context */
168 TextMetaDataContext meta_;
169 /*! \brief meta collector */
170 MetaCollector meta_collector_;
171 /*! \brief map from Function to GlobalVar */
172 std::unordered_map<const BaseFuncNode*, GlobalVar> func2var_;
173 /*! \brief var collector (var defined by For/Loop/Block) */
174 std::unordered_set<const VarNode*> var_not_in_headers_;
175 /*!
176 * \brief buffer collector
177 * (buffer defined in BufferMap, BufferAllocation and MatchBufferRegion)
178 */
179 std::unordered_set<const BufferNode*> buf_not_in_headers_;
180 /*! \brief Map from Var to thread env name */
181 std::unordered_map<Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_;
182 /*! \brief Map from Var to Doc */
183 std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
184 /*! \brief Map from Buffer to Doc */
185 std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
186 /*! \brief Map from Buffer to Declaration Doc */
187 std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_decl_;
188 /*! \brief name allocation map */
189 std::unordered_map<std::string, int> name_alloc_map_;
190 /*! \brief number of children of current node's parent */
191 int num_child_;
192 /*! \brief the number of current node */
193 int current_num_;
194 /*! \brief loop stack without annotations */
195 std::vector<For> simple_loop_stack_;
196 /*! \brief the maps from loop_vars to the loops */
197 std::unordered_map<const VarNode*, For> loop_var_map_;
198 /*!
199 * \brief simple block vars remap from loop vars
200 * simple_remap requires:
201 * 1. block var iter type is kDataPar or kCommReduce
202 * 2. value is a single Var, which is a loop_var outside the block
203 * 3. The iter range is equal to loop range
204 */
205 std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_;
206 /*!
207 * \brief Map from variables to the buffers they are used in.
208 *
209 * Used for identifying buffers that should be declared after the
210 * LetStmt or Allocate that generates their data pointer, rather
211 * than in the header.
212 */
213 Map<Var, Array<Buffer>> buffer_var_usage_;
214 /*! \brief Analyzer to simplify some expressions. */
215 arith::Analyzer ana_;
216
217 Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override;
218 Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override;
219 Doc VisitExpr_(const AddNode* op, ExprPrecedence* out_precedence) override;
220 Doc VisitExpr_(const SubNode* op, ExprPrecedence* out_precedence) override;
221 Doc VisitExpr_(const MulNode* op, ExprPrecedence* out_precedence) override;
222 Doc VisitExpr_(const DivNode* op, ExprPrecedence* out_precedence) override;
223 Doc VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) override;
224 Doc VisitExpr_(const FloorDivNode* op, ExprPrecedence* out_precedence) override;
225 Doc VisitExpr_(const FloorModNode* op, ExprPrecedence* out_precedence) override;
226 Doc VisitExpr_(const MinNode* op, ExprPrecedence* out_precedence) override;
227 Doc VisitExpr_(const MaxNode* op, ExprPrecedence* out_precedence) override;
228 Doc VisitExpr_(const EQNode* op, ExprPrecedence* out_precedence) override;
229 Doc VisitExpr_(const NENode* op, ExprPrecedence* out_precedence) override;
230 Doc VisitExpr_(const LTNode* op, ExprPrecedence* out_precedence) override;
231 Doc VisitExpr_(const LENode* op, ExprPrecedence* out_precedence) override;
232 Doc VisitExpr_(const GTNode* op, ExprPrecedence* out_precedence) override;
233 Doc VisitExpr_(const GENode* op, ExprPrecedence* out_precedence) override;
234 Doc VisitExpr_(const AndNode* op, ExprPrecedence* out_precedence) override;
235 Doc VisitExpr_(const OrNode* op, ExprPrecedence* out_precedence) override;
236 Doc VisitExpr_(const NotNode* op, ExprPrecedence* out_precedence) override;
237 Doc VisitExpr_(const SelectNode* op, ExprPrecedence* out_precedence) override;
238 Doc VisitExpr_(const IntImmNode* op, ExprPrecedence* out_precedence) override;
239 Doc VisitExpr_(const FloatImmNode* op, ExprPrecedence* out_precedence) override;
240 Doc VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) override;
241 Doc VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) override;
242 Doc VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) override;
243 Doc VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) override;
244 Doc VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) override;
245 Doc VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) override;
246 Doc VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) override;
247 Doc VisitExpr_(const CallNode* op, ExprPrecedence* out_precedence) override;
248 Doc VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) override;
249 Doc VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) override;
250 Doc VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) override;
251
252 Doc VisitStmt_(const LetStmtNode* op) override;
253 Doc VisitStmt_(const AttrStmtNode* op) override;
254 Doc VisitStmt_(const AssertStmtNode* op) override;
255 Doc VisitStmt_(const StoreNode* op) override;
256 Doc VisitStmt_(const BufferStoreNode* op) override;
257 Doc VisitStmt_(const BufferRealizeNode* op) override;
258 Doc VisitStmt_(const AllocateNode* op) override;
259 Doc VisitStmt_(const AllocateConstNode* op) override;
260 Doc VisitStmt_(const DeclBufferNode* op) override;
261 Doc VisitStmt_(const IfThenElseNode* op) override;
262 Doc VisitStmt_(const SeqStmtNode* op) override;
263 Doc VisitStmt_(const ForNode* op) override;
264 Doc VisitStmt_(const WhileNode* op) override;
265 Doc VisitStmt_(const PrefetchNode* op) override;
266 Doc VisitStmt_(const EvaluateNode* op) override;
267 Doc VisitStmt_(const BlockRealizeNode* op) override;
268 Doc VisitStmtDefault_(const Object* op) override;
269
270 Doc VisitType_(const PrimTypeNode* node) override;
271 Doc VisitType_(const PointerTypeNode* node) override;
272 Doc VisitType_(const TupleTypeNode* node) override;
273
274 Doc PrintBody(const Stmt& body);
275 Doc PrintIRModule(const IRModule& module);
276 Doc PrintPrimFunc(const PrimFunc& primFunc);
277 Doc PrintIterVar(const IterVarNode* op);
278 Doc PrintRange(const RangeNode* op);
279 Doc PrintArray(const ArrayNode* op);
280 Doc PrintBuffer(const BufferNode* op);
281 Doc PrintBufferIndices(const Array<PrimExpr>& indices);
282 Doc PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers);
283 Doc AllocBufferDeclaration(const Buffer& buf);
284 Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value);
285 Doc PrintBlockVarRemaps();
286 Doc PrintBlockPredicate(const BlockRealizeNode* op);
287 Doc PrintBlockVars(const BlockRealizeNode* op);
288 Doc PrintBlockAttr(const BlockRealizeNode* op);
289 Doc PrintExpandedArray(const ArrayNode* op);
290 Doc PrintBlockBody(const BlockNode* op);
291 virtual Doc PrintBlockName(const BlockNode* block_op);
292 Doc PrintBufferRegion(const BufferRegionNode* op);
293 Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op);
294 Doc PrintCommReducer(const CommReducerNode* op);
295 Doc PrintAnnotations(const Map<String, ObjectRef>& annotations);
296 Doc PrintTarget(const TargetNode* target);
297 static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); }
298
299 Doc GetUniqueName(std::string prefix);
300 Doc AllocVar(const Var& var);
301 Doc AllocBuf(const Buffer& buffer);
302 void TryDeallocVar(const Var& var);
303 bool ContainsOptionalInfo(const Stmt& stmt);
304 /*!
305 * \brief Check if a buffer declaration satisfies:
306 * 1. has only 'shape' and 'dtype' arguments specified,
307 * 2. the shape and strides are not dynamic.
308 * \param buffer The match buffer to be checked
309 */
310 bool IsSimpleBuffer(const Buffer& buffer);
311 Doc PrintInlineBufferBind(const Buffer& buffer);
312 Doc PrintTuple(const ArrayNode* op);
313
314 /*! Helper functions for loop printing. */
315 /*!
316 * \brief Print a single for loop
317 * \param loop The for loop to be printed
318 */
319 virtual Doc PrintLoop(const For& loop);
320 /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */
321 Doc PrintLoopStack();
322 /*!
323 * \brief Check whether a loop satisfies:
324 * 1. the loop is serial;
325 * 2. the loop has no annotation;
326 * 3. the loop starts from 0;
327 * 4. there is no optional information.
328 * \param for_op the for node to be checked
329 * \return A boolean indicating whether the input loop satisfies the above conditions
330 */
331 bool IsSimpleLoop(const ForNode* for_op) {
332 return for_op->kind == ForKind::kSerial && for_op->annotations.empty() &&
333 is_zero(for_op->min) && !ContainsOptionalInfo(GetRef<Stmt>(for_op));
334 }
335 /*!
336 * \brief Check whether the `min` or `extent` of a loop depends on previous loops
337 * \param for_op The loop to be checked
338 * \return A boolean indicating whether the input loop depends on previous loops
339 */
340 bool DependOnPrevLoops(const ForNode* for_op) {
341 auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return var_map.count(v); };
342 return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check);
343 }
344
345 /*!
346 * \brief Print additional info about expr in comment.
347 * \param expr The expression.
348 */
349 Doc PrintOptionalInfo(const Stmt& stmt) {
350 Doc doc;
351 // default annotations
352 if (ContainsOptionalInfo(stmt)) {
353 std::string annotated_stmt = annotate_(stmt);
354 doc << "# " << annotated_stmt << Doc::NewLine();
355 }
356 return doc;
357 }
358
359 /*!
360 * \brief special method to render vectors of docs with a separator
361 * \param vec vector of docs
362 * \param sep separator
363 */
364 static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
365 Doc seq;
366 if (vec.size() != 0) {
367 seq = vec[0];
368 for (size_t i = 1; i < vec.size(); i++) {
369 seq << sep << vec[i];
370 }
371 }
372 return seq;
373 }
374
375 /*!
376 * \brief dump meta info
377 * \return Doc with meta info
378 */
379 Doc DumpMeta() {
380 if (show_meta_) {
381 return Doc::Text("__tvm_meta__ = ")
382 << (meta_.empty() ? Doc::Text("None") : meta_.GetMetaSection());
383 } else {
384 return Doc::Text("");
385 }
386 }
387
388 /*!
389 * \brief special method to print out data type
390 * \param dtype The data type
391 */
392 static Doc PrintDType(DataType dtype) {
393 return Doc::StrLiteral(runtime::DLDataType2String(dtype));
394 }
395
396 /*!
397 * \brief special method to print out const int64_t scalar
398 * \param dtype The data type
399 * \param data The pointer to hold the data.
400 */
401 Doc PrintConstScalar(DataType dtype, const int64_t* data) const {
402 Doc doc;
403 std::ostringstream os;
404
405 os << data[0];
406
407 if (dtype == DataType::Int(32)) {
408 doc << Doc::Text(os.str());
409 } else if (dtype == DataType::Bool()) {
410 doc << Doc::Text(data[0] ? "True" : "False");
411 } else {
412 doc << tir_prefix_ << "." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str())
413 << ")";
414 }
415 return doc;
416 }
417
418 /*!
419 * \brief special method to print out const double scalar
420 * \param dtype The data type
421 * \param data The pointer to hold the data.
422 * \note this overriden function is created as std::isnan of msvc will complain about int64_t
423 */
424 Doc PrintConstScalar(DataType dtype, const double* data) const {
425 Doc doc;
426 std::ostringstream os;
427
428 os.precision(17);
429 if (std::isinf(data[0]) || std::isnan(data[0])) {
430 os << "\"" << data[0] << "\"";
431 } else {
432 os << data[0];
433 }
434
435 doc << tir_prefix_ << "." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str())
436 << ")";
437
438 return doc;
439 }
440
441 public:
442 static Doc PrintHeader(const std::string& tir_prefix) {
443 Doc header;
444 if (tir_prefix != "tir") {
445 header << "# from tvm.script import tir as " << tir_prefix << Doc::NewLine();
446 } else {
447 header << "# from tvm.script import tir" << Doc::NewLine();
448 }
449 return header;
450 }
451};
452
453/*!
454 * \brief special method to print NDArray in TIR
455 * \param arr the NDArray to be printed
456 * \param os the output stream where the NDArray will be printed to
457 */
458template <typename T>
459void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) {
460 if ((arr.DataType().code() == runtime::DataType::kInt ||
461 arr.DataType().code() == runtime::DataType::kUInt) &&
462 arr.DataType().bits() == 8) {
463 // Printing int8 NDArrays causes "UnicodeDecodeError: 'utf-8' codec can't decode byte"
464 // error during MetaSchedule tuning on int8 models.
465 return;
466 }
467 int ndim = arr->ndim;
468 int tot_dim = 1;
469 for (int i = 0; i < ndim; i++) {
470 tot_dim *= arr->shape[i];
471 }
472 T* data_ptr = reinterpret_cast<T*>(arr->data);
473 constexpr int NUM_PRINT = 20;
474 os << "[";
475 for (int i = 0; i < tot_dim; i++) {
476 os << (i != 0 ? ", " : "") << data_ptr[i];
477 if (i == NUM_PRINT) {
478 os << "...";
479 break;
480 }
481 }
482 os << "]";
483}
484
485Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
486 std::replace(prefix.begin(), prefix.end(), '.', '_');
487 std::string unique_prefix = prefix;
488 auto it = name_alloc_map_.find(prefix);
489 if (it != name_alloc_map_.end() && it->second >= 0) {
490 while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) {
491 }
492 }
493 name_alloc_map_[unique_prefix] = 0;
494 return Doc::Text(unique_prefix);
495}
496
497Doc TVMScriptPrinter::AllocVar(const Var& var) {
498 const auto& it = memo_var_.find(var);
499 if (it != memo_var_.end()) {
500 return it->second;
501 }
502 std::string name = var->name_hint.operator std::string();
503 if (name.length() == 0 || !std::isalpha(name[0])) {
504 name = "v" + name;
505 }
506 Doc val = GetUniqueName(name);
507 memo_var_[var] = val;
508 return val;
509}
510
511Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) {
512 Doc doc = Print(buf->shape);
513 bool print_factor_explicitly = false;
514 doc << ", dtype=" << PrintDType(buf->dtype);
515 if (memo_var_.find(buf->data) != memo_var_.end()) {
516 doc << ", data=" << Print(buf->data);
517 } else {
518 // implicitly define data
519 memo_var_[buf->data] = Doc::Text(memo_buf_[buf].str() + ".data");
520 var_not_in_headers_.insert(buf->data.get());
521 }
522 if (!buf->strides.empty()) {
523 doc << ", strides=" << Print(buf->strides);
524 }
525 if (buf->elem_offset->IsInstance<VarNode>()) {
526 Var elem_offset = Downcast<Var>(buf->elem_offset);
527 if (memo_var_.find(elem_offset) != memo_var_.end()) {
528 doc << ", elem_offset=" << Print(buf->elem_offset);
529 } else {
530 // implicitly define elem_offset
531 memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset");
532 var_not_in_headers_.insert(elem_offset.get());
533 print_factor_explicitly = true;
534 }
535 } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
536 IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
537 if (elem_offset->value != 0) {
538 doc << ", elem_offset=" << Print(buf->elem_offset);
539 }
540 }
541 if (buf.scope() != "global") {
542 doc << ", scope=" << Doc::StrLiteral(buf.scope());
543 }
544 if (buf->data_alignment != runtime::kAllocAlignment) {
545 doc << ", align=" << buf->data_alignment;
546 }
547 if (buf->offset_factor != 1 || print_factor_explicitly) {
548 doc << ", offset_factor=" << buf->offset_factor;
549 }
550 if (buf->buffer_type != BufferType::kDefault) {
551 doc << ", type=" << Doc::StrLiteral("auto");
552 }
553 if (buf->axis_separators.size()) {
554 doc << ", axis_separators=" << Print(buf->axis_separators);
555 }
556 return doc;
557}
558
559Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) {
560 const auto& it = memo_buf_.find(buffer);
561 if (it != memo_buf_.end()) {
562 return it->second;
563 }
564 std::string name = buffer->name;
565 if (name.length() == 0 || !std::isalpha(name[0])) {
566 name = "buf_" + name;
567 }
568 Doc val = GetUniqueName(name);
569 memo_buf_[buffer] = val;
570 memo_buf_decl_[buffer] = AllocBufferDeclaration(buffer);
571 return val;
572}
573
574/*!
575 * \brief Check if any optional information exists in annotate_ for
576 * a given Stmt.
577 * \param stmt The statement.
578 */
579bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) {
580 if (annotate_ == nullptr) return false;
581 return !annotate_(stmt).empty();
582}
583
584/*!
585 * \brief Try to dealloc vars out of space and leave the index to coming vars.
586 * \note It is not a necessary step.
587 */
588void TVMScriptPrinter::TryDeallocVar(const Var& var) {
589 auto it = memo_var_.find(var);
590 ICHECK(it != memo_var_.end());
591 std::string print_name = it->second.str();
592
593 std::string name_hint = var->name_hint.operator std::string();
594 if (name_hint.length() == 0 || !std::isalpha(name_hint[0])) {
595 name_hint = "v" + name_hint;
596 }
597 std::replace(name_hint.begin(), name_hint.end(), '.', '_');
598
599 auto it2 = name_alloc_map_.find(name_hint);
600 // Skip it if we can not find the name_hint in name_alloc_map_.
601 if (it2 == name_alloc_map_.end()) return;
602 if (it2->second > 0) {
603 name_hint = name_hint + '_' + std::to_string(it2->second);
604 }
605 // Skip it if the name_hint is not equal to how it should be printed.
606 if (name_hint != print_name) return;
607 // Free the conresponding name_alloc_map_ index
608 --it2->second;
609}
610
611Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) {
612 const Buffer& buf = op->buffer;
613 buf_not_in_headers_.insert(buf.get());
614
615 Doc doc = Print(op->buffer) << " = " << tir_prefix_ << ".match_buffer(" << Print(op->source)
616 << ", " << memo_buf_decl_[op->buffer] << ")";
617 return doc;
618}
619
620// check if all arguments, except the first two, are specified for T.match_buffer
621// if not, then this match buffer is printed out as T.buffer in prim_func arguments
622// and check whether there are undefined variables in the shape/strides.
623bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) {
624 if (memo_var_.find(buf->data) != memo_var_.end()) {
625 return false;
626 }
627 if (!buf->strides.empty()) {
628 return false;
629 }
630 for (const PrimExpr& shp_i : buf->shape) {
631 if (!UndefinedVars(shp_i).empty()) {
632 return false;
633 }
634 }
635 for (const PrimExpr& stride_i : buf->strides) {
636 if (!UndefinedVars(stride_i).empty()) {
637 return false;
638 }
639 }
640 if (!UndefinedVars(buf->elem_offset).empty()) {
641 return false;
642 } else if (buf->elem_offset->IsInstance<IntImmNode>()) {
643 IntImm elem_offset = Downcast<IntImm>(buf->elem_offset);
644 if (elem_offset->value != 0) {
645 return false;
646 }
647 }
648 if (buf.scope() != "global") {
649 return false;
650 }
651 if (buf->data_alignment != runtime::kAllocAlignment) {
652 return false;
653 }
654 if (buf->offset_factor != 1) {
655 return false;
656 }
657 if (buf->buffer_type != BufferType::kDefault) {
658 return false;
659 }
660 if (buf->axis_separators.size()) {
661 return false;
662 }
663 return true;
664}
665
666Doc TVMScriptPrinter::PrintInlineBufferBind(const Buffer& buffer) {
667 Doc doc;
668 doc << tir_prefix_ << ".Buffer[";
669 if (buffer->shape.size() == 1) {
670 doc << Print(buffer->shape[0]);
671 } else {
672 doc << PrintTuple(buffer->shape.as<ArrayNode>());
673 }
674 doc << ", " << PrintDType(buffer->dtype) << "]";
675 return doc;
676}
677
678// print array out as tuple with parentheses
679Doc TVMScriptPrinter::PrintTuple(const ArrayNode* op) {
680 Doc doc;
681 doc << '(';
682 for (size_t i = 0; i < op->size(); ++i) {
683 if (i != 0) {
684 doc << ", ";
685 }
686 doc << Print(op->at(i));
687 }
688 if (op->size() == 1) doc << ",";
689 doc << ')';
690 return doc;
691}
692
693Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) {
694 Doc doc;
695 int n_var = static_cast<int>(op->rhs.size());
696
697 doc << tir_prefix_ << ".comm_reducer(lambda ";
698 for (const Var& v_lhs : op->lhs) {
699 doc << Print(v_lhs) << ", ";
700 }
701 for (int i = 0; i < n_var; ++i) {
702 doc << Print(op->rhs[i]) << (i == n_var - 1 ? ": " : ", ");
703 }
704 if (n_var == 1) {
705 doc << Print(op->result[0]) << ", ";
706 } else {
707 doc << "(";
708 for (int i = 0; i < n_var; ++i) {
709 doc << Print(op->result[i]);
710 if (i != n_var - 1) {
711 doc << ", ";
712 }
713 }
714 doc << "), ";
715 }
716 doc << Print(op->identity_element) << ")";
717
718 // Remove the vars in `lhs` and `rhs`, because they are the parameters of the printed lambda.
719 for (int i = 0; i < n_var; ++i) {
720 memo_var_.erase(op->lhs[i]);
721 memo_var_.erase(op->rhs[i]);
722 }
723 return doc;
724}
725
726Doc TVMScriptPrinter::Print(const ObjectRef& node) {
727 if (!node.defined()) return Doc::Text("None");
728 if (node->IsInstance<StmtNode>()) {
729 return PrintOptionalInfo(Downcast<Stmt>(node)) << VisitStmt(Downcast<Stmt>(node));
730 } else if (node->IsInstance<PrimExprNode>()) {
731 ExprPrecedence t = ExprPrecedence::kUnknown;
732 return VisitExpr(Downcast<PrimExpr>(node), &t);
733 } else if (node->IsInstance<TypeNode>()) {
734 return VisitType(Downcast<Type>(node));
735 } else if (node->IsInstance<PrimFuncNode>()) {
736 return PrintPrimFunc(Downcast<PrimFunc>(node));
737 } else if (node->IsInstance<IRModuleNode>()) {
738 return PrintIRModule(Downcast<IRModule>(node));
739 } else if (node->IsInstance<ArrayNode>()) {
740 return PrintArray(node.as<ArrayNode>());
741 } else if (node->IsInstance<BufferNode>()) {
742 return PrintBuffer(node.as<BufferNode>());
743 } else if (node->IsInstance<StringObj>()) {
744 return PrintString(node.as<StringObj>());
745 } else if (node->IsInstance<IterVarNode>()) {
746 return PrintIterVar(node.as<IterVarNode>());
747 } else if (node->IsInstance<RangeNode>()) {
748 return PrintRange(node.as<RangeNode>());
749 } else if (node->IsInstance<BufferRegionNode>()) {
750 return PrintBufferRegion(node.as<BufferRegionNode>());
751 } else if (node->IsInstance<MatchBufferRegionNode>()) {
752 return PrintMatchBufferRegion(node.as<MatchBufferRegionNode>());
753 } else if (node->IsInstance<CommReducerNode>()) {
754 return PrintCommReducer(node.as<CommReducerNode>());
755 } else if (node->IsInstance<TargetNode>()) {
756 return PrintTarget(node.as<TargetNode>());
757 } else {
758 LOG(FATAL) << "Do not know how to print " << node->GetTypeKey();
759 }
760}
761
762Doc TVMScriptPrinter::VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) {
763 LOG(FATAL) << "Do not know how to print " << op->GetTypeKey();
764}
765
766Doc TVMScriptPrinter::VisitStmtDefault_(const Object* op) {
767 LOG(FATAL) << "Do not know how to print " << op->GetTypeKey();
768}
769
770Doc TVMScriptPrinter::VisitExpr_(const IntImmNode* op, ExprPrecedence* out_precedence) {
771 *out_precedence = ExprPrecedence::kIdentity;
772 return PrintConstScalar(op->dtype, &(op->value));
773}
774
775Doc TVMScriptPrinter::VisitExpr_(const FloatImmNode* op, ExprPrecedence* out_precedence) {
776 *out_precedence = ExprPrecedence::kIdentity;
777 return PrintConstScalar(op->dtype, &(op->value));
778}
779
780Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) {
781 *out_precedence = ExprPrecedence::kIdentity;
782 return Doc::StrLiteral(op->value);
783}
784
785Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) {
786 *out_precedence = ExprPrecedence::kIdentity;
787 Doc doc;
788 doc << tir_prefix_ << ".Cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")";
789 return doc;
790}
791
792Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) {
793 *out_precedence = ExprPrecedence::kIdentity;
794 const Var& var = GetRef<Var>(op);
795 return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op));
796}
797
798bool WillPrintConstScalar(const PrimExpr& expr) {
799 if (const auto* imm = expr.as<IntImmNode>()) {
800 DataType dtype = imm->dtype;
801 return dtype == DataType::Int(32) || dtype == DataType::Bool();
802 }
803 return false;
804}
805
806#define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpClass, OpPrecedence) \
807 Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \
808 Doc doc; \
809 if (WillPrintConstScalar(op->a) && WillPrintConstScalar(op->b)) { \
810 *out_precedence = ExprPrecedence::kIdentity; \
811 doc << tir_prefix_ << "." << OpClass << "(" << Print(op->a) << ", " << Print(op->b) << ")"; \
812 return doc; \
813 } \
814 ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown; \
815 ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown; \
816 /* Get children expr out_precedence */ \
817 Doc lhs_doc = VisitExpr(op->a, &lhs_precedence); \
818 Doc rhs_doc = VisitExpr(op->b, &rhs_precedence); \
819 ICHECK(lhs_precedence != ExprPrecedence::kUnknown); \
820 ICHECK(rhs_precedence != ExprPrecedence::kUnknown); \
821 /* Update out_precedence of current node. */ \
822 *out_precedence = OpPrecedence; \
823 if (lhs_precedence > OpPrecedence || \
824 (lhs_precedence == ExprPrecedence::kAnd && OpPrecedence == ExprPrecedence::kOr)) { \
825 doc << "(" << lhs_doc << ")"; \
826 } else { \
827 doc << lhs_doc; \
828 } \
829 doc << OpString; \
830 if (rhs_precedence >= OpPrecedence || \
831 (rhs_precedence == ExprPrecedence::kAnd && OpPrecedence == ExprPrecedence::kOr)) { \
832 doc << "(" << rhs_doc << ")"; \
833 } else { \
834 doc << rhs_doc; \
835 } \
836 return doc; \
837 }
838
839TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * ", "Mul", ExprPrecedence::kMultiplicationDivision)
840TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / ", "Div", ExprPrecedence::kMultiplicationDivision)
841TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // ", "FloorDiv",
842 ExprPrecedence::kMultiplicationDivision)
843TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % ", "FloorMod",
844 ExprPrecedence::kMultiplicationDivision)
845TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + ", "Add", ExprPrecedence::kAdditionSubtraction)
846TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - ", "Sub", ExprPrecedence::kAdditionSubtraction)
847TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < ", "LT", ExprPrecedence::kRelational)
848TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= ", "LE", ExprPrecedence::kRelational)
849TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > ", "GT", ExprPrecedence::kRelational)
850TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= ", "GE", ExprPrecedence::kRelational)
851TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == ", "EQ", ExprPrecedence::kEquality)
852TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != ", "NE", ExprPrecedence::kEquality)
853TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and ", "And", ExprPrecedence::kAnd)
854TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or ", "Or", ExprPrecedence::kOr)
855
856Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) {
857 *out_precedence = ExprPrecedence::kIdentity;
858 Doc doc;
859 doc << tir_prefix_ << ".truncmod(" << Print(op->a) << ", " << Print(op->b) << ")";
860 return doc;
861}
862
863Doc TVMScriptPrinter::VisitExpr_(const MinNode* op, ExprPrecedence* out_precedence) {
864 *out_precedence = ExprPrecedence::kIdentity;
865 Doc doc;
866 doc << tir_prefix_ << ".min(" << Print(op->a) << ", " << Print(op->b) << ")";
867 return doc;
868}
869
870Doc TVMScriptPrinter::VisitExpr_(const MaxNode* op, ExprPrecedence* out_precedence) {
871 *out_precedence = ExprPrecedence::kIdentity;
872 Doc doc;
873 doc << tir_prefix_ << ".max(" << Print(op->a) << ", " << Print(op->b) << ")";
874 return doc;
875}
876
877Doc TVMScriptPrinter::VisitExpr_(const NotNode* op, ExprPrecedence* out_precedence) {
878 *out_precedence = ExprPrecedence::kIdentity;
879 Doc doc;
880 doc << "not(" << Print(op->a) << ")";
881 return doc;
882}
883
884Doc TVMScriptPrinter::VisitExpr_(const SelectNode* op, ExprPrecedence* out_precedence) {
885 *out_precedence = ExprPrecedence::kIdentity;
886 Doc doc;
887 doc << tir_prefix_ << ".Select(" << Print(op->condition) << ", " << Print(op->true_value) << ", "
888 << Print(op->false_value) << ")";
889 return doc;
890}
891
892Doc TVMScriptPrinter::VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) {
893 LOG(FATAL) << "Cannot print a tir.ProducerLoad as it is not valid in TIR Primfuncs. You need to "
894 "lower this function first.";
895 return Doc();
896}
897
898Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) {
899 *out_precedence = ExprPrecedence::kIdentity;
900 Doc doc;
901 if (op->indices.size() == 0) {
902 doc << Print(op->buffer) << "[()]";
903 } else {
904 doc << Print(op->buffer) << PrintBufferIndices(op->indices);
905 }
906 return doc;
907}
908
909Doc TVMScriptPrinter::VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) {
910 *out_precedence = ExprPrecedence::kIdentity;
911 Doc doc;
912 if (op->dtype == DataType::Float(32) && is_one(op->predicate) &&
913 op->buffer_var->dtype == DataType::Float(32)) {
914 doc << Print(op->buffer_var) << "[" << Print(op->index) << "]";
915 } else {
916 doc << tir_prefix_ << ".load(" << PrintDType(op->dtype) << ", " << Print(op->buffer_var) << ", "
917 << Print(op->index);
918 if (!is_one(op->predicate) || op->dtype.lanes() != 1) {
919 doc << ", " << Print(op->predicate);
920 }
921 doc << ")";
922 }
923 return doc;
924}
925
926Doc TVMScriptPrinter::VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) {
927 *out_precedence = ExprPrecedence::kIdentity;
928 Doc doc;
929 doc << tir_prefix_ << ".ramp(" << Print(op->base) << ", " << Print(op->stride) << ", "
930 << op->lanes << ")";
931 return doc;
932}
933
934Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) {
935 *out_precedence = ExprPrecedence::kIdentity;
936 Doc doc;
937 doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " << op->lanes << ")";
938 return doc;
939}
940
941Doc TVMScriptPrinter::VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) {
942 *out_precedence = ExprPrecedence::kIdentity;
943 Doc doc;
944 doc << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << ", "
945 << Print(op->body) << ")";
946 return doc;
947}
948
949Doc TVMScriptPrinter::VisitExpr_(const CallNode* op, ExprPrecedence* out_precedence) {
950 *out_precedence = ExprPrecedence::kIdentity;
951 Doc doc;
952 if (auto* ptr_op = op->op.as<OpNode>()) {
953 std::string name = ptr_op->name;
954 if (name.find("tir.") == 0) {
955 name = tir_prefix_ + "." + name.substr(4);
956 }
957 doc << name << "(";
958 } else {
959 auto* op_gvar = op->op.as<GlobalVarNode>();
960 ICHECK(op_gvar != nullptr);
961 doc << Doc::Text(op_gvar->name_hint) << "(";
962 }
963 std::vector<Doc> args;
964 for (const auto& arg : op->args) {
965 args.push_back(Print(arg));
966 }
967 args.push_back(Doc::Text("dtype=") << PrintDType(op->dtype));
968 doc << PrintSep(args, Doc::Text(", ")) << ")";
969 return doc;
970}
971
972Doc TVMScriptPrinter::VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) {
973 *out_precedence = ExprPrecedence::kIdentity;
974 Doc doc;
975 doc << tir_prefix_ << ".shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")";
976 return doc;
977}
978
979Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) {
980 *out_precedence = ExprPrecedence::kIdentity;
981 Doc doc;
982 doc << tir_prefix_ << ".reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", "
983 << Print(op->axis) << ", " << op->value_index << ")";
984 return doc;
985}
986
987Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) {
988 if (!buffer_var_usage_.count(op->var)) {
989 buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
990 }
991 Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->var).value_or({});
992
993 Doc doc;
994 if (current_num_ != num_child_ - 1) {
995 doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):";
996 doc << Doc::Indent(
997 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body));
998 } else {
999 if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get());
1000 doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value)
1001 << Doc::NewLine();
1002 doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body);
1003 }
1004 return doc;
1005}
1006
1007Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) {
1008 Doc doc;
1009 if (op->node.defined()) {
1010 // merge attr with realize when possible
1011 if (op->node->IsInstance<BufferNode>() && op->attr_key == "realize_scope" &&
1012 op->body->IsInstance<BufferRealizeNode>()) {
1013 const auto* realize = Downcast<BufferRealize>(op->body).get();
1014 if (realize->buffer.same_as(op->node)) {
1015 if (current_num_ != num_child_ - 1) {
1016 doc << "with " << tir_prefix_ << ".realize(" << Print(realize->buffer)
1017 << Print(realize->bounds) << ", " << Print(op->value);
1018 if (!is_one(realize->condition)) {
1019 doc << ", " << Print(realize->condition);
1020 }
1021 doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(realize->body));
1022 } else {
1023 doc << tir_prefix_ << ".realize(" << Print(realize->buffer) << Print(realize->bounds)
1024 << ", " << Print(op->value);
1025 if (!is_one(realize->condition)) {
1026 doc << ", " << Print(realize->condition);
1027 }
1028 doc << ")" << Doc::NewLine() << PrintBody(realize->body);
1029 }
1030 return doc;
1031 }
1032 }
1033 // concise thread env
1034 if (op->node->IsInstance<IterVarNode>() &&
1035 (op->attr_key == "thread_extent" || op->attr_key == "virtual_thread")) {
1036 const auto* iter_var = Downcast<IterVar>(op->node).get();
1037 var_not_in_headers_.insert(iter_var->var.get());
1038 var_env_map_[iter_var->var] = iter_var->thread_tag;
1039 if (current_num_ != num_child_ - 1) {
1040 doc << "with " << tir_prefix_ << ".launch_thread(" << Print(iter_var->var) << ", "
1041 << Print(op->value) << "):";
1042 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1043 } else {
1044 doc << tir_prefix_ << ".launch_thread(" << Print(iter_var->var) << ", " << Print(op->value)
1045 << ")";
1046 doc << Doc::NewLine() << PrintBody(op->body);
1047 }
1048 return doc;
1049 }
1050 }
1051 // default
1052 if (current_num_ != num_child_ - 1) {
1053 doc << "with " << tir_prefix_ << ".attr(" << Print(op->node) << ", "
1054 << Doc::StrLiteral(op->attr_key) << ", " << Print(op->value) << "):";
1055 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1056 } else {
1057 doc << tir_prefix_ << ".attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key)
1058 << ", " << Print(op->value) << ")";
1059 doc << Doc::NewLine() << PrintBody(op->body);
1060 }
1061 return doc;
1062}
1063
1064Doc TVMScriptPrinter::VisitStmt_(const AssertStmtNode* op) {
1065 Doc doc;
1066 if (current_num_ != num_child_ - 1) {
1067 doc << "with " << tir_prefix_ << ".Assert(" << Print(op->condition) << ", "
1068 << Print(op->message) << "):";
1069 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1070 } else {
1071 doc << "assert " << Print(op->condition) << ", " << Print(op->message);
1072 doc << Doc::NewLine() << PrintBody(op->body);
1073 }
1074 return doc;
1075}
1076
1077Doc TVMScriptPrinter::VisitStmt_(const StoreNode* op) {
1078 Doc doc;
1079 doc << tir_prefix_ << ".store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", "
1080 << Print(op->value) << ", " << Print(op->predicate) << ")";
1081 return doc;
1082}
1083
1084Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) {
1085 LOG(FATAL)
1086 << "TVM Script Printer Internal Error: All the BufferRealize should be folded with Attr";
1087 return Doc();
1088}
1089
1090namespace {
1091
1092bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) {
1093 const Var& buffer_var = allocate->buffer_var;
1094 const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>();
1095 if (!decl_buffer) {
1096 return false;
1097 }
1098 const Buffer& buffer = decl_buffer->buffer;
1099 if (!buffer_var.same_as(buffer->data)) {
1100 return false;
1101 }
1102 if (allocate->dtype != buffer->dtype) {
1103 return false;
1104 }
1105 if (!is_one(allocate->condition)) {
1106 return false;
1107 }
1108 if (allocate->annotations.size()) {
1109 return false;
1110 }
1111 if (allocate->extents.size() != buffer->shape.size()) {
1112 return false;
1113 }
1114 tir::ExprDeepEqual expr_equal;
1115 for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) {
1116 if (!expr_equal(allocate->extents[i], buffer->shape[i])) {
1117 return false;
1118 }
1119 }
1120 return true;
1121}
1122
1123} // namespace
1124
1125Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
1126 var_not_in_headers_.insert(op->buffer_var.get());
1127
1128 if (!buffer_var_usage_.count(op->buffer_var)) {
1129 buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body);
1130 }
1131 Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({});
1132
1133 if (buffer_usage.empty()) {
1134 if (IsAllocateDeclBufferPattern(op)) {
1135 // As a syntax sugar, we identify the pattern of Allocate and DeclBuffer and print a single
1136 // DeclBuffer statement. It is intentionally to call `Print` instead of `PrintBody` here to
1137 // delegate the printing of the current node to `DeclBufferNode` while maintaining the
1138 // same value of `current_num_` and `num_child_`.
1139 return Print(op->body);
1140 }
1141 }
1142
1143 auto storage_scope = GetPtrStorageScope(op->buffer_var);
1144 Doc func_call;
1145 func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype)
1146 << ", " << Print(storage_scope);
1147 if (!is_one(op->condition)) {
1148 func_call << ", " << Print(op->condition);
1149 }
1150 if (!op->annotations.empty()) {
1151 func_call << ", annotations={";
1152 func_call << PrintAnnotations(op->annotations);
1153 func_call << "}";
1154 }
1155 func_call << ")";
1156
1157 Doc doc;
1158 if (current_num_ != num_child_ - 1) {
1159 doc << "with " << func_call << " as " << Print(op->buffer_var) << ":";
1160 doc << Doc::Indent(
1161 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body));
1162 } else {
1163 doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine();
1164 doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body);
1165 }
1166 TryDeallocVar(op->buffer_var);
1167 return doc;
1168}
1169
1170Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
1171 std::stringstream ss;
1172 ICHECK(alloc->data) << "Should be presented";
1173 const auto& data = alloc->data.value();
1174
1175 if (alloc->dtype.is_int()) {
1176 if (alloc->dtype.bits() == 8) {
1177 NDArrayToTIR<int8_t>(data, ss);
1178 } else if (alloc->dtype.bits() == 16) {
1179 NDArrayToTIR<int16_t>(data, ss);
1180 } else if (alloc->dtype.bits() == 32) {
1181 NDArrayToTIR<int32_t>(data, ss);
1182 } else if (alloc->dtype.bits() == 64) {
1183 NDArrayToTIR<int64_t>(data, ss);
1184 } else {
1185 LOG(FATAL) << "DataType not supported";
1186 }
1187 } else if (alloc->dtype.is_uint()) {
1188 if (alloc->dtype.bits() == 8) {
1189 NDArrayToTIR<uint8_t>(data, ss);
1190 } else if (alloc->dtype.bits() == 16) {
1191 NDArrayToTIR<uint16_t>(data, ss);
1192 } else if (alloc->dtype.bits() == 32) {
1193 NDArrayToTIR<uint32_t>(data, ss);
1194 } else if (alloc->dtype.bits() == 64) {
1195 NDArrayToTIR<int64_t>(data, ss);
1196 } else {
1197 LOG(FATAL) << "DataType not supported";
1198 }
1199 } else if (alloc->dtype.is_float()) {
1200 if (alloc->dtype.bits() == 16) {
1201 NDArrayToTIR<int16_t>(data, ss);
1202 } else if (alloc->dtype.bits() == 32) {
1203 NDArrayToTIR<float>(data, ss);
1204 } else if (alloc->dtype.bits() == 64) {
1205 NDArrayToTIR<double>(data, ss);
1206 } else {
1207 LOG(FATAL) << "DataType not supported";
1208 }
1209 } else {
1210 LOG(FATAL) << "DataType not supported";
1211 }
1212 auto ndarray_str = ss.str();
1213
1214 var_not_in_headers_.insert(alloc->buffer_var.get());
1215
1216 if (!buffer_var_usage_.count(alloc->buffer_var)) {
1217 buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), alloc->body);
1218 }
1219 Array<Buffer> buffer_usage = buffer_var_usage_.Get(alloc->buffer_var).value_or({});
1220
1221 Doc func_call;
1222 func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype)
1223 << ", " << Print(alloc->extents) << ")";
1224
1225 Doc doc;
1226 var_not_in_headers_.insert(alloc->buffer_var.get());
1227 if (current_num_ != num_child_ - 1) {
1228 doc << "with " << func_call << " as " << Print(alloc->buffer_var) << ":";
1229 doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage)
1230 << PrintBody(alloc->body));
1231 } else {
1232 doc << Print(alloc->buffer_var) << " = " << func_call << Doc::NewLine();
1233 doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(alloc->body);
1234 }
1235 return doc;
1236}
1237
1238Doc TVMScriptPrinter::VisitStmt_(const DeclBufferNode* op) {
1239 const Buffer& buffer = op->buffer;
1240 buf_not_in_headers_.insert(buffer.get());
1241 Doc buffer_name = Print(op->buffer);
1242 Doc func_call;
1243 func_call << tir_prefix_ << ".decl_buffer(" << memo_buf_decl_.at(buffer) << ")";
1244
1245 Doc doc;
1246 if (current_num_ != num_child_ - 1) {
1247 doc << "with " << func_call << " as " << buffer_name << ":";
1248 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1249 } else {
1250 doc << buffer_name << " = " << func_call << Doc::NewLine();
1251 doc << PrintBody(op->body);
1252 }
1253 return doc;
1254}
1255
1256Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
1257 Doc doc;
1258 doc << "if " << Print(op->condition) << ":";
1259 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case));
1260
1261 Optional<Stmt> else_case = op->else_case;
1262 while (else_case) {
1263 if (auto* else_if = else_case.value().as<IfThenElseNode>()) {
1264 doc << Doc::NewLine();
1265 doc << "elif " << Print(else_if->condition) << ":";
1266 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(else_if->then_case));
1267
1268 else_case = else_if->else_case;
1269 } else {
1270 doc << Doc::NewLine();
1271 doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(else_case.value()));
1272 break;
1273 }
1274 }
1275
1276 return doc;
1277}
1278
1279Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) {
1280 std::vector<Doc> stmts;
1281 for (Stmt stmt : op->seq) {
1282 stmts.push_back(Print(stmt));
1283 }
1284 return PrintSep(stmts, Doc::NewLine());
1285}
1286
1287Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
1288 // When parsing TVMScript, a PrimExpr that occurs as a statement is
1289 // automatically wrapped in `tir::Evaluate`. Therefore, when
1290 // printing, it's only necessary to print the value. For
1291 // readability, though, we still print T.evaluate() when the
1292 // expression is something other than a call node.
1293 Doc doc;
1294 if (op->value.as<CallNode>()) {
1295 doc << Print(op->value);
1296 } else {
1297 doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
1298 }
1299 return doc;
1300}
1301
1302Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
1303 Doc doc;
1304 var_not_in_headers_.insert(op->loop_var.get());
1305 loop_var_map_[op->loop_var.get()] = GetRef<For>(op);
1306 const auto* body = op->body.as<ForNode>();
1307 bool simple_loop = IsSimpleLoop(op);
1308 if (simple_loop) simple_loop_stack_.push_back(GetRef<For>(op));
1309 // It is a loop that can be compressed, let the loops below print it out
1310 if (simple_loop && body != nullptr && IsSimpleLoop(body) && !DependOnPrevLoops(body)) {
1311 doc << Print(GetRef<For>(body));
1312 TryDeallocVar(op->loop_var);
1313 loop_var_map_.erase(op->loop_var.get());
1314 return doc;
1315 }
1316 // It is a loop that can not be compressed
1317 bool print_above = !simple_loop_stack_.empty();
1318 // print loops above if needed
1319 if (print_above) {
1320 doc << PrintLoopStack();
1321 simple_loop_stack_.clear();
1322 }
1323 if (!simple_loop) {
1324 // print current loop if needed
1325 Doc current_loop;
1326 current_loop << PrintLoop(GetRef<For>(op));
1327 current_loop << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1328 doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) : current_loop);
1329 } else {
1330 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1331 }
1332 TryDeallocVar(op->loop_var);
1333 loop_var_map_.erase(op->loop_var.get());
1334 return doc;
1335}
1336
1337Doc TVMScriptPrinter::VisitStmt_(const PrefetchNode* op) {
1338 Doc doc;
1339 doc << tir_prefix_ << ".prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")";
1340 return doc;
1341}
1342
1343Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) {
1344 Doc doc;
1345 doc << "while " << Print(op->condition) << ":";
1346 doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body));
1347 return doc;
1348}
1349
1350Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) {
1351 Doc doc;
1352 doc << tir_prefix_ << ".";
1353 if (node->dtype.is_void()) {
1354 doc << "void";
1355 } else {
1356 doc << runtime::DLDataType2String(node->dtype);
1357 }
1358 return doc;
1359}
1360
1361Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) {
1362 Doc doc;
1363 doc << tir_prefix_ << ".Ptr[";
1364 doc << Print(node->element_type);
1365 if (!node->storage_scope.empty()) {
1366 doc << ", " << Doc::StrLiteral(node->storage_scope);
1367 }
1368 doc << "]";
1369 return doc;
1370}
1371
1372Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) {
1373 if (node->fields.empty()) {
1374 return Doc::Text("None");
1375 } else {
1376 std::vector<Doc> fields;
1377 for (Type field : node->fields) {
1378 fields.push_back(Print(field));
1379 }
1380 return Doc::Text(tir_prefix_ + ".Tuple[") << Doc::Concat(fields) << "]";
1381 }
1382}
1383
1384Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) {
1385 Doc doc;
1386 if (op->indices.size() == 0) {
1387 doc << Print(op->buffer) << "[()] = " << Print(op->value);
1388 } else {
1389 doc << Print(op->buffer) << PrintBufferIndices(op->indices) << " = " << Print(op->value);
1390 }
1391 return doc;
1392}
1393
1394/*! Helper functions for block printing. */
1395Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) {
1396 Doc doc;
1397 doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis.";
1398 switch (iter_var->iter_type) {
1399 case kDataPar:
1400 doc << "spatial";
1401 break;
1402 case kCommReduce:
1403 doc << "reduce";
1404 break;
1405 case kOrdered:
1406 doc << "scan";
1407 break;
1408 case kOpaque:
1409 doc << "opaque";
1410 break;
1411 default:
1412 LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type;
1413 break;
1414 }
1415 doc << "(";
1416 const Range& dom = iter_var->dom;
1417 if (is_zero(dom->min)) {
1418 doc << Print(dom->extent);
1419 } else {
1420 doc << "(" << Print(dom->min) << ", " << Print(dom->min + dom->extent) << ")";
1421 }
1422 doc << ", " << Print(value) << ")";
1423 return doc;
1424}
1425
1426Doc TVMScriptPrinter::PrintBlockVarRemaps() {
1427 ICHECK(!block_var_remaps_.empty());
1428 if (block_var_remaps_.size() == 1) {
1429 const IterVar& iter_var = block_var_remaps_[0].first;
1430 const PrimExpr& value = block_var_remaps_[0].second;
1431 return PrintBlockVar(iter_var, value);
1432 }
1433 Doc doc;
1434 std::vector<Doc> iter_vars, iter_values;
1435 std::string iter_type;
1436 for (const auto& pair : block_var_remaps_) {
1437 const IterVar& iter_var = pair.first;
1438 const PrimExpr& value = pair.second;
1439 iter_vars.push_back(Print(iter_var->var));
1440 iter_values.push_back(Print(value));
1441 if (iter_var->iter_type == kDataPar) {
1442 iter_type += "S";
1443 } else if (iter_var->iter_type == kCommReduce) {
1444 iter_type += "R";
1445 } else {
1446 ICHECK(false);
1447 }
1448 }
1449 doc << PrintSep(iter_vars, Doc::Text(", ")) << " = " << tir_prefix_ << ".axis.remap("
1450 << Doc::StrLiteral(iter_type) << ", [" << PrintSep(iter_values, Doc::Text(", ")) << "])";
1451 return doc;
1452}
1453
1454Doc TVMScriptPrinter::PrintBlockPredicate(const BlockRealizeNode* op) {
1455 Doc doc;
1456 if (!is_one(op->predicate)) {
1457 doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")";
1458 }
1459 return doc;
1460}
1461
1462Doc TVMScriptPrinter::PrintBlockVars(const BlockRealizeNode* op) {
1463 Doc doc;
1464 const auto* block_op = op->block.as<BlockNode>();
1465 ICHECK_EQ(block_op->iter_vars.size(), op->iter_values.size());
1466 tir::ExprDeepEqual expr_equal;
1467
1468 auto is_simple_remap = [this, &expr_equal](const IterVar& iter_var,
1469 const PrimExpr& value) -> bool {
1470 if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) return false;
1471 if (!value->IsInstance<VarNode>()) return false;
1472 const Var& var = Downcast<Var>(value);
1473 auto it = loop_var_map_.find(var.get());
1474 return it != loop_var_map_.end() && expr_equal(it->second->min, iter_var->dom->min) &&
1475 expr_equal(it->second->extent, iter_var->dom->extent);
1476 };
1477
1478 for (size_t i = 0; i < block_op->iter_vars.size(); ++i) {
1479 const IterVar& iter_var = block_op->iter_vars[i];
1480 const PrimExpr& value = op->iter_values[i];
1481 var_not_in_headers_.insert(iter_var->var.get());
1482 if (is_simple_remap(iter_var, value)) {
1483 block_var_remaps_.push_back(std::make_pair(iter_var, value));
1484 } else {
1485 if (!block_var_remaps_.empty()) {
1486 doc << Doc::NewLine() << PrintBlockVarRemaps();
1487 block_var_remaps_.clear();
1488 }
1489 doc << Doc::NewLine() << PrintBlockVar(iter_var, value);
1490 }
1491 }
1492 if (!block_var_remaps_.empty()) {
1493 doc << Doc::NewLine() << PrintBlockVarRemaps();
1494 block_var_remaps_.clear();
1495 }
1496 return doc;
1497}
1498
1499Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) {
1500 const auto* block_op = op->block.as<BlockNode>();
1501 Doc block_attr_doc;
1502 // print binding, read/write tensor region, annotations
1503 block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads("
1504 << PrintExpandedArray(block_op->reads.as<ArrayNode>()) << ")";
1505 block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes("
1506 << PrintExpandedArray(block_op->writes.as<ArrayNode>()) << ")";
1507 if (!block_op->annotations.empty()) {
1508 block_attr_doc << Doc::NewLine() << tir_prefix_ << ".block_attr({";
1509 block_attr_doc << PrintAnnotations(block_op->annotations);
1510 block_attr_doc << "})";
1511 }
1512 return block_attr_doc;
1513}
1514
1515// This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a
1516// List. Therefore the brackets are removed before and after printing arguments out
1517Doc TVMScriptPrinter::PrintExpandedArray(const ArrayNode* op) {
1518 Doc doc;
1519 for (size_t i = 0; i < op->size(); ++i) {
1520 if (i != 0) {
1521 doc << ", ";
1522 }
1523 doc << Print(op->at(i));
1524 }
1525 return doc;
1526}
1527
1528Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) {
1529 Doc body;
1530 for (const auto& alloc_buf : op->alloc_buffers) {
1531 buf_not_in_headers_.insert(alloc_buf.get());
1532 body << Print(alloc_buf) << " = " << tir_prefix_ << ".alloc_buffer("
1533 << memo_buf_decl_[alloc_buf] << ")" << Doc::NewLine();
1534 }
1535 for (const auto& match_buf : op->match_buffers) {
1536 body << Print(match_buf) << Doc::NewLine();
1537 }
1538 if (op->init.defined()) {
1539 Doc init_block;
1540 init_block << "with " << tir_prefix_ << ".init():";
1541 init_block << Doc::Indent(4, Doc::NewLine() << PrintBody(op->init.value()));
1542 body << init_block << Doc::NewLine();
1543 }
1544 body << PrintBody(op->body);
1545 return body;
1546}
1547
1548/*!
1549 * \brief Print the name of a block
1550 * \param block_op The block node to be printed
1551 */
1552Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) {
1553 Doc doc;
1554 doc << "with " << tir_prefix_ << ".block(";
1555 if (!block_op->name_hint.empty()) {
1556 doc << Doc::StrLiteral(block_op->name_hint);
1557 }
1558 doc << "):";
1559 return doc;
1560}
1561
1562Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) {
1563 const auto* block_op = op->block.as<BlockNode>();
1564 Doc doc = PrintOptionalInfo(GetRef<Stmt>(block_op));
1565 // print block name
1566 doc << PrintBlockName(block_op);
1567 // Print block predicate.
1568 Doc block_predicate = PrintBlockPredicate(op);
1569 // Print the variable bindings, valid to use in block attributes and
1570 // body
1571 Doc block_var = PrintBlockVars(op);
1572 // print read/write tensor region, annotations
1573 Doc block_attr_doc = PrintBlockAttr(op);
1574 // print body
1575 Doc body = PrintBlockBody(block_op);
1576 doc << Doc::Indent(4, block_predicate << block_var << block_attr_doc << Doc::NewLine() << body);
1577 for (const auto& iter_var : block_op->iter_vars) {
1578 TryDeallocVar(iter_var->var);
1579 }
1580 return doc;
1581}
1582
1583Doc TVMScriptPrinter::PrintBody(const Stmt& body) {
1584 int memo_num_child, memo_current_num;
1585 std::swap(memo_num_child, num_child_);
1586 std::swap(memo_current_num, current_num_);
1587
1588 Doc doc;
1589 if (body->IsInstance<SeqStmtNode>()) {
1590 const auto& op = Downcast<SeqStmt>(body);
1591 num_child_ = op->seq.size();
1592 current_num_ = 0;
1593 std::vector<Doc> stmts;
1594 for (Stmt stmt : op->seq) {
1595 stmts.push_back(Print(stmt));
1596 current_num_++;
1597 }
1598 doc = PrintSep(stmts, Doc::NewLine());
1599 } else {
1600 num_child_ = 1;
1601 current_num_ = 0;
1602 doc = Print(body);
1603 }
1604
1605 std::swap(memo_num_child, num_child_);
1606 std::swap(memo_current_num, current_num_);
1607 return doc;
1608}
1609
1610Doc TVMScriptPrinter::PrintIRModule(const IRModule& module) {
1611 auto* op = module.operator->();
1612 Doc doc;
1613 doc << "@tvm.script.ir_module" << Doc::NewLine();
1614 doc << "class Module:";
1615 for (const auto& x : op->functions) {
1616 func2var_[x.second.operator->()] = x.first;
1617 }
1618 Doc body = Doc::NewLine();
1619 std::vector<Doc> functions;
1620 for (auto it = op->functions.begin(); it != op->functions.end(); ++it) {
1621 if ((*it).second.as<PrimFuncNode>()) {
1622 functions.push_back(Print((*it).second));
1623 }
1624 }
1625 body << TVMScriptPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine());
1626 body << Doc::NewLine() << DumpMeta();
1627 doc << Doc::Indent(4, body);
1628 return doc;
1629}
1630
1631Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
1632 auto* op = primFunc.operator->();
1633 // clear renaming map
1634 memo_var_.clear();
1635 memo_buf_.clear();
1636 memo_buf_decl_.clear();
1637 var_not_in_headers_.clear();
1638 buf_not_in_headers_.clear();
1639 // print signature
1640 Doc doc;
1641 doc << "@" << tir_prefix_ << ".prim_func" << Doc::NewLine();
1642 doc << "def " << (func2var_.find(op) == func2var_.end() ? "func" : func2var_[op]->name_hint)
1643 << "(";
1644 std::vector<Doc> params;
1645 std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> simple_buf;
1646 for (const auto& param : op->params) {
1647 var_not_in_headers_.insert(param.get());
1648 auto it = op->buffer_map.find(param);
1649 // check if this param is a T.handle
1650 if (it != op->buffer_map.end()) {
1651 // check if this match_buffer has only the first two arguments specified
1652 // and whether the match_buffer is a dynamic buffer.
1653 const Buffer& buf = (*it).second;
1654 if (IsSimpleBuffer(buf)) {
1655 simple_buf.insert(buf);
1656 buf_not_in_headers_.insert(buf.get());
1657 params.push_back(Print(buf) << ": " << PrintInlineBufferBind(buf));
1658 continue;
1659 }
1660 }
1661 params.push_back(Print(param) << ": " << Print(GetType(param)));
1662 }
1663 doc << PrintSep(params, Doc::Text(", ")) << ")";
1664 if (primFunc->ret_type.defined()) {
1665 auto as_tuple = primFunc->ret_type.as<TupleTypeNode>();
1666 if (!as_tuple || as_tuple->fields.size()) {
1667 doc << " -> " << Print(primFunc->ret_type);
1668 }
1669 }
1670 doc << ":";
1671
1672 Doc body = Doc::NewLine();
1673 // print buffer_bind
1674 for (const auto& param : op->params) {
1675 auto it = op->buffer_map.find(param);
1676 if (it == op->buffer_map.end()) continue;
1677 const Buffer& buf = (*it).second;
1678 if (simple_buf.count(buf)) continue;
1679 buf_not_in_headers_.insert(buf.get());
1680 body << Print(buf) << " = " << tir_prefix_ << ".match_buffer(";
1681 ICHECK(memo_buf_decl_.count(buf));
1682 body << Print((*it).first) << ", " << memo_buf_decl_[buf];
1683 body << ")" << Doc::NewLine();
1684 }
1685 // print body
1686 body << "# body" << Doc::NewLine();
1687
1688 Optional<Block> elided_root_block_body = [&]() -> Optional<Block> {
1689 auto block_realize = op->body.as<BlockRealizeNode>();
1690 if (!block_realize || block_realize->iter_values.size()) {
1691 return NullOpt;
1692 }
1693
1694 const auto& block = block_realize->block;
1695 if (block->annotations.size() || ContainsOptionalInfo(block)) {
1696 return NullOpt;
1697 }
1698
1699 // The autocomplete might recognize the body itself as being a
1700 // root block, and fail to insert it.
1701 bool autocomplete_would_insert_root_block = [&]() -> bool {
1702 if (block->alloc_buffers.size()) {
1703 return true;
1704 }
1705
1706 auto* block_realize = block->body.as<BlockRealizeNode>();
1707 if (block_realize && block_realize->block->iter_vars.size()) {
1708 return true;
1709 }
1710 if (!block_realize && ContainsNode<BlockRealizeNode>(block->body)) {
1711 return true;
1712 }
1713 return false;
1714 }();
1715
1716 if (autocomplete_would_insert_root_block) {
1717 return block;
1718 } else {
1719 return NullOpt;
1720 }
1721 }();
1722
1723 if (elided_root_block_body) {
1724 // Skip printing of root block in cases where tvm::tir::ScriptComplete
1725 // would re-insert it.
1726 body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine();
1727 body << PrintBlockBody(elided_root_block_body.value().get());
1728 } else {
1729 // If this is a non-root block, or is an unskippable root block,
1730 // just print it without skipping.
1731 body << PrintBody(op->body);
1732 }
1733
1734 // print func attrs
1735 Doc header_attr;
1736 if (primFunc->attrs.defined()) {
1737 header_attr << Doc::NewLine() << "# function attr dict" << Doc::NewLine() << tir_prefix_
1738 << ".func_attr({";
1739 std::vector<Doc> attrs;
1740 for (const auto& it : op->attrs->dict) {
1741 attrs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second));
1742 }
1743 header_attr << PrintSep(attrs, Doc::Text(", ")) << "})";
1744 }
1745 // print buffer declarations(buffers not defined by buffer_bind or buffer_allocate)
1746 Doc header_buf;
1747 std::vector<const BufferNode*> bufs;
1748 for (const auto& it : memo_buf_) {
1749 if (buf_not_in_headers_.find(it.first.get()) == buf_not_in_headers_.end()) {
1750 bufs.push_back(it.first.get());
1751 }
1752 }
1753 if (!bufs.empty()) {
1754 header_buf << Doc::NewLine() << "# buffer definition";
1755 std::sort(bufs.begin(), bufs.end(), [&](const BufferNode* a, const BufferNode* b) {
1756 return memo_buf_[GetRef<Buffer>(a)].str() < memo_buf_[GetRef<Buffer>(b)].str();
1757 });
1758 for (const auto& buf : bufs) {
1759 header_buf << Doc::NewLine() << Print(GetRef<Buffer>(buf)) << " = " << tir_prefix_
1760 << ".buffer_decl(";
1761 header_buf << memo_buf_decl_[GetRef<Buffer>(buf)] << ")";
1762 }
1763 }
1764 // print var declaration
1765 Doc header_var;
1766 std::vector<const VarNode*> vars;
1767 for (const auto& it : memo_var_) {
1768 if (var_not_in_headers_.find(it.first.get()) == var_not_in_headers_.end()) {
1769 vars.push_back(it.first.get());
1770 }
1771 }
1772 if (!var_env_map_.empty()) {
1773 header_var << Doc::NewLine() << "# var definition";
1774 for (const auto& it : var_env_map_) {
1775 header_var << Doc::NewLine() << Print(it.first) << " = " << tir_prefix_ << ".env_thread("
1776 << Doc::StrLiteral(it.second) << ")";
1777 }
1778 }
1779 if (!vars.empty()) {
1780 std::sort(vars.begin(), vars.end(), [&](const VarNode* a, const VarNode* b) {
1781 return memo_var_[GetRef<Var>(a)].str() < memo_var_[GetRef<Var>(b)].str();
1782 });
1783 for (const auto& var : vars) {
1784 auto type = GetRef<Var>(var)->type_annotation;
1785 if (auto* ptr_type = type.as<PointerTypeNode>()) {
1786 auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
1787 ICHECK(prim_type);
1788 header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = " << tir_prefix_
1789 << ".buffer_var(";
1790 header_var << PrintDType(prim_type->dtype) << ", "
1791 << Doc::StrLiteral(ptr_type->storage_scope) << ")";
1792 } else {
1793 header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = " << tir_prefix_ << ".var(";
1794 header_var << PrintDType(var->dtype) << ")";
1795 }
1796 }
1797 }
1798 doc << Doc::Indent(4, header_attr << header_var << header_buf << body);
1799 return doc;
1800}
1801
1802Doc TVMScriptPrinter::PrintArray(const ArrayNode* op) {
1803 Doc doc;
1804 doc << '[';
1805 for (size_t i = 0; i < op->size(); ++i) {
1806 if (i != 0) {
1807 doc << ", ";
1808 }
1809 doc << Print(op->at(i));
1810 }
1811 doc << ']';
1812 return doc;
1813}
1814
1815Doc TVMScriptPrinter::PrintIterVar(const IterVarNode* op) {
1816 Doc doc;
1817 doc << tir_prefix_ << ".iter_var(" << Print(op->var);
1818 if (op->dom.defined()) {
1819 doc << ", [" << Print(op->dom) << "], ";
1820 } else {
1821 doc << ", None, ";
1822 }
1823 doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", ";
1824 doc << Doc::StrLiteral(op->thread_tag) << ")";
1825 return doc;
1826}
1827
1828Doc TVMScriptPrinter::PrintRange(const RangeNode* op) {
1829 return Print(op->min) << ":" << Print(op->min + op->extent);
1830}
1831
1832Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) {
1833 const Buffer& buffer = GetRef<Buffer>(op);
1834 return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer);
1835}
1836
1837Doc TVMScriptPrinter::PrintBufferIndices(const Array<PrimExpr>& indices) {
1838 Doc doc;
1839 doc << '[';
1840 for (size_t i = 0; i < indices.size(); ++i) {
1841 if (i != 0) {
1842 doc << ", ";
1843 }
1844 PrimExpr index = indices[i];
1845 if (const RampNode* ramp = index.as<RampNode>()) {
1846 // specify ramp printing as python index slice
1847 if (auto* stride_imm = ramp->stride.as<IntImmNode>()) {
1848 doc << Print(ramp->base) << ":" << Print(ramp->base + ramp->lanes * ramp->stride);
1849 if (stride_imm->value != 1) {
1850 doc << ":" << Print(ramp->stride);
1851 }
1852 continue;
1853 }
1854 }
1855 doc << Print(index);
1856 }
1857 doc << ']';
1858 return doc;
1859}
1860
1861Doc TVMScriptPrinter::PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers) {
1862 Doc decls;
1863 for (const auto& buf_usage : aliasing_buffers) {
1864 decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl("
1865 << memo_buf_decl_[buf_usage] << ")" << Doc::NewLine();
1866 buf_not_in_headers_.insert(buf_usage.get());
1867 }
1868 return decls;
1869}
1870
1871Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) {
1872 Doc doc;
1873 if (op->region.size() == 0) {
1874 doc << Print(op->buffer) << "[()]";
1875 } else {
1876 doc << Print(op->buffer) << "[";
1877 for (size_t i = 0; i < op->region.size(); ++i) {
1878 if (i != 0) doc << ", ";
1879 const auto& range = op->region[i];
1880 if (!is_one(range->extent)) {
1881 doc << Print(range->min) << " : " << Print(ana_.Simplify(range->min + range->extent));
1882 } else {
1883 doc << Print(range->min);
1884 }
1885 }
1886 doc << "]";
1887 }
1888 return doc;
1889}
1890
1891Doc TVMScriptPrinter::PrintAnnotations(const Map<String, ObjectRef>& annotations) {
1892 Doc res;
1893 std::vector<std::pair<String, ObjectRef>> anno_list;
1894 anno_list.reserve(annotations.size());
1895 for (const auto& pair : annotations) {
1896 anno_list.emplace_back(pair);
1897 }
1898 sort(anno_list.begin(), anno_list.end());
1899 for (size_t i = 0; i < anno_list.size(); ++i) {
1900 if (i != 0) {
1901 res << ", ";
1902 }
1903 res << "\"" << anno_list[i].first << "\":" << Print(anno_list[i].second);
1904 }
1905 return res;
1906}
1907
1908Doc TVMScriptPrinter::PrintLoop(const For& loop) {
1909 Doc res;
1910 res << "for " << Print(loop->loop_var) << " in " << tir_prefix_
1911 << "." + std::string(ForKind2String(loop->kind)) + "(";
1912 if (is_zero(loop->min)) {
1913 res << Print(loop->extent);
1914 } else {
1915 res << Print(loop->min) << ", " << Print(ana_.Simplify(loop->min + loop->extent));
1916 }
1917 if (loop->thread_binding.defined()) {
1918 res << ", thread=";
1919 res << Print(loop->thread_binding.value()->thread_tag);
1920 }
1921 if (!loop->annotations.empty()) {
1922 res << ", annotations={";
1923 res << PrintAnnotations(loop->annotations);
1924 res << "}";
1925 }
1926 res << "):";
1927 return res;
1928}
1929
1930Doc TVMScriptPrinter::PrintLoopStack() {
1931 Doc res;
1932 if (simple_loop_stack_.size() == 1) {
1933 res << PrintLoop(simple_loop_stack_[0]);
1934 } else if (simple_loop_stack_.size() > 1) {
1935 std::vector<Doc> vars, extents;
1936 for (const auto& loop : simple_loop_stack_) {
1937 vars.push_back(Print(loop->loop_var));
1938 extents.push_back(Print(loop->extent));
1939 }
1940 res << "for " << PrintSep(vars, Doc::Text(", ")) << " in " << tir_prefix_ << ".grid("
1941 << PrintSep(extents, Doc::Text(", ")) << "):";
1942 }
1943 return res;
1944}
1945
1946Doc TVMScriptPrinter::PrintTarget(const TargetNode* target) {
1947 Doc res;
1948 res << tir_prefix_ << ".target({";
1949 Map<String, ObjectRef> config = target->Export();
1950 for (auto it = config.begin(); it != config.end(); ++it) {
1951 if (it != config.begin()) {
1952 res << ", ";
1953 }
1954 res << "\"" << (*it).first << "\":";
1955 if ((*it).first == "host") {
1956 ICHECK(target->host.defined());
1957 res << PrintTarget(target->GetHost().value().get());
1958 } else {
1959 res << Print((*it).second);
1960 }
1961 }
1962 res << "})";
1963 return res;
1964}
1965
1966/*!
1967 * \brief The printer for TVMScript with diagnostic
1968 * \details The printer obtain the precedence of the top-level operation when printing each
1969 * subexpression to decide whether or not parentheses is needed.
1970 */
1971class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter {
1972 public:
1973 explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta,
1974 runtime::TypedPackedFunc<std::string(Stmt)> annotate)
1975 : TVMScriptPrinter(tir_prefix, show_meta, annotate) {}
1976
1977 protected:
1978 Doc PrintBlockName(const BlockNode* block_op) override;
1979 Doc PrintUnderline(const Stmt& stmt, int length);
1980 Doc PrintLoop(const For& loop) override;
1981};
1982
1983Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) {
1984 Doc doc = TVMScriptPrinter::PrintBlockName(block_op);
1985 doc << PrintUnderline(GetRef<Stmt>(block_op), doc.str().size());
1986 return doc;
1987}
1988
1989Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) {
1990 Doc doc;
1991 // annotation
1992 if (ContainsOptionalInfo(stmt)) {
1993 String underline = std::string(length, '^');
1994 doc << Doc::NewLine() << underline;
1995 }
1996 return doc;
1997}
1998
1999Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) {
2000 Doc res = TVMScriptPrinter::PrintLoop(loop);
2001 res << PrintUnderline(loop, res.str().size());
2002 return res;
2003}
2004
2005String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) {
2006 ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
2007 Doc doc;
2008 doc << TVMScriptPrinter::PrintHeader(tir_prefix)
2009 << TVMScriptPrinter(tir_prefix, show_meta).Print(mod);
2010 return doc.str() + "\n";
2011}
2012
2013TVM_REGISTER_GLOBAL("script.AsTVMScript").set_body_typed(AsTVMScript);
2014
2015String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta,
2016 runtime::TypedPackedFunc<std::string(Stmt)> annotate) {
2017 ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>());
2018 Doc doc;
2019 doc << TVMScriptPrinter::PrintHeader(tir_prefix)
2020 << TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod);
2021 return doc.str() + "\n";
2022}
2023
2024TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic").set_body_typed(AsTVMScriptWithDiagnostic);
2025
2026} // namespace tir
2027} // namespace tvm
2028