1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/ir/tensor_type.h> |
20 | |
21 | #include "./utils.h" |
22 | |
23 | namespace tvm { |
24 | namespace script { |
25 | namespace printer { |
26 | |
27 | TVM_REGISTER_NODE_TYPE(IRFrameNode); |
28 | |
29 | struct SortableFunction { |
30 | int priority; |
31 | GlobalVar gv; |
32 | BaseFunc func; |
33 | |
34 | explicit SortableFunction(const std::pair<GlobalVar, BaseFunc>& obj) |
35 | : priority(0), gv(obj.first), func(obj.second) { |
36 | if (gv->name_hint == "main" ) { |
37 | priority = 1000; |
38 | } else if (obj.second->GetTypeKey() == "tir.PrimFunc" ) { |
39 | priority = 1; |
40 | } else if (obj.second->GetTypeKey() == "relax.expr.ExternFunc" ) { |
41 | priority = 2; |
42 | } else if (obj.second->GetTypeKey() == "relax.expr.Function" ) { |
43 | priority = 3; |
44 | } else { |
45 | LOG(FATAL) << "TypeError: TVMScript cannot print functions of type: " |
46 | << obj.second->GetTypeKey(); |
47 | } |
48 | } |
49 | |
50 | bool operator<(const SortableFunction& other) const { |
51 | if (this->priority != other.priority) { |
52 | return this->priority < other.priority; |
53 | } |
54 | return this->gv->name_hint < other.gv->name_hint; |
55 | } |
56 | }; |
57 | |
58 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
59 | .set_dispatch<IRModule>("" , [](IRModule mod, ObjectPath p, IRDocsifier d) -> Doc { |
60 | std::vector<SortableFunction> functions; |
61 | for (const auto& kv : mod->functions) { |
62 | functions.push_back(SortableFunction(kv)); |
63 | } |
64 | std::sort(functions.begin(), functions.end()); |
65 | With<IRFrame> f(d); |
66 | (*f)->AddDispatchToken(d, "ir" ); |
67 | for (const auto& entry : functions) { |
68 | const GlobalVar& gv = entry.gv; |
69 | const BaseFunc& func = entry.func; |
70 | d->cfg->binding_names.push_back(gv->name_hint); |
71 | Doc doc = d->AsDoc(func, p->Attr("functions" )->MapValue(gv)); |
72 | d->cfg->binding_names.pop_back(); |
73 | if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) { |
74 | (*f)->stmts.push_back(stmt_block->stmts.back()); |
75 | } else if (const auto* stmt = doc.as<StmtDocNode>()) { |
76 | (*f)->stmts.push_back(GetRef<StmtDoc>(stmt)); |
77 | } else { |
78 | (*f)->stmts.push_back(Downcast<FunctionDoc>(doc)); |
79 | } |
80 | } |
81 | return HeaderWrapper(d, ClassDoc(IdDoc(GetBindingName(d).value_or("Module" )), |
82 | {IR(d, "ir_module" )}, (*f)->stmts)); |
83 | }); |
84 | |
85 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
86 | .set_dispatch<DictAttrs>("" , [](DictAttrs attrs, ObjectPath p, IRDocsifier d) -> Doc { |
87 | return d->AsDoc(attrs->dict, p->Attr("dict" )); |
88 | }); |
89 | |
90 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
91 | .set_dispatch<GlobalVar>("" , [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc { |
92 | return IR(d, "GlobalVar" )->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint" ))}); |
93 | }); |
94 | |
95 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
96 | .set_dispatch<Op>("" , [](Op op, ObjectPath p, IRDocsifier d) -> Doc { |
97 | return IR(d, "Op" )->Call({LiteralDoc::Str(op->name, p->Attr("name" ))}); |
98 | }); |
99 | |
100 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
101 | .set_dispatch<TypeVar>("" , [](TypeVar var, ObjectPath p, IRDocsifier d) -> Doc { |
102 | return IR(d, "TypeVar" ) |
103 | ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint" )), // |
104 | LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind" ))}); |
105 | }); |
106 | |
107 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
108 | .set_dispatch<GlobalTypeVar>( // |
109 | "" , [](GlobalTypeVar var, ObjectPath p, IRDocsifier d) -> Doc { |
110 | return IR(d, "GlobalTypeVar" ) |
111 | ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint" )), |
112 | LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind" ))}); |
113 | }); |
114 | |
115 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
116 | .set_dispatch<RelayRefType>("" , [](RelayRefType ref, ObjectPath p, IRDocsifier d) -> Doc { |
117 | return IR(d, "RelayRef" )->Call({d->AsDoc<ExprDoc>(ref->value, p->Attr("value" ))}); |
118 | }); |
119 | |
120 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
121 | .set_dispatch<TensorType>("" , [](TensorType type, ObjectPath p, IRDocsifier d) -> Doc { |
122 | return IR(d, "TensorType" ) |
123 | ->Call({d->AsDoc<ExprDoc>(type->shape, p->Attr("shape" )), |
124 | LiteralDoc::DataType(type->dtype, p->Attr("dtype" ))}); |
125 | }); |
126 | |
127 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
128 | .set_dispatch<FuncType>("" , [](FuncType func_type, ObjectPath p, IRDocsifier d) -> Doc { |
129 | return IR(d, "FuncType" ) |
130 | ->Call({ |
131 | d->AsDoc<ExprDoc>(func_type->type_params, p->Attr("type_params" )), |
132 | d->AsDoc<ExprDoc>(func_type->arg_types, p->Attr("arg_types" )), |
133 | d->AsDoc<ExprDoc>(func_type->ret_type, p->Attr("ret_type" )), |
134 | }); |
135 | }); |
136 | |
137 | TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) |
138 | .set_dispatch<IncompleteType>("" , [](IncompleteType ty, ObjectPath p, IRDocsifier d) -> Doc { |
139 | return IR(d, "IncompleteType" )->Call({}); |
140 | }); |
141 | |
142 | std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) { |
143 | if (const auto* f = runtime::Registry::Get("relay.ir.PrintRelayModule" )) { |
144 | if (Optional<String> s = (*f)(mod)) { |
145 | return s.value(); |
146 | } |
147 | } |
148 | return ReprPrintIR(mod, cfg); |
149 | } |
150 | |
151 | TVM_SCRIPT_REPR(TypeVarNode, ReprPrintIR); |
152 | TVM_SCRIPT_REPR(GlobalTypeVarNode, ReprPrintIR); |
153 | TVM_SCRIPT_REPR(GlobalVarNode, ReprPrintIR); |
154 | TVM_SCRIPT_REPR(DictAttrsNode, ReprPrintIR); |
155 | TVM_SCRIPT_REPR(RelayRefTypeNode, ReprPrintIR); |
156 | TVM_SCRIPT_REPR(FuncTypeNode, ReprPrintIR); |
157 | TVM_SCRIPT_REPR(IncompleteTypeNode, ReprPrintIR); |
158 | TVM_SCRIPT_REPR(IRModuleNode, ReprPrintIRModule); |
159 | |
160 | } // namespace printer |
161 | } // namespace script |
162 | } // namespace tvm |
163 | |