1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_TIR_SCHEDULE_UTILS_H_
20#define TVM_TIR_SCHEDULE_UTILS_H_
21
22#include <tvm/arith/analyzer.h>
23#include <tvm/arith/int_set.h>
24#include <tvm/arith/iter_affine_map.h>
25#include <tvm/node/serialization.h>
26#include <tvm/tir/analysis.h>
27#include <tvm/tir/function.h>
28#include <tvm/tir/op.h>
29#include <tvm/tir/schedule/instruction.h>
30#include <tvm/tir/schedule/schedule.h>
31#include <tvm/tir/schedule/state.h>
32#include <tvm/tir/schedule/trace.h>
33#include <tvm/tir/stmt_functor.h>
34
35#include <string>
36#include <unordered_map>
37#include <unordered_set>
38#include <utility>
39
40#include "../../arith/pattern_match.h"
41#include "../../node/attr_registry.h"
42#include "../../runtime/thread_storage_scope.h"
43#include "../../support/array.h"
44#include "../../support/nd_int_set.h"
45#include "./analysis.h"
46#include "./error.h"
47#include "./instruction_traits.h"
48#include "./primitive.h"
49#include "./transform.h"
50
51namespace tvm {
52namespace tir {
53
54/*!
55 * \brief A helper macro to convert an sref to the statement it points to,
56 * then check if the downcasting succeeded.
57 * \param Result The result variable, used for checking
58 * \param SRef The SRef to be cast
59 * \param Type The type to be cast to, can be Block or For
60 */
61#define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \
62 SRef->StmtAs<Type>(); \
63 ICHECK(Result)
64
65/*!
66 * \brief A helper macro to convert an sref to the block it points to,
67 *
68 * Throws an internal error if downcasting fails. The variable name
69 * in the parent scope is used for the error message.
70 *
71 * \param SRef The SRef to be cast
72 */
73#define TVM_SREF_TO_BLOCK(SRef) \
74 [&]() { \
75 auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode) \
76 << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \
77 << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \
78 return result; \
79 }()
80
81/*!
82 * \brief A helper macro to convert an sref to the for-loop it points to
83 *
84 * Throws an internal error if downcasting fails. The variable name
85 * in the parent scope is used for the error message.
86 *
87 * \param SRef The SRef to be cast
88 */
89#define TVM_SREF_TO_FOR(SRef) \
90 [&]() { \
91 auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode) \
92 << "TypeError: Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \
93 << ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \
94 return result; \
95 }()
96
97/*!
98 * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`,
99 * then check if the downcasting succeeded.
100 * \param Result The result variable, used for checking
101 * \param From The ObjectRef to be downcast
102 * \param Type The type to be downcast to
103 */
104#define TVM_TYPE_AS_OR_ERR(Result, From, Type) \
105 From.as<Type>(); \
106 ICHECK(Result)
107
108/*!
109 * \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`,
110 * throwing an internal error if downcast fails.
111 * \param Result The result variable, used for checking
112 * \param From The ObjectRef to be downcast
113 * \param Type The type to be downcast to
114 */
115#define TVM_TYPE_AS(From, Type) \
116 [&]() { \
117 auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type) \
118 << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \
119 << "`, but gets: " << ((From).defined() ? (From)->GetTypeKey() : "None"); \
120 return result; \
121 }()
122
123/*!
124 * \brief Convert an array of loop StmtSRefs to an array of loops
125 * \param loop_srefs The loop StmtSRefs to be converted
126 * \return The conversion result loops
127 */
128inline Array<For> LoopSRefs2Loops(const Array<StmtSRef>& loop_srefs) {
129 Array<For> loops;
130 loops.reserve(loop_srefs.size());
131 for (StmtSRef loop_sref : loop_srefs) {
132 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
133 loops.push_back(GetRef<For>(loop));
134 }
135 return loops;
136}
137
138/*!
139 * \brief Convert an array of block rvs to an array of block StmtSRefs
140 * \param sch The schedule used to evaluate the random variables
141 * \param block_rvs The random variables to be converted
142 * \return The conversion result srefs
143 */
144inline Array<StmtSRef> BlockRVs2StmtSRefs(const Schedule& sch, const Array<BlockRV>& block_rvs) {
145 Array<StmtSRef> block_srefs;
146 block_srefs.reserve(block_rvs.size());
147 for (const BlockRV& block_rv : block_rvs) {
148 block_srefs.push_back(sch->GetSRef(block_rv));
149 }
150 return block_srefs;
151}
152
153/******** Storage scope ********/
154
155/*!
156 * \brief Determine if iterators of a storage scope should be relaxed
157 * under a specific thread scope
158 * \param storage_scope The storage scope that the iterators are on
159 * \param thread_scope The thread scope to be relaxed
160 * \return A boolean indicating the result
161 */
162inline bool CanRelaxStorageUnderThread(const runtime::StorageScope& storage_scope,
163 const runtime::ThreadScope& thread_scope) {
164 if (storage_scope.rank == runtime::StorageRank::kWarp) {
165 // for warp memory, we only relax threadIdx.x
166 return thread_scope.rank == 1 && thread_scope.dim_index == 0;
167 }
168 return static_cast<int>(storage_scope.rank) <= static_cast<int>(thread_scope.rank);
169}
170
171/******** SeqStmt ********/
172
173/*!
174 * \brief Remove a specific Stmt from a SeqStmt. If a SeqStmt contains a BlockRealize,
175 * whose block is the Stmt to be removed, then remove that BlockRealize too.
176 * \param seq The SeqStmt to be removed from
177 * \param to_remove The Stmt to be removed
178 * \return The removal result
179 */
180inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) {
181 ICHECK_GT(seq->size(), 1);
182 Array<Stmt> new_stmts;
183 new_stmts.reserve(seq->size());
184 for (const Stmt& stmt : seq->seq) {
185 if (to_remove.same_as(stmt)) {
186 continue;
187 }
188 if (const auto* realize = stmt.as<BlockRealizeNode>()) {
189 if (to_remove.same_as(realize->block)) {
190 continue;
191 }
192 }
193 new_stmts.push_back(stmt);
194 }
195 return SeqStmt::Flatten(new_stmts);
196}
197
198/*!
199 * \brief Convert a Stmt to an Array.
200 * \param stmt The Stmt to be converted to
201 * \return If the Stmt is SeqStmt, then returns the sequence;
202 * Otherwise, returns a single-element Array with the Stmt inside.
203 */
204inline Array<Stmt> AsArray(const Stmt& stmt) {
205 if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) {
206 return seq_stmt->seq;
207 }
208 return {stmt};
209}
210
211/*!
212 * \brief Checks of a statement is a SeqStmt that contains multiple statements
213 * \param stmt The statement to be checked
214 * \return A boolean indicating the result
215 */
216inline bool IsSingleStmt(const Stmt& stmt) {
217 if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) {
218 return seq_stmt->seq.size() == 1;
219 }
220 return true;
221}
222
223/******** IterVar ********/
224
225/*!
226 * \brief Create a new IterVar for the input For loop, with specified name and type
227 * \param loop The loop to be created from
228 * \param name The name of the new IterVar
229 * \param iter_var_type The type of the new IterVar
230 * \return The newly created IterVar
231 */
232inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_var_type) {
233 return IterVar(Range::FromMinExtent(loop->min, loop->extent),
234 Var(std::move(name), loop->loop_var.dtype()), iter_var_type);
235}
236
237/*!
238 * \brief Get the thread scope bound to the specific loop
239 * \param loop The loop to be inspected
240 * \return The thread scope bound to the loop
241 */
242inline runtime::ThreadScope GetThreadScope(const ForNode* loop) {
243 if (loop->kind == ForKind::kThreadBinding) {
244 return runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag);
245 }
246 return runtime::ThreadScope{-1, -1};
247}
248
249/*!
250 * \brief Check if the thread scope is blockIdx
251 * \param thread_scope The thread scope to be checked
252 * \return True if the thread scope is blockIdx
253 */
254inline bool IsBlockIdx(const runtime::ThreadScope& thread_scope) {
255 return thread_scope.rank == 0; // The rank of blockIdx is 0
256}
257
258/*!
259 * \brief Check if the thread scope is threadIdx
260 * \param thread_scope The thread scope to be checked
261 * \return True if the thread scope is threadIdx
262 */
263inline bool IsThreadIdx(const runtime::ThreadScope& thread_scope) {
264 return thread_scope.rank == 1 && thread_scope.dim_index >= 0;
265}
266
267/**************** Loop extents ****************/
268
269/*!
270 * \brief Get the extents of a loop
271 * \param loop The loop to be queried
272 * \return The extent of the loop, nullptr if the extent is not constant
273 */
274inline const int64_t* GetLoopIntExtent(const ForNode* loop) { return as_const_int(loop->extent); }
275
276/*!
277 * \brief Get the extents of a loop
278 * \param loop_sref The loop to be queried
279 * \return The extent of the loop, nullptr if the extent is not constant
280 */
281inline const int64_t* GetLoopIntExtent(const StmtSRef& loop_sref) {
282 const ForNode* loop = TVM_SREF_TO_FOR(loop_sref);
283 return as_const_int(loop->extent);
284}
285
286/*!
287 * \brief Check if an expression consists of a single variable,
288 * or a variable plus/minus an constant integer shift
289 * \param expr The expression to be checked
290 * \return The single variable in the expression, or NullOpt if the expression is neither a variable
291 * or a constant shift from a variable
292 */
293inline Optional<Var> AnalyzeVarWithShift(const PrimExpr& expr, Optional<IntImm>* constant) {
294 if (const auto* var = expr.as<VarNode>()) {
295 *constant = NullOpt;
296 return GetRef<Var>(var);
297 }
298 arith::PVar<Var> var;
299 arith::PVar<IntImm> shift;
300 // match: "var + shift"
301 if ((var + shift).Match(expr) || (shift + var).Match(expr)) {
302 *constant = shift.Eval();
303 return var.Eval();
304 }
305 // match: "var - shift"
306 if ((var - shift).Match(expr)) {
307 IntImm result = shift.Eval();
308 *constant = IntImm(result->dtype, -result->value);
309 return var.Eval();
310 }
311 return NullOpt;
312}
313
314/******** Annotation ********/
315
316/*!
317 * \brief Get the annotation on a Block/For
318 * \tparam TObjectRef The type of the annotation value
319 * \param sref The sref to the block or the for loop
320 * \param ann_key The annotation key to be looked up
321 * \return NullOpt if not found; otherwise the annotation value
322 */
323template <class TObjectRef, class TStmtNode>
324inline Optional<TObjectRef> GetAnn(const TStmtNode* stmt, const String& ann_key) {
325 const Map<String, ObjectRef>* annotations = &stmt->annotations;
326 for (const auto& ann : *annotations) {
327 if (ann.first == ann_key) {
328 return Downcast<TObjectRef>(ann.second);
329 }
330 }
331 return NullOpt;
332}
333
334/*!
335 * \brief Get the annotation on a Block/For
336 * \tparam TObjectRef The type of the annotation value
337 * \param sref The sref to the block or the for loop
338 * \param ann_key The annotation key to be looked up
339 * \return NullOpt if not found; otherwise the annotation value
340 */
341template <class TObjectRef>
342inline Optional<TObjectRef> GetAnn(const StmtSRef& sref, const String& ann_key) {
343 if (const auto* loop = sref->StmtAs<ForNode>()) {
344 return GetAnn<TObjectRef, ForNode>(loop, ann_key);
345 } else if (const auto* block = sref->StmtAs<BlockNode>()) {
346 return GetAnn<TObjectRef, BlockNode>(block, ann_key);
347 } else {
348 LOG(FATAL) << "TypeError: Unknown type of sref: " << sref->stmt->GetTypeKey();
349 throw;
350 }
351}
352
353/*!
354 * \brief Check if a Block/For has a specific pair of annotation key and values
355 * \param sref The sref to the block or the for loop
356 * \param ann_key The annotation key to be checked
357 * \param ann_val The annotation value to be checked
358 * \return Whether a Block/For has a specific pair of annotation key and values
359 */
360inline bool HasAnn(const StmtSRef& sref, const String& ann_key, const String& ann_val) {
361 Optional<String> result = GetAnn<String>(sref, ann_key);
362 return result.defined() && result.value() == ann_val;
363}
364
365/*!
366 * \brief Check if a Block/For has a specific pair of annotation key and values
367 * \param sref The sref to the block or the for loop
368 * \param ann_key The annotation key to be checked
369 * \param ann_val The boolean annotation value to be checked
370 * \return Whether a Block/For has a specific pair of annotation key and values
371 */
372inline bool HasAnn(const StmtSRef& sref, const String& ann_key, bool ann_val) {
373 Optional<Bool> result = GetAnn<Bool>(sref, ann_key);
374 return result.defined() && result.value() == ann_val;
375}
376
377/********** Helper Functions for RuleAddRFactor and RuleCrossThreadReduction **********/
378
379/*!
380 * \brief Reorder the reduction loops to innermost positions if needed.
381 * \param sch The schedule
382 * \param block_rv The block where to apply the reorder
383 * \param fused_reduce_loop The fusion-generated loop to return.
384 * \param num_spatial_loops The number of spatial loops to return.
385 * \note Before invoking this helper function, make sure that the block has only spatial and
386 * reduction loop axes.
387 */
388inline void ReorderAndFuseReductionLoops(const tir::Schedule& sch, const tir::BlockRV& block_rv,
389 tir::LoopRV* fused_reduce_loop,
390 size_t* num_spatial_loops) {
391 Array<tir::LoopRV> loops = sch->GetLoops(block_rv);
392 Array<tir::StmtSRef> loop_srefs;
393 for (const tir::LoopRV& loop_rv : loops) {
394 loop_srefs.push_back(sch->GetSRef(loop_rv));
395 }
396
397 Array<tir::LoopRV> new_order;
398 // Step 1. Add spatial loops.
399 *num_spatial_loops = 0;
400 for (size_t i = 0; i < loops.size(); ++i) {
401 if (GetLoopIterType(loop_srefs[i]) == tir::kDataPar) {
402 new_order.push_back(loops[i]);
403 (*num_spatial_loops)++;
404 }
405 }
406 // Step 2. Add reduction loops.
407 Array<tir::LoopRV> reduction_loops;
408 for (size_t i = 0; i < loops.size(); ++i) {
409 if (GetLoopIterType(loop_srefs[i]) == tir::kCommReduce) {
410 new_order.push_back(loops[i]);
411 reduction_loops.push_back(loops[i]);
412 }
413 }
414 // Step 3. Apply reordering if new_order differs from the original order.
415 ICHECK_EQ(new_order.size(), loops.size());
416 for (size_t i = 0; i < loops.size(); ++i) {
417 if (!new_order[i].same_as(loops[i])) {
418 sch->Reorder(new_order);
419 break;
420 }
421 }
422 // Step 4. Fuse all the reduction loops if there are multiple reduction loops.
423 CHECK(!reduction_loops.empty()) << "ValueError: There should be at least one reduction loop";
424 if (reduction_loops.size() > 1) {
425 *fused_reduce_loop = sch->Fuse(reduction_loops);
426 } else {
427 *fused_reduce_loop = reduction_loops[0];
428 }
429}
430
431/******** Helper functions for enum conversion ********/
432
433/*!
434 * \brief Convert BufferIndexType to String
435 * \param buffer_index_type The BufferIndexType value to convert
436 * \return The string representation of BufferIndexType
437 */
438inline String BufferIndexType2Str(BufferIndexType buffer_index_type) {
439 if (buffer_index_type == BufferIndexType::kRead) {
440 return "read";
441 } else {
442 ICHECK(buffer_index_type == BufferIndexType::kWrite);
443 return "write";
444 }
445}
446
447/******** Utilities for retrieving information about blocks ********/
448
449/*! \brief Returns the names of the blocks in the provided module. */
450inline std::unordered_set<std::string> GetBlockNames(const IRModule& mod) {
451 struct BlockNameCollector : public tir::StmtVisitor {
452 void VisitStmt_(const tir::BlockNode* block) override {
453 block_names.insert(block->name_hint);
454 StmtVisitor::VisitStmt(block->body);
455 }
456 std::unordered_set<std::string> block_names;
457 };
458
459 if (auto prim_func = tir::FindEntryFunc(mod, nullptr)) {
460 BlockNameCollector collector;
461 collector(prim_func->body);
462 return collector.block_names;
463 }
464 return {};
465}
466
467/*! \brief Query if the given block name exists in the module associated with the schedule */
468inline bool HasBlock(const Schedule& sch, const std::string& block_name) {
469 auto block_names = GetBlockNames(sch->mod());
470 return block_names.count(block_name);
471}
472
473/******** Utilites for trace application ********/
474
475/*!
476 * \brief Translate the input objects using the provided substitution map.
477 * \param inputs The input objects.
478 * \param rv_map The substitution map for variables.
479 * \return The transformed objects.
480 */
481Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
482 const std::unordered_map<const Object*, const Object*>& rv_map);
483
484/*!
485 * \brief Update the variable substitution map according to the new outputs.
486 * \param old_outputs The previous outputs of a schedule instruction.
487 * \param new_outputs The new outputs of the same schedule instruction.
488 * \param rv_map The substitution map for variables.
489 */
490void TranslateAddOutputRVs(const Array<ObjectRef>& old_outputs, const Array<ObjectRef>& new_outputs,
491 std::unordered_map<const Object*, const Object*>* rv_map);
492
493} // namespace tir
494} // namespace tvm
495
496#endif // TVM_TIR_SCHEDULE_UTILS_H_
497