1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/runtime.h |
22 | * \brief Object representation of Runtime configuration and registry |
23 | */ |
24 | #ifndef TVM_RELAY_RUNTIME_H_ |
25 | #define TVM_RELAY_RUNTIME_H_ |
26 | |
27 | #include <dmlc/registry.h> |
28 | #include <tvm/ir/attrs.h> |
29 | #include <tvm/ir/expr.h> |
30 | #include <tvm/ir/type.h> |
31 | #include <tvm/ir/type_relation.h> |
32 | #include <tvm/node/attr_registry_map.h> |
33 | #include <tvm/runtime/registry.h> |
34 | |
35 | #include <string> |
36 | #include <unordered_map> |
37 | #include <utility> |
38 | #include <vector> |
39 | |
40 | namespace tvm { |
41 | |
42 | template <typename, typename> |
43 | class AttrRegistry; |
44 | |
45 | namespace relay { |
46 | |
47 | /*! \brief Value used with Runtime::name to indicate the C++ runtime. */ |
48 | static constexpr const char* kTvmRuntimeCpp = "cpp" ; |
49 | |
50 | /*! \brief Value used with Runtime::name to indicate the C runtime. */ |
51 | static constexpr const char* kTvmRuntimeCrt = "crt" ; |
52 | |
53 | /*! |
54 | * \brief Runtime information. |
55 | * |
56 | * This data structure stores the meta-data |
57 | * about Runtimes which can be used to pass around information. |
58 | * |
59 | * \sa Runtime |
60 | */ |
61 | class RuntimeNode : public Object { |
62 | public: |
63 | /*! \brief name of the Runtime */ |
64 | String name; |
65 | /* \brief Additional attributes storing meta-data about the Runtime. */ |
66 | DictAttrs attrs; |
67 | |
68 | /*! |
69 | * \brief Get an attribute. |
70 | * |
71 | * \param attr_key The attribute key. |
72 | * \param default_value The default value if the key does not exist, defaults to nullptr. |
73 | * |
74 | * \return The result |
75 | * |
76 | * \tparam TObjectRef the expected object type. |
77 | * \throw Error if the key exists but the value does not match TObjectRef |
78 | * |
79 | * \code |
80 | * |
81 | * void GetAttrExample(const Runtime& runtime) { |
82 | * auto value = runtime->GetAttr<Integer>("AttrKey", 0); |
83 | * } |
84 | * |
85 | * \endcode |
86 | */ |
87 | template <typename TObjectRef> |
88 | Optional<TObjectRef> GetAttr( |
89 | const std::string& attr_key, |
90 | Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const { |
91 | return attrs.GetAttr(attr_key, default_value); |
92 | } |
93 | // variant that uses TObjectRef to enable implicit conversion to default value. |
94 | template <typename TObjectRef> |
95 | Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const { |
96 | return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value)); |
97 | } |
98 | |
99 | void VisitAttrs(AttrVisitor* v) { |
100 | v->Visit("name" , &name); |
101 | v->Visit("attrs" , &attrs); |
102 | } |
103 | |
104 | bool SEqualReduce(const RuntimeNode* other, SEqualReducer equal) const { |
105 | return name == other->name && equal.DefEqual(attrs, other->attrs); |
106 | } |
107 | |
108 | void SHashReduce(SHashReducer hash_reduce) const { |
109 | hash_reduce(name); |
110 | hash_reduce(attrs); |
111 | } |
112 | |
113 | static constexpr const char* _type_key = "Runtime" ; |
114 | static constexpr const bool _type_has_method_sequal_reduce = true; |
115 | static constexpr const bool _type_has_method_shash_reduce = true; |
116 | TVM_DECLARE_FINAL_OBJECT_INFO(RuntimeNode, Object); |
117 | }; |
118 | |
119 | /*! |
120 | * \brief Managed reference class to RuntimeNode. |
121 | * \sa RuntimeNode |
122 | */ |
123 | class Runtime : public ObjectRef { |
124 | public: |
125 | Runtime() = default; |
126 | |
127 | /*! |
128 | * \brief Create a new Runtime object using the registry |
129 | * \throws Error if name is not registered |
130 | * \param name The name of the Runtime. |
131 | * \param attrs Attributes for the Runtime. |
132 | * \return the new Runtime object. |
133 | */ |
134 | TVM_DLL static Runtime Create(String name, Map<String, ObjectRef> attrs = {}); |
135 | |
136 | /*! |
137 | * \brief List all registered Runtimes |
138 | * \return the list of Runtimes |
139 | */ |
140 | TVM_DLL static Array<String> ListRuntimes(); |
141 | |
142 | /*! |
143 | * \brief List all options for a specific Runtime |
144 | * \param name The name of the Runtime |
145 | * \return Map of option name to type |
146 | */ |
147 | TVM_DLL static Map<String, String> ListRuntimeOptions(const String& name); |
148 | |
149 | /*! \brief specify container node */ |
150 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Runtime, ObjectRef, RuntimeNode); |
151 | |
152 | private: |
153 | /*! |
154 | * \brief Private Constructor |
155 | * \param name The Runtime name |
156 | * \param attrs Attributes to apply to this Runtime node |
157 | */ |
158 | TVM_DLL Runtime(String name, DictAttrs attrs) { |
159 | auto n = make_object<RuntimeNode>(); |
160 | n->name = std::move(name); |
161 | n->attrs = std::move(attrs); |
162 | data_ = std::move(n); |
163 | } |
164 | }; |
165 | |
166 | /*! |
167 | * \brief Helper structure to register Runtimes |
168 | * \sa TVM_REGISTER_Runtime |
169 | */ |
170 | class RuntimeRegEntry { |
171 | public: |
172 | /*! |
173 | * \brief Register a valid configuration option and its ValueType for validation |
174 | * \param key The configuration key |
175 | * \tparam ValueType The value type to be registered |
176 | */ |
177 | template <typename ValueType> |
178 | inline RuntimeRegEntry& add_attr_option(const String& key); |
179 | |
180 | /*! |
181 | * \brief Register a valid configuration option and its ValueType for validation |
182 | * \param key The configuration key |
183 | * \param default_value The default value of the key |
184 | * \tparam ValueType The value type to be registered |
185 | */ |
186 | template <typename ValueType> |
187 | inline RuntimeRegEntry& add_attr_option(const String& key, ObjectRef default_value); |
188 | |
189 | /*! |
190 | * \brief Register or get a new entry. |
191 | * \param name The name of the operator. |
192 | * \return the corresponding entry. |
193 | */ |
194 | TVM_DLL static RuntimeRegEntry& RegisterOrGet(const String& name); |
195 | |
196 | private: |
197 | /*! \brief Internal storage of value types */ |
198 | struct ValueTypeInfo { |
199 | std::string type_key; |
200 | uint32_t type_index; |
201 | }; |
202 | std::unordered_map<std::string, ValueTypeInfo> key2vtype_; |
203 | /*! \brief A hash table that stores the default value of each attr */ |
204 | std::unordered_map<String, ObjectRef> key2default_; |
205 | |
206 | /*! \brief Index used for internal lookup of attribute registry */ |
207 | uint32_t index_; |
208 | |
209 | // the name |
210 | std::string name; |
211 | |
212 | /*! \brief Return the index stored in attr registry */ |
213 | uint32_t AttrRegistryIndex() const { return index_; } |
214 | /*! \brief Return the name stored in attr registry */ |
215 | String AttrRegistryName() const { return name; } |
216 | |
217 | /*! \brief private constructor */ |
218 | explicit RuntimeRegEntry(uint32_t reg_index) : index_(reg_index) {} |
219 | |
220 | // friend class |
221 | template <typename> |
222 | friend class AttrRegistryMapContainerMap; |
223 | template <typename, typename> |
224 | friend class tvm::AttrRegistry; |
225 | friend class Runtime; |
226 | }; |
227 | |
228 | template <typename ValueType> |
229 | inline RuntimeRegEntry& RuntimeRegEntry::add_attr_option(const String& key) { |
230 | ICHECK(!key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key |
231 | << "' has been set once" ; |
232 | |
233 | using ValueNodeType = typename ValueType::ContainerType; |
234 | // NOTE: we could further update the function later. |
235 | uint32_t value_type_index = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); |
236 | |
237 | ValueTypeInfo info; |
238 | info.type_index = value_type_index; |
239 | info.type_key = runtime::Object::TypeIndex2Key(value_type_index); |
240 | key2vtype_[key] = info; |
241 | return *this; |
242 | } |
243 | |
244 | template <typename ValueType> |
245 | inline RuntimeRegEntry& RuntimeRegEntry::add_attr_option(const String& key, |
246 | ObjectRef default_value) { |
247 | add_attr_option<ValueType>(key); |
248 | key2default_[key] = default_value; |
249 | return *this; |
250 | } |
251 | |
252 | // internal macros to make Runtime entries |
253 | #define TVM_RUNTIME_REGISTER_VAR_DEF \ |
254 | static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::RuntimeRegEntry& __make_##Runtime |
255 | |
256 | /*! |
257 | * \def TVM_REGISTER_RUNTIME |
258 | * \brief Register a new Runtime, or set attribute of the corresponding Runtime. |
259 | * |
260 | * \param RuntimeName The name of registry |
261 | * |
262 | * \code |
263 | * |
264 | * TVM_REGISTER_RUNTIME("c") |
265 | * .add_attr_option<String>("my_option"); |
266 | * .add_attr_option<String>("my_option_default", String("default")); |
267 | * |
268 | * \endcode |
269 | */ |
270 | #define TVM_REGISTER_RUNTIME(RuntimeName) \ |
271 | TVM_STR_CONCAT(TVM_RUNTIME_REGISTER_VAR_DEF, __COUNTER__) = \ |
272 | ::tvm::relay::RuntimeRegEntry::RegisterOrGet(RuntimeName) |
273 | } // namespace relay |
274 | } // namespace tvm |
275 | |
276 | #endif // TVM_RELAY_RUNTIME_H_ |
277 | |