1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file text_printer.cc
22 * \brief Printer to print out the unified IR text format
23 * that can be parsed by a parser.
24 */
25
26#include "./text_printer.h"
27
28#include <tvm/tir/function.h>
29
30#include <algorithm>
31#include <string>
32
33namespace tvm {
34namespace relay {
35
36static const char* kSemVer = "0.0.5";
37
38Doc TextPrinter::PrintMod(const IRModule& mod) {
39 Doc doc;
40 int counter = 0;
41
42 // We'll print in alphabetical order to make a/b diffs easier to work with.
43
44 // type definitions
45 std::vector<GlobalTypeVar> tyvars;
46 for (const auto& kv : mod->type_definitions) {
47 tyvars.emplace_back(kv.first);
48 }
49 std::sort(tyvars.begin(), tyvars.end(),
50 [](const GlobalTypeVar& left, const GlobalTypeVar& right) {
51 return left->name_hint < right->name_hint;
52 });
53 for (const auto& tyvar : tyvars) {
54 if (counter++ != 0) {
55 doc << Doc::NewLine();
56 }
57 doc << relay_text_printer_.Print(mod->type_definitions[tyvar]);
58 doc << Doc::NewLine();
59 }
60
61 // functions
62 std::vector<GlobalVar> vars;
63 for (const auto& kv : mod->functions) {
64 vars.emplace_back(kv.first);
65 }
66 std::sort(vars.begin(), vars.end(), [](const GlobalVar& left, const GlobalVar& right) {
67 return left->name_hint < right->name_hint;
68 });
69 for (const auto& var : vars) {
70 const BaseFunc& base_func = mod->functions[var];
71 if (base_func.as<relay::FunctionNode>()) {
72 relay_text_printer_.dg_ =
73 relay::DependencyGraph::Create(&relay_text_printer_.arena_, base_func);
74 }
75 if (counter++ != 0) {
76 doc << Doc::NewLine();
77 }
78 if (base_func.as<relay::FunctionNode>()) {
79 std::ostringstream os;
80 os << "def @" << var->name_hint;
81 doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), base_func);
82 } else if (base_func.as<tir::PrimFuncNode>()) {
83 doc << "@" << var->name_hint;
84 doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(base_func));
85 }
86 doc << Doc::NewLine();
87 }
88
89#if TVM_LOG_DEBUG
90 // attributes
91 // TODO(mbs): Make this official, including support from parser.
92 if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
93 std::vector<String> keys;
94 for (const auto& kv : mod->attrs->dict) {
95 keys.emplace_back(kv.first);
96 }
97 std::sort(keys.begin(), keys.end());
98 doc << "attributes {" << Doc::NewLine();
99 for (const auto& key : keys) {
100 doc << " '" << key << "' = " << PrettyPrint(mod->attrs->dict[key]) << Doc::NewLine();
101 }
102 doc << "}" << Doc::NewLine();
103 }
104#endif
105
106 return doc;
107}
108
109String PrettyPrint(const ObjectRef& node) {
110 Doc doc;
111 doc << TextPrinter(/*show_meta_data=*/false, nullptr, false).PrintFinal(node);
112 return doc.str();
113}
114
115String AsText(const ObjectRef& node, bool show_meta_data,
116 runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
117 Doc doc;
118 doc << "#[version = \"" << kSemVer << "\"]" << Doc::NewLine();
119 runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
120 if (annotate != nullptr) {
121 ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
122 [&annotate](const ObjectRef& expr) -> std::string { return annotate(expr); });
123 }
124 doc << TextPrinter(show_meta_data, ftyped).PrintFinal(node);
125 return doc.str();
126}
127
128TVM_REGISTER_GLOBAL("relay.ir.PrettyPrint").set_body_typed(PrettyPrint);
129TVM_REGISTER_GLOBAL("relay.ir.AsText").set_body_typed(AsText);
130
131} // namespace relay
132} // namespace tvm
133