1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24/*!
25 * \brief Calculate the strides of the buffer
26 * \param buffer The buffer
27 * \return The strides
28 */
29Array<PrimExpr> GetStrides(const Buffer& buffer) {
30 if (!buffer->strides.empty()) {
31 ICHECK_EQ(buffer->strides.size(), buffer->shape.size());
32 return buffer->strides;
33 }
34 int ndim = buffer->shape.size();
35 if (ndim == 0) {
36 return {};
37 }
38 Array<PrimExpr> strides(ndim, PrimExpr{nullptr});
39 PrimExpr stride = make_const(buffer->DefaultIndexType(), 1);
40 for (int i = ndim - 1; i >= 0; --i) {
41 strides.Set(i, stride);
42 stride = stride * buffer->shape[i];
43 }
44 return strides;
45}
46
47/*!
48 * \brief Auxiliary class that collects the IterSplitExpr in the indexing pattern
49 * to help decision making in layout transformation
50 */
51class SplitExprCollector {
52 public:
53 /*!
54 * \brief The corresponding IterSplitExpr, simplified for our case
55 * The pattern is `source // lower_factor % extent * scale`
56 */
57 struct SplitExpr {
58 /*! \brief The source variable */
59 Var source;
60 /*! \brief The lower factor of the split expression */
61 int64_t lower_factor;
62 /*! \brief The extent of the split expression */
63 int64_t extent;
64 };
65
66 /*!
67 * \brief Collect the split expressions in the indexing pattern
68 * \param index The indexing pattern
69 * \param input_iters The input iterators' domain
70 * \param predicate The predicate of the affine map
71 * \param check_level The iter mapping checking level
72 * \param analyzer The analyzer
73 * \return The collected split expressions
74 */
75 static std::vector<SplitExpr> Collect(const PrimExpr& index,
76 const Map<Var, Range>& input_iters, //
77 const PrimExpr& predicate, //
78 arith::IterMapLevel check_level, //
79 arith::Analyzer* analyzer) {
80 arith::IterMapResult res = arith::DetectIterMap({analyzer->Simplify(index)}, input_iters,
81 predicate, check_level, analyzer);
82 const auto& iter_sum_exprs = res->indices;
83 if (iter_sum_exprs.empty()) {
84 return {};
85 }
86 ICHECK_EQ(iter_sum_exprs.size(), 1);
87 if (iter_sum_exprs[0]->args.size() == 0) {
88 return {};
89 }
90 SplitExprCollector collector;
91 collector.Visit(iter_sum_exprs[0]);
92 if (collector.failed_) {
93 return {};
94 }
95 return std::move(collector.exprs_);
96 }
97
98 private:
99 void Visit(const arith::IterSplitExpr& expr) {
100 if (const auto* var = expr->source->source.as<tir::VarNode>()) {
101 const int64_t* lower_factor = as_const_int(expr->lower_factor);
102 const int64_t* extent = as_const_int(expr->extent);
103 if (lower_factor == nullptr || extent == nullptr) {
104 failed_ = true;
105 return;
106 }
107 exprs_.push_back(SplitExpr{GetRef<Var>(var), *lower_factor, *extent});
108 } else if (const auto* iter_sum_expr = expr->source->source.as<arith::IterSumExprNode>()) {
109 Visit(GetRef<arith::IterSumExpr>(iter_sum_expr));
110 } else {
111 ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey();
112 }
113 }
114
115 void Visit(const arith::IterSumExpr& expr) {
116 for (const arith::IterSplitExpr& arg : expr->args) {
117 Visit(arg);
118 }
119 }
120
121 /*! \brief Whether the analysis failed */
122 bool failed_ = false;
123 /*! \brief The collected split expressions */
124 std::vector<SplitExpr> exprs_;
125};
126
127Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>& indices,
128 const Array<For>& loops, const PrimExpr& predicate,
129 arith::Analyzer* analyzer) {
130 int ndim = buffer->shape.size();
131 int n_loops = loops.size();
132 // Step 1. Collect the domains and indices of loop variables
133 Map<Var, Range> input_iters;
134 std::unordered_map<const VarNode*, int> var2id;
135 var2id.reserve(n_loops);
136 for (int i = 0; i < n_loops; ++i) {
137 const For& loop = loops[i];
138 input_iters.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
139 var2id.emplace(loop->loop_var.get(), i);
140 }
141 // Step 2. Calculate a functor that flattens a multi-dimensional index
142 auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype = buffer->DefaultIndexType()](
143 const Array<PrimExpr>& indices) -> PrimExpr {
144 PrimExpr flatten_index = make_const(dtype, 0);
145 for (int i = 0; i < ndim; ++i) {
146 flatten_index = flatten_index + strides[i] * indices[i];
147 }
148 return flatten_index;
149 };
150 // Step 3. Detect the IterSplitExpr of the indexing pattern
151 std::vector<SplitExprCollector::SplitExpr> split_exprs = SplitExprCollector::Collect(
152 /*index=*/f_flatten_index(indices), input_iters, predicate,
153 /*check_level=*/arith::IterMapLevel::Surjective, analyzer);
154 if (split_exprs.empty()) {
155 return NullOpt;
156 }
157 // Step 4. Sort the order of the split expressions
158 std::vector<int> order(split_exprs.size(), 0);
159 std::generate(order.begin(), order.end(), [n = 0]() mutable { return n++; });
160 std::sort(order.begin(), order.end(), [&split_exprs, &var2id](int _a, int _b) -> bool {
161 const SplitExprCollector::SplitExpr& a = split_exprs[_a];
162 const SplitExprCollector::SplitExpr& b = split_exprs[_b];
163 int a_var_id = var2id.at(a.source.get());
164 int b_var_id = var2id.at(b.source.get());
165 if (a_var_id != b_var_id) {
166 return a_var_id < b_var_id;
167 }
168 return a.lower_factor > b.lower_factor;
169 });
170 // Compute the inverse permutation by argsort
171 std::vector<int> inverse_order = order;
172 std::sort(inverse_order.begin(), inverse_order.end(),
173 [&order](int _a, int _b) -> bool { return order[_a] < order[_b]; });
174 // Step 5. Create the indexing mapping
175 auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), //
176 &split_exprs, //
177 &order, //
178 & shape = buffer->shape, //
179 analyzer //
180 ](Array<Var> indices) -> Array<PrimExpr> {
181 ICHECK_EQ(indices.size(), shape.size());
182 for (int i = 0, n = indices.size(); i < n; ++i) {
183 analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i]));
184 }
185 // Step 5.1: Fuse all indices into a flattened one
186 PrimExpr index = f_flatten_index({indices.begin(), indices.end()});
187 int ndim = split_exprs.size();
188 // Step 5.2. Split the flattened index according to `split_exprs`
189 std::vector<PrimExpr> split;
190 split.reserve(ndim);
191 for (int i = ndim - 1; i >= 0; --i) {
192 index = analyzer->Simplify(index);
193 int64_t extent = split_exprs[i].extent;
194 split.push_back(analyzer->Simplify(floormod(index, extent)));
195 index = floordiv(index, extent);
196 }
197 std::reverse(split.begin(), split.end());
198 // Step 5.3. Reorder the indexing pattern according to `order`
199 Array<PrimExpr> results;
200 results.reserve(ndim);
201 for (int i = 0; i < ndim; ++i) {
202 results.push_back(split[order[i]]);
203 }
204 return results;
205 };
206 // Step 6: Create the inverse index mapping.
207 auto f_inverse = [&inverse_order, &split_exprs, &shape = buffer->shape,
208 analyzer](Array<Var> indices) -> Array<PrimExpr> {
209 ICHECK_EQ(indices.size(), split_exprs.size());
210 // Step 6.1: Reorder the indices according to `inverse_order`. This is the inverse of Step 5.3.
211 // After the inverse permutation, indices[i] corresponds to split_exprs[i]
212 Array<Var> inv_permuted_indices;
213 inv_permuted_indices.reserve(indices.size());
214 for (int i = 0, n = indices.size(); i < n; ++i) {
215 const Var& index = indices[inverse_order[i]];
216 inv_permuted_indices.push_back(index);
217 analyzer->Bind(index, Range::FromMinExtent(0, Integer(split_exprs[i].extent)));
218 }
219
220 // Step 6.2: Fuse all the indices. This is the inverse of Step 5.2.
221 PrimExpr flattened_index = make_const(indices[0]->dtype, 0);
222 int64_t stride = 1;
223 for (int i = static_cast<int>(split_exprs.size()) - 1; i >= 0; --i) {
224 flattened_index = inv_permuted_indices[i] * Integer(stride) + flattened_index;
225 stride *= split_exprs[i].extent;
226 }
227 // Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1.
228 Array<PrimExpr> result;
229 result.reserve(shape.size());
230 for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
231 PrimExpr index = analyzer->Simplify(floormod(flattened_index, shape[i]));
232 flattened_index = floordiv(flattened_index, shape[i]);
233 result.push_back(index);
234 }
235 return Array<PrimExpr>(result.rbegin(), result.rend());
236 };
237 IndexMap inverse_index_map = IndexMap::FromFunc(split_exprs.size(), f_inverse);
238 return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map);
239}
240
241TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap")
242 .set_body_typed([](Buffer buffer, Array<PrimExpr> indices, Array<For> loops,
243 PrimExpr predicate) {
244 arith::Analyzer analyzer;
245 return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer);
246 });
247
248} // namespace tir
249} // namespace tvm
250