1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/target/tag.h |
22 | * \brief Target tag registry |
23 | */ |
24 | #ifndef TVM_TARGET_TAG_H_ |
25 | #define TVM_TARGET_TAG_H_ |
26 | |
27 | #include <tvm/node/attr_registry_map.h> |
28 | #include <tvm/node/node.h> |
29 | #include <tvm/target/target.h> |
30 | |
31 | #include <utility> |
32 | |
33 | namespace tvm { |
34 | |
35 | /*! \brief A target tag */ |
36 | class TargetTagNode : public Object { |
37 | public: |
38 | /*! \brief Name of the target */ |
39 | String name; |
40 | /*! \brief Config map to generate the target */ |
41 | Map<String, ObjectRef> config; |
42 | |
43 | void VisitAttrs(AttrVisitor* v) { |
44 | v->Visit("name" , &name); |
45 | v->Visit("config" , &config); |
46 | } |
47 | |
48 | static constexpr const char* _type_key = "TargetTag" ; |
49 | TVM_DECLARE_FINAL_OBJECT_INFO(TargetTagNode, Object); |
50 | |
51 | private: |
52 | /*! \brief Return the index stored in attr registry */ |
53 | uint32_t AttrRegistryIndex() const { return index_; } |
54 | /*! \brief Return the name stored in attr registry */ |
55 | String AttrRegistryName() const { return name; } |
56 | /*! \brief Index used for internal lookup of attribute registry */ |
57 | uint32_t index_; |
58 | |
59 | template <typename, typename> |
60 | friend class AttrRegistry; |
61 | template <typename> |
62 | friend class AttrRegistryMapContainerMap; |
63 | friend class TargetTagRegEntry; |
64 | }; |
65 | |
66 | /*! |
67 | * \brief Managed reference class to TargetTagNode |
68 | * \sa TargetTagNode |
69 | */ |
70 | class TargetTag : public ObjectRef { |
71 | public: |
72 | /*! |
73 | * \brief Retrieve the Target given it the name of target tag |
74 | * \param target_tag_name Name of the target tag |
75 | * \return The Target requested |
76 | */ |
77 | TVM_DLL static Optional<Target> Get(const String& target_tag_name); |
78 | /*! |
79 | * \brief List all names of the existing target tags |
80 | * \return A dictionary that maps tag name to the concrete target it corresponds to |
81 | */ |
82 | TVM_DLL static Map<String, Target> ListTags(); |
83 | /*! |
84 | * \brief Add a tag into the registry |
85 | * \param name Name of the tag |
86 | * \param config The target config corresponding to the tag |
87 | * \param override Allow overriding existing tags |
88 | * \return Target created with the tag |
89 | */ |
90 | TVM_DLL static Target AddTag(String name, Map<String, ObjectRef> config, bool override); |
91 | |
92 | TVM_DEFINE_OBJECT_REF_METHODS(TargetTag, ObjectRef, TargetTagNode); |
93 | |
94 | private: |
95 | /*! \brief Mutable access to the container class */ |
96 | TargetTagNode* operator->() { return static_cast<TargetTagNode*>(data_.get()); } |
97 | friend class TargetTagRegEntry; |
98 | }; |
99 | |
100 | class TargetTagRegEntry { |
101 | public: |
102 | /*! |
103 | * \brief Set the config dict corresponding to the target tag |
104 | * \param config The config dict for target creation |
105 | */ |
106 | inline TargetTagRegEntry& set_config(Map<String, ObjectRef> config); |
107 | /*! \brief Set name of the TargetTag to be the same as registry if it is empty */ |
108 | inline TargetTagRegEntry& set_name(); |
109 | /*! |
110 | * \brief Register or get a new entry. |
111 | * \param target_tag_name The name of the TargetTag. |
112 | * \return the corresponding entry. |
113 | */ |
114 | TVM_DLL static TargetTagRegEntry& RegisterOrGet(const String& target_tag_name); |
115 | |
116 | private: |
117 | TargetTag tag_; |
118 | String name; |
119 | |
120 | /*! \brief private constructor */ |
121 | explicit TargetTagRegEntry(uint32_t reg_index) : tag_(make_object<TargetTagNode>()) { |
122 | tag_->index_ = reg_index; |
123 | } |
124 | template <typename, typename> |
125 | friend class AttrRegistry; |
126 | friend class TargetTag; |
127 | }; |
128 | |
129 | inline TargetTagRegEntry& TargetTagRegEntry::set_config(Map<String, ObjectRef> config) { |
130 | tag_->config = std::move(config); |
131 | return *this; |
132 | } |
133 | |
134 | inline TargetTagRegEntry& TargetTagRegEntry::set_name() { |
135 | if (tag_->name.empty()) { |
136 | tag_->name = name; |
137 | } |
138 | return *this; |
139 | } |
140 | |
141 | #define TVM_TARGET_TAG_REGISTER_VAR_DEF \ |
142 | static DMLC_ATTRIBUTE_UNUSED ::tvm::TargetTagRegEntry& __make_##TargetTag |
143 | |
144 | /*! |
145 | * \def TVM_REGISTER_TARGET_TAG |
146 | * \brief Register a new target tag, or set attribute of the corresponding target tag. |
147 | * \param TargetTagName The name of target tag |
148 | */ |
149 | #define TVM_REGISTER_TARGET_TAG(TargetTagName) \ |
150 | TVM_STR_CONCAT(TVM_TARGET_TAG_REGISTER_VAR_DEF, __COUNTER__) = \ |
151 | ::tvm::TargetTagRegEntry::RegisterOrGet(TargetTagName).set_name() |
152 | |
153 | } // namespace tvm |
154 | |
155 | #endif // TVM_TARGET_TAG_H_ |
156 | |