1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/search_task.cc
22 * \brief Meta information and hardware parameters for a search task.
23 */
24
25#include <dlpack/dlpack.h>
26#include <tvm/auto_scheduler/search_task.h>
27#include <tvm/runtime/device_api.h>
28#include <tvm/runtime/registry.h>
29#include <tvm/runtime/threading_backend.h>
30
31#include <utility>
32
33namespace tvm {
34namespace auto_scheduler {
35
36TVM_REGISTER_NODE_TYPE(HardwareParamsNode);
37TVM_REGISTER_NODE_TYPE(SearchTaskNode);
38
39HardwareParams::HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes,
40 int max_shared_memory_per_block, int max_local_memory_per_block,
41 int max_threads_per_block, int max_vthread_extent, int warp_size) {
42 auto node = make_object<HardwareParamsNode>();
43 node->num_cores = num_cores;
44 node->vector_unit_bytes = vector_unit_bytes;
45 node->cache_line_bytes = cache_line_bytes;
46 node->max_shared_memory_per_block = max_shared_memory_per_block;
47 node->max_local_memory_per_block = max_local_memory_per_block;
48 node->max_threads_per_block = max_threads_per_block;
49 node->max_vthread_extent = max_vthread_extent;
50 node->warp_size = warp_size;
51 data_ = std::move(node);
52}
53
54HardwareParams HardwareParamsNode::GetDefaultHardwareParams(const Target& target,
55 const Target& target_host) {
56 // There is no use of target_host so no updates here in the function.
57 const auto device_type = target->GetTargetDeviceType();
58 if (device_type == kDLCPU) {
59 return HardwareParams(tvm::runtime::threading::MaxConcurrency(), 64, 64, 0, 0, 0, 0, 0);
60 } else if (device_type == kDLCUDA || device_type == kDLROCM) {
61 auto dev = Device{static_cast<DLDeviceType>(device_type), 0};
62 auto device_name = device_type == kDLCUDA ? "device_api.cuda" : "device_api.rocm";
63 auto func = tvm::runtime::Registry::Get(device_name);
64 ICHECK(func != nullptr) << "Cannot find CUDA device_api in registry";
65 auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
66
67 tvm::runtime::TVMRetValue ret;
68 device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
69 int max_shared_memory_per_block = ret;
70
71 // There is no explicit local memory limition in CUDA runtime,
72 // so we can use INT32_MAX to disalbe the check on local_memory.
73 int max_local_memory_per_block = INT32_MAX;
74
75 device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
76 int max_threads_per_block = ret;
77
78 device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
79 int warp_size = ret;
80
81 int max_vthread_extent = warp_size / 4;
82 return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
83 max_threads_per_block, max_vthread_extent, warp_size);
84 } else if (device_type == kDLMetal) {
85 // Reference: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
86 // This setting looks working for Metal GPUs later than A10
87 int max_shared_memory_per_block = 32 * 1024;
88 int max_local_memory_per_block = INT32_MAX; // skip the check on local memory
89 int max_threads_per_block = 1024;
90 int warp_size = 8;
91 int max_vthread_extent = warp_size / 4;
92 return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
93 max_threads_per_block, max_vthread_extent, warp_size);
94 } else if (target->GetTargetDeviceType() == kDLOpenCL) {
95 if (target->GetAttr<String>("device", "") == "mali") {
96 // We cannot use device API to get hardware attributes like CUDA,
97 // because like Mali target is normally on the remote machine.
98 int max_shared_memory_per_block = 32768;
99 int max_local_memory_per_block = INT32_MAX; // skip the check on local memory
100 int max_threads_per_block = 256;
101 int warp_size = 1;
102 int max_vthread_extent = 1;
103 return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
104 max_threads_per_block, max_vthread_extent, warp_size);
105 } else if (target->GetAttr<String>("device", "") == "adreno") {
106 int max_shared_memory_per_block = 32768;
107 int max_local_memory_per_block = 32768;
108 int max_threads_per_block = 256;
109 int warp_size = 1;
110 int max_vthread_extent = 1;
111 return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
112 max_threads_per_block, max_vthread_extent, warp_size);
113 } else {
114 // add other opencl target
115 auto dev = Device{static_cast<DLDeviceType>(device_type), 0};
116 auto device_name = "device_api.opencl";
117 auto func = tvm::runtime::Registry::Get(device_name);
118 ICHECK(func != nullptr) << "Cannot find OpenCL device_api in registry";
119 auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
120
121 tvm::runtime::TVMRetValue ret;
122 device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
123 int max_shared_memory_per_block = ret;
124
125 int max_local_memory_per_block = INT32_MAX;
126
127 device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
128 int max_threads_per_block = ret;
129
130 device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
131 int warp_size = ret;
132
133 if (warp_size == 1) {
134 LOG(WARNING)
135 << "Warp size 1 is not recommended for OpenCL devices. Tuning might crash or stuck";
136 }
137
138 int max_vthread_extent = std::max(1, warp_size / 4);
139 return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
140 max_threads_per_block, max_vthread_extent, warp_size);
141 }
142 } else if (device_type == kDLVulkan) {
143 auto dev = Device{static_cast<DLDeviceType>(device_type), 0};
144 auto device_name = "device_api.vulkan";
145 auto func = tvm::runtime::Registry::Get(device_name);
146 ICHECK(func != nullptr) << "Cannot find Vulkan device_api in registry";
147 auto device_api = static_cast<tvm::runtime::DeviceAPI*>(((*func)()).operator void*());
148
149 tvm::runtime::TVMRetValue ret;
150 device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxSharedMemoryPerBlock, &ret);
151 int max_shared_memory_per_block = ret;
152
153 int max_local_memory_per_block = INT32_MAX;
154
155 device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kMaxThreadsPerBlock, &ret);
156 int max_threads_per_block = ret;
157
158 device_api->GetAttr(dev, tvm::runtime::DeviceAttrKind::kWarpSize, &ret);
159 int warp_size = ret;
160
161 int max_vthread_extent = std::max(1, warp_size / 4);
162
163 return HardwareParams(-1, 16, 64, max_shared_memory_per_block, max_local_memory_per_block,
164 max_threads_per_block, max_vthread_extent, warp_size);
165 } else {
166 LOG(FATAL) << "No default hardware parameters for target: " << target;
167 }
168 return HardwareParams();
169}
170
171SearchTask::SearchTask(ComputeDAG compute_dag, String workload_key, Target target,
172 Target target_host, Optional<HardwareParams> hardware_params,
173 LayoutRewriteOption layout_rewrite_option, Array<String> task_input_names,
174 String desc) {
175 CheckAndUpdateHostConsistency(&target, &target_host);
176 auto node = make_object<SearchTaskNode>();
177 node->compute_dag = std::move(compute_dag);
178 node->workload_key = std::move(workload_key);
179 node->desc = std::move(desc);
180 node->target = std::move(target);
181 node->target_host = std::move(target_host);
182 if (hardware_params) {
183 node->hardware_params = hardware_params.value();
184 } else {
185 node->hardware_params =
186 HardwareParamsNode::GetDefaultHardwareParams(node->target, node->target_host);
187 }
188 node->layout_rewrite_option = layout_rewrite_option;
189 node->task_input_names = std::move(task_input_names);
190 data_ = std::move(node);
191}
192
193TVM_REGISTER_GLOBAL("auto_scheduler.HardwareParams")
194 .set_body_typed([](int num_cores, int vector_unit_bytes, int cache_line_bytes,
195 int max_shared_memory_per_block, int max_local_memory_per_block,
196 int max_threads_per_block, int max_vthread_extent, int warp_size) {
197 return HardwareParams(num_cores, vector_unit_bytes, cache_line_bytes,
198 max_shared_memory_per_block, max_local_memory_per_block,
199 max_threads_per_block, max_vthread_extent, warp_size);
200 });
201
202TVM_REGISTER_GLOBAL("auto_scheduler.GetDefaultHardwareParams")
203 .set_body_typed([](Target target, Target target_host) {
204 return HardwareParamsNode::GetDefaultHardwareParams(target, target_host);
205 });
206
207TVM_REGISTER_GLOBAL("auto_scheduler.SearchTask")
208 .set_body_typed([](ComputeDAG compute_dag, String workload_key, Target target,
209 Target target_host, Optional<HardwareParams> hardware_params,
210 int layout_rewrite_option, Array<String> task_input_names, String desc) {
211 CheckAndUpdateHostConsistency(&target, &target_host);
212 return SearchTask(compute_dag, workload_key, target, target_host, hardware_params,
213 LayoutRewriteOption(layout_rewrite_option), task_input_names, desc);
214 });
215
216} // namespace auto_scheduler
217} // namespace tvm
218