1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/search_task.cc |
22 | * \brief Meta information and hardware parameters for a search task. |
23 | */ |
24 | |
25 | #include <dlpack/dlpack.h> |
26 | #include <tvm/auto_scheduler/search_task.h> |
27 | #include <tvm/runtime/device_api.h> |
28 | #include <tvm/runtime/registry.h> |
29 | #include <tvm/runtime/threading_backend.h> |
30 | |
31 | #include <utility> |
32 | |
33 | namespace tvm { |
34 | namespace auto_scheduler { |
35 | |
36 | TVM_REGISTER_NODE_TYPE(HardwareParamsNode); |
37 | TVM_REGISTER_NODE_TYPE(SearchTaskNode); |
38 | |
39 | HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, |
40 | int max_shared_memory_per_block, int max_local_memory_per_block, |
41 | int max_threads_per_block, int max_vthread_extent, int warp_size) { |
42 | auto node = make_object<HardwareParamsNode>(); |
43 | node->num_cores = num_cores; |
44 | node->vector_unit_bytes = vector_unit_bytes; |
45 | node->cache_line_bytes = cache_line_bytes; |
46 | node->max_shared_memory_per_block = max_shared_memory_per_block; |
47 | node->max_local_memory_per_block = max_local_memory_per_block; |
48 | node->max_threads_per_block = max_threads_per_block; |
49 | node->max_vthread_extent = max_vthread_extent; |
50 | node->warp_size = warp_size; |
51 | data_ = std::move(node); |
52 | } |
53 | |
54 | HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target, |
55 | const Target& target_host) { |
56 | // There is no use of target_host so no updates here in the function. |
57 | const auto device_type = target->GetTargetDeviceType(); |
58 | if (device_type == kDLCPU) { |
59 | return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0); |
60 | } else if (device_type == kDLCUDA || device_type == kDLROCM) { |
61 | auto dev = Device{static_cast<DLDeviceType>(device_type), 0}; |
62 | auto device_name = device_type == kDLCUDA ? "device_api.cuda" : "device_api.rocm" ; |
63 | auto func = tvm::runtime::Registry::Get(device_name); |
64 | ICHECK(func != nullptr) << "Cannot find CUDA device_api in registry" ; |
65 | auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*()); |
66 | |
67 | tvm::runtime::TVMRetValue ret; |
68 | device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); |
69 | int max_shared_memory_per_block = ret; |
70 | |
71 | // There is no explicit local memory limition in CUDA runtime, |
72 | // so we can use INT32_MAX to disalbe the check on local_memory. |
73 | int max_local_memory_per_block = INT32_MAX; |
74 | |
75 | device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); |
76 | int max_threads_per_block = ret; |
77 | |
78 | device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); |
79 | int warp_size = ret; |
80 | |
81 | int max_vthread_extent = warp_size / 4; |
82 | return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, |
83 | max_threads_per_block, max_vthread_extent, warp_size); |
84 | } else if (device_type == kDLMetal) { |
85 | // Reference: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf |
86 | // This setting looks working for Metal GPUs later than A10 |
87 | int max_shared_memory_per_block = 32 * 1024; |
88 | int max_local_memory_per_block = INT32_MAX; // skip the check on local memory |
89 | int max_threads_per_block = 1024; |
90 | int warp_size = 8; |
91 | int max_vthread_extent = warp_size / 4; |
92 | return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, |
93 | max_threads_per_block, max_vthread_extent, warp_size); |
94 | } else if (target->GetTargetDeviceType() == kDLOpenCL) { |
95 | if (target->GetAttr<String>("device" , "" ) == "mali" ) { |
96 | // We cannot use device API to get hardware attributes like CUDA, |
97 | // because like Mali target is normally on the remote machine. |
98 | int max_shared_memory_per_block = 32768; |
99 | int max_local_memory_per_block = INT32_MAX; // skip the check on local memory |
100 | int max_threads_per_block = 256; |
101 | int warp_size = 1; |
102 | int max_vthread_extent = 1; |
103 | return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, |
104 | max_threads_per_block, max_vthread_extent, warp_size); |
105 | } else if (target->GetAttr<String>("device" , "" ) == "adreno" ) { |
106 | int max_shared_memory_per_block = 32768; |
107 | int max_local_memory_per_block = 32768; |
108 | int max_threads_per_block = 256; |
109 | int warp_size = 1; |
110 | int max_vthread_extent = 1; |
111 | return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, |
112 | max_threads_per_block, max_vthread_extent, warp_size); |
113 | } else { |
114 | // add other opencl target |
115 | auto dev = Device{static_cast<DLDeviceType>(device_type), 0}; |
116 | auto device_name = "device_api.opencl" ; |
117 | auto func = tvm::runtime::Registry::Get(device_name); |
118 | ICHECK(func != nullptr) << "Cannot find OpenCL device_api in registry" ; |
119 | auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*()); |
120 | |
121 | tvm::runtime::TVMRetValue ret; |
122 | device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); |
123 | int max_shared_memory_per_block = ret; |
124 | |
125 | int max_local_memory_per_block = INT32_MAX; |
126 | |
127 | device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); |
128 | int max_threads_per_block = ret; |
129 | |
130 | device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); |
131 | int warp_size = ret; |
132 | |
133 | if (warp_size == 1) { |
134 | LOG(WARNING) |
135 | << "Warp size 1 is not recommended for OpenCL devices. Tuning might crash or stuck" ; |
136 | } |
137 | |
138 | int max_vthread_extent = std::max(1, warp_size / 4); |
139 | return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, |
140 | max_threads_per_block, max_vthread_extent, warp_size); |
141 | } |
142 | } else if (device_type == kDLVulkan) { |
143 | auto dev = Device{static_cast<DLDeviceType>(device_type), 0}; |
144 | auto device_name = "device_api.vulkan" ; |
145 | auto func = tvm::runtime::Registry::Get(device_name); |
146 | ICHECK(func != nullptr) << "Cannot find Vulkan device_api in registry" ; |
147 | auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*()); |
148 | |
149 | tvm::runtime::TVMRetValue ret; |
150 | device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret); |
151 | int max_shared_memory_per_block = ret; |
152 | |
153 | int max_local_memory_per_block = INT32_MAX; |
154 | |
155 | device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret); |
156 | int max_threads_per_block = ret; |
157 | |
158 | device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kWarpSize, &ret); |
159 | int warp_size = ret; |
160 | |
161 | int max_vthread_extent = std::max(1, warp_size / 4); |
162 | |
163 | return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block, |
164 | max_threads_per_block, max_vthread_extent, warp_size); |
165 | } else { |
166 | LOG(FATAL) << "No default hardware parameters for target: " << target; |
167 | } |
168 | return HardwareParams(); |
169 | } |
170 | |
171 | SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target, |
172 | Target target_host, Optional<HardwareParams> hardware_params, |
173 | LayoutRewriteOption layout_rewrite_option, Array<String> task_input_names, |
174 | String desc) { |
175 | CheckAndUpdateHostConsistency(&target, &target_host); |
176 | auto node = make_object<SearchTaskNode>(); |
177 | node->compute_dag = std::move(compute_dag); |
178 | node->workload_key = std::move(workload_key); |
179 | node->desc = std::move(desc); |
180 | node->target = std::move(target); |
181 | node->target_host = std::move(target_host); |
182 | if (hardware_params) { |
183 | node->hardware_params = hardware_params.value(); |
184 | } else { |
185 | node->hardware_params = |
186 | HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host); |
187 | } |
188 | node->layout_rewrite_option = layout_rewrite_option; |
189 | node->task_input_names = std::move(task_input_names); |
190 | data_ = std::move(node); |
191 | } |
192 | |
193 | TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams" ) |
194 | .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes, |
195 | int max_shared_memory_per_block, int max_local_memory_per_block, |
196 | int max_threads_per_block, int max_vthread_extent, int warp_size) { |
197 | return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes, |
198 | max_shared_memory_per_block, max_local_memory_per_block, |
199 | max_threads_per_block, max_vthread_extent, warp_size); |
200 | }); |
201 | |
202 | TVM_REGISTER_GLOBAL("auto_scheduler.GetDefaultHardwareParams" ) |
203 | .set_body_typed([](Target target, Target target_host) { |
204 | return HardwareParamsNode::GetDefaultHardwareParams(target, target_host); |
205 | }); |
206 | |
207 | TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask" ) |
208 | .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target, |
209 | Target target_host, Optional<HardwareParams> hardware_params, |
210 | int layout_rewrite_option, Array<String> task_input_names, String desc) { |
211 | CheckAndUpdateHostConsistency(&target, &target_host); |
212 | return SearchTask(compute_dag, workload_key, target, target_host, hardware_params, |
213 | LayoutRewriteOption(layout_rewrite_option), task_input_names, desc); |
214 | }); |
215 | |
216 | } // namespace auto_scheduler |
217 | } // namespace tvm |
218 | |