1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/node/attr_registry.h |
22 | * \brief Common global registry for objects that also have additional attrs. |
23 | */ |
24 | #ifndef TVM_NODE_ATTR_REGISTRY_H_ |
25 | #define TVM_NODE_ATTR_REGISTRY_H_ |
26 | |
27 | #include <tvm/node/attr_registry_map.h> |
28 | #include <tvm/runtime/packed_func.h> |
29 | |
30 | #include <memory> |
31 | #include <mutex> |
32 | #include <unordered_map> |
33 | #include <utility> |
34 | #include <vector> |
35 | |
36 | namespace tvm { |
37 | |
38 | /*! |
39 | * \brief Implementation of registry with attributes. |
40 | * |
41 | * \tparam EntryType The type of the registry entry. |
42 | * \tparam KeyType The actual key that is used to lookup the attributes. |
43 | * each entry has a corresponding key by default. |
44 | */ |
45 | template <typename EntryType, typename KeyType> |
46 | class AttrRegistry { |
47 | public: |
48 | using TSelf = AttrRegistry<EntryType, KeyType>; |
49 | /*! |
50 | * \brief Get an entry from the registry. |
51 | * \param name The name of the item. |
52 | * \return The corresponding entry. |
53 | */ |
54 | const EntryType* Get(const String& name) const { |
55 | auto it = entry_map_.find(name); |
56 | if (it != entry_map_.end()) return it->second; |
57 | return nullptr; |
58 | } |
59 | |
60 | /*! |
61 | * \brief Get an entry or register a new one. |
62 | * \param name The name of the item. |
63 | * \return The corresponding entry. |
64 | */ |
65 | EntryType& RegisterOrGet(const String& name) { |
66 | auto it = entry_map_.find(name); |
67 | if (it != entry_map_.end()) return *it->second; |
68 | uint32_t registry_index = static_cast<uint32_t>(entries_.size()); |
69 | auto entry = std::unique_ptr<EntryType>(new EntryType(registry_index)); |
70 | auto* eptr = entry.get(); |
71 | eptr->name = name; |
72 | entry_map_[name] = eptr; |
73 | entries_.emplace_back(std::move(entry)); |
74 | return *eptr; |
75 | } |
76 | |
77 | /*! |
78 | * \brief List all the entry names in the registry. |
79 | * \return The entry names. |
80 | */ |
81 | Array<String> ListAllNames() const { |
82 | Array<String> names; |
83 | for (const auto& kv : entry_map_) { |
84 | names.push_back(kv.first); |
85 | } |
86 | return names; |
87 | } |
88 | |
89 | /*! |
90 | * \brief Update the attribute stable. |
91 | * \param attr_name The name of the attribute. |
92 | * \param key The key to the attribute table. |
93 | * \param value The value to be set. |
94 | * \param plevel The support level. |
95 | */ |
96 | void UpdateAttr(const String& attr_name, const KeyType& key, runtime::TVMRetValue value, |
97 | int plevel) { |
98 | using runtime::TVMRetValue; |
99 | std::lock_guard<std::mutex> lock(mutex_); |
100 | auto& op_map = attrs_[attr_name]; |
101 | if (op_map == nullptr) { |
102 | op_map.reset(new AttrRegistryMapContainerMap<KeyType>()); |
103 | op_map->attr_name_ = attr_name; |
104 | } |
105 | |
106 | uint32_t index = key->AttrRegistryIndex(); |
107 | if (op_map->data_.size() <= index) { |
108 | op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); |
109 | } |
110 | std::pair<TVMRetValue, int>& p = op_map->data_[index]; |
111 | ICHECK(p.second != plevel) << "Attribute " << attr_name << " of " << key->AttrRegistryName() |
112 | << " is already registered with same plevel=" << plevel; |
113 | ICHECK(value.type_code() != kTVMNullptr) << "Registered packed_func is Null for " << attr_name |
114 | << " of operator " << key->AttrRegistryName(); |
115 | if (p.second < plevel && value.type_code() != kTVMNullptr) { |
116 | op_map->data_[index] = std::make_pair(value, plevel); |
117 | } |
118 | } |
119 | |
120 | /*! |
121 | * \brief Reset an attribute table entry. |
122 | * \param attr_name The name of the attribute. |
123 | * \param key The key to the attribute table. |
124 | */ |
125 | void ResetAttr(const String& attr_name, const KeyType& key) { |
126 | std::lock_guard<std::mutex> lock(mutex_); |
127 | auto& op_map = attrs_[attr_name]; |
128 | if (op_map == nullptr) { |
129 | return; |
130 | } |
131 | uint32_t index = key->AttrRegistryIndex(); |
132 | if (op_map->data_.size() > index) { |
133 | op_map->data_[index] = std::make_pair(TVMRetValue(), 0); |
134 | } |
135 | } |
136 | |
137 | /*! |
138 | * \brief Get an internal attribute map. |
139 | * \param attr_name The name of the attribute. |
140 | * \return The result attribute map. |
141 | */ |
142 | const AttrRegistryMapContainerMap<KeyType>& GetAttrMap(const String& attr_name) { |
143 | std::lock_guard<std::mutex> lock(mutex_); |
144 | auto it = attrs_.find(attr_name); |
145 | if (it == attrs_.end()) { |
146 | LOG(FATAL) << "Attribute \'" << attr_name << "\' is not registered" ; |
147 | } |
148 | return *it->second.get(); |
149 | } |
150 | |
151 | /*! |
152 | * \brief Check of attribute has been registered. |
153 | * \param attr_name The name of the attribute. |
154 | * \return The check result. |
155 | */ |
156 | bool HasAttrMap(const String& attr_name) { |
157 | std::lock_guard<std::mutex> lock(mutex_); |
158 | return attrs_.count(attr_name); |
159 | } |
160 | |
161 | /*! |
162 | * \return a global singleton of the registry. |
163 | */ |
164 | static TSelf* Global() { |
165 | static TSelf* inst = new TSelf(); |
166 | return inst; |
167 | } |
168 | |
169 | private: |
170 | // mutex to avoid registration from multiple threads. |
171 | std::mutex mutex_; |
172 | // entries in the registry |
173 | std::vector<std::unique_ptr<EntryType>> entries_; |
174 | // map from name to entries. |
175 | std::unordered_map<String, EntryType*> entry_map_; |
176 | // storage of additional attribute table. |
177 | std::unordered_map<String, std::unique_ptr<AttrRegistryMapContainerMap<KeyType>>> attrs_; |
178 | }; |
179 | |
180 | } // namespace tvm |
181 | #endif // TVM_NODE_ATTR_REGISTRY_H_ |
182 | |