1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tensor_utils.h |
22 | * \brief Utility functions for handling tensor |
23 | */ |
24 | #ifndef TVM_TOPI_DETAIL_TENSOR_UTILS_H_ |
25 | #define TVM_TOPI_DETAIL_TENSOR_UTILS_H_ |
26 | |
27 | #include <tvm/te/operation.h> |
28 | |
29 | #include <vector> |
30 | namespace tvm { |
31 | namespace topi { |
32 | namespace detail { |
33 | |
34 | using namespace tvm::te; |
35 | |
36 | /*! |
37 | * \brief Check whether input shape has dimension of size 0; |
38 | * |
39 | * \param x Input shape |
40 | * |
41 | * \return True if the input shape is empty. |
42 | */ |
43 | inline bool is_empty_shape(const Array<PrimExpr>& x) { |
44 | bool is_empty = false; |
45 | for (const auto& dim : x) { |
46 | if (auto int_dim = dim.as<IntImmNode>()) { |
47 | if (int_dim->value == 0) { |
48 | is_empty = true; |
49 | break; |
50 | } |
51 | } |
52 | } |
53 | return is_empty; |
54 | } |
55 | |
56 | /*! |
57 | * \brief Sample a point in a tensor using bilinear interpolation. |
58 | * |
59 | * \param input The input tensor. |
60 | * \param indices The index of the target point, which can be fractional |
61 | * \param max_y The maximum of y dimension |
62 | * \param max_x The maximum of x dimension |
63 | * |
64 | * \return The interpolated value in the given index. |
65 | */ |
66 | inline PrimExpr bilinear_sample_nchw(const Tensor& input, const Array<PrimExpr>& indices, |
67 | const PrimExpr max_y, const PrimExpr max_x) { |
68 | auto batch_id = indices[0]; |
69 | auto channel_id = indices[1]; |
70 | auto in_y = indices[2]; |
71 | auto in_x = indices[3]; |
72 | |
73 | auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y)); |
74 | auto y_high = y_low + 1; |
75 | |
76 | auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x)); |
77 | auto x_high = x_low + 1; |
78 | |
79 | auto wy_h = in_y - y_low; |
80 | auto wx_h = in_x - x_low; |
81 | auto wy_l = 1 - wy_h; |
82 | auto wx_l = 1 - wx_h; |
83 | |
84 | PrimExpr val = 0; |
85 | std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}}; |
86 | std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}}; |
87 | for (auto wx_xp_ele : wx_xp) { |
88 | for (auto wy_yp_ele : wy_yp) { |
89 | auto wx = wx_xp_ele[0]; |
90 | auto xp = wx_xp_ele[1]; |
91 | auto wy = wy_yp_ele[0]; |
92 | auto yp = wy_yp_ele[1]; |
93 | val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x, |
94 | wx * wy * input(batch_id, channel_id, yp, xp), 0); |
95 | } |
96 | } |
97 | return val; |
98 | } |
99 | |
100 | /*! |
101 | * \brief Sample a point in a tensor using bilinear interpolation. |
102 | * |
103 | * \param input The input tensor. |
104 | * \param indices The index of the target point, which can be fractional |
105 | * \param max_y The maximum of y dimension |
106 | * \param max_x The maximum of x dimension |
107 | * |
108 | * \return The interpolated value in the given index. |
109 | */ |
110 | inline PrimExpr bilinear_sample_nhwc(const Tensor& input, const Array<PrimExpr>& indices, |
111 | const PrimExpr max_y, const PrimExpr max_x) { |
112 | auto batch_id = indices[0]; |
113 | auto channel_id = indices[3]; |
114 | auto in_y = indices[1]; |
115 | auto in_x = indices[2]; |
116 | |
117 | auto y_low = tvm::cast(DataType::Int(32), tvm::floor(in_y)); |
118 | auto y_high = y_low + 1; |
119 | |
120 | auto x_low = tvm::cast(DataType::Int(32), tvm::floor(in_x)); |
121 | auto x_high = x_low + 1; |
122 | |
123 | auto wy_h = in_y - y_low; |
124 | auto wx_h = in_x - x_low; |
125 | auto wy_l = 1 - wy_h; |
126 | auto wx_l = 1 - wx_h; |
127 | |
128 | PrimExpr val = 0; |
129 | std::vector<std::vector<PrimExpr>> wx_xp{{wx_l, x_low}, {wx_h, x_high}}; |
130 | std::vector<std::vector<PrimExpr>> wy_yp{{wy_l, y_low}, {wy_h, y_high}}; |
131 | for (auto wx_xp_ele : wx_xp) { |
132 | for (auto wy_yp_ele : wy_yp) { |
133 | auto wx = wx_xp_ele[0]; |
134 | auto xp = wx_xp_ele[1]; |
135 | auto wy = wy_yp_ele[0]; |
136 | auto yp = wy_yp_ele[1]; |
137 | val += tvm::if_then_else(0 <= yp && yp <= max_y && 0 <= xp && xp <= max_x, |
138 | wx * wy * input(batch_id, yp, xp, channel_id), 0); |
139 | } |
140 | } |
141 | return val; |
142 | } |
143 | |
144 | } // namespace detail |
145 | } // namespace topi |
146 | } // namespace tvm |
147 | #endif // TVM_TOPI_DETAIL_TENSOR_UTILS_H_ |
148 | |