1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/node/attr_registry.h
22 * \brief Common global registry for objects that also have additional attrs.
23 */
24#ifndef TVM_NODE_ATTR_REGISTRY_H_
25#define TVM_NODE_ATTR_REGISTRY_H_
26
27#include <tvm/node/attr_registry_map.h>
28#include <tvm/runtime/packed_func.h>
29
30#include <memory>
31#include <mutex>
32#include <unordered_map>
33#include <utility>
34#include <vector>
35
36namespace tvm {
37
38/*!
39 * \brief Implementation of registry with attributes.
40 *
41 * \tparam EntryType The type of the registry entry.
42 * \tparam KeyType The actual key that is used to lookup the attributes.
43 * each entry has a corresponding key by default.
44 */
45template <typename EntryType, typename KeyType>
46class AttrRegistry {
47 public:
48 using TSelf = AttrRegistry<EntryType, KeyType>;
49 /*!
50 * \brief Get an entry from the registry.
51 * \param name The name of the item.
52 * \return The corresponding entry.
53 */
54 const EntryType* Get(const String& name) const {
55 auto it = entry_map_.find(name);
56 if (it != entry_map_.end()) return it->second;
57 return nullptr;
58 }
59
60 /*!
61 * \brief Get an entry or register a new one.
62 * \param name The name of the item.
63 * \return The corresponding entry.
64 */
65 EntryType& RegisterOrGet(const String& name) {
66 auto it = entry_map_.find(name);
67 if (it != entry_map_.end()) return *it->second;
68 uint32_t registry_index = static_cast<uint32_t>(entries_.size());
69 auto entry = std::unique_ptr<EntryType>(new EntryType(registry_index));
70 auto* eptr = entry.get();
71 eptr->name = name;
72 entry_map_[name] = eptr;
73 entries_.emplace_back(std::move(entry));
74 return *eptr;
75 }
76
77 /*!
78 * \brief List all the entry names in the registry.
79 * \return The entry names.
80 */
81 Array<String> ListAllNames() const {
82 Array<String> names;
83 for (const auto& kv : entry_map_) {
84 names.push_back(kv.first);
85 }
86 return names;
87 }
88
89 /*!
90 * \brief Update the attribute stable.
91 * \param attr_name The name of the attribute.
92 * \param key The key to the attribute table.
93 * \param value The value to be set.
94 * \param plevel The support level.
95 */
96 void UpdateAttr(const String& attr_name, const KeyType& key, runtime::TVMRetValue value,
97 int plevel) {
98 using runtime::TVMRetValue;
99 std::lock_guard<std::mutex> lock(mutex_);
100 auto& op_map = attrs_[attr_name];
101 if (op_map == nullptr) {
102 op_map.reset(new AttrRegistryMapContainerMap<KeyType>());
103 op_map->attr_name_ = attr_name;
104 }
105
106 uint32_t index = key->AttrRegistryIndex();
107 if (op_map->data_.size() <= index) {
108 op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
109 }
110 std::pair<TVMRetValue, int>& p = op_map->data_[index];
111 ICHECK(p.second != plevel) << "Attribute " << attr_name << " of " << key->AttrRegistryName()
112 << " is already registered with same plevel=" << plevel;
113 ICHECK(value.type_code() != kTVMNullptr) << "Registered packed_func is Null for " << attr_name
114 << " of operator " << key->AttrRegistryName();
115 if (p.second < plevel && value.type_code() != kTVMNullptr) {
116 op_map->data_[index] = std::make_pair(value, plevel);
117 }
118 }
119
120 /*!
121 * \brief Reset an attribute table entry.
122 * \param attr_name The name of the attribute.
123 * \param key The key to the attribute table.
124 */
125 void ResetAttr(const String& attr_name, const KeyType& key) {
126 std::lock_guard<std::mutex> lock(mutex_);
127 auto& op_map = attrs_[attr_name];
128 if (op_map == nullptr) {
129 return;
130 }
131 uint32_t index = key->AttrRegistryIndex();
132 if (op_map->data_.size() > index) {
133 op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
134 }
135 }
136
137 /*!
138 * \brief Get an internal attribute map.
139 * \param attr_name The name of the attribute.
140 * \return The result attribute map.
141 */
142 const AttrRegistryMapContainerMap<KeyType>& GetAttrMap(const String& attr_name) {
143 std::lock_guard<std::mutex> lock(mutex_);
144 auto it = attrs_.find(attr_name);
145 if (it == attrs_.end()) {
146 LOG(FATAL) << "Attribute \'" << attr_name << "\' is not registered";
147 }
148 return *it->second.get();
149 }
150
151 /*!
152 * \brief Check of attribute has been registered.
153 * \param attr_name The name of the attribute.
154 * \return The check result.
155 */
156 bool HasAttrMap(const String& attr_name) {
157 std::lock_guard<std::mutex> lock(mutex_);
158 return attrs_.count(attr_name);
159 }
160
161 /*!
162 * \return a global singleton of the registry.
163 */
164 static TSelf* Global() {
165 static TSelf* inst = new TSelf();
166 return inst;
167 }
168
169 private:
170 // mutex to avoid registration from multiple threads.
171 std::mutex mutex_;
172 // entries in the registry
173 std::vector<std::unique_ptr<EntryType>> entries_;
174 // map from name to entries.
175 std::unordered_map<String, EntryType*> entry_map_;
176 // storage of additional attribute table.
177 std::unordered_map<String, std::unique_ptr<AttrRegistryMapContainerMap<KeyType>>> attrs_;
178};
179
180} // namespace tvm
181#endif // TVM_NODE_ATTR_REGISTRY_H_
182