1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file grid_sample.cc |
22 | * \brief affine_grid and grid_sample operator |
23 | */ |
24 | #include <tvm/relay/attrs/image.h> |
25 | #include <tvm/relay/op.h> |
26 | #include <tvm/tir/data_layout.h> |
27 | |
28 | #include "../op_common.h" |
29 | |
30 | namespace tvm { |
31 | namespace relay { |
32 | |
33 | // relay.image.affine_grid |
34 | TVM_REGISTER_NODE_TYPE(AffineGridAttrs); |
35 | |
36 | bool AffineGridRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
37 | const TypeReporter& reporter) { |
38 | ICHECK_EQ(types.size(), 2); |
39 | const auto* data = types[0].as<TensorTypeNode>(); |
40 | if (data == nullptr) return false; |
41 | auto batch_size = data->shape[0]; |
42 | |
43 | const AffineGridAttrs* param = attrs.as<AffineGridAttrs>(); |
44 | ICHECK(param != nullptr); |
45 | |
46 | Array<IndexExpr> oshape; |
47 | |
48 | ICHECK(data->shape.size() == 3U && reporter->AssertEQ(data->shape[1], 2) && |
49 | reporter->AssertEQ(data->shape[2], 3)) |
50 | << "data should be an" |
51 | "affine matrix with shape [batch_size, 2, 3]" ; |
52 | ICHECK(param->target_shape.defined() && param->target_shape.size() == 2) |
53 | << "target_shape should be 2D" ; |
54 | oshape.push_back(batch_size); |
55 | oshape.push_back(2); |
56 | oshape.push_back(param->target_shape[0]); |
57 | oshape.push_back(param->target_shape[1]); |
58 | |
59 | // assign output type |
60 | reporter->Assign(types[1], TensorType(oshape, data->dtype)); |
61 | return true; |
62 | } |
63 | |
64 | // Positional relay function to create affine_grid operator |
65 | // used by frontend FFI. |
66 | Expr MakeAffineGrid(Expr data, Array<IndexExpr> target_shape) { |
67 | auto attrs = make_object<AffineGridAttrs>(); |
68 | attrs->target_shape = std::move(target_shape); |
69 | static const Op& op = Op::Get("image.affine_grid" ); |
70 | return Call(op, {data}, Attrs(attrs), {}); |
71 | } |
72 | |
73 | TVM_REGISTER_GLOBAL("relay.op.image._make.affine_grid" ).set_body_typed(MakeAffineGrid); |
74 | |
75 | RELAY_REGISTER_OP("image.affine_grid" ) |
76 | .describe(R"code(affine_grid operator that generates 2D sampling grid. |
77 | |
78 | This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform |
79 | sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine |
80 | transformation is then applied on the sampling grid. |
81 | |
82 | - **data**: data is 3D array of shape [batch, 2, 3], which defines an affine transformation. |
83 | |
84 | - **out**: out is 4D array of shape [batch, 2, height, width], where each vector |
85 | :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)` |
86 | |
87 | )code" TVM_ADD_FILELINE) |
88 | .set_attrs_type<AffineGridAttrs>() |
89 | .set_num_inputs(1) |
90 | .add_argument("data" , "Tensor" , "The affine matrix." ) |
91 | .set_support_level(5) |
92 | .add_type_rel("AffineGrid" , AffineGridRel) |
93 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
94 | |
95 | // relay.image.grid_sample |
96 | TVM_REGISTER_NODE_TYPE(GridSampleAttrs); |
97 | |
98 | bool GridSampleRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, |
99 | const TypeReporter& reporter) { |
100 | ICHECK_EQ(types.size(), 3); |
101 | const auto* data = types[0].as<TensorTypeNode>(); |
102 | const auto* grid = types[1].as<TensorTypeNode>(); |
103 | if (!data || !grid) return false; |
104 | const auto* param = attrs.as<GridSampleAttrs>(); |
105 | ICHECK(param); |
106 | const Layout in_layout(param->layout); |
107 | |
108 | if (data->shape.size() == 4) { |
109 | static const Layout kNCHW("NCHW" ); |
110 | auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW); |
111 | auto oshape = layout_converter.ForwardShape(data->shape); |
112 | oshape.Set(2, grid->shape[2]); |
113 | oshape.Set(3, grid->shape[3]); |
114 | |
115 | // assign output type |
116 | reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); |
117 | return true; |
118 | } else if (data->shape.size() == 5) { |
119 | static const Layout kNDCHW("NCDHW" ); |
120 | auto layout_converter = tir::BijectiveLayout(in_layout, kNDCHW); |
121 | auto oshape = layout_converter.ForwardShape(data->shape); |
122 | oshape.Set(2, grid->shape[2]); |
123 | oshape.Set(3, grid->shape[3]); |
124 | oshape.Set(4, grid->shape[4]); |
125 | |
126 | // assign output type |
127 | reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype)); |
128 | return true; |
129 | } |
130 | |
131 | return false; |
132 | } |
133 | |
134 | // Positional relay function to create affine_grid operator |
135 | // used by frontend FFI. |
136 | Expr MakeGridSample(Expr data, Expr grid, String method, String layout, String padding_mode, |
137 | bool align_corners) { |
138 | auto attrs = make_object<GridSampleAttrs>(); |
139 | attrs->method = std::move(method); |
140 | attrs->layout = std::move(layout); |
141 | attrs->padding_mode = std::move(padding_mode); |
142 | attrs->align_corners = std::move(align_corners); |
143 | |
144 | static const Op& op = Op::Get("image.grid_sample" ); |
145 | return Call(op, {data, grid}, Attrs(attrs), {}); |
146 | } |
147 | |
148 | TVM_REGISTER_GLOBAL("relay.op.image._make.grid_sample" ).set_body_typed(MakeGridSample); |
149 | |
150 | RELAY_REGISTER_OP("image.grid_sample" ) |
151 | .describe(R"code(Applies grid sampling to input feature map. |
152 | |
153 | Given :math:`data` and :math:`grid`, then the output is computed by |
154 | |
155 | .. math:: |
156 | |
157 | x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\ |
158 | y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\ |
159 | output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}]) |
160 | |
161 | For 5-D, the output is computed by |
162 | |
163 | .. math:: |
164 | |
165 | x_{src} = grid[batch, 0, z_{dst}, y_{dst}, x_{dst}] \\ |
166 | y_{src} = grid[batch, 1, z_{dst}, y_{dst}, x_{dst}] \\ |
167 | z_{src} = grid[batch, 2, z_{dst}, y_{dst}, x_{dst}] \\ |
168 | output[batch, channel, z_{src}, y_{dst}, x_{dst}] |
169 | = G(data[batch, channel, z_{src}, y_{src}, x_{src}]) |
170 | |
171 | :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and |
172 | :math:`G()` denotes the interpolation function. |
173 | |
174 | The out-boundary points will be padded with zeros if padding_mode is "zeros", or |
175 | border pixel value if padding_mode is "border", or |
176 | inner pixel value if padding_mode is "reflection". |
177 | |
178 | The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to |
179 | (0, 0) and (h - 1, w - 1) of data if align_corners is "True", or |
180 | (-0.5, -0.5) and (h - 0.5, w - 0.5) of data if align_corners is "False". |
181 | |
182 | The shape of the output will be |
183 | 4-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or |
184 | 5-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]). |
185 | |
186 | The operator assumes that :math:`data` and :math:`grid` has been normalized to [-1, 1]. |
187 | |
188 | grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample. |
189 | |
190 | - **data**: data is of 4-D shape (batch_size, channels, in_height, in_width), or |
191 | of 5-D shape (batch_size, channels, in_depth, in_height, in_width) |
192 | |
193 | - **grid**: grid is of 4-D shape [batch, 2, out_height, out_width] |
194 | where each vector :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)`, |
195 | or of 5-D of shape [batch, 3, out_depth, out_height, out_width] |
196 | where each vector :math:`out[b, :, d, h, w]` represents the coordinate |
197 | :math:`(x, y, z)` |
198 | |
199 | - **out**: out is of 4-D shape (batch, in_channel, out_height, out_width), or |
200 | of 5-D shape [batch, channel, out_depth, out_height, out_width] |
201 | |
202 | )code" TVM_ADD_FILELINE) |
203 | .set_num_inputs(2) |
204 | .set_attrs_type<GridSampleAttrs>() |
205 | .add_argument("data" , "Tensor" , "The input tensor." ) |
206 | .add_argument("grid" , "Tensor" , "The grid tensor." ) |
207 | .set_support_level(5) |
208 | .add_type_rel("GridSample" , GridSampleRel) |
209 | .set_attr<TOpPattern>("TOpPattern" , kInjective); |
210 | |
211 | } // namespace relay |
212 | } // namespace tvm |
213 | |