1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file src/relay/op/memory/on_device.cc |
23 | * \brief Helpers for working with the "on_device" 'annotation' call. |
24 | */ |
25 | |
26 | #include "./on_device.h" |
27 | |
28 | #include <tvm/relay/attrs/annotation.h> |
29 | #include <tvm/relay/expr.h> |
30 | #include <tvm/relay/op.h> |
31 | #include <tvm/relay/op_attr_types.h> |
32 | #include <tvm/relay/transform.h> |
33 | |
34 | #include "../../transforms/infer_layout_utils.h" |
35 | #include "../type_relations.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); |
41 | |
42 | const Op& OnDeviceOp() { |
43 | static const Op& op = Op::Get("on_device" ); |
44 | return op; |
45 | } |
46 | |
47 | Call OnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result, bool constrain_body) { |
48 | ICHECK((!constrain_result && !constrain_body) || !virtual_device->IsFullyUnconstrained()); |
49 | auto attrs = make_object<OnDeviceAttrs>(); |
50 | attrs->virtual_device = (constrain_result || constrain_body) |
51 | ? std::move(virtual_device) |
52 | : VirtualDevice::FullyUnconstrained(); |
53 | attrs->constrain_result = constrain_result; |
54 | attrs->constrain_body = constrain_body; |
55 | Span span = body->span; // about to be moved |
56 | return Call(OnDeviceOp(), {std::move(body)}, Attrs(std::move(attrs)), /*type_args=*/{}, |
57 | std::move(span)); |
58 | } |
59 | |
60 | TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice" ).set_body_typed(OnDevice); |
61 | |
62 | Expr MaybeOnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result, |
63 | bool constrain_body) { |
64 | if (virtual_device->IsFullyUnconstrained()) { |
65 | // Nothing to annotate with. |
66 | return body; |
67 | } |
68 | if (body->IsInstance<OpNode>() || body->IsInstance<ConstructorNode>()) { |
69 | // These operators are device polymorphic so no annotation is required. |
70 | return body; |
71 | } |
72 | if (body->IsInstance<GlobalVarNode>() || body->IsInstance<VarNode>()) { |
73 | // The device can be recovered from the binding site of the global or local variable. |
74 | return body; |
75 | } |
76 | if (body->IsInstance<FunctionNode>()) { |
77 | // If a primitive function then it is device polymorphic. Otherwise the device is captured |
78 | // by the function's "result_virtual_device" attribute. |
79 | return body; |
80 | } |
81 | OnDeviceProps props = GetOnDeviceProps(body); |
82 | if (props.body.defined()) { |
83 | // The user is asking for |
84 | // on_device(on_device(body, virtual_device=inner), virtual_device=outer) |
85 | // ^ ^ ^ |
86 | // outer middle inner |
87 | // First recover the implied constraints (if any) for outer and inner, and check they don't |
88 | // contradict. |
89 | const VirtualDevice& inner = props.virtual_device; |
90 | const VirtualDevice& outer = virtual_device; |
91 | bool constrain_outer = constrain_result; |
92 | bool constrain_inner = props.constrain_body; |
93 | if (constrain_outer && constrain_inner) { |
94 | ICHECK(inner == outer) << "Cannot constrain result and body of nested on_device calls to " |
95 | "different virtual devices" ; |
96 | } |
97 | // There are two possible ways the middle sub-expression may be constrained, check they don't |
98 | // contradict. |
99 | bool constrain_middle_via_outer = constrain_body; |
100 | bool constrain_middle_via_inner = props.constrain_result; |
101 | if (constrain_middle_via_outer && constrain_middle_via_inner) { |
102 | ICHECK(inner == outer) << "Cannot constrain intermediate result of nested on_device calls to " |
103 | "different virtual devices" ; |
104 | } |
105 | // We can now ignore the middle constraint. |
106 | // If the outer on_device has any constraint then use virtual_device given for it. |
107 | // Otherwise we can use the existing inner virtual_device. |
108 | return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : inner, |
109 | constrain_outer, constrain_inner); |
110 | } else { |
111 | return OnDevice(body, std::move(virtual_device), constrain_result, constrain_body); |
112 | } |
113 | } |
114 | |
115 | RELAY_REGISTER_OP("on_device" ) |
116 | .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) |
117 | .set_num_inputs(1) |
118 | .add_argument("body" , "Expr" , "The sub-expression to be annotated." ) |
119 | .set_support_level(10) |
120 | .add_type_rel("Identity" , IdentityRel) |
121 | .set_attrs_type_key("relay.attrs.OnDeviceAttrs" ) |
122 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
123 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
124 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout) |
125 | .set_attr<TNonComputational>("TNonComputational" , true); |
126 | |
127 | OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { |
128 | if (call_node->op == OnDeviceOp()) { |
129 | ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument" ; |
130 | ICHECK(call_node->attrs.defined()) << "on_device requires attributes" ; |
131 | const auto* on_device_attrs = call_node->attrs.as<OnDeviceAttrs>(); |
132 | ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs" ; |
133 | return {call_node->args[0], on_device_attrs->virtual_device, on_device_attrs->constrain_result, |
134 | on_device_attrs->constrain_body}; |
135 | } |
136 | return {}; |
137 | } |
138 | |
139 | OnDeviceProps GetOnDeviceProps(const Expr& expr) { |
140 | if (const auto* call_node = expr.as<CallNode>()) { |
141 | return GetOnDeviceProps(call_node); |
142 | } |
143 | return {}; |
144 | } |
145 | |
146 | } // namespace relay |
147 | } // namespace tvm |
148 | |