1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | /*! |
20 | * \file tvm/ir/attrs.h |
21 | * \brief Helpers for attribute objects. |
22 | * |
23 | * This module enables declaration of named attributes |
24 | * which support default value setup and bound checking. |
25 | * |
26 | * \code |
27 | * struct MyAttrs : public tvm::AttrsNode<MyAttrs> { |
28 | * float learning_rate; |
29 | * int num_hidden; |
30 | * String name; |
31 | * // declare attribute fields in header file |
32 | * TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") { |
33 | * TVM_ATTR_FIELD(num_hidden).set_lower_bound(1); |
34 | * TVM_ATTR_FIELD(learning_rate).set_default(0.01f); |
35 | * TVM_ATTR_FIELD(name).set_default("hello"); |
36 | * } |
37 | * }; |
38 | * // register it in cc file |
39 | * TVM_REGISTER_NODE_TYPE(MyAttrs); |
40 | * \endcode |
41 | * |
42 | * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD |
43 | */ |
44 | #ifndef TVM_IR_ATTRS_H_ |
45 | #define TVM_IR_ATTRS_H_ |
46 | |
47 | #include <dmlc/common.h> |
48 | #include <tvm/ir/expr.h> |
49 | #include <tvm/node/structural_equal.h> |
50 | #include <tvm/node/structural_hash.h> |
51 | #include <tvm/runtime/packed_func.h> |
52 | |
53 | #include <functional> |
54 | #include <string> |
55 | #include <type_traits> |
56 | #include <unordered_map> |
57 | #include <utility> |
58 | #include <vector> |
59 | |
60 | namespace tvm { |
61 | /*! |
62 | * \brief Declare an attribute function. |
63 | * \param ClassName The name of the class. |
64 | * \param TypeKey The type key to be used by the TVM node system. |
65 | */ |
66 | #define TVM_DECLARE_ATTRS(ClassName, TypeKey) \ |
67 | static constexpr const char* _type_key = TypeKey; \ |
68 | TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \ |
69 | template <typename FVisit> \ |
70 | void _tvm_VisitAttrs(FVisit& _tvm_fvisit) // NOLINT(*) |
71 | |
72 | /*! |
73 | * \brief Declare an attribute field. |
74 | * \param FieldName The field name. |
75 | */ |
76 | #define TVM_ATTR_FIELD(FieldName) _tvm_fvisit(#FieldName, &FieldName) |
77 | |
78 | /*! |
79 | * \brief Create a NodeRef type that represents null. |
80 | * \tparam TNodeRef the type to be created. |
81 | * \return A instance that will represent None. |
82 | */ |
83 | template <typename TObjectRef> |
84 | inline TObjectRef NullValue() { |
85 | static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types" ); |
86 | return TObjectRef(ObjectPtr<Object>(nullptr)); |
87 | } |
88 | |
89 | template <> |
90 | inline DataType NullValue<DataType>() { |
91 | return DataType(DataType::kHandle, 0, 0); |
92 | } |
93 | |
94 | /*! \brief Error thrown during attribute checking. */ |
95 | struct AttrError : public Error { |
96 | /*! |
97 | * \brief constructor |
98 | * \param msg error message |
99 | */ |
100 | explicit AttrError(std::string msg) : Error("AttributeError:" + msg) {} |
101 | }; |
102 | |
103 | /*! |
104 | * \brief Information about attribute fields in string representations. |
105 | */ |
106 | class AttrFieldInfoNode : public Object { |
107 | public: |
108 | /*! \brief name of the field */ |
109 | String name; |
110 | /*! \brief type docstring information in str. */ |
111 | String type_info; |
112 | /*! \brief detailed description of the type */ |
113 | String description; |
114 | |
115 | void VisitAttrs(AttrVisitor* v) { |
116 | v->Visit("name" , &name); |
117 | v->Visit("type_info" , &type_info); |
118 | v->Visit("description" , &description); |
119 | } |
120 | |
121 | static constexpr const char* _type_key = "AttrFieldInfo" ; |
122 | static constexpr bool _type_has_method_sequal_reduce = false; |
123 | static constexpr bool _type_has_method_shash_reduce = false; |
124 | TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); |
125 | }; |
126 | |
127 | /*! \brief AttrFieldInfo */ |
128 | class AttrFieldInfo : public ObjectRef { |
129 | public: |
130 | TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode); |
131 | }; |
132 | |
133 | /*! |
134 | * \brief Base class of all attribute class |
135 | * \note Do not subclass AttrBaseNode directly, |
136 | * subclass AttrsNode instead. |
137 | * \sa AttrsNode |
138 | */ |
139 | class BaseAttrsNode : public Object { |
140 | public: |
141 | using TVMArgs = runtime::TVMArgs; |
142 | using TVMRetValue = runtime::TVMRetValue; |
143 | /*! \brief virtual destructor */ |
144 | virtual ~BaseAttrsNode() {} |
145 | // visit function |
146 | virtual void VisitAttrs(AttrVisitor* v) {} |
147 | /*! |
148 | * \brief Initialize the attributes by sequence of arguments |
149 | * \param args The positional arguments in the form |
150 | * [key0, value0, key1, value1, ..., key_n, value_n] |
151 | */ |
152 | template <typename... Args> |
153 | inline void InitBySeq(Args&&... args); |
154 | /*! |
155 | * \brief Print readible docstring to ostream, add newline. |
156 | * \param os the stream to print the docstring to. |
157 | */ |
158 | inline void PrintDocString(std::ostream& os) const; // NOLINT(*) |
159 | /*! |
160 | * \brief Visit attributes that do not equal the default value. |
161 | * |
162 | * \note This is useful to extract fields for concise printing. |
163 | * \param v The visitor |
164 | */ |
165 | TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0; |
166 | /*! |
167 | * \brief Get the field information |
168 | * \return The fields in the Attrs. |
169 | */ |
170 | TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0; |
171 | /*! |
172 | * \brief Initialize the attributes by arguments. |
173 | * \param kwargs The key value pairs for initialization. |
174 | * [key0, value0, key1, value1, ..., key_n, value_n] |
175 | * \param allow_unknown Whether allow additional unknown fields. |
176 | * \note This function throws when the required field is not present. |
177 | */ |
178 | TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0; |
179 | |
180 | static constexpr const bool _type_has_method_sequal_reduce = true; |
181 | static constexpr const bool _type_has_method_shash_reduce = true; |
182 | static constexpr const char* _type_key = "Attrs" ; |
183 | TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); |
184 | }; |
185 | |
186 | /*! |
187 | * \brief Managed reference to BaseAttrsNode. |
188 | * \sa AttrsNode, BaseAttrsNode |
189 | */ |
190 | class Attrs : public ObjectRef { |
191 | public: |
192 | TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode); |
193 | }; |
194 | |
195 | /*! |
196 | * \brief Specialized attribute type that is backed by a map. |
197 | * The DictAttrsNode implements the Attrs behavior, |
198 | * its fields are directly accessible via object.field_name |
199 | * like other normal nodes. |
200 | */ |
201 | class DictAttrsNode : public BaseAttrsNode { |
202 | public: |
203 | /*! \brief internal attrs map */ |
204 | Map<String, ObjectRef> dict; |
205 | |
206 | bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const { |
207 | return equal(dict, other->dict); |
208 | } |
209 | |
210 | void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); } |
211 | |
212 | // implementations |
213 | void VisitAttrs(AttrVisitor* v) final; |
214 | void VisitNonDefaultAttrs(AttrVisitor* v) final; |
215 | void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; |
216 | Array<AttrFieldInfo> ListFieldInfo() const final; |
217 | |
218 | // type info |
219 | static constexpr const char* _type_key = "DictAttrs" ; |
220 | TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); |
221 | }; |
222 | |
223 | /*! |
224 | * \brief Managed reference to DictAttrsNode |
225 | * \sa DictAttrsNode. |
226 | */ |
227 | class DictAttrs : public Attrs { |
228 | public: |
229 | /*! |
230 | * \brief Consruct a Attrs backed by DictAttrsNode. |
231 | * \param dict The attributes. |
232 | * \return The dict attributes. |
233 | */ |
234 | TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict); |
235 | |
236 | // Utils for accessing attributes |
237 | // This needs to be on DictAttrs, not DictAttrsNode because we return the default |
238 | // value if DictAttrsNode is not defined. |
239 | /*! |
240 | * \brief Get a function attribute. |
241 | * |
242 | * \param attr_key The attribute key. |
243 | * \param default_value The default value if the key does not exist, defaults to nullptr. |
244 | * |
245 | * \return The result |
246 | * |
247 | * \tparam TOBjectRef the expected object type. |
248 | * \throw Error if the key exists but the value does not match TObjectRef |
249 | * |
250 | * \code |
251 | * |
252 | * void GetAttrExample(const BaseFunc& f) { |
253 | * auto value = f->attrs.GetAttr<Integer>("AttrKey", 0); |
254 | * } |
255 | * |
256 | * \endcode |
257 | */ |
258 | template <typename TObjectRef> |
259 | Optional<TObjectRef> GetAttr( |
260 | const std::string& attr_key, |
261 | Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const { |
262 | static_assert(std::is_base_of<ObjectRef, TObjectRef>::value, |
263 | "Can only call GetAttr with ObjectRef types." ); |
264 | if (!defined()) return default_value; |
265 | const DictAttrsNode* node = this->as<DictAttrsNode>(); |
266 | |
267 | auto it = node->dict.find(attr_key); |
268 | if (it != node->dict.end()) { |
269 | return Downcast<Optional<TObjectRef>>((*it).second); |
270 | } else { |
271 | return default_value; |
272 | } |
273 | } |
274 | // variant that uses TObjectRef to enable implicit conversion to default value. |
275 | template <typename TObjectRef> |
276 | Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const { |
277 | return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value)); |
278 | } |
279 | /*! |
280 | * \brief Check whether the function has an non-zero integer attr. |
281 | * |
282 | * This function can be used to check whether an optional |
283 | * attribute mark(e.g. inline) exists. |
284 | * |
285 | * \param attr_key The key to the attribute. |
286 | * \return The check result. |
287 | * |
288 | * \code |
289 | * |
290 | * void HasNonzeroAttrExample(const BaseFunc& f) { |
291 | * if (f->HasNonzeroAttr(attr::kInline)) { |
292 | * // inline the function. |
293 | * } |
294 | * } |
295 | * |
296 | * \endcode |
297 | */ |
298 | bool HasNonzeroAttr(const std::string& attr_key) const { |
299 | return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0; |
300 | } |
301 | |
302 | TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); |
303 | TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); |
304 | }; |
305 | |
306 | /*! |
307 | * \brief Create an Attr object with all default values. |
308 | * \tparam TAttrNode the type to be created. |
309 | * \return A instance that will represent None. |
310 | */ |
311 | template <typename TAttrs> |
312 | inline TAttrs AttrsWithDefaultValues() { |
313 | static_assert(std::is_base_of<Attrs, TAttrs>::value, "Can only take attr nodes" ); |
314 | auto n = make_object<typename TAttrs::ContainerType>(); |
315 | n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false); |
316 | return TAttrs(n); |
317 | } |
318 | |
319 | /*! |
320 | * \brief Copy the function or module, but overrides |
321 | * the attribute value key with the value. |
322 | * |
323 | * \param input The thing to annotate (BaseFunc or IRModule) |
324 | * \param attr_key The attribute key. |
325 | * \param attr_value The value attribute value. |
326 | * |
327 | * \tparam TFunc The corresponding function or module type. |
328 | * |
329 | * \returns The new function or module with updated attributes. |
330 | * |
331 | * \note This function performs copy on write optimization for func and module. |
332 | * If we move a uniquely referenced func or module into WithAttr, |
333 | * then no additional copy will be performed. |
334 | * |
335 | * This is also why we make it as a function instead of a member function |
336 | * and why we pass by value in the first argument. |
337 | * |
338 | * \code |
339 | * |
340 | * // Recommended way to trigger copy on write |
341 | * func = WithAttr(std::move(func), "key1", value1); |
342 | * func = WithAttr(std::move(func), "key2", value2); |
343 | * |
344 | * \endcode |
345 | */ |
346 | template <typename TFunc> |
347 | inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) { |
348 | using TNode = typename TFunc::ContainerType; |
349 | static_assert(TNode::_type_final, "Can only operate on the leaf nodes" ); |
350 | TNode* node = input.CopyOnWrite(); |
351 | if (node->attrs.defined()) { |
352 | node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); |
353 | } else { |
354 | Map<String, ObjectRef> dict = {{attr_key, attr_value}}; |
355 | node->attrs = DictAttrs(dict); |
356 | } |
357 | return input; |
358 | } |
359 | |
360 | /*! |
361 | * \brief Copy the function or module, but overrides the attributes with the entries from \p attrs. |
362 | * |
363 | * \param input The thing to annotate (BaseFunc or IRModule) |
364 | * \param attrs Key/values attributes to add to \p input. |
365 | * |
366 | * \tparam TFunc The corresponding function or module type. |
367 | * |
368 | * \returns The new function or module with updated attributes. |
369 | */ |
370 | template <typename TFunc> |
371 | inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) { |
372 | using TNode = typename TFunc::ContainerType; |
373 | static_assert(TNode::_type_final, "Can only operate on the leaf nodes" ); |
374 | TNode* node = input.CopyOnWrite(); |
375 | if (node->attrs.defined()) { |
376 | for (const auto& pair : attrs) { |
377 | node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second); |
378 | } |
379 | } else { |
380 | node->attrs = DictAttrs(std::move(attrs)); |
381 | } |
382 | return input; |
383 | } |
384 | |
385 | /*! |
386 | * \brief Copy the function or module, but removes the specified |
387 | * attribute. |
388 | * |
389 | * \param input The thing to annotate (BaseFunc or IRModule) |
390 | * \param attr_key The attribute key. |
391 | * |
392 | * \tparam TFunc The corresponding function or module type. |
393 | * |
394 | * \returns The new function or module with removed attribute. |
395 | * |
396 | * \note This function performs copy on write optimization for func and module. |
397 | * If we move a uniquely referenced func or module into WithoutAttr, |
398 | * then no additional copy will be performed. |
399 | * |
400 | * This is also why we make it as a function instead of a member function |
401 | * and why we pass by value in the first argument. |
402 | * |
403 | * \code |
404 | * |
405 | * // Recommended way to trigger copy on write |
406 | * func = WithoutAttr(std::move(func), "key1"); |
407 | * func = WithoutAttr(std::move(func), "key2"); |
408 | * |
409 | * \endcode |
410 | */ |
411 | template <typename TFunc> |
412 | inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { |
413 | using TNode = typename TFunc::ContainerType; |
414 | static_assert(TNode::_type_final, "Can only operate on the leaf nodes" ); |
415 | |
416 | if (input->attrs.defined()) { |
417 | TNode* node = input.CopyOnWrite(); |
418 | node->attrs.CopyOnWrite()->dict.erase(attr_key); |
419 | if (node->attrs->dict.size() == 0) { |
420 | node->attrs = NullValue<DictAttrs>(); |
421 | } |
422 | } |
423 | return input; |
424 | } |
425 | |
426 | // Namespace containing detail implementations |
427 | namespace detail { |
428 | using runtime::TVMArgValue; |
429 | |
430 | // helper entry that does nothing in set_default/bound/describe calls. |
431 | struct AttrNopEntry { |
432 | using TSelf = AttrNopEntry; |
433 | |
434 | TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } |
435 | template <typename T> |
436 | TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) { |
437 | return *this; |
438 | } |
439 | template <typename T> |
440 | TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { |
441 | return *this; |
442 | } |
443 | template <typename T> |
444 | TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { |
445 | return *this; |
446 | } |
447 | }; |
448 | |
449 | // Wrapper for normal visitor. |
450 | class AttrNormalVisitor { |
451 | public: |
452 | explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {} |
453 | template <typename T> |
454 | AttrNopEntry operator()(const char* key, T* value) { |
455 | visitor_->Visit(key, value); |
456 | return AttrNopEntry(); |
457 | } |
458 | |
459 | private: |
460 | AttrVisitor* visitor_; |
461 | }; |
462 | |
463 | class { |
464 | public: |
465 | bool {true}; |
466 | // constructor |
467 | (const Object* lhs, const Object* rhs, const SEqualReducer& equal) |
468 | : lhs_(lhs), rhs_(rhs), equal_(equal) {} |
469 | template <typename T> |
470 | AttrNopEntry (const char* key, T* lhs_value) { |
471 | if (!result_) return AttrNopEntry(); |
472 | const T* rhs_value = reinterpret_cast<const T*>( |
473 | reinterpret_cast<const char*>(rhs_) + |
474 | (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_))); |
475 | if (!equal_(*lhs_value, *rhs_value)) { |
476 | result_ = false; |
477 | } |
478 | return AttrNopEntry(); |
479 | } |
480 | |
481 | private: |
482 | const Object* ; |
483 | const Object* ; |
484 | const SEqualReducer& ; |
485 | }; |
486 | |
487 | class { |
488 | public: |
489 | explicit (const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {} |
490 | |
491 | template <typename T> |
492 | AttrNopEntry (const char* key, T* value) { |
493 | hash_reducer_(*value); |
494 | return AttrNopEntry(); |
495 | } |
496 | |
497 | private: |
498 | const SHashReducer& ; |
499 | }; |
500 | |
501 | // helper entry that does initialization, set default. |
502 | template <typename T> |
503 | struct AttrInitEntry { |
504 | // The attributes |
505 | using TSelf = AttrInitEntry<T>; |
506 | // The type key |
507 | const char* type_key_; |
508 | // field name |
509 | const char* key_; |
510 | // internal value. |
511 | T* value_; |
512 | // whether the value is missing. |
513 | // NOTE: initialize to false so that the destructor does not throw unless |
514 | // AttrInitVisitor::operator() is committed to returning an instance of this class. |
515 | // It is expected not to set this to true until that is true. |
516 | bool value_missing_{false}; |
517 | |
518 | AttrInitEntry() = default; |
519 | |
520 | AttrInitEntry(AttrInitEntry&& other) { |
521 | type_key_ = other.type_key_; |
522 | key_ = other.key_; |
523 | value_ = other.value_; |
524 | value_missing_ = other.value_missing_; |
525 | // avoid unexpected throw |
526 | other.value_missing_ = false; |
527 | } |
528 | |
529 | // If the value is still missing in destruction time throw an error. |
530 | ~AttrInitEntry() DMLC_THROW_EXCEPTION { |
531 | if (value_missing_) { |
532 | std::ostringstream os; |
533 | os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. " |
534 | << "If the key is defined check that its type matches the declared type." ; |
535 | throw AttrError(os.str()); |
536 | } |
537 | } |
538 | // override fields. |
539 | // This function sets the lower bound of the attribute |
540 | TSelf& set_lower_bound(const T& begin) { |
541 | if (this->value_missing_) return *this; |
542 | const T& val = *value_; |
543 | if (begin > val) { |
544 | std::ostringstream os; |
545 | os << type_key_ << "." << key_ << ": " |
546 | << "value " << val << " is smaller than the lower bound " << begin; |
547 | throw AttrError(os.str()); |
548 | } |
549 | return *this; |
550 | } |
551 | // This function sets the upper bound of the attribute |
552 | TSelf& set_upper_bound(const T& end) { |
553 | if (this->value_missing_) return *this; |
554 | const T& val = *value_; |
555 | if (val > end) { |
556 | std::ostringstream os; |
557 | os << type_key_ << "." << key_ << ": " |
558 | << "value " << val << " is bigger than the upper bound " << end; |
559 | throw AttrError(os.str()); |
560 | } |
561 | return *this; |
562 | } |
563 | // set default when |
564 | TSelf& set_default(const T& value) { |
565 | if (!value_missing_) return *this; |
566 | *value_ = value; |
567 | value_missing_ = false; |
568 | return *this; |
569 | } |
570 | TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } |
571 | }; |
572 | |
573 | // Template function to allow smart conversion |
574 | // from Expr types into the constants. |
575 | template <typename T> |
576 | inline void SetValue(T* ptr, const TVMArgValue& val) { |
577 | *ptr = val.operator T(); |
578 | } |
579 | |
580 | template <typename T> |
581 | inline void SetIntValue(T* ptr, const TVMArgValue& val) { |
582 | if (val.type_code() == kDLInt) { |
583 | *ptr = static_cast<T>(val.value().v_int64); |
584 | } else { |
585 | IntImm expr = val; |
586 | *ptr = static_cast<T>(expr->value); |
587 | } |
588 | } |
589 | |
590 | // Workaround for GCC8.1 / GCC8.2 |
591 | template <> |
592 | inline void SetValue<DataType>(DataType* ptr, const TVMArgValue& val) { |
593 | *ptr = val.operator DataType(); |
594 | } |
595 | |
596 | template <> |
597 | inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) { |
598 | if (String::CanConvertFrom(val)) { |
599 | *ptr = val.operator std::string(); |
600 | } else { |
601 | LOG(FATAL) << "Expect str" ; |
602 | } |
603 | } |
604 | |
605 | template <> |
606 | inline void SetValue<double>(double* ptr, const TVMArgValue& val) { |
607 | if (val.type_code() == kDLFloat || val.type_code() == kDLInt) { |
608 | *ptr = val.operator double(); |
609 | } else { |
610 | ObjectRef expr = val; |
611 | ICHECK(expr.defined()); |
612 | if (const IntImmNode* op = expr.as<IntImmNode>()) { |
613 | *ptr = static_cast<double>(op->value); |
614 | } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) { |
615 | *ptr = static_cast<double>(op->value); |
616 | } else { |
617 | LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey(); |
618 | } |
619 | } |
620 | } |
621 | template <> |
622 | inline void SetValue<int>(int* ptr, const TVMArgValue& val) { |
623 | SetIntValue(ptr, val); |
624 | } |
625 | template <> |
626 | inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) { |
627 | SetIntValue(ptr, val); |
628 | } |
629 | template <> |
630 | inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) { |
631 | SetIntValue(ptr, val); |
632 | } |
633 | template <> |
634 | inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) { |
635 | SetIntValue(ptr, val); |
636 | } |
637 | |
638 | // Visitor for value initialization |
639 | template <typename FFind> |
640 | class AttrInitVisitor { |
641 | public: |
642 | // Counter of number of matched attributes during visit. |
643 | // This is used to decide if there is additional unmatched attributes. |
644 | size_t hit_count_{0}; |
645 | // constructor |
646 | AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {} |
647 | |
648 | template <typename T> |
649 | AttrInitEntry<T> operator()(const char* key, T* value) { |
650 | TVMArgValue val; |
651 | AttrInitEntry<T> opt; |
652 | opt.type_key_ = type_key_; |
653 | opt.key_ = key; |
654 | opt.value_ = value; |
655 | if (ffind_(key, &val)) { |
656 | SetValue(value, val); |
657 | opt.value_missing_ = false; |
658 | ++hit_count_; |
659 | } else { |
660 | opt.value_missing_ = true; |
661 | } |
662 | #if defined(__GNUC__) |
663 | #pragma GCC diagnostic ignored "-Wpragmas" |
664 | #pragma GCC diagnostic ignored "-Wpessimizing-move" |
665 | #endif |
666 | return std::move(opt); |
667 | } |
668 | |
669 | private: |
670 | // the type key |
671 | const char* type_key_; |
672 | FFind ffind_; |
673 | }; |
674 | |
675 | template <typename FFind> |
676 | inline AttrInitVisitor<FFind> CreateInitVisitor(const char* type_key, FFind ffind) { |
677 | return AttrInitVisitor<FFind>(type_key, ffind); |
678 | } |
679 | |
680 | /*! |
681 | * \brief Helper struct to get the type name known to tvm. |
682 | * \tparam T the type we are interested in. |
683 | */ |
684 | template <typename T> |
685 | struct TypeName { |
686 | static constexpr const char* value = T::ContainerType::_type_key; |
687 | }; |
688 | |
689 | template <> |
690 | struct TypeName<int> { |
691 | static constexpr const char* value = "int" ; |
692 | }; |
693 | |
694 | template <> |
695 | struct TypeName<int64_t> { |
696 | static constexpr const char* value = "int64" ; |
697 | }; |
698 | |
699 | template <> |
700 | struct TypeName<uint64_t> { |
701 | static constexpr const char* value = "uint64_t" ; |
702 | }; |
703 | |
704 | template <> |
705 | struct TypeName<DataType> { |
706 | static constexpr const char* value = "DataType" ; |
707 | }; |
708 | |
709 | template <> |
710 | struct TypeName<std::string> { |
711 | static constexpr const char* value = "str" ; |
712 | }; |
713 | |
714 | template <> |
715 | struct TypeName<bool> { |
716 | static constexpr const char* value = "bool" ; |
717 | }; |
718 | |
719 | template <> |
720 | struct TypeName<void*> { |
721 | static constexpr const char* value = "handle" ; |
722 | }; |
723 | |
724 | template <> |
725 | struct TypeName<double> { |
726 | static constexpr const char* value = "double" ; |
727 | }; |
728 | |
729 | class AttrDocEntry { |
730 | public: |
731 | using TSelf = AttrDocEntry; |
732 | |
733 | explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info) : info_(info) {} |
734 | TSelf& describe(const char* str) { |
735 | info_->description = str; |
736 | return *this; |
737 | } |
738 | template <typename T> |
739 | TSelf& set_default(const T& value) { |
740 | std::ostringstream os; |
741 | os << info_->type_info << ", default=" << value; |
742 | info_->type_info = os.str(); |
743 | return *this; |
744 | } |
745 | template <typename T> |
746 | TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) { |
747 | return *this; |
748 | } |
749 | template <typename T> |
750 | TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) { |
751 | return *this; |
752 | } |
753 | |
754 | private: |
755 | ObjectPtr<AttrFieldInfoNode> info_; |
756 | }; |
757 | |
758 | class AttrDocVisitor { |
759 | public: |
760 | template <typename T> |
761 | AttrDocEntry operator()(const char* key, T* v) { |
762 | ObjectPtr<AttrFieldInfoNode> info = make_object<AttrFieldInfoNode>(); |
763 | info->name = key; |
764 | info->type_info = TypeName<T>::value; |
765 | fields_.push_back(AttrFieldInfo(info)); |
766 | return AttrDocEntry(info); |
767 | } |
768 | |
769 | Array<AttrFieldInfo> fields_; |
770 | }; |
771 | |
772 | class AttrExistVisitor { |
773 | public: |
774 | std::string key_; |
775 | bool exist_{false}; |
776 | |
777 | template <typename T> |
778 | AttrNopEntry operator()(const char* key, T* v) { |
779 | if (exist_) return AttrNopEntry(); |
780 | if (key == key_) exist_ = true; |
781 | return AttrNopEntry(); |
782 | } |
783 | }; |
784 | |
785 | template <typename T> |
786 | struct AttrTriggerNonDefaultEntry { |
787 | using TSelf = AttrTriggerNonDefaultEntry<T>; |
788 | // constructor |
789 | AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data) |
790 | : visitor_(visitor), key_(key), data_(data) {} |
791 | |
792 | ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION { |
793 | if (trigger_) { |
794 | visitor_->Visit(key_, data_); |
795 | } |
796 | } |
797 | TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; } |
798 | TSelf& set_default(const T& value) { |
799 | if (tvm::StructuralEqual()(value, *data_)) { |
800 | trigger_ = false; |
801 | } |
802 | return *this; |
803 | } |
804 | TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; } |
805 | TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; } |
806 | |
807 | private: |
808 | AttrVisitor* visitor_; |
809 | const char* key_; |
810 | T* data_; |
811 | bool trigger_{true}; |
812 | }; |
813 | |
814 | class AttrNonDefaultVisitor { |
815 | public: |
816 | explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {} |
817 | template <typename T> |
818 | AttrTriggerNonDefaultEntry<T> operator()(const char* key, T* value) { |
819 | return AttrTriggerNonDefaultEntry<T>(visitor_, key, value); |
820 | } |
821 | |
822 | private: |
823 | AttrVisitor* visitor_; |
824 | }; |
825 | } // namespace detail |
826 | |
827 | /*! |
828 | * \brief The base class of the all the |
829 | * Use "curiously recurring template pattern". |
830 | * |
831 | * \tparam DerivedType The final attribute type. |
832 | */ |
833 | template <typename DerivedType> |
834 | class AttrsNode : public BaseAttrsNode { |
835 | public: |
836 | void VisitAttrs(AttrVisitor* v) { |
837 | ::tvm::detail::AttrNormalVisitor vis(v); |
838 | self()->_tvm_VisitAttrs(vis); |
839 | } |
840 | |
841 | void VisitNonDefaultAttrs(AttrVisitor* v) { |
842 | ::tvm::detail::AttrNonDefaultVisitor vis(v); |
843 | self()->_tvm_VisitAttrs(vis); |
844 | } |
845 | |
846 | void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final { |
847 | ICHECK_EQ(args.size() % 2, 0); |
848 | const int kLinearSearchBound = 16; |
849 | int hit_count = 0; |
850 | // applies two strategies to lookup |
851 | if (args.size() < kLinearSearchBound) { |
852 | // linear search. |
853 | auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { |
854 | for (int i = 0; i < args.size(); i += 2) { |
855 | ICHECK_EQ(args.type_codes[i], kTVMStr); |
856 | if (!std::strcmp(key, args.values[i].v_str)) { |
857 | *val = args[i + 1]; |
858 | return true; |
859 | } |
860 | } |
861 | return false; |
862 | }; |
863 | auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind); |
864 | self()->_tvm_VisitAttrs(vis); |
865 | hit_count = vis.hit_count_; |
866 | } else { |
867 | // construct a map then do lookup. |
868 | std::unordered_map<std::string, runtime::TVMArgValue> kwargs; |
869 | for (int i = 0; i < args.size(); i += 2) { |
870 | ICHECK_EQ(args.type_codes[i], kTVMStr); |
871 | kwargs[args[i].operator std::string()] = args[i + 1]; |
872 | } |
873 | auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) { |
874 | auto it = kwargs.find(key); |
875 | if (it != kwargs.end()) { |
876 | *val = it->second; |
877 | return true; |
878 | } |
879 | return false; |
880 | }; |
881 | auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind); |
882 | self()->_tvm_VisitAttrs(vis); |
883 | hit_count = vis.hit_count_; |
884 | } |
885 | // error handling, slow path |
886 | if (hit_count * 2 != args.size() && !allow_unknown) { |
887 | for (int i = 0; i < args.size(); i += 2) { |
888 | ::tvm::detail::AttrExistVisitor visitor; |
889 | visitor.key_ = args[i].operator std::string(); |
890 | self()->_tvm_VisitAttrs(visitor); |
891 | if (!visitor.exist_) { |
892 | std::ostringstream os; |
893 | os << DerivedType::_type_key << ": does not have field \'" << visitor.key_ |
894 | << "\', Possible fields:\n" ; |
895 | os << "----------------\n" ; |
896 | this->PrintDocString(os); |
897 | throw AttrError(os.str()); |
898 | } |
899 | } |
900 | } |
901 | } |
902 | |
903 | bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const { |
904 | DerivedType* pself = self(); |
905 | ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal); |
906 | self()->_tvm_VisitAttrs(visitor); |
907 | return visitor.result_; |
908 | } |
909 | |
910 | void SHashReduce(SHashReducer hash_reducer) const { |
911 | ::tvm::detail::AttrsSHashVisitor visitor(hash_reducer); |
912 | self()->_tvm_VisitAttrs(visitor); |
913 | } |
914 | |
915 | Array<AttrFieldInfo> ListFieldInfo() const final { |
916 | ::tvm::detail::AttrDocVisitor visitor; |
917 | self()->_tvm_VisitAttrs(visitor); |
918 | return visitor.fields_; |
919 | } |
920 | |
921 | private: |
922 | DerivedType* self() const { |
923 | return const_cast<DerivedType*>(static_cast<const DerivedType*>(this)); |
924 | } |
925 | }; |
926 | |
927 | template <typename... Args> |
928 | inline void BaseAttrsNode::InitBySeq(Args&&... args) { |
929 | runtime::PackedFunc pf( |
930 | [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); }); |
931 | pf(std::forward<Args>(args)...); |
932 | } |
933 | |
934 | inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*) |
935 | Array<AttrFieldInfo> entry = this->ListFieldInfo(); |
936 | for (AttrFieldInfo info : entry) { |
937 | os << info->name << " : " << info->type_info << '\n'; |
938 | if (info->description.length() != 0) { |
939 | os << " " << info->description << '\n'; |
940 | } |
941 | } |
942 | } |
943 | |
944 | } // namespace tvm |
945 | #endif // TVM_IR_ATTRS_H_ |
946 | |