1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file printer/tvmscript_printer.cc |
22 | * \brief Printer class to print Tensor IR to python syntax script |
23 | */ |
24 | |
25 | #include <tvm/arith/analyzer.h> |
26 | #include <tvm/ir/module.h> |
27 | #include <tvm/node/serialization.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/target/target.h> |
30 | #include <tvm/tir/analysis.h> |
31 | #include <tvm/tir/buffer.h> |
32 | #include <tvm/tir/expr.h> |
33 | #include <tvm/tir/expr_functor.h> |
34 | #include <tvm/tir/function.h> |
35 | #include <tvm/tir/op.h> |
36 | #include <tvm/tir/stmt.h> |
37 | #include <tvm/tir/stmt_functor.h> |
38 | |
39 | #include <algorithm> |
40 | #include <utility> |
41 | |
42 | #include "../../tir/transforms/ir_utils.h" |
43 | #include "doc.h" |
44 | #include "meta_data.h" |
45 | #include "text_printer.h" |
46 | |
47 | namespace tvm { |
48 | namespace relay { |
49 | |
50 | using namespace tvm::tir; |
51 | |
52 | enum class ExprPrecedence : int { |
53 | /*! \brief Identity(e.g., IntImm, Var) and function call(e.g., floordiv, min) */ |
54 | kIdentity = 0, |
55 | /*! |
56 | * \brief Multiplication(*), division(/), and remainder(%) |
57 | * \note floorDiv, floorMod is marked as kIdentity since they are function calls. |
58 | */ |
59 | kMultiplicationDivision = 1, |
60 | /*! \brief Addition(+) and subtraction(-) */ |
61 | kAdditionSubtraction = 2, |
62 | /*! \brief For relational operators < and <= and > and >= respectively */ |
63 | kRelational = 3, |
64 | /*! \brief For equality operators = and != respectively */ |
65 | kEquality = 4, |
66 | /*! \brief And(&&) */ |
67 | kAnd = 5, |
68 | /*! \brief Or(||) */ |
69 | kOr = 6, |
70 | /*! \brief Unknown precedence */ |
71 | kUnknown = 7, |
72 | }; |
73 | |
74 | /*! \brief Utility used for identifying usage of a buffer_var |
75 | * |
76 | * \details Find the Buffer object that corresponds to a variable or |
77 | * allocation, based on the BufferLoad/BufferStore instances that |
78 | * occur within the allocation's body. |
79 | */ |
80 | class BufferUsageFinder : public StmtExprVisitor { |
81 | public: |
82 | static Map<tir::Var, Array<Buffer>> FindUsage(Map<tir::Var, Array<Buffer>> usage, Stmt body) { |
83 | BufferUsageFinder visitor(std::move(usage)); |
84 | visitor.VisitStmt(body); |
85 | return std::move(visitor.usage_); |
86 | } |
87 | |
88 | void VisitExpr_(const tir::VarNode* op) final { |
89 | tir::Var var = GetRef<tir::Var>(op); |
90 | if (!usage_.count(var)) { |
91 | usage_.Set(var, {}); |
92 | } |
93 | } |
94 | |
95 | void VisitExpr_(const BufferLoadNode* op) final { |
96 | VisitBuffer(op->buffer); |
97 | StmtExprVisitor::VisitExpr_(op); |
98 | } |
99 | |
100 | void VisitStmt_(const BufferStoreNode* op) final { |
101 | VisitBuffer(op->buffer); |
102 | StmtExprVisitor::VisitStmt_(op); |
103 | } |
104 | |
105 | void VisitStmt_(const DeclBufferNode* op) final { |
106 | buffers_declared_.insert(op->buffer.get()); |
107 | StmtExprVisitor::VisitStmt_(op); |
108 | buffers_declared_.erase(op->buffer.get()); |
109 | } |
110 | |
111 | private: |
112 | explicit BufferUsageFinder(Map<tir::Var, Array<Buffer>> usage) : usage_(usage) {} |
113 | |
114 | void VisitBuffer(const Buffer& buffer) { |
115 | if (buffers_visited_.count(buffer.get())) { |
116 | return; |
117 | } |
118 | if (buffers_declared_.count(buffer.get())) { |
119 | return; |
120 | } |
121 | buffers_visited_.insert(buffer.get()); |
122 | |
123 | Array<Buffer> arr = usage_.Get(buffer->data).value_or({}); |
124 | arr.push_back(buffer); |
125 | usage_.Set(buffer->data, arr); |
126 | } |
127 | |
128 | // The search result. |
129 | Map<tir::Var, Array<Buffer>> usage_; |
130 | // The buffers that have been visited so far, to avoid duplicate |
131 | // entries in the search result. |
132 | std::unordered_set<const BufferNode*> buffers_visited_; |
133 | // The buffers declared via `DeclBuffer`. These buffers are excluded from the result because |
134 | // T.buffer_decl shouldn't be printed for them. |
135 | std::unordered_set<const BufferNode*> buffers_declared_; |
136 | }; |
137 | |
138 | /*! |
139 | * \brief The printer for TVMScript |
140 | * \details The printer obtain the precedence of the top-level operation when printing each |
141 | * subexpression to decide whether or not parentheses is needed. |
142 | */ |
143 | class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>, |
144 | public tir::ExprFunctor<Doc(const PrimExpr&, ExprPrecedence*)>, |
145 | public TypeFunctor<Doc(const Type&)> { |
146 | public: |
147 | explicit TVMScriptPrinter(const String& tir_prefix, bool show_meta, |
148 | runtime::TypedPackedFunc<std::string(Stmt)> annotate = nullptr) |
149 | : tir_prefix_(tir_prefix), |
150 | show_meta_(show_meta), |
151 | annotate_(std::move(annotate)), |
152 | meta_collector_(&meta_) {} |
153 | |
154 | /*! |
155 | * \brief Print the node. |
156 | * \param node The node to be printed. |
157 | * \param out_precedence The operator precedence of node if it's a PrimExpr, |
158 | * so we can simplify the bracket. |
159 | */ |
160 | TVM_DLL Doc Print(const ObjectRef& node); |
161 | |
162 | protected: |
163 | /*! \brief The tir prefix */ |
164 | String tir_prefix_; |
165 | /*! \brief whether show meta data */ |
166 | bool show_meta_; |
167 | /*! \brief additional comment function */ |
168 | runtime::TypedPackedFunc<std::string(Stmt)> annotate_; |
169 | /*! \brief meta data context */ |
170 | TextMetaDataContext meta_; |
171 | /*! \brief meta collector */ |
172 | relay::MetaCollector meta_collector_; |
173 | /*! \brief map from Function to GlobalVar */ |
174 | std::unordered_map<const BaseFuncNode*, GlobalVar> func2var_; |
175 | /*! \brief var collector (var defined by For/Loop/Block) */ |
176 | std::unordered_set<const tir::VarNode*> ; |
177 | /*! |
178 | * \brief buffer collector |
179 | * (buffer defined in BufferMap, BufferAllocation and MatchBufferRegion) |
180 | */ |
181 | std::unordered_set<const BufferNode*> ; |
182 | /*! \brief Map from Var to thread env name */ |
183 | std::unordered_map<tir::Var, String, ObjectPtrHash, ObjectPtrEqual> var_env_map_; |
184 | /*! \brief Map from Var to Doc */ |
185 | std::unordered_map<tir::Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_; |
186 | /*! \brief Map from Buffer to Doc */ |
187 | std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_; |
188 | /*! \brief Map from Buffer to Declaration Doc */ |
189 | std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_decl_; |
190 | /*! \brief name allocation map */ |
191 | std::unordered_map<std::string, int> name_alloc_map_; |
192 | /*! \brief number of children of current node's parent */ |
193 | int num_child_; |
194 | /*! \brief the number of current node */ |
195 | int current_num_; |
196 | /*! \brief loop stack without annotations */ |
197 | std::vector<For> simple_loop_stack_; |
198 | /*! \brief the maps from loop_vars to the loops */ |
199 | std::unordered_map<const tir::VarNode*, For> loop_var_map_; |
200 | /*! |
201 | * \brief simple block vars remap from loop vars |
202 | * simple_remap requires: |
203 | * 1. block var iter type is kDataPar or kCommReduce |
204 | * 2. value is a single Var, which is a loop_var outside the block |
205 | * 3. The iter range is equal to loop range |
206 | */ |
207 | std::vector<std::pair<IterVar, PrimExpr>> block_var_remaps_; |
208 | /*! |
209 | * \brief Map from variables to the buffers they are used in. |
210 | * |
211 | * Used for identifying buffers that should be declared after the |
212 | * LetStmt or Allocate that generates their data pointer, rather |
213 | * than in the header. |
214 | */ |
215 | Map<tir::Var, Array<Buffer>> buffer_var_usage_; |
216 | /*! \brief Analyzer to simplify some expressions. */ |
217 | arith::Analyzer ana_; |
218 | |
219 | Doc VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) override; |
220 | Doc VisitExpr_(const tir::VarNode* op, ExprPrecedence* out_precedence) override; |
221 | Doc VisitExpr_(const AddNode* op, ExprPrecedence* out_precedence) override; |
222 | Doc VisitExpr_(const SubNode* op, ExprPrecedence* out_precedence) override; |
223 | Doc VisitExpr_(const MulNode* op, ExprPrecedence* out_precedence) override; |
224 | Doc VisitExpr_(const DivNode* op, ExprPrecedence* out_precedence) override; |
225 | Doc VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) override; |
226 | Doc VisitExpr_(const FloorDivNode* op, ExprPrecedence* out_precedence) override; |
227 | Doc VisitExpr_(const FloorModNode* op, ExprPrecedence* out_precedence) override; |
228 | Doc VisitExpr_(const MinNode* op, ExprPrecedence* out_precedence) override; |
229 | Doc VisitExpr_(const MaxNode* op, ExprPrecedence* out_precedence) override; |
230 | Doc VisitExpr_(const EQNode* op, ExprPrecedence* out_precedence) override; |
231 | Doc VisitExpr_(const NENode* op, ExprPrecedence* out_precedence) override; |
232 | Doc VisitExpr_(const LTNode* op, ExprPrecedence* out_precedence) override; |
233 | Doc VisitExpr_(const LENode* op, ExprPrecedence* out_precedence) override; |
234 | Doc VisitExpr_(const GTNode* op, ExprPrecedence* out_precedence) override; |
235 | Doc VisitExpr_(const GENode* op, ExprPrecedence* out_precedence) override; |
236 | Doc VisitExpr_(const AndNode* op, ExprPrecedence* out_precedence) override; |
237 | Doc VisitExpr_(const OrNode* op, ExprPrecedence* out_precedence) override; |
238 | Doc VisitExpr_(const NotNode* op, ExprPrecedence* out_precedence) override; |
239 | Doc VisitExpr_(const SelectNode* op, ExprPrecedence* out_precedence) override; |
240 | Doc VisitExpr_(const IntImmNode* op, ExprPrecedence* out_precedence) override; |
241 | Doc VisitExpr_(const FloatImmNode* op, ExprPrecedence* out_precedence) override; |
242 | Doc VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) override; |
243 | Doc VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) override; |
244 | Doc VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) override; |
245 | Doc VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) override; |
246 | Doc VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) override; |
247 | Doc VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) override; |
248 | Doc VisitExpr_(const tir::LetNode* op, ExprPrecedence* out_precedence) override; |
249 | Doc VisitExpr_(const tir::CallNode* op, ExprPrecedence* out_precedence) override; |
250 | Doc VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) override; |
251 | Doc VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) override; |
252 | Doc VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) override; |
253 | |
254 | Doc VisitStmt_(const LetStmtNode* op) override; |
255 | Doc VisitStmt_(const AttrStmtNode* op) override; |
256 | Doc VisitStmt_(const AssertStmtNode* op) override; |
257 | Doc VisitStmt_(const StoreNode* op) override; |
258 | Doc VisitStmt_(const BufferStoreNode* op) override; |
259 | Doc VisitStmt_(const BufferRealizeNode* op) override; |
260 | Doc VisitStmt_(const AllocateNode* op) override; |
261 | Doc VisitStmt_(const AllocateConstNode* op) override; |
262 | Doc VisitStmt_(const DeclBufferNode* op) override; |
263 | Doc VisitStmt_(const IfThenElseNode* op) override; |
264 | Doc VisitStmt_(const SeqStmtNode* op) override; |
265 | Doc VisitStmt_(const ForNode* op) override; |
266 | Doc VisitStmt_(const WhileNode* op) override; |
267 | Doc VisitStmt_(const PrefetchNode* op) override; |
268 | Doc VisitStmt_(const EvaluateNode* op) override; |
269 | Doc VisitStmt_(const BlockRealizeNode* op) override; |
270 | Doc VisitStmtDefault_(const Object* op) override; |
271 | |
272 | Doc VisitType_(const PrimTypeNode* node) override; |
273 | Doc VisitType_(const PointerTypeNode* node) override; |
274 | Doc VisitType_(const TupleTypeNode* node) override; |
275 | |
276 | Doc PrintBody(const Stmt& body); |
277 | Doc PrintIRModule(const IRModule& module); |
278 | Doc PrintPrimFunc(const PrimFunc& primFunc); |
279 | Doc PrintIterVar(const IterVarNode* op); |
280 | Doc PrintRange(const RangeNode* op); |
281 | Doc PrintArray(const ArrayNode* op); |
282 | Doc PrintBuffer(const BufferNode* op); |
283 | Doc PrintBufferIndices(const Array<PrimExpr>& indices); |
284 | Doc PrintNonHeaderBufferDeclarations(const Array<Buffer>& aliasing_buffers); |
285 | Doc AllocBufferDeclaration(const Buffer& buf); |
286 | Doc PrintBlockVar(const IterVar& iter_var, const PrimExpr& value); |
287 | Doc PrintBlockVarRemaps(); |
288 | Doc PrintBlockPredicate(const BlockRealizeNode* op); |
289 | Doc PrintBlockVars(const BlockRealizeNode* op); |
290 | Doc PrintBlockAttr(const BlockRealizeNode* op); |
291 | Doc PrintExpandedArray(const ArrayNode* op); |
292 | Doc PrintBlockBody(const BlockNode* op); |
293 | virtual Doc PrintBlockName(const BlockNode* block_op); |
294 | Doc PrintBufferRegion(const BufferRegionNode* op); |
295 | Doc PrintMatchBufferRegion(const MatchBufferRegionNode* op); |
296 | Doc PrintCommReducer(const CommReducerNode* op); |
297 | Doc PrintAnnotations(const Map<String, ObjectRef>& annotations); |
298 | Doc PrintTarget(const TargetNode* target); |
299 | static Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } |
300 | |
301 | Doc GetUniqueName(std::string prefix); |
302 | Doc AllocVar(const tir::Var& var); |
303 | Doc AllocBuf(const Buffer& buffer); |
304 | void TryDeallocVar(const tir::Var& var); |
305 | bool ContainsOptionalInfo(const Stmt& stmt); |
306 | /*! |
307 | * \brief Check if a buffer declaration satisfies: |
308 | * 1. has only 'shape' and 'dtype' arguments specified, |
309 | * 2. the shape and strides are not dynamic. |
310 | * \param buffer The match buffer to be checked |
311 | */ |
312 | bool IsSimpleBuffer(const Buffer& buffer); |
313 | Doc PrintInlineBufferBind(const Buffer& buffer); |
314 | Doc PrintTuple(const ArrayNode* op); |
315 | |
316 | /*! Helper functions for loop printing. */ |
317 | /*! |
318 | * \brief Print a single for loop |
319 | * \param loop The for loop to be printed |
320 | */ |
321 | virtual Doc PrintLoop(const For& loop); |
322 | /*! \brief Print all simple loops in stack into one line using tir_prefix_.grid(). */ |
323 | Doc PrintLoopStack(); |
324 | /*! |
325 | * \brief Check whether a loop satisfies: |
326 | * 1. the loop is serial; |
327 | * 2. the loop has no annotation; |
328 | * 3. the loop starts from 0; |
329 | * 4. there is no optional information. |
330 | * \param for_op the for node to be checked |
331 | * \return A boolean indicating whether the input loop satisfies the above conditions |
332 | */ |
333 | bool IsSimpleLoop(const ForNode* for_op) { |
334 | return for_op->kind == ForKind::kSerial && for_op->annotations.empty() && |
335 | is_zero(for_op->min) && !ContainsOptionalInfo(GetRef<Stmt>(for_op)); |
336 | } |
337 | /*! |
338 | * \brief Check whether the `min` or `extent` of a loop depends on previous loops |
339 | * \param for_op The loop to be checked |
340 | * \return A boolean indicating whether the input loop depends on previous loops |
341 | */ |
342 | bool DependOnPrevLoops(const ForNode* for_op) { |
343 | auto f_check = [&var_map = this->loop_var_map_](const tir::VarNode* v) { |
344 | return var_map.count(v); |
345 | }; |
346 | return UsesVar(for_op->min, f_check) || UsesVar(for_op->extent, f_check); |
347 | } |
348 | |
349 | /*! |
350 | * \brief Print additional info about expr in comment. |
351 | * \param expr The expression. |
352 | */ |
353 | Doc PrintOptionalInfo(const Stmt& stmt) { |
354 | Doc doc; |
355 | // default annotations |
356 | if (ContainsOptionalInfo(stmt)) { |
357 | std::string annotated_stmt = annotate_(stmt); |
358 | doc << "# " << annotated_stmt << Doc::NewLine(); |
359 | } |
360 | return doc; |
361 | } |
362 | |
363 | /*! |
364 | * \brief special method to render vectors of docs with a separator |
365 | * \param vec vector of docs |
366 | * \param sep separator |
367 | */ |
368 | static Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) { |
369 | Doc seq; |
370 | if (vec.size() != 0) { |
371 | seq = vec[0]; |
372 | for (size_t i = 1; i < vec.size(); i++) { |
373 | seq << sep << vec[i]; |
374 | } |
375 | } |
376 | return seq; |
377 | } |
378 | |
379 | /*! |
380 | * \brief dump meta info |
381 | * \return Doc with meta info |
382 | */ |
383 | Doc DumpMeta() { |
384 | if (show_meta_) { |
385 | return Doc::Text("__tvm_meta__ = " ) |
386 | << (meta_.empty() ? Doc::Text("None" ) : meta_.GetMetaSection()); |
387 | } else { |
388 | return Doc::Text("" ); |
389 | } |
390 | } |
391 | |
392 | /*! |
393 | * \brief special method to print out data type |
394 | * \param dtype The data type |
395 | */ |
396 | static Doc PrintDType(DataType dtype) { |
397 | return Doc::StrLiteral(runtime::DLDataType2String(dtype)); |
398 | } |
399 | |
400 | /*! |
401 | * \brief special method to print out const int64_t scalar |
402 | * \param dtype The data type |
403 | * \param data The pointer to hold the data. |
404 | */ |
405 | Doc PrintConstScalar(DataType dtype, const int64_t* data) const { |
406 | Doc doc; |
407 | std::ostringstream os; |
408 | |
409 | os << data[0]; |
410 | |
411 | if (dtype == DataType::Int(32)) { |
412 | doc << Doc::Text(os.str()); |
413 | } else if (dtype == DataType::Bool()) { |
414 | doc << Doc::Text(data[0] ? "True" : "False" ); |
415 | } else { |
416 | doc << tir_prefix_ << "." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str()) |
417 | << ")" ; |
418 | } |
419 | return doc; |
420 | } |
421 | |
422 | /*! |
423 | * \brief special method to print out const double scalar |
424 | * \param dtype The data type |
425 | * \param data The pointer to hold the data. |
426 | * \note this overriden function is created as std::isnan of msvc will complain about int64_t |
427 | */ |
428 | Doc PrintConstScalar(DataType dtype, const double* data) const { |
429 | Doc doc; |
430 | std::ostringstream os; |
431 | |
432 | os.precision(17); |
433 | if (std::isinf(data[0]) || std::isnan(data[0])) { |
434 | os << "\"" << data[0] << "\"" ; |
435 | } else { |
436 | os << data[0]; |
437 | } |
438 | |
439 | doc << tir_prefix_ << "." << runtime::DLDataType2String(dtype) << "(" << Doc::Text(os.str()) |
440 | << ")" ; |
441 | |
442 | return doc; |
443 | } |
444 | |
445 | public: |
446 | static Doc (const std::string& tir_prefix) { |
447 | Doc ; |
448 | if (tir_prefix != "tir" ) { |
449 | header << "# from tvm.script import tir as " << tir_prefix << Doc::NewLine(); |
450 | } else { |
451 | header << "# from tvm.script import tir" << Doc::NewLine(); |
452 | } |
453 | return header; |
454 | } |
455 | }; |
456 | |
457 | /*! |
458 | * \brief special method to print NDArray in TIR |
459 | * \param arr the NDArray to be printed |
460 | * \param os the output stream where the NDArray will be printed to |
461 | */ |
462 | template <typename T> |
463 | void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) { |
464 | if ((arr.DataType().code() == runtime::DataType::kInt || |
465 | arr.DataType().code() == runtime::DataType::kUInt) && |
466 | arr.DataType().bits() == 8) { |
467 | // Printing int8 NDArrays causes "UnicodeDecodeError: 'utf-8' codec can't decode byte" |
468 | // error during MetaSchedule tuning on int8 models. |
469 | return; |
470 | } |
471 | int ndim = arr->ndim; |
472 | int tot_dim = 1; |
473 | for (int i = 0; i < ndim; i++) { |
474 | tot_dim *= arr->shape[i]; |
475 | } |
476 | T* data_ptr = reinterpret_cast<T*>(arr->data); |
477 | constexpr int NUM_PRINT = 20; |
478 | os << "[" ; |
479 | for (int i = 0; i < tot_dim; i++) { |
480 | os << (i != 0 ? ", " : "" ) << data_ptr[i]; |
481 | if (i == NUM_PRINT) { |
482 | os << "..." ; |
483 | break; |
484 | } |
485 | } |
486 | os << "]" ; |
487 | } |
488 | |
489 | Doc TVMScriptPrinter::GetUniqueName(std::string prefix) { |
490 | std::replace(prefix.begin(), prefix.end(), '.', '_'); |
491 | std::string unique_prefix = prefix; |
492 | auto it = name_alloc_map_.find(prefix); |
493 | if (it != name_alloc_map_.end() && it->second >= 0) { |
494 | while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) { |
495 | } |
496 | } |
497 | name_alloc_map_[unique_prefix] = 0; |
498 | return Doc::Text(unique_prefix); |
499 | } |
500 | |
501 | Doc TVMScriptPrinter::AllocVar(const tir::Var& var) { |
502 | const auto& it = memo_var_.find(var); |
503 | if (it != memo_var_.end()) { |
504 | return it->second; |
505 | } |
506 | std::string name = var->name_hint.operator std::string(); |
507 | if (name.length() == 0 || !std::isalpha(name[0])) { |
508 | name = "v" + name; |
509 | } |
510 | Doc val = GetUniqueName(name); |
511 | memo_var_[var] = val; |
512 | return val; |
513 | } |
514 | |
515 | Doc TVMScriptPrinter::AllocBufferDeclaration(const Buffer& buf) { |
516 | Doc doc = Print(buf->shape); |
517 | bool print_factor_explicitly = false; |
518 | doc << ", dtype=" << PrintDType(buf->dtype); |
519 | if (memo_var_.find(buf->data) != memo_var_.end()) { |
520 | doc << ", data=" << Print(buf->data); |
521 | } else { |
522 | // implicitly define data |
523 | memo_var_[buf->data] = Doc::Text(memo_buf_[buf].str() + ".data" ); |
524 | var_not_in_headers_.insert(buf->data.get()); |
525 | } |
526 | if (!buf->strides.empty()) { |
527 | doc << ", strides=" << Print(buf->strides); |
528 | } |
529 | if (buf->elem_offset->IsInstance<tir::VarNode>()) { |
530 | tir::Var elem_offset = Downcast<tir::Var>(buf->elem_offset); |
531 | if (memo_var_.find(elem_offset) != memo_var_.end()) { |
532 | doc << ", elem_offset=" << Print(buf->elem_offset); |
533 | } else { |
534 | // implicitly define elem_offset |
535 | memo_var_[elem_offset] = Doc::Text(memo_buf_[buf].str() + ".elem_offset" ); |
536 | var_not_in_headers_.insert(elem_offset.get()); |
537 | print_factor_explicitly = true; |
538 | } |
539 | } else if (buf->elem_offset->IsInstance<IntImmNode>()) { |
540 | IntImm elem_offset = Downcast<IntImm>(buf->elem_offset); |
541 | if (elem_offset->value != 0) { |
542 | doc << ", elem_offset=" << Print(buf->elem_offset); |
543 | } |
544 | } |
545 | if (buf.scope() != "global" ) { |
546 | doc << ", scope=" << Doc::StrLiteral(buf.scope()); |
547 | } |
548 | if (buf->data_alignment != runtime::kAllocAlignment) { |
549 | doc << ", align=" << buf->data_alignment; |
550 | } |
551 | if (buf->offset_factor != 1 || print_factor_explicitly) { |
552 | doc << ", offset_factor=" << buf->offset_factor; |
553 | } |
554 | if (buf->buffer_type != BufferType::kDefault) { |
555 | doc << ", type=" << Doc::StrLiteral("auto" ); |
556 | } |
557 | if (buf->axis_separators.size()) { |
558 | doc << ", axis_separators=" << Print(buf->axis_separators); |
559 | } |
560 | return doc; |
561 | } |
562 | |
563 | Doc TVMScriptPrinter::AllocBuf(const Buffer& buffer) { |
564 | const auto& it = memo_buf_.find(buffer); |
565 | if (it != memo_buf_.end()) { |
566 | return it->second; |
567 | } |
568 | std::string name = buffer->name; |
569 | if (name.length() == 0 || !std::isalpha(name[0])) { |
570 | name = "buf_" + name; |
571 | } |
572 | Doc val = GetUniqueName(name); |
573 | memo_buf_[buffer] = val; |
574 | memo_buf_decl_[buffer] = AllocBufferDeclaration(buffer); |
575 | return val; |
576 | } |
577 | |
578 | /*! |
579 | * \brief Check if any optional information exists in annotate_ for |
580 | * a given Stmt. |
581 | * \param stmt The statement. |
582 | */ |
583 | bool TVMScriptPrinter::ContainsOptionalInfo(const Stmt& stmt) { |
584 | if (annotate_ == nullptr) return false; |
585 | return !annotate_(stmt).empty(); |
586 | } |
587 | |
588 | /*! |
589 | * \brief Try to dealloc vars out of space and leave the index to coming vars. |
590 | * \note It is not a necessary step. |
591 | */ |
592 | void TVMScriptPrinter::TryDeallocVar(const tir::Var& var) { |
593 | auto it = memo_var_.find(var); |
594 | ICHECK(it != memo_var_.end()); |
595 | std::string print_name = it->second.str(); |
596 | |
597 | std::string name_hint = var->name_hint.operator std::string(); |
598 | if (name_hint.length() == 0 || !std::isalpha(name_hint[0])) { |
599 | name_hint = "v" + name_hint; |
600 | } |
601 | std::replace(name_hint.begin(), name_hint.end(), '.', '_'); |
602 | |
603 | auto it2 = name_alloc_map_.find(name_hint); |
604 | // Skip it if we can not find the name_hint in name_alloc_map_. |
605 | if (it2 == name_alloc_map_.end()) return; |
606 | if (it2->second > 0) { |
607 | name_hint = name_hint + '_' + std::to_string(it2->second); |
608 | } |
609 | // Skip it if the name_hint is not equal to how it should be printed. |
610 | if (name_hint != print_name) return; |
611 | // Free the conresponding name_alloc_map_ index |
612 | --it2->second; |
613 | } |
614 | |
615 | Doc TVMScriptPrinter::PrintMatchBufferRegion(const MatchBufferRegionNode* op) { |
616 | const Buffer& buf = op->buffer; |
617 | buf_not_in_headers_.insert(buf.get()); |
618 | |
619 | Doc doc = Print(op->buffer) << " = " << tir_prefix_ << ".match_buffer(" << Print(op->source) |
620 | << ", " << memo_buf_decl_[op->buffer] << ")" ; |
621 | return doc; |
622 | } |
623 | |
624 | // check if all arguments, except the first two, are specified for T.match_buffer |
625 | // if not, then this match buffer is printed out as T.buffer in prim_func arguments |
626 | // and check whether there are undefined variables in the shape/strides. |
627 | bool TVMScriptPrinter::IsSimpleBuffer(const Buffer& buf) { |
628 | if (memo_var_.find(buf->data) != memo_var_.end()) { |
629 | return false; |
630 | } |
631 | if (!buf->strides.empty()) { |
632 | return false; |
633 | } |
634 | for (const PrimExpr& shp_i : buf->shape) { |
635 | if (!UndefinedVars(shp_i).empty()) { |
636 | return false; |
637 | } |
638 | } |
639 | for (const PrimExpr& stride_i : buf->strides) { |
640 | if (!UndefinedVars(stride_i).empty()) { |
641 | return false; |
642 | } |
643 | } |
644 | if (!UndefinedVars(buf->elem_offset).empty()) { |
645 | return false; |
646 | } else if (buf->elem_offset->IsInstance<IntImmNode>()) { |
647 | IntImm elem_offset = Downcast<IntImm>(buf->elem_offset); |
648 | if (elem_offset->value != 0) { |
649 | return false; |
650 | } |
651 | } |
652 | if (buf.scope() != "global" ) { |
653 | return false; |
654 | } |
655 | if (buf->data_alignment != runtime::kAllocAlignment) { |
656 | return false; |
657 | } |
658 | if (buf->offset_factor != 1) { |
659 | return false; |
660 | } |
661 | if (buf->buffer_type != BufferType::kDefault) { |
662 | return false; |
663 | } |
664 | if (buf->axis_separators.size()) { |
665 | return false; |
666 | } |
667 | return true; |
668 | } |
669 | |
670 | Doc TVMScriptPrinter::PrintInlineBufferBind(const Buffer& buffer) { |
671 | Doc doc; |
672 | doc << tir_prefix_ << ".Buffer[" ; |
673 | if (buffer->shape.size() == 1) { |
674 | doc << Print(buffer->shape[0]); |
675 | } else { |
676 | doc << PrintTuple(buffer->shape.as<ArrayNode>()); |
677 | } |
678 | doc << ", " << PrintDType(buffer->dtype) << "]" ; |
679 | return doc; |
680 | } |
681 | |
682 | // print array out as tuple with parentheses |
683 | Doc TVMScriptPrinter::PrintTuple(const ArrayNode* op) { |
684 | Doc doc; |
685 | doc << '('; |
686 | for (size_t i = 0; i < op->size(); ++i) { |
687 | if (i != 0) { |
688 | doc << ", " ; |
689 | } |
690 | doc << Print(op->at(i)); |
691 | } |
692 | if (op->size() == 1) doc << "," ; |
693 | doc << ')'; |
694 | return doc; |
695 | } |
696 | |
697 | Doc TVMScriptPrinter::PrintCommReducer(const CommReducerNode* op) { |
698 | Doc doc; |
699 | int n_var = static_cast<int>(op->rhs.size()); |
700 | |
701 | doc << tir_prefix_ << ".comm_reducer(lambda " ; |
702 | for (const tir::Var& v_lhs : op->lhs) { |
703 | doc << Print(v_lhs) << ", " ; |
704 | } |
705 | for (int i = 0; i < n_var; ++i) { |
706 | doc << Print(op->rhs[i]) << (i == n_var - 1 ? ": " : ", " ); |
707 | } |
708 | if (n_var == 1) { |
709 | doc << Print(op->result[0]) << ", " ; |
710 | } else { |
711 | doc << "(" ; |
712 | for (int i = 0; i < n_var; ++i) { |
713 | doc << Print(op->result[i]); |
714 | if (i != n_var - 1) { |
715 | doc << ", " ; |
716 | } |
717 | } |
718 | doc << "), " ; |
719 | } |
720 | doc << Print(op->identity_element) << ")" ; |
721 | |
722 | // Remove the vars in `lhs` and `rhs`, because they are the parameters of the printed lambda. |
723 | for (int i = 0; i < n_var; ++i) { |
724 | memo_var_.erase(op->lhs[i]); |
725 | memo_var_.erase(op->rhs[i]); |
726 | } |
727 | return doc; |
728 | } |
729 | |
730 | Doc TVMScriptPrinter::Print(const ObjectRef& node) { |
731 | if (!node.defined()) return Doc::Text("None" ); |
732 | if (node->IsInstance<StmtNode>()) { |
733 | return PrintOptionalInfo(Downcast<Stmt>(node)) << VisitStmt(Downcast<Stmt>(node)); |
734 | } else if (node->IsInstance<PrimExprNode>()) { |
735 | ExprPrecedence t = ExprPrecedence::kUnknown; |
736 | return VisitExpr(Downcast<PrimExpr>(node), &t); |
737 | } else if (node->IsInstance<TypeNode>()) { |
738 | return VisitType(Downcast<Type>(node)); |
739 | } else if (node->IsInstance<PrimFuncNode>()) { |
740 | return PrintPrimFunc(Downcast<PrimFunc>(node)); |
741 | } else if (node->IsInstance<IRModuleNode>()) { |
742 | return PrintIRModule(Downcast<IRModule>(node)); |
743 | } else if (node->IsInstance<ArrayNode>()) { |
744 | return PrintArray(node.as<ArrayNode>()); |
745 | } else if (node->IsInstance<BufferNode>()) { |
746 | return PrintBuffer(node.as<BufferNode>()); |
747 | } else if (node->IsInstance<StringObj>()) { |
748 | return PrintString(node.as<StringObj>()); |
749 | } else if (node->IsInstance<IterVarNode>()) { |
750 | return PrintIterVar(node.as<IterVarNode>()); |
751 | } else if (node->IsInstance<RangeNode>()) { |
752 | return PrintRange(node.as<RangeNode>()); |
753 | } else if (node->IsInstance<BufferRegionNode>()) { |
754 | return PrintBufferRegion(node.as<BufferRegionNode>()); |
755 | } else if (node->IsInstance<MatchBufferRegionNode>()) { |
756 | return PrintMatchBufferRegion(node.as<MatchBufferRegionNode>()); |
757 | } else if (node->IsInstance<CommReducerNode>()) { |
758 | return PrintCommReducer(node.as<CommReducerNode>()); |
759 | } else if (node->IsInstance<TargetNode>()) { |
760 | return PrintTarget(node.as<TargetNode>()); |
761 | } else { |
762 | LOG(FATAL) << "Do not know how to print " << node->GetTypeKey(); |
763 | } |
764 | } |
765 | |
766 | Doc TVMScriptPrinter::VisitExprDefault_(const Object* op, ExprPrecedence* out_precedence) { |
767 | LOG(FATAL) << "Do not know how to print " << op->GetTypeKey(); |
768 | } |
769 | |
770 | Doc TVMScriptPrinter::VisitStmtDefault_(const Object* op) { |
771 | LOG(FATAL) << "Do not know how to print " << op->GetTypeKey(); |
772 | } |
773 | |
774 | Doc TVMScriptPrinter::VisitExpr_(const IntImmNode* op, ExprPrecedence* out_precedence) { |
775 | *out_precedence = ExprPrecedence::kIdentity; |
776 | return PrintConstScalar(op->dtype, &(op->value)); |
777 | } |
778 | |
779 | Doc TVMScriptPrinter::VisitExpr_(const FloatImmNode* op, ExprPrecedence* out_precedence) { |
780 | *out_precedence = ExprPrecedence::kIdentity; |
781 | return PrintConstScalar(op->dtype, &(op->value)); |
782 | } |
783 | |
784 | Doc TVMScriptPrinter::VisitExpr_(const StringImmNode* op, ExprPrecedence* out_precedence) { |
785 | *out_precedence = ExprPrecedence::kIdentity; |
786 | return Doc::StrLiteral(op->value); |
787 | } |
788 | |
789 | Doc TVMScriptPrinter::VisitExpr_(const CastNode* op, ExprPrecedence* out_precedence) { |
790 | *out_precedence = ExprPrecedence::kIdentity; |
791 | Doc doc; |
792 | doc << tir_prefix_ << ".Cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")" ; |
793 | return doc; |
794 | } |
795 | |
796 | Doc TVMScriptPrinter::VisitExpr_(const tir::VarNode* op, ExprPrecedence* out_precedence) { |
797 | *out_precedence = ExprPrecedence::kIdentity; |
798 | const tir::Var& var = GetRef<tir::Var>(op); |
799 | return meta_.InMeta(var) ? meta_.GetMetaNode(var) : AllocVar(GetRef<tir::Var>(op)); |
800 | } |
801 | |
802 | bool WillPrintConstScalar(const PrimExpr& expr) { |
803 | if (const auto* imm = expr.as<IntImmNode>()) { |
804 | DataType dtype = imm->dtype; |
805 | return dtype == DataType::Int(32) || dtype == DataType::Bool(); |
806 | } |
807 | return false; |
808 | } |
809 | |
810 | #define TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OpName, OpString, OpClass, OpPrecedence) \ |
811 | Doc TVMScriptPrinter::VisitExpr_(const OpName* op, ExprPrecedence* out_precedence) { \ |
812 | Doc doc; \ |
813 | if (WillPrintConstScalar(op->a) && WillPrintConstScalar(op->b)) { \ |
814 | *out_precedence = ExprPrecedence::kIdentity; \ |
815 | doc << tir_prefix_ << "." << OpClass << "(" << Print(op->a) << ", " << Print(op->b) << ")"; \ |
816 | return doc; \ |
817 | } \ |
818 | ExprPrecedence lhs_precedence = ExprPrecedence::kUnknown; \ |
819 | ExprPrecedence rhs_precedence = ExprPrecedence::kUnknown; \ |
820 | /* Get children expr out_precedence */ \ |
821 | Doc lhs_doc = VisitExpr(op->a, &lhs_precedence); \ |
822 | Doc rhs_doc = VisitExpr(op->b, &rhs_precedence); \ |
823 | ICHECK(lhs_precedence != ExprPrecedence::kUnknown); \ |
824 | ICHECK(rhs_precedence != ExprPrecedence::kUnknown); \ |
825 | /* Update out_precedence of current node. */ \ |
826 | *out_precedence = OpPrecedence; \ |
827 | if (lhs_precedence > OpPrecedence || \ |
828 | (lhs_precedence == ExprPrecedence::kAnd && OpPrecedence == ExprPrecedence::kOr)) { \ |
829 | doc << "(" << lhs_doc << ")"; \ |
830 | } else { \ |
831 | doc << lhs_doc; \ |
832 | } \ |
833 | doc << OpString; \ |
834 | if (rhs_precedence >= OpPrecedence || \ |
835 | (rhs_precedence == ExprPrecedence::kAnd && OpPrecedence == ExprPrecedence::kOr)) { \ |
836 | doc << "(" << rhs_doc << ")"; \ |
837 | } else { \ |
838 | doc << rhs_doc; \ |
839 | } \ |
840 | return doc; \ |
841 | } |
842 | |
843 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(MulNode, " * " , "Mul" , ExprPrecedence::kMultiplicationDivision) |
844 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(DivNode, " / " , "Div" , ExprPrecedence::kMultiplicationDivision) |
845 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorDivNode, " // " , "FloorDiv" , |
846 | ExprPrecedence::kMultiplicationDivision) |
847 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(FloorModNode, " % " , "FloorMod" , |
848 | ExprPrecedence::kMultiplicationDivision) |
849 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AddNode, " + " , "Add" , ExprPrecedence::kAdditionSubtraction) |
850 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(SubNode, " - " , "Sub" , ExprPrecedence::kAdditionSubtraction) |
851 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LTNode, " < " , "LT" , ExprPrecedence::kRelational) |
852 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(LENode, " <= " , "LE" , ExprPrecedence::kRelational) |
853 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GTNode, " > " , "GT" , ExprPrecedence::kRelational) |
854 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(GENode, " >= " , "GE" , ExprPrecedence::kRelational) |
855 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(EQNode, " == " , "EQ" , ExprPrecedence::kEquality) |
856 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(NENode, " != " , "NE" , ExprPrecedence::kEquality) |
857 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(AndNode, " and " , "And" , ExprPrecedence::kAnd) |
858 | TVM_DECLARE_TVMSCRIPT_PRINTER_BINOP(OrNode, " or " , "Or" , ExprPrecedence::kOr) |
859 | |
860 | Doc TVMScriptPrinter::VisitExpr_(const ModNode* op, ExprPrecedence* out_precedence) { |
861 | *out_precedence = ExprPrecedence::kIdentity; |
862 | Doc doc; |
863 | doc << tir_prefix_ << ".truncmod(" << Print(op->a) << ", " << Print(op->b) << ")" ; |
864 | return doc; |
865 | } |
866 | |
867 | Doc TVMScriptPrinter::VisitExpr_(const MinNode* op, ExprPrecedence* out_precedence) { |
868 | *out_precedence = ExprPrecedence::kIdentity; |
869 | Doc doc; |
870 | doc << tir_prefix_ << ".min(" << Print(op->a) << ", " << Print(op->b) << ")" ; |
871 | return doc; |
872 | } |
873 | |
874 | Doc TVMScriptPrinter::VisitExpr_(const MaxNode* op, ExprPrecedence* out_precedence) { |
875 | *out_precedence = ExprPrecedence::kIdentity; |
876 | Doc doc; |
877 | doc << tir_prefix_ << ".max(" << Print(op->a) << ", " << Print(op->b) << ")" ; |
878 | return doc; |
879 | } |
880 | |
881 | Doc TVMScriptPrinter::VisitExpr_(const NotNode* op, ExprPrecedence* out_precedence) { |
882 | *out_precedence = ExprPrecedence::kIdentity; |
883 | Doc doc; |
884 | doc << "not(" << Print(op->a) << ")" ; |
885 | return doc; |
886 | } |
887 | |
888 | Doc TVMScriptPrinter::VisitExpr_(const SelectNode* op, ExprPrecedence* out_precedence) { |
889 | *out_precedence = ExprPrecedence::kIdentity; |
890 | Doc doc; |
891 | doc << tir_prefix_ << ".Select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " |
892 | << Print(op->false_value) << ")" ; |
893 | return doc; |
894 | } |
895 | |
896 | Doc TVMScriptPrinter::VisitExpr_(const ProducerLoadNode* op, ExprPrecedence* out_precedence) { |
897 | LOG(FATAL) << "Cannot print a tir.ProducerLoad as it is not valid in TIR Primfuncs. You need to " |
898 | "lower this function first." ; |
899 | return Doc(); |
900 | } |
901 | |
902 | Doc TVMScriptPrinter::VisitExpr_(const BufferLoadNode* op, ExprPrecedence* out_precedence) { |
903 | *out_precedence = ExprPrecedence::kIdentity; |
904 | Doc doc; |
905 | if (op->indices.size() == 0) { |
906 | doc << Print(op->buffer) << "[()]" ; |
907 | } else { |
908 | doc << Print(op->buffer) << PrintBufferIndices(op->indices); |
909 | } |
910 | return doc; |
911 | } |
912 | |
913 | Doc TVMScriptPrinter::VisitExpr_(const LoadNode* op, ExprPrecedence* out_precedence) { |
914 | *out_precedence = ExprPrecedence::kIdentity; |
915 | Doc doc; |
916 | if (op->dtype == DataType::Float(32) && is_one(op->predicate) && |
917 | op->buffer_var->dtype == DataType::Float(32)) { |
918 | doc << Print(op->buffer_var) << "[" << Print(op->index) << "]" ; |
919 | } else { |
920 | doc << tir_prefix_ << ".load(" << PrintDType(op->dtype) << ", " << Print(op->buffer_var) << ", " |
921 | << Print(op->index); |
922 | if (!is_one(op->predicate) || op->dtype.lanes() != 1) { |
923 | doc << ", " << Print(op->predicate); |
924 | } |
925 | doc << ")" ; |
926 | } |
927 | return doc; |
928 | } |
929 | |
930 | Doc TVMScriptPrinter::VisitExpr_(const RampNode* op, ExprPrecedence* out_precedence) { |
931 | *out_precedence = ExprPrecedence::kIdentity; |
932 | Doc doc; |
933 | doc << tir_prefix_ << ".ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " |
934 | << op->lanes << ")" ; |
935 | return doc; |
936 | } |
937 | |
938 | Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence* out_precedence) { |
939 | *out_precedence = ExprPrecedence::kIdentity; |
940 | Doc doc; |
941 | doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " << op->lanes << ")" ; |
942 | return doc; |
943 | } |
944 | |
945 | Doc TVMScriptPrinter::VisitExpr_(const tir::LetNode* op, ExprPrecedence* out_precedence) { |
946 | *out_precedence = ExprPrecedence::kIdentity; |
947 | Doc doc; |
948 | doc << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << ", " |
949 | << Print(op->body) << ")" ; |
950 | return doc; |
951 | } |
952 | |
953 | Doc TVMScriptPrinter::VisitExpr_(const tir::CallNode* op, ExprPrecedence* out_precedence) { |
954 | *out_precedence = ExprPrecedence::kIdentity; |
955 | Doc doc; |
956 | if (auto* ptr_op = op->op.as<OpNode>()) { |
957 | std::string name = ptr_op->name; |
958 | if (name.find("tir." ) == 0) { |
959 | name = tir_prefix_ + "." + name.substr(4); |
960 | } |
961 | doc << name << "(" ; |
962 | } else { |
963 | auto* op_gvar = op->op.as<GlobalVarNode>(); |
964 | ICHECK(op_gvar != nullptr); |
965 | doc << Doc::Text(op_gvar->name_hint) << "(" ; |
966 | } |
967 | std::vector<Doc> args; |
968 | for (const auto& arg : op->args) { |
969 | args.push_back(Print(arg)); |
970 | } |
971 | args.push_back(Doc::Text("dtype=" ) << PrintDType(op->dtype)); |
972 | doc << PrintSep(args, Doc::Text(", " )) << ")" ; |
973 | return doc; |
974 | } |
975 | |
976 | Doc TVMScriptPrinter::VisitExpr_(const ShuffleNode* op, ExprPrecedence* out_precedence) { |
977 | *out_precedence = ExprPrecedence::kIdentity; |
978 | Doc doc; |
979 | doc << tir_prefix_ << ".shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")" ; |
980 | return doc; |
981 | } |
982 | |
983 | Doc TVMScriptPrinter::VisitExpr_(const ReduceNode* op, ExprPrecedence* out_precedence) { |
984 | *out_precedence = ExprPrecedence::kIdentity; |
985 | Doc doc; |
986 | doc << tir_prefix_ << ".reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " |
987 | << Print(op->axis) << ", " << op->value_index << ")" ; |
988 | return doc; |
989 | } |
990 | |
991 | Doc TVMScriptPrinter::VisitStmt_(const LetStmtNode* op) { |
992 | if (!buffer_var_usage_.count(op->var)) { |
993 | buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body); |
994 | } |
995 | Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->var).value_or({}); |
996 | |
997 | Doc doc; |
998 | if (current_num_ != num_child_ - 1) { |
999 | doc << "with " << tir_prefix_ << ".let(" << Print(op->var) << ", " << Print(op->value) << "):" ; |
1000 | doc << Doc::Indent( |
1001 | 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body)); |
1002 | } else { |
1003 | if (memo_var_.find(op->var) == memo_var_.end()) var_not_in_headers_.insert(op->var.get()); |
1004 | doc << Print(op->var) << ": " << Print(GetType(op->var)) << " = " << Print(op->value) |
1005 | << Doc::NewLine(); |
1006 | doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body); |
1007 | } |
1008 | return doc; |
1009 | } |
1010 | |
1011 | Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { |
1012 | Doc doc; |
1013 | if (op->node.defined()) { |
1014 | // merge attr with realize when possible |
1015 | if (op->node->IsInstance<BufferNode>() && op->attr_key == "realize_scope" && |
1016 | op->body->IsInstance<BufferRealizeNode>()) { |
1017 | const auto* realize = Downcast<BufferRealize>(op->body).get(); |
1018 | if (realize->buffer.same_as(op->node)) { |
1019 | if (current_num_ != num_child_ - 1) { |
1020 | doc << "with " << tir_prefix_ << ".realize(" << Print(realize->buffer) |
1021 | << Print(realize->bounds) << ", " << Print(op->value); |
1022 | if (!is_one(realize->condition)) { |
1023 | doc << ", " << Print(realize->condition); |
1024 | } |
1025 | doc << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(realize->body)); |
1026 | } else { |
1027 | doc << tir_prefix_ << ".realize(" << Print(realize->buffer) << Print(realize->bounds) |
1028 | << ", " << Print(op->value); |
1029 | if (!is_one(realize->condition)) { |
1030 | doc << ", " << Print(realize->condition); |
1031 | } |
1032 | doc << ")" << Doc::NewLine() << PrintBody(realize->body); |
1033 | } |
1034 | return doc; |
1035 | } |
1036 | } |
1037 | // concise thread env |
1038 | if (op->node->IsInstance<IterVarNode>() && |
1039 | (op->attr_key == "thread_extent" || op->attr_key == "virtual_thread" )) { |
1040 | const auto* iter_var = Downcast<IterVar>(op->node).get(); |
1041 | var_not_in_headers_.insert(iter_var->var.get()); |
1042 | var_env_map_[iter_var->var] = iter_var->thread_tag; |
1043 | if (current_num_ != num_child_ - 1) { |
1044 | doc << "with " << tir_prefix_ << ".launch_thread(" << Print(iter_var->var) << ", " |
1045 | << Print(op->value) << "):" ; |
1046 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1047 | } else { |
1048 | doc << tir_prefix_ << ".launch_thread(" << Print(iter_var->var) << ", " << Print(op->value) |
1049 | << ")" ; |
1050 | doc << Doc::NewLine() << PrintBody(op->body); |
1051 | } |
1052 | return doc; |
1053 | } |
1054 | } |
1055 | // default |
1056 | if (current_num_ != num_child_ - 1) { |
1057 | doc << "with " << tir_prefix_ << ".attr(" << Print(op->node) << ", " |
1058 | << Doc::StrLiteral(op->attr_key) << ", " << Print(op->value) << "):" ; |
1059 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1060 | } else { |
1061 | doc << tir_prefix_ << ".attr(" << Print(op->node) << ", " << Doc::StrLiteral(op->attr_key) |
1062 | << ", " << Print(op->value) << ")" ; |
1063 | doc << Doc::NewLine() << PrintBody(op->body); |
1064 | } |
1065 | return doc; |
1066 | } |
1067 | |
1068 | Doc TVMScriptPrinter::VisitStmt_(const AssertStmtNode* op) { |
1069 | Doc doc; |
1070 | if (current_num_ != num_child_ - 1) { |
1071 | doc << "with " << tir_prefix_ << ".Assert(" << Print(op->condition) << ", " |
1072 | << Print(op->message) << "):" ; |
1073 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1074 | } else { |
1075 | doc << "assert " << Print(op->condition) << ", " << Print(op->message); |
1076 | doc << Doc::NewLine() << PrintBody(op->body); |
1077 | } |
1078 | return doc; |
1079 | } |
1080 | |
1081 | Doc TVMScriptPrinter::VisitStmt_(const StoreNode* op) { |
1082 | Doc doc; |
1083 | doc << tir_prefix_ << ".store(" << Print(op->buffer_var) << ", " << Print(op->index) << ", " |
1084 | << Print(op->value) << ", " << Print(op->predicate) << ")" ; |
1085 | return doc; |
1086 | } |
1087 | |
1088 | Doc TVMScriptPrinter::VisitStmt_(const BufferRealizeNode* op) { |
1089 | LOG(FATAL) |
1090 | << "TVM Script Printer Internal Error: All the BufferRealize should be folded with Attr" ; |
1091 | return Doc(); |
1092 | } |
1093 | |
1094 | namespace { |
1095 | |
1096 | bool IsAllocateDeclBufferPattern(const AllocateNode* allocate) { |
1097 | const tir::Var& buffer_var = allocate->buffer_var; |
1098 | const DeclBufferNode* decl_buffer = allocate->body.as<DeclBufferNode>(); |
1099 | if (!decl_buffer) { |
1100 | return false; |
1101 | } |
1102 | const Buffer& buffer = decl_buffer->buffer; |
1103 | if (!buffer_var.same_as(buffer->data)) { |
1104 | return false; |
1105 | } |
1106 | if (allocate->dtype != buffer->dtype) { |
1107 | return false; |
1108 | } |
1109 | if (!is_one(allocate->condition)) { |
1110 | return false; |
1111 | } |
1112 | if (allocate->annotations.size()) { |
1113 | return false; |
1114 | } |
1115 | if (allocate->extents.size() != buffer->shape.size()) { |
1116 | return false; |
1117 | } |
1118 | tir::ExprDeepEqual expr_equal; |
1119 | for (size_t i = 0, n = allocate->extents.size(); i < n; ++i) { |
1120 | if (!expr_equal(allocate->extents[i], buffer->shape[i])) { |
1121 | return false; |
1122 | } |
1123 | } |
1124 | return true; |
1125 | } |
1126 | |
1127 | } // namespace |
1128 | |
1129 | Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { |
1130 | var_not_in_headers_.insert(op->buffer_var.get()); |
1131 | |
1132 | if (!buffer_var_usage_.count(op->buffer_var)) { |
1133 | buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), op->body); |
1134 | } |
1135 | Array<Buffer> buffer_usage = buffer_var_usage_.Get(op->buffer_var).value_or({}); |
1136 | |
1137 | if (buffer_usage.empty()) { |
1138 | if (IsAllocateDeclBufferPattern(op)) { |
1139 | // As a syntax sugar, we identify the pattern of Allocate and DeclBuffer and print a single |
1140 | // DeclBuffer statement. It is intentionally to call `Print` instead of `PrintBody` here to |
1141 | // delegate the printing of the current node to `DeclBufferNode` while maintaining the |
1142 | // same value of `current_num_` and `num_child_`. |
1143 | return Print(op->body); |
1144 | } |
1145 | } |
1146 | |
1147 | auto storage_scope = GetPtrStorageScope(op->buffer_var); |
1148 | Doc func_call; |
1149 | func_call << tir_prefix_ << ".allocate(" << Print(op->extents) << ", " << PrintDType(op->dtype) |
1150 | << ", " << Print(storage_scope); |
1151 | if (!is_one(op->condition)) { |
1152 | func_call << ", " << Print(op->condition); |
1153 | } |
1154 | if (!op->annotations.empty()) { |
1155 | func_call << ", annotations={" ; |
1156 | func_call << PrintAnnotations(op->annotations); |
1157 | func_call << "}" ; |
1158 | } |
1159 | func_call << ")" ; |
1160 | |
1161 | Doc doc; |
1162 | if (current_num_ != num_child_ - 1) { |
1163 | doc << "with " << func_call << " as " << Print(op->buffer_var) << ":" ; |
1164 | doc << Doc::Indent( |
1165 | 4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body)); |
1166 | } else { |
1167 | doc << Print(op->buffer_var) << " = " << func_call << Doc::NewLine(); |
1168 | doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(op->body); |
1169 | } |
1170 | TryDeallocVar(op->buffer_var); |
1171 | return doc; |
1172 | } |
1173 | |
1174 | Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { |
1175 | std::stringstream ss; |
1176 | ICHECK(alloc->data) << "Should be presented" ; |
1177 | const auto& data = alloc->data.value(); |
1178 | |
1179 | if (alloc->dtype.is_int()) { |
1180 | if (alloc->dtype.bits() == 8) { |
1181 | NDArrayToTIR<int8_t>(data, ss); |
1182 | } else if (alloc->dtype.bits() == 16) { |
1183 | NDArrayToTIR<int16_t>(data, ss); |
1184 | } else if (alloc->dtype.bits() == 32) { |
1185 | NDArrayToTIR<int32_t>(data, ss); |
1186 | } else if (alloc->dtype.bits() == 64) { |
1187 | NDArrayToTIR<int64_t>(data, ss); |
1188 | } else { |
1189 | LOG(FATAL) << "DataType not supported" ; |
1190 | } |
1191 | } else if (alloc->dtype.is_uint()) { |
1192 | if (alloc->dtype.bits() == 8) { |
1193 | NDArrayToTIR<uint8_t>(data, ss); |
1194 | } else if (alloc->dtype.bits() == 16) { |
1195 | NDArrayToTIR<uint16_t>(data, ss); |
1196 | } else if (alloc->dtype.bits() == 32) { |
1197 | NDArrayToTIR<uint32_t>(data, ss); |
1198 | } else if (alloc->dtype.bits() == 64) { |
1199 | NDArrayToTIR<int64_t>(data, ss); |
1200 | } else { |
1201 | LOG(FATAL) << "DataType not supported" ; |
1202 | } |
1203 | } else if (alloc->dtype.is_float()) { |
1204 | if (alloc->dtype.bits() == 16) { |
1205 | NDArrayToTIR<int16_t>(data, ss); |
1206 | } else if (alloc->dtype.bits() == 32) { |
1207 | NDArrayToTIR<float>(data, ss); |
1208 | } else if (alloc->dtype.bits() == 64) { |
1209 | NDArrayToTIR<double>(data, ss); |
1210 | } else { |
1211 | LOG(FATAL) << "DataType not supported" ; |
1212 | } |
1213 | } else { |
1214 | LOG(FATAL) << "DataType not supported" ; |
1215 | } |
1216 | auto ndarray_str = ss.str(); |
1217 | |
1218 | var_not_in_headers_.insert(alloc->buffer_var.get()); |
1219 | |
1220 | if (!buffer_var_usage_.count(alloc->buffer_var)) { |
1221 | buffer_var_usage_ = BufferUsageFinder::FindUsage(std::move(buffer_var_usage_), alloc->body); |
1222 | } |
1223 | Array<Buffer> buffer_usage = buffer_var_usage_.Get(alloc->buffer_var).value_or({}); |
1224 | |
1225 | Doc func_call; |
1226 | func_call << tir_prefix_ << ".allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) |
1227 | << ", " << Print(alloc->extents) << ")" ; |
1228 | |
1229 | Doc doc; |
1230 | var_not_in_headers_.insert(alloc->buffer_var.get()); |
1231 | if (current_num_ != num_child_ - 1) { |
1232 | doc << "with " << func_call << " as " << Print(alloc->buffer_var) << ":" ; |
1233 | doc << Doc::Indent(4, Doc::NewLine() << PrintNonHeaderBufferDeclarations(buffer_usage) |
1234 | << PrintBody(alloc->body)); |
1235 | } else { |
1236 | doc << Print(alloc->buffer_var) << " = " << func_call << Doc::NewLine(); |
1237 | doc << PrintNonHeaderBufferDeclarations(buffer_usage) << PrintBody(alloc->body); |
1238 | } |
1239 | return doc; |
1240 | } |
1241 | |
1242 | Doc TVMScriptPrinter::VisitStmt_(const DeclBufferNode* op) { |
1243 | const Buffer& buffer = op->buffer; |
1244 | buf_not_in_headers_.insert(buffer.get()); |
1245 | Doc buffer_name = Print(op->buffer); |
1246 | Doc func_call; |
1247 | func_call << tir_prefix_ << ".decl_buffer(" << memo_buf_decl_.at(buffer) << ")" ; |
1248 | |
1249 | Doc doc; |
1250 | if (current_num_ != num_child_ - 1) { |
1251 | doc << "with " << func_call << " as " << buffer_name << ":" ; |
1252 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1253 | } else { |
1254 | doc << buffer_name << " = " << func_call << Doc::NewLine(); |
1255 | doc << PrintBody(op->body); |
1256 | } |
1257 | return doc; |
1258 | } |
1259 | |
1260 | Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { |
1261 | Doc doc; |
1262 | doc << "if " << Print(op->condition) << ":" ; |
1263 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case)); |
1264 | |
1265 | Optional<Stmt> else_case = op->else_case; |
1266 | while (else_case) { |
1267 | if (auto* else_if = else_case.value().as<IfThenElseNode>()) { |
1268 | doc << Doc::NewLine(); |
1269 | doc << "elif " << Print(else_if->condition) << ":" ; |
1270 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(else_if->then_case)); |
1271 | |
1272 | else_case = else_if->else_case; |
1273 | } else { |
1274 | doc << Doc::NewLine(); |
1275 | doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(else_case.value())); |
1276 | break; |
1277 | } |
1278 | } |
1279 | |
1280 | return doc; |
1281 | } |
1282 | |
1283 | Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) { |
1284 | std::vector<Doc> stmts; |
1285 | for (Stmt stmt : op->seq) { |
1286 | stmts.push_back(Print(stmt)); |
1287 | } |
1288 | return PrintSep(stmts, Doc::NewLine()); |
1289 | } |
1290 | |
1291 | Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { |
1292 | // When parsing TVMScript, a PrimExpr that occurs as a statement is |
1293 | // automatically wrapped in `tir::Evaluate`. Therefore, when |
1294 | // printing, it's only necessary to print the value. For |
1295 | // readability, though, we still print T.evaluate() when the |
1296 | // expression is something other than a call node. |
1297 | Doc doc; |
1298 | if (op->value.as<CallNode>()) { |
1299 | doc << Print(op->value); |
1300 | } else { |
1301 | doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")" ; |
1302 | } |
1303 | return doc; |
1304 | } |
1305 | |
1306 | Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { |
1307 | Doc doc; |
1308 | var_not_in_headers_.insert(op->loop_var.get()); |
1309 | loop_var_map_[op->loop_var.get()] = GetRef<For>(op); |
1310 | const auto* body = op->body.as<ForNode>(); |
1311 | bool simple_loop = IsSimpleLoop(op); |
1312 | if (simple_loop) simple_loop_stack_.push_back(GetRef<For>(op)); |
1313 | // It is a loop that can be compressed, let the loops below print it out |
1314 | if (simple_loop && body != nullptr && IsSimpleLoop(body) && !DependOnPrevLoops(body)) { |
1315 | doc << Print(GetRef<For>(body)); |
1316 | TryDeallocVar(op->loop_var); |
1317 | loop_var_map_.erase(op->loop_var.get()); |
1318 | return doc; |
1319 | } |
1320 | // It is a loop that can not be compressed |
1321 | bool print_above = !simple_loop_stack_.empty(); |
1322 | // print loops above if needed |
1323 | if (print_above) { |
1324 | doc << PrintLoopStack(); |
1325 | simple_loop_stack_.clear(); |
1326 | } |
1327 | if (!simple_loop) { |
1328 | // print current loop if needed |
1329 | Doc current_loop; |
1330 | current_loop << PrintLoop(GetRef<For>(op)); |
1331 | current_loop << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1332 | doc << (print_above ? Doc::Indent(4, Doc::NewLine() << current_loop) : current_loop); |
1333 | } else { |
1334 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1335 | } |
1336 | TryDeallocVar(op->loop_var); |
1337 | loop_var_map_.erase(op->loop_var.get()); |
1338 | return doc; |
1339 | } |
1340 | |
1341 | Doc TVMScriptPrinter::VisitStmt_(const PrefetchNode* op) { |
1342 | Doc doc; |
1343 | doc << tir_prefix_ << ".prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")" ; |
1344 | return doc; |
1345 | } |
1346 | |
1347 | Doc TVMScriptPrinter::VisitStmt_(const WhileNode* op) { |
1348 | Doc doc; |
1349 | doc << "while " << Print(op->condition) << ":" ; |
1350 | doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); |
1351 | return doc; |
1352 | } |
1353 | |
1354 | Doc TVMScriptPrinter::VisitType_(const PrimTypeNode* node) { |
1355 | Doc doc; |
1356 | doc << tir_prefix_ << "." ; |
1357 | if (node->dtype.is_void()) { |
1358 | doc << "void" ; |
1359 | } else { |
1360 | doc << runtime::DLDataType2String(node->dtype); |
1361 | } |
1362 | return doc; |
1363 | } |
1364 | |
1365 | Doc TVMScriptPrinter::VisitType_(const PointerTypeNode* node) { |
1366 | Doc doc; |
1367 | doc << tir_prefix_ << ".Ptr[" ; |
1368 | doc << Print(node->element_type); |
1369 | if (!node->storage_scope.empty()) { |
1370 | doc << ", " << Doc::StrLiteral(node->storage_scope); |
1371 | } |
1372 | doc << "]" ; |
1373 | return doc; |
1374 | } |
1375 | |
1376 | Doc TVMScriptPrinter::VisitType_(const TupleTypeNode* node) { |
1377 | if (node->fields.empty()) { |
1378 | return Doc::Text("None" ); |
1379 | } else { |
1380 | std::vector<Doc> fields; |
1381 | for (Type field : node->fields) { |
1382 | fields.push_back(Print(field)); |
1383 | } |
1384 | return Doc::Text(tir_prefix_ + ".Tuple[" ) << Doc::Concat(fields) << "]" ; |
1385 | } |
1386 | } |
1387 | |
1388 | Doc TVMScriptPrinter::VisitStmt_(const BufferStoreNode* op) { |
1389 | Doc doc; |
1390 | if (op->indices.size() == 0) { |
1391 | doc << Print(op->buffer) << "[()] = " << Print(op->value); |
1392 | } else { |
1393 | doc << Print(op->buffer) << PrintBufferIndices(op->indices) << " = " << Print(op->value); |
1394 | } |
1395 | return doc; |
1396 | } |
1397 | |
1398 | /*! Helper functions for block printing. */ |
1399 | Doc TVMScriptPrinter::PrintBlockVar(const IterVar& iter_var, const PrimExpr& value) { |
1400 | Doc doc; |
1401 | doc << Print(iter_var->var) << " = " << tir_prefix_ << ".axis." ; |
1402 | switch (iter_var->iter_type) { |
1403 | case kDataPar: |
1404 | doc << "spatial" ; |
1405 | break; |
1406 | case kCommReduce: |
1407 | doc << "reduce" ; |
1408 | break; |
1409 | case kOrdered: |
1410 | doc << "scan" ; |
1411 | break; |
1412 | case kOpaque: |
1413 | doc << "opaque" ; |
1414 | break; |
1415 | default: |
1416 | LOG(FATAL) << "Unknown block var iter type: " << iter_var->iter_type; |
1417 | break; |
1418 | } |
1419 | doc << "(" ; |
1420 | const Range& dom = iter_var->dom; |
1421 | if (is_zero(dom->min)) { |
1422 | doc << Print(dom->extent); |
1423 | } else { |
1424 | doc << "(" << Print(dom->min) << ", " << Print(dom->min + dom->extent) << ")" ; |
1425 | } |
1426 | doc << ", " << Print(value) << ")" ; |
1427 | return doc; |
1428 | } |
1429 | |
1430 | Doc TVMScriptPrinter::PrintBlockVarRemaps() { |
1431 | ICHECK(!block_var_remaps_.empty()); |
1432 | if (block_var_remaps_.size() == 1) { |
1433 | const IterVar& iter_var = block_var_remaps_[0].first; |
1434 | const PrimExpr& value = block_var_remaps_[0].second; |
1435 | return PrintBlockVar(iter_var, value); |
1436 | } |
1437 | Doc doc; |
1438 | std::vector<Doc> iter_vars, iter_values; |
1439 | std::string iter_type; |
1440 | for (const auto& pair : block_var_remaps_) { |
1441 | const IterVar& iter_var = pair.first; |
1442 | const PrimExpr& value = pair.second; |
1443 | iter_vars.push_back(Print(iter_var->var)); |
1444 | iter_values.push_back(Print(value)); |
1445 | if (iter_var->iter_type == kDataPar) { |
1446 | iter_type += "S" ; |
1447 | } else if (iter_var->iter_type == kCommReduce) { |
1448 | iter_type += "R" ; |
1449 | } else { |
1450 | ICHECK(false); |
1451 | } |
1452 | } |
1453 | doc << PrintSep(iter_vars, Doc::Text(", " )) << " = " << tir_prefix_ << ".axis.remap(" |
1454 | << Doc::StrLiteral(iter_type) << ", [" << PrintSep(iter_values, Doc::Text(", " )) << "])" ; |
1455 | return doc; |
1456 | } |
1457 | |
1458 | Doc TVMScriptPrinter::PrintBlockPredicate(const BlockRealizeNode* op) { |
1459 | Doc doc; |
1460 | if (!is_one(op->predicate)) { |
1461 | doc << Doc::NewLine() << tir_prefix_ << ".where(" << Print(op->predicate) << ")" ; |
1462 | } |
1463 | return doc; |
1464 | } |
1465 | |
1466 | Doc TVMScriptPrinter::PrintBlockVars(const BlockRealizeNode* op) { |
1467 | Doc doc; |
1468 | const auto* block_op = op->block.as<BlockNode>(); |
1469 | ICHECK_EQ(block_op->iter_vars.size(), op->iter_values.size()); |
1470 | tir::ExprDeepEqual expr_equal; |
1471 | |
1472 | auto is_simple_remap = [this, &expr_equal](const IterVar& iter_var, |
1473 | const PrimExpr& value) -> bool { |
1474 | if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) return false; |
1475 | if (!value->IsInstance<tir::VarNode>()) return false; |
1476 | const tir::Var& var = Downcast<tir::Var>(value); |
1477 | auto it = loop_var_map_.find(var.get()); |
1478 | return it != loop_var_map_.end() && expr_equal(it->second->min, iter_var->dom->min) && |
1479 | expr_equal(it->second->extent, iter_var->dom->extent); |
1480 | }; |
1481 | |
1482 | for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { |
1483 | const IterVar& iter_var = block_op->iter_vars[i]; |
1484 | const PrimExpr& value = op->iter_values[i]; |
1485 | var_not_in_headers_.insert(iter_var->var.get()); |
1486 | if (is_simple_remap(iter_var, value)) { |
1487 | block_var_remaps_.push_back(std::make_pair(iter_var, value)); |
1488 | } else { |
1489 | if (!block_var_remaps_.empty()) { |
1490 | doc << Doc::NewLine() << PrintBlockVarRemaps(); |
1491 | block_var_remaps_.clear(); |
1492 | } |
1493 | doc << Doc::NewLine() << PrintBlockVar(iter_var, value); |
1494 | } |
1495 | } |
1496 | if (!block_var_remaps_.empty()) { |
1497 | doc << Doc::NewLine() << PrintBlockVarRemaps(); |
1498 | block_var_remaps_.clear(); |
1499 | } |
1500 | return doc; |
1501 | } |
1502 | |
1503 | Doc TVMScriptPrinter::PrintBlockAttr(const BlockRealizeNode* op) { |
1504 | const auto* block_op = op->block.as<BlockNode>(); |
1505 | Doc block_attr_doc; |
1506 | // print binding, read/write tensor region, annotations |
1507 | block_attr_doc << Doc::NewLine() << tir_prefix_ << ".reads(" |
1508 | << PrintExpandedArray(block_op->reads.as<ArrayNode>()) << ")" ; |
1509 | block_attr_doc << Doc::NewLine() << tir_prefix_ << ".writes(" |
1510 | << PrintExpandedArray(block_op->writes.as<ArrayNode>()) << ")" ; |
1511 | if (!block_op->annotations.empty()) { |
1512 | block_attr_doc << Doc::NewLine() << tir_prefix_ << ".block_attr({" ; |
1513 | block_attr_doc << PrintAnnotations(block_op->annotations); |
1514 | block_attr_doc << "})" ; |
1515 | } |
1516 | return block_attr_doc; |
1517 | } |
1518 | |
1519 | // This function is to make sure arguments of T.reads() and T.writes() is not parsed by printer as a |
1520 | // List. Therefore the brackets are removed before and after printing arguments out |
1521 | Doc TVMScriptPrinter::PrintExpandedArray(const ArrayNode* op) { |
1522 | Doc doc; |
1523 | for (size_t i = 0; i < op->size(); ++i) { |
1524 | if (i != 0) { |
1525 | doc << ", " ; |
1526 | } |
1527 | doc << Print(op->at(i)); |
1528 | } |
1529 | return doc; |
1530 | } |
1531 | |
1532 | Doc TVMScriptPrinter::PrintBlockBody(const BlockNode* op) { |
1533 | Doc body; |
1534 | for (const auto& alloc_buf : op->alloc_buffers) { |
1535 | buf_not_in_headers_.insert(alloc_buf.get()); |
1536 | body << Print(alloc_buf) << " = " << tir_prefix_ << ".alloc_buffer(" |
1537 | << memo_buf_decl_[alloc_buf] << ")" << Doc::NewLine(); |
1538 | } |
1539 | for (const auto& match_buf : op->match_buffers) { |
1540 | body << Print(match_buf) << Doc::NewLine(); |
1541 | } |
1542 | if (op->init.defined()) { |
1543 | Doc init_block; |
1544 | init_block << "with " << tir_prefix_ << ".init():" ; |
1545 | init_block << Doc::Indent(4, Doc::NewLine() << PrintBody(op->init.value())); |
1546 | body << init_block << Doc::NewLine(); |
1547 | } |
1548 | body << PrintBody(op->body); |
1549 | return body; |
1550 | } |
1551 | |
1552 | /*! |
1553 | * \brief Print the name of a block |
1554 | * \param block_op The block node to be printed |
1555 | */ |
1556 | Doc TVMScriptPrinter::PrintBlockName(const BlockNode* block_op) { |
1557 | Doc doc; |
1558 | doc << "with " << tir_prefix_ << ".block(" ; |
1559 | if (!block_op->name_hint.empty()) { |
1560 | doc << Doc::StrLiteral(block_op->name_hint); |
1561 | } |
1562 | doc << "):" ; |
1563 | return doc; |
1564 | } |
1565 | |
1566 | Doc TVMScriptPrinter::VisitStmt_(const BlockRealizeNode* op) { |
1567 | const auto* block_op = op->block.as<BlockNode>(); |
1568 | Doc doc = PrintOptionalInfo(GetRef<Stmt>(block_op)); |
1569 | // print block name |
1570 | doc << PrintBlockName(block_op); |
1571 | // Print block predicate. |
1572 | Doc block_predicate = PrintBlockPredicate(op); |
1573 | // Print the variable bindings, valid to use in block attributes and |
1574 | // body |
1575 | Doc block_var = PrintBlockVars(op); |
1576 | // print read/write tensor region, annotations |
1577 | Doc block_attr_doc = PrintBlockAttr(op); |
1578 | // print body |
1579 | Doc body = PrintBlockBody(block_op); |
1580 | doc << Doc::Indent(4, block_predicate << block_var << block_attr_doc << Doc::NewLine() << body); |
1581 | for (const auto& iter_var : block_op->iter_vars) { |
1582 | TryDeallocVar(iter_var->var); |
1583 | } |
1584 | return doc; |
1585 | } |
1586 | |
1587 | Doc TVMScriptPrinter::PrintBody(const Stmt& body) { |
1588 | int memo_num_child, memo_current_num; |
1589 | std::swap(memo_num_child, num_child_); |
1590 | std::swap(memo_current_num, current_num_); |
1591 | |
1592 | Doc doc; |
1593 | if (body->IsInstance<SeqStmtNode>()) { |
1594 | const auto& op = Downcast<SeqStmt>(body); |
1595 | num_child_ = op->seq.size(); |
1596 | current_num_ = 0; |
1597 | std::vector<Doc> stmts; |
1598 | for (Stmt stmt : op->seq) { |
1599 | stmts.push_back(Print(stmt)); |
1600 | current_num_++; |
1601 | } |
1602 | doc = PrintSep(stmts, Doc::NewLine()); |
1603 | } else { |
1604 | num_child_ = 1; |
1605 | current_num_ = 0; |
1606 | doc = Print(body); |
1607 | } |
1608 | |
1609 | std::swap(memo_num_child, num_child_); |
1610 | std::swap(memo_current_num, current_num_); |
1611 | return doc; |
1612 | } |
1613 | |
1614 | Doc TVMScriptPrinter::PrintIRModule(const IRModule& module) { |
1615 | auto* op = module.operator->(); |
1616 | Doc doc; |
1617 | doc << "@tvm.script.ir_module" << Doc::NewLine(); |
1618 | doc << "class Module:" ; |
1619 | for (const auto& x : op->functions) { |
1620 | func2var_[x.second.operator->()] = x.first; |
1621 | } |
1622 | Doc body = Doc::NewLine(); |
1623 | std::vector<Doc> functions; |
1624 | for (auto it = op->functions.begin(); it != op->functions.end(); ++it) { |
1625 | if ((*it).second.as<PrimFuncNode>()) { |
1626 | functions.push_back(Print((*it).second)); |
1627 | } |
1628 | } |
1629 | body << TVMScriptPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine()); |
1630 | body << Doc::NewLine() << DumpMeta(); |
1631 | doc << Doc::Indent(4, body); |
1632 | return doc; |
1633 | } |
1634 | |
1635 | Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { |
1636 | auto* op = primFunc.operator->(); |
1637 | // clear renaming map |
1638 | memo_var_.clear(); |
1639 | memo_buf_.clear(); |
1640 | memo_buf_decl_.clear(); |
1641 | var_not_in_headers_.clear(); |
1642 | buf_not_in_headers_.clear(); |
1643 | // print signature |
1644 | Doc doc; |
1645 | doc << "@" << tir_prefix_ << ".prim_func" << Doc::NewLine(); |
1646 | doc << "def " << (func2var_.find(op) == func2var_.end() ? "func" : func2var_[op]->name_hint) |
1647 | << "(" ; |
1648 | std::vector<Doc> params; |
1649 | std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> simple_buf; |
1650 | for (const auto& param : op->params) { |
1651 | var_not_in_headers_.insert(param.get()); |
1652 | auto it = op->buffer_map.find(param); |
1653 | // check if this param is a T.handle |
1654 | if (it != op->buffer_map.end()) { |
1655 | // check if this match_buffer has only the first two arguments specified |
1656 | // and whether the match_buffer is a dynamic buffer. |
1657 | const Buffer& buf = (*it).second; |
1658 | if (IsSimpleBuffer(buf)) { |
1659 | simple_buf.insert(buf); |
1660 | buf_not_in_headers_.insert(buf.get()); |
1661 | params.push_back(Print(buf) << ": " << PrintInlineBufferBind(buf)); |
1662 | continue; |
1663 | } |
1664 | } |
1665 | params.push_back(Print(param) << ": " << Print(GetType(param))); |
1666 | } |
1667 | doc << PrintSep(params, Doc::Text(", " )) << ")" ; |
1668 | if (primFunc->ret_type.defined()) { |
1669 | auto as_tuple = primFunc->ret_type.as<TupleTypeNode>(); |
1670 | if (!as_tuple || as_tuple->fields.size()) { |
1671 | doc << " -> " << Print(primFunc->ret_type); |
1672 | } |
1673 | } |
1674 | doc << ":" ; |
1675 | |
1676 | Doc body = Doc::NewLine(); |
1677 | // print buffer_bind |
1678 | for (const auto& param : op->params) { |
1679 | auto it = op->buffer_map.find(param); |
1680 | if (it == op->buffer_map.end()) continue; |
1681 | const Buffer& buf = (*it).second; |
1682 | if (simple_buf.count(buf)) continue; |
1683 | buf_not_in_headers_.insert(buf.get()); |
1684 | body << Print(buf) << " = " << tir_prefix_ << ".match_buffer(" ; |
1685 | ICHECK(memo_buf_decl_.count(buf)); |
1686 | body << Print((*it).first) << ", " << memo_buf_decl_[buf]; |
1687 | body << ")" << Doc::NewLine(); |
1688 | } |
1689 | // print body |
1690 | body << "# body" << Doc::NewLine(); |
1691 | |
1692 | Optional<Block> elided_root_block_body = [&]() -> Optional<Block> { |
1693 | auto block_realize = op->body.as<BlockRealizeNode>(); |
1694 | if (!block_realize || block_realize->iter_values.size()) { |
1695 | return NullOpt; |
1696 | } |
1697 | |
1698 | const auto& block = block_realize->block; |
1699 | if (block->annotations.size() || ContainsOptionalInfo(block)) { |
1700 | return NullOpt; |
1701 | } |
1702 | |
1703 | // The autocomplete might recognize the body itself as being a |
1704 | // root block, and fail to insert it. |
1705 | bool autocomplete_would_insert_root_block = [&]() -> bool { |
1706 | if (block->alloc_buffers.size()) { |
1707 | return true; |
1708 | } |
1709 | |
1710 | auto* block_realize = block->body.as<BlockRealizeNode>(); |
1711 | if (block_realize && block_realize->block->iter_vars.size()) { |
1712 | return true; |
1713 | } |
1714 | if (!block_realize && ContainsNode<BlockRealizeNode>(block->body)) { |
1715 | return true; |
1716 | } |
1717 | return false; |
1718 | }(); |
1719 | |
1720 | if (autocomplete_would_insert_root_block) { |
1721 | return block; |
1722 | } else { |
1723 | return NullOpt; |
1724 | } |
1725 | }(); |
1726 | |
1727 | if (elided_root_block_body) { |
1728 | // Skip printing of root block in cases where tvm::tir::ScriptComplete |
1729 | // would re-insert it. |
1730 | body << "# with " << tir_prefix_ << ".block(\"root\")" << Doc::NewLine(); |
1731 | body << PrintBlockBody(elided_root_block_body.value().get()); |
1732 | } else { |
1733 | // If this is a non-root block, or is an unskippable root block, |
1734 | // just print it without skipping. |
1735 | body << PrintBody(op->body); |
1736 | } |
1737 | |
1738 | // print func attrs |
1739 | Doc ; |
1740 | if (primFunc->attrs.defined()) { |
1741 | header_attr << Doc::NewLine() << "# function attr dict" << Doc::NewLine() << tir_prefix_ |
1742 | << ".func_attr({" ; |
1743 | std::vector<Doc> attrs; |
1744 | for (const auto& it : op->attrs->dict) { |
1745 | attrs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); |
1746 | } |
1747 | header_attr << PrintSep(attrs, Doc::Text(", " )) << "})" ; |
1748 | } |
1749 | // print buffer declarations(buffers not defined by buffer_bind or buffer_allocate) |
1750 | Doc ; |
1751 | std::vector<const BufferNode*> bufs; |
1752 | for (const auto& it : memo_buf_) { |
1753 | if (buf_not_in_headers_.find(it.first.get()) == buf_not_in_headers_.end()) { |
1754 | bufs.push_back(it.first.get()); |
1755 | } |
1756 | } |
1757 | if (!bufs.empty()) { |
1758 | header_buf << Doc::NewLine() << "# buffer definition" ; |
1759 | std::sort(bufs.begin(), bufs.end(), [&](const BufferNode* a, const BufferNode* b) { |
1760 | return memo_buf_[GetRef<Buffer>(a)].str() < memo_buf_[GetRef<Buffer>(b)].str(); |
1761 | }); |
1762 | for (const auto& buf : bufs) { |
1763 | header_buf << Doc::NewLine() << Print(GetRef<Buffer>(buf)) << " = " << tir_prefix_ |
1764 | << ".buffer_decl(" ; |
1765 | header_buf << memo_buf_decl_[GetRef<Buffer>(buf)] << ")" ; |
1766 | } |
1767 | } |
1768 | // print var declaration |
1769 | Doc ; |
1770 | std::vector<const tir::VarNode*> vars; |
1771 | for (const auto& it : memo_var_) { |
1772 | if (var_not_in_headers_.find(it.first.get()) == var_not_in_headers_.end()) { |
1773 | vars.push_back(it.first.get()); |
1774 | } |
1775 | } |
1776 | if (!var_env_map_.empty()) { |
1777 | header_var << Doc::NewLine() << "# var definition" ; |
1778 | for (const auto& it : var_env_map_) { |
1779 | header_var << Doc::NewLine() << Print(it.first) << " = " << tir_prefix_ << ".env_thread(" |
1780 | << Doc::StrLiteral(it.second) << ")" ; |
1781 | } |
1782 | } |
1783 | if (!vars.empty()) { |
1784 | std::sort(vars.begin(), vars.end(), [&](const tir::VarNode* a, const tir::VarNode* b) { |
1785 | return memo_var_[GetRef<tir::Var>(a)].str() < memo_var_[GetRef<tir::Var>(b)].str(); |
1786 | }); |
1787 | for (const auto& var : vars) { |
1788 | auto type = GetRef<tir::Var>(var)->type_annotation; |
1789 | if (auto* ptr_type = type.as<PointerTypeNode>()) { |
1790 | auto* prim_type = ptr_type->element_type.as<PrimTypeNode>(); |
1791 | ICHECK(prim_type); |
1792 | header_var << Doc::NewLine() << Print(GetRef<tir::Var>(var)) << " = " << tir_prefix_ |
1793 | << ".buffer_var(" ; |
1794 | header_var << PrintDType(prim_type->dtype) << ", " |
1795 | << Doc::StrLiteral(ptr_type->storage_scope) << ")" ; |
1796 | } else { |
1797 | header_var << Doc::NewLine() << Print(GetRef<tir::Var>(var)) << " = " << tir_prefix_ |
1798 | << ".var(" ; |
1799 | header_var << PrintDType(var->dtype) << ")" ; |
1800 | } |
1801 | } |
1802 | } |
1803 | doc << Doc::Indent(4, header_attr << header_var << header_buf << body); |
1804 | return doc; |
1805 | } |
1806 | |
1807 | Doc TVMScriptPrinter::PrintArray(const ArrayNode* op) { |
1808 | Doc doc; |
1809 | doc << '['; |
1810 | for (size_t i = 0; i < op->size(); ++i) { |
1811 | if (i != 0) { |
1812 | doc << ", " ; |
1813 | } |
1814 | doc << Print(op->at(i)); |
1815 | } |
1816 | doc << ']'; |
1817 | return doc; |
1818 | } |
1819 | |
1820 | Doc TVMScriptPrinter::PrintIterVar(const IterVarNode* op) { |
1821 | Doc doc; |
1822 | doc << tir_prefix_ << ".iter_var(" << Print(op->var); |
1823 | if (op->dom.defined()) { |
1824 | doc << ", [" << Print(op->dom) << "], " ; |
1825 | } else { |
1826 | doc << ", None, " ; |
1827 | } |
1828 | doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", " ; |
1829 | doc << Doc::StrLiteral(op->thread_tag) << ")" ; |
1830 | return doc; |
1831 | } |
1832 | |
1833 | Doc TVMScriptPrinter::PrintRange(const RangeNode* op) { |
1834 | return Print(op->min) << ":" << Print(op->min + op->extent); |
1835 | } |
1836 | |
1837 | Doc TVMScriptPrinter::PrintBuffer(const BufferNode* op) { |
1838 | const Buffer& buffer = GetRef<Buffer>(op); |
1839 | return meta_.InMeta(buffer) ? meta_.GetMetaNode(buffer) : AllocBuf(buffer); |
1840 | } |
1841 | |
1842 | Doc TVMScriptPrinter::PrintBufferIndices(const Array<PrimExpr>& indices) { |
1843 | Doc doc; |
1844 | doc << '['; |
1845 | for (size_t i = 0; i < indices.size(); ++i) { |
1846 | if (i != 0) { |
1847 | doc << ", " ; |
1848 | } |
1849 | PrimExpr index = indices[i]; |
1850 | if (const RampNode* ramp = index.as<RampNode>()) { |
1851 | // specify ramp printing as python index slice |
1852 | if (auto* stride_imm = ramp->stride.as<IntImmNode>()) { |
1853 | doc << Print(ramp->base) << ":" << Print(ramp->base + ramp->lanes * ramp->stride); |
1854 | if (stride_imm->value != 1) { |
1855 | doc << ":" << Print(ramp->stride); |
1856 | } |
1857 | continue; |
1858 | } |
1859 | } |
1860 | doc << Print(index); |
1861 | } |
1862 | doc << ']'; |
1863 | return doc; |
1864 | } |
1865 | |
1866 | Doc TVMScriptPrinter::(const Array<Buffer>& aliasing_buffers) { |
1867 | Doc decls; |
1868 | for (const auto& buf_usage : aliasing_buffers) { |
1869 | decls << Print(buf_usage) << " = " << tir_prefix_ << ".buffer_decl(" |
1870 | << memo_buf_decl_[buf_usage] << ")" << Doc::NewLine(); |
1871 | buf_not_in_headers_.insert(buf_usage.get()); |
1872 | } |
1873 | return decls; |
1874 | } |
1875 | |
1876 | Doc TVMScriptPrinter::PrintBufferRegion(const BufferRegionNode* op) { |
1877 | Doc doc; |
1878 | if (op->region.size() == 0) { |
1879 | doc << Print(op->buffer) << "[()]" ; |
1880 | } else { |
1881 | doc << Print(op->buffer) << "[" ; |
1882 | for (size_t i = 0; i < op->region.size(); ++i) { |
1883 | if (i != 0) doc << ", " ; |
1884 | const auto& range = op->region[i]; |
1885 | if (!is_one(range->extent)) { |
1886 | doc << Print(range->min) << " : " << Print(ana_.Simplify(range->min + range->extent)); |
1887 | } else { |
1888 | doc << Print(range->min); |
1889 | } |
1890 | } |
1891 | doc << "]" ; |
1892 | } |
1893 | return doc; |
1894 | } |
1895 | |
1896 | Doc TVMScriptPrinter::PrintAnnotations(const Map<String, ObjectRef>& annotations) { |
1897 | Doc res; |
1898 | std::vector<std::pair<String, ObjectRef>> anno_list; |
1899 | anno_list.reserve(annotations.size()); |
1900 | for (const auto& pair : annotations) { |
1901 | anno_list.emplace_back(pair); |
1902 | } |
1903 | sort(anno_list.begin(), anno_list.end()); |
1904 | for (size_t i = 0; i < anno_list.size(); ++i) { |
1905 | if (i != 0) { |
1906 | res << ", " ; |
1907 | } |
1908 | res << "\"" << anno_list[i].first << "\":" << Print(anno_list[i].second); |
1909 | } |
1910 | return res; |
1911 | } |
1912 | |
1913 | Doc TVMScriptPrinter::PrintLoop(const For& loop) { |
1914 | Doc res; |
1915 | res << "for " << Print(loop->loop_var) << " in " << tir_prefix_ |
1916 | << "." + std::string(ForKind2String(loop->kind)) + "(" ; |
1917 | if (is_zero(loop->min)) { |
1918 | res << Print(loop->extent); |
1919 | } else { |
1920 | res << Print(loop->min) << ", " << Print(ana_.Simplify(loop->min + loop->extent)); |
1921 | } |
1922 | if (loop->thread_binding.defined()) { |
1923 | res << ", thread=" ; |
1924 | res << Print(loop->thread_binding.value()->thread_tag); |
1925 | } |
1926 | if (!loop->annotations.empty()) { |
1927 | res << ", annotations={" ; |
1928 | res << PrintAnnotations(loop->annotations); |
1929 | res << "}" ; |
1930 | } |
1931 | res << "):" ; |
1932 | return res; |
1933 | } |
1934 | |
1935 | Doc TVMScriptPrinter::PrintLoopStack() { |
1936 | Doc res; |
1937 | if (simple_loop_stack_.size() == 1) { |
1938 | res << PrintLoop(simple_loop_stack_[0]); |
1939 | } else if (simple_loop_stack_.size() > 1) { |
1940 | std::vector<Doc> vars, extents; |
1941 | for (const auto& loop : simple_loop_stack_) { |
1942 | vars.push_back(Print(loop->loop_var)); |
1943 | extents.push_back(Print(loop->extent)); |
1944 | } |
1945 | res << "for " << PrintSep(vars, Doc::Text(", " )) << " in " << tir_prefix_ << ".grid(" |
1946 | << PrintSep(extents, Doc::Text(", " )) << "):" ; |
1947 | } |
1948 | return res; |
1949 | } |
1950 | |
1951 | Doc TVMScriptPrinter::PrintTarget(const TargetNode* target) { |
1952 | Doc res; |
1953 | res << tir_prefix_ << ".target({" ; |
1954 | Map<String, ObjectRef> config = target->Export(); |
1955 | for (auto it = config.begin(); it != config.end(); ++it) { |
1956 | if (it != config.begin()) { |
1957 | res << ", " ; |
1958 | } |
1959 | res << "\"" << (*it).first << "\":" ; |
1960 | if ((*it).first == "host" ) { |
1961 | ICHECK(target->host.defined()); |
1962 | res << PrintTarget(target->GetHost().value().get()); |
1963 | } else { |
1964 | res << Print((*it).second); |
1965 | } |
1966 | } |
1967 | res << "})" ; |
1968 | return res; |
1969 | } |
1970 | |
1971 | /*! |
1972 | * \brief The printer for TVMScript with diagnostic |
1973 | * \details The printer obtain the precedence of the top-level operation when printing each |
1974 | * subexpression to decide whether or not parentheses is needed. |
1975 | */ |
1976 | class TVMScriptPrinterWithDiagnostic : public TVMScriptPrinter { |
1977 | public: |
1978 | explicit TVMScriptPrinterWithDiagnostic(const String& tir_prefix, bool show_meta, |
1979 | runtime::TypedPackedFunc<std::string(Stmt)> annotate) |
1980 | : TVMScriptPrinter(tir_prefix, show_meta, annotate) {} |
1981 | |
1982 | protected: |
1983 | Doc PrintBlockName(const BlockNode* block_op) override; |
1984 | Doc PrintUnderline(const Stmt& stmt, int length); |
1985 | Doc PrintLoop(const For& loop) override; |
1986 | }; |
1987 | |
1988 | Doc TVMScriptPrinterWithDiagnostic::PrintBlockName(const BlockNode* block_op) { |
1989 | Doc doc = TVMScriptPrinter::PrintBlockName(block_op); |
1990 | doc << PrintUnderline(GetRef<Stmt>(block_op), doc.str().size()); |
1991 | return doc; |
1992 | } |
1993 | |
1994 | Doc TVMScriptPrinterWithDiagnostic::PrintUnderline(const Stmt& stmt, int length) { |
1995 | Doc doc; |
1996 | // annotation |
1997 | if (ContainsOptionalInfo(stmt)) { |
1998 | String underline = std::string(length, '^'); |
1999 | doc << Doc::NewLine() << underline; |
2000 | } |
2001 | return doc; |
2002 | } |
2003 | |
2004 | Doc TVMScriptPrinterWithDiagnostic::PrintLoop(const For& loop) { |
2005 | Doc res = TVMScriptPrinter::PrintLoop(loop); |
2006 | res << PrintUnderline(loop, res.str().size()); |
2007 | return res; |
2008 | } |
2009 | |
2010 | String AsTVMScriptWithDiagnostic(const ObjectRef& mod, const String& tir_prefix, bool show_meta, |
2011 | runtime::TypedPackedFunc<std::string(Stmt)> annotate) { |
2012 | ICHECK(mod->IsInstance<PrimFuncNode>() || mod->IsInstance<IRModuleNode>()); |
2013 | Doc doc; |
2014 | doc << TVMScriptPrinter::PrintHeader(tir_prefix) |
2015 | << TVMScriptPrinterWithDiagnostic(tir_prefix, show_meta, annotate).Print(mod); |
2016 | return doc.str() + "\n" ; |
2017 | } |
2018 | |
2019 | TVM_REGISTER_GLOBAL("script.AsTVMScriptWithDiagnostic" ).set_body_typed(AsTVMScriptWithDiagnostic); |
2020 | |
2021 | } // namespace relay |
2022 | } // namespace tvm |
2023 | |