1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file text_printer.cc |
22 | * \brief Printer to print out the unified IR text format |
23 | * that can be parsed by a parser. |
24 | */ |
25 | |
26 | #include "./text_printer.h" |
27 | |
28 | #include <tvm/tir/function.h> |
29 | |
30 | #include <algorithm> |
31 | #include <string> |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | |
36 | static const char* kSemVer = "0.0.5" ; |
37 | |
38 | Doc TextPrinter::PrintMod(const IRModule& mod) { |
39 | Doc doc; |
40 | int counter = 0; |
41 | |
42 | // We'll print in alphabetical order to make a/b diffs easier to work with. |
43 | |
44 | // type definitions |
45 | std::vector<GlobalTypeVar> tyvars; |
46 | for (const auto& kv : mod->type_definitions) { |
47 | tyvars.emplace_back(kv.first); |
48 | } |
49 | std::sort(tyvars.begin(), tyvars.end(), |
50 | [](const GlobalTypeVar& left, const GlobalTypeVar& right) { |
51 | return left->name_hint < right->name_hint; |
52 | }); |
53 | for (const auto& tyvar : tyvars) { |
54 | if (counter++ != 0) { |
55 | doc << Doc::NewLine(); |
56 | } |
57 | doc << relay_text_printer_.Print(mod->type_definitions[tyvar]); |
58 | doc << Doc::NewLine(); |
59 | } |
60 | |
61 | // functions |
62 | std::vector<GlobalVar> vars; |
63 | for (const auto& kv : mod->functions) { |
64 | vars.emplace_back(kv.first); |
65 | } |
66 | std::sort(vars.begin(), vars.end(), [](const GlobalVar& left, const GlobalVar& right) { |
67 | return left->name_hint < right->name_hint; |
68 | }); |
69 | for (const auto& var : vars) { |
70 | const BaseFunc& base_func = mod->functions[var]; |
71 | if (base_func.as<relay::FunctionNode>()) { |
72 | relay_text_printer_.dg_ = |
73 | relay::DependencyGraph::Create(&relay_text_printer_.arena_, base_func); |
74 | } |
75 | if (counter++ != 0) { |
76 | doc << Doc::NewLine(); |
77 | } |
78 | if (base_func.as<relay::FunctionNode>()) { |
79 | std::ostringstream os; |
80 | os << "def @" << var->name_hint; |
81 | doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), base_func); |
82 | } else if (base_func.as<tir::PrimFuncNode>()) { |
83 | doc << "@" << var->name_hint; |
84 | doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(base_func)); |
85 | } |
86 | doc << Doc::NewLine(); |
87 | } |
88 | |
89 | #if TVM_LOG_DEBUG |
90 | // attributes |
91 | // TODO(mbs): Make this official, including support from parser. |
92 | if (mod->attrs.defined() && !mod->attrs->dict.empty()) { |
93 | std::vector<String> keys; |
94 | for (const auto& kv : mod->attrs->dict) { |
95 | keys.emplace_back(kv.first); |
96 | } |
97 | std::sort(keys.begin(), keys.end()); |
98 | doc << "attributes {" << Doc::NewLine(); |
99 | for (const auto& key : keys) { |
100 | doc << " '" << key << "' = " << PrettyPrint(mod->attrs->dict[key]) << Doc::NewLine(); |
101 | } |
102 | doc << "}" << Doc::NewLine(); |
103 | } |
104 | #endif |
105 | |
106 | return doc; |
107 | } |
108 | |
109 | String PrettyPrint(const ObjectRef& node) { |
110 | Doc doc; |
111 | doc << TextPrinter(/*show_meta_data=*/false, nullptr, false).PrintFinal(node); |
112 | return doc.str(); |
113 | } |
114 | |
115 | String AsText(const ObjectRef& node, bool show_meta_data, |
116 | runtime::TypedPackedFunc<String(ObjectRef)> annotate) { |
117 | Doc doc; |
118 | doc << "#[version = \"" << kSemVer << "\"]" << Doc::NewLine(); |
119 | runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr; |
120 | if (annotate != nullptr) { |
121 | ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>( |
122 | [&annotate](const ObjectRef& expr) -> std::string { return annotate(expr); }); |
123 | } |
124 | doc << TextPrinter(show_meta_data, ftyped).PrintFinal(node); |
125 | return doc.str(); |
126 | } |
127 | |
128 | TVM_REGISTER_GLOBAL("relay.ir.PrettyPrint" ).set_body_typed(PrettyPrint); |
129 | TVM_REGISTER_GLOBAL("relay.ir.AsText" ).set_body_typed(AsText); |
130 | |
131 | } // namespace relay |
132 | } // namespace tvm |
133 | |