1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace tir {
23
24/*!
25 * \brief Check whether the loop has any annotation
26 * \param sref The sref of loop
27 * \return Whether the loop has any annotation
28 */
29inline bool HasAnnOrBinding(const ForNode* loop) {
30 return loop->kind == ForKind::kThreadBinding || !loop->annotations.empty();
31}
32
33/*! \brief The visitor for extracting the stride of a var in a PrimExpr. */
34class StrideExtractor : public ExprVisitor {
35 public:
36 /*!
37 * \brief Extracting the stride of a var in a PrimExpr.
38 * e.g the stride of `x` in `(x * 2 + 1) * 3 + 1` is 6
39 * \param expr The given PrimExpr.
40 * \param var The target var.
41 * \return The stride of the var.
42 */
43 static int64_t Extract(const PrimExpr& expr, const Var& var) {
44 StrideExtractor extractor(var);
45 extractor.VisitExpr(expr);
46 return extractor.strides_[expr.get()];
47 }
48
49 private:
50 explicit StrideExtractor(const Var& var) : var_(var) {}
51
52 void VisitExpr_(const MulNode* node) final {
53 ExprVisitor::VisitExpr_(node);
54
55 if (const auto* a = node->a.as<IntImmNode>()) {
56 if (strides_.count(node->b.get())) {
57 strides_[node] = strides_[node->b.get()] * a->value;
58 }
59 } else if (const auto* b = node->b.as<IntImmNode>()) {
60 if (strides_.count(node->a.get())) {
61 strides_[node] = strides_[node->a.get()] * b->value;
62 }
63 }
64 }
65
66 void VisitExpr_(const AddNode* node) final {
67 ExprVisitor::VisitExpr_(node);
68 int64_t stride_a, stride_b;
69 if (strides_.count(node->a.get())) {
70 stride_a = strides_[node->a.get()];
71 } else {
72 stride_a = INT64_MAX;
73 }
74 if (strides_.count(node->b.get())) {
75 stride_b = strides_[node->b.get()];
76 } else {
77 stride_b = INT64_MAX;
78 }
79 if (stride_a != INT64_MAX || stride_b != INT64_MAX) {
80 strides_[node] = std::min(stride_a, stride_b);
81 }
82 }
83
84 void VisitExpr_(const VarNode* node) final {
85 if (node == var_.get()) {
86 strides_[node] = 1;
87 }
88 }
89
90 const Var& var_;
91 std::unordered_map<const PrimExprNode*, int64_t> strides_;
92};
93
94struct ParsedAnnotation {
95 int max_parallel_extent;
96 int max_vectorize_extent;
97 int unroll_explicit;
98 int unroll_implicit;
99 int num_parallel_loops;
100 int num_vectorize_loops;
101};
102
103bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) {
104 bool found = false;
105 *parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1};
106 for (const auto& ann : block->annotations) {
107 if (ann.first == attr::meta_schedule_parallel) {
108 found = true;
109 if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
110 parsed->max_parallel_extent = imm->value;
111 }
112 } else if (ann.first == attr::meta_schedule_vectorize) {
113 found = true;
114 if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
115 parsed->max_vectorize_extent = imm->value;
116 }
117 } else if (ann.first == attr::meta_schedule_unroll_explicit) {
118 found = true;
119 if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
120 parsed->unroll_explicit = imm->value;
121 }
122 } else if (ann.first == attr::meta_schedule_unroll_implicit) {
123 found = true;
124 if (const auto* imm = ann.second.as<tir::IntImmNode>()) {
125 parsed->unroll_implicit = imm->value;
126 }
127 }
128 }
129 return found;
130}
131
132void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedAnnotation& parsed) {
133 if (parsed.max_parallel_extent != -1) {
134 sch->Unannotate(block_rv, attr::meta_schedule_parallel);
135 }
136 if (parsed.max_vectorize_extent != -1) {
137 sch->Unannotate(block_rv, attr::meta_schedule_vectorize);
138 }
139 if (parsed.unroll_explicit != -1) {
140 sch->Unannotate(block_rv, attr::meta_schedule_unroll_explicit);
141 }
142 if (parsed.unroll_implicit != -1) {
143 sch->Unannotate(block_rv, attr::meta_schedule_unroll_implicit);
144 }
145}
146
147void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv,
148 const Array<LoopRV>& loop_rvs, ParsedAnnotation* parsed) {
149 StmtSRef block_sref = sch->GetSRef(block_rv);
150 if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) {
151 return;
152 }
153 int n_loops = loop_rvs.size();
154 if (n_loops == 0) {
155 parsed->max_parallel_extent = -1;
156 parsed->max_vectorize_extent = -1;
157 return;
158 }
159 // Extract loop_srefs, and calculate the iterator types
160 Array<StmtSRef> loop_srefs;
161 std::vector<int> loop_types;
162 {
163 loop_srefs.reserve(n_loops);
164 loop_types.reserve(n_loops);
165 for (const LoopRV& loop_rv : loop_rvs) {
166 loop_srefs.push_back(sch->GetSRef(loop_rv));
167 loop_types.push_back(GetLoopIterType(loop_srefs.back()));
168 }
169 }
170 // check the maximal number of axes that are vectorizable (contiguous memory access)
171 BlockRealize realize = GetBlockRealize(sch->state(), block_sref);
172 Array<BufferRegion> buffer_access(realize->block->reads);
173 buffer_access.insert(buffer_access.end(), realize->block->writes.begin(),
174 realize->block->writes.end());
175 std::unordered_map<const VarNode*, PrimExpr> binding_map;
176 for (size_t i = 0; i < realize->iter_values.size(); i++) {
177 binding_map[realize->block->iter_vars[i]->var.get()] = realize->iter_values[i];
178 }
179 int max_fusible = INT32_MAX;
180 // for each block read/write, get the strides of the loop vars and find the fusible
181 // (vectorizable) axes
182 for (const BufferRegion& access : buffer_access) {
183 int fusible = 0;
184 std::vector<int64_t> strides;
185 // get strides for each loop var
186 for (const StmtSRef& loop_sref : loop_srefs) {
187 int64_t stride = 0, buffer_stride = 1;
188 const auto* var = loop_sref->StmtAs<ForNode>();
189 arith::Analyzer analyzer;
190 for (int i = access->region.size() - 1; i >= 0; i--) {
191 PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map));
192 int64_t coef = StrideExtractor::Extract(idx, var->loop_var);
193 if (coef != 0) {
194 stride = coef * buffer_stride;
195 break;
196 }
197 buffer_stride *= access->buffer->shape[i].as<IntImmNode>()->value;
198 }
199 strides.push_back(stride);
200 }
201 int prev_used_iter = -1;
202 // check the number of fusible loops
203 for (int i = strides.size() - 1; i >= 0; i--) {
204 if (strides[i] == 0) {
205 // not used in the buffer access, safe to fuse
206 fusible++;
207 continue;
208 } else if (prev_used_iter == -1) {
209 // the stride of last axis is not 1 means the memory access is not contiguous
210 if (strides[i] != 1 && fusible != 0) {
211 break;
212 }
213 fusible++;
214 prev_used_iter = i;
215 } else {
216 // contiguous memory access
217 const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs<ForNode>();
218 int64_t prev_used_iter_extent = prev_loop->extent.as<IntImmNode>()->value;
219 if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) {
220 fusible++;
221 prev_used_iter = i;
222 } else {
223 break;
224 }
225 }
226 }
227 max_fusible = std::min(max_fusible, fusible);
228 }
229 // Calculate the parallelize extent
230 if (parsed->max_parallel_extent != -1) {
231 int max_extent = parsed->max_parallel_extent;
232 int& num_fusible = parsed->num_parallel_loops = 0;
233 int64_t prod_extent = 1;
234 for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) {
235 const StmtSRef& loop_sref = loop_srefs[i];
236 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
237 if (HasAnnOrBinding(loop)) {
238 break;
239 }
240 // Check if the loop extent is valid
241 const int64_t* extent = GetLoopIntExtent(loop_sref);
242 if (extent == nullptr) {
243 break;
244 }
245 // Then we can fuse it in
246 ++num_fusible;
247 // Check if we need to break
248 prod_extent *= *extent;
249 if (prod_extent > max_extent || !IsSingleStmt(loop->body)) {
250 break;
251 }
252 }
253 if (prod_extent == 1) {
254 num_fusible = -1;
255 }
256 }
257 // Calculate the vectorize extent
258 if (parsed->max_vectorize_extent != -1) {
259 int max_extent = parsed->max_vectorize_extent;
260 int& num_fusible = parsed->num_vectorize_loops = 0;
261 int64_t prod_extent = 1;
262 for (int i = n_loops - 1;
263 i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) {
264 const StmtSRef& loop_sref = loop_srefs[i];
265 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
266 if (HasAnnOrBinding(loop)) {
267 break;
268 }
269 // Cannot vectorize reduce axis
270 if (GetLoopIterType(loop_sref) != IterVarType::kDataPar) {
271 break;
272 }
273 // Cannot fuse with a loop with multiple children
274 if (!IsSingleStmt(loop->body)) {
275 break;
276 }
277 // Check if the loop extent is valid
278 const int64_t* extent = GetLoopIntExtent(loop_sref);
279 if (extent == nullptr) {
280 break;
281 }
282 // Check if the extent is still in a good range
283 prod_extent *= *extent;
284 if (prod_extent > max_extent) {
285 break;
286 }
287 ++num_fusible;
288 }
289 if (prod_extent == 1) {
290 num_fusible = -1;
291 }
292 }
293 // Prefer num_vectorize to num_parallel
294 if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) {
295 parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, //
296 n_loops - parsed->num_vectorize_loops);
297 }
298}
299
300bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, BlockRV* root_rv) {
301 IRModule mod = sch->mod();
302 for (const auto& kv : mod->functions) {
303 const GlobalVar& g_var = kv.first;
304 const BaseFunc& base_func = kv.second;
305 if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
306 const BlockRealizeNode* block_realize = prim_func->body.as<BlockRealizeNode>();
307 if (block_realize != nullptr) {
308 Block block = block_realize->block;
309 if (ParseAnnotation(block, parsed)) {
310 *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint);
311 RemoveParsedAnn(sch, *root_rv, *parsed);
312 return true;
313 }
314 }
315 }
316 }
317 return false;
318}
319
320void RewriteParallel(const Schedule& sch, size_t n, Array<LoopRV>* loop_rvs) {
321 ICHECK_LE(n, loop_rvs->size());
322 LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n});
323 sch->Parallel(fused);
324 for (size_t i = 0; i < n; ++i) {
325 loop_rvs->Set(i, fused);
326 }
327}
328
329void RewriteVectorize(const Schedule& sch, size_t n, Array<LoopRV>* loop_rvs) {
330 size_t n_loops = loop_rvs->size();
331 ICHECK_LE(n, n_loops);
332 LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()});
333 sch->Vectorize(fused);
334 for (size_t i = n_loops - n; i < n_loops; ++i) {
335 loop_rvs->Set(i, fused);
336 }
337}
338
339void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const LoopRV& loop) {
340 if (max_step > 0) {
341 sch->Annotate(loop, attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step));
342 sch->Annotate(loop, attr::pragma_unroll_explicit, IntImm(DataType::Int(32), unroll_explicit));
343 }
344}
345
346} // namespace tir
347
348namespace meta_schedule {
349
350using tir::Schedule;
351
352class RewriteParallelVectorizeUnrollNode : public PostprocNode {
353 public:
354 void InitializeWithTuneContext(const TuneContext& context) final {}
355
356 bool Apply(const Schedule& sch) final {
357 tir::ParsedAnnotation parsed_root;
358 tir::BlockRV root_rv{nullptr};
359 while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) {
360 for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) {
361 Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
362 if (loop_rvs.empty()) {
363 continue;
364 }
365 tir::ParsedAnnotation parsed = parsed_root;
366 tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed);
367 // Parallel
368 if (parsed.num_parallel_loops > 0) {
369 tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs);
370 }
371 // Vectorize
372 if (parsed.num_vectorize_loops > 0) {
373 tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs);
374 }
375 // AutoUnroll
376 if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) {
377 ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1);
378 int unroll_explicit = parsed.unroll_explicit != -1;
379 int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1;
380 tir::RewriteUnroll(sch, unroll_explicit, max_step, loop_rvs[0]);
381 }
382 }
383 }
384 return true;
385 }
386
387 Postproc Clone() const {
388 ObjectPtr<RewriteParallelVectorizeUnrollNode> n =
389 make_object<RewriteParallelVectorizeUnrollNode>(*this);
390 return Postproc(n);
391 }
392
393 static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll";
394 TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode);
395};
396
397Postproc Postproc::RewriteParallelVectorizeUnroll() {
398 ObjectPtr<RewriteParallelVectorizeUnrollNode> n =
399 make_object<RewriteParallelVectorizeUnrollNode>();
400 return Postproc(n);
401}
402
403TVM_REGISTER_NODE_TYPE(RewriteParallelVectorizeUnrollNode);
404TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll")
405 .set_body_typed(Postproc::RewriteParallelVectorizeUnroll);
406
407} // namespace meta_schedule
408} // namespace tvm
409