1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/target/compilation_config.cc
22 * \brief Implementation of \p CompilationConfig for collecting \p Targets.
23 */
24
25#include <tvm/runtime/device_api.h>
26#include <tvm/target/compilation_config.h>
27
28namespace tvm {
29
30TVM_REGISTER_NODE_TYPE(CompilationConfigNode);
31
32void CompilationConfigNode::VisitAttrs(AttrVisitor* v) {
33 v->Visit("host_target", &host_target);
34 v->Visit("primitive_targets", &primitive_targets);
35 v->Visit("default_primitive_virtual_device", &default_primitive_virtual_device);
36 v->Visit("host_virtual_device", &host_virtual_device);
37 v->Visit("optional_homogenous_target", &optional_homogeneous_target);
38 // NOTE: The virtual_device_cache_ is not accessible via FFI.
39}
40
41Target CompilationConfigNode::FindPrimitiveTargetForDeviceOrFail(DLDeviceType device_type) const {
42 ICHECK_GT(device_type, 0) << "Invalid device type";
43 auto itr = std::find_if(
44 primitive_targets.begin(), primitive_targets.end(),
45 [device_type](const Target& target) { return target->GetTargetDeviceType() == device_type; });
46 if (itr == primitive_targets.end()) {
47 std::stringstream msg;
48 msg << "No target is specified for device type " << device_type
49 << ". The available device types and targets are:" << std::endl;
50 for (const auto& target : primitive_targets) {
51 msg << " " << target->GetTargetDeviceType() << "-> " << target->ToDebugString() << std::endl;
52 }
53 LOG(FATAL) << msg.str();
54 }
55 return *itr;
56}
57
58Optional<Target> CompilationConfigNode::FindPrimitiveTargetForKind(
59 const std::string& kind_name) const {
60 Optional<TargetKind> opt_kind = TargetKind::Get(kind_name);
61 if (!opt_kind.defined()) {
62 VLOG(1) << "No such target kind for '" << kind_name << "'";
63 return {};
64 }
65 auto itr =
66 std::find_if(primitive_targets.begin(), primitive_targets.end(),
67 [kind_name](const Target& target) { return target->kind->name == kind_name; });
68 if (itr == primitive_targets.end()) {
69 VLOG(1) << "No target available matching kind '" << kind_name << "'";
70 return {};
71 }
72 return *itr;
73}
74
75Target CompilationConfigNode::CanonicalTarget(const Target& target) const {
76 // Fast path -- object identity.
77 if (target == host_target) {
78 return target;
79 }
80 for (const auto& primitive_target : primitive_targets) {
81 if (target == primitive_target) {
82 return target;
83 }
84 }
85 // Slow path -- structural equality. We have so few targets it does not seem worth building an
86 // index.
87 if (StructuralEqual()(target, host_target)) {
88 return host_target;
89 }
90 for (const auto& primitive_target : primitive_targets) {
91 if (StructuralEqual()(target, primitive_target)) {
92 return primitive_target;
93 }
94 }
95 // No match.
96 return target;
97}
98
99VirtualDevice CompilationConfigNode::CanonicalVirtualDevice(
100 const VirtualDevice& virtual_device) const {
101 // Targets need special handling.
102 Target target = virtual_device->target;
103 if (target.defined()) {
104 // It is possible the given target object was constructed by the user, but was then
105 // rewritten on the way into the CompilationConfig. So 'canonicalize' it by replacing
106 // the given target with one structurally equal to one already known in the config if
107 // possible.
108 Target canon_target = CanonicalTarget(target);
109 if (canon_target != target) {
110 VLOG(1) << "Canonicalized target " << canon_target->ToDebugString();
111 }
112 target = canon_target;
113 } else if (virtual_device->device_type() != kInvalidDeviceType) {
114 // Since no target was given, choose one with a matching device type.
115 // This is the one place where we allow device types to imply targets.
116 target = FindPrimitiveTargetForDeviceOrFail(virtual_device->device_type());
117 VLOG(1) << "Defaulted to target " << target->ToDebugString();
118 }
119 // else: the target will remain unknown.
120
121 // Redirect to an existing structurally equal virtual device.
122 return virtual_device_cache_.Unique(VirtualDevice(virtual_device->device_type(),
123 virtual_device->virtual_device_id, target,
124 virtual_device->memory_scope));
125}
126
127void CompilationConfigNode::Init(const transform::PassContext& pass_ctx,
128 const Array<Target>& raw_targets) {
129 VLOG_CONTEXT << "CompilationConfig";
130 CHECK_GT(raw_targets.size(), 0U) << "Require at least one target";
131
132 //
133 // Decide on the host target.
134 //
135
136 // Any targets which could act as a host?
137 auto hosting_itr = std::find_if(raw_targets.begin(), raw_targets.end(), [](const Target& target) {
138 // TODO(tvm-team): The kDLHexagon device can act as a host. We can remove kDLHexagon
139 // here once we refactored kDLHexagon to kDLCPU.
140 return target->GetTargetDeviceType() == kDLCPU || target->GetTargetDeviceType() == kDLHexagon;
141 });
142
143 // Any targets with their host field set?
144 auto has_host_itr = std::find_if(raw_targets.begin(), raw_targets.end(),
145 [](const Target& target) { return target->host.defined(); });
146
147 if (has_host_itr != raw_targets.end()) {
148 // RULE A: If any raw target has a host, use the first such host for all the primitive
149 // targets.
150 host_target = Target((*has_host_itr)->GetHost().value(), /*host=*/Target());
151 VLOG(1) << "The target " << (*has_host_itr)->ToDebugString() << " supplies a host target "
152 << host_target->ToDebugString() << " of device type "
153 << host_target->GetTargetDeviceType();
154 } else if (hosting_itr != raw_targets.end()) {
155 // RULE B: If any raw target is for a device which could be a host then use the first such as
156 // the host.
157 host_target = Target(*hosting_itr, /*host=*/Target());
158 VLOG(1) << "Using target " << host_target->ToDebugString() << " of CPU-like device type "
159 << host_target->GetTargetDeviceType() << " as the host target";
160 } else {
161 // RULE C: Otherwise, create a default CPU host target.
162 host_target = MakeDefaultCPUTarget();
163 VLOG(1) << "Created a default target " << host_target->ToDebugString() << " of device type "
164 << host_target->GetTargetDeviceType() << " for the host target";
165 }
166 ICHECK(host_target.defined());
167 ICHECK(!host_target->host.defined());
168
169 if (host_target->GetTargetDeviceType() != kDLCPU) {
170 // I think we're on thin ice here until we've audited the code base for assumed CPU hosts.
171 VLOG(1) << "The host target is not a CPU. This is probably not going to work.";
172 }
173
174 //
175 // Establish the host VirtualDevice.
176 //
177 host_virtual_device = virtual_device_cache_.Unique(
178 VirtualDevice(static_cast<DLDeviceType>(host_target->GetTargetDeviceType()),
179 /*virtual_device_id=*/0, host_target));
180 ICHECK(host_virtual_device.defined());
181 ICHECK(host_virtual_device->target.defined());
182
183 //
184 // Now that we've settled on a host, we can set it as the host on all the raw targets.
185 //
186 primitive_targets.clear();
187 primitive_targets.reserve(raw_targets.size());
188 for (const auto& raw_target : raw_targets) {
189 if (raw_target->host.defined() && !StructuralEqual()(raw_target->host, host_target)) {
190 VLOG(1) << "The target " << raw_target->ToDebugString()
191 << " already has a host which disagrees with the desired host target. It "
192 << "will be overridden.";
193 }
194 primitive_targets.push_back(Target(raw_target, host_target));
195 }
196 ICHECK_GT(primitive_targets.size(), 0U);
197
198 //
199 // Check the primitive_targets are ordered correctly re Target::IsExternalCodegenFor,
200 // and make sure no two targets share a kind name.
201 //
202
203 // TODO(mbs): We could just sort the list, but given all the implicit defaulting for backwards
204 // compat it seems we should avoid making this any more magical than necessary. But revisit
205 // if usability suffers.
206 std::unordered_set<DLDeviceType> primitive_target_device_types;
207 std::unordered_set<std::string> kind_names;
208 for (const auto& target : primitive_targets) {
209 primitive_target_device_types.emplace(static_cast<DLDeviceType>(target->GetTargetDeviceType()));
210 CHECK(kind_names.emplace(target->kind->name).second) << "Multiple targets have been given"
211 "for the same device kind '"
212 << target->kind->name << "'";
213 }
214 for (DLDeviceType device_type : primitive_target_device_types) {
215 Target first_primitive_target;
216 for (const auto& current_primitive_target : primitive_targets) {
217 if (current_primitive_target->GetTargetDeviceType() != device_type) {
218 continue;
219 }
220 if (!first_primitive_target.defined()) {
221 first_primitive_target = current_primitive_target;
222 // Note it is valid to have only one external codegen target.
223 } else {
224 CHECK(current_primitive_target.IsExternalCodegenFor(first_primitive_target))
225 << "When given multiple targets for the device type " << device_type
226 << " the first must be for non external codegen, and all subsequent must be for "
227 "external codegen. However have been given first "
228 << first_primitive_target->ToDebugString() << " and subsequent "
229 << current_primitive_target->ToDebugString();
230 }
231 }
232 }
233
234 //
235 // Decide on the default device type for primitives.
236 //
237 DLDeviceType default_primitive_device_type;
238 Optional<Integer> opt_fallback_dev = pass_ctx->GetConfig<Integer>("relay.fallback_device_type");
239 if (opt_fallback_dev) {
240 // RULE D: Respect the PassContext setting if given.
241 const int64_t v = opt_fallback_dev.value()->value;
242 CHECK_GT(v, 0)
243 << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " << v;
244 default_primitive_device_type = static_cast<DLDeviceType>(v);
245 VLOG(1) << "Using the 'relay.fallback_device_type' pass attribute "
246 << default_primitive_device_type
247 << " as the default device type for all primitive operations";
248 } else if (primitive_target_device_types.size() == 1) {
249 // RULE E: Since only one device in use there's no choice to make.
250 default_primitive_device_type = *primitive_target_device_types.begin();
251 VLOG(1) << "All primitive targets have the device type " << default_primitive_device_type
252 << " so that is also the default device type for all primitive operations.";
253 } else {
254 // RULE F: Fallback to CPU.
255 default_primitive_device_type = kDLCPU;
256 VLOG(1) << "Using " << default_primitive_device_type
257 << " as the default device type for all primitive operations";
258 }
259
260 //
261 // Establish the default primitive VirtualDevice, choosing a known Target to match the device
262 // type. We do not create a default target, it must already exist as a primitive target.
263 //
264 default_primitive_virtual_device = CanonicalVirtualDevice(
265 VirtualDevice::ForDeviceType(default_primitive_device_type, /*virtual_device_id=*/0));
266
267 ICHECK(default_primitive_virtual_device.defined());
268 ICHECK(default_primitive_virtual_device->target.defined());
269
270 // Legacy: Some passes only support homogenous compilation and expect the target to be
271 // given by the global target context. Make this easy to detect.
272 optional_homogeneous_target =
273 primitive_targets.size() == 1 ? *primitive_targets.begin() : Target();
274}
275
276/* static */ Target CompilationConfigNode::MakeDefaultCPUTarget() {
277 if (runtime::Registry::Get("codegen.LLVMModuleCreate")) {
278 // LLVM is available.
279 // TODO(mbs): More robust extension mechanism?
280 return Target("llvm");
281 } else {
282 // LLVM is not available.
283 // TODO(mbs): Already deprecated?
284 return Target("stackvm");
285 }
286}
287
288TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
289 .set_dispatch<CompilationConfigNode>([](const ObjectRef& ref, ReprPrinter* p) {
290 auto* node = ref.as<CompilationConfigNode>();
291 p->stream << "Primitive targets:";
292 for (const auto& target : node->primitive_targets) {
293 p->stream << std::endl
294 << " " << target->GetTargetDeviceType() << " |-> " << target->ToDebugString();
295 }
296 p->stream << std::endl
297 << "Default primitive virtual device: " << node->default_primitive_virtual_device;
298 p->stream << std::endl << "Host virtual device: " << node->host_virtual_device;
299 });
300
301CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx,
302 const Array<Target>& raw_targets) {
303 auto node = make_object<CompilationConfigNode>();
304 node->Init(pass_ctx, raw_targets);
305 data_ = std::move(node);
306}
307
308TVM_REGISTER_GLOBAL("target.MakeCompilationConfig")
309 .set_body_typed([](const transform::PassContext& pass_ctx,
310 const Array<Target>& raw_targets) -> CompilationConfig {
311 return CompilationConfig(pass_ctx, raw_targets);
312 });
313
314} // namespace tvm
315