1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/target/target_kind.cc
22 * \brief Target kind registry
23 */
24#include <tvm/ir/expr.h>
25#include <tvm/runtime/device_api.h>
26#include <tvm/runtime/registry.h>
27#include <tvm/target/target.h>
28#include <tvm/target/target_kind.h>
29
30#include <algorithm>
31
32#include "../node/attr_registry.h"
33#include "./parsers/cpu.h"
34
35namespace tvm {
36
37TVM_REGISTER_NODE_TYPE(TargetKindNode);
38
39TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
40 .set_dispatch<TargetKindNode>([](const ObjectRef& obj, ReprPrinter* p) {
41 const TargetKind& kind = Downcast<TargetKind>(obj);
42 p->stream << kind->name;
43 });
44
45/********** Registry-related code **********/
46
47using TargetKindRegistry = AttrRegistry<TargetKindRegEntry, TargetKind>;
48
49Array<String> TargetKindRegEntry::ListTargetKinds() {
50 return TargetKindRegistry::Global()->ListAllNames();
51}
52
53Map<String, String> TargetKindRegEntry::ListTargetKindOptions(const TargetKind& target_kind) {
54 Map<String, String> options;
55 for (const auto& kv : target_kind->key2vtype_) {
56 options.Set(kv.first, kv.second.type_key);
57 }
58 return options;
59}
60
61TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_name) {
62 return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name);
63}
64
65void TargetKindRegEntry::UpdateAttr(const String& key, TVMRetValue value, int plevel) {
66 TargetKindRegistry::Global()->UpdateAttr(key, kind_, value, plevel);
67}
68
69const AttrRegistryMapContainerMap<TargetKind>& TargetKind::GetAttrMapContainer(
70 const String& attr_name) {
71 return TargetKindRegistry::Global()->GetAttrMap(attr_name);
72}
73
74Optional<TargetKind> TargetKind::Get(const String& target_kind_name) {
75 const TargetKindRegEntry* reg = TargetKindRegistry::Global()->Get(target_kind_name);
76 if (reg == nullptr) {
77 return NullOpt;
78 }
79 return reg->kind_;
80}
81
82/********** Utility functions **********/
83
84/*!
85 * \brief Extract a number from the string with the given prefix.
86 * For example, when `str` is "sm_20" and `prefix` is "sm_".
87 * This function first checks if `str` starts with `prefix`,
88 * then return the integer 20 after the `prefix`
89 * \param str The string to be extracted
90 * \param prefix The prefix to be checked
91 * \return An integer, the extracted number. -1 if the check fails
92 */
93static int ExtractIntWithPrefix(const std::string& str, const std::string& prefix) {
94 if (str.substr(0, prefix.size()) != prefix) {
95 return -1;
96 }
97 int result = 0;
98 for (size_t i = prefix.size(); i < str.size(); ++i) {
99 char c = str[i];
100 if (!isdigit(c)) {
101 return -1;
102 }
103 result = result * 10 + c - '0';
104 }
105 return result;
106}
107
108/*!
109 * \brief Using TVM DeviceAPI to detect the device flag
110 * \param device The device to be detected
111 * \param flag The device flag to be detected
112 * \param val The detected value
113 * \return A boolean indicating if detection succeeds
114 */
115static bool DetectDeviceFlag(Device device, runtime::DeviceAttrKind flag, TVMRetValue* val) {
116 using runtime::DeviceAPI;
117 DeviceAPI* api = DeviceAPI::Get(device, true);
118 // Check if compiled with the corresponding device api
119 if (api == nullptr) {
120 return false;
121 }
122 // Check if the device exists
123 api->GetAttr(device, runtime::kExist, val);
124 int exists = *val;
125 if (!exists) {
126 return false;
127 }
128 // Get the arch of the device
129 DeviceAPI::Get(device)->GetAttr(device, flag, val);
130 return true;
131}
132
133void CheckOrSetAttr(Map<String, ObjectRef>* attrs, const String& name, const String& value) {
134 auto iter = attrs->find(name);
135 if (iter == attrs->end()) {
136 attrs->Set(name, value);
137 } else {
138 const auto* str = (*iter).second.as<StringObj>();
139 ICHECK(str != nullptr && GetRef<String>(str) == value)
140 << "ValueError: Expects \"" << name << "\" to be \"" << value
141 << "\", but gets: " << (*iter).second;
142 }
143}
144
145/********** Target kind attribute updaters **********/
146
147/*!
148 * \brief Update the attributes in the CUDA target.
149 * \param target The Target to update
150 * \return The updated attributes
151 */
152TargetJSON UpdateCUDAAttrs(TargetJSON target) {
153 // Update -arch=sm_xx
154 int archInt;
155 if (target.count("arch")) {
156 // If -arch has been specified, validate the correctness
157 String archStr = Downcast<String>(target.at("arch"));
158 archInt = ExtractIntWithPrefix(archStr, "sm_");
159 ICHECK(archInt != -1) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
160 } else {
161 // Use the compute version of the first CUDA GPU instead
162 TVMRetValue version;
163 if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
164 LOG(WARNING) << "Unable to detect CUDA version, default to \"-arch=sm_20\" instead";
165 archInt = 20;
166 } else {
167 archInt = std::stod(version.operator std::string()) * 10 + 0.1;
168 }
169 target.Set("arch", String("sm_") + std::to_string(archInt));
170 }
171 return target;
172}
173
174/*!
175 * \brief Update the attributes in the LLVM NVPTX target.
176 * \param target The Target to update
177 * \return The updated attributes
178 */
179TargetJSON UpdateNVPTXAttrs(TargetJSON target) {
180 CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda");
181 // Update -mcpu=sm_xx
182 int arch;
183 if (target.count("mcpu")) {
184 // If -mcpu has been specified, validate the correctness
185 String mcpu = Downcast<String>(target.at("mcpu"));
186 arch = ExtractIntWithPrefix(mcpu, "sm_");
187 ICHECK(arch != -1) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
188 } else {
189 // Use the compute version of the first CUDA GPU instead
190 TVMRetValue version;
191 if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) {
192 LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_20\" instead";
193 arch = 20;
194 } else {
195 arch = std::stod(version.operator std::string()) * 10 + 0.1;
196 }
197 target.Set("mcpu", String("sm_") + std::to_string(arch));
198 }
199 return target;
200}
201
202/*!
203 * \brief Update the attributes in the LLVM ROCm target.
204 * \param target The Target to update
205 * \return The updated attributes
206 */
207TargetJSON UpdateROCmAttrs(TargetJSON target) {
208 CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc");
209 // Update -mcpu=gfx
210 int arch;
211 if (target.count("mcpu")) {
212 String mcpu = Downcast<String>(target.at("mcpu"));
213 arch = ExtractIntWithPrefix(mcpu, "gfx");
214 ICHECK(arch != -1) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
215 } else {
216 TVMRetValue val;
217 if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kGcnArch, &val)) {
218 LOG(WARNING) << "Unable to detect ROCm compute arch, default to \"-mcpu=gfx900\" instead";
219 arch = 900;
220 } else {
221 arch = val.operator int();
222 }
223 target.Set("mcpu", String("gfx") + std::to_string(arch));
224 }
225 // Update -mattr before ROCm 3.5:
226 // Before ROCm 3.5 we needed code object v2, starting
227 // with 3.5 we need v3 (this argument disables v3)
228
229 TVMRetValue val;
230 int version;
231 if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &val)) {
232 LOG(WARNING) << "Unable to detect ROCm version, assuming >= 3.5";
233 version = 305;
234 } else {
235 version = val.operator int();
236 }
237 if (version < 305) {
238 Array<String> mattr;
239 if (target.count("mattr")) {
240 mattr = Downcast<Array<String>>(target.at("mattr"));
241 }
242 mattr.push_back("-code-object-v3");
243 target.Set("mattr", mattr);
244 }
245 return target;
246}
247
248/*!
249 * \brief Test Target Parser
250 * \param target The Target to update
251 * \return The updated attributes
252 */
253TargetJSON TestTargetParser(TargetJSON target) {
254 Map<String, ObjectRef> features = {{"is_test", Bool(true)}};
255 target.Set("features", features);
256 return target;
257}
258
259/********** Register Target kinds and attributes **********/
260
261TVM_REGISTER_TARGET_KIND("llvm", kDLCPU)
262 .add_attr_option<Array<String>>("mattr")
263 .add_attr_option<String>("mcpu")
264 .add_attr_option<String>("mtriple")
265 .add_attr_option<String>("mfloat-abi")
266 .add_attr_option<String>("mabi")
267 .add_attr_option<Integer>("num-cores")
268 // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags
269 .add_attr_option<Bool>("fast-math") // implies all the below
270 .add_attr_option<Bool>("fast-math-nnan")
271 .add_attr_option<Bool>("fast-math-ninf")
272 .add_attr_option<Bool>("fast-math-nsz")
273 .add_attr_option<Bool>("fast-math-arcp")
274 .add_attr_option<Bool>("fast-math-contract")
275 .add_attr_option<Bool>("fast-math-reassoc")
276 .add_attr_option<Integer>("opt-level")
277 // LLVM command line flags, see below
278 .add_attr_option<Array<String>>("cl-opt")
279 .set_default_keys({"cpu"})
280 // Force the external codegen kind attribute to be registered, even if no external
281 // codegen targets are enabled by the TVM build.
282 .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(false))
283 .set_target_parser(tvm::target::parsers::cpu::ParseTarget);
284
285// Note regarding the "cl-opt" attribute:
286// Each string in the array has the format
287// -optionname[[:type]=value]
288// where
289// * optionname is the actual LLVM option (e.g. "unroll-threshold")
290// * type is one of "bool", "int", "uint", or "string"
291// * value is the corresponding option value (for "bool" type is can be 0 or "false"
292// for false value, or 1 or "true" for true value)
293// If type is omitted, it is assumed to be "bool". If value is omitted, it is assumed
294// to be "true".
295//
296// The type must match the option type in LLVM. To find the type, search the LLVM
297// repository (https://github.com/llvm/llvm-project) for optionname, and look for
298// its definition: it will be a declaration of a variable of type cl::opt<T> with
299// optionname being an argument to the constructor. The T in the declaration is
300// the type.
301// For example, for unroll-threshold, we get the following declaration:
302// static cl::opt<unsigned>
303// UnrollThreshold("unroll-threshold", cl::Hidden,
304// cl::desc("The cost threshold for loop unrolling"));
305// Hence the type is "uint".
306
307TVM_REGISTER_TARGET_KIND("c", kDLCPU)
308 .add_attr_option<String>("mcpu")
309 .add_attr_option<String>("march")
310 .add_attr_option<Integer>("workspace-byte-alignment")
311 .add_attr_option<Integer>("constants-byte-alignment")
312 .set_default_keys({"cpu"})
313 .set_target_parser(tvm::target::parsers::cpu::ParseTarget);
314
315TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
316 .add_attr_option<String>("mcpu")
317 .add_attr_option<String>("arch")
318 .add_attr_option<Integer>("max_shared_memory_per_block")
319 .add_attr_option<Integer>("max_threads_per_block")
320 .add_attr_option<Integer>("thread_warp_size", Integer(32))
321 .add_attr_option<Integer>("registers_per_block")
322 .add_attr_option<Integer>("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it
323 .set_default_keys({"cuda", "gpu"})
324 .set_target_parser(UpdateCUDAAttrs);
325
326TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
327 .add_attr_option<String>("mcpu")
328 .add_attr_option<String>("mtriple")
329 .add_attr_option<Integer>("max_num_threads", Integer(1024))
330 .add_attr_option<Integer>("thread_warp_size", Integer(32))
331 .set_default_keys({"cuda", "gpu"})
332 .set_target_parser(UpdateNVPTXAttrs);
333
334TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
335 .add_attr_option<String>("mcpu")
336 .add_attr_option<String>("mtriple")
337 .add_attr_option<Array<String>>("mattr")
338 // TODO(masahi): Support querying from a target device
339 // On RDNA cards, thread_warp_size should be 32
340 .add_attr_option<Integer>("max_num_threads", Integer(256))
341 .add_attr_option<Integer>("max_threads_per_block", Integer(256))
342 .add_attr_option<Integer>("max_shared_memory_per_block", Integer(65536))
343 .add_attr_option<Integer>("thread_warp_size", Integer(64))
344 .set_default_keys({"rocm", "gpu"})
345 .set_target_parser(UpdateROCmAttrs);
346
347TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
348 .add_attr_option<Integer>("max_num_threads", Integer(256))
349 .add_attr_option<Integer>("thread_warp_size", Integer(1))
350 .add_attr_option<Integer>("texture_spatial_limit", Integer(16384))
351 .set_default_keys({"opencl", "gpu"});
352
353// The metal has some limitations on the number of input parameters. This is why attribute
354// `max_function_args` was introduced. It specifies the maximum number of kernel argumetns. More
355// information about this limitation can be found here:
356// https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc
357// See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
358TVM_REGISTER_TARGET_KIND("metal", kDLMetal)
359 .add_attr_option<Integer>("max_num_threads", Integer(256))
360 .add_attr_option<Integer>("max_threads_per_block", Integer(256))
361 .add_attr_option<Integer>("max_shared_memory_per_block", Integer(32768))
362 .add_attr_option<Integer>("thread_warp_size", Integer(16))
363 .add_attr_option<Integer>("max_function_args", Integer(31))
364 .set_default_keys({"metal", "gpu"});
365
366TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
367 .add_attr_option<Array<String>>("mattr")
368 // Feature support
369 .add_attr_option<Bool>("supports_float16")
370 .add_attr_option<Bool>("supports_float32", Bool(true))
371 .add_attr_option<Bool>("supports_float64")
372 .add_attr_option<Bool>("supports_int8")
373 .add_attr_option<Bool>("supports_int16")
374 .add_attr_option<Bool>("supports_int32", Bool(true))
375 .add_attr_option<Bool>("supports_int64")
376 .add_attr_option<Bool>("supports_8bit_buffer")
377 .add_attr_option<Bool>("supports_16bit_buffer")
378 .add_attr_option<Bool>("supports_storage_buffer_storage_class")
379 .add_attr_option<Bool>("supports_push_descriptor")
380 .add_attr_option<Bool>("supports_dedicated_allocation")
381 .add_attr_option<Bool>("supports_integer_dot_product")
382 .add_attr_option<Integer>("supported_subgroup_operations")
383 // Physical device limits
384 .add_attr_option<Integer>("max_num_threads", Integer(256))
385 .add_attr_option<Integer>("max_threads_per_block", Integer(256))
386 .add_attr_option<Integer>("thread_warp_size", Integer(1))
387 .add_attr_option<Integer>("max_block_size_x")
388 .add_attr_option<Integer>("max_block_size_y")
389 .add_attr_option<Integer>("max_block_size_z")
390 .add_attr_option<Integer>("max_push_constants_size")
391 .add_attr_option<Integer>("max_uniform_buffer_range")
392 .add_attr_option<Integer>("max_storage_buffer_range")
393 .add_attr_option<Integer>("max_per_stage_descriptor_storage_buffer")
394 .add_attr_option<Integer>("max_shared_memory_per_block")
395 // Other device properties
396 .add_attr_option<String>("device_type")
397 .add_attr_option<String>("device_name")
398 .add_attr_option<String>("driver_name")
399 .add_attr_option<Integer>("driver_version")
400 .add_attr_option<Integer>("vulkan_api_version")
401 .add_attr_option<Integer>("max_spirv_version")
402 // Tags
403 .set_default_keys({"vulkan", "gpu"});
404
405TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
406 .add_attr_option<Integer>("max_num_threads", Integer(256))
407 .set_default_keys({"webgpu", "gpu"});
408
409TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL) // line break
410 .set_default_keys({"sdaccel", "hls"});
411
412TVM_REGISTER_TARGET_KIND("aocl", kDLAOCL) // line break
413 .set_default_keys({"aocl", "hls"});
414
415TVM_REGISTER_TARGET_KIND("aocl_sw_emu", kDLAOCL) // line break
416 .set_default_keys({"aocl", "hls"});
417
418TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
419 .add_attr_option<Array<String>>("mattr")
420 .add_attr_option<String>("mcpu")
421 .add_attr_option<String>("mtriple")
422 .add_attr_option<Array<String>>("llvm-options")
423 .add_attr_option<Integer>("num-cores")
424 .add_attr_option<Integer>("vtcm-capacity")
425 .set_default_keys({"hexagon"});
426
427TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU);
428
429TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev);
430
431TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU);
432
433TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break
434 .add_attr_option<Array<Target>>("devices");
435
436TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break
437 .set_target_parser(TestTargetParser);
438
439/********** Registry **********/
440
441TVM_REGISTER_GLOBAL("target.TargetKindGetAttr")
442 .set_body_typed([](TargetKind kind, String attr_name) -> TVMRetValue {
443 auto target_attr_map = TargetKind::GetAttrMap<TVMRetValue>(attr_name);
444 TVMRetValue rv;
445 if (target_attr_map.count(kind)) {
446 rv = target_attr_map[kind];
447 }
448 return rv;
449 });
450TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds);
451TVM_REGISTER_GLOBAL("target.ListTargetKindOptions")
452 .set_body_typed(TargetKindRegEntry::ListTargetKindOptions);
453TVM_REGISTER_GLOBAL("target.ListTargetKindOptionsFromName")
454 .set_body_typed([](String target_kind_name) {
455 TargetKind kind = TargetKind::Get(target_kind_name).value();
456 return TargetKindRegEntry::ListTargetKindOptions(kind);
457 });
458
459} // namespace tvm
460