1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/executor.h |
22 | * \brief Object representation of Executor configuration and registry |
23 | */ |
24 | #ifndef TVM_RELAY_EXECUTOR_H_ |
25 | #define TVM_RELAY_EXECUTOR_H_ |
26 | |
27 | #include <dmlc/registry.h> |
28 | #include <tvm/ir/attrs.h> |
29 | #include <tvm/ir/expr.h> |
30 | #include <tvm/ir/type.h> |
31 | #include <tvm/ir/type_relation.h> |
32 | #include <tvm/node/attr_registry_map.h> |
33 | #include <tvm/runtime/registry.h> |
34 | |
35 | #include <string> |
36 | #include <unordered_map> |
37 | #include <utility> |
38 | #include <vector> |
39 | |
40 | namespace tvm { |
41 | |
42 | template <typename, typename> |
43 | class AttrRegistry; |
44 | |
45 | namespace relay { |
46 | |
47 | /*! |
48 | * \brief Executor information. |
49 | * |
50 | * This data structure stores the meta-data |
51 | * about executors which can be used to pass around information. |
52 | * |
53 | * \sa Executor |
54 | */ |
55 | class ExecutorNode : public Object { |
56 | public: |
57 | /*! \brief name of the Executor */ |
58 | String name; |
59 | /* \brief Additional attributes storing meta-data about the Executor. */ |
60 | DictAttrs attrs; |
61 | |
62 | /*! |
63 | * \brief Should Link Parameters into the module |
64 | * \return Whether the Executor is configured to execute modules with linked parameters |
65 | */ |
66 | Bool ShouldLinkParameters() const { |
67 | return name == "aot" || GetAttr<Bool>("link-params" ).value_or(Bool(false)); |
68 | } |
69 | |
70 | /*! |
71 | * \brief Get an attribute. |
72 | * |
73 | * \param attr_key The attribute key. |
74 | * \param default_value The default value if the key does not exist, defaults to nullptr. |
75 | * |
76 | * \return The result |
77 | * |
78 | * \tparam TObjectRef the expected object type. |
79 | * \throw Error if the key exists but the value does not match TObjectRef |
80 | * |
81 | * \code |
82 | * |
83 | * void GetAttrExample(const Executor& executor) { |
84 | * auto value = executor->GetAttr<Integer>("AttrKey", 0); |
85 | * } |
86 | * |
87 | * \endcode |
88 | */ |
89 | template <typename TObjectRef> |
90 | Optional<TObjectRef> GetAttr( |
91 | const std::string& attr_key, |
92 | Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const { |
93 | return attrs.GetAttr(attr_key, default_value); |
94 | } |
95 | // variant that uses TObjectRef to enable implicit conversion to default value. |
96 | template <typename TObjectRef> |
97 | Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const { |
98 | return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value)); |
99 | } |
100 | |
101 | void VisitAttrs(AttrVisitor* v) { |
102 | v->Visit("name" , &name); |
103 | v->Visit("attrs" , &attrs); |
104 | } |
105 | |
106 | bool SEqualReduce(const ExecutorNode* other, SEqualReducer equal) const { |
107 | return name == other->name && equal.DefEqual(attrs, other->attrs); |
108 | } |
109 | |
110 | void SHashReduce(SHashReducer hash_reduce) const { |
111 | hash_reduce(name); |
112 | hash_reduce(attrs); |
113 | } |
114 | |
115 | static constexpr const char* _type_key = "Executor" ; |
116 | static constexpr const bool _type_has_method_sequal_reduce = true; |
117 | static constexpr const bool _type_has_method_shash_reduce = true; |
118 | TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorNode, Object); |
119 | }; |
120 | |
121 | /*! |
122 | * \brief Managed reference class to ExecutorNode. |
123 | * \sa ExecutorNode |
124 | */ |
125 | class Executor : public ObjectRef { |
126 | public: |
127 | /*! |
128 | * \brief Create a new Executor object using the registry |
129 | * \throws Error if name is not registered |
130 | * \param name The name of the executor. |
131 | * \param attrs Attributes for the executor. |
132 | * \return the new Executor object. |
133 | */ |
134 | TVM_DLL static Executor Create(String name, Map<String, ObjectRef> attrs = {}); |
135 | |
136 | /*! |
137 | * \brief List all registered Executors |
138 | * \return the list of Executors |
139 | */ |
140 | TVM_DLL static Array<String> ListExecutors(); |
141 | |
142 | /*! |
143 | * \brief List all options for a specific Executor |
144 | * \param name The name of the Executor |
145 | * \return Map of option name to type |
146 | */ |
147 | TVM_DLL static Map<String, String> ListExecutorOptions(const String& name); |
148 | |
149 | /*! \brief specify container node */ |
150 | TVM_DEFINE_OBJECT_REF_METHODS(Executor, ObjectRef, ExecutorNode); |
151 | TVM_DEFINE_OBJECT_REF_COW_METHOD(ExecutorNode) |
152 | |
153 | private: |
154 | /*! |
155 | * \brief Private Constructor |
156 | * \param name The executor name |
157 | * \param attrs Attributes to apply to this Executor node |
158 | */ |
159 | TVM_DLL Executor(String name, DictAttrs attrs) { |
160 | auto n = make_object<ExecutorNode>(); |
161 | n->name = std::move(name); |
162 | n->attrs = std::move(attrs); |
163 | data_ = std::move(n); |
164 | } |
165 | }; |
166 | |
167 | /*! |
168 | * \brief Helper structure to register Executors |
169 | * \sa TVM_REGISTER_EXECUTOR |
170 | */ |
171 | class ExecutorRegEntry { |
172 | public: |
173 | /*! |
174 | * \brief Register a valid configuration option and its ValueType for validation |
175 | * \param key The configuration key |
176 | * \tparam ValueType The value type to be registered |
177 | */ |
178 | template <typename ValueType> |
179 | inline ExecutorRegEntry& add_attr_option(const String& key); |
180 | |
181 | /*! |
182 | * \brief Register a valid configuration option and its ValueType for validation |
183 | * \param key The configuration key |
184 | * \param default_value The default value of the key |
185 | * \tparam ValueType The value type to be registered |
186 | */ |
187 | template <typename ValueType> |
188 | inline ExecutorRegEntry& add_attr_option(const String& key, ObjectRef default_value); |
189 | |
190 | /*! |
191 | * \brief Register or get a new entry. |
192 | * \param name The name of the operator. |
193 | * \return the corresponding entry. |
194 | */ |
195 | TVM_DLL static ExecutorRegEntry& RegisterOrGet(const String& name); |
196 | |
197 | private: |
198 | /*! \brief Internal storage of value types */ |
199 | struct ValueTypeInfo { |
200 | std::string type_key; |
201 | uint32_t type_index; |
202 | }; |
203 | std::unordered_map<std::string, ValueTypeInfo> key2vtype_; |
204 | /*! \brief A hash table that stores the default value of each attr */ |
205 | std::unordered_map<String, ObjectRef> key2default_; |
206 | |
207 | /*! \brief Index used for internal lookup of attribute registry */ |
208 | uint32_t index_; |
209 | |
210 | // the name |
211 | std::string name; |
212 | |
213 | /*! \brief Return the index stored in attr registry */ |
214 | uint32_t AttrRegistryIndex() const { return index_; } |
215 | /*! \brief Return the name stored in attr registry */ |
216 | String AttrRegistryName() const { return name; } |
217 | |
218 | /*! \brief private constructor */ |
219 | explicit ExecutorRegEntry(uint32_t reg_index) : index_(reg_index) {} |
220 | |
221 | // friend class |
222 | template <typename> |
223 | friend class AttrRegistryMapContainerMap; |
224 | template <typename, typename> |
225 | friend class tvm::AttrRegistry; |
226 | friend class Executor; |
227 | }; |
228 | |
229 | template <typename ValueType> |
230 | inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key) { |
231 | ICHECK(!key2vtype_.count(key)) << "AttributeError: add_attr_option failed because '" << key |
232 | << "' has been set once" ; |
233 | |
234 | using ValueNodeType = typename ValueType::ContainerType; |
235 | // NOTE: we could further update the function later. |
236 | uint32_t value_type_index = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); |
237 | |
238 | ValueTypeInfo info; |
239 | info.type_index = value_type_index; |
240 | info.type_key = runtime::Object::TypeIndex2Key(value_type_index); |
241 | key2vtype_[key] = info; |
242 | return *this; |
243 | } |
244 | |
245 | template <typename ValueType> |
246 | inline ExecutorRegEntry& ExecutorRegEntry::add_attr_option(const String& key, |
247 | ObjectRef default_value) { |
248 | add_attr_option<ValueType>(key); |
249 | key2default_[key] = default_value; |
250 | return *this; |
251 | } |
252 | |
253 | // internal macros to make executor entries |
254 | #define TVM_EXECUTOR_REGISTER_VAR_DEF \ |
255 | static DMLC_ATTRIBUTE_UNUSED ::tvm::relay::ExecutorRegEntry& __make_##Executor |
256 | |
257 | /*! |
258 | * \def TVM_REGISTER_EXECUTOR |
259 | * \brief Register a new executor, or set attribute of the corresponding executor. |
260 | * |
261 | * \param ExecutorName The name of registry |
262 | * |
263 | * \code |
264 | * |
265 | * TVM_REGISTER_EXECUTOR("aot") |
266 | * .add_attr_option<String>("my_option"); |
267 | * .add_attr_option<String>("my_option_default", String("default")); |
268 | * |
269 | * \endcode |
270 | */ |
271 | #define TVM_REGISTER_EXECUTOR(ExecutorName) \ |
272 | TVM_STR_CONCAT(TVM_EXECUTOR_REGISTER_VAR_DEF, __COUNTER__) = \ |
273 | ::tvm::relay::ExecutorRegEntry::RegisterOrGet(ExecutorName) |
274 | } // namespace relay |
275 | } // namespace tvm |
276 | |
277 | #endif // TVM_RELAY_EXECUTOR_H_ |
278 | |