1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/runtime/container/array.h> |
20 | #include <tvm/runtime/logging.h> |
21 | #include <tvm/runtime/registry.h> |
22 | #include <tvm/script/printer/doc.h> |
23 | |
24 | namespace tvm { |
25 | namespace script { |
26 | namespace printer { |
27 | |
28 | ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef<ExprDoc>(this), attr); } |
29 | |
30 | ExprDoc ExprDocNode::operator[](Array<Doc> indices) const { |
31 | return IndexDoc(GetRef<ExprDoc>(this), indices); |
32 | } |
33 | |
34 | ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args) const { |
35 | return CallDoc(GetRef<ExprDoc>(this), args, Array<String>(), Array<ExprDoc>()); |
36 | } |
37 | |
38 | ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args, Array<String, void> kwargs_keys, |
39 | Array<ExprDoc, void> kwargs_values) const { |
40 | return CallDoc(GetRef<ExprDoc>(this), args, kwargs_keys, kwargs_values); |
41 | } |
42 | |
43 | ExprDoc ExprDoc::operator[](Array<Doc> indices) const { return (*get())[indices]; } |
44 | |
45 | StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) { |
46 | ObjectPtr<StmtBlockDocNode> n = make_object<StmtBlockDocNode>(); |
47 | n->stmts = stmts; |
48 | this->data_ = std::move(n); |
49 | } |
50 | |
51 | LiteralDoc::LiteralDoc(ObjectRef value, const Optional<ObjectPath>& object_path) { |
52 | ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>(); |
53 | n->value = value; |
54 | if (object_path.defined()) { |
55 | n->source_paths.push_back(object_path.value()); |
56 | } |
57 | this->data_ = std::move(n); |
58 | } |
59 | |
60 | IdDoc::IdDoc(String name) { |
61 | ObjectPtr<IdDocNode> n = make_object<IdDocNode>(); |
62 | n->name = name; |
63 | this->data_ = std::move(n); |
64 | } |
65 | |
66 | AttrAccessDoc::AttrAccessDoc(ExprDoc value, String name) { |
67 | ObjectPtr<AttrAccessDocNode> n = make_object<AttrAccessDocNode>(); |
68 | n->value = value; |
69 | n->name = name; |
70 | this->data_ = std::move(n); |
71 | } |
72 | |
73 | IndexDoc::IndexDoc(ExprDoc value, Array<Doc> indices) { |
74 | ObjectPtr<IndexDocNode> n = make_object<IndexDocNode>(); |
75 | n->value = value; |
76 | n->indices = indices; |
77 | this->data_ = std::move(n); |
78 | } |
79 | |
80 | CallDoc::CallDoc(ExprDoc callee, Array<ExprDoc> args, Array<String> kwargs_keys, |
81 | Array<ExprDoc> kwargs_values) { |
82 | ObjectPtr<CallDocNode> n = make_object<CallDocNode>(); |
83 | n->callee = callee; |
84 | n->args = args; |
85 | n->kwargs_keys = kwargs_keys; |
86 | n->kwargs_values = kwargs_values; |
87 | this->data_ = std::move(n); |
88 | } |
89 | |
90 | OperationDoc::OperationDoc(OperationDocNode::Kind kind, Array<ExprDoc> operands) { |
91 | ObjectPtr<OperationDocNode> n = make_object<OperationDocNode>(); |
92 | n->kind = kind; |
93 | n->operands = operands; |
94 | this->data_ = std::move(n); |
95 | } |
96 | |
97 | LambdaDoc::LambdaDoc(Array<IdDoc> args, ExprDoc body) { |
98 | ObjectPtr<LambdaDocNode> n = make_object<LambdaDocNode>(); |
99 | n->args = args; |
100 | n->body = body; |
101 | this->data_ = std::move(n); |
102 | } |
103 | |
104 | TupleDoc::TupleDoc(Array<ExprDoc> elements) { |
105 | ObjectPtr<TupleDocNode> n = make_object<TupleDocNode>(); |
106 | n->elements = elements; |
107 | this->data_ = std::move(n); |
108 | } |
109 | |
110 | ListDoc::ListDoc(Array<ExprDoc> elements) { |
111 | ObjectPtr<ListDocNode> n = make_object<ListDocNode>(); |
112 | n->elements = elements; |
113 | this->data_ = std::move(n); |
114 | } |
115 | |
116 | DictDoc::DictDoc(Array<ExprDoc> keys, Array<ExprDoc> values) { |
117 | ObjectPtr<DictDocNode> n = make_object<DictDocNode>(); |
118 | n->keys = keys; |
119 | n->values = values; |
120 | this->data_ = std::move(n); |
121 | } |
122 | |
123 | SliceDoc::SliceDoc(Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<ExprDoc> step) { |
124 | ObjectPtr<SliceDocNode> n = make_object<SliceDocNode>(); |
125 | n->start = start; |
126 | n->stop = stop; |
127 | n->step = step; |
128 | this->data_ = std::move(n); |
129 | } |
130 | |
131 | AssignDoc::AssignDoc(ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation) { |
132 | CHECK(rhs.defined() || annotation.defined()) |
133 | << "ValueError: At least one of rhs and annotation needs to be non-null for AssignDoc." ; |
134 | CHECK(lhs->IsInstance<IdDocNode>() || annotation == nullptr) |
135 | << "ValueError: annotation can only be nonnull if lhs is an identifier." ; |
136 | |
137 | ObjectPtr<AssignDocNode> n = make_object<AssignDocNode>(); |
138 | n->lhs = lhs; |
139 | n->rhs = rhs; |
140 | n->annotation = annotation; |
141 | this->data_ = std::move(n); |
142 | } |
143 | |
144 | IfDoc::IfDoc(ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch) { |
145 | CHECK(!then_branch.empty() || !else_branch.empty()) |
146 | << "ValueError: At least one of the then branch or else branch needs to be non-empty." ; |
147 | |
148 | ObjectPtr<IfDocNode> n = make_object<IfDocNode>(); |
149 | n->predicate = predicate; |
150 | n->then_branch = then_branch; |
151 | n->else_branch = else_branch; |
152 | this->data_ = std::move(n); |
153 | } |
154 | |
155 | WhileDoc::WhileDoc(ExprDoc predicate, Array<StmtDoc> body) { |
156 | ObjectPtr<WhileDocNode> n = make_object<WhileDocNode>(); |
157 | n->predicate = predicate; |
158 | n->body = body; |
159 | this->data_ = std::move(n); |
160 | } |
161 | |
162 | ForDoc::ForDoc(ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body) { |
163 | ObjectPtr<ForDocNode> n = make_object<ForDocNode>(); |
164 | n->lhs = lhs; |
165 | n->rhs = rhs; |
166 | n->body = body; |
167 | this->data_ = std::move(n); |
168 | } |
169 | |
170 | ScopeDoc::ScopeDoc(Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body) { |
171 | ObjectPtr<ScopeDocNode> n = make_object<ScopeDocNode>(); |
172 | n->lhs = lhs; |
173 | n->rhs = rhs; |
174 | n->body = body; |
175 | this->data_ = std::move(n); |
176 | } |
177 | |
178 | ScopeDoc::ScopeDoc(ExprDoc rhs, Array<StmtDoc> body) { |
179 | ObjectPtr<ScopeDocNode> n = make_object<ScopeDocNode>(); |
180 | n->lhs = NullOpt; |
181 | n->rhs = rhs; |
182 | n->body = body; |
183 | this->data_ = std::move(n); |
184 | } |
185 | |
186 | ExprStmtDoc::ExprStmtDoc(ExprDoc expr) { |
187 | ObjectPtr<ExprStmtDocNode> n = make_object<ExprStmtDocNode>(); |
188 | n->expr = expr; |
189 | this->data_ = std::move(n); |
190 | } |
191 | |
192 | AssertDoc::AssertDoc(ExprDoc test, Optional<ExprDoc> msg) { |
193 | ObjectPtr<AssertDocNode> n = make_object<AssertDocNode>(); |
194 | n->test = test; |
195 | n->msg = msg; |
196 | this->data_ = std::move(n); |
197 | } |
198 | |
199 | ReturnDoc::ReturnDoc(ExprDoc value) { |
200 | ObjectPtr<ReturnDocNode> n = make_object<ReturnDocNode>(); |
201 | n->value = value; |
202 | this->data_ = std::move(n); |
203 | } |
204 | |
205 | FunctionDoc::FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators, |
206 | Optional<ExprDoc> return_type, Array<StmtDoc> body) { |
207 | ObjectPtr<FunctionDocNode> n = make_object<FunctionDocNode>(); |
208 | n->name = name; |
209 | n->args = args; |
210 | n->decorators = decorators; |
211 | n->return_type = return_type; |
212 | n->body = body; |
213 | this->data_ = std::move(n); |
214 | } |
215 | |
216 | ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) { |
217 | ObjectPtr<ClassDocNode> n = make_object<ClassDocNode>(); |
218 | n->name = name; |
219 | n->decorators = decorators; |
220 | n->body = body; |
221 | this->data_ = std::move(n); |
222 | } |
223 | |
224 | CommentDoc::(String ) { |
225 | ObjectPtr<CommentDocNode> n = make_object<CommentDocNode>(); |
226 | n->comment = comment; |
227 | this->data_ = std::move(n); |
228 | } |
229 | |
230 | DocStringDoc::DocStringDoc(String docs) { |
231 | ObjectPtr<DocStringDocNode> n = make_object<DocStringDocNode>(); |
232 | n->comment = docs; |
233 | this->data_ = std::move(n); |
234 | } |
235 | |
236 | TVM_REGISTER_NODE_TYPE(DocNode); |
237 | TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths" ) |
238 | .set_body_typed([](Doc doc, Array<ObjectPath> source_paths) { |
239 | doc->source_paths = source_paths; |
240 | }); |
241 | |
242 | TVM_REGISTER_NODE_TYPE(ExprDocNode); |
243 | TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr" ) |
244 | .set_body_method<ExprDoc, ExprDocNode, ExprDoc, String>(&ExprDocNode::Attr); |
245 | TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex" ) |
246 | .set_body_method<ExprDoc>(&ExprDocNode::operator[]); |
247 | TVM_REGISTER_GLOBAL("script.printer.ExprDocCall" ) |
248 | .set_body_method<ExprDoc, ExprDocNode, ExprDoc, Array<ExprDoc>, Array<String>, Array<ExprDoc>>( |
249 | &ExprDocNode::Call); |
250 | |
251 | TVM_REGISTER_NODE_TYPE(StmtDocNode); |
252 | TVM_REGISTER_GLOBAL("script.printer.StmtDocSetComment" ) |
253 | .set_body_typed([](StmtDoc doc, Optional<String> ) { doc->comment = comment; }); |
254 | |
255 | TVM_REGISTER_NODE_TYPE(StmtBlockDocNode); |
256 | TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc" ).set_body_typed([](Array<StmtDoc> stmts) { |
257 | return StmtBlockDoc(stmts); |
258 | }); |
259 | |
260 | TVM_REGISTER_NODE_TYPE(LiteralDocNode); |
261 | TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone" ).set_body_typed(LiteralDoc::None); |
262 | TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt" ).set_body_typed(LiteralDoc::Int); |
263 | TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean" ).set_body_typed(LiteralDoc::Boolean); |
264 | TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat" ).set_body_typed(LiteralDoc::Float); |
265 | TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr" ).set_body_typed(LiteralDoc::Str); |
266 | |
267 | TVM_REGISTER_NODE_TYPE(IdDocNode); |
268 | TVM_REGISTER_GLOBAL("script.printer.IdDoc" ).set_body_typed([](String name) { return IdDoc(name); }); |
269 | |
270 | TVM_REGISTER_NODE_TYPE(AttrAccessDocNode); |
271 | TVM_REGISTER_GLOBAL("script.printer.AttrAccessDoc" ).set_body_typed([](ExprDoc value, String attr) { |
272 | return AttrAccessDoc(value, attr); |
273 | }); |
274 | |
275 | TVM_REGISTER_NODE_TYPE(IndexDocNode); |
276 | TVM_REGISTER_GLOBAL("script.printer.IndexDoc" ) |
277 | .set_body_typed([](ExprDoc value, Array<Doc> indices) { return IndexDoc(value, indices); }); |
278 | |
279 | TVM_REGISTER_NODE_TYPE(CallDocNode); |
280 | TVM_REGISTER_GLOBAL("script.printer.CallDoc" ) |
281 | .set_body_typed([](ExprDoc callee, // |
282 | Array<ExprDoc> args, // |
283 | Array<String> kwargs_keys, // |
284 | Array<ExprDoc> kwargs_values) { |
285 | return CallDoc(callee, args, kwargs_keys, kwargs_values); |
286 | }); |
287 | |
288 | TVM_REGISTER_NODE_TYPE(OperationDocNode); |
289 | TVM_REGISTER_GLOBAL("script.printer.OperationDoc" ) |
290 | .set_body_typed([](int32_t kind, Array<ExprDoc> operands) { |
291 | return OperationDoc(OperationDocNode::Kind(kind), operands); |
292 | }); |
293 | |
294 | TVM_REGISTER_NODE_TYPE(LambdaDocNode); |
295 | TVM_REGISTER_GLOBAL("script.printer.LambdaDoc" ).set_body_typed([](Array<IdDoc> args, ExprDoc body) { |
296 | return LambdaDoc(args, body); |
297 | }); |
298 | |
299 | TVM_REGISTER_NODE_TYPE(TupleDocNode); |
300 | TVM_REGISTER_GLOBAL("script.printer.TupleDoc" ).set_body_typed([](Array<ExprDoc> elements) { |
301 | return TupleDoc(elements); |
302 | }); |
303 | |
304 | TVM_REGISTER_NODE_TYPE(ListDocNode); |
305 | TVM_REGISTER_GLOBAL("script.printer.ListDoc" ).set_body_typed([](Array<ExprDoc> elements) { |
306 | return ListDoc(elements); |
307 | }); |
308 | |
309 | TVM_REGISTER_NODE_TYPE(DictDocNode); |
310 | TVM_REGISTER_GLOBAL("script.printer.DictDoc" ) |
311 | .set_body_typed([](Array<ExprDoc> keys, Array<ExprDoc> values) { |
312 | return DictDoc(keys, values); |
313 | }); |
314 | |
315 | TVM_REGISTER_NODE_TYPE(SliceDocNode); |
316 | TVM_REGISTER_GLOBAL("script.printer.SliceDoc" ) |
317 | .set_body_typed([](Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<ExprDoc> step) { |
318 | return SliceDoc(start, stop, step); |
319 | }); |
320 | |
321 | TVM_REGISTER_NODE_TYPE(AssignDocNode); |
322 | TVM_REGISTER_GLOBAL("script.printer.AssignDoc" ) |
323 | .set_body_typed([](ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation) { |
324 | return AssignDoc(lhs, rhs, annotation); |
325 | }); |
326 | |
327 | TVM_REGISTER_NODE_TYPE(IfDocNode); |
328 | TVM_REGISTER_GLOBAL("script.printer.IfDoc" ) |
329 | .set_body_typed([](ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch) { |
330 | return IfDoc(predicate, then_branch, else_branch); |
331 | }); |
332 | |
333 | TVM_REGISTER_NODE_TYPE(WhileDocNode); |
334 | TVM_REGISTER_GLOBAL("script.printer.WhileDoc" ) |
335 | .set_body_typed([](ExprDoc predicate, Array<StmtDoc> body) { |
336 | return WhileDoc(predicate, body); |
337 | }); |
338 | |
339 | TVM_REGISTER_NODE_TYPE(ForDocNode); |
340 | TVM_REGISTER_GLOBAL("script.printer.ForDoc" ) |
341 | .set_body_typed([](ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body) { |
342 | return ForDoc(lhs, rhs, body); |
343 | }); |
344 | |
345 | TVM_REGISTER_NODE_TYPE(ScopeDocNode); |
346 | TVM_REGISTER_GLOBAL("script.printer.ScopeDoc" ) |
347 | .set_body_typed([](Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body) { |
348 | return ScopeDoc(lhs, rhs, body); |
349 | }); |
350 | |
351 | TVM_REGISTER_NODE_TYPE(ExprStmtDocNode); |
352 | TVM_REGISTER_GLOBAL("script.printer.ExprStmtDoc" ).set_body_typed([](ExprDoc expr) { |
353 | return ExprStmtDoc(expr); |
354 | }); |
355 | |
356 | TVM_REGISTER_NODE_TYPE(AssertDocNode); |
357 | TVM_REGISTER_GLOBAL("script.printer.AssertDoc" ) |
358 | .set_body_typed([](ExprDoc test, Optional<ExprDoc> msg = NullOpt) { |
359 | return AssertDoc(test, msg); |
360 | }); |
361 | |
362 | TVM_REGISTER_NODE_TYPE(ReturnDocNode); |
363 | TVM_REGISTER_GLOBAL("script.printer.ReturnDoc" ).set_body_typed([](ExprDoc value) { |
364 | return ReturnDoc(value); |
365 | }); |
366 | |
367 | TVM_REGISTER_NODE_TYPE(FunctionDocNode); |
368 | TVM_REGISTER_GLOBAL("script.printer.FunctionDoc" ) |
369 | .set_body_typed([](IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators, |
370 | Optional<ExprDoc> return_type, Array<StmtDoc> body) { |
371 | return FunctionDoc(name, args, decorators, return_type, body); |
372 | }); |
373 | |
374 | TVM_REGISTER_NODE_TYPE(ClassDocNode); |
375 | TVM_REGISTER_GLOBAL("script.printer.ClassDoc" ) |
376 | .set_body_typed([](IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) { |
377 | return ClassDoc(name, decorators, body); |
378 | }); |
379 | |
380 | TVM_REGISTER_NODE_TYPE(CommentDocNode); |
381 | TVM_REGISTER_GLOBAL("script.printer.CommentDoc" ).set_body_typed([](String ) { |
382 | return CommentDoc(comment); |
383 | }); |
384 | |
385 | TVM_REGISTER_NODE_TYPE(DocStringDocNode); |
386 | TVM_REGISTER_GLOBAL("script.printer.DocStringDoc" ).set_body_typed([](String docs) { |
387 | return DocStringDoc(docs); |
388 | }); |
389 | |
390 | } // namespace printer |
391 | } // namespace script |
392 | } // namespace tvm |
393 | |