1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/op/memory/on_device.h |
22 | * \brief Helpers for working with the "on_device" 'annotation' call. |
23 | */ |
24 | #ifndef TVM_RELAY_OP_MEMORY_ON_DEVICE_H_ |
25 | #define TVM_RELAY_OP_MEMORY_ON_DEVICE_H_ |
26 | |
27 | #include <tvm/relay/attrs/on_device.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/function.h> |
30 | #include <tvm/runtime/ndarray.h> |
31 | |
32 | #include <utility> |
33 | #include <vector> |
34 | |
35 | namespace tvm { |
36 | namespace relay { |
37 | |
38 | /*! \brief Returns the "on_device" operator. */ |
39 | const Op& OnDeviceOp(); |
40 | |
41 | /*! |
42 | * \brief Wraps \p body in an "on_device" CallNode for \p virtual_device. |
43 | * |
44 | * See \p OnDeviceAttrs for an overview. |
45 | */ |
46 | Call OnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result = false, |
47 | bool constrain_body = true); |
48 | |
49 | /*! \brief Result of \p GetOnDeviceProps. */ |
50 | struct OnDeviceProps { |
51 | Expr body; // = null |
52 | VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained(); |
53 | bool constrain_result = false; |
54 | bool constrain_body = false; |
55 | |
56 | OnDeviceProps() = default; |
57 | |
58 | OnDeviceProps(Expr body, VirtualDevice virtual_device, bool constrain_result, bool constrain_body) |
59 | : body(std::move(body)), |
60 | virtual_device(std::move(virtual_device)), |
61 | constrain_result(constrain_result), |
62 | constrain_body(constrain_body) {} |
63 | |
64 | bool is_fixed() const { return constrain_result && constrain_body; } |
65 | bool is_normal() const { return !constrain_result && constrain_body; } |
66 | }; |
67 | |
68 | /*! |
69 | * \brief Wraps \p body in an "on_device" CallNode, taking all fields other than \p body from \p |
70 | * props. |
71 | */ |
72 | inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) { |
73 | return OnDevice(std::move(body), props.virtual_device, props.constrain_result, |
74 | props.constrain_body); |
75 | } |
76 | |
77 | /*! |
78 | * \brief Wraps \p body in an "on_device" CallNode, but don't constrain the body or result to |
79 | * any particular virtual device. This allows a "device_copy" to be inserted by PlanDevices |
80 | * where required, while at the same time not introducing unnecessary freedom in the device |
81 | * choices. |
82 | */ |
83 | inline Call OnDeviceCopyOk(Expr body) { |
84 | return OnDevice(std::move(body), VirtualDevice::FullyUnconstrained(), |
85 | /*constrain_result=*/false, /*constrain_body=*/false); |
86 | } |
87 | |
88 | /*! |
89 | * \brief Wraps \p expr in an "on_device" CallNode for \p virtual_device and \p constraint if the |
90 | * \p VirtualDevice for \p expr cannot otherwise be recovered by the lexical scoping convention. |
91 | * This means we will NOT wrap if: |
92 | * - \p virtual_device is full unconstrained, which signals there are no device annotations |
93 | * already in play. |
94 | * - \p expr is an operator or primitive function literal. These are device polymorphic. |
95 | * - \p expr is a non-primitive function literal. The device is captured by the |
96 | * "result_virtual_device" attribute on the function itself. |
97 | * - \p expr is a global var. The device is on the function attributes the global is bound to. |
98 | * - \p expr is a local var. The device is tracked by the device aware visitors for us. |
99 | * - \p expr is a constructor. These are device polymorphic. |
100 | * Nested on_device calls will never be constructed, they are instead merged on-the-fly. |
101 | */ |
102 | Expr MaybeOnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result = false, |
103 | bool constrain_body = true); |
104 | |
105 | /*! \brief As for MaybeOnDevice, but with both body and result constrained. */ |
106 | inline Expr MaybeOnDeviceFixed(Expr body, VirtualDevice virtual_device) { |
107 | return MaybeOnDevice(std::move(body), std::move(virtual_device), /*constrain_result=*/true, |
108 | /*constrain_body=*/true); |
109 | } |
110 | |
111 | /*! \brief As for MaybeOnDevice, but with fields other than body taken from \p props. */ |
112 | inline Expr MaybeOnDeviceWithProps(Expr body, const OnDeviceProps& props) { |
113 | return MaybeOnDevice(std::move(body), props.virtual_device, props.constrain_result, |
114 | props.constrain_body); |
115 | } |
116 | |
117 | /*! |
118 | * \brief Returns the body expression, \p VirtualDevice, and constraint field for \p call_node if it |
119 | * is an "on_device" CallNode. Otherwise returns the null expression, the unconstrained |
120 | * \p VirtualDevice, and \p kBody. |
121 | */ |
122 | OnDeviceProps GetOnDeviceProps(const CallNode* call_node); |
123 | |
124 | /*! |
125 | * \brief Returns the body expression, \p VirtualDevice, and constraint field for \p expr if it is |
126 | * an "on_device" CallNode. Otherwise returns the null expression, the unconstrained \p |
127 | * VirtualDevice, and \p kBody. |
128 | */ |
129 | OnDeviceProps GetOnDeviceProps(const Expr& expr); |
130 | |
131 | /*! |
132 | * \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns |
133 | * \p expr directly. |
134 | */ |
135 | inline Expr IgnoreOnDevice(const Expr& expr) { |
136 | OnDeviceProps props = GetOnDeviceProps(expr); |
137 | return props.body.defined() ? props.body : expr; |
138 | } |
139 | |
140 | /*! |
141 | * \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through |
142 | * any "on_device" annotations. |
143 | */ |
144 | template <typename NodeType> |
145 | const NodeType* AsIgnoringOnDevice(const Expr& expr) { |
146 | const auto* node = expr.as<NodeType>(); |
147 | if (node != nullptr) { |
148 | return node; |
149 | } |
150 | OnDeviceProps props = GetOnDeviceProps(expr); |
151 | if (!props.body.defined()) { |
152 | return nullptr; |
153 | } |
154 | return props.body.as<NodeType>(); |
155 | } |
156 | |
157 | } // namespace relay |
158 | } // namespace tvm |
159 | |
160 | #endif // TVM_RELAY_OP_MEMORY_ON_DEVICE_H_ |
161 | |