1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*
20 * \file src/runtime/object.cc
21 * \brief Object type management system.
22 */
23#include <tvm/runtime/logging.h>
24#include <tvm/runtime/object.h>
25#include <tvm/runtime/registry.h>
26
27#include <iostream>
28#include <mutex>
29#include <string>
30#include <unordered_map>
31#include <utility>
32#include <vector>
33
34#include "object_internal.h"
35#include "runtime_base.h"
36
37namespace tvm {
38namespace runtime {
39
40/*! \brief Type information */
41struct TypeInfo {
42 /*! \brief The current index. */
43 uint32_t index{0};
44 /*! \brief Index of the parent in the type hierarchy */
45 uint32_t parent_index{0};
46 // NOTE: the indices in [index, index + num_reserved_slots) are
47 // reserved for the child-class of this type.
48 /*! \brief Total number of slots reserved for the type and its children. */
49 uint32_t num_slots{0};
50 /*! \brief number of allocated child slots. */
51 uint32_t allocated_slots{0};
52 /*! \brief Whether child can overflow. */
53 bool child_slots_can_overflow{true};
54 /*! \brief name of the type. */
55 std::string name;
56 /*! \brief hash of the name */
57 size_t name_hash{0};
58};
59
60/*!
61 * \brief Type context that manages the type hierarchy information.
62 */
63class TypeContext {
64 public:
65 // NOTE: this is a relatively slow path for child checking
66 // Most types are already checked by the fast-path via reserved slot checking.
67 bool DerivedFrom(uint32_t child_tindex, uint32_t parent_tindex) {
68 // invariance: child's type index is always bigger than its parent.
69 if (child_tindex < parent_tindex) return false;
70 if (child_tindex == parent_tindex) return true;
71 {
72 std::lock_guard<std::mutex> lock(mutex_);
73 ICHECK_LT(child_tindex, type_table_.size());
74 while (child_tindex > parent_tindex) {
75 child_tindex = type_table_[child_tindex].parent_index;
76 }
77 }
78 return child_tindex == parent_tindex;
79 }
80
81 uint32_t GetOrAllocRuntimeTypeIndex(const std::string& skey, uint32_t static_tindex,
82 uint32_t parent_tindex, uint32_t num_child_slots,
83 bool child_slots_can_overflow) {
84 std::lock_guard<std::mutex> lock(mutex_);
85 auto it = type_key2index_.find(skey);
86 if (it != type_key2index_.end()) {
87 return it->second;
88 }
89 // try to allocate from parent's type table.
90 ICHECK_LT(parent_tindex, type_table_.size())
91 << " skey=" << skey << ", static_index=" << static_tindex;
92 TypeInfo& pinfo = type_table_[parent_tindex];
93 ICHECK_EQ(pinfo.index, parent_tindex);
94
95 // if parent cannot overflow, then this class cannot.
96 if (!pinfo.child_slots_can_overflow) {
97 child_slots_can_overflow = false;
98 }
99
100 // total number of slots include the type itself.
101 uint32_t num_slots = num_child_slots + 1;
102 uint32_t allocated_tindex;
103
104 if (static_tindex != TypeIndex::kDynamic) {
105 // statically assigned type
106 VLOG(3) << "TypeIndex[" << static_tindex << "]: static: " << skey << ", parent "
107 << type_table_[parent_tindex].name;
108 allocated_tindex = static_tindex;
109 ICHECK_LT(static_tindex, type_table_.size());
110 ICHECK_EQ(type_table_[allocated_tindex].allocated_slots, 0U)
111 << "Conflicting static index " << static_tindex << " between "
112 << type_table_[allocated_tindex].name << " and " << skey;
113 } else if (pinfo.allocated_slots + num_slots <= pinfo.num_slots) {
114 // allocate the slot from parent's reserved pool
115 allocated_tindex = parent_tindex + pinfo.allocated_slots;
116 VLOG(3) << "TypeIndex[" << allocated_tindex << "]: dynamic: " << skey << ", parent "
117 << type_table_[parent_tindex].name;
118 // update parent's state
119 pinfo.allocated_slots += num_slots;
120 } else {
121 VLOG(3) << "TypeIndex[" << type_counter_ << "]: dynamic (overflow): " << skey << ", parent "
122 << type_table_[parent_tindex].name;
123 ICHECK(pinfo.child_slots_can_overflow)
124 << "Reach maximum number of sub-classes for " << pinfo.name;
125 // allocate new entries.
126 allocated_tindex = type_counter_;
127 type_counter_ += num_slots;
128 ICHECK_LE(type_table_.size(), type_counter_);
129 type_table_.resize(type_counter_, TypeInfo());
130 }
131 ICHECK_GT(allocated_tindex, parent_tindex);
132 // initialize the slot.
133 type_table_[allocated_tindex].index = allocated_tindex;
134 type_table_[allocated_tindex].parent_index = parent_tindex;
135 type_table_[allocated_tindex].num_slots = num_slots;
136 type_table_[allocated_tindex].allocated_slots = 1;
137 type_table_[allocated_tindex].child_slots_can_overflow = child_slots_can_overflow;
138 type_table_[allocated_tindex].name = skey;
139 type_table_[allocated_tindex].name_hash = std::hash<std::string>()(skey);
140 // update the key2index mapping.
141 type_key2index_[skey] = allocated_tindex;
142 return allocated_tindex;
143 }
144
145 std::string TypeIndex2Key(uint32_t tindex) {
146 std::lock_guard<std::mutex> lock(mutex_);
147 if (tindex != 0) {
148 // always return the right type key for root
149 // for non-root type nodes, allocated slots should not equal 0
150 ICHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0)
151 << "Unknown type index " << tindex;
152 }
153 return type_table_[tindex].name;
154 }
155
156 size_t TypeIndex2KeyHash(uint32_t tindex) {
157 std::lock_guard<std::mutex> lock(mutex_);
158 ICHECK(tindex < type_table_.size() && type_table_[tindex].allocated_slots != 0)
159 << "Unknown type index " << tindex;
160 return type_table_[tindex].name_hash;
161 }
162
163 uint32_t TypeKey2Index(const std::string& skey) {
164 auto it = type_key2index_.find(skey);
165 ICHECK(it != type_key2index_.end())
166 << "Cannot find type " << skey
167 << ". Did you forget to register the node by TVM_REGISTER_NODE_TYPE ?";
168 return it->second;
169 }
170
171 void Dump(int min_children_count) {
172 std::vector<int> num_children(type_table_.size(), 0);
173 // reverse accumulation so we can get total counts in a bottom-up manner.
174 for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
175 if (it->index != 0) {
176 num_children[it->parent_index] += num_children[it->index] + 1;
177 }
178 }
179
180 for (const auto& info : type_table_) {
181 if (info.index != 0 && num_children[info.index] >= min_children_count) {
182 std::cerr << '[' << info.index << "] " << info.name
183 << "\tparent=" << type_table_[info.parent_index].name
184 << "\tnum_child_slots=" << info.num_slots - 1
185 << "\tnum_children=" << num_children[info.index] << std::endl;
186 }
187 }
188 }
189
190 static TypeContext* Global() {
191 static TypeContext inst;
192 return &inst;
193 }
194
195 private:
196 TypeContext() {
197 type_table_.resize(TypeIndex::kStaticIndexEnd, TypeInfo());
198 type_table_[0].name = "runtime.Object";
199 }
200 // mutex to avoid registration from multiple threads.
201 std::mutex mutex_;
202 std::atomic<uint32_t> type_counter_{TypeIndex::kStaticIndexEnd};
203 std::vector<TypeInfo> type_table_;
204 std::unordered_map<std::string, uint32_t> type_key2index_;
205};
206
207uint32_t Object::GetOrAllocRuntimeTypeIndex(const std::string& key, uint32_t static_tindex,
208 uint32_t parent_tindex, uint32_t num_child_slots,
209 bool child_slots_can_overflow) {
210 return TypeContext::Global()->GetOrAllocRuntimeTypeIndex(
211 key, static_tindex, parent_tindex, num_child_slots, child_slots_can_overflow);
212}
213
214bool Object::DerivedFrom(uint32_t parent_tindex) const {
215 return TypeContext::Global()->DerivedFrom(this->type_index_, parent_tindex);
216}
217
218std::string Object::TypeIndex2Key(uint32_t tindex) {
219 return TypeContext::Global()->TypeIndex2Key(tindex);
220}
221
222size_t Object::TypeIndex2KeyHash(uint32_t tindex) {
223 return TypeContext::Global()->TypeIndex2KeyHash(tindex);
224}
225
226uint32_t Object::TypeKey2Index(const std::string& key) {
227 return TypeContext::Global()->TypeKey2Index(key);
228}
229
230TVM_REGISTER_GLOBAL("runtime.ObjectPtrHash").set_body_typed([](ObjectRef obj) {
231 return static_cast<int64_t>(ObjectPtrHash()(obj));
232});
233
234TVM_REGISTER_GLOBAL("runtime.DumpTypeTable").set_body_typed([](int min_child_count) {
235 TypeContext::Global()->Dump(min_child_count);
236});
237} // namespace runtime
238} // namespace tvm
239
240int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) {
241 API_BEGIN();
242 ICHECK(obj != nullptr);
243 out_tindex[0] = static_cast<tvm::runtime::Object*>(obj)->type_index();
244 API_END();
245}
246
247int TVMObjectRetain(TVMObjectHandle obj) {
248 API_BEGIN();
249 tvm::runtime::ObjectInternal::ObjectRetain(obj);
250 API_END();
251}
252
253int TVMObjectFree(TVMObjectHandle obj) {
254 API_BEGIN();
255 tvm::runtime::ObjectInternal::ObjectFree(obj);
256 API_END();
257}
258
259int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) {
260 API_BEGIN();
261 *is_derived =
262 tvm::runtime::TypeContext::Global()->DerivedFrom(child_type_index, parent_type_index);
263 API_END();
264}
265
266int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) {
267 API_BEGIN();
268 out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index(type_key);
269 API_END();
270}
271
272int TVMObjectTypeIndex2Key(unsigned tindex, char** out_type_key) {
273 API_BEGIN();
274 auto key = tvm::runtime::Object::TypeIndex2Key(tindex);
275 *out_type_key = static_cast<char*>(malloc(key.size() + 1));
276 strncpy(*out_type_key, key.c_str(), key.size() + 1);
277 API_END();
278}
279