1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir_text_printer.cc |
22 | * \brief Printer to print out the IR text format |
23 | * that can be parsed by a parser. |
24 | */ |
25 | |
26 | #include <tvm/ir/module.h> |
27 | #include <tvm/ir/type.h> |
28 | #include <tvm/ir/type_functor.h> |
29 | #include <tvm/node/serialization.h> |
30 | #include <tvm/target/target.h> |
31 | #include <tvm/tir/expr.h> |
32 | #include <tvm/tir/function.h> |
33 | #include <tvm/tir/op.h> |
34 | #include <tvm/tir/stmt.h> |
35 | |
36 | #include <algorithm> |
37 | #include <string> |
38 | |
39 | #include "../../tir/transforms/ir_utils.h" |
40 | #include "doc.h" |
41 | #include "meta_data.h" |
42 | #include "text_printer.h" |
43 | |
44 | namespace tvm { |
45 | namespace relay { |
46 | |
47 | Doc TIRTextPrinter::Print(const ObjectRef& node) { |
48 | if (!node.defined()) return Doc::Text("(nullptr)" ); |
49 | if (node->IsInstance<StmtNode>()) { |
50 | return VisitStmt(Downcast<Stmt>(node)); |
51 | } else if (node->IsInstance<AnyNode>()) { |
52 | return Doc::Text("?" ); |
53 | } else if (node->IsInstance<PrimExprNode>()) { |
54 | return VisitExpr(Downcast<PrimExpr>(node)); |
55 | } else if (node->IsInstance<TypeNode>()) { |
56 | return VisitType(Downcast<Type>(node)); |
57 | } else if (node->IsInstance<PrimFuncNode>()) { |
58 | return PrintPrimFunc(Downcast<PrimFunc>(node)); |
59 | } else if (node->IsInstance<IRModuleNode>()) { |
60 | return PrintIRModule(Downcast<IRModule>(node)); |
61 | } else if (node->IsInstance<ArrayNode>()) { |
62 | return PrintArray(node.as<ArrayNode>()); |
63 | } else if (node->IsInstance<IterVarNode>()) { |
64 | return PrintIterVar(node.as<IterVarNode>()); |
65 | } else if (node->IsInstance<RangeNode>()) { |
66 | return PrintRange(node.as<RangeNode>()); |
67 | } else if (node->IsInstance<BufferNode>()) { |
68 | return PrintBuffer(node.as<BufferNode>()); |
69 | } else if (node->IsInstance<DataProducerNode>()) { |
70 | return PrintProducer(node.as<DataProducerNode>()); |
71 | } else if (node->IsInstance<StringObj>()) { |
72 | return PrintString(node.as<StringObj>()); |
73 | } else if (node->IsInstance<BufferRegionNode>()) { |
74 | return PrintBufferRegion(node.as<BufferRegionNode>()); |
75 | } else if (node->IsInstance<TargetNode>()) { |
76 | return Doc::Text(node.as<TargetNode>()->ToDebugString()); |
77 | } else { |
78 | return this->meta_->GetMetaNode(node); |
79 | } |
80 | } |
81 | |
82 | Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { |
83 | const auto* op = prim_func.operator->(); |
84 | const auto& signature = op->func_type_annotation(); |
85 | // collect Meta in DictAttr |
86 | if (prim_func->attrs.defined()) { |
87 | for (const auto& it : prim_func->attrs->dict) { |
88 | meta_collector_.Collect(it.second); |
89 | } |
90 | } |
91 | // collect buffers in buffer_map |
92 | memo_var_.clear(); |
93 | memo_buf_.clear(); |
94 | |
95 | // ordered vars associated with buffers, for consistent printing |
96 | std::vector<tir::Var> buffer_vars_ordered; |
97 | |
98 | for (tir::Var v : op->params) { |
99 | auto buffer_map_find = op->buffer_map.find(v); |
100 | if (buffer_map_find != op->buffer_map.end()) { |
101 | auto map_data = *buffer_map_find; |
102 | buffer_vars_ordered.push_back(map_data.first); |
103 | memo_buf_[map_data.second] = AllocBuf(map_data.second); |
104 | } |
105 | } |
106 | |
107 | // print PrimFunc |
108 | Doc doc; |
109 | doc << "primfn" |
110 | << "(" ; |
111 | // print params and its type annotation |
112 | std::vector<Doc> params; |
113 | for (const auto& param : op->params) { |
114 | params.push_back(Print(param)); |
115 | } |
116 | Doc sep; |
117 | doc << PrintSep(params, Doc::Indent(9, Doc::Text(", " ))) << ")" ; |
118 | // print return type |
119 | doc << " -> " << Print(signature->ret_type); |
120 | // print attr |
121 | Doc attr_doc; |
122 | std::vector<Doc> attr_docs; |
123 | if (prim_func->attrs.defined()) { |
124 | for (const auto& it : op->attrs->dict) { |
125 | attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); |
126 | } |
127 | attr_doc << NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", " )) << "}" ; |
128 | doc << Doc::Indent(2, attr_doc); |
129 | } |
130 | |
131 | // print all the buffers in the tree |
132 | if (memo_buf_.size() != 0) { |
133 | Doc buffer_doc; |
134 | std::vector<Doc> buffer_docs; |
135 | for (const tir::Var& v : buffer_vars_ordered) { |
136 | const Buffer buf = op->buffer_map[v]; |
137 | buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf))); |
138 | } |
139 | buffer_doc << NewLine() << "buffers = {" ; |
140 | buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text("," ) << NewLine())); |
141 | doc << Doc::Indent(2, buffer_doc) << "}" ; |
142 | } |
143 | |
144 | if (op->buffer_map.size() != 0) { |
145 | // print buffer_map |
146 | std::vector<Doc> buffer_map_doc; |
147 | for (const tir::Var& v : buffer_vars_ordered) { |
148 | const Buffer buf = op->buffer_map[v]; |
149 | buffer_map_doc.push_back(Print(v) << ": " << Print(buf)); |
150 | } |
151 | doc << Doc::Indent( |
152 | 2, NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", " )) << "}" ); |
153 | } |
154 | |
155 | doc << PrintBody(op->body); |
156 | return doc; |
157 | } |
158 | |
159 | Doc TIRTextPrinter::NewLine() { return Doc::NewLine(); } |
160 | |
161 | Doc TIRTextPrinter::PrintIRModule(const IRModule& module) { |
162 | const auto* op = module.operator->(); |
163 | Doc doc; |
164 | |
165 | Doc body; |
166 | body << NewLine(); |
167 | std::vector<Doc> functions; |
168 | for (auto it = op->functions.begin(); it != op->functions.end(); ++it) { |
169 | if ((*it).second.as<PrimFuncNode>()) { |
170 | functions.push_back(Print((*it).second)); |
171 | } |
172 | } |
173 | body << TIRTextPrinter::PrintSep(functions, NewLine() << NewLine()); |
174 | doc << Doc::Indent(0, body); |
175 | return doc; |
176 | } |
177 | |
178 | Doc TIRTextPrinter::PrintArray(const ArrayNode* op) { |
179 | Doc doc; |
180 | doc << '['; |
181 | for (size_t i = 0; i < op->size(); ++i) { |
182 | if (i != 0) { |
183 | doc << ", " ; |
184 | } |
185 | doc << Print(op->at(i)); |
186 | } |
187 | doc << ']'; |
188 | return doc; |
189 | } |
190 | |
191 | Doc TIRTextPrinter::PrintIterVar(const IterVarNode* op) { |
192 | Doc doc; |
193 | doc << "IterVar(" << Print(op->var); |
194 | if (op->dom.defined()) { |
195 | doc << ", [" << Print(op->dom) << "], " ; |
196 | } else { |
197 | doc << ", " << Print(op->dom) << ", " ; |
198 | } |
199 | doc << Doc::StrLiteral(IterVarType2String(op->iter_type)) << ", " ; |
200 | doc << Doc::StrLiteral(op->thread_tag) << ")" ; |
201 | return doc; |
202 | } |
203 | |
204 | Doc TIRTextPrinter::PrintRange(const RangeNode* op) { |
205 | return Print(op->min) << ":" << Print(op->min + op->extent); |
206 | } |
207 | |
208 | Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) { |
209 | const Buffer& buffer = GetRef<Buffer>(op); |
210 | |
211 | if (meta_->InMeta(buffer)) { |
212 | return meta_->GetMetaNode(buffer); |
213 | } else if (memo_buf_.count(buffer)) { |
214 | return memo_buf_[buffer]; |
215 | } else { |
216 | memo_buf_[buffer] = AllocBuf(buffer); |
217 | return BufferNode2Doc(op, memo_buf_[buffer]); |
218 | } |
219 | } |
220 | |
221 | Doc TIRTextPrinter::PrintProducer(const DataProducerNode* op) { |
222 | const DataProducer& prod = GetRef<DataProducer>(op); |
223 | |
224 | if (meta_->InMeta(prod)) { |
225 | return meta_->GetMetaNode(prod); |
226 | } else if (memo_producer_.count(prod)) { |
227 | return memo_producer_[prod]; |
228 | } else { |
229 | memo_producer_[prod] = AllocProducer(prod); |
230 | return DataProducerNode2Doc(op, memo_producer_[prod]); |
231 | } |
232 | } |
233 | |
234 | Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { |
235 | doc << Doc::Text(": Buffer(" ) << Print(buf->data) << ", " << PrintDType(buf->dtype) << ", " |
236 | << Print(buf->shape) << ", " << Print(buf->strides); |
237 | if (!is_zero(buf->elem_offset)) { |
238 | doc << ", elem_offset=" << Print(buf->elem_offset); |
239 | } |
240 | if (buf->axis_separators.size()) { |
241 | doc << ", axis_separators=" << Print(buf->axis_separators); |
242 | } |
243 | if (GetRef<Buffer>(buf).scope() != "global" ) { |
244 | doc << ", scope=" << Doc::StrLiteral(GetRef<Buffer>(buf).scope()); |
245 | } |
246 | if (buf->data_alignment != runtime::kAllocAlignment) { |
247 | doc << ", align=" << buf->data_alignment; |
248 | } |
249 | if (buf->offset_factor != 1) { |
250 | doc << ", offset_factor=" << buf->offset_factor; |
251 | } |
252 | if (buf->buffer_type != 1) { |
253 | doc << ", type=" << Doc::StrLiteral("auto" ); |
254 | } |
255 | return doc << ")" ; |
256 | } |
257 | |
258 | Doc TIRTextPrinter::DataProducerNode2Doc(const DataProducerNode* prod, Doc doc) { |
259 | return doc << Doc::Text(": DataProducer(" ) << Print(prod->GetNameHint()) << ", " |
260 | << PrintDType(prod->GetDataType()) << ", " << Print(prod->GetShape()) << ")" ; |
261 | } |
262 | |
263 | Doc TIRTextPrinter::PrintBufferRegion(const BufferRegionNode* op) { |
264 | Doc doc; |
265 | doc << Print(op->buffer) << "[" ; |
266 | for (size_t i = 0; i < op->region.size(); ++i) { |
267 | if (i != 0) { |
268 | doc << ", " ; |
269 | } |
270 | const auto& range = op->region[i]; |
271 | if (!is_one(range->extent)) { |
272 | doc << Print(range->min) << ":" << Print(range->min + range->extent); |
273 | } else { |
274 | doc << Print(range->min); |
275 | } |
276 | } |
277 | doc << "]" ; |
278 | return doc; |
279 | } |
280 | |
281 | Doc TIRTextPrinter::VisitExprDefault_(const Object* op) { |
282 | return this->meta_->GetMetaNode(GetRef<ObjectRef>(op)); |
283 | } |
284 | |
285 | Doc TIRTextPrinter::VisitStmtDefault_(const Object* op) { |
286 | return this->meta_->GetMetaNode(GetRef<ObjectRef>(op)); |
287 | } |
288 | |
289 | Doc TIRTextPrinter::VisitExpr_(const IntImmNode* op) { |
290 | return PrintConstScalar<int64_t>(op->dtype, op->value); |
291 | } |
292 | |
293 | Doc TIRTextPrinter::VisitExpr_(const FloatImmNode* op) { |
294 | return PrintConstScalar<double>(op->dtype, op->value); |
295 | } |
296 | |
297 | Doc TIRTextPrinter::VisitExpr_(const StringImmNode* op) { return Doc::StrLiteral(op->value); } |
298 | |
299 | Doc TIRTextPrinter::VisitExpr_(const CastNode* op) { |
300 | Doc doc; |
301 | doc << "cast(" << PrintDType(op->dtype) << ", " << Print(op->value) << ")" ; |
302 | return doc; |
303 | } |
304 | |
305 | Doc TIRTextPrinter::VisitExpr_(const tir::VarNode* op) { |
306 | const tir::Var& var = GetRef<tir::Var>(op); |
307 | return meta_->InMeta(var) ? meta_->GetMetaNode(var) : AllocVar(GetRef<tir::Var>(op)); |
308 | } |
309 | |
310 | #define TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OpName, OpString) \ |
311 | Doc TIRTextPrinter::VisitExpr_(const OpName* op) { \ |
312 | Doc doc; \ |
313 | doc << "(" << Print(op->a) << OpString; \ |
314 | doc << Print(op->b) << ")"; \ |
315 | return doc; \ |
316 | } |
317 | |
318 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(AddNode, " + " ) |
319 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(SubNode, " - " ) |
320 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(MulNode, "*" ) |
321 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(DivNode, " / " ) |
322 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(ModNode, " % " ) |
323 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(EQNode, " == " ) |
324 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(NENode, " != " ) |
325 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(LTNode, " < " ) |
326 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(LENode, " <= " ) |
327 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(GTNode, " > " ) |
328 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(GENode, " >= " ) |
329 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(AndNode, " && " ) |
330 | TVM_DECLARE_TIR_TEXT_PRINTER_BINOP(OrNode, " || " ) |
331 | |
332 | Doc TIRTextPrinter::VisitExpr_(const FloorDivNode* op) { |
333 | Doc doc; |
334 | doc << "floordiv(" << Print(op->a) << ", " << Print(op->b) << ")" ; |
335 | return doc; |
336 | } |
337 | |
338 | Doc TIRTextPrinter::VisitExpr_(const FloorModNode* op) { |
339 | Doc doc; |
340 | doc << "floormod(" << Print(op->a) << ", " << Print(op->b) << ")" ; |
341 | return doc; |
342 | } |
343 | |
344 | Doc TIRTextPrinter::VisitExpr_(const MinNode* op) { |
345 | Doc doc; |
346 | doc << "min(" << Print(op->a) << ", " << Print(op->b) << ")" ; |
347 | return doc; |
348 | } |
349 | |
350 | Doc TIRTextPrinter::VisitExpr_(const MaxNode* op) { |
351 | Doc doc; |
352 | doc << "max(" << Print(op->a) << ", " << Print(op->b) << ")" ; |
353 | return doc; |
354 | } |
355 | |
356 | Doc TIRTextPrinter::VisitExpr_(const NotNode* op) { |
357 | Doc doc; |
358 | doc << "!" << Print(op->a); |
359 | return doc; |
360 | } |
361 | |
362 | Doc TIRTextPrinter::VisitExpr_(const SelectNode* op) { |
363 | Doc doc; |
364 | doc << "select(" << Print(op->condition) << ", " << Print(op->true_value) << ", " |
365 | << Print(op->false_value) << ")" ; |
366 | return doc; |
367 | } |
368 | |
369 | Doc TIRTextPrinter::VisitExpr_(const BufferLoadNode* op) { |
370 | Doc doc; |
371 | doc << Print(op->buffer) << Print(op->indices); |
372 | return doc; |
373 | } |
374 | |
375 | Doc TIRTextPrinter::VisitExpr_(const ProducerLoadNode* op) { |
376 | // TODO(tvm-team): consider make a better text format for producer. |
377 | Doc doc; |
378 | doc << op->producer->GetNameHint() << Print(op->indices); |
379 | return doc; |
380 | } |
381 | |
382 | Doc TIRTextPrinter::VisitExpr_(const LoadNode* op) { |
383 | Doc doc; |
384 | doc << "(" << PrintDType(op->dtype) << "*)" << Print(op->buffer_var) << "[" << Print(op->index) |
385 | << "]" ; |
386 | if (!is_one(op->predicate)) { |
387 | doc << " if " << Print(op->predicate); |
388 | } |
389 | return doc; |
390 | } |
391 | |
392 | Doc TIRTextPrinter::VisitExpr_(const RampNode* op) { |
393 | Doc doc; |
394 | doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " << op->lanes << ")" ; |
395 | return doc; |
396 | } |
397 | |
398 | Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) { |
399 | Doc doc; |
400 | doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")" ; |
401 | return doc; |
402 | } |
403 | |
404 | Doc TIRTextPrinter::VisitExpr_(const tir::LetNode* op) { |
405 | Doc doc; |
406 | doc << "let " << Print(op->var) << " = " << Print(op->value) << " in " << Print(op->body); |
407 | return doc; |
408 | } |
409 | |
410 | Doc TIRTextPrinter::VisitExpr_(const tir::CallNode* op) { |
411 | Doc doc; |
412 | std::vector<Doc> func_args; |
413 | if (auto* ptr_op = op->op.as<OpNode>()) { |
414 | doc << "@" << Doc::Text(ptr_op->name) << "(" ; |
415 | if (ptr_op->name == "tir.call_llvm_pure_intrin" ) { |
416 | auto f = tvm::runtime::Registry::Get("target.llvm_get_intrinsic_name" ); |
417 | ICHECK(f != nullptr) |
418 | << "Cannot find target.llvm_get_intrinsic_name. Compile with USE_LLVM=On" ; |
419 | func_args.push_back(Print((*f)(Downcast<IntImm>(op->args[0])->value))); |
420 | for (size_t i = 1; i < op->args.size(); i++) { |
421 | func_args.push_back(Print(op->args[i])); |
422 | } |
423 | } else { |
424 | for (const auto& arg : op->args) { |
425 | func_args.push_back(Print(arg)); |
426 | } |
427 | } |
428 | } else { |
429 | // TODO(bohan): Print out the name by he global var in the module. |
430 | auto* op_gvar = op->op.as<GlobalVarNode>(); |
431 | ICHECK(op_gvar != nullptr); |
432 | doc << "@" << Doc::Text(op_gvar->name_hint) << "(" ; |
433 | for (const auto& arg : op->args) { |
434 | func_args.push_back(Print(arg)); |
435 | } |
436 | } |
437 | doc << PrintSep(func_args, Doc::Text(", " )) << ", dtype=" << PrintDType(op->dtype) << ")" ; |
438 | return doc; |
439 | } |
440 | |
441 | Doc TIRTextPrinter::VisitExpr_(const ShuffleNode* op) { |
442 | Doc doc; |
443 | doc << "shuffle(" << Print(op->vectors) << ", " << Print(op->indices) << ")" ; |
444 | return doc; |
445 | } |
446 | |
447 | Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) { |
448 | Doc doc; |
449 | doc << "reduce(" << Print(op->combiner) << ", " << Print(op->source) << ", " << Print(op->axis) |
450 | << ", " << op->value_index << ", " << Print(op->init) << ")" ; |
451 | return doc; |
452 | } |
453 | |
454 | Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) { |
455 | Doc doc; |
456 | doc << "let " << Print(op->var) << " = " << Print(op->value) << NewLine() << Print(op->body); |
457 | return doc; |
458 | } |
459 | |
460 | Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) { |
461 | Doc doc; |
462 | meta_collector_.Collect(op->node); |
463 | doc << "attr [" << Print(op->node) << "] " << Doc::StrLiteral(op->attr_key) << " = " |
464 | << Print(op->value); |
465 | if (op->body->IsInstance<SeqStmtNode>()) { |
466 | doc << PrintBody(op->body); |
467 | } else { |
468 | doc << ";" << NewLine() << Print(op->body); |
469 | } |
470 | return doc; |
471 | } |
472 | |
473 | Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) { |
474 | Doc doc; |
475 | doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << NewLine() |
476 | << Print(op->body); |
477 | return doc; |
478 | } |
479 | |
480 | Doc TIRTextPrinter::VisitStmt_(const StoreNode* op) { |
481 | Doc doc; |
482 | doc << Print(op->buffer_var) << "[" << Print(op->index) << "] = " << Print(op->value); |
483 | if (!is_one(op->predicate)) { |
484 | doc << " if " << Print(op->predicate); |
485 | } |
486 | return doc; |
487 | } |
488 | |
489 | Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) { |
490 | Doc doc; |
491 | doc << Print(op->buffer) << Print(op->indices) << " = " << Print(op->value); |
492 | return doc; |
493 | } |
494 | |
495 | Doc TIRTextPrinter::VisitStmt_(const ProducerStoreNode* op) { |
496 | Doc doc; |
497 | doc << Print(op->producer) << Print(op->indices) << " = " << Print(op->value); |
498 | return doc; |
499 | } |
500 | |
501 | Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { |
502 | Doc doc; |
503 | doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", " |
504 | << Print(op->condition) << PrintBody(op->body) << ")" ; |
505 | return doc; |
506 | } |
507 | |
508 | Doc TIRTextPrinter::VisitStmt_(const ProducerRealizeNode* op) { |
509 | Doc doc; |
510 | doc << "producer_realize(" << Print(op->producer) << ", " << Print(op->bounds) << ", " |
511 | << Print(op->condition) << ", " << PrintBody(op->body) << ")" ; |
512 | return doc; |
513 | } |
514 | |
515 | Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { |
516 | Doc doc; |
517 | auto scope = GetPtrStorageScope(op->buffer_var); |
518 | doc << "allocate(" << Print(op->buffer_var) << ", " ; |
519 | doc << PrintDType(op->dtype) << ", " ; |
520 | doc << Print(op->extents) << "), storage_scope = " << scope; |
521 | if (!op->annotations.empty()) { |
522 | std::vector<Doc> attr_docs; |
523 | for (const auto& it : op->annotations) { |
524 | attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); |
525 | } |
526 | doc << ", annotations = {" << PrintSep(attr_docs, Doc::Text(", " )) << "})" ; |
527 | } |
528 | if (!is_one(op->condition)) { |
529 | doc << " if " << Print(op->condition); |
530 | } |
531 | if (op->body->IsInstance<SeqStmtNode>()) { |
532 | doc << PrintBody(op->body); |
533 | } else { |
534 | doc << ";" << NewLine() << Print(op->body); |
535 | } |
536 | return doc; |
537 | } |
538 | |
539 | Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) { |
540 | Doc doc; |
541 | doc << "constant(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", " |
542 | << Print(op->extents) << ")" ; |
543 | |
544 | if (op->body->IsInstance<SeqStmtNode>()) { |
545 | doc << PrintBody(op->body); |
546 | } else { |
547 | doc << ";" << NewLine() << Print(op->body); |
548 | } |
549 | return doc; |
550 | } |
551 | |
552 | Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) { |
553 | Doc doc; |
554 | doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) << ", " |
555 | << PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << NewLine(); |
556 | if (op->body->IsInstance<SeqStmtNode>()) { |
557 | doc << PrintBody(op->body); |
558 | } else { |
559 | doc << ";" << NewLine() << Print(op->body); |
560 | } |
561 | return doc; |
562 | } |
563 | |
564 | Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) { |
565 | Doc doc; |
566 | doc << "if " << Print(op->condition) << PrintBody(op->then_case); |
567 | if (!is_one(op->condition) && op->else_case) { |
568 | doc << " else" << PrintBody(op->else_case.value()); |
569 | } |
570 | return doc; |
571 | } |
572 | |
573 | Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) { |
574 | std::vector<Doc> stmts; |
575 | Doc seq_doc, doc; |
576 | for (Stmt stmt : op->seq) { |
577 | seq_doc << NewLine() << Print(stmt); |
578 | } |
579 | doc << " {" << Doc::Indent(2, seq_doc) << NewLine() << "}" ; |
580 | return doc; |
581 | } |
582 | |
583 | Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) { |
584 | Doc doc; |
585 | doc << Print(op->value); |
586 | return doc; |
587 | } |
588 | |
589 | Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { |
590 | Doc doc; |
591 | doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " |
592 | << Print(op->min + op->extent) << ")" ; |
593 | if (op->kind != ForKind::kSerial) { |
594 | doc << " " << Doc::StrLiteral(ForKind2String(op->kind)); |
595 | } |
596 | doc << PrintBody(op->body); |
597 | return doc; |
598 | } |
599 | |
600 | Doc TIRTextPrinter::VisitStmt_(const WhileNode* op) { |
601 | Doc doc; |
602 | doc << "while (" << Print(op->condition) << ")" ; |
603 | doc << PrintBody(op->body); |
604 | return doc; |
605 | } |
606 | |
607 | Doc TIRTextPrinter::VisitStmt_(const PrefetchNode* op) { |
608 | Doc doc; |
609 | doc << "prefetch(" << Print(op->buffer) << ", " << Print(op->bounds) << ")" ; |
610 | return doc; |
611 | } |
612 | |
613 | Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) { |
614 | const auto* block_op = op->block.as<BlockNode>(); |
615 | // print block name and block vars |
616 | Doc doc; |
617 | doc << "block([" ; |
618 | std::vector<Doc> block_var_docs; |
619 | for (const auto& iter_var : block_op->iter_vars) { |
620 | Doc block_var_doc; |
621 | if (is_zero(iter_var->dom->min) && iter_var->iter_type == kDataPar) { |
622 | block_var_doc << Print(iter_var->dom->extent); |
623 | } else { |
624 | block_var_doc << "tir." ; |
625 | switch (iter_var->iter_type) { |
626 | case kDataPar: |
627 | block_var_doc << "range" ; |
628 | break; |
629 | case kCommReduce: |
630 | block_var_doc << "reduce_axis" ; |
631 | break; |
632 | case kOrdered: |
633 | block_var_doc << "scan_axis" ; |
634 | break; |
635 | case kOpaque: |
636 | block_var_doc << "opaque_axis" ; |
637 | break; |
638 | default: |
639 | LOG(FATAL) << "Unknown block var iter type" ; |
640 | break; |
641 | } |
642 | block_var_doc << "(" << Print(iter_var->dom->min) << ", " |
643 | << Print(iter_var->dom->min + iter_var->dom->extent) << ")" ; |
644 | } |
645 | block_var_docs.push_back(block_var_doc); |
646 | } |
647 | doc << PrintSep(block_var_docs, Doc::Text(", " )) << "], " ; |
648 | doc << Doc::StrLiteral(block_op->name_hint) << ")" ; |
649 | std::vector<Doc> block_var_names; |
650 | for (const auto& iter_var : block_op->iter_vars) { |
651 | Doc block_var_name; |
652 | AllocVar(iter_var->var); |
653 | block_var_names.push_back(Print(iter_var->var)); |
654 | } |
655 | if (!block_var_names.empty()) { |
656 | doc << " as [" << PrintSep(block_var_names, Doc::Text(", " )) << "]" ; |
657 | } |
658 | doc << " {" ; |
659 | Doc block_attr_doc; |
660 | // print predicate, binding, read/write tensor region, annotations |
661 | if (!is_one(op->predicate)) { |
662 | block_attr_doc << NewLine() << "where(" << Print(op->predicate) << ")" ; |
663 | } |
664 | for (size_t i = 0; i < block_op->iter_vars.size(); ++i) |
665 | block_attr_doc << NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", " |
666 | << Print(op->iter_values[i]) << ")" ; |
667 | block_attr_doc << NewLine() << "tir.reads(" << Print(block_op->reads) << ")" ; |
668 | block_attr_doc << NewLine() << "tir.writes(" << Print(block_op->writes) << ")" ; |
669 | if (!block_op->annotations.empty()) { |
670 | std::vector<Doc> attr_docs; |
671 | for (const auto& it : block_op->annotations) { |
672 | attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); |
673 | } |
674 | block_attr_doc << NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", " )) << "})" ; |
675 | } |
676 | // print body |
677 | Doc body; |
678 | body << NewLine(); |
679 | for (const auto& alloc_buf : block_op->alloc_buffers) { |
680 | body << AllocBuf(alloc_buf) << " = alloc_buffer(" << PrintDType(alloc_buf->dtype) |
681 | << Print(alloc_buf->shape) << ")" << NewLine(); |
682 | } |
683 | for (const auto& match_buf : block_op->match_buffers) { |
684 | body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")" |
685 | << NewLine(); |
686 | } |
687 | if (block_op->init.defined()) { |
688 | Doc init_block; |
689 | init_block << "with init()" ; |
690 | init_block << PrintBody(block_op->init.value()); |
691 | body << init_block << NewLine(); |
692 | } |
693 | body << Print(block_op->body); |
694 | doc << Doc::Indent(2, block_attr_doc << body); |
695 | return doc; |
696 | } |
697 | |
698 | Doc TIRTextPrinter::VisitType_(const PrimTypeNode* node) { |
699 | Doc doc; |
700 | doc << PrintDType(node->dtype); |
701 | return doc; |
702 | } |
703 | |
704 | Doc TIRTextPrinter::VisitType_(const PointerTypeNode* node) { |
705 | Doc doc; |
706 | doc << "Pointer(" ; |
707 | if (!node->storage_scope.empty()) { |
708 | doc << node->storage_scope << " " ; |
709 | } |
710 | doc << Print(node->element_type) << ")" ; |
711 | return doc; |
712 | } |
713 | |
714 | Doc TIRTextPrinter::VisitType_(const TupleTypeNode* node) { |
715 | std::vector<Doc> fields; |
716 | for (Type field : node->fields) { |
717 | fields.push_back(Print(field)); |
718 | } |
719 | Doc doc; |
720 | doc << "(" << Doc::Concat(fields); |
721 | // conform to python tuple format (1,) |
722 | if (node->fields.size() == 1) { |
723 | doc << "," ; |
724 | } |
725 | return doc << ")" ; |
726 | } |
727 | |
728 | Doc TIRTextPrinter::PrintDType(DataType dtype) { |
729 | return Doc::Text(runtime::DLDataType2String(dtype)); |
730 | } |
731 | |
732 | template <typename T> |
733 | Doc TIRTextPrinter::PrintConstScalar(DataType dtype, const T& data) { |
734 | Doc doc; |
735 | std::ostringstream os; |
736 | os << data; |
737 | if (dtype == DataType::Int(32)) { |
738 | doc << Doc::Text(os.str()); |
739 | } else { |
740 | if (dtype.bits() == 1 && dtype.lanes() == 1 && dtype.code() == kDLUInt) { |
741 | doc << ((data == 1) ? "True" : "False" ); |
742 | return doc; |
743 | } |
744 | doc << Doc::Text(os.str()); |
745 | switch (dtype.code()) { |
746 | case kDLInt: |
747 | doc << "i" ; |
748 | break; |
749 | case kDLUInt: |
750 | doc << "u" ; |
751 | break; |
752 | case kDLFloat: |
753 | doc << "f" ; |
754 | break; |
755 | } |
756 | doc << Doc::Text(std::to_string(dtype.bits())); |
757 | if (dtype.lanes() != 1) doc << "x" << Doc::Text(std::to_string(dtype.lanes())); |
758 | } |
759 | return doc; |
760 | } |
761 | |
762 | Doc TIRTextPrinter::GetUniqueName(std::string prefix) { |
763 | // std::replace(prefix.begin(), prefix.end(), '.', '_'); |
764 | std::string unique_prefix = prefix; |
765 | auto it = name_alloc_map_.find(prefix); |
766 | if (it != name_alloc_map_.end()) { |
767 | while (name_alloc_map_.count(unique_prefix = prefix + "_" + std::to_string(++it->second)) > 0) { |
768 | } |
769 | } |
770 | name_alloc_map_[unique_prefix] = 0; |
771 | return Doc::Text(unique_prefix); |
772 | } |
773 | |
774 | Doc TIRTextPrinter::AllocVar(const tir::Var& var) { |
775 | const auto& it = memo_var_.find(var); |
776 | if (it != memo_var_.end()) { |
777 | return it->second; |
778 | } |
779 | std::string name = var->name_hint.operator std::string(); |
780 | if (name.length() == 0 || !std::isalpha(name[0])) { |
781 | name = "v" + name; |
782 | } |
783 | Doc val = GetUniqueName(name); |
784 | memo_var_[var] = val; |
785 | return val << ": " << Print(GetType(var)); |
786 | } |
787 | |
788 | Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) { |
789 | const auto& it = memo_buf_.find(buffer); |
790 | if (it != memo_buf_.end()) { |
791 | return it->second; |
792 | } |
793 | std::string name = buffer->name; |
794 | if (name.length() == 0 || !std::isalpha(name[0])) { |
795 | name = "buf_" + name; |
796 | } |
797 | Doc val = GetUniqueName(name); |
798 | memo_buf_[buffer] = val; |
799 | return val; |
800 | } |
801 | |
802 | Doc TIRTextPrinter::AllocProducer(const DataProducer& producer) { |
803 | const auto& it = memo_producer_.find(producer); |
804 | if (it != memo_producer_.end()) { |
805 | return it->second; |
806 | } |
807 | std::string name = producer->GetNameHint(); |
808 | if (name.length() == 0 || !std::isalpha(name[0])) { |
809 | name = "tensor_" + name; |
810 | } |
811 | Doc val = GetUniqueName(name); |
812 | memo_producer_[producer] = val; |
813 | return val; |
814 | } |
815 | |
816 | Doc TIRTextPrinter::PrintSep(const std::vector<Doc>& vec, const Doc& sep) { |
817 | Doc seq; |
818 | if (vec.size() != 0) { |
819 | seq = vec[0]; |
820 | for (size_t i = 1; i < vec.size(); i++) { |
821 | seq << sep << vec[i]; |
822 | } |
823 | } |
824 | return seq; |
825 | } |
826 | |
827 | Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) { |
828 | Doc doc; |
829 | if (body->IsInstance<SeqStmtNode>()) return Print(body); |
830 | doc << " {" << Doc::Indent(2, NewLine() << Print(body)) << NewLine() << "}" ; |
831 | return doc; |
832 | } |
833 | |
834 | bool TIRTextPrinter::GetVarName(tir::Var v, std::string* s) { |
835 | auto it = memo_var_.find(v); |
836 | if (it == memo_var_.end()) { |
837 | return false; |
838 | } |
839 | |
840 | *s = it->second.str(); |
841 | return true; |
842 | } |
843 | |
844 | } // namespace relay |
845 | } // namespace tvm |
846 | |