1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file upsampling.cc |
22 | * \brief upsampling operator |
23 | */ |
24 | |
25 | #include "upsampling.h" |
26 | |
27 | #include <tvm/relay/attrs/nn.h> |
28 | #include <tvm/relay/op.h> |
29 | #include <tvm/relay/op_attr_types.h> |
30 | #include <tvm/tir/data_layout.h> |
31 | |
32 | #include <utility> |
33 | #include <vector> |
34 | |
35 | #include "../op_common.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | |
40 | TVM_REGISTER_NODE_TYPE(UpSamplingAttrs); |
41 | TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs); |
42 | |
43 | bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
44 | const TypeReporter& reporter) { |
45 | ICHECK_EQ(types.size(), 2); |
46 | const auto* data = types[0].as<TensorTypeNode>(); |
47 | if (data == nullptr) return false; |
48 | |
49 | static const Layout kNCHW("NCHW" ); |
50 | |
51 | const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>(); |
52 | ICHECK(param != nullptr); |
53 | const Layout in_layout(param->layout); |
54 | |
55 | auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); |
56 | ICHECK(layout_converter.defined()) |
57 | << "UpSampling only support input layouts that are convertible from NCHW." |
58 | << " But got " << in_layout; |
59 | |
60 | auto oshape = layout_converter.ForwardShape(data->shape); |
61 | oshape.Set(2, tir::Cast(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h))); |
62 | oshape.Set(3, tir::Cast(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w))); |
63 | |
64 | // assign output type |
65 | reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); |
66 | return true; |
67 | } |
68 | |
69 | // Positional relay function to create upsampling operator |
70 | // used by frontend FFI. |
71 | Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method, |
72 | bool align_corners) { |
73 | auto attrs = make_object<UpSamplingAttrs>(); |
74 | attrs->layout = std::move(layout); |
75 | attrs->method = std::move(method); |
76 | attrs->scale_h = scale_h; |
77 | attrs->scale_w = scale_w; |
78 | attrs->align_corners = align_corners; |
79 | static const Op& op = Op::Get("nn.upsampling" ); |
80 | return Call(op, {data}, Attrs(attrs), {}); |
81 | } |
82 | |
83 | TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling" ).set_body_typed(MakeUpSampling); |
84 | |
85 | RELAY_REGISTER_OP("nn.upsampling" ) |
86 | .describe( |
87 | R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. |
88 | |
89 | - **data**: data is 4D array of shape |
90 | (batch_size, channels, in_height, in_width) for NCHW |
91 | (batch_size, in_height, in_width, channels) for NHWC |
92 | |
93 | - **out**: Output is 4D array of shape |
94 | for layout NCHW |
95 | (batch_size, channels, in_height*scale, in_width*scale) |
96 | |
97 | for layout NHWC |
98 | (batch_size, in_height*scale, in_width*scale, channels) |
99 | |
100 | )code" TVM_ADD_FILELINE) |
101 | .set_attrs_type<UpSamplingAttrs>() |
102 | .set_num_inputs(1) |
103 | .add_argument("data" , "Tensor" , "The input tensor." ) |
104 | .set_support_level(2) |
105 | .add_type_rel("UpSampling" , UpSamplingRel) |
106 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , |
107 | UpsamplingInferCorrectLayout<UpSamplingAttrs>) |
108 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
109 | |
110 | // UpSampling3D |
111 | bool UpSampling3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
112 | const TypeReporter& reporter) { |
113 | ICHECK_EQ(types.size(), 2); |
114 | const auto* data = types[0].as<TensorTypeNode>(); |
115 | if (data == nullptr) return false; |
116 | |
117 | static const Layout kNCDHW("NCDHW" ); |
118 | |
119 | const UpSampling3DAttrs* param = attrs.as<UpSampling3DAttrs>(); |
120 | ICHECK(param != nullptr); |
121 | const Layout in_layout(param->layout); |
122 | |
123 | auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW); |
124 | ICHECK(layout_converter.defined()) |
125 | << "UpSampling3D only support input layouts that are convertible from NCDHW." |
126 | << " But got " << in_layout; |
127 | |
128 | auto oshape = layout_converter.ForwardShape(data->shape); |
129 | oshape.Set(2, tir::Cast(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d))); |
130 | oshape.Set(3, tir::Cast(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h))); |
131 | oshape.Set(4, tir::Cast(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w))); |
132 | |
133 | // assign output type |
134 | reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); |
135 | return true; |
136 | } |
137 | |
138 | // Positional relay function to create upsampling3d operator |
139 | // used by frontend FFI. |
140 | Expr MakeUpSampling3D(Expr data, double scale_d, double scale_h, double scale_w, String layout, |
141 | String method, String coordinate_transformation_mode) { |
142 | auto attrs = make_object<UpSampling3DAttrs>(); |
143 | attrs->layout = std::move(layout); |
144 | attrs->method = std::move(method); |
145 | attrs->scale_d = scale_d; |
146 | attrs->scale_h = scale_h; |
147 | attrs->scale_w = scale_w; |
148 | attrs->coordinate_transformation_mode = coordinate_transformation_mode; |
149 | static const Op& op = Op::Get("nn.upsampling3d" ); |
150 | return Call(op, {data}, Attrs(attrs), {}); |
151 | } |
152 | |
153 | TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d" ).set_body_typed(MakeUpSampling3D); |
154 | |
155 | RELAY_REGISTER_OP("nn.upsampling3d" ) |
156 | .describe(R"code(Perform upsampling on input array with nearest neighbour or |
157 | bilinear interpolation. |
158 | |
159 | - **data**: data is 5D array of shape |
160 | (batch_size, channels, in_depth, in_height, in_width) for NCDHW |
161 | (batch_size, in_depth, in_height, in_width, channels) for NDHWC |
162 | |
163 | - **out**: Output is 5D array of shape |
164 | for layout NCDHW |
165 | (batch_size, channels, in_depth*scale, in_height*scale, in_width*scale) |
166 | |
167 | for layout NDHWC |
168 | (batch_size, in_depth*scale, in_height*scale, in_width*scale, channels) |
169 | |
170 | )code" TVM_ADD_FILELINE) |
171 | .set_attrs_type<UpSampling3DAttrs>() |
172 | .set_num_inputs(1) |
173 | .add_argument("data" , "Tensor" , "The input tensor." ) |
174 | .set_support_level(2) |
175 | .add_type_rel("UpSampling3D" , UpSampling3DRel) |
176 | .set_attr<FInferCorrectLayout>("FInferCorrectLayout" , |
177 | UpsamplingInferCorrectLayout<UpSampling3DAttrs>) |
178 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
179 | |
180 | } // namespace relay |
181 | } // namespace tvm |
182 | |