1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/runtime/container/base.h> |
20 | #include <tvm/runtime/logging.h> |
21 | #include <tvm/runtime/registry.h> |
22 | #include <tvm/script/printer/ir_docsifier.h> |
23 | |
24 | #include "./utils.h" |
25 | |
26 | namespace tvm { |
27 | namespace script { |
28 | namespace printer { |
29 | |
30 | IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) { |
31 | ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; |
32 | String name = GenerateUniqueName(name_hint, this->defined_names); |
33 | this->defined_names.insert(name); |
34 | DocCreator doc_factory = [name]() { return IdDoc(name); }; |
35 | obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}}); |
36 | IdDoc def_doc(name); |
37 | frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); |
38 | return def_doc; |
39 | } |
40 | |
41 | void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory) { |
42 | ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; |
43 | obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}}); |
44 | frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); }); |
45 | } |
46 | |
47 | Optional<ExprDoc> IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const { |
48 | auto it = obj2info.find(obj); |
49 | if (it == obj2info.end()) { |
50 | return NullOpt; |
51 | } |
52 | return it->second.creator(); |
53 | } |
54 | |
55 | ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) { |
56 | ICHECK(obj.defined()) << "TypeError: Cannot add nullptr to metadata" ; |
57 | String key = obj->GetTypeKey(); |
58 | Array<ObjectRef>& array = metadata[key]; |
59 | int index = array.size(); |
60 | array.push_back(obj); |
61 | return IdDoc("metadata" ) // |
62 | [{LiteralDoc::Str(key, NullOpt)}] // |
63 | [{LiteralDoc::Int(index, NullOpt)}]; |
64 | } |
65 | |
66 | bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } |
67 | |
68 | void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { |
69 | auto it = obj2info.find(obj); |
70 | ICHECK(it != obj2info.end()) << "No such object: " << obj; |
71 | if (it->second.name.defined()) { |
72 | defined_names.erase(it->second.name.value()); |
73 | } |
74 | obj2info.erase(it); |
75 | } |
76 | |
77 | void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, |
78 | runtime::TypedPackedFunc<bool(ObjectRef)> is_var) { |
79 | class Visitor : public AttrVisitor { |
80 | public: |
81 | inline void operator()(ObjectRef obj) { Visit("" , &obj); } |
82 | |
83 | private: |
84 | void Visit(const char* key, double* value) final {} |
85 | void Visit(const char* key, int64_t* value) final {} |
86 | void Visit(const char* key, uint64_t* value) final {} |
87 | void Visit(const char* key, int* value) final {} |
88 | void Visit(const char* key, bool* value) final {} |
89 | void Visit(const char* key, std::string* value) final {} |
90 | void Visit(const char* key, void** value) final {} |
91 | void Visit(const char* key, DataType* value) final {} |
92 | void Visit(const char* key, runtime::NDArray* value) final {} |
93 | void Visit(const char* key, ObjectRef* value) final { |
94 | const Object* obj = value->get(); |
95 | if (obj == nullptr) { |
96 | return; |
97 | } |
98 | stack_.push_back(obj); |
99 | if (obj->IsInstance<ArrayNode>()) { |
100 | const ArrayNode* array = static_cast<const ArrayNode*>(obj); |
101 | for (ObjectRef element : *array) { |
102 | this->Visit("" , &element); |
103 | } |
104 | } else if (obj->IsInstance<MapNode>()) { |
105 | const MapNode* map = static_cast<const MapNode*>(obj); |
106 | for (std::pair<ObjectRef, ObjectRef> kv : *map) { |
107 | this->Visit("" , &kv.first); |
108 | this->Visit("" , &kv.second); |
109 | } |
110 | } else { |
111 | vtable_->VisitAttrs(const_cast<Object*>(obj), this); |
112 | } |
113 | if (is_var(GetRef<ObjectRef>(obj))) { |
114 | HandleVar(obj); |
115 | } |
116 | stack_.pop_back(); |
117 | } |
118 | |
119 | void HandleVar(const Object* var) { |
120 | if (common_prefix.count(var) == 0) { |
121 | common_prefix[var] = stack_; |
122 | return; |
123 | } |
124 | std::vector<const Object*>& a = common_prefix[var]; |
125 | std::vector<const Object*>& b = stack_; |
126 | int n = std::min(a.size(), b.size()); |
127 | for (int i = 0; i < n; ++i) { |
128 | if (a[i] != b[i]) { |
129 | a.resize(i); |
130 | break; |
131 | } |
132 | } |
133 | } |
134 | |
135 | ReflectionVTable* vtable_ = ReflectionVTable::Global(); |
136 | std::vector<const Object*> stack_; |
137 | |
138 | public: |
139 | runtime::TypedPackedFunc<bool(ObjectRef)> is_var; |
140 | std::unordered_map<const Object*, std::vector<const Object*>> common_prefix; |
141 | }; |
142 | Visitor visitor; |
143 | visitor.is_var = is_var; |
144 | visitor(root); |
145 | this->common_prefix = std::move(visitor.common_prefix); |
146 | } |
147 | |
148 | IRDocsifier::IRDocsifier(const PrinterConfig& cfg) { |
149 | auto n = make_object<IRDocsifierNode>(); |
150 | n->cfg = cfg; |
151 | n->dispatch_tokens.push_back("" ); |
152 | data_ = std::move(n); |
153 | } |
154 | |
155 | IRDocsifier::FType& IRDocsifier::vtable() { |
156 | static IRDocsifier::FType inst; |
157 | return inst; |
158 | } |
159 | |
160 | TVM_REGISTER_NODE_TYPE(FrameNode); |
161 | TVM_REGISTER_NODE_TYPE(IRDocsifierNode); |
162 | |
163 | } // namespace printer |
164 | } // namespace script |
165 | } // namespace tvm |
166 | |