1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/op/memory/on_device.h
22 * \brief Helpers for working with the "on_device" 'annotation' call.
23 */
24#ifndef TVM_RELAY_OP_MEMORY_ON_DEVICE_H_
25#define TVM_RELAY_OP_MEMORY_ON_DEVICE_H_
26
27#include <tvm/relay/attrs/on_device.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/function.h>
30#include <tvm/runtime/ndarray.h>
31
32#include <utility>
33#include <vector>
34
35namespace tvm {
36namespace relay {
37
38/*! \brief Returns the "on_device" operator. */
39const Op& OnDeviceOp();
40
41/*!
42 * \brief Wraps \p body in an "on_device" CallNode for \p virtual_device.
43 *
44 * See \p OnDeviceAttrs for an overview.
45 */
46Call OnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result = false,
47 bool constrain_body = true);
48
49/*! \brief Result of \p GetOnDeviceProps. */
50struct OnDeviceProps {
51 Expr body; // = null
52 VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained();
53 bool constrain_result = false;
54 bool constrain_body = false;
55
56 OnDeviceProps() = default;
57
58 OnDeviceProps(Expr body, VirtualDevice virtual_device, bool constrain_result, bool constrain_body)
59 : body(std::move(body)),
60 virtual_device(std::move(virtual_device)),
61 constrain_result(constrain_result),
62 constrain_body(constrain_body) {}
63
64 bool is_fixed() const { return constrain_result && constrain_body; }
65 bool is_normal() const { return !constrain_result && constrain_body; }
66};
67
68/*!
69 * \brief Wraps \p body in an "on_device" CallNode, taking all fields other than \p body from \p
70 * props.
71 */
72inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) {
73 return OnDevice(std::move(body), props.virtual_device, props.constrain_result,
74 props.constrain_body);
75}
76
77/*!
78 * \brief Wraps \p body in an "on_device" CallNode, but don't constrain the body or result to
79 * any particular virtual device. This allows a "device_copy" to be inserted by PlanDevices
80 * where required, while at the same time not introducing unnecessary freedom in the device
81 * choices.
82 */
83inline Call OnDeviceCopyOk(Expr body) {
84 return OnDevice(std::move(body), VirtualDevice::FullyUnconstrained(),
85 /*constrain_result=*/false, /*constrain_body=*/false);
86}
87
88/*!
89 * \brief Wraps \p expr in an "on_device" CallNode for \p virtual_device and \p constraint if the
90 * \p VirtualDevice for \p expr cannot otherwise be recovered by the lexical scoping convention.
91 * This means we will NOT wrap if:
92 * - \p virtual_device is full unconstrained, which signals there are no device annotations
93 * already in play.
94 * - \p expr is an operator or primitive function literal. These are device polymorphic.
95 * - \p expr is a non-primitive function literal. The device is captured by the
96 * "result_virtual_device" attribute on the function itself.
97 * - \p expr is a global var. The device is on the function attributes the global is bound to.
98 * - \p expr is a local var. The device is tracked by the device aware visitors for us.
99 * - \p expr is a constructor. These are device polymorphic.
100 * Nested on_device calls will never be constructed, they are instead merged on-the-fly.
101 */
102Expr MaybeOnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result = false,
103 bool constrain_body = true);
104
105/*! \brief As for MaybeOnDevice, but with both body and result constrained. */
106inline Expr MaybeOnDeviceFixed(Expr body, VirtualDevice virtual_device) {
107 return MaybeOnDevice(std::move(body), std::move(virtual_device), /*constrain_result=*/true,
108 /*constrain_body=*/true);
109}
110
111/*! \brief As for MaybeOnDevice, but with fields other than body taken from \p props. */
112inline Expr MaybeOnDeviceWithProps(Expr body, const OnDeviceProps& props) {
113 return MaybeOnDevice(std::move(body), props.virtual_device, props.constrain_result,
114 props.constrain_body);
115}
116
117/*!
118 * \brief Returns the body expression, \p VirtualDevice, and constraint field for \p call_node if it
119 * is an "on_device" CallNode. Otherwise returns the null expression, the unconstrained
120 * \p VirtualDevice, and \p kBody.
121 */
122OnDeviceProps GetOnDeviceProps(const CallNode* call_node);
123
124/*!
125 * \brief Returns the body expression, \p VirtualDevice, and constraint field for \p expr if it is
126 * an "on_device" CallNode. Otherwise returns the null expression, the unconstrained \p
127 * VirtualDevice, and \p kBody.
128 */
129OnDeviceProps GetOnDeviceProps(const Expr& expr);
130
131/*!
132 * \brief Returns the body of \p expr if it is an "on_device" annotation, otherwise returns
133 * \p expr directly.
134 */
135inline Expr IgnoreOnDevice(const Expr& expr) {
136 OnDeviceProps props = GetOnDeviceProps(expr);
137 return props.body.defined() ? props.body : expr;
138}
139
140/*!
141 * \brief Returns \p expr as \p NodeType, or null if it is not of that type. Looks through
142 * any "on_device" annotations.
143 */
144template <typename NodeType>
145const NodeType* AsIgnoringOnDevice(const Expr& expr) {
146 const auto* node = expr.as<NodeType>();
147 if (node != nullptr) {
148 return node;
149 }
150 OnDeviceProps props = GetOnDeviceProps(expr);
151 if (!props.body.defined()) {
152 return nullptr;
153 }
154 return props.body.as<NodeType>();
155}
156
157} // namespace relay
158} // namespace tvm
159
160#endif // TVM_RELAY_OP_MEMORY_ON_DEVICE_H_
161