1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/target/target_kind.h
22 * \brief Target kind registry
23 */
24#ifndef TVM_TARGET_TARGET_KIND_H_
25#define TVM_TARGET_TARGET_KIND_H_
26
27#include <tvm/ir/transform.h>
28#include <tvm/node/attr_registry_map.h>
29#include <tvm/node/node.h>
30
31#include <memory>
32#include <unordered_map>
33#include <utility>
34#include <vector>
35
36namespace tvm {
37
38class Target;
39
40/*!
41 * \brief Map containing parsed features of a specific Target
42 */
43using TargetFeatures = Map<String, ObjectRef>;
44
45/*!
46 * \brief TargetParser to apply on instantiation of a given TargetKind
47 *
48 * \param target_json Target in JSON format to be transformed during parsing.
49 *
50 * \return The transformed Target JSON object.
51 */
52using TargetJSON = Map<String, ObjectRef>;
53using FTVMTargetParser = TypedPackedFunc<TargetJSON(TargetJSON)>;
54
55/*!
56 * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind
57 *
58 * Called before the default lowering passes.
59 *
60 * \param mod The module that an optimization pass runs on.
61 * \param pass_ctx The pass context that can provide information for the optimization.
62 *
63 * \return The transformed module.
64 */
65using FTVMRelayToTIR = transform::Pass;
66
67/*!
68 * \brief TIRToRuntime conversion specific to a TargetKind
69 *
70 * This function is responsible for scanning an IRModule for appropriate Target-specific functions
71 and generating a Runtime module representing the compiled output
72 *
73 * \param ir_module Unified IRModule
74 * \param target Target to filter on or retrieve arguments from
75 * \return Runtime Module containing compiled functions
76 */
77using FTVMTIRToRuntime = runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>;
78
79namespace detail {
80template <typename, typename, typename>
81struct ValueTypeInfoMaker;
82}
83
84class TargetInternal;
85
86template <typename>
87class TargetKindAttrMap;
88
89/*! \brief Target kind, specifies the kind of the target */
90class TargetKindNode : public Object {
91 public:
92 /*! \brief Name of the target kind */
93 String name;
94 /*! \brief Device type of target kind */
95 int default_device_type;
96 /*! \brief Default keys of the target */
97 Array<String> default_keys;
98 /*! \brief Function used to preprocess on target creation */
99 PackedFunc preprocessor;
100 /*! \brief Function used to parse a JSON target during creation */
101 FTVMTargetParser target_parser;
102
103 void VisitAttrs(AttrVisitor* v) {
104 v->Visit("name", &name);
105 v->Visit("default_device_type", &default_device_type);
106 v->Visit("default_keys", &default_keys);
107 }
108
109 static constexpr const char* _type_key = "TargetKind";
110 TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object);
111
112 private:
113 /*! \brief Return the index stored in attr registry */
114 uint32_t AttrRegistryIndex() const { return index_; }
115 /*! \brief Return the name stored in attr registry */
116 String AttrRegistryName() const { return name; }
117 /*! \brief Stores the required type_key and type_index of a specific attr of a target */
118 struct ValueTypeInfo {
119 String type_key;
120 uint32_t type_index;
121 std::unique_ptr<ValueTypeInfo> key;
122 std::unique_ptr<ValueTypeInfo> val;
123 };
124 /*! \brief A hash table that stores the type information of each attr of the target key */
125 std::unordered_map<String, ValueTypeInfo> key2vtype_;
126 /*! \brief A hash table that stores the default value of each attr of the target key */
127 std::unordered_map<String, ObjectRef> key2default_;
128 /*! \brief Index used for internal lookup of attribute registry */
129 uint32_t index_;
130
131 template <typename, typename, typename>
132 friend struct detail::ValueTypeInfoMaker;
133 template <typename, typename>
134 friend class AttrRegistry;
135 template <typename>
136 friend class AttrRegistryMapContainerMap;
137 friend class TargetKindRegEntry;
138 friend class TargetInternal;
139};
140
141/*!
142 * \brief Managed reference class to TargetKindNode
143 * \sa TargetKindNode
144 */
145class TargetKind : public ObjectRef {
146 public:
147 TargetKind() = default;
148 /*! \brief Get the attribute map given the attribute name */
149 template <typename ValueType>
150 static inline TargetKindAttrMap<ValueType> GetAttrMap(const String& attr_name);
151 /*!
152 * \brief Retrieve the TargetKind given its name
153 * \param target_kind_name Name of the target kind
154 * \return The TargetKind requested
155 */
156 TVM_DLL static Optional<TargetKind> Get(const String& target_kind_name);
157 TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode);
158
159 private:
160 /*! \brief Mutable access to the container class */
161 TargetKindNode* operator->() { return static_cast<TargetKindNode*>(data_.get()); }
162 TVM_DLL static const AttrRegistryMapContainerMap<TargetKind>& GetAttrMapContainer(
163 const String& attr_name);
164 friend class TargetKindRegEntry;
165 friend class TargetInternal;
166};
167
168/*!
169 * \brief Map<TargetKind, ValueType> used to store meta-information about TargetKind
170 * \tparam ValueType The type of the value stored in map
171 */
172template <typename ValueType>
173class TargetKindAttrMap : public AttrRegistryMap<TargetKind, ValueType> {
174 public:
175 using TParent = AttrRegistryMap<TargetKind, ValueType>;
176 using TParent::count;
177 using TParent::get;
178 using TParent::operator[];
179 explicit TargetKindAttrMap(const AttrRegistryMapContainerMap<TargetKind>& map) : TParent(map) {}
180};
181
182/*! \brief Value used with --runtime in target specs to indicate the C++ runtime. */
183static constexpr const char* kTvmRuntimeCpp = "c++";
184
185/*! \brief Value used with --runtime in target specs to indicate the C runtime. */
186static constexpr const char* kTvmRuntimeCrt = "c";
187
188/*!
189 * \brief Helper structure to register TargetKind
190 * \sa TVM_REGISTER_TARGET_KIND
191 */
192class TargetKindRegEntry {
193 public:
194 /*!
195 * \brief Register additional attributes to target_kind.
196 * \param attr_name The name of the attribute.
197 * \param value The value to be set.
198 * \param plevel The priority level of this attribute,
199 * an higher priority level attribute
200 * will replace lower priority level attribute.
201 * Must be bigger than 0.
202 *
203 * Cannot set with same plevel twice in the code.
204 *
205 * \tparam ValueType The type of the value to be set.
206 */
207 template <typename ValueType>
208 inline TargetKindRegEntry& set_attr(const String& attr_name, const ValueType& value,
209 int plevel = 10);
210 /*!
211 * \brief Set DLPack's device_type the target
212 * \param device_type Device type
213 */
214 inline TargetKindRegEntry& set_default_device_type(int device_type);
215 /*!
216 * \brief Set DLPack's device_type the target
217 * \param keys The default keys
218 */
219 inline TargetKindRegEntry& set_default_keys(std::vector<String> keys);
220 /*!
221 * \brief Set the pre-processing function applied upon target creation
222 * \tparam FLambda Type of the function
223 * \param f The pre-processing function
224 */
225 template <typename FLambda>
226 inline TargetKindRegEntry& set_attrs_preprocessor(FLambda f);
227 /*!
228 * \brief Set the parsing function applied upon target creation
229 * \param parser The Target parsing function
230 */
231 inline TargetKindRegEntry& set_target_parser(FTVMTargetParser parser);
232 /*!
233 * \brief Register a valid configuration option and its ValueType for validation
234 * \param key The configuration key
235 * \tparam ValueType The value type to be registered
236 */
237 template <typename ValueType>
238 inline TargetKindRegEntry& add_attr_option(const String& key);
239 /*!
240 * \brief Register a valid configuration option and its ValueType for validation
241 * \param key The configuration key
242 * \param default_value The default value of the key
243 * \tparam ValueType The value type to be registered
244 */
245 template <typename ValueType>
246 inline TargetKindRegEntry& add_attr_option(const String& key, ObjectRef default_value);
247 /*! \brief Set name of the TargetKind to be the same as registry if it is empty */
248 inline TargetKindRegEntry& set_name();
249 /*!
250 * \brief List all the entry names in the registry.
251 * \return The entry names.
252 */
253 TVM_DLL static Array<String> ListTargetKinds();
254 /*!
255 * \brief Get all supported option names and types for a given Target kind.
256 * \return Map of option name to type
257 */
258 TVM_DLL static Map<String, String> ListTargetKindOptions(const TargetKind& kind);
259
260 /*!
261 * \brief Register or get a new entry.
262 * \param target_kind_name The name of the TargetKind.
263 * \return the corresponding entry.
264 */
265 TVM_DLL static TargetKindRegEntry& RegisterOrGet(const String& target_kind_name);
266
267 private:
268 TargetKind kind_;
269 String name;
270
271 /*! \brief private constructor */
272 explicit TargetKindRegEntry(uint32_t reg_index) : kind_(make_object<TargetKindNode>()) {
273 kind_->index_ = reg_index;
274 }
275 /*!
276 * \brief update the attribute TargetKindAttrMap
277 * \param key The name of the attribute
278 * \param value The value to be set
279 * \param plevel The priority level
280 */
281 TVM_DLL void UpdateAttr(const String& key, TVMRetValue value, int plevel);
282 template <typename, typename>
283 friend class AttrRegistry;
284 friend class TargetKind;
285};
286
287namespace detail {
288template <typename Type, template <typename...> class Container>
289struct is_specialized : std::false_type {
290 using type = std::false_type;
291};
292
293template <template <typename...> class Container, typename... Args>
294struct is_specialized<Container<Args...>, Container> : std::true_type {
295 using type = std::true_type;
296};
297
298template <typename ValueType, typename IsArray = typename is_specialized<ValueType, Array>::type,
299 typename IsMap = typename is_specialized<ValueType, Map>::type>
300struct ValueTypeInfoMaker {};
301
302template <typename ValueType>
303struct ValueTypeInfoMaker<ValueType, std::false_type, std::false_type> {
304 using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
305
306 ValueTypeInfo operator()() const {
307 uint32_t tindex = ValueType::ContainerType::_GetOrAllocRuntimeTypeIndex();
308 ValueTypeInfo info;
309 info.type_index = tindex;
310 info.type_key = runtime::Object::TypeIndex2Key(tindex);
311 info.key = nullptr;
312 info.val = nullptr;
313 return info;
314 }
315};
316
317template <typename ValueType>
318struct ValueTypeInfoMaker<ValueType, std::true_type, std::false_type> {
319 using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
320
321 ValueTypeInfo operator()() const {
322 using key_type = ValueTypeInfoMaker<typename ValueType::value_type>;
323 uint32_t tindex = ValueType::ContainerType::_GetOrAllocRuntimeTypeIndex();
324 ValueTypeInfo info;
325 info.type_index = tindex;
326 info.type_key = runtime::Object::TypeIndex2Key(tindex);
327 info.key = std::make_unique<ValueTypeInfo>(key_type()());
328 info.val = nullptr;
329 return info;
330 }
331};
332
333template <typename ValueType>
334struct ValueTypeInfoMaker<ValueType, std::false_type, std::true_type> {
335 using ValueTypeInfo = TargetKindNode::ValueTypeInfo;
336 ValueTypeInfo operator()() const {
337 using key_type = ValueTypeInfoMaker<typename ValueType::key_type>;
338 using val_type = ValueTypeInfoMaker<typename ValueType::mapped_type>;
339 uint32_t tindex = ValueType::ContainerType::_GetOrAllocRuntimeTypeIndex();
340 ValueTypeInfo info;
341 info.type_index = tindex;
342 info.type_key = runtime::Object::TypeIndex2Key(tindex);
343 info.key = std::make_unique<ValueTypeInfo>(key_type()());
344 info.val = std::make_unique<ValueTypeInfo>(val_type()());
345 return info;
346 }
347};
348
349} // namespace detail
350
351template <typename ValueType>
352inline TargetKindAttrMap<ValueType> TargetKind::GetAttrMap(const String& attr_name) {
353 return TargetKindAttrMap<ValueType>(GetAttrMapContainer(attr_name));
354}
355
356template <typename ValueType>
357inline TargetKindRegEntry& TargetKindRegEntry::set_attr(const String& attr_name,
358 const ValueType& value, int plevel) {
359 ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
360 runtime::TVMRetValue rv;
361 rv = value;
362 UpdateAttr(attr_name, rv, plevel);
363 return *this;
364}
365
366inline TargetKindRegEntry& TargetKindRegEntry::set_default_device_type(int device_type) {
367 kind_->default_device_type = device_type;
368 return *this;
369}
370
371inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector<String> keys) {
372 kind_->default_keys = keys;
373 return *this;
374}
375
376template <typename FLambda>
377inline TargetKindRegEntry& TargetKindRegEntry::set_attrs_preprocessor(FLambda f) {
378 LOG(WARNING) << "set_attrs_preprocessor is deprecated please use set_target_parser instead";
379 using FType = typename tvm::runtime::detail::function_signature<FLambda>::FType;
380 kind_->preprocessor = tvm::runtime::TypedPackedFunc<FType>(std::move(f)).packed();
381 return *this;
382}
383
384inline TargetKindRegEntry& TargetKindRegEntry::set_target_parser(FTVMTargetParser parser) {
385 kind_->target_parser = parser;
386 return *this;
387}
388
389template <typename ValueType>
390inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key) {
391 ICHECK(!kind_->key2vtype_.count(key))
392 << "AttributeError: add_attr_option failed because '" << key << "' has been set once";
393 kind_->key2vtype_[key] = detail::ValueTypeInfoMaker<ValueType>()();
394 return *this;
395}
396
397template <typename ValueType>
398inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key,
399 ObjectRef default_value) {
400 add_attr_option<ValueType>(key);
401 kind_->key2default_[key] = default_value;
402 return *this;
403}
404
405inline TargetKindRegEntry& TargetKindRegEntry::set_name() {
406 if (kind_->name.empty()) {
407 kind_->name = name;
408 }
409 return *this;
410}
411
412#define TVM_TARGET_KIND_REGISTER_VAR_DEF \
413 static DMLC_ATTRIBUTE_UNUSED ::tvm::TargetKindRegEntry& __make_##TargetKind
414
415namespace attr {
416//
417// Distinguished TargetKind attribute names.
418//
419
420/*!
421 * \brief A \p TargetKind attribute of type \p Bool. If true, then the target kind name also
422 * corresponds to an external codegen 'compiler' name. That name may be used:
423 * - To retrieve partitioning rules using \p get_partition_table.
424 * - To attach to Relay Functions under the \p attr::kCompiler attribute to indicate
425 * the function is to be compiled by the external codegen path.
426 *
427 * The \p CollagePartition pass uses this attribute to guide it's search over candidate partitions
428 * using external codegen.
429 *
430 * See also \p Target::IsExternalCodegenFor
431 */
432constexpr const char* kIsExternalCodegen = "is_external_codegen";
433
434/*!
435 * \brief A \p TargetKind attribute of type \p FTVMRelayToTIR. If set, then the target kind name
436 * also corresponds to an external codegen 'compiler' name, and the bound value is a \p Pass
437 * to apply before the TVM lowering.
438 *
439 * See also \p Target::IsExternalCodegenFor
440 */
441constexpr const char* kRelayToTIR = "RelayToTIR";
442
443} // namespace attr
444
445/*!
446 * \def TVM_REGISTER_TARGET_KIND
447 * \brief Register a new target kind, or set attribute of the corresponding target kind.
448 *
449 * \param TargetKindName The name of target kind
450 * \param DeviceType The DLDeviceType of the target kind
451 *
452 * \code
453 *
454 * TVM_REGISTER_TARGET_KIND("llvm")
455 * .set_attr<TPreCodegenPass>("TPreCodegenPass", a-pre-codegen-pass)
456 * .add_attr_option<Bool>("system_lib")
457 * .add_attr_option<String>("mtriple")
458 * .add_attr_option<String>("mattr");
459 *
460 * \endcode
461 */
462#define TVM_REGISTER_TARGET_KIND(TargetKindName, DeviceType) \
463 TVM_STR_CONCAT(TVM_TARGET_KIND_REGISTER_VAR_DEF, __COUNTER__) = \
464 ::tvm::TargetKindRegEntry::RegisterOrGet(TargetKindName) \
465 .set_name() \
466 .set_default_device_type(DeviceType) \
467 .add_attr_option<Array<String>>("keys") \
468 .add_attr_option<String>("tag") \
469 .add_attr_option<String>("device") \
470 .add_attr_option<String>("model") \
471 .add_attr_option<Array<String>>("libs") \
472 .add_attr_option<Target>("host") \
473 .add_attr_option<Integer>("from_device") \
474 .add_attr_option<Integer>("target_device_type")
475
476} // namespace tvm
477
478#endif // TVM_TARGET_TARGET_KIND_H_
479