1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19/*!
20 * \file tvm/ir/attrs.h
21 * \brief Helpers for attribute objects.
22 *
23 * This module enables declaration of named attributes
24 * which support default value setup and bound checking.
25 *
26 * \code
27 * struct MyAttrs : public tvm::AttrsNode<MyAttrs> {
28 * float learning_rate;
29 * int num_hidden;
30 * String name;
31 * // declare attribute fields in header file
32 * TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") {
33 * TVM_ATTR_FIELD(num_hidden).set_lower_bound(1);
34 * TVM_ATTR_FIELD(learning_rate).set_default(0.01f);
35 * TVM_ATTR_FIELD(name).set_default("hello");
36 * }
37 * };
38 * // register it in cc file
39 * TVM_REGISTER_NODE_TYPE(MyAttrs);
40 * \endcode
41 *
42 * \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
43 */
44#ifndef TVM_IR_ATTRS_H_
45#define TVM_IR_ATTRS_H_
46
47#include <dmlc/common.h>
48#include <tvm/ir/expr.h>
49#include <tvm/node/structural_equal.h>
50#include <tvm/node/structural_hash.h>
51#include <tvm/runtime/packed_func.h>
52
53#include <functional>
54#include <string>
55#include <type_traits>
56#include <unordered_map>
57#include <utility>
58#include <vector>
59
60namespace tvm {
61/*!
62 * \brief Declare an attribute function.
63 * \param ClassName The name of the class.
64 * \param TypeKey The type key to be used by the TVM node system.
65 */
66#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
67 static constexpr const char* _type_key = TypeKey; \
68 TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
69 template <typename FVisit> \
70 void _tvm_VisitAttrs(FVisit& _tvm_fvisit) // NOLINT(*)
71
72/*!
73 * \brief Declare an attribute field.
74 * \param FieldName The field name.
75 */
76#define TVM_ATTR_FIELD(FieldName) _tvm_fvisit(#FieldName, &FieldName)
77
78/*!
79 * \brief Create a NodeRef type that represents null.
80 * \tparam TNodeRef the type to be created.
81 * \return A instance that will represent None.
82 */
83template <typename TObjectRef>
84inline TObjectRef NullValue() {
85 static_assert(TObjectRef::_type_is_nullable, "Can only get NullValue for nullable types");
86 return TObjectRef(ObjectPtr<Object>(nullptr));
87}
88
89template <>
90inline DataType NullValue<DataType>() {
91 return DataType(DataType::kHandle, 0, 0);
92}
93
94/*! \brief Error thrown during attribute checking. */
95struct AttrError : public Error {
96 /*!
97 * \brief constructor
98 * \param msg error message
99 */
100 explicit AttrError(std::string msg) : Error("AttributeError:" + msg) {}
101};
102
103/*!
104 * \brief Information about attribute fields in string representations.
105 */
106class AttrFieldInfoNode : public Object {
107 public:
108 /*! \brief name of the field */
109 String name;
110 /*! \brief type docstring information in str. */
111 String type_info;
112 /*! \brief detailed description of the type */
113 String description;
114
115 void VisitAttrs(AttrVisitor* v) {
116 v->Visit("name", &name);
117 v->Visit("type_info", &type_info);
118 v->Visit("description", &description);
119 }
120
121 static constexpr const char* _type_key = "AttrFieldInfo";
122 static constexpr bool _type_has_method_sequal_reduce = false;
123 static constexpr bool _type_has_method_shash_reduce = false;
124 TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
125};
126
127/*! \brief AttrFieldInfo */
128class AttrFieldInfo : public ObjectRef {
129 public:
130 TVM_DEFINE_OBJECT_REF_METHODS(AttrFieldInfo, ObjectRef, AttrFieldInfoNode);
131};
132
133/*!
134 * \brief Base class of all attribute class
135 * \note Do not subclass AttrBaseNode directly,
136 * subclass AttrsNode instead.
137 * \sa AttrsNode
138 */
139class BaseAttrsNode : public Object {
140 public:
141 using TVMArgs = runtime::TVMArgs;
142 using TVMRetValue = runtime::TVMRetValue;
143 /*! \brief virtual destructor */
144 virtual ~BaseAttrsNode() {}
145 // visit function
146 virtual void VisitAttrs(AttrVisitor* v) {}
147 /*!
148 * \brief Initialize the attributes by sequence of arguments
149 * \param args The positional arguments in the form
150 * [key0, value0, key1, value1, ..., key_n, value_n]
151 */
152 template <typename... Args>
153 inline void InitBySeq(Args&&... args);
154 /*!
155 * \brief Print readible docstring to ostream, add newline.
156 * \param os the stream to print the docstring to.
157 */
158 inline void PrintDocString(std::ostream& os) const; // NOLINT(*)
159 /*!
160 * \brief Visit attributes that do not equal the default value.
161 *
162 * \note This is useful to extract fields for concise printing.
163 * \param v The visitor
164 */
165 TVM_DLL virtual void VisitNonDefaultAttrs(AttrVisitor* v) = 0;
166 /*!
167 * \brief Get the field information
168 * \return The fields in the Attrs.
169 */
170 TVM_DLL virtual Array<AttrFieldInfo> ListFieldInfo() const = 0;
171 /*!
172 * \brief Initialize the attributes by arguments.
173 * \param kwargs The key value pairs for initialization.
174 * [key0, value0, key1, value1, ..., key_n, value_n]
175 * \param allow_unknown Whether allow additional unknown fields.
176 * \note This function throws when the required field is not present.
177 */
178 TVM_DLL virtual void InitByPackedArgs(const TVMArgs& kwargs, bool allow_unknown = false) = 0;
179
180 static constexpr const bool _type_has_method_sequal_reduce = true;
181 static constexpr const bool _type_has_method_shash_reduce = true;
182 static constexpr const char* _type_key = "Attrs";
183 TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
184};
185
186/*!
187 * \brief Managed reference to BaseAttrsNode.
188 * \sa AttrsNode, BaseAttrsNode
189 */
190class Attrs : public ObjectRef {
191 public:
192 TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode);
193};
194
195/*!
196 * \brief Specialized attribute type that is backed by a map.
197 * The DictAttrsNode implements the Attrs behavior,
198 * its fields are directly accessible via object.field_name
199 * like other normal nodes.
200 */
201class DictAttrsNode : public BaseAttrsNode {
202 public:
203 /*! \brief internal attrs map */
204 Map<String, ObjectRef> dict;
205
206 bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
207 return equal(dict, other->dict);
208 }
209
210 void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dict); }
211
212 // implementations
213 void VisitAttrs(AttrVisitor* v) final;
214 void VisitNonDefaultAttrs(AttrVisitor* v) final;
215 void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
216 Array<AttrFieldInfo> ListFieldInfo() const final;
217
218 // type info
219 static constexpr const char* _type_key = "DictAttrs";
220 TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
221};
222
223/*!
224 * \brief Managed reference to DictAttrsNode
225 * \sa DictAttrsNode.
226 */
227class DictAttrs : public Attrs {
228 public:
229 /*!
230 * \brief Consruct a Attrs backed by DictAttrsNode.
231 * \param dict The attributes.
232 * \return The dict attributes.
233 */
234 TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);
235
236 // Utils for accessing attributes
237 // This needs to be on DictAttrs, not DictAttrsNode because we return the default
238 // value if DictAttrsNode is not defined.
239 /*!
240 * \brief Get a function attribute.
241 *
242 * \param attr_key The attribute key.
243 * \param default_value The default value if the key does not exist, defaults to nullptr.
244 *
245 * \return The result
246 *
247 * \tparam TOBjectRef the expected object type.
248 * \throw Error if the key exists but the value does not match TObjectRef
249 *
250 * \code
251 *
252 * void GetAttrExample(const BaseFunc& f) {
253 * auto value = f->attrs.GetAttr<Integer>("AttrKey", 0);
254 * }
255 *
256 * \endcode
257 */
258 template <typename TObjectRef>
259 Optional<TObjectRef> GetAttr(
260 const std::string& attr_key,
261 Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
262 static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
263 "Can only call GetAttr with ObjectRef types.");
264 if (!defined()) return default_value;
265 const DictAttrsNode* node = this->as<DictAttrsNode>();
266
267 auto it = node->dict.find(attr_key);
268 if (it != node->dict.end()) {
269 return Downcast<Optional<TObjectRef>>((*it).second);
270 } else {
271 return default_value;
272 }
273 }
274 // variant that uses TObjectRef to enable implicit conversion to default value.
275 template <typename TObjectRef>
276 Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
277 return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
278 }
279 /*!
280 * \brief Check whether the function has an non-zero integer attr.
281 *
282 * This function can be used to check whether an optional
283 * attribute mark(e.g. inline) exists.
284 *
285 * \param attr_key The key to the attribute.
286 * \return The check result.
287 *
288 * \code
289 *
290 * void HasNonzeroAttrExample(const BaseFunc& f) {
291 * if (f->HasNonzeroAttr(attr::kInline)) {
292 * // inline the function.
293 * }
294 * }
295 *
296 * \endcode
297 */
298 bool HasNonzeroAttr(const std::string& attr_key) const {
299 return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
300 }
301
302 TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
303 TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
304};
305
306/*!
307 * \brief Create an Attr object with all default values.
308 * \tparam TAttrNode the type to be created.
309 * \return A instance that will represent None.
310 */
311template <typename TAttrs>
312inline TAttrs AttrsWithDefaultValues() {
313 static_assert(std::is_base_of<Attrs, TAttrs>::value, "Can only take attr nodes");
314 auto n = make_object<typename TAttrs::ContainerType>();
315 n->InitByPackedArgs(runtime::TVMArgs(nullptr, nullptr, 0), false);
316 return TAttrs(n);
317}
318
319/*!
320 * \brief Copy the function or module, but overrides
321 * the attribute value key with the value.
322 *
323 * \param input The thing to annotate (BaseFunc or IRModule)
324 * \param attr_key The attribute key.
325 * \param attr_value The value attribute value.
326 *
327 * \tparam TFunc The corresponding function or module type.
328 *
329 * \returns The new function or module with updated attributes.
330 *
331 * \note This function performs copy on write optimization for func and module.
332 * If we move a uniquely referenced func or module into WithAttr,
333 * then no additional copy will be performed.
334 *
335 * This is also why we make it as a function instead of a member function
336 * and why we pass by value in the first argument.
337 *
338 * \code
339 *
340 * // Recommended way to trigger copy on write
341 * func = WithAttr(std::move(func), "key1", value1);
342 * func = WithAttr(std::move(func), "key2", value2);
343 *
344 * \endcode
345 */
346template <typename TFunc>
347inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) {
348 using TNode = typename TFunc::ContainerType;
349 static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
350 TNode* node = input.CopyOnWrite();
351 if (node->attrs.defined()) {
352 node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
353 } else {
354 Map<String, ObjectRef> dict = {{attr_key, attr_value}};
355 node->attrs = DictAttrs(dict);
356 }
357 return input;
358}
359
360/*!
361 * \brief Copy the function or module, but overrides the attributes with the entries from \p attrs.
362 *
363 * \param input The thing to annotate (BaseFunc or IRModule)
364 * \param attrs Key/values attributes to add to \p input.
365 *
366 * \tparam TFunc The corresponding function or module type.
367 *
368 * \returns The new function or module with updated attributes.
369 */
370template <typename TFunc>
371inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
372 using TNode = typename TFunc::ContainerType;
373 static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
374 TNode* node = input.CopyOnWrite();
375 if (node->attrs.defined()) {
376 for (const auto& pair : attrs) {
377 node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
378 }
379 } else {
380 node->attrs = DictAttrs(std::move(attrs));
381 }
382 return input;
383}
384
385/*!
386 * \brief Copy the function or module, but removes the specified
387 * attribute.
388 *
389 * \param input The thing to annotate (BaseFunc or IRModule)
390 * \param attr_key The attribute key.
391 *
392 * \tparam TFunc The corresponding function or module type.
393 *
394 * \returns The new function or module with removed attribute.
395 *
396 * \note This function performs copy on write optimization for func and module.
397 * If we move a uniquely referenced func or module into WithoutAttr,
398 * then no additional copy will be performed.
399 *
400 * This is also why we make it as a function instead of a member function
401 * and why we pass by value in the first argument.
402 *
403 * \code
404 *
405 * // Recommended way to trigger copy on write
406 * func = WithoutAttr(std::move(func), "key1");
407 * func = WithoutAttr(std::move(func), "key2");
408 *
409 * \endcode
410 */
411template <typename TFunc>
412inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
413 using TNode = typename TFunc::ContainerType;
414 static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
415
416 if (input->attrs.defined()) {
417 TNode* node = input.CopyOnWrite();
418 node->attrs.CopyOnWrite()->dict.erase(attr_key);
419 if (node->attrs->dict.size() == 0) {
420 node->attrs = NullValue<DictAttrs>();
421 }
422 }
423 return input;
424}
425
426// Namespace containing detail implementations
427namespace detail {
428using runtime::TVMArgValue;
429
430// helper entry that does nothing in set_default/bound/describe calls.
431struct AttrNopEntry {
432 using TSelf = AttrNopEntry;
433
434 TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
435 template <typename T>
436 TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {
437 return *this;
438 }
439 template <typename T>
440 TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {
441 return *this;
442 }
443 template <typename T>
444 TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {
445 return *this;
446 }
447};
448
449// Wrapper for normal visitor.
450class AttrNormalVisitor {
451 public:
452 explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
453 template <typename T>
454 AttrNopEntry operator()(const char* key, T* value) {
455 visitor_->Visit(key, value);
456 return AttrNopEntry();
457 }
458
459 private:
460 AttrVisitor* visitor_;
461};
462
463class AttrsSEqualVisitor {
464 public:
465 bool result_{true};
466 // constructor
467 AttrsSEqualVisitor(const Object* lhs, const Object* rhs, const SEqualReducer& equal)
468 : lhs_(lhs), rhs_(rhs), equal_(equal) {}
469 template <typename T>
470 AttrNopEntry operator()(const char* key, T* lhs_value) {
471 if (!result_) return AttrNopEntry();
472 const T* rhs_value = reinterpret_cast<const T*>(
473 reinterpret_cast<const char*>(rhs_) +
474 (reinterpret_cast<const char*>(lhs_value) - reinterpret_cast<const char*>(lhs_)));
475 if (!equal_(*lhs_value, *rhs_value)) {
476 result_ = false;
477 }
478 return AttrNopEntry();
479 }
480
481 private:
482 const Object* lhs_;
483 const Object* rhs_;
484 const SEqualReducer& equal_;
485};
486
487class AttrsSHashVisitor {
488 public:
489 explicit AttrsSHashVisitor(const SHashReducer& hash_reducer) : hash_reducer_(hash_reducer) {}
490
491 template <typename T>
492 AttrNopEntry operator()(const char* key, T* value) {
493 hash_reducer_(*value);
494 return AttrNopEntry();
495 }
496
497 private:
498 const SHashReducer& hash_reducer_;
499};
500
501// helper entry that does initialization, set default.
502template <typename T>
503struct AttrInitEntry {
504 // The attributes
505 using TSelf = AttrInitEntry<T>;
506 // The type key
507 const char* type_key_;
508 // field name
509 const char* key_;
510 // internal value.
511 T* value_;
512 // whether the value is missing.
513 // NOTE: initialize to false so that the destructor does not throw unless
514 // AttrInitVisitor::operator() is committed to returning an instance of this class.
515 // It is expected not to set this to true until that is true.
516 bool value_missing_{false};
517
518 AttrInitEntry() = default;
519
520 AttrInitEntry(AttrInitEntry&& other) {
521 type_key_ = other.type_key_;
522 key_ = other.key_;
523 value_ = other.value_;
524 value_missing_ = other.value_missing_;
525 // avoid unexpected throw
526 other.value_missing_ = false;
527 }
528
529 // If the value is still missing in destruction time throw an error.
530 ~AttrInitEntry() DMLC_THROW_EXCEPTION {
531 if (value_missing_) {
532 std::ostringstream os;
533 os << type_key_ << ": Cannot find required field \'" << key_ << "\' during initialization. "
534 << "If the key is defined check that its type matches the declared type.";
535 throw AttrError(os.str());
536 }
537 }
538 // override fields.
539 // This function sets the lower bound of the attribute
540 TSelf& set_lower_bound(const T& begin) {
541 if (this->value_missing_) return *this;
542 const T& val = *value_;
543 if (begin > val) {
544 std::ostringstream os;
545 os << type_key_ << "." << key_ << ": "
546 << "value " << val << " is smaller than the lower bound " << begin;
547 throw AttrError(os.str());
548 }
549 return *this;
550 }
551 // This function sets the upper bound of the attribute
552 TSelf& set_upper_bound(const T& end) {
553 if (this->value_missing_) return *this;
554 const T& val = *value_;
555 if (val > end) {
556 std::ostringstream os;
557 os << type_key_ << "." << key_ << ": "
558 << "value " << val << " is bigger than the upper bound " << end;
559 throw AttrError(os.str());
560 }
561 return *this;
562 }
563 // set default when
564 TSelf& set_default(const T& value) {
565 if (!value_missing_) return *this;
566 *value_ = value;
567 value_missing_ = false;
568 return *this;
569 }
570 TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
571};
572
573// Template function to allow smart conversion
574// from Expr types into the constants.
575template <typename T>
576inline void SetValue(T* ptr, const TVMArgValue& val) {
577 *ptr = val.operator T();
578}
579
580template <typename T>
581inline void SetIntValue(T* ptr, const TVMArgValue& val) {
582 if (val.type_code() == kDLInt) {
583 *ptr = static_cast<T>(val.value().v_int64);
584 } else {
585 IntImm expr = val;
586 *ptr = static_cast<T>(expr->value);
587 }
588}
589
590// Workaround for GCC8.1 / GCC8.2
591template <>
592inline void SetValue<DataType>(DataType* ptr, const TVMArgValue& val) {
593 *ptr = val.operator DataType();
594}
595
596template <>
597inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
598 if (String::CanConvertFrom(val)) {
599 *ptr = val.operator std::string();
600 } else {
601 LOG(FATAL) << "Expect str";
602 }
603}
604
605template <>
606inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
607 if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
608 *ptr = val.operator double();
609 } else {
610 ObjectRef expr = val;
611 ICHECK(expr.defined());
612 if (const IntImmNode* op = expr.as<IntImmNode>()) {
613 *ptr = static_cast<double>(op->value);
614 } else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
615 *ptr = static_cast<double>(op->value);
616 } else {
617 LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
618 }
619 }
620}
621template <>
622inline void SetValue<int>(int* ptr, const TVMArgValue& val) {
623 SetIntValue(ptr, val);
624}
625template <>
626inline void SetValue<int64_t>(int64_t* ptr, const TVMArgValue& val) {
627 SetIntValue(ptr, val);
628}
629template <>
630inline void SetValue<uint64_t>(uint64_t* ptr, const TVMArgValue& val) {
631 SetIntValue(ptr, val);
632}
633template <>
634inline void SetValue<bool>(bool* ptr, const TVMArgValue& val) {
635 SetIntValue(ptr, val);
636}
637
638// Visitor for value initialization
639template <typename FFind>
640class AttrInitVisitor {
641 public:
642 // Counter of number of matched attributes during visit.
643 // This is used to decide if there is additional unmatched attributes.
644 size_t hit_count_{0};
645 // constructor
646 AttrInitVisitor(const char* type_key, FFind ffind) : type_key_(type_key), ffind_(ffind) {}
647
648 template <typename T>
649 AttrInitEntry<T> operator()(const char* key, T* value) {
650 TVMArgValue val;
651 AttrInitEntry<T> opt;
652 opt.type_key_ = type_key_;
653 opt.key_ = key;
654 opt.value_ = value;
655 if (ffind_(key, &val)) {
656 SetValue(value, val);
657 opt.value_missing_ = false;
658 ++hit_count_;
659 } else {
660 opt.value_missing_ = true;
661 }
662#if defined(__GNUC__)
663#pragma GCC diagnostic ignored "-Wpragmas"
664#pragma GCC diagnostic ignored "-Wpessimizing-move"
665#endif
666 return std::move(opt);
667 }
668
669 private:
670 // the type key
671 const char* type_key_;
672 FFind ffind_;
673};
674
675template <typename FFind>
676inline AttrInitVisitor<FFind> CreateInitVisitor(const char* type_key, FFind ffind) {
677 return AttrInitVisitor<FFind>(type_key, ffind);
678}
679
680/*!
681 * \brief Helper struct to get the type name known to tvm.
682 * \tparam T the type we are interested in.
683 */
684template <typename T>
685struct TypeName {
686 static constexpr const char* value = T::ContainerType::_type_key;
687};
688
689template <>
690struct TypeName<int> {
691 static constexpr const char* value = "int";
692};
693
694template <>
695struct TypeName<int64_t> {
696 static constexpr const char* value = "int64";
697};
698
699template <>
700struct TypeName<uint64_t> {
701 static constexpr const char* value = "uint64_t";
702};
703
704template <>
705struct TypeName<DataType> {
706 static constexpr const char* value = "DataType";
707};
708
709template <>
710struct TypeName<std::string> {
711 static constexpr const char* value = "str";
712};
713
714template <>
715struct TypeName<bool> {
716 static constexpr const char* value = "bool";
717};
718
719template <>
720struct TypeName<void*> {
721 static constexpr const char* value = "handle";
722};
723
724template <>
725struct TypeName<double> {
726 static constexpr const char* value = "double";
727};
728
729class AttrDocEntry {
730 public:
731 using TSelf = AttrDocEntry;
732
733 explicit AttrDocEntry(ObjectPtr<AttrFieldInfoNode> info) : info_(info) {}
734 TSelf& describe(const char* str) {
735 info_->description = str;
736 return *this;
737 }
738 template <typename T>
739 TSelf& set_default(const T& value) {
740 std::ostringstream os;
741 os << info_->type_info << ", default=" << value;
742 info_->type_info = os.str();
743 return *this;
744 }
745 template <typename T>
746 TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED T begin) {
747 return *this;
748 }
749 template <typename T>
750 TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED T end) {
751 return *this;
752 }
753
754 private:
755 ObjectPtr<AttrFieldInfoNode> info_;
756};
757
758class AttrDocVisitor {
759 public:
760 template <typename T>
761 AttrDocEntry operator()(const char* key, T* v) {
762 ObjectPtr<AttrFieldInfoNode> info = make_object<AttrFieldInfoNode>();
763 info->name = key;
764 info->type_info = TypeName<T>::value;
765 fields_.push_back(AttrFieldInfo(info));
766 return AttrDocEntry(info);
767 }
768
769 Array<AttrFieldInfo> fields_;
770};
771
772class AttrExistVisitor {
773 public:
774 std::string key_;
775 bool exist_{false};
776
777 template <typename T>
778 AttrNopEntry operator()(const char* key, T* v) {
779 if (exist_) return AttrNopEntry();
780 if (key == key_) exist_ = true;
781 return AttrNopEntry();
782 }
783};
784
785template <typename T>
786struct AttrTriggerNonDefaultEntry {
787 using TSelf = AttrTriggerNonDefaultEntry<T>;
788 // constructor
789 AttrTriggerNonDefaultEntry(AttrVisitor* visitor, const char* key, T* data)
790 : visitor_(visitor), key_(key), data_(data) {}
791
792 ~AttrTriggerNonDefaultEntry() DMLC_THROW_EXCEPTION {
793 if (trigger_) {
794 visitor_->Visit(key_, data_);
795 }
796 }
797 TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
798 TSelf& set_default(const T& value) {
799 if (tvm::StructuralEqual()(value, *data_)) {
800 trigger_ = false;
801 }
802 return *this;
803 }
804 TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) { return *this; }
805 TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) { return *this; }
806
807 private:
808 AttrVisitor* visitor_;
809 const char* key_;
810 T* data_;
811 bool trigger_{true};
812};
813
814class AttrNonDefaultVisitor {
815 public:
816 explicit AttrNonDefaultVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
817 template <typename T>
818 AttrTriggerNonDefaultEntry<T> operator()(const char* key, T* value) {
819 return AttrTriggerNonDefaultEntry<T>(visitor_, key, value);
820 }
821
822 private:
823 AttrVisitor* visitor_;
824};
825} // namespace detail
826
827/*!
828 * \brief The base class of the all the
829 * Use "curiously recurring template pattern".
830 *
831 * \tparam DerivedType The final attribute type.
832 */
833template <typename DerivedType>
834class AttrsNode : public BaseAttrsNode {
835 public:
836 void VisitAttrs(AttrVisitor* v) {
837 ::tvm::detail::AttrNormalVisitor vis(v);
838 self()->_tvm_VisitAttrs(vis);
839 }
840
841 void VisitNonDefaultAttrs(AttrVisitor* v) {
842 ::tvm::detail::AttrNonDefaultVisitor vis(v);
843 self()->_tvm_VisitAttrs(vis);
844 }
845
846 void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {
847 ICHECK_EQ(args.size() % 2, 0);
848 const int kLinearSearchBound = 16;
849 int hit_count = 0;
850 // applies two strategies to lookup
851 if (args.size() < kLinearSearchBound) {
852 // linear search.
853 auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
854 for (int i = 0; i < args.size(); i += 2) {
855 ICHECK_EQ(args.type_codes[i], kTVMStr);
856 if (!std::strcmp(key, args.values[i].v_str)) {
857 *val = args[i + 1];
858 return true;
859 }
860 }
861 return false;
862 };
863 auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
864 self()->_tvm_VisitAttrs(vis);
865 hit_count = vis.hit_count_;
866 } else {
867 // construct a map then do lookup.
868 std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
869 for (int i = 0; i < args.size(); i += 2) {
870 ICHECK_EQ(args.type_codes[i], kTVMStr);
871 kwargs[args[i].operator std::string()] = args[i + 1];
872 }
873 auto ffind = [&kwargs](const char* key, runtime::TVMArgValue* val) {
874 auto it = kwargs.find(key);
875 if (it != kwargs.end()) {
876 *val = it->second;
877 return true;
878 }
879 return false;
880 };
881 auto vis = ::tvm::detail::CreateInitVisitor(DerivedType::_type_key, ffind);
882 self()->_tvm_VisitAttrs(vis);
883 hit_count = vis.hit_count_;
884 }
885 // error handling, slow path
886 if (hit_count * 2 != args.size() && !allow_unknown) {
887 for (int i = 0; i < args.size(); i += 2) {
888 ::tvm::detail::AttrExistVisitor visitor;
889 visitor.key_ = args[i].operator std::string();
890 self()->_tvm_VisitAttrs(visitor);
891 if (!visitor.exist_) {
892 std::ostringstream os;
893 os << DerivedType::_type_key << ": does not have field \'" << visitor.key_
894 << "\', Possible fields:\n";
895 os << "----------------\n";
896 this->PrintDocString(os);
897 throw AttrError(os.str());
898 }
899 }
900 }
901 }
902
903 bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
904 DerivedType* pself = self();
905 ::tvm::detail::AttrsSEqualVisitor visitor(pself, other, equal);
906 self()->_tvm_VisitAttrs(visitor);
907 return visitor.result_;
908 }
909
910 void SHashReduce(SHashReducer hash_reducer) const {
911 ::tvm::detail::AttrsSHashVisitor visitor(hash_reducer);
912 self()->_tvm_VisitAttrs(visitor);
913 }
914
915 Array<AttrFieldInfo> ListFieldInfo() const final {
916 ::tvm::detail::AttrDocVisitor visitor;
917 self()->_tvm_VisitAttrs(visitor);
918 return visitor.fields_;
919 }
920
921 private:
922 DerivedType* self() const {
923 return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
924 }
925};
926
927template <typename... Args>
928inline void BaseAttrsNode::InitBySeq(Args&&... args) {
929 runtime::PackedFunc pf(
930 [this](const TVMArgs& args, TVMRetValue* rv) { this->InitByPackedArgs(args); });
931 pf(std::forward<Args>(args)...);
932}
933
934inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*)
935 Array<AttrFieldInfo> entry = this->ListFieldInfo();
936 for (AttrFieldInfo info : entry) {
937 os << info->name << " : " << info->type_info << '\n';
938 if (info->description.length() != 0) {
939 os << " " << info->description << '\n';
940 }
941 }
942}
943
944} // namespace tvm
945#endif // TVM_IR_ATTRS_H_
946