1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/runtime/container/base.h>
20#include <tvm/runtime/logging.h>
21#include <tvm/runtime/registry.h>
22#include <tvm/script/printer/ir_docsifier.h>
23
24#include "./utils.h"
25
26namespace tvm {
27namespace script {
28namespace printer {
29
30IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) {
31 ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
32 String name = GenerateUniqueName(name_hint, this->defined_names);
33 this->defined_names.insert(name);
34 DocCreator doc_factory = [name]() { return IdDoc(name); };
35 obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
36 IdDoc def_doc(name);
37 frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
38 return def_doc;
39}
40
41void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory) {
42 ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
43 obj2info.insert({obj, VariableInfo{std::move(doc_factory), NullOpt}});
44 frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
45}
46
47Optional<ExprDoc> IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const {
48 auto it = obj2info.find(obj);
49 if (it == obj2info.end()) {
50 return NullOpt;
51 }
52 return it->second.creator();
53}
54
55ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) {
56 ICHECK(obj.defined()) << "TypeError: Cannot add nullptr to metadata";
57 String key = obj->GetTypeKey();
58 Array<ObjectRef>& array = metadata[key];
59 int index = array.size();
60 array.push_back(obj);
61 return IdDoc("metadata") //
62 [{LiteralDoc::Str(key, NullOpt)}] //
63 [{LiteralDoc::Int(index, NullOpt)}];
64}
65
66bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); }
67
68void IRDocsifierNode::RemoveVar(const ObjectRef& obj) {
69 auto it = obj2info.find(obj);
70 ICHECK(it != obj2info.end()) << "No such object: " << obj;
71 if (it->second.name.defined()) {
72 defined_names.erase(it->second.name.value());
73 }
74 obj2info.erase(it);
75}
76
77void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root,
78 runtime::TypedPackedFunc<bool(ObjectRef)> is_var) {
79 class Visitor : public AttrVisitor {
80 public:
81 inline void operator()(ObjectRef obj) { Visit("", &obj); }
82
83 private:
84 void Visit(const char* key, double* value) final {}
85 void Visit(const char* key, int64_t* value) final {}
86 void Visit(const char* key, uint64_t* value) final {}
87 void Visit(const char* key, int* value) final {}
88 void Visit(const char* key, bool* value) final {}
89 void Visit(const char* key, std::string* value) final {}
90 void Visit(const char* key, void** value) final {}
91 void Visit(const char* key, DataType* value) final {}
92 void Visit(const char* key, runtime::NDArray* value) final {}
93 void Visit(const char* key, ObjectRef* value) final {
94 const Object* obj = value->get();
95 if (obj == nullptr) {
96 return;
97 }
98 stack_.push_back(obj);
99 if (obj->IsInstance<ArrayNode>()) {
100 const ArrayNode* array = static_cast<const ArrayNode*>(obj);
101 for (ObjectRef element : *array) {
102 this->Visit("", &element);
103 }
104 } else if (obj->IsInstance<MapNode>()) {
105 const MapNode* map = static_cast<const MapNode*>(obj);
106 for (std::pair<ObjectRef, ObjectRef> kv : *map) {
107 this->Visit("", &kv.first);
108 this->Visit("", &kv.second);
109 }
110 } else {
111 vtable_->VisitAttrs(const_cast<Object*>(obj), this);
112 }
113 if (is_var(GetRef<ObjectRef>(obj))) {
114 HandleVar(obj);
115 }
116 stack_.pop_back();
117 }
118
119 void HandleVar(const Object* var) {
120 if (common_prefix.count(var) == 0) {
121 common_prefix[var] = stack_;
122 return;
123 }
124 std::vector<const Object*>& a = common_prefix[var];
125 std::vector<const Object*>& b = stack_;
126 int n = std::min(a.size(), b.size());
127 for (int i = 0; i < n; ++i) {
128 if (a[i] != b[i]) {
129 a.resize(i);
130 break;
131 }
132 }
133 }
134
135 ReflectionVTable* vtable_ = ReflectionVTable::Global();
136 std::vector<const Object*> stack_;
137
138 public:
139 runtime::TypedPackedFunc<bool(ObjectRef)> is_var;
140 std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;
141 };
142 Visitor visitor;
143 visitor.is_var = is_var;
144 visitor(root);
145 this->common_prefix = std::move(visitor.common_prefix);
146}
147
148IRDocsifier::IRDocsifier(const PrinterConfig& cfg) {
149 auto n = make_object<IRDocsifierNode>();
150 n->cfg = cfg;
151 n->dispatch_tokens.push_back("");
152 data_ = std::move(n);
153}
154
155IRDocsifier::FType& IRDocsifier::vtable() {
156 static IRDocsifier::FType inst;
157 return inst;
158}
159
160TVM_REGISTER_NODE_TYPE(FrameNode);
161TVM_REGISTER_NODE_TYPE(IRDocsifierNode);
162
163} // namespace printer
164} // namespace script
165} // namespace tvm
166