1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Reorg op constructions
22 * \file vision/reorg.h
23 */
24#ifndef TVM_TOPI_VISION_REORG_H_
25#define TVM_TOPI_VISION_REORG_H_
26
27#include <tvm/te/operation.h>
28#include <tvm/topi/detail/constant_utils.h>
29#include <tvm/topi/reduction.h>
30#include <tvm/topi/tags.h>
31#include <tvm/topi/transform.h>
32
33#include <algorithm>
34#include <string>
35
36namespace tvm {
37namespace topi {
38namespace vision {
39
40using namespace tvm::te;
41
42/*!
43 * \brief Reorg operation
44 *
45 * \param data The input tensor. Can be any dimension
46 * \param stride The input integer used as stride in reorg operation
47 * \param name The name of the operation
48 * \param tag The tag to mark the operation
49 *
50 * \return A Tensor whose op member is the reorg operation
51 */
52inline Tensor reorg(const Tensor& data, int stride = 1, std::string name = "tensor",
53 std::string tag = "reorg_output") {
54 auto input_shape = data->shape;
55
56 int batch = GetConstInt(input_shape[0]);
57 int c_in = GetConstInt(input_shape[1]);
58 int h_in = GetConstInt(input_shape[2]);
59 int w_in = GetConstInt(input_shape[3]);
60 int out_c = c_in / (stride * stride);
61
62 auto out = tvm::te::compute(
63 input_shape,
64 [&](Var b, Var k, Var j, Var i) {
65 return data(b * stride * stride, indexmod(k, out_c) * stride * stride,
66 (j * stride + indexdiv(indexdiv(k, out_c), stride)) * stride,
67 (i * stride + indexmod(indexdiv(k, out_c), stride)));
68 },
69 name, tag);
70
71 out_c = c_in * stride * stride;
72 int out_h = h_in / stride;
73 int out_w = w_in / stride;
74
75 Array<PrimExpr> out_shape = {batch, out_c, out_h, out_w};
76 return reshape(out, out_shape);
77}
78} // namespace vision
79} // namespace topi
80} // namespace tvm
81#endif // TVM_TOPI_VISION_REORG_H_
82