1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace tir { |
23 | |
24 | /*! |
25 | * \brief Check whether the loop has any annotation |
26 | * \param sref The sref of loop |
27 | * \return Whether the loop has any annotation |
28 | */ |
29 | inline bool HasAnnOrBinding(const ForNode* loop) { |
30 | return loop->kind == ForKind::kThreadBinding || !loop->annotations.empty(); |
31 | } |
32 | |
33 | /*! \brief The visitor for extracting the stride of a var in a PrimExpr. */ |
34 | class : public ExprVisitor { |
35 | public: |
36 | /*! |
37 | * \brief Extracting the stride of a var in a PrimExpr. |
38 | * e.g the stride of `x` in `(x * 2 + 1) * 3 + 1` is 6 |
39 | * \param expr The given PrimExpr. |
40 | * \param var The target var. |
41 | * \return The stride of the var. |
42 | */ |
43 | static int64_t (const PrimExpr& expr, const Var& var) { |
44 | StrideExtractor (var); |
45 | extractor.VisitExpr(expr); |
46 | return extractor.strides_[expr.get()]; |
47 | } |
48 | |
49 | private: |
50 | explicit (const Var& var) : var_(var) {} |
51 | |
52 | void (const MulNode* node) final { |
53 | ExprVisitor::VisitExpr_(node); |
54 | |
55 | if (const auto* a = node->a.as<IntImmNode>()) { |
56 | if (strides_.count(node->b.get())) { |
57 | strides_[node] = strides_[node->b.get()] * a->value; |
58 | } |
59 | } else if (const auto* b = node->b.as<IntImmNode>()) { |
60 | if (strides_.count(node->a.get())) { |
61 | strides_[node] = strides_[node->a.get()] * b->value; |
62 | } |
63 | } |
64 | } |
65 | |
66 | void (const AddNode* node) final { |
67 | ExprVisitor::VisitExpr_(node); |
68 | int64_t stride_a, stride_b; |
69 | if (strides_.count(node->a.get())) { |
70 | stride_a = strides_[node->a.get()]; |
71 | } else { |
72 | stride_a = INT64_MAX; |
73 | } |
74 | if (strides_.count(node->b.get())) { |
75 | stride_b = strides_[node->b.get()]; |
76 | } else { |
77 | stride_b = INT64_MAX; |
78 | } |
79 | if (stride_a != INT64_MAX || stride_b != INT64_MAX) { |
80 | strides_[node] = std::min(stride_a, stride_b); |
81 | } |
82 | } |
83 | |
84 | void (const VarNode* node) final { |
85 | if (node == var_.get()) { |
86 | strides_[node] = 1; |
87 | } |
88 | } |
89 | |
90 | const Var& ; |
91 | std::unordered_map<const PrimExprNode*, int64_t> ; |
92 | }; |
93 | |
94 | struct ParsedAnnotation { |
95 | int max_parallel_extent; |
96 | int max_vectorize_extent; |
97 | int unroll_explicit; |
98 | int unroll_implicit; |
99 | int num_parallel_loops; |
100 | int num_vectorize_loops; |
101 | }; |
102 | |
103 | bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { |
104 | bool found = false; |
105 | *parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1}; |
106 | for (const auto& ann : block->annotations) { |
107 | if (ann.first == attr::meta_schedule_parallel) { |
108 | found = true; |
109 | if (const auto* imm = ann.second.as<tir::IntImmNode>()) { |
110 | parsed->max_parallel_extent = imm->value; |
111 | } |
112 | } else if (ann.first == attr::meta_schedule_vectorize) { |
113 | found = true; |
114 | if (const auto* imm = ann.second.as<tir::IntImmNode>()) { |
115 | parsed->max_vectorize_extent = imm->value; |
116 | } |
117 | } else if (ann.first == attr::meta_schedule_unroll_explicit) { |
118 | found = true; |
119 | if (const auto* imm = ann.second.as<tir::IntImmNode>()) { |
120 | parsed->unroll_explicit = imm->value; |
121 | } |
122 | } else if (ann.first == attr::meta_schedule_unroll_implicit) { |
123 | found = true; |
124 | if (const auto* imm = ann.second.as<tir::IntImmNode>()) { |
125 | parsed->unroll_implicit = imm->value; |
126 | } |
127 | } |
128 | } |
129 | return found; |
130 | } |
131 | |
132 | void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedAnnotation& parsed) { |
133 | if (parsed.max_parallel_extent != -1) { |
134 | sch->Unannotate(block_rv, attr::meta_schedule_parallel); |
135 | } |
136 | if (parsed.max_vectorize_extent != -1) { |
137 | sch->Unannotate(block_rv, attr::meta_schedule_vectorize); |
138 | } |
139 | if (parsed.unroll_explicit != -1) { |
140 | sch->Unannotate(block_rv, attr::meta_schedule_unroll_explicit); |
141 | } |
142 | if (parsed.unroll_implicit != -1) { |
143 | sch->Unannotate(block_rv, attr::meta_schedule_unroll_implicit); |
144 | } |
145 | } |
146 | |
147 | void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, |
148 | const Array<LoopRV>& loop_rvs, ParsedAnnotation* parsed) { |
149 | StmtSRef block_sref = sch->GetSRef(block_rv); |
150 | if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { |
151 | return; |
152 | } |
153 | int n_loops = loop_rvs.size(); |
154 | if (n_loops == 0) { |
155 | parsed->max_parallel_extent = -1; |
156 | parsed->max_vectorize_extent = -1; |
157 | return; |
158 | } |
159 | // Extract loop_srefs, and calculate the iterator types |
160 | Array<StmtSRef> loop_srefs; |
161 | std::vector<int> loop_types; |
162 | { |
163 | loop_srefs.reserve(n_loops); |
164 | loop_types.reserve(n_loops); |
165 | for (const LoopRV& loop_rv : loop_rvs) { |
166 | loop_srefs.push_back(sch->GetSRef(loop_rv)); |
167 | loop_types.push_back(GetLoopIterType(loop_srefs.back())); |
168 | } |
169 | } |
170 | // check the maximal number of axes that are vectorizable (contiguous memory access) |
171 | BlockRealize realize = GetBlockRealize(sch->state(), block_sref); |
172 | Array<BufferRegion> buffer_access(realize->block->reads); |
173 | buffer_access.insert(buffer_access.end(), realize->block->writes.begin(), |
174 | realize->block->writes.end()); |
175 | std::unordered_map<const VarNode*, PrimExpr> binding_map; |
176 | for (size_t i = 0; i < realize->iter_values.size(); i++) { |
177 | binding_map[realize->block->iter_vars[i]->var.get()] = realize->iter_values[i]; |
178 | } |
179 | int max_fusible = INT32_MAX; |
180 | // for each block read/write, get the strides of the loop vars and find the fusible |
181 | // (vectorizable) axes |
182 | for (const BufferRegion& access : buffer_access) { |
183 | int fusible = 0; |
184 | std::vector<int64_t> strides; |
185 | // get strides for each loop var |
186 | for (const StmtSRef& loop_sref : loop_srefs) { |
187 | int64_t stride = 0, buffer_stride = 1; |
188 | const auto* var = loop_sref->StmtAs<ForNode>(); |
189 | arith::Analyzer analyzer; |
190 | for (int i = access->region.size() - 1; i >= 0; i--) { |
191 | PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map)); |
192 | int64_t coef = StrideExtractor::Extract(idx, var->loop_var); |
193 | if (coef != 0) { |
194 | stride = coef * buffer_stride; |
195 | break; |
196 | } |
197 | buffer_stride *= access->buffer->shape[i].as<IntImmNode>()->value; |
198 | } |
199 | strides.push_back(stride); |
200 | } |
201 | int prev_used_iter = -1; |
202 | // check the number of fusible loops |
203 | for (int i = strides.size() - 1; i >= 0; i--) { |
204 | if (strides[i] == 0) { |
205 | // not used in the buffer access, safe to fuse |
206 | fusible++; |
207 | continue; |
208 | } else if (prev_used_iter == -1) { |
209 | // the stride of last axis is not 1 means the memory access is not contiguous |
210 | if (strides[i] != 1 && fusible != 0) { |
211 | break; |
212 | } |
213 | fusible++; |
214 | prev_used_iter = i; |
215 | } else { |
216 | // contiguous memory access |
217 | const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs<ForNode>(); |
218 | int64_t prev_used_iter_extent = prev_loop->extent.as<IntImmNode>()->value; |
219 | if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) { |
220 | fusible++; |
221 | prev_used_iter = i; |
222 | } else { |
223 | break; |
224 | } |
225 | } |
226 | } |
227 | max_fusible = std::min(max_fusible, fusible); |
228 | } |
229 | // Calculate the parallelize extent |
230 | if (parsed->max_parallel_extent != -1) { |
231 | int max_extent = parsed->max_parallel_extent; |
232 | int& num_fusible = parsed->num_parallel_loops = 0; |
233 | int64_t prod_extent = 1; |
234 | for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) { |
235 | const StmtSRef& loop_sref = loop_srefs[i]; |
236 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
237 | if (HasAnnOrBinding(loop)) { |
238 | break; |
239 | } |
240 | // Check if the loop extent is valid |
241 | const int64_t* extent = GetLoopIntExtent(loop_sref); |
242 | if (extent == nullptr) { |
243 | break; |
244 | } |
245 | // Then we can fuse it in |
246 | ++num_fusible; |
247 | // Check if we need to break |
248 | prod_extent *= *extent; |
249 | if (prod_extent > max_extent || !IsSingleStmt(loop->body)) { |
250 | break; |
251 | } |
252 | } |
253 | if (prod_extent == 1) { |
254 | num_fusible = -1; |
255 | } |
256 | } |
257 | // Calculate the vectorize extent |
258 | if (parsed->max_vectorize_extent != -1) { |
259 | int max_extent = parsed->max_vectorize_extent; |
260 | int& num_fusible = parsed->num_vectorize_loops = 0; |
261 | int64_t prod_extent = 1; |
262 | for (int i = n_loops - 1; |
263 | i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) { |
264 | const StmtSRef& loop_sref = loop_srefs[i]; |
265 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
266 | if (HasAnnOrBinding(loop)) { |
267 | break; |
268 | } |
269 | // Cannot vectorize reduce axis |
270 | if (GetLoopIterType(loop_sref) != IterVarType::kDataPar) { |
271 | break; |
272 | } |
273 | // Cannot fuse with a loop with multiple children |
274 | if (!IsSingleStmt(loop->body)) { |
275 | break; |
276 | } |
277 | // Check if the loop extent is valid |
278 | const int64_t* extent = GetLoopIntExtent(loop_sref); |
279 | if (extent == nullptr) { |
280 | break; |
281 | } |
282 | // Check if the extent is still in a good range |
283 | prod_extent *= *extent; |
284 | if (prod_extent > max_extent) { |
285 | break; |
286 | } |
287 | ++num_fusible; |
288 | } |
289 | if (prod_extent == 1) { |
290 | num_fusible = -1; |
291 | } |
292 | } |
293 | // Prefer num_vectorize to num_parallel |
294 | if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) { |
295 | parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, // |
296 | n_loops - parsed->num_vectorize_loops); |
297 | } |
298 | } |
299 | |
300 | bool FindAnnotatedRootBlock(const Schedule& sch, ParsedAnnotation* parsed, BlockRV* root_rv) { |
301 | IRModule mod = sch->mod(); |
302 | for (const auto& kv : mod->functions) { |
303 | const GlobalVar& g_var = kv.first; |
304 | const BaseFunc& base_func = kv.second; |
305 | if (const auto* prim_func = base_func.as<PrimFuncNode>()) { |
306 | const BlockRealizeNode* block_realize = prim_func->body.as<BlockRealizeNode>(); |
307 | if (block_realize != nullptr) { |
308 | Block block = block_realize->block; |
309 | if (ParseAnnotation(block, parsed)) { |
310 | *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint); |
311 | RemoveParsedAnn(sch, *root_rv, *parsed); |
312 | return true; |
313 | } |
314 | } |
315 | } |
316 | } |
317 | return false; |
318 | } |
319 | |
320 | void RewriteParallel(const Schedule& sch, size_t n, Array<LoopRV>* loop_rvs) { |
321 | ICHECK_LE(n, loop_rvs->size()); |
322 | LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); |
323 | sch->Parallel(fused); |
324 | for (size_t i = 0; i < n; ++i) { |
325 | loop_rvs->Set(i, fused); |
326 | } |
327 | } |
328 | |
329 | void RewriteVectorize(const Schedule& sch, size_t n, Array<LoopRV>* loop_rvs) { |
330 | size_t n_loops = loop_rvs->size(); |
331 | ICHECK_LE(n, n_loops); |
332 | LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()}); |
333 | sch->Vectorize(fused); |
334 | for (size_t i = n_loops - n; i < n_loops; ++i) { |
335 | loop_rvs->Set(i, fused); |
336 | } |
337 | } |
338 | |
339 | void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const LoopRV& loop) { |
340 | if (max_step > 0) { |
341 | sch->Annotate(loop, attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); |
342 | sch->Annotate(loop, attr::pragma_unroll_explicit, IntImm(DataType::Int(32), unroll_explicit)); |
343 | } |
344 | } |
345 | |
346 | } // namespace tir |
347 | |
348 | namespace meta_schedule { |
349 | |
350 | using tir::Schedule; |
351 | |
352 | class RewriteParallelVectorizeUnrollNode : public PostprocNode { |
353 | public: |
354 | void InitializeWithTuneContext(const TuneContext& context) final {} |
355 | |
356 | bool Apply(const Schedule& sch) final { |
357 | tir::ParsedAnnotation parsed_root; |
358 | tir::BlockRV root_rv{nullptr}; |
359 | while (tir::FindAnnotatedRootBlock(sch, &parsed_root, &root_rv)) { |
360 | for (tir::BlockRV block_rv : sch->GetChildBlocks(root_rv)) { |
361 | Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv); |
362 | if (loop_rvs.empty()) { |
363 | continue; |
364 | } |
365 | tir::ParsedAnnotation parsed = parsed_root; |
366 | tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); |
367 | // Parallel |
368 | if (parsed.num_parallel_loops > 0) { |
369 | tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); |
370 | } |
371 | // Vectorize |
372 | if (parsed.num_vectorize_loops > 0) { |
373 | tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); |
374 | } |
375 | // AutoUnroll |
376 | if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { |
377 | ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); |
378 | int unroll_explicit = parsed.unroll_explicit != -1; |
379 | int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; |
380 | tir::RewriteUnroll(sch, unroll_explicit, max_step, loop_rvs[0]); |
381 | } |
382 | } |
383 | } |
384 | return true; |
385 | } |
386 | |
387 | Postproc Clone() const { |
388 | ObjectPtr<RewriteParallelVectorizeUnrollNode> n = |
389 | make_object<RewriteParallelVectorizeUnrollNode>(*this); |
390 | return Postproc(n); |
391 | } |
392 | |
393 | static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll" ; |
394 | TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode); |
395 | }; |
396 | |
397 | Postproc Postproc::RewriteParallelVectorizeUnroll() { |
398 | ObjectPtr<RewriteParallelVectorizeUnrollNode> n = |
399 | make_object<RewriteParallelVectorizeUnrollNode>(); |
400 | return Postproc(n); |
401 | } |
402 | |
403 | TVM_REGISTER_NODE_TYPE(RewriteParallelVectorizeUnrollNode); |
404 | TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll" ) |
405 | .set_body_typed(Postproc::RewriteParallelVectorizeUnroll); |
406 | |
407 | } // namespace meta_schedule |
408 | } // namespace tvm |
409 | |