1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file src/relay/op/memory/on_device.cc
23 * \brief Helpers for working with the "on_device" 'annotation' call.
24 */
25
26#include "./on_device.h"
27
28#include <tvm/relay/attrs/annotation.h>
29#include <tvm/relay/expr.h>
30#include <tvm/relay/op.h>
31#include <tvm/relay/op_attr_types.h>
32#include <tvm/relay/transform.h>
33
34#include "../../transforms/infer_layout_utils.h"
35#include "../type_relations.h"
36
37namespace tvm {
38namespace relay {
39
40TVM_REGISTER_NODE_TYPE(OnDeviceAttrs);
41
42const Op& OnDeviceOp() {
43 static const Op& op = Op::Get("on_device");
44 return op;
45}
46
47Call OnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result, bool constrain_body) {
48 ICHECK((!constrain_result && !constrain_body) || !virtual_device->IsFullyUnconstrained());
49 auto attrs = make_object<OnDeviceAttrs>();
50 attrs->virtual_device = (constrain_result || constrain_body)
51 ? std::move(virtual_device)
52 : VirtualDevice::FullyUnconstrained();
53 attrs->constrain_result = constrain_result;
54 attrs->constrain_body = constrain_body;
55 Span span = body->span; // about to be moved
56 return Call(OnDeviceOp(), {std::move(body)}, Attrs(std::move(attrs)), /*type_args=*/{},
57 std::move(span));
58}
59
60TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice").set_body_typed(OnDevice);
61
62Expr MaybeOnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result,
63 bool constrain_body) {
64 if (virtual_device->IsFullyUnconstrained()) {
65 // Nothing to annotate with.
66 return body;
67 }
68 if (body->IsInstance<OpNode>() || body->IsInstance<ConstructorNode>()) {
69 // These operators are device polymorphic so no annotation is required.
70 return body;
71 }
72 if (body->IsInstance<GlobalVarNode>() || body->IsInstance<VarNode>()) {
73 // The device can be recovered from the binding site of the global or local variable.
74 return body;
75 }
76 if (body->IsInstance<FunctionNode>()) {
77 // If a primitive function then it is device polymorphic. Otherwise the device is captured
78 // by the function's "result_virtual_device" attribute.
79 return body;
80 }
81 OnDeviceProps props = GetOnDeviceProps(body);
82 if (props.body.defined()) {
83 // The user is asking for
84 // on_device(on_device(body, virtual_device=inner), virtual_device=outer)
85 // ^ ^ ^
86 // outer middle inner
87 // First recover the implied constraints (if any) for outer and inner, and check they don't
88 // contradict.
89 const VirtualDevice& inner = props.virtual_device;
90 const VirtualDevice& outer = virtual_device;
91 bool constrain_outer = constrain_result;
92 bool constrain_inner = props.constrain_body;
93 if (constrain_outer && constrain_inner) {
94 ICHECK(inner == outer) << "Cannot constrain result and body of nested on_device calls to "
95 "different virtual devices";
96 }
97 // There are two possible ways the middle sub-expression may be constrained, check they don't
98 // contradict.
99 bool constrain_middle_via_outer = constrain_body;
100 bool constrain_middle_via_inner = props.constrain_result;
101 if (constrain_middle_via_outer && constrain_middle_via_inner) {
102 ICHECK(inner == outer) << "Cannot constrain intermediate result of nested on_device calls to "
103 "different virtual devices";
104 }
105 // We can now ignore the middle constraint.
106 // If the outer on_device has any constraint then use virtual_device given for it.
107 // Otherwise we can use the existing inner virtual_device.
108 return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : inner,
109 constrain_outer, constrain_inner);
110 } else {
111 return OnDevice(body, std::move(virtual_device), constrain_result, constrain_body);
112 }
113}
114
115RELAY_REGISTER_OP("on_device")
116 .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE)
117 .set_num_inputs(1)
118 .add_argument("body", "Expr", "The sub-expression to be annotated.")
119 .set_support_level(10)
120 .add_type_rel("Identity", IdentityRel)
121 .set_attrs_type_key("relay.attrs.OnDeviceAttrs")
122 .set_attr<TOpPattern>("TOpPattern", kOpaque)
123 .set_attr<TOpIsStateful>("TOpIsStateful", false)
124 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
125 .set_attr<TNonComputational>("TNonComputational", true);
126
127OnDeviceProps GetOnDeviceProps(const CallNode* call_node) {
128 if (call_node->op == OnDeviceOp()) {
129 ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument";
130 ICHECK(call_node->attrs.defined()) << "on_device requires attributes";
131 const auto* on_device_attrs = call_node->attrs.as<OnDeviceAttrs>();
132 ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs";
133 return {call_node->args[0], on_device_attrs->virtual_device, on_device_attrs->constrain_result,
134 on_device_attrs->constrain_body};
135 }
136 return {};
137}
138
139OnDeviceProps GetOnDeviceProps(const Expr& expr) {
140 if (const auto* call_node = expr.as<CallNode>()) {
141 return GetOnDeviceProps(call_node);
142 }
143 return {};
144}
145
146} // namespace relay
147} // namespace tvm
148