1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/ir/tensor_type.h>
20
21#include "./utils.h"
22
23namespace tvm {
24namespace script {
25namespace printer {
26
27TVM_REGISTER_NODE_TYPE(IRFrameNode);
28
29struct SortableFunction {
30 int priority;
31 GlobalVar gv;
32 BaseFunc func;
33
34 explicit SortableFunction(const std::pair<GlobalVar, BaseFunc>& obj)
35 : priority(0), gv(obj.first), func(obj.second) {
36 if (gv->name_hint == "main") {
37 priority = 1000;
38 } else if (obj.second->GetTypeKey() == "tir.PrimFunc") {
39 priority = 1;
40 } else if (obj.second->GetTypeKey() == "relax.expr.ExternFunc") {
41 priority = 2;
42 } else if (obj.second->GetTypeKey() == "relax.expr.Function") {
43 priority = 3;
44 } else {
45 LOG(FATAL) << "TypeError: TVMScript cannot print functions of type: "
46 << obj.second->GetTypeKey();
47 }
48 }
49
50 bool operator<(const SortableFunction& other) const {
51 if (this->priority != other.priority) {
52 return this->priority < other.priority;
53 }
54 return this->gv->name_hint < other.gv->name_hint;
55 }
56};
57
58TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
59 .set_dispatch<IRModule>("", [](IRModule mod, ObjectPath p, IRDocsifier d) -> Doc {
60 std::vector<SortableFunction> functions;
61 for (const auto& kv : mod->functions) {
62 functions.push_back(SortableFunction(kv));
63 }
64 std::sort(functions.begin(), functions.end());
65 With<IRFrame> f(d);
66 (*f)->AddDispatchToken(d, "ir");
67 for (const auto& entry : functions) {
68 const GlobalVar& gv = entry.gv;
69 const BaseFunc& func = entry.func;
70 d->cfg->binding_names.push_back(gv->name_hint);
71 Doc doc = d->AsDoc(func, p->Attr("functions")->MapValue(gv));
72 d->cfg->binding_names.pop_back();
73 if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
74 (*f)->stmts.push_back(stmt_block->stmts.back());
75 } else if (const auto* stmt = doc.as<StmtDocNode>()) {
76 (*f)->stmts.push_back(GetRef<StmtDoc>(stmt));
77 } else {
78 (*f)->stmts.push_back(Downcast<FunctionDoc>(doc));
79 }
80 }
81 return HeaderWrapper(d, ClassDoc(IdDoc(GetBindingName(d).value_or("Module")),
82 {IR(d, "ir_module")}, (*f)->stmts));
83 });
84
85TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
86 .set_dispatch<DictAttrs>("", [](DictAttrs attrs, ObjectPath p, IRDocsifier d) -> Doc {
87 return d->AsDoc(attrs->dict, p->Attr("dict"));
88 });
89
90TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
91 .set_dispatch<GlobalVar>("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc {
92 return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))});
93 });
94
95TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
96 .set_dispatch<Op>("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc {
97 return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))});
98 });
99
100TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
101 .set_dispatch<TypeVar>("", [](TypeVar var, ObjectPath p, IRDocsifier d) -> Doc {
102 return IR(d, "TypeVar")
103 ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), //
104 LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))});
105 });
106
107TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
108 .set_dispatch<GlobalTypeVar>( //
109 "", [](GlobalTypeVar var, ObjectPath p, IRDocsifier d) -> Doc {
110 return IR(d, "GlobalTypeVar")
111 ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")),
112 LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))});
113 });
114
115TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
116 .set_dispatch<RelayRefType>("", [](RelayRefType ref, ObjectPath p, IRDocsifier d) -> Doc {
117 return IR(d, "RelayRef")->Call({d->AsDoc<ExprDoc>(ref->value, p->Attr("value"))});
118 });
119
120TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
121 .set_dispatch<TensorType>("", [](TensorType type, ObjectPath p, IRDocsifier d) -> Doc {
122 return IR(d, "TensorType")
123 ->Call({d->AsDoc<ExprDoc>(type->shape, p->Attr("shape")),
124 LiteralDoc::DataType(type->dtype, p->Attr("dtype"))});
125 });
126
127TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
128 .set_dispatch<FuncType>("", [](FuncType func_type, ObjectPath p, IRDocsifier d) -> Doc {
129 return IR(d, "FuncType")
130 ->Call({
131 d->AsDoc<ExprDoc>(func_type->type_params, p->Attr("type_params")),
132 d->AsDoc<ExprDoc>(func_type->arg_types, p->Attr("arg_types")),
133 d->AsDoc<ExprDoc>(func_type->ret_type, p->Attr("ret_type")),
134 });
135 });
136
137TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
138 .set_dispatch<IncompleteType>("", [](IncompleteType ty, ObjectPath p, IRDocsifier d) -> Doc {
139 return IR(d, "IncompleteType")->Call({});
140 });
141
142std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) {
143 if (const auto* f = runtime::Registry::Get("relay.ir.PrintRelayModule")) {
144 if (Optional<String> s = (*f)(mod)) {
145 return s.value();
146 }
147 }
148 return ReprPrintIR(mod, cfg);
149}
150
151TVM_SCRIPT_REPR(TypeVarNode, ReprPrintIR);
152TVM_SCRIPT_REPR(GlobalTypeVarNode, ReprPrintIR);
153TVM_SCRIPT_REPR(GlobalVarNode, ReprPrintIR);
154TVM_SCRIPT_REPR(DictAttrsNode, ReprPrintIR);
155TVM_SCRIPT_REPR(RelayRefTypeNode, ReprPrintIR);
156TVM_SCRIPT_REPR(FuncTypeNode, ReprPrintIR);
157TVM_SCRIPT_REPR(IncompleteTypeNode, ReprPrintIR);
158TVM_SCRIPT_REPR(IRModuleNode, ReprPrintIRModule);
159
160} // namespace printer
161} // namespace script
162} // namespace tvm
163