1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file strided_slice.h
22 * \brief Utility functions for strided_slice op
23 */
24#ifndef TVM_TOPI_DETAIL_STRIDED_SLICE_H_
25#define TVM_TOPI_DETAIL_STRIDED_SLICE_H_
26
27#include <tvm/tir/expr.h>
28
29#include <algorithm>
30#include <limits>
31#include <string>
32#include <tuple>
33#include <vector>
34
35#include "constant_utils.h"
36
37namespace tvm {
38namespace topi {
39namespace detail {
40
41using namespace tvm::te;
42
43inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) {
44 int64_t begin_range = stride < 0 ? -1 : 0;
45 int64_t end_range = stride < 0 ? extent - 1 : extent;
46 if (index < 0) {
47 index += extent;
48 }
49 return std::min(std::max(index, begin_range), end_range);
50}
51
52inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> ConvertToVec(
53 const Array<Integer>& begin, const Array<Integer>& end, const Array<Integer>& strides,
54 std::string slice_mode) {
55 std::vector<int64_t> stride_vec(strides.size(), 1);
56 if (slice_mode == "end") {
57 for (size_t i = 0; i < strides.size(); ++i) {
58 ICHECK(strides[i].defined());
59 stride_vec[i] = GetConstInt(strides[i]);
60 }
61 }
62 const int64_t max_range = std::numeric_limits<int64_t>::max();
63 std::vector<int64_t> begin_vec;
64 for (size_t i = 0; i < begin.size(); ++i) {
65 if (!begin[i].defined()) {
66 // value=None
67 begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
68 } else {
69 begin_vec.push_back(GetConstInt(begin[i]));
70 }
71 }
72 std::vector<int64_t> end_vec;
73 for (size_t i = 0; i < end.size(); ++i) {
74 // allow end to be None
75 if (!end[i].defined()) {
76 end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
77 } else if (slice_mode == "size") {
78 int64_t end_val = GetConstInt(end[i]);
79 if (end_val < 0) {
80 end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
81 } else {
82 end_vec.push_back(begin_vec[i] + end_val);
83 }
84 } else {
85 end_vec.push_back(GetConstInt(end[i]));
86 }
87 }
88 return std::make_tuple(begin_vec, end_vec, stride_vec);
89}
90
91inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Array<PrimExpr>& ishape,
92 const std::vector<int64_t>& begin,
93 const std::vector<int64_t>& strides,
94 const Array<Integer>& axes, DataType dtype,
95 std::string slice_mode = "end") {
96 Array<PrimExpr> begin_expr;
97 for (size_t i = 0; i < axes.size(); ++i) {
98 if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) {
99 int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]);
100 int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]);
101 begin_expr.push_back(make_const(dtype, begin_i));
102 } else {
103 auto idim = ishape[axes[i].IntValue()];
104 auto b_expr = make_const(dtype, begin[i]);
105 PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
106 auto s = strides[i];
107 if (s < 0) {
108 b = tvm::min(b, idim - 1);
109 } else {
110 b = tvm::if_then_else(b < 0, 0, b);
111 }
112 begin_expr.push_back(b);
113 }
114 }
115 return begin_expr;
116}
117
118inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape,
119 const std::vector<int64_t>& begin,
120 const std::vector<int64_t>& end,
121 const std::vector<int64_t>& strides,
122 const Array<Integer>& axes, std::string slice_mode,
123 const Array<PrimExpr>& begin_canonicalized,
124 bool use_any = false) {
125 const size_t src_tensor_dim = ishape.size();
126 Array<PrimExpr> out_shape;
127 for (size_t i = 0; i < src_tensor_dim; ++i) {
128 out_shape.push_back(ishape[i]);
129 }
130
131 for (size_t i = 0; i < axes.size(); ++i) {
132 if (ishape[axes[i].IntValue()]->IsInstance<tvm::IntImmNode>()) {
133 const int64_t dim_i = GetConstInt(ishape[axes[i].IntValue()]);
134 ICHECK(begin_canonicalized[i]->IsInstance<tvm::IntImmNode>());
135 int64_t begin_i = GetConstInt(begin_canonicalized[i]);
136 int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]);
137 int interval = std::abs(end_i - begin_i);
138 int slice_size =
139 static_cast<int>((interval + std::abs(strides[i]) - 1) / std::abs(strides[i]));
140 ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
141 << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i;
142 out_shape.Set(axes[i].IntValue(), cast(out_shape[i].dtype(), PrimExpr(slice_size)));
143 } else if (use_any) {
144 out_shape.Set(axes[i].IntValue(), tvm::tir::Any());
145 } else {
146 out_shape.Set(axes[i].IntValue(), tvm::tir::Var("dim", out_shape[i]->dtype));
147 }
148 }
149
150 return out_shape;
151}
152
153} // namespace detail
154} // namespace topi
155} // namespace tvm
156#endif // TVM_TOPI_DETAIL_STRIDED_SLICE_H_
157