1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/analysis/device_domains.h |
22 | * \brief Unification domain for the device planner. |
23 | */ |
24 | |
25 | #ifndef TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ |
26 | #define TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ |
27 | |
28 | #include <dlpack/dlpack.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/type.h> |
31 | #include <tvm/runtime/ndarray.h> |
32 | #include <tvm/target/compilation_config.h> |
33 | #include <tvm/target/virtual_device.h> |
34 | |
35 | #include <memory> |
36 | #include <string> |
37 | #include <unordered_map> |
38 | #include <utility> |
39 | #include <vector> |
40 | |
41 | namespace tvm { |
42 | namespace relay { |
43 | namespace transform { |
44 | |
45 | class DeviceDomain; |
46 | using DeviceDomainPtr = std::shared_ptr<DeviceDomain>; |
47 | class DeviceDomains; |
48 | |
49 | /*! |
50 | * \brief Represents the domain over which we collect equality constraints. |
51 | * |
52 | * \code |
53 | * D ::= ?x? -- first order, free |
54 | * | <virtual_device> -- first order, bound to specific virtual device |
55 | * | fn(D1, ..., Dn):Dr -- higher order |
56 | * \endcode |
57 | * |
58 | * We require a function value to be on the same device as its result. To support that we need |
59 | * a notion of the 'result domain' of a domain: |
60 | * \code |
61 | * result_domain(?x?) = ?x? |
62 | * result_domain(<virtual_device>) = <virtual_device> |
63 | * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) |
64 | * \endcode |
65 | * |
66 | * TODO(mbs): We currently don't allow sub-VirtualDevice constraints. Eg for a function we can |
67 | * express that the argument and result VirtualDevices must be exactly equal, but we cannot express |
68 | * that though the devices and targets for arguments and results must be equal, it is ok for |
69 | * memory scopes to differ. At the moment we can get away with this since we run PlanDevices |
70 | * twice: once with all memory scopes unconstrained, then again with just memory scopes as |
71 | * the new property to flow. However we're on thin ice here and better would be to allow |
72 | * constraints on VirtualDevices to be exploded into their device/target component and their |
73 | * memory scope component. Should we fold layout constraints into VirtualDevices then they would |
74 | * probably be grouped with memory scopes. |
75 | */ |
76 | class DeviceDomain { |
77 | public: |
78 | /*! |
79 | * \brief Constructs a first-order domain for \p virtual_device, which may be |
80 | * fully free (ie virtual_device is unconstrained), partially free (ie virtual_device has at |
81 | * least on of its target, device id or memory scopes known), or fully fixed (ie virtual_device |
82 | * has its target, device id and memory scopes set). |
83 | * |
84 | * CAUTION: Use DeviceDomains::MakeFirstOrderDomain instead of this ctor. |
85 | */ |
86 | explicit DeviceDomain(VirtualDevice virtual_device) |
87 | : virtual_device_(std::move(virtual_device)) {} |
88 | |
89 | /*! |
90 | * \brief Constructs a higher-order domain, where \p args_and_result contain the |
91 | * function argument and result domains in order. |
92 | * |
93 | * CAUTION: Use DeviceDomains::MakeHigherOrderDomain instead of this ctor. |
94 | */ |
95 | explicit DeviceDomain(std::vector<DeviceDomainPtr> args_and_result) |
96 | : virtual_device_(VirtualDevice::FullyUnconstrained()), |
97 | args_and_result_(std::move(args_and_result)) {} |
98 | |
99 | bool is_higher_order() const { return !args_and_result_.empty(); } |
100 | |
101 | VirtualDevice first_order_virtual_device() const { |
102 | ICHECK(args_and_result_.empty()) << "expecting domain to be first-order" ; |
103 | return virtual_device_; |
104 | } |
105 | |
106 | size_t function_arity() const { |
107 | ICHECK(!args_and_result_.empty()) << "expecting domain to be higher-order" ; |
108 | return args_and_result_.size() - 1UL; |
109 | } |
110 | |
111 | DeviceDomainPtr function_param(size_t i) const { |
112 | ICHECK(!args_and_result_.empty()) << "expecting domain to be higher-order" ; |
113 | ICHECK_LT(i + 1, args_and_result_.size()) << "parameter index is out of range" ; |
114 | return args_and_result_[i]; |
115 | } |
116 | |
117 | DeviceDomainPtr function_result() const { |
118 | ICHECK(!args_and_result_.empty()); |
119 | return args_and_result_.back(); |
120 | } |
121 | |
122 | private: |
123 | /*! |
124 | * \brief If this is a function domain then always fully unconstrained. Otherwise will be |
125 | * fully unconstrained (the domain is still completely free), partially constrained |
126 | * (for example, the \p target and \p device_type are constrained but the \p virtual_device_id and |
127 | * \p memory_scope are still unconstrained), or fully constrained (everything is known). |
128 | */ |
129 | const VirtualDevice virtual_device_; |
130 | |
131 | /*! |
132 | * \brief If this is a function domain then the sub-domains for each of the function's |
133 | * arguments, and the domain for its result. Otherwise empty. |
134 | */ |
135 | const std::vector<DeviceDomainPtr> args_and_result_; |
136 | |
137 | friend class DeviceDomains; |
138 | }; |
139 | |
140 | /*! |
141 | * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation |
142 | * built up by calls to \p UnifyOrNull. |
143 | */ |
144 | class DeviceDomains { |
145 | public: |
146 | explicit DeviceDomains(CompilationConfig config); |
147 | |
148 | const CompilationConfig& config() const { return config_; } |
149 | |
150 | /*! |
151 | * \brief Returns the domain representing \p virtual_device. If \p virtual_device is fully |
152 | * constrained then the domain will be unique that \p virtual_device. |
153 | */ |
154 | DeviceDomainPtr MakeFirstOrderDomain(const VirtualDevice& virtual_device); |
155 | |
156 | /*! |
157 | * \brief Returns a higher-order domain with \p args_and_results. |
158 | */ |
159 | DeviceDomainPtr MakeHigherOrderDomain(std::vector<DeviceDomainPtr> arg_and_results) { |
160 | return std::make_shared<DeviceDomain>(std::move(arg_and_results)); |
161 | } |
162 | |
163 | /*! |
164 | * \brief Returns a domain appropriate for \p type who's result domain is bound to \p |
165 | * virtual_device. If \p type is a function then all parameter domains will be completely free. It |
166 | * is valid for \p virtual_device to be fully unconstrained. |
167 | */ |
168 | DeviceDomainPtr MakeDomain(const Type& type, const VirtualDevice& virtual_device); |
169 | |
170 | /*! |
171 | * \brief Returns a domain with the given result appropriate \p non_canonical_virtual_device, |
172 | * which cannot be fully unconstrained. We first canonicalize the virtual device to unsure it has |
173 | * a target and is unique. |
174 | */ |
175 | DeviceDomainPtr ForVirtualDevice(const Type& type, |
176 | const VirtualDevice& non_canonical_virtual_device); |
177 | |
178 | /*! \brief Returns a free domain appropriate for \p type. */ |
179 | DeviceDomainPtr Free(const Type& type) { |
180 | return MakeDomain(type, VirtualDevice::FullyUnconstrained()); |
181 | } |
182 | |
183 | /*! \brief Returns the domain representing the equivalence class containing \p domain. */ |
184 | DeviceDomainPtr Lookup(DeviceDomainPtr domain); |
185 | |
186 | /*! |
187 | * \brief Returns the most constrained domain which agrees with both \p lhs and \p rhs. Returns |
188 | * null if no such domain exists, ie some first-order component of \p lhs is constrained |
189 | * differently than the corresponding component of \p rhs. |
190 | */ |
191 | DeviceDomainPtr JoinOrNull(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); |
192 | |
193 | /*! |
194 | * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Returns null if |
195 | * \p lhs and \p rhs are not unifiable, in which case the constraint system may be left in |
196 | * a partially modified state. |
197 | */ |
198 | // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but |
199 | // given we have refs to functions I'm prepared to be surprised. |
200 | DeviceDomainPtr UnifyOrNull(DeviceDomainPtr lhs, DeviceDomainPtr rhs); |
201 | |
202 | /* |
203 | * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. |
204 | * This can be used to handle functions within tuples, references and ADTs since we don't |
205 | * attempt to track anything beyond 'the device' for expressions of those first-order types. |
206 | * |
207 | * Returns false if any unification fails. |
208 | */ |
209 | bool CollapseOrFalse(const DeviceDomainPtr& first_order_domain, |
210 | const DeviceDomainPtr& higher_order_domain); |
211 | |
212 | /*! |
213 | * \brief Unifies \p lhs_first_order and \p rhs_maybe_higher_order. If \p rhs_maybe_higher_order |
214 | * is indeed higher-order, require all of its arguments and result to unify with |
215 | * \p lhs_first_order. Otherwise same as \p Unify. Returns false if unification is not possible. |
216 | * |
217 | * In an expression such as: |
218 | * \code |
219 | * (fn(...) {...}, ...).0 |
220 | * \endcode |
221 | * we need to force all the devices of the inner function to be the same as the device for the |
222 | * overall tuple since the device domain does not understand tuples. Similarly for references |
223 | * and ADTs. |
224 | */ |
225 | bool UnifyCollapsedOrFalse(const DeviceDomainPtr& lhs_first_order, |
226 | const DeviceDomainPtr& rhs_maybe_higher_order); |
227 | |
228 | /*! \brief Returns true if a domain is known for \p expr. */ |
229 | bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); } |
230 | |
231 | /*! \brief Returns the domain representing \p expr. */ |
232 | DeviceDomainPtr DomainFor(const Expr& expr); |
233 | |
234 | /*! |
235 | * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the |
236 | * callee is a primitive or special operation we handle it specially. Otherwise defers to \p |
237 | * DomainFor(call->op). |
238 | * |
239 | * This special handling is needed: |
240 | * - To handle the "on_device" and "device_copy" ops which constrain devices to the given |
241 | * devices. |
242 | * - To handle some special ops which constrain devices to the CPU. |
243 | * - To allow the same primitive to be called on different devices at different call sites. |
244 | * Since each call to the op can have a different domain we index the ops by the call expression |
245 | * rather than the op itself. |
246 | */ |
247 | DeviceDomainPtr DomainForCallee(const Call& call); |
248 | |
249 | /*! |
250 | * \brief Unifies the domains for expressions \p lhs and \p rhs. |
251 | * |
252 | * Aborts if unification fails. |
253 | */ |
254 | void UnifyExprExact(const Expr& lhs, const Expr& rhs); |
255 | |
256 | /*! |
257 | * \brief Attempts to unify the domains for expressions \p lhs and \p rhs, however if they |
258 | * cannot be unified then returns with no change to the unification system. |
259 | */ |
260 | void OptionalUnifyExprExact(const Expr& lhs, const Expr& rhs); |
261 | |
262 | /*! |
263 | * \brief Unifies the domain for \p expr with \p expected_domain. |
264 | * |
265 | * Aborts if unification fails. |
266 | */ |
267 | void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain); |
268 | |
269 | /*! |
270 | * \brief Unifies the domain for \p expr with \p expected_domain. |
271 | * If \p expected_domain is higher-order but \p expr is first-order, require all arguments |
272 | * and the result of \p expected_domain to have the same domain as for \p expr. |
273 | * |
274 | * Aborts if unification fails. |
275 | */ |
276 | void UnifyExprCollapsed(const Expr& expr_first_order, |
277 | const DeviceDomainPtr& expected_domain_maybe_higher_order); |
278 | |
279 | /*! \brief Returns true if \p domain is fully constrainted. */ |
280 | bool IsFullyConstrained(DeviceDomainPtr domain); |
281 | |
282 | /*! \brief Force all \p VirtualDevices in \p domain to default to \p default_virtual_device. */ |
283 | void SetDefault(DeviceDomainPtr domain, const VirtualDevice& default_virtual_device); |
284 | |
285 | /*! |
286 | * \brief If \p domain is higher-order default it's result domain to \p default_virtual_device. |
287 | * Then force all remaining \p VirtualDevices to the result domain (freshly defaulted or |
288 | * original). If \p domain is first-order same as \p SetDefault. |
289 | */ |
290 | void SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order, |
291 | const VirtualDevice& default_virtual_device); |
292 | |
293 | /*! |
294 | * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment). |
295 | */ |
296 | DeviceDomainPtr ResultDomain(DeviceDomainPtr domain); |
297 | |
298 | /*! |
299 | * \brief Returns the result \p VirtualDevice (possibly unconstrained) for \p domain |
300 | * (see defn in DeviceDomain comment). |
301 | */ |
302 | VirtualDevice ResultVirtualDevice(const DeviceDomainPtr& domain) { |
303 | return ResultDomain(domain)->first_order_virtual_device(); |
304 | } |
305 | |
306 | /*! \brief Returns one-line description of \p domain for debugging. */ |
307 | std::string ToString(DeviceDomainPtr domain); |
308 | |
309 | /*! \brief Returns description of entire system of constraints for debugging */ |
310 | std::string ToString(); |
311 | |
312 | private: |
313 | /*! \brief Intrinsics we need to handle specially. */ |
314 | const Op& alloc_storage_op = Op::Get("memory.alloc_storage" ); |
315 | const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor" ); |
316 | const Op& shape_of_op = Op::Get("vm.shape_of" ); |
317 | const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op" ); |
318 | const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor" ); |
319 | |
320 | CompilationConfig config_; |
321 | |
322 | /*! |
323 | * \brief The domain for first-order expressions of non-tensor type, such as shapes and |
324 | * buffer dimensions. Generally this will be a CPU. |
325 | */ |
326 | DeviceDomainPtr host_domain_; |
327 | |
328 | /*! \brief Maps expressions to their domains as determined during analysis. */ |
329 | std::unordered_map<const ExprNode*, DeviceDomainPtr> expr_to_domain_; |
330 | |
331 | /*! |
332 | * \brief Maps call expressions to the domains for their callee where the callee is a primitive. |
333 | */ |
334 | std::unordered_map<const CallNode*, DeviceDomainPtr> call_to_callee_domain_; |
335 | |
336 | /*! \brief Maps device domains to their equivalent domains as determined during unification. */ |
337 | std::unordered_map<DeviceDomainPtr, DeviceDomainPtr> domain_to_equiv_; |
338 | |
339 | /*! |
340 | * \brief Maps fully constrained \p VirtualDevices to their corresponding domains. By sharing |
341 | * those domains we can ensure: |
342 | * |
343 | * \code |
344 | * domain0 != domain1 && domain0 fully constrained && domain1 fully constrained |
345 | * ==> domain0 and domain1 are incompatible |
346 | * \endcode |
347 | */ |
348 | std::unordered_map<VirtualDevice, DeviceDomainPtr, runtime::ObjectPtrHash, |
349 | runtime::ObjectPtrEqual> |
350 | fully_constrained_virtual_device_to_domain_; |
351 | }; |
352 | |
353 | } // namespace transform |
354 | } // namespace relay |
355 | } // namespace tvm |
356 | |
357 | #endif // TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ |
358 | |