1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_TIR_SCHEDULE_UTILS_H_ |
20 | #define TVM_TIR_SCHEDULE_UTILS_H_ |
21 | |
22 | #include <tvm/arith/analyzer.h> |
23 | #include <tvm/arith/int_set.h> |
24 | #include <tvm/arith/iter_affine_map.h> |
25 | #include <tvm/node/serialization.h> |
26 | #include <tvm/tir/analysis.h> |
27 | #include <tvm/tir/function.h> |
28 | #include <tvm/tir/op.h> |
29 | #include <tvm/tir/schedule/instruction.h> |
30 | #include <tvm/tir/schedule/schedule.h> |
31 | #include <tvm/tir/schedule/state.h> |
32 | #include <tvm/tir/schedule/trace.h> |
33 | #include <tvm/tir/stmt_functor.h> |
34 | |
35 | #include <string> |
36 | #include <unordered_map> |
37 | #include <unordered_set> |
38 | #include <utility> |
39 | |
40 | #include "../../arith/pattern_match.h" |
41 | #include "../../node/attr_registry.h" |
42 | #include "../../runtime/thread_storage_scope.h" |
43 | #include "../../support/array.h" |
44 | #include "../../support/nd_int_set.h" |
45 | #include "./analysis.h" |
46 | #include "./error.h" |
47 | #include "./instruction_traits.h" |
48 | #include "./primitive.h" |
49 | #include "./transform.h" |
50 | |
51 | namespace tvm { |
52 | namespace tir { |
53 | |
54 | /*! |
55 | * \brief A helper macro to convert an sref to the statement it points to, |
56 | * then check if the downcasting succeeded. |
57 | * \param Result The result variable, used for checking |
58 | * \param SRef The SRef to be cast |
59 | * \param Type The type to be cast to, can be Block or For |
60 | */ |
61 | #define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \ |
62 | SRef->StmtAs<Type>(); \ |
63 | ICHECK(Result) |
64 | |
65 | /*! |
66 | * \brief A helper macro to convert an sref to the block it points to, |
67 | * |
68 | * Throws an internal error if downcasting fails. The variable name |
69 | * in the parent scope is used for the error message. |
70 | * |
71 | * \param SRef The SRef to be cast |
72 | */ |
73 | #define TVM_SREF_TO_BLOCK(SRef) \ |
74 | [&]() { \ |
75 | auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode) \ |
76 | << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \ |
77 | << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ |
78 | return result; \ |
79 | }() |
80 | |
81 | /*! |
82 | * \brief A helper macro to convert an sref to the for-loop it points to |
83 | * |
84 | * Throws an internal error if downcasting fails. The variable name |
85 | * in the parent scope is used for the error message. |
86 | * |
87 | * \param SRef The SRef to be cast |
88 | */ |
89 | #define TVM_SREF_TO_FOR(SRef) \ |
90 | [&]() { \ |
91 | auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode) \ |
92 | << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \ |
93 | << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \ |
94 | return result; \ |
95 | }() |
96 | |
97 | /*! |
98 | * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`, |
99 | * then check if the downcasting succeeded. |
100 | * \param Result The result variable, used for checking |
101 | * \param From The ObjectRef to be downcast |
102 | * \param Type The type to be downcast to |
103 | */ |
104 | #define TVM_TYPE_AS_OR_ERR(Result, From, Type) \ |
105 | From.as<Type>(); \ |
106 | ICHECK(Result) |
107 | |
108 | /*! |
109 | * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`, |
110 | * throwing an internal error if downcast fails. |
111 | * \param Result The result variable, used for checking |
112 | * \param From The ObjectRef to be downcast |
113 | * \param Type The type to be downcast to |
114 | */ |
115 | #define TVM_TYPE_AS(From, Type) \ |
116 | [&]() { \ |
117 | auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type) \ |
118 | << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ |
119 | << "`, but gets: " << ((From).defined() ? (From)->GetTypeKey() : "None"); \ |
120 | return result; \ |
121 | }() |
122 | |
123 | /*! |
124 | * \brief Convert an array of loop StmtSRefs to an array of loops |
125 | * \param loop_srefs The loop StmtSRefs to be converted |
126 | * \return The conversion result loops |
127 | */ |
128 | inline Array<For> LoopSRefs2Loops(const Array<StmtSRef>& loop_srefs) { |
129 | Array<For> loops; |
130 | loops.reserve(loop_srefs.size()); |
131 | for (StmtSRef loop_sref : loop_srefs) { |
132 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
133 | loops.push_back(GetRef<For>(loop)); |
134 | } |
135 | return loops; |
136 | } |
137 | |
138 | /*! |
139 | * \brief Convert an array of block rvs to an array of block StmtSRefs |
140 | * \param sch The schedule used to evaluate the random variables |
141 | * \param block_rvs The random variables to be converted |
142 | * \return The conversion result srefs |
143 | */ |
144 | inline Array<StmtSRef> BlockRVs2StmtSRefs(const Schedule& sch, const Array<BlockRV>& block_rvs) { |
145 | Array<StmtSRef> block_srefs; |
146 | block_srefs.reserve(block_rvs.size()); |
147 | for (const BlockRV& block_rv : block_rvs) { |
148 | block_srefs.push_back(sch->GetSRef(block_rv)); |
149 | } |
150 | return block_srefs; |
151 | } |
152 | |
153 | /******** Storage scope ********/ |
154 | |
155 | /*! |
156 | * \brief Determine if iterators of a storage scope should be relaxed |
157 | * under a specific thread scope |
158 | * \param storage_scope The storage scope that the iterators are on |
159 | * \param thread_scope The thread scope to be relaxed |
160 | * \return A boolean indicating the result |
161 | */ |
162 | inline bool CanRelaxStorageUnderThread(const runtime::StorageScope& storage_scope, |
163 | const runtime::ThreadScope& thread_scope) { |
164 | if (storage_scope.rank == runtime::StorageRank::kWarp) { |
165 | // for warp memory, we only relax threadIdx.x |
166 | return thread_scope.rank == 1 && thread_scope.dim_index == 0; |
167 | } |
168 | return static_cast<int>(storage_scope.rank) <= static_cast<int>(thread_scope.rank); |
169 | } |
170 | |
171 | /******** SeqStmt ********/ |
172 | |
173 | /*! |
174 | * \brief Remove a specific Stmt from a SeqStmt. If a SeqStmt contains a BlockRealize, |
175 | * whose block is the Stmt to be removed, then remove that BlockRealize too. |
176 | * \param seq The SeqStmt to be removed from |
177 | * \param to_remove The Stmt to be removed |
178 | * \return The removal result |
179 | */ |
180 | inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { |
181 | ICHECK_GT(seq->size(), 1); |
182 | Array<Stmt> new_stmts; |
183 | new_stmts.reserve(seq->size()); |
184 | for (const Stmt& stmt : seq->seq) { |
185 | if (to_remove.same_as(stmt)) { |
186 | continue; |
187 | } |
188 | if (const auto* realize = stmt.as<BlockRealizeNode>()) { |
189 | if (to_remove.same_as(realize->block)) { |
190 | continue; |
191 | } |
192 | } |
193 | new_stmts.push_back(stmt); |
194 | } |
195 | return SeqStmt::Flatten(new_stmts); |
196 | } |
197 | |
198 | /*! |
199 | * \brief Convert a Stmt to an Array. |
200 | * \param stmt The Stmt to be converted to |
201 | * \return If the Stmt is SeqStmt, then returns the sequence; |
202 | * Otherwise, returns a single-element Array with the Stmt inside. |
203 | */ |
204 | inline Array<Stmt> AsArray(const Stmt& stmt) { |
205 | if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) { |
206 | return seq_stmt->seq; |
207 | } |
208 | return {stmt}; |
209 | } |
210 | |
211 | /*! |
212 | * \brief Checks of a statement is a SeqStmt that contains multiple statements |
213 | * \param stmt The statement to be checked |
214 | * \return A boolean indicating the result |
215 | */ |
216 | inline bool IsSingleStmt(const Stmt& stmt) { |
217 | if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) { |
218 | return seq_stmt->seq.size() == 1; |
219 | } |
220 | return true; |
221 | } |
222 | |
223 | /******** IterVar ********/ |
224 | |
225 | /*! |
226 | * \brief Create a new IterVar for the input For loop, with specified name and type |
227 | * \param loop The loop to be created from |
228 | * \param name The name of the new IterVar |
229 | * \param iter_var_type The type of the new IterVar |
230 | * \return The newly created IterVar |
231 | */ |
232 | inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_var_type) { |
233 | return IterVar(Range::FromMinExtent(loop->min, loop->extent), |
234 | Var(std::move(name), loop->loop_var.dtype()), iter_var_type); |
235 | } |
236 | |
237 | /*! |
238 | * \brief Get the thread scope bound to the specific loop |
239 | * \param loop The loop to be inspected |
240 | * \return The thread scope bound to the loop |
241 | */ |
242 | inline runtime::ThreadScope GetThreadScope(const ForNode* loop) { |
243 | if (loop->kind == ForKind::kThreadBinding) { |
244 | return runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag); |
245 | } |
246 | return runtime::ThreadScope{-1, -1}; |
247 | } |
248 | |
249 | /*! |
250 | * \brief Check if the thread scope is blockIdx |
251 | * \param thread_scope The thread scope to be checked |
252 | * \return True if the thread scope is blockIdx |
253 | */ |
254 | inline bool IsBlockIdx(const runtime::ThreadScope& thread_scope) { |
255 | return thread_scope.rank == 0; // The rank of blockIdx is 0 |
256 | } |
257 | |
258 | /*! |
259 | * \brief Check if the thread scope is threadIdx |
260 | * \param thread_scope The thread scope to be checked |
261 | * \return True if the thread scope is threadIdx |
262 | */ |
263 | inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) { |
264 | return thread_scope.rank == 1 && thread_scope.dim_index >= 0; |
265 | } |
266 | |
267 | /**************** Loop extents ****************/ |
268 | |
269 | /*! |
270 | * \brief Get the extents of a loop |
271 | * \param loop The loop to be queried |
272 | * \return The extent of the loop, nullptr if the extent is not constant |
273 | */ |
274 | inline const int64_t* GetLoopIntExtent(const ForNode* loop) { return as_const_int(loop->extent); } |
275 | |
276 | /*! |
277 | * \brief Get the extents of a loop |
278 | * \param loop_sref The loop to be queried |
279 | * \return The extent of the loop, nullptr if the extent is not constant |
280 | */ |
281 | inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) { |
282 | const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); |
283 | return as_const_int(loop->extent); |
284 | } |
285 | |
286 | /*! |
287 | * \brief Check if an expression consists of a single variable, |
288 | * or a variable plus/minus an constant integer shift |
289 | * \param expr The expression to be checked |
290 | * \return The single variable in the expression, or NullOpt if the expression is neither a variable |
291 | * or a constant shift from a variable |
292 | */ |
293 | inline Optional<Var> AnalyzeVarWithShift(const PrimExpr& expr, Optional<IntImm>* constant) { |
294 | if (const auto* var = expr.as<VarNode>()) { |
295 | *constant = NullOpt; |
296 | return GetRef<Var>(var); |
297 | } |
298 | arith::PVar<Var> var; |
299 | arith::PVar<IntImm> shift; |
300 | // match: "var + shift" |
301 | if ((var + shift).Match(expr) || (shift + var).Match(expr)) { |
302 | *constant = shift.Eval(); |
303 | return var.Eval(); |
304 | } |
305 | // match: "var - shift" |
306 | if ((var - shift).Match(expr)) { |
307 | IntImm result = shift.Eval(); |
308 | *constant = IntImm(result->dtype, -result->value); |
309 | return var.Eval(); |
310 | } |
311 | return NullOpt; |
312 | } |
313 | |
314 | /******** Annotation ********/ |
315 | |
316 | /*! |
317 | * \brief Get the annotation on a Block/For |
318 | * \tparam TObjectRef The type of the annotation value |
319 | * \param sref The sref to the block or the for loop |
320 | * \param ann_key The annotation key to be looked up |
321 | * \return NullOpt if not found; otherwise the annotation value |
322 | */ |
323 | template <class TObjectRef, class TStmtNode> |
324 | inline Optional<TObjectRef> GetAnn(const TStmtNode* stmt, const String& ann_key) { |
325 | const Map<String, ObjectRef>* annotations = &stmt->annotations; |
326 | for (const auto& ann : *annotations) { |
327 | if (ann.first == ann_key) { |
328 | return Downcast<TObjectRef>(ann.second); |
329 | } |
330 | } |
331 | return NullOpt; |
332 | } |
333 | |
334 | /*! |
335 | * \brief Get the annotation on a Block/For |
336 | * \tparam TObjectRef The type of the annotation value |
337 | * \param sref The sref to the block or the for loop |
338 | * \param ann_key The annotation key to be looked up |
339 | * \return NullOpt if not found; otherwise the annotation value |
340 | */ |
341 | template <class TObjectRef> |
342 | inline Optional<TObjectRef> GetAnn(const StmtSRef& sref, const String& ann_key) { |
343 | if (const auto* loop = sref->StmtAs<ForNode>()) { |
344 | return GetAnn<TObjectRef, ForNode>(loop, ann_key); |
345 | } else if (const auto* block = sref->StmtAs<BlockNode>()) { |
346 | return GetAnn<TObjectRef, BlockNode>(block, ann_key); |
347 | } else { |
348 | LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey(); |
349 | throw; |
350 | } |
351 | } |
352 | |
353 | /*! |
354 | * \brief Check if a Block/For has a specific pair of annotation key and values |
355 | * \param sref The sref to the block or the for loop |
356 | * \param ann_key The annotation key to be checked |
357 | * \param ann_val The annotation value to be checked |
358 | * \return Whether a Block/For has a specific pair of annotation key and values |
359 | */ |
360 | inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& ann_val) { |
361 | Optional<String> result = GetAnn<String>(sref, ann_key); |
362 | return result.defined() && result.value() == ann_val; |
363 | } |
364 | |
365 | /*! |
366 | * \brief Check if a Block/For has a specific pair of annotation key and values |
367 | * \param sref The sref to the block or the for loop |
368 | * \param ann_key The annotation key to be checked |
369 | * \param ann_val The boolean annotation value to be checked |
370 | * \return Whether a Block/For has a specific pair of annotation key and values |
371 | */ |
372 | inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) { |
373 | Optional<Bool> result = GetAnn<Bool>(sref, ann_key); |
374 | return result.defined() && result.value() == ann_val; |
375 | } |
376 | |
377 | /********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/ |
378 | |
379 | /*! |
380 | * \brief Reorder the reduction loops to innermost positions if needed. |
381 | * \param sch The schedule |
382 | * \param block_rv The block where to apply the reorder |
383 | * \param fused_reduce_loop The fusion-generated loop to return. |
384 | * \param num_spatial_loops The number of spatial loops to return. |
385 | * \note Before invoking this helper function, make sure that the block has only spatial and |
386 | * reduction loop axes. |
387 | */ |
388 | inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv, |
389 | tir::LoopRV* fused_reduce_loop, |
390 | size_t* num_spatial_loops) { |
391 | Array<tir::LoopRV> loops = sch->GetLoops(block_rv); |
392 | Array<tir::StmtSRef> loop_srefs; |
393 | for (const tir::LoopRV& loop_rv : loops) { |
394 | loop_srefs.push_back(sch->GetSRef(loop_rv)); |
395 | } |
396 | |
397 | Array<tir::LoopRV> new_order; |
398 | // Step 1. Add spatial loops. |
399 | *num_spatial_loops = 0; |
400 | for (size_t i = 0; i < loops.size(); ++i) { |
401 | if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) { |
402 | new_order.push_back(loops[i]); |
403 | (*num_spatial_loops)++; |
404 | } |
405 | } |
406 | // Step 2. Add reduction loops. |
407 | Array<tir::LoopRV> reduction_loops; |
408 | for (size_t i = 0; i < loops.size(); ++i) { |
409 | if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) { |
410 | new_order.push_back(loops[i]); |
411 | reduction_loops.push_back(loops[i]); |
412 | } |
413 | } |
414 | // Step 3. Apply reordering if new_order differs from the original order. |
415 | ICHECK_EQ(new_order.size(), loops.size()); |
416 | for (size_t i = 0; i < loops.size(); ++i) { |
417 | if (!new_order[i].same_as(loops[i])) { |
418 | sch->Reorder(new_order); |
419 | break; |
420 | } |
421 | } |
422 | // Step 4. Fuse all the reduction loops if there are multiple reduction loops. |
423 | CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop" ; |
424 | if (reduction_loops.size() > 1) { |
425 | *fused_reduce_loop = sch->Fuse(reduction_loops); |
426 | } else { |
427 | *fused_reduce_loop = reduction_loops[0]; |
428 | } |
429 | } |
430 | |
431 | /******** Helper functions for enum conversion ********/ |
432 | |
433 | /*! |
434 | * \brief Convert BufferIndexType to String |
435 | * \param buffer_index_type The BufferIndexType value to convert |
436 | * \return The string representation of BufferIndexType |
437 | */ |
438 | inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { |
439 | if (buffer_index_type == BufferIndexType::kRead) { |
440 | return "read" ; |
441 | } else { |
442 | ICHECK(buffer_index_type == BufferIndexType::kWrite); |
443 | return "write" ; |
444 | } |
445 | } |
446 | |
447 | /******** Utilities for retrieving information about blocks ********/ |
448 | |
449 | /*! \brief Returns the names of the blocks in the provided module. */ |
450 | inline std::unordered_set<std::string> GetBlockNames(const IRModule& mod) { |
451 | struct BlockNameCollector : public tir::StmtVisitor { |
452 | void VisitStmt_(const tir::BlockNode* block) override { |
453 | block_names.insert(block->name_hint); |
454 | StmtVisitor::VisitStmt(block->body); |
455 | } |
456 | std::unordered_set<std::string> block_names; |
457 | }; |
458 | |
459 | if (auto prim_func = tir::FindEntryFunc(mod, nullptr)) { |
460 | BlockNameCollector collector; |
461 | collector(prim_func->body); |
462 | return collector.block_names; |
463 | } |
464 | return {}; |
465 | } |
466 | |
467 | /*! \brief Query if the given block name exists in the module associated with the schedule */ |
468 | inline bool HasBlock(const Schedule& sch, const std::string& block_name) { |
469 | auto block_names = GetBlockNames(sch->mod()); |
470 | return block_names.count(block_name); |
471 | } |
472 | |
473 | /******** Utilites for trace application ********/ |
474 | |
475 | /*! |
476 | * \brief Translate the input objects using the provided substitution map. |
477 | * \param inputs The input objects. |
478 | * \param rv_map The substitution map for variables. |
479 | * \return The transformed objects. |
480 | */ |
481 | Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs, |
482 | const std::unordered_map<const Object*, const Object*>& rv_map); |
483 | |
484 | /*! |
485 | * \brief Update the variable substitution map according to the new outputs. |
486 | * \param old_outputs The previous outputs of a schedule instruction. |
487 | * \param new_outputs The new outputs of the same schedule instruction. |
488 | * \param rv_map The substitution map for variables. |
489 | */ |
490 | void TranslateAddOutputRVs(const Array<ObjectRef>& old_outputs, const Array<ObjectRef>& new_outputs, |
491 | std::unordered_map<const Object*, const Object*>* rv_map); |
492 | |
493 | } // namespace tir |
494 | } // namespace tvm |
495 | |
496 | #endif // TVM_TIR_SCHEDULE_UTILS_H_ |
497 | |