1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/transforms/device_planner.cc
22 * \brief Determines a unique \p VirtualDevice to hold the result of every Relay sub-expression.
23 * This pass can be run multiple times, and can be run both before and after lowering.
24 *
25 * We say a Relay expression E is 'on device D' if the result of executing E is stored on D.
26 * We represent D by an \p VirtualDevice, which means we can track anywhere from an arbitrary device
27 * of some \p DLDeviceType to a specific memory scope on a specific (virtual) \p Device who's
28 * code is compiled with a specific \p Target.
29 *
30 * Note that 'stored on device D' is almost but not quite the same as 'executes on device D',
31 * see below.
32 *
33 * This pass works by collecting and solving device constraints, using defaulting heuristics to
34 * resolve any remaining undetermined devices, and encoding the results on the output in a form
35 * that's reasonably friendly to downstream passes.
36 *
37 * Specific \p VirtualDevices flow into the constraints from five places:
38 * - Existing "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a
39 * 'src_virtual_device' and 'dst_virtual_device' \p VirtualDevice. Those constrain the argument
40 * and context of the call respectively. It is ok if source and destination devices are the same,
41 * such no-op copies will be removed after accounting for the device preference.
42 * - Existing "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an
43 * 'virtual_device', which constrains the argument of the call, but (usually, see below) leaves the
44 * context unconstrained. These are called 'annotations' in the rest of the code, have no
45 * operational significance by themselves, but may trigger the insertion of a new "device_copy" call
46 * by this pass. In two situations the result of an "on_device" CallNode may also be constrained to
47 * the given 'virtual_device':
48 * - The "on_device" call occurs at the top-level of a function body, or occurs as an
49 * immediately let-bound expression. In this situation the extra degree of freedom in
50 * the function result and let-binding leads to surprising device copies, so we simply
51 * force the function result or let-bound variable to the given device.
52 * - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted
53 * it ourselves during an earlier invocation of this pass. This helps make this pass
54 * idempotent.
55 * - Some special operators require their arguments or results to be on the 'host' (typcially
56 * a CPU) \p VirtualDevice, see below.
57 * - Any \p PrimFuncs in the \p IRModule (if \p LowerTEPass has already run) may constrain their
58 * argument buffers to have a specific memory scope, which is part of \p VirtualDevice.
59 * - Annotations left over from a previous run of this pass, such as 'param_virtual_devices' and
60 * 'result_virtual_device' function attributes we introduce below. This is so the pass is
61 * idempotent and can be re-run to flow additional memory scope constraints.
62 *
63 * We proceed in four phases:
64 *
65 * Phase 0
66 * -------
67 * We rewrite the programs to handle some special cases:
68 * - "on_device" calls at the top-level of function or immediately let-bound are rewritten
69 * to have \code is_fixed=true \endcode.
70 * - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written
71 * \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from
72 * the tuple rather than project from a copy of the tuple. We'll do this by rewriting.
73 * - We are prepared to insert device_copies on the arguments and result of calls to PrimFuncs,
74 * on the assumption a) we already ran PlanDevices before lowering so we are not allowing
75 * any new cross-device copies, but b) after lowering we may have new memory scope constraits
76 * to deal with.
77 *
78 * Phase 1
79 * -------
80 * We flow constraints from the "on_device" and "device_copy" calls, PrimFunc buffer memory scopes,
81 * and some special ops, to all other Relay sub-expressions.
82 *
83 * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the
84 * same device. However each call site can use a different device. In other words primitives are
85 * 'device polymorphic' since we compile and execute them for each required device. ADT constructors
86 * are similarly polymorphic, but require all constructor args to be on the same device.
87 *
88 * For most Relay expressions the device for the overall expression is the same as the device
89 * for its sub-expressions. E.g. each field of a tuple must be on the same device as the tuple
90 * itself, the condition and arms of an \p if must all be on the same device as the overall \p if,
91 * and so on.
92 *
93 * Some special ops (or 'dialects') are handled:
94 * - Relay supports computing the shape of tensors and operators at runtime using "shape_of"
95 * and "reshape_tensor". Shapes must only be held on the CPU, but the tensors they describe
96 * may reside on any device.
97 * - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again
98 * shapes reside on the CPU, but the allocated tensors may reside on any device.
99 *
100 * Two Relay expression have special handling:
101 * - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the
102 * overall let. However the result of \p e1 may be on a different device.
103 * - For a function \code fn(x, y) { body } \endcode the result of the function must be on the
104 * same device as \p body. However parameters \p x and \p may be on different devices, even
105 * different from each other. Every call to the function must use the same choice of parameter
106 * and result devices -- there is no 'device polymorphism' for Relay functions.
107 *
108 * Currently \p PrimFuncs and external functions do not carry over their parameter and result
109 * devices from their original Relay Function representations. However we know all calls to those
110 * functions are device-consistent, thus no information is lost.
111 *
112 * Phase 2
113 * -------
114 * After flowing constraints we apply some defaulting heuristics (using a global default \p
115 * VirtualDevice) to fix the device for any as-yet unconstrained sub-expressions.
116 * - Unconstrained function result devices default to the global default device.
117 * - Unconstrained function parameters devices default to the device for the function result.
118 * - Unconstrained let-bound expression devices default to the device for the overall let.
119 * TODO(mbs): These are very simple minded heuristics, and ultimately we'd like to treat the
120 * assignment of the remaining unconstrained sub-expressions as an optimiziation problem in itself.
121 * This requires a formal notion of 'choicepoint' inside the compiler which can integrate with
122 * automation.
123 *
124 * Phase 3
125 * -------
126 * Finally, the result of this analysis is reified into the result as:
127 * - Additional "param_virtual_devices" (an \p Array<VirtualDevice>) and "result_virtual_device"
128 * (an \p VirtualDevice) attributes for every function (both top-level and local). These describe
129 * the devices for the function's parameters and the result.
130 * - Additional "device_copy" CallNodes where a copy is required in order to respect the
131 * intent of the original "on_device" CallNodes.
132 * - Additional "on_device" CallNodes where the device type of an expression is not trivially
133 * implied by the lexically enclosing "on_device" CallNode or function attribute. In practice
134 * this means "on_device" CallNodes may appear in two places:
135 * - On let-bound expressions. It is tempting to elide the "on_device" if the let-bound value
136 * has the same device as the overall let expression. However this would mean passes which
137 * inline let-bound values, such as FoldConstant and DeadCodeElimination, would need to us
138 * a DeviceAware visitor which in turn requires the expression to be in ANF to avoid
139 * deep recursion. To minimize disruption we always include the "on_device" so that it
140 * can follow the inline.
141 * - On a call argument if its device differs from the call result. In particular, the
142 * argument to a "device_copy" call will always be wrapped in an "on_device". (That may
143 * seem pedantic but simplifies downstream handling.)
144 * However since we make it easy to track devices for variables we never wrap an "on_device"
145 * around a var or global var. These uses of "on_device" imply both the argument and result are
146 * on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true,
147 * which helps make this pass idempotent.
148 * - The buffer maps for called PrimFuncs are updated to capture memory scopes.
149 *
150 * Helper visitors (in device_aware_visitors.h) can be used by downstream transforms to recover
151 * the device for any expression for their own use, e.g. during memory planning. All downstream
152 * passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion
153 * to ANF must respect the lexical scoping convention:
154 * \code
155 * f(on_device(g(h(a, b), c), virtual_device=CPU))
156 * ==>
157 * let %x0 = on_device(h(a, b), virtual_device=CPU)
158 * let %x1 = on_device(g(%x0), virtual_device=CPU)
159 * f(on_device(%x1, virtual_device=CPU))
160 * \endcode
161 *
162 * This pass can be run before FuseOps so that it can use device-specific fusion rules.
163 *
164 * 'Stored on' vs 'Executes on'
165 * ----------------------------
166 * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the
167 * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for
168 * primitives.
169 *
170 * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are
171 * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific
172 * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to
173 * know exactly which device (possibly one of a number of available 'CPU'-like devices) is
174 * responsible for execution. Currently that's handled independently by the \p AnnotateTargets
175 * pass, but we'd like to fold that into device planning here to ensure everything is consistent.
176 *
177 * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay
178 * expression (eg an \p if expression) on one device even though the tensor data resides on
179 * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on'
180 * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just
181 * compile the function body for the function's result device.
182 *
183 * This works after conversion to ANF provided the compilation for a let expression is prepared
184 * to make a cross-device call. However we leave it to a downstream transformation to heuristically
185 * minimize cross-device calls by moving device copies out of functions. E.g.:
186 * \code
187 * def @f() { // execute on CPU
188 * let x = on_device(...GPU computation..., virtual_device=GPU);
189 * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU)
190 * }
191 * def @main() {
192 * ... call @f() on CPU ...
193 * }
194 * \endcode
195 * could be rewritten to:
196 * \code
197 * def @f() { // execute on GPU
198 * let x = ...GPU computation...;
199 * ...GPU computation...
200 * }
201 * def @main() {
202 * let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU)
203 * ... use x on CPU ...
204 * }
205 * \endcode
206 *
207 * Higher-order shenanigans
208 * ------------------------
209 * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions
210 * as arguments (even anonymous functions), return functions, evaluate conditional expressions
211 * over functions, and so on. We handle this during constraint solving using the domain:
212 * \code
213 * D ::= <specific device type> -- first-order
214 * | fn(D,...,D):D -- higher-order
215 * \endcode
216 * In this way we can determine the device for all function parameters and results. E.g. for
217 * \code
218 * let f = fn(x, y) { ... }
219 * let g = fn(f, z) { f(z, z) }
220 * g(f, on_device(..., virtual_device=CPU))
221 * \endcode
222 * the parameters \p x and \p y will be on the CPU.
223 *
224 * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a
225 * function. Our analysis must guarantee that the function's parameters and result devices are
226 * consistent for \p e2, \p e3, and the context of the call. But:
227 * - Which device holds the closure result of evaluating \p e1 ?
228 * - If \p e2 is of function type, what does that mean when we say every function parameter
229 * is on a device?
230 * - If \p e1 returns a function, what does that mean when we say every function result is
231 * on a device?
232 *
233 * Since higher-order aspects are later compiled away (by 'defunctionalization'
234 * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular,
235 * we really don't want our domain \p D to allow for yet another device for the function closure.
236 * So we'll just force the 'device for a function' to be the same as the device for the function's
237 * result using the notion of the 'result domain' for a domain:
238 * \code
239 * result_domain(<specific device type>) = <specific device type>
240 * result_domain(fn(D1,...,Dn):Dr) = result_domain(Dr)
241 * \endcode
242 *
243 * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the
244 * analysis encounters a function inside one of those it simply forces all argument and result
245 * devices for the function to match the device for the first-order expression. For example,
246 * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function
247 * parameters and result must similarly be on the GPU.
248 *
249 * -------
250 * | AOR | This pass supports all of Relay.
251 * -------
252 * ^
253 * |
254 * `-- Mark's stamp of completeness :-)
255 *
256 * TODO(mbs): Proper diagnostics for unification failure using spans.
257 * TODO(mbs): We may want some 'device polymorphism' for Relay functions. Eg it's ok for the
258 * function to be called with params/result on different (virtual) device ids provided the target
259 * and memory scopes are consistent.
260 */
261
262#include <tvm/ir/transform.h>
263#include <tvm/relay/analysis.h>
264#include <tvm/relay/attrs/annotation.h>
265#include <tvm/relay/attrs/device_copy.h>
266#include <tvm/relay/attrs/memory.h>
267#include <tvm/relay/expr_functor.h>
268#include <tvm/relay/op.h>
269#include <tvm/relay/pattern_functor.h>
270#include <tvm/relay/transform.h>
271#include <tvm/relay/type.h>
272#include <tvm/runtime/c_runtime_api.h>
273#include <tvm/runtime/object.h>
274#include <tvm/tir/function.h>
275#include <tvm/tir/stmt_functor.h>
276
277#include <unordered_map>
278
279#include "../../tir/analysis/device_constraint_utils.h"
280#include "../op/annotation/annotation.h"
281#include "../op/memory/device_copy.h"
282#include "../op/memory/on_device.h"
283#include "./device_domains.h"
284
285namespace tvm {
286namespace relay {
287namespace transform {
288
289namespace {
290
291/* =============== Phase 0 =============== */
292
293/*!
294 * \brief Rewrites "on_device" calls to handle some special cases.
295 *
296 * - Don't let the device for %x remain unconstrained:
297 * \code
298 * let %x = on_device(e, virtual_device=d)
299 * ==> let %x = on_device(e, virtual_device=d, constraint=kBoth)
300 * \endcode
301 *
302 * - Don't let the function result remain unconstrained:
303 * \code
304 * fn(%x) { on_device(e, virtual_device=d) }
305 * ==> fn(%x) { on_device(e, virtual_device=d, constraint=kBoth)
306 * \endcode
307 *
308 * - Project-then-copy rather than copy-then-project:
309 * \code
310 * on_device(e).0
311 * ==> on_device(e.0)
312 * \endcode
313 *
314 * - Be prepared to copy arguments and results on primitive call boundaries in case memory
315 * scopes don't line up. We'll use the 'fully unconstrained' version of on_device so that
316 * we can allow for a device_copy without knowing the specific device for the arguments.
317 * \code
318 * call_lowered(@prim, (a, b))
319 * ==> copy_ok(call_lowered(@prim, (copy_ok(a), copy_ok(b))))
320 * where
321 * copy_ok(x) = on_device(x, virtual_device=VirtualDevice::FullyUnconstrained,
322 * constrain_body=False, constrain_result=False)
323 * \endcode
324 */
325class RewriteOnDevices : public ExprMutator {
326 public:
327 explicit RewriteOnDevices(IRModule mod) : mod_(std::move(mod)) {}
328
329 private:
330 Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
331 Expr tuple = VisitExpr(tuple_get_item_node->tuple);
332 OnDeviceProps props = GetOnDeviceProps(tuple);
333
334 Expr tuple_get_item = WithFields(GetRef<TupleGetItem>(tuple_get_item_node), tuple);
335 if (props.body.defined() && props.is_normal()) {
336 VLOG(2) << "wrapping tuple get item:" << std::endl
337 << PrettyPrint(GetRef<TupleGetItem>(tuple_get_item_node)) << std::endl
338 << "with \"on_device\" for VirtualDevice " << props.virtual_device;
339 return OnDeviceWithProps(tuple_get_item, props);
340 } else {
341 return tuple_get_item;
342 }
343 }
344
345 Expr VisitExpr_(const LetNode* let_node) final {
346 auto expr = GetRef<Expr>(let_node);
347 std::vector<std::tuple<Let, Expr>> bindings;
348 while (const auto* inner_let_node = expr.as<LetNode>()) {
349 Let inner_let = GetRef<Let>(inner_let_node);
350 Expr value = VisitExpr(inner_let_node->value);
351 OnDeviceProps props = GetOnDeviceProps(value);
352 if (props.body.defined() && props.is_normal()) {
353 VLOG(2) << "revising let-bound expression of let:" << std::endl
354 << PrettyPrint(expr) << std::endl
355 << "to be fixed to VirtualDevice " << props.virtual_device;
356 value = MaybeOnDeviceFixed(props.body, props.virtual_device);
357 }
358 bindings.emplace_back(inner_let, value);
359 expr = inner_let_node->body;
360 }
361 expr = VisitExpr(expr);
362 for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
363 expr = WithFields(/*let=*/std::get<0>(*itr), /*opt_var=*/{},
364 /*opt_value=*/std::get<1>(*itr), /*opt_body=*/expr);
365 }
366 return expr;
367 }
368
369 Expr VisitExpr_(const FunctionNode* function_node) final {
370 Expr body = VisitExpr(function_node->body);
371 OnDeviceProps props = GetOnDeviceProps(body);
372 if (props.body.defined() && props.is_normal()) {
373 VLOG(2) << "revising body of function:" << std::endl
374 << PrettyPrint(GetRef<Function>(function_node)) << std::endl
375 << "to be fixed to VirtualDevice " << props.virtual_device;
376 body = MaybeOnDeviceFixed(props.body, props.virtual_device);
377 }
378 return WithFields(GetRef<Function>(function_node), function_node->params, body);
379 }
380
381 Expr VisitExpr_(const CallNode* call_node) final {
382 CallLoweredProps props = GetCallLoweredProps(call_node);
383 if (props.lowered_func.defined()) {
384 BaseFunc base_func = mod_->Lookup(props.lowered_func);
385 if (base_func.as<tir::PrimFuncNode>()) {
386 VLOG(2) << "allowing device_copy on PrimFunc arguments and result";
387 Array<Expr> new_args;
388 new_args.reserve(props.arguments.size());
389 for (const auto& arg : props.arguments) {
390 Expr new_arg = VisitExpr(arg);
391 new_args.push_back(OnDeviceCopyOk(std::move(new_arg)));
392 }
393 Call new_call = CallLowered(std::move(props.lowered_func), std::move(new_args), props.attrs,
394 call_node->span);
395 return OnDeviceCopyOk(std::move(new_call));
396 }
397 }
398 return ExprMutator::VisitExpr_(call_node);
399 }
400
401 /*! \brief Module we are rewriting, so we can lookup global definitions. */
402 IRModule mod_;
403};
404
405/* =============== Phase 1 =============== */
406
407/*
408 * \brief Collects the system of device constraints for all sub-expressions in a module.
409 * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter.
410 *
411 * Eg from \code add(%x, %y) \endcode we know \p %x and \p %y must be on the same device. Later,
412 * from \code on_device(%x, virtual_device=d) \endcode we know \p %x must be on device \p d, and
413 * thus so must \p %y.
414 *
415 * Constraints can flow in interesting ways. E.g. in:
416 * \code
417 * let %f = fn(%x, %y) { add(%x, on_device(%y, virtual_device=d)) }
418 * let %g = fn(%f, %x, %y) { %f(%x, %y) }
419 * %g(%f, %a, %b)
420 * \endcode
421 * we discover \p %b must be on device \p d.
422 */
423class DeviceAnalyzer : public MixedModeVisitor {
424 public:
425 DeviceAnalyzer(IRModule mod, CompilationConfig config)
426 : mod_(std::move(mod)), domains_(std::make_unique<DeviceDomains>(std::move(config))) {}
427
428 /*!
429 * \brief Returns the expression-to-device-domain map for all expressions in all the global
430 * function definitions in the module. Expressions may have free domains, these will be resolved
431 * by \p DeviceDefaulter below.
432 */
433 std::unique_ptr<DeviceDomains> Analyze() {
434 VLOG_CONTEXT << "DeviceAnalyzer";
435 for (const auto& kv : mod_->functions) {
436 // The global variable and what it is bound to must obviously agree on domain.
437 if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
438 VLOG(2) << "collecting constraints from Relay Function '" << kv.first->name_hint << "'";
439 domains_->UnifyExprExact(kv.first, kv.second);
440 VisitExpr(GetRef<Function>(function_node));
441 } else if (const auto* prim_func_node = kv.second.as<tir::PrimFuncNode>()) {
442 VLOG(2) << "collecting constraints from TIR PrimFunc '" << kv.first->name_hint << "'";
443 domains_->UnifyExprExact(
444 kv.first, DomainForPrimFunc(kv.first, GetRef<tir::PrimFunc>(prim_func_node)));
445 } else {
446 VLOG(2) << "skipping '" << kv.first->name_hint << "'";
447 }
448 }
449 return std::move(domains_);
450 }
451
452 private:
453 /*!
454 * \brief Return the domain representing \p prim_func which, before lowering, had
455 * the Relay \p type.
456 */
457 DeviceDomainPtr DomainForPrimFunc(const GlobalVar& global_var, const tir::PrimFunc& prim_func) {
458 // CAUTION: The prim_func->checked_type() is currently w.r.t. the flattened and DPS form
459 // of the prim func, however here we wish to remain within the Relay view of all functions.
460 // Thus we'll use the global var who's checked_type is in Relay form.
461 auto func_domain = domains_->DomainFor(global_var); // higher-order
462
463 // TODO(mbs): We don't visit the body of the function -- there's currently nothing to be done.
464 const auto* func_type_node = global_var->checked_type().as<FuncTypeNode>();
465 ICHECK(func_type_node);
466 ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size());
467
468 Array<VirtualDevice> virtual_devices =
469 tir::GetPrimFuncArgAndResultConstraints(prim_func, GetRef<FuncType>(func_type_node));
470
471 // Build the implied domain (in terms of the function's Relay type) implied by any memory scope
472 // constrains in the function's buffers, for both arguments and results.
473 std::vector<DeviceDomainPtr> args_and_result_domains;
474 args_and_result_domains.reserve(virtual_devices.size());
475 for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) {
476 const VirtualDevice& param_virtual_device = virtual_devices[i];
477 VLOG(2) << "param_virtual_device[" << i << "] = " << param_virtual_device;
478 args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(param_virtual_device));
479 }
480 const VirtualDevice& ret_virtual_device = virtual_devices.back();
481 VLOG(2) << "ret_virtual_device = " << ret_virtual_device;
482 args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(ret_virtual_device));
483
484 return domains_->MakeHigherOrderDomain(std::move(args_and_result_domains));
485 }
486
487 void VisitExpr_(const CallNode* call_node) final {
488 auto call = GetRef<Call>(call_node);
489
490 // We don't care if the call is in pre- or post-lowered form.
491 auto vanilla_call = GetAnyCall(call_node);
492
493 // Find the higher-order domain for the callee. See DomainForCallee for the special rules
494 // for primitives.
495 VisitExpr(vanilla_call->op);
496 auto func_domain = domains_->DomainForCallee(call); // higher-order
497
498 // Build the domain for the function implied by its arguments and call context.
499 ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()) << PrettyPrint(call);
500 std::vector<DeviceDomainPtr> args_and_result_domains;
501 args_and_result_domains.reserve(vanilla_call->args.size() + 1);
502 for (const auto& arg : vanilla_call->args) {
503 args_and_result_domains.emplace_back(domains_->DomainFor(arg));
504 }
505 args_and_result_domains.emplace_back(domains_->DomainFor(call));
506 auto implied_domain =
507 domains_->MakeHigherOrderDomain(std::move(args_and_result_domains)); // higher-order
508
509 VLOG(2) << "initial call function domain:" << std::endl
510 << domains_->ToString(func_domain) << std::endl
511 << "and implied domain:" << std::endl
512 << domains_->ToString(implied_domain) << std::endl
513 << "for call:" << std::endl
514 << PrettyPrint(call);
515
516 // The above must match.
517 if (domains_->UnifyOrNull(func_domain, implied_domain) == nullptr) { // higher-order
518 // TODO(mbs): Proper diagnostics.
519 LOG(FATAL)
520 << "Function parameters and result VirtualDevices do not match those of call. Call:"
521 << std::endl
522 << PrettyPrint(call) << std::endl
523 << "with function virtual devices:" << std::endl
524 << domains_->ToString(func_domain) << std::endl
525 << "and implied call virtual devices:" << std::endl
526 << domains_->ToString(implied_domain);
527 }
528
529 VLOG(2) << "final call function domain:" << std::endl
530 << domains_->ToString(func_domain) << std::endl
531 << "for call:" << std::endl
532 << PrettyPrint(call);
533 }
534
535 void VisitExpr_(const LetNode* let_node) final {
536 Expr expr = GetRef<Let>(let_node);
537 // Iteratively visit let nodes to avoid stack overflow.
538 while (expr->IsInstance<LetNode>()) {
539 Let let = Downcast<Let>(expr);
540 // Let var must be same device as value it is bound to.
541 domains_->UnifyExprExact(let->var, let->value); // may be higher-order
542 // Let body must be same device as overall let.
543 domains_->UnifyExprExact(let, let->body); // may be higher-order
544
545 VisitExpr(let->var);
546 VisitExpr(let->value);
547
548 expr = let->body;
549 }
550
551 // Visit the last body
552 VisitExpr(expr);
553 }
554
555 void VisitExpr_(const FunctionNode* function_node) final {
556 auto function = GetRef<Function>(function_node);
557 auto func_domain = domains_->DomainFor(function); // higher-order
558 ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
559
560 VLOG(2) << "initial function domain:" << std::endl
561 << domains_->ToString(func_domain) << std::endl
562 << "and function body domain:" << std::endl
563 << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl
564 << "for function:" << std::endl
565 << PrettyPrint(function);
566
567 // The function body domain must match the function result domain.
568 domains_->UnifyExprExact(function_node->body,
569 func_domain->function_result()); // may be higher-order
570 if (!function_node->virtual_device()->IsFullyUnconstrained()) {
571 // The function body domain must match any existing virtual device annotation.
572 domains_->UnifyExprExact(function_node->body,
573 domains_->ForVirtualDevice(function_node->body->checked_type(),
574 function_node->virtual_device()));
575 }
576
577 for (size_t i = 0; i < function_node->params.size(); ++i) {
578 const auto& param = function_node->params[i];
579 // The parameter domain must match the function argument domain.
580 domains_->UnifyExprExact(param,
581 func_domain->function_param(i)); // may be higher-order
582 if (!param->virtual_device()->IsFullyUnconstrained()) {
583 // The parameter domain must match any existing virtual device annotation.
584 domains_->UnifyExprExact(
585 param, domains_->ForVirtualDevice(param->checked_type(), param->virtual_device()));
586 }
587 VisitExpr(param);
588 }
589
590 // No need to step into the body of Primitive functions.
591 if (!function_node->HasNonzeroAttr(attr::kPrimitive)) {
592 VisitExpr(function_node->body);
593 }
594
595 VLOG(2) << "final function domain:" << std::endl
596 << domains_->ToString(func_domain) << std::endl
597 << "and function body domain:" << std::endl
598 << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl
599 << "for function:" << std::endl
600 << PrettyPrint(function);
601 }
602
603 void VisitExpr_(const TupleNode* tuple_node) final {
604 Tuple tuple = GetRef<Tuple>(tuple_node);
605 for (size_t i = 0; i < tuple->fields.size(); i++) {
606 auto domain = domains_->DomainFor(tuple->fields[i]); // may be higher-order
607 domains_->UnifyExprCollapsed(tuple, domain); // collapse to first-order if needed
608 }
609 }
610
611 void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
612 TupleGetItem tuple_get_item = GetRef<TupleGetItem>(tuple_get_item_node);
613 auto domain = domains_->DomainFor(tuple_get_item); // may be higher-order
614 domains_->UnifyExprCollapsed(tuple_get_item_node->tuple,
615 domain); // collapse to first-order if needed
616 }
617
618 class DevicePatternAnalyzer : public PatternVisitor {
619 public:
620 DevicePatternAnalyzer(DeviceDomains* domains, const ExprNode* adt_node)
621 : domains_(domains), adt_node_(adt_node) {}
622
623 private:
624 void VisitPattern_(const PatternVarNode* pattern_var_node) final {
625 auto var_domain = domains_->DomainFor(pattern_var_node->var); // may be higher order
626 domains_->UnifyExprCollapsed(GetRef<Expr>(adt_node_),
627 var_domain); // collapse to first-order if needed
628 }
629
630 /*! \brief (Mutable borrow of) the domains for all expressions processed so far. */
631 DeviceDomains* domains_;
632 /*! \brief The expression for the ADT we are matching over. */
633 const ExprNode* adt_node_;
634 };
635
636 void VisitPattern(const Pattern& pattern) final {}
637
638 void VisitExpr_(const MatchNode* match_node) final {
639 // For match node, we unify the value and the rhs of each clause
640 Match match = GetRef<Match>(match_node);
641 auto match_domain = domains_->DomainFor(match); // may be higher-order
642 DevicePatternAnalyzer pattern_analyzer(domains_.get(), match->data.get());
643 domains_->UnifyExprCollapsed(match->data, match_domain); // collapse to first-order if needed
644 for (const auto& clause : match->clauses) {
645 pattern_analyzer.VisitPattern(clause->lhs);
646 domains_->UnifyExprExact(clause->rhs, match_domain);
647 VisitExpr(clause->rhs);
648 }
649 VisitExpr(match_node->data);
650 }
651
652 void VisitExpr_(const GlobalVarNode* global_var_node) final {
653 domains_->DomainFor(GetRef<GlobalVar>(global_var_node));
654 }
655
656 void VisitExpr_(const VarNode* var_node) final { domains_->DomainFor(GetRef<Var>(var_node)); }
657
658 void VisitExpr_(const ConstantNode* constant_node) final {
659 domains_->DomainFor(GetRef<Constant>(constant_node));
660 }
661
662 void VisitExpr_(const ConstructorNode* constructor_node) final {
663 // no-op, constructors are handled at their call-sites.
664 // TODO(mbs): Assumes eta-expansion
665 }
666
667 void VisitExpr_(const IfNode* if_node) final {
668 auto ife = GetRef<If>(if_node);
669 auto domain = domains_->DomainFor(ife); // may be higher-order
670 domains_->UnifyExprCollapsed(if_node->cond, domain); // collapse to first-order if needed
671 domains_->UnifyExprExact(if_node->true_branch, domain);
672 domains_->UnifyExprExact(if_node->false_branch, domain);
673 VisitExpr(if_node->cond);
674 VisitExpr(if_node->true_branch);
675 VisitExpr(if_node->false_branch);
676 }
677
678 void VisitExpr_(const OpNode* op) final {
679 // no-op, primitive operators are handled at their call-sites.
680 }
681
682 void VisitExpr_(const RefCreateNode* ref_create_node) final {
683 auto ref_create = GetRef<RefCreate>(ref_create_node);
684 auto domain = domains_->DomainFor(ref_create_node->value); // may be higher-order
685 domains_->UnifyExprCollapsed(ref_create, domain); // collapse to first-order if needed
686 VisitExpr(ref_create_node->value);
687 }
688
689 void VisitExpr_(const RefReadNode* ref_read_node) final {
690 auto ref_read = GetRef<RefRead>(ref_read_node);
691 auto domain = domains_->DomainFor(ref_read); // may be higher-order
692 domains_->UnifyExprCollapsed(ref_read_node->ref, domain); // collapse to first-order if needed
693 VisitExpr(ref_read_node->ref);
694 }
695
696 void VisitExpr_(const RefWriteNode* ref_write_node) final {
697 auto ref_write = GetRef<RefWrite>(ref_write_node);
698 auto domain = domains_->DomainFor(ref_write->value); // may be higher-order
699 domains_->UnifyExprCollapsed(ref_write->ref, domain); // collapse to first-order if needed
700 domains_->UnifyExprCollapsed(ref_write, domain); // collapse to first-order if needed
701 VisitExpr(ref_write_node->ref);
702 VisitExpr(ref_write_node->value);
703 }
704
705 /*! \brief The module we are analyzing. */
706 IRModule mod_;
707 /*! \brief The domains for all expressions processed so far. */
708 std::unique_ptr<DeviceDomains> domains_;
709};
710
711/* =============== Phase 2 =============== */
712
713/*!
714 * \brief Calls to 'free' "on_device" annotations (ie where both constrain_body=false and
715 * constrain_result=false) indicate a device_copy is allowed if required, but no particular
716 * device is imposed on the body or the context. At this stage we can attempt to unify the
717 * body and device contexts. In this way we can avoid the defaulting rules in \p DeviceDefaulter
718 * from choosing default devices which are only going to induce a device copy.
719 *
720 * TODO(mbs): The order in which we encounter the "on_device" calls can influence the final global
721 * device assignment. However we visit global functions in hash map order.
722 */
723class FreeOnDeviceDefaulter : public ExprVisitor {
724 public:
725 FreeOnDeviceDefaulter(IRModule mod, std::unique_ptr<DeviceDomains> domains)
726 : mod_(std::move(mod)), domains_(std::move(domains)) {}
727
728 std::unique_ptr<DeviceDomains> Default() {
729 VLOG_CONTEXT << "FreeOnDeviceDefaulter";
730 VLOG(0) << "unifying free on_device annotations";
731 for (const auto& kv : mod_->functions) {
732 if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
733 VLOG(2) << "unifying for '" << kv.first->name_hint << "'";
734 VisitExpr(GetRef<Function>(function_node));
735 } else {
736 VLOG(2) << "skipping '" << kv.first->name_hint << "'";
737 }
738 }
739 return std::move(domains_);
740 }
741
742 private:
743 void VisitExpr_(const CallNode* call_node) final {
744 auto call = GetRef<Call>(call_node);
745 OnDeviceProps props = GetOnDeviceProps(call_node);
746 ExprVisitor::VisitExpr_(call_node);
747 if (props.body.defined() && !props.constrain_body && !props.constrain_result) {
748 domains_->OptionalUnifyExprExact(call, props.body);
749 }
750 }
751
752 /*! \brief The module we are processing. */
753 IRModule mod_;
754 /*! \brief The domains for all expressions. */
755 std::unique_ptr<DeviceDomains> domains_;
756};
757
758/*!
759 * \brief Ensures every sub-expression in a module has a device type, using both the global
760 * default and some local heuristics to avoid unnecessary additional "device_copy" CallNodes.
761 *
762 * E.g. in:
763 * \code
764 * def @main(%x, %y, %z) {
765 * let %a = add(%x, %y);
766 * multiply(%a, on_device(%z, virtual_device=d))
767 * }
768 * \endcode
769 * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y,
770 * and the device for the function result, are still 'free'. The global 'default' device type
771 * is first used to 'fix' \p main's result type, which in turn 'fixes' \p %x and \p %y, which
772 * in turn 'fixes' the device on which the \p add and \p multiply are executed.
773 *
774 * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap
775 * order.
776 */
777class DeviceDefaulter : public ExprVisitor {
778 public:
779 DeviceDefaulter(IRModule mod, std::unique_ptr<DeviceDomains> domains)
780 : mod_(std::move(mod)), domains_(std::move(domains)) {}
781
782 std::unique_ptr<DeviceDomains> Default() {
783 VLOG_CONTEXT << "DeviceDefaulter";
784 VLOG(0) << "defaulting to VirtualDevice "
785 << domains_->config()->default_primitive_virtual_device;
786 for (const auto& kv : mod_->functions) {
787 if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
788 VLOG(2) << "defaulting devices for '" << kv.first->name_hint << "'";
789 VisitExpr(GetRef<Function>(function_node));
790 } else {
791 VLOG(2) << "skipping '" << kv.first->name_hint << "'";
792 }
793 }
794 return std::move(domains_);
795 }
796
797 private:
798 void VisitExpr_(const FunctionNode* function_node) final {
799 if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
800 return;
801 }
802
803 auto function = GetRef<Function>(function_node);
804 auto func_domain = domains_->DomainFor(function); // higher-order
805 ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
806 if (!domains_->IsFullyConstrained(func_domain)) {
807 VLOG(2) << "before defaulting function:" << std::endl << domains_->ToString(func_domain);
808 domains_->SetResultDefaultThenParams(func_domain,
809 domains_->config()->default_primitive_virtual_device);
810 VLOG(2) << "after defaulting function:" << std::endl << domains_->ToString(func_domain);
811 }
812 VisitExpr(function_node->body);
813 }
814
815 void VisitExpr_(const CallNode* call_node) final {
816 auto call = GetRef<Call>(call_node);
817
818 // We don't care if the call is pre- or post-lowered.
819 auto vanilla_call = GetAnyCall(call_node);
820
821 auto func_domain = domains_->DomainForCallee(call); // higher-order
822 ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size());
823 if (!domains_->IsFullyConstrained(func_domain)) {
824 // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*)
825 // above. But for calls to primitives we may still need to force free domains to be
826 // defaulted.
827 VLOG(2) << "before defaulting callee:" << std::endl
828 << PrettyPrint(call_node->op) << std::endl
829 << "of domain:" << std::endl
830 << domains_->ToString(func_domain);
831 domains_->SetResultDefaultThenParams(func_domain,
832 domains_->config()->default_primitive_virtual_device);
833 VLOG(2) << "after defaulting callee:" << std::endl
834 << PrettyPrint(call_node->op) << std::endl
835 << "of domain:" << std::endl
836 << domains_->ToString(func_domain);
837 }
838 return ExprVisitor::VisitExpr_(call_node);
839 }
840
841 void VisitExpr_(const LetNode* let_node) final {
842 Expr expr = GetRef<Let>(let_node);
843 // Iteratively visit let nodes to avoid stack overflow.
844 while (expr->IsInstance<LetNode>()) {
845 Let let = Downcast<Let>(expr);
846 // If the let-var device is still free force it to match the overall let.
847 auto let_domain = domains_->DomainFor(let); // may be higher-order
848 VirtualDevice let_virtual_device = domains_->ResultVirtualDevice(let_domain);
849 ICHECK(!let_virtual_device->IsFullyUnconstrained());
850 auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order
851 if (!domains_->IsFullyConstrained(let_var_domain)) {
852 VLOG(2) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain);
853 domains_->SetDefault(let_var_domain, let_virtual_device);
854 VLOG(2) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain);
855 }
856 VisitExpr(let->var);
857 VisitExpr(let->value);
858 expr = let->body;
859 }
860 VisitExpr(expr);
861 }
862
863 /*! \brief The module we are processing. */
864 IRModule mod_;
865 /*! \brief The domains for all expressions. */
866 std::unique_ptr<DeviceDomains> domains_;
867};
868
869/* =============== Phase 3 =============== */
870/*!
871 * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every
872 * sub-expression in a module can be easily recovered by a later transformation using simple
873 * lexical scoping rules (e.g. for memory planning).
874 *
875 * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard
876 * any existing "device_copy" CallNodes which are no-ops.
877 *
878 * - The result virtual device for a function is stored in the function's virtual_device_ field
879 * and the virtual devices of the function's parameters are stored in the parameter's
880 * virtual_device_ field.
881 *
882 * - Additional "device_copy" CallNodes are inserted wherever there's a transition between
883 * storage device types. Since the DeviceAnalyzer phase succeeded this can only happen
884 * where the original program explicitly allowed a transition using an "on_device" CallNode.
885 * That is, we do not not try to 'fix' a program with inconsistent devices.
886 *
887 * - Additional "on_device" CallNodes are inserted so that a later transform can discover
888 * the device for an arbitrary sub-expression by looking only for the lexically enclosing
889 * "on_device" CallNode or "on_device" function attribute. In particular, since function
890 * arguments and let-bound expressions can be on a device different from the function
891 * or let body itself we will insert "on_device" CallNodes to spell out any differences. This
892 * applies even to the argument to a "device_copy" CallNode, which may look pedantic but
893 * keeps downstream processing simple. The "on_device" calls should be removed before code gen,
894 * which is easily done on-the-fly.
895 *
896 * - Update memory scopes in PrimFunc buffer maps.
897 *
898 * For example, we'll end up with programs that look like:
899 * \code
900 * def @main(%x, %y, param_virtual_devices=[...], result_virtual_device=...) {
901 * let %a = on_device(..., virtual_device=..., is_fixed=True)
902 * @f(%a, device_copy(on_device(..., virtual_device=..., is_fixed=True),
903 * src_virtual_device=..., dst_virtual_device=...))
904 * }
905 * \endcode
906 */
907class DeviceCapturer : public ExprMutator {
908 public:
909 DeviceCapturer(IRModule mod, std::unique_ptr<DeviceDomains> domains)
910 : mod_(std::move(mod)), domains_(std::move(domains)) {}
911
912 IRModule Capture() {
913 VLOG_CONTEXT << "CaptureDevices";
914 IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map,
915 mod_->attrs);
916 for (const auto& kv : mod_->functions) {
917 if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
918 VLOG(2) << "capturing devices for Relay Function '" << kv.first->name_hint << "'";
919 result->Add(kv.first, Downcast<Function>(Mutate(GetRef<Function>(function_node))));
920 } else if (const auto* prim_func_node = kv.second.as<tir::PrimFuncNode>()) {
921 VLOG(2) << "capturing devices for TIR PrimFunc '" << kv.first->name_hint << "'";
922 auto prim_func = GetRef<tir::PrimFunc>(prim_func_node);
923 tir::PrimFunc new_prim_func = UpdatePrimFunc(kv.first, prim_func);
924 VLOG(2) << "Rewritten prim func:" << std::endl
925 << PrettyPrint(prim_func) << std::endl
926 << "to:" << std::endl
927 << PrettyPrint(new_prim_func);
928 result->Add(kv.first, std::move(new_prim_func));
929 } else {
930 VLOG(2) << "skipping '" << kv.first->name_hint << "'";
931 result->Add(kv.first, kv.second);
932 }
933 }
934 return result;
935 }
936
937 private:
938 /*!
939 * \brief Returns \p prim_func updated to capture any memory scope's implied by its device
940 * domain.
941 */
942 tir::PrimFunc UpdatePrimFunc(const GlobalVar& global_var, const tir::PrimFunc& prim_func) {
943 // CAUTION: Same caution as for DeviceAnalyzer::DomainForPrimFunc.
944 auto func_domain = domains_->DomainFor(global_var);
945 ICHECK(func_domain->is_higher_order());
946
947 const auto* func_type_node = global_var->checked_type().as<FuncTypeNode>();
948 ICHECK(func_type_node);
949 ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size());
950
951 std::vector<VirtualDevice> arg_and_result_virtual_devices;
952 arg_and_result_virtual_devices.reserve(func_type_node->arg_types.size() + 1);
953 for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) {
954 VirtualDevice param_virtual_device =
955 domains_->ResultVirtualDevice(func_domain->function_param(i));
956 VLOG(2) << "param_virtual_device[" << i << "] = " << param_virtual_device;
957 arg_and_result_virtual_devices.push_back(param_virtual_device);
958 }
959 VirtualDevice ret_virtual_device =
960 domains_->ResultVirtualDevice(func_domain->function_result());
961 VLOG(2) << "ret_virtual_device = " << ret_virtual_device;
962 arg_and_result_virtual_devices.push_back(ret_virtual_device);
963
964 return tir::ApplyPrimFuncArgAndResultConstraints(prim_func, GetRef<FuncType>(func_type_node),
965 arg_and_result_virtual_devices);
966 }
967
968 // Nothing interesting for VarNode, ConstantNode, GlobalVarNode, OpNode and ConstructorNode
969
970 Expr VisitExpr_(const TupleNode* tuple_node) final {
971 auto tuple = GetRef<Tuple>(tuple_node);
972 Array<Expr> fields;
973 fields.reserve(tuple_node->fields.size());
974 for (const auto& field : tuple_node->fields) {
975 fields.push_back(VisitChild(tuple, field));
976 }
977 return WithFields(tuple, fields);
978 }
979
980 Expr VisitExpr_(const FunctionNode* function_node) final {
981 if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
982 return GetRef<Function>(function_node);
983 }
984
985 auto function = GetRef<Function>(function_node);
986 auto func_domain = domains_->DomainFor(function); // higher-order
987 VLOG(2) << "capturing function:" << std::endl
988 << PrettyPrint(function) << std::endl
989 << "with domain:" << std::endl
990 << domains_->ToString(func_domain);
991
992 // Gather the parameter and result device types for the function attributes.
993 ICHECK_EQ(func_domain->function_arity(), function_node->params.size());
994 VirtualDevice result_virtual_device = domains_->ResultVirtualDevice(func_domain);
995 ICHECK(!result_virtual_device->IsFullyUnconstrained());
996
997 // Map the function parameters to a new variable annotated with a virtual device so
998 // we can substitute them later.
999 Map<Var, Expr> annotated_bind_map;
1000 Array<Var> annotated_params;
1001 annotated_params.reserve(function_node->params.size());
1002 for (size_t i = 0; i < function_node->params.size(); ++i) {
1003 VirtualDevice param_virtual_device =
1004 domains_->ResultVirtualDevice(func_domain->function_param(i));
1005 VLOG(4) << "Param: " << function_node->params[i];
1006 Var annotated_var = WithFields(function_node->params[i], {}, {}, param_virtual_device);
1007 VLOG(4) << "Annotated param: " << annotated_var;
1008 VLOG(4) << "VirtualDevice: " << annotated_var->virtual_device();
1009 ICHECK(!param_virtual_device->IsFullyUnconstrained());
1010 annotated_bind_map.Set(function_node->params[i], annotated_var);
1011 annotated_params.push_back(annotated_var);
1012 }
1013 // Eventually we probably want to bind before visiting, but for now this is causing an issue
1014 // with the GetVirtualDevice utility, so leaving as is for now.
1015
1016 // Rewrite the body. Note that the body may have begun with an "on_device" so
1017 // be prepared to insert a "device_copy".
1018 Expr body = VisitChild(
1019 /*lexical_virtual_device=*/result_virtual_device,
1020 /*expected_virtual_device=*/result_virtual_device,
1021 /*child_virtual_device=*/GetVirtualDevice(function_node->body), function_node->body);
1022 VLOG(4) << "Visited body: " << body;
1023 Function func = WithFields(GetRef<Function>(function_node), function_node->params, body);
1024 VLOG(4) << "New function: " << func;
1025 func = SubstituteBoundVars(func, annotated_bind_map);
1026 VLOG(4) << "Func with bound params: " << func;
1027 func->virtual_device_ = result_virtual_device;
1028 VLOG(4) << "Func with bound params & result vid set: " << func;
1029 return std::move(func);
1030 }
1031
1032 Expr VisitExpr_(const CallNode* call_node) final {
1033 auto call = GetRef<Call>(call_node);
1034
1035 // We don't care if the call is pre- or post-lowered
1036 // (However we'll preserve the form in the result below.)
1037 auto vanilla_call = GetAnyCall(call_node);
1038
1039 VirtualDevice call_virtual_device = GetVirtualDevice(call);
1040
1041 auto on_device_props = GetOnDeviceProps(call_node);
1042 if (on_device_props.body.defined()) {
1043 // We're done with the original "on_device" calls and can pinch them out.
1044 // Note that this step has already been simulated by GetDeviceType.
1045 return VisitExpr(on_device_props.body);
1046 }
1047
1048 DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
1049 if (device_copy_props.body.defined()) {
1050 VirtualDevice src_virtual_device =
1051 domains_->config()->CanonicalVirtualDevice(device_copy_props.src_virtual_device);
1052 VirtualDevice dst_virtual_device =
1053 domains_->config()->CanonicalVirtualDevice(device_copy_props.dst_virtual_device);
1054 ICHECK_EQ(call_virtual_device, dst_virtual_device);
1055 if (src_virtual_device == dst_virtual_device) {
1056 // We can pinch out existing "device_copy" CallNodes if their source and destinations
1057 // match.
1058 return VisitExpr(device_copy_props.body);
1059 } else {
1060 return VisitChild(/*lexical_virtual_device=*/dst_virtual_device,
1061 /*expected_virtual_device=*/dst_virtual_device,
1062 /*child_virtual_device=*/src_virtual_device, device_copy_props.body);
1063 }
1064 }
1065
1066 // Generic call.
1067 auto func_domain = domains_->DomainForCallee(call); // higher-order
1068 VLOG(2) << "considering call:" << std::endl
1069 << PrettyPrint(call) << std::endl
1070 << "in virtual device " << call_virtual_device
1071 << " with function virtual devices:" << std::endl
1072 << domains_->ToString(func_domain);
1073 VirtualDevice result_virtual_device = domains_->ResultVirtualDevice(func_domain);
1074 ICHECK(!result_virtual_device->IsFullyUnconstrained());
1075
1076 // The callee is on the current device.
1077 Expr op = VisitChild(
1078 /*lexical_virtual_device=*/call_virtual_device,
1079 /*expected_virtual_device=*/call_virtual_device,
1080 /*child_virtual_device=*/result_virtual_device, vanilla_call->op);
1081
1082 // Each argument can be on the device for the corresponding function parameter. However if
1083 // any of those differ from the overall call device then wrap them in an "on_device" to
1084 // help downstream transforms track devices lexically.
1085 Array<Expr> args;
1086 args.reserve(vanilla_call->args.size());
1087 ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size());
1088 for (size_t i = 0; i < vanilla_call->args.size(); ++i) {
1089 VirtualDevice param_virtual_device =
1090 domains_->ResultVirtualDevice(func_domain->function_param(i));
1091 ICHECK(!param_virtual_device->IsFullyUnconstrained())
1092 << "for parameter " << i << " for call:" << std::endl
1093 << PrettyPrint(call);
1094 args.push_back(VisitChild(/*lexical_virtual_device=*/call_virtual_device,
1095 /*expected_virtual_device=*/param_virtual_device,
1096 /*child_virtual_device=*/GetVirtualDevice(vanilla_call->args[i]),
1097 vanilla_call->args[i]));
1098 }
1099
1100 if (call_node->op == CallLoweredOp()) {
1101 Call new_call =
1102 CallLowered(Downcast<GlobalVar>(op), args, /*call_lowered_attrs=*/{}, /*span=*/{});
1103 return WithFields(call, new_call->op, new_call->args);
1104 } else {
1105 return WithFields(call, op, args);
1106 }
1107 }
1108
1109 Expr VisitExpr_(const LetNode* let_node) final {
1110 Expr expr = GetRef<Expr>(let_node);
1111 // Iterate through chained lets, provided they all agree on their device type.
1112 VirtualDevice let_virtual_device = GetVirtualDevice(expr);
1113 std::vector<std::tuple<Var, Expr, Span>> bindings;
1114 while (const auto* inner_let_node = expr.as<LetNode>()) {
1115 Expr inner_let = GetRef<Let>(inner_let_node);
1116 if (GetVirtualDevice(inner_let) != let_virtual_device) {
1117 // We have a device transition which needs to be handled.
1118 break;
1119 }
1120 // The let-bound value can be on a different device than the overall let.
1121 // By using the fully-unconstrained virtual device for the 'lexical' scope we'll force the
1122 // let-bound value to *always* be wrapped by an "on_device" (see introductory comment for
1123 // motivation.)
1124 Expr value = VisitChild(/*lexical_virtual_device=*/VirtualDevice::FullyUnconstrained(),
1125 /*expected_virtual_device=*/GetVirtualDevice(inner_let_node->var),
1126 /*child_virtual_device=*/GetVirtualDevice(inner_let_node->value),
1127 inner_let_node->value);
1128 bindings.emplace_back(inner_let_node->var, value, inner_let_node->span);
1129 expr = inner_let_node->body;
1130 }
1131 Expr body = VisitChild(/*lexical_virtual_device=*/let_virtual_device,
1132 /*expected_virtual_device=*/let_virtual_device,
1133 /*child_virtual_device=*/GetVirtualDevice(expr), expr);
1134 for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
1135 body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body,
1136 /*span=*/std::get<2>(*itr));
1137 }
1138 return body;
1139 }
1140
1141 Expr VisitExpr_(const IfNode* if_node) final {
1142 auto ife = GetRef<If>(if_node);
1143 Expr cond = VisitChild(ife, if_node->cond);
1144 Expr true_branch = VisitChild(ife, if_node->true_branch);
1145 Expr false_branch = VisitChild(ife, if_node->false_branch);
1146 return WithFields(ife, cond, true_branch, false_branch);
1147 }
1148
1149 Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final {
1150 auto tuple_get_item = GetRef<TupleGetItem>(tuple_get_item_node);
1151 Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple);
1152 return WithFields(tuple_get_item, tuple);
1153 }
1154
1155 Expr VisitExpr_(const RefCreateNode* ref_create_node) final {
1156 auto ref_create = GetRef<RefCreate>(ref_create_node);
1157 Expr value = VisitChild(ref_create, ref_create_node->value);
1158 return WithFields(ref_create, value);
1159 }
1160
1161 Expr VisitExpr_(const RefReadNode* ref_read_node) final {
1162 auto ref_read = GetRef<RefRead>(ref_read_node);
1163 Expr ref = VisitChild(ref_read, ref_read_node->ref);
1164 return WithFields(ref_read, ref);
1165 }
1166
1167 Expr VisitExpr_(const RefWriteNode* ref_write_node) final {
1168 auto ref_write = GetRef<RefWrite>(ref_write_node);
1169 Expr ref = VisitChild(ref_write, ref_write_node->ref);
1170 Expr value = VisitChild(ref_write, ref_write_node->value);
1171 return WithFields(ref_write, ref, value);
1172 }
1173
1174 Expr VisitExpr_(const MatchNode* match_node) final {
1175 auto match = GetRef<Match>(match_node);
1176 Expr data = VisitChild(match, match_node->data);
1177 Array<Clause> clauses;
1178 clauses.reserve(match_node->clauses.size());
1179 for (const auto& clause : match_node->clauses) {
1180 Pattern lhs = VisitPattern(clause->lhs); // actually a no-op, so we're not checking vars
1181 Expr rhs = VisitChild(match, clause->rhs);
1182 clauses.push_back(Clause(lhs, rhs));
1183 }
1184 return WithFields(match, data, clauses);
1185 }
1186
1187 VirtualDevice GetVirtualDevice(const Expr& expr) {
1188 // Look through any "on_device" CallNodes, to mimic how we will be pinching them out.
1189 OnDeviceProps props = GetOnDeviceProps(expr);
1190 Expr true_expr = props.body.defined() ? props.body : expr;
1191 ICHECK(domains_->contains(true_expr));
1192 // If expr is higher order we'll return only the result domain's device.
1193 VirtualDevice virtual_device = domains_->ResultVirtualDevice(domains_->DomainFor(true_expr));
1194 ICHECK(!virtual_device->IsFullyUnconstrained())
1195 << "no VirtualDevice was determined for expression:" << std::endl
1196 << PrettyPrint(true_expr);
1197 return std::move(virtual_device);
1198 }
1199
1200 /*!
1201 * \brief Reconcile the \p child_virtual_device for \p child with both the \p
1202 * expected_virtual_device (as required by the expression context the \p child is in) and the \p
1203 * lexical_virtual_device (as a downstream transform would infer based only on lexically enclosing
1204 * "on_device" CallNodes and function attributes.) Generally \p lexical_virtual_device and \p
1205 * expected_virtual_device are the same by definition, but may differ in arguments to functions
1206 * and let-bound expressions.
1207 *
1208 * If \p child_virtual_device differs from \p expected_virtual_device, wrap it as:
1209 * \code
1210 * device_copy(on_device(child', virtual_device=child_virtual_device),
1211 * src_dev_type=child_virtual_device, dst_dev_type=expected_virtual_device)
1212 * \endcode
1213 * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the
1214 * child.
1215 *
1216 * If \p expected_virtual_device differs from \p lexical_virtual_device, then (also) wrap
1217 * the expression as:
1218 * \code
1219 * on_device(..., virtual_device=expected_virtual_device)
1220 * \endcode
1221 *
1222 * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped
1223 * by a "device_copy", even though those copies will generally all be to the same destination
1224 * device.
1225 */
1226 Expr VisitChild(const VirtualDevice& lexical_virtual_device,
1227 const VirtualDevice& expected_virtual_device,
1228 const VirtualDevice& child_virtual_device, const Expr& child) {
1229 ICHECK(!expected_virtual_device->IsFullyUnconstrained());
1230 if (child->IsInstance<OpNode>() || child->IsInstance<ConstructorNode>()) {
1231 // Primitive operators and contructors don't need to be rewritten and can have a
1232 // different domain at each call site.
1233 return child;
1234 }
1235 Expr result = VisitExpr(child);
1236 if (child_virtual_device != expected_virtual_device) {
1237 VLOG(2) << "creating " << DeviceCopyOp()->name << " from virtual device "
1238 << child_virtual_device << " to virtual device " << expected_virtual_device
1239 << " for:" << std::endl
1240 << PrettyPrint(result);
1241 // Also wrap the child in an "on_device" so downstream transforms can track devices
1242 // lexically.
1243 result = MaybeOnDeviceFixed(result, child_virtual_device);
1244 result = DeviceCopy(result, child_virtual_device, expected_virtual_device);
1245 }
1246 if (expected_virtual_device != lexical_virtual_device) {
1247 VLOG(2) << "creating " << OnDeviceOp()->name << " for virtual device "
1248 << expected_virtual_device << " for:" << std::endl
1249 << PrettyPrint(result);
1250 result = MaybeOnDeviceFixed(result, expected_virtual_device);
1251 }
1252 return result;
1253 }
1254
1255 /*!
1256 * Common case of visiting a direct \p child of \p parent where by default the \p child
1257 * is expected to be on the same device as the \p parent.
1258 */
1259 Expr VisitChild(const Expr& parent, const Expr& child) {
1260 VirtualDevice expected_virtual_device = GetVirtualDevice(parent);
1261 VirtualDevice child_virtual_device = GetVirtualDevice(child);
1262 return VisitChild(expected_virtual_device, expected_virtual_device, child_virtual_device,
1263 child);
1264 }
1265
1266 /*! \brief Module we are rewriting, so we can lookup global variables. */
1267 IRModule mod_;
1268 /*! \brief Device domain for every expression from DeviceAnalyzer. */
1269 std::unique_ptr<DeviceDomains> domains_;
1270};
1271
1272/*! \brief Rewrite the "on_device" calls (and implicitly re-type-check). */
1273tvm::transform::Pass Rewrite() {
1274 auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) {
1275 auto attrs = m->attrs;
1276 auto r = Downcast<Function>(RewriteOnDevices(std::move(m)).Mutate(f));
1277 return attrs.defined() ? WithAttrs(r, {attrs->dict}) : r;
1278 };
1279 return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite", {});
1280}
1281
1282/*! \brief Run the remaining phases. */
1283tvm::transform::Pass PlanDevicesCore(CompilationConfig config) {
1284 return tvm::transform::CreateModulePass(
1285 [config = std::move(config)](IRModule mod,
1286 tvm::transform::PassContext pass_cnxt) -> IRModule {
1287 // Collect the system of constraints for every sub-expression using existing "on_device"
1288 // and "device_copy" calls.
1289 std::unique_ptr<DeviceDomains> domains = DeviceAnalyzer(mod, config).Analyze();
1290 VLOG(3) << "Domains after analysis:" << std::endl << domains->ToString();
1291
1292 // Choose sensible default devices for every sub-expression if otherwise unconstrained
1293 // by existing "on_device" or "device_copy" calls.
1294 domains = FreeOnDeviceDefaulter(mod, std::move(domains)).Default();
1295 domains = DeviceDefaulter(mod, std::move(domains)).Default();
1296 VLOG(3) << "Domains after defaulting: " << std::endl << domains->ToString();
1297
1298 // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture
1299 // the above map, and attach additional "param_virtual_devices" and "result_virtual_device"
1300 // attributes to all function definitions.
1301 return DeviceCapturer(mod, std::move(domains)).Capture();
1302 },
1303 /*opt_level=*/0, "PlanDevicesCore", {});
1304}
1305
1306} // namespace
1307
1308/* =============== Driver =============== */
1309
1310// This function is declared in the public <tvm/relay/transform.h>.
1311tvm::transform::Pass PlanDevices(CompilationConfig config) {
1312 std::vector<Pass> passes;
1313 passes.emplace_back(Rewrite());
1314 passes.emplace_back(PlanDevicesCore(std::move(config)));
1315 return tvm::transform::Sequential(passes, "PlanDevices");
1316}
1317
1318TVM_REGISTER_GLOBAL("relay._transform.PlanDevices").set_body_typed(PlanDevices);
1319
1320} // namespace transform
1321} // namespace relay
1322} // namespace tvm
1323