1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file text_printer.cc
22 * \brief Printer to print out the unified IR text format
23 * that can be parsed by a parser.
24 */
25
26#include "text_printer.h"
27
28#include <tvm/tir/function.h>
29
30#include <algorithm>
31#include <string>
32
33namespace tvm {
34
35static const char* kSemVer = "0.0.5";
36
37Doc TextPrinter::PrintMod(const IRModule& mod) {
38 Doc doc;
39 int counter = 0;
40
41 // We'll print in alphabetical order to make a/b diffs easier to work with.
42
43 // type definitions
44 std::vector<GlobalTypeVar> tyvars;
45 for (const auto& kv : mod->type_definitions) {
46 tyvars.emplace_back(kv.first);
47 }
48 std::sort(tyvars.begin(), tyvars.end(),
49 [](const GlobalTypeVar& left, const GlobalTypeVar& right) {
50 return left->name_hint < right->name_hint;
51 });
52 for (const auto& tyvar : tyvars) {
53 if (counter++ != 0) {
54 doc << Doc::NewLine();
55 }
56 doc << relay_text_printer_.Print(mod->type_definitions[tyvar]);
57 doc << Doc::NewLine();
58 }
59
60 // functions
61 std::vector<GlobalVar> vars;
62 for (const auto& kv : mod->functions) {
63 vars.emplace_back(kv.first);
64 }
65 std::sort(vars.begin(), vars.end(), [](const GlobalVar& left, const GlobalVar& right) {
66 return left->name_hint < right->name_hint;
67 });
68 for (const auto& var : vars) {
69 const BaseFunc& base_func = mod->functions[var];
70 if (base_func.as<relay::FunctionNode>()) {
71 relay_text_printer_.dg_ =
72 relay::DependencyGraph::Create(&relay_text_printer_.arena_, base_func);
73 }
74 if (counter++ != 0) {
75 doc << Doc::NewLine();
76 }
77 if (base_func.as<relay::FunctionNode>()) {
78 std::ostringstream os;
79 os << "def @" << var->name_hint;
80 doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), base_func);
81 } else if (base_func.as<tir::PrimFuncNode>()) {
82 doc << "@" << var->name_hint;
83 doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(base_func));
84 }
85 doc << Doc::NewLine();
86 }
87
88#if TVM_LOG_DEBUG
89 // attributes
90 // TODO(mbs): Make this official, including support from parser.
91 if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
92 std::vector<String> keys;
93 for (const auto& kv : mod->attrs->dict) {
94 keys.emplace_back(kv.first);
95 }
96 std::sort(keys.begin(), keys.end());
97 doc << "attributes {" << Doc::NewLine();
98 for (const auto& key : keys) {
99 doc << " '" << key << "' = " << PrettyPrint(mod->attrs->dict[key]) << Doc::NewLine();
100 }
101 doc << "}" << Doc::NewLine();
102 }
103#endif
104
105 return doc;
106}
107
108String PrettyPrint(const ObjectRef& node) {
109 Doc doc;
110 doc << TextPrinter(/*show_meta_data=*/false, nullptr, false).PrintFinal(node);
111 return doc.str();
112}
113
114String AsText(const ObjectRef& node, bool show_meta_data,
115 runtime::TypedPackedFunc<String(ObjectRef)> annotate) {
116 Doc doc;
117 doc << "#[version = \"" << kSemVer << "\"]" << Doc::NewLine();
118 runtime::TypedPackedFunc<std::string(ObjectRef)> ftyped = nullptr;
119 if (annotate != nullptr) {
120 ftyped = runtime::TypedPackedFunc<std::string(ObjectRef)>(
121 [&annotate](const ObjectRef& expr) -> std::string { return annotate(expr); });
122 }
123 doc << TextPrinter(show_meta_data, ftyped).PrintFinal(node);
124 return doc.str();
125}
126
127TVM_REGISTER_GLOBAL("ir.PrettyPrint").set_body_typed(PrettyPrint);
128
129TVM_REGISTER_GLOBAL("ir.AsText").set_body_typed(AsText);
130
131} // namespace tvm
132