1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include "../utils.h" |
20 | |
21 | namespace tvm { |
22 | namespace meta_schedule { |
23 | |
24 | class CrossThreadReductionNode : public ScheduleRuleNode { |
25 | public: |
26 | // Inherited from ScheduleRuleNode |
27 | void InitializeWithTuneContext(const TuneContext& context) final { |
28 | ICHECK(context->target.defined()); |
29 | Target target = context->target.value(); |
30 | |
31 | Optional<Integer> opt_max_threads_per_block = target->GetAttr<Integer>("max_threads_per_block" ); |
32 | Optional<Integer> opt_warp_size = target->GetAttr<Integer>("thread_warp_size" ); |
33 | |
34 | if (!opt_max_threads_per_block.defined()) { |
35 | TVM_PY_LOG(WARNING, context->logger) |
36 | << "Target does not have attribute \"max_threads_per_block\", therefore the " |
37 | "rule CrossThreadReduction will not be applied" ; |
38 | } |
39 | if (!opt_warp_size.defined()) { |
40 | TVM_PY_LOG(WARNING, context->logger) |
41 | << "Target does not have attribute \"thread_warp_size\", therefore the rule " |
42 | "CrossThreadReduction will not be applied" ; |
43 | } |
44 | max_threads_per_block = opt_max_threads_per_block.value_or(Integer(-1))->value; |
45 | warp_size = opt_warp_size.value_or(Integer(-1))->value; |
46 | } |
47 | |
48 | // Inherited from ScheduleRuleNode |
49 | Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { |
50 | // Step 0. Check the conditions of this rule. |
51 | if (max_threads_per_block == -1 || warp_size == -1) { |
52 | return {sch}; |
53 | } |
54 | const tir::StmtSRef& block_sref = sch->GetSRef(block_rv); |
55 | if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_threads_per_block, |
56 | warp_size)) { |
57 | return {sch}; |
58 | } |
59 | |
60 | // Step 1. Make a copy of the original schedule. The new copy is used for scheduling. |
61 | tir::Schedule tmp_sch = sch->Copy(); |
62 | tmp_sch->Seed(sch->ForkSeed()); |
63 | |
64 | // Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the |
65 | // block to its consumers. We want to fuse as much as possible because it results in |
66 | // significantly faster schedule. |
67 | // `target_loop` is the loop position where the input block will be computed at. |
68 | // `target_block` is the consumer block that we want to compute-at the input block to. |
69 | // `tgt_block_innermost_loop` is the innermost loop outside the target block. |
70 | |
71 | auto [fusible, target_loop, target_block, tgt_block_innermost_loop] = |
72 | GetComputeTargetLoopAndBlock(tmp_sch, block_rv); |
73 | |
74 | // Step 3. Try block fusion. |
75 | int n_candidate = static_cast<int>(thread_extents.size()); |
76 | Array<FloatImm> probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate)); |
77 | tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); |
78 | if (fusible) { |
79 | ICHECK(target_block.defined()); |
80 | ICHECK(target_loop.defined()); |
81 | |
82 | // Step 3.1. |
83 | // - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first |
84 | // bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split |
85 | // the loop before binding. |
86 | // - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor. |
87 | if (!InThreadScope(tmp_sch, target_block)) { |
88 | const Array<tir::LoopRV>& split_res = |
89 | tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, thread_extent}); |
90 | tmp_sch->Bind(split_res[1], "threadIdx.x" ); |
91 | if (tgt_block_innermost_loop.same_as(target_loop)) { |
92 | target_loop = split_res[0]; |
93 | } |
94 | } else { |
95 | thread_extent = GetThreadIdxExtentFromTrace(tmp_sch->trace().value()); |
96 | } |
97 | // Step 3.2. Do the compute-at. |
98 | tmp_sch->ComputeAt(block_rv, target_loop, /*preserve_unit_loops=*/true); |
99 | // Step 3.3. Set the storage scope of the output buffer to shared memory. |
100 | tmp_sch->SetScope(block_rv, /*buffer_index=*/0, /*storage_scope=*/"shared" ); |
101 | } |
102 | |
103 | // Step 4. Reorder the loop axes if reduction loops are not innermost. After the reordering, |
104 | // fuse all the reduction loops. |
105 | size_t num_spatial_loops; |
106 | tir::LoopRV fused_reduce_loop; |
107 | ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops); |
108 | // Step 5. Split the fused reduction loop and bind the inner one to threadIdx. |
109 | const Array<tir::LoopRV>& split_res = |
110 | tmp_sch->Split(fused_reduce_loop, {NullOpt, thread_extent}); |
111 | tmp_sch->Bind(split_res[1], "threadIdx.x" ); |
112 | |
113 | return {tmp_sch, sch}; |
114 | } |
115 | |
116 | // Inherited from ScheduleRuleNode |
117 | ScheduleRule Clone() const final { |
118 | ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>(*this); |
119 | return ScheduleRule(n); |
120 | } |
121 | |
122 | private: |
123 | /*! |
124 | * \brief Check whether the input block is in thread scope, i.e., some of its outer loop is |
125 | * bound to threadIdx. |
126 | * \param sch The TensorIR schedule |
127 | * \param block The block to be checked |
128 | * \return A boolean indicating whether the block is in thread scope. |
129 | */ |
130 | bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) { |
131 | const Array<tir::LoopRV>& axes = sch->GetLoops(block); |
132 | for (const tir::LoopRV& loop_rv : axes) { |
133 | const tir::For& loop = sch->Get(loop_rv); |
134 | runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get()); |
135 | if (tir::IsThreadIdx(thread_scope)) { |
136 | return true; |
137 | } |
138 | } |
139 | return false; |
140 | } |
141 | |
142 | /*! |
143 | * \brief Get the ExprRV which used to define the extent of a given loop. |
144 | * \param trace The trace of the schedule, where the extent is to be found |
145 | * \param loop The loop whose extent is to be found |
146 | * \param extent The finding result |
147 | * \return Whether the find is successful. |
148 | */ |
149 | bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop, |
150 | tir::ExprRV* extent) { |
151 | for (const tir::Instruction& inst : trace->insts) { |
152 | if (inst->kind->name == "Split" ) { |
153 | int i = std::find(inst->outputs.begin(), inst->outputs.end(), loop) - inst->outputs.begin(); |
154 | CHECK(inst->inputs[1 + i].defined()) |
155 | << "ValueError: Extracting an extent which needs inference is not supported so far" ; |
156 | *extent = Downcast<tir::ExprRV>(inst->inputs[1 + i]); |
157 | return true; |
158 | } |
159 | } |
160 | return false; |
161 | } |
162 | |
163 | /*! |
164 | * \brief Get the ExprRV extent of "threadIdx.x" in the given schedule trace. |
165 | * \param trace The trace of the schedule, where the extent is to be found |
166 | * \return The extent of "threadIdx.x" in the input schedule |
167 | */ |
168 | tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) { |
169 | tir::ExprRV extent{nullptr}; |
170 | for (const tir::Instruction& inst : trace->insts) { |
171 | if (inst->kind->name == "Bind" && Downcast<String>(inst->attrs[0]) == "threadIdx.x" ) { |
172 | if (GetLoopRVExtentSource(trace, Downcast<tir::LoopRV>(inst->inputs[0]), &extent)) { |
173 | return extent; |
174 | } |
175 | } |
176 | } |
177 | CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\"" ; |
178 | throw; |
179 | } |
180 | |
181 | /*! |
182 | * \brief Get the compute-at target loop and the first block under the target loop. |
183 | * \param sch The TensorIR schedule |
184 | * \param block_rv The block whose compute-at target loop is queried |
185 | * \return A tuple consisting of |
186 | * 1. a boolean indicating whether the block can be computed at some target loop (a.k.a. fusible); |
187 | * 2. the compute-at target loop when fusible, or a null loop random variable; |
188 | * 3. the first block under the target loop when fusible, or a null block random variable; |
189 | * 4. the innermost loop outside the target block when fusible, or a null block random variable. |
190 | */ |
191 | std::tuple<bool, tir::LoopRV, tir::BlockRV, tir::LoopRV> GetComputeTargetLoopAndBlock( |
192 | const tir::Schedule& sch, const tir::BlockRV& block_rv) { |
193 | // Step 0. Due to technical reason of some primitives (e.g., compute-at), if the block is doing |
194 | // a tuple reduction, fusion is temporarily not supported. |
195 | if (sch->Get(block_rv)->writes.size() != 1) { |
196 | return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, |
197 | tir::LoopRV{nullptr}); |
198 | } |
199 | |
200 | // Step 1. Get all the consumers of the input block. |
201 | Array<tir::BlockRV> consumers = sch->GetConsumers(block_rv); |
202 | |
203 | // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is |
204 | // not fusible. |
205 | if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) { |
206 | return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, |
207 | tir::LoopRV{nullptr}); |
208 | } |
209 | |
210 | // Step 3. Calculate the lowest common ancestor of all the consumers. |
211 | // - If the lowest common ancestor is a block: |
212 | // - if there is only one consumer, the target block is that consumer; |
213 | // - if there are multiple consumers, they must not share a common loop, and the case is not |
214 | // fusible; |
215 | // - If the lowest common ancestor is a loop, the target block is also the first consumer. |
216 | const tir::StmtSRef& lca_sref = |
217 | tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers)); |
218 | if (consumers.size() > 1 && lca_sref->StmtAs<tir::BlockNode>() != nullptr) { |
219 | return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, |
220 | tir::LoopRV{nullptr}); |
221 | } |
222 | |
223 | // Step 4. Get the outer loops of the target block, and get the compute-at position index. |
224 | Array<tir::LoopRV> tgt_block_loops = sch->GetLoops(consumers[0]); |
225 | int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref); |
226 | |
227 | // Step 5. A negative position index means not fusible, and vice-versa. |
228 | if (pos < 0) { |
229 | return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, |
230 | tir::LoopRV{nullptr}); |
231 | } else { |
232 | return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back()); |
233 | } |
234 | } |
235 | |
236 | /*! |
237 | * \brief Get the compute-at position index of the input block, according to |
238 | * 1. the loops outside the input block; |
239 | * 2. the loops outside the target block; |
240 | * 3. the lowest common ancestor of all the consumers of the input block. |
241 | * \param sch The TensorIR schedule |
242 | * \param block_loops The loops outside the input block |
243 | * \param tgt_block_loops The loops outside the target block |
244 | * \param lca_sref The lowest common ancestor of all the consumers of the input block |
245 | * \return The compute-at position index of the input block |
246 | */ |
247 | int GetComputePosition(const tir::Schedule& sch, const Array<tir::LoopRV>& block_loops, |
248 | const Array<tir::LoopRV>& tgt_block_loops, const tir::StmtSRef& lca_sref) { |
249 | int n_block_loop = static_cast<int>(block_loops.size()); |
250 | int n_tgt_block_loop = static_cast<int>(tgt_block_loops.size()); |
251 | |
252 | for (int i = 0; i < n_block_loop && i < n_tgt_block_loop; ++i) { |
253 | if (tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) { |
254 | return i - 1; |
255 | } else if (sch->GetSRef(tgt_block_loops[i]).same_as(lca_sref)) { |
256 | // If the lowest common ancestor is a loop, the compute location of the input block should |
257 | // not be deeper than the LCA loop. |
258 | return i; |
259 | } |
260 | } |
261 | return std::min(n_block_loop, n_tgt_block_loop) - 1; |
262 | } |
263 | |
264 | public: |
265 | /*! \brief The maximum number of threads allowed in a thread block */ |
266 | int max_threads_per_block; |
267 | /*! \brief The number of threads per warp */ |
268 | int warp_size; |
269 | /*! \brief Candidates of thread axis extent (values are required to be positive). */ |
270 | Array<Integer> thread_extents; |
271 | |
272 | void VisitAttrs(tvm::AttrVisitor* v) { |
273 | v->Visit("max_threads_per_block" , &max_threads_per_block); |
274 | v->Visit("warp_size" , &warp_size); |
275 | v->Visit("thread_extents" , &thread_extents); |
276 | } |
277 | |
278 | static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction" ; |
279 | TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); |
280 | }; |
281 | |
282 | ScheduleRule ScheduleRule::CrossThreadReduction(Array<Integer> thread_extents) { |
283 | for (const Integer& extent : thread_extents) { |
284 | CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive" ; |
285 | } |
286 | ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>(); |
287 | n->thread_extents = std::move(thread_extents); |
288 | return ScheduleRule(n); |
289 | } |
290 | |
291 | TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode); |
292 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction" ) |
293 | .set_body_typed(ScheduleRule::CrossThreadReduction); |
294 | |
295 | } // namespace meta_schedule |
296 | } // namespace tvm |
297 | |