1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /*! |
25 | * \brief Calculate the strides of the buffer |
26 | * \param buffer The buffer |
27 | * \return The strides |
28 | */ |
29 | Array<PrimExpr> GetStrides(const Buffer& buffer) { |
30 | if (!buffer->strides.empty()) { |
31 | ICHECK_EQ(buffer->strides.size(), buffer->shape.size()); |
32 | return buffer->strides; |
33 | } |
34 | int ndim = buffer->shape.size(); |
35 | if (ndim == 0) { |
36 | return {}; |
37 | } |
38 | Array<PrimExpr> strides(ndim, PrimExpr{nullptr}); |
39 | PrimExpr stride = make_const(buffer->DefaultIndexType(), 1); |
40 | for (int i = ndim - 1; i >= 0; --i) { |
41 | strides.Set(i, stride); |
42 | stride = stride * buffer->shape[i]; |
43 | } |
44 | return strides; |
45 | } |
46 | |
47 | /*! |
48 | * \brief Auxiliary class that collects the IterSplitExpr in the indexing pattern |
49 | * to help decision making in layout transformation |
50 | */ |
51 | class SplitExprCollector { |
52 | public: |
53 | /*! |
54 | * \brief The corresponding IterSplitExpr, simplified for our case |
55 | * The pattern is `source // lower_factor % extent * scale` |
56 | */ |
57 | struct SplitExpr { |
58 | /*! \brief The source variable */ |
59 | Var source; |
60 | /*! \brief The lower factor of the split expression */ |
61 | int64_t lower_factor; |
62 | /*! \brief The extent of the split expression */ |
63 | int64_t extent; |
64 | }; |
65 | |
66 | /*! |
67 | * \brief Collect the split expressions in the indexing pattern |
68 | * \param index The indexing pattern |
69 | * \param input_iters The input iterators' domain |
70 | * \param predicate The predicate of the affine map |
71 | * \param check_level The iter mapping checking level |
72 | * \param analyzer The analyzer |
73 | * \return The collected split expressions |
74 | */ |
75 | static std::vector<SplitExpr> Collect(const PrimExpr& index, |
76 | const Map<Var, Range>& input_iters, // |
77 | const PrimExpr& predicate, // |
78 | arith::IterMapLevel check_level, // |
79 | arith::Analyzer* analyzer) { |
80 | arith::IterMapResult res = arith::DetectIterMap({analyzer->Simplify(index)}, input_iters, |
81 | predicate, check_level, analyzer); |
82 | const auto& iter_sum_exprs = res->indices; |
83 | if (iter_sum_exprs.empty()) { |
84 | return {}; |
85 | } |
86 | ICHECK_EQ(iter_sum_exprs.size(), 1); |
87 | if (iter_sum_exprs[0]->args.size() == 0) { |
88 | return {}; |
89 | } |
90 | SplitExprCollector collector; |
91 | collector.Visit(iter_sum_exprs[0]); |
92 | if (collector.failed_) { |
93 | return {}; |
94 | } |
95 | return std::move(collector.exprs_); |
96 | } |
97 | |
98 | private: |
99 | void Visit(const arith::IterSplitExpr& expr) { |
100 | if (const auto* var = expr->source->source.as<tir::VarNode>()) { |
101 | const int64_t* lower_factor = as_const_int(expr->lower_factor); |
102 | const int64_t* extent = as_const_int(expr->extent); |
103 | if (lower_factor == nullptr || extent == nullptr) { |
104 | failed_ = true; |
105 | return; |
106 | } |
107 | exprs_.push_back(SplitExpr{GetRef<Var>(var), *lower_factor, *extent}); |
108 | } else if (const auto* iter_sum_expr = expr->source->source.as<arith::IterSumExprNode>()) { |
109 | Visit(GetRef<arith::IterSumExpr>(iter_sum_expr)); |
110 | } else { |
111 | ICHECK(false) << "Unexpected type: " << expr->source->source->GetTypeKey(); |
112 | } |
113 | } |
114 | |
115 | void Visit(const arith::IterSumExpr& expr) { |
116 | for (const arith::IterSplitExpr& arg : expr->args) { |
117 | Visit(arg); |
118 | } |
119 | } |
120 | |
121 | /*! \brief Whether the analysis failed */ |
122 | bool failed_ = false; |
123 | /*! \brief The collected split expressions */ |
124 | std::vector<SplitExpr> exprs_; |
125 | }; |
126 | |
127 | Optional<IndexMap> SuggestIndexMap(const Buffer& buffer, const Array<PrimExpr>& indices, |
128 | const Array<For>& loops, const PrimExpr& predicate, |
129 | arith::Analyzer* analyzer) { |
130 | int ndim = buffer->shape.size(); |
131 | int n_loops = loops.size(); |
132 | // Step 1. Collect the domains and indices of loop variables |
133 | Map<Var, Range> input_iters; |
134 | std::unordered_map<const VarNode*, int> var2id; |
135 | var2id.reserve(n_loops); |
136 | for (int i = 0; i < n_loops; ++i) { |
137 | const For& loop = loops[i]; |
138 | input_iters.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); |
139 | var2id.emplace(loop->loop_var.get(), i); |
140 | } |
141 | // Step 2. Calculate a functor that flattens a multi-dimensional index |
142 | auto f_flatten_index = [ndim, strides = GetStrides(buffer), dtype = buffer->DefaultIndexType()]( |
143 | const Array<PrimExpr>& indices) -> PrimExpr { |
144 | PrimExpr flatten_index = make_const(dtype, 0); |
145 | for (int i = 0; i < ndim; ++i) { |
146 | flatten_index = flatten_index + strides[i] * indices[i]; |
147 | } |
148 | return flatten_index; |
149 | }; |
150 | // Step 3. Detect the IterSplitExpr of the indexing pattern |
151 | std::vector<SplitExprCollector::SplitExpr> split_exprs = SplitExprCollector::Collect( |
152 | /*index=*/f_flatten_index(indices), input_iters, predicate, |
153 | /*check_level=*/arith::IterMapLevel::Surjective, analyzer); |
154 | if (split_exprs.empty()) { |
155 | return NullOpt; |
156 | } |
157 | // Step 4. Sort the order of the split expressions |
158 | std::vector<int> order(split_exprs.size(), 0); |
159 | std::generate(order.begin(), order.end(), [n = 0]() mutable { return n++; }); |
160 | std::sort(order.begin(), order.end(), [&split_exprs, &var2id](int _a, int _b) -> bool { |
161 | const SplitExprCollector::SplitExpr& a = split_exprs[_a]; |
162 | const SplitExprCollector::SplitExpr& b = split_exprs[_b]; |
163 | int a_var_id = var2id.at(a.source.get()); |
164 | int b_var_id = var2id.at(b.source.get()); |
165 | if (a_var_id != b_var_id) { |
166 | return a_var_id < b_var_id; |
167 | } |
168 | return a.lower_factor > b.lower_factor; |
169 | }); |
170 | // Compute the inverse permutation by argsort |
171 | std::vector<int> inverse_order = order; |
172 | std::sort(inverse_order.begin(), inverse_order.end(), |
173 | [&order](int _a, int _b) -> bool { return order[_a] < order[_b]; }); |
174 | // Step 5. Create the indexing mapping |
175 | auto f_alter_layout = [f_flatten_index = std::move(f_flatten_index), // |
176 | &split_exprs, // |
177 | &order, // |
178 | & shape = buffer->shape, // |
179 | analyzer // |
180 | ](Array<Var> indices) -> Array<PrimExpr> { |
181 | ICHECK_EQ(indices.size(), shape.size()); |
182 | for (int i = 0, n = indices.size(); i < n; ++i) { |
183 | analyzer->Bind(indices[i], Range::FromMinExtent(0, shape[i])); |
184 | } |
185 | // Step 5.1: Fuse all indices into a flattened one |
186 | PrimExpr index = f_flatten_index({indices.begin(), indices.end()}); |
187 | int ndim = split_exprs.size(); |
188 | // Step 5.2. Split the flattened index according to `split_exprs` |
189 | std::vector<PrimExpr> split; |
190 | split.reserve(ndim); |
191 | for (int i = ndim - 1; i >= 0; --i) { |
192 | index = analyzer->Simplify(index); |
193 | int64_t extent = split_exprs[i].extent; |
194 | split.push_back(analyzer->Simplify(floormod(index, extent))); |
195 | index = floordiv(index, extent); |
196 | } |
197 | std::reverse(split.begin(), split.end()); |
198 | // Step 5.3. Reorder the indexing pattern according to `order` |
199 | Array<PrimExpr> results; |
200 | results.reserve(ndim); |
201 | for (int i = 0; i < ndim; ++i) { |
202 | results.push_back(split[order[i]]); |
203 | } |
204 | return results; |
205 | }; |
206 | // Step 6: Create the inverse index mapping. |
207 | auto f_inverse = [&inverse_order, &split_exprs, &shape = buffer->shape, |
208 | analyzer](Array<Var> indices) -> Array<PrimExpr> { |
209 | ICHECK_EQ(indices.size(), split_exprs.size()); |
210 | // Step 6.1: Reorder the indices according to `inverse_order`. This is the inverse of Step 5.3. |
211 | // After the inverse permutation, indices[i] corresponds to split_exprs[i] |
212 | Array<Var> inv_permuted_indices; |
213 | inv_permuted_indices.reserve(indices.size()); |
214 | for (int i = 0, n = indices.size(); i < n; ++i) { |
215 | const Var& index = indices[inverse_order[i]]; |
216 | inv_permuted_indices.push_back(index); |
217 | analyzer->Bind(index, Range::FromMinExtent(0, Integer(split_exprs[i].extent))); |
218 | } |
219 | |
220 | // Step 6.2: Fuse all the indices. This is the inverse of Step 5.2. |
221 | PrimExpr flattened_index = make_const(indices[0]->dtype, 0); |
222 | int64_t stride = 1; |
223 | for (int i = static_cast<int>(split_exprs.size()) - 1; i >= 0; --i) { |
224 | flattened_index = inv_permuted_indices[i] * Integer(stride) + flattened_index; |
225 | stride *= split_exprs[i].extent; |
226 | } |
227 | // Step 6.3: Split the flattened index into multiple indices. This is the inverse of Step 5.1. |
228 | Array<PrimExpr> result; |
229 | result.reserve(shape.size()); |
230 | for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) { |
231 | PrimExpr index = analyzer->Simplify(floormod(flattened_index, shape[i])); |
232 | flattened_index = floordiv(flattened_index, shape[i]); |
233 | result.push_back(index); |
234 | } |
235 | return Array<PrimExpr>(result.rbegin(), result.rend()); |
236 | }; |
237 | IndexMap inverse_index_map = IndexMap::FromFunc(split_exprs.size(), f_inverse); |
238 | return IndexMap::FromFunc(ndim, f_alter_layout, inverse_index_map); |
239 | } |
240 | |
241 | TVM_REGISTER_GLOBAL("tir.schedule.SuggestIndexMap" ) |
242 | .set_body_typed([](Buffer buffer, Array<PrimExpr> indices, Array<For> loops, |
243 | PrimExpr predicate) { |
244 | arith::Analyzer analyzer; |
245 | return SuggestIndexMap(buffer, indices, loops, predicate, &analyzer); |
246 | }); |
247 | |
248 | } // namespace tir |
249 | } // namespace tvm |
250 | |