1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/target/tag.h
22 * \brief Target tag registry
23 */
24#ifndef TVM_TARGET_TAG_H_
25#define TVM_TARGET_TAG_H_
26
27#include <tvm/node/attr_registry_map.h>
28#include <tvm/node/node.h>
29#include <tvm/target/target.h>
30
31#include <utility>
32
33namespace tvm {
34
35/*! \brief A target tag */
36class TargetTagNode : public Object {
37 public:
38 /*! \brief Name of the target */
39 String name;
40 /*! \brief Config map to generate the target */
41 Map<String, ObjectRef> config;
42
43 void VisitAttrs(AttrVisitor* v) {
44 v->Visit("name", &name);
45 v->Visit("config", &config);
46 }
47
48 static constexpr const char* _type_key = "TargetTag";
49 TVM_DECLARE_FINAL_OBJECT_INFO(TargetTagNode, Object);
50
51 private:
52 /*! \brief Return the index stored in attr registry */
53 uint32_t AttrRegistryIndex() const { return index_; }
54 /*! \brief Return the name stored in attr registry */
55 String AttrRegistryName() const { return name; }
56 /*! \brief Index used for internal lookup of attribute registry */
57 uint32_t index_;
58
59 template <typename, typename>
60 friend class AttrRegistry;
61 template <typename>
62 friend class AttrRegistryMapContainerMap;
63 friend class TargetTagRegEntry;
64};
65
66/*!
67 * \brief Managed reference class to TargetTagNode
68 * \sa TargetTagNode
69 */
70class TargetTag : public ObjectRef {
71 public:
72 /*!
73 * \brief Retrieve the Target given it the name of target tag
74 * \param target_tag_name Name of the target tag
75 * \return The Target requested
76 */
77 TVM_DLL static Optional<Target> Get(const String& target_tag_name);
78 /*!
79 * \brief List all names of the existing target tags
80 * \return A dictionary that maps tag name to the concrete target it corresponds to
81 */
82 TVM_DLL static Map<String, Target> ListTags();
83 /*!
84 * \brief Add a tag into the registry
85 * \param name Name of the tag
86 * \param config The target config corresponding to the tag
87 * \param override Allow overriding existing tags
88 * \return Target created with the tag
89 */
90 TVM_DLL static Target AddTag(String name, Map<String, ObjectRef> config, bool override);
91
92 TVM_DEFINE_OBJECT_REF_METHODS(TargetTag, ObjectRef, TargetTagNode);
93
94 private:
95 /*! \brief Mutable access to the container class */
96 TargetTagNode* operator->() { return static_cast<TargetTagNode*>(data_.get()); }
97 friend class TargetTagRegEntry;
98};
99
100class TargetTagRegEntry {
101 public:
102 /*!
103 * \brief Set the config dict corresponding to the target tag
104 * \param config The config dict for target creation
105 */
106 inline TargetTagRegEntry& set_config(Map<String, ObjectRef> config);
107 /*! \brief Set name of the TargetTag to be the same as registry if it is empty */
108 inline TargetTagRegEntry& set_name();
109 /*!
110 * \brief Register or get a new entry.
111 * \param target_tag_name The name of the TargetTag.
112 * \return the corresponding entry.
113 */
114 TVM_DLL static TargetTagRegEntry& RegisterOrGet(const String& target_tag_name);
115
116 private:
117 TargetTag tag_;
118 String name;
119
120 /*! \brief private constructor */
121 explicit TargetTagRegEntry(uint32_t reg_index) : tag_(make_object<TargetTagNode>()) {
122 tag_->index_ = reg_index;
123 }
124 template <typename, typename>
125 friend class AttrRegistry;
126 friend class TargetTag;
127};
128
129inline TargetTagRegEntry& TargetTagRegEntry::set_config(Map<String, ObjectRef> config) {
130 tag_->config = std::move(config);
131 return *this;
132}
133
134inline TargetTagRegEntry& TargetTagRegEntry::set_name() {
135 if (tag_->name.empty()) {
136 tag_->name = name;
137 }
138 return *this;
139}
140
141#define TVM_TARGET_TAG_REGISTER_VAR_DEF \
142 static DMLC_ATTRIBUTE_UNUSED ::tvm::TargetTagRegEntry& __make_##TargetTag
143
144/*!
145 * \def TVM_REGISTER_TARGET_TAG
146 * \brief Register a new target tag, or set attribute of the corresponding target tag.
147 * \param TargetTagName The name of target tag
148 */
149#define TVM_REGISTER_TARGET_TAG(TargetTagName) \
150 TVM_STR_CONCAT(TVM_TARGET_TAG_REGISTER_VAR_DEF, __COUNTER__) = \
151 ::tvm::TargetTagRegEntry::RegisterOrGet(TargetTagName).set_name()
152
153} // namespace tvm
154
155#endif // TVM_TARGET_TAG_H_
156