1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file grid_sample.cc
22 * \brief affine_grid and grid_sample operator
23 */
24#include <tvm/relay/attrs/image.h>
25#include <tvm/relay/op.h>
26#include <tvm/tir/data_layout.h>
27
28#include "../op_common.h"
29
30namespace tvm {
31namespace relay {
32
33// relay.image.affine_grid
34TVM_REGISTER_NODE_TYPE(AffineGridAttrs);
35
36bool AffineGridRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
37 const TypeReporter& reporter) {
38 ICHECK_EQ(types.size(), 2);
39 const auto* data = types[0].as<TensorTypeNode>();
40 if (data == nullptr) return false;
41 auto batch_size = data->shape[0];
42
43 const AffineGridAttrs* param = attrs.as<AffineGridAttrs>();
44 ICHECK(param != nullptr);
45
46 Array<IndexExpr> oshape;
47
48 ICHECK(data->shape.size() == 3U && reporter->AssertEQ(data->shape[1], 2) &&
49 reporter->AssertEQ(data->shape[2], 3))
50 << "data should be an"
51 "affine matrix with shape [batch_size, 2, 3]";
52 ICHECK(param->target_shape.defined() && param->target_shape.size() == 2)
53 << "target_shape should be 2D";
54 oshape.push_back(batch_size);
55 oshape.push_back(2);
56 oshape.push_back(param->target_shape[0]);
57 oshape.push_back(param->target_shape[1]);
58
59 // assign output type
60 reporter->Assign(types[1], TensorType(oshape, data->dtype));
61 return true;
62}
63
64// Positional relay function to create affine_grid operator
65// used by frontend FFI.
66Expr MakeAffineGrid(Expr data, Array<IndexExpr> target_shape) {
67 auto attrs = make_object<AffineGridAttrs>();
68 attrs->target_shape = std::move(target_shape);
69 static const Op& op = Op::Get("image.affine_grid");
70 return Call(op, {data}, Attrs(attrs), {});
71}
72
73TVM_REGISTER_GLOBAL("relay.op.image._make.affine_grid").set_body_typed(MakeAffineGrid);
74
75RELAY_REGISTER_OP("image.affine_grid")
76 .describe(R"code(affine_grid operator that generates 2D sampling grid.
77
78This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It generates a uniform
79sampling grid within the target shape and normalizes it to [-1, 1]. The provided affine
80transformation is then applied on the sampling grid.
81
82- **data**: data is 3D array of shape [batch, 2, 3], which defines an affine transformation.
83
84- **out**: out is 4D array of shape [batch, 2, height, width], where each vector
85 :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)`
86
87)code" TVM_ADD_FILELINE)
88 .set_attrs_type<AffineGridAttrs>()
89 .set_num_inputs(1)
90 .add_argument("data", "Tensor", "The affine matrix.")
91 .set_support_level(5)
92 .add_type_rel("AffineGrid", AffineGridRel)
93 .set_attr<TOpPattern>("TOpPattern", kInjective);
94
95// relay.image.grid_sample
96TVM_REGISTER_NODE_TYPE(GridSampleAttrs);
97
98bool GridSampleRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
99 const TypeReporter& reporter) {
100 ICHECK_EQ(types.size(), 3);
101 const auto* data = types[0].as<TensorTypeNode>();
102 const auto* grid = types[1].as<TensorTypeNode>();
103 if (!data || !grid) return false;
104 const auto* param = attrs.as<GridSampleAttrs>();
105 ICHECK(param);
106 const Layout in_layout(param->layout);
107
108 if (data->shape.size() == 4) {
109 static const Layout kNCHW("NCHW");
110 auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
111 auto oshape = layout_converter.ForwardShape(data->shape);
112 oshape.Set(2, grid->shape[2]);
113 oshape.Set(3, grid->shape[3]);
114
115 // assign output type
116 reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
117 return true;
118 } else if (data->shape.size() == 5) {
119 static const Layout kNDCHW("NCDHW");
120 auto layout_converter = tir::BijectiveLayout(in_layout, kNDCHW);
121 auto oshape = layout_converter.ForwardShape(data->shape);
122 oshape.Set(2, grid->shape[2]);
123 oshape.Set(3, grid->shape[3]);
124 oshape.Set(4, grid->shape[4]);
125
126 // assign output type
127 reporter->Assign(types[2], TensorType(layout_converter.BackwardShape(oshape), data->dtype));
128 return true;
129 }
130
131 return false;
132}
133
134// Positional relay function to create affine_grid operator
135// used by frontend FFI.
136Expr MakeGridSample(Expr data, Expr grid, String method, String layout, String padding_mode,
137 bool align_corners) {
138 auto attrs = make_object<GridSampleAttrs>();
139 attrs->method = std::move(method);
140 attrs->layout = std::move(layout);
141 attrs->padding_mode = std::move(padding_mode);
142 attrs->align_corners = std::move(align_corners);
143
144 static const Op& op = Op::Get("image.grid_sample");
145 return Call(op, {data, grid}, Attrs(attrs), {});
146}
147
148TVM_REGISTER_GLOBAL("relay.op.image._make.grid_sample").set_body_typed(MakeGridSample);
149
150RELAY_REGISTER_OP("image.grid_sample")
151 .describe(R"code(Applies grid sampling to input feature map.
152
153Given :math:`data` and :math:`grid`, then the output is computed by
154
155.. math::
156
157 x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\
158 y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\
159 output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, x_{src}])
160
161For 5-D, the output is computed by
162
163.. math::
164
165 x_{src} = grid[batch, 0, z_{dst}, y_{dst}, x_{dst}] \\
166 y_{src} = grid[batch, 1, z_{dst}, y_{dst}, x_{dst}] \\
167 z_{src} = grid[batch, 2, z_{dst}, y_{dst}, x_{dst}] \\
168 output[batch, channel, z_{src}, y_{dst}, x_{dst}]
169 = G(data[batch, channel, z_{src}, y_{src}, x_{src}])
170
171:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in :math:`output`, and
172:math:`G()` denotes the interpolation function.
173
174The out-boundary points will be padded with zeros if padding_mode is "zeros", or
175border pixel value if padding_mode is "border", or
176inner pixel value if padding_mode is "reflection".
177
178The left-top corner (-1, -1) and right-bottom corner (1, 1) in grid will be map to
179(0, 0) and (h - 1, w - 1) of data if align_corners is "True", or
180(-0.5, -0.5) and (h - 0.5, w - 0.5) of data if align_corners is "False".
181
182The shape of the output will be
1834-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]), or
1845-D (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3], grid.shape[4]).
185
186The operator assumes that :math:`data` and :math:`grid` has been normalized to [-1, 1].
187
188grid_sample often cooperates with affine_grid which generates sampling grids for grid_sample.
189
190- **data**: data is of 4-D shape (batch_size, channels, in_height, in_width), or
191 of 5-D shape (batch_size, channels, in_depth, in_height, in_width)
192
193- **grid**: grid is of 4-D shape [batch, 2, out_height, out_width]
194 where each vector :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)`,
195 or of 5-D of shape [batch, 3, out_depth, out_height, out_width]
196 where each vector :math:`out[b, :, d, h, w]` represents the coordinate
197 :math:`(x, y, z)`
198
199- **out**: out is of 4-D shape (batch, in_channel, out_height, out_width), or
200 of 5-D shape [batch, channel, out_depth, out_height, out_width]
201
202)code" TVM_ADD_FILELINE)
203 .set_num_inputs(2)
204 .set_attrs_type<GridSampleAttrs>()
205 .add_argument("data", "Tensor", "The input tensor.")
206 .add_argument("grid", "Tensor", "The grid tensor.")
207 .set_support_level(5)
208 .add_type_rel("GridSample", GridSampleRel)
209 .set_attr<TOpPattern>("TOpPattern", kInjective);
210
211} // namespace relay
212} // namespace tvm
213