1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Pooling op constructions
22 * \file nn/pooling.h
23 */
24#ifndef TVM_TOPI_NN_POOLING_H_
25#define TVM_TOPI_NN_POOLING_H_
26
27#include <tvm/arith/analyzer.h>
28#include <tvm/topi/detail/pad_utils.h>
29#include <tvm/topi/nn.h>
30#include <tvm/topi/reduction.h>
31#include <tvm/topi/tags.h>
32
33#include <algorithm>
34#include <string>
35#include <vector>
36
37namespace tvm {
38namespace topi {
39namespace nn {
40
41using namespace tvm::te;
42
43/*! \brief Pooling type */
44enum PoolType : int {
45 kAvgPool,
46 kMaxPool,
47};
48
49inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
50 const Array<PrimExpr>& kernel_size, const Array<PrimExpr>& stride_size,
51 const Array<PrimExpr>& padding_size, PoolType pool_type,
52 bool ceil_mode, const size_t height_axis, const size_t width_axis,
53 bool count_include_pad) {
54 ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
55 ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
56 ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
57 ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
58 ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
59
60 auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]);
61 auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]);
62 auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
63 auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
64
65 auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
66 auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);
67
68 auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
69 auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
70 auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]);
71 auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]);
72
73 if (ceil_mode) {
74 // Additional padding to ensure we do ceil instead of floor when
75 // dividing by stride.
76 pad_bottom += stride_height - 1;
77 pad_right += stride_width - 1;
78 }
79
80 Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
81 pad_before.Set(height_axis, pad_top);
82 pad_before.Set(width_axis, pad_left);
83
84 Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
85 pad_after.Set(height_axis, pad_bottom);
86 pad_after.Set(width_axis, pad_right);
87 arith::Analyzer analyzer;
88 auto out_height =
89 analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
90 auto out_width =
91 analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
92
93 auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh");
94 auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw");
95
96 Array<PrimExpr> data_shape = x->shape;
97 for (size_t i = 0; i < data_shape.size(); ++i) {
98 data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
99 }
100
101 Array<PrimExpr> out_shape = data_shape;
102 out_shape.Set(height_axis, out_height);
103 out_shape.Set(width_axis, out_width);
104
105 const int64_t* padding_h0 = as_const_int(pad_top);
106 const int64_t* padding_w0 = as_const_int(pad_left);
107 const int64_t* padding_h1 = as_const_int(pad_bottom);
108 const int64_t* padding_w1 = as_const_int(pad_right);
109 const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
110 ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
111
112 if (pool_type == kMaxPool) {
113 Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
114 ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
115 ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
116
117 auto windowh =
118 tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
119 auto windoww =
120 tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
121
122 auto argmax = MakeArgmaxReducer();
123 auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
124
125 auto mp_argmax = tvm::te::compute(
126 out_shape,
127 [&](const Array<Var>& inds) {
128 Array<PrimExpr> window_inds{inds.begin(), inds.end()};
129 window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
130 window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
131 auto idx = detail::RavelIndex(window_inds, ravel_shape);
132 return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
133 },
134 "maxpool_grad_argmax", kCommReduceIdx);
135
136 auto mp_inds = mp_argmax[0];
137
138 return tvm::te::compute(
139 data_shape,
140 [&](const Array<Var>& inds) {
141 Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
142 pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
143 pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
144 auto idx = detail::RavelIndex(pad_inds, ravel_shape);
145
146 Array<PrimExpr> out_idx{inds.begin(), inds.end()};
147 out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
148 out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
149
150 PrimExpr out_idx_lower_h = tir::Select(
151 pad_inds[height_axis] < kernel_height, make_const(DataType::DataType::Int(32), 0),
152 (pad_inds[height_axis] - kernel_height) / stride_height + 1);
153 PrimExpr out_idx_lower_w = tir::Select(
154 pad_inds[width_axis] < kernel_width, make_const(DataType::DataType::Int(32), 0),
155 (pad_inds[width_axis] - kernel_width) / stride_width + 1);
156
157 return tvm::sum(
158 tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
159 out_idx[width_axis] >= out_idx_lower_w),
160 mp_inds(out_idx) == idx),
161 out_grad(out_idx), make_const(x->dtype, 0)),
162 {windowh, windoww});
163 },
164 "T_pool_grad", "pool_grad_max");
165 } else if (pool_type == kAvgPool) {
166 auto windowh =
167 tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
168 auto windoww =
169 tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
170 return tvm::te::compute(
171 data_shape,
172 [&](const Array<Var>& inds) {
173 PrimExpr pad_h_idx = inds[height_axis] + pad_top;
174 PrimExpr pad_w_idx = inds[width_axis] + pad_left;
175
176 // output indices whose pooling windows cover current input element (can be out-of-bound)
177 Array<PrimExpr> out_idx{inds.begin(), inds.end()};
178 out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
179 out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
180
181 PrimExpr out_idx_lower_h =
182 tir::Select(pad_h_idx < kernel_height, make_const(DataType::Int(32), 0),
183 (pad_h_idx - kernel_height) / stride_height + 1);
184 PrimExpr out_idx_lower_w =
185 tir::Select(pad_w_idx < kernel_width, make_const(DataType::Int(32), 0),
186 (pad_w_idx - kernel_width) / stride_width + 1);
187
188 PrimExpr divide_factor; // number of pooled elements
189 if (count_include_pad) {
190 divide_factor = kernel_height * kernel_width;
191 } else {
192 PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
193 PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
194
195 PrimExpr h_end = min(h_start + kernel_height, height);
196 PrimExpr w_end = min(w_start + kernel_width, width);
197 h_start = max(h_start, make_const(DataType::Int(32), 0));
198 w_start = max(w_start, make_const(DataType::Int(32), 0));
199 divide_factor =
200 max((h_end - h_start) * (w_end - w_start), make_const(DataType::Int(32), 1));
201 }
202 return tvm::sum(
203 tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
204 out_idx[height_axis] < out_height),
205 tir::And(out_idx[width_axis] >= out_idx_lower_w,
206 out_idx[width_axis] < out_width)),
207 out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
208 {windowh, windoww});
209 },
210 "T_pool_grad", "pool_grad_avg");
211 } else {
212 LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
213 return Tensor();
214 }
215}
216
217inline bool find_depth_height_width(const std::string& layout, int* depth_axis, int* height_axis,
218 int* width_axis) {
219 *depth_axis = -1;
220 *height_axis = -1;
221 *width_axis = -1;
222 int curr_idx = 0;
223 for (size_t i = 0; i < layout.size(); ++i) {
224 if ((layout[i] >= 'A' && layout[i] <= 'Z') || (layout[i] >= 'a' && layout[i] <= 'z')) {
225 if (layout[i] == 'D') {
226 if (*depth_axis != -1) return false;
227 *depth_axis = curr_idx;
228 } else if (layout[i] == 'H') {
229 if (*height_axis != -1) return false;
230 *height_axis = curr_idx;
231 } else if (layout[i] == 'W') {
232 if (*width_axis != -1) return false;
233 *width_axis = curr_idx;
234 } else if (layout[i] == 'd' || layout[i] == 'h' || layout[i] == 'w') {
235 // do not support split on height or width, e.g., NCHW16w
236 return false;
237 }
238 ++curr_idx;
239 }
240 }
241 if (*depth_axis == -1 || *height_axis == -1 || *width_axis == -1) return false;
242 return true;
243}
244
245inline bool find_height_width(const std::string& layout, int* height_axis, int* width_axis) {
246 int dummy;
247 ICHECK_EQ(find_depth_height_width(layout, &dummy, height_axis, width_axis), false);
248 if (*height_axis != -1 && *width_axis != -1) {
249 return true;
250 }
251 return false;
252}
253
254inline bool find_width(const std::string& layout, int* width_axis) {
255 int dummy;
256 ICHECK_EQ(find_depth_height_width(layout, &dummy, &dummy, width_axis), false);
257 if (*width_axis != -1) {
258 return true;
259 }
260 return false;
261}
262
263/*!
264 * \brief Calculate gradient of pooling on height and width dimension of data.
265 * It decides the height and width dimension according to the layout string,
266 * in which 'W' and 'H' means width and height respectively.
267 * Width and height dimension cannot be split.
268 * For example, NCHW, NCHW16c, etc. are valid for pool,
269 * while NCHW16w, NCHW16h are not.
270 * See \a layout for more information of the layout string convention.
271 * \param out_grad The output gradient tensor.
272 * \param x The input tensor.
273 * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
274 * \param stride_size Vector of two ints: {stride_height, stride_width}
275 * \param padding_size Vector of two ints: {padding_height, padding_width}
276 * \param pool_type The type of pooling operator
277 * \param ceil_mode Whether to use ceil when calculating the output size
278 * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
279 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
280 * where upper case indicates a dimension and
281 * the corresponding lower case (with factor size) indicates the split dimension.
282 * For example, NCHW16c can describe a 5-D tensor of
283 * [batch_size, channel, height, width, channel_block].
284 * (in which factor size `16` will not be used in pooling but for other operators,
285 * it can be used to decide the output shape).
286 * Since pooling does not care about the factor size of dimensions
287 * other than `H` and `W`, one can pass `NCHWc` as well.
288 * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
289 *
290 *
291 * \return The output tensor in the same layout
292 */
293inline Tensor pool_grad(const Tensor& out_grad, const Tensor& x, const Array<PrimExpr>& kernel_size,
294 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& padding_size,
295 PoolType pool_type, bool ceil_mode, const std::string& layout = "NCHW",
296 bool count_include_pad = true) {
297 int height_axis = -1, width_axis = -1;
298 ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
299 return pool_grad_impl(out_grad, x, kernel_size, stride_size, padding_size, pool_type, ceil_mode,
300 height_axis, width_axis, count_include_pad);
301}
302
303inline PrimExpr start_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
304 return indexdiv(out_index * idim, odim);
305}
306
307inline PrimExpr end_index(const Var& out_index, const PrimExpr& odim, const PrimExpr& idim) {
308 PrimExpr tmp = indexdiv((out_index + 1) * idim, odim);
309 return tvm::tir::Select(indexmod((out_index + 1) * idim, odim) == 0, tmp, tmp + 1);
310}
311
312/*!
313 * \brief Perform adaptive pooling on N dimensional data
314 *
315 * \param x The input tensor
316 * \param output_size int vector of size in each dimension
317 * \param pool_type The type of pooling operator
318 * \param axes indices of each dimension
319 *
320 * \return The output tensor in same layout order
321 */
322inline Tensor adaptive_pool_impl(const Tensor& x, const Array<PrimExpr>& output_size,
323 PoolType pool_type, const std::vector<int>& axes) {
324 const auto n_dim = output_size.size();
325 ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out dimension";
326
327 Array<PrimExpr> data_shape = x->shape;
328 for (size_t i = 0; i < data_shape.size(); ++i) {
329 data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
330 }
331 Array<PrimExpr> out_shape = data_shape;
332 Array<PrimExpr> in_size, out_size;
333 for (size_t i = 0; i < n_dim; ++i) {
334 in_size.push_back(data_shape[axes[i]]);
335 out_size.push_back(cast(DataType::Int(32), output_size[i]));
336 out_shape.Set(axes[i], out_size[i]);
337 }
338
339 auto get_iter_vars = [=](const Array<Var>& output, bool reduce_indices) {
340 Array<PrimExpr> indices;
341 for (size_t i = 0; i < output.size(); ++i) indices.push_back(output[i]);
342 Array<tir::IterVar> reduce_axes;
343 for (size_t i = 0; i < n_dim; ++i) {
344 auto i_start = start_index(output[axes[i]], out_size[i], in_size[i]);
345 auto i_end = end_index(output[axes[i]], out_size[i], in_size[i]);
346 auto rv_name = "rv" + std::to_string(i);
347 auto rv_axis = tvm::te::reduce_axis(Range(0, i_end - i_start), rv_name);
348 reduce_axes.push_back(rv_axis);
349 if (reduce_indices) {
350 indices.Set(axes[i], i_start + rv_axis);
351 }
352 }
353 return std::make_tuple(indices, reduce_axes);
354 };
355
356 Map<String, ObjectRef> attrs;
357 if (pool_type == kMaxPool) {
358 attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.adaptive_pool_max"));
359 return tvm::te::compute(
360 out_shape,
361 [&](const Array<Var>& output) {
362 Array<PrimExpr> indices;
363 Array<tir::IterVar> reduce_axes;
364 std::tie(indices, reduce_axes) = get_iter_vars(output, true);
365 return tvm::max(x(indices), reduce_axes); // NOLINT(*)
366 },
367 "adaptive_pool_max", "adaptive_pool_max", attrs);
368 } else if (pool_type == kAvgPool) {
369 attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.adaptive_pool_avg"));
370 auto pool_sum = tvm::te::compute(
371 out_shape,
372 [&](const Array<Var>& output) {
373 Array<PrimExpr> indices;
374 Array<tir::IterVar> reduce_axes;
375 std::tie(indices, reduce_axes) = get_iter_vars(output, true);
376 return tvm::sum(x(indices), reduce_axes);
377 },
378 "adaptive_pool_sum", "adaptive_pool_sum");
379
380 return tvm::te::compute(
381 out_shape,
382 [&](const Array<Var>& output) {
383 Array<PrimExpr> indices;
384 Array<tir::IterVar> reduce_axes;
385 std::tie(indices, reduce_axes) = get_iter_vars(output, false);
386
387 PrimExpr divide_factor = tvm::cast(x->dtype, 1);
388 for (size_t i = 0; i < n_dim; ++i) {
389 divide_factor *= tvm::cast(x->dtype, reduce_axes[i]->dom->extent);
390 }
391
392 return div(pool_sum(indices), divide_factor);
393 },
394 "adaptive_pool_avg", kElementWise, attrs);
395 } else {
396 LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
397 return x;
398 }
399}
400
401/*!
402 * \brief Adaptively perform pooling on height and width dimension of data.
403 * The pooling kernel and stride sizes are automatically chosen for desired output sizes.
404 * It decides the height and width dimension according to the layout string,
405 * in which 'W' and 'H' means width and height respectively.
406 * Width and height dimension cannot be split.
407 * For example, NCHW, NCHW16c, etc. are valid for pool,
408 * while NCHW16w, NCHW16h are not.
409 * See \a layout for more information of the layout string convention.
410 *
411 * \param x The input tensor
412 * \param output_size Vector of two ints: {output_height, output_width}
413 * \param pool_type The type of pooling operator
414 * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
415 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
416 * where upper case indicates a dimension and
417 * the corresponding lower case (with factor size) indicates the split dimension.
418 * For example, NCHW16c can describe a 5-D tensor of
419 * [batch_size, channel, height, width, channel_block].
420 * (in which factor size `16` will not be used in pooling but for other operators,
421 * it can be used to decide the output shape).
422 * Since pooling does not care about the factor size of dimensions
423 * other than `H` and `W`, one can pass `NCHWc` as well.
424 *
425 * \return The output tensor in same layout order
426 */
427inline Tensor adaptive_pool(const Tensor& x, const Array<PrimExpr>& output_size, PoolType pool_type,
428 const std::string& layout = "NCHW") {
429 int height_axis = -1, width_axis = -1;
430 ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
431 return adaptive_pool_impl(x, output_size, pool_type, {height_axis, width_axis});
432}
433
434/*!
435 * \brief Adaptively perform pooling on three dimensional data.
436 * See the two dimensional version above for details.
437 * \param x The input tensor
438 * \param output_size Vector of three ints: {output_depth, output_height, output_width}
439 * \param pool_type The type of pooling operator
440 * \param layout The input layout. The default is "NCDHW".
441 */
442inline Tensor adaptive_pool3d(const Tensor& x, const Array<PrimExpr>& output_size,
443 PoolType pool_type, const std::string& layout = "NCDHW") {
444 int depth_axis = -1, height_axis = -1, width_axis = -1;
445 ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
446 << "Unsupported layout " << layout;
447 return adaptive_pool_impl(x, output_size, pool_type, {depth_axis, height_axis, width_axis});
448}
449
450/*!
451 * \brief Adaptively perform pooling on one dimensional data.
452 * See the two dimensional version above for details.
453 * \param x The input tensor
454 * \param output_size Vector of one int: {output_width}
455 * \param pool_type The type of pooling operator
456 * \param layout The input layout. The default is "NCW".
457 */
458inline Tensor adaptive_pool1d(const Tensor& x, const Array<PrimExpr>& output_size,
459 PoolType pool_type, const std::string& layout = "NCW") {
460 int width_axis = -1;
461 ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
462 return adaptive_pool_impl(x, output_size, pool_type, {width_axis});
463}
464
465/*!
466 * \brief Perform global pooling on height and width dimension of data.
467 * It decides the height and width dimension according to the layout string,
468 * in which 'W' and 'H' means width and height respectively.
469 * Width and height dimension cannot be split.
470 * For example, NCHW, NCHW16c, ... are valid for global_pool,
471 * while NCHW16w, NCHW16h are not.
472 * See \a layout for more information of the layout string convention.
473 *
474 * \param x The input tensor represent as layout
475 * \param pool_type The type of pooling operator
476 * \param layout The input layout. global-pooling supports any layout as long as 'H' and 'W' appear.
477 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
478 * where upper case indicates a dimension and
479 * the corresponding lower case (with factor size) indicates the sub-dimension.
480 * For example, `NCHW16c` can describe a 5-D tensor of
481 * [batch_size, channel, height, width, channel_block].
482 * (in which factor size `16` will not be used in pooling but for other operators,
483 * it can be used to decide the output shape).
484 * Since pooling does not care about the factor size of
485 * dimensions other than `H` and `W`, one can pass `NCHWc` as well.
486 *
487 * \return The output tensor in same layout with height and width dimension size of 1.
488 * e.g., for NCHW, the output shape will be [batch, channel, 1, 1]
489 */
490inline Tensor global_pool(const Tensor& x, PoolType pool_type, const std::string& layout = "NCHW") {
491 return adaptive_pool(x, Array<PrimExpr>{1, 1}, pool_type, layout);
492}
493
494/*!
495 * \brief Perform pooling on N-dimension of data.
496 *
497 * \param x The input tensor
498 * \param kernel_size Vector of N ints
499 * \param stride_size Vector of N ints
500 * \param dilation_size Vector of N ints
501 * \param padding_size Vector of N*2 ints [head_pad_d1, head_pad_d2, ...,
502 * head_pad_dN, tail_pad_d1, tail_pad_d2, ..., tail_pad_dN]
503 * \param pool_type The type of pooling operator
504 * \param ceil_mode Whether to use ceil when calculating the output size
505 * \param axis Vector of indices for the N dimensions
506 * \param count_include_pad Whether include padding in the calculation
507 *
508 * \return The output tensor in same layout order
509 */
510inline Tensor pool_impl_nd(const Tensor& x, const Array<PrimExpr>& kernel_size,
511 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
512 const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
513 const std::vector<int>& axis, bool count_include_pad) {
514 int k_size = kernel_size.size();
515 int x_size = x->shape.size();
516 ICHECK_EQ(stride_size.size(), k_size) << "Pooling stride_size must have same elements as kernel";
517 ICHECK_EQ(padding_size.size(), k_size * 2) << "Pooling padding_size must has double elements of"
518 " kernel";
519 ICHECK_EQ(axis.size(), k_size) << "axis must have same elements as kernel";
520
521 Array<IterVar> daxis;
522 std::vector<PrimExpr> kernel(k_size);
523 std::vector<PrimExpr> stride(k_size);
524 std::vector<PrimExpr> dilation(k_size);
525 std::vector<PrimExpr> pad_head(k_size);
526 std::vector<PrimExpr> pad_tail(k_size);
527 std::vector<PrimExpr> offset(k_size, 0);
528 Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
529 Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
530 Array<PrimExpr> data_shape = x->shape;
531 for (size_t i = 0; i < data_shape.size(); ++i) {
532 data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
533 }
534 Array<PrimExpr> out_shape = data_shape;
535
536 bool do_pad = false;
537 for (int i = 0; i < k_size; i++) {
538 int ii = axis[i];
539 kernel[i] = cast(DataType::Int(32), kernel_size[i]);
540 stride[i] = cast(DataType::Int(32), stride_size[i]);
541 dilation[i] = cast(DataType::Int(32), dilation_size[i]);
542 pad_head[i] = cast(DataType::Int(32), padding_size[i]);
543 pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]);
544
545 if (ceil_mode) {
546 // The offset[i] is an additional padding to ensure we do ceil instead of floor when
547 // dividing by stride.
548 // In the case of ceil_mode=True and count_include_pad=True,
549 // in order to obtain the correct boundary,
550 // we also need to use the offset[i] to eliminate this extra padding.
551 offset[i] = stride[i] - 1;
552 pad_tail[i] += offset[i];
553 }
554
555 const int64_t* padding0 = as_const_int(pad_head[i]);
556 const int64_t* padding1 = as_const_int(pad_tail[i]);
557 do_pad = do_pad || (padding0 && *padding0) || (padding1 && *padding1);
558
559 daxis.push_back(tvm::te::reduce_axis(Range(0, kernel[i]), "rv" + std::to_string(i)));
560
561 pad_before.Set(ii, pad_head[i]);
562 pad_after.Set(ii, pad_tail[i]);
563
564 arith::Analyzer analyzer;
565
566 PrimExpr numerator =
567 data_shape[ii] - (kernel[i] - 1) * dilation[i] - 1 + pad_head[i] + pad_tail[i];
568 auto out_dim = analyzer.Simplify(indexdiv(numerator, stride[i]) + 1);
569 out_shape.Set(ii, out_dim);
570 }
571
572 Map<String, ObjectRef> attrs;
573 if (pool_type == kMaxPool) {
574 auto temp = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
575 attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.pool_max"));
576 return tvm::te::compute(
577 out_shape,
578 [&](const Array<Var>& output) {
579 Array<PrimExpr> indices;
580 for (const Var& var : output) indices.push_back(var);
581
582 for (int i = 0; i < k_size; i++) {
583 int ii = axis[i];
584 indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
585 }
586 return tvm::max(temp(indices), daxis);
587 },
588 "pool_max", "pool_max", attrs);
589 } else if (pool_type == kAvgPool) {
590 attrs.Set("schedule_rule", tvm::runtime::String("meta_schedule.pool_avg"));
591 // Pad the inputs
592 auto temp = do_pad ? pad(x, pad_before, pad_after, 0, "pad_temp") : x;
593
594 // TVM compute for summing the pooling window.
595 auto pool_sum = tvm::te::compute(
596 out_shape,
597 [&](const Array<Var>& output) {
598 Array<PrimExpr> indices;
599 for (const Var& var : output) indices.push_back(var);
600
601 for (int i = 0; i < k_size; i++) {
602 int ii = axis[i];
603 indices.Set(ii, output[ii] * stride[i] + daxis[i] * dilation[i]);
604 }
605 return tvm::sum(temp(indices), daxis);
606 },
607 "pool_sum", "pool_sum");
608
609 // TVM compute for dividing the reduced window sum by kernel size.
610 return tvm::te::compute(
611 out_shape,
612 [&](const Array<Var>& output) {
613 Array<PrimExpr> indices;
614 for (const Var& var : output) indices.push_back(var);
615 if (count_include_pad) {
616 std::vector<PrimExpr> start(k_size);
617 std::vector<PrimExpr> end(k_size);
618 auto num_el = make_const(DataType::Int(32), 1);
619 for (int i = 0; i < k_size; i++) {
620 int ii = axis[i];
621 start[i] = output[ii] * stride[i] - pad_head[i];
622 // When computing the output shape in ceil_mode,
623 // we have added the extra padding of offset[i],
624 // so now in order to calculate the correct boundary ,
625 // we need to substract the offset[i].
626 end[i] = start[i] + (kernel[i] - 1) * dilation[i];
627 end[i] = min(end[i], data_shape[ii] + pad_tail[i] - 1 - offset[i]);
628 num_el *= (end[i] - start[i]) / dilation[i] + 1;
629 }
630 return div(pool_sum(indices), num_el);
631 } else {
632 std::vector<PrimExpr> start(k_size);
633 std::vector<PrimExpr> end(k_size);
634 auto num_el = make_const(DataType::Int(32), 1);
635 for (int i = 0; i < k_size; i++) {
636 int ii = axis[i];
637
638 // Let start and end contain the first and last index of our Tensor
639 // along the relevant dimension we use in our calculation.
640 // Assume indices -1, -2 represent the padding before (tail) and
641 // len(arr), len(arr) + 1 represent the padding after (head).
642 start[i] = output[ii] * stride[i] - pad_head[i];
643 end[i] = start[i] + (kernel[i] - 1) * dilation[i];
644
645 // if start[i] < 0, e.g. we start on a tail padded number this will be a positive
646 // number that represents the number of steps along the dilated kernel to reach a
647 // non-padded value. Otherwise this should be 0.
648 PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / dilation[i];
649 jumps_to_non_pad = max(jumps_to_non_pad, make_const(DataType::Int(32), 0));
650
651 end[i] = min(end[i], data_shape[ii] - 1);
652 num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) / dilation[i] + 1;
653 }
654
655 PrimExpr divide_factor = max(num_el, make_const(DataType::Int(32), 1));
656 return div(pool_sum(indices), divide_factor);
657 }
658 },
659 "pool_avg", kElementWise, attrs);
660 } else {
661 LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
662 return x;
663 }
664}
665
666/*!
667 * \brief Perform pooling on the width dimension of data.
668 * Width axis is determined by the layout string
669 * in which 'W' means width.
670 * Width dimension cannot be split.
671 * For example, NCW, NCW16c, etc. are valid for pool,
672 * while NCW16w is not.
673 * See \a layout for more information of the layout string convention.
674 * \param x The input tensor.
675 * \param kernel_size Vector of one int: {kernel_width}
676 * \param stride_size Vector of one int: {stride_width}
677 * \param dilation_size Vector of one int: {dilation_width}
678 * \param padding_size Vector of two ints: {head_pad_width, tail_pad_width}
679 * \param pool_type The type of pooling operator
680 * \param ceil_mode Whether to use ceil when calculating the output size
681 * \param layout The input layout. Pooling supports any layout as long as 'W' appears.
682 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
683 * where upper case indicates a dimension and
684 * the corresponding lower case (with factor size) indicates the split dimension.
685 * For example, NCW16c can describe a 4-D tensor of
686 * [batch_size, channel, width, channel_block].
687 * (in which factor size `16` will not be used in pooling but for other operators,
688 * it can be used to decide the output shape).
689 * Since pooling does not care about the factor size of dimensions
690 * other than `W`, one can pass `NCWc` as well.
691 * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
692 *
693 *
694 * \return The output tensor in the same layout
695 */
696inline Tensor pool1d(const Tensor& x, const Array<PrimExpr>& kernel_size,
697 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
698 const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
699 const std::string& layout = "NCW", bool count_include_pad = true) {
700 int width_axis = -1;
701 ICHECK(find_width(layout, &width_axis)) << "Unsupported layout " << layout;
702 std::vector<int> axis = {width_axis};
703 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
704 ceil_mode, axis, count_include_pad);
705}
706
707/*!
708 * \brief Perform pooling on height and width dimension of data.
709 * It decides the height and width dimension according to the layout string,
710 * in which 'W' and 'H' means width and height respectively.
711 * Width and height dimension cannot be split.
712 * For example, NCHW, NCHW16c, etc. are valid for pool,
713 * while NCHW16w, NCHW16h are not.
714 * See \a layout for more information of the layout string convention.
715 * \param x The input tensor.
716 * \param kernel_size Vector of two ints: {kernel_height, kernel_width}
717 * \param stride_size Vector of two ints: {stride_height, stride_width}
718 * \param dilation_size Vector of two ints: {dilation_height, dilation_width}
719 * \param padding_size Vector of two ints: {padding_height, padding_width}
720 * \param pool_type The type of pooling operator
721 * \param ceil_mode Whether to use ceil when calculating the output size
722 * \param layout The input layout. Pooling supports any layout as long as 'H' and 'W' appear.
723 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
724 * where upper case indicates a dimension and
725 * the corresponding lower case (with factor size) indicates the split dimension.
726 * For example, NCHW16c can describe a 5-D tensor of
727 * [batch_size, channel, height, width, channel_block].
728 * (in which factor size `16` will not be used in pooling but for other operators,
729 * it can be used to decide the output shape).
730 * Since pooling does not care about the factor size of dimensions
731 * other than `H` and `W`, one can pass `NCHWc` as well.
732 * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
733 *
734 *
735 * \return The output tensor in the same layout
736 */
737inline Tensor pool2d(const Tensor& x, const Array<PrimExpr>& kernel_size,
738 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
739 const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
740 const std::string& layout = "NCHW", bool count_include_pad = true) {
741 int height_axis = -1, width_axis = -1;
742 ICHECK(find_height_width(layout, &height_axis, &width_axis)) << "Unsupported layout " << layout;
743 std::vector<int> axis = {height_axis, width_axis};
744 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
745 ceil_mode, axis, count_include_pad);
746}
747
748/*!
749 * \brief Perform pooling on depth, height and width dimension of data.
750 * It decides the depth, height and width dimension according to the layout string,
751 * in which 'D', 'W' and 'H' means depth, width and height respectively.
752 * Depth, Width and height dimension cannot be split.
753 * For example, NCDHW, NCDHW16c, etc. are valid for pool,
754 * while NCDHW16d, NCDHW16w or NCDHW16h are not.
755 * See \a layout for more information of the layout string convention.
756 * \param x The input tensor.
757 * \param kernel_size Vector of three ints: {kernel_depth, kernel_height, kernel_width}
758 * \param stride_size Vector of three ints: {stride_depth, stride_height, stride_width}
759 * \param dilation_size Vector of three ints: {dilation_depth, dilation_height, dilation_width}
760 * \param padding_size Vector of six ints: {head_pad_depth, head_pad_height, head_pad_width,
761 * tail_pad_depth, tail_pad_height, tail_pad_width}
762 * \param pool_type The type of pooling operator
763 * \param ceil_mode Whether to use ceil when calculating the output size
764 * \param layout The input layout. Pooling supports any layout as long as 'D', 'H' and 'W' appear.
765 * The layout is supposed to be composed of upper cases, lower cases and (optional) numbers,
766 * where upper case indicates a dimension and
767 * the corresponding lower case (with factor size) indicates the split dimension.
768 * For example, NCDHW16c can describe a 6-D tensor of
769 * [batch_size, channel, depth, height, width, channel_block].
770 * (in which factor size `16` will not be used in pooling but for other operators,
771 * it can be used to decide the output shape).
772 * Since pooling does not care about the factor size of dimensions
773 * other than `D`, `H` and `W`, one can pass `NCDHWc` as well.
774 * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg'
775 *
776 *
777 * \return The output tensor in the same layout
778 */
779inline Tensor pool3d(const Tensor& x, const Array<PrimExpr>& kernel_size,
780 const Array<PrimExpr>& stride_size, const Array<PrimExpr>& dilation_size,
781 const Array<PrimExpr>& padding_size, PoolType pool_type, bool ceil_mode,
782 const std::string& layout = "NCDHW", bool count_include_pad = true) {
783 int depth_axis = -1, height_axis = -1, width_axis = -1;
784 ICHECK(find_depth_height_width(layout, &depth_axis, &height_axis, &width_axis))
785 << "Unsupported layout " << layout;
786 std::vector<int> axis = {depth_axis, height_axis, width_axis};
787 return pool_impl_nd(x, kernel_size, stride_size, dilation_size, padding_size, pool_type,
788 ceil_mode, axis, count_include_pad);
789}
790
791} // namespace nn
792} // namespace topi
793} // namespace tvm
794#endif // TVM_TOPI_NN_POOLING_H_
795