1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/analysis/device_domains.cc |
22 | * \brief Unification domain for the device planner. |
23 | */ |
24 | |
25 | #include "./device_domains.h" |
26 | |
27 | #include <tvm/relay/attrs/call.h> |
28 | #include <tvm/relay/attrs/memory.h> |
29 | |
30 | #include "../op/annotation/annotation.h" |
31 | #include "../op/call/call.h" |
32 | #include "../op/memory/device_copy.h" |
33 | #include "../op/memory/on_device.h" |
34 | |
35 | namespace tvm { |
36 | namespace relay { |
37 | namespace transform { |
38 | |
39 | DeviceDomains::DeviceDomains(CompilationConfig config) : config_(std::move(config)) { |
40 | host_domain_ = MakeFirstOrderDomain(config_->host_virtual_device); |
41 | } |
42 | |
43 | DeviceDomainPtr DeviceDomains::MakeFirstOrderDomain(const VirtualDevice& virtual_device) { |
44 | if (virtual_device->IsFullyConstrained()) { |
45 | auto itr = fully_constrained_virtual_device_to_domain_.find(virtual_device); |
46 | if (itr != fully_constrained_virtual_device_to_domain_.end()) { |
47 | return itr->second; |
48 | } |
49 | DeviceDomainPtr domain = std::make_shared<DeviceDomain>(virtual_device); |
50 | fully_constrained_virtual_device_to_domain_.emplace(virtual_device, domain); |
51 | return domain; |
52 | } else { |
53 | return std::make_shared<DeviceDomain>(virtual_device); |
54 | } |
55 | } |
56 | |
57 | DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, const VirtualDevice& virtual_device) { |
58 | if (const auto* func_type_node = type.as<FuncTypeNode>()) { |
59 | std::vector<DeviceDomainPtr> args_and_result; |
60 | args_and_result.reserve(func_type_node->arg_types.size() + 1); |
61 | for (const auto& arg_type : func_type_node->arg_types) { |
62 | args_and_result.emplace_back(MakeDomain(arg_type, VirtualDevice::FullyUnconstrained())); |
63 | } |
64 | args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, virtual_device)); |
65 | return std::make_shared<DeviceDomain>(std::move(args_and_result)); |
66 | } else { |
67 | return MakeFirstOrderDomain(virtual_device); |
68 | } |
69 | } |
70 | |
71 | DeviceDomainPtr DeviceDomains::ForVirtualDevice(const Type& type, |
72 | const VirtualDevice& non_canonical_virtual_device) { |
73 | // Generally the virtual device will have come from an annotation so resolve it to ensure we have |
74 | // its canonical representation. |
75 | VirtualDevice virtual_device = config_->CanonicalVirtualDevice(non_canonical_virtual_device); |
76 | ICHECK(!virtual_device->IsFullyUnconstrained()); |
77 | return MakeDomain(type, virtual_device); |
78 | } |
79 | |
80 | DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) { |
81 | DeviceDomainPtr root = domain; |
82 | while (true) { |
83 | auto itr = domain_to_equiv_.find(root); |
84 | if (itr == domain_to_equiv_.end()) { |
85 | break; |
86 | } |
87 | ICHECK_NE(itr->second, root); |
88 | root = itr->second; |
89 | ICHECK_NOTNULL(root); |
90 | } |
91 | // Path compression. |
92 | while (domain != root) { |
93 | auto itr = domain_to_equiv_.find(domain); |
94 | ICHECK(itr != domain_to_equiv_.end()); |
95 | domain = itr->second; |
96 | ICHECK_NOTNULL(domain); |
97 | itr->second = root; |
98 | } |
99 | return root; |
100 | } |
101 | |
102 | DeviceDomainPtr DeviceDomains::JoinOrNull(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { |
103 | if (lhs == rhs) { |
104 | return lhs; |
105 | } |
106 | ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size()) |
107 | << "Device domains:" << std::endl |
108 | << ToString(lhs) << std::endl |
109 | << "and" << std::endl |
110 | << ToString(rhs) << std::endl |
111 | << "do not have the same kind and can't be unified." ; |
112 | if (lhs->args_and_result_.empty()) { |
113 | // Directly compare first-order. |
114 | if (rhs->virtual_device_->IsFullyUnconstrained()) { |
115 | return lhs; |
116 | } |
117 | if (lhs->virtual_device_->IsFullyUnconstrained()) { |
118 | return rhs; |
119 | } |
120 | Optional<VirtualDevice> joined_virtual_device = |
121 | VirtualDevice::Join(lhs->virtual_device_, rhs->virtual_device_); |
122 | if (!joined_virtual_device) { |
123 | return nullptr; |
124 | } |
125 | return MakeFirstOrderDomain(config_->CanonicalVirtualDevice(joined_virtual_device.value())); |
126 | } else { |
127 | // Recurse for higher-order. |
128 | std::vector<DeviceDomainPtr> args_and_result; |
129 | args_and_result.reserve(lhs->args_and_result_.size()); |
130 | for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { |
131 | DeviceDomainPtr joined_domain = |
132 | UnifyOrNull(lhs->args_and_result_[i], rhs->args_and_result_[i]); |
133 | if (joined_domain == nullptr) { |
134 | return nullptr; |
135 | } |
136 | args_and_result.emplace_back(std::move(joined_domain)); |
137 | } |
138 | return MakeHigherOrderDomain(std::move(args_and_result)); |
139 | } |
140 | } |
141 | |
142 | DeviceDomainPtr DeviceDomains::UnifyOrNull(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { |
143 | ICHECK_NOTNULL(lhs); |
144 | ICHECK_NOTNULL(rhs); |
145 | lhs = Lookup(lhs); |
146 | rhs = Lookup(rhs); |
147 | DeviceDomainPtr joined_domain = JoinOrNull(lhs, rhs); |
148 | if (joined_domain == nullptr) { |
149 | return nullptr; |
150 | } |
151 | if (lhs != joined_domain) { |
152 | domain_to_equiv_.emplace(lhs, joined_domain); |
153 | } |
154 | if (rhs != joined_domain) { |
155 | domain_to_equiv_.emplace(rhs, joined_domain); |
156 | } |
157 | return joined_domain; |
158 | } |
159 | |
160 | bool DeviceDomains::CollapseOrFalse(const DeviceDomainPtr& first_order_domain, |
161 | const DeviceDomainPtr& higher_order_domain) { |
162 | ICHECK(!first_order_domain->is_higher_order()); |
163 | ICHECK(higher_order_domain->is_higher_order()); |
164 | for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { |
165 | if (UnifyOrNull(higher_order_domain->function_param(i), first_order_domain) == nullptr) { |
166 | return false; |
167 | } |
168 | } |
169 | return UnifyOrNull(higher_order_domain->function_result(), first_order_domain) != nullptr; |
170 | } |
171 | |
172 | bool DeviceDomains::UnifyCollapsedOrFalse(const DeviceDomainPtr& lhs_first_order, |
173 | const DeviceDomainPtr& rhs_maybe_higher_order) { |
174 | ICHECK(!lhs_first_order->is_higher_order()); |
175 | if (rhs_maybe_higher_order->is_higher_order()) { |
176 | return CollapseOrFalse(lhs_first_order, rhs_maybe_higher_order); |
177 | } else { |
178 | return UnifyOrNull(lhs_first_order, rhs_maybe_higher_order) != nullptr; |
179 | } |
180 | } |
181 | |
182 | DeviceDomainPtr DeviceDomains::DomainFor(const Expr& expr) { |
183 | ICHECK(expr.defined()); |
184 | auto itr = expr_to_domain_.find(expr.get()); |
185 | if (itr != expr_to_domain_.end()) { |
186 | return Lookup(itr->second); |
187 | } |
188 | auto domain = Free(expr->checked_type()); |
189 | expr_to_domain_.emplace(expr.get(), domain); |
190 | return domain; |
191 | } |
192 | |
193 | DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { |
194 | auto itr = call_to_callee_domain_.find(call.get()); |
195 | if (itr != call_to_callee_domain_.end()) { |
196 | return Lookup(itr->second); |
197 | } |
198 | std::vector<DeviceDomainPtr> args_and_result; |
199 | |
200 | OnDeviceProps on_device_props = GetOnDeviceProps(call.get()); |
201 | DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get()); |
202 | CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get()); |
203 | |
204 | if (call_lowered_props.lowered_func.defined()) { |
205 | // Presumably we've already seen the call to the "primitive" Function from which this lowered |
206 | // function was derived in an earlier PlanDevices pass. Thus we've already established that |
207 | // all the argument and result devices domains must be equal, ignoring memory scopes. |
208 | // So at this point we'll let all the arguments and result be free so that memory scopes can |
209 | // differ. |
210 | // TODO(mbs): As per header comments, need to revisit when can setup sub-virtual device |
211 | // constraints. |
212 | return DomainFor(call_lowered_props.lowered_func); |
213 | } else if (on_device_props.body.defined()) { |
214 | // By default: |
215 | // on_device(expr, virtual_device=<t>) |
216 | // on_device : fn(<t>):?x? |
217 | // However we'll interpret the constrain_body and constrain_result fields to decide |
218 | // on free vs constrained domains for the argument and result respectively. |
219 | if (on_device_props.constrain_body) { |
220 | args_and_result.emplace_back( |
221 | ForVirtualDevice(on_device_props.body->checked_type(), on_device_props.virtual_device)); |
222 | } else { |
223 | args_and_result.emplace_back(Free(on_device_props.body->checked_type())); |
224 | } |
225 | if (on_device_props.constrain_result) { |
226 | args_and_result.emplace_back( |
227 | ForVirtualDevice(on_device_props.body->checked_type(), on_device_props.virtual_device)); |
228 | } else { |
229 | args_and_result.emplace_back(Free(on_device_props.body->checked_type())); |
230 | } |
231 | } else if (device_copy_props.body.defined()) { |
232 | // device_copy(expr, src_virtual_device=<s>, dst_virtual_device=<d>) |
233 | // device_copy: fn(<s>):<d> |
234 | args_and_result.emplace_back(ForVirtualDevice(device_copy_props.body->checked_type(), |
235 | device_copy_props.src_virtual_device)); |
236 | args_and_result.emplace_back(ForVirtualDevice(device_copy_props.body->checked_type(), |
237 | device_copy_props.dst_virtual_device)); |
238 | } else if (call->op == alloc_storage_op) { |
239 | ICHECK_EQ(call->args.size(), 2U); |
240 | // alloc_storage(size, alignment, virtual_device=<t>) |
241 | // alloc_storage: fn(<cpu>, <cpu>):<t> |
242 | const auto* attrs = call->attrs.as<AllocStorageAttrs>(); |
243 | args_and_result.emplace_back(host_domain_); |
244 | args_and_result.emplace_back(host_domain_); |
245 | args_and_result.emplace_back(ForVirtualDevice(call->checked_type(), attrs->virtual_device)); |
246 | } else if (call->op == alloc_tensor_op) { |
247 | ICHECK_EQ(call->args.size(), 3U); |
248 | // alloc_tensor(storage, offset, shape) |
249 | // alloc_tensor: fn(?x?, <cpu>, <cpu>):?x? |
250 | auto free_domain = Free(call->checked_type()); |
251 | args_and_result.emplace_back(free_domain); |
252 | args_and_result.emplace_back(host_domain_); |
253 | args_and_result.emplace_back(host_domain_); |
254 | args_and_result.emplace_back(free_domain); |
255 | } else if (call->op == shape_of_op) { |
256 | ICHECK_EQ(call->args.size(), 1U); |
257 | // shape_of(tensor) |
258 | // shape_of: fn(?x?):<cpu> |
259 | args_and_result.emplace_back(Free(call->args[0]->checked_type())); |
260 | args_and_result.emplace_back(host_domain_); |
261 | } else if (call->op == invoke_tvm_op) { |
262 | ICHECK_EQ(call->args.size(), 3U); |
263 | // invoke_tvm_op(op, inputs, outputs) |
264 | // invoke_tvm_op: fn(..., ?x?, ?x?):?x? |
265 | // where ... is a free domain appropriate for op's type |
266 | auto free_domain = Free(call->checked_type()); |
267 | args_and_result.emplace_back(Free(call->args[0]->checked_type())); |
268 | args_and_result.emplace_back(free_domain); |
269 | args_and_result.emplace_back(free_domain); |
270 | args_and_result.emplace_back(free_domain); |
271 | } else if (call->op == reshape_tensor_op) { |
272 | ICHECK_EQ(call->args.size(), 2U); |
273 | // reshape_tensor(data, shape) |
274 | // reshape_tensor: fn(?x?, <cpu>):?x? |
275 | auto free_domain = Free(call->checked_type()); |
276 | args_and_result.emplace_back(free_domain); |
277 | args_and_result.emplace_back(host_domain_); |
278 | args_and_result.emplace_back(free_domain); |
279 | } else if (call->op->IsInstance<OpNode>()) { |
280 | // <primitive>(arg1, ..., argn) |
281 | // <primitive>: fn(?x?, ..., ?x?):?x? |
282 | // (all args and result must be first-order). |
283 | auto free_domain = MakeFirstOrderDomain(VirtualDevice::FullyUnconstrained()); |
284 | for (size_t i = 0; i < call->args.size(); ++i) { |
285 | args_and_result.emplace_back(free_domain); |
286 | } |
287 | args_and_result.emplace_back(free_domain); |
288 | } else if (call->op->IsInstance<ConstructorNode>()) { |
289 | // <constructor>(arg1, ..., argn) |
290 | // <constructor>: fn(?x1?, ..., ?xn?):?xr? |
291 | // where we force all possibly higher-order ?xi? to be collapsed to the first-order ?xr?. |
292 | // TODO(mbs): This assumes we've eta-expanded constructors, thus all constructors appear |
293 | // in callee positions. |
294 | const auto* func_type_node = call->op->checked_type().as<FuncTypeNode>(); |
295 | ICHECK_NOTNULL(func_type_node); |
296 | ICHECK_EQ(func_type_node->arg_types.size(), call->args.size()); |
297 | auto result_domain = Free(func_type_node->ret_type); // first-order |
298 | for (const auto& arg_type : func_type_node->arg_types) { |
299 | auto param_domain = Free(arg_type); // possibly higher-order |
300 | bool success = UnifyCollapsedOrFalse(result_domain, param_domain); // collapse if required |
301 | ICHECK(success); |
302 | args_and_result.emplace_back(param_domain); |
303 | } |
304 | args_and_result.emplace_back(result_domain); |
305 | } else { |
306 | // We still need to handle the case where the function / op is not lowered |
307 | // because the device planner runs both before and after lowering. |
308 | return DomainFor(call->op); |
309 | } |
310 | auto domain = MakeHigherOrderDomain(std::move(args_and_result)); |
311 | call_to_callee_domain_.emplace(call.get(), domain); |
312 | return domain; |
313 | } |
314 | |
315 | void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) { |
316 | auto lhs_domain = DomainFor(lhs); |
317 | auto rhs_domain = DomainFor(rhs); |
318 | if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) { |
319 | // TODO(mbs): Proper diagnostics. |
320 | LOG(FATAL) << "Incompatible virtual devices for expressions:" << std::endl |
321 | << PrettyPrint(lhs) << std::endl |
322 | << "with virtual device:" << std::endl |
323 | << ToString(lhs_domain) << "and:" << std::endl |
324 | << PrettyPrint(rhs) << std::endl |
325 | << "with virtual device:" << std::endl |
326 | << ToString(rhs_domain); |
327 | } |
328 | } |
329 | |
330 | void DeviceDomains::OptionalUnifyExprExact(const Expr& lhs, const Expr& rhs) { |
331 | auto lhs_domain = DomainFor(lhs); |
332 | auto rhs_domain = DomainFor(rhs); |
333 | // Snapshot |
334 | std::unordered_map<DeviceDomainPtr, DeviceDomainPtr> domain_to_equiv_snapshot = domain_to_equiv_; |
335 | if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) { |
336 | // Rollback |
337 | domain_to_equiv_ = domain_to_equiv_snapshot; |
338 | VLOG(2) << "Unable to unify virtual devices for expression:" << std::endl |
339 | << PrettyPrint(lhs) << std::endl |
340 | << "with virtual device:" << std::endl |
341 | << ToString(lhs_domain) << std::endl |
342 | << "and expression:" << std::endl |
343 | << PrettyPrint(rhs) << std::endl |
344 | << "with virtual device:" << std::endl |
345 | << ToString(rhs_domain) << std::endl |
346 | << ". Leaving virtual devices non-unified." ; |
347 | } else { |
348 | VLOG(2) << "Unified virtual devices for expression:" << std::endl |
349 | << PrettyPrint(lhs) << std::endl |
350 | << "and expression:" << std::endl |
351 | << PrettyPrint(rhs) << std::endl |
352 | << "to virtual devices:" << std::endl |
353 | << ToString(lhs_domain); |
354 | } |
355 | } |
356 | |
357 | void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) { |
358 | auto actual_domain = DomainFor(expr); |
359 | if (UnifyOrNull(actual_domain, expected_domain) == nullptr) { |
360 | // TODO(mbs): Proper diagnostics. |
361 | LOG(FATAL) << "Incompatible virtual devices for expression:" << std::endl |
362 | << PrettyPrint(expr) << std::endl |
363 | << "with actual virtual device:" << std::endl |
364 | << ToString(actual_domain) << std::endl |
365 | << "and expected virtual device:" << std::endl |
366 | << ToString(expected_domain); |
367 | } |
368 | } |
369 | |
370 | void DeviceDomains::UnifyExprCollapsed(const Expr& expr_first_order, |
371 | const DeviceDomainPtr& expected_domain_maybe_higher_order) { |
372 | auto actual_domain_first_order = DomainFor(expr_first_order); |
373 | if (!UnifyCollapsedOrFalse(actual_domain_first_order, expected_domain_maybe_higher_order)) { |
374 | // TODO(mbs): Proper diagnostics. |
375 | LOG(FATAL) << "Incompatible virtual devices for expression:" << std::endl |
376 | << PrettyPrint(expr_first_order) << std::endl |
377 | << "with actual virtual devices:" << std::endl |
378 | << ToString(actual_domain_first_order) << std::endl |
379 | << "and expected virtual device:" << std::endl |
380 | << ToString(expected_domain_maybe_higher_order); |
381 | } |
382 | } |
383 | |
384 | bool DeviceDomains::IsFullyConstrained(DeviceDomainPtr domain) { |
385 | domain = Lookup(domain); |
386 | if (domain->args_and_result_.empty()) { |
387 | // First-order. |
388 | return domain->virtual_device_->IsFullyConstrained(); |
389 | } else { |
390 | // Higher-order. |
391 | return std::all_of( |
392 | domain->args_and_result_.begin(), domain->args_and_result_.end(), |
393 | [this](const DeviceDomainPtr& sub_domain) { return IsFullyConstrained(sub_domain); }); |
394 | } |
395 | } |
396 | |
397 | void DeviceDomains::SetDefault(DeviceDomainPtr domain, |
398 | const VirtualDevice& default_virtual_device) { |
399 | ICHECK(!default_virtual_device->IsFullyUnconstrained()); |
400 | domain = Lookup(domain); |
401 | if (domain->args_and_result_.empty()) { |
402 | DeviceDomainPtr default_domain = MakeFirstOrderDomain(config_->CanonicalVirtualDevice( |
403 | VirtualDevice::Default(domain->virtual_device_, default_virtual_device))); |
404 | DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull(domain, default_domain); |
405 | ICHECK(defaulted_domain_ptr != nullptr) << "domain:" << std::endl |
406 | << ToString(domain) << std::endl |
407 | << "default domain:" << std::endl |
408 | << ToString(default_domain); |
409 | } else { |
410 | for (const auto& sub_domain : domain->args_and_result_) { |
411 | SetDefault(sub_domain, default_virtual_device); |
412 | } |
413 | } |
414 | } |
415 | |
416 | void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order, |
417 | const VirtualDevice& default_virtual_device) { |
418 | if (domain_maybe_higher_order->args_and_result_.empty()) { |
419 | SetDefault(domain_maybe_higher_order, default_virtual_device); |
420 | } else { |
421 | // First set default for result domain. |
422 | SetDefault(ResultDomain(domain_maybe_higher_order), default_virtual_device); |
423 | // Then use current result domain as default for everything else. |
424 | SetDefault(domain_maybe_higher_order, ResultVirtualDevice(domain_maybe_higher_order)); |
425 | } |
426 | } |
427 | |
428 | DeviceDomainPtr DeviceDomains::ResultDomain(DeviceDomainPtr domain) { |
429 | domain = Lookup(domain); |
430 | while (!domain->args_and_result_.empty()) { |
431 | domain = Lookup(domain->args_and_result_.back()); |
432 | } |
433 | return domain; |
434 | } |
435 | |
436 | std::string DeviceDomains::ToString(DeviceDomainPtr domain) { |
437 | domain = Lookup(domain); |
438 | std::ostringstream os; |
439 | if (domain->args_and_result_.empty()) { |
440 | // First-order. |
441 | if (!domain->virtual_device_->IsFullyConstrained()) { |
442 | os << "?" << static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get())) << "?" ; |
443 | } |
444 | if (!domain->virtual_device_->IsFullyUnconstrained()) { |
445 | os << domain->virtual_device_; |
446 | } |
447 | } else { |
448 | // higher-order |
449 | os << "fn(" ; |
450 | for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) { |
451 | if (i > 0) { |
452 | os << "," ; |
453 | } |
454 | os << ToString(domain->args_and_result_[i]); |
455 | } |
456 | os << "):" << ToString(domain->args_and_result_.back()); |
457 | } |
458 | return os.str(); |
459 | } |
460 | |
461 | std::string DeviceDomains::ToString() { |
462 | std::ostringstream os; |
463 | for (const auto& pair : expr_to_domain_) { |
464 | os << "expression:" << std::endl |
465 | << PrettyPrint(GetRef<Expr>(pair.first)) << std::endl |
466 | << "domain:" << std::endl |
467 | << ToString(pair.second) << std::endl |
468 | << std::endl; |
469 | } |
470 | for (const auto& pair : call_to_callee_domain_) { |
471 | os << "call:" << std::endl |
472 | << PrettyPrint(GetRef<Call>(pair.first)) << std::endl |
473 | << "callee domain:" << std::endl |
474 | << ToString(pair.second) << std::endl |
475 | << std::endl; |
476 | } |
477 | return os.str(); |
478 | } |
479 | |
480 | } // namespace transform |
481 | } // namespace relay |
482 | } // namespace tvm |
483 | |