1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/runtime/container/array.h>
20#include <tvm/runtime/logging.h>
21#include <tvm/runtime/registry.h>
22#include <tvm/script/printer/doc.h>
23
24namespace tvm {
25namespace script {
26namespace printer {
27
28ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef<ExprDoc>(this), attr); }
29
30ExprDoc ExprDocNode::operator[](Array<Doc> indices) const {
31 return IndexDoc(GetRef<ExprDoc>(this), indices);
32}
33
34ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args) const {
35 return CallDoc(GetRef<ExprDoc>(this), args, Array<String>(), Array<ExprDoc>());
36}
37
38ExprDoc ExprDocNode::Call(Array<ExprDoc, void> args, Array<String, void> kwargs_keys,
39 Array<ExprDoc, void> kwargs_values) const {
40 return CallDoc(GetRef<ExprDoc>(this), args, kwargs_keys, kwargs_values);
41}
42
43ExprDoc ExprDoc::operator[](Array<Doc> indices) const { return (*get())[indices]; }
44
45StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) {
46 ObjectPtr<StmtBlockDocNode> n = make_object<StmtBlockDocNode>();
47 n->stmts = stmts;
48 this->data_ = std::move(n);
49}
50
51LiteralDoc::LiteralDoc(ObjectRef value, const Optional<ObjectPath>& object_path) {
52 ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
53 n->value = value;
54 if (object_path.defined()) {
55 n->source_paths.push_back(object_path.value());
56 }
57 this->data_ = std::move(n);
58}
59
60IdDoc::IdDoc(String name) {
61 ObjectPtr<IdDocNode> n = make_object<IdDocNode>();
62 n->name = name;
63 this->data_ = std::move(n);
64}
65
66AttrAccessDoc::AttrAccessDoc(ExprDoc value, String name) {
67 ObjectPtr<AttrAccessDocNode> n = make_object<AttrAccessDocNode>();
68 n->value = value;
69 n->name = name;
70 this->data_ = std::move(n);
71}
72
73IndexDoc::IndexDoc(ExprDoc value, Array<Doc> indices) {
74 ObjectPtr<IndexDocNode> n = make_object<IndexDocNode>();
75 n->value = value;
76 n->indices = indices;
77 this->data_ = std::move(n);
78}
79
80CallDoc::CallDoc(ExprDoc callee, Array<ExprDoc> args, Array<String> kwargs_keys,
81 Array<ExprDoc> kwargs_values) {
82 ObjectPtr<CallDocNode> n = make_object<CallDocNode>();
83 n->callee = callee;
84 n->args = args;
85 n->kwargs_keys = kwargs_keys;
86 n->kwargs_values = kwargs_values;
87 this->data_ = std::move(n);
88}
89
90OperationDoc::OperationDoc(OperationDocNode::Kind kind, Array<ExprDoc> operands) {
91 ObjectPtr<OperationDocNode> n = make_object<OperationDocNode>();
92 n->kind = kind;
93 n->operands = operands;
94 this->data_ = std::move(n);
95}
96
97LambdaDoc::LambdaDoc(Array<IdDoc> args, ExprDoc body) {
98 ObjectPtr<LambdaDocNode> n = make_object<LambdaDocNode>();
99 n->args = args;
100 n->body = body;
101 this->data_ = std::move(n);
102}
103
104TupleDoc::TupleDoc(Array<ExprDoc> elements) {
105 ObjectPtr<TupleDocNode> n = make_object<TupleDocNode>();
106 n->elements = elements;
107 this->data_ = std::move(n);
108}
109
110ListDoc::ListDoc(Array<ExprDoc> elements) {
111 ObjectPtr<ListDocNode> n = make_object<ListDocNode>();
112 n->elements = elements;
113 this->data_ = std::move(n);
114}
115
116DictDoc::DictDoc(Array<ExprDoc> keys, Array<ExprDoc> values) {
117 ObjectPtr<DictDocNode> n = make_object<DictDocNode>();
118 n->keys = keys;
119 n->values = values;
120 this->data_ = std::move(n);
121}
122
123SliceDoc::SliceDoc(Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<ExprDoc> step) {
124 ObjectPtr<SliceDocNode> n = make_object<SliceDocNode>();
125 n->start = start;
126 n->stop = stop;
127 n->step = step;
128 this->data_ = std::move(n);
129}
130
131AssignDoc::AssignDoc(ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation) {
132 CHECK(rhs.defined() || annotation.defined())
133 << "ValueError: At least one of rhs and annotation needs to be non-null for AssignDoc.";
134 CHECK(lhs->IsInstance<IdDocNode>() || annotation == nullptr)
135 << "ValueError: annotation can only be nonnull if lhs is an identifier.";
136
137 ObjectPtr<AssignDocNode> n = make_object<AssignDocNode>();
138 n->lhs = lhs;
139 n->rhs = rhs;
140 n->annotation = annotation;
141 this->data_ = std::move(n);
142}
143
144IfDoc::IfDoc(ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch) {
145 CHECK(!then_branch.empty() || !else_branch.empty())
146 << "ValueError: At least one of the then branch or else branch needs to be non-empty.";
147
148 ObjectPtr<IfDocNode> n = make_object<IfDocNode>();
149 n->predicate = predicate;
150 n->then_branch = then_branch;
151 n->else_branch = else_branch;
152 this->data_ = std::move(n);
153}
154
155WhileDoc::WhileDoc(ExprDoc predicate, Array<StmtDoc> body) {
156 ObjectPtr<WhileDocNode> n = make_object<WhileDocNode>();
157 n->predicate = predicate;
158 n->body = body;
159 this->data_ = std::move(n);
160}
161
162ForDoc::ForDoc(ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body) {
163 ObjectPtr<ForDocNode> n = make_object<ForDocNode>();
164 n->lhs = lhs;
165 n->rhs = rhs;
166 n->body = body;
167 this->data_ = std::move(n);
168}
169
170ScopeDoc::ScopeDoc(Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body) {
171 ObjectPtr<ScopeDocNode> n = make_object<ScopeDocNode>();
172 n->lhs = lhs;
173 n->rhs = rhs;
174 n->body = body;
175 this->data_ = std::move(n);
176}
177
178ScopeDoc::ScopeDoc(ExprDoc rhs, Array<StmtDoc> body) {
179 ObjectPtr<ScopeDocNode> n = make_object<ScopeDocNode>();
180 n->lhs = NullOpt;
181 n->rhs = rhs;
182 n->body = body;
183 this->data_ = std::move(n);
184}
185
186ExprStmtDoc::ExprStmtDoc(ExprDoc expr) {
187 ObjectPtr<ExprStmtDocNode> n = make_object<ExprStmtDocNode>();
188 n->expr = expr;
189 this->data_ = std::move(n);
190}
191
192AssertDoc::AssertDoc(ExprDoc test, Optional<ExprDoc> msg) {
193 ObjectPtr<AssertDocNode> n = make_object<AssertDocNode>();
194 n->test = test;
195 n->msg = msg;
196 this->data_ = std::move(n);
197}
198
199ReturnDoc::ReturnDoc(ExprDoc value) {
200 ObjectPtr<ReturnDocNode> n = make_object<ReturnDocNode>();
201 n->value = value;
202 this->data_ = std::move(n);
203}
204
205FunctionDoc::FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
206 Optional<ExprDoc> return_type, Array<StmtDoc> body) {
207 ObjectPtr<FunctionDocNode> n = make_object<FunctionDocNode>();
208 n->name = name;
209 n->args = args;
210 n->decorators = decorators;
211 n->return_type = return_type;
212 n->body = body;
213 this->data_ = std::move(n);
214}
215
216ClassDoc::ClassDoc(IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) {
217 ObjectPtr<ClassDocNode> n = make_object<ClassDocNode>();
218 n->name = name;
219 n->decorators = decorators;
220 n->body = body;
221 this->data_ = std::move(n);
222}
223
224CommentDoc::CommentDoc(String comment) {
225 ObjectPtr<CommentDocNode> n = make_object<CommentDocNode>();
226 n->comment = comment;
227 this->data_ = std::move(n);
228}
229
230DocStringDoc::DocStringDoc(String docs) {
231 ObjectPtr<DocStringDocNode> n = make_object<DocStringDocNode>();
232 n->comment = docs;
233 this->data_ = std::move(n);
234}
235
236TVM_REGISTER_NODE_TYPE(DocNode);
237TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths")
238 .set_body_typed([](Doc doc, Array<ObjectPath> source_paths) {
239 doc->source_paths = source_paths;
240 });
241
242TVM_REGISTER_NODE_TYPE(ExprDocNode);
243TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr")
244 .set_body_method<ExprDoc, ExprDocNode, ExprDoc, String>(&ExprDocNode::Attr);
245TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex")
246 .set_body_method<ExprDoc>(&ExprDocNode::operator[]);
247TVM_REGISTER_GLOBAL("script.printer.ExprDocCall")
248 .set_body_method<ExprDoc, ExprDocNode, ExprDoc, Array<ExprDoc>, Array<String>, Array<ExprDoc>>(
249 &ExprDocNode::Call);
250
251TVM_REGISTER_NODE_TYPE(StmtDocNode);
252TVM_REGISTER_GLOBAL("script.printer.StmtDocSetComment")
253 .set_body_typed([](StmtDoc doc, Optional<String> comment) { doc->comment = comment; });
254
255TVM_REGISTER_NODE_TYPE(StmtBlockDocNode);
256TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array<StmtDoc> stmts) {
257 return StmtBlockDoc(stmts);
258});
259
260TVM_REGISTER_NODE_TYPE(LiteralDocNode);
261TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
262TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
263TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean);
264TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float);
265TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str);
266
267TVM_REGISTER_NODE_TYPE(IdDocNode);
268TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); });
269
270TVM_REGISTER_NODE_TYPE(AttrAccessDocNode);
271TVM_REGISTER_GLOBAL("script.printer.AttrAccessDoc").set_body_typed([](ExprDoc value, String attr) {
272 return AttrAccessDoc(value, attr);
273});
274
275TVM_REGISTER_NODE_TYPE(IndexDocNode);
276TVM_REGISTER_GLOBAL("script.printer.IndexDoc")
277 .set_body_typed([](ExprDoc value, Array<Doc> indices) { return IndexDoc(value, indices); });
278
279TVM_REGISTER_NODE_TYPE(CallDocNode);
280TVM_REGISTER_GLOBAL("script.printer.CallDoc")
281 .set_body_typed([](ExprDoc callee, //
282 Array<ExprDoc> args, //
283 Array<String> kwargs_keys, //
284 Array<ExprDoc> kwargs_values) {
285 return CallDoc(callee, args, kwargs_keys, kwargs_values);
286 });
287
288TVM_REGISTER_NODE_TYPE(OperationDocNode);
289TVM_REGISTER_GLOBAL("script.printer.OperationDoc")
290 .set_body_typed([](int32_t kind, Array<ExprDoc> operands) {
291 return OperationDoc(OperationDocNode::Kind(kind), operands);
292 });
293
294TVM_REGISTER_NODE_TYPE(LambdaDocNode);
295TVM_REGISTER_GLOBAL("script.printer.LambdaDoc").set_body_typed([](Array<IdDoc> args, ExprDoc body) {
296 return LambdaDoc(args, body);
297});
298
299TVM_REGISTER_NODE_TYPE(TupleDocNode);
300TVM_REGISTER_GLOBAL("script.printer.TupleDoc").set_body_typed([](Array<ExprDoc> elements) {
301 return TupleDoc(elements);
302});
303
304TVM_REGISTER_NODE_TYPE(ListDocNode);
305TVM_REGISTER_GLOBAL("script.printer.ListDoc").set_body_typed([](Array<ExprDoc> elements) {
306 return ListDoc(elements);
307});
308
309TVM_REGISTER_NODE_TYPE(DictDocNode);
310TVM_REGISTER_GLOBAL("script.printer.DictDoc")
311 .set_body_typed([](Array<ExprDoc> keys, Array<ExprDoc> values) {
312 return DictDoc(keys, values);
313 });
314
315TVM_REGISTER_NODE_TYPE(SliceDocNode);
316TVM_REGISTER_GLOBAL("script.printer.SliceDoc")
317 .set_body_typed([](Optional<ExprDoc> start, Optional<ExprDoc> stop, Optional<ExprDoc> step) {
318 return SliceDoc(start, stop, step);
319 });
320
321TVM_REGISTER_NODE_TYPE(AssignDocNode);
322TVM_REGISTER_GLOBAL("script.printer.AssignDoc")
323 .set_body_typed([](ExprDoc lhs, Optional<ExprDoc> rhs, Optional<ExprDoc> annotation) {
324 return AssignDoc(lhs, rhs, annotation);
325 });
326
327TVM_REGISTER_NODE_TYPE(IfDocNode);
328TVM_REGISTER_GLOBAL("script.printer.IfDoc")
329 .set_body_typed([](ExprDoc predicate, Array<StmtDoc> then_branch, Array<StmtDoc> else_branch) {
330 return IfDoc(predicate, then_branch, else_branch);
331 });
332
333TVM_REGISTER_NODE_TYPE(WhileDocNode);
334TVM_REGISTER_GLOBAL("script.printer.WhileDoc")
335 .set_body_typed([](ExprDoc predicate, Array<StmtDoc> body) {
336 return WhileDoc(predicate, body);
337 });
338
339TVM_REGISTER_NODE_TYPE(ForDocNode);
340TVM_REGISTER_GLOBAL("script.printer.ForDoc")
341 .set_body_typed([](ExprDoc lhs, ExprDoc rhs, Array<StmtDoc> body) {
342 return ForDoc(lhs, rhs, body);
343 });
344
345TVM_REGISTER_NODE_TYPE(ScopeDocNode);
346TVM_REGISTER_GLOBAL("script.printer.ScopeDoc")
347 .set_body_typed([](Optional<ExprDoc> lhs, ExprDoc rhs, Array<StmtDoc> body) {
348 return ScopeDoc(lhs, rhs, body);
349 });
350
351TVM_REGISTER_NODE_TYPE(ExprStmtDocNode);
352TVM_REGISTER_GLOBAL("script.printer.ExprStmtDoc").set_body_typed([](ExprDoc expr) {
353 return ExprStmtDoc(expr);
354});
355
356TVM_REGISTER_NODE_TYPE(AssertDocNode);
357TVM_REGISTER_GLOBAL("script.printer.AssertDoc")
358 .set_body_typed([](ExprDoc test, Optional<ExprDoc> msg = NullOpt) {
359 return AssertDoc(test, msg);
360 });
361
362TVM_REGISTER_NODE_TYPE(ReturnDocNode);
363TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value) {
364 return ReturnDoc(value);
365});
366
367TVM_REGISTER_NODE_TYPE(FunctionDocNode);
368TVM_REGISTER_GLOBAL("script.printer.FunctionDoc")
369 .set_body_typed([](IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
370 Optional<ExprDoc> return_type, Array<StmtDoc> body) {
371 return FunctionDoc(name, args, decorators, return_type, body);
372 });
373
374TVM_REGISTER_NODE_TYPE(ClassDocNode);
375TVM_REGISTER_GLOBAL("script.printer.ClassDoc")
376 .set_body_typed([](IdDoc name, Array<ExprDoc> decorators, Array<StmtDoc> body) {
377 return ClassDoc(name, decorators, body);
378 });
379
380TVM_REGISTER_NODE_TYPE(CommentDocNode);
381TVM_REGISTER_GLOBAL("script.printer.CommentDoc").set_body_typed([](String comment) {
382 return CommentDoc(comment);
383});
384
385TVM_REGISTER_NODE_TYPE(DocStringDocNode);
386TVM_REGISTER_GLOBAL("script.printer.DocStringDoc").set_body_typed([](String docs) {
387 return DocStringDoc(docs);
388});
389
390} // namespace printer
391} // namespace script
392} // namespace tvm
393