1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Pooling op constructions |
22 | * \file nn/pooling.h |
23 | */ |
24 | #ifndef TVM_TOPI_NN_POOLING_H_ |
25 | #define TVM_TOPI_NN_POOLING_H_ |
26 | |
27 | #include <tvm/arith/analyzer.h> |
28 | #include <tvm/topi/detail/pad_utils.h> |
29 | #include <tvm/topi/nn.h> |
30 | #include <tvm/topi/reduction.h> |
31 | #include <tvm/topi/tags.h> |
32 | |
33 | #include <algorithm> |
34 | #include <string> |
35 | #include <vector> |
36 | |
37 | namespace tvm { |
38 | namespace topi { |
39 | namespace nn { |
40 | |
41 | using namespace tvm::te; |
42 | |
43 | /*! \brief Pooling type */ |
44 | enum PoolType : int { |
45 | kAvgPool, |
46 | kMaxPool, |
47 | }; |
48 | |
49 | inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x, |
50 | const Array<PrimExpr>& kernel_size, const Array<PrimExpr>& stride_size, |
51 | const Array<PrimExpr>& padding_size, PoolType pool_type, |
52 | bool ceil_mode, const size_t height_axis, const size_t width_axis, |
53 | bool count_include_pad) { |
54 | ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)" ; |
55 | ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)" ; |
56 | ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements" ; |
57 | ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements" ; |
58 | ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements" ; |
59 | |
60 | auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]); |
61 | auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]); |
62 | auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]); |
63 | auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]); |
64 | |
65 | auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]); |
66 | auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]); |
67 | |
68 | auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]); |
69 | auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]); |
70 | auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]); |
71 | auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]); |
72 | |
73 | if (ceil_mode) { |
74 | // Additional padding to ensure we do ceil instead of floor when |
75 | // dividing by stride. |
76 | pad_bottom += stride_height - 1; |
77 | pad_right += stride_width - 1; |
78 | } |
79 | |
80 | Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0)); |
81 | pad_before.Set(height_axis, pad_top); |
82 | pad_before.Set(width_axis, pad_left); |
83 | |
84 | Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0)); |
85 | pad_after.Set(height_axis, pad_bottom); |
86 | pad_after.Set(width_axis, pad_right); |
87 | arith::Analyzer analyzer; |
88 | auto out_height = |
89 | analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); |
90 | auto out_width = |
91 | analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); |
92 | |
93 | auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh" ); |
94 | auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw" ); |
95 | |
96 | Array<PrimExpr> data_shape = x->shape; |
97 | for (size_t i = 0; i < data_shape.size(); ++i) { |
98 | data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i])); |
99 | } |
100 | |
101 | Array<PrimExpr> out_shape = data_shape; |
102 | out_shape.Set(height_axis, out_height); |
103 | out_shape.Set(width_axis, out_width); |
104 | |
105 | const int64_t* padding_h0 = as_const_int(pad_top); |
106 | const int64_t* padding_w0 = as_const_int(pad_left); |
107 | const int64_t* padding_h1 = as_const_int(pad_bottom); |
108 | const int64_t* padding_w1 = as_const_int(pad_right); |
109 | const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) || |
110 | ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); |
111 | |
112 | if (pool_type == kMaxPool) { |
113 | Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()}; |
114 | ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom); |
115 | ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right); |
116 | |
117 | auto windowh = |
118 | tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh" ); |
119 | auto windoww = |
120 | tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww" ); |
121 | |
122 | auto argmax = MakeArgmaxReducer(); |
123 | auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp" ) : x; |
124 | |
125 | auto mp_argmax = tvm::te::compute( |
126 | out_shape, |
127 | [&](const Array<Var>& inds) { |
128 | Array<PrimExpr> window_inds{inds.begin(), inds.end()}; |
129 | window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight); |
130 | window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth); |
131 | auto idx = detail::RavelIndex(window_inds, ravel_shape); |
132 | return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr); |
133 | }, |
134 | "maxpool_grad_argmax" , kCommReduceIdx); |
135 | |
136 | auto mp_inds = mp_argmax[0]; |
137 | |
138 | return tvm::te::compute( |
139 | data_shape, |
140 | [&](const Array<Var>& inds) { |
141 | Array<PrimExpr> pad_inds{inds.begin(), inds.end()}; |
142 | pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top); |
143 | pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left); |
144 | auto idx = detail::RavelIndex(pad_inds, ravel_shape); |
145 | |
146 | Array<PrimExpr> out_idx{inds.begin(), inds.end()}; |
147 | out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh); |
148 | out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww); |
149 | |
150 | PrimExpr out_idx_lower_h = tir::Select( |
151 | pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0), |
152 | (pad_inds[height_axis] - kernel_height) / stride_height + 1); |
153 | PrimExpr out_idx_lower_w = tir::Select( |
154 | pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0), |
155 | (pad_inds[width_axis] - kernel_width) / stride_width + 1); |
156 | |
157 | return tvm::sum( |
158 | tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h, |
159 | out_idx[width_axis] >= out_idx_lower_w), |
160 | mp_inds(out_idx) == idx), |
161 | out_grad(out_idx), make_const(x->dtype, 0)), |
162 | {windowh, windoww}); |
163 | }, |
164 | "T_pool_grad" , "pool_grad_max" ); |
165 | } else if (pool_type == kAvgPool) { |
166 | auto windowh = |
167 | tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh" ); |
168 | auto windoww = |
169 | tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww" ); |
170 | return tvm::te::compute( |
171 | data_shape, |
172 | [&](const Array<Var>& inds) { |
173 | PrimExpr pad_h_idx = inds[height_axis] + pad_top; |
174 | PrimExpr pad_w_idx = inds[width_axis] + pad_left; |
175 | |
176 | // output indices whose pooling windows cover current input element (can be out-of-bound) |
177 | Array<PrimExpr> out_idx{inds.begin(), inds.end()}; |
178 | out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh)); |
179 | out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww)); |
180 | |
181 | PrimExpr out_idx_lower_h = |
182 | tir::Select(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0), |
183 | (pad_h_idx - kernel_height) / stride_height + 1); |
184 | PrimExpr out_idx_lower_w = |
185 | tir::Select(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0), |
186 | (pad_w_idx - kernel_width) / stride_width + 1); |
187 | |
188 | PrimExpr divide_factor; // number of pooled elements |
189 | if (count_include_pad) { |
190 | divide_factor = kernel_height * kernel_width; |
191 | } else { |
192 | PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top; |
193 | PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left; |
194 | |
195 | PrimExpr h_end = min(h_start + kernel_height, height); |
196 | PrimExpr w_end = min(w_start + kernel_width, width); |
197 | h_start = max(h_start, make_const(DataType::Int(32), 0)); |
198 | w_start = max(w_start, make_const(DataType::Int(32), 0)); |
199 | divide_factor = |
200 | max((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1)); |
201 | } |
202 | return tvm::sum( |
203 | tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h, |
204 | out_idx[height_axis] < out_height), |
205 | tir::And(out_idx[width_axis] >= out_idx_lower_w, |
206 | out_idx[width_axis] < out_width)), |
207 | out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)), |
208 | {windowh, windoww}); |
209 | }, |
210 | "T_pool_grad" , "pool_grad_avg" ); |
211 | } else { |
212 | LOG(ERROR) << "Unrecognized pool_type: " << pool_type; |
213 | return Tensor(); |
214 | } |
215 | } |
216 | |
217 | inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis, |
218 | int* width_axis) { |
219 | *depth_axis = -1; |
220 | *height_axis = -1; |
221 | *width_axis = -1; |
222 | int curr_idx = 0; |
223 | for (size_t i = 0; i < layout.size(); ++i) { |
224 | if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) { |
225 | if (layout[i] == 'D') { |
226 | if (*depth_axis != -1) return false; |
227 | *depth_axis = curr_idx; |
228 | } else if (layout[i] == 'H') { |
229 | if (*height_axis != -1) return false; |
230 | *height_axis = curr_idx; |
231 | } else if (layout[i] == 'W') { |
232 | if (*width_axis != -1) return false; |
233 | *width_axis = curr_idx; |
234 | } else if (layout[i] == 'd' || layout[i] == 'h' || layout[i] == 'w') { |
235 | // do not support split on height or width, e.g., NCHW16w |
236 | return false; |
237 | } |
238 | ++curr_idx; |
239 | } |
240 | } |
241 | if (*depth_axis == -1 || *height_axis == -1 || *width_axis == -1) return false; |
242 | return true; |
243 | } |
244 | |
245 | inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) { |
246 | int dummy; |
247 | ICHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false); |
248 | if (*height_axis != -1 && *width_axis != -1) { |
249 | return true; |
250 | } |
251 | return false; |
252 | } |
253 | |
254 | inline bool find_width(const std::string& layout, int* width_axis) { |
255 | int dummy; |
256 | ICHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false); |
257 | if (*width_axis != -1) { |
258 | return true; |
259 | } |
260 | return false; |
261 | } |
262 | |
263 | /*! |
264 | * \brief Calculate gradient of pooling on height and width dimension of data. |
265 | * It decides the height and width dimension according to the layout string, |
266 | * in which 'W' and 'H' means width and height respectively. |
267 | * Width and height dimension cannot be split. |
268 | * For example, NCHW, NCHW16c, etc. are valid for pool, |
269 | * while NCHW16w, NCHW16h are not. |
270 | * See \a layout for more information of the layout string convention. |
271 | * \param out_grad The output gradient tensor. |
272 | * \param x The input tensor. |
273 | * \param kernel_size Vector of two ints: {kernel_height, kernel_width} |
274 | * \param stride_size Vector of two ints: {stride_height, stride_width} |
275 | * \param padding_size Vector of two ints: {padding_height, padding_width} |
276 | * \param pool_type The type of pooling operator |
277 | * \param ceil_mode Whether to use ceil when calculating the output size |
278 | * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. |
279 | * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, |
280 | * where upper case indicates a dimension and |
281 | * the corresponding lower case (with factor size) indicates the split dimension. |
282 | * For example, NCHW16c can describe a 5-D tensor of |
283 | * [batch_size, channel, height, width, channel_block]. |
284 | * (in which factor size `16` will not be used in pooling but for other operators, |
285 | * it can be used to decide the output shape). |
286 | * Since pooling does not care about the factor size of dimensions |
287 | * other than `H` and `W`, one can pass `NCHWc` as well. |
288 | * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' |
289 | * |
290 | * |
291 | * \return The output tensor in the same layout |
292 | */ |
293 | inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<PrimExpr>& kernel_size, |
294 | const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size, |
295 | PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW" , |
296 | bool count_include_pad = true) { |
297 | int height_axis = -1, width_axis = -1; |
298 | ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; |
299 | return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode, |
300 | height_axis, width_axis, count_include_pad); |
301 | } |
302 | |
303 | inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) { |
304 | return indexdiv(out_index * idim, odim); |
305 | } |
306 | |
307 | inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) { |
308 | PrimExpr tmp = indexdiv((out_index + 1) * idim, odim); |
309 | return tvm::tir::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1); |
310 | } |
311 | |
312 | /*! |
313 | * \brief Perform adaptive pooling on N dimensional data |
314 | * |
315 | * \param x The input tensor |
316 | * \param output_size int vector of size in each dimension |
317 | * \param pool_type The type of pooling operator |
318 | * \param axes indices of each dimension |
319 | * |
320 | * \return The output tensor in same layout order |
321 | */ |
322 | inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_size, |
323 | PoolType pool_type, const std::vector<int>& axes) { |
324 | const auto n_dim = output_size.size(); |
325 | ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension" ; |
326 | |
327 | Array<PrimExpr> data_shape = x->shape; |
328 | for (size_t i = 0; i < data_shape.size(); ++i) { |
329 | data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i])); |
330 | } |
331 | Array<PrimExpr> out_shape = data_shape; |
332 | Array<PrimExpr> in_size, out_size; |
333 | for (size_t i = 0; i < n_dim; ++i) { |
334 | in_size.push_back(data_shape[axes[i]]); |
335 | out_size.push_back(cast(DataType::Int(32), output_size[i])); |
336 | out_shape.Set(axes[i], out_size[i]); |
337 | } |
338 | |
339 | auto get_iter_vars = [=](const Array<Var>& output, bool reduce_indices) { |
340 | Array<PrimExpr> indices; |
341 | for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]); |
342 | Array<tir::IterVar> reduce_axes; |
343 | for (size_t i = 0; i < n_dim; ++i) { |
344 | auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]); |
345 | auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]); |
346 | auto rv_name = "rv" + std::to_string(i); |
347 | auto rv_axis = tvm::te::reduce_axis(Range(0, i_end - i_start), rv_name); |
348 | reduce_axes.push_back(rv_axis); |
349 | if (reduce_indices) { |
350 | indices.Set(axes[i], i_start + rv_axis); |
351 | } |
352 | } |
353 | return std::make_tuple(indices, reduce_axes); |
354 | }; |
355 | |
356 | Map<String, ObjectRef> attrs; |
357 | if (pool_type == kMaxPool) { |
358 | attrs.Set("schedule_rule" , tvm::runtime::String("meta_schedule.adaptive_pool_max" )); |
359 | return tvm::te::compute( |
360 | out_shape, |
361 | [&](const Array<Var>& output) { |
362 | Array<PrimExpr> indices; |
363 | Array<tir::IterVar> reduce_axes; |
364 | std::tie(indices, reduce_axes) = get_iter_vars(output, true); |
365 | return tvm::max(x(indices), reduce_axes); // NOLINT(*) |
366 | }, |
367 | "adaptive_pool_max" , "adaptive_pool_max" , attrs); |
368 | } else if (pool_type == kAvgPool) { |
369 | attrs.Set("schedule_rule" , tvm::runtime::String("meta_schedule.adaptive_pool_avg" )); |
370 | auto pool_sum = tvm::te::compute( |
371 | out_shape, |
372 | [&](const Array<Var>& output) { |
373 | Array<PrimExpr> indices; |
374 | Array<tir::IterVar> reduce_axes; |
375 | std::tie(indices, reduce_axes) = get_iter_vars(output, true); |
376 | return tvm::sum(x(indices), reduce_axes); |
377 | }, |
378 | "adaptive_pool_sum" , "adaptive_pool_sum" ); |
379 | |
380 | return tvm::te::compute( |
381 | out_shape, |
382 | [&](const Array<Var>& output) { |
383 | Array<PrimExpr> indices; |
384 | Array<tir::IterVar> reduce_axes; |
385 | std::tie(indices, reduce_axes) = get_iter_vars(output, false); |
386 | |
387 | PrimExpr divide_factor = tvm::cast(x->dtype, 1); |
388 | for (size_t i = 0; i < n_dim; ++i) { |
389 | divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent); |
390 | } |
391 | |
392 | return div(pool_sum(indices), divide_factor); |
393 | }, |
394 | "adaptive_pool_avg" , kElementWise, attrs); |
395 | } else { |
396 | LOG(ERROR) << "Unrecognized pool_type: " << pool_type; |
397 | return x; |
398 | } |
399 | } |
400 | |
401 | /*! |
402 | * \brief Adaptively perform pooling on height and width dimension of data. |
403 | * The pooling kernel and stride sizes are automatically chosen for desired output sizes. |
404 | * It decides the height and width dimension according to the layout string, |
405 | * in which 'W' and 'H' means width and height respectively. |
406 | * Width and height dimension cannot be split. |
407 | * For example, NCHW, NCHW16c, etc. are valid for pool, |
408 | * while NCHW16w, NCHW16h are not. |
409 | * See \a layout for more information of the layout string convention. |
410 | * |
411 | * \param x The input tensor |
412 | * \param output_size Vector of two ints: {output_height, output_width} |
413 | * \param pool_type The type of pooling operator |
414 | * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. |
415 | * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, |
416 | * where upper case indicates a dimension and |
417 | * the corresponding lower case (with factor size) indicates the split dimension. |
418 | * For example, NCHW16c can describe a 5-D tensor of |
419 | * [batch_size, channel, height, width, channel_block]. |
420 | * (in which factor size `16` will not be used in pooling but for other operators, |
421 | * it can be used to decide the output shape). |
422 | * Since pooling does not care about the factor size of dimensions |
423 | * other than `H` and `W`, one can pass `NCHWc` as well. |
424 | * |
425 | * \return The output tensor in same layout order |
426 | */ |
427 | inline Tensor adaptive_pool(const Tensor& x, const Array<PrimExpr>& output_size, PoolType pool_type, |
428 | const std::string& layout = "NCHW" ) { |
429 | int height_axis = -1, width_axis = -1; |
430 | ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; |
431 | return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis}); |
432 | } |
433 | |
434 | /*! |
435 | * \brief Adaptively perform pooling on three dimensional data. |
436 | * See the two dimensional version above for details. |
437 | * \param x The input tensor |
438 | * \param output_size Vector of three ints: {output_depth, output_height, output_width} |
439 | * \param pool_type The type of pooling operator |
440 | * \param layout The input layout. The default is "NCDHW". |
441 | */ |
442 | inline Tensor adaptive_pool3d(const Tensor& x, const Array<PrimExpr>& output_size, |
443 | PoolType pool_type, const std::string& layout = "NCDHW" ) { |
444 | int depth_axis = -1, height_axis = -1, width_axis = -1; |
445 | ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) |
446 | << "Unsupported layout " << layout; |
447 | return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis}); |
448 | } |
449 | |
450 | /*! |
451 | * \brief Adaptively perform pooling on one dimensional data. |
452 | * See the two dimensional version above for details. |
453 | * \param x The input tensor |
454 | * \param output_size Vector of one int: {output_width} |
455 | * \param pool_type The type of pooling operator |
456 | * \param layout The input layout. The default is "NCW". |
457 | */ |
458 | inline Tensor adaptive_pool1d(const Tensor& x, const Array<PrimExpr>& output_size, |
459 | PoolType pool_type, const std::string& layout = "NCW" ) { |
460 | int width_axis = -1; |
461 | ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; |
462 | return adaptive_pool_impl(x, output_size, pool_type, {width_axis}); |
463 | } |
464 | |
465 | /*! |
466 | * \brief Perform global pooling on height and width dimension of data. |
467 | * It decides the height and width dimension according to the layout string, |
468 | * in which 'W' and 'H' means width and height respectively. |
469 | * Width and height dimension cannot be split. |
470 | * For example, NCHW, NCHW16c, ... are valid for global_pool, |
471 | * while NCHW16w, NCHW16h are not. |
472 | * See \a layout for more information of the layout string convention. |
473 | * |
474 | * \param x The input tensor represent as layout |
475 | * \param pool_type The type of pooling operator |
476 | * \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear. |
477 | * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, |
478 | * where upper case indicates a dimension and |
479 | * the corresponding lower case (with factor size) indicates the sub-dimension. |
480 | * For example, `NCHW16c` can describe a 5-D tensor of |
481 | * [batch_size, channel, height, width, channel_block]. |
482 | * (in which factor size `16` will not be used in pooling but for other operators, |
483 | * it can be used to decide the output shape). |
484 | * Since pooling does not care about the factor size of |
485 | * dimensions other than `H` and `W`, one can pass `NCHWc` as well. |
486 | * |
487 | * \return The output tensor in same layout with height and width dimension size of 1. |
488 | * e.g., for NCHW, the output shape will be [batch, channel, 1, 1] |
489 | */ |
490 | inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW" ) { |
491 | return adaptive_pool(x, Array<PrimExpr>{1, 1}, pool_type, layout); |
492 | } |
493 | |
494 | /*! |
495 | * \brief Perform pooling on N-dimension of data. |
496 | * |
497 | * \param x The input tensor |
498 | * \param kernel_size Vector of N ints |
499 | * \param stride_size Vector of N ints |
500 | * \param dilation_size Vector of N ints |
501 | * \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ..., |
502 | * head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN] |
503 | * \param pool_type The type of pooling operator |
504 | * \param ceil_mode Whether to use ceil when calculating the output size |
505 | * \param axis Vector of indices for the N dimensions |
506 | * \param count_include_pad Whether include padding in the calculation |
507 | * |
508 | * \return The output tensor in same layout order |
509 | */ |
510 | inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size, |
511 | const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size, |
512 | const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode, |
513 | const std::vector<int>& axis, bool count_include_pad) { |
514 | int k_size = kernel_size.size(); |
515 | int x_size = x->shape.size(); |
516 | ICHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel" ; |
517 | ICHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of" |
518 | " kernel" ; |
519 | ICHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel" ; |
520 | |
521 | Array<IterVar> daxis; |
522 | std::vector<PrimExpr> kernel(k_size); |
523 | std::vector<PrimExpr> stride(k_size); |
524 | std::vector<PrimExpr> dilation(k_size); |
525 | std::vector<PrimExpr> pad_head(k_size); |
526 | std::vector<PrimExpr> pad_tail(k_size); |
527 | std::vector<PrimExpr> offset(k_size, 0); |
528 | Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0)); |
529 | Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0)); |
530 | Array<PrimExpr> data_shape = x->shape; |
531 | for (size_t i = 0; i < data_shape.size(); ++i) { |
532 | data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i])); |
533 | } |
534 | Array<PrimExpr> out_shape = data_shape; |
535 | |
536 | bool do_pad = false; |
537 | for (int i = 0; i < k_size; i++) { |
538 | int ii = axis[i]; |
539 | kernel[i] = cast(DataType::Int(32), kernel_size[i]); |
540 | stride[i] = cast(DataType::Int(32), stride_size[i]); |
541 | dilation[i] = cast(DataType::Int(32), dilation_size[i]); |
542 | pad_head[i] = cast(DataType::Int(32), padding_size[i]); |
543 | pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]); |
544 | |
545 | if (ceil_mode) { |
546 | // The offset[i] is an additional padding to ensure we do ceil instead of floor when |
547 | // dividing by stride. |
548 | // In the case of ceil_mode=True and count_include_pad=True, |
549 | // in order to obtain the correct boundary, |
550 | // we also need to use the offset[i] to eliminate this extra padding. |
551 | offset[i] = stride[i] - 1; |
552 | pad_tail[i] += offset[i]; |
553 | } |
554 | |
555 | const int64_t* padding0 = as_const_int(pad_head[i]); |
556 | const int64_t* padding1 = as_const_int(pad_tail[i]); |
557 | do_pad = do_pad || (padding0 && *padding0) || (padding1 && *padding1); |
558 | |
559 | daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i]), "rv" + std::to_string(i))); |
560 | |
561 | pad_before.Set(ii, pad_head[i]); |
562 | pad_after.Set(ii, pad_tail[i]); |
563 | |
564 | arith::Analyzer analyzer; |
565 | |
566 | PrimExpr numerator = |
567 | data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i]; |
568 | auto out_dim = analyzer.Simplify(indexdiv(numerator, stride[i]) + 1); |
569 | out_shape.Set(ii, out_dim); |
570 | } |
571 | |
572 | Map<String, ObjectRef> attrs; |
573 | if (pool_type == kMaxPool) { |
574 | auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp" ) : x; |
575 | attrs.Set("schedule_rule" , tvm::runtime::String("meta_schedule.pool_max" )); |
576 | return tvm::te::compute( |
577 | out_shape, |
578 | [&](const Array<Var>& output) { |
579 | Array<PrimExpr> indices; |
580 | for (const Var& var : output) indices.push_back(var); |
581 | |
582 | for (int i = 0; i < k_size; i++) { |
583 | int ii = axis[i]; |
584 | indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]); |
585 | } |
586 | return tvm::max(temp(indices), daxis); |
587 | }, |
588 | "pool_max" , "pool_max" , attrs); |
589 | } else if (pool_type == kAvgPool) { |
590 | attrs.Set("schedule_rule" , tvm::runtime::String("meta_schedule.pool_avg" )); |
591 | // Pad the inputs |
592 | auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp" ) : x; |
593 | |
594 | // TVM compute for summing the pooling window. |
595 | auto pool_sum = tvm::te::compute( |
596 | out_shape, |
597 | [&](const Array<Var>& output) { |
598 | Array<PrimExpr> indices; |
599 | for (const Var& var : output) indices.push_back(var); |
600 | |
601 | for (int i = 0; i < k_size; i++) { |
602 | int ii = axis[i]; |
603 | indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]); |
604 | } |
605 | return tvm::sum(temp(indices), daxis); |
606 | }, |
607 | "pool_sum" , "pool_sum" ); |
608 | |
609 | // TVM compute for dividing the reduced window sum by kernel size. |
610 | return tvm::te::compute( |
611 | out_shape, |
612 | [&](const Array<Var>& output) { |
613 | Array<PrimExpr> indices; |
614 | for (const Var& var : output) indices.push_back(var); |
615 | if (count_include_pad) { |
616 | std::vector<PrimExpr> start(k_size); |
617 | std::vector<PrimExpr> end(k_size); |
618 | auto num_el = make_const(DataType::Int(32), 1); |
619 | for (int i = 0; i < k_size; i++) { |
620 | int ii = axis[i]; |
621 | start[i] = output[ii] * stride[i] - pad_head[i]; |
622 | // When computing the output shape in ceil_mode, |
623 | // we have added the extra padding of offset[i], |
624 | // so now in order to calculate the correct boundary , |
625 | // we need to substract the offset[i]. |
626 | end[i] = start[i] + (kernel[i] - 1) * dilation[i]; |
627 | end[i] = min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]); |
628 | num_el *= (end[i] - start[i]) / dilation[i] + 1; |
629 | } |
630 | return div(pool_sum(indices), num_el); |
631 | } else { |
632 | std::vector<PrimExpr> start(k_size); |
633 | std::vector<PrimExpr> end(k_size); |
634 | auto num_el = make_const(DataType::Int(32), 1); |
635 | for (int i = 0; i < k_size; i++) { |
636 | int ii = axis[i]; |
637 | |
638 | // Let start and end contain the first and last index of our Tensor |
639 | // along the relevant dimension we use in our calculation. |
640 | // Assume indices -1, -2 represent the padding before (tail) and |
641 | // len(arr), len(arr) + 1 represent the padding after (head). |
642 | start[i] = output[ii] * stride[i] - pad_head[i]; |
643 | end[i] = start[i] + (kernel[i] - 1) * dilation[i]; |
644 | |
645 | // if start[i] < 0, e.g. we start on a tail padded number this will be a positive |
646 | // number that represents the number of steps along the dilated kernel to reach a |
647 | // non-padded value. Otherwise this should be 0. |
648 | PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i]; |
649 | jumps_to_non_pad = max(jumps_to_non_pad, make_const(DataType::Int(32), 0)); |
650 | |
651 | end[i] = min(end[i], data_shape[ii] - 1); |
652 | num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1; |
653 | } |
654 | |
655 | PrimExpr divide_factor = max(num_el, make_const(DataType::Int(32), 1)); |
656 | return div(pool_sum(indices), divide_factor); |
657 | } |
658 | }, |
659 | "pool_avg" , kElementWise, attrs); |
660 | } else { |
661 | LOG(ERROR) << "Unrecognized pool_type: " << pool_type; |
662 | return x; |
663 | } |
664 | } |
665 | |
666 | /*! |
667 | * \brief Perform pooling on the width dimension of data. |
668 | * Width axis is determined by the layout string |
669 | * in which 'W' means width. |
670 | * Width dimension cannot be split. |
671 | * For example, NCW, NCW16c, etc. are valid for pool, |
672 | * while NCW16w is not. |
673 | * See \a layout for more information of the layout string convention. |
674 | * \param x The input tensor. |
675 | * \param kernel_size Vector of one int: {kernel_width} |
676 | * \param stride_size Vector of one int: {stride_width} |
677 | * \param dilation_size Vector of one int: {dilation_width} |
678 | * \param padding_size Vector of two ints: {head_pad_width, tail_pad_width} |
679 | * \param pool_type The type of pooling operator |
680 | * \param ceil_mode Whether to use ceil when calculating the output size |
681 | * \param layout The input layout. Pooling supports any layout as long as 'W' appears. |
682 | * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, |
683 | * where upper case indicates a dimension and |
684 | * the corresponding lower case (with factor size) indicates the split dimension. |
685 | * For example, NCW16c can describe a 4-D tensor of |
686 | * [batch_size, channel, width, channel_block]. |
687 | * (in which factor size `16` will not be used in pooling but for other operators, |
688 | * it can be used to decide the output shape). |
689 | * Since pooling does not care about the factor size of dimensions |
690 | * other than `W`, one can pass `NCWc` as well. |
691 | * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' |
692 | * |
693 | * |
694 | * \return The output tensor in the same layout |
695 | */ |
696 | inline Tensor pool1d(const Tensor& x, const Array<PrimExpr>& kernel_size, |
697 | const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size, |
698 | const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode, |
699 | const std::string& layout = "NCW" , bool count_include_pad = true) { |
700 | int width_axis = -1; |
701 | ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout; |
702 | std::vector<int> axis = {width_axis}; |
703 | return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type, |
704 | ceil_mode, axis, count_include_pad); |
705 | } |
706 | |
707 | /*! |
708 | * \brief Perform pooling on height and width dimension of data. |
709 | * It decides the height and width dimension according to the layout string, |
710 | * in which 'W' and 'H' means width and height respectively. |
711 | * Width and height dimension cannot be split. |
712 | * For example, NCHW, NCHW16c, etc. are valid for pool, |
713 | * while NCHW16w, NCHW16h are not. |
714 | * See \a layout for more information of the layout string convention. |
715 | * \param x The input tensor. |
716 | * \param kernel_size Vector of two ints: {kernel_height, kernel_width} |
717 | * \param stride_size Vector of two ints: {stride_height, stride_width} |
718 | * \param dilation_size Vector of two ints: {dilation_height, dilation_width} |
719 | * \param padding_size Vector of two ints: {padding_height, padding_width} |
720 | * \param pool_type The type of pooling operator |
721 | * \param ceil_mode Whether to use ceil when calculating the output size |
722 | * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear. |
723 | * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, |
724 | * where upper case indicates a dimension and |
725 | * the corresponding lower case (with factor size) indicates the split dimension. |
726 | * For example, NCHW16c can describe a 5-D tensor of |
727 | * [batch_size, channel, height, width, channel_block]. |
728 | * (in which factor size `16` will not be used in pooling but for other operators, |
729 | * it can be used to decide the output shape). |
730 | * Since pooling does not care about the factor size of dimensions |
731 | * other than `H` and `W`, one can pass `NCHWc` as well. |
732 | * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' |
733 | * |
734 | * |
735 | * \return The output tensor in the same layout |
736 | */ |
737 | inline Tensor pool2d(const Tensor& x, const Array<PrimExpr>& kernel_size, |
738 | const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size, |
739 | const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode, |
740 | const std::string& layout = "NCHW" , bool count_include_pad = true) { |
741 | int height_axis = -1, width_axis = -1; |
742 | ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout; |
743 | std::vector<int> axis = {height_axis, width_axis}; |
744 | return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type, |
745 | ceil_mode, axis, count_include_pad); |
746 | } |
747 | |
748 | /*! |
749 | * \brief Perform pooling on depth, height and width dimension of data. |
750 | * It decides the depth, height and width dimension according to the layout string, |
751 | * in which 'D', 'W' and 'H' means depth, width and height respectively. |
752 | * Depth, Width and height dimension cannot be split. |
753 | * For example, NCDHW, NCDHW16c, etc. are valid for pool, |
754 | * while NCDHW16d, NCDHW16w or NCDHW16h are not. |
755 | * See \a layout for more information of the layout string convention. |
756 | * \param x The input tensor. |
757 | * \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width} |
758 | * \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width} |
759 | * \param dilation_size Vector of three ints: {dilation_depth, dilation_height, dilation_width} |
760 | * \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width, |
761 | * tail_pad_depth, tail_pad_height, tail_pad_width} |
762 | * \param pool_type The type of pooling operator |
763 | * \param ceil_mode Whether to use ceil when calculating the output size |
764 | * \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear. |
765 | * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers, |
766 | * where upper case indicates a dimension and |
767 | * the corresponding lower case (with factor size) indicates the split dimension. |
768 | * For example, NCDHW16c can describe a 6-D tensor of |
769 | * [batch_size, channel, depth, height, width, channel_block]. |
770 | * (in which factor size `16` will not be used in pooling but for other operators, |
771 | * it can be used to decide the output shape). |
772 | * Since pooling does not care about the factor size of dimensions |
773 | * other than `D`, `H` and `W`, one can pass `NCDHWc` as well. |
774 | * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' |
775 | * |
776 | * |
777 | * \return The output tensor in the same layout |
778 | */ |
779 | inline Tensor pool3d(const Tensor& x, const Array<PrimExpr>& kernel_size, |
780 | const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size, |
781 | const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode, |
782 | const std::string& layout = "NCDHW" , bool count_include_pad = true) { |
783 | int depth_axis = -1, height_axis = -1, width_axis = -1; |
784 | ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis)) |
785 | << "Unsupported layout " << layout; |
786 | std::vector<int> axis = {depth_axis, height_axis, width_axis}; |
787 | return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type, |
788 | ceil_mode, axis, count_include_pad); |
789 | } |
790 | |
791 | } // namespace nn |
792 | } // namespace topi |
793 | } // namespace tvm |
794 | #endif // TVM_TOPI_NN_POOLING_H_ |
795 | |