1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/analysis/device_domains.cc
22 * \brief Unification domain for the device planner.
23 */
24
25#include "./device_domains.h"
26
27#include <tvm/relay/attrs/call.h>
28#include <tvm/relay/attrs/memory.h>
29
30#include "../op/annotation/annotation.h"
31#include "../op/call/call.h"
32#include "../op/memory/device_copy.h"
33#include "../op/memory/on_device.h"
34
35namespace tvm {
36namespace relay {
37namespace transform {
38
39DeviceDomains::DeviceDomains(CompilationConfig config) : config_(std::move(config)) {
40 host_domain_ = MakeFirstOrderDomain(config_->host_virtual_device);
41}
42
43DeviceDomainPtr DeviceDomains::MakeFirstOrderDomain(const VirtualDevice& virtual_device) {
44 if (virtual_device->IsFullyConstrained()) {
45 auto itr = fully_constrained_virtual_device_to_domain_.find(virtual_device);
46 if (itr != fully_constrained_virtual_device_to_domain_.end()) {
47 return itr->second;
48 }
49 DeviceDomainPtr domain = std::make_shared<DeviceDomain>(virtual_device);
50 fully_constrained_virtual_device_to_domain_.emplace(virtual_device, domain);
51 return domain;
52 } else {
53 return std::make_shared<DeviceDomain>(virtual_device);
54 }
55}
56
57DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, const VirtualDevice& virtual_device) {
58 if (const auto* func_type_node = type.as<FuncTypeNode>()) {
59 std::vector<DeviceDomainPtr> args_and_result;
60 args_and_result.reserve(func_type_node->arg_types.size() + 1);
61 for (const auto& arg_type : func_type_node->arg_types) {
62 args_and_result.emplace_back(MakeDomain(arg_type, VirtualDevice::FullyUnconstrained()));
63 }
64 args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, virtual_device));
65 return std::make_shared<DeviceDomain>(std::move(args_and_result));
66 } else {
67 return MakeFirstOrderDomain(virtual_device);
68 }
69}
70
71DeviceDomainPtr DeviceDomains::ForVirtualDevice(const Type& type,
72 const VirtualDevice& non_canonical_virtual_device) {
73 // Generally the virtual device will have come from an annotation so resolve it to ensure we have
74 // its canonical representation.
75 VirtualDevice virtual_device = config_->CanonicalVirtualDevice(non_canonical_virtual_device);
76 ICHECK(!virtual_device->IsFullyUnconstrained());
77 return MakeDomain(type, virtual_device);
78}
79
80DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) {
81 DeviceDomainPtr root = domain;
82 while (true) {
83 auto itr = domain_to_equiv_.find(root);
84 if (itr == domain_to_equiv_.end()) {
85 break;
86 }
87 ICHECK_NE(itr->second, root);
88 root = itr->second;
89 ICHECK_NOTNULL(root);
90 }
91 // Path compression.
92 while (domain != root) {
93 auto itr = domain_to_equiv_.find(domain);
94 ICHECK(itr != domain_to_equiv_.end());
95 domain = itr->second;
96 ICHECK_NOTNULL(domain);
97 itr->second = root;
98 }
99 return root;
100}
101
102DeviceDomainPtr DeviceDomains::JoinOrNull(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
103 if (lhs == rhs) {
104 return lhs;
105 }
106 ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size())
107 << "Device domains:" << std::endl
108 << ToString(lhs) << std::endl
109 << "and" << std::endl
110 << ToString(rhs) << std::endl
111 << "do not have the same kind and can't be unified.";
112 if (lhs->args_and_result_.empty()) {
113 // Directly compare first-order.
114 if (rhs->virtual_device_->IsFullyUnconstrained()) {
115 return lhs;
116 }
117 if (lhs->virtual_device_->IsFullyUnconstrained()) {
118 return rhs;
119 }
120 Optional<VirtualDevice> joined_virtual_device =
121 VirtualDevice::Join(lhs->virtual_device_, rhs->virtual_device_);
122 if (!joined_virtual_device) {
123 return nullptr;
124 }
125 return MakeFirstOrderDomain(config_->CanonicalVirtualDevice(joined_virtual_device.value()));
126 } else {
127 // Recurse for higher-order.
128 std::vector<DeviceDomainPtr> args_and_result;
129 args_and_result.reserve(lhs->args_and_result_.size());
130 for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) {
131 DeviceDomainPtr joined_domain =
132 UnifyOrNull(lhs->args_and_result_[i], rhs->args_and_result_[i]);
133 if (joined_domain == nullptr) {
134 return nullptr;
135 }
136 args_and_result.emplace_back(std::move(joined_domain));
137 }
138 return MakeHigherOrderDomain(std::move(args_and_result));
139 }
140}
141
142DeviceDomainPtr DeviceDomains::UnifyOrNull(DeviceDomainPtr lhs, DeviceDomainPtr rhs) {
143 ICHECK_NOTNULL(lhs);
144 ICHECK_NOTNULL(rhs);
145 lhs = Lookup(lhs);
146 rhs = Lookup(rhs);
147 DeviceDomainPtr joined_domain = JoinOrNull(lhs, rhs);
148 if (joined_domain == nullptr) {
149 return nullptr;
150 }
151 if (lhs != joined_domain) {
152 domain_to_equiv_.emplace(lhs, joined_domain);
153 }
154 if (rhs != joined_domain) {
155 domain_to_equiv_.emplace(rhs, joined_domain);
156 }
157 return joined_domain;
158}
159
160bool DeviceDomains::CollapseOrFalse(const DeviceDomainPtr& first_order_domain,
161 const DeviceDomainPtr& higher_order_domain) {
162 ICHECK(!first_order_domain->is_higher_order());
163 ICHECK(higher_order_domain->is_higher_order());
164 for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) {
165 if (UnifyOrNull(higher_order_domain->function_param(i), first_order_domain) == nullptr) {
166 return false;
167 }
168 }
169 return UnifyOrNull(higher_order_domain->function_result(), first_order_domain) != nullptr;
170}
171
172bool DeviceDomains::UnifyCollapsedOrFalse(const DeviceDomainPtr& lhs_first_order,
173 const DeviceDomainPtr& rhs_maybe_higher_order) {
174 ICHECK(!lhs_first_order->is_higher_order());
175 if (rhs_maybe_higher_order->is_higher_order()) {
176 return CollapseOrFalse(lhs_first_order, rhs_maybe_higher_order);
177 } else {
178 return UnifyOrNull(lhs_first_order, rhs_maybe_higher_order) != nullptr;
179 }
180}
181
182DeviceDomainPtr DeviceDomains::DomainFor(const Expr& expr) {
183 ICHECK(expr.defined());
184 auto itr = expr_to_domain_.find(expr.get());
185 if (itr != expr_to_domain_.end()) {
186 return Lookup(itr->second);
187 }
188 auto domain = Free(expr->checked_type());
189 expr_to_domain_.emplace(expr.get(), domain);
190 return domain;
191}
192
193DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) {
194 auto itr = call_to_callee_domain_.find(call.get());
195 if (itr != call_to_callee_domain_.end()) {
196 return Lookup(itr->second);
197 }
198 std::vector<DeviceDomainPtr> args_and_result;
199
200 OnDeviceProps on_device_props = GetOnDeviceProps(call.get());
201 DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get());
202 CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get());
203
204 if (call_lowered_props.lowered_func.defined()) {
205 // Presumably we've already seen the call to the "primitive" Function from which this lowered
206 // function was derived in an earlier PlanDevices pass. Thus we've already established that
207 // all the argument and result devices domains must be equal, ignoring memory scopes.
208 // So at this point we'll let all the arguments and result be free so that memory scopes can
209 // differ.
210 // TODO(mbs): As per header comments, need to revisit when can setup sub-virtual device
211 // constraints.
212 return DomainFor(call_lowered_props.lowered_func);
213 } else if (on_device_props.body.defined()) {
214 // By default:
215 // on_device(expr, virtual_device=<t>)
216 // on_device : fn(<t>):?x?
217 // However we'll interpret the constrain_body and constrain_result fields to decide
218 // on free vs constrained domains for the argument and result respectively.
219 if (on_device_props.constrain_body) {
220 args_and_result.emplace_back(
221 ForVirtualDevice(on_device_props.body->checked_type(), on_device_props.virtual_device));
222 } else {
223 args_and_result.emplace_back(Free(on_device_props.body->checked_type()));
224 }
225 if (on_device_props.constrain_result) {
226 args_and_result.emplace_back(
227 ForVirtualDevice(on_device_props.body->checked_type(), on_device_props.virtual_device));
228 } else {
229 args_and_result.emplace_back(Free(on_device_props.body->checked_type()));
230 }
231 } else if (device_copy_props.body.defined()) {
232 // device_copy(expr, src_virtual_device=<s>, dst_virtual_device=<d>)
233 // device_copy: fn(<s>):<d>
234 args_and_result.emplace_back(ForVirtualDevice(device_copy_props.body->checked_type(),
235 device_copy_props.src_virtual_device));
236 args_and_result.emplace_back(ForVirtualDevice(device_copy_props.body->checked_type(),
237 device_copy_props.dst_virtual_device));
238 } else if (call->op == alloc_storage_op) {
239 ICHECK_EQ(call->args.size(), 2U);
240 // alloc_storage(size, alignment, virtual_device=<t>)
241 // alloc_storage: fn(<cpu>, <cpu>):<t>
242 const auto* attrs = call->attrs.as<AllocStorageAttrs>();
243 args_and_result.emplace_back(host_domain_);
244 args_and_result.emplace_back(host_domain_);
245 args_and_result.emplace_back(ForVirtualDevice(call->checked_type(), attrs->virtual_device));
246 } else if (call->op == alloc_tensor_op) {
247 ICHECK_EQ(call->args.size(), 3U);
248 // alloc_tensor(storage, offset, shape)
249 // alloc_tensor: fn(?x?, <cpu>, <cpu>):?x?
250 auto free_domain = Free(call->checked_type());
251 args_and_result.emplace_back(free_domain);
252 args_and_result.emplace_back(host_domain_);
253 args_and_result.emplace_back(host_domain_);
254 args_and_result.emplace_back(free_domain);
255 } else if (call->op == shape_of_op) {
256 ICHECK_EQ(call->args.size(), 1U);
257 // shape_of(tensor)
258 // shape_of: fn(?x?):<cpu>
259 args_and_result.emplace_back(Free(call->args[0]->checked_type()));
260 args_and_result.emplace_back(host_domain_);
261 } else if (call->op == invoke_tvm_op) {
262 ICHECK_EQ(call->args.size(), 3U);
263 // invoke_tvm_op(op, inputs, outputs)
264 // invoke_tvm_op: fn(..., ?x?, ?x?):?x?
265 // where ... is a free domain appropriate for op's type
266 auto free_domain = Free(call->checked_type());
267 args_and_result.emplace_back(Free(call->args[0]->checked_type()));
268 args_and_result.emplace_back(free_domain);
269 args_and_result.emplace_back(free_domain);
270 args_and_result.emplace_back(free_domain);
271 } else if (call->op == reshape_tensor_op) {
272 ICHECK_EQ(call->args.size(), 2U);
273 // reshape_tensor(data, shape)
274 // reshape_tensor: fn(?x?, <cpu>):?x?
275 auto free_domain = Free(call->checked_type());
276 args_and_result.emplace_back(free_domain);
277 args_and_result.emplace_back(host_domain_);
278 args_and_result.emplace_back(free_domain);
279 } else if (call->op->IsInstance<OpNode>()) {
280 // <primitive>(arg1, ..., argn)
281 // <primitive>: fn(?x?, ..., ?x?):?x?
282 // (all args and result must be first-order).
283 auto free_domain = MakeFirstOrderDomain(VirtualDevice::FullyUnconstrained());
284 for (size_t i = 0; i < call->args.size(); ++i) {
285 args_and_result.emplace_back(free_domain);
286 }
287 args_and_result.emplace_back(free_domain);
288 } else if (call->op->IsInstance<ConstructorNode>()) {
289 // <constructor>(arg1, ..., argn)
290 // <constructor>: fn(?x1?, ..., ?xn?):?xr?
291 // where we force all possibly higher-order ?xi? to be collapsed to the first-order ?xr?.
292 // TODO(mbs): This assumes we've eta-expanded constructors, thus all constructors appear
293 // in callee positions.
294 const auto* func_type_node = call->op->checked_type().as<FuncTypeNode>();
295 ICHECK_NOTNULL(func_type_node);
296 ICHECK_EQ(func_type_node->arg_types.size(), call->args.size());
297 auto result_domain = Free(func_type_node->ret_type); // first-order
298 for (const auto& arg_type : func_type_node->arg_types) {
299 auto param_domain = Free(arg_type); // possibly higher-order
300 bool success = UnifyCollapsedOrFalse(result_domain, param_domain); // collapse if required
301 ICHECK(success);
302 args_and_result.emplace_back(param_domain);
303 }
304 args_and_result.emplace_back(result_domain);
305 } else {
306 // We still need to handle the case where the function / op is not lowered
307 // because the device planner runs both before and after lowering.
308 return DomainFor(call->op);
309 }
310 auto domain = MakeHigherOrderDomain(std::move(args_and_result));
311 call_to_callee_domain_.emplace(call.get(), domain);
312 return domain;
313}
314
315void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) {
316 auto lhs_domain = DomainFor(lhs);
317 auto rhs_domain = DomainFor(rhs);
318 if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) {
319 // TODO(mbs): Proper diagnostics.
320 LOG(FATAL) << "Incompatible virtual devices for expressions:" << std::endl
321 << PrettyPrint(lhs) << std::endl
322 << "with virtual device:" << std::endl
323 << ToString(lhs_domain) << "and:" << std::endl
324 << PrettyPrint(rhs) << std::endl
325 << "with virtual device:" << std::endl
326 << ToString(rhs_domain);
327 }
328}
329
330void DeviceDomains::OptionalUnifyExprExact(const Expr& lhs, const Expr& rhs) {
331 auto lhs_domain = DomainFor(lhs);
332 auto rhs_domain = DomainFor(rhs);
333 // Snapshot
334 std::unordered_map<DeviceDomainPtr, DeviceDomainPtr> domain_to_equiv_snapshot = domain_to_equiv_;
335 if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) {
336 // Rollback
337 domain_to_equiv_ = domain_to_equiv_snapshot;
338 VLOG(2) << "Unable to unify virtual devices for expression:" << std::endl
339 << PrettyPrint(lhs) << std::endl
340 << "with virtual device:" << std::endl
341 << ToString(lhs_domain) << std::endl
342 << "and expression:" << std::endl
343 << PrettyPrint(rhs) << std::endl
344 << "with virtual device:" << std::endl
345 << ToString(rhs_domain) << std::endl
346 << ". Leaving virtual devices non-unified.";
347 } else {
348 VLOG(2) << "Unified virtual devices for expression:" << std::endl
349 << PrettyPrint(lhs) << std::endl
350 << "and expression:" << std::endl
351 << PrettyPrint(rhs) << std::endl
352 << "to virtual devices:" << std::endl
353 << ToString(lhs_domain);
354 }
355}
356
357void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) {
358 auto actual_domain = DomainFor(expr);
359 if (UnifyOrNull(actual_domain, expected_domain) == nullptr) {
360 // TODO(mbs): Proper diagnostics.
361 LOG(FATAL) << "Incompatible virtual devices for expression:" << std::endl
362 << PrettyPrint(expr) << std::endl
363 << "with actual virtual device:" << std::endl
364 << ToString(actual_domain) << std::endl
365 << "and expected virtual device:" << std::endl
366 << ToString(expected_domain);
367 }
368}
369
370void DeviceDomains::UnifyExprCollapsed(const Expr& expr_first_order,
371 const DeviceDomainPtr& expected_domain_maybe_higher_order) {
372 auto actual_domain_first_order = DomainFor(expr_first_order);
373 if (!UnifyCollapsedOrFalse(actual_domain_first_order, expected_domain_maybe_higher_order)) {
374 // TODO(mbs): Proper diagnostics.
375 LOG(FATAL) << "Incompatible virtual devices for expression:" << std::endl
376 << PrettyPrint(expr_first_order) << std::endl
377 << "with actual virtual devices:" << std::endl
378 << ToString(actual_domain_first_order) << std::endl
379 << "and expected virtual device:" << std::endl
380 << ToString(expected_domain_maybe_higher_order);
381 }
382}
383
384bool DeviceDomains::IsFullyConstrained(DeviceDomainPtr domain) {
385 domain = Lookup(domain);
386 if (domain->args_and_result_.empty()) {
387 // First-order.
388 return domain->virtual_device_->IsFullyConstrained();
389 } else {
390 // Higher-order.
391 return std::all_of(
392 domain->args_and_result_.begin(), domain->args_and_result_.end(),
393 [this](const DeviceDomainPtr& sub_domain) { return IsFullyConstrained(sub_domain); });
394 }
395}
396
397void DeviceDomains::SetDefault(DeviceDomainPtr domain,
398 const VirtualDevice& default_virtual_device) {
399 ICHECK(!default_virtual_device->IsFullyUnconstrained());
400 domain = Lookup(domain);
401 if (domain->args_and_result_.empty()) {
402 DeviceDomainPtr default_domain = MakeFirstOrderDomain(config_->CanonicalVirtualDevice(
403 VirtualDevice::Default(domain->virtual_device_, default_virtual_device)));
404 DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull(domain, default_domain);
405 ICHECK(defaulted_domain_ptr != nullptr) << "domain:" << std::endl
406 << ToString(domain) << std::endl
407 << "default domain:" << std::endl
408 << ToString(default_domain);
409 } else {
410 for (const auto& sub_domain : domain->args_and_result_) {
411 SetDefault(sub_domain, default_virtual_device);
412 }
413 }
414}
415
416void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order,
417 const VirtualDevice& default_virtual_device) {
418 if (domain_maybe_higher_order->args_and_result_.empty()) {
419 SetDefault(domain_maybe_higher_order, default_virtual_device);
420 } else {
421 // First set default for result domain.
422 SetDefault(ResultDomain(domain_maybe_higher_order), default_virtual_device);
423 // Then use current result domain as default for everything else.
424 SetDefault(domain_maybe_higher_order, ResultVirtualDevice(domain_maybe_higher_order));
425 }
426}
427
428DeviceDomainPtr DeviceDomains::ResultDomain(DeviceDomainPtr domain) {
429 domain = Lookup(domain);
430 while (!domain->args_and_result_.empty()) {
431 domain = Lookup(domain->args_and_result_.back());
432 }
433 return domain;
434}
435
436std::string DeviceDomains::ToString(DeviceDomainPtr domain) {
437 domain = Lookup(domain);
438 std::ostringstream os;
439 if (domain->args_and_result_.empty()) {
440 // First-order.
441 if (!domain->virtual_device_->IsFullyConstrained()) {
442 os << "?" << static_cast<size_t>(reinterpret_cast<uintptr_t>(domain.get())) << "?";
443 }
444 if (!domain->virtual_device_->IsFullyUnconstrained()) {
445 os << domain->virtual_device_;
446 }
447 } else {
448 // higher-order
449 os << "fn(";
450 for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) {
451 if (i > 0) {
452 os << ",";
453 }
454 os << ToString(domain->args_and_result_[i]);
455 }
456 os << "):" << ToString(domain->args_and_result_.back());
457 }
458 return os.str();
459}
460
461std::string DeviceDomains::ToString() {
462 std::ostringstream os;
463 for (const auto& pair : expr_to_domain_) {
464 os << "expression:" << std::endl
465 << PrettyPrint(GetRef<Expr>(pair.first)) << std::endl
466 << "domain:" << std::endl
467 << ToString(pair.second) << std::endl
468 << std::endl;
469 }
470 for (const auto& pair : call_to_callee_domain_) {
471 os << "call:" << std::endl
472 << PrettyPrint(GetRef<Call>(pair.first)) << std::endl
473 << "callee domain:" << std::endl
474 << ToString(pair.second) << std::endl
475 << std::endl;
476 }
477 return os.str();
478}
479
480} // namespace transform
481} // namespace relay
482} // namespace tvm
483