1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/target/compilation_config.cc |
22 | * \brief Implementation of \p CompilationConfig for collecting \p Targets. |
23 | */ |
24 | |
25 | #include <tvm/runtime/device_api.h> |
26 | #include <tvm/target/compilation_config.h> |
27 | |
28 | namespace tvm { |
29 | |
30 | TVM_REGISTER_NODE_TYPE(CompilationConfigNode); |
31 | |
32 | void CompilationConfigNode::VisitAttrs(AttrVisitor* v) { |
33 | v->Visit("host_target" , &host_target); |
34 | v->Visit("primitive_targets" , &primitive_targets); |
35 | v->Visit("default_primitive_virtual_device" , &default_primitive_virtual_device); |
36 | v->Visit("host_virtual_device" , &host_virtual_device); |
37 | v->Visit("optional_homogenous_target" , &optional_homogeneous_target); |
38 | // NOTE: The virtual_device_cache_ is not accessible via FFI. |
39 | } |
40 | |
41 | Target CompilationConfigNode::FindPrimitiveTargetForDeviceOrFail(DLDeviceType device_type) const { |
42 | ICHECK_GT(device_type, 0) << "Invalid device type" ; |
43 | auto itr = std::find_if( |
44 | primitive_targets.begin(), primitive_targets.end(), |
45 | [device_type](const Target& target) { return target->GetTargetDeviceType() == device_type; }); |
46 | if (itr == primitive_targets.end()) { |
47 | std::stringstream msg; |
48 | msg << "No target is specified for device type " << device_type |
49 | << ". The available device types and targets are:" << std::endl; |
50 | for (const auto& target : primitive_targets) { |
51 | msg << " " << target->GetTargetDeviceType() << "-> " << target->ToDebugString() << std::endl; |
52 | } |
53 | LOG(FATAL) << msg.str(); |
54 | } |
55 | return *itr; |
56 | } |
57 | |
58 | Optional<Target> CompilationConfigNode::FindPrimitiveTargetForKind( |
59 | const std::string& kind_name) const { |
60 | Optional<TargetKind> opt_kind = TargetKind::Get(kind_name); |
61 | if (!opt_kind.defined()) { |
62 | VLOG(1) << "No such target kind for '" << kind_name << "'" ; |
63 | return {}; |
64 | } |
65 | auto itr = |
66 | std::find_if(primitive_targets.begin(), primitive_targets.end(), |
67 | [kind_name](const Target& target) { return target->kind->name == kind_name; }); |
68 | if (itr == primitive_targets.end()) { |
69 | VLOG(1) << "No target available matching kind '" << kind_name << "'" ; |
70 | return {}; |
71 | } |
72 | return *itr; |
73 | } |
74 | |
75 | Target CompilationConfigNode::CanonicalTarget(const Target& target) const { |
76 | // Fast path -- object identity. |
77 | if (target == host_target) { |
78 | return target; |
79 | } |
80 | for (const auto& primitive_target : primitive_targets) { |
81 | if (target == primitive_target) { |
82 | return target; |
83 | } |
84 | } |
85 | // Slow path -- structural equality. We have so few targets it does not seem worth building an |
86 | // index. |
87 | if (StructuralEqual()(target, host_target)) { |
88 | return host_target; |
89 | } |
90 | for (const auto& primitive_target : primitive_targets) { |
91 | if (StructuralEqual()(target, primitive_target)) { |
92 | return primitive_target; |
93 | } |
94 | } |
95 | // No match. |
96 | return target; |
97 | } |
98 | |
99 | VirtualDevice CompilationConfigNode::CanonicalVirtualDevice( |
100 | const VirtualDevice& virtual_device) const { |
101 | // Targets need special handling. |
102 | Target target = virtual_device->target; |
103 | if (target.defined()) { |
104 | // It is possible the given target object was constructed by the user, but was then |
105 | // rewritten on the way into the CompilationConfig. So 'canonicalize' it by replacing |
106 | // the given target with one structurally equal to one already known in the config if |
107 | // possible. |
108 | Target canon_target = CanonicalTarget(target); |
109 | if (canon_target != target) { |
110 | VLOG(1) << "Canonicalized target " << canon_target->ToDebugString(); |
111 | } |
112 | target = canon_target; |
113 | } else if (virtual_device->device_type() != kInvalidDeviceType) { |
114 | // Since no target was given, choose one with a matching device type. |
115 | // This is the one place where we allow device types to imply targets. |
116 | target = FindPrimitiveTargetForDeviceOrFail(virtual_device->device_type()); |
117 | VLOG(1) << "Defaulted to target " << target->ToDebugString(); |
118 | } |
119 | // else: the target will remain unknown. |
120 | |
121 | // Redirect to an existing structurally equal virtual device. |
122 | return virtual_device_cache_.Unique(VirtualDevice(virtual_device->device_type(), |
123 | virtual_device->virtual_device_id, target, |
124 | virtual_device->memory_scope)); |
125 | } |
126 | |
127 | void CompilationConfigNode::Init(const transform::PassContext& pass_ctx, |
128 | const Array<Target>& raw_targets) { |
129 | VLOG_CONTEXT << "CompilationConfig" ; |
130 | CHECK_GT(raw_targets.size(), 0U) << "Require at least one target" ; |
131 | |
132 | // |
133 | // Decide on the host target. |
134 | // |
135 | |
136 | // Any targets which could act as a host? |
137 | auto hosting_itr = std::find_if(raw_targets.begin(), raw_targets.end(), [](const Target& target) { |
138 | // TODO(tvm-team): The kDLHexagon device can act as a host. We can remove kDLHexagon |
139 | // here once we refactored kDLHexagon to kDLCPU. |
140 | return target->GetTargetDeviceType() == kDLCPU || target->GetTargetDeviceType() == kDLHexagon; |
141 | }); |
142 | |
143 | // Any targets with their host field set? |
144 | auto has_host_itr = std::find_if(raw_targets.begin(), raw_targets.end(), |
145 | [](const Target& target) { return target->host.defined(); }); |
146 | |
147 | if (has_host_itr != raw_targets.end()) { |
148 | // RULE A: If any raw target has a host, use the first such host for all the primitive |
149 | // targets. |
150 | host_target = Target((*has_host_itr)->GetHost().value(), /*host=*/Target()); |
151 | VLOG(1) << "The target " << (*has_host_itr)->ToDebugString() << " supplies a host target " |
152 | << host_target->ToDebugString() << " of device type " |
153 | << host_target->GetTargetDeviceType(); |
154 | } else if (hosting_itr != raw_targets.end()) { |
155 | // RULE B: If any raw target is for a device which could be a host then use the first such as |
156 | // the host. |
157 | host_target = Target(*hosting_itr, /*host=*/Target()); |
158 | VLOG(1) << "Using target " << host_target->ToDebugString() << " of CPU-like device type " |
159 | << host_target->GetTargetDeviceType() << " as the host target" ; |
160 | } else { |
161 | // RULE C: Otherwise, create a default CPU host target. |
162 | host_target = MakeDefaultCPUTarget(); |
163 | VLOG(1) << "Created a default target " << host_target->ToDebugString() << " of device type " |
164 | << host_target->GetTargetDeviceType() << " for the host target" ; |
165 | } |
166 | ICHECK(host_target.defined()); |
167 | ICHECK(!host_target->host.defined()); |
168 | |
169 | if (host_target->GetTargetDeviceType() != kDLCPU) { |
170 | // I think we're on thin ice here until we've audited the code base for assumed CPU hosts. |
171 | VLOG(1) << "The host target is not a CPU. This is probably not going to work." ; |
172 | } |
173 | |
174 | // |
175 | // Establish the host VirtualDevice. |
176 | // |
177 | host_virtual_device = virtual_device_cache_.Unique( |
178 | VirtualDevice(static_cast<DLDeviceType>(host_target->GetTargetDeviceType()), |
179 | /*virtual_device_id=*/0, host_target)); |
180 | ICHECK(host_virtual_device.defined()); |
181 | ICHECK(host_virtual_device->target.defined()); |
182 | |
183 | // |
184 | // Now that we've settled on a host, we can set it as the host on all the raw targets. |
185 | // |
186 | primitive_targets.clear(); |
187 | primitive_targets.reserve(raw_targets.size()); |
188 | for (const auto& raw_target : raw_targets) { |
189 | if (raw_target->host.defined() && !StructuralEqual()(raw_target->host, host_target)) { |
190 | VLOG(1) << "The target " << raw_target->ToDebugString() |
191 | << " already has a host which disagrees with the desired host target. It " |
192 | << "will be overridden." ; |
193 | } |
194 | primitive_targets.push_back(Target(raw_target, host_target)); |
195 | } |
196 | ICHECK_GT(primitive_targets.size(), 0U); |
197 | |
198 | // |
199 | // Check the primitive_targets are ordered correctly re Target::IsExternalCodegenFor, |
200 | // and make sure no two targets share a kind name. |
201 | // |
202 | |
203 | // TODO(mbs): We could just sort the list, but given all the implicit defaulting for backwards |
204 | // compat it seems we should avoid making this any more magical than necessary. But revisit |
205 | // if usability suffers. |
206 | std::unordered_set<DLDeviceType> primitive_target_device_types; |
207 | std::unordered_set<std::string> kind_names; |
208 | for (const auto& target : primitive_targets) { |
209 | primitive_target_device_types.emplace(static_cast<DLDeviceType>(target->GetTargetDeviceType())); |
210 | CHECK(kind_names.emplace(target->kind->name).second) << "Multiple targets have been given" |
211 | "for the same device kind '" |
212 | << target->kind->name << "'" ; |
213 | } |
214 | for (DLDeviceType device_type : primitive_target_device_types) { |
215 | Target first_primitive_target; |
216 | for (const auto& current_primitive_target : primitive_targets) { |
217 | if (current_primitive_target->GetTargetDeviceType() != device_type) { |
218 | continue; |
219 | } |
220 | if (!first_primitive_target.defined()) { |
221 | first_primitive_target = current_primitive_target; |
222 | // Note it is valid to have only one external codegen target. |
223 | } else { |
224 | CHECK(current_primitive_target.IsExternalCodegenFor(first_primitive_target)) |
225 | << "When given multiple targets for the device type " << device_type |
226 | << " the first must be for non external codegen, and all subsequent must be for " |
227 | "external codegen. However have been given first " |
228 | << first_primitive_target->ToDebugString() << " and subsequent " |
229 | << current_primitive_target->ToDebugString(); |
230 | } |
231 | } |
232 | } |
233 | |
234 | // |
235 | // Decide on the default device type for primitives. |
236 | // |
237 | DLDeviceType default_primitive_device_type; |
238 | Optional<Integer> opt_fallback_dev = pass_ctx->GetConfig<Integer>("relay.fallback_device_type" ); |
239 | if (opt_fallback_dev) { |
240 | // RULE D: Respect the PassContext setting if given. |
241 | const int64_t v = opt_fallback_dev.value()->value; |
242 | CHECK_GT(v, 0) |
243 | << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " << v; |
244 | default_primitive_device_type = static_cast<DLDeviceType>(v); |
245 | VLOG(1) << "Using the 'relay.fallback_device_type' pass attribute " |
246 | << default_primitive_device_type |
247 | << " as the default device type for all primitive operations" ; |
248 | } else if (primitive_target_device_types.size() == 1) { |
249 | // RULE E: Since only one device in use there's no choice to make. |
250 | default_primitive_device_type = *primitive_target_device_types.begin(); |
251 | VLOG(1) << "All primitive targets have the device type " << default_primitive_device_type |
252 | << " so that is also the default device type for all primitive operations." ; |
253 | } else { |
254 | // RULE F: Fallback to CPU. |
255 | default_primitive_device_type = kDLCPU; |
256 | VLOG(1) << "Using " << default_primitive_device_type |
257 | << " as the default device type for all primitive operations" ; |
258 | } |
259 | |
260 | // |
261 | // Establish the default primitive VirtualDevice, choosing a known Target to match the device |
262 | // type. We do not create a default target, it must already exist as a primitive target. |
263 | // |
264 | default_primitive_virtual_device = CanonicalVirtualDevice( |
265 | VirtualDevice::ForDeviceType(default_primitive_device_type, /*virtual_device_id=*/0)); |
266 | |
267 | ICHECK(default_primitive_virtual_device.defined()); |
268 | ICHECK(default_primitive_virtual_device->target.defined()); |
269 | |
270 | // Legacy: Some passes only support homogenous compilation and expect the target to be |
271 | // given by the global target context. Make this easy to detect. |
272 | optional_homogeneous_target = |
273 | primitive_targets.size() == 1 ? *primitive_targets.begin() : Target(); |
274 | } |
275 | |
276 | /* static */ Target CompilationConfigNode::MakeDefaultCPUTarget() { |
277 | if (runtime::Registry::Get("codegen.LLVMModuleCreate" )) { |
278 | // LLVM is available. |
279 | // TODO(mbs): More robust extension mechanism? |
280 | return Target("llvm" ); |
281 | } else { |
282 | // LLVM is not available. |
283 | // TODO(mbs): Already deprecated? |
284 | return Target("stackvm" ); |
285 | } |
286 | } |
287 | |
288 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
289 | .set_dispatch<CompilationConfigNode>([](const ObjectRef& ref, ReprPrinter* p) { |
290 | auto* node = ref.as<CompilationConfigNode>(); |
291 | p->stream << "Primitive targets:" ; |
292 | for (const auto& target : node->primitive_targets) { |
293 | p->stream << std::endl |
294 | << " " << target->GetTargetDeviceType() << " |-> " << target->ToDebugString(); |
295 | } |
296 | p->stream << std::endl |
297 | << "Default primitive virtual device: " << node->default_primitive_virtual_device; |
298 | p->stream << std::endl << "Host virtual device: " << node->host_virtual_device; |
299 | }); |
300 | |
301 | CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, |
302 | const Array<Target>& raw_targets) { |
303 | auto node = make_object<CompilationConfigNode>(); |
304 | node->Init(pass_ctx, raw_targets); |
305 | data_ = std::move(node); |
306 | } |
307 | |
308 | TVM_REGISTER_GLOBAL("target.MakeCompilationConfig" ) |
309 | .set_body_typed([](const transform::PassContext& pass_ctx, |
310 | const Array<Target>& raw_targets) -> CompilationConfig { |
311 | return CompilationConfig(pass_ctx, raw_targets); |
312 | }); |
313 | |
314 | } // namespace tvm |
315 | |