1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../ir_comparator.h" |
20 | #include "../utils.h" |
21 | |
22 | namespace tvm { |
23 | namespace tir { |
24 | |
25 | /******** IR Module ********/ |
26 | |
27 | const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, |
28 | GlobalVar* result_g_var) { |
29 | for (const auto& kv : mod->functions) { |
30 | const GlobalVar& g_var = kv.first; |
31 | const BaseFunc& base_func = kv.second; |
32 | if (const auto* func = base_func.as<PrimFuncNode>()) { |
33 | if (const auto* realize = func->body.as<BlockRealizeNode>()) { |
34 | if (realize->block.get() == root_block) { |
35 | if (result_g_var != nullptr) { |
36 | *result_g_var = g_var; |
37 | } |
38 | return func; |
39 | } |
40 | } |
41 | } |
42 | } |
43 | LOG(FATAL) << "IndexError: Could not get the corresponding function in the schedule state of the " |
44 | "statement:\n" |
45 | << GetRef<Stmt>(root_block); |
46 | throw; |
47 | } |
48 | |
49 | /******** Scope ********/ |
50 | |
51 | StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, |
52 | bool require_stage_pipeline) { |
53 | class RootBlockError : public ScheduleError { |
54 | public: |
55 | explicit RootBlockError(IRModule mod) : mod_(mod) {} |
56 | IRModule mod() const final { return mod_; } |
57 | String FastErrorString() const final { |
58 | return "ScheduleError: The primitive does not operate on the root block" ; |
59 | } |
60 | String DetailRenderTemplate() const final { |
61 | return "The primitive does not operate on the root block" ; |
62 | } |
63 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
64 | IRModule mod_; |
65 | }; |
66 | |
67 | class NotStagePipelineError : public ScheduleError { |
68 | public: |
69 | explicit NotStagePipelineError(IRModule mod, Block block) : mod_(mod), block_(block) {} |
70 | IRModule mod() const final { return mod_; } |
71 | String FastErrorString() const final { |
72 | return "ScheduleError: The scope root is not a stage pipeline" ; |
73 | } |
74 | String DetailRenderTemplate() const final { |
75 | return R"(The scope {0} is not a stage pipeline. |
76 | Definition of a scope that is a stage pipeline: |
77 | - The region cover property holds for every of its child blocks |
78 | - No write-after-read dependency or opaque dependency, |
79 | - only read-after-write and write-after-write are allowed |
80 | - All the statements in the scope are schedulable statements, i.e. Block and For |
81 | )" ; |
82 | } |
83 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
84 | IRModule mod_; |
85 | Block block_; |
86 | }; |
87 | |
88 | StmtSRef scope_root_sref{nullptr}; |
89 | StmtSRef scope_root_subtree{nullptr}; |
90 | // Step 1. Find the scope root and the subtree that the given sref is in |
91 | { |
92 | const StmtSRefNode* p = sref->parent; |
93 | const StmtSRefNode* subtree = sref.get(); |
94 | for (; p != nullptr; subtree = p, p = p->parent) { |
95 | if (p->stmt->IsInstance<BlockNode>()) { |
96 | scope_root_sref = GetRef<StmtSRef>(p); |
97 | scope_root_subtree = GetRef<StmtSRef>(subtree); |
98 | break; |
99 | } |
100 | } |
101 | if (p == nullptr) { |
102 | throw RootBlockError(self->mod); |
103 | } |
104 | } |
105 | // Step 2. Handle `require_stage_pipeline` |
106 | if (require_stage_pipeline) { |
107 | bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; |
108 | if (stage_pipeline == false) { |
109 | const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref); |
110 | throw NotStagePipelineError(self->mod, GetRef<Block>(block)); |
111 | } |
112 | } |
113 | return scope_root_sref; |
114 | } |
115 | |
116 | ScopeBlockLoopInfo GetScopeBlockLoopInfo(const Block& scope_block) { |
117 | struct Collector : public StmtVisitor { |
118 | void VisitStmt_(const BlockRealizeNode* realize) final { |
119 | result.realizes.push_back(GetRef<BlockRealize>(realize)); |
120 | const Array<IterVar>& iter_vars = realize->block->iter_vars; |
121 | const Array<PrimExpr>& iter_values = realize->iter_values; |
122 | ICHECK_EQ(iter_vars.size(), iter_values.size()); |
123 | int n = realize->iter_values.size(); |
124 | for (int i = 0; i < n; ++i) { |
125 | const IterVar& iter_var = iter_vars[i]; |
126 | const PrimExpr& iter_value = iter_values[i]; |
127 | std::unordered_set<const VarNode*>* vars = nullptr; |
128 | if (iter_var->iter_type == IterVarType::kDataPar) { |
129 | vars = &result.spatial_vars; |
130 | } else { |
131 | vars = &result.non_spatial_vars; |
132 | } |
133 | PostOrderVisit(iter_value, [vars](const ObjectRef& obj) { |
134 | if (const VarNode* var = obj.as<VarNode>()) { |
135 | vars->insert(var); |
136 | } |
137 | }); |
138 | } |
139 | } |
140 | |
141 | ScopeBlockLoopInfo result; |
142 | } visitor; |
143 | visitor(scope_block->body); |
144 | return std::move(visitor.result); |
145 | } |
146 | |
147 | /*! |
148 | * \brief Check whether the given sref_a is higher than or equal to sref_b. |
149 | */ |
150 | void CheckSRefHigherOrEqual(const StmtSRef& sref_a, const StmtSRef& sref_b) { |
151 | const StmtSRefNode* p = sref_b.get(); |
152 | for (; p != nullptr; p = p->parent) { |
153 | if (p == sref_a.get()) { |
154 | return; |
155 | } |
156 | } |
157 | CHECK(false) << "Expect StmtSRef " << sref_a << "to be higher than or equal to " << sref_b; |
158 | } |
159 | |
160 | /*! |
161 | * \brief Check the dominant property of a block: |
162 | * the block is the only writer of its output, dominating the reader of its output buffers under the |
163 | * given root scope. |
164 | * \param self The schedule state. |
165 | * \param scope_root_sref The StmtSRef corresponding to the root scope. |
166 | * \param block_sref The block whose dominant property is to be checked. |
167 | * \return A boolean indicating if the block is a dominant block. |
168 | */ |
169 | bool IsDominantBlock(const ScheduleState& self, const StmtSRef& scope_root_sref, |
170 | const StmtSRef& block_sref) { |
171 | std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers; |
172 | CheckSRefHigherOrEqual(scope_root_sref, block_sref); |
173 | const BlockNode* maybe_root_block = scope_root_sref->StmtAs<BlockNode>(); |
174 | if (maybe_root_block) { |
175 | BlockScope scope = self->GetBlockScope(scope_root_sref); |
176 | buffer_writers = scope->buffer_writers; |
177 | } else { |
178 | // Collect all child blocks of root sub-tree, and merge their buffer writers. |
179 | Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, scope_root_sref); |
180 | for (const StmtSRef& child_block_sref : child_block_srefs) { |
181 | BlockScope child_scope = self->GetBlockScope(child_block_sref); |
182 | for (const auto& it : child_scope->buffer_writers) { |
183 | buffer_writers.insert(it); |
184 | } |
185 | } |
186 | } |
187 | // Check whether the input block is the only writer of its outputs |
188 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
189 | for (const BufferRegion& write_region : block->writes) { |
190 | if (buffer_writers.count(write_region->buffer)) { |
191 | if (buffer_writers.at(write_region->buffer).size() != 1) { |
192 | return false; |
193 | } |
194 | } |
195 | } |
196 | return true; |
197 | } |
198 | |
199 | /*! |
200 | * \brief A helper function that checks whether a given block is a complete block under the scope, |
201 | * or return the condition it violates if it is not a complete block |
202 | * \param self The schedule state |
203 | * \param block_sref The block to be checked |
204 | * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in |
205 | * \return 0 if the block is a complete block, or a positive integer indicating which condition is |
206 | * first violated |
207 | */ |
208 | int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, |
209 | const StmtSRef& scope_root_sref) { |
210 | // Cond 1. All block vars are data parallel |
211 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
212 | for (const IterVar& iter_var : block->iter_vars) { |
213 | if (iter_var->iter_type != kDataPar) { |
214 | return 1; |
215 | } |
216 | } |
217 | // Cond 2. Dominant: the block is the only writer of its output, |
218 | // dominating the reader of its output buffers |
219 | if (!IsDominantBlock(self, scope_root_sref, block_sref)) { |
220 | return 2; |
221 | } |
222 | // Cond 3. No overlap between the buffers the block reads and writes |
223 | std::unordered_set<const BufferNode*> written_buffers; |
224 | written_buffers.reserve(block->writes.size()); |
225 | for (const BufferRegion& write : block->writes) { |
226 | written_buffers.insert(write->buffer.get()); |
227 | } |
228 | for (const BufferRegion& read : block->reads) { |
229 | if (written_buffers.count(read->buffer.get())) { |
230 | return 3; |
231 | } |
232 | } |
233 | return 0; |
234 | } |
235 | |
236 | static const char* kCompleteBlockDefinition = R"(Definition of a complete block: |
237 | 1) All block vars are data parallel |
238 | 2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers |
239 | 3) No overlap between the buffers the block reads and writes)" ; |
240 | |
241 | static const char* kReductionBlockDefinition = R"(Definition of a reduction block: |
242 | 1) The block has the `init` statement |
243 | 2) All the block bindings are quasi-affine expressions |
244 | 3) All block vars are either data parallel block vars or reduction block vars |
245 | 4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers |
246 | 5) The reduction block vars are not used to index the output buffers)" ; |
247 | |
248 | static const char* kLocalCompleteBlockDefinition = R"(Definition of a local complete block: |
249 | 1) All block vars are data parallel |
250 | 2) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree |
251 | 3) No overlap between the buffers the block reads and writes)" ; |
252 | |
253 | static const char* kLocalReductionBlockDefinition = R"(Definition of a reduction block: |
254 | 1) The block has the `init` statement |
255 | 2) All the block bindings are quasi-affine expressions |
256 | 3) All block vars are either data parallel block vars or reduction block vars |
257 | 4) Local Dominant: the block is the only writer of its output, dominating the reader of its output buffers under a given subtree |
258 | 5) The reduction block vars are not used to index the output buffers)" ; |
259 | |
260 | bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, |
261 | const StmtSRef& scope_root_sref) { |
262 | return CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref) == 0; |
263 | } |
264 | |
265 | void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, |
266 | const StmtSRef& scope_root_sref) { |
267 | class IncompleteBlockError : public ScheduleError { |
268 | public: |
269 | explicit IncompleteBlockError(IRModule mod, Block block, int violated_cond) |
270 | : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} |
271 | String FastErrorString() const final { return "ScheduleError: Incomplete block" ; } |
272 | String DetailRenderTemplate() const final { |
273 | std::ostringstream os; |
274 | os << "The block {0} is not a complete block - it violates condition #" << violated_cond_; |
275 | os << ".\n" << kCompleteBlockDefinition; |
276 | return os.str(); |
277 | } |
278 | IRModule mod() const final { return mod_; } |
279 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
280 | IRModule mod_; |
281 | Block block_; |
282 | int violated_cond_; |
283 | }; |
284 | |
285 | int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); |
286 | if (error_code != 0) { |
287 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
288 | throw IncompleteBlockError(self->mod, GetRef<Block>(block), error_code); |
289 | } |
290 | } |
291 | |
292 | /*! |
293 | * \brief A helper function that checks whether a given block is a reduction block under the scope, |
294 | * or return the condition it violates if it is not a reduction block |
295 | * \param self The schedule state |
296 | * \param block_sref The block to be checked |
297 | * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in |
298 | * \return 0 if the block is a reduction block, or a positive integer indicating which condition is |
299 | * first violated |
300 | */ |
301 | int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, |
302 | const StmtSRef& scope_root_sref) { |
303 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
304 | // Cond 1. The block has the `init` statement. |
305 | if (!block->init.defined()) { |
306 | return 1; |
307 | } |
308 | // Cond 2. All the block bindings are quasi-affine expressions. |
309 | if (!self->IsAffineBlockBinding(block_sref)) { |
310 | return 2; |
311 | } |
312 | // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile, |
313 | // we collect all the reduction block vars. |
314 | if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) { |
315 | return 3; |
316 | } |
317 | // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its |
318 | // output buffers. |
319 | if (!IsDominantBlock(self, scope_root_sref, block_sref)) { |
320 | return 4; |
321 | } |
322 | // Cond 5. The reduction block vars are not used to index the output buffers. |
323 | return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block)) ? 0 : 5; |
324 | } |
325 | |
326 | bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, |
327 | const StmtSRef& scope_root_sref) { |
328 | return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0; |
329 | } |
330 | |
331 | void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, |
332 | const StmtSRef& scope_root_sref) { |
333 | class NotReductionBlockError : public ScheduleError { |
334 | public: |
335 | explicit NotReductionBlockError(IRModule mod, Block block, int violated_cond) |
336 | : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} |
337 | String FastErrorString() const final { return "ScheduleError: Not a reduction block" ; } |
338 | String DetailRenderTemplate() const final { |
339 | std::ostringstream os; |
340 | os << "The block {0} is not a reduction block - it violates condition #" << violated_cond_; |
341 | os << ".\n" << kReductionBlockDefinition; |
342 | return os.str(); |
343 | } |
344 | IRModule mod() const final { return mod_; } |
345 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
346 | IRModule mod_; |
347 | Block block_; |
348 | int violated_cond_; |
349 | }; |
350 | |
351 | int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); |
352 | if (error_code != 0) { |
353 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
354 | throw NotReductionBlockError(self->mod, GetRef<Block>(block), error_code); |
355 | } |
356 | } |
357 | |
358 | void CheckCompleteOrReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, |
359 | const StmtSRef& scope_root_sref) { |
360 | class NotCompleteOrReductionBlockError : public ScheduleError { |
361 | public: |
362 | explicit NotCompleteOrReductionBlockError(IRModule mod, Block block, |
363 | int complete_block_error_code, |
364 | int reduction_block_error_code) |
365 | : mod_(mod), |
366 | block_(block), |
367 | complete_block_error_code_(complete_block_error_code), |
368 | reduction_block_error_code_(reduction_block_error_code) {} |
369 | |
370 | String FastErrorString() const final { |
371 | return "ScheduleError: Not a complete or reduction block" ; |
372 | } |
373 | String DetailRenderTemplate() const final { |
374 | std::ostringstream os; |
375 | os << "The block {0} is not a complete block - it violates condition #" |
376 | << complete_block_error_code_; |
377 | os << ".\n" << kCompleteBlockDefinition; |
378 | os << "\nThe block is not a reduction block either - it violates condition #" |
379 | << reduction_block_error_code_; |
380 | os << ".\n" << kReductionBlockDefinition; |
381 | return os.str(); |
382 | } |
383 | IRModule mod() const final { return mod_; } |
384 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
385 | |
386 | IRModule mod_; |
387 | Block block_; |
388 | int complete_block_error_code_; |
389 | int reduction_block_error_code_; |
390 | }; |
391 | |
392 | int complete_block_error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); |
393 | if (complete_block_error_code == 0) { |
394 | return; |
395 | } |
396 | int reduction_block_error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); |
397 | if (reduction_block_error_code == 0) { |
398 | return; |
399 | } |
400 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
401 | throw NotCompleteOrReductionBlockError(self->mod, GetRef<Block>(block), complete_block_error_code, |
402 | reduction_block_error_code); |
403 | } |
404 | |
405 | void CheckSubtreeCompactDataflow(const ScheduleState& self, const StmtSRef& subtree_root) { |
406 | class NotCompactDataFlowError : public ScheduleError { |
407 | public: |
408 | explicit NotCompactDataFlowError(IRModule mod, Stmt subtree_root, Block violate_block, |
409 | int local_complete_block_code, int local_reduction_block_code) |
410 | : mod_(std::move(mod)), |
411 | subtree_root_(std::move(subtree_root)), |
412 | violate_block_(std::move(violate_block)), |
413 | local_complete_block_code_(local_complete_block_code), |
414 | local_reduction_block_code_(local_reduction_block_code) { |
415 | ICHECK(subtree_root_->IsInstance<BlockNode>() || subtree_root_->IsInstance<ForNode>()); |
416 | } |
417 | String FastErrorString() const final { |
418 | return "ScheduleError: The queried subtree root in SRef tree does not have compact dataflow, " |
419 | "because some of its child block on SRef tree is neither a local complete block nor a " |
420 | "local reduction block." ; |
421 | } |
422 | String DetailRenderTemplate() const final { |
423 | std::ostringstream os; |
424 | os << "The queried subtree root {0} in SRef tree does not have compact dataflow, because " |
425 | "its child block {1} on SRef tree is neither a local complete block nor a local " |
426 | "reduction block.\n" ; |
427 | os << "It violates condition #" << local_complete_block_code_ |
428 | << " as a local complete block.\n" ; |
429 | os << kLocalCompleteBlockDefinition << "\n" ; |
430 | os << "It violates condition #" << local_reduction_block_code_ |
431 | << " as a local reduction block.\n" ; |
432 | os << kLocalReductionBlockDefinition << "\n" ; |
433 | return os.str(); |
434 | } |
435 | IRModule mod() const final { return mod_; } |
436 | Array<ObjectRef> LocationsOfInterest() const final { return {subtree_root_, violate_block_}; } |
437 | |
438 | IRModule mod_; |
439 | Stmt subtree_root_; |
440 | Block violate_block_; |
441 | int local_complete_block_code_; |
442 | int local_reduction_block_code_; |
443 | }; |
444 | |
445 | Array<StmtSRef> child_block_srefs = GetChildBlockSRefOnSRefTree(self, subtree_root); |
446 | for (const StmtSRef& block_sref : child_block_srefs) { |
447 | int local_complete_block_code = CheckCompleteBlockErrorCode(self, block_sref, subtree_root), |
448 | local_reduction_block_code = CheckReductionBlockErrorCode(self, block_sref, subtree_root); |
449 | if (local_complete_block_code != 0 && local_reduction_block_code != 0) { |
450 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
451 | throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(subtree_root->stmt), |
452 | GetRef<Block>(block), local_complete_block_code, |
453 | local_reduction_block_code); |
454 | } |
455 | } |
456 | } |
457 | |
458 | bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, |
459 | const StmtSRef& scope_root_sref) { |
460 | const BlockNode* scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); |
461 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
462 | std::unordered_set<const BufferNode*> scope_allocated; |
463 | scope_allocated.reserve(scope_root->alloc_buffers.size()); |
464 | for (const Buffer& buffer : scope_root->alloc_buffers) { |
465 | scope_allocated.insert(buffer.get()); |
466 | } |
467 | for (const BufferRegion& buffer_region : block->writes) { |
468 | if (!scope_allocated.count(buffer_region->buffer.get())) { |
469 | return true; |
470 | } |
471 | } |
472 | return false; |
473 | } |
474 | |
475 | void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, |
476 | const StmtSRef& scope_root_sref) { |
477 | class OutputBlockError : public ScheduleError { |
478 | public: |
479 | explicit OutputBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {} |
480 | String FastErrorString() const final { |
481 | return "ScheduleError: Cannot operate on an output block" ; |
482 | } |
483 | String DetailRenderTemplate() const final { return "The block {0} is an output block" ; } |
484 | IRModule mod() const final { return mod_; } |
485 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
486 | |
487 | IRModule mod_; |
488 | Block block_; |
489 | }; |
490 | if (IsOutputBlock(self, block_sref, scope_root_sref)) { |
491 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
492 | throw OutputBlockError(self->mod, GetRef<Block>(block)); |
493 | } |
494 | } |
495 | |
496 | std::vector<IterVarType> GetBlockVarTypes(const BlockNode* block) { |
497 | std::vector<IterVarType> results; |
498 | results.reserve(block->iter_vars.size()); |
499 | for (const IterVar& iter_var : block->iter_vars) { |
500 | results.push_back(iter_var->iter_type); |
501 | } |
502 | return results; |
503 | } |
504 | |
505 | std::vector<IterVarType> GetBlockVarTypes(const StmtSRef& block_sref) { |
506 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
507 | return GetBlockVarTypes(block); |
508 | } |
509 | |
510 | bool IsWriteCache(const StmtSRef& block_sref) { |
511 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
512 | if (block->writes.size() != 1) { |
513 | return false; |
514 | } |
515 | const BufferRegion& write_region = block->writes[0]; |
516 | for (const BufferRegion& read_region : block->reads) { |
517 | auto [exists, surjective, injective, ordered, no_const_read, no_shift_read] = |
518 | AnalyzeReadWritePattern(read_region, write_region); |
519 | // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 |
520 | (void)exists; |
521 | (void)surjective; |
522 | (void)no_const_read; |
523 | (void)no_shift_read; |
524 | if (!(injective && ordered)) { |
525 | return false; |
526 | } |
527 | } |
528 | return true; |
529 | } |
530 | |
531 | /******** Binding ********/ |
532 | |
533 | bool IsAffineBinding(const BlockRealize& realize, const Map<Var, Range>& loop_var_ranges, |
534 | arith::Analyzer* analyzer) { |
535 | if (loop_var_ranges.empty()) { |
536 | return true; |
537 | } |
538 | auto res = arith::DetectIterMap( |
539 | /*indices=*/realize->iter_values, |
540 | /*input_iters=*/loop_var_ranges, |
541 | /*predicate=*/realize->predicate, |
542 | /*check_level=*/arith::IterMapLevel::Surjective, |
543 | /*analyzer=*/analyzer, |
544 | /*simplify_trivial_iterators=*/false); |
545 | if (res->indices.empty()) { |
546 | return false; |
547 | } |
548 | for (const arith::IterSumExpr& sum_expr : res->indices) { |
549 | const Array<arith::IterSplitExpr>& args = sum_expr->args; |
550 | if (!args.empty() && !is_one(args[0]->scale)) { |
551 | return false; |
552 | } |
553 | } |
554 | return true; |
555 | } |
556 | |
557 | void CheckPartialAffineBinding(const ScheduleState& self, Block block, |
558 | const Optional<StmtSRef>& high_exclusive) { |
559 | class NotAffineBindingError : public ScheduleError { |
560 | public: |
561 | explicit NotAffineBindingError(IRModule mod, Block block, Optional<StmtSRef> high_exclusive) |
562 | : mod_(std::move(mod)), block_(std::move(block)) { |
563 | if (high_exclusive.defined()) { |
564 | high_exclusive_loop_ = high_exclusive.value()->StmtAs<ForNode>(); |
565 | } |
566 | } |
567 | String FastErrorString() const final { |
568 | std::ostringstream ss; |
569 | if (high_exclusive_loop_) { |
570 | ss << "ScheduleError: The block is required to have an partial affine binding under " |
571 | << high_exclusive_loop_->loop_var; |
572 | } else { |
573 | ss << "ScheduleError: The block is required to have an affine binding" ; |
574 | } |
575 | return ss.str(); |
576 | } |
577 | String DetailRenderTemplate() const final { |
578 | std::ostringstream ss; |
579 | if (high_exclusive_loop_) { |
580 | ss << "The block {0} is required to have an partial affine binding under " |
581 | << high_exclusive_loop_->loop_var; |
582 | } else { |
583 | ss << "The block {0} is required to have an affine binding" ; |
584 | } |
585 | return ss.str(); |
586 | } |
587 | IRModule mod() const final { return mod_; } |
588 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
589 | IRModule mod_; |
590 | Block block_; |
591 | const ForNode* high_exclusive_loop_{nullptr}; |
592 | }; |
593 | |
594 | StmtSRef block_sref = self->stmt2ref.at(block.get()); |
595 | if (self->IsAffineBlockBinding(block_sref)) { |
596 | // check block cached state for global affineness |
597 | return; |
598 | } |
599 | if (block_sref->parent && high_exclusive.defined()) { |
600 | // if it is not of global affine binding, check affineness under high_exclusive, |
601 | arith::Analyzer analyzer; |
602 | Map<Var, Range> dom_map = |
603 | LoopDomainOfSRefTreePath(GetRef<StmtSRef>(block_sref->parent), high_exclusive); |
604 | if (IsAffineBinding(GetBlockRealize(self, block_sref), dom_map, &analyzer)) { |
605 | return; |
606 | } |
607 | } |
608 | throw NotAffineBindingError(self->mod, std::move(block), high_exclusive); |
609 | } |
610 | |
611 | void CheckAffineBinding(const ScheduleState& self, Block block) { |
612 | CheckPartialAffineBinding(self, std::move(block), NullOpt); |
613 | } |
614 | |
615 | void CheckBlockHasTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { |
616 | class NotTrivialBindingError : public ScheduleError { |
617 | public: |
618 | explicit NotTrivialBindingError(IRModule mod, Block block) |
619 | : mod_(std::move(mod)), block_(std::move(block)) {} |
620 | |
621 | String FastErrorString() const final { |
622 | return "ScheduleError: The binding values of the block are not variables of outer loops." ; |
623 | } |
624 | |
625 | String DetailRenderTemplate() const final { |
626 | std::ostringstream os; |
627 | os << "The binding values of the {0} are not variables of outer loops." ; |
628 | return os.str(); |
629 | } |
630 | |
631 | IRModule mod() const final { return mod_; } |
632 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
633 | |
634 | private: |
635 | IRModule mod_; |
636 | Block block_; |
637 | }; |
638 | |
639 | if (!IsTrivialBinding(self, block_sref)) { |
640 | throw NotTrivialBindingError(self->mod, GetRef<Block>(block_sref->StmtAs<BlockNode>())); |
641 | } |
642 | } |
643 | |
644 | Map<Var, Range> LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, |
645 | const Optional<StmtSRef>& high_exclusive, |
646 | const runtime::StorageScope& ) { |
647 | Map<Var, Range> result; |
648 | const StmtSRefNode* p = low_inclusive.get(); |
649 | const StmtSRefNode* limit = static_cast<const StmtSRefNode*>(high_exclusive.get()); |
650 | for (; p != limit; p = p->parent) { |
651 | const ForNode* loop = p->StmtAs<ForNode>(); |
652 | if (loop == nullptr) { |
653 | break; |
654 | } |
655 | result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); |
656 | } |
657 | if (extra_relax_scope.rank != runtime::StorageRank::kGlobal) { |
658 | for (; p; p = p->parent) { |
659 | if (const ForNode* loop = p->StmtAs<ForNode>()) { |
660 | if (loop->kind == ForKind::kThreadBinding) { |
661 | const String& thread_tag = loop->thread_binding.value()->thread_tag; |
662 | if (CanRelaxStorageUnderThread(extra_relax_scope, |
663 | runtime::ThreadScope::Create(thread_tag))) { |
664 | result.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); |
665 | } |
666 | } |
667 | } |
668 | } |
669 | } |
670 | return result; |
671 | } |
672 | |
673 | Map<Var, PrimExpr> GetBindings(const BlockRealize& realize) { |
674 | const BlockNode* block = realize->block.get(); |
675 | const Array<IterVar>& all_lhs = block->iter_vars; |
676 | const Array<PrimExpr>& all_rhs = realize->iter_values; |
677 | ICHECK_EQ(all_lhs.size(), all_rhs.size()); |
678 | Map<Var, PrimExpr> result; |
679 | for (int i = 0, n = all_lhs.size(); i < n; ++i) { |
680 | const IterVar& lhs = all_lhs[i]; |
681 | const PrimExpr& rhs = all_rhs[i]; |
682 | result.Set(lhs->var, rhs); |
683 | } |
684 | return result; |
685 | } |
686 | |
687 | bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, |
688 | std::unordered_set<const VarNode*>* data_par_vars, |
689 | std::unordered_set<const VarNode*>* reduce_vars) { |
690 | Block block = block_realize->block; |
691 | ICHECK(block_realize->block.same_as(block)) |
692 | << "ValueError: The input `block_realize` is required to be the exact BlockRealize of the " |
693 | "input block" ; |
694 | |
695 | bool has_block_vars_of_other_types = false; |
696 | ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); |
697 | int n = static_cast<int>(block->iter_vars.size()); |
698 | for (int i = 0; i < n; ++i) { |
699 | const IterVar& iter_var = block->iter_vars[i]; |
700 | const PrimExpr& iter_value = block_realize->iter_values[i]; |
701 | std::unordered_set<const VarNode*>* set = nullptr; |
702 | if (iter_var->iter_type == IterVarType::kDataPar) { |
703 | set = data_par_vars; |
704 | } else if (iter_var->iter_type == IterVarType::kCommReduce) { |
705 | set = reduce_vars; |
706 | } else { |
707 | has_block_vars_of_other_types = true; |
708 | } |
709 | if (set == nullptr) { |
710 | continue; |
711 | } |
712 | Array<Var> vars_in_binding = UndefinedVars(iter_value); |
713 | for (const Var& var : vars_in_binding) { |
714 | set->insert(var.get()); |
715 | } |
716 | } |
717 | |
718 | return has_block_vars_of_other_types; |
719 | } |
720 | |
721 | /******** Loop properties ********/ |
722 | |
723 | void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, |
724 | arith::Analyzer* analyzer) { |
725 | class LoopNotStartWithZeroError : public ScheduleError { |
726 | public: |
727 | explicit LoopNotStartWithZeroError(IRModule mod, For loop) |
728 | : mod_(mod), loop_(std::move(loop)) {} |
729 | |
730 | String FastErrorString() const final { |
731 | return "ScheduleError: The primitive only supports loop starting with 0" ; |
732 | } |
733 | |
734 | String DetailRenderTemplate() const final { |
735 | return "The loop {0} does not start with 0, which is not supported" ; |
736 | } |
737 | |
738 | IRModule mod() const final { return mod_; } |
739 | Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; } |
740 | |
741 | IRModule mod_; |
742 | For loop_; |
743 | }; |
744 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
745 | if (!analyzer->CanProve(loop->min == 0)) { |
746 | throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop)); |
747 | } |
748 | } |
749 | |
750 | /******** Block-loop relation ********/ |
751 | |
752 | Array<StmtSRef> GetChildBlockSRefOnSRefTree(const ScheduleState& self, |
753 | const StmtSRef& parent_sref) { |
754 | Array<BlockRealize> child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); |
755 | Array<StmtSRef> child_block_srefs; |
756 | child_block_srefs.reserve(child_block_realize.size()); |
757 | |
758 | for (BlockRealize realize : child_block_realize) { |
759 | child_block_srefs.push_back(self->stmt2ref.at(realize->block.get())); |
760 | } |
761 | return child_block_srefs; |
762 | } |
763 | |
764 | Array<BlockRealize> GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) { |
765 | struct Collector : public StmtVisitor { |
766 | static Array<BlockRealize> Collect(const Stmt& stmt) { |
767 | Collector collector; |
768 | collector(stmt); |
769 | return std::move(collector.result_); |
770 | } |
771 | |
772 | void VisitStmt_(const BlockRealizeNode* block_realize) final { |
773 | result_.push_back(GetRef<BlockRealize>(block_realize)); |
774 | } |
775 | |
776 | Array<BlockRealize> result_; |
777 | }; |
778 | |
779 | if (parent_sref->stmt->IsInstance<ForNode>()) { |
780 | const auto* loop = static_cast<const ForNode*>(parent_sref->stmt); |
781 | return Collector::Collect(loop->body); |
782 | } else if (parent_sref->stmt->IsInstance<BlockNode>()) { |
783 | const auto* block = static_cast<const BlockNode*>(parent_sref->stmt); |
784 | return Collector::Collect(block->body); |
785 | } |
786 | ICHECK(false) << "Unreachable" ; |
787 | throw; |
788 | } |
789 | |
790 | BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, |
791 | const StmtSRef& parent_sref) { |
792 | class NonSingleChildBlockError : public ScheduleError { |
793 | public: |
794 | explicit NonSingleChildBlockError(IRModule mod, const StmtSRef& sref) |
795 | : mod_(std::move(mod)), stmt_(GetRef<Stmt>(sref->stmt)) { |
796 | sref_type_ = stmt_.as<BlockNode>() != nullptr ? "block" : "loop" ; |
797 | } |
798 | |
799 | String FastErrorString() const final { |
800 | std::ostringstream os; |
801 | os << "ScheduleError: The " << sref_type_ << " is required to have only one child block" ; |
802 | return os.str(); |
803 | } |
804 | |
805 | String DetailRenderTemplate() const final { |
806 | std::ostringstream os; |
807 | os << "The " << sref_type_ << " {0} is required to have only one child block" ; |
808 | return os.str(); |
809 | } |
810 | |
811 | IRModule mod() const final { return mod_; } |
812 | Array<ObjectRef> LocationsOfInterest() const final { return {stmt_}; } |
813 | |
814 | IRModule mod_; |
815 | Stmt stmt_; |
816 | String sref_type_; |
817 | }; |
818 | |
819 | Array<BlockRealize> child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); |
820 | if (child_block_realize.size() != 1) { |
821 | throw NonSingleChildBlockError(self->mod, parent_sref); |
822 | } |
823 | return child_block_realize[0]; |
824 | } |
825 | |
826 | BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref) { |
827 | struct BlockRealizeFinder : public StmtVisitor { |
828 | explicit BlockRealizeFinder(const BlockNode* target_block) |
829 | : target_block(target_block), result(nullptr) {} |
830 | |
831 | void VisitStmt(const Stmt& stmt) final { |
832 | if (result != nullptr) { |
833 | return; |
834 | } |
835 | StmtVisitor::VisitStmt(stmt); |
836 | } |
837 | |
838 | void VisitStmt_(const BlockRealizeNode* block_realize) final { |
839 | if (block_realize->block.get() == target_block) { |
840 | result = block_realize; |
841 | } |
842 | // No need to visit recursively, since the deeper BlockRealizes must not be the result. |
843 | } |
844 | |
845 | const BlockNode* target_block; |
846 | const BlockRealizeNode* result; |
847 | }; |
848 | |
849 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
850 | if (block_sref->parent == nullptr) { |
851 | const PrimFuncNode* func = GetRootPrimFunc(self->mod, block, nullptr); |
852 | return Downcast<BlockRealize>(func->body); |
853 | } else { |
854 | BlockRealizeFinder finder(block); |
855 | finder(GetRef<Stmt>(block_sref->parent->stmt)); |
856 | ICHECK(finder.result != nullptr) |
857 | << "InternalError: Cannot find the BlockRealize of block " << GetRef<Block>(block); |
858 | return GetRef<BlockRealize>(finder.result); |
859 | } |
860 | } |
861 | |
862 | IterVarType GetLoopIterType(const StmtSRef& loop_sref) { |
863 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
864 | const Var& loop_var = loop->loop_var; |
865 | int n_spatial = 0; |
866 | int n_reduce = 0; |
867 | int n_other = 0; |
868 | auto f_visit = [&loop_var, &n_spatial, &n_reduce, &n_other](const ObjectRef& obj) -> bool { |
869 | if (const auto* realize = obj.as<BlockRealizeNode>()) { |
870 | const BlockNode* block = realize->block.get(); |
871 | // Number of block vars and their bindings |
872 | ICHECK_EQ(realize->iter_values.size(), block->iter_vars.size()); |
873 | size_t n = realize->iter_values.size(); |
874 | for (size_t i = 0; i < n; ++i) { |
875 | const IterVar& iter_var = block->iter_vars[i]; |
876 | const PrimExpr& binding = realize->iter_values[i]; |
877 | // Categorize the current block var |
878 | int* ref = nullptr; |
879 | if (iter_var->iter_type == IterVarType::kDataPar) { |
880 | ref = &n_spatial; |
881 | } else if (iter_var->iter_type == IterVarType::kCommReduce) { |
882 | ref = &n_reduce; |
883 | } else { |
884 | ref = &n_other; |
885 | } |
886 | // Visit the binding to see if `loop_var` appears |
887 | PostOrderVisit(binding, [&ref, &loop_var](const ObjectRef& obj) -> void { |
888 | if (obj.same_as(loop_var)) { |
889 | (*ref) += 1; |
890 | } |
891 | }); |
892 | } |
893 | return false; |
894 | } |
895 | return true; |
896 | }; |
897 | PreOrderVisit(loop->body, f_visit); |
898 | if (n_other) { |
899 | return IterVarType::kOpaque; |
900 | } else if (n_spatial && n_reduce) { |
901 | return IterVarType::kOpaque; |
902 | } else if (n_reduce) { |
903 | return IterVarType::kCommReduce; |
904 | } else { |
905 | return IterVarType::kDataPar; |
906 | } |
907 | } |
908 | |
909 | StmtSRef GetSRefLowestCommonAncestor(const Array<StmtSRef>& srefs) { |
910 | CHECK(!srefs.empty()) << "ValueError: The input array is required to have at least one sref" ; |
911 | |
912 | std::unordered_map<const StmtSRefNode*, size_t> sref_visited_cnt; |
913 | for (const StmtSRef& sref : srefs) { |
914 | const StmtSRefNode* p = sref.get(); |
915 | while (p != nullptr) { |
916 | ++sref_visited_cnt[p]; |
917 | p = p->parent; |
918 | } |
919 | } |
920 | size_t n_sref = srefs.size(); |
921 | const StmtSRefNode* p = srefs[0].get(); |
922 | while (p != nullptr && sref_visited_cnt[p] != n_sref) { |
923 | p = p->parent; |
924 | } |
925 | ICHECK(p != nullptr); |
926 | return GetRef<StmtSRef>(p); |
927 | } |
928 | |
929 | bool HasBeenMultiLevelTiled(const StmtSRef& block_sref) { |
930 | return tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_tiling_structure).defined(); |
931 | } |
932 | |
933 | std::pair<Array<StmtSRef>, std::vector<int>> CollectComputeLocation(const ScheduleState& self, |
934 | const StmtSRef& block_sref) { |
935 | Array<StmtSRef> location_srefs; |
936 | std::vector<int> location_indices; |
937 | |
938 | // Step 1. Add the "compute-root" candidate. Add the "compute-inline" candidate if the block can |
939 | // be inlined. |
940 | if (CanComputeInline(self, block_sref)) { |
941 | location_srefs.push_back(StmtSRef::InlineMark()); |
942 | location_indices.push_back(-2); |
943 | } |
944 | location_srefs.push_back(StmtSRef::RootMark()); |
945 | location_indices.push_back(-1); |
946 | |
947 | // Step 2. If the block has no consumer, there is no more candidate. |
948 | Array<StmtSRef> consumers = GetConsumers(self, block_sref); |
949 | if (consumers.empty()) { |
950 | return std::make_pair(location_srefs, location_indices); |
951 | } |
952 | |
953 | // Step 3. Get the deepest loop that the input block can be computed at (namely "boundary"). If |
954 | // such a loop cannot be found, there is no more candidate and we just return. |
955 | StmtSRef loop_boundary = consumers.size() > 1 ? GetSRefLowestCommonAncestor(consumers) |
956 | : GetRef<StmtSRef>(consumers[0]->parent); |
957 | if (loop_boundary->StmtAs<ForNode>() == nullptr) { |
958 | return std::make_pair(location_srefs, location_indices); |
959 | } |
960 | |
961 | // Step 4. Collect the loops outside the first consumer and locate the boundary loop. The position |
962 | // of the boundary loop reveals the number of possible additional candidates. |
963 | Array<StmtSRef> loop_srefs = GetLoops(consumers[0]); |
964 | size_t lca_pos = |
965 | std::find(loop_srefs.begin(), loop_srefs.end(), loop_boundary) - loop_srefs.begin(); |
966 | ICHECK_LT(lca_pos, loop_srefs.size()); |
967 | size_t n_candidate = lca_pos + 1; |
968 | |
969 | // Step 5. Find the position of the deepest data-parallel loop among the candidate loops. This |
970 | // position is used for removing the unwanted candidates from the perspective of performance. |
971 | std::vector<IterVarType> loop_iter_types; |
972 | loop_iter_types.reserve(n_candidate); |
973 | int i_last_datapar = -1; |
974 | for (size_t i = 0; i < n_candidate; ++i) { |
975 | // TODO(siyuan): improve the performance |
976 | IterVarType iter_type = GetLoopIterType(loop_srefs[i]); |
977 | loop_iter_types.push_back(iter_type); |
978 | if (iter_type == IterVarType::kDataPar) { |
979 | i_last_datapar = i; |
980 | } |
981 | } |
982 | // Step 6. Check and add the candidates in turn according to the following rules: |
983 | // - skip the unit loops (loops with extent 1); |
984 | // - do not consider the data-parallel loops after a not-data-parallel loop; |
985 | // - do not consider the trailing not-data-parallel loops. |
986 | location_srefs.reserve(n_candidate + 2); |
987 | location_indices.reserve(n_candidate + 2); |
988 | bool visited_reduce = false; |
989 | for (size_t i = 0; i < n_candidate; ++i) { |
990 | const int64_t* loop_extent = GetLoopIntExtent(loop_srefs[i]); |
991 | if (loop_extent != nullptr && *loop_extent == 1) { |
992 | continue; |
993 | } |
994 | |
995 | if (loop_iter_types[i] == IterVarType::kDataPar) { |
996 | if (visited_reduce) { |
997 | break; |
998 | } |
999 | } else { |
1000 | visited_reduce = true; |
1001 | if (static_cast<int>(i) > i_last_datapar) { |
1002 | break; |
1003 | } |
1004 | } |
1005 | if (CanComputeAt(self, block_sref, loop_srefs[i], true)) { |
1006 | location_srefs.push_back(loop_srefs[i]); |
1007 | location_indices.push_back(i); |
1008 | } |
1009 | } |
1010 | |
1011 | return std::make_pair(location_srefs, location_indices); |
1012 | } |
1013 | |
1014 | /******** Producer-consumer relation ********/ |
1015 | |
1016 | Array<StmtSRef> GetProducers(const StmtSRef& block_sref, const BlockScope& scope) { |
1017 | Array<Dependency> edges = scope->GetDepsByDst(block_sref); |
1018 | Array<StmtSRef> results; |
1019 | std::unordered_set<StmtSRef, ObjectPtrHash, ObjectPtrEqual> result_set; |
1020 | results.reserve(edges.size()); |
1021 | for (const Dependency& edge : edges) { |
1022 | if ((edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) && |
1023 | !result_set.count(edge->src)) { |
1024 | results.push_back(edge->src); |
1025 | result_set.emplace(edge->src); |
1026 | } |
1027 | } |
1028 | return results; |
1029 | } |
1030 | |
1031 | Array<StmtSRef> GetConsumers(const StmtSRef& block_sref, const BlockScope& scope) { |
1032 | Array<Dependency> edges = scope->GetDepsBySrc(block_sref); |
1033 | Array<StmtSRef> results; |
1034 | std::unordered_set<StmtSRef, ObjectPtrHash, ObjectPtrEqual> result_set; |
1035 | results.reserve(edges.size()); |
1036 | for (const Dependency& edge : edges) { |
1037 | if ((edge->kind == DepKind::kRAW || edge->kind == DepKind::kWAW) && |
1038 | !result_set.count(edge->dst)) { |
1039 | results.push_back(edge->dst); |
1040 | result_set.emplace(edge->dst); |
1041 | } |
1042 | } |
1043 | return results; |
1044 | } |
1045 | |
1046 | ProducerConsumerSplit ProducerConsumerSplit::Find( |
1047 | const ScheduleState& self, const Array<Stmt>& subtrees, |
1048 | const Array<StmtSRef>& producer_block_srefs, const Array<StmtSRef>& consumer_block_srefs, |
1049 | std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) { |
1050 | class InsertionPointNotFoundError : public ScheduleError { |
1051 | public: |
1052 | explicit InsertionPointNotFoundError(IRModule mod, int last_producer_position, |
1053 | int first_consumer_position) |
1054 | : mod_(mod), |
1055 | last_producer_position_(last_producer_position), |
1056 | first_consumer_position_(first_consumer_position) {} |
1057 | |
1058 | String FastErrorString() const final { |
1059 | return "ScheduleError: Cannot find the insertion point that satisfies the producer-consumer " |
1060 | "constraint" ; |
1061 | } |
1062 | |
1063 | String DetailRenderTemplate() const final { |
1064 | return "Cannot find the insertion point that satisfies the producer-consumer constraint. In " |
1065 | "0-based indexing, the last producer appears in subtree " + |
1066 | std::to_string(last_producer_position_) + |
1067 | ", and the first consumer appears in subtree " + |
1068 | std::to_string(first_consumer_position_); |
1069 | } |
1070 | |
1071 | IRModule mod() const final { return mod_; } |
1072 | |
1073 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
1074 | |
1075 | private: |
1076 | IRModule mod_; |
1077 | int last_producer_position_; |
1078 | int first_consumer_position_; |
1079 | }; |
1080 | |
1081 | class Finder : public StmtVisitor { |
1082 | public: |
1083 | void VisitStmt_(const BlockRealizeNode* realize) final { |
1084 | const BlockNode* block = realize->block.get(); |
1085 | if (block2realize_) { |
1086 | block2realize_->emplace(block, realize); |
1087 | } |
1088 | if (producer_blocks_.count(block)) { |
1089 | ++this->n_producers_visited_; |
1090 | } |
1091 | if (consumer_blocks_.count(block)) { |
1092 | ++this->n_consumers_visited_; |
1093 | } |
1094 | } |
1095 | |
1096 | std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize_; |
1097 | std::unordered_set<const StmtNode*> producer_blocks_; |
1098 | std::unordered_set<const StmtNode*> consumer_blocks_; |
1099 | int n_producers_visited_ = 0; |
1100 | int n_consumers_visited_ = 0; |
1101 | }; |
1102 | |
1103 | Finder finder; |
1104 | finder.block2realize_ = block2realize; |
1105 | // Set up the lookup table for producers |
1106 | finder.producer_blocks_.reserve(producer_block_srefs.size()); |
1107 | for (const StmtSRef& block_sref : producer_block_srefs) { |
1108 | finder.producer_blocks_.insert(block_sref->stmt); |
1109 | } |
1110 | // Set up the lookup table for consumers |
1111 | finder.consumer_blocks_.reserve(consumer_block_srefs.size()); |
1112 | for (const StmtSRef& block_sref : consumer_block_srefs) { |
1113 | finder.consumer_blocks_.insert(block_sref->stmt); |
1114 | } |
1115 | // Visit the subtrees |
1116 | int n = subtrees.size(); |
1117 | int last_producer_position = -1; |
1118 | int first_consumer_position = n; |
1119 | for (int i = 0; i < n; ++i) { |
1120 | int n_producers_visited_before = finder.n_producers_visited_; |
1121 | int n_consumers_visited_before = finder.n_consumers_visited_; |
1122 | finder(subtrees[i]); |
1123 | // Check if the subtree contains at least a producer |
1124 | if (finder.n_producers_visited_ != n_producers_visited_before) { |
1125 | last_producer_position = i; |
1126 | } |
1127 | // Check if the subtree contains at least a consumer |
1128 | if (finder.n_consumers_visited_ != n_consumers_visited_before) { |
1129 | if (first_consumer_position == n) { |
1130 | first_consumer_position = i; |
1131 | } |
1132 | } |
1133 | } |
1134 | if (last_producer_position >= first_consumer_position) { |
1135 | throw InsertionPointNotFoundError(self->mod, last_producer_position, first_consumer_position); |
1136 | } |
1137 | return ProducerConsumerSplit{last_producer_position, // |
1138 | first_consumer_position, // |
1139 | finder.n_producers_visited_, // |
1140 | finder.n_consumers_visited_}; |
1141 | } |
1142 | |
1143 | /******** Block-buffer relation ********/ |
1144 | |
1145 | BufferRegion GetNthAccessBufferRegion(const ScheduleState& self, const Block& block, int n, |
1146 | BufferIndexType index_type) { |
1147 | class BufferIndexOutOfRangeError : public ScheduleError { |
1148 | public: |
1149 | explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index, |
1150 | BufferIndexType index_type) |
1151 | : mod_(std::move(mod)), |
1152 | block_(std::move(block)), |
1153 | buffer_index_(buffer_index), |
1154 | index_type_(index_type) {} |
1155 | |
1156 | String FastErrorString() const final { |
1157 | if (index_type_ == BufferIndexType::kWrite) { |
1158 | return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " |
1159 | "range " |
1160 | "[0, num_write_regions) where `num_write_regions` is the number of buffer regions " |
1161 | "written by the block." ; |
1162 | } else { |
1163 | return "ScheduleError: The input `buffer_index` is out of range. It is required to be in " |
1164 | "range " |
1165 | "[0, num_read_regions) where `num_read_regions` is the number of buffer regions " |
1166 | "read by the block." ; |
1167 | } |
1168 | } |
1169 | |
1170 | String DetailRenderTemplate() const final { |
1171 | std::ostringstream os; |
1172 | size_t num = |
1173 | index_type_ == BufferIndexType::kWrite ? block_->writes.size() : block_->reads.size(); |
1174 | os << "The block {0} has " << num << " " << BufferIndexType2Str(index_type_) |
1175 | << " regions, so `buffer_index` is required to be in [0, " << num |
1176 | << "). However, the input `buffer_index` is " << buffer_index_ |
1177 | << ", which is out of the expected range." ; |
1178 | return os.str(); |
1179 | } |
1180 | |
1181 | IRModule mod() const final { return mod_; } |
1182 | Array<ObjectRef> LocationsOfInterest() const final { return {block_}; } |
1183 | |
1184 | private: |
1185 | IRModule mod_; |
1186 | Block block_; |
1187 | int buffer_index_; |
1188 | BufferIndexType index_type_; |
1189 | }; |
1190 | |
1191 | const Array<BufferRegion>& access_region = |
1192 | index_type == BufferIndexType::kWrite ? block->writes : block->reads; |
1193 | |
1194 | if (n < 0 || static_cast<int>(access_region.size()) <= n) { |
1195 | throw BufferIndexOutOfRangeError(self->mod, block, n, index_type); |
1196 | } |
1197 | return access_region[n]; |
1198 | } |
1199 | |
1200 | Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, |
1201 | BufferIndexType index_type) { |
1202 | return GetNthAccessBufferRegion(self, block, n, index_type)->buffer; |
1203 | } |
1204 | |
1205 | std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_sref, |
1206 | const Buffer& buffer) { |
1207 | // Climb up along the sref tree, and find the block where `buffer` is in alloc_buffers or |
1208 | // match_buffers. |
1209 | const StmtSRefNode* defining_site_sref = block_sref.get(); |
1210 | while (defining_site_sref != nullptr) { |
1211 | const auto* block = defining_site_sref->StmtAs<BlockNode>(); |
1212 | // If this sref is not a block sref, skip it. |
1213 | if (block == nullptr) { |
1214 | defining_site_sref = defining_site_sref->parent; |
1215 | continue; |
1216 | } |
1217 | // Try to find the buffer in `allloc_buffers` |
1218 | for (const Buffer& alloc_buffer : block->alloc_buffers) { |
1219 | if (buffer.same_as(alloc_buffer)) { |
1220 | return {GetRef<StmtSRef>(defining_site_sref), true}; |
1221 | } |
1222 | } |
1223 | // We do not allow the buffer being defined in `match_buffer`. |
1224 | for (const MatchBufferRegion match_buffer : block->match_buffers) { |
1225 | if (buffer.same_as(match_buffer)) { |
1226 | return {GetRef<StmtSRef>(defining_site_sref), false}; |
1227 | } |
1228 | } |
1229 | defining_site_sref = defining_site_sref->parent; |
1230 | } |
1231 | // If we cannot find the defining site block, it means that the buffer must be in the function's |
1232 | // buffer_map, which isn't an intermediate buffer. |
1233 | return {NullOpt, false}; |
1234 | } |
1235 | |
1236 | /******** SRef Tree Related ********/ |
1237 | |
1238 | StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { |
1239 | const StmtSRefNode* p = sref.get(); |
1240 | for (; p->parent != nullptr; p = p->parent) { |
1241 | } |
1242 | return GetRef<StmtSRef>(p); |
1243 | } |
1244 | |
1245 | /******** Misc ********/ |
1246 | |
1247 | bool HasOp(const Stmt& stmt, const Array<Op>& ops) { |
1248 | std::unordered_set<const Object*> op_set; |
1249 | op_set.reserve(ops.size()); |
1250 | for (const Op& op : ops) { |
1251 | op_set.insert(op.operator->()); |
1252 | } |
1253 | bool found = false; |
1254 | PreOrderVisit(stmt, [&found, &op_set](const ObjectRef& obj) -> bool { |
1255 | if (found) { |
1256 | return false; |
1257 | } |
1258 | if (const auto* call = obj.as<CallNode>()) { |
1259 | if (op_set.count(call->op.operator->())) { |
1260 | found = true; |
1261 | } |
1262 | } |
1263 | return !found; |
1264 | }); |
1265 | return found; |
1266 | } |
1267 | |
1268 | bool HasIfThenElse(const Stmt& stmt) { |
1269 | bool has_branch = false; |
1270 | auto f_visit = [&has_branch](const ObjectRef& obj) -> bool { |
1271 | if (has_branch) { |
1272 | // stop visiting |
1273 | return false; |
1274 | } |
1275 | if (const auto* realize = obj.as<BlockRealizeNode>()) { |
1276 | // Case 1: BlockRealize |
1277 | if (!is_one(realize->predicate)) { |
1278 | has_branch = true; |
1279 | } |
1280 | } else if (obj->IsInstance<IfThenElseNode>() || obj->IsInstance<SelectNode>()) { |
1281 | // Case 2: IfThenElse / Select |
1282 | has_branch = true; |
1283 | } else if (const auto* call = obj.as<CallNode>()) { |
1284 | // Case 3: Call the `if_then_else` operator |
1285 | static const Op& op_if_then_else = Op::Get("tir.if_then_else" ); |
1286 | if (call->op.same_as(op_if_then_else)) { |
1287 | has_branch = true; |
1288 | } |
1289 | } |
1290 | return !has_branch; |
1291 | }; |
1292 | PreOrderVisit(stmt, f_visit); |
1293 | return has_branch; |
1294 | } |
1295 | |
1296 | std::tuple</*exists=*/bool, |
1297 | /*surjective=*/bool, |
1298 | /*injective=*/bool, |
1299 | /*ordered=*/bool, |
1300 | /*no_const_read=*/bool, |
1301 | /*no_shift_read=*/bool> |
1302 | AnalyzeReadWritePattern(const BufferRegion& read_region, const BufferRegion& write_region) { |
1303 | static constexpr const std::tuple<bool, bool, bool, bool, bool, bool> kNotExist = |
1304 | std::make_tuple(false, false, false, false, false, false); |
1305 | // Step 1. Extract the write indices |
1306 | int w_dim = write_region->buffer->shape.size(); |
1307 | std::unordered_map<const VarNode*, int> var2idx; |
1308 | var2idx.reserve(w_dim); |
1309 | for (int i = 0; i < w_dim; ++i) { |
1310 | const Range& dom = write_region->region[i]; |
1311 | if (as_const_int(dom->extent) == nullptr) { |
1312 | return kNotExist; |
1313 | } |
1314 | if (const auto* v = dom->min.as<VarNode>()) { |
1315 | var2idx.emplace(v, i); |
1316 | } else { |
1317 | return kNotExist; |
1318 | } |
1319 | } |
1320 | // Step 2. Map each read index to a write index |
1321 | bool no_const_read = true; |
1322 | bool no_shift_read = true; |
1323 | int r_dim = read_region->buffer->shape.size(); |
1324 | std::vector<int> mapped(r_dim, -1); |
1325 | for (int i = 0; i < r_dim; ++i) { |
1326 | const Range& dom = read_region->region[i]; |
1327 | if (as_const_int(dom->extent) == nullptr) { |
1328 | return kNotExist; |
1329 | } |
1330 | // Case 1. Read index is a constant |
1331 | if (as_const_int(dom->min) != nullptr) { |
1332 | no_const_read = false; |
1333 | continue; |
1334 | } |
1335 | // Case 2. Read index cannot be recognized as `var +/- const` |
1336 | // where `var` is a write index and `const` is an optional constant shift |
1337 | Optional<IntImm> opt_const = NullOpt; |
1338 | const VarNode* var = |
1339 | static_cast<const VarNode*>(AnalyzeVarWithShift(dom->min, &opt_const).get()); |
1340 | if (var == nullptr || !var2idx.count(var)) { |
1341 | return kNotExist; |
1342 | } |
1343 | // Case 3. Read index is `var +/- const` |
1344 | mapped[i] = var2idx.at(var); |
1345 | if (opt_const.defined()) { |
1346 | no_shift_read = false; |
1347 | } |
1348 | } |
1349 | // Step 3. Check if the mapping is ordered, and count how many times each var is mapped |
1350 | std::vector<int> mapped_counter(w_dim, 0); |
1351 | bool ordered = true; |
1352 | int last_mapped = -1; |
1353 | for (int i : mapped) { |
1354 | if (i != -1) { |
1355 | ++mapped_counter[i]; |
1356 | if (last_mapped != -1 && last_mapped > i) { |
1357 | ordered = false; |
1358 | } |
1359 | last_mapped = i; |
1360 | } |
1361 | } |
1362 | // Step 4. Check if the mapping is surjective or injective |
1363 | // Surjective: each write index is mapped at least once |
1364 | // Injective: each write index is mapped at most once |
1365 | bool surjective = true; |
1366 | bool injective = true; |
1367 | for (int cnt : mapped_counter) { |
1368 | if (cnt == 0) { |
1369 | surjective = false; |
1370 | } else if (cnt >= 2) { |
1371 | injective = false; |
1372 | } |
1373 | } |
1374 | return std::make_tuple(/*exist=*/true, surjective, injective, ordered, no_const_read, |
1375 | no_shift_read); |
1376 | } |
1377 | |
1378 | /******** Storage Scope ********/ |
1379 | |
1380 | void CheckStorageScope(const ScheduleState& self, String storage_scope) { |
1381 | class InvalidStorageScopeError : public ScheduleError { |
1382 | public: |
1383 | explicit InvalidStorageScopeError(IRModule mod, String storage_scope) |
1384 | : mod_(std::move(mod)), storage_scope_(std::move(storage_scope)) {} |
1385 | |
1386 | String FastErrorString() const final { |
1387 | return "ScheduleError: The input storage scope is invalid" ; |
1388 | } |
1389 | |
1390 | String DetailRenderTemplate() const final { |
1391 | return "The input storage scope \"" + storage_scope_ + "\" is invalid." ; |
1392 | } |
1393 | |
1394 | Array<ObjectRef> LocationsOfInterest() const final { return {}; } |
1395 | IRModule mod() const final { return mod_; } |
1396 | |
1397 | private: |
1398 | IRModule mod_; |
1399 | String storage_scope_; |
1400 | }; |
1401 | |
1402 | try { |
1403 | runtime::StorageScope::Create(std::string(storage_scope)); |
1404 | } catch (...) { |
1405 | throw InvalidStorageScopeError(self->mod, std::move(storage_scope)); |
1406 | } |
1407 | } |
1408 | |
1409 | bool IsSpatial(const StmtSRef& block_sref) { |
1410 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
1411 | for (const IterVar& iter_var : block->iter_vars) { |
1412 | if (iter_var->iter_type != IterVarType::kDataPar) { |
1413 | return false; |
1414 | } |
1415 | } |
1416 | return true; |
1417 | } |
1418 | |
1419 | bool IsTrivialBinding(const ScheduleState& self, const StmtSRef& block_sref) { |
1420 | TVM_SREF_TO_BLOCK(block_sref); |
1421 | Array<StmtSRef> loops = GetLoops(block_sref); |
1422 | Array<PrimExpr> binds = GetBlockRealize(self, block_sref)->iter_values; |
1423 | if (loops.size() != binds.size()) { |
1424 | return false; |
1425 | } |
1426 | for (int i = 0, n = loops.size(); i < n; ++i) { |
1427 | const ForNode* loop = TVM_SREF_TO_FOR(loops[i]); |
1428 | if (binds[i].get() != loop->loop_var.get()) { |
1429 | return false; |
1430 | } |
1431 | } |
1432 | return true; |
1433 | } |
1434 | |
1435 | bool NeedsMultiLevelTiling(const ScheduleState& self, const StmtSRef& block_sref) { |
1436 | if (HasBeenMultiLevelTiled(block_sref)) { |
1437 | return false; |
1438 | } |
1439 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
1440 | if (block->writes.size() != 1 || block->reads.empty() || IsSpatial(block_sref) || |
1441 | !IsTrivialBinding(self, block_sref)) { |
1442 | return false; |
1443 | } |
1444 | const BufferNode* write_buffer = block->writes[0]->buffer.get(); |
1445 | // Step 1. Sort out spatial block variables. Skip the block iters of domain [0, 1), since such |
1446 | // block iters distracts the following check of the unused block iters. |
1447 | std::vector<const VarNode*> spatial_block_vars; |
1448 | spatial_block_vars.reserve(block->iter_vars.size()); |
1449 | for (const IterVar& block_var : block->iter_vars) { |
1450 | const int64_t* dom_min = as_const_int(block_var->dom->min); |
1451 | const int64_t* dom_extent = as_const_int(block_var->dom->extent); |
1452 | bool has_trivial_dom = |
1453 | dom_min != nullptr && dom_extent != nullptr && *dom_min == 0 && *dom_extent == 1; |
1454 | if (block_var->iter_type == IterVarType::kDataPar && !has_trivial_dom) { |
1455 | spatial_block_vars.push_back(block_var->var.get()); |
1456 | } |
1457 | } |
1458 | // Step 2. Enumerate each read region, check the number of block vars that are not used |
1459 | // to index the read region |
1460 | int total_unused_block_vars = 0; |
1461 | std::unordered_set<const BufferNode*> read_buffers; |
1462 | read_buffers.reserve(block->reads.size()); |
1463 | for (const BufferRegion& buffer_region : block->reads) { |
1464 | const BufferNode* buffer = buffer_region->buffer.get(); |
1465 | const Array<Range>& regions = buffer_region->region; |
1466 | // Step 2.1. Duplication of read buffers are not allowed |
1467 | if (read_buffers.insert(buffer).second == false) { |
1468 | return false; |
1469 | } |
1470 | // Step 2.2. Skip the reduction buffer |
1471 | if (buffer == write_buffer) { |
1472 | continue; |
1473 | } |
1474 | // Step 2.3. Collect the block vars that are used to index the read region |
1475 | std::unordered_set<const VarNode*> vars; |
1476 | for (const Range& range : regions) { |
1477 | if (as_const_int(range->extent) == nullptr) { |
1478 | return false; |
1479 | } |
1480 | for (const Var& var : UndefinedVars(range->min)) { |
1481 | vars.insert(var.get()); |
1482 | } |
1483 | } |
1484 | // Step 2.4. Check if the block vars are not used to index the read region |
1485 | int n_unused_block_vars = 0; |
1486 | for (const VarNode* block_var : spatial_block_vars) { |
1487 | if (vars.count(block_var) == 0) { |
1488 | ++n_unused_block_vars; |
1489 | } |
1490 | } |
1491 | total_unused_block_vars += n_unused_block_vars; |
1492 | } |
1493 | return total_unused_block_vars >= 1; |
1494 | } |
1495 | |
1496 | bool IsSpatialPrimFunc(const PrimFunc& func) { |
1497 | bool result = true; |
1498 | PreOrderVisit(func->body, [&result](const ObjectRef& obj) { |
1499 | if (result == false) { |
1500 | return false; |
1501 | } |
1502 | if (const auto* block = obj.as<BlockNode>()) { |
1503 | for (const IterVar& iter_var : block->iter_vars) { |
1504 | if (iter_var->iter_type != IterVarType::kDataPar) { |
1505 | result = false; |
1506 | return false; |
1507 | } |
1508 | } |
1509 | } |
1510 | return true; |
1511 | }); |
1512 | return result; |
1513 | } |
1514 | |
1515 | std::pair<int64_t, int64_t> GetCumulativeSpaceAndReductionLength(const tir::ScheduleState& self, |
1516 | const tir::StmtSRef& block_sref) { |
1517 | Array<tir::StmtSRef> loops = tir::GetLoops(block_sref); |
1518 | int64_t cum_space_len = 1, cum_reduce_len = 1; |
1519 | /* |
1520 | * Return (-1, -1) if |
1521 | * 1. there is some loop with type other than kDataPar and kCommReduce; |
1522 | * 2. there is some loop which is dynamic. |
1523 | */ |
1524 | for (const tir::StmtSRef& loop_sref : loops) { |
1525 | tir::IterVarType type = GetLoopIterType(loop_sref); |
1526 | if (type == tir::kDataPar) { |
1527 | const int64_t* extent = GetLoopIntExtent(loop_sref); |
1528 | if (*extent != -1) { |
1529 | cum_space_len *= *extent; |
1530 | } else { |
1531 | return std::make_pair(-1, -1); |
1532 | } |
1533 | } else if (type == tir::kCommReduce) { |
1534 | const int64_t* extent = GetLoopIntExtent(loop_sref); |
1535 | if (*extent != -1) { |
1536 | cum_reduce_len *= *extent; |
1537 | } else { |
1538 | return std::make_pair(-1, -1); |
1539 | } |
1540 | } else { |
1541 | return std::make_pair(-1, -1); |
1542 | } |
1543 | } |
1544 | return std::make_pair(cum_space_len, cum_reduce_len); |
1545 | } |
1546 | |
1547 | bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // |
1548 | const tir::StmtSRef& block_sref, // |
1549 | int64_t max_parallel_extent, // |
1550 | int64_t max_parallel_basic) { |
1551 | const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); |
1552 | Array<tir::StmtSRef> loops = tir::GetLoops(block_sref); |
1553 | |
1554 | // Cond 1. The block must have at lease one write buffer |
1555 | if (block->writes.size() == 0) { |
1556 | return false; |
1557 | } |
1558 | |
1559 | // Cond 2. The block is a reduction block and has trivial binding. |
1560 | const StmtSRef& scope_sref = GetScopeRoot(self, block_sref, |
1561 | /*require_stage_pipeline=*/false); |
1562 | if (!IsReductionBlock(self, block_sref, scope_sref) // |
1563 | || !IsTrivialBinding(self, block_sref) // |
1564 | || HasBeenMultiLevelTiled(block_sref)) { |
1565 | return false; |
1566 | } |
1567 | |
1568 | // Cond 3. Every the loop axis must be either spatial axis or reduction axis. |
1569 | for (const tir::StmtSRef& loop_sref : loops) { |
1570 | const tir::IterVarType& type = GetLoopIterType(loop_sref); |
1571 | if (type != tir::kDataPar && type != tir::kCommReduce) { |
1572 | return false; |
1573 | } |
1574 | } |
1575 | |
1576 | // Cond 4. Whether there is at least one reduction loop. |
1577 | // Cond 5. The loops are continuous, and the body of the innermost loop is exactly the block. |
1578 | bool has_reduction_loop = false; |
1579 | for (size_t i = 0; i < loops.size(); ++i) { |
1580 | // Cond 4. |
1581 | if (GetLoopIterType(loops[i]) == tir::kCommReduce) { |
1582 | has_reduction_loop = true; |
1583 | } |
1584 | |
1585 | // Cond 5. |
1586 | const ForNode* loop_i = TVM_SREF_TO_FOR(loops[i]); |
1587 | if (i < loops.size() - 1) { |
1588 | const ForNode* loop_i1 = TVM_SREF_TO_FOR(loops[i + 1]); |
1589 | if (loop_i->body.get() != loop_i1) { |
1590 | return false; |
1591 | } |
1592 | } else { |
1593 | const auto* block_realize = loop_i->body.as<tir::BlockRealizeNode>(); |
1594 | if (!block_realize || block_realize->block.get() != block) { |
1595 | return false; |
1596 | } |
1597 | } |
1598 | } |
1599 | if (!has_reduction_loop) { |
1600 | return false; |
1601 | } |
1602 | |
1603 | // Cond 6. Can successfully calculating the cumulative loop length. |
1604 | auto [cum_space_len, cum_reduce_len] = GetCumulativeSpaceAndReductionLength(self, block_sref); |
1605 | if (cum_space_len == -1 || cum_reduce_len == -1) { |
1606 | return false; |
1607 | } |
1608 | |
1609 | // Cond 7. |
1610 | if (NeedsMultiLevelTiling(self, block_sref)) { |
1611 | // Do not use rfactor/cross-thread-reduction if we have enough parallelism on spatial loops. |
1612 | return !(cum_space_len >= cum_reduce_len || cum_space_len > max_parallel_extent); |
1613 | } else { |
1614 | // Always try rfactor/cross-thread-reduction for other reduction blocks. |
1615 | return cum_reduce_len > 1; |
1616 | } |
1617 | } |
1618 | |
1619 | PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer) { |
1620 | auto simplified = analyzer->Simplify(expr); |
1621 | if (simplified->IsInstance<IntImmNode>()) { |
1622 | return expr; |
1623 | } else { |
1624 | return simplified; |
1625 | } |
1626 | } |
1627 | |
1628 | TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); |
1629 | |
1630 | /*! \brief Auxiliary data structure of information extracted from tensor intrin description */ |
1631 | struct TensorIntrinDescInfo { |
1632 | /*! \brief The block of the description function, which is the (unique) direct child of the root |
1633 | * block. |
1634 | */ |
1635 | const BlockRealizeNode* desc_block = nullptr; |
1636 | /*! \brief The loops of the description function, in the order from outer loops to inner ones. */ |
1637 | std::vector<const tir::ForNode*> desc_loops; |
1638 | /*! \brief The loop variables. */ |
1639 | std::unordered_set<const tir::VarNode*> desc_loop_vars; |
1640 | }; |
1641 | |
1642 | /*! |
1643 | * \brief Extract auxilary information from the tensor intrin description. |
1644 | * \param analyze The arithmetic analyzer |
1645 | * \param desc_func The description PrimFunc |
1646 | * \return The auxilary information |
1647 | */ |
1648 | TensorIntrinDescInfo (arith::Analyzer* analyzer, |
1649 | const PrimFunc& desc_func) { |
1650 | TensorIntrinDescInfo info; |
1651 | const auto* desc_scope_realize = desc_func->body.as<BlockRealizeNode>(); |
1652 | ICHECK(desc_scope_realize); |
1653 | { |
1654 | auto f_visit = [&](const ObjectRef& obj) -> bool { |
1655 | // Extract the block |
1656 | if (const auto* block = obj.as<BlockRealizeNode>()) { |
1657 | info.desc_block = block; |
1658 | return false; |
1659 | } |
1660 | // Extract the loops |
1661 | if (const auto* loop = obj.as<ForNode>()) { |
1662 | info.desc_loops.push_back(loop); |
1663 | info.desc_loop_vars.insert(loop->loop_var.get()); |
1664 | if (!analyzer->CanProve(loop->min == 0)) { |
1665 | return false; |
1666 | } |
1667 | } |
1668 | return true; |
1669 | }; |
1670 | tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); |
1671 | std::reverse(info.desc_loops.begin(), info.desc_loops.end()); |
1672 | ICHECK(info.desc_block); |
1673 | } |
1674 | return info; |
1675 | } |
1676 | |
1677 | Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self, |
1678 | const tir::StmtSRef& block_sref, |
1679 | const tir::PrimFunc& desc_func, |
1680 | bool allow_padding) { |
1681 | arith::Analyzer analyzer; |
1682 | const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); |
1683 | // Step 1. Analyze desc_func, extract its block, loops and loop vars |
1684 | TensorIntrinDescInfo desc_info = ExtractTensorIntrinDescInfo(&analyzer, desc_func); |
1685 | // Step 2. Collect loops from block_sref |
1686 | const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); |
1687 | TVM_SREF_TO_BLOCK(scope_sref); |
1688 | std::vector<const tir::ForNode*> block_loops; |
1689 | std::unordered_set<const tir::VarNode*> block_loop_vars; |
1690 | { |
1691 | for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { |
1692 | const auto* loop = loop_sref->StmtAs<tir::ForNode>(); |
1693 | if (loop == nullptr || loop->body->IsInstance<tir::SeqStmtNode>()) { |
1694 | break; |
1695 | } |
1696 | block_loops.push_back(loop); |
1697 | block_loop_vars.insert(loop->loop_var.get()); |
1698 | if (!analyzer.CanProve(loop->min == 0)) { |
1699 | return NullOpt; |
1700 | } |
1701 | } |
1702 | std::reverse(block_loops.begin(), block_loops.end()); |
1703 | } |
1704 | // Step 3. Map from block loops to desc block loops |
1705 | const std::vector<const ForNode*>& desc_loops = desc_info.desc_loops; |
1706 | const std::unordered_set<const VarNode*>& desc_loop_vars = desc_info.desc_loop_vars; |
1707 | const BlockRealizeNode* desc_block = desc_info.desc_block; |
1708 | ObjectPtr<TensorizeInfoNode> ret = make_object<TensorizeInfoNode>(); |
1709 | const int n_block_vars = block->iter_values.size(); |
1710 | const int n_desc_vars = desc_block->iter_values.size(); |
1711 | const int offset = n_block_vars - n_desc_vars; |
1712 | |
1713 | std::unordered_map<int, int> block_index_to_padding; // padding of each block iter if necessary |
1714 | |
1715 | if (offset < 0) { |
1716 | return NullOpt; |
1717 | } |
1718 | |
1719 | const std::vector<IterVarType> iter_types_block = GetBlockVarTypes(block_sref); |
1720 | const std::vector<IterVarType> iter_types_desc = GetBlockVarTypes(desc_block->block.get()); |
1721 | |
1722 | ICHECK(desc_loops.size() == static_cast<size_t>(n_desc_vars)); |
1723 | ICHECK(block_loops.size() == iter_types_block.size()); |
1724 | |
1725 | // We assume that the orders of iter_vars in the target and the desc block are consistent. |
1726 | // Based on that assumption, the following logic supports arbitrary permutations of a loop order, |
1727 | // such as |
1728 | |
1729 | // for k: |
1730 | // for i: |
1731 | // for j: |
1732 | // C[i, j] += A[i, k] * B[k, j] |
1733 | |
1734 | // or |
1735 | |
1736 | // for i: |
1737 | // for j: |
1738 | // for k: |
1739 | // C[i, j] += A[i, k] * B[k, j] |
1740 | |
1741 | int next_block_ind = block_loops.size() - 1; |
1742 | for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { |
1743 | // Step 3.1. Find the corresponding loop of the i_desc-th block var of desc |
1744 | const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; |
1745 | const tir::ForNode* desc_loop = nullptr; |
1746 | IterVarType iter_type_desc = iter_types_desc[i_desc]; |
1747 | for (int i = 0, n = desc_loops.size(); i < n; ++i) { |
1748 | // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars |
1749 | PrimExpr residual = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); |
1750 | if (!UsesVar(residual, |
1751 | [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) { |
1752 | desc_loop = desc_loops[i]; |
1753 | iter_type_desc = iter_types_desc[i]; |
1754 | break; |
1755 | } |
1756 | } |
1757 | if (desc_loop == nullptr || desc_loop->extent.as<IntImmNode>() == nullptr) { |
1758 | return NullOpt; |
1759 | } |
1760 | |
1761 | const IntImmNode* int_desc_extent = desc_loop->extent.as<IntImmNode>(); |
1762 | |
1763 | // Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type |
1764 | PrimExpr block_bind; |
1765 | int current_block_ind = next_block_ind; |
1766 | for (; current_block_ind >= 0; --current_block_ind) { |
1767 | if (iter_types_block[current_block_ind] == iter_type_desc) { |
1768 | next_block_ind = current_block_ind - 1; |
1769 | block_bind = block->iter_values[current_block_ind]; |
1770 | break; |
1771 | } |
1772 | } |
1773 | |
1774 | if (!block_bind.defined()) return NullOpt; |
1775 | |
1776 | // Step 3.3. Find the corresponding loop of the target block |
1777 | for (int i = 0, n = block_loops.size(); i < n; ++i) { |
1778 | // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars |
1779 | const tir::ForNode* block_loop = block_loops[i]; |
1780 | const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; |
1781 | // Skip i-th loop if it has already been mapped |
1782 | if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue; |
1783 | |
1784 | PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var); |
1785 | if (UsesVar(residual, |
1786 | [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { |
1787 | continue; |
1788 | } |
1789 | // padding is allowed only when the block has trivial bindings |
1790 | if (allow_padding && !is_zero(residual)) { |
1791 | allow_padding = false; |
1792 | } |
1793 | |
1794 | const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>(); |
1795 | |
1796 | // Check divisibility |
1797 | if (!int_block_extent) { |
1798 | return NullOpt; |
1799 | } |
1800 | int64_t remainder = int_block_extent->value % int_desc_extent->value; |
1801 | if (remainder != 0) { |
1802 | if (allow_padding) { |
1803 | // If the block loop is not divisible by the desc loop, we pad the block loop to make it |
1804 | // divisible if padding is allowed. |
1805 | block_index_to_padding[current_block_ind] = int_desc_extent->value - remainder; |
1806 | } else { |
1807 | return NullOpt; |
1808 | } |
1809 | } |
1810 | |
1811 | ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop)); |
1812 | break; |
1813 | } |
1814 | } |
1815 | |
1816 | for (int i = 0, n = desc_loops.size(); i < n; ++i) { |
1817 | ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i)); |
1818 | } |
1819 | if (!block_index_to_padding.empty()) { |
1820 | if (!allow_padding) { |
1821 | return NullOpt; |
1822 | } |
1823 | Array<Integer> paddings; |
1824 | for (int i = 0, n = block->block->iter_vars.size(); i < n; ++i) { |
1825 | const IterVar& iter_var = block->block->iter_vars[i]; |
1826 | if (auto it = block_index_to_padding.find(i); it != block_index_to_padding.end()) { |
1827 | paddings.push_back(IntImm(iter_var->var.dtype(), it->second)); |
1828 | } else { |
1829 | paddings.push_back(IntImm(iter_var->var.dtype(), 0)); |
1830 | } |
1831 | } |
1832 | ret->block_iter_paddings = std::move(paddings); |
1833 | } |
1834 | |
1835 | return TensorizeInfo(ret); |
1836 | } |
1837 | |
1838 | TVM_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc" ).set_body_typed(IsSpatialPrimFunc); |
1839 | TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping" ) |
1840 | .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func, bool allow_padding) { |
1841 | return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding); |
1842 | }); |
1843 | |
1844 | /******** Auto Tensorization ********/ |
1845 | |
1846 | /*! \brief IndexMap proposer for layout transformation in auto tensorization. */ |
1847 | class AutoTensorizeMappingProposer { |
1848 | public: |
1849 | static Array<IndexMap> ProposeMappings(const AutoTensorizeComparator* , |
1850 | arith::Analyzer* analyzer) { |
1851 | AutoTensorizeMappingProposer proposer(extractor, analyzer); |
1852 | proposer.CollectFeasibleSet(); |
1853 | return proposer.ProposeAllFuseMapping(); |
1854 | } |
1855 | |
1856 | private: |
1857 | explicit AutoTensorizeMappingProposer(const AutoTensorizeComparator* , |
1858 | arith::Analyzer* analyzer) |
1859 | : extractor_(extractor), analyzer_(analyzer) {} |
1860 | |
1861 | using VarSet = std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>; |
1862 | |
1863 | void CollectFeasibleSet() { |
1864 | // Collect the set of potential iter var mapping between the workload and the tensor intrin. |
1865 | // We analyze the appearance of each variable in the buffer indices of each buffer on LHS and |
1866 | // RHS. The appearance of a variable in the buffer indices is encoded as bit-masks (BufferMask). |
1867 | // Variables on the LHS and the RHS with the same bit-mask and the same iter type are potential |
1868 | // mappings. |
1869 | // |
1870 | // For example, consider the conv2d case. We will try to match the workload |
1871 | // conv2d[n, h, w, c] = sum_{rh, rw, rc} X[n, h + rh, w + rw, c + rc] * W[rh, rw, rc, c] |
1872 | // against a matmul tensor intrin |
1873 | // C[m, n] = sum_{k} A[m, k] * B[k, n] |
1874 | // First we extract the correspondence of the buffers: conv2d <=> C, A <=> X, B <=> W. |
1875 | // Then for each variable, we extract the buffers where it is used for indexing. |
1876 | // Take the variable m on the RHS as an example. m is used to index buffer A and C. On the LHS, |
1877 | // we will find the variables used to index only the exact corresponding buffers conv2d and X |
1878 | // (the variable is not allowed to index other buffers). In this case, n, h, w is used to index |
1879 | // both buffer conv2d and W, and not in other buffers. Therefore, {n, h, w} <=> m is a potential |
1880 | // mapping. |
1881 | |
1882 | // Note: the mapping is not unique when multiple variables on RHS has the same bit-mask. |
1883 | // This is currently not supported. |
1884 | |
1885 | using BufferMask = std::vector<bool>; |
1886 | |
1887 | // Step 1: Assign an index to each buffer in LHS and RHS |
1888 | std::unordered_map<Buffer, int, ObjectPtrHash, ObjectEqual> rhs_buffer_index; |
1889 | std::unordered_map<Buffer, int, ObjectPtrHash, ObjectEqual> lhs_buffer_index; |
1890 | { |
1891 | int i = 0; |
1892 | for (const auto& kv : extractor_->rhs_buffer_map_) { |
1893 | const Buffer& rhs_buffer = kv.first; |
1894 | const Buffer& lhs_buffer = kv.second; |
1895 | rhs_buffer_index[rhs_buffer] = i; |
1896 | lhs_buffer_index[lhs_buffer] = i; |
1897 | ++i; |
1898 | } |
1899 | } |
1900 | |
1901 | // Step 2: Compute the buffer mask |
1902 | ICHECK_EQ(rhs_buffer_index.size(), lhs_buffer_index.size()); |
1903 | int num_buffers = rhs_buffer_index.size(); |
1904 | std::unordered_map<const VarNode*, std::vector<bool>> rhs_buffer_masks, lhs_buffer_masks; |
1905 | // helper function to initialize or update the buffer mask |
1906 | auto update_mask = [&](const VarNode* var, |
1907 | std::unordered_map<const VarNode*, std::vector<bool>>* masks, int i) { |
1908 | if (!masks->count(var)) { |
1909 | (*masks)[var].resize(num_buffers); |
1910 | } |
1911 | (*masks)[var][i] = true; |
1912 | }; |
1913 | |
1914 | for (const auto& it : extractor_->rhs_buffer_indices_map_) { |
1915 | const Buffer& rhs_buffer = it.first; |
1916 | for (const PrimExpr& rhs_index : it.second) { |
1917 | if (const VarNode* var_node = rhs_index.as<VarNode>()) { |
1918 | update_mask(var_node, &rhs_buffer_masks, rhs_buffer_index.at(rhs_buffer)); |
1919 | } else { |
1920 | LOG(FATAL) << "ValueError: Buffer index " << rhs_index |
1921 | << " other that variables in tensor intrinsics is not supported." ; |
1922 | } |
1923 | } |
1924 | |
1925 | auto lhs_buffer_it = extractor_->rhs_buffer_map_.find(rhs_buffer); |
1926 | ICHECK(lhs_buffer_it != extractor_->rhs_buffer_map_.end()); |
1927 | const Buffer& lhs_buffer = lhs_buffer_it->second; |
1928 | for (const PrimExpr& index : extractor_->lhs_buffer_indices_map_.at(lhs_buffer)) { |
1929 | PreOrderVisit(index, [&](const ObjectRef& obj) -> bool { |
1930 | if (const VarNode* var = obj.as<VarNode>()) { |
1931 | update_mask(var, &lhs_buffer_masks, lhs_buffer_index.at(lhs_buffer)); |
1932 | } |
1933 | return true; |
1934 | }); |
1935 | } |
1936 | } |
1937 | |
1938 | // Step 3: Find variables on LHS and RHS with the same buffer mask. Ensure LHS and RHS vars |
1939 | // have the same iter type. |
1940 | std::unordered_map<BufferMask, VarSet> mask_to_rhs_vars; |
1941 | for (const auto& kv : rhs_buffer_masks) { |
1942 | const VarNode* rhs_var = kv.first; |
1943 | const BufferMask& mask = kv.second; |
1944 | mask_to_rhs_vars[mask].insert(GetRef<Var>(rhs_var)); |
1945 | } |
1946 | std::unordered_map<const VarNode*, IterVarType> rhs_var_iter_type; |
1947 | for (const auto& iter : extractor_->rhs_iters_) { |
1948 | rhs_var_iter_type.emplace(iter->var.get(), iter->iter_type); |
1949 | } |
1950 | for (const auto& iter : extractor_->lhs_iters_) { |
1951 | auto& potential_mappings = lhs_feasible_vars_[iter->var]; |
1952 | VarSet rhs_candidates = mask_to_rhs_vars[lhs_buffer_masks[iter->var.get()]]; |
1953 | std::copy_if( |
1954 | rhs_candidates.begin(), rhs_candidates.end(), |
1955 | std::inserter(potential_mappings, potential_mappings.begin()), |
1956 | [&](const Var& var) { return rhs_var_iter_type.at(var.get()) == iter->iter_type; }); |
1957 | } |
1958 | } |
1959 | |
1960 | Array<IndexMap> ProposeAllFuseMapping() { |
1961 | // Now we have calcuated potential mapping for each iter var on LHS. For iters on LHS mapped to |
1962 | // the same iter on RHS, they will be fused in the original order in LHS block iters. We will |
1963 | // generate IndexMap to represent such fusion on LHS. For example, if n, h, w on LHS are mapped |
1964 | // to the same iter var on RHS, we will produce index map `lambda n, h, w: fuse(n, h, w)`, where |
1965 | // fuse(v0, .., vn) = ((v0 * v1_extent + v1) + ... ) * vn_extent + vn |
1966 | |
1967 | // the parameters of the result index map, each parameter corresponds to a LHS iter |
1968 | Array<Var> index_map_src; |
1969 | // the outputs of the result index map |
1970 | Array<PrimExpr> index_map_tgt; |
1971 | |
1972 | // Step 1: Collect extents of LHS iters and prepare the initial indices of the IndexMap |
1973 | Map<Var, PrimExpr> lhs_iter_extents; |
1974 | for (const auto& iter : extractor_->lhs_iters_) { |
1975 | lhs_iter_extents.Set(iter->var, iter->dom->extent); |
1976 | index_map_src.push_back(iter->var.copy_with_suffix("" )); |
1977 | } |
1978 | |
1979 | // Step 2: Each iter on RHS has a group of corresponding iters on LHS. Initialize the fusion |
1980 | // result for each group of iters on LHS. |
1981 | Map<Var, PrimExpr> fused_lhs_iters; |
1982 | for (const auto& iter : extractor_->rhs_iters_) { |
1983 | fused_lhs_iters.Set(iter->var, 0); |
1984 | } |
1985 | |
1986 | // Step 3: Fuse LHS iters mapped to the same RHS iter |
1987 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_rhs_vars; |
1988 | for (size_t i = 0; i < extractor_->lhs_iters_.size(); ++i) { |
1989 | const Var& lhs_iter_var = extractor_->lhs_iters_[i]->var; |
1990 | const VarSet& rhs_candidates = lhs_feasible_vars_[lhs_iter_var]; |
1991 | if (rhs_candidates.empty()) { |
1992 | // put unmapped iters at the beginning |
1993 | index_map_tgt.push_back(index_map_src[i]); |
1994 | } else if (rhs_candidates.size() == 1) { |
1995 | Var rhs_var = *rhs_candidates.begin(); |
1996 | PrimExpr fused_lhs = fused_lhs_iters.at(rhs_var); |
1997 | PrimExpr updated_fused_lhs = |
1998 | fused_lhs * lhs_iter_extents.at(lhs_iter_var) + index_map_src[i]; |
1999 | fused_lhs_iters.Set(rhs_var, updated_fused_lhs); |
2000 | used_rhs_vars.insert(rhs_var); |
2001 | } else { |
2002 | // non-unique mapping is not supported |
2003 | return {}; |
2004 | } |
2005 | } |
2006 | for (const auto& iter : extractor_->rhs_iters_) { |
2007 | if (!used_rhs_vars.count(iter->var)) { |
2008 | return {}; |
2009 | } |
2010 | index_map_tgt.push_back(analyzer_->Simplify(fused_lhs_iters[iter->var])); |
2011 | } |
2012 | // At most one mapping is supported. |
2013 | return {IndexMap(index_map_src, index_map_tgt)}; |
2014 | } |
2015 | |
2016 | private: |
2017 | // The extractor that has extracted information for auto tensorization from the workload and the |
2018 | // tensor intrin. |
2019 | const AutoTensorizeComparator* ; |
2020 | // The arithmetic analyzer. |
2021 | arith::Analyzer* analyzer_; |
2022 | /*! \brief Potential mappings on RHS for each variable on LHS */ |
2023 | std::unordered_map<Var, VarSet, ObjectPtrHash, ObjectPtrEqual> lhs_feasible_vars_; |
2024 | }; |
2025 | |
2026 | bool CheckAutoTensorizeApplicable(const ScheduleState& state, const tir::StmtSRef& block_sref, |
2027 | const tir::PrimFunc& desc_func, |
2028 | AutoTensorizeComparator* ) { |
2029 | // Step 1. Analyze desc_func, extract its block, loops and loop vars |
2030 | // Step 2. Check if `desc_block` matches `block` |
2031 | // Ignore the scope of buffers when comparing, since we can do cache_read/write |
2032 | const BlockRealize& block = tir::GetBlockRealize(state, block_sref); |
2033 | arith::Analyzer analyzer; |
2034 | auto desc_info = tir::ExtractTensorIntrinDescInfo(&analyzer, desc_func); |
2035 | |
2036 | return extractor->VisitStmt(block->block, desc_info.desc_block->block); |
2037 | } |
2038 | |
2039 | bool CheckAutoTensorizeApplicable(const tir::Schedule& sch, const tir::BlockRV& block_rv, |
2040 | const tir::PrimFunc& desc_func) { |
2041 | AutoTensorizeComparator (sch->state()->mod); |
2042 | return CheckAutoTensorizeApplicable(sch->state(), sch->GetSRef(block_rv), desc_func, &extractor); |
2043 | } |
2044 | |
2045 | Optional<AutoTensorizeMappingInfo> GetAutoTensorizeMappingInfo(const tir::ScheduleState& self, |
2046 | const tir::StmtSRef& block_sref, |
2047 | const tir::PrimFunc& desc_func) { |
2048 | AutoTensorizeComparator (self->mod); |
2049 | if (!CheckAutoTensorizeApplicable(self, block_sref, desc_func, &extractor)) { |
2050 | return NullOpt; |
2051 | } |
2052 | arith::Analyzer analyzer; |
2053 | Array<IndexMap> mappings = AutoTensorizeMappingProposer::ProposeMappings(&extractor, &analyzer); |
2054 | if (mappings.empty()) { |
2055 | return NullOpt; |
2056 | } |
2057 | ObjectPtr<AutoTensorizeMappingInfoNode> ret = make_object<AutoTensorizeMappingInfoNode>(); |
2058 | ret->mappings = std::move(mappings); |
2059 | ret->lhs_buffer_map = std::move(extractor.lhs_buffer_map_); |
2060 | ret->rhs_buffer_indices = std::move(extractor.rhs_buffer_indices_map_); |
2061 | ret->lhs_iters = std::move(extractor.lhs_iters_); |
2062 | ret->rhs_iters = std::move(extractor.rhs_iters_); |
2063 | return AutoTensorizeMappingInfo(ret); |
2064 | } |
2065 | |
2066 | TVM_REGISTER_NODE_TYPE(AutoTensorizeMappingInfoNode); |
2067 | |
2068 | TVM_REGISTER_GLOBAL("tir.schedule.GetAutoTensorizeMappingInfo" ) |
2069 | .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { |
2070 | return GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block), desc_func); |
2071 | }); |
2072 | |
2073 | TVM_REGISTER_GLOBAL("tir.schedule.HasBlock" ).set_body_typed(HasBlock); |
2074 | |
2075 | } // namespace tir |
2076 | } // namespace tvm |
2077 | |