1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tensor_utils.h
22 * \brief Utility functions for handling tensor
23 */
24#ifndef TVM_TOPI_DETAIL_TENSOR_UTILS_H_
25#define TVM_TOPI_DETAIL_TENSOR_UTILS_H_
26
27#include <tvm/te/operation.h>
28
29#include <vector>
30namespace tvm {
31namespace topi {
32namespace detail {
33
34using namespace tvm::te;
35
36/*!
37 * \brief Check whether input shape has dimension of size 0;
38 *
39 * \param x Input shape
40 *
41 * \return True if the input shape is empty.
42 */
43inline bool is_empty_shape(const Array<PrimExpr>& x) {
44 bool is_empty = false;
45 for (const auto& dim : x) {
46 if (auto int_dim = dim.as<IntImmNode>()) {
47 if (int_dim->value == 0) {
48 is_empty = true;
49 break;
50 }
51 }
52 }
53 return is_empty;
54}
55
56/*!
57 * \brief Sample a point in a tensor using bilinear interpolation.
58 *
59 * \param input The input tensor.
60 * \param indices The index of the target point, which can be fractional
61 * \param max_y The maximum of y dimension
62 * \param max_x The maximum of x dimension
63 *
64 * \return The interpolated value in the given index.
65 */
66inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>& indices,
67 const PrimExpr max_y, const PrimExpr max_x) {
68 auto batch_id = indices[0];
69 auto channel_id = indices[1];
70 auto in_y = indices[2];
71 auto in_x = indices[3];
72
73 auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y));
74 auto y_high = y_low + 1;
75
76 auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x));
77 auto x_high = x_low + 1;
78
79 auto wy_h = in_y - y_low;
80 auto wx_h = in_x - x_low;
81 auto wy_l = 1 - wy_h;
82 auto wx_l = 1 - wx_h;
83
84 PrimExpr val = 0;
85 std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
86 std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
87 for (auto wx_xp_ele : wx_xp) {
88 for (auto wy_yp_ele : wy_yp) {
89 auto wx = wx_xp_ele[0];
90 auto xp = wx_xp_ele[1];
91 auto wy = wy_yp_ele[0];
92 auto yp = wy_yp_ele[1];
93 val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
94 wx * wy * input(batch_id, channel_id, yp, xp), 0);
95 }
96 }
97 return val;
98}
99
100/*!
101 * \brief Sample a point in a tensor using bilinear interpolation.
102 *
103 * \param input The input tensor.
104 * \param indices The index of the target point, which can be fractional
105 * \param max_y The maximum of y dimension
106 * \param max_x The maximum of x dimension
107 *
108 * \return The interpolated value in the given index.
109 */
110inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array<PrimExpr>& indices,
111 const PrimExpr max_y, const PrimExpr max_x) {
112 auto batch_id = indices[0];
113 auto channel_id = indices[3];
114 auto in_y = indices[1];
115 auto in_x = indices[2];
116
117 auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y));
118 auto y_high = y_low + 1;
119
120 auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x));
121 auto x_high = x_low + 1;
122
123 auto wy_h = in_y - y_low;
124 auto wx_h = in_x - x_low;
125 auto wy_l = 1 - wy_h;
126 auto wx_l = 1 - wx_h;
127
128 PrimExpr val = 0;
129 std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}};
130 std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}};
131 for (auto wx_xp_ele : wx_xp) {
132 for (auto wy_yp_ele : wy_yp) {
133 auto wx = wx_xp_ele[0];
134 auto xp = wx_xp_ele[1];
135 auto wy = wy_yp_ele[0];
136 auto yp = wy_yp_ele[1];
137 val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x,
138 wx * wy * input(batch_id, yp, xp, channel_id), 0);
139 }
140 }
141 return val;
142}
143
144} // namespace detail
145} // namespace topi
146} // namespace tvm
147#endif // TVM_TOPI_DETAIL_TENSOR_UTILS_H_
148