1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file printer/tvmscript_printer.cc |
22 | * \brief Printer class to print Tensor IR to python syntax script |
23 | */ |
24 | |
25 | #include <tvm/arith/analyzer.h> |
26 | #include <tvm/ir/module.h> |
27 | #include <tvm/node/serialization.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/target/target.h> |
30 | #include <tvm/tir/analysis.h> |
31 | #include <tvm/tir/buffer.h> |
32 | #include <tvm/tir/expr.h> |
33 | #include <tvm/tir/expr_functor.h> |
34 | #include <tvm/tir/function.h> |
35 | #include <tvm/tir/op.h> |
36 | #include <tvm/tir/stmt.h> |
37 | #include <tvm/tir/stmt_functor.h> |
38 | |
39 | #include <algorithm> |
40 | #include <utility> |
41 | |
42 | #include "../tir/transforms/ir_utils.h" |
43 | #include "doc.h" |
44 | #include "meta_data.h" |
45 | #include "text_printer.h" |
46 | |
47 | namespace tvm { |
48 | namespace tir { |
49 | |
50 | enum class ExprPrecedence : int { |
51 | /*! \brief Identity(e.g., IntImm, Var) and function call(e.g., floordiv, min) */ |
52 | kIdentity = 0, |
53 | /*! |
54 | * \brief Multiplication(*), division(/), and remainder(%) |
55 | * \note floorDiv, floorMod is marked as kIdentity since they are function calls. |
56 | */ |
57 | kMultiplicationDivision = 1, |
58 | /*! \brief Addition(+) and subtraction(-) */ |
59 | kAdditionSubtraction = 2, |
60 | /*! \brief For relational operators < and <= and > and >= respectively */ |
61 | kRelational = 3, |
62 | /*! \brief For equality operators = and != respectively */ |
63 | kEquality = 4, |
64 | /*! \brief And(&&) */ |
65 | kAnd = 5, |
66 | /*! \brief Or(||) */ |
67 | kOr = 6, |
68 | /*! \brief Unknown precedence */ |
69 | kUnknown = 7, |
70 | }; |
71 | |
72 | /*! \brief Utility used for identifying usage of a buffer_var |
73 | * |
74 | * \details Find the Buffer object that corresponds to a variable or |
75 | * allocation, based on the BufferLoad/BufferStore instances that |
76 | * occur within the allocation's body. |
77 | */ |
78 | class BufferUsageFinder : public StmtExprVisitor { |
79 | public: |
80 | static Map<Var, Array<Buffer>> FindUsage(Map<Var, Array<Buffer>> usage, Stmt body) { |
81 | BufferUsageFinder visitor(std::move(usage)); |
82 | visitor.VisitStmt(body); |
83 | return std::move(visitor.usage_); |
84 | } |
85 | |
86 | void VisitExpr_(const VarNode* op) final { |
87 | Var var = GetRef<Var>(op); |
88 | if (!usage_.count(var)) { |
89 | usage_.Set(var, {}); |
90 | } |
91 | } |
92 | |
93 | void VisitExpr_(const BufferLoadNode* op) final { |
94 | VisitBuffer(op->buffer); |
95 | StmtExprVisitor::VisitExpr_(op); |
96 | } |
97 | |
98 | void VisitStmt_(const BufferStoreNode* op) final { |
99 | VisitBuffer(op->buffer); |
100 | StmtExprVisitor::VisitStmt_(op); |
101 | } |
102 | |
103 | void VisitStmt_(const DeclBufferNode* op) final { |
104 | buffers_declared_.insert(op->buffer.get()); |
105 | StmtExprVisitor::VisitStmt_(op); |
106 | buffers_declared_.erase(op->buffer.get()); |
107 | } |
108 | |
109 | private: |
110 | explicit BufferUsageFinder(Map<Var, Array<Buffer>> usage) : usage_(usage) {} |
111 | |
112 | void VisitBuffer(const Buffer& buffer) { |
113 | if (buffers_visited_.count(buffer.get())) { |
114 | return; |
115 | } |
116 | if (buffers_declared_.count(buffer.get())) { |
117 | return; |
118 | } |
119 | buffers_visited_.insert(buffer.get()); |
120 | |
121 | Array<Buffer> arr = usage_.Get(buffer->data).value_or({}); |
122 | arr.push_back(buffer); |
123 | usage_.Set(buffer->data, arr); |
124 | } |
125 | |
126 | // The search result. |
127 | Map<Var, Array<Buffer>> usage_; |
128 | // The buffers that have been visited so far, to avoid duplicate |
129 | // entries in the search result. |
130 | std::unordered_set<const BufferNode*> buffers_visited_; |
131 | // The buffers declared via `DeclBuffer`. These buffers are excluded from the result because |
132 | // T.buffer_decl shouldn't be printed for them. |
133 | std::unordered_set<const BufferNode*> buffers_declared_; |
134 | }; |
135 | |
136 | /*! |
137 | * \brief The printer for TVMScript |
138 | * \details The printer obtain the precedence of the top-level operation when printing each |
139 | * subexpression to decide whether or not parentheses is needed. |
140 | */ |
141 | class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>, |
142 | public ExprFunctor<Doc(const PrimExpr&, ExprPrecedence*)>, |
143 | public TypeFunctor<Doc(const Type&)> { |
144 | public: |
145 | explicit TVMScriptPrinter(const String& tir_prefix, bool show_meta, |
146 | runtime::TypedPackedFunc<std::string(Stmt)> annotate = nullptr) |
147 | : tir_prefix_(tir_prefix), |
148 | show_meta_(show_meta), |
149 | annotate_(std::move(annotate)), |
150 | meta_collector_(&meta_) {} |
151 | |
152 | /*! |
153 | * \brief Print the node. |
154 | * \param node The node to be printed. |
155 | * \param out_precedence The operator precedence of node if it's a PrimExpr, |
156 | * so we can simplify the bracket. |
157 | */ |
158 | TVM_DLL Doc Print(const ObjectRef& node); |
159 | |
160 | protected: |
161 | /*! \brief The tir prefix */ |
162 | String tir_prefix_; |
163 | /*! \brief whether show meta data */ |
164 | bool show_meta_; |
165 | /*! \brief additional comment function */ |
166 | runtime::TypedPackedFunc<std::string(Stmt)> annotate_; |
167 | /*! \brief meta data context */ |
168 | TextMetaDataContext meta_; |
169 | /*! \brief meta collector */ |
170 | MetaCollector meta_collector_; |
171 | /*! \brief map from Function to GlobalVar */ |
172 | std::unordered_map<const BaseFuncNode*, GlobalVar> func2var_; |
173 | /*! \brief var collector (var defined by For/Loop/Block) */ |
174 | std::unordered_set<const VarNode*> ; |
175 | /*! |
176 | * \brief buffer collector |
177 | * (buffer defined in BufferMap, BufferAllocation and MatchBufferRegion) |
178 | */ |
179 | std::unordered_set<const BufferNode*> ; |
180 | /*! \brief Map from Var to thread env name */ |
181 | std::unordered_map<Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_; |
182 | /*! \brief Map from Var to Doc */ |
183 | std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_; |
184 | /*! \brief Map from Buffer to Doc */ |
185 | std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_; |
186 | /*! \brief Map from Buffer to Declaration Doc */ |
187 | std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_decl_; |
188 | /*! \brief name allocation map */ |
189 | std::unordered_map<std::string, int> name_alloc_map_; |
190 | /*! \brief number of children of current node's parent */ |
191 | int num_child_; |
192 | /*! \brief the number of current node */ |
193 | int current_num_; |
194 | /*! \brief loop stack without annotations */ |
195 | std::vector<For> simple_loop_stack_; |
196 | /*! \brief the maps from loop_vars to the loops */ |
197 | std::unordered_map<const VarNode*, For> loop_var_map_; |
198 | /*! |
199 | * \brief simple block vars remap from loop vars |
200 | * simple_remap requires: |
201 | * 1. block var iter type is kDataPar or kCommReduce |
202 | * 2. value is a single Var, which is a loop_var outside the block |
203 | * 3. The iter range is equal to loop range |
204 | */ |
205 | std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_; |
206 | /*! |
207 | * \brief Map from variables to the buffers they are used in. |
208 | * |
209 | * Used for identifying buffers that should be declared after the |
210 | * LetStmt or Allocate that generates their data pointer, rather |
211 | * than in the header. |
212 | */ |
213 | Map<Var, Array<Buffer>> buffer_var_usage_; |
214 | /*! \brief Analyzer to simplify some expressions. */ |
215 | arith::Analyzer ana_; |
216 | |
217 | Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override; |
218 | Doc VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) override; |
219 | Doc VisitExpr_(const AddNode* op, ExprPrecedence* out_precedence) override; |
220 | Doc VisitExpr_(const SubNode* op, ExprPrecedence* out_precedence) override; |
221 | Doc VisitExpr_(const MulNode* op, ExprPrecedence* out_precedence) override; |
222 | Doc VisitExpr_(const DivNode* op, ExprPrecedence* out_precedence) override; |
223 | Doc VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) override; |
224 | Doc VisitExpr_(const FloorDivNode* op, ExprPrecedence* out_precedence) override; |
225 | Doc VisitExpr_(const FloorModNode* op, ExprPrecedence* out_precedence) override; |
226 | Doc VisitExpr_(const MinNode* op, ExprPrecedence* out_precedence) override; |
227 | Doc VisitExpr_(const MaxNode* op, ExprPrecedence* out_precedence) override; |
228 | Doc VisitExpr_(const EQNode* op, ExprPrecedence* out_precedence) override; |
229 | Doc VisitExpr_(const NENode* op, ExprPrecedence* out_precedence) override; |
230 | Doc VisitExpr_(const LTNode* op, ExprPrecedence* out_precedence) override; |
231 | Doc VisitExpr_(const LENode* op, ExprPrecedence* out_precedence) override; |
232 | Doc VisitExpr_(const GTNode* op, ExprPrecedence* out_precedence) override; |
233 | Doc VisitExpr_(const GENode* op, ExprPrecedence* out_precedence) override; |
234 | Doc VisitExpr_(const AndNode* op, ExprPrecedence* out_precedence) override; |
235 | Doc VisitExpr_(const OrNode* op, ExprPrecedence* out_precedence) override; |
236 | Doc VisitExpr_(const NotNode* op, ExprPrecedence* out_precedence) override; |
237 | Doc VisitExpr_(const SelectNode* op, ExprPrecedence* out_precedence) override; |
238 | Doc VisitExpr_(const IntImmNode* op, ExprPrecedence* out_precedence) override; |
239 | Doc VisitExpr_(const FloatImmNode* op, ExprPrecedence* out_precedence) override; |
240 | Doc VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) override; |
241 | Doc VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) override; |
242 | Doc VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) override; |
243 | Doc VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) override; |
244 | Doc VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) override; |
245 | Doc VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) override; |
246 | Doc VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) override; |
247 | Doc VisitExpr_(const CallNode* op, ExprPrecedence* out_precedence) override; |
248 | Doc VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) override; |
249 | Doc VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) override; |
250 | Doc VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) override; |
251 | |
252 | Doc VisitStmt_(const LetStmtNode* op) override; |
253 | Doc VisitStmt_(const AttrStmtNode* op) override; |
254 | Doc VisitStmt_(const AssertStmtNode* op) override; |
255 | Doc VisitStmt_(const StoreNode* op) override; |
256 | Doc VisitStmt_(const BufferStoreNode* op) override; |
257 | Doc VisitStmt_(const BufferRealizeNode* op) override; |
258 | Doc VisitStmt_(const AllocateNode* op) override; |
259 | Doc VisitStmt_(const AllocateConstNode* op) override; |
260 | Doc VisitStmt_(const DeclBufferNode* op) override; |
261 | Doc VisitStmt_(const IfThenElseNode* op) override; |
262 | Doc VisitStmt_(const SeqStmtNode* op) override; |
263 | Doc VisitStmt_(const ForNode* op) override; |
264 | Doc VisitStmt_(const WhileNode* op) override; |
265 | Doc VisitStmt_(const PrefetchNode* op) override; |
266 | Doc VisitStmt_(const EvaluateNode* op) override; |
267 | Doc VisitStmt_(const BlockRealizeNode* op) override; |
268 | Doc VisitStmtDefault_(const Object* op) override; |
269 | |
270 | Doc VisitType_(const PrimTypeNode* node) override; |
271 | Doc VisitType_(const PointerTypeNode* node) override; |
272 | Doc VisitType_(const TupleTypeNode* node) override; |
273 | |
274 | Doc PrintBody(const Stmt& body); |
275 | Doc PrintIRModule(const IRModule& module); |
276 | Doc PrintPrimFunc(const PrimFunc& primFunc); |
277 | Doc PrintIterVar(const IterVarNode* op); |
278 | Doc PrintRange(const RangeNode* op); |
279 | Doc PrintArray(const ArrayNode* op); |
280 | Doc PrintBuffer(const BufferNode* op); |
281 | Doc PrintBufferIndices(const Array<PrimExpr>& indices); |
282 | Doc PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers); |
283 | Doc AllocBufferDeclaration(const Buffer& buf); |
284 | Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); |
285 | Doc PrintBlockVarRemaps(); |
286 | Doc PrintBlockPredicate(const BlockRealizeNode* op); |
287 | Doc PrintBlockVars(const BlockRealizeNode* op); |
288 | Doc PrintBlockAttr(const BlockRealizeNode* op); |
289 | Doc PrintExpandedArray(const ArrayNode* op); |
290 | Doc PrintBlockBody(const BlockNode* op); |
291 | virtual Doc PrintBlockName(const BlockNode* block_op); |
292 | Doc PrintBufferRegion(const BufferRegionNode* op); |
293 | Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); |
294 | Doc PrintCommReducer(const CommReducerNode* op); |
295 | Doc PrintAnnotations(const Map<String, ObjectRef>& annotations); |
296 | Doc PrintTarget(const TargetNode* target); |
297 | static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } |
298 | |
299 | Doc GetUniqueName(std::string prefix); |
300 | Doc AllocVar(const Var& var); |
301 | Doc AllocBuf(const Buffer& buffer); |
302 | void TryDeallocVar(const Var& var); |
303 | bool ContainsOptionalInfo(const Stmt& stmt); |
304 | /*! |
305 | * \brief Check if a buffer declaration satisfies: |
306 | * 1. has only 'shape' and 'dtype' arguments specified, |
307 | * 2. the shape and strides are not dynamic. |
308 | * \param buffer The match buffer to be checked |
309 | */ |
310 | bool IsSimpleBuffer(const Buffer& buffer); |
311 | Doc PrintInlineBufferBind(const Buffer& buffer); |
312 | Doc PrintTuple(const ArrayNode* op); |
313 | |
314 | /*! Helper functions for loop printing. */ |
315 | /*! |
316 | * \brief Print a single for loop |
317 | * \param loop The for loop to be printed |
318 | */ |
319 | virtual Doc PrintLoop(const For& loop); |
320 | /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */ |
321 | Doc PrintLoopStack(); |
322 | /*! |
323 | * \brief Check whether a loop satisfies: |
324 | * 1. the loop is serial; |
325 | * 2. the loop has no annotation; |
326 | * 3. the loop starts from 0; |
327 | * 4. there is no optional information. |
328 | * \param for_op the for node to be checked |
329 | * \return A boolean indicating whether the input loop satisfies the above conditions |
330 | */ |
331 | bool IsSimpleLoop(const ForNode* for_op) { |
332 | return for_op->kind == ForKind::kSerial && for_op->annotations.empty() && |
333 | is_zero(for_op->min) && !ContainsOptionalInfo(GetRef<Stmt>(for_op)); |
334 | } |
335 | /*! |
336 | * \brief Check whether the `min` or `extent` of a loop depends on previous loops |
337 | * \param for_op The loop to be checked |
338 | * \return A boolean indicating whether the input loop depends on previous loops |
339 | */ |
340 | bool DependOnPrevLoops(const ForNode* for_op) { |
341 | auto f_check = [&var_map = this->loop_var_map_](const VarNode* v) { return var_map.count(v); }; |
342 | return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check); |
343 | } |
344 | |
345 | /*! |
346 | * \brief Print additional info about expr in comment. |
347 | * \param expr The expression. |
348 | */ |
349 | Doc PrintOptionalInfo(const Stmt& stmt) { |
350 | Doc doc; |
351 | // default annotations |
352 | if (ContainsOptionalInfo(stmt)) { |
353 | std::string annotated_stmt = annotate_(stmt); |
354 | doc << "# " << annotated_stmt << Doc::NewLine(); |
355 | } |
356 | return doc; |
357 | } |
358 | |
359 | /*! |
360 | * \brief special method to render vectors of docs with a separator |
361 | * \param vec vector of docs |
362 | * \param sep separator |
363 | */ |
364 | static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) { |
365 | Doc seq; |
366 | if (vec.size() != 0) { |
367 | seq = vec[0]; |
368 | for (size_t i = 1; i < vec.size(); i++) { |
369 | seq << sep << vec[i]; |
370 | } |
371 | } |
372 | return seq; |
373 | } |
374 | |
375 | /*! |
376 | * \brief dump meta info |
377 | * \return Doc with meta info |
378 | */ |
379 | Doc DumpMeta() { |
380 | if (show_meta_) { |
381 | return Doc::Text("__tvm_meta__ = " ) |
382 | << (meta_.empty() ? Doc::Text("None" ) : meta_.GetMetaSection()); |
383 | } else { |
384 | return Doc::Text("" ); |
385 | } |
386 | } |
387 | |
388 | /*! |
389 | * \brief special method to print out data type |
390 | * \param dtype The data type |
391 | */ |
392 | static Doc PrintDType(DataType dtype) { |
393 | return Doc::StrLiteral(runtime::DLDataType2String(dtype)); |
394 | } |
395 | |
396 | /*! |
397 | * \brief special method to print out const int64_t scalar |
398 | * \param dtype The data type |
399 | * \param data The pointer to hold the data. |
400 | */ |
401 | Doc PrintConstScalar(DataType dtype, const int64_t* data) const { |
402 | Doc doc; |
403 | std::ostringstream os; |
404 | |
405 | os << data[0]; |
406 | |
407 | if (dtype == DataType::Int(32)) { |
408 | doc << Doc::Text(os.str()); |
409 | } else if (dtype == DataType::Bool()) { |
410 | doc << Doc::Text(data[0] ? "True" : "False" ); |
411 | } else { |
412 | doc << tir_prefix_ << "." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str()) |
413 | << ")" ; |
414 | } |
415 | return doc; |
416 | } |
417 | |
418 | /*! |
419 | * \brief special method to print out const double scalar |
420 | * \param dtype The data type |
421 | * \param data The pointer to hold the data. |
422 | * \note this overriden function is created as std::isnan of msvc will complain about int64_t |
423 | */ |
424 | Doc PrintConstScalar(DataType dtype, const double* data) const { |
425 | Doc doc; |
426 | std::ostringstream os; |
427 | |
428 | os.precision(17); |
429 | if (std::isinf(data[0]) || std::isnan(data[0])) { |
430 | os << "\"" << data[0] << "\"" ; |
431 | } else { |
432 | os << data[0]; |
433 | } |
434 | |
435 | doc << tir_prefix_ << "." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str()) |
436 | << ")" ; |
437 | |
438 | return doc; |
439 | } |
440 | |
441 | public: |
442 | static Doc (const std::string& tir_prefix) { |
443 | Doc ; |
444 | if (tir_prefix != "tir" ) { |
445 | header << "# from tvm.script import tir as " << tir_prefix << Doc::NewLine(); |
446 | } else { |
447 | header << "# from tvm.script import tir" << Doc::NewLine(); |
448 | } |
449 | return header; |
450 | } |
451 | }; |
452 | |
453 | /*! |
454 | * \brief special method to print NDArray in TIR |
455 | * \param arr the NDArray to be printed |
456 | * \param os the output stream where the NDArray will be printed to |
457 | */ |
458 | template <typename T> |
459 | void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) { |
460 | if ((arr.DataType().code() == runtime::DataType::kInt || |
461 | arr.DataType().code() == runtime::DataType::kUInt) && |
462 | arr.DataType().bits() == 8) { |
463 | // Printing int8 NDArrays causes "UnicodeDecodeError: 'utf-8' codec can't decode byte" |
464 | // error during MetaSchedule tuning on int8 models. |
465 | return; |
466 | } |
467 | int ndim = arr->ndim; |
468 | int tot_dim = 1; |
469 | for (int i = 0; i < ndim; i++) { |
470 | tot_dim *= arr->shape[i]; |
471 | } |
472 | T* data_ptr = reinterpret_cast<T*>(arr->data); |
473 | constexpr int NUM_PRINT = 20; |
474 | os << "[" ; |
475 | for (int i = 0; i < tot_dim; i++) { |
476 | os << (i != 0 ? ", " : "" ) << data_ptr[i]; |
477 | if (i == NUM_PRINT) { |
478 | os << "..." ; |
479 | break; |
480 | } |
481 | } |
482 | os << "]" ; |
483 | } |
484 | |
485 | Doc TVMScriptPrinter::GetUniqueName(std::string prefix) { |
486 | std::replace(prefix.begin(), prefix.end(), '.', '_'); |
487 | std::string unique_prefix = prefix; |
488 | auto it = name_alloc_map_.find(prefix); |
489 | if (it != name_alloc_map_.end() && it->second >= 0) { |
490 | while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) { |
491 | } |
492 | } |
493 | name_alloc_map_[unique_prefix] = 0; |
494 | return Doc::Text(unique_prefix); |
495 | } |
496 | |
497 | Doc TVMScriptPrinter::AllocVar(const Var& var) { |
498 | const auto& it = memo_var_.find(var); |
499 | if (it != memo_var_.end()) { |
500 | return it->second; |
501 | } |
502 | std::string name = var->name_hint.operator std::string(); |
503 | if (name.length() == 0 || !std::isalpha(name[0])) { |
504 | name = "v" + name; |
505 | } |
506 | Doc val = GetUniqueName(name); |
507 | memo_var_[var] = val; |
508 | return val; |
509 | } |
510 | |
511 | Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) { |
512 | Doc doc = Print(buf->shape); |
513 | bool print_factor_explicitly = false; |
514 | doc << ", dtype=" << PrintDType(buf->dtype); |
515 | if (memo_var_.find(buf->data) != memo_var_.end()) { |
516 | doc << ", data=" << Print(buf->data); |
517 | } else { |
518 | // implicitly define data |
519 | memo_var_[buf->data] = Doc::Text(memo_buf_[buf].str() + ".data" ); |
520 | var_not_in_headers_.insert(buf->data.get()); |
521 | } |
522 | if (!buf->strides.empty()) { |
523 | doc << ", strides=" << Print(buf->strides); |
524 | } |
525 | if (buf->elem_offset->IsInstance<VarNode>()) { |
526 | Var elem_offset = Downcast<Var>(buf->elem_offset); |
527 | if (memo_var_.find(elem_offset) != memo_var_.end()) { |
528 | doc << ", elem_offset=" << Print(buf->elem_offset); |
529 | } else { |
530 | // implicitly define elem_offset |
531 | memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset" ); |
532 | var_not_in_headers_.insert(elem_offset.get()); |
533 | print_factor_explicitly = true; |
534 | } |
535 | } else if (buf->elem_offset->IsInstance<IntImmNode>()) { |
536 | IntImm elem_offset = Downcast<IntImm>(buf->elem_offset); |
537 | if (elem_offset->value != 0) { |
538 | doc << ", elem_offset=" << Print(buf->elem_offset); |
539 | } |
540 | } |
541 | if (buf.scope() != "global" ) { |
542 | doc << ", scope=" << Doc::StrLiteral(buf.scope()); |
543 | } |
544 | if (buf->data_alignment != runtime::kAllocAlignment) { |
545 | doc << ", align=" << buf->data_alignment; |
546 | } |
547 | if (buf->offset_factor != 1 || print_factor_explicitly) { |
548 | doc << ", offset_factor=" << buf->offset_factor; |
549 | } |
550 | if (buf->buffer_type != BufferType::kDefault) { |
551 | doc << ", type=" << Doc::StrLiteral("auto" ); |
552 | } |
553 | if (buf->axis_separators.size()) { |
554 | doc << ", axis_separators=" << Print(buf->axis_separators); |
555 | } |
556 | return doc; |
557 | } |
558 | |
559 | Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) { |
560 | const auto& it = memo_buf_.find(buffer); |
561 | if (it != memo_buf_.end()) { |
562 | return it->second; |
563 | } |
564 | std::string name = buffer->name; |
565 | if (name.length() == 0 || !std::isalpha(name[0])) { |
566 | name = "buf_" + name; |
567 | } |
568 | Doc val = GetUniqueName(name); |
569 | memo_buf_[buffer] = val; |
570 | memo_buf_decl_[buffer] = AllocBufferDeclaration(buffer); |
571 | return val; |
572 | } |
573 | |
574 | /*! |
575 | * \brief Check if any optional information exists in annotate_ for |
576 | * a given Stmt. |
577 | * \param stmt The statement. |
578 | */ |
579 | bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) { |
580 | if (annotate_ == nullptr) return false; |
581 | return !annotate_(stmt).empty(); |
582 | } |
583 | |
584 | /*! |
585 | * \brief Try to dealloc vars out of space and leave the index to coming vars. |
586 | * \note It is not a necessary step. |
587 | */ |
588 | void TVMScriptPrinter::TryDeallocVar(const Var& var) { |
589 | auto it = memo_var_.find(var); |
590 | ICHECK(it != memo_var_.end()); |
591 | std::string print_name = it->second.str(); |
592 | |
593 | std::string name_hint = var->name_hint.operator std::string(); |
594 | if (name_hint.length() == 0 || !std::isalpha(name_hint[0])) { |
595 | name_hint = "v" + name_hint; |
596 | } |
597 | std::replace(name_hint.begin(), name_hint.end(), '.', '_'); |
598 | |
599 | auto it2 = name_alloc_map_.find(name_hint); |
600 | // Skip it if we can not find the name_hint in name_alloc_map_. |
601 | if (it2 == name_alloc_map_.end()) return; |
602 | if (it2->second > 0) { |
603 | name_hint = name_hint + '_' + std::to_string(it2->second); |
604 | } |
605 | // Skip it if the name_hint is not equal to how it should be printed. |
606 | if (name_hint != print_name) return; |
607 | // Free the conresponding name_alloc_map_ index |
608 | --it2->second; |
609 | } |
610 | |
611 | Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { |
612 | const Buffer& buf = op->buffer; |
613 | buf_not_in_headers_.insert(buf.get()); |
614 | |
615 | Doc doc = Print(op->buffer) << " = " << tir_prefix_ << ".match_buffer(" << Print(op->source) |
616 | << ", " << memo_buf_decl_[op->buffer] << ")" ; |
617 | return doc; |
618 | } |
619 | |
620 | // check if all arguments, except the first two, are specified for T.match_buffer |
621 | // if not, then this match buffer is printed out as T.buffer in prim_func arguments |
622 | // and check whether there are undefined variables in the shape/strides. |
623 | bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) { |
624 | if (memo_var_.find(buf->data) != memo_var_.end()) { |
625 | return false; |
626 | } |
627 | if (!buf->strides.empty()) { |
628 | return false; |
629 | } |
630 | for (const PrimExpr& shp_i : buf->shape) { |
631 | if (!UndefinedVars(shp_i).empty()) { |
632 | return false; |
633 | } |
634 | } |
635 | for (const PrimExpr& stride_i : buf->strides) { |
636 | if (!UndefinedVars(stride_i).empty()) { |
637 | return false; |
638 | } |
639 | } |
640 | if (!UndefinedVars(buf->elem_offset).empty()) { |
641 | return false; |
642 | } else if (buf->elem_offset->IsInstance<IntImmNode>()) { |
643 | IntImm elem_offset = Downcast<IntImm>(buf->elem_offset); |
644 | if (elem_offset->value != 0) { |
645 | return false; |
646 | } |
647 | } |
648 | if (buf.scope() != "global" ) { |
649 | return false; |
650 | } |
651 | if (buf->data_alignment != runtime::kAllocAlignment) { |
652 | return false; |
653 | } |
654 | if (buf->offset_factor != 1) { |
655 | return false; |
656 | } |
657 | if (buf->buffer_type != BufferType::kDefault) { |
658 | return false; |
659 | } |
660 | if (buf->axis_separators.size()) { |
661 | return false; |
662 | } |
663 | return true; |
664 | } |
665 | |
666 | Doc TVMScriptPrinter::PrintInlineBufferBind(const Buffer& buffer) { |
667 | Doc doc; |
668 | doc << tir_prefix_ << ".Buffer[" ; |
669 | if (buffer->shape.size() == 1) { |
670 | doc << Print(buffer->shape[0]); |
671 | } else { |
672 | doc << PrintTuple(buffer->shape.as<ArrayNode>()); |
673 | } |
674 | doc << ", " << PrintDType(buffer->dtype) << "]" ; |
675 | return doc; |
676 | } |
677 | |
678 | // print array out as tuple with parentheses |
679 | Doc TVMScriptPrinter::PrintTuple(const ArrayNode* op) { |
680 | Doc doc; |
681 | doc << '('; |
682 | for (size_t i = 0; i < op->size(); ++i) { |
683 | if (i != 0) { |
684 | doc << ", " ; |
685 | } |
686 | doc << Print(op->at(i)); |
687 | } |
688 | if (op->size() == 1) doc << "," ; |
689 | doc << ')'; |
690 | return doc; |
691 | } |
692 | |
693 | Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) { |
694 | Doc doc; |
695 | int n_var = static_cast<int>(op->rhs.size()); |
696 | |
697 | doc << tir_prefix_ << ".comm_reducer(lambda " ; |
698 | for (const Var& v_lhs : op->lhs) { |
699 | doc << Print(v_lhs) << ", " ; |
700 | } |
701 | for (int i = 0; i < n_var; ++i) { |
702 | doc << Print(op->rhs[i]) << (i == n_var - 1 ? ": " : ", " ); |
703 | } |
704 | if (n_var == 1) { |
705 | doc << Print(op->result[0]) << ", " ; |
706 | } else { |
707 | doc << "(" ; |
708 | for (int i = 0; i < n_var; ++i) { |
709 | doc << Print(op->result[i]); |
710 | if (i != n_var - 1) { |
711 | doc << ", " ; |
712 | } |
713 | } |
714 | doc << "), " ; |
715 | } |
716 | doc << Print(op->identity_element) << ")" ; |
717 | |
718 | // Remove the vars in `lhs` and `rhs`, because they are the parameters of the printed lambda. |
719 | for (int i = 0; i < n_var; ++i) { |
720 | memo_var_.erase(op->lhs[i]); |
721 | memo_var_.erase(op->rhs[i]); |
722 | } |
723 | return doc; |
724 | } |
725 | |
726 | Doc TVMScriptPrinter::Print(const ObjectRef& node) { |
727 | if (!node.defined()) return Doc::Text("None" ); |
728 | if (node->IsInstance<StmtNode>()) { |
729 | return PrintOptionalInfo(Downcast<Stmt>(node)) << VisitStmt(Downcast<Stmt>(node)); |
730 | } else if (node->IsInstance<PrimExprNode>()) { |
731 | ExprPrecedence t = ExprPrecedence::kUnknown; |
732 | return VisitExpr(Downcast<PrimExpr>(node), &t); |
733 | } else if (node->IsInstance<TypeNode>()) { |
734 | return VisitType(Downcast<Type>(node)); |
735 | } else if (node->IsInstance<PrimFuncNode>()) { |
736 | return PrintPrimFunc(Downcast<PrimFunc>(node)); |
737 | } else if (node->IsInstance<IRModuleNode>()) { |
738 | return PrintIRModule(Downcast<IRModule>(node)); |
739 | } else if (node->IsInstance<ArrayNode>()) { |
740 | return PrintArray(node.as<ArrayNode>()); |
741 | } else if (node->IsInstance<BufferNode>()) { |
742 | return PrintBuffer(node.as<BufferNode>()); |
743 | } else if (node->IsInstance<StringObj>()) { |
744 | return PrintString(node.as<StringObj>()); |
745 | } else if (node->IsInstance<IterVarNode>()) { |
746 | return PrintIterVar(node.as<IterVarNode>()); |
747 | } else if (node->IsInstance<RangeNode>()) { |
748 | return PrintRange(node.as<RangeNode>()); |
749 | } else if (node->IsInstance<BufferRegionNode>()) { |
750 | return PrintBufferRegion(node.as<BufferRegionNode>()); |
751 | } else if (node->IsInstance<MatchBufferRegionNode>()) { |
752 | return PrintMatchBufferRegion(node.as<MatchBufferRegionNode>()); |
753 | } else if (node->IsInstance<CommReducerNode>()) { |
754 | return PrintCommReducer(node.as<CommReducerNode>()); |
755 | } else if (node->IsInstance<TargetNode>()) { |
756 | return PrintTarget(node.as<TargetNode>()); |
757 | } else { |
758 | LOG(FATAL) << "Do not know how to print " << node->GetTypeKey(); |
759 | } |
760 | } |
761 | |
762 | Doc TVMScriptPrinter::VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) { |
763 | LOG(FATAL) << "Do not know how to print " << op->GetTypeKey(); |
764 | } |
765 | |
766 | Doc TVMScriptPrinter::VisitStmtDefault_(const Object* op) { |
767 | LOG(FATAL) << "Do not know how to print " << op->GetTypeKey(); |
768 | } |
769 | |
770 | Doc TVMScriptPrinter::VisitExpr_(const IntImmNode* op, ExprPrecedence* out_precedence) { |
771 | *out_precedence = ExprPrecedence::kIdentity; |
772 | return PrintConstScalar(op->dtype, &(op->value)); |
773 | } |
774 | |
775 | Doc TVMScriptPrinter::VisitExpr_(const FloatImmNode* op, ExprPrecedence* out_precedence) { |
776 | *out_precedence = ExprPrecedence::kIdentity; |
777 | return PrintConstScalar(op->dtype, &(op->value)); |
778 | } |
779 | |
780 | Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) { |
781 | *out_precedence = ExprPrecedence::kIdentity; |
782 | return Doc::StrLiteral(op->value); |
783 | } |
784 | |
785 | Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) { |
786 | *out_precedence = ExprPrecedence::kIdentity; |
787 | Doc doc; |
788 | doc << tir_prefix_ << ".Cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")" ; |
789 | return doc; |
790 | } |
791 | |
792 | Doc TVMScriptPrinter::VisitExpr_(const VarNode* op, ExprPrecedence* out_precedence) { |
793 | *out_precedence = ExprPrecedence::kIdentity; |
794 | const Var& var = GetRef<Var>(op); |
795 | return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<Var>(op)); |
796 | } |
797 | |
798 | bool WillPrintConstScalar(const PrimExpr& expr) { |
799 | if (const auto* imm = expr.as<IntImmNode>()) { |
800 | DataType dtype = imm->dtype; |
801 | return dtype == DataType::Int(32) || dtype == DataType::Bool(); |
802 | } |
803 | return false; |
804 | } |
805 | |
806 | #define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpClass, OpPrecedence) \ |
807 | Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \ |
808 | Doc doc; \ |
809 | if (WillPrintConstScalar(op->a) && WillPrintConstScalar(op->b)) { \ |
810 | *out_precedence = ExprPrecedence::kIdentity; \ |
811 | doc << tir_prefix_ << "." << OpClass << "(" << Print(op->a) << ", " << Print(op->b) << ")"; \ |
812 | return doc; \ |
813 | } \ |
814 | ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown; \ |
815 | ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown; \ |
816 | /* Get children expr out_precedence */ \ |
817 | Doc lhs_doc = VisitExpr(op->a, &lhs_precedence); \ |
818 | Doc rhs_doc = VisitExpr(op->b, &rhs_precedence); \ |
819 | ICHECK(lhs_precedence != ExprPrecedence::kUnknown); \ |
820 | ICHECK(rhs_precedence != ExprPrecedence::kUnknown); \ |
821 | /* Update out_precedence of current node. */ \ |
822 | *out_precedence = OpPrecedence; \ |
823 | if (lhs_precedence > OpPrecedence || \ |
824 | (lhs_precedence == ExprPrecedence::kAnd && OpPrecedence == ExprPrecedence::kOr)) { \ |
825 | doc << "(" << lhs_doc << ")"; \ |
826 | } else { \ |
827 | doc << lhs_doc; \ |
828 | } \ |
829 | doc << OpString; \ |
830 | if (rhs_precedence >= OpPrecedence || \ |
831 | (rhs_precedence == ExprPrecedence::kAnd && OpPrecedence == ExprPrecedence::kOr)) { \ |
832 | doc << "(" << rhs_doc << ")"; \ |
833 | } else { \ |
834 | doc << rhs_doc; \ |
835 | } \ |
836 | return doc; \ |
837 | } |
838 | |
839 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * " , "Mul" , ExprPrecedence::kMultiplicationDivision) |
840 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / " , "Div" , ExprPrecedence::kMultiplicationDivision) |
841 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // " , "FloorDiv" , |
842 | ExprPrecedence::kMultiplicationDivision) |
843 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % " , "FloorMod" , |
844 | ExprPrecedence::kMultiplicationDivision) |
845 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + " , "Add" , ExprPrecedence::kAdditionSubtraction) |
846 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - " , "Sub" , ExprPrecedence::kAdditionSubtraction) |
847 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < " , "LT" , ExprPrecedence::kRelational) |
848 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= " , "LE" , ExprPrecedence::kRelational) |
849 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > " , "GT" , ExprPrecedence::kRelational) |
850 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= " , "GE" , ExprPrecedence::kRelational) |
851 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == " , "EQ" , ExprPrecedence::kEquality) |
852 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != " , "NE" , ExprPrecedence::kEquality) |
853 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and " , "And" , ExprPrecedence::kAnd) |
854 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or " , "Or" , ExprPrecedence::kOr) |
855 | |
856 | Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) { |
857 | *out_precedence = ExprPrecedence::kIdentity; |
858 | Doc doc; |
859 | doc << tir_prefix_ << ".truncmod(" << Print(op->a) << ", " << Print(op->b) << ")" ; |
860 | return doc; |
861 | } |
862 | |
863 | Doc TVMScriptPrinter::VisitExpr_(const MinNode* op, ExprPrecedence* out_precedence) { |
864 | *out_precedence = ExprPrecedence::kIdentity; |
865 | Doc doc; |
866 | doc << tir_prefix_ << ".min(" << Print(op->a) << ", " << Print(op->b) << ")" ; |
867 | return doc; |
868 | } |
869 | |
870 | Doc TVMScriptPrinter::VisitExpr_(const MaxNode* op, ExprPrecedence* out_precedence) { |
871 | *out_precedence = ExprPrecedence::kIdentity; |
872 | Doc doc; |
873 | doc << tir_prefix_ << ".max(" << Print(op->a) << ", " << Print(op->b) << ")" ; |
874 | return doc; |
875 | } |
876 | |
877 | Doc TVMScriptPrinter::VisitExpr_(const NotNode* op, ExprPrecedence* out_precedence) { |
878 | *out_precedence = ExprPrecedence::kIdentity; |
879 | Doc doc; |
880 | doc << "not(" << Print(op->a) << ")" ; |
881 | return doc; |
882 | } |
883 | |
884 | Doc TVMScriptPrinter::VisitExpr_(const SelectNode* op, ExprPrecedence* out_precedence) { |
885 | *out_precedence = ExprPrecedence::kIdentity; |
886 | Doc doc; |
887 | doc << tir_prefix_ << ".Select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " |
888 | << Print(op->false_value) << ")" ; |
889 | return doc; |
890 | } |
891 | |
892 | Doc TVMScriptPrinter::VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) { |
893 | LOG(FATAL) << "Cannot print a tir.ProducerLoad as it is not valid in TIR Primfuncs. You need to " |
894 | "lower this function first." ; |
895 | return Doc(); |
896 | } |
897 | |
898 | Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) { |
899 | *out_precedence = ExprPrecedence::kIdentity; |
900 | Doc doc; |
901 | if (op->indices.size() == 0) { |
902 | doc << Print(op->buffer) << "[()]" ; |
903 | } else { |
904 | doc << Print(op->buffer) << PrintBufferIndices(op->indices); |
905 | } |
906 | return doc; |
907 | } |
908 | |
909 | Doc TVMScriptPrinter::VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) { |
910 | *out_precedence = ExprPrecedence::kIdentity; |
911 | Doc doc; |
912 | if (op->dtype == DataType::Float(32) && is_one(op->predicate) && |
913 | op->buffer_var->dtype == DataType::Float(32)) { |
914 | doc << Print(op->buffer_var) << "[" << Print(op->index) << "]" ; |
915 | } else { |
916 | doc << tir_prefix_ << ".load(" << PrintDType(op->dtype) << ", " << Print(op->buffer_var) << ", " |
917 | << Print(op->index); |
918 | if (!is_one(op->predicate) || op->dtype.lanes() != 1) { |
919 | doc << ", " << Print(op->predicate); |
920 | } |
921 | doc << ")" ; |
922 | } |
923 | return doc; |
924 | } |
925 | |
926 | Doc TVMScriptPrinter::VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) { |
927 | *out_precedence = ExprPrecedence::kIdentity; |
928 | Doc doc; |
929 | doc << tir_prefix_ << ".ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " |
930 | << op->lanes << ")" ; |
931 | return doc; |
932 | } |
933 | |
934 | Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) { |
935 | *out_precedence = ExprPrecedence::kIdentity; |
936 | Doc doc; |
937 | doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " << op->lanes << ")" ; |
938 | return doc; |
939 | } |
940 | |
941 | Doc TVMScriptPrinter::VisitExpr_(const LetNode* op, ExprPrecedence* out_precedence) { |
942 | *out_precedence = ExprPrecedence::kIdentity; |
943 | Doc doc; |
944 | doc << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << ", " |
945 | << Print(op->body) << ")" ; |
946 | return doc; |
947 | } |
948 | |
949 | Doc TVMScriptPrinter::VisitExpr_(const CallNode* op, ExprPrecedence* out_precedence) { |
950 | *out_precedence = ExprPrecedence::kIdentity; |
951 | Doc doc; |
952 | if (auto* ptr_op = op->op.as<OpNode>()) { |
953 | std::string name = ptr_op->name; |
954 | if (name.find("tir." ) == 0) { |
955 | name = tir_prefix_ + "." + name.substr(4); |
956 | } |
957 | doc << name << "(" ; |
958 | } else { |
959 | auto* op_gvar = op->op.as<GlobalVarNode>(); |
960 | ICHECK(op_gvar != nullptr); |
961 | doc << Doc::Text(op_gvar->name_hint) << "(" ; |
962 | } |
963 | std::vector<Doc> args; |
964 | for (const auto& arg : op->args) { |
965 | args.push_back(Print(arg)); |
966 | } |
967 | args.push_back(Doc::Text("dtype=" ) << PrintDType(op->dtype)); |
968 | doc << PrintSep(args, Doc::Text(", " )) << ")" ; |
969 | return doc; |
970 | } |
971 | |
972 | Doc TVMScriptPrinter::VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) { |
973 | *out_precedence = ExprPrecedence::kIdentity; |
974 | Doc doc; |
975 | doc << tir_prefix_ << ".shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")" ; |
976 | return doc; |
977 | } |
978 | |
979 | Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) { |
980 | *out_precedence = ExprPrecedence::kIdentity; |
981 | Doc doc; |
982 | doc << tir_prefix_ << ".reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " |
983 | << Print(op->axis) << ", " << op->value_index << ")" ; |
984 | return doc; |
985 | } |
986 | |
987 | Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { |
988 | if (!buffer_var_usage_.count(op->var)) { |
989 | buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body); |
990 | } |
991 | Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->var).value_or({}); |
992 | |
993 | Doc doc; |
994 | if (current_num_ != num_child_ - 1) { |
995 | doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):" ; |
996 | doc << Doc::Indent( |
997 | 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body)); |
998 | } else { |
999 | if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get()); |
1000 | doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value) |
1001 | << Doc::NewLine(); |
1002 | doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body); |
1003 | } |
1004 | return doc; |
1005 | } |
1006 | |
1007 | Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { |
1008 | Doc doc; |
1009 | if (op->node.defined()) { |
1010 | // merge attr with realize when possible |
1011 | if (op->node->IsInstance<BufferNode>() && op->attr_key == "realize_scope" && |
1012 | op->body->IsInstance<BufferRealizeNode>()) { |
1013 | const auto* realize = Downcast<BufferRealize>(op->body).get(); |
1014 | if (realize->buffer.same_as(op->node)) { |
1015 | if (current_num_ != num_child_ - 1) { |
1016 | doc << "with " << tir_prefix_ << ".realize(" << Print(realize->buffer) |
1017 | << Print(realize->bounds) << ", " << Print(op->value); |
1018 | if (!is_one(realize->condition)) { |
1019 | doc << ", " << Print(realize->condition); |
1020 | } |
1021 | doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(realize->body)); |
1022 | } else { |
1023 | doc << tir_prefix_ << ".realize(" << Print(realize->buffer) << Print(realize->bounds) |
1024 | << ", " << Print(op->value); |
1025 | if (!is_one(realize->condition)) { |
1026 | doc << ", " << Print(realize->condition); |
1027 | } |
1028 | doc << ")" << Doc::NewLine() << PrintBody(realize->body); |
1029 | } |
1030 | return doc; |
1031 | } |
1032 | } |
1033 | // concise thread env |
1034 | if (op->node->IsInstance<IterVarNode>() && |
1035 | (op->attr_key == "thread_extent" || op->attr_key == "virtual_thread" )) { |
1036 | const auto* iter_var = Downcast<IterVar>(op->node).get(); |
1037 | var_not_in_headers_.insert(iter_var->var.get()); |
1038 | var_env_map_[iter_var->var] = iter_var->thread_tag; |
1039 | if (current_num_ != num_child_ - 1) { |
1040 | doc << "with " << tir_prefix_ << ".launch_thread(" << Print(iter_var->var) << ", " |
1041 | << Print(op->value) << "):" ; |
1042 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1043 | } else { |
1044 | doc << tir_prefix_ << ".launch_thread(" << Print(iter_var->var) << ", " << Print(op->value) |
1045 | << ")" ; |
1046 | doc << Doc::NewLine() << PrintBody(op->body); |
1047 | } |
1048 | return doc; |
1049 | } |
1050 | } |
1051 | // default |
1052 | if (current_num_ != num_child_ - 1) { |
1053 | doc << "with " << tir_prefix_ << ".attr(" << Print(op->node) << ", " |
1054 | << Doc::StrLiteral(op->attr_key) << ", " << Print(op->value) << "):" ; |
1055 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1056 | } else { |
1057 | doc << tir_prefix_ << ".attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) |
1058 | << ", " << Print(op->value) << ")" ; |
1059 | doc << Doc::NewLine() << PrintBody(op->body); |
1060 | } |
1061 | return doc; |
1062 | } |
1063 | |
1064 | Doc TVMScriptPrinter::VisitStmt_(const AssertStmtNode* op) { |
1065 | Doc doc; |
1066 | if (current_num_ != num_child_ - 1) { |
1067 | doc << "with " << tir_prefix_ << ".Assert(" << Print(op->condition) << ", " |
1068 | << Print(op->message) << "):" ; |
1069 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1070 | } else { |
1071 | doc << "assert " << Print(op->condition) << ", " << Print(op->message); |
1072 | doc << Doc::NewLine() << PrintBody(op->body); |
1073 | } |
1074 | return doc; |
1075 | } |
1076 | |
1077 | Doc TVMScriptPrinter::VisitStmt_(const StoreNode* op) { |
1078 | Doc doc; |
1079 | doc << tir_prefix_ << ".store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", " |
1080 | << Print(op->value) << ", " << Print(op->predicate) << ")" ; |
1081 | return doc; |
1082 | } |
1083 | |
1084 | Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { |
1085 | LOG(FATAL) |
1086 | << "TVM Script Printer Internal Error: All the BufferRealize should be folded with Attr" ; |
1087 | return Doc(); |
1088 | } |
1089 | |
1090 | namespace { |
1091 | |
1092 | bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) { |
1093 | const Var& buffer_var = allocate->buffer_var; |
1094 | const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>(); |
1095 | if (!decl_buffer) { |
1096 | return false; |
1097 | } |
1098 | const Buffer& buffer = decl_buffer->buffer; |
1099 | if (!buffer_var.same_as(buffer->data)) { |
1100 | return false; |
1101 | } |
1102 | if (allocate->dtype != buffer->dtype) { |
1103 | return false; |
1104 | } |
1105 | if (!is_one(allocate->condition)) { |
1106 | return false; |
1107 | } |
1108 | if (allocate->annotations.size()) { |
1109 | return false; |
1110 | } |
1111 | if (allocate->extents.size() != buffer->shape.size()) { |
1112 | return false; |
1113 | } |
1114 | tir::ExprDeepEqual expr_equal; |
1115 | for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) { |
1116 | if (!expr_equal(allocate->extents[i], buffer->shape[i])) { |
1117 | return false; |
1118 | } |
1119 | } |
1120 | return true; |
1121 | } |
1122 | |
1123 | } // namespace |
1124 | |
1125 | Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { |
1126 | var_not_in_headers_.insert(op->buffer_var.get()); |
1127 | |
1128 | if (!buffer_var_usage_.count(op->buffer_var)) { |
1129 | buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body); |
1130 | } |
1131 | Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({}); |
1132 | |
1133 | if (buffer_usage.empty()) { |
1134 | if (IsAllocateDeclBufferPattern(op)) { |
1135 | // As a syntax sugar, we identify the pattern of Allocate and DeclBuffer and print a single |
1136 | // DeclBuffer statement. It is intentionally to call `Print` instead of `PrintBody` here to |
1137 | // delegate the printing of the current node to `DeclBufferNode` while maintaining the |
1138 | // same value of `current_num_` and `num_child_`. |
1139 | return Print(op->body); |
1140 | } |
1141 | } |
1142 | |
1143 | auto storage_scope = GetPtrStorageScope(op->buffer_var); |
1144 | Doc func_call; |
1145 | func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype) |
1146 | << ", " << Print(storage_scope); |
1147 | if (!is_one(op->condition)) { |
1148 | func_call << ", " << Print(op->condition); |
1149 | } |
1150 | if (!op->annotations.empty()) { |
1151 | func_call << ", annotations={" ; |
1152 | func_call << PrintAnnotations(op->annotations); |
1153 | func_call << "}" ; |
1154 | } |
1155 | func_call << ")" ; |
1156 | |
1157 | Doc doc; |
1158 | if (current_num_ != num_child_ - 1) { |
1159 | doc << "with " << func_call << " as " << Print(op->buffer_var) << ":" ; |
1160 | doc << Doc::Indent( |
1161 | 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body)); |
1162 | } else { |
1163 | doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine(); |
1164 | doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body); |
1165 | } |
1166 | TryDeallocVar(op->buffer_var); |
1167 | return doc; |
1168 | } |
1169 | |
1170 | Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { |
1171 | std::stringstream ss; |
1172 | ICHECK(alloc->data) << "Should be presented" ; |
1173 | const auto& data = alloc->data.value(); |
1174 | |
1175 | if (alloc->dtype.is_int()) { |
1176 | if (alloc->dtype.bits() == 8) { |
1177 | NDArrayToTIR<int8_t>(data, ss); |
1178 | } else if (alloc->dtype.bits() == 16) { |
1179 | NDArrayToTIR<int16_t>(data, ss); |
1180 | } else if (alloc->dtype.bits() == 32) { |
1181 | NDArrayToTIR<int32_t>(data, ss); |
1182 | } else if (alloc->dtype.bits() == 64) { |
1183 | NDArrayToTIR<int64_t>(data, ss); |
1184 | } else { |
1185 | LOG(FATAL) << "DataType not supported" ; |
1186 | } |
1187 | } else if (alloc->dtype.is_uint()) { |
1188 | if (alloc->dtype.bits() == 8) { |
1189 | NDArrayToTIR<uint8_t>(data, ss); |
1190 | } else if (alloc->dtype.bits() == 16) { |
1191 | NDArrayToTIR<uint16_t>(data, ss); |
1192 | } else if (alloc->dtype.bits() == 32) { |
1193 | NDArrayToTIR<uint32_t>(data, ss); |
1194 | } else if (alloc->dtype.bits() == 64) { |
1195 | NDArrayToTIR<int64_t>(data, ss); |
1196 | } else { |
1197 | LOG(FATAL) << "DataType not supported" ; |
1198 | } |
1199 | } else if (alloc->dtype.is_float()) { |
1200 | if (alloc->dtype.bits() == 16) { |
1201 | NDArrayToTIR<int16_t>(data, ss); |
1202 | } else if (alloc->dtype.bits() == 32) { |
1203 | NDArrayToTIR<float>(data, ss); |
1204 | } else if (alloc->dtype.bits() == 64) { |
1205 | NDArrayToTIR<double>(data, ss); |
1206 | } else { |
1207 | LOG(FATAL) << "DataType not supported" ; |
1208 | } |
1209 | } else { |
1210 | LOG(FATAL) << "DataType not supported" ; |
1211 | } |
1212 | auto ndarray_str = ss.str(); |
1213 | |
1214 | var_not_in_headers_.insert(alloc->buffer_var.get()); |
1215 | |
1216 | if (!buffer_var_usage_.count(alloc->buffer_var)) { |
1217 | buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), alloc->body); |
1218 | } |
1219 | Array<Buffer> buffer_usage = buffer_var_usage_.Get(alloc->buffer_var).value_or({}); |
1220 | |
1221 | Doc func_call; |
1222 | func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) |
1223 | << ", " << Print(alloc->extents) << ")" ; |
1224 | |
1225 | Doc doc; |
1226 | var_not_in_headers_.insert(alloc->buffer_var.get()); |
1227 | if (current_num_ != num_child_ - 1) { |
1228 | doc << "with " << func_call << " as " << Print(alloc->buffer_var) << ":" ; |
1229 | doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) |
1230 | << PrintBody(alloc->body)); |
1231 | } else { |
1232 | doc << Print(alloc->buffer_var) << " = " << func_call << Doc::NewLine(); |
1233 | doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(alloc->body); |
1234 | } |
1235 | return doc; |
1236 | } |
1237 | |
1238 | Doc TVMScriptPrinter::VisitStmt_(const DeclBufferNode* op) { |
1239 | const Buffer& buffer = op->buffer; |
1240 | buf_not_in_headers_.insert(buffer.get()); |
1241 | Doc buffer_name = Print(op->buffer); |
1242 | Doc func_call; |
1243 | func_call << tir_prefix_ << ".decl_buffer(" << memo_buf_decl_.at(buffer) << ")" ; |
1244 | |
1245 | Doc doc; |
1246 | if (current_num_ != num_child_ - 1) { |
1247 | doc << "with " << func_call << " as " << buffer_name << ":" ; |
1248 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1249 | } else { |
1250 | doc << buffer_name << " = " << func_call << Doc::NewLine(); |
1251 | doc << PrintBody(op->body); |
1252 | } |
1253 | return doc; |
1254 | } |
1255 | |
1256 | Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { |
1257 | Doc doc; |
1258 | doc << "if " << Print(op->condition) << ":" ; |
1259 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case)); |
1260 | |
1261 | Optional<Stmt> else_case = op->else_case; |
1262 | while (else_case) { |
1263 | if (auto* else_if = else_case.value().as<IfThenElseNode>()) { |
1264 | doc << Doc::NewLine(); |
1265 | doc << "elif " << Print(else_if->condition) << ":" ; |
1266 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(else_if->then_case)); |
1267 | |
1268 | else_case = else_if->else_case; |
1269 | } else { |
1270 | doc << Doc::NewLine(); |
1271 | doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(else_case.value())); |
1272 | break; |
1273 | } |
1274 | } |
1275 | |
1276 | return doc; |
1277 | } |
1278 | |
1279 | Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) { |
1280 | std::vector<Doc> stmts; |
1281 | for (Stmt stmt : op->seq) { |
1282 | stmts.push_back(Print(stmt)); |
1283 | } |
1284 | return PrintSep(stmts, Doc::NewLine()); |
1285 | } |
1286 | |
1287 | Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { |
1288 | // When parsing TVMScript, a PrimExpr that occurs as a statement is |
1289 | // automatically wrapped in `tir::Evaluate`. Therefore, when |
1290 | // printing, it's only necessary to print the value. For |
1291 | // readability, though, we still print T.evaluate() when the |
1292 | // expression is something other than a call node. |
1293 | Doc doc; |
1294 | if (op->value.as<CallNode>()) { |
1295 | doc << Print(op->value); |
1296 | } else { |
1297 | doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")" ; |
1298 | } |
1299 | return doc; |
1300 | } |
1301 | |
1302 | Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { |
1303 | Doc doc; |
1304 | var_not_in_headers_.insert(op->loop_var.get()); |
1305 | loop_var_map_[op->loop_var.get()] = GetRef<For>(op); |
1306 | const auto* body = op->body.as<ForNode>(); |
1307 | bool simple_loop = IsSimpleLoop(op); |
1308 | if (simple_loop) simple_loop_stack_.push_back(GetRef<For>(op)); |
1309 | // It is a loop that can be compressed, let the loops below print it out |
1310 | if (simple_loop && body != nullptr && IsSimpleLoop(body) && !DependOnPrevLoops(body)) { |
1311 | doc << Print(GetRef<For>(body)); |
1312 | TryDeallocVar(op->loop_var); |
1313 | loop_var_map_.erase(op->loop_var.get()); |
1314 | return doc; |
1315 | } |
1316 | // It is a loop that can not be compressed |
1317 | bool print_above = !simple_loop_stack_.empty(); |
1318 | // print loops above if needed |
1319 | if (print_above) { |
1320 | doc << PrintLoopStack(); |
1321 | simple_loop_stack_.clear(); |
1322 | } |
1323 | if (!simple_loop) { |
1324 | // print current loop if needed |
1325 | Doc current_loop; |
1326 | current_loop << PrintLoop(GetRef<For>(op)); |
1327 | current_loop << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1328 | doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) : current_loop); |
1329 | } else { |
1330 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1331 | } |
1332 | TryDeallocVar(op->loop_var); |
1333 | loop_var_map_.erase(op->loop_var.get()); |
1334 | return doc; |
1335 | } |
1336 | |
1337 | Doc TVMScriptPrinter::VisitStmt_(const PrefetchNode* op) { |
1338 | Doc doc; |
1339 | doc << tir_prefix_ << ".prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")" ; |
1340 | return doc; |
1341 | } |
1342 | |
1343 | Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) { |
1344 | Doc doc; |
1345 | doc << "while " << Print(op->condition) << ":" ; |
1346 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1347 | return doc; |
1348 | } |
1349 | |
1350 | Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) { |
1351 | Doc doc; |
1352 | doc << tir_prefix_ << "." ; |
1353 | if (node->dtype.is_void()) { |
1354 | doc << "void" ; |
1355 | } else { |
1356 | doc << runtime::DLDataType2String(node->dtype); |
1357 | } |
1358 | return doc; |
1359 | } |
1360 | |
1361 | Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) { |
1362 | Doc doc; |
1363 | doc << tir_prefix_ << ".Ptr[" ; |
1364 | doc << Print(node->element_type); |
1365 | if (!node->storage_scope.empty()) { |
1366 | doc << ", " << Doc::StrLiteral(node->storage_scope); |
1367 | } |
1368 | doc << "]" ; |
1369 | return doc; |
1370 | } |
1371 | |
1372 | Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) { |
1373 | if (node->fields.empty()) { |
1374 | return Doc::Text("None" ); |
1375 | } else { |
1376 | std::vector<Doc> fields; |
1377 | for (Type field : node->fields) { |
1378 | fields.push_back(Print(field)); |
1379 | } |
1380 | return Doc::Text(tir_prefix_ + ".Tuple[" ) << Doc::Concat(fields) << "]" ; |
1381 | } |
1382 | } |
1383 | |
1384 | Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { |
1385 | Doc doc; |
1386 | if (op->indices.size() == 0) { |
1387 | doc << Print(op->buffer) << "[()] = " << Print(op->value); |
1388 | } else { |
1389 | doc << Print(op->buffer) << PrintBufferIndices(op->indices) << " = " << Print(op->value); |
1390 | } |
1391 | return doc; |
1392 | } |
1393 | |
1394 | /*! Helper functions for block printing. */ |
1395 | Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { |
1396 | Doc doc; |
1397 | doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis." ; |
1398 | switch (iter_var->iter_type) { |
1399 | case kDataPar: |
1400 | doc << "spatial" ; |
1401 | break; |
1402 | case kCommReduce: |
1403 | doc << "reduce" ; |
1404 | break; |
1405 | case kOrdered: |
1406 | doc << "scan" ; |
1407 | break; |
1408 | case kOpaque: |
1409 | doc << "opaque" ; |
1410 | break; |
1411 | default: |
1412 | LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type; |
1413 | break; |
1414 | } |
1415 | doc << "(" ; |
1416 | const Range& dom = iter_var->dom; |
1417 | if (is_zero(dom->min)) { |
1418 | doc << Print(dom->extent); |
1419 | } else { |
1420 | doc << "(" << Print(dom->min) << ", " << Print(dom->min + dom->extent) << ")" ; |
1421 | } |
1422 | doc << ", " << Print(value) << ")" ; |
1423 | return doc; |
1424 | } |
1425 | |
1426 | Doc TVMScriptPrinter::PrintBlockVarRemaps() { |
1427 | ICHECK(!block_var_remaps_.empty()); |
1428 | if (block_var_remaps_.size() == 1) { |
1429 | const IterVar& iter_var = block_var_remaps_[0].first; |
1430 | const PrimExpr& value = block_var_remaps_[0].second; |
1431 | return PrintBlockVar(iter_var, value); |
1432 | } |
1433 | Doc doc; |
1434 | std::vector<Doc> iter_vars, iter_values; |
1435 | std::string iter_type; |
1436 | for (const auto& pair : block_var_remaps_) { |
1437 | const IterVar& iter_var = pair.first; |
1438 | const PrimExpr& value = pair.second; |
1439 | iter_vars.push_back(Print(iter_var->var)); |
1440 | iter_values.push_back(Print(value)); |
1441 | if (iter_var->iter_type == kDataPar) { |
1442 | iter_type += "S" ; |
1443 | } else if (iter_var->iter_type == kCommReduce) { |
1444 | iter_type += "R" ; |
1445 | } else { |
1446 | ICHECK(false); |
1447 | } |
1448 | } |
1449 | doc << PrintSep(iter_vars, Doc::Text(", " )) << " = " << tir_prefix_ << ".axis.remap(" |
1450 | << Doc::StrLiteral(iter_type) << ", [" << PrintSep(iter_values, Doc::Text(", " )) << "])" ; |
1451 | return doc; |
1452 | } |
1453 | |
1454 | Doc TVMScriptPrinter::PrintBlockPredicate(const BlockRealizeNode* op) { |
1455 | Doc doc; |
1456 | if (!is_one(op->predicate)) { |
1457 | doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")" ; |
1458 | } |
1459 | return doc; |
1460 | } |
1461 | |
1462 | Doc TVMScriptPrinter::PrintBlockVars(const BlockRealizeNode* op) { |
1463 | Doc doc; |
1464 | const auto* block_op = op->block.as<BlockNode>(); |
1465 | ICHECK_EQ(block_op->iter_vars.size(), op->iter_values.size()); |
1466 | tir::ExprDeepEqual expr_equal; |
1467 | |
1468 | auto is_simple_remap = [this, &expr_equal](const IterVar& iter_var, |
1469 | const PrimExpr& value) -> bool { |
1470 | if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) return false; |
1471 | if (!value->IsInstance<VarNode>()) return false; |
1472 | const Var& var = Downcast<Var>(value); |
1473 | auto it = loop_var_map_.find(var.get()); |
1474 | return it != loop_var_map_.end() && expr_equal(it->second->min, iter_var->dom->min) && |
1475 | expr_equal(it->second->extent, iter_var->dom->extent); |
1476 | }; |
1477 | |
1478 | for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { |
1479 | const IterVar& iter_var = block_op->iter_vars[i]; |
1480 | const PrimExpr& value = op->iter_values[i]; |
1481 | var_not_in_headers_.insert(iter_var->var.get()); |
1482 | if (is_simple_remap(iter_var, value)) { |
1483 | block_var_remaps_.push_back(std::make_pair(iter_var, value)); |
1484 | } else { |
1485 | if (!block_var_remaps_.empty()) { |
1486 | doc << Doc::NewLine() << PrintBlockVarRemaps(); |
1487 | block_var_remaps_.clear(); |
1488 | } |
1489 | doc << Doc::NewLine() << PrintBlockVar(iter_var, value); |
1490 | } |
1491 | } |
1492 | if (!block_var_remaps_.empty()) { |
1493 | doc << Doc::NewLine() << PrintBlockVarRemaps(); |
1494 | block_var_remaps_.clear(); |
1495 | } |
1496 | return doc; |
1497 | } |
1498 | |
1499 | Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) { |
1500 | const auto* block_op = op->block.as<BlockNode>(); |
1501 | Doc block_attr_doc; |
1502 | // print binding, read/write tensor region, annotations |
1503 | block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads(" |
1504 | << PrintExpandedArray(block_op->reads.as<ArrayNode>()) << ")" ; |
1505 | block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes(" |
1506 | << PrintExpandedArray(block_op->writes.as<ArrayNode>()) << ")" ; |
1507 | if (!block_op->annotations.empty()) { |
1508 | block_attr_doc << Doc::NewLine() << tir_prefix_ << ".block_attr({" ; |
1509 | block_attr_doc << PrintAnnotations(block_op->annotations); |
1510 | block_attr_doc << "})" ; |
1511 | } |
1512 | return block_attr_doc; |
1513 | } |
1514 | |
1515 | // This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a |
1516 | // List. Therefore the brackets are removed before and after printing arguments out |
1517 | Doc TVMScriptPrinter::PrintExpandedArray(const ArrayNode* op) { |
1518 | Doc doc; |
1519 | for (size_t i = 0; i < op->size(); ++i) { |
1520 | if (i != 0) { |
1521 | doc << ", " ; |
1522 | } |
1523 | doc << Print(op->at(i)); |
1524 | } |
1525 | return doc; |
1526 | } |
1527 | |
1528 | Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { |
1529 | Doc body; |
1530 | for (const auto& alloc_buf : op->alloc_buffers) { |
1531 | buf_not_in_headers_.insert(alloc_buf.get()); |
1532 | body << Print(alloc_buf) << " = " << tir_prefix_ << ".alloc_buffer(" |
1533 | << memo_buf_decl_[alloc_buf] << ")" << Doc::NewLine(); |
1534 | } |
1535 | for (const auto& match_buf : op->match_buffers) { |
1536 | body << Print(match_buf) << Doc::NewLine(); |
1537 | } |
1538 | if (op->init.defined()) { |
1539 | Doc init_block; |
1540 | init_block << "with " << tir_prefix_ << ".init():" ; |
1541 | init_block << Doc::Indent(4, Doc::NewLine() << PrintBody(op->init.value())); |
1542 | body << init_block << Doc::NewLine(); |
1543 | } |
1544 | body << PrintBody(op->body); |
1545 | return body; |
1546 | } |
1547 | |
1548 | /*! |
1549 | * \brief Print the name of a block |
1550 | * \param block_op The block node to be printed |
1551 | */ |
1552 | Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) { |
1553 | Doc doc; |
1554 | doc << "with " << tir_prefix_ << ".block(" ; |
1555 | if (!block_op->name_hint.empty()) { |
1556 | doc << Doc::StrLiteral(block_op->name_hint); |
1557 | } |
1558 | doc << "):" ; |
1559 | return doc; |
1560 | } |
1561 | |
1562 | Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { |
1563 | const auto* block_op = op->block.as<BlockNode>(); |
1564 | Doc doc = PrintOptionalInfo(GetRef<Stmt>(block_op)); |
1565 | // print block name |
1566 | doc << PrintBlockName(block_op); |
1567 | // Print block predicate. |
1568 | Doc block_predicate = PrintBlockPredicate(op); |
1569 | // Print the variable bindings, valid to use in block attributes and |
1570 | // body |
1571 | Doc block_var = PrintBlockVars(op); |
1572 | // print read/write tensor region, annotations |
1573 | Doc block_attr_doc = PrintBlockAttr(op); |
1574 | // print body |
1575 | Doc body = PrintBlockBody(block_op); |
1576 | doc << Doc::Indent(4, block_predicate << block_var << block_attr_doc << Doc::NewLine() << body); |
1577 | for (const auto& iter_var : block_op->iter_vars) { |
1578 | TryDeallocVar(iter_var->var); |
1579 | } |
1580 | return doc; |
1581 | } |
1582 | |
1583 | Doc TVMScriptPrinter::PrintBody(const Stmt& body) { |
1584 | int memo_num_child, memo_current_num; |
1585 | std::swap(memo_num_child, num_child_); |
1586 | std::swap(memo_current_num, current_num_); |
1587 | |
1588 | Doc doc; |
1589 | if (body->IsInstance<SeqStmtNode>()) { |
1590 | const auto& op = Downcast<SeqStmt>(body); |
1591 | num_child_ = op->seq.size(); |
1592 | current_num_ = 0; |
1593 | std::vector<Doc> stmts; |
1594 | for (Stmt stmt : op->seq) { |
1595 | stmts.push_back(Print(stmt)); |
1596 | current_num_++; |
1597 | } |
1598 | doc = PrintSep(stmts, Doc::NewLine()); |
1599 | } else { |
1600 | num_child_ = 1; |
1601 | current_num_ = 0; |
1602 | doc = Print(body); |
1603 | } |
1604 | |
1605 | std::swap(memo_num_child, num_child_); |
1606 | std::swap(memo_current_num, current_num_); |
1607 | return doc; |
1608 | } |
1609 | |
1610 | Doc TVMScriptPrinter::PrintIRModule(const IRModule& module) { |
1611 | auto* op = module.operator->(); |
1612 | Doc doc; |
1613 | doc << "@tvm.script.ir_module" << Doc::NewLine(); |
1614 | doc << "class Module:" ; |
1615 | for (const auto& x : op->functions) { |
1616 | func2var_[x.second.operator->()] = x.first; |
1617 | } |
1618 | Doc body = Doc::NewLine(); |
1619 | std::vector<Doc> functions; |
1620 | for (auto it = op->functions.begin(); it != op->functions.end(); ++it) { |
1621 | if ((*it).second.as<PrimFuncNode>()) { |
1622 | functions.push_back(Print((*it).second)); |
1623 | } |
1624 | } |
1625 | body << TVMScriptPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine()); |
1626 | body << Doc::NewLine() << DumpMeta(); |
1627 | doc << Doc::Indent(4, body); |
1628 | return doc; |
1629 | } |
1630 | |
1631 | Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { |
1632 | auto* op = primFunc.operator->(); |
1633 | // clear renaming map |
1634 | memo_var_.clear(); |
1635 | memo_buf_.clear(); |
1636 | memo_buf_decl_.clear(); |
1637 | var_not_in_headers_.clear(); |
1638 | buf_not_in_headers_.clear(); |
1639 | // print signature |
1640 | Doc doc; |
1641 | doc << "@" << tir_prefix_ << ".prim_func" << Doc::NewLine(); |
1642 | doc << "def " << (func2var_.find(op) == func2var_.end() ? "func" : func2var_[op]->name_hint) |
1643 | << "(" ; |
1644 | std::vector<Doc> params; |
1645 | std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> simple_buf; |
1646 | for (const auto& param : op->params) { |
1647 | var_not_in_headers_.insert(param.get()); |
1648 | auto it = op->buffer_map.find(param); |
1649 | // check if this param is a T.handle |
1650 | if (it != op->buffer_map.end()) { |
1651 | // check if this match_buffer has only the first two arguments specified |
1652 | // and whether the match_buffer is a dynamic buffer. |
1653 | const Buffer& buf = (*it).second; |
1654 | if (IsSimpleBuffer(buf)) { |
1655 | simple_buf.insert(buf); |
1656 | buf_not_in_headers_.insert(buf.get()); |
1657 | params.push_back(Print(buf) << ": " << PrintInlineBufferBind(buf)); |
1658 | continue; |
1659 | } |
1660 | } |
1661 | params.push_back(Print(param) << ": " << Print(GetType(param))); |
1662 | } |
1663 | doc << PrintSep(params, Doc::Text(", " )) << ")" ; |
1664 | if (primFunc->ret_type.defined()) { |
1665 | auto as_tuple = primFunc->ret_type.as<TupleTypeNode>(); |
1666 | if (!as_tuple || as_tuple->fields.size()) { |
1667 | doc << " -> " << Print(primFunc->ret_type); |
1668 | } |
1669 | } |
1670 | doc << ":" ; |
1671 | |
1672 | Doc body = Doc::NewLine(); |
1673 | // print buffer_bind |
1674 | for (const auto& param : op->params) { |
1675 | auto it = op->buffer_map.find(param); |
1676 | if (it == op->buffer_map.end()) continue; |
1677 | const Buffer& buf = (*it).second; |
1678 | if (simple_buf.count(buf)) continue; |
1679 | buf_not_in_headers_.insert(buf.get()); |
1680 | body << Print(buf) << " = " << tir_prefix_ << ".match_buffer(" ; |
1681 | ICHECK(memo_buf_decl_.count(buf)); |
1682 | body << Print((*it).first) << ", " << memo_buf_decl_[buf]; |
1683 | body << ")" << Doc::NewLine(); |
1684 | } |
1685 | // print body |
1686 | body << "# body" << Doc::NewLine(); |
1687 | |
1688 | Optional<Block> elided_root_block_body = [&]() -> Optional<Block> { |
1689 | auto block_realize = op->body.as<BlockRealizeNode>(); |
1690 | if (!block_realize || block_realize->iter_values.size()) { |
1691 | return NullOpt; |
1692 | } |
1693 | |
1694 | const auto& block = block_realize->block; |
1695 | if (block->annotations.size() || ContainsOptionalInfo(block)) { |
1696 | return NullOpt; |
1697 | } |
1698 | |
1699 | // The autocomplete might recognize the body itself as being a |
1700 | // root block, and fail to insert it. |
1701 | bool autocomplete_would_insert_root_block = [&]() -> bool { |
1702 | if (block->alloc_buffers.size()) { |
1703 | return true; |
1704 | } |
1705 | |
1706 | auto* block_realize = block->body.as<BlockRealizeNode>(); |
1707 | if (block_realize && block_realize->block->iter_vars.size()) { |
1708 | return true; |
1709 | } |
1710 | if (!block_realize && ContainsNode<BlockRealizeNode>(block->body)) { |
1711 | return true; |
1712 | } |
1713 | return false; |
1714 | }(); |
1715 | |
1716 | if (autocomplete_would_insert_root_block) { |
1717 | return block; |
1718 | } else { |
1719 | return NullOpt; |
1720 | } |
1721 | }(); |
1722 | |
1723 | if (elided_root_block_body) { |
1724 | // Skip printing of root block in cases where tvm::tir::ScriptComplete |
1725 | // would re-insert it. |
1726 | body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine(); |
1727 | body << PrintBlockBody(elided_root_block_body.value().get()); |
1728 | } else { |
1729 | // If this is a non-root block, or is an unskippable root block, |
1730 | // just print it without skipping. |
1731 | body << PrintBody(op->body); |
1732 | } |
1733 | |
1734 | // print func attrs |
1735 | Doc ; |
1736 | if (primFunc->attrs.defined()) { |
1737 | header_attr << Doc::NewLine() << "# function attr dict" << Doc::NewLine() << tir_prefix_ |
1738 | << ".func_attr({" ; |
1739 | std::vector<Doc> attrs; |
1740 | for (const auto& it : op->attrs->dict) { |
1741 | attrs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); |
1742 | } |
1743 | header_attr << PrintSep(attrs, Doc::Text(", " )) << "})" ; |
1744 | } |
1745 | // print buffer declarations(buffers not defined by buffer_bind or buffer_allocate) |
1746 | Doc ; |
1747 | std::vector<const BufferNode*> bufs; |
1748 | for (const auto& it : memo_buf_) { |
1749 | if (buf_not_in_headers_.find(it.first.get()) == buf_not_in_headers_.end()) { |
1750 | bufs.push_back(it.first.get()); |
1751 | } |
1752 | } |
1753 | if (!bufs.empty()) { |
1754 | header_buf << Doc::NewLine() << "# buffer definition" ; |
1755 | std::sort(bufs.begin(), bufs.end(), [&](const BufferNode* a, const BufferNode* b) { |
1756 | return memo_buf_[GetRef<Buffer>(a)].str() < memo_buf_[GetRef<Buffer>(b)].str(); |
1757 | }); |
1758 | for (const auto& buf : bufs) { |
1759 | header_buf << Doc::NewLine() << Print(GetRef<Buffer>(buf)) << " = " << tir_prefix_ |
1760 | << ".buffer_decl(" ; |
1761 | header_buf << memo_buf_decl_[GetRef<Buffer>(buf)] << ")" ; |
1762 | } |
1763 | } |
1764 | // print var declaration |
1765 | Doc ; |
1766 | std::vector<const VarNode*> vars; |
1767 | for (const auto& it : memo_var_) { |
1768 | if (var_not_in_headers_.find(it.first.get()) == var_not_in_headers_.end()) { |
1769 | vars.push_back(it.first.get()); |
1770 | } |
1771 | } |
1772 | if (!var_env_map_.empty()) { |
1773 | header_var << Doc::NewLine() << "# var definition" ; |
1774 | for (const auto& it : var_env_map_) { |
1775 | header_var << Doc::NewLine() << Print(it.first) << " = " << tir_prefix_ << ".env_thread(" |
1776 | << Doc::StrLiteral(it.second) << ")" ; |
1777 | } |
1778 | } |
1779 | if (!vars.empty()) { |
1780 | std::sort(vars.begin(), vars.end(), [&](const VarNode* a, const VarNode* b) { |
1781 | return memo_var_[GetRef<Var>(a)].str() < memo_var_[GetRef<Var>(b)].str(); |
1782 | }); |
1783 | for (const auto& var : vars) { |
1784 | auto type = GetRef<Var>(var)->type_annotation; |
1785 | if (auto* ptr_type = type.as<PointerTypeNode>()) { |
1786 | auto* prim_type = ptr_type->element_type.as<PrimTypeNode>(); |
1787 | ICHECK(prim_type); |
1788 | header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = " << tir_prefix_ |
1789 | << ".buffer_var(" ; |
1790 | header_var << PrintDType(prim_type->dtype) << ", " |
1791 | << Doc::StrLiteral(ptr_type->storage_scope) << ")" ; |
1792 | } else { |
1793 | header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = " << tir_prefix_ << ".var(" ; |
1794 | header_var << PrintDType(var->dtype) << ")" ; |
1795 | } |
1796 | } |
1797 | } |
1798 | doc << Doc::Indent(4, header_attr << header_var << header_buf << body); |
1799 | return doc; |
1800 | } |
1801 | |
1802 | Doc TVMScriptPrinter::PrintArray(const ArrayNode* op) { |
1803 | Doc doc; |
1804 | doc << '['; |
1805 | for (size_t i = 0; i < op->size(); ++i) { |
1806 | if (i != 0) { |
1807 | doc << ", " ; |
1808 | } |
1809 | doc << Print(op->at(i)); |
1810 | } |
1811 | doc << ']'; |
1812 | return doc; |
1813 | } |
1814 | |
1815 | Doc TVMScriptPrinter::PrintIterVar(const IterVarNode* op) { |
1816 | Doc doc; |
1817 | doc << tir_prefix_ << ".iter_var(" << Print(op->var); |
1818 | if (op->dom.defined()) { |
1819 | doc << ", [" << Print(op->dom) << "], " ; |
1820 | } else { |
1821 | doc << ", None, " ; |
1822 | } |
1823 | doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", " ; |
1824 | doc << Doc::StrLiteral(op->thread_tag) << ")" ; |
1825 | return doc; |
1826 | } |
1827 | |
1828 | Doc TVMScriptPrinter::PrintRange(const RangeNode* op) { |
1829 | return Print(op->min) << ":" << Print(op->min + op->extent); |
1830 | } |
1831 | |
1832 | Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) { |
1833 | const Buffer& buffer = GetRef<Buffer>(op); |
1834 | return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); |
1835 | } |
1836 | |
1837 | Doc TVMScriptPrinter::PrintBufferIndices(const Array<PrimExpr>& indices) { |
1838 | Doc doc; |
1839 | doc << '['; |
1840 | for (size_t i = 0; i < indices.size(); ++i) { |
1841 | if (i != 0) { |
1842 | doc << ", " ; |
1843 | } |
1844 | PrimExpr index = indices[i]; |
1845 | if (const RampNode* ramp = index.as<RampNode>()) { |
1846 | // specify ramp printing as python index slice |
1847 | if (auto* stride_imm = ramp->stride.as<IntImmNode>()) { |
1848 | doc << Print(ramp->base) << ":" << Print(ramp->base + ramp->lanes * ramp->stride); |
1849 | if (stride_imm->value != 1) { |
1850 | doc << ":" << Print(ramp->stride); |
1851 | } |
1852 | continue; |
1853 | } |
1854 | } |
1855 | doc << Print(index); |
1856 | } |
1857 | doc << ']'; |
1858 | return doc; |
1859 | } |
1860 | |
1861 | Doc TVMScriptPrinter::(const Array<Buffer>& aliasing_buffers) { |
1862 | Doc decls; |
1863 | for (const auto& buf_usage : aliasing_buffers) { |
1864 | decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl(" |
1865 | << memo_buf_decl_[buf_usage] << ")" << Doc::NewLine(); |
1866 | buf_not_in_headers_.insert(buf_usage.get()); |
1867 | } |
1868 | return decls; |
1869 | } |
1870 | |
1871 | Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) { |
1872 | Doc doc; |
1873 | if (op->region.size() == 0) { |
1874 | doc << Print(op->buffer) << "[()]" ; |
1875 | } else { |
1876 | doc << Print(op->buffer) << "[" ; |
1877 | for (size_t i = 0; i < op->region.size(); ++i) { |
1878 | if (i != 0) doc << ", " ; |
1879 | const auto& range = op->region[i]; |
1880 | if (!is_one(range->extent)) { |
1881 | doc << Print(range->min) << " : " << Print(ana_.Simplify(range->min + range->extent)); |
1882 | } else { |
1883 | doc << Print(range->min); |
1884 | } |
1885 | } |
1886 | doc << "]" ; |
1887 | } |
1888 | return doc; |
1889 | } |
1890 | |
1891 | Doc TVMScriptPrinter::PrintAnnotations(const Map<String, ObjectRef>& annotations) { |
1892 | Doc res; |
1893 | std::vector<std::pair<String, ObjectRef>> anno_list; |
1894 | anno_list.reserve(annotations.size()); |
1895 | for (const auto& pair : annotations) { |
1896 | anno_list.emplace_back(pair); |
1897 | } |
1898 | sort(anno_list.begin(), anno_list.end()); |
1899 | for (size_t i = 0; i < anno_list.size(); ++i) { |
1900 | if (i != 0) { |
1901 | res << ", " ; |
1902 | } |
1903 | res << "\"" << anno_list[i].first << "\":" << Print(anno_list[i].second); |
1904 | } |
1905 | return res; |
1906 | } |
1907 | |
1908 | Doc TVMScriptPrinter::PrintLoop(const For& loop) { |
1909 | Doc res; |
1910 | res << "for " << Print(loop->loop_var) << " in " << tir_prefix_ |
1911 | << "." + std::string(ForKind2String(loop->kind)) + "(" ; |
1912 | if (is_zero(loop->min)) { |
1913 | res << Print(loop->extent); |
1914 | } else { |
1915 | res << Print(loop->min) << ", " << Print(ana_.Simplify(loop->min + loop->extent)); |
1916 | } |
1917 | if (loop->thread_binding.defined()) { |
1918 | res << ", thread=" ; |
1919 | res << Print(loop->thread_binding.value()->thread_tag); |
1920 | } |
1921 | if (!loop->annotations.empty()) { |
1922 | res << ", annotations={" ; |
1923 | res << PrintAnnotations(loop->annotations); |
1924 | res << "}" ; |
1925 | } |
1926 | res << "):" ; |
1927 | return res; |
1928 | } |
1929 | |
1930 | Doc TVMScriptPrinter::PrintLoopStack() { |
1931 | Doc res; |
1932 | if (simple_loop_stack_.size() == 1) { |
1933 | res << PrintLoop(simple_loop_stack_[0]); |
1934 | } else if (simple_loop_stack_.size() > 1) { |
1935 | std::vector<Doc> vars, extents; |
1936 | for (const auto& loop : simple_loop_stack_) { |
1937 | vars.push_back(Print(loop->loop_var)); |
1938 | extents.push_back(Print(loop->extent)); |
1939 | } |
1940 | res << "for " << PrintSep(vars, Doc::Text(", " )) << " in " << tir_prefix_ << ".grid(" |
1941 | << PrintSep(extents, Doc::Text(", " )) << "):" ; |
1942 | } |
1943 | return res; |
1944 | } |
1945 | |
1946 | Doc TVMScriptPrinter::PrintTarget(const TargetNode* target) { |
1947 | Doc res; |
1948 | res << tir_prefix_ << ".target({" ; |
1949 | Map<String, ObjectRef> config = target->Export(); |
1950 | for (auto it = config.begin(); it != config.end(); ++it) { |
1951 | if (it != config.begin()) { |
1952 | res << ", " ; |
1953 | } |
1954 | res << "\"" << (*it).first << "\":" ; |
1955 | if ((*it).first == "host" ) { |
1956 | ICHECK(target->host.defined()); |
1957 | res << PrintTarget(target->GetHost().value().get()); |
1958 | } else { |
1959 | res << Print((*it).second); |
1960 | } |
1961 | } |
1962 | res << "})" ; |
1963 | return res; |
1964 | } |
1965 | |
1966 | /*! |
1967 | * \brief The printer for TVMScript with diagnostic |
1968 | * \details The printer obtain the precedence of the top-level operation when printing each |
1969 | * subexpression to decide whether or not parentheses is needed. |
1970 | */ |
1971 | class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { |
1972 | public: |
1973 | explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta, |
1974 | runtime::TypedPackedFunc<std::string(Stmt)> annotate) |
1975 | : TVMScriptPrinter(tir_prefix, show_meta, annotate) {} |
1976 | |
1977 | protected: |
1978 | Doc PrintBlockName(const BlockNode* block_op) override; |
1979 | Doc PrintUnderline(const Stmt& stmt, int length); |
1980 | Doc PrintLoop(const For& loop) override; |
1981 | }; |
1982 | |
1983 | Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { |
1984 | Doc doc = TVMScriptPrinter::PrintBlockName(block_op); |
1985 | doc << PrintUnderline(GetRef<Stmt>(block_op), doc.str().size()); |
1986 | return doc; |
1987 | } |
1988 | |
1989 | Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) { |
1990 | Doc doc; |
1991 | // annotation |
1992 | if (ContainsOptionalInfo(stmt)) { |
1993 | String underline = std::string(length, '^'); |
1994 | doc << Doc::NewLine() << underline; |
1995 | } |
1996 | return doc; |
1997 | } |
1998 | |
1999 | Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { |
2000 | Doc res = TVMScriptPrinter::PrintLoop(loop); |
2001 | res << PrintUnderline(loop, res.str().size()); |
2002 | return res; |
2003 | } |
2004 | |
2005 | String AsTVMScript(const ObjectRef& mod, const String& tir_prefix, bool show_meta) { |
2006 | ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>()); |
2007 | Doc doc; |
2008 | doc << TVMScriptPrinter::PrintHeader(tir_prefix) |
2009 | << TVMScriptPrinter(tir_prefix, show_meta).Print(mod); |
2010 | return doc.str() + "\n" ; |
2011 | } |
2012 | |
2013 | TVM_REGISTER_GLOBAL("script.AsTVMScript" ).set_body_typed(AsTVMScript); |
2014 | |
2015 | String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, |
2016 | runtime::TypedPackedFunc<std::string(Stmt)> annotate) { |
2017 | ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>()); |
2018 | Doc doc; |
2019 | doc << TVMScriptPrinter::PrintHeader(tir_prefix) |
2020 | << TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod); |
2021 | return doc.str() + "\n" ; |
2022 | } |
2023 | |
2024 | TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic" ).set_body_typed(AsTVMScriptWithDiagnostic); |
2025 | |
2026 | } // namespace tir |
2027 | } // namespace tvm |
2028 | |