1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/analysis/device_domains.h
22 * \brief Unification domain for the device planner.
23 */
24
25#ifndef TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_
26#define TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_
27
28#include <dlpack/dlpack.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/type.h>
31#include <tvm/runtime/ndarray.h>
32#include <tvm/target/compilation_config.h>
33#include <tvm/target/virtual_device.h>
34
35#include <memory>
36#include <string>
37#include <unordered_map>
38#include <utility>
39#include <vector>
40
41namespace tvm {
42namespace relay {
43namespace transform {
44
45class DeviceDomain;
46using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;
47class DeviceDomains;
48
49/*!
50 * \brief Represents the domain over which we collect equality constraints.
51 *
52 * \code
53 * D ::= ?x? -- first order, free
54 * | <virtual_device> -- first order, bound to specific virtual device
55 * | fn(D1, ..., Dn):Dr -- higher order
56 * \endcode
57 *
58 * We require a function value to be on the same device as its result. To support that we need
59 * a notion of the 'result domain' of a domain:
60 * \code
61 * result_domain(?x?) = ?x?
62 * result_domain(<virtual_device>) = <virtual_device>
63 * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr)
64 * \endcode
65 *
66 * TODO(mbs): We currently don't allow sub-VirtualDevice constraints. Eg for a function we can
67 * express that the argument and result VirtualDevices must be exactly equal, but we cannot express
68 * that though the devices and targets for arguments and results must be equal, it is ok for
69 * memory scopes to differ. At the moment we can get away with this since we run PlanDevices
70 * twice: once with all memory scopes unconstrained, then again with just memory scopes as
71 * the new property to flow. However we're on thin ice here and better would be to allow
72 * constraints on VirtualDevices to be exploded into their device/target component and their
73 * memory scope component. Should we fold layout constraints into VirtualDevices then they would
74 * probably be grouped with memory scopes.
75 */
76class DeviceDomain {
77 public:
78 /*!
79 * \brief Constructs a first-order domain for \p virtual_device, which may be
80 * fully free (ie virtual_device is unconstrained), partially free (ie virtual_device has at
81 * least on of its target, device id or memory scopes known), or fully fixed (ie virtual_device
82 * has its target, device id and memory scopes set).
83 *
84 * CAUTION: Use DeviceDomains::MakeFirstOrderDomain instead of this ctor.
85 */
86 explicit DeviceDomain(VirtualDevice virtual_device)
87 : virtual_device_(std::move(virtual_device)) {}
88
89 /*!
90 * \brief Constructs a higher-order domain, where \p args_and_result contain the
91 * function argument and result domains in order.
92 *
93 * CAUTION: Use DeviceDomains::MakeHigherOrderDomain instead of this ctor.
94 */
95 explicit DeviceDomain(std::vector<DeviceDomainPtr> args_and_result)
96 : virtual_device_(VirtualDevice::FullyUnconstrained()),
97 args_and_result_(std::move(args_and_result)) {}
98
99 bool is_higher_order() const { return !args_and_result_.empty(); }
100
101 VirtualDevice first_order_virtual_device() const {
102 ICHECK(args_and_result_.empty()) << "expecting domain to be first-order";
103 return virtual_device_;
104 }
105
106 size_t function_arity() const {
107 ICHECK(!args_and_result_.empty()) << "expecting domain to be higher-order";
108 return args_and_result_.size() - 1UL;
109 }
110
111 DeviceDomainPtr function_param(size_t i) const {
112 ICHECK(!args_and_result_.empty()) << "expecting domain to be higher-order";
113 ICHECK_LT(i + 1, args_and_result_.size()) << "parameter index is out of range";
114 return args_and_result_[i];
115 }
116
117 DeviceDomainPtr function_result() const {
118 ICHECK(!args_and_result_.empty());
119 return args_and_result_.back();
120 }
121
122 private:
123 /*!
124 * \brief If this is a function domain then always fully unconstrained. Otherwise will be
125 * fully unconstrained (the domain is still completely free), partially constrained
126 * (for example, the \p target and \p device_type are constrained but the \p virtual_device_id and
127 * \p memory_scope are still unconstrained), or fully constrained (everything is known).
128 */
129 const VirtualDevice virtual_device_;
130
131 /*!
132 * \brief If this is a function domain then the sub-domains for each of the function's
133 * arguments, and the domain for its result. Otherwise empty.
134 */
135 const std::vector<DeviceDomainPtr> args_and_result_;
136
137 friend class DeviceDomains;
138};
139
140/*!
141 * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation
142 * built up by calls to \p UnifyOrNull.
143 */
144class DeviceDomains {
145 public:
146 explicit DeviceDomains(CompilationConfig config);
147
148 const CompilationConfig& config() const { return config_; }
149
150 /*!
151 * \brief Returns the domain representing \p virtual_device. If \p virtual_device is fully
152 * constrained then the domain will be unique that \p virtual_device.
153 */
154 DeviceDomainPtr MakeFirstOrderDomain(const VirtualDevice& virtual_device);
155
156 /*!
157 * \brief Returns a higher-order domain with \p args_and_results.
158 */
159 DeviceDomainPtr MakeHigherOrderDomain(std::vector<DeviceDomainPtr> arg_and_results) {
160 return std::make_shared<DeviceDomain>(std::move(arg_and_results));
161 }
162
163 /*!
164 * \brief Returns a domain appropriate for \p type who's result domain is bound to \p
165 * virtual_device. If \p type is a function then all parameter domains will be completely free. It
166 * is valid for \p virtual_device to be fully unconstrained.
167 */
168 DeviceDomainPtr MakeDomain(const Type& type, const VirtualDevice& virtual_device);
169
170 /*!
171 * \brief Returns a domain with the given result appropriate \p non_canonical_virtual_device,
172 * which cannot be fully unconstrained. We first canonicalize the virtual device to unsure it has
173 * a target and is unique.
174 */
175 DeviceDomainPtr ForVirtualDevice(const Type& type,
176 const VirtualDevice& non_canonical_virtual_device);
177
178 /*! \brief Returns a free domain appropriate for \p type. */
179 DeviceDomainPtr Free(const Type& type) {
180 return MakeDomain(type, VirtualDevice::FullyUnconstrained());
181 }
182
183 /*! \brief Returns the domain representing the equivalence class containing \p domain. */
184 DeviceDomainPtr Lookup(DeviceDomainPtr domain);
185
186 /*!
187 * \brief Returns the most constrained domain which agrees with both \p lhs and \p rhs. Returns
188 * null if no such domain exists, ie some first-order component of \p lhs is constrained
189 * differently than the corresponding component of \p rhs.
190 */
191 DeviceDomainPtr JoinOrNull(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs);
192
193 /*!
194 * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Returns null if
195 * \p lhs and \p rhs are not unifiable, in which case the constraint system may be left in
196 * a partially modified state.
197 */
198 // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but
199 // given we have refs to functions I'm prepared to be surprised.
200 DeviceDomainPtr UnifyOrNull(DeviceDomainPtr lhs, DeviceDomainPtr rhs);
201
202 /*
203 * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain.
204 * This can be used to handle functions within tuples, references and ADTs since we don't
205 * attempt to track anything beyond 'the device' for expressions of those first-order types.
206 *
207 * Returns false if any unification fails.
208 */
209 bool CollapseOrFalse(const DeviceDomainPtr& first_order_domain,
210 const DeviceDomainPtr& higher_order_domain);
211
212 /*!
213 * \brief Unifies \p lhs_first_order and \p rhs_maybe_higher_order. If \p rhs_maybe_higher_order
214 * is indeed higher-order, require all of its arguments and result to unify with
215 * \p lhs_first_order. Otherwise same as \p Unify. Returns false if unification is not possible.
216 *
217 * In an expression such as:
218 * \code
219 * (fn(...) {...}, ...).0
220 * \endcode
221 * we need to force all the devices of the inner function to be the same as the device for the
222 * overall tuple since the device domain does not understand tuples. Similarly for references
223 * and ADTs.
224 */
225 bool UnifyCollapsedOrFalse(const DeviceDomainPtr& lhs_first_order,
226 const DeviceDomainPtr& rhs_maybe_higher_order);
227
228 /*! \brief Returns true if a domain is known for \p expr. */
229 bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); }
230
231 /*! \brief Returns the domain representing \p expr. */
232 DeviceDomainPtr DomainFor(const Expr& expr);
233
234 /*!
235 * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the
236 * callee is a primitive or special operation we handle it specially. Otherwise defers to \p
237 * DomainFor(call->op).
238 *
239 * This special handling is needed:
240 * - To handle the "on_device" and "device_copy" ops which constrain devices to the given
241 * devices.
242 * - To handle some special ops which constrain devices to the CPU.
243 * - To allow the same primitive to be called on different devices at different call sites.
244 * Since each call to the op can have a different domain we index the ops by the call expression
245 * rather than the op itself.
246 */
247 DeviceDomainPtr DomainForCallee(const Call& call);
248
249 /*!
250 * \brief Unifies the domains for expressions \p lhs and \p rhs.
251 *
252 * Aborts if unification fails.
253 */
254 void UnifyExprExact(const Expr& lhs, const Expr& rhs);
255
256 /*!
257 * \brief Attempts to unify the domains for expressions \p lhs and \p rhs, however if they
258 * cannot be unified then returns with no change to the unification system.
259 */
260 void OptionalUnifyExprExact(const Expr& lhs, const Expr& rhs);
261
262 /*!
263 * \brief Unifies the domain for \p expr with \p expected_domain.
264 *
265 * Aborts if unification fails.
266 */
267 void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain);
268
269 /*!
270 * \brief Unifies the domain for \p expr with \p expected_domain.
271 * If \p expected_domain is higher-order but \p expr is first-order, require all arguments
272 * and the result of \p expected_domain to have the same domain as for \p expr.
273 *
274 * Aborts if unification fails.
275 */
276 void UnifyExprCollapsed(const Expr& expr_first_order,
277 const DeviceDomainPtr& expected_domain_maybe_higher_order);
278
279 /*! \brief Returns true if \p domain is fully constrainted. */
280 bool IsFullyConstrained(DeviceDomainPtr domain);
281
282 /*! \brief Force all \p VirtualDevices in \p domain to default to \p default_virtual_device. */
283 void SetDefault(DeviceDomainPtr domain, const VirtualDevice& default_virtual_device);
284
285 /*!
286 * \brief If \p domain is higher-order default it's result domain to \p default_virtual_device.
287 * Then force all remaining \p VirtualDevices to the result domain (freshly defaulted or
288 * original). If \p domain is first-order same as \p SetDefault.
289 */
290 void SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order,
291 const VirtualDevice& default_virtual_device);
292
293 /*!
294 * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment).
295 */
296 DeviceDomainPtr ResultDomain(DeviceDomainPtr domain);
297
298 /*!
299 * \brief Returns the result \p VirtualDevice (possibly unconstrained) for \p domain
300 * (see defn in DeviceDomain comment).
301 */
302 VirtualDevice ResultVirtualDevice(const DeviceDomainPtr& domain) {
303 return ResultDomain(domain)->first_order_virtual_device();
304 }
305
306 /*! \brief Returns one-line description of \p domain for debugging. */
307 std::string ToString(DeviceDomainPtr domain);
308
309 /*! \brief Returns description of entire system of constraints for debugging */
310 std::string ToString();
311
312 private:
313 /*! \brief Intrinsics we need to handle specially. */
314 const Op& alloc_storage_op = Op::Get("memory.alloc_storage");
315 const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor");
316 const Op& shape_of_op = Op::Get("vm.shape_of");
317 const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op");
318 const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor");
319
320 CompilationConfig config_;
321
322 /*!
323 * \brief The domain for first-order expressions of non-tensor type, such as shapes and
324 * buffer dimensions. Generally this will be a CPU.
325 */
326 DeviceDomainPtr host_domain_;
327
328 /*! \brief Maps expressions to their domains as determined during analysis. */
329 std::unordered_map<const ExprNode*, DeviceDomainPtr> expr_to_domain_;
330
331 /*!
332 * \brief Maps call expressions to the domains for their callee where the callee is a primitive.
333 */
334 std::unordered_map<const CallNode*, DeviceDomainPtr> call_to_callee_domain_;
335
336 /*! \brief Maps device domains to their equivalent domains as determined during unification. */
337 std::unordered_map<DeviceDomainPtr, DeviceDomainPtr> domain_to_equiv_;
338
339 /*!
340 * \brief Maps fully constrained \p VirtualDevices to their corresponding domains. By sharing
341 * those domains we can ensure:
342 *
343 * \code
344 * domain0 != domain1 && domain0 fully constrained && domain1 fully constrained
345 * ==> domain0 and domain1 are incompatible
346 * \endcode
347 */
348 std::unordered_map<VirtualDevice, DeviceDomainPtr, runtime::ObjectPtrHash,
349 runtime::ObjectPtrEqual>
350 fully_constrained_virtual_device_to_domain_;
351};
352
353} // namespace transform
354} // namespace relay
355} // namespace tvm
356
357#endif // TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_
358