1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include "../utils.h"
20
21namespace tvm {
22namespace meta_schedule {
23
24class CrossThreadReductionNode : public ScheduleRuleNode {
25 public:
26 // Inherited from ScheduleRuleNode
27 void InitializeWithTuneContext(const TuneContext& context) final {
28 ICHECK(context->target.defined());
29 Target target = context->target.value();
30
31 Optional<Integer> opt_max_threads_per_block = target->GetAttr<Integer>("max_threads_per_block");
32 Optional<Integer> opt_warp_size = target->GetAttr<Integer>("thread_warp_size");
33
34 if (!opt_max_threads_per_block.defined()) {
35 TVM_PY_LOG(WARNING, context->logger)
36 << "Target does not have attribute \"max_threads_per_block\", therefore the "
37 "rule CrossThreadReduction will not be applied";
38 }
39 if (!opt_warp_size.defined()) {
40 TVM_PY_LOG(WARNING, context->logger)
41 << "Target does not have attribute \"thread_warp_size\", therefore the rule "
42 "CrossThreadReduction will not be applied";
43 }
44 max_threads_per_block = opt_max_threads_per_block.value_or(Integer(-1))->value;
45 warp_size = opt_warp_size.value_or(Integer(-1))->value;
46 }
47
48 // Inherited from ScheduleRuleNode
49 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
50 // Step 0. Check the conditions of this rule.
51 if (max_threads_per_block == -1 || warp_size == -1) {
52 return {sch};
53 }
54 const tir::StmtSRef& block_sref = sch->GetSRef(block_rv);
55 if (!NeedsRFactorOrCrossThreadReduction(sch->state(), block_sref, max_threads_per_block,
56 warp_size)) {
57 return {sch};
58 }
59
60 // Step 1. Make a copy of the original schedule. The new copy is used for scheduling.
61 tir::Schedule tmp_sch = sch->Copy();
62 tmp_sch->Seed(sch->ForkSeed());
63
64 // Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the
65 // block to its consumers. We want to fuse as much as possible because it results in
66 // significantly faster schedule.
67 // `target_loop` is the loop position where the input block will be computed at.
68 // `target_block` is the consumer block that we want to compute-at the input block to.
69 // `tgt_block_innermost_loop` is the innermost loop outside the target block.
70
71 auto [fusible, target_loop, target_block, tgt_block_innermost_loop] =
72 GetComputeTargetLoopAndBlock(tmp_sch, block_rv);
73
74 // Step 3. Try block fusion.
75 int n_candidate = static_cast<int>(thread_extents.size());
76 Array<FloatImm> probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate));
77 tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs);
78 if (fusible) {
79 ICHECK(target_block.defined());
80 ICHECK(target_loop.defined());
81
82 // Step 3.1.
83 // - If the outer loops of `target_block` haven't been bound to "threadIdx.x", we should first
84 // bound the innermost outer loop of `target_block` to threadIdx. Possibly we need to split
85 // the loop before binding.
86 // - Otherwise, we search for the extent of "threadIdx.x" and use it as the split factor.
87 if (!InThreadScope(tmp_sch, target_block)) {
88 const Array<tir::LoopRV>& split_res =
89 tmp_sch->Split(tgt_block_innermost_loop, {NullOpt, thread_extent});
90 tmp_sch->Bind(split_res[1], "threadIdx.x");
91 if (tgt_block_innermost_loop.same_as(target_loop)) {
92 target_loop = split_res[0];
93 }
94 } else {
95 thread_extent = GetThreadIdxExtentFromTrace(tmp_sch->trace().value());
96 }
97 // Step 3.2. Do the compute-at.
98 tmp_sch->ComputeAt(block_rv, target_loop, /*preserve_unit_loops=*/true);
99 // Step 3.3. Set the storage scope of the output buffer to shared memory.
100 tmp_sch->SetScope(block_rv, /*buffer_index=*/0, /*storage_scope=*/"shared");
101 }
102
103 // Step 4. Reorder the loop axes if reduction loops are not innermost. After the reordering,
104 // fuse all the reduction loops.
105 size_t num_spatial_loops;
106 tir::LoopRV fused_reduce_loop;
107 ReorderAndFuseReductionLoops(tmp_sch, block_rv, &fused_reduce_loop, &num_spatial_loops);
108 // Step 5. Split the fused reduction loop and bind the inner one to threadIdx.
109 const Array<tir::LoopRV>& split_res =
110 tmp_sch->Split(fused_reduce_loop, {NullOpt, thread_extent});
111 tmp_sch->Bind(split_res[1], "threadIdx.x");
112
113 return {tmp_sch, sch};
114 }
115
116 // Inherited from ScheduleRuleNode
117 ScheduleRule Clone() const final {
118 ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>(*this);
119 return ScheduleRule(n);
120 }
121
122 private:
123 /*!
124 * \brief Check whether the input block is in thread scope, i.e., some of its outer loop is
125 * bound to threadIdx.
126 * \param sch The TensorIR schedule
127 * \param block The block to be checked
128 * \return A boolean indicating whether the block is in thread scope.
129 */
130 bool InThreadScope(const tir::Schedule& sch, const tir::BlockRV& block) {
131 const Array<tir::LoopRV>& axes = sch->GetLoops(block);
132 for (const tir::LoopRV& loop_rv : axes) {
133 const tir::For& loop = sch->Get(loop_rv);
134 runtime::ThreadScope thread_scope = tir::GetThreadScope(loop.get());
135 if (tir::IsThreadIdx(thread_scope)) {
136 return true;
137 }
138 }
139 return false;
140 }
141
142 /*!
143 * \brief Get the ExprRV which used to define the extent of a given loop.
144 * \param trace The trace of the schedule, where the extent is to be found
145 * \param loop The loop whose extent is to be found
146 * \param extent The finding result
147 * \return Whether the find is successful.
148 */
149 bool GetLoopRVExtentSource(const tir::Trace& trace, const tir::LoopRV& loop,
150 tir::ExprRV* extent) {
151 for (const tir::Instruction& inst : trace->insts) {
152 if (inst->kind->name == "Split") {
153 int i = std::find(inst->outputs.begin(), inst->outputs.end(), loop) - inst->outputs.begin();
154 CHECK(inst->inputs[1 + i].defined())
155 << "ValueError: Extracting an extent which needs inference is not supported so far";
156 *extent = Downcast<tir::ExprRV>(inst->inputs[1 + i]);
157 return true;
158 }
159 }
160 return false;
161 }
162
163 /*!
164 * \brief Get the ExprRV extent of "threadIdx.x" in the given schedule trace.
165 * \param trace The trace of the schedule, where the extent is to be found
166 * \return The extent of "threadIdx.x" in the input schedule
167 */
168 tir::ExprRV GetThreadIdxExtentFromTrace(const tir::Trace& trace) {
169 tir::ExprRV extent{nullptr};
170 for (const tir::Instruction& inst : trace->insts) {
171 if (inst->kind->name == "Bind" && Downcast<String>(inst->attrs[0]) == "threadIdx.x") {
172 if (GetLoopRVExtentSource(trace, Downcast<tir::LoopRV>(inst->inputs[0]), &extent)) {
173 return extent;
174 }
175 }
176 }
177 CHECK(false) << "ValueError: Unable to get the extent of \"threadIdx.x\"";
178 throw;
179 }
180
181 /*!
182 * \brief Get the compute-at target loop and the first block under the target loop.
183 * \param sch The TensorIR schedule
184 * \param block_rv The block whose compute-at target loop is queried
185 * \return A tuple consisting of
186 * 1. a boolean indicating whether the block can be computed at some target loop (a.k.a. fusible);
187 * 2. the compute-at target loop when fusible, or a null loop random variable;
188 * 3. the first block under the target loop when fusible, or a null block random variable;
189 * 4. the innermost loop outside the target block when fusible, or a null block random variable.
190 */
191 std::tuple<bool, tir::LoopRV, tir::BlockRV, tir::LoopRV> GetComputeTargetLoopAndBlock(
192 const tir::Schedule& sch, const tir::BlockRV& block_rv) {
193 // Step 0. Due to technical reason of some primitives (e.g., compute-at), if the block is doing
194 // a tuple reduction, fusion is temporarily not supported.
195 if (sch->Get(block_rv)->writes.size() != 1) {
196 return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
197 tir::LoopRV{nullptr});
198 }
199
200 // Step 1. Get all the consumers of the input block.
201 Array<tir::BlockRV> consumers = sch->GetConsumers(block_rv);
202
203 // Step 2. If the block has no consumer or the first consumer needs multi-level tiling, it is
204 // not fusible.
205 if (consumers.empty() || tir::NeedsMultiLevelTiling(sch->state(), sch->GetSRef(consumers[0]))) {
206 return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
207 tir::LoopRV{nullptr});
208 }
209
210 // Step 3. Calculate the lowest common ancestor of all the consumers.
211 // - If the lowest common ancestor is a block:
212 // - if there is only one consumer, the target block is that consumer;
213 // - if there are multiple consumers, they must not share a common loop, and the case is not
214 // fusible;
215 // - If the lowest common ancestor is a loop, the target block is also the first consumer.
216 const tir::StmtSRef& lca_sref =
217 tir::GetSRefLowestCommonAncestor(tir::BlockRVs2StmtSRefs(sch, consumers));
218 if (consumers.size() > 1 && lca_sref->StmtAs<tir::BlockNode>() != nullptr) {
219 return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
220 tir::LoopRV{nullptr});
221 }
222
223 // Step 4. Get the outer loops of the target block, and get the compute-at position index.
224 Array<tir::LoopRV> tgt_block_loops = sch->GetLoops(consumers[0]);
225 int pos = GetComputePosition(sch, sch->GetLoops(block_rv), tgt_block_loops, lca_sref);
226
227 // Step 5. A negative position index means not fusible, and vice-versa.
228 if (pos < 0) {
229 return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
230 tir::LoopRV{nullptr});
231 } else {
232 return std::make_tuple(true, tgt_block_loops[pos], consumers[0], tgt_block_loops.back());
233 }
234 }
235
236 /*!
237 * \brief Get the compute-at position index of the input block, according to
238 * 1. the loops outside the input block;
239 * 2. the loops outside the target block;
240 * 3. the lowest common ancestor of all the consumers of the input block.
241 * \param sch The TensorIR schedule
242 * \param block_loops The loops outside the input block
243 * \param tgt_block_loops The loops outside the target block
244 * \param lca_sref The lowest common ancestor of all the consumers of the input block
245 * \return The compute-at position index of the input block
246 */
247 int GetComputePosition(const tir::Schedule& sch, const Array<tir::LoopRV>& block_loops,
248 const Array<tir::LoopRV>& tgt_block_loops, const tir::StmtSRef& lca_sref) {
249 int n_block_loop = static_cast<int>(block_loops.size());
250 int n_tgt_block_loop = static_cast<int>(tgt_block_loops.size());
251
252 for (int i = 0; i < n_block_loop && i < n_tgt_block_loop; ++i) {
253 if (tir::GetLoopIterType(sch->GetSRef(block_loops[i])) != tir::IterVarType::kDataPar) {
254 return i - 1;
255 } else if (sch->GetSRef(tgt_block_loops[i]).same_as(lca_sref)) {
256 // If the lowest common ancestor is a loop, the compute location of the input block should
257 // not be deeper than the LCA loop.
258 return i;
259 }
260 }
261 return std::min(n_block_loop, n_tgt_block_loop) - 1;
262 }
263
264 public:
265 /*! \brief The maximum number of threads allowed in a thread block */
266 int max_threads_per_block;
267 /*! \brief The number of threads per warp */
268 int warp_size;
269 /*! \brief Candidates of thread axis extent (values are required to be positive). */
270 Array<Integer> thread_extents;
271
272 void VisitAttrs(tvm::AttrVisitor* v) {
273 v->Visit("max_threads_per_block", &max_threads_per_block);
274 v->Visit("warp_size", &warp_size);
275 v->Visit("thread_extents", &thread_extents);
276 }
277
278 static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction";
279 TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode);
280};
281
282ScheduleRule ScheduleRule::CrossThreadReduction(Array<Integer> thread_extents) {
283 for (const Integer& extent : thread_extents) {
284 CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive";
285 }
286 ObjectPtr<CrossThreadReductionNode> n = make_object<CrossThreadReductionNode>();
287 n->thread_extents = std::move(thread_extents);
288 return ScheduleRule(n);
289}
290
291TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode);
292TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction")
293 .set_body_typed(ScheduleRule::CrossThreadReduction);
294
295} // namespace meta_schedule
296} // namespace tvm
297