1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/tir/expr.h>
20#include <tvm/tir/op.h>
21#include <tvm/tir/stmt.h>
22#include <tvm/tir/stmt_functor.h>
23
24#include <sstream>
25
26#include "../../support/str_escape.h"
27
28namespace tvm {
29
30#define TVM_LEGACY_REPR_PRINTER_DEF_OP(Type) \
31 ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, Type value) { \
32 p.Stream() << value; \
33 return p; \
34 }
35
36TVM_LEGACY_REPR_PRINTER_DEF_OP(int);
37TVM_LEGACY_REPR_PRINTER_DEF_OP(int64_t);
38TVM_LEGACY_REPR_PRINTER_DEF_OP(float);
39TVM_LEGACY_REPR_PRINTER_DEF_OP(double);
40TVM_LEGACY_REPR_PRINTER_DEF_OP(char);
41TVM_LEGACY_REPR_PRINTER_DEF_OP(const char*);
42TVM_LEGACY_REPR_PRINTER_DEF_OP(const std::string&);
43TVM_LEGACY_REPR_PRINTER_DEF_OP(runtime::DataType);
44TVM_LEGACY_REPR_PRINTER_DEF_OP(const void*);
45TVM_LEGACY_REPR_PRINTER_DEF_OP(const String&);
46
47std::ostream& ReprLegacyPrinter::Stream() const { return stream; }
48
49ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, const ObjectRef& value) {
50 p.Stream() << AsLegacyRepr(value);
51 return p;
52}
53
54ReprLegacyPrinter& operator<<(ReprLegacyPrinter& out, tir::ForKind type) { // NOLINT(*)
55 using tvm::tir::ForKind;
56 switch (type) {
57 case ForKind::kSerial:
58 out << "for";
59 break;
60 case ForKind::kParallel:
61 out << "parallel";
62 break;
63 case ForKind::kUnrolled:
64 out << "unrolled";
65 break;
66 case ForKind::kVectorized:
67 out << "vectorized";
68 break;
69 case ForKind::kThreadBinding:
70 out << "launch_thread";
71 break;
72 }
73 return out;
74}
75
76TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
77 .set_dispatch<ArrayNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
78 auto* op = static_cast<const ArrayNode*>(node.get());
79 (*p) << '[';
80 for (size_t i = 0; i < op->size(); ++i) {
81 if (i != 0) {
82 (*p) << ", ";
83 }
84 p->Print(op->at(i));
85 }
86 (*p) << ']';
87 });
88
89TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
90 .set_dispatch<MapNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
91 auto* op = static_cast<const MapNode*>(node.get());
92 (*p) << '{';
93 for (auto it = op->begin(); it != op->end(); ++it) {
94 if (it != op->begin()) {
95 (*p) << ", ";
96 }
97 if (it->first->IsInstance<StringObj>()) {
98 (*p) << '\"' << Downcast<String>(it->first) << "\": ";
99 } else {
100 p->Print(it->first);
101 (*p) << ": ";
102 }
103 p->Print(it->second);
104 }
105 (*p) << '}';
106 });
107
108TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
109 .set_dispatch<ShapeTupleObj>([](const ObjectRef& node, ReprLegacyPrinter* p) {
110 auto* op = static_cast<const ShapeTupleObj*>(node.get());
111 (*p) << '[';
112 for (size_t i = 0; i < op->size; ++i) {
113 if (i != 0) {
114 (*p) << ", ";
115 }
116 (*p) << op->data[i];
117 }
118 (*p) << ']';
119 });
120
121TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
122 .set_dispatch<IntImmNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
123 auto* op = static_cast<const IntImmNode*>(node.get());
124 if (op->dtype == DataType::Int(32)) {
125 (*p) << op->value;
126 } else {
127 (*p) << "(" << op->dtype << ")" << op->value;
128 }
129 });
130
131TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
132 .set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
133 auto* op = static_cast<const FloatImmNode*>(node.get());
134 switch (op->dtype.bits()) {
135 case 64:
136 (*p) << op->value;
137 break;
138 case 32:
139 (*p) << op->value << 'f';
140 break;
141 case 16:
142 (*p) << op->value << 'h';
143 break;
144 default:
145 LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
146 }
147 });
148
149TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
150 .set_dispatch<RangeNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
151 auto* op = static_cast<const RangeNode*>(node.get());
152 (*p) << "range(min=" << op->min << ", ext=" << op->extent << ')';
153 });
154
155TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
156 .set_dispatch<PrimTypeNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
157 auto* node = static_cast<const PrimTypeNode*>(ref.get());
158 (*p) << node->dtype;
159 });
160
161TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
162 .set_dispatch<PointerTypeNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
163 auto* node = static_cast<const PointerTypeNode*>(ref.get());
164 if (!node->storage_scope.empty()) {
165 (*p) << node->storage_scope << " ";
166 }
167 p->Print(node->element_type);
168 (*p) << '*';
169 });
170
171TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
172 .set_dispatch<TupleTypeNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
173 auto* node = static_cast<const TupleTypeNode*>(ref.get());
174 (*p) << "TupleTypeNode(" << node->fields << ")";
175 });
176
177TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
178 .set_dispatch<IncompleteTypeNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
179 auto* node = static_cast<const IncompleteTypeNode*>(ref.get());
180 (*p) << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
181 });
182
183TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
184 .set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
185 auto* op = static_cast<const DictAttrsNode*>(node.get());
186 (*p) << op->dict;
187 });
188
189TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
190 .set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
191 auto* node = static_cast<const GlobalVarNode*>(ref.get());
192 (*p) << "GlobalVar(" << node->name_hint << ")";
193 });
194
195TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
196 .set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
197 auto* node = static_cast<const IRModuleNode*>(ref.get());
198 (*p) << "IRModule(" << node->functions << ")";
199 });
200
201TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
202 .set_dispatch<TypeVarNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
203 auto* node = static_cast<const TypeVarNode*>(ref.get());
204 (*p) << "TypeVar(" << node->name_hint << ", " << node->kind << ")";
205 });
206
207TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
208 .set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
209 auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
210 (*p) << "GlobalTypeVar(" << node->name_hint << ", " << node->kind << ")";
211 });
212
213TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
214 .set_dispatch<FuncTypeNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
215 auto* node = static_cast<const FuncTypeNode*>(ref.get());
216 (*p) << "FuncType(" << node->type_params << ", " << node->arg_types << ", " << node->ret_type
217 << ", " << node->type_constraints << ")";
218 });
219
220TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
221 .set_dispatch<RelayRefTypeNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
222 auto* node = static_cast<const RelayRefTypeNode*>(ref.get());
223 (*p) << "RelayRefTypeNode(" << node->value << ")";
224 });
225
226} // namespace tvm
227
228namespace tvm {
229namespace tir {
230
231TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
232 .set_dispatch<BufferNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
233 auto* op = static_cast<const BufferNode*>(node.get());
234 (*p) << "buffer(" << op->name << ", " << op << ")";
235 });
236
237TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
238 .set_dispatch<VarNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
239 auto* op = static_cast<const VarNode*>(node.get());
240 // omit the type
241 // stream << op->name << "." << op->type;
242 (*p) << op->name_hint;
243 });
244
245TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
246 .set_dispatch<SizeVarNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
247 auto* op = static_cast<const SizeVarNode*>(node.get());
248 (*p) << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
249 });
250
251TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
252 .set_dispatch<IterVarNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
253 auto* op = static_cast<const IterVarNode*>(node.get());
254 (*p) << "iter_var(";
255 if (op->var->name_hint.length() != 0) {
256 (*p) << op->var->name_hint << ", ";
257 }
258 if (op->dom.defined()) {
259 (*p) << op->dom;
260 }
261 if (op->thread_tag.length() != 0) {
262 (*p) << ", " << op->thread_tag;
263 }
264 (*p) << ")";
265 });
266
267TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
268 .set_dispatch<StringImmNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
269 auto* op = static_cast<const StringImmNode*>(node.get());
270 (*p) << '\"' << support::StrEscape(op->value) << '\"';
271 });
272
273TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
274 .set_dispatch<CastNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
275 auto* op = static_cast<const CastNode*>(node.get());
276 (*p) << op->dtype << '(';
277 p->Print(op->value);
278 (*p) << ')';
279 });
280
281TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
282 .set_dispatch<AddNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
283 auto* op = static_cast<const AddNode*>(node.get());
284 (*p) << '(';
285 p->Print(op->a);
286 (*p) << " + ";
287 p->Print(op->b);
288 (*p) << ')';
289 });
290
291TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
292 .set_dispatch<SubNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
293 auto* op = static_cast<const SubNode*>(node.get());
294 (*p) << '(';
295 p->Print(op->a);
296 (*p) << " - ";
297 p->Print(op->b);
298 (*p) << ')';
299 });
300
301TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
302 .set_dispatch<MulNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
303 auto* op = static_cast<const MulNode*>(node.get());
304 (*p) << '(';
305 p->Print(op->a);
306 (*p) << "*";
307 p->Print(op->b);
308 (*p) << ')';
309 });
310
311TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
312 .set_dispatch<DivNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
313 auto* op = static_cast<const DivNode*>(node.get());
314 (*p) << '(';
315 p->Print(op->a);
316 (*p) << "/";
317 p->Print(op->b);
318 (*p) << ')';
319 });
320
321TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
322 .set_dispatch<ModNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
323 auto* op = static_cast<const ModNode*>(node.get());
324 (*p) << '(';
325 p->Print(op->a);
326 (*p) << " % ";
327 p->Print(op->b);
328 (*p) << ')';
329 });
330
331TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
332 .set_dispatch<FloorDivNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
333 auto* op = static_cast<const FloorDivNode*>(node.get());
334 (*p) << "floordiv(" << op->a << ", " << op->b << ")";
335 });
336
337TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
338 .set_dispatch<FloorModNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
339 auto* op = static_cast<const FloorModNode*>(node.get());
340 (*p) << "floormod(" << op->a << ", " << op->b << ")";
341 });
342
343TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
344 .set_dispatch<MinNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
345 auto* op = static_cast<const MinNode*>(node.get());
346 (*p) << "min(";
347 p->Print(op->a);
348 (*p) << ", ";
349 p->Print(op->b);
350 (*p) << ")";
351 });
352
353TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
354 .set_dispatch<MaxNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
355 auto* op = static_cast<const MaxNode*>(node.get());
356 (*p) << "max(";
357 p->Print(op->a);
358 (*p) << ", ";
359 p->Print(op->b);
360 (*p) << ")";
361 });
362
363TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
364 .set_dispatch<EQNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
365 auto* op = static_cast<const EQNode*>(node.get());
366 (*p) << '(';
367 p->Print(op->a);
368 (*p) << " == ";
369 p->Print(op->b);
370 (*p) << ')';
371 });
372
373TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
374 .set_dispatch<NENode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
375 auto* op = static_cast<const NENode*>(node.get());
376 (*p) << '(';
377 p->Print(op->a);
378 (*p) << " != ";
379 p->Print(op->b);
380 (*p) << ')';
381 });
382
383TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
384 .set_dispatch<LTNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
385 auto* op = static_cast<const LTNode*>(node.get());
386 (*p) << '(';
387 p->Print(op->a);
388 (*p) << " < ";
389 p->Print(op->b);
390 (*p) << ')';
391 });
392
393TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
394 .set_dispatch<LENode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
395 auto* op = static_cast<const LENode*>(node.get());
396 (*p) << '(';
397 p->Print(op->a);
398 (*p) << " <= ";
399 p->Print(op->b);
400 (*p) << ')';
401 });
402
403TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
404 .set_dispatch<GTNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
405 auto* op = static_cast<const GTNode*>(node.get());
406 (*p) << '(';
407 p->Print(op->a);
408 (*p) << " > ";
409 p->Print(op->b);
410 (*p) << ')';
411 });
412
413TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
414 .set_dispatch<GENode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
415 auto* op = static_cast<const GENode*>(node.get());
416 (*p) << '(';
417 p->Print(op->a);
418 (*p) << " >= ";
419 p->Print(op->b);
420 (*p) << ')';
421 });
422
423TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
424 .set_dispatch<AndNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
425 auto* op = static_cast<const AndNode*>(node.get());
426 (*p) << '(';
427 p->Print(op->a);
428 (*p) << " && ";
429 p->Print(op->b);
430 (*p) << ')';
431 });
432
433TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
434 .set_dispatch<OrNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
435 auto* op = static_cast<const OrNode*>(node.get());
436 (*p) << '(';
437 p->Print(op->a);
438 (*p) << " || ";
439 p->Print(op->b);
440 (*p) << ')';
441 });
442
443TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
444 .set_dispatch<NotNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
445 auto* op = static_cast<const NotNode*>(node.get());
446 (*p) << '!';
447 p->Print(op->a);
448 });
449
450TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
451 .set_dispatch<SelectNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
452 auto* op = static_cast<const SelectNode*>(node.get());
453 (*p) << "select(";
454 p->Print(op->condition);
455 (*p) << ", ";
456 p->Print(op->true_value);
457 (*p) << ", ";
458 p->Print(op->false_value);
459 (*p) << ")";
460 });
461
462TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
463 .set_dispatch<LoadNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
464 auto* op = static_cast<const LoadNode*>(node.get());
465 (*p) << op->buffer_var << "[";
466 p->Print(op->index);
467 (*p) << "]";
468 if (!is_one(op->predicate)) {
469 (*p) << " if ";
470 p->Print(op->predicate);
471 }
472 });
473
474TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
475 .set_dispatch<RampNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
476 auto* op = static_cast<const RampNode*>(node.get());
477 (*p) << "ramp(";
478 p->Print(op->base);
479 (*p) << ", ";
480 p->Print(op->stride);
481 (*p) << ", " << op->lanes << ")";
482 });
483
484TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
485 .set_dispatch<BroadcastNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
486 auto* op = static_cast<const BroadcastNode*>(node.get());
487 (*p) << "x" << op->lanes << "(";
488 p->Print(op->value);
489 (*p) << ")";
490 });
491
492TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
493 .set_dispatch<LetNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
494 auto* op = static_cast<const LetNode*>(node.get());
495 (*p) << "(let " << op->var << " = ";
496 p->Print(op->value);
497 (*p) << " in ";
498 p->Print(op->body);
499 (*p) << ")";
500 });
501
502TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
503 .set_dispatch<CallNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
504 auto* op = static_cast<const CallNode*>(node.get());
505 if (auto* ptr_op = op->op.as<OpNode>()) {
506 (*p) << ptr_op->name << "(";
507 } else {
508 auto* ptr_gvar = op->op.as<GlobalVarNode>();
509 ICHECK(ptr_gvar != nullptr);
510 (*p) << "@" << ptr_gvar->name_hint << "(";
511 }
512 for (size_t i = 0; i < op->args.size(); ++i) {
513 p->Print(op->args[i]);
514 if (i < op->args.size() - 1) {
515 (*p) << ", ";
516 }
517 }
518 (*p) << ")";
519 });
520
521template <typename T>
522void PrintList(const Array<T>& exprs, ReprLegacyPrinter* p) {
523 for (size_t i = 0; i < exprs.size(); ++i) {
524 p->Print(exprs[i]);
525 if (i < exprs.size() - 1) {
526 (*p) << ", ";
527 }
528 }
529}
530
531TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
532 .set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
533 auto* op = static_cast<const ShuffleNode*>(node.get());
534 (*p) << "shuffle(";
535 PrintList(op->vectors, p);
536 (*p) << ", ";
537 PrintList(op->indices, p);
538 (*p) << ")";
539 });
540
541TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
542 .set_dispatch<CommReducerNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
543 auto* op = static_cast<const CommReducerNode*>(node.get());
544 (*p) << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs << ", rhs=" << op->rhs
545 << ", identity_element=" << op->identity_element << ")";
546 });
547
548TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
549 .set_dispatch<ReduceNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
550 auto* op = static_cast<const ReduceNode*>(node.get());
551 (*p) << "reduce(combiner=" << op->combiner;
552 (*p) << ", source=" << op->source;
553 (*p) << ", init=" << op->init;
554 (*p) << ", axis=" << op->axis;
555 (*p) << ", where=" << op->condition;
556 (*p) << ", value_index=" << op->value_index;
557 (*p) << ")";
558 });
559
560TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
561 .set_dispatch<AnyNode>([](const ObjectRef& node, ReprLegacyPrinter* p) { (*p) << "?"; });
562
563TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
564 .set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
565 auto* op = static_cast<const BufferLoadNode*>(node.get());
566 (*p) << op->buffer->name << "[";
567 for (size_t i = 0; i < op->indices.size(); ++i) {
568 p->Print(op->indices[i]);
569 if (i < op->indices.size() - 1) {
570 (*p) << ", ";
571 }
572 }
573 (*p) << "]";
574 });
575
576TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
577 .set_dispatch<ProducerLoadNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
578 auto* op = static_cast<const ProducerLoadNode*>(node.get());
579 (*p) << op->producer->GetNameHint() << "[";
580 for (size_t i = 0; i < op->indices.size(); ++i) {
581 p->Print(op->indices[i]);
582 if (i < op->indices.size() - 1) {
583 (*p) << ", ";
584 }
585 }
586 (*p) << "]";
587 });
588
589TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
590 .set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
591 auto* node = static_cast<const PrimFuncNode*>(ref.get());
592 (*p) << "PrimFunc(" << node->params << ") ";
593 if (node->attrs.defined()) {
594 (*p) << "attrs=" << node->attrs;
595 }
596 (*p) << " {\n";
597 p->indent += 2;
598 p->Print(node->body);
599 p->indent -= 2;
600 (*p) << "}\n";
601 });
602
603TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
604 .set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
605 auto* op = static_cast<const LetStmtNode*>(node.get());
606 p->PrintIndent();
607 (*p) << "let " << op->var << " = ";
608 p->Print(op->value);
609 (*p) << '\n';
610 p->Print(op->body);
611 });
612
613TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
614 .set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
615 auto* op = static_cast<const AttrStmtNode*>(node.get());
616 p->PrintIndent();
617 (*p) << "// attr [";
618 p->Print(op->node);
619 (*p) << "] " << op->attr_key << " = ";
620 p->Print(op->value);
621 (*p) << '\n';
622 p->Print(op->body);
623 });
624
625TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
626 .set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
627 auto* op = static_cast<const AssertStmtNode*>(node.get());
628 p->PrintIndent();
629 (*p) << "assert(";
630 p->Print(op->condition);
631 (*p) << ", ";
632 p->Print(op->message);
633 (*p) << ")\n";
634 p->Print(op->body);
635 });
636
637TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
638 .set_dispatch<ForNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
639 auto* op = static_cast<const ForNode*>(node.get());
640 p->PrintIndent();
641 (*p) << op->kind << " (" << op->loop_var << ", ";
642 p->Print(op->min);
643 (*p) << ", ";
644 p->Print(op->extent);
645 (*p) << ") {\n";
646
647 p->indent += 2;
648 p->Print(op->body);
649 p->indent -= 2;
650
651 p->PrintIndent();
652 (*p) << "}\n";
653 });
654
655TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
656 .set_dispatch<WhileNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
657 auto* op = static_cast<const WhileNode*>(node.get());
658 p->PrintIndent();
659 (*p) << "while(" << op->condition << ") {\n";
660 p->indent += 2;
661 p->Print(op->body);
662 p->indent -= 2;
663 p->PrintIndent();
664 (*p) << "}\n";
665 });
666
667TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
668 .set_dispatch<StoreNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
669 auto* op = static_cast<const StoreNode*>(node.get());
670 p->PrintIndent();
671 (*p) << op->buffer_var << "[";
672 p->Print(op->index);
673 (*p) << "] = ";
674 p->Print(op->value);
675 if (!is_one(op->predicate)) {
676 (*p) << " if ";
677 p->Print(op->predicate);
678 }
679 (*p) << '\n';
680 });
681
682TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
683 .set_dispatch<ProducerStoreNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
684 auto* op = static_cast<const ProducerStoreNode*>(node.get());
685 p->PrintIndent();
686 (*p) << op->producer->GetNameHint() << "[";
687 for (size_t i = 0; i < op->indices.size(); ++i) {
688 p->Print(op->indices[i]);
689 if (i < op->indices.size() - 1) (*p) << ", ";
690 }
691 (*p) << "]";
692 (*p) << " =";
693 p->Print(op->value);
694 (*p) << '\n';
695 });
696
697TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
698 .set_dispatch<AllocateNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
699 auto* op = static_cast<const AllocateNode*>(node.get());
700 const auto* ptr_type = op->buffer_var->type_annotation.as<PointerTypeNode>();
701 ICHECK(ptr_type) << "The provided variable is not of pointer type";
702 p->PrintIndent();
703 (*p) << "allocate " << op->buffer_var << "[" << op->dtype;
704 for (size_t i = 0; i < op->extents.size(); ++i) {
705 (*p) << " * ";
706 p->Print(op->extents[i]);
707 }
708 (*p) << "], storage_scope = " << ptr_type->storage_scope;
709 if (!is_one(op->condition)) {
710 (*p) << " if ";
711 p->Print(op->condition);
712 }
713 (*p) << "\n";
714 p->Print(op->body);
715 });
716
717TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
718 .set_dispatch<AllocateConstNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
719 auto* op = static_cast<const AllocateConstNode*>(node.get());
720 p->PrintIndent();
721 (*p) << "constant " << op->buffer_var << "[" << op->dtype;
722 for (size_t i = 0; i < op->extents.size(); ++i) {
723 (*p) << " * ";
724 p->Print(op->extents[i]);
725 }
726 (*p) << "]";
727 (*p) << "\n";
728 p->Print(op->body);
729 });
730
731TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
732 .set_dispatch<DeclBufferNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
733 auto* op = static_cast<const DeclBufferNode*>(node.get());
734 p->PrintIndent();
735 (*p) << "decl_buffer " << op->buffer << "\n";
736 (*p) << op->body;
737 });
738
739TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
740 .set_dispatch<ProducerRealizeNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
741 auto* op = static_cast<const ProducerRealizeNode*>(node.get());
742 p->PrintIndent();
743 (*p) << "producer_realize " << op->producer->GetNameHint() << "(";
744 for (size_t i = 0; i < op->bounds.size(); ++i) {
745 (*p) << "[";
746 p->Print(op->bounds[i]->min);
747 (*p) << ", ";
748 p->Print(op->bounds[i]->extent);
749 (*p) << "]";
750 if (i < op->bounds.size() - 1) (*p) << ", ";
751 }
752 (*p) << ")";
753 if (!is_one(op->condition)) {
754 (*p) << " if ";
755 p->Print(op->condition);
756 }
757 (*p) << " {\n";
758
759 p->indent += 2;
760 p->Print(op->body);
761 p->indent -= 2;
762
763 p->PrintIndent();
764 (*p) << "}\n";
765 });
766
767TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
768 .set_dispatch<PrefetchNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
769 auto* op = static_cast<const PrefetchNode*>(node.get());
770 p->PrintIndent();
771 (*p) << "prefetch " << op->buffer << "(";
772 for (size_t i = 0; i < op->bounds.size(); ++i) {
773 (*p) << "[";
774 p->Print(op->bounds[i]->min);
775 (*p) << ", ";
776 p->Print(op->bounds[i]->extent);
777 (*p) << "]";
778 if (i < op->bounds.size() - 1) (*p) << ", ";
779 }
780 (*p) << ")";
781 });
782
783TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
784 .set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
785 auto* op = static_cast<const SeqStmtNode*>(node.get());
786 for (Stmt stmt : op->seq) {
787 p->Print(stmt);
788 }
789 });
790
791TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
792 .set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
793 auto* op = static_cast<const IfThenElseNode*>(node.get());
794 p->PrintIndent();
795 while (true) {
796 (*p) << "if (" << op->condition << ") {\n";
797 p->indent += 2;
798 p->Print(op->then_case);
799 p->indent -= 2;
800
801 if (!op->else_case) {
802 break;
803 }
804
805 if (const IfThenElseNode* nested_if = op->else_case.as<IfThenElseNode>()) {
806 p->PrintIndent();
807 (*p) << "} else ";
808 op = nested_if;
809 } else {
810 p->PrintIndent();
811 (*p) << "} else {\n";
812 p->indent += 2;
813 p->Print(op->else_case);
814 p->indent -= 2;
815 break;
816 }
817 }
818 p->PrintIndent();
819 (*p) << "}\n";
820 });
821
822TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
823 .set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
824 auto* op = static_cast<const EvaluateNode*>(node.get());
825 p->PrintIndent();
826 p->Print(op->value);
827 (*p) << "\n";
828 });
829
830TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
831 .set_dispatch<BufferStoreNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
832 auto* op = static_cast<const BufferStoreNode*>(node.get());
833 p->PrintIndent();
834 (*p) << op->buffer->name << "[";
835 for (size_t i = 0; i < op->indices.size(); ++i) {
836 p->Print(op->indices[i]);
837 if (i < op->indices.size() - 1) (*p) << ", ";
838 }
839 (*p) << "]";
840 (*p) << " = ";
841 p->Print(op->value);
842 (*p) << '\n';
843 });
844
845TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
846 .set_dispatch<BufferRealizeNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
847 auto* op = static_cast<const BufferRealizeNode*>(node.get());
848 p->PrintIndent();
849 (*p) << "buffer_realize " << op->buffer->name << "(";
850 for (size_t i = 0; i < op->bounds.size(); ++i) {
851 (*p) << "[";
852 p->Print(op->bounds[i]->min);
853 (*p) << ", ";
854 p->Print(op->bounds[i]->extent);
855 (*p) << "]";
856 if (i < op->bounds.size() - 1) (*p) << ", ";
857 }
858 (*p) << ")";
859 if (!is_one(op->condition)) {
860 (*p) << " if ";
861 p->Print(op->condition);
862 }
863 (*p) << " {\n";
864
865 p->indent += 2;
866 p->Print(op->body);
867 p->indent -= 2;
868
869 p->PrintIndent();
870 (*p) << "}\n";
871 });
872
873TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
874 .set_dispatch<BufferRegionNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
875 auto* op = static_cast<const BufferRegionNode*>(node.get());
876 (*p) << op->buffer->name;
877 (*p) << "[";
878 for (size_t i = 0; i < op->region.size(); ++i) {
879 const auto& range = op->region[i];
880 p->Print(range->min);
881 if (!is_one(range->extent)) {
882 (*p) << ":";
883 p->Print(range->min + range->extent);
884 }
885 if (i != op->region.size() - 1) (*p) << ", ";
886 }
887 (*p) << "]";
888 });
889
890TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
891 .set_dispatch<MatchBufferRegionNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
892 auto* op = static_cast<const MatchBufferRegionNode*>(node.get());
893 p->PrintIndent();
894 (*p) << op->buffer->name << " = match_buffer(";
895 p->Print(op->source);
896 (*p) << ")\n";
897 });
898
899void PrintBlockTitle(const BlockNode* op, ReprLegacyPrinter* p) {
900 (*p) << "block " << op->name_hint << "(";
901 for (size_t i = 0; i < op->iter_vars.size(); i++) {
902 p->Print(op->iter_vars[i]);
903 if (i < op->iter_vars.size() - 1) (*p) << ", ";
904 }
905 (*p) << ")";
906}
907
908void PrintBlockSignature(const BlockNode* op, ReprLegacyPrinter* p) {
909 // print read/write regions
910 p->PrintIndent();
911 (*p) << "reads(";
912 p->Print(op->reads);
913 (*p) << ")\n";
914 p->PrintIndent();
915 (*p) << "writes(";
916 p->Print(op->writes);
917 (*p) << ")\n";
918 // Print alloc_buffers
919 for (const auto& alloc_buf : op->alloc_buffers) {
920 p->PrintIndent();
921 (*p) << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "[";
922 for (size_t i = 0; i < alloc_buf->shape.size(); ++i) {
923 if (i > 0) (*p) << ", ";
924 p->Print(alloc_buf->shape[i]);
925 }
926 (*p) << "])\n";
927 }
928 // Print match_buffer_regions
929 for (const auto& match_buf : op->match_buffers) {
930 p->Print(match_buf);
931 }
932 if (!op->annotations.empty()) {
933 p->PrintIndent();
934 (*p) << "annotations(" << op->annotations << ")\n";
935 }
936}
937
938void PrintBlockBody(const BlockNode* op, ReprLegacyPrinter* p) {
939 // Print init
940 if (op->init.defined()) {
941 p->PrintIndent();
942 (*p) << "with init() {\n";
943 p->indent += 2;
944 p->Print(op->init.value());
945 p->indent -= 2;
946 p->PrintIndent();
947 (*p) << "}\n";
948 }
949 // Print body
950 p->Print(op->body);
951}
952
953TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
954 .set_dispatch<BlockNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
955 auto* op = static_cast<const BlockNode*>(node.get());
956 p->PrintIndent();
957 PrintBlockTitle(op, p);
958 (*p) << " {\n";
959 p->indent += 2;
960
961 // Print block elements (e.g. reads/writes, etc)
962 PrintBlockSignature(op, p);
963 // Print block init and body
964 PrintBlockBody(op, p);
965
966 p->indent -= 2;
967 p->PrintIndent();
968 (*p) << "}\n";
969 });
970
971TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
972 .set_dispatch<BlockRealizeNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
973 auto* op = static_cast<const BlockRealizeNode*>(node.get());
974 auto* block_op = op->block.get();
975 p->PrintIndent();
976 PrintBlockTitle(block_op, p);
977 (*p) << " {\n";
978 p->indent += 2;
979
980 // Print binding iter_values
981 for (size_t i = 0; i < block_op->iter_vars.size(); ++i) {
982 p->PrintIndent();
983 (*p) << "bind(";
984 p->Print(block_op->iter_vars[i]->var);
985 (*p) << ", ";
986 p->Print(op->iter_values[i]);
987 (*p) << ")\n";
988 }
989 // Print predicate
990 if (!is_one(op->predicate)) {
991 p->PrintIndent();
992 (*p) << "where(";
993 p->Print(op->predicate);
994 (*p) << ")\n";
995 }
996 // Print block elements (e.g. reads/writes, etc)
997 PrintBlockSignature(block_op, p);
998 // Print block init and body
999 PrintBlockBody(block_op, p);
1000
1001 p->indent -= 2;
1002 p->PrintIndent();
1003 (*p) << "}\n";
1004 });
1005
1006} // namespace tir
1007} // namespace tvm
1008