1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file upsampling.cc
22 * \brief upsampling operator
23 */
24
25#include "upsampling.h"
26
27#include <tvm/relay/attrs/nn.h>
28#include <tvm/relay/op.h>
29#include <tvm/relay/op_attr_types.h>
30#include <tvm/tir/data_layout.h>
31
32#include <utility>
33#include <vector>
34
35#include "../op_common.h"
36
37namespace tvm {
38namespace relay {
39
40TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);
41TVM_REGISTER_NODE_TYPE(UpSampling3DAttrs);
42
43bool UpSamplingRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
44 const TypeReporter& reporter) {
45 ICHECK_EQ(types.size(), 2);
46 const auto* data = types[0].as<TensorTypeNode>();
47 if (data == nullptr) return false;
48
49 static const Layout kNCHW("NCHW");
50
51 const UpSamplingAttrs* param = attrs.as<UpSamplingAttrs>();
52 ICHECK(param != nullptr);
53 const Layout in_layout(param->layout);
54
55 auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
56 ICHECK(layout_converter.defined())
57 << "UpSampling only support input layouts that are convertible from NCHW."
58 << " But got " << in_layout;
59
60 auto oshape = layout_converter.ForwardShape(data->shape);
61 oshape.Set(2, tir::Cast(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_h)));
62 oshape.Set(3, tir::Cast(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_w)));
63
64 // assign output type
65 reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
66 return true;
67}
68
69// Positional relay function to create upsampling operator
70// used by frontend FFI.
71Expr MakeUpSampling(Expr data, double scale_h, double scale_w, String layout, String method,
72 bool align_corners) {
73 auto attrs = make_object<UpSamplingAttrs>();
74 attrs->layout = std::move(layout);
75 attrs->method = std::move(method);
76 attrs->scale_h = scale_h;
77 attrs->scale_w = scale_w;
78 attrs->align_corners = align_corners;
79 static const Op& op = Op::Get("nn.upsampling");
80 return Call(op, {data}, Attrs(attrs), {});
81}
82
83TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling").set_body_typed(MakeUpSampling);
84
85RELAY_REGISTER_OP("nn.upsampling")
86 .describe(
87 R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation.
88
89- **data**: data is 4D array of shape
90 (batch_size, channels, in_height, in_width) for NCHW
91 (batch_size, in_height, in_width, channels) for NHWC
92
93- **out**: Output is 4D array of shape
94 for layout NCHW
95 (batch_size, channels, in_height*scale, in_width*scale)
96
97 for layout NHWC
98 (batch_size, in_height*scale, in_width*scale, channels)
99
100)code" TVM_ADD_FILELINE)
101 .set_attrs_type<UpSamplingAttrs>()
102 .set_num_inputs(1)
103 .add_argument("data", "Tensor", "The input tensor.")
104 .set_support_level(2)
105 .add_type_rel("UpSampling", UpSamplingRel)
106 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
107 UpsamplingInferCorrectLayout<UpSamplingAttrs>)
108 .set_attr<TOpPattern>("TOpPattern", kInjective);
109
110// UpSampling3D
111bool UpSampling3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
112 const TypeReporter& reporter) {
113 ICHECK_EQ(types.size(), 2);
114 const auto* data = types[0].as<TensorTypeNode>();
115 if (data == nullptr) return false;
116
117 static const Layout kNCDHW("NCDHW");
118
119 const UpSampling3DAttrs* param = attrs.as<UpSampling3DAttrs>();
120 ICHECK(param != nullptr);
121 const Layout in_layout(param->layout);
122
123 auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW);
124 ICHECK(layout_converter.defined())
125 << "UpSampling3D only support input layouts that are convertible from NCDHW."
126 << " But got " << in_layout;
127
128 auto oshape = layout_converter.ForwardShape(data->shape);
129 oshape.Set(2, tir::Cast(oshape[2].dtype(), tvm::round(oshape[2] * param->scale_d)));
130 oshape.Set(3, tir::Cast(oshape[3].dtype(), tvm::round(oshape[3] * param->scale_h)));
131 oshape.Set(4, tir::Cast(oshape[4].dtype(), tvm::round(oshape[4] * param->scale_w)));
132
133 // assign output type
134 reporter->Assign(types[1], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
135 return true;
136}
137
138// Positional relay function to create upsampling3d operator
139// used by frontend FFI.
140Expr MakeUpSampling3D(Expr data, double scale_d, double scale_h, double scale_w, String layout,
141 String method, String coordinate_transformation_mode) {
142 auto attrs = make_object<UpSampling3DAttrs>();
143 attrs->layout = std::move(layout);
144 attrs->method = std::move(method);
145 attrs->scale_d = scale_d;
146 attrs->scale_h = scale_h;
147 attrs->scale_w = scale_w;
148 attrs->coordinate_transformation_mode = coordinate_transformation_mode;
149 static const Op& op = Op::Get("nn.upsampling3d");
150 return Call(op, {data}, Attrs(attrs), {});
151}
152
153TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d").set_body_typed(MakeUpSampling3D);
154
155RELAY_REGISTER_OP("nn.upsampling3d")
156 .describe(R"code(Perform upsampling on input array with nearest neighbour or
157bilinear interpolation.
158
159- **data**: data is 5D array of shape
160 (batch_size, channels, in_depth, in_height, in_width) for NCDHW
161 (batch_size, in_depth, in_height, in_width, channels) for NDHWC
162
163- **out**: Output is 5D array of shape
164 for layout NCDHW
165 (batch_size, channels, in_depth*scale, in_height*scale, in_width*scale)
166
167 for layout NDHWC
168 (batch_size, in_depth*scale, in_height*scale, in_width*scale, channels)
169
170)code" TVM_ADD_FILELINE)
171 .set_attrs_type<UpSampling3DAttrs>()
172 .set_num_inputs(1)
173 .add_argument("data", "Tensor", "The input tensor.")
174 .set_support_level(2)
175 .add_type_rel("UpSampling3D", UpSampling3DRel)
176 .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
177 UpsamplingInferCorrectLayout<UpSampling3DAttrs>)
178 .set_attr<TOpPattern>("TOpPattern", kInjective);
179
180} // namespace relay
181} // namespace tvm
182