1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file lower_cross_thread_reduction.cc |
22 | */ |
23 | #include <tvm/arith/analyzer.h> |
24 | #include <tvm/tir/analysis.h> |
25 | #include <tvm/tir/stmt_functor.h> |
26 | #include <tvm/tir/transform.h> |
27 | |
28 | #include "../schedule/analysis.h" |
29 | #include "./ir_utils.h" |
30 | |
31 | namespace tvm { |
32 | namespace tir { |
33 | |
34 | /*! |
35 | * \brief Checks if a loop is bound to threadIdx.x/y/z |
36 | * \brief loop The loop to be checked |
37 | * \return True if the loop is bound to threadIdx.x/y/z |
38 | */ |
39 | bool IsBoundToThreadIdx(const ForNode* loop) { |
40 | if (!loop->thread_binding.defined()) { |
41 | return false; |
42 | } |
43 | runtime::ThreadScope scope = |
44 | runtime::ThreadScope::Create(loop->thread_binding.value()->thread_tag); |
45 | return scope.rank == 1 && scope.dim_index >= 0; |
46 | } |
47 | |
48 | /*! |
49 | * \brief Check the dominant property of a block: |
50 | * the block is the only writer of its output, dominating the reader of its output buffers |
51 | * \param scope_block The scope block of the block to be checked |
52 | * \param block The block whose dominant property is to be checked |
53 | * \return A boolean indicating if the block is a dominant block |
54 | */ |
55 | bool IsDominantBlock(const Block& scope_block, const Block& block) { |
56 | // Step 1. Count the number of writers for each buffer written by the scope block. |
57 | std::unordered_map<const BufferNode*, int> buffer_writer_cnt; |
58 | PreOrderVisit(scope_block->body, [&buffer_writer_cnt](const ObjectRef& obj) { |
59 | if (const auto* block = obj.as<BlockNode>()) { |
60 | for (const BufferRegion& buffer_region : block->writes) { |
61 | ++buffer_writer_cnt[buffer_region->buffer.get()]; |
62 | } |
63 | return false; |
64 | } |
65 | return true; |
66 | }); |
67 | // Step 2. Check whether `block` is the only writer of its outputs. |
68 | for (const BufferRegion& buffer_region : block->writes) { |
69 | ICHECK(buffer_writer_cnt.count(buffer_region->buffer.get())); |
70 | if (buffer_writer_cnt[buffer_region->buffer.get()] != 1) { |
71 | return false; |
72 | } |
73 | } |
74 | return true; |
75 | } |
76 | |
77 | /*! |
78 | * \brief Check whether the input block is a reduction block. |
79 | * \param realize The block to be checked |
80 | * \param loop_range_map The mapping from the loop variables outside the input block to their ranges |
81 | * \param scope_block The scope block of the input block |
82 | * \param analyzer The analyzer |
83 | * \return A boolean indicating whether the input block is a reduction block. |
84 | * \note A similar check has been implemented in "src/tir/schedule/analysis.h", but that check is |
85 | * based on `tir.Schedule`. Here we have no schedule information, and thus we must implement the |
86 | * check again. |
87 | */ |
88 | bool IsReductionBlock(const BlockRealize& realize, const Map<Var, Range>& loop_range_map, |
89 | const Block& scope_block, arith::Analyzer* analyzer) { |
90 | const auto* block = realize->block.as<BlockNode>(); |
91 | // Cond 1. The block has the `init` statement. |
92 | if (!block->init.defined()) { |
93 | return false; |
94 | } |
95 | // Cond 2. All the block bindings are quasi-affine expressions. |
96 | if (!IsAffineBinding(realize, loop_range_map, analyzer)) { |
97 | return false; |
98 | } |
99 | // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile, |
100 | // we collect all the reduction block vars. |
101 | if (!ContainsOnlyDataParAndReductionBlockIter(block->iter_vars)) { |
102 | return false; |
103 | } |
104 | // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its |
105 | // output buffers. |
106 | if (!IsDominantBlock(scope_block, GetRef<Block>(block))) { |
107 | return false; |
108 | } |
109 | // Cond 5. The reduction block vars are not used to index the output buffers. |
110 | return ReductionIterNotIndexOutputBuffer(GetRef<Block>(block)); |
111 | } |
112 | |
113 | /*! |
114 | * \brief Create intermediate buffers according to the input buffers and buffer kind |
115 | * \param reduction_buffers The old reduction buffers which provide the buffer names and data types |
116 | * \param is_cross_thread_buffer A boolean indicating whether to create buffers for the cross-thread |
117 | * computation results or not, which is used for determine the buffer name prefix |
118 | * \return The created buffers |
119 | */ |
120 | Array<Buffer> MakeScratchpads(const Array<Buffer>& reduction_buffers, bool is_cross_thread_buffer) { |
121 | Array<Buffer> new_buffers; |
122 | new_buffers.reserve(reduction_buffers.size()); |
123 | for (const Buffer& buffer : reduction_buffers) { |
124 | String name = is_cross_thread_buffer ? "cross" : "in" ; |
125 | name = name + "_thread_" + buffer->name; |
126 | new_buffers.push_back(Buffer(/*ptr=*/Var(name, PointerType(PrimType(buffer->dtype), "local" )), |
127 | /*dtype=*/buffer->dtype, |
128 | /*shape=*/{Integer(1)}, |
129 | /*strides=*/{Integer(1)}, |
130 | /*elem_offset=*/PrimExpr{nullptr}, |
131 | /*name=*/name, |
132 | /*data_alignment=*/0, |
133 | /*offset_factor=*/0, |
134 | /*buffer_type=*/kDefault)); |
135 | } |
136 | return new_buffers; |
137 | } |
138 | |
139 | /*! |
140 | * \brief Substitute given source buffers with given target buffers respectively in the input |
141 | * statement |
142 | */ |
143 | class BufferReplacer : private StmtExprMutator { |
144 | public: |
145 | static Stmt Run(Array<Buffer> src_buffers, Array<Buffer> tgt_buffers, Stmt stmt) { |
146 | Map<Buffer, Buffer> buffer_map; |
147 | ICHECK_EQ(src_buffers.size(), tgt_buffers.size()); |
148 | int n_buffers = src_buffers.size(); |
149 | for (int i = 0; i < n_buffers; ++i) { |
150 | buffer_map.Set(src_buffers[i], tgt_buffers[i]); |
151 | } |
152 | return BufferReplacer(buffer_map)(std::move(stmt)); |
153 | } |
154 | |
155 | private: |
156 | explicit BufferReplacer(Map<Buffer, Buffer> buffer_map) : buffer_map_(std::move(buffer_map)) {} |
157 | |
158 | PrimExpr VisitExpr_(const BufferLoadNode* load) final { |
159 | auto it = buffer_map_.find(load->buffer); |
160 | return it != buffer_map_.end() ? BufferLoad((*it).second, {0}) : GetRef<BufferLoad>(load); |
161 | } |
162 | |
163 | Stmt VisitStmt_(const BufferStoreNode* store) final { |
164 | auto it = buffer_map_.find(store->buffer); |
165 | if (it != buffer_map_.end()) { |
166 | PrimExpr value = StmtExprMutator::VisitExpr(store->value); |
167 | return BufferStore((*it).second, std::move(value), {0}); |
168 | } else { |
169 | return StmtMutator::VisitStmt_(store); |
170 | } |
171 | } |
172 | |
173 | Map<Buffer, Buffer> buffer_map_; |
174 | }; |
175 | |
176 | /*! |
177 | * \brief Substitute a given source block with a given target block, or remove the source block |
178 | * branch from the AST if the target block is undefined |
179 | */ |
180 | class InThreadReducerMaker : private StmtMutator { |
181 | public: |
182 | static Optional<Stmt> Make(const BlockRealizeNode* src_realize, |
183 | Optional<BlockRealize> tgt_realize, Stmt stmt) { |
184 | return InThreadReducerMaker(src_realize, std::move(tgt_realize))(std::move(stmt)); |
185 | } |
186 | |
187 | private: |
188 | explicit InThreadReducerMaker(const BlockRealizeNode* src_realize, |
189 | Optional<BlockRealize> tgt_realize) |
190 | : src_realize_(src_realize), tgt_realize_(tgt_realize) {} |
191 | Stmt VisitStmt_(const BlockRealizeNode* realize) final { |
192 | if (realize == src_realize_) { |
193 | return tgt_realize_.defined() // |
194 | ? tgt_realize_.value() |
195 | : Stmt{nullptr}; |
196 | } |
197 | return GetRef<BlockRealize>(realize); |
198 | } |
199 | |
200 | Stmt VisitStmt_(const ForNode* loop) final { |
201 | if (Optional<For> opt_res = Downcast<Optional<For>>(StmtMutator::VisitStmt_(loop))) { |
202 | For res = opt_res.value(); |
203 | if (res->thread_binding.defined()) { |
204 | return res->body; |
205 | } else { |
206 | return std::move(res); |
207 | } |
208 | } else { |
209 | return Stmt{nullptr}; |
210 | } |
211 | } |
212 | |
213 | Stmt VisitStmt_(const SeqStmtNode* seq) final { |
214 | Array<Stmt> stmts; |
215 | stmts.reserve(seq->size()); |
216 | for (const Stmt& stmt : seq->seq) { |
217 | if (Optional<Stmt> opt_res = VisitStmt(stmt)) { |
218 | stmts.push_back(opt_res.value()); |
219 | } |
220 | } |
221 | return stmts.empty() ? Stmt{nullptr} : SeqStmt::Flatten(stmts); |
222 | } |
223 | |
224 | const BlockRealizeNode* src_realize_; |
225 | Optional<BlockRealize> tgt_realize_; |
226 | }; |
227 | |
228 | /*! |
229 | * \brief Create the lowered allreduce block transformed from the input reduction block |
230 | * \param realize The block-realize which contains the old reduction block |
231 | * \param it_buffers The buffers to store in-thread reduction results |
232 | * \param ct_buffers The buffers to store cross-thread reduction results |
233 | * \param wb_buffers The buffers to store the final reduction results |
234 | * \param old_wb_indices The indices used to access the write-back buffers when storing the final |
235 | * reduction results into the write-back buffers |
236 | * \param reducer The reduction function |
237 | * \param combiner_rhs The RHS values of the combiner |
238 | * \param reduction_loops The reduction loops |
239 | */ |
240 | Stmt TransformReductionBlock(const BlockRealizeNode* realize, // |
241 | const Optional<Array<Buffer>>& it_buffers, // |
242 | const Array<Buffer>& ct_buffers, // |
243 | const Array<Buffer>& wb_buffers, // |
244 | const Array<PrimExpr>& old_wb_indices, // |
245 | const CommReducer& reducer, // |
246 | const Array<PrimExpr>& combiner_rhs, // |
247 | const std::vector<const ForNode*>& reduction_loops) { |
248 | int n_buffers = wb_buffers.size(); |
249 | const BlockNode* block = realize->block.get(); |
250 | |
251 | auto f_create_buffer_regions = [](Array<Buffer> buffers) { |
252 | Array<BufferRegion> regions; |
253 | regions.reserve(buffers.size()); |
254 | for (const Buffer& buffer : buffers) { |
255 | regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)})); |
256 | } |
257 | return regions; |
258 | }; |
259 | |
260 | Array<BufferRegion> ct_buffer_regions = f_create_buffer_regions(ct_buffers); |
261 | Optional<Array<BufferRegion>> it_buffer_regions = NullOpt; |
262 | if (it_buffers.defined()) { |
263 | it_buffer_regions = f_create_buffer_regions(it_buffers.value()); |
264 | } |
265 | // In total, the block is transformed into at most 4 statements |
266 | // - Stmt 1: initialize the buffer for in-thread reduction |
267 | // - Stmt 2: do in-thread reduction |
268 | // - Stmt 3: do cross-thread reduction |
269 | // - Stmt 4: write cross-thread reduction result to the original buffer |
270 | Array<Stmt> stmts; |
271 | stmts.reserve(4); |
272 | // Stmt 1: initialize the buffer for in-thread reduction |
273 | if (it_buffers.defined()) { |
274 | Array<Stmt> inits; |
275 | inits.reserve(n_buffers); |
276 | for (int i = 0; i < n_buffers; ++i) { |
277 | inits.push_back( |
278 | BufferStore(it_buffers.value()[i], reducer->identity_element[i], {Integer(0)})); |
279 | } |
280 | stmts.push_back(BlockRealize(/*iter_values=*/{}, |
281 | /*predicate=*/const_true(), |
282 | /*block=*/ |
283 | Block(/*iter_vars=*/{}, |
284 | /*reads=*/{}, |
285 | /*writes=*/it_buffer_regions.value(), |
286 | /*name_hint=*/block->name_hint + "_in_thread_init" , |
287 | /*body=*/n_buffers > 1 ? SeqStmt(inits) : inits[0]))); |
288 | } |
289 | // Stmt 2: do in-thread reduction |
290 | { |
291 | Optional<BlockRealize> new_realize = NullOpt; |
292 | // If need to generate in-thread reduction, |
293 | // then replace `wb_buffers` with `it_buffers` accordingly in given BlockRealize |
294 | // otherwise, directly remove given BlockRealize |
295 | if (it_buffers.defined()) { |
296 | ObjectPtr<BlockNode> new_block = make_object<BlockNode>(*block); |
297 | new_block->reads = std::move(new_block->reads); |
298 | new_block->writes = it_buffer_regions.value(); |
299 | new_block->name_hint = new_block->name_hint + "_in_thread" ; |
300 | new_block->body = |
301 | BufferReplacer::Run(wb_buffers, it_buffers.value(), std::move(new_block->body)); |
302 | new_block->init = NullOpt; |
303 | ObjectPtr<BlockRealizeNode> n = make_object<BlockRealizeNode>(*realize); |
304 | n->block = Block(new_block); |
305 | new_realize = BlockRealize(n); |
306 | } |
307 | For loop = GetRef<For>(reduction_loops[0]); |
308 | if (Optional<Stmt> stmt = InThreadReducerMaker::Make(realize, new_realize, std::move(loop))) { |
309 | stmts.push_back(stmt.value()); |
310 | } |
311 | } |
312 | // Stmt 3: do cross-thread reduction |
313 | { |
314 | // Step 3.1. Create the parameters to the intrinsic |
315 | Array<PrimExpr> parameters; |
316 | parameters.reserve(reduction_loops.size() + 4); |
317 | // 1-st argument: number of buffers |
318 | parameters.push_back(make_const(DataType::UInt(32), n_buffers)); |
319 | // Next `n_buffers` arguments: sources |
320 | if (it_buffers.defined()) { |
321 | for (int i = 0; i < n_buffers; ++i) { |
322 | parameters.push_back(BufferLoad(it_buffers.value()[i], {Integer(0)})); |
323 | } |
324 | } else { |
325 | parameters.insert(parameters.end(), combiner_rhs.begin(), combiner_rhs.end()); |
326 | } |
327 | // Next argument: predicate |
328 | parameters.push_back(const_true()); |
329 | // Next `n_buffers` arguments: destinations |
330 | for (int i = 0; i < n_buffers; ++i) { |
331 | parameters.push_back(BufferLoad(ct_buffers[i], {0})); |
332 | } |
333 | // Next arguments: all the reduction threads |
334 | for (const ForNode* reduction_loop : reduction_loops) { |
335 | if (reduction_loop->thread_binding.defined()) { |
336 | parameters.push_back(reduction_loop->loop_var); |
337 | } |
338 | } |
339 | // Step 3.2. Create the block and the block-realize. |
340 | Array<IterVar> iter_vars{nullptr}; |
341 | Array<PrimExpr> bindings{nullptr}; |
342 | Array<BufferRegion> reads{nullptr}; |
343 | if (it_buffers.defined()) { |
344 | iter_vars = Array<IterVar>{}; |
345 | bindings = Array<PrimExpr>{}; |
346 | reads = it_buffer_regions.value(); |
347 | } else { |
348 | iter_vars = block->iter_vars; |
349 | bindings = realize->iter_values; |
350 | reads = block->reads; |
351 | } |
352 | stmts.push_back(BlockRealize( |
353 | /*iter_values=*/std::move(bindings), |
354 | /*predicate=*/const_true(), |
355 | /*block=*/ |
356 | Block(/*iter_vars=*/std::move(iter_vars), |
357 | /*reads=*/std::move(reads), |
358 | /*writes=*/ct_buffer_regions, |
359 | /*name_hint=*/block->name_hint + "_cross_thread" , |
360 | /*body=*/ |
361 | AttrStmt(/*node=*/reducer, |
362 | /*attr_key=*/tir::attr::reduce_scope, |
363 | /*value=*/make_zero(DataType::Handle()), |
364 | /*body=*/ |
365 | Evaluate(Call(/*dtype=*/DataType::Handle(), |
366 | /*op=*/tir::builtin::tvm_thread_allreduce(), |
367 | /*args=*/std::move(parameters))))))); |
368 | } |
369 | // Stmt 4: write cross-thread reduction result to the original buffer |
370 | { |
371 | ICHECK_EQ(block->iter_vars.size(), realize->iter_values.size()); |
372 | int n_iter = static_cast<int>(block->iter_vars.size()); |
373 | Array<IterVar> iter_vars; |
374 | Array<PrimExpr> bindings; |
375 | Map<Var, PrimExpr> var_map; |
376 | iter_vars.reserve(n_iter); |
377 | bindings.reserve(n_iter); |
378 | for (int i = 0; i < n_iter; ++i) { |
379 | const IterVar& iter_var = block->iter_vars[i]; |
380 | const PrimExpr& binding = realize->iter_values[i]; |
381 | if (iter_var->iter_type != kCommReduce) { |
382 | IterVar new_iter_var{nullptr}; |
383 | { |
384 | ObjectPtr<IterVarNode> n = make_object<IterVarNode>(*iter_var.get()); |
385 | ObjectPtr<VarNode> v = make_object<VarNode>(*iter_var->var.get()); |
386 | n->var = Var(v); |
387 | new_iter_var = IterVar(n); |
388 | } |
389 | iter_vars.push_back(new_iter_var); |
390 | bindings.push_back(binding); |
391 | var_map.Set(iter_var->var, new_iter_var->var); |
392 | } |
393 | } |
394 | Array<Stmt> wb_updates; |
395 | Array<BufferRegion> wb_regions; |
396 | wb_updates.reserve(n_buffers); |
397 | wb_regions.reserve(n_buffers); |
398 | int n_dim = static_cast<int>(old_wb_indices.size()); |
399 | Array<Range> region = Substitute(block->writes[0]->region, var_map); |
400 | Array<PrimExpr> wb_indices; |
401 | wb_indices.reserve(n_dim); |
402 | for (int d = 0; d < n_dim; ++d) { |
403 | wb_indices.push_back(Substitute(old_wb_indices[d], var_map)); |
404 | } |
405 | for (int i = 0; i < n_buffers; ++i) { |
406 | wb_updates.push_back( |
407 | BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices)); |
408 | wb_regions.push_back(BufferRegion(wb_buffers[i], region)); |
409 | } |
410 | stmts.push_back(BlockRealize( |
411 | /*iter_values=*/std::move(bindings), |
412 | /*predicate=*/const_true(), |
413 | /*block=*/ |
414 | Block(/*iter_vars=*/std::move(iter_vars), |
415 | /*reads=*/std::move(ct_buffer_regions), |
416 | /*writes=*/std::move(wb_regions), |
417 | /*name_hint=*/block->name_hint + "_write_back" , |
418 | /*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0]))); |
419 | } |
420 | // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx |
421 | Stmt new_stmt = SeqStmt::Flatten(std::move(stmts)); |
422 | for (auto rit = reduction_loops.rbegin(); rit != reduction_loops.rend(); ++rit) { |
423 | const ForNode* loop = *rit; |
424 | if (loop->thread_binding.defined()) { |
425 | ObjectPtr<ForNode> n = make_object<ForNode>(*loop); |
426 | n->body = std::move(new_stmt); |
427 | new_stmt = For(n); |
428 | } |
429 | } |
430 | return new_stmt; |
431 | } |
432 | |
433 | /*! |
434 | * \brief Detect cross-thread reduction pattern and then transform |
435 | */ |
436 | class CrossThreadReductionTransformer : public StmtMutator { |
437 | private: |
438 | // Check if the input block needs cross-thread reduction. |
439 | std::vector<const ForNode*> NeedCrossThreadReduction(const BlockRealizeNode* realize) { |
440 | // Step 0. If the block is the root block, just return. |
441 | if (block_stack_.empty()) { |
442 | return {}; |
443 | } |
444 | |
445 | // Step 1. If the block is not a reduction block, cross-thread reduction is not needed. |
446 | if (!IsReductionBlock(GetRef<BlockRealize>(realize), loop_range_map_, |
447 | GetRef<Block>(block_stack_.back()), &analyzer_)) { |
448 | return {}; |
449 | } |
450 | |
451 | // Step 2. Collect all the vars that appear in the bindings of reduction block iters. |
452 | std::unordered_set<const VarNode*> reduction_vars; |
453 | GetVarsTouchedByBlockIters(GetRef<BlockRealize>(realize), nullptr, &reduction_vars); |
454 | |
455 | // Step 3. Collect the loops whose loop vars appear in the bindings of reduction block iters. |
456 | // We call these loops "reduction-related". |
457 | // Step 4. See whether at least one reduction-related loop is bound to thread axis in GPU - if |
458 | // so, cross-thread reduction is needed. If none of the reduction-related loops is bound to |
459 | // thread axis, cross-thread reduction is not needed for the input block. |
460 | bool need = false; |
461 | std::vector<const ForNode*> reduction_loops; |
462 | for (const ForNode* loop : loop_stack_) { |
463 | if (reduction_vars.count(loop->loop_var.get())) { |
464 | // Step 3. Collect the loop. |
465 | reduction_loops.push_back(loop); |
466 | // Step 4. See whether the loop is bound to some thread axis. |
467 | if (loop->thread_binding.defined()) { |
468 | need = true; |
469 | } |
470 | } |
471 | } |
472 | return need ? reduction_loops : std::vector<const ForNode*>{}; |
473 | } |
474 | |
475 | /*! |
476 | * \brief Given that the input block needs cross-thread reduction, check if cross-thread reduction |
477 | * can be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread |
478 | * reduction) |
479 | * \param block The block to be checked |
480 | * \param reduction_loops The reduction loops above the block |
481 | * \return A tuple consisting of five elements: |
482 | * - an integer which indicates the number of reduction loops that are bound to thread axes, |
483 | * - the detected commutative reducer of the reduction, |
484 | * - the reduction buffers which store the reduction results, |
485 | * - the RHS values of the reduction updates, |
486 | * - the indices which is used to access the reduction buffers when storing the reduction results |
487 | */ |
488 | std::tuple<int, CommReducer, Array<Buffer>, Array<PrimExpr>, Array<PrimExpr>> |
489 | CheckCanApplyCrossThreadReduction(const BlockNode* block, |
490 | const std::vector<const ForNode*>& reduction_loops) const { |
491 | // Condition 1. All the reduction-related loops should be the deepest among all statements |
492 | // outside the block (ignoring SeqStmt here). |
493 | int n_deepest_reduction_loops = 0; |
494 | for (auto rit = statement_stack_.rbegin() + 1; rit != statement_stack_.rend(); ++rit) { |
495 | const StmtNode* stmt = *rit; |
496 | if ((*rit)->IsInstance<SeqStmtNode>()) { |
497 | // Skip SeqStmt. |
498 | continue; |
499 | } |
500 | if (std::find(reduction_loops.begin(), reduction_loops.end(), |
501 | reinterpret_cast<const ForNode*>(stmt)) == reduction_loops.end()) { |
502 | break; |
503 | } |
504 | ++n_deepest_reduction_loops; |
505 | } |
506 | CHECK_EQ(n_deepest_reduction_loops, reduction_loops.size()) |
507 | << "ValueError: Cross-thread reduction requires all the reduction-related loops to be the " |
508 | "deepest among all statements outside the desired block. However, block " |
509 | << block->name_hint |
510 | << " needs cross-thread reduction, while the reduction-related loops outside of it are not " |
511 | "the deepest statements, which violates the condition." ; |
512 | |
513 | // Condition 2. All the reduction-related loops that are bound to thread axes should only be |
514 | // bound to `threadIdx.x/y/z`. |
515 | int n_bound_reduction_loops = 0; |
516 | for (const ForNode* reduction_loop : reduction_loops) { |
517 | if (reduction_loop->thread_binding.defined()) { |
518 | ++n_bound_reduction_loops; |
519 | CHECK(IsBoundToThreadIdx(reduction_loop)) |
520 | << "ValueError: Cross-thread reduction requires all the reduction-related loops that " |
521 | "are bound to GPU thread axes to only be bound `threadIdx.x/y/z`. However, loop " |
522 | << reduction_loop->loop_var->name_hint << " violates the condition." ; |
523 | } |
524 | } |
525 | |
526 | // Condition 3. Get the identity values of the block init and the BufferStore block combiner |
527 | // updates of the reduction. Extract the commutative reducer, combiner lhs and combiner rhs from |
528 | // the reduction identities and the reduction combiner. |
529 | Array<PrimExpr> init_values{nullptr}; |
530 | Array<BufferStore> updates{nullptr}; |
531 | CommReducer reducer{nullptr}; |
532 | Array<PrimExpr> combiner_lhs{nullptr}; |
533 | Array<PrimExpr> combiner_rhs{nullptr}; |
534 | std::tie(init_values, updates) = |
535 | GetInitValuesAndUpdatesFromReductionBlock(NullOpt, GetRef<Block>(block)); |
536 | std::tie(reducer, combiner_lhs, combiner_rhs) = |
537 | GetReducerAndCombinerLhsRhs(NullOpt, init_values, updates); |
538 | |
539 | Array<Buffer> reduction_buffers; |
540 | reduction_buffers.reserve(updates.size()); |
541 | for (const BufferStore& buf_store : updates) { |
542 | reduction_buffers.push_back(buf_store->buffer); |
543 | } |
544 | |
545 | // Condition 4. The block should be the last block under the first reduction-related loop. |
546 | bool visit = false; |
547 | PreOrderVisit(GetRef<For>(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { |
548 | if (const auto* realize = obj.as<BlockRealizeNode>()) { |
549 | CHECK(!visit) << "ValueError: Cross-thread reduction cannot be applied when the reduction " |
550 | "block isn't the last block under its first reduction-related loop" ; |
551 | if (realize->block.get() == block) { |
552 | visit = true; |
553 | } |
554 | return false; |
555 | } |
556 | return true; |
557 | }); |
558 | return std::make_tuple(n_bound_reduction_loops, // |
559 | std::move(reducer), // |
560 | std::move(reduction_buffers), // |
561 | std::move(combiner_rhs), // |
562 | updates[0]->indices); |
563 | } |
564 | |
565 | Stmt VisitStmt(const Stmt& stmt) final { |
566 | statement_stack_.push_back(stmt.get()); |
567 | Stmt result = StmtMutator::VisitStmt(stmt); |
568 | statement_stack_.pop_back(); |
569 | return result; |
570 | } |
571 | |
572 | Stmt VisitStmt_(const ForNode* loop) final { |
573 | loop_stack_.push_back(loop); |
574 | loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); |
575 | Stmt result = StmtMutator::VisitStmt_(loop); |
576 | loop_stack_.pop_back(); |
577 | loop_range_map_.erase(loop->loop_var); |
578 | |
579 | // Replace `result` with the pre-stored result if `loop` appears as a key in `loop2new_stmt_`. |
580 | auto it = loop2new_stmt_.find(loop); |
581 | if (it != loop2new_stmt_.end()) { |
582 | return it->second; |
583 | } else { |
584 | return result; |
585 | } |
586 | } |
587 | |
588 | Stmt VisitStmt_(const BlockNode* block) final { |
589 | Map<Var, Range> old_loop_range_map; |
590 | |
591 | block_stack_.push_back(block); |
592 | std::swap(old_loop_range_map, loop_range_map_); |
593 | Block new_block = Downcast<Block>(StmtMutator::VisitStmt_(block)); |
594 | block_stack_.pop_back(); |
595 | std::swap(old_loop_range_map, loop_range_map_); |
596 | |
597 | // Insert the new allocated buffers into the block's `alloc_buffers` field. |
598 | auto it = block2new_buffers_.find(block); |
599 | if (it != block2new_buffers_.end()) { |
600 | BlockNode* p_new_block = new_block.CopyOnWrite(); |
601 | for (const Buffer& new_buffer : it->second) { |
602 | if (new_buffer.defined()) { |
603 | p_new_block->alloc_buffers.push_back(new_buffer); |
604 | } |
605 | } |
606 | } |
607 | return std::move(new_block); |
608 | } |
609 | |
610 | Stmt VisitStmt_(const BlockRealizeNode* realize) final { |
611 | const BlockNode* block = realize->block.get(); |
612 | // Step 1. Check whether cross-thread reduction is needed. If no, skip this block. |
613 | std::vector<const ForNode*> reduction_loops = NeedCrossThreadReduction(realize); |
614 | if (reduction_loops.empty()) { |
615 | return StmtMutator::VisitStmt_(realize); |
616 | } |
617 | // Step 2. Check whether cross-thread reduction can be applied. If no, throw an exception on |
618 | // which condition the block violates. |
619 | int n_bound_reduction_loops = 0; |
620 | CommReducer reducer{nullptr}; |
621 | Array<Buffer> reduction_buffers{nullptr}; |
622 | Array<PrimExpr> combiner_rhs{nullptr}; |
623 | Array<PrimExpr> wb_indices{nullptr}; |
624 | std::tie(n_bound_reduction_loops, reducer, reduction_buffers, combiner_rhs, wb_indices) = |
625 | CheckCanApplyCrossThreadReduction(block, reduction_loops); |
626 | // Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when |
627 | // - not all the reduction-related loops are bound to thread axes, or |
628 | // - the block-realize has a non-constant-true predicate. |
629 | bool need_in_thread_reduction = |
630 | n_bound_reduction_loops < static_cast<int>(reduction_loops.size()) || |
631 | !is_one(realize->predicate); |
632 | // Step 4. Create intermediate buffers, storing them in `ct_buffers` and |
633 | // `it_buffers`. Let the scope block allocate these new buffers. |
634 | Array<Buffer>& new_buffers = block2new_buffers_[block_stack_.back()]; |
635 | Array<Buffer> ct_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true); |
636 | new_buffers.insert(new_buffers.end(), ct_buffers.begin(), ct_buffers.end()); |
637 | Optional<Array<Buffer>> it_buffers = NullOpt; |
638 | if (need_in_thread_reduction) { |
639 | it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false); |
640 | new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end()); |
641 | } |
642 | // Step 5. Transform. |
643 | loop2new_stmt_[reduction_loops[0]] = |
644 | TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices, |
645 | reducer, combiner_rhs, reduction_loops); |
646 | // Step 6. Return an empty statement, because the transformation result will be inserted when |
647 | // returning to the first reduction-related loop. |
648 | return Stmt{nullptr}; |
649 | } |
650 | |
651 | private: |
652 | std::vector<const StmtNode*> statement_stack_; |
653 | std::vector<const ForNode*> loop_stack_; |
654 | std::vector<const BlockNode*> block_stack_; |
655 | std::unordered_map<const BlockNode*, Array<Buffer>> block2new_buffers_; |
656 | std::unordered_map<const ForNode*, Stmt> loop2new_stmt_; |
657 | Map<Var, Range> loop_range_map_; |
658 | arith::Analyzer analyzer_; |
659 | }; |
660 | |
661 | PrimFunc LowerCrossThreadReduction(PrimFunc f) { |
662 | // Only apply this pass to TIR that is not from TE schedules |
663 | if (!IsFromLegacyTESchedule(f)) { |
664 | PrimFuncNode* fptr = f.CopyOnWrite(); |
665 | fptr->body = CrossThreadReductionTransformer()(f->body); |
666 | return f; |
667 | } else { |
668 | return f; |
669 | } |
670 | } |
671 | |
672 | namespace transform { |
673 | |
674 | Pass LowerCrossThreadReduction() { |
675 | auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { |
676 | return LowerCrossThreadReduction(std::move(f)); |
677 | }; |
678 | return CreatePrimFuncPass(pass_func, 0, "tir.LowerCrossThreadReduction" , {}); |
679 | } |
680 | |
681 | TVM_REGISTER_GLOBAL("tir.transform.LowerCrossThreadReduction" ) |
682 | .set_body_typed(LowerCrossThreadReduction); |
683 | |
684 | } // namespace transform |
685 | |
686 | } // namespace tir |
687 | } // namespace tvm |
688 | |