1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/tir/builtin.h>
20
21#include "./utils.h"
22
23namespace tvm {
24namespace script {
25namespace printer {
26
27Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) {
28 if (!d->IsVarDefined(var)) {
29 if (Optional<Frame> opt_f = FindLowestVarDef(var, d)) {
30 ExprDoc lhs = DefineVar(var, opt_f.value(), d);
31 Type type = var->type_annotation;
32 if (const auto* ptr_type = type.as<PointerTypeNode>()) {
33 ICHECK(ptr_type->element_type->IsInstance<PrimTypeNode>());
34 ExprDoc rhs = d->AsDoc<ExprDoc>(type, var_p->Attr("type_annotation"));
35 opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
36 } else {
37 ExprDoc rhs = TIR(d, "var")->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype"))});
38 opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
39 }
40 } else {
41 LOG(WARNING) << "Didn't find variable definition for: " << var->name_hint;
42 }
43 }
44 if (Optional<ExprDoc> doc = d->GetVarDoc(var)) {
45 return doc.value();
46 }
47 LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << var->name_hint;
48}
49
50TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) //
51 .set_dispatch<tir::Var>("", [](tir::Var var, ObjectPath p, IRDocsifier d) -> Doc {
52 return PrintVar(var, p, d);
53 });
54
55TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) //
56 .set_dispatch<tir::SizeVar>("", [](tir::SizeVar var, ObjectPath p, IRDocsifier d) -> Doc {
57 return PrintVar(var, p, d);
58 });
59
60TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
61 .set_dispatch<tir::IterVar>("", [](tir::IterVar var, ObjectPath var_p, IRDocsifier d) -> Doc {
62 return TIR(d, "iter_var")
63 ->Call({
64 d->AsDoc<ExprDoc>(var->var, var_p->Attr("var")),
65 d->AsDoc<ExprDoc>(var->dom, var_p->Attr("dom")),
66 LiteralDoc::Str(IterVarType2String(var->iter_type), var_p->Attr("iter_type")),
67 LiteralDoc::Str(var->thread_tag, var_p->Attr("thread_tag")),
68 });
69 });
70
71TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
72 .set_dispatch<tir::Not>("", [](tir::Not node, ObjectPath p, IRDocsifier d) -> Doc {
73 ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a"));
74 if (a->IsInstance<LiteralDocNode>()) {
75 return TIR(d, "Not")->Call({a});
76 }
77 return OperationDoc(OperationDocNode::Kind::kNot, {a});
78 });
79
80TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
81 .set_dispatch<tir::StringImm>("", [](tir::StringImm s, ObjectPath p, IRDocsifier d) -> Doc {
82 if (HasMultipleLines(s->value)) {
83 return d->AddMetadata(s);
84 } else {
85 return d->AsDoc<ExprDoc>(s->value, p->Attr("value"));
86 }
87 });
88
89TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
90 .set_dispatch<tir::Cast>("", [](tir::Cast cast, ObjectPath p, IRDocsifier d) -> Doc {
91 ExprDoc dtype = LiteralDoc::DataType(cast->dtype, p->Attr("dtype"));
92 ExprDoc value = d->AsDoc<ExprDoc>(cast->value, p->Attr("value"));
93 return TIR(d, "Cast")->Call({dtype, value});
94 });
95
96TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
97 .set_dispatch<tir::Select>("", [](tir::Select select, ObjectPath p, IRDocsifier d) -> Doc {
98 return TIR(d, "Select")
99 ->Call({
100 d->AsDoc<ExprDoc>(select->condition, p->Attr("condition")),
101 d->AsDoc<ExprDoc>(select->true_value, p->Attr("true_value")),
102 d->AsDoc<ExprDoc>(select->false_value, p->Attr("false_value")),
103 });
104 });
105
106TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
107 .set_dispatch<tir::Ramp>("", [](tir::Ramp ramp, ObjectPath ramp_p, IRDocsifier d) -> Doc {
108 return TIR(d, "Ramp")->Call({
109 d->AsDoc<ExprDoc>(ramp->base, ramp_p->Attr("base")),
110 d->AsDoc<ExprDoc>(ramp->stride, ramp_p->Attr("stride")),
111 LiteralDoc::Int(ramp->lanes, ramp_p->Attr("lanes")),
112 });
113 });
114
115TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
116 .set_dispatch<tir::Broadcast>("", [](tir::Broadcast bc, ObjectPath bc_p, IRDocsifier d) -> Doc {
117 return TIR(d, "Broadcast")
118 ->Call({
119 d->AsDoc<ExprDoc>(bc->value, bc_p->Attr("value")),
120 LiteralDoc::Int(bc->lanes, bc_p->Attr("lanes")),
121 });
122 });
123
124TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
125 .set_dispatch<tir::Shuffle>( //
126 "", [](tir::Shuffle shuffle, ObjectPath p, IRDocsifier d) -> Doc {
127 return TIR(d, "Shuffle")
128 ->Call({
129 d->AsDoc<ExprDoc>(shuffle->vectors, p->Attr("vectors")),
130 d->AsDoc<ExprDoc>(shuffle->indices, p->Attr("indices")),
131 });
132 });
133
134TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
135 .set_dispatch<tir::CommReducer>( //
136 "", [](tir::CommReducer r, ObjectPath p, IRDocsifier d) -> Doc {
137 ICHECK_EQ(r->lhs.size(), r->rhs.size());
138 LambdaDoc lambda{nullptr};
139 {
140 With<TIRFrame> f(d, r);
141 int n_vars = r->lhs.size();
142 Array<IdDoc> vars;
143 vars.reserve(n_vars + n_vars);
144 for (int i = 0; i < n_vars; ++i) {
145 vars.push_back(Downcast<IdDoc>(DefineVar(r->lhs[i], *f, d)));
146 }
147 for (int i = 0; i < n_vars; ++i) {
148 vars.push_back(Downcast<IdDoc>(DefineVar(r->rhs[i], *f, d)));
149 }
150 int n_results = r->result.size();
151 Array<ExprDoc> results;
152 results.reserve(n_results);
153 for (int i = 0; i < n_results; ++i) {
154 results.push_back(d->AsDoc<ExprDoc>(r->result[i], p->Attr("result")->ArrayIndex(i)));
155 }
156 if (results.size() == 1) {
157 lambda = LambdaDoc(vars, results[0]);
158 } else {
159 lambda = LambdaDoc(vars, TupleDoc(results));
160 }
161 }
162 ExprDoc id = d->AsDoc<ExprDoc>(r->identity_element, p->Attr("identity_element"));
163 return TIR(d, "comm_reducer")->Call({lambda, id});
164 });
165
166LambdaDoc PrintIndexMap(const ObjectRef& map, const Array<tir::Var>& vs, const ObjectPath& vs_p,
167 const Array<PrimExpr>& es, const ObjectPath& es_p, const IRDocsifier& d) {
168 With<TIRFrame> f(d, map);
169 Array<IdDoc> vars;
170 for (int i = 0, l = vs.size(); i < l; ++i) {
171 vars.push_back(Downcast<IdDoc>(DefineVar(vs[i], *f, d)));
172 }
173 Array<ExprDoc> exprs;
174 for (int i = 0, l = es.size(); i < l; ++i) {
175 exprs.push_back(d->AsDoc<ExprDoc>(es[i], es_p->ArrayIndex(i)));
176 }
177 return LambdaDoc(vars, TupleDoc(exprs));
178}
179
180TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
181 .set_dispatch<tir::IndexMap>( //
182 "", [](tir::IndexMap m, ObjectPath m_p, IRDocsifier d) -> Doc {
183 LambdaDoc map = PrintIndexMap(m, m->initial_indices, m_p->Attr("initial_indices"),
184 m->final_indices, m_p->Attr("final_indices"), d);
185 if (m->inverse_index_map.defined()) {
186 tir::IndexMap inverse = Downcast<tir::IndexMap>(m->inverse_index_map);
187 LambdaDoc inv = PrintIndexMap(inverse, inverse->initial_indices,
188 m_p->Attr("inverse_index_map")->Attr("initial_indices"),
189 inverse->final_indices,
190 m_p->Attr("inverse_index_map")->Attr("final_indices"), d);
191 return TIR(d, "index_map")->Call({map}, {"inverse_index_map"}, {inv});
192 } else {
193 return TIR(d, "index_map")->Call({map});
194 }
195 });
196
197TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
198 .set_dispatch<tir::Let>("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc {
199 return TIR(d, "let")->Call({
200 d->AsDoc<ExprDoc>(let->var, p->Attr("var")),
201 d->AsDoc<ExprDoc>(let->value, p->Attr("value")),
202 d->AsDoc<ExprDoc>(let->body, p->Attr("body")),
203 });
204 });
205
206TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
207 .set_dispatch<tir::Call>("", [](tir::Call call, ObjectPath call_p, IRDocsifier d) -> Doc {
208 static const OpAttrMap<tir::TScriptPrinterName>& op_names =
209 Op::GetAttrMap<tir::TScriptPrinterName>("TScriptPrinterName");
210 static const std::unordered_set<const Object*> dtype_first_arg = {
211 tir::builtin::reinterpret().get(),
212 tir::builtin::call_extern().get(),
213 tir::builtin::call_llvm_intrin().get(), //
214 tir::builtin::call_llvm_pure_intrin().get(), //
215 tir::builtin::call_pure_extern().get(), //
216 tir::builtin::ptx_mma().get(),
217 tir::builtin::ptx_mma_sp().get(),
218 tir::builtin::ptx_ldmatrix().get(),
219 tir::builtin::ptx_cp_async().get(),
220 tir::builtin::mma_store().get(),
221 tir::builtin::mma_fill().get(),
222 tir::builtin::vectorlow().get(),
223 tir::builtin::vectorhigh().get(),
224 tir::builtin::vectorcombine().get(),
225 Op::Get("tir.type_annotation").get(),
226 };
227 static const std::unordered_set<const Object*> dtype_last_arg = {
228 tir::builtin::tvm_struct_get().get(),
229 };
230 ExprDoc prefix{nullptr};
231 if (const auto* op = call->op.as<OpNode>()) {
232 String name = op_names.get(GetRef<Op>(op), op->name);
233 if (op_names.count(GetRef<Op>(op)) == 0) {
234 LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name;
235 }
236 prefix = TIR(d, name);
237 } else if (const auto* gv = call->op.as<GlobalVarNode>()) {
238 prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op"));
239 } else {
240 LOG(FATAL) << "call: " << call;
241 }
242 Array<ExprDoc> args;
243 int n_args = call->args.size();
244 args.reserve(n_args + 1);
245 if (dtype_first_arg.count(call->op.get())) {
246 args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")));
247 }
248 for (int i = 0; i < n_args; ++i) {
249 args.push_back(d->AsDoc<ExprDoc>(call->args[i], call_p->Attr("args")->ArrayIndex(i)));
250 }
251 if (dtype_last_arg.count(call->op.get())) {
252 args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype")));
253 }
254 return prefix->Call(args);
255 });
256
257TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
258 .set_dispatch<tir::Any>("", [](tir::Any any, ObjectPath p, IRDocsifier d) -> Doc {
259 return TIR(d, "Any")->Call({});
260 });
261
262TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
263 .set_dispatch<tir::Reduce>("", [](tir::Reduce r, ObjectPath p, IRDocsifier d) -> Doc {
264 ExprDoc combiner = d->AsDoc<ExprDoc>(r->combiner, p->Attr("combiner"));
265 ExprDoc source = d->AsDoc<ExprDoc>(r->source, p->Attr("source"));
266 ExprDoc init = d->AsDoc<ExprDoc>(r->init, p->Attr("init"));
267 ExprDoc axis = d->AsDoc<ExprDoc>(r->axis, p->Attr("axis"));
268 ExprDoc condition = d->AsDoc<ExprDoc>(r->condition, p->Attr("condition"));
269 ExprDoc value_index = LiteralDoc::Int(r->value_index, p->Attr("value_index"));
270 return TIR(d, "reduce")
271 ->Call({combiner}, {"source", "init", "axis", "condition", "value_index"},
272 {source, init, axis, condition, value_index});
273 LOG(FATAL) << "ValueError: Reduce should never exist in TIR: " << r;
274 });
275
276TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
277 .set_dispatch<tir::Load>("", [](tir::Load load, ObjectPath p, IRDocsifier d) -> Doc {
278 LOG(FATAL) << "ValueError: Load has been deprecated for BufferLoad: " << load;
279 });
280
281#define TVM_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \
282 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \
283 .set_dispatch<tir::NodeType>("", \
284 [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \
285 ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \
286 ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \
287 return TIR(d, OpString)->Call({a, b}); \
288 });
289
290bool IsNumber(const ExprDoc& e) {
291 if (const auto* n = e.as<LiteralDocNode>()) {
292 if (n->value.defined()) {
293 return n->value->IsInstance<IntImmNode>() || n->value->IsInstance<FloatImmNode>();
294 }
295 }
296 return false;
297}
298
299#define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \
300 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \
301 .set_dispatch<tir::NodeType>("", \
302 [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \
303 ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \
304 ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \
305 if (IsNumber(a) && IsNumber(b)) { \
306 return TIR(d, OpString)->Call({a, b}); \
307 } \
308 return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \
309 });
310
311TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add", kAdd);
312TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub", kSub);
313TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, "Mul", kMult);
314TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Div, "Div", kDiv);
315TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, "FloorDiv", kFloorDiv);
316TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, "FloorMod", kMod);
317TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, "LT", kLt);
318TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, "LE", kLtE);
319TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, "EQ", kEq);
320TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, "NE", kNotEq);
321TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, "GT", kGt);
322TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, "GE", kGtE);
323TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, "And", kAnd);
324TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, "Or", kOr);
325
326TVM_SCRIPT_PRINTER_DEF_BINARY(Mod, "truncmod");
327TVM_SCRIPT_PRINTER_DEF_BINARY(Min, "min");
328TVM_SCRIPT_PRINTER_DEF_BINARY(Max, "max");
329
330#undef TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR
331#undef TVM_SCRIPT_PRINTER_DEF_BINARY
332
333TVM_SCRIPT_REPR(tir::VarNode, ReprPrintTIR);
334TVM_SCRIPT_REPR(tir::SizeVarNode, ReprPrintTIR);
335TVM_SCRIPT_REPR(tir::IterVarNode, ReprPrintTIR);
336TVM_SCRIPT_REPR(tir::StringImmNode, ReprPrintTIR);
337TVM_SCRIPT_REPR(tir::CastNode, ReprPrintTIR);
338TVM_SCRIPT_REPR(tir::AddNode, ReprPrintTIR);
339TVM_SCRIPT_REPR(tir::SubNode, ReprPrintTIR);
340TVM_SCRIPT_REPR(tir::MulNode, ReprPrintTIR);
341TVM_SCRIPT_REPR(tir::DivNode, ReprPrintTIR);
342TVM_SCRIPT_REPR(tir::ModNode, ReprPrintTIR);
343TVM_SCRIPT_REPR(tir::FloorDivNode, ReprPrintTIR);
344TVM_SCRIPT_REPR(tir::FloorModNode, ReprPrintTIR);
345TVM_SCRIPT_REPR(tir::MinNode, ReprPrintTIR);
346TVM_SCRIPT_REPR(tir::MaxNode, ReprPrintTIR);
347TVM_SCRIPT_REPR(tir::LTNode, ReprPrintTIR);
348TVM_SCRIPT_REPR(tir::LENode, ReprPrintTIR);
349TVM_SCRIPT_REPR(tir::EQNode, ReprPrintTIR);
350TVM_SCRIPT_REPR(tir::NENode, ReprPrintTIR);
351TVM_SCRIPT_REPR(tir::GTNode, ReprPrintTIR);
352TVM_SCRIPT_REPR(tir::GENode, ReprPrintTIR);
353TVM_SCRIPT_REPR(tir::AndNode, ReprPrintTIR);
354TVM_SCRIPT_REPR(tir::OrNode, ReprPrintTIR);
355TVM_SCRIPT_REPR(tir::NotNode, ReprPrintTIR);
356TVM_SCRIPT_REPR(tir::SelectNode, ReprPrintTIR);
357TVM_SCRIPT_REPR(tir::RampNode, ReprPrintTIR);
358TVM_SCRIPT_REPR(tir::BroadcastNode, ReprPrintTIR);
359TVM_SCRIPT_REPR(tir::LetNode, ReprPrintTIR);
360TVM_SCRIPT_REPR(tir::CallNode, ReprPrintTIR);
361TVM_SCRIPT_REPR(tir::ShuffleNode, ReprPrintTIR);
362TVM_SCRIPT_REPR(tir::CommReducerNode, ReprPrintTIR);
363TVM_SCRIPT_REPR(tir::IndexMapNode, ReprPrintTIR);
364TVM_SCRIPT_REPR(tir::AnyNode, ReprPrintTIR);
365TVM_SCRIPT_REPR(tir::ReduceNode, ReprPrintTIR);
366TVM_SCRIPT_REPR(tir::LoadNode, ReprPrintTIR);
367
368} // namespace printer
369} // namespace script
370} // namespace tvm
371