1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/tir/builtin.h> |
20 | |
21 | #include "./utils.h" |
22 | |
23 | namespace tvm { |
24 | namespace script { |
25 | namespace printer { |
26 | |
27 | Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) { |
28 | if (!d->IsVarDefined(var)) { |
29 | if (Optional<Frame> opt_f = FindLowestVarDef(var, d)) { |
30 | ExprDoc lhs = DefineVar(var, opt_f.value(), d); |
31 | Type type = var->type_annotation; |
32 | if (const auto* ptr_type = type.as<PointerTypeNode>()) { |
33 | ICHECK(ptr_type->element_type->IsInstance<PrimTypeNode>()); |
34 | ExprDoc rhs = d->AsDoc<ExprDoc>(type, var_p->Attr("type_annotation" )); |
35 | opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); |
36 | } else { |
37 | ExprDoc rhs = TIR(d, "var" )->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype" ))}); |
38 | opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); |
39 | } |
40 | } else { |
41 | LOG(WARNING) << "Didn't find variable definition for: " << var->name_hint; |
42 | } |
43 | } |
44 | if (Optional<ExprDoc> doc = d->GetVarDoc(var)) { |
45 | return doc.value(); |
46 | } |
47 | LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << var->name_hint; |
48 | } |
49 | |
50 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // |
51 | .set_dispatch<tir::Var>("" , [](tir::Var var, ObjectPath p, IRDocsifier d) -> Doc { |
52 | return PrintVar(var, p, d); |
53 | }); |
54 | |
55 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // |
56 | .set_dispatch<tir::SizeVar>("" , [](tir::SizeVar var, ObjectPath p, IRDocsifier d) -> Doc { |
57 | return PrintVar(var, p, d); |
58 | }); |
59 | |
60 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
61 | .set_dispatch<tir::IterVar>("" , [](tir::IterVar var, ObjectPath var_p, IRDocsifier d) -> Doc { |
62 | return TIR(d, "iter_var" ) |
63 | ->Call({ |
64 | d->AsDoc<ExprDoc>(var->var, var_p->Attr("var" )), |
65 | d->AsDoc<ExprDoc>(var->dom, var_p->Attr("dom" )), |
66 | LiteralDoc::Str(IterVarType2String(var->iter_type), var_p->Attr("iter_type" )), |
67 | LiteralDoc::Str(var->thread_tag, var_p->Attr("thread_tag" )), |
68 | }); |
69 | }); |
70 | |
71 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
72 | .set_dispatch<tir::Not>("" , [](tir::Not node, ObjectPath p, IRDocsifier d) -> Doc { |
73 | ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a" )); |
74 | if (a->IsInstance<LiteralDocNode>()) { |
75 | return TIR(d, "Not" )->Call({a}); |
76 | } |
77 | return OperationDoc(OperationDocNode::Kind::kNot, {a}); |
78 | }); |
79 | |
80 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
81 | .set_dispatch<tir::StringImm>("" , [](tir::StringImm s, ObjectPath p, IRDocsifier d) -> Doc { |
82 | if (HasMultipleLines(s->value)) { |
83 | return d->AddMetadata(s); |
84 | } else { |
85 | return d->AsDoc<ExprDoc>(s->value, p->Attr("value" )); |
86 | } |
87 | }); |
88 | |
89 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
90 | .set_dispatch<tir::Cast>("" , [](tir::Cast cast, ObjectPath p, IRDocsifier d) -> Doc { |
91 | ExprDoc dtype = LiteralDoc::DataType(cast->dtype, p->Attr("dtype" )); |
92 | ExprDoc value = d->AsDoc<ExprDoc>(cast->value, p->Attr("value" )); |
93 | return TIR(d, "Cast" )->Call({dtype, value}); |
94 | }); |
95 | |
96 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
97 | .set_dispatch<tir::Select>("" , [](tir::Select select, ObjectPath p, IRDocsifier d) -> Doc { |
98 | return TIR(d, "Select" ) |
99 | ->Call({ |
100 | d->AsDoc<ExprDoc>(select->condition, p->Attr("condition" )), |
101 | d->AsDoc<ExprDoc>(select->true_value, p->Attr("true_value" )), |
102 | d->AsDoc<ExprDoc>(select->false_value, p->Attr("false_value" )), |
103 | }); |
104 | }); |
105 | |
106 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
107 | .set_dispatch<tir::Ramp>("" , [](tir::Ramp ramp, ObjectPath ramp_p, IRDocsifier d) -> Doc { |
108 | return TIR(d, "Ramp" )->Call({ |
109 | d->AsDoc<ExprDoc>(ramp->base, ramp_p->Attr("base" )), |
110 | d->AsDoc<ExprDoc>(ramp->stride, ramp_p->Attr("stride" )), |
111 | LiteralDoc::Int(ramp->lanes, ramp_p->Attr("lanes" )), |
112 | }); |
113 | }); |
114 | |
115 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
116 | .set_dispatch<tir::Broadcast>("" , [](tir::Broadcast bc, ObjectPath bc_p, IRDocsifier d) -> Doc { |
117 | return TIR(d, "Broadcast" ) |
118 | ->Call({ |
119 | d->AsDoc<ExprDoc>(bc->value, bc_p->Attr("value" )), |
120 | LiteralDoc::Int(bc->lanes, bc_p->Attr("lanes" )), |
121 | }); |
122 | }); |
123 | |
124 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
125 | .set_dispatch<tir::Shuffle>( // |
126 | "" , [](tir::Shuffle shuffle, ObjectPath p, IRDocsifier d) -> Doc { |
127 | return TIR(d, "Shuffle" ) |
128 | ->Call({ |
129 | d->AsDoc<ExprDoc>(shuffle->vectors, p->Attr("vectors" )), |
130 | d->AsDoc<ExprDoc>(shuffle->indices, p->Attr("indices" )), |
131 | }); |
132 | }); |
133 | |
134 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
135 | .set_dispatch<tir::CommReducer>( // |
136 | "" , [](tir::CommReducer r, ObjectPath p, IRDocsifier d) -> Doc { |
137 | ICHECK_EQ(r->lhs.size(), r->rhs.size()); |
138 | LambdaDoc lambda{nullptr}; |
139 | { |
140 | With<TIRFrame> f(d, r); |
141 | int n_vars = r->lhs.size(); |
142 | Array<IdDoc> vars; |
143 | vars.reserve(n_vars + n_vars); |
144 | for (int i = 0; i < n_vars; ++i) { |
145 | vars.push_back(Downcast<IdDoc>(DefineVar(r->lhs[i], *f, d))); |
146 | } |
147 | for (int i = 0; i < n_vars; ++i) { |
148 | vars.push_back(Downcast<IdDoc>(DefineVar(r->rhs[i], *f, d))); |
149 | } |
150 | int n_results = r->result.size(); |
151 | Array<ExprDoc> results; |
152 | results.reserve(n_results); |
153 | for (int i = 0; i < n_results; ++i) { |
154 | results.push_back(d->AsDoc<ExprDoc>(r->result[i], p->Attr("result" )->ArrayIndex(i))); |
155 | } |
156 | if (results.size() == 1) { |
157 | lambda = LambdaDoc(vars, results[0]); |
158 | } else { |
159 | lambda = LambdaDoc(vars, TupleDoc(results)); |
160 | } |
161 | } |
162 | ExprDoc id = d->AsDoc<ExprDoc>(r->identity_element, p->Attr("identity_element" )); |
163 | return TIR(d, "comm_reducer" )->Call({lambda, id}); |
164 | }); |
165 | |
166 | LambdaDoc PrintIndexMap(const ObjectRef& map, const Array<tir::Var>& vs, const ObjectPath& vs_p, |
167 | const Array<PrimExpr>& es, const ObjectPath& es_p, const IRDocsifier& d) { |
168 | With<TIRFrame> f(d, map); |
169 | Array<IdDoc> vars; |
170 | for (int i = 0, l = vs.size(); i < l; ++i) { |
171 | vars.push_back(Downcast<IdDoc>(DefineVar(vs[i], *f, d))); |
172 | } |
173 | Array<ExprDoc> exprs; |
174 | for (int i = 0, l = es.size(); i < l; ++i) { |
175 | exprs.push_back(d->AsDoc<ExprDoc>(es[i], es_p->ArrayIndex(i))); |
176 | } |
177 | return LambdaDoc(vars, TupleDoc(exprs)); |
178 | } |
179 | |
180 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
181 | .set_dispatch<tir::IndexMap>( // |
182 | "" , [](tir::IndexMap m, ObjectPath m_p, IRDocsifier d) -> Doc { |
183 | LambdaDoc map = PrintIndexMap(m, m->initial_indices, m_p->Attr("initial_indices" ), |
184 | m->final_indices, m_p->Attr("final_indices" ), d); |
185 | if (m->inverse_index_map.defined()) { |
186 | tir::IndexMap inverse = Downcast<tir::IndexMap>(m->inverse_index_map); |
187 | LambdaDoc inv = PrintIndexMap(inverse, inverse->initial_indices, |
188 | m_p->Attr("inverse_index_map" )->Attr("initial_indices" ), |
189 | inverse->final_indices, |
190 | m_p->Attr("inverse_index_map" )->Attr("final_indices" ), d); |
191 | return TIR(d, "index_map" )->Call({map}, {"inverse_index_map" }, {inv}); |
192 | } else { |
193 | return TIR(d, "index_map" )->Call({map}); |
194 | } |
195 | }); |
196 | |
197 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
198 | .set_dispatch<tir::Let>("" , [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc { |
199 | return TIR(d, "let" )->Call({ |
200 | d->AsDoc<ExprDoc>(let->var, p->Attr("var" )), |
201 | d->AsDoc<ExprDoc>(let->value, p->Attr("value" )), |
202 | d->AsDoc<ExprDoc>(let->body, p->Attr("body" )), |
203 | }); |
204 | }); |
205 | |
206 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
207 | .set_dispatch<tir::Call>("" , [](tir::Call call, ObjectPath call_p, IRDocsifier d) -> Doc { |
208 | static const OpAttrMap<tir::TScriptPrinterName>& op_names = |
209 | Op::GetAttrMap<tir::TScriptPrinterName>("TScriptPrinterName" ); |
210 | static const std::unordered_set<const Object*> dtype_first_arg = { |
211 | tir::builtin::reinterpret().get(), |
212 | tir::builtin::call_extern().get(), |
213 | tir::builtin::call_llvm_intrin().get(), // |
214 | tir::builtin::call_llvm_pure_intrin().get(), // |
215 | tir::builtin::call_pure_extern().get(), // |
216 | tir::builtin::ptx_mma().get(), |
217 | tir::builtin::ptx_mma_sp().get(), |
218 | tir::builtin::ptx_ldmatrix().get(), |
219 | tir::builtin::ptx_cp_async().get(), |
220 | tir::builtin::mma_store().get(), |
221 | tir::builtin::mma_fill().get(), |
222 | tir::builtin::vectorlow().get(), |
223 | tir::builtin::vectorhigh().get(), |
224 | tir::builtin::vectorcombine().get(), |
225 | Op::Get("tir.type_annotation" ).get(), |
226 | }; |
227 | static const std::unordered_set<const Object*> dtype_last_arg = { |
228 | tir::builtin::tvm_struct_get().get(), |
229 | }; |
230 | ExprDoc prefix{nullptr}; |
231 | if (const auto* op = call->op.as<OpNode>()) { |
232 | String name = op_names.get(GetRef<Op>(op), op->name); |
233 | if (op_names.count(GetRef<Op>(op)) == 0) { |
234 | LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; |
235 | } |
236 | prefix = TIR(d, name); |
237 | } else if (const auto* gv = call->op.as<GlobalVarNode>()) { |
238 | prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op" )); |
239 | } else { |
240 | LOG(FATAL) << "call: " << call; |
241 | } |
242 | Array<ExprDoc> args; |
243 | int n_args = call->args.size(); |
244 | args.reserve(n_args + 1); |
245 | if (dtype_first_arg.count(call->op.get())) { |
246 | args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype" ))); |
247 | } |
248 | for (int i = 0; i < n_args; ++i) { |
249 | args.push_back(d->AsDoc<ExprDoc>(call->args[i], call_p->Attr("args" )->ArrayIndex(i))); |
250 | } |
251 | if (dtype_last_arg.count(call->op.get())) { |
252 | args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype" ))); |
253 | } |
254 | return prefix->Call(args); |
255 | }); |
256 | |
257 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
258 | .set_dispatch<tir::Any>("" , [](tir::Any any, ObjectPath p, IRDocsifier d) -> Doc { |
259 | return TIR(d, "Any" )->Call({}); |
260 | }); |
261 | |
262 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
263 | .set_dispatch<tir::Reduce>("" , [](tir::Reduce r, ObjectPath p, IRDocsifier d) -> Doc { |
264 | ExprDoc combiner = d->AsDoc<ExprDoc>(r->combiner, p->Attr("combiner" )); |
265 | ExprDoc source = d->AsDoc<ExprDoc>(r->source, p->Attr("source" )); |
266 | ExprDoc init = d->AsDoc<ExprDoc>(r->init, p->Attr("init" )); |
267 | ExprDoc axis = d->AsDoc<ExprDoc>(r->axis, p->Attr("axis" )); |
268 | ExprDoc condition = d->AsDoc<ExprDoc>(r->condition, p->Attr("condition" )); |
269 | ExprDoc value_index = LiteralDoc::Int(r->value_index, p->Attr("value_index" )); |
270 | return TIR(d, "reduce" ) |
271 | ->Call({combiner}, {"source" , "init" , "axis" , "condition" , "value_index" }, |
272 | {source, init, axis, condition, value_index}); |
273 | LOG(FATAL) << "ValueError: Reduce should never exist in TIR: " << r; |
274 | }); |
275 | |
276 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
277 | .set_dispatch<tir::Load>("" , [](tir::Load load, ObjectPath p, IRDocsifier d) -> Doc { |
278 | LOG(FATAL) << "ValueError: Load has been deprecated for BufferLoad: " << load; |
279 | }); |
280 | |
281 | #define TVM_SCRIPT_PRINTER_DEF_BINARY(NodeType, OpString) \ |
282 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ |
283 | .set_dispatch<tir::NodeType>("", \ |
284 | [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ |
285 | ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \ |
286 | ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \ |
287 | return TIR(d, OpString)->Call({a, b}); \ |
288 | }); |
289 | |
290 | bool IsNumber(const ExprDoc& e) { |
291 | if (const auto* n = e.as<LiteralDocNode>()) { |
292 | if (n->value.defined()) { |
293 | return n->value->IsInstance<IntImmNode>() || n->value->IsInstance<FloatImmNode>(); |
294 | } |
295 | } |
296 | return false; |
297 | } |
298 | |
299 | #define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \ |
300 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) \ |
301 | .set_dispatch<tir::NodeType>("", \ |
302 | [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ |
303 | ExprDoc a = d->AsDoc<ExprDoc>(node->a, p->Attr("a")); \ |
304 | ExprDoc b = d->AsDoc<ExprDoc>(node->b, p->Attr("b")); \ |
305 | if (IsNumber(a) && IsNumber(b)) { \ |
306 | return TIR(d, OpString)->Call({a, b}); \ |
307 | } \ |
308 | return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ |
309 | }); |
310 | |
311 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Add, "Add" , kAdd); |
312 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Sub, "Sub" , kSub); |
313 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Mul, "Mul" , kMult); |
314 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Div, "Div" , kDiv); |
315 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorDiv, "FloorDiv" , kFloorDiv); |
316 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(FloorMod, "FloorMod" , kMod); |
317 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LT, "LT" , kLt); |
318 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(LE, "LE" , kLtE); |
319 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(EQ, "EQ" , kEq); |
320 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NE, "NE" , kNotEq); |
321 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GT, "GT" , kGt); |
322 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(GE, "GE" , kGtE); |
323 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(And, "And" , kAnd); |
324 | TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(Or, "Or" , kOr); |
325 | |
326 | TVM_SCRIPT_PRINTER_DEF_BINARY(Mod, "truncmod" ); |
327 | TVM_SCRIPT_PRINTER_DEF_BINARY(Min, "min" ); |
328 | TVM_SCRIPT_PRINTER_DEF_BINARY(Max, "max" ); |
329 | |
330 | #undef TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR |
331 | #undef TVM_SCRIPT_PRINTER_DEF_BINARY |
332 | |
333 | TVM_SCRIPT_REPR(tir::VarNode, ReprPrintTIR); |
334 | TVM_SCRIPT_REPR(tir::SizeVarNode, ReprPrintTIR); |
335 | TVM_SCRIPT_REPR(tir::IterVarNode, ReprPrintTIR); |
336 | TVM_SCRIPT_REPR(tir::StringImmNode, ReprPrintTIR); |
337 | TVM_SCRIPT_REPR(tir::CastNode, ReprPrintTIR); |
338 | TVM_SCRIPT_REPR(tir::AddNode, ReprPrintTIR); |
339 | TVM_SCRIPT_REPR(tir::SubNode, ReprPrintTIR); |
340 | TVM_SCRIPT_REPR(tir::MulNode, ReprPrintTIR); |
341 | TVM_SCRIPT_REPR(tir::DivNode, ReprPrintTIR); |
342 | TVM_SCRIPT_REPR(tir::ModNode, ReprPrintTIR); |
343 | TVM_SCRIPT_REPR(tir::FloorDivNode, ReprPrintTIR); |
344 | TVM_SCRIPT_REPR(tir::FloorModNode, ReprPrintTIR); |
345 | TVM_SCRIPT_REPR(tir::MinNode, ReprPrintTIR); |
346 | TVM_SCRIPT_REPR(tir::MaxNode, ReprPrintTIR); |
347 | TVM_SCRIPT_REPR(tir::LTNode, ReprPrintTIR); |
348 | TVM_SCRIPT_REPR(tir::LENode, ReprPrintTIR); |
349 | TVM_SCRIPT_REPR(tir::EQNode, ReprPrintTIR); |
350 | TVM_SCRIPT_REPR(tir::NENode, ReprPrintTIR); |
351 | TVM_SCRIPT_REPR(tir::GTNode, ReprPrintTIR); |
352 | TVM_SCRIPT_REPR(tir::GENode, ReprPrintTIR); |
353 | TVM_SCRIPT_REPR(tir::AndNode, ReprPrintTIR); |
354 | TVM_SCRIPT_REPR(tir::OrNode, ReprPrintTIR); |
355 | TVM_SCRIPT_REPR(tir::NotNode, ReprPrintTIR); |
356 | TVM_SCRIPT_REPR(tir::SelectNode, ReprPrintTIR); |
357 | TVM_SCRIPT_REPR(tir::RampNode, ReprPrintTIR); |
358 | TVM_SCRIPT_REPR(tir::BroadcastNode, ReprPrintTIR); |
359 | TVM_SCRIPT_REPR(tir::LetNode, ReprPrintTIR); |
360 | TVM_SCRIPT_REPR(tir::CallNode, ReprPrintTIR); |
361 | TVM_SCRIPT_REPR(tir::ShuffleNode, ReprPrintTIR); |
362 | TVM_SCRIPT_REPR(tir::CommReducerNode, ReprPrintTIR); |
363 | TVM_SCRIPT_REPR(tir::IndexMapNode, ReprPrintTIR); |
364 | TVM_SCRIPT_REPR(tir::AnyNode, ReprPrintTIR); |
365 | TVM_SCRIPT_REPR(tir::ReduceNode, ReprPrintTIR); |
366 | TVM_SCRIPT_REPR(tir::LoadNode, ReprPrintTIR); |
367 | |
368 | } // namespace printer |
369 | } // namespace script |
370 | } // namespace tvm |
371 | |