1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/op/memory/device_copy.cc
22 * \brief Helpers for working with "device_copy" attributes.
23 */
24
25#include "./device_copy.h"
26
27#include <tvm/relay/attrs/annotation.h>
28#include <tvm/relay/attrs/call.h>
29#include <tvm/relay/attrs/device_copy.h>
30#include <tvm/relay/expr.h>
31#include <tvm/relay/op.h>
32#include <tvm/relay/op_attr_types.h>
33#include <tvm/topi/elemwise.h>
34
35#include <utility>
36
37#include "../../transforms/infer_layout_utils.h"
38#include "../annotation/annotation.h"
39#include "../call/call.h"
40#include "../type_relations.h"
41
42namespace tvm {
43namespace relay {
44
45// relay.device_copy
46TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs);
47
48const Op& DeviceCopyOp() {
49 static const Op& op = Op::Get("device_copy");
50 return op;
51}
52
53Expr DeviceCopy(Expr expr, VirtualDevice src_virtual_device, VirtualDevice dst_virtual_device) {
54 ICHECK(!src_virtual_device->IsFullyUnconstrained());
55 ICHECK(!dst_virtual_device->IsFullyUnconstrained());
56 auto attrs = make_object<DeviceCopyAttrs>();
57 attrs->src_virtual_device = std::move(src_virtual_device);
58 attrs->dst_virtual_device = std::move(dst_virtual_device);
59 Span span = expr->span;
60 return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{},
61 std::move(span));
62}
63
64TVM_REGISTER_GLOBAL("relay.op._make.DeviceCopy").set_body_typed(DeviceCopy);
65
66Expr MaybeDeviceCopy(Expr expr, VirtualDevice src_virtual_device,
67 VirtualDevice dst_virtual_device) {
68 if (src_virtual_device == dst_virtual_device) {
69 // No copy needed.
70 return expr;
71 }
72 return DeviceCopy(std::move(expr), std::move(src_virtual_device), std::move(dst_virtual_device));
73}
74
75RELAY_REGISTER_OP("device_copy")
76 .describe(R"code(
77Copy data from one tensor to another. The source and destination might be
78on different devices.
79)code" TVM_ADD_FILELINE)
80 .set_num_inputs(1)
81 .add_argument("data", "Tensor", "The input data.")
82 .set_support_level(10)
83 .add_type_rel("Identity", IdentityRel)
84 .set_attrs_type_key("relay.attrs.DeviceCopyAttrs")
85 .set_attr<TOpPattern>("TOpPattern", kOpaque)
86 .set_attr<TOpIsStateful>("TOpIsStateful", false)
87 .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
88 .set_attr<FTVMCompute>("FTVMCompute",
89 [](const Attrs& attrs, const Array<te::Tensor>& inputs,
90 const Type& out_dtype) -> Array<te::Tensor> {
91 return {topi::identity(inputs[0])};
92 });
93
94// Get device copy props for original device copy op
95DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) {
96 if (call_node->op == DeviceCopyOp()) {
97 ICHECK_EQ(call_node->args.size(), 1) << "device_copy expects one argument";
98 ICHECK(call_node->attrs.defined()) << "device_copy requires attributes";
99 const auto* device_copy_attrs = call_node->attrs.as<DeviceCopyAttrs>();
100 ICHECK(device_copy_attrs != nullptr) << "device_copy requires DeviceCopyAttrs";
101 // Follow nesting:
102 // device_copy(device_copy(expr, src_virtual_device=S, dst_virtual_device=T),
103 // src_virtual_device=T, dst_virtual_device=U) ==> {expr, S, U}
104 auto inner = GetDeviceCopyProps(call_node->args[0]);
105 if (inner.body.defined()) {
106 return {inner.body, inner.src_virtual_device, device_copy_attrs->dst_virtual_device};
107 } else {
108 return {call_node->args[0], device_copy_attrs->src_virtual_device,
109 device_copy_attrs->dst_virtual_device};
110 }
111 }
112 return {};
113}
114
115DeviceCopyProps GetDeviceCopyProps(const Expr& expr) {
116 if (const auto* call_node = expr.as<CallNode>()) {
117 return GetDeviceCopyProps(call_node);
118 }
119 return {};
120}
121
122} // namespace relay
123} // namespace tvm
124