1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/ir/expr.h>
20#include <tvm/node/repr_printer.h>
21#include <tvm/node/script_printer.h>
22#include <tvm/runtime/registry.h>
23
24namespace tvm {
25
26TVMScriptPrinter::FType& TVMScriptPrinter::vtable() {
27 static FType inst;
28 return inst;
29}
30
31std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional<PrinterConfig>& cfg) {
32 if (!TVMScriptPrinter::vtable().can_dispatch(node)) {
33 return AsLegacyRepr(node);
34 }
35 return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig()));
36}
37
38PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
39 runtime::ObjectPtr<PrinterConfigNode> n = make_object<PrinterConfigNode>();
40 if (auto v = config_dict.Get("name")) {
41 n->binding_names.push_back(Downcast<String>(v));
42 }
43 if (auto v = config_dict.Get("show_meta")) {
44 n->show_meta = Downcast<IntImm>(v)->value;
45 }
46 if (auto v = config_dict.Get("ir_prefix")) {
47 n->ir_prefix = Downcast<String>(v);
48 }
49 if (auto v = config_dict.Get("tir_prefix")) {
50 n->tir_prefix = Downcast<String>(v);
51 }
52 if (auto v = config_dict.Get("relax_prefix")) {
53 n->relax_prefix = Downcast<String>(v);
54 }
55 if (auto v = config_dict.Get("buffer_dtype")) {
56 n->buffer_dtype = DataType(runtime::String2DLDataType(Downcast<String>(v)));
57 }
58 if (auto v = config_dict.Get("int_dtype")) {
59 n->int_dtype = DataType(runtime::String2DLDataType(Downcast<String>(v)));
60 }
61 if (auto v = config_dict.Get("float_dtype")) {
62 n->float_dtype = DataType(runtime::String2DLDataType(Downcast<String>(v)));
63 }
64 if (auto v = config_dict.Get("verbose_expr")) {
65 n->verbose_expr = Downcast<IntImm>(v)->value;
66 }
67 if (auto v = config_dict.Get("indent_spaces")) {
68 n->indent_spaces = Downcast<IntImm>(v)->value;
69 }
70 if (auto v = config_dict.Get("print_line_numbers")) {
71 n->print_line_numbers = Downcast<IntImm>(v)->value;
72 }
73 if (auto v = config_dict.Get("num_context_lines")) {
74 n->num_context_lines = Downcast<IntImm>(v)->value;
75 }
76 if (auto v = config_dict.Get("path_to_underline")) {
77 n->path_to_underline = Downcast<Optional<Array<ObjectPath>>>(v).value_or(Array<ObjectPath>());
78 }
79 if (auto v = config_dict.Get("path_to_annotate")) {
80 n->path_to_annotate =
81 Downcast<Optional<Map<ObjectPath, String>>>(v).value_or(Map<ObjectPath, String>());
82 }
83 if (auto v = config_dict.Get("obj_to_underline")) {
84 n->obj_to_underline = Downcast<Optional<Array<ObjectRef>>>(v).value_or(Array<ObjectRef>());
85 }
86 if (auto v = config_dict.Get("obj_to_annotate")) {
87 n->obj_to_annotate =
88 Downcast<Optional<Map<ObjectRef, String>>>(v).value_or(Map<ObjectRef, String>());
89 }
90 if (auto v = config_dict.Get("syntax_sugar")) {
91 n->syntax_sugar = Downcast<IntImm>(v)->value;
92 }
93 this->data_ = std::move(n);
94}
95
96TVM_REGISTER_NODE_TYPE(PrinterConfigNode);
97TVM_REGISTER_GLOBAL("node.PrinterConfig").set_body_typed([](Map<String, ObjectRef> config_dict) {
98 return PrinterConfig(config_dict);
99});
100TVM_REGISTER_GLOBAL("node.TVMScriptPrinterScript").set_body_typed(TVMScriptPrinter::Script);
101
102} // namespace tvm
103