1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/ir/expr.h> |
20 | #include <tvm/node/repr_printer.h> |
21 | #include <tvm/node/script_printer.h> |
22 | #include <tvm/runtime/registry.h> |
23 | |
24 | namespace tvm { |
25 | |
26 | TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { |
27 | static FType inst; |
28 | return inst; |
29 | } |
30 | |
31 | std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional<PrinterConfig>& cfg) { |
32 | if (!TVMScriptPrinter::vtable().can_dispatch(node)) { |
33 | return AsLegacyRepr(node); |
34 | } |
35 | return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig())); |
36 | } |
37 | |
38 | PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) { |
39 | runtime::ObjectPtr<PrinterConfigNode> n = make_object<PrinterConfigNode>(); |
40 | if (auto v = config_dict.Get("name" )) { |
41 | n->binding_names.push_back(Downcast<String>(v)); |
42 | } |
43 | if (auto v = config_dict.Get("show_meta" )) { |
44 | n->show_meta = Downcast<IntImm>(v)->value; |
45 | } |
46 | if (auto v = config_dict.Get("ir_prefix" )) { |
47 | n->ir_prefix = Downcast<String>(v); |
48 | } |
49 | if (auto v = config_dict.Get("tir_prefix" )) { |
50 | n->tir_prefix = Downcast<String>(v); |
51 | } |
52 | if (auto v = config_dict.Get("relax_prefix" )) { |
53 | n->relax_prefix = Downcast<String>(v); |
54 | } |
55 | if (auto v = config_dict.Get("buffer_dtype" )) { |
56 | n->buffer_dtype = DataType(runtime::String2DLDataType(Downcast<String>(v))); |
57 | } |
58 | if (auto v = config_dict.Get("int_dtype" )) { |
59 | n->int_dtype = DataType(runtime::String2DLDataType(Downcast<String>(v))); |
60 | } |
61 | if (auto v = config_dict.Get("float_dtype" )) { |
62 | n->float_dtype = DataType(runtime::String2DLDataType(Downcast<String>(v))); |
63 | } |
64 | if (auto v = config_dict.Get("verbose_expr" )) { |
65 | n->verbose_expr = Downcast<IntImm>(v)->value; |
66 | } |
67 | if (auto v = config_dict.Get("indent_spaces" )) { |
68 | n->indent_spaces = Downcast<IntImm>(v)->value; |
69 | } |
70 | if (auto v = config_dict.Get("print_line_numbers" )) { |
71 | n->print_line_numbers = Downcast<IntImm>(v)->value; |
72 | } |
73 | if (auto v = config_dict.Get("num_context_lines" )) { |
74 | n->num_context_lines = Downcast<IntImm>(v)->value; |
75 | } |
76 | if (auto v = config_dict.Get("path_to_underline" )) { |
77 | n->path_to_underline = Downcast<Optional<Array<ObjectPath>>>(v).value_or(Array<ObjectPath>()); |
78 | } |
79 | if (auto v = config_dict.Get("path_to_annotate" )) { |
80 | n->path_to_annotate = |
81 | Downcast<Optional<Map<ObjectPath, String>>>(v).value_or(Map<ObjectPath, String>()); |
82 | } |
83 | if (auto v = config_dict.Get("obj_to_underline" )) { |
84 | n->obj_to_underline = Downcast<Optional<Array<ObjectRef>>>(v).value_or(Array<ObjectRef>()); |
85 | } |
86 | if (auto v = config_dict.Get("obj_to_annotate" )) { |
87 | n->obj_to_annotate = |
88 | Downcast<Optional<Map<ObjectRef, String>>>(v).value_or(Map<ObjectRef, String>()); |
89 | } |
90 | if (auto v = config_dict.Get("syntax_sugar" )) { |
91 | n->syntax_sugar = Downcast<IntImm>(v)->value; |
92 | } |
93 | this->data_ = std::move(n); |
94 | } |
95 | |
96 | TVM_REGISTER_NODE_TYPE(PrinterConfigNode); |
97 | TVM_REGISTER_GLOBAL("node.PrinterConfig" ).set_body_typed([](Map<String, ObjectRef> config_dict) { |
98 | return PrinterConfig(config_dict); |
99 | }); |
100 | TVM_REGISTER_GLOBAL("node.TVMScriptPrinterScript" ).set_body_typed(TVMScriptPrinter::Script); |
101 | |
102 | } // namespace tvm |
103 | |