1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/target/target_kind.h |
22 | * \brief Target kind registry |
23 | */ |
24 | #ifndef TVM_TARGET_TARGET_KIND_H_ |
25 | #define TVM_TARGET_TARGET_KIND_H_ |
26 | |
27 | #include <tvm/ir/transform.h> |
28 | #include <tvm/node/attr_registry_map.h> |
29 | #include <tvm/node/node.h> |
30 | |
31 | #include <memory> |
32 | #include <unordered_map> |
33 | #include <utility> |
34 | #include <vector> |
35 | |
36 | namespace tvm { |
37 | |
38 | class Target; |
39 | |
40 | /*! |
41 | * \brief Map containing parsed features of a specific Target |
42 | */ |
43 | using TargetFeatures = Map<String, ObjectRef>; |
44 | |
45 | /*! |
46 | * \brief TargetParser to apply on instantiation of a given TargetKind |
47 | * |
48 | * \param target_json Target in JSON format to be transformed during parsing. |
49 | * |
50 | * \return The transformed Target JSON object. |
51 | */ |
52 | using TargetJSON = Map<String, ObjectRef>; |
53 | using FTVMTargetParser = TypedPackedFunc<TargetJSON(TargetJSON)>; |
54 | |
55 | /*! |
56 | * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind |
57 | * |
58 | * Called before the default lowering passes. |
59 | * |
60 | * \param mod The module that an optimization pass runs on. |
61 | * \param pass_ctx The pass context that can provide information for the optimization. |
62 | * |
63 | * \return The transformed module. |
64 | */ |
65 | using FTVMRelayToTIR = transform::Pass; |
66 | |
67 | /*! |
68 | * \brief TIRToRuntime conversion specific to a TargetKind |
69 | * |
70 | * This function is responsible for scanning an IRModule for appropriate Target-specific functions |
71 | and generating a Runtime module representing the compiled output |
72 | * |
73 | * \param ir_module Unified IRModule |
74 | * \param target Target to filter on or retrieve arguments from |
75 | * \return Runtime Module containing compiled functions |
76 | */ |
77 | using FTVMTIRToRuntime = runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>; |
78 | |
79 | namespace detail { |
80 | template <typename, typename, typename> |
81 | struct ValueTypeInfoMaker; |
82 | } |
83 | |
84 | class TargetInternal; |
85 | |
86 | template <typename> |
87 | class TargetKindAttrMap; |
88 | |
89 | /*! \brief Target kind, specifies the kind of the target */ |
90 | class TargetKindNode : public Object { |
91 | public: |
92 | /*! \brief Name of the target kind */ |
93 | String name; |
94 | /*! \brief Device type of target kind */ |
95 | int default_device_type; |
96 | /*! \brief Default keys of the target */ |
97 | Array<String> default_keys; |
98 | /*! \brief Function used to preprocess on target creation */ |
99 | PackedFunc preprocessor; |
100 | /*! \brief Function used to parse a JSON target during creation */ |
101 | FTVMTargetParser target_parser; |
102 | |
103 | void VisitAttrs(AttrVisitor* v) { |
104 | v->Visit("name" , &name); |
105 | v->Visit("default_device_type" , &default_device_type); |
106 | v->Visit("default_keys" , &default_keys); |
107 | } |
108 | |
109 | static constexpr const char* _type_key = "TargetKind" ; |
110 | TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object); |
111 | |
112 | private: |
113 | /*! \brief Return the index stored in attr registry */ |
114 | uint32_t AttrRegistryIndex() const { return index_; } |
115 | /*! \brief Return the name stored in attr registry */ |
116 | String AttrRegistryName() const { return name; } |
117 | /*! \brief Stores the required type_key and type_index of a specific attr of a target */ |
118 | struct ValueTypeInfo { |
119 | String type_key; |
120 | uint32_t type_index; |
121 | std::unique_ptr<ValueTypeInfo> key; |
122 | std::unique_ptr<ValueTypeInfo> val; |
123 | }; |
124 | /*! \brief A hash table that stores the type information of each attr of the target key */ |
125 | std::unordered_map<String, ValueTypeInfo> key2vtype_; |
126 | /*! \brief A hash table that stores the default value of each attr of the target key */ |
127 | std::unordered_map<String, ObjectRef> key2default_; |
128 | /*! \brief Index used for internal lookup of attribute registry */ |
129 | uint32_t index_; |
130 | |
131 | template <typename, typename, typename> |
132 | friend struct detail::ValueTypeInfoMaker; |
133 | template <typename, typename> |
134 | friend class AttrRegistry; |
135 | template <typename> |
136 | friend class AttrRegistryMapContainerMap; |
137 | friend class TargetKindRegEntry; |
138 | friend class TargetInternal; |
139 | }; |
140 | |
141 | /*! |
142 | * \brief Managed reference class to TargetKindNode |
143 | * \sa TargetKindNode |
144 | */ |
145 | class TargetKind : public ObjectRef { |
146 | public: |
147 | TargetKind() = default; |
148 | /*! \brief Get the attribute map given the attribute name */ |
149 | template <typename ValueType> |
150 | static inline TargetKindAttrMap<ValueType> GetAttrMap(const String& attr_name); |
151 | /*! |
152 | * \brief Retrieve the TargetKind given its name |
153 | * \param target_kind_name Name of the target kind |
154 | * \return The TargetKind requested |
155 | */ |
156 | TVM_DLL static Optional<TargetKind> Get(const String& target_kind_name); |
157 | TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode); |
158 | |
159 | private: |
160 | /*! \brief Mutable access to the container class */ |
161 | TargetKindNode* operator->() { return static_cast<TargetKindNode*>(data_.get()); } |
162 | TVM_DLL static const AttrRegistryMapContainerMap<TargetKind>& GetAttrMapContainer( |
163 | const String& attr_name); |
164 | friend class TargetKindRegEntry; |
165 | friend class TargetInternal; |
166 | }; |
167 | |
168 | /*! |
169 | * \brief Map<TargetKind, ValueType> used to store meta-information about TargetKind |
170 | * \tparam ValueType The type of the value stored in map |
171 | */ |
172 | template <typename ValueType> |
173 | class TargetKindAttrMap : public AttrRegistryMap<TargetKind, ValueType> { |
174 | public: |
175 | using TParent = AttrRegistryMap<TargetKind, ValueType>; |
176 | using TParent::count; |
177 | using TParent::get; |
178 | using TParent::operator[]; |
179 | explicit TargetKindAttrMap(const AttrRegistryMapContainerMap<TargetKind>& map) : TParent(map) {} |
180 | }; |
181 | |
182 | /*! \brief Value used with --runtime in target specs to indicate the C++ runtime. */ |
183 | static constexpr const char* kTvmRuntimeCpp = "c++" ; |
184 | |
185 | /*! \brief Value used with --runtime in target specs to indicate the C runtime. */ |
186 | static constexpr const char* kTvmRuntimeCrt = "c" ; |
187 | |
188 | /*! |
189 | * \brief Helper structure to register TargetKind |
190 | * \sa TVM_REGISTER_TARGET_KIND |
191 | */ |
192 | class TargetKindRegEntry { |
193 | public: |
194 | /*! |
195 | * \brief Register additional attributes to target_kind. |
196 | * \param attr_name The name of the attribute. |
197 | * \param value The value to be set. |
198 | * \param plevel The priority level of this attribute, |
199 | * an higher priority level attribute |
200 | * will replace lower priority level attribute. |
201 | * Must be bigger than 0. |
202 | * |
203 | * Cannot set with same plevel twice in the code. |
204 | * |
205 | * \tparam ValueType The type of the value to be set. |
206 | */ |
207 | template <typename ValueType> |
208 | inline TargetKindRegEntry& set_attr(const String& attr_name, const ValueType& value, |
209 | int plevel = 10); |
210 | /*! |
211 | * \brief Set DLPack's device_type the target |
212 | * \param device_type Device type |
213 | */ |
214 | inline TargetKindRegEntry& set_default_device_type(int device_type); |
215 | /*! |
216 | * \brief Set DLPack's device_type the target |
217 | * \param keys The default keys |
218 | */ |
219 | inline TargetKindRegEntry& set_default_keys(std::vector<String> keys); |
220 | /*! |
221 | * \brief Set the pre-processing function applied upon target creation |
222 | * \tparam FLambda Type of the function |
223 | * \param f The pre-processing function |
224 | */ |
225 | template <typename FLambda> |
226 | inline TargetKindRegEntry& set_attrs_preprocessor(FLambda f); |
227 | /*! |
228 | * \brief Set the parsing function applied upon target creation |
229 | * \param parser The Target parsing function |
230 | */ |
231 | inline TargetKindRegEntry& set_target_parser(FTVMTargetParser parser); |
232 | /*! |
233 | * \brief Register a valid configuration option and its ValueType for validation |
234 | * \param key The configuration key |
235 | * \tparam ValueType The value type to be registered |
236 | */ |
237 | template <typename ValueType> |
238 | inline TargetKindRegEntry& add_attr_option(const String& key); |
239 | /*! |
240 | * \brief Register a valid configuration option and its ValueType for validation |
241 | * \param key The configuration key |
242 | * \param default_value The default value of the key |
243 | * \tparam ValueType The value type to be registered |
244 | */ |
245 | template <typename ValueType> |
246 | inline TargetKindRegEntry& add_attr_option(const String& key, ObjectRef default_value); |
247 | /*! \brief Set name of the TargetKind to be the same as registry if it is empty */ |
248 | inline TargetKindRegEntry& set_name(); |
249 | /*! |
250 | * \brief List all the entry names in the registry. |
251 | * \return The entry names. |
252 | */ |
253 | TVM_DLL static Array<String> ListTargetKinds(); |
254 | /*! |
255 | * \brief Get all supported option names and types for a given Target kind. |
256 | * \return Map of option name to type |
257 | */ |
258 | TVM_DLL static Map<String, String> ListTargetKindOptions(const TargetKind& kind); |
259 | |
260 | /*! |
261 | * \brief Register or get a new entry. |
262 | * \param target_kind_name The name of the TargetKind. |
263 | * \return the corresponding entry. |
264 | */ |
265 | TVM_DLL static TargetKindRegEntry& RegisterOrGet(const String& target_kind_name); |
266 | |
267 | private: |
268 | TargetKind kind_; |
269 | String name; |
270 | |
271 | /*! \brief private constructor */ |
272 | explicit TargetKindRegEntry(uint32_t reg_index) : kind_(make_object<TargetKindNode>()) { |
273 | kind_->index_ = reg_index; |
274 | } |
275 | /*! |
276 | * \brief update the attribute TargetKindAttrMap |
277 | * \param key The name of the attribute |
278 | * \param value The value to be set |
279 | * \param plevel The priority level |
280 | */ |
281 | TVM_DLL void UpdateAttr(const String& key, TVMRetValue value, int plevel); |
282 | template <typename, typename> |
283 | friend class AttrRegistry; |
284 | friend class TargetKind; |
285 | }; |
286 | |
287 | namespace detail { |
288 | template <typename Type, template <typename...> class Container> |
289 | struct is_specialized : std::false_type { |
290 | using type = std::false_type; |
291 | }; |
292 | |
293 | template <template <typename...> class Container, typename... Args> |
294 | struct is_specialized<Container<Args...>, Container> : std::true_type { |
295 | using type = std::true_type; |
296 | }; |
297 | |
298 | template <typename ValueType, typename IsArray = typename is_specialized<ValueType, Array>::type, |
299 | typename IsMap = typename is_specialized<ValueType, Map>::type> |
300 | struct ValueTypeInfoMaker {}; |
301 | |
302 | template <typename ValueType> |
303 | struct ValueTypeInfoMaker<ValueType, std::false_type, std::false_type> { |
304 | using ValueTypeInfo = TargetKindNode::ValueTypeInfo; |
305 | |
306 | ValueTypeInfo operator()() const { |
307 | uint32_t tindex = ValueType::ContainerType::_GetOrAllocRuntimeTypeIndex(); |
308 | ValueTypeInfo info; |
309 | info.type_index = tindex; |
310 | info.type_key = runtime::Object::TypeIndex2Key(tindex); |
311 | info.key = nullptr; |
312 | info.val = nullptr; |
313 | return info; |
314 | } |
315 | }; |
316 | |
317 | template <typename ValueType> |
318 | struct ValueTypeInfoMaker<ValueType, std::true_type, std::false_type> { |
319 | using ValueTypeInfo = TargetKindNode::ValueTypeInfo; |
320 | |
321 | ValueTypeInfo operator()() const { |
322 | using key_type = ValueTypeInfoMaker<typename ValueType::value_type>; |
323 | uint32_t tindex = ValueType::ContainerType::_GetOrAllocRuntimeTypeIndex(); |
324 | ValueTypeInfo info; |
325 | info.type_index = tindex; |
326 | info.type_key = runtime::Object::TypeIndex2Key(tindex); |
327 | info.key = std::make_unique<ValueTypeInfo>(key_type()()); |
328 | info.val = nullptr; |
329 | return info; |
330 | } |
331 | }; |
332 | |
333 | template <typename ValueType> |
334 | struct ValueTypeInfoMaker<ValueType, std::false_type, std::true_type> { |
335 | using ValueTypeInfo = TargetKindNode::ValueTypeInfo; |
336 | ValueTypeInfo operator()() const { |
337 | using key_type = ValueTypeInfoMaker<typename ValueType::key_type>; |
338 | using val_type = ValueTypeInfoMaker<typename ValueType::mapped_type>; |
339 | uint32_t tindex = ValueType::ContainerType::_GetOrAllocRuntimeTypeIndex(); |
340 | ValueTypeInfo info; |
341 | info.type_index = tindex; |
342 | info.type_key = runtime::Object::TypeIndex2Key(tindex); |
343 | info.key = std::make_unique<ValueTypeInfo>(key_type()()); |
344 | info.val = std::make_unique<ValueTypeInfo>(val_type()()); |
345 | return info; |
346 | } |
347 | }; |
348 | |
349 | } // namespace detail |
350 | |
351 | template <typename ValueType> |
352 | inline TargetKindAttrMap<ValueType> TargetKind::GetAttrMap(const String& attr_name) { |
353 | return TargetKindAttrMap<ValueType>(GetAttrMapContainer(attr_name)); |
354 | } |
355 | |
356 | template <typename ValueType> |
357 | inline TargetKindRegEntry& TargetKindRegEntry::set_attr(const String& attr_name, |
358 | const ValueType& value, int plevel) { |
359 | ICHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0" ; |
360 | runtime::TVMRetValue rv; |
361 | rv = value; |
362 | UpdateAttr(attr_name, rv, plevel); |
363 | return *this; |
364 | } |
365 | |
366 | inline TargetKindRegEntry& TargetKindRegEntry::set_default_device_type(int device_type) { |
367 | kind_->default_device_type = device_type; |
368 | return *this; |
369 | } |
370 | |
371 | inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector<String> keys) { |
372 | kind_->default_keys = keys; |
373 | return *this; |
374 | } |
375 | |
376 | template <typename FLambda> |
377 | inline TargetKindRegEntry& TargetKindRegEntry::set_attrs_preprocessor(FLambda f) { |
378 | LOG(WARNING) << "set_attrs_preprocessor is deprecated please use set_target_parser instead" ; |
379 | using FType = typename tvm::runtime::detail::function_signature<FLambda>::FType; |
380 | kind_->preprocessor = tvm::runtime::TypedPackedFunc<FType>(std::move(f)).packed(); |
381 | return *this; |
382 | } |
383 | |
384 | inline TargetKindRegEntry& TargetKindRegEntry::set_target_parser(FTVMTargetParser parser) { |
385 | kind_->target_parser = parser; |
386 | return *this; |
387 | } |
388 | |
389 | template <typename ValueType> |
390 | inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key) { |
391 | ICHECK(!kind_->key2vtype_.count(key)) |
392 | << "AttributeError: add_attr_option failed because '" << key << "' has been set once" ; |
393 | kind_->key2vtype_[key] = detail::ValueTypeInfoMaker<ValueType>()(); |
394 | return *this; |
395 | } |
396 | |
397 | template <typename ValueType> |
398 | inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key, |
399 | ObjectRef default_value) { |
400 | add_attr_option<ValueType>(key); |
401 | kind_->key2default_[key] = default_value; |
402 | return *this; |
403 | } |
404 | |
405 | inline TargetKindRegEntry& TargetKindRegEntry::set_name() { |
406 | if (kind_->name.empty()) { |
407 | kind_->name = name; |
408 | } |
409 | return *this; |
410 | } |
411 | |
412 | #define TVM_TARGET_KIND_REGISTER_VAR_DEF \ |
413 | static DMLC_ATTRIBUTE_UNUSED ::tvm::TargetKindRegEntry& __make_##TargetKind |
414 | |
415 | namespace attr { |
416 | // |
417 | // Distinguished TargetKind attribute names. |
418 | // |
419 | |
420 | /*! |
421 | * \brief A \p TargetKind attribute of type \p Bool. If true, then the target kind name also |
422 | * corresponds to an external codegen 'compiler' name. That name may be used: |
423 | * - To retrieve partitioning rules using \p get_partition_table. |
424 | * - To attach to Relay Functions under the \p attr::kCompiler attribute to indicate |
425 | * the function is to be compiled by the external codegen path. |
426 | * |
427 | * The \p CollagePartition pass uses this attribute to guide it's search over candidate partitions |
428 | * using external codegen. |
429 | * |
430 | * See also \p Target::IsExternalCodegenFor |
431 | */ |
432 | constexpr const char* kIsExternalCodegen = "is_external_codegen" ; |
433 | |
434 | /*! |
435 | * \brief A \p TargetKind attribute of type \p FTVMRelayToTIR. If set, then the target kind name |
436 | * also corresponds to an external codegen 'compiler' name, and the bound value is a \p Pass |
437 | * to apply before the TVM lowering. |
438 | * |
439 | * See also \p Target::IsExternalCodegenFor |
440 | */ |
441 | constexpr const char* kRelayToTIR = "RelayToTIR" ; |
442 | |
443 | } // namespace attr |
444 | |
445 | /*! |
446 | * \def TVM_REGISTER_TARGET_KIND |
447 | * \brief Register a new target kind, or set attribute of the corresponding target kind. |
448 | * |
449 | * \param TargetKindName The name of target kind |
450 | * \param DeviceType The DLDeviceType of the target kind |
451 | * |
452 | * \code |
453 | * |
454 | * TVM_REGISTER_TARGET_KIND("llvm") |
455 | * .set_attr<TPreCodegenPass>("TPreCodegenPass", a-pre-codegen-pass) |
456 | * .add_attr_option<Bool>("system_lib") |
457 | * .add_attr_option<String>("mtriple") |
458 | * .add_attr_option<String>("mattr"); |
459 | * |
460 | * \endcode |
461 | */ |
462 | #define TVM_REGISTER_TARGET_KIND(TargetKindName, DeviceType) \ |
463 | TVM_STR_CONCAT(TVM_TARGET_KIND_REGISTER_VAR_DEF, __COUNTER__) = \ |
464 | ::tvm::TargetKindRegEntry::RegisterOrGet(TargetKindName) \ |
465 | .set_name() \ |
466 | .set_default_device_type(DeviceType) \ |
467 | .add_attr_option<Array<String>>("keys") \ |
468 | .add_attr_option<String>("tag") \ |
469 | .add_attr_option<String>("device") \ |
470 | .add_attr_option<String>("model") \ |
471 | .add_attr_option<Array<String>>("libs") \ |
472 | .add_attr_option<Target>("host") \ |
473 | .add_attr_option<Integer>("from_device") \ |
474 | .add_attr_option<Integer>("target_device_type") |
475 | |
476 | } // namespace tvm |
477 | |
478 | #endif // TVM_TARGET_TARGET_KIND_H_ |
479 | |