1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/op/memory/device_copy.cc |
22 | * \brief Helpers for working with "device_copy" attributes. |
23 | */ |
24 | |
25 | #include "./device_copy.h" |
26 | |
27 | #include <tvm/relay/attrs/annotation.h> |
28 | #include <tvm/relay/attrs/call.h> |
29 | #include <tvm/relay/attrs/device_copy.h> |
30 | #include <tvm/relay/expr.h> |
31 | #include <tvm/relay/op.h> |
32 | #include <tvm/relay/op_attr_types.h> |
33 | #include <tvm/topi/elemwise.h> |
34 | |
35 | #include <utility> |
36 | |
37 | #include "../../transforms/infer_layout_utils.h" |
38 | #include "../annotation/annotation.h" |
39 | #include "../call/call.h" |
40 | #include "../type_relations.h" |
41 | |
42 | namespace tvm { |
43 | namespace relay { |
44 | |
45 | // relay.device_copy |
46 | TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); |
47 | |
48 | const Op& DeviceCopyOp() { |
49 | static const Op& op = Op::Get("device_copy" ); |
50 | return op; |
51 | } |
52 | |
53 | Expr DeviceCopy(Expr expr, VirtualDevice src_virtual_device, VirtualDevice dst_virtual_device) { |
54 | ICHECK(!src_virtual_device->IsFullyUnconstrained()); |
55 | ICHECK(!dst_virtual_device->IsFullyUnconstrained()); |
56 | auto attrs = make_object<DeviceCopyAttrs>(); |
57 | attrs->src_virtual_device = std::move(src_virtual_device); |
58 | attrs->dst_virtual_device = std::move(dst_virtual_device); |
59 | Span span = expr->span; |
60 | return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, |
61 | std::move(span)); |
62 | } |
63 | |
64 | TVM_REGISTER_GLOBAL("relay.op._make.DeviceCopy" ).set_body_typed(DeviceCopy); |
65 | |
66 | Expr MaybeDeviceCopy(Expr expr, VirtualDevice src_virtual_device, |
67 | VirtualDevice dst_virtual_device) { |
68 | if (src_virtual_device == dst_virtual_device) { |
69 | // No copy needed. |
70 | return expr; |
71 | } |
72 | return DeviceCopy(std::move(expr), std::move(src_virtual_device), std::move(dst_virtual_device)); |
73 | } |
74 | |
75 | RELAY_REGISTER_OP("device_copy" ) |
76 | .describe(R"code( |
77 | Copy data from one tensor to another. The source and destination might be |
78 | on different devices. |
79 | )code" TVM_ADD_FILELINE) |
80 | .set_num_inputs(1) |
81 | .add_argument("data" , "Tensor" , "The input data." ) |
82 | .set_support_level(10) |
83 | .add_type_rel("Identity" , IdentityRel) |
84 | .set_attrs_type_key("relay.attrs.DeviceCopyAttrs" ) |
85 | .set_attr<TOpPattern>("TOpPattern" , kOpaque) |
86 | .set_attr<TOpIsStateful>("TOpIsStateful" , false) |
87 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , ElemwiseArbitraryLayout) |
88 | .set_attr<FTVMCompute>("FTVMCompute" , |
89 | [](const Attrs& attrs, const Array<te::Tensor>& inputs, |
90 | const Type& out_dtype) -> Array<te::Tensor> { |
91 | return {topi::identity(inputs[0])}; |
92 | }); |
93 | |
94 | // Get device copy props for original device copy op |
95 | DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { |
96 | if (call_node->op == DeviceCopyOp()) { |
97 | ICHECK_EQ(call_node->args.size(), 1) << "device_copy expects one argument" ; |
98 | ICHECK(call_node->attrs.defined()) << "device_copy requires attributes" ; |
99 | const auto* device_copy_attrs = call_node->attrs.as<DeviceCopyAttrs>(); |
100 | ICHECK(device_copy_attrs != nullptr) << "device_copy requires DeviceCopyAttrs" ; |
101 | // Follow nesting: |
102 | // device_copy(device_copy(expr, src_virtual_device=S, dst_virtual_device=T), |
103 | // src_virtual_device=T, dst_virtual_device=U) ==> {expr, S, U} |
104 | auto inner = GetDeviceCopyProps(call_node->args[0]); |
105 | if (inner.body.defined()) { |
106 | return {inner.body, inner.src_virtual_device, device_copy_attrs->dst_virtual_device}; |
107 | } else { |
108 | return {call_node->args[0], device_copy_attrs->src_virtual_device, |
109 | device_copy_attrs->dst_virtual_device}; |
110 | } |
111 | } |
112 | return {}; |
113 | } |
114 | |
115 | DeviceCopyProps GetDeviceCopyProps(const Expr& expr) { |
116 | if (const auto* call_node = expr.as<CallNode>()) { |
117 | return GetDeviceCopyProps(call_node); |
118 | } |
119 | return {}; |
120 | } |
121 | |
122 | } // namespace relay |
123 | } // namespace tvm |
124 | |