1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file strided_slice.h |
22 | * \brief Utility functions for strided_slice op |
23 | */ |
24 | #ifndef TVM_TOPI_DETAIL_STRIDED_SLICE_H_ |
25 | #define TVM_TOPI_DETAIL_STRIDED_SLICE_H_ |
26 | |
27 | #include <tvm/tir/expr.h> |
28 | |
29 | #include <algorithm> |
30 | #include <limits> |
31 | #include <string> |
32 | #include <tuple> |
33 | #include <vector> |
34 | |
35 | #include "constant_utils.h" |
36 | |
37 | namespace tvm { |
38 | namespace topi { |
39 | namespace detail { |
40 | |
41 | using namespace tvm::te; |
42 | |
43 | inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) { |
44 | int64_t begin_range = stride < 0 ? -1 : 0; |
45 | int64_t end_range = stride < 0 ? extent - 1 : extent; |
46 | if (index < 0) { |
47 | index += extent; |
48 | } |
49 | return std::min(std::max(index, begin_range), end_range); |
50 | } |
51 | |
52 | inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> ConvertToVec( |
53 | const Array<Integer>& begin, const Array<Integer>& end, const Array<Integer>& strides, |
54 | std::string slice_mode) { |
55 | std::vector<int64_t> stride_vec(strides.size(), 1); |
56 | if (slice_mode == "end" ) { |
57 | for (size_t i = 0; i < strides.size(); ++i) { |
58 | ICHECK(strides[i].defined()); |
59 | stride_vec[i] = GetConstInt(strides[i]); |
60 | } |
61 | } |
62 | const int64_t max_range = std::numeric_limits<int64_t>::max(); |
63 | std::vector<int64_t> begin_vec; |
64 | for (size_t i = 0; i < begin.size(); ++i) { |
65 | if (!begin[i].defined()) { |
66 | // value=None |
67 | begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); |
68 | } else { |
69 | begin_vec.push_back(GetConstInt(begin[i])); |
70 | } |
71 | } |
72 | std::vector<int64_t> end_vec; |
73 | for (size_t i = 0; i < end.size(); ++i) { |
74 | // allow end to be None |
75 | if (!end[i].defined()) { |
76 | end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); |
77 | } else if (slice_mode == "size" ) { |
78 | int64_t end_val = GetConstInt(end[i]); |
79 | if (end_val < 0) { |
80 | end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); |
81 | } else { |
82 | end_vec.push_back(begin_vec[i] + end_val); |
83 | } |
84 | } else { |
85 | end_vec.push_back(GetConstInt(end[i])); |
86 | } |
87 | } |
88 | return std::make_tuple(begin_vec, end_vec, stride_vec); |
89 | } |
90 | |
91 | inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Array<PrimExpr>& ishape, |
92 | const std::vector<int64_t>& begin, |
93 | const std::vector<int64_t>& strides, |
94 | const Array<Integer>& axes, DataType dtype, |
95 | std::string slice_mode = "end" ) { |
96 | Array<PrimExpr> begin_expr; |
97 | for (size_t i = 0; i < axes.size(); ++i) { |
98 | if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) { |
99 | int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]); |
100 | int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]); |
101 | begin_expr.push_back(make_const(dtype, begin_i)); |
102 | } else { |
103 | auto idim = ishape[axes[i].IntValue()]; |
104 | auto b_expr = make_const(dtype, begin[i]); |
105 | PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr; |
106 | auto s = strides[i]; |
107 | if (s < 0) { |
108 | b = tvm::min(b, idim - 1); |
109 | } else { |
110 | b = tvm::if_then_else(b < 0, 0, b); |
111 | } |
112 | begin_expr.push_back(b); |
113 | } |
114 | } |
115 | return begin_expr; |
116 | } |
117 | |
118 | inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape, |
119 | const std::vector<int64_t>& begin, |
120 | const std::vector<int64_t>& end, |
121 | const std::vector<int64_t>& strides, |
122 | const Array<Integer>& axes, std::string slice_mode, |
123 | const Array<PrimExpr>& begin_canonicalized, |
124 | bool use_any = false) { |
125 | const size_t src_tensor_dim = ishape.size(); |
126 | Array<PrimExpr> out_shape; |
127 | for (size_t i = 0; i < src_tensor_dim; ++i) { |
128 | out_shape.push_back(ishape[i]); |
129 | } |
130 | |
131 | for (size_t i = 0; i < axes.size(); ++i) { |
132 | if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) { |
133 | const int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]); |
134 | ICHECK(begin_canonicalized[i]->IsInstance<tvm::IntImmNode>()); |
135 | int64_t begin_i = GetConstInt(begin_canonicalized[i]); |
136 | int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]); |
137 | int interval = std::abs(end_i - begin_i); |
138 | int slice_size = |
139 | static_cast<int>((interval + std::abs(strides[i]) - 1) / std::abs(strides[i])); |
140 | ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) |
141 | << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; |
142 | out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(), PrimExpr(slice_size))); |
143 | } else if (use_any) { |
144 | out_shape.Set(axes[i].IntValue(), tvm::tir::Any()); |
145 | } else { |
146 | out_shape.Set(axes[i].IntValue(), tvm::tir::Var("dim" , out_shape[i]->dtype)); |
147 | } |
148 | } |
149 | |
150 | return out_shape; |
151 | } |
152 | |
153 | } // namespace detail |
154 | } // namespace topi |
155 | } // namespace tvm |
156 | #endif // TVM_TOPI_DETAIL_STRIDED_SLICE_H_ |
157 | |