1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/target/target_kind.cc |
22 | * \brief Target kind registry |
23 | */ |
24 | #include <tvm/ir/expr.h> |
25 | #include <tvm/runtime/device_api.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/target/target.h> |
28 | #include <tvm/target/target_kind.h> |
29 | |
30 | #include <algorithm> |
31 | |
32 | #include "../node/attr_registry.h" |
33 | #include "./parsers/cpu.h" |
34 | |
35 | namespace tvm { |
36 | |
37 | TVM_REGISTER_NODE_TYPE(TargetKindNode); |
38 | |
39 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
40 | .set_dispatch<TargetKindNode>([](const ObjectRef& obj, ReprPrinter* p) { |
41 | const TargetKind& kind = Downcast<TargetKind>(obj); |
42 | p->stream << kind->name; |
43 | }); |
44 | |
45 | /********** Registry-related code **********/ |
46 | |
47 | using TargetKindRegistry = AttrRegistry<TargetKindRegEntry, TargetKind>; |
48 | |
49 | Array<String> TargetKindRegEntry::ListTargetKinds() { |
50 | return TargetKindRegistry::Global()->ListAllNames(); |
51 | } |
52 | |
53 | Map<String, String> TargetKindRegEntry::ListTargetKindOptions(const TargetKind& target_kind) { |
54 | Map<String, String> options; |
55 | for (const auto& kv : target_kind->key2vtype_) { |
56 | options.Set(kv.first, kv.second.type_key); |
57 | } |
58 | return options; |
59 | } |
60 | |
61 | TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_name) { |
62 | return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name); |
63 | } |
64 | |
65 | void TargetKindRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) { |
66 | TargetKindRegistry::Global()->UpdateAttr(key, kind_, value, plevel); |
67 | } |
68 | |
69 | const AttrRegistryMapContainerMap<TargetKind>& TargetKind::GetAttrMapContainer( |
70 | const String& attr_name) { |
71 | return TargetKindRegistry::Global()->GetAttrMap(attr_name); |
72 | } |
73 | |
74 | Optional<TargetKind> TargetKind::Get(const String& target_kind_name) { |
75 | const TargetKindRegEntry* reg = TargetKindRegistry::Global()->Get(target_kind_name); |
76 | if (reg == nullptr) { |
77 | return NullOpt; |
78 | } |
79 | return reg->kind_; |
80 | } |
81 | |
82 | /********** Utility functions **********/ |
83 | |
84 | /*! |
85 | * \brief Extract a number from the string with the given prefix. |
86 | * For example, when `str` is "sm_20" and `prefix` is "sm_". |
87 | * This function first checks if `str` starts with `prefix`, |
88 | * then return the integer 20 after the `prefix` |
89 | * \param str The string to be extracted |
90 | * \param prefix The prefix to be checked |
91 | * \return An integer, the extracted number. -1 if the check fails |
92 | */ |
93 | static int (const std::string& str, const std::string& prefix) { |
94 | if (str.substr(0, prefix.size()) != prefix) { |
95 | return -1; |
96 | } |
97 | int result = 0; |
98 | for (size_t i = prefix.size(); i < str.size(); ++i) { |
99 | char c = str[i]; |
100 | if (!isdigit(c)) { |
101 | return -1; |
102 | } |
103 | result = result * 10 + c - '0'; |
104 | } |
105 | return result; |
106 | } |
107 | |
108 | /*! |
109 | * \brief Using TVM DeviceAPI to detect the device flag |
110 | * \param device The device to be detected |
111 | * \param flag The device flag to be detected |
112 | * \param val The detected value |
113 | * \return A boolean indicating if detection succeeds |
114 | */ |
115 | static bool DetectDeviceFlag(Device device, runtime::DeviceAttrKind flag, TVMRetValue* val) { |
116 | using runtime::DeviceAPI; |
117 | DeviceAPI* api = DeviceAPI::Get(device, true); |
118 | // Check if compiled with the corresponding device api |
119 | if (api == nullptr) { |
120 | return false; |
121 | } |
122 | // Check if the device exists |
123 | api->GetAttr(device, runtime::kExist, val); |
124 | int exists = *val; |
125 | if (!exists) { |
126 | return false; |
127 | } |
128 | // Get the arch of the device |
129 | DeviceAPI::Get(device)->GetAttr(device, flag, val); |
130 | return true; |
131 | } |
132 | |
133 | void CheckOrSetAttr(Map<String, ObjectRef>* attrs, const String& name, const String& value) { |
134 | auto iter = attrs->find(name); |
135 | if (iter == attrs->end()) { |
136 | attrs->Set(name, value); |
137 | } else { |
138 | const auto* str = (*iter).second.as<StringObj>(); |
139 | ICHECK(str != nullptr && GetRef<String>(str) == value) |
140 | << "ValueError: Expects \"" << name << "\" to be \"" << value |
141 | << "\", but gets: " << (*iter).second; |
142 | } |
143 | } |
144 | |
145 | /********** Target kind attribute updaters **********/ |
146 | |
147 | /*! |
148 | * \brief Update the attributes in the CUDA target. |
149 | * \param target The Target to update |
150 | * \return The updated attributes |
151 | */ |
152 | TargetJSON UpdateCUDAAttrs(TargetJSON target) { |
153 | // Update -arch=sm_xx |
154 | int archInt; |
155 | if (target.count("arch" )) { |
156 | // If -arch has been specified, validate the correctness |
157 | String archStr = Downcast<String>(target.at("arch" )); |
158 | archInt = ExtractIntWithPrefix(archStr, "sm_" ); |
159 | ICHECK(archInt != -1) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr; |
160 | } else { |
161 | // Use the compute version of the first CUDA GPU instead |
162 | TVMRetValue version; |
163 | if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) { |
164 | LOG(WARNING) << "Unable to detect CUDA version, default to \"-arch=sm_20\" instead" ; |
165 | archInt = 20; |
166 | } else { |
167 | archInt = std::stod(version.operator std::string()) * 10 + 0.1; |
168 | } |
169 | target.Set("arch" , String("sm_" ) + std::to_string(archInt)); |
170 | } |
171 | return target; |
172 | } |
173 | |
174 | /*! |
175 | * \brief Update the attributes in the LLVM NVPTX target. |
176 | * \param target The Target to update |
177 | * \return The updated attributes |
178 | */ |
179 | TargetJSON UpdateNVPTXAttrs(TargetJSON target) { |
180 | CheckOrSetAttr(&target, "mtriple" , "nvptx64-nvidia-cuda" ); |
181 | // Update -mcpu=sm_xx |
182 | int arch; |
183 | if (target.count("mcpu" )) { |
184 | // If -mcpu has been specified, validate the correctness |
185 | String mcpu = Downcast<String>(target.at("mcpu" )); |
186 | arch = ExtractIntWithPrefix(mcpu, "sm_" ); |
187 | ICHECK(arch != -1) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu; |
188 | } else { |
189 | // Use the compute version of the first CUDA GPU instead |
190 | TVMRetValue version; |
191 | if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) { |
192 | LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_20\" instead" ; |
193 | arch = 20; |
194 | } else { |
195 | arch = std::stod(version.operator std::string()) * 10 + 0.1; |
196 | } |
197 | target.Set("mcpu" , String("sm_" ) + std::to_string(arch)); |
198 | } |
199 | return target; |
200 | } |
201 | |
202 | /*! |
203 | * \brief Update the attributes in the LLVM ROCm target. |
204 | * \param target The Target to update |
205 | * \return The updated attributes |
206 | */ |
207 | TargetJSON UpdateROCmAttrs(TargetJSON target) { |
208 | CheckOrSetAttr(&target, "mtriple" , "amdgcn-amd-amdhsa-hcc" ); |
209 | // Update -mcpu=gfx |
210 | int arch; |
211 | if (target.count("mcpu" )) { |
212 | String mcpu = Downcast<String>(target.at("mcpu" )); |
213 | arch = ExtractIntWithPrefix(mcpu, "gfx" ); |
214 | ICHECK(arch != -1) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu; |
215 | } else { |
216 | TVMRetValue val; |
217 | if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kGcnArch, &val)) { |
218 | LOG(WARNING) << "Unable to detect ROCm compute arch, default to \"-mcpu=gfx900\" instead" ; |
219 | arch = 900; |
220 | } else { |
221 | arch = val.operator int(); |
222 | } |
223 | target.Set("mcpu" , String("gfx" ) + std::to_string(arch)); |
224 | } |
225 | // Update -mattr before ROCm 3.5: |
226 | // Before ROCm 3.5 we needed code object v2, starting |
227 | // with 3.5 we need v3 (this argument disables v3) |
228 | |
229 | TVMRetValue val; |
230 | int version; |
231 | if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &val)) { |
232 | LOG(WARNING) << "Unable to detect ROCm version, assuming >= 3.5" ; |
233 | version = 305; |
234 | } else { |
235 | version = val.operator int(); |
236 | } |
237 | if (version < 305) { |
238 | Array<String> mattr; |
239 | if (target.count("mattr" )) { |
240 | mattr = Downcast<Array<String>>(target.at("mattr" )); |
241 | } |
242 | mattr.push_back("-code-object-v3" ); |
243 | target.Set("mattr" , mattr); |
244 | } |
245 | return target; |
246 | } |
247 | |
248 | /*! |
249 | * \brief Test Target Parser |
250 | * \param target The Target to update |
251 | * \return The updated attributes |
252 | */ |
253 | TargetJSON TestTargetParser(TargetJSON target) { |
254 | Map<String, ObjectRef> features = {{"is_test" , Bool(true)}}; |
255 | target.Set("features" , features); |
256 | return target; |
257 | } |
258 | |
259 | /********** Register Target kinds and attributes **********/ |
260 | |
261 | TVM_REGISTER_TARGET_KIND("llvm" , kDLCPU) |
262 | .add_attr_option<Array<String>>("mattr" ) |
263 | .add_attr_option<String>("mcpu" ) |
264 | .add_attr_option<String>("mtriple" ) |
265 | .add_attr_option<String>("mfloat-abi" ) |
266 | .add_attr_option<String>("mabi" ) |
267 | .add_attr_option<Integer>("num-cores" ) |
268 | // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags |
269 | .add_attr_option<Bool>("fast-math" ) // implies all the below |
270 | .add_attr_option<Bool>("fast-math-nnan" ) |
271 | .add_attr_option<Bool>("fast-math-ninf" ) |
272 | .add_attr_option<Bool>("fast-math-nsz" ) |
273 | .add_attr_option<Bool>("fast-math-arcp" ) |
274 | .add_attr_option<Bool>("fast-math-contract" ) |
275 | .add_attr_option<Bool>("fast-math-reassoc" ) |
276 | .add_attr_option<Integer>("opt-level" ) |
277 | // LLVM command line flags, see below |
278 | .add_attr_option<Array<String>>("cl-opt" ) |
279 | .set_default_keys({"cpu" }) |
280 | // Force the external codegen kind attribute to be registered, even if no external |
281 | // codegen targets are enabled by the TVM build. |
282 | .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(false)) |
283 | .set_target_parser(tvm::target::parsers::cpu::ParseTarget); |
284 | |
285 | // Note regarding the "cl-opt" attribute: |
286 | // Each string in the array has the format |
287 | // -optionname[[:type]=value] |
288 | // where |
289 | // * optionname is the actual LLVM option (e.g. "unroll-threshold") |
290 | // * type is one of "bool", "int", "uint", or "string" |
291 | // * value is the corresponding option value (for "bool" type is can be 0 or "false" |
292 | // for false value, or 1 or "true" for true value) |
293 | // If type is omitted, it is assumed to be "bool". If value is omitted, it is assumed |
294 | // to be "true". |
295 | // |
296 | // The type must match the option type in LLVM. To find the type, search the LLVM |
297 | // repository (https://github.com/llvm/llvm-project) for optionname, and look for |
298 | // its definition: it will be a declaration of a variable of type cl::opt<T> with |
299 | // optionname being an argument to the constructor. The T in the declaration is |
300 | // the type. |
301 | // For example, for unroll-threshold, we get the following declaration: |
302 | // static cl::opt<unsigned> |
303 | // UnrollThreshold("unroll-threshold", cl::Hidden, |
304 | // cl::desc("The cost threshold for loop unrolling")); |
305 | // Hence the type is "uint". |
306 | |
307 | TVM_REGISTER_TARGET_KIND("c" , kDLCPU) |
308 | .add_attr_option<String>("mcpu" ) |
309 | .add_attr_option<String>("march" ) |
310 | .add_attr_option<Integer>("workspace-byte-alignment" ) |
311 | .add_attr_option<Integer>("constants-byte-alignment" ) |
312 | .set_default_keys({"cpu" }) |
313 | .set_target_parser(tvm::target::parsers::cpu::ParseTarget); |
314 | |
315 | TVM_REGISTER_TARGET_KIND("cuda" , kDLCUDA) |
316 | .add_attr_option<String>("mcpu" ) |
317 | .add_attr_option<String>("arch" ) |
318 | .add_attr_option<Integer>("max_shared_memory_per_block" ) |
319 | .add_attr_option<Integer>("max_threads_per_block" ) |
320 | .add_attr_option<Integer>("thread_warp_size" , Integer(32)) |
321 | .add_attr_option<Integer>("registers_per_block" ) |
322 | .add_attr_option<Integer>("max_num_threads" , Integer(1024)) // TODO(@zxybazh): deprecate it |
323 | .set_default_keys({"cuda" , "gpu" }) |
324 | .set_target_parser(UpdateCUDAAttrs); |
325 | |
326 | TVM_REGISTER_TARGET_KIND("nvptx" , kDLCUDA) |
327 | .add_attr_option<String>("mcpu" ) |
328 | .add_attr_option<String>("mtriple" ) |
329 | .add_attr_option<Integer>("max_num_threads" , Integer(1024)) |
330 | .add_attr_option<Integer>("thread_warp_size" , Integer(32)) |
331 | .set_default_keys({"cuda" , "gpu" }) |
332 | .set_target_parser(UpdateNVPTXAttrs); |
333 | |
334 | TVM_REGISTER_TARGET_KIND("rocm" , kDLROCM) |
335 | .add_attr_option<String>("mcpu" ) |
336 | .add_attr_option<String>("mtriple" ) |
337 | .add_attr_option<Array<String>>("mattr" ) |
338 | // TODO(masahi): Support querying from a target device |
339 | // On RDNA cards, thread_warp_size should be 32 |
340 | .add_attr_option<Integer>("max_num_threads" , Integer(256)) |
341 | .add_attr_option<Integer>("max_threads_per_block" , Integer(256)) |
342 | .add_attr_option<Integer>("max_shared_memory_per_block" , Integer(65536)) |
343 | .add_attr_option<Integer>("thread_warp_size" , Integer(64)) |
344 | .set_default_keys({"rocm" , "gpu" }) |
345 | .set_target_parser(UpdateROCmAttrs); |
346 | |
347 | TVM_REGISTER_TARGET_KIND("opencl" , kDLOpenCL) |
348 | .add_attr_option<Integer>("max_num_threads" , Integer(256)) |
349 | .add_attr_option<Integer>("thread_warp_size" , Integer(1)) |
350 | .add_attr_option<Integer>("texture_spatial_limit" , Integer(16384)) |
351 | .set_default_keys({"opencl" , "gpu" }); |
352 | |
353 | // The metal has some limitations on the number of input parameters. This is why attribute |
354 | // `max_function_args` was introduced. It specifies the maximum number of kernel argumetns. More |
355 | // information about this limitation can be found here: |
356 | // https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc |
357 | // See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf |
358 | TVM_REGISTER_TARGET_KIND("metal" , kDLMetal) |
359 | .add_attr_option<Integer>("max_num_threads" , Integer(256)) |
360 | .add_attr_option<Integer>("max_threads_per_block" , Integer(256)) |
361 | .add_attr_option<Integer>("max_shared_memory_per_block" , Integer(32768)) |
362 | .add_attr_option<Integer>("thread_warp_size" , Integer(16)) |
363 | .add_attr_option<Integer>("max_function_args" , Integer(31)) |
364 | .set_default_keys({"metal" , "gpu" }); |
365 | |
366 | TVM_REGISTER_TARGET_KIND("vulkan" , kDLVulkan) |
367 | .add_attr_option<Array<String>>("mattr" ) |
368 | // Feature support |
369 | .add_attr_option<Bool>("supports_float16" ) |
370 | .add_attr_option<Bool>("supports_float32" , Bool(true)) |
371 | .add_attr_option<Bool>("supports_float64" ) |
372 | .add_attr_option<Bool>("supports_int8" ) |
373 | .add_attr_option<Bool>("supports_int16" ) |
374 | .add_attr_option<Bool>("supports_int32" , Bool(true)) |
375 | .add_attr_option<Bool>("supports_int64" ) |
376 | .add_attr_option<Bool>("supports_8bit_buffer" ) |
377 | .add_attr_option<Bool>("supports_16bit_buffer" ) |
378 | .add_attr_option<Bool>("supports_storage_buffer_storage_class" ) |
379 | .add_attr_option<Bool>("supports_push_descriptor" ) |
380 | .add_attr_option<Bool>("supports_dedicated_allocation" ) |
381 | .add_attr_option<Bool>("supports_integer_dot_product" ) |
382 | .add_attr_option<Integer>("supported_subgroup_operations" ) |
383 | // Physical device limits |
384 | .add_attr_option<Integer>("max_num_threads" , Integer(256)) |
385 | .add_attr_option<Integer>("max_threads_per_block" , Integer(256)) |
386 | .add_attr_option<Integer>("thread_warp_size" , Integer(1)) |
387 | .add_attr_option<Integer>("max_block_size_x" ) |
388 | .add_attr_option<Integer>("max_block_size_y" ) |
389 | .add_attr_option<Integer>("max_block_size_z" ) |
390 | .add_attr_option<Integer>("max_push_constants_size" ) |
391 | .add_attr_option<Integer>("max_uniform_buffer_range" ) |
392 | .add_attr_option<Integer>("max_storage_buffer_range" ) |
393 | .add_attr_option<Integer>("max_per_stage_descriptor_storage_buffer" ) |
394 | .add_attr_option<Integer>("max_shared_memory_per_block" ) |
395 | // Other device properties |
396 | .add_attr_option<String>("device_type" ) |
397 | .add_attr_option<String>("device_name" ) |
398 | .add_attr_option<String>("driver_name" ) |
399 | .add_attr_option<Integer>("driver_version" ) |
400 | .add_attr_option<Integer>("vulkan_api_version" ) |
401 | .add_attr_option<Integer>("max_spirv_version" ) |
402 | // Tags |
403 | .set_default_keys({"vulkan" , "gpu" }); |
404 | |
405 | TVM_REGISTER_TARGET_KIND("webgpu" , kDLWebGPU) |
406 | .add_attr_option<Integer>("max_num_threads" , Integer(256)) |
407 | .set_default_keys({"webgpu" , "gpu" }); |
408 | |
409 | TVM_REGISTER_TARGET_KIND("sdaccel" , kDLOpenCL) // line break |
410 | .set_default_keys({"sdaccel" , "hls" }); |
411 | |
412 | TVM_REGISTER_TARGET_KIND("aocl" , kDLAOCL) // line break |
413 | .set_default_keys({"aocl" , "hls" }); |
414 | |
415 | TVM_REGISTER_TARGET_KIND("aocl_sw_emu" , kDLAOCL) // line break |
416 | .set_default_keys({"aocl" , "hls" }); |
417 | |
418 | TVM_REGISTER_TARGET_KIND("hexagon" , kDLHexagon) |
419 | .add_attr_option<Array<String>>("mattr" ) |
420 | .add_attr_option<String>("mcpu" ) |
421 | .add_attr_option<String>("mtriple" ) |
422 | .add_attr_option<Array<String>>("llvm-options" ) |
423 | .add_attr_option<Integer>("num-cores" ) |
424 | .add_attr_option<Integer>("vtcm-capacity" ) |
425 | .set_default_keys({"hexagon" }); |
426 | |
427 | TVM_REGISTER_TARGET_KIND("stackvm" , kDLCPU); |
428 | |
429 | TVM_REGISTER_TARGET_KIND("ext_dev" , kDLExtDev); |
430 | |
431 | TVM_REGISTER_TARGET_KIND("hybrid" , kDLCPU); |
432 | |
433 | TVM_REGISTER_TARGET_KIND("composite" , kDLCPU) // line break |
434 | .add_attr_option<Array<Target>>("devices" ); |
435 | |
436 | TVM_REGISTER_TARGET_KIND("test" , kDLCPU) // line break |
437 | .set_target_parser(TestTargetParser); |
438 | |
439 | /********** Registry **********/ |
440 | |
441 | TVM_REGISTER_GLOBAL("target.TargetKindGetAttr" ) |
442 | .set_body_typed([](TargetKind kind, String attr_name) -> TVMRetValue { |
443 | auto target_attr_map = TargetKind::GetAttrMap<TVMRetValue>(attr_name); |
444 | TVMRetValue rv; |
445 | if (target_attr_map.count(kind)) { |
446 | rv = target_attr_map[kind]; |
447 | } |
448 | return rv; |
449 | }); |
450 | TVM_REGISTER_GLOBAL("target.ListTargetKinds" ).set_body_typed(TargetKindRegEntry::ListTargetKinds); |
451 | TVM_REGISTER_GLOBAL("target.ListTargetKindOptions" ) |
452 | .set_body_typed(TargetKindRegEntry::ListTargetKindOptions); |
453 | TVM_REGISTER_GLOBAL("target.ListTargetKindOptionsFromName" ) |
454 | .set_body_typed([](String target_kind_name) { |
455 | TargetKind kind = TargetKind::Get(target_kind_name).value(); |
456 | return TargetKindRegEntry::ListTargetKindOptions(kind); |
457 | }); |
458 | |
459 | } // namespace tvm |
460 | |