1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_SCRIPT_PRINTER_UTILS_H_
20#define TVM_SCRIPT_PRINTER_UTILS_H_
21
22#include <tvm/node/serialization.h>
23#include <tvm/script/printer/ir_docsifier.h>
24
25#include <string>
26#include <unordered_set>
27#include <utility>
28#include <vector>
29
30#include "../../support/str_escape.h"
31
32namespace tvm {
33namespace script {
34namespace printer {
35
36#define TVM_SCRIPT_REPR(ObjectType, Method) \
37 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) \
38 .set_dispatch<ObjectType>(RedirectedReprPrinterMethod); \
39 TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter, vtable).set_dispatch<ObjectType>(Method);
40
41inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) {
42 try {
43 p->stream << TVMScriptPrinter::Script(obj, NullOpt);
44 } catch (const tvm::Error& e) {
45 if (ReprLegacyPrinter::CanDispatch(obj)) {
46 LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n"
47 << e.what();
48 try {
49 p->stream << AsLegacyRepr(obj);
50 } catch (const tvm::Error& e) {
51 LOG(WARNING) << "AsLegacyRepr fails. Falling back to the basic address printer";
52 }
53 } else {
54 LOG(WARNING) << "TVMScript printer falls back to the basic address printer with the error:\n"
55 << e.what();
56 }
57 p->stream << obj->GetTypeKey() << '(' << obj.get() << ')';
58 }
59}
60
61inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f,
62 const PrinterConfig& cfg) {
63 Doc doc = d->AsDoc(obj, ObjectPath::Root());
64 if (const auto* expr_doc = doc.as<ExprDocNode>()) {
65 if (!cfg->verbose_expr) {
66 f->stmts.clear();
67 }
68 f->stmts.push_back(ExprStmtDoc(GetRef<ExprDoc>(expr_doc)));
69 } else if (const auto* stmt_doc = doc.as<StmtDocNode>()) {
70 f->stmts.push_back(GetRef<StmtDoc>(stmt_doc));
71 } else if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
72 for (const StmtDoc& d : stmt_block->stmts) {
73 f->stmts.push_back(d);
74 }
75 } else {
76 LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey();
77 }
78 std::ostringstream os;
79 if (!d->metadata.empty()) {
80 if (d->cfg->show_meta) {
81 os << "metadata = tvm.ir.load_json(\""
82 << support::StrEscape(
83 SaveJSON(Map<String, ObjectRef>(d->metadata.begin(), d->metadata.end())))
84 << "\")\n";
85 } else {
86 f->stmts.push_back(
87 CommentDoc("Metadata omitted. Use show_meta=True in script() method to show it."));
88 }
89 }
90 os << DocToPythonScript(StmtBlockDoc(f->stmts), cfg);
91 return os.str();
92}
93
94/*! \brief Creates the IR common prefix, which is by default `I` */
95inline ExprDoc IR(const IRDocsifier& d, const String& attr) {
96 d->ir_usage.insert("ir");
97 return IdDoc(d->cfg->ir_prefix)->Attr(attr);
98}
99
100/*! \brief Creates the TIR common prefix, which is by default `T` */
101inline ExprDoc TIR(const IRDocsifier& d, const String& attr) {
102 d->ir_usage.insert("tir");
103 return IdDoc(d->cfg->tir_prefix)->Attr(attr);
104}
105
106/*! \brief Creates the TIR common prefix, which is by default `T` */
107inline ExprDoc Relax(const IRDocsifier& d, const String& attr) {
108 d->ir_usage.insert("relax");
109 return IdDoc(d->cfg->relax_prefix)->Attr(attr);
110}
111
112inline std::string DType2Str(const runtime::DataType& dtype) {
113 return dtype.is_void() ? "void" : runtime::DLDataType2String(dtype);
114}
115
116/*! \brief Add headers as comments to doc if needed */
117inline Doc HeaderWrapper(const IRDocsifier& d, const Doc& doc) {
118 if (d->ir_usage.size()) {
119 Array<StmtDoc> stmts;
120 if (d->ir_usage.count("ir")) {
121 stmts.push_back(CommentDoc("from tvm.script import ir as " + d->cfg->ir_prefix));
122 }
123 if (d->ir_usage.count("tir")) {
124 stmts.push_back(CommentDoc("from tvm.script import tir as " + d->cfg->tir_prefix));
125 }
126 if (d->ir_usage.count("relax")) {
127 stmts.push_back(CommentDoc("from tvm.script import relax as " + d->cfg->relax_prefix));
128 }
129 stmts.push_back(CommentDoc(""));
130 stmts.push_back(Downcast<StmtDoc>(doc));
131 return StmtBlockDoc(stmts);
132 }
133 return doc;
134}
135
136/*! \brief Check if a string has multiple lines. */
137inline bool HasMultipleLines(const std::string& str) {
138 return str.find_first_of('\n') != std::string::npos;
139}
140
141inline Optional<String> GetBindingName(const IRDocsifier& d) {
142 return d->cfg->binding_names.empty() ? Optional<String>(NullOpt) : d->cfg->binding_names.back();
143}
144
145inline Optional<String> FindFunctionName(const IRDocsifier& d, const BaseFunc& f) {
146 if (Optional<String> name = GetBindingName(d)) {
147 return name.value();
148 }
149 if (Optional<String> sym = f->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
150 return sym.value();
151 }
152 return NullOpt;
153}
154
155inline String GenerateUniqueName(std::string name_hint,
156 const std::unordered_set<String>& defined_names) {
157 for (char& c : name_hint) {
158 if (c != '_' && !std::isalnum(c)) {
159 c = '_';
160 }
161 }
162 std::string name = name_hint;
163 for (int i = 1; defined_names.count(name) > 0; ++i) {
164 name = name_hint + "_" + std::to_string(i);
165 }
166 return name;
167}
168
169} // namespace printer
170} // namespace script
171} // namespace tvm
172
173#endif // TVM_SCRIPT_PRINTER_UTILS_H_
174