1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/transforms/device_planner.cc |
22 | * \brief Determines a unique \p VirtualDevice to hold the result of every Relay sub-expression. |
23 | * This pass can be run multiple times, and can be run both before and after lowering. |
24 | * |
25 | * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. |
26 | * We represent D by an \p VirtualDevice, which means we can track anywhere from an arbitrary device |
27 | * of some \p DLDeviceType to a specific memory scope on a specific (virtual) \p Device who's |
28 | * code is compiled with a specific \p Target. |
29 | * |
30 | * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', |
31 | * see below. |
32 | * |
33 | * This pass works by collecting and solving device constraints, using defaulting heuristics to |
34 | * resolve any remaining undetermined devices, and encoding the results on the output in a form |
35 | * that's reasonably friendly to downstream passes. |
36 | * |
37 | * Specific \p VirtualDevices flow into the constraints from five places: |
38 | * - Existing "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a |
39 | * 'src_virtual_device' and 'dst_virtual_device' \p VirtualDevice. Those constrain the argument |
40 | * and context of the call respectively. It is ok if source and destination devices are the same, |
41 | * such no-op copies will be removed after accounting for the device preference. |
42 | * - Existing "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an |
43 | * 'virtual_device', which constrains the argument of the call, but (usually, see below) leaves the |
44 | * context unconstrained. These are called 'annotations' in the rest of the code, have no |
45 | * operational significance by themselves, but may trigger the insertion of a new "device_copy" call |
46 | * by this pass. In two situations the result of an "on_device" CallNode may also be constrained to |
47 | * the given 'virtual_device': |
48 | * - The "on_device" call occurs at the top-level of a function body, or occurs as an |
49 | * immediately let-bound expression. In this situation the extra degree of freedom in |
50 | * the function result and let-binding leads to surprising device copies, so we simply |
51 | * force the function result or let-bound variable to the given device. |
52 | * - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted |
53 | * it ourselves during an earlier invocation of this pass. This helps make this pass |
54 | * idempotent. |
55 | * - Some special operators require their arguments or results to be on the 'host' (typcially |
56 | * a CPU) \p VirtualDevice, see below. |
57 | * - Any \p PrimFuncs in the \p IRModule (if \p LowerTEPass has already run) may constrain their |
58 | * argument buffers to have a specific memory scope, which is part of \p VirtualDevice. |
59 | * - Annotations left over from a previous run of this pass, such as 'param_virtual_devices' and |
60 | * 'result_virtual_device' function attributes we introduce below. This is so the pass is |
61 | * idempotent and can be re-run to flow additional memory scope constraints. |
62 | * |
63 | * We proceed in four phases: |
64 | * |
65 | * Phase 0 |
66 | * ------- |
67 | * We rewrite the programs to handle some special cases: |
68 | * - "on_device" calls at the top-level of function or immediately let-bound are rewritten |
69 | * to have \code is_fixed=true \endcode. |
70 | * - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written |
71 | * \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from |
72 | * the tuple rather than project from a copy of the tuple. We'll do this by rewriting. |
73 | * - We are prepared to insert device_copies on the arguments and result of calls to PrimFuncs, |
74 | * on the assumption a) we already ran PlanDevices before lowering so we are not allowing |
75 | * any new cross-device copies, but b) after lowering we may have new memory scope constraits |
76 | * to deal with. |
77 | * |
78 | * Phase 1 |
79 | * ------- |
80 | * We flow constraints from the "on_device" and "device_copy" calls, PrimFunc buffer memory scopes, |
81 | * and some special ops, to all other Relay sub-expressions. |
82 | * |
83 | * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the |
84 | * same device. However each call site can use a different device. In other words primitives are |
85 | * 'device polymorphic' since we compile and execute them for each required device. ADT constructors |
86 | * are similarly polymorphic, but require all constructor args to be on the same device. |
87 | * |
88 | * For most Relay expressions the device for the overall expression is the same as the device |
89 | * for its sub-expressions. E.g. each field of a tuple must be on the same device as the tuple |
90 | * itself, the condition and arms of an \p if must all be on the same device as the overall \p if, |
91 | * and so on. |
92 | * |
93 | * Some special ops (or 'dialects') are handled: |
94 | * - Relay supports computing the shape of tensors and operators at runtime using "shape_of" |
95 | * and "reshape_tensor". Shapes must only be held on the CPU, but the tensors they describe |
96 | * may reside on any device. |
97 | * - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again |
98 | * shapes reside on the CPU, but the allocated tensors may reside on any device. |
99 | * |
100 | * Two Relay expression have special handling: |
101 | * - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the |
102 | * overall let. However the result of \p e1 may be on a different device. |
103 | * - For a function \code fn(x, y) { body } \endcode the result of the function must be on the |
104 | * same device as \p body. However parameters \p x and \p may be on different devices, even |
105 | * different from each other. Every call to the function must use the same choice of parameter |
106 | * and result devices -- there is no 'device polymorphism' for Relay functions. |
107 | * |
108 | * Currently \p PrimFuncs and external functions do not carry over their parameter and result |
109 | * devices from their original Relay Function representations. However we know all calls to those |
110 | * functions are device-consistent, thus no information is lost. |
111 | * |
112 | * Phase 2 |
113 | * ------- |
114 | * After flowing constraints we apply some defaulting heuristics (using a global default \p |
115 | * VirtualDevice) to fix the device for any as-yet unconstrained sub-expressions. |
116 | * - Unconstrained function result devices default to the global default device. |
117 | * - Unconstrained function parameters devices default to the device for the function result. |
118 | * - Unconstrained let-bound expression devices default to the device for the overall let. |
119 | * TODO(mbs): These are very simple minded heuristics, and ultimately we'd like to treat the |
120 | * assignment of the remaining unconstrained sub-expressions as an optimiziation problem in itself. |
121 | * This requires a formal notion of 'choicepoint' inside the compiler which can integrate with |
122 | * automation. |
123 | * |
124 | * Phase 3 |
125 | * ------- |
126 | * Finally, the result of this analysis is reified into the result as: |
127 | * - Additional "param_virtual_devices" (an \p Array<VirtualDevice>) and "result_virtual_device" |
128 | * (an \p VirtualDevice) attributes for every function (both top-level and local). These describe |
129 | * the devices for the function's parameters and the result. |
130 | * - Additional "device_copy" CallNodes where a copy is required in order to respect the |
131 | * intent of the original "on_device" CallNodes. |
132 | * - Additional "on_device" CallNodes where the device type of an expression is not trivially |
133 | * implied by the lexically enclosing "on_device" CallNode or function attribute. In practice |
134 | * this means "on_device" CallNodes may appear in two places: |
135 | * - On let-bound expressions. It is tempting to elide the "on_device" if the let-bound value |
136 | * has the same device as the overall let expression. However this would mean passes which |
137 | * inline let-bound values, such as FoldConstant and DeadCodeElimination, would need to us |
138 | * a DeviceAware visitor which in turn requires the expression to be in ANF to avoid |
139 | * deep recursion. To minimize disruption we always include the "on_device" so that it |
140 | * can follow the inline. |
141 | * - On a call argument if its device differs from the call result. In particular, the |
142 | * argument to a "device_copy" call will always be wrapped in an "on_device". (That may |
143 | * seem pedantic but simplifies downstream handling.) |
144 | * However since we make it easy to track devices for variables we never wrap an "on_device" |
145 | * around a var or global var. These uses of "on_device" imply both the argument and result are |
146 | * on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true, |
147 | * which helps make this pass idempotent. |
148 | * - The buffer maps for called PrimFuncs are updated to capture memory scopes. |
149 | * |
150 | * Helper visitors (in device_aware_visitors.h) can be used by downstream transforms to recover |
151 | * the device for any expression for their own use, e.g. during memory planning. All downstream |
152 | * passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion |
153 | * to ANF must respect the lexical scoping convention: |
154 | * \code |
155 | * f(on_device(g(h(a, b), c), virtual_device=CPU)) |
156 | * ==> |
157 | * let %x0 = on_device(h(a, b), virtual_device=CPU) |
158 | * let %x1 = on_device(g(%x0), virtual_device=CPU) |
159 | * f(on_device(%x1, virtual_device=CPU)) |
160 | * \endcode |
161 | * |
162 | * This pass can be run before FuseOps so that it can use device-specific fusion rules. |
163 | * |
164 | * 'Stored on' vs 'Executes on' |
165 | * ---------------------------- |
166 | * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the |
167 | * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for |
168 | * primitives. |
169 | * |
170 | * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are |
171 | * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific |
172 | * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to |
173 | * know exactly which device (possibly one of a number of available 'CPU'-like devices) is |
174 | * responsible for execution. Currently that's handled independently by the \p AnnotateTargets |
175 | * pass, but we'd like to fold that into device planning here to ensure everything is consistent. |
176 | * |
177 | * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay |
178 | * expression (eg an \p if expression) on one device even though the tensor data resides on |
179 | * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on' |
180 | * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just |
181 | * compile the function body for the function's result device. |
182 | * |
183 | * This works after conversion to ANF provided the compilation for a let expression is prepared |
184 | * to make a cross-device call. However we leave it to a downstream transformation to heuristically |
185 | * minimize cross-device calls by moving device copies out of functions. E.g.: |
186 | * \code |
187 | * def @f() { // execute on CPU |
188 | * let x = on_device(...GPU computation..., virtual_device=GPU); |
189 | * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU) |
190 | * } |
191 | * def @main() { |
192 | * ... call @f() on CPU ... |
193 | * } |
194 | * \endcode |
195 | * could be rewritten to: |
196 | * \code |
197 | * def @f() { // execute on GPU |
198 | * let x = ...GPU computation...; |
199 | * ...GPU computation... |
200 | * } |
201 | * def @main() { |
202 | * let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU) |
203 | * ... use x on CPU ... |
204 | * } |
205 | * \endcode |
206 | * |
207 | * Higher-order shenanigans |
208 | * ------------------------ |
209 | * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions |
210 | * as arguments (even anonymous functions), return functions, evaluate conditional expressions |
211 | * over functions, and so on. We handle this during constraint solving using the domain: |
212 | * \code |
213 | * D ::= <specific device type> -- first-order |
214 | * | fn(D,...,D):D -- higher-order |
215 | * \endcode |
216 | * In this way we can determine the device for all function parameters and results. E.g. for |
217 | * \code |
218 | * let f = fn(x, y) { ... } |
219 | * let g = fn(f, z) { f(z, z) } |
220 | * g(f, on_device(..., virtual_device=CPU)) |
221 | * \endcode |
222 | * the parameters \p x and \p y will be on the CPU. |
223 | * |
224 | * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a |
225 | * function. Our analysis must guarantee that the function's parameters and result devices are |
226 | * consistent for \p e2, \p e3, and the context of the call. But: |
227 | * - Which device holds the closure result of evaluating \p e1 ? |
228 | * - If \p e2 is of function type, what does that mean when we say every function parameter |
229 | * is on a device? |
230 | * - If \p e1 returns a function, what does that mean when we say every function result is |
231 | * on a device? |
232 | * |
233 | * Since higher-order aspects are later compiled away (by 'defunctionalization' |
234 | * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular, |
235 | * we really don't want our domain \p D to allow for yet another device for the function closure. |
236 | * So we'll just force the 'device for a function' to be the same as the device for the function's |
237 | * result using the notion of the 'result domain' for a domain: |
238 | * \code |
239 | * result_domain(<specific device type>) = <specific device type> |
240 | * result_domain(fn(D1,...,Dn):Dr) = result_domain(Dr) |
241 | * \endcode |
242 | * |
243 | * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the |
244 | * analysis encounters a function inside one of those it simply forces all argument and result |
245 | * devices for the function to match the device for the first-order expression. For example, |
246 | * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function |
247 | * parameters and result must similarly be on the GPU. |
248 | * |
249 | * ------- |
250 | * | AOR | This pass supports all of Relay. |
251 | * ------- |
252 | * ^ |
253 | * | |
254 | * `-- Mark's stamp of completeness :-) |
255 | * |
256 | * TODO(mbs): Proper diagnostics for unification failure using spans. |
257 | * TODO(mbs): We may want some 'device polymorphism' for Relay functions. Eg it's ok for the |
258 | * function to be called with params/result on different (virtual) device ids provided the target |
259 | * and memory scopes are consistent. |
260 | */ |
261 | |
262 | #include <tvm/ir/transform.h> |
263 | #include <tvm/relay/analysis.h> |
264 | #include <tvm/relay/attrs/annotation.h> |
265 | #include <tvm/relay/attrs/device_copy.h> |
266 | #include <tvm/relay/attrs/memory.h> |
267 | #include <tvm/relay/expr_functor.h> |
268 | #include <tvm/relay/op.h> |
269 | #include <tvm/relay/pattern_functor.h> |
270 | #include <tvm/relay/transform.h> |
271 | #include <tvm/relay/type.h> |
272 | #include <tvm/runtime/c_runtime_api.h> |
273 | #include <tvm/runtime/object.h> |
274 | #include <tvm/tir/function.h> |
275 | #include <tvm/tir/stmt_functor.h> |
276 | |
277 | #include <unordered_map> |
278 | |
279 | #include "../../tir/analysis/device_constraint_utils.h" |
280 | #include "../op/annotation/annotation.h" |
281 | #include "../op/memory/device_copy.h" |
282 | #include "../op/memory/on_device.h" |
283 | #include "./device_domains.h" |
284 | |
285 | namespace tvm { |
286 | namespace relay { |
287 | namespace transform { |
288 | |
289 | namespace { |
290 | |
291 | /* =============== Phase 0 =============== */ |
292 | |
293 | /*! |
294 | * \brief Rewrites "on_device" calls to handle some special cases. |
295 | * |
296 | * - Don't let the device for %x remain unconstrained: |
297 | * \code |
298 | * let %x = on_device(e, virtual_device=d) |
299 | * ==> let %x = on_device(e, virtual_device=d, constraint=kBoth) |
300 | * \endcode |
301 | * |
302 | * - Don't let the function result remain unconstrained: |
303 | * \code |
304 | * fn(%x) { on_device(e, virtual_device=d) } |
305 | * ==> fn(%x) { on_device(e, virtual_device=d, constraint=kBoth) |
306 | * \endcode |
307 | * |
308 | * - Project-then-copy rather than copy-then-project: |
309 | * \code |
310 | * on_device(e).0 |
311 | * ==> on_device(e.0) |
312 | * \endcode |
313 | * |
314 | * - Be prepared to copy arguments and results on primitive call boundaries in case memory |
315 | * scopes don't line up. We'll use the 'fully unconstrained' version of on_device so that |
316 | * we can allow for a device_copy without knowing the specific device for the arguments. |
317 | * \code |
318 | * call_lowered(@prim, (a, b)) |
319 | * ==> copy_ok(call_lowered(@prim, (copy_ok(a), copy_ok(b)))) |
320 | * where |
321 | * copy_ok(x) = on_device(x, virtual_device=VirtualDevice::FullyUnconstrained, |
322 | * constrain_body=False, constrain_result=False) |
323 | * \endcode |
324 | */ |
325 | class RewriteOnDevices : public ExprMutator { |
326 | public: |
327 | explicit RewriteOnDevices(IRModule mod) : mod_(std::move(mod)) {} |
328 | |
329 | private: |
330 | Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { |
331 | Expr tuple = VisitExpr(tuple_get_item_node->tuple); |
332 | OnDeviceProps props = GetOnDeviceProps(tuple); |
333 | |
334 | Expr tuple_get_item = WithFields(GetRef<TupleGetItem>(tuple_get_item_node), tuple); |
335 | if (props.body.defined() && props.is_normal()) { |
336 | VLOG(2) << "wrapping tuple get item:" << std::endl |
337 | << PrettyPrint(GetRef<TupleGetItem>(tuple_get_item_node)) << std::endl |
338 | << "with \"on_device\" for VirtualDevice " << props.virtual_device; |
339 | return OnDeviceWithProps(tuple_get_item, props); |
340 | } else { |
341 | return tuple_get_item; |
342 | } |
343 | } |
344 | |
345 | Expr VisitExpr_(const LetNode* let_node) final { |
346 | auto expr = GetRef<Expr>(let_node); |
347 | std::vector<std::tuple<Let, Expr>> bindings; |
348 | while (const auto* inner_let_node = expr.as<LetNode>()) { |
349 | Let inner_let = GetRef<Let>(inner_let_node); |
350 | Expr value = VisitExpr(inner_let_node->value); |
351 | OnDeviceProps props = GetOnDeviceProps(value); |
352 | if (props.body.defined() && props.is_normal()) { |
353 | VLOG(2) << "revising let-bound expression of let:" << std::endl |
354 | << PrettyPrint(expr) << std::endl |
355 | << "to be fixed to VirtualDevice " << props.virtual_device; |
356 | value = MaybeOnDeviceFixed(props.body, props.virtual_device); |
357 | } |
358 | bindings.emplace_back(inner_let, value); |
359 | expr = inner_let_node->body; |
360 | } |
361 | expr = VisitExpr(expr); |
362 | for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { |
363 | expr = WithFields(/*let=*/std::get<0>(*itr), /*opt_var=*/{}, |
364 | /*opt_value=*/std::get<1>(*itr), /*opt_body=*/expr); |
365 | } |
366 | return expr; |
367 | } |
368 | |
369 | Expr VisitExpr_(const FunctionNode* function_node) final { |
370 | Expr body = VisitExpr(function_node->body); |
371 | OnDeviceProps props = GetOnDeviceProps(body); |
372 | if (props.body.defined() && props.is_normal()) { |
373 | VLOG(2) << "revising body of function:" << std::endl |
374 | << PrettyPrint(GetRef<Function>(function_node)) << std::endl |
375 | << "to be fixed to VirtualDevice " << props.virtual_device; |
376 | body = MaybeOnDeviceFixed(props.body, props.virtual_device); |
377 | } |
378 | return WithFields(GetRef<Function>(function_node), function_node->params, body); |
379 | } |
380 | |
381 | Expr VisitExpr_(const CallNode* call_node) final { |
382 | CallLoweredProps props = GetCallLoweredProps(call_node); |
383 | if (props.lowered_func.defined()) { |
384 | BaseFunc base_func = mod_->Lookup(props.lowered_func); |
385 | if (base_func.as<tir::PrimFuncNode>()) { |
386 | VLOG(2) << "allowing device_copy on PrimFunc arguments and result" ; |
387 | Array<Expr> new_args; |
388 | new_args.reserve(props.arguments.size()); |
389 | for (const auto& arg : props.arguments) { |
390 | Expr new_arg = VisitExpr(arg); |
391 | new_args.push_back(OnDeviceCopyOk(std::move(new_arg))); |
392 | } |
393 | Call new_call = CallLowered(std::move(props.lowered_func), std::move(new_args), props.attrs, |
394 | call_node->span); |
395 | return OnDeviceCopyOk(std::move(new_call)); |
396 | } |
397 | } |
398 | return ExprMutator::VisitExpr_(call_node); |
399 | } |
400 | |
401 | /*! \brief Module we are rewriting, so we can lookup global definitions. */ |
402 | IRModule mod_; |
403 | }; |
404 | |
405 | /* =============== Phase 1 =============== */ |
406 | |
407 | /* |
408 | * \brief Collects the system of device constraints for all sub-expressions in a module. |
409 | * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter. |
410 | * |
411 | * Eg from \code add(%x, %y) \endcode we know \p %x and \p %y must be on the same device. Later, |
412 | * from \code on_device(%x, virtual_device=d) \endcode we know \p %x must be on device \p d, and |
413 | * thus so must \p %y. |
414 | * |
415 | * Constraints can flow in interesting ways. E.g. in: |
416 | * \code |
417 | * let %f = fn(%x, %y) { add(%x, on_device(%y, virtual_device=d)) } |
418 | * let %g = fn(%f, %x, %y) { %f(%x, %y) } |
419 | * %g(%f, %a, %b) |
420 | * \endcode |
421 | * we discover \p %b must be on device \p d. |
422 | */ |
423 | class DeviceAnalyzer : public MixedModeVisitor { |
424 | public: |
425 | DeviceAnalyzer(IRModule mod, CompilationConfig config) |
426 | : mod_(std::move(mod)), domains_(std::make_unique<DeviceDomains>(std::move(config))) {} |
427 | |
428 | /*! |
429 | * \brief Returns the expression-to-device-domain map for all expressions in all the global |
430 | * function definitions in the module. Expressions may have free domains, these will be resolved |
431 | * by \p DeviceDefaulter below. |
432 | */ |
433 | std::unique_ptr<DeviceDomains> Analyze() { |
434 | VLOG_CONTEXT << "DeviceAnalyzer" ; |
435 | for (const auto& kv : mod_->functions) { |
436 | // The global variable and what it is bound to must obviously agree on domain. |
437 | if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { |
438 | VLOG(2) << "collecting constraints from Relay Function '" << kv.first->name_hint << "'" ; |
439 | domains_->UnifyExprExact(kv.first, kv.second); |
440 | VisitExpr(GetRef<Function>(function_node)); |
441 | } else if (const auto* prim_func_node = kv.second.as<tir::PrimFuncNode>()) { |
442 | VLOG(2) << "collecting constraints from TIR PrimFunc '" << kv.first->name_hint << "'" ; |
443 | domains_->UnifyExprExact( |
444 | kv.first, DomainForPrimFunc(kv.first, GetRef<tir::PrimFunc>(prim_func_node))); |
445 | } else { |
446 | VLOG(2) << "skipping '" << kv.first->name_hint << "'" ; |
447 | } |
448 | } |
449 | return std::move(domains_); |
450 | } |
451 | |
452 | private: |
453 | /*! |
454 | * \brief Return the domain representing \p prim_func which, before lowering, had |
455 | * the Relay \p type. |
456 | */ |
457 | DeviceDomainPtr DomainForPrimFunc(const GlobalVar& global_var, const tir::PrimFunc& prim_func) { |
458 | // CAUTION: The prim_func->checked_type() is currently w.r.t. the flattened and DPS form |
459 | // of the prim func, however here we wish to remain within the Relay view of all functions. |
460 | // Thus we'll use the global var who's checked_type is in Relay form. |
461 | auto func_domain = domains_->DomainFor(global_var); // higher-order |
462 | |
463 | // TODO(mbs): We don't visit the body of the function -- there's currently nothing to be done. |
464 | const auto* func_type_node = global_var->checked_type().as<FuncTypeNode>(); |
465 | ICHECK(func_type_node); |
466 | ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size()); |
467 | |
468 | Array<VirtualDevice> virtual_devices = |
469 | tir::GetPrimFuncArgAndResultConstraints(prim_func, GetRef<FuncType>(func_type_node)); |
470 | |
471 | // Build the implied domain (in terms of the function's Relay type) implied by any memory scope |
472 | // constrains in the function's buffers, for both arguments and results. |
473 | std::vector<DeviceDomainPtr> args_and_result_domains; |
474 | args_and_result_domains.reserve(virtual_devices.size()); |
475 | for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) { |
476 | const VirtualDevice& param_virtual_device = virtual_devices[i]; |
477 | VLOG(2) << "param_virtual_device[" << i << "] = " << param_virtual_device; |
478 | args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(param_virtual_device)); |
479 | } |
480 | const VirtualDevice& ret_virtual_device = virtual_devices.back(); |
481 | VLOG(2) << "ret_virtual_device = " << ret_virtual_device; |
482 | args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(ret_virtual_device)); |
483 | |
484 | return domains_->MakeHigherOrderDomain(std::move(args_and_result_domains)); |
485 | } |
486 | |
487 | void VisitExpr_(const CallNode* call_node) final { |
488 | auto call = GetRef<Call>(call_node); |
489 | |
490 | // We don't care if the call is in pre- or post-lowered form. |
491 | auto vanilla_call = GetAnyCall(call_node); |
492 | |
493 | // Find the higher-order domain for the callee. See DomainForCallee for the special rules |
494 | // for primitives. |
495 | VisitExpr(vanilla_call->op); |
496 | auto func_domain = domains_->DomainForCallee(call); // higher-order |
497 | |
498 | // Build the domain for the function implied by its arguments and call context. |
499 | ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()) << PrettyPrint(call); |
500 | std::vector<DeviceDomainPtr> args_and_result_domains; |
501 | args_and_result_domains.reserve(vanilla_call->args.size() + 1); |
502 | for (const auto& arg : vanilla_call->args) { |
503 | args_and_result_domains.emplace_back(domains_->DomainFor(arg)); |
504 | } |
505 | args_and_result_domains.emplace_back(domains_->DomainFor(call)); |
506 | auto implied_domain = |
507 | domains_->MakeHigherOrderDomain(std::move(args_and_result_domains)); // higher-order |
508 | |
509 | VLOG(2) << "initial call function domain:" << std::endl |
510 | << domains_->ToString(func_domain) << std::endl |
511 | << "and implied domain:" << std::endl |
512 | << domains_->ToString(implied_domain) << std::endl |
513 | << "for call:" << std::endl |
514 | << PrettyPrint(call); |
515 | |
516 | // The above must match. |
517 | if (domains_->UnifyOrNull(func_domain, implied_domain) == nullptr) { // higher-order |
518 | // TODO(mbs): Proper diagnostics. |
519 | LOG(FATAL) |
520 | << "Function parameters and result VirtualDevices do not match those of call. Call:" |
521 | << std::endl |
522 | << PrettyPrint(call) << std::endl |
523 | << "with function virtual devices:" << std::endl |
524 | << domains_->ToString(func_domain) << std::endl |
525 | << "and implied call virtual devices:" << std::endl |
526 | << domains_->ToString(implied_domain); |
527 | } |
528 | |
529 | VLOG(2) << "final call function domain:" << std::endl |
530 | << domains_->ToString(func_domain) << std::endl |
531 | << "for call:" << std::endl |
532 | << PrettyPrint(call); |
533 | } |
534 | |
535 | void VisitExpr_(const LetNode* let_node) final { |
536 | Expr expr = GetRef<Let>(let_node); |
537 | // Iteratively visit let nodes to avoid stack overflow. |
538 | while (expr->IsInstance<LetNode>()) { |
539 | Let let = Downcast<Let>(expr); |
540 | // Let var must be same device as value it is bound to. |
541 | domains_->UnifyExprExact(let->var, let->value); // may be higher-order |
542 | // Let body must be same device as overall let. |
543 | domains_->UnifyExprExact(let, let->body); // may be higher-order |
544 | |
545 | VisitExpr(let->var); |
546 | VisitExpr(let->value); |
547 | |
548 | expr = let->body; |
549 | } |
550 | |
551 | // Visit the last body |
552 | VisitExpr(expr); |
553 | } |
554 | |
555 | void VisitExpr_(const FunctionNode* function_node) final { |
556 | auto function = GetRef<Function>(function_node); |
557 | auto func_domain = domains_->DomainFor(function); // higher-order |
558 | ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); |
559 | |
560 | VLOG(2) << "initial function domain:" << std::endl |
561 | << domains_->ToString(func_domain) << std::endl |
562 | << "and function body domain:" << std::endl |
563 | << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl |
564 | << "for function:" << std::endl |
565 | << PrettyPrint(function); |
566 | |
567 | // The function body domain must match the function result domain. |
568 | domains_->UnifyExprExact(function_node->body, |
569 | func_domain->function_result()); // may be higher-order |
570 | if (!function_node->virtual_device()->IsFullyUnconstrained()) { |
571 | // The function body domain must match any existing virtual device annotation. |
572 | domains_->UnifyExprExact(function_node->body, |
573 | domains_->ForVirtualDevice(function_node->body->checked_type(), |
574 | function_node->virtual_device())); |
575 | } |
576 | |
577 | for (size_t i = 0; i < function_node->params.size(); ++i) { |
578 | const auto& param = function_node->params[i]; |
579 | // The parameter domain must match the function argument domain. |
580 | domains_->UnifyExprExact(param, |
581 | func_domain->function_param(i)); // may be higher-order |
582 | if (!param->virtual_device()->IsFullyUnconstrained()) { |
583 | // The parameter domain must match any existing virtual device annotation. |
584 | domains_->UnifyExprExact( |
585 | param, domains_->ForVirtualDevice(param->checked_type(), param->virtual_device())); |
586 | } |
587 | VisitExpr(param); |
588 | } |
589 | |
590 | // No need to step into the body of Primitive functions. |
591 | if (!function_node->HasNonzeroAttr(attr::kPrimitive)) { |
592 | VisitExpr(function_node->body); |
593 | } |
594 | |
595 | VLOG(2) << "final function domain:" << std::endl |
596 | << domains_->ToString(func_domain) << std::endl |
597 | << "and function body domain:" << std::endl |
598 | << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl |
599 | << "for function:" << std::endl |
600 | << PrettyPrint(function); |
601 | } |
602 | |
603 | void VisitExpr_(const TupleNode* tuple_node) final { |
604 | Tuple tuple = GetRef<Tuple>(tuple_node); |
605 | for (size_t i = 0; i < tuple->fields.size(); i++) { |
606 | auto domain = domains_->DomainFor(tuple->fields[i]); // may be higher-order |
607 | domains_->UnifyExprCollapsed(tuple, domain); // collapse to first-order if needed |
608 | } |
609 | } |
610 | |
611 | void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { |
612 | TupleGetItem tuple_get_item = GetRef<TupleGetItem>(tuple_get_item_node); |
613 | auto domain = domains_->DomainFor(tuple_get_item); // may be higher-order |
614 | domains_->UnifyExprCollapsed(tuple_get_item_node->tuple, |
615 | domain); // collapse to first-order if needed |
616 | } |
617 | |
618 | class DevicePatternAnalyzer : public PatternVisitor { |
619 | public: |
620 | DevicePatternAnalyzer(DeviceDomains* domains, const ExprNode* adt_node) |
621 | : domains_(domains), adt_node_(adt_node) {} |
622 | |
623 | private: |
624 | void VisitPattern_(const PatternVarNode* pattern_var_node) final { |
625 | auto var_domain = domains_->DomainFor(pattern_var_node->var); // may be higher order |
626 | domains_->UnifyExprCollapsed(GetRef<Expr>(adt_node_), |
627 | var_domain); // collapse to first-order if needed |
628 | } |
629 | |
630 | /*! \brief (Mutable borrow of) the domains for all expressions processed so far. */ |
631 | DeviceDomains* domains_; |
632 | /*! \brief The expression for the ADT we are matching over. */ |
633 | const ExprNode* adt_node_; |
634 | }; |
635 | |
636 | void VisitPattern(const Pattern& pattern) final {} |
637 | |
638 | void VisitExpr_(const MatchNode* match_node) final { |
639 | // For match node, we unify the value and the rhs of each clause |
640 | Match match = GetRef<Match>(match_node); |
641 | auto match_domain = domains_->DomainFor(match); // may be higher-order |
642 | DevicePatternAnalyzer pattern_analyzer(domains_.get(), match->data.get()); |
643 | domains_->UnifyExprCollapsed(match->data, match_domain); // collapse to first-order if needed |
644 | for (const auto& clause : match->clauses) { |
645 | pattern_analyzer.VisitPattern(clause->lhs); |
646 | domains_->UnifyExprExact(clause->rhs, match_domain); |
647 | VisitExpr(clause->rhs); |
648 | } |
649 | VisitExpr(match_node->data); |
650 | } |
651 | |
652 | void VisitExpr_(const GlobalVarNode* global_var_node) final { |
653 | domains_->DomainFor(GetRef<GlobalVar>(global_var_node)); |
654 | } |
655 | |
656 | void VisitExpr_(const VarNode* var_node) final { domains_->DomainFor(GetRef<Var>(var_node)); } |
657 | |
658 | void VisitExpr_(const ConstantNode* constant_node) final { |
659 | domains_->DomainFor(GetRef<Constant>(constant_node)); |
660 | } |
661 | |
662 | void VisitExpr_(const ConstructorNode* constructor_node) final { |
663 | // no-op, constructors are handled at their call-sites. |
664 | // TODO(mbs): Assumes eta-expansion |
665 | } |
666 | |
667 | void VisitExpr_(const IfNode* if_node) final { |
668 | auto ife = GetRef<If>(if_node); |
669 | auto domain = domains_->DomainFor(ife); // may be higher-order |
670 | domains_->UnifyExprCollapsed(if_node->cond, domain); // collapse to first-order if needed |
671 | domains_->UnifyExprExact(if_node->true_branch, domain); |
672 | domains_->UnifyExprExact(if_node->false_branch, domain); |
673 | VisitExpr(if_node->cond); |
674 | VisitExpr(if_node->true_branch); |
675 | VisitExpr(if_node->false_branch); |
676 | } |
677 | |
678 | void VisitExpr_(const OpNode* op) final { |
679 | // no-op, primitive operators are handled at their call-sites. |
680 | } |
681 | |
682 | void VisitExpr_(const RefCreateNode* ref_create_node) final { |
683 | auto ref_create = GetRef<RefCreate>(ref_create_node); |
684 | auto domain = domains_->DomainFor(ref_create_node->value); // may be higher-order |
685 | domains_->UnifyExprCollapsed(ref_create, domain); // collapse to first-order if needed |
686 | VisitExpr(ref_create_node->value); |
687 | } |
688 | |
689 | void VisitExpr_(const RefReadNode* ref_read_node) final { |
690 | auto ref_read = GetRef<RefRead>(ref_read_node); |
691 | auto domain = domains_->DomainFor(ref_read); // may be higher-order |
692 | domains_->UnifyExprCollapsed(ref_read_node->ref, domain); // collapse to first-order if needed |
693 | VisitExpr(ref_read_node->ref); |
694 | } |
695 | |
696 | void VisitExpr_(const RefWriteNode* ref_write_node) final { |
697 | auto ref_write = GetRef<RefWrite>(ref_write_node); |
698 | auto domain = domains_->DomainFor(ref_write->value); // may be higher-order |
699 | domains_->UnifyExprCollapsed(ref_write->ref, domain); // collapse to first-order if needed |
700 | domains_->UnifyExprCollapsed(ref_write, domain); // collapse to first-order if needed |
701 | VisitExpr(ref_write_node->ref); |
702 | VisitExpr(ref_write_node->value); |
703 | } |
704 | |
705 | /*! \brief The module we are analyzing. */ |
706 | IRModule mod_; |
707 | /*! \brief The domains for all expressions processed so far. */ |
708 | std::unique_ptr<DeviceDomains> domains_; |
709 | }; |
710 | |
711 | /* =============== Phase 2 =============== */ |
712 | |
713 | /*! |
714 | * \brief Calls to 'free' "on_device" annotations (ie where both constrain_body=false and |
715 | * constrain_result=false) indicate a device_copy is allowed if required, but no particular |
716 | * device is imposed on the body or the context. At this stage we can attempt to unify the |
717 | * body and device contexts. In this way we can avoid the defaulting rules in \p DeviceDefaulter |
718 | * from choosing default devices which are only going to induce a device copy. |
719 | * |
720 | * TODO(mbs): The order in which we encounter the "on_device" calls can influence the final global |
721 | * device assignment. However we visit global functions in hash map order. |
722 | */ |
723 | class FreeOnDeviceDefaulter : public ExprVisitor { |
724 | public: |
725 | FreeOnDeviceDefaulter(IRModule mod, std::unique_ptr<DeviceDomains> domains) |
726 | : mod_(std::move(mod)), domains_(std::move(domains)) {} |
727 | |
728 | std::unique_ptr<DeviceDomains> Default() { |
729 | VLOG_CONTEXT << "FreeOnDeviceDefaulter" ; |
730 | VLOG(0) << "unifying free on_device annotations" ; |
731 | for (const auto& kv : mod_->functions) { |
732 | if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { |
733 | VLOG(2) << "unifying for '" << kv.first->name_hint << "'" ; |
734 | VisitExpr(GetRef<Function>(function_node)); |
735 | } else { |
736 | VLOG(2) << "skipping '" << kv.first->name_hint << "'" ; |
737 | } |
738 | } |
739 | return std::move(domains_); |
740 | } |
741 | |
742 | private: |
743 | void VisitExpr_(const CallNode* call_node) final { |
744 | auto call = GetRef<Call>(call_node); |
745 | OnDeviceProps props = GetOnDeviceProps(call_node); |
746 | ExprVisitor::VisitExpr_(call_node); |
747 | if (props.body.defined() && !props.constrain_body && !props.constrain_result) { |
748 | domains_->OptionalUnifyExprExact(call, props.body); |
749 | } |
750 | } |
751 | |
752 | /*! \brief The module we are processing. */ |
753 | IRModule mod_; |
754 | /*! \brief The domains for all expressions. */ |
755 | std::unique_ptr<DeviceDomains> domains_; |
756 | }; |
757 | |
758 | /*! |
759 | * \brief Ensures every sub-expression in a module has a device type, using both the global |
760 | * default and some local heuristics to avoid unnecessary additional "device_copy" CallNodes. |
761 | * |
762 | * E.g. in: |
763 | * \code |
764 | * def @main(%x, %y, %z) { |
765 | * let %a = add(%x, %y); |
766 | * multiply(%a, on_device(%z, virtual_device=d)) |
767 | * } |
768 | * \endcode |
769 | * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y, |
770 | * and the device for the function result, are still 'free'. The global 'default' device type |
771 | * is first used to 'fix' \p main's result type, which in turn 'fixes' \p %x and \p %y, which |
772 | * in turn 'fixes' the device on which the \p add and \p multiply are executed. |
773 | * |
774 | * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap |
775 | * order. |
776 | */ |
777 | class DeviceDefaulter : public ExprVisitor { |
778 | public: |
779 | DeviceDefaulter(IRModule mod, std::unique_ptr<DeviceDomains> domains) |
780 | : mod_(std::move(mod)), domains_(std::move(domains)) {} |
781 | |
782 | std::unique_ptr<DeviceDomains> Default() { |
783 | VLOG_CONTEXT << "DeviceDefaulter" ; |
784 | VLOG(0) << "defaulting to VirtualDevice " |
785 | << domains_->config()->default_primitive_virtual_device; |
786 | for (const auto& kv : mod_->functions) { |
787 | if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { |
788 | VLOG(2) << "defaulting devices for '" << kv.first->name_hint << "'" ; |
789 | VisitExpr(GetRef<Function>(function_node)); |
790 | } else { |
791 | VLOG(2) << "skipping '" << kv.first->name_hint << "'" ; |
792 | } |
793 | } |
794 | return std::move(domains_); |
795 | } |
796 | |
797 | private: |
798 | void VisitExpr_(const FunctionNode* function_node) final { |
799 | if (function_node->HasNonzeroAttr(attr::kPrimitive)) { |
800 | return; |
801 | } |
802 | |
803 | auto function = GetRef<Function>(function_node); |
804 | auto func_domain = domains_->DomainFor(function); // higher-order |
805 | ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); |
806 | if (!domains_->IsFullyConstrained(func_domain)) { |
807 | VLOG(2) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); |
808 | domains_->SetResultDefaultThenParams(func_domain, |
809 | domains_->config()->default_primitive_virtual_device); |
810 | VLOG(2) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); |
811 | } |
812 | VisitExpr(function_node->body); |
813 | } |
814 | |
815 | void VisitExpr_(const CallNode* call_node) final { |
816 | auto call = GetRef<Call>(call_node); |
817 | |
818 | // We don't care if the call is pre- or post-lowered. |
819 | auto vanilla_call = GetAnyCall(call_node); |
820 | |
821 | auto func_domain = domains_->DomainForCallee(call); // higher-order |
822 | ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()); |
823 | if (!domains_->IsFullyConstrained(func_domain)) { |
824 | // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) |
825 | // above. But for calls to primitives we may still need to force free domains to be |
826 | // defaulted. |
827 | VLOG(2) << "before defaulting callee:" << std::endl |
828 | << PrettyPrint(call_node->op) << std::endl |
829 | << "of domain:" << std::endl |
830 | << domains_->ToString(func_domain); |
831 | domains_->SetResultDefaultThenParams(func_domain, |
832 | domains_->config()->default_primitive_virtual_device); |
833 | VLOG(2) << "after defaulting callee:" << std::endl |
834 | << PrettyPrint(call_node->op) << std::endl |
835 | << "of domain:" << std::endl |
836 | << domains_->ToString(func_domain); |
837 | } |
838 | return ExprVisitor::VisitExpr_(call_node); |
839 | } |
840 | |
841 | void VisitExpr_(const LetNode* let_node) final { |
842 | Expr expr = GetRef<Let>(let_node); |
843 | // Iteratively visit let nodes to avoid stack overflow. |
844 | while (expr->IsInstance<LetNode>()) { |
845 | Let let = Downcast<Let>(expr); |
846 | // If the let-var device is still free force it to match the overall let. |
847 | auto let_domain = domains_->DomainFor(let); // may be higher-order |
848 | VirtualDevice let_virtual_device = domains_->ResultVirtualDevice(let_domain); |
849 | ICHECK(!let_virtual_device->IsFullyUnconstrained()); |
850 | auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order |
851 | if (!domains_->IsFullyConstrained(let_var_domain)) { |
852 | VLOG(2) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); |
853 | domains_->SetDefault(let_var_domain, let_virtual_device); |
854 | VLOG(2) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); |
855 | } |
856 | VisitExpr(let->var); |
857 | VisitExpr(let->value); |
858 | expr = let->body; |
859 | } |
860 | VisitExpr(expr); |
861 | } |
862 | |
863 | /*! \brief The module we are processing. */ |
864 | IRModule mod_; |
865 | /*! \brief The domains for all expressions. */ |
866 | std::unique_ptr<DeviceDomains> domains_; |
867 | }; |
868 | |
869 | /* =============== Phase 3 =============== */ |
870 | /*! |
871 | * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every |
872 | * sub-expression in a module can be easily recovered by a later transformation using simple |
873 | * lexical scoping rules (e.g. for memory planning). |
874 | * |
875 | * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard |
876 | * any existing "device_copy" CallNodes which are no-ops. |
877 | * |
878 | * - The result virtual device for a function is stored in the function's virtual_device_ field |
879 | * and the virtual devices of the function's parameters are stored in the parameter's |
880 | * virtual_device_ field. |
881 | * |
882 | * - Additional "device_copy" CallNodes are inserted wherever there's a transition between |
883 | * storage device types. Since the DeviceAnalyzer phase succeeded this can only happen |
884 | * where the original program explicitly allowed a transition using an "on_device" CallNode. |
885 | * That is, we do not not try to 'fix' a program with inconsistent devices. |
886 | * |
887 | * - Additional "on_device" CallNodes are inserted so that a later transform can discover |
888 | * the device for an arbitrary sub-expression by looking only for the lexically enclosing |
889 | * "on_device" CallNode or "on_device" function attribute. In particular, since function |
890 | * arguments and let-bound expressions can be on a device different from the function |
891 | * or let body itself we will insert "on_device" CallNodes to spell out any differences. This |
892 | * applies even to the argument to a "device_copy" CallNode, which may look pedantic but |
893 | * keeps downstream processing simple. The "on_device" calls should be removed before code gen, |
894 | * which is easily done on-the-fly. |
895 | * |
896 | * - Update memory scopes in PrimFunc buffer maps. |
897 | * |
898 | * For example, we'll end up with programs that look like: |
899 | * \code |
900 | * def @main(%x, %y, param_virtual_devices=[...], result_virtual_device=...) { |
901 | * let %a = on_device(..., virtual_device=..., is_fixed=True) |
902 | * @f(%a, device_copy(on_device(..., virtual_device=..., is_fixed=True), |
903 | * src_virtual_device=..., dst_virtual_device=...)) |
904 | * } |
905 | * \endcode |
906 | */ |
907 | class DeviceCapturer : public ExprMutator { |
908 | public: |
909 | DeviceCapturer(IRModule mod, std::unique_ptr<DeviceDomains> domains) |
910 | : mod_(std::move(mod)), domains_(std::move(domains)) {} |
911 | |
912 | IRModule Capture() { |
913 | VLOG_CONTEXT << "CaptureDevices" ; |
914 | IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map, |
915 | mod_->attrs); |
916 | for (const auto& kv : mod_->functions) { |
917 | if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { |
918 | VLOG(2) << "capturing devices for Relay Function '" << kv.first->name_hint << "'" ; |
919 | result->Add(kv.first, Downcast<Function>(Mutate(GetRef<Function>(function_node)))); |
920 | } else if (const auto* prim_func_node = kv.second.as<tir::PrimFuncNode>()) { |
921 | VLOG(2) << "capturing devices for TIR PrimFunc '" << kv.first->name_hint << "'" ; |
922 | auto prim_func = GetRef<tir::PrimFunc>(prim_func_node); |
923 | tir::PrimFunc new_prim_func = UpdatePrimFunc(kv.first, prim_func); |
924 | VLOG(2) << "Rewritten prim func:" << std::endl |
925 | << PrettyPrint(prim_func) << std::endl |
926 | << "to:" << std::endl |
927 | << PrettyPrint(new_prim_func); |
928 | result->Add(kv.first, std::move(new_prim_func)); |
929 | } else { |
930 | VLOG(2) << "skipping '" << kv.first->name_hint << "'" ; |
931 | result->Add(kv.first, kv.second); |
932 | } |
933 | } |
934 | return result; |
935 | } |
936 | |
937 | private: |
938 | /*! |
939 | * \brief Returns \p prim_func updated to capture any memory scope's implied by its device |
940 | * domain. |
941 | */ |
942 | tir::PrimFunc UpdatePrimFunc(const GlobalVar& global_var, const tir::PrimFunc& prim_func) { |
943 | // CAUTION: Same caution as for DeviceAnalyzer::DomainForPrimFunc. |
944 | auto func_domain = domains_->DomainFor(global_var); |
945 | ICHECK(func_domain->is_higher_order()); |
946 | |
947 | const auto* func_type_node = global_var->checked_type().as<FuncTypeNode>(); |
948 | ICHECK(func_type_node); |
949 | ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size()); |
950 | |
951 | std::vector<VirtualDevice> arg_and_result_virtual_devices; |
952 | arg_and_result_virtual_devices.reserve(func_type_node->arg_types.size() + 1); |
953 | for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) { |
954 | VirtualDevice param_virtual_device = |
955 | domains_->ResultVirtualDevice(func_domain->function_param(i)); |
956 | VLOG(2) << "param_virtual_device[" << i << "] = " << param_virtual_device; |
957 | arg_and_result_virtual_devices.push_back(param_virtual_device); |
958 | } |
959 | VirtualDevice ret_virtual_device = |
960 | domains_->ResultVirtualDevice(func_domain->function_result()); |
961 | VLOG(2) << "ret_virtual_device = " << ret_virtual_device; |
962 | arg_and_result_virtual_devices.push_back(ret_virtual_device); |
963 | |
964 | return tir::ApplyPrimFuncArgAndResultConstraints(prim_func, GetRef<FuncType>(func_type_node), |
965 | arg_and_result_virtual_devices); |
966 | } |
967 | |
968 | // Nothing interesting for VarNode, ConstantNode, GlobalVarNode, OpNode and ConstructorNode |
969 | |
970 | Expr VisitExpr_(const TupleNode* tuple_node) final { |
971 | auto tuple = GetRef<Tuple>(tuple_node); |
972 | Array<Expr> fields; |
973 | fields.reserve(tuple_node->fields.size()); |
974 | for (const auto& field : tuple_node->fields) { |
975 | fields.push_back(VisitChild(tuple, field)); |
976 | } |
977 | return WithFields(tuple, fields); |
978 | } |
979 | |
980 | Expr VisitExpr_(const FunctionNode* function_node) final { |
981 | if (function_node->HasNonzeroAttr(attr::kPrimitive)) { |
982 | return GetRef<Function>(function_node); |
983 | } |
984 | |
985 | auto function = GetRef<Function>(function_node); |
986 | auto func_domain = domains_->DomainFor(function); // higher-order |
987 | VLOG(2) << "capturing function:" << std::endl |
988 | << PrettyPrint(function) << std::endl |
989 | << "with domain:" << std::endl |
990 | << domains_->ToString(func_domain); |
991 | |
992 | // Gather the parameter and result device types for the function attributes. |
993 | ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); |
994 | VirtualDevice result_virtual_device = domains_->ResultVirtualDevice(func_domain); |
995 | ICHECK(!result_virtual_device->IsFullyUnconstrained()); |
996 | |
997 | // Map the function parameters to a new variable annotated with a virtual device so |
998 | // we can substitute them later. |
999 | Map<Var, Expr> annotated_bind_map; |
1000 | Array<Var> annotated_params; |
1001 | annotated_params.reserve(function_node->params.size()); |
1002 | for (size_t i = 0; i < function_node->params.size(); ++i) { |
1003 | VirtualDevice param_virtual_device = |
1004 | domains_->ResultVirtualDevice(func_domain->function_param(i)); |
1005 | VLOG(4) << "Param: " << function_node->params[i]; |
1006 | Var annotated_var = WithFields(function_node->params[i], {}, {}, param_virtual_device); |
1007 | VLOG(4) << "Annotated param: " << annotated_var; |
1008 | VLOG(4) << "VirtualDevice: " << annotated_var->virtual_device(); |
1009 | ICHECK(!param_virtual_device->IsFullyUnconstrained()); |
1010 | annotated_bind_map.Set(function_node->params[i], annotated_var); |
1011 | annotated_params.push_back(annotated_var); |
1012 | } |
1013 | // Eventually we probably want to bind before visiting, but for now this is causing an issue |
1014 | // with the GetVirtualDevice utility, so leaving as is for now. |
1015 | |
1016 | // Rewrite the body. Note that the body may have begun with an "on_device" so |
1017 | // be prepared to insert a "device_copy". |
1018 | Expr body = VisitChild( |
1019 | /*lexical_virtual_device=*/result_virtual_device, |
1020 | /*expected_virtual_device=*/result_virtual_device, |
1021 | /*child_virtual_device=*/GetVirtualDevice(function_node->body), function_node->body); |
1022 | VLOG(4) << "Visited body: " << body; |
1023 | Function func = WithFields(GetRef<Function>(function_node), function_node->params, body); |
1024 | VLOG(4) << "New function: " << func; |
1025 | func = SubstituteBoundVars(func, annotated_bind_map); |
1026 | VLOG(4) << "Func with bound params: " << func; |
1027 | func->virtual_device_ = result_virtual_device; |
1028 | VLOG(4) << "Func with bound params & result vid set: " << func; |
1029 | return std::move(func); |
1030 | } |
1031 | |
1032 | Expr VisitExpr_(const CallNode* call_node) final { |
1033 | auto call = GetRef<Call>(call_node); |
1034 | |
1035 | // We don't care if the call is pre- or post-lowered |
1036 | // (However we'll preserve the form in the result below.) |
1037 | auto vanilla_call = GetAnyCall(call_node); |
1038 | |
1039 | VirtualDevice call_virtual_device = GetVirtualDevice(call); |
1040 | |
1041 | auto on_device_props = GetOnDeviceProps(call_node); |
1042 | if (on_device_props.body.defined()) { |
1043 | // We're done with the original "on_device" calls and can pinch them out. |
1044 | // Note that this step has already been simulated by GetDeviceType. |
1045 | return VisitExpr(on_device_props.body); |
1046 | } |
1047 | |
1048 | DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); |
1049 | if (device_copy_props.body.defined()) { |
1050 | VirtualDevice src_virtual_device = |
1051 | domains_->config()->CanonicalVirtualDevice(device_copy_props.src_virtual_device); |
1052 | VirtualDevice dst_virtual_device = |
1053 | domains_->config()->CanonicalVirtualDevice(device_copy_props.dst_virtual_device); |
1054 | ICHECK_EQ(call_virtual_device, dst_virtual_device); |
1055 | if (src_virtual_device == dst_virtual_device) { |
1056 | // We can pinch out existing "device_copy" CallNodes if their source and destinations |
1057 | // match. |
1058 | return VisitExpr(device_copy_props.body); |
1059 | } else { |
1060 | return VisitChild(/*lexical_virtual_device=*/dst_virtual_device, |
1061 | /*expected_virtual_device=*/dst_virtual_device, |
1062 | /*child_virtual_device=*/src_virtual_device, device_copy_props.body); |
1063 | } |
1064 | } |
1065 | |
1066 | // Generic call. |
1067 | auto func_domain = domains_->DomainForCallee(call); // higher-order |
1068 | VLOG(2) << "considering call:" << std::endl |
1069 | << PrettyPrint(call) << std::endl |
1070 | << "in virtual device " << call_virtual_device |
1071 | << " with function virtual devices:" << std::endl |
1072 | << domains_->ToString(func_domain); |
1073 | VirtualDevice result_virtual_device = domains_->ResultVirtualDevice(func_domain); |
1074 | ICHECK(!result_virtual_device->IsFullyUnconstrained()); |
1075 | |
1076 | // The callee is on the current device. |
1077 | Expr op = VisitChild( |
1078 | /*lexical_virtual_device=*/call_virtual_device, |
1079 | /*expected_virtual_device=*/call_virtual_device, |
1080 | /*child_virtual_device=*/result_virtual_device, vanilla_call->op); |
1081 | |
1082 | // Each argument can be on the device for the corresponding function parameter. However if |
1083 | // any of those differ from the overall call device then wrap them in an "on_device" to |
1084 | // help downstream transforms track devices lexically. |
1085 | Array<Expr> args; |
1086 | args.reserve(vanilla_call->args.size()); |
1087 | ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()); |
1088 | for (size_t i = 0; i < vanilla_call->args.size(); ++i) { |
1089 | VirtualDevice param_virtual_device = |
1090 | domains_->ResultVirtualDevice(func_domain->function_param(i)); |
1091 | ICHECK(!param_virtual_device->IsFullyUnconstrained()) |
1092 | << "for parameter " << i << " for call:" << std::endl |
1093 | << PrettyPrint(call); |
1094 | args.push_back(VisitChild(/*lexical_virtual_device=*/call_virtual_device, |
1095 | /*expected_virtual_device=*/param_virtual_device, |
1096 | /*child_virtual_device=*/GetVirtualDevice(vanilla_call->args[i]), |
1097 | vanilla_call->args[i])); |
1098 | } |
1099 | |
1100 | if (call_node->op == CallLoweredOp()) { |
1101 | Call new_call = |
1102 | CallLowered(Downcast<GlobalVar>(op), args, /*call_lowered_attrs=*/{}, /*span=*/{}); |
1103 | return WithFields(call, new_call->op, new_call->args); |
1104 | } else { |
1105 | return WithFields(call, op, args); |
1106 | } |
1107 | } |
1108 | |
1109 | Expr VisitExpr_(const LetNode* let_node) final { |
1110 | Expr expr = GetRef<Expr>(let_node); |
1111 | // Iterate through chained lets, provided they all agree on their device type. |
1112 | VirtualDevice let_virtual_device = GetVirtualDevice(expr); |
1113 | std::vector<std::tuple<Var, Expr, Span>> bindings; |
1114 | while (const auto* inner_let_node = expr.as<LetNode>()) { |
1115 | Expr inner_let = GetRef<Let>(inner_let_node); |
1116 | if (GetVirtualDevice(inner_let) != let_virtual_device) { |
1117 | // We have a device transition which needs to be handled. |
1118 | break; |
1119 | } |
1120 | // The let-bound value can be on a different device than the overall let. |
1121 | // By using the fully-unconstrained virtual device for the 'lexical' scope we'll force the |
1122 | // let-bound value to *always* be wrapped by an "on_device" (see introductory comment for |
1123 | // motivation.) |
1124 | Expr value = VisitChild(/*lexical_virtual_device=*/VirtualDevice::FullyUnconstrained(), |
1125 | /*expected_virtual_device=*/GetVirtualDevice(inner_let_node->var), |
1126 | /*child_virtual_device=*/GetVirtualDevice(inner_let_node->value), |
1127 | inner_let_node->value); |
1128 | bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); |
1129 | expr = inner_let_node->body; |
1130 | } |
1131 | Expr body = VisitChild(/*lexical_virtual_device=*/let_virtual_device, |
1132 | /*expected_virtual_device=*/let_virtual_device, |
1133 | /*child_virtual_device=*/GetVirtualDevice(expr), expr); |
1134 | for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { |
1135 | body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body, |
1136 | /*span=*/std::get<2>(*itr)); |
1137 | } |
1138 | return body; |
1139 | } |
1140 | |
1141 | Expr VisitExpr_(const IfNode* if_node) final { |
1142 | auto ife = GetRef<If>(if_node); |
1143 | Expr cond = VisitChild(ife, if_node->cond); |
1144 | Expr true_branch = VisitChild(ife, if_node->true_branch); |
1145 | Expr false_branch = VisitChild(ife, if_node->false_branch); |
1146 | return WithFields(ife, cond, true_branch, false_branch); |
1147 | } |
1148 | |
1149 | Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { |
1150 | auto tuple_get_item = GetRef<TupleGetItem>(tuple_get_item_node); |
1151 | Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple); |
1152 | return WithFields(tuple_get_item, tuple); |
1153 | } |
1154 | |
1155 | Expr VisitExpr_(const RefCreateNode* ref_create_node) final { |
1156 | auto ref_create = GetRef<RefCreate>(ref_create_node); |
1157 | Expr value = VisitChild(ref_create, ref_create_node->value); |
1158 | return WithFields(ref_create, value); |
1159 | } |
1160 | |
1161 | Expr VisitExpr_(const RefReadNode* ref_read_node) final { |
1162 | auto ref_read = GetRef<RefRead>(ref_read_node); |
1163 | Expr ref = VisitChild(ref_read, ref_read_node->ref); |
1164 | return WithFields(ref_read, ref); |
1165 | } |
1166 | |
1167 | Expr VisitExpr_(const RefWriteNode* ref_write_node) final { |
1168 | auto ref_write = GetRef<RefWrite>(ref_write_node); |
1169 | Expr ref = VisitChild(ref_write, ref_write_node->ref); |
1170 | Expr value = VisitChild(ref_write, ref_write_node->value); |
1171 | return WithFields(ref_write, ref, value); |
1172 | } |
1173 | |
1174 | Expr VisitExpr_(const MatchNode* match_node) final { |
1175 | auto match = GetRef<Match>(match_node); |
1176 | Expr data = VisitChild(match, match_node->data); |
1177 | Array<Clause> clauses; |
1178 | clauses.reserve(match_node->clauses.size()); |
1179 | for (const auto& clause : match_node->clauses) { |
1180 | Pattern lhs = VisitPattern(clause->lhs); // actually a no-op, so we're not checking vars |
1181 | Expr rhs = VisitChild(match, clause->rhs); |
1182 | clauses.push_back(Clause(lhs, rhs)); |
1183 | } |
1184 | return WithFields(match, data, clauses); |
1185 | } |
1186 | |
1187 | VirtualDevice GetVirtualDevice(const Expr& expr) { |
1188 | // Look through any "on_device" CallNodes, to mimic how we will be pinching them out. |
1189 | OnDeviceProps props = GetOnDeviceProps(expr); |
1190 | Expr true_expr = props.body.defined() ? props.body : expr; |
1191 | ICHECK(domains_->contains(true_expr)); |
1192 | // If expr is higher order we'll return only the result domain's device. |
1193 | VirtualDevice virtual_device = domains_->ResultVirtualDevice(domains_->DomainFor(true_expr)); |
1194 | ICHECK(!virtual_device->IsFullyUnconstrained()) |
1195 | << "no VirtualDevice was determined for expression:" << std::endl |
1196 | << PrettyPrint(true_expr); |
1197 | return std::move(virtual_device); |
1198 | } |
1199 | |
1200 | /*! |
1201 | * \brief Reconcile the \p child_virtual_device for \p child with both the \p |
1202 | * expected_virtual_device (as required by the expression context the \p child is in) and the \p |
1203 | * lexical_virtual_device (as a downstream transform would infer based only on lexically enclosing |
1204 | * "on_device" CallNodes and function attributes.) Generally \p lexical_virtual_device and \p |
1205 | * expected_virtual_device are the same by definition, but may differ in arguments to functions |
1206 | * and let-bound expressions. |
1207 | * |
1208 | * If \p child_virtual_device differs from \p expected_virtual_device, wrap it as: |
1209 | * \code |
1210 | * device_copy(on_device(child', virtual_device=child_virtual_device), |
1211 | * src_dev_type=child_virtual_device, dst_dev_type=expected_virtual_device) |
1212 | * \endcode |
1213 | * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the |
1214 | * child. |
1215 | * |
1216 | * If \p expected_virtual_device differs from \p lexical_virtual_device, then (also) wrap |
1217 | * the expression as: |
1218 | * \code |
1219 | * on_device(..., virtual_device=expected_virtual_device) |
1220 | * \endcode |
1221 | * |
1222 | * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped |
1223 | * by a "device_copy", even though those copies will generally all be to the same destination |
1224 | * device. |
1225 | */ |
1226 | Expr VisitChild(const VirtualDevice& lexical_virtual_device, |
1227 | const VirtualDevice& expected_virtual_device, |
1228 | const VirtualDevice& child_virtual_device, const Expr& child) { |
1229 | ICHECK(!expected_virtual_device->IsFullyUnconstrained()); |
1230 | if (child->IsInstance<OpNode>() || child->IsInstance<ConstructorNode>()) { |
1231 | // Primitive operators and contructors don't need to be rewritten and can have a |
1232 | // different domain at each call site. |
1233 | return child; |
1234 | } |
1235 | Expr result = VisitExpr(child); |
1236 | if (child_virtual_device != expected_virtual_device) { |
1237 | VLOG(2) << "creating " << DeviceCopyOp()->name << " from virtual device " |
1238 | << child_virtual_device << " to virtual device " << expected_virtual_device |
1239 | << " for:" << std::endl |
1240 | << PrettyPrint(result); |
1241 | // Also wrap the child in an "on_device" so downstream transforms can track devices |
1242 | // lexically. |
1243 | result = MaybeOnDeviceFixed(result, child_virtual_device); |
1244 | result = DeviceCopy(result, child_virtual_device, expected_virtual_device); |
1245 | } |
1246 | if (expected_virtual_device != lexical_virtual_device) { |
1247 | VLOG(2) << "creating " << OnDeviceOp()->name << " for virtual device " |
1248 | << expected_virtual_device << " for:" << std::endl |
1249 | << PrettyPrint(result); |
1250 | result = MaybeOnDeviceFixed(result, expected_virtual_device); |
1251 | } |
1252 | return result; |
1253 | } |
1254 | |
1255 | /*! |
1256 | * Common case of visiting a direct \p child of \p parent where by default the \p child |
1257 | * is expected to be on the same device as the \p parent. |
1258 | */ |
1259 | Expr VisitChild(const Expr& parent, const Expr& child) { |
1260 | VirtualDevice expected_virtual_device = GetVirtualDevice(parent); |
1261 | VirtualDevice child_virtual_device = GetVirtualDevice(child); |
1262 | return VisitChild(expected_virtual_device, expected_virtual_device, child_virtual_device, |
1263 | child); |
1264 | } |
1265 | |
1266 | /*! \brief Module we are rewriting, so we can lookup global variables. */ |
1267 | IRModule mod_; |
1268 | /*! \brief Device domain for every expression from DeviceAnalyzer. */ |
1269 | std::unique_ptr<DeviceDomains> domains_; |
1270 | }; |
1271 | |
1272 | /*! \brief Rewrite the "on_device" calls (and implicitly re-type-check). */ |
1273 | tvm::transform::Pass Rewrite() { |
1274 | auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { |
1275 | auto attrs = m->attrs; |
1276 | auto r = Downcast<Function>(RewriteOnDevices(std::move(m)).Mutate(f)); |
1277 | return attrs.defined() ? WithAttrs(r, {attrs->dict}) : r; |
1278 | }; |
1279 | return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite" , {}); |
1280 | } |
1281 | |
1282 | /*! \brief Run the remaining phases. */ |
1283 | tvm::transform::Pass PlanDevicesCore(CompilationConfig config) { |
1284 | return tvm::transform::CreateModulePass( |
1285 | [config = std::move(config)](IRModule mod, |
1286 | tvm::transform::PassContext pass_cnxt) -> IRModule { |
1287 | // Collect the system of constraints for every sub-expression using existing "on_device" |
1288 | // and "device_copy" calls. |
1289 | std::unique_ptr<DeviceDomains> domains = DeviceAnalyzer(mod, config).Analyze(); |
1290 | VLOG(3) << "Domains after analysis:" << std::endl << domains->ToString(); |
1291 | |
1292 | // Choose sensible default devices for every sub-expression if otherwise unconstrained |
1293 | // by existing "on_device" or "device_copy" calls. |
1294 | domains = FreeOnDeviceDefaulter(mod, std::move(domains)).Default(); |
1295 | domains = DeviceDefaulter(mod, std::move(domains)).Default(); |
1296 | VLOG(3) << "Domains after defaulting: " << std::endl << domains->ToString(); |
1297 | |
1298 | // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture |
1299 | // the above map, and attach additional "param_virtual_devices" and "result_virtual_device" |
1300 | // attributes to all function definitions. |
1301 | return DeviceCapturer(mod, std::move(domains)).Capture(); |
1302 | }, |
1303 | /*opt_level=*/0, "PlanDevicesCore" , {}); |
1304 | } |
1305 | |
1306 | } // namespace |
1307 | |
1308 | /* =============== Driver =============== */ |
1309 | |
1310 | // This function is declared in the public <tvm/relay/transform.h>. |
1311 | tvm::transform::Pass PlanDevices(CompilationConfig config) { |
1312 | std::vector<Pass> passes; |
1313 | passes.emplace_back(Rewrite()); |
1314 | passes.emplace_back(PlanDevicesCore(std::move(config))); |
1315 | return tvm::transform::Sequential(passes, "PlanDevices" ); |
1316 | } |
1317 | |
1318 | TVM_REGISTER_GLOBAL("relay._transform.PlanDevices" ).set_body_typed(PlanDevices); |
1319 | |
1320 | } // namespace transform |
1321 | } // namespace relay |
1322 | } // namespace tvm |
1323 | |