1 | #include <arith.h> |
2 | #include <index_compute.h> |
3 | #include <ir_iostream.h> |
4 | #include <ir_utils.h> |
5 | #include <lower2device.h> |
6 | #include <lower_utils.h> |
7 | #include <predicate_compute.h> |
8 | |
9 | #include <lower_index.h> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | |
16 | Val* IndexLowering::lowerSrcIndex(Val* src, Val* dst) const { |
17 | if (auto tv = dynamic_cast<TensorView*>(src)) { |
18 | TORCH_INTERNAL_ASSERT(dst->isA<TensorView>()); |
19 | return Index::getProducerIndex(tv, dst->as<TensorView>(), for_loops_); |
20 | } else { |
21 | return src; |
22 | } |
23 | } |
24 | |
25 | Val* IndexLowering::lowerDstIndex(Val* dst) const { |
26 | if (auto tv = dynamic_cast<TensorView*>(dst)) { |
27 | return Index::getConsumerIndex(tv, for_loops_); |
28 | } else { |
29 | return dst; |
30 | } |
31 | } |
32 | |
33 | void IndexLowering::pushBack(Expr* expr) { |
34 | if (active_scope_ == nullptr) { |
35 | lowered_exprs_.push_back(expr); |
36 | } else { |
37 | active_scope_->push_back(expr); |
38 | } |
39 | } |
40 | |
41 | Expr* IndexLowering::back() const { |
42 | if (active_scope_ == nullptr) { |
43 | TORCH_INTERNAL_ASSERT( |
44 | !lowered_exprs_.empty(), "IndexLowering::back: empty scope." ); |
45 | return lowered_exprs_.back(); |
46 | } |
47 | TORCH_INTERNAL_ASSERT( |
48 | !active_scope_->empty(), "IndexLowering::back: empty scope." ); |
49 | return active_scope_->exprs().back(); |
50 | } |
51 | |
52 | void IndexLowering::insertAtTopLevel(Expr* expr) { |
53 | TORCH_INTERNAL_ASSERT(!lowered_exprs_.empty()); |
54 | lowered_exprs_.insert(lowered_exprs_.end() - 1, expr); |
55 | } |
56 | |
57 | void IndexLowering::handle(const kir::IfThenElse* ite) { |
58 | const auto prev_scope = active_scope_; |
59 | |
60 | auto new_ite = IrBuilder::create<kir::IfThenElse>(ite->predicate()); |
61 | pushBack(new_ite); |
62 | |
63 | active_scope_ = &new_ite->thenBody(); |
64 | |
65 | for (auto expr : ite->thenBody().exprs()) { |
66 | OptOutConstDispatch::handle(expr); |
67 | } |
68 | |
69 | active_scope_ = &new_ite->elseBody(); |
70 | |
71 | for (auto expr : ite->elseBody().exprs()) { |
72 | OptOutConstDispatch::handle(expr); |
73 | } |
74 | |
75 | active_scope_ = prev_scope; |
76 | } |
77 | |
78 | void IndexLowering::handle(const kir::ForLoop* for_loop) { |
79 | const auto prev_scope = active_scope_; |
80 | |
81 | auto new_for_loop = IrBuilder::create<kir::ForLoop>(for_loop); |
82 | pushBack(new_for_loop); |
83 | |
84 | active_scope_ = &new_for_loop->body(); |
85 | for_loops_.push_back(new_for_loop); |
86 | |
87 | for (auto expr : for_loop->body().exprs()) { |
88 | OptOutConstDispatch::handle(expr); |
89 | } |
90 | |
91 | for_loops_.pop_back(); |
92 | active_scope_ = prev_scope; |
93 | } |
94 | |
95 | void IndexLowering::handle(const RNGOp* rop) { |
96 | // Write random tensor indices into the consumer |
97 | // tensor index if the output is a tensor. |
98 | auto out_tv = dynamic_cast<TensorView*>(rop->output(0)); |
99 | TORCH_INTERNAL_ASSERT(out_tv != nullptr, "rand scalar not yet supported" ); |
100 | |
101 | // TensorIndex for philox subsequence and component. |
102 | auto philox_index = SimplifyingIrBuilder::create<kir::TensorIndex>( |
103 | out_tv, Index::getLinearLogicalIndex(out_tv, for_loops_)); |
104 | |
105 | // TensorIndex for writing rand_like output. |
106 | const auto out = lowerDstIndex(out_tv); |
107 | |
108 | auto lowered = IrBuilder::create<RNGOp>( |
109 | rop->getRNGOpType(), |
110 | out, |
111 | rop->dtype(), |
112 | rop->getParameters(), |
113 | rop->getRNGOffset(), |
114 | philox_index); |
115 | |
116 | pushBack(lowered); |
117 | GpuLower::current()->propagateExprInfo(rop, back()); |
118 | } |
119 | |
120 | void IndexLowering::handle(const FullOp* fop) { |
121 | auto out_tv = dynamic_cast<TensorView*>(fop->output(0)); |
122 | TORCH_INTERNAL_ASSERT(out_tv != nullptr); |
123 | |
124 | // TensorIndex for writing output. |
125 | const auto out = lowerDstIndex(out_tv); |
126 | auto lowered = |
127 | IrBuilder::create<FullOp>(out, fop->getFillValue(), fop->dtype()); |
128 | |
129 | pushBack(lowered); |
130 | GpuLower::current()->propagateExprInfo(fop, back()); |
131 | } |
132 | |
133 | void IndexLowering::handle(const ARangeOp* aop) { |
134 | // Write linear tensor indices into the consumer |
135 | // tensor index if the output is a tensor. |
136 | auto out_tv = dynamic_cast<TensorView*>(aop->output(0)); |
137 | TORCH_INTERNAL_ASSERT(out_tv != nullptr); |
138 | |
139 | // linear index for computing arange output |
140 | auto linear_index = SimplifyingIrBuilder::create<kir::TensorIndex>( |
141 | out_tv, Index::getLinearLogicalIndex(out_tv, for_loops_)); |
142 | |
143 | // TensorIndex for writing arange output. |
144 | const auto out = lowerDstIndex(out_tv); |
145 | auto lowered = IrBuilder::create<ARangeOp>( |
146 | out, aop->start(), aop->end(), aop->step(), aop->dtype(), linear_index); |
147 | |
148 | pushBack(lowered); |
149 | GpuLower::current()->propagateExprInfo(aop, back()); |
150 | } |
151 | |
152 | void IndexLowering::handle(const EyeOp* eop) { |
153 | auto out_tv = dynamic_cast<TensorView*>(eop->output(0)); |
154 | TORCH_INTERNAL_ASSERT(out_tv != nullptr); |
155 | |
156 | // linear index for computing eye output |
157 | auto indices = Index::getPerDimLogicalIndex(out_tv, for_loops_); |
158 | TORCH_INTERNAL_ASSERT(indices.size() == 2); |
159 | auto index1 = indices[0]; |
160 | auto index2 = indices[1]; |
161 | |
162 | // TensorIndex for writing eye output. |
163 | const auto out = lowerDstIndex(out_tv); |
164 | auto lowered = IrBuilder::create<EyeOp>(out, eop->dtype(), index1, index2); |
165 | |
166 | pushBack(lowered); |
167 | GpuLower::current()->propagateExprInfo(eop, back()); |
168 | } |
169 | |
170 | void IndexLowering::handle(const UnaryOp* uop) { |
171 | const auto in = lowerSrcIndex(uop->in(), uop->out()); |
172 | const auto out = lowerDstIndex(uop->out()); |
173 | pushBack(IrBuilder::create<UnaryOp>(uop->getUnaryOpType(), out, in)); |
174 | GpuLower::current()->propagateExprInfo(uop, back()); |
175 | } |
176 | |
177 | void IndexLowering::handle(const BinaryOp* bop) { |
178 | const auto lhs = lowerSrcIndex(bop->lhs(), bop->out()); |
179 | const auto rhs = lowerSrcIndex(bop->rhs(), bop->out()); |
180 | const auto out = lowerDstIndex(bop->out()); |
181 | pushBack(IrBuilder::create<BinaryOp>(bop->getBinaryOpType(), out, lhs, rhs)); |
182 | GpuLower::current()->propagateExprInfo(bop, back()); |
183 | } |
184 | |
185 | void IndexLowering::handle(const TernaryOp* top) { |
186 | const auto in1 = lowerSrcIndex(top->in1(), top->out()); |
187 | const auto in2 = lowerSrcIndex(top->in2(), top->out()); |
188 | const auto in3 = lowerSrcIndex(top->in3(), top->out()); |
189 | const auto out = lowerDstIndex(top->out()); |
190 | pushBack(IrBuilder::create<TernaryOp>( |
191 | top->getTernaryOpType(), out, in1, in2, in3)); |
192 | GpuLower::current()->propagateExprInfo(top, back()); |
193 | } |
194 | |
195 | void IndexLowering::handle(const ViewAsScalar* uop) { |
196 | const auto in = lowerSrcIndex(uop->in(), uop->out()); |
197 | const auto out = lowerDstIndex(uop->out()); |
198 | for (auto loop : for_loops_) { |
199 | if (GpuLower::current()->caMap()->areMapped( |
200 | loop->iter_domain(), |
201 | uop->vector_id()->as<IterDomain>(), |
202 | IdMappingMode::LOOP)) { |
203 | Val* index = loop->index(); |
204 | pushBack( |
205 | IrBuilder::create<ViewAsScalar>(out, in, uop->vector_id(), index)); |
206 | GpuLower::current()->propagateExprInfo(uop, back()); |
207 | return; |
208 | } |
209 | } |
210 | TORCH_INTERNAL_ASSERT(false, "Can not find index for vector dim" ); |
211 | } |
212 | |
213 | namespace { |
214 | |
215 | struct GridCommWorkBufferSizeInfo { |
216 | // Size of overall buffer. Can be expanded for privatization |
217 | Val* size_of_privatized_buffer = nullptr; |
218 | // Size of single buffer. |
219 | Val* buffer_stride = nullptr; |
220 | }; |
221 | |
222 | // Get the size of the temporary work buffer for grid communication, this can be |
223 | // grid reduction, broadcast, or grid welford. |
224 | // The buffer is expanded for privatization when not persistent or grouped. |
225 | GridCommWorkBufferSizeInfo getGridCommWorkBufferSize( |
226 | const TensorDomain* td, |
227 | const std::vector<kir::ForLoop*>& for_loops, |
228 | bool is_persistent) { |
229 | // The buffer size is the number of thread blocks multiplied by the |
230 | // number of threads not used for reduction domains. |
231 | // Note: Previously it was calculated based on the shape of the |
232 | // tensor, but it makes more sense to compute the size based on the |
233 | // shape of the thread block and grid since this buffer is used for |
234 | // communications among them. Both methods should result in the same |
235 | // size if the parallel dimensions are exact, but otherwise, just |
236 | // computing the buffer size based on the tensor shape isn't |
237 | // sufficient since there could be extra threads/blocks. |
238 | Val* size_of_single_buffer = GpuLower::current()->kernel()->oneVal(); |
239 | for (auto pt : kParallelTypeThreads) { |
240 | auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); |
241 | if (pt_dim == nullptr || pt_dim->isOneInt()) { |
242 | continue; |
243 | } |
244 | if (isParallelTypeThreadDim(pt) && |
245 | std::any_of(td->domain().begin(), td->domain().end(), [&](auto out_id) { |
246 | return out_id->getParallelType() == pt && |
247 | (out_id->isReduction() || out_id->isBroadcast()); |
248 | })) { |
249 | continue; |
250 | } |
251 | size_of_single_buffer = |
252 | SimplifyingIrBuilder::mulExpr(size_of_single_buffer, pt_dim); |
253 | } |
254 | |
255 | // Expand the buffer for privatization. The buffer is expanded so |
256 | // that each non-reduction IterDomain uses a different part of the |
257 | // buffer. For persistent mode, this expansion is only done for |
258 | // grouped IterDomains. |
259 | |
260 | Val* size_of_privatized_buffer = size_of_single_buffer; |
261 | |
262 | // In persistent mode, if non-grouped no-reduction domain is used, |
263 | // double the buffer size to save a final grid sync |
264 | bool is_doubled = false; |
265 | |
266 | for (auto fl : for_loops) { |
267 | // Buffer size of parallelized domains are already taken care |
268 | if (fl->isTrivial() || fl->iter_domain()->isReduction() || |
269 | fl->iter_domain()->isThread()) { |
270 | continue; |
271 | } |
272 | // If persistent, i.e., allreduce, only IterDomains with |
273 | // ParallelType::Group are privatized |
274 | if (!is_persistent || |
275 | fl->iter_domain()->getParallelType() == ParallelType::Group) { |
276 | size_of_privatized_buffer = SimplifyingIrBuilder::mulExpr( |
277 | size_of_privatized_buffer, fl->iter_domain()->extent()); |
278 | } else if (is_persistent) { |
279 | is_doubled = true; |
280 | } |
281 | } |
282 | |
283 | if (is_doubled) { |
284 | size_of_privatized_buffer = SimplifyingIrBuilder::mulExpr( |
285 | size_of_privatized_buffer, IrBuilder::create<Int>(2)); |
286 | } |
287 | |
288 | GridCommWorkBufferSizeInfo info; |
289 | info.size_of_privatized_buffer = size_of_privatized_buffer; |
290 | info.buffer_stride = size_of_single_buffer; |
291 | if (is_doubled) { |
292 | info.buffer_stride = SimplifyingIrBuilder::mulExpr( |
293 | info.buffer_stride, IrBuilder::create<Int>(2)); |
294 | } |
295 | |
296 | return info; |
297 | } |
298 | |
299 | Val* getGridSyncBufferSize( |
300 | const TensorDomain* td, |
301 | const std::vector<kir::ForLoop*>& for_loops, |
302 | bool is_persistent) { |
303 | // See the comment above for getGridCommWorkBufferSize. |
304 | Val* buffer_size = GpuLower::current()->kernel()->oneVal(); |
305 | for (auto pt : kParallelTypeBIDs) { |
306 | auto pt_dim = GpuLower::current()->parallelDimensionMap().get(pt); |
307 | if (pt_dim == nullptr || pt_dim->isOneInt()) { |
308 | continue; |
309 | } |
310 | if (std::any_of(td->domain().begin(), td->domain().end(), [&](auto out_id) { |
311 | return out_id->getParallelType() == pt && |
312 | (out_id->isReduction() || out_id->isBroadcast()); |
313 | })) { |
314 | continue; |
315 | } |
316 | buffer_size = SimplifyingIrBuilder::mulExpr(buffer_size, pt_dim); |
317 | } |
318 | |
319 | // If not persistent, all iteration domains require a separate |
320 | // semaphore for re-entrant grid reductions |
321 | if (!is_persistent) { |
322 | for (auto fl : for_loops) { |
323 | if (fl->isTrivial()) { |
324 | continue; |
325 | } |
326 | if (fl->iter_domain()->isThread()) { |
327 | // already accounted for. |
328 | continue; |
329 | } |
330 | |
331 | buffer_size = SimplifyingIrBuilder::mulExpr( |
332 | buffer_size, fl->iter_domain()->extent()); |
333 | } |
334 | } |
335 | |
336 | return buffer_size; |
337 | } |
338 | |
339 | Val* getEntranceCountGridReduce(std::vector<kir::ForLoop*>& for_loops) { |
340 | Val* grid_reduction_entrances = GpuLower::current()->kernel()->oneVal(); |
341 | |
342 | for (const auto loop : for_loops) { |
343 | if (loop->isTrivial()) { |
344 | continue; |
345 | } |
346 | if (loop->iter_domain()->isThread()) { |
347 | // already accounted for. |
348 | continue; |
349 | } |
350 | // TODO: Does this work for shift/gather? |
351 | grid_reduction_entrances = SimplifyingIrBuilder::mulExpr( |
352 | grid_reduction_entrances, loop->iter_domain()->extent()); |
353 | } |
354 | return grid_reduction_entrances; |
355 | } |
356 | |
357 | // Linear indexing of for loops for multiple entrances into grid reduce |
358 | // TODO: What happens if there's a broadcast that's resolved (not present in the |
359 | // grid reduce) but the global buffer isn't expanded? |
360 | Val* getEntranceLinIndGridReduce(std::vector<kir::ForLoop*>& for_loops) { |
361 | Val* linear_index = GpuLower::current()->kernel()->zeroVal(); |
362 | |
363 | for (const auto loop : for_loops) { |
364 | if (loop->isTrivial()) { |
365 | continue; |
366 | } |
367 | if (loop->iter_domain()->isThread()) { |
368 | // already accounted for. |
369 | continue; |
370 | } |
371 | // TODO: Does this work for shift/gather? |
372 | linear_index = SimplifyingIrBuilder::addExpr( |
373 | SimplifyingIrBuilder::mulExpr( |
374 | linear_index, loop->iter_domain()->extent()), |
375 | loop->index()); |
376 | } |
377 | return linear_index; |
378 | } |
379 | |
380 | } // namespace |
381 | |
382 | void IndexLowering::handle(const ReductionOp* rop) { |
383 | TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(rop)); |
384 | |
385 | const auto out_tv = rop->out()->as<TensorView>(); |
386 | const auto out_domain = out_tv->domain(); |
387 | |
388 | const bool has_block_reduce = out_domain->hasBlockReduction(); |
389 | const bool has_grid_reduce = out_domain->hasGridReduction(); |
390 | |
391 | const auto out = lowerDstIndex(rop->out()); |
392 | const auto in = lowerSrcIndex(rop->in(), rop->out()); |
393 | |
394 | if (has_grid_reduce) { |
395 | handleGridReduction(rop, out, in); |
396 | } else if (has_block_reduce) { |
397 | handleBlockReduction(rop, out, in); |
398 | } else { |
399 | pushBack( |
400 | IrBuilder::create<BinaryOp>(rop->getReductionOpType(), out, out, in)); |
401 | GpuLower::current()->propagateExprInfo(rop, back()); |
402 | } |
403 | } |
404 | |
405 | void IndexLowering::handleBlockReduction( |
406 | const ReductionOp* rop, |
407 | Val* out, |
408 | Val* in) { |
409 | TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(rop)); |
410 | |
411 | ReductionOp* indexed_rop = IrBuilder::create<ReductionOp>( |
412 | rop->getReductionOpType(), rop->init(), out, in, rop->isAllreduce()); |
413 | if (rop->predicate()) { |
414 | indexed_rop = |
415 | indexed_rop->withPredicate(rop->predicate())->as<ReductionOp>(); |
416 | } |
417 | if (rop->writePredicate()) { |
418 | indexed_rop = indexed_rop->withWritePredicate(rop->writePredicate()) |
419 | ->as<ReductionOp>(); |
420 | } |
421 | |
422 | pushBack(indexed_rop); |
423 | GpuLower::current()->propagateExprInfo(rop, back()); |
424 | } |
425 | |
426 | void IndexLowering::handleGridReduction( |
427 | const ReductionOp* rop, |
428 | Val* out, |
429 | Val* in) { |
430 | const auto out_tv = out->as<kir::TensorIndex>()->view(); |
431 | const auto out_domain = out_tv->domain(); |
432 | |
433 | TORCH_INTERNAL_ASSERT(out_domain->hasGridReduction()); |
434 | |
435 | // If we do a grid reduction we can't have a reduction axis that is not bound |
436 | // to a grid or block dim. |
437 | TORCH_INTERNAL_ASSERT( |
438 | std::none_of( |
439 | out_domain->domain().begin(), |
440 | out_domain->domain().end(), |
441 | [](IterDomain* id) { |
442 | return !id->isThread() && id->isReduction() && |
443 | !id->extent()->isOneInt(); |
444 | }), |
445 | "Found a reduction stage that has both a non-parallelized " , |
446 | "reduction and a grid reduction. This is not supported, " , |
447 | "please use rfactor to do the serialized reduction first, " , |
448 | "then the grid reduction. " , |
449 | rop->toString()); |
450 | |
451 | // Use a unique buffer for work and sync flag when called within a |
452 | // loop unless it's persistent. Grid all reduce means persistence is |
453 | // required. However, not being a grid all reduce does not mean |
454 | // non-persistence. Currently, if a cooperative grid reduction is |
455 | // required anywhere in the kernel, all grid reducitons are done in |
456 | // a persistent manner, so all grid reductions should be consulted. |
457 | // TODO: fix this |
458 | const bool is_persistent = rop->isAllreduce(); |
459 | const auto buffer_size_info = |
460 | getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent); |
461 | |
462 | auto work_buffer = allocateUniqueBuffer( |
463 | buffer_size_info.size_of_privatized_buffer, |
464 | out_tv->dtype(), |
465 | false, |
466 | out_tv, |
467 | work_buffer_map_); |
468 | |
469 | auto sync_buffer_size = |
470 | getGridSyncBufferSize(out_domain, for_loops_, is_persistent); |
471 | auto sync_buffer = allocateUniqueBuffer( |
472 | sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_); |
473 | |
474 | const auto entrance_ind = !is_persistent |
475 | ? getEntranceLinIndGridReduce(for_loops_) |
476 | : GpuLower::current()->kernel()->zeroVal(); |
477 | const auto n_entrances = !is_persistent |
478 | ? getEntranceCountGridReduce(for_loops_) |
479 | : GpuLower::current()->kernel()->oneVal(); |
480 | |
481 | // The thread predicate for GridReduction needs to be set |
482 | // separately from the main predicate. Do not combine them like |
483 | // other expressions. |
484 | const auto& thread_pred = |
485 | GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); |
486 | |
487 | auto grid_reduction = IrBuilder::create<kir::GridReduction>( |
488 | rop->getReductionOpType(), |
489 | rop->init(), |
490 | out, |
491 | in, |
492 | work_buffer, |
493 | sync_buffer, |
494 | entrance_ind, |
495 | n_entrances, |
496 | rop->isAllreduce()); |
497 | |
498 | grid_reduction = grid_reduction->withThreadPredicate(thread_pred); |
499 | |
500 | if (rop->predicate()) { |
501 | grid_reduction = grid_reduction->withPredicate(rop->predicate()) |
502 | ->as<kir::GridReduction>(); |
503 | } |
504 | if (rop->writePredicate()) { |
505 | grid_reduction = grid_reduction->withWritePredicate(rop->writePredicate()) |
506 | ->as<kir::GridReduction>(); |
507 | } |
508 | |
509 | pushBack(grid_reduction); |
510 | GpuLower::current()->propagateExprInfo(rop, back()); |
511 | |
512 | if (rop->isAllreduce()) { |
513 | allocateUniqueFusedReduction(grid_reduction, out_tv); |
514 | } |
515 | } |
516 | |
517 | void IndexLowering::handle(const GroupedReductionOp* grouped_rop) { |
518 | TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(grouped_rop)); |
519 | |
520 | const auto out_tv = ir_utils::getTvOutput(grouped_rop); |
521 | const auto out_domain = out_tv->domain(); |
522 | |
523 | const bool has_block_reduce = out_domain->hasBlockReduction(); |
524 | const bool has_grid_reduce = out_domain->hasGridReduction(); |
525 | |
526 | std::vector<Val*> indexed_outputs(grouped_rop->numExprs()); |
527 | std::vector<Val*> indexed_inputs(grouped_rop->numExprs()); |
528 | |
529 | for (const auto i : c10::irange(grouped_rop->numExprs())) { |
530 | indexed_outputs.at(i) = lowerDstIndex(grouped_rop->output(i)); |
531 | indexed_inputs.at(i) = |
532 | lowerSrcIndex(grouped_rop->input(i), grouped_rop->output(i)); |
533 | } |
534 | |
535 | if (has_grid_reduce) { |
536 | handleGridReduction(grouped_rop, indexed_outputs, indexed_inputs); |
537 | } else if (has_block_reduce) { |
538 | handleBlockReduction(grouped_rop, indexed_outputs, indexed_inputs); |
539 | } else { |
540 | for (const auto i : c10::irange(grouped_rop->numExprs())) { |
541 | pushBack(IrBuilder::create<BinaryOp>( |
542 | grouped_rop->getReductionOpType(i), |
543 | indexed_outputs.at(i), |
544 | indexed_outputs.at(i), |
545 | indexed_inputs.at(i))); |
546 | } |
547 | } |
548 | } |
549 | |
550 | void IndexLowering::handleBlockReduction( |
551 | const GroupedReductionOp* grouped_rop, |
552 | const std::vector<Val*>& outputs, |
553 | const std::vector<Val*>& inputs) { |
554 | TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(grouped_rop)); |
555 | |
556 | GroupedReductionOp* indexed_rop = IrBuilder::create<GroupedReductionOp>( |
557 | grouped_rop->getReductionOpTypes(), |
558 | grouped_rop->initVals(), |
559 | outputs, |
560 | inputs, |
561 | grouped_rop->isAllreduce()); |
562 | if (grouped_rop->predicate()) { |
563 | indexed_rop = indexed_rop->withPredicate(grouped_rop->predicate()) |
564 | ->as<GroupedReductionOp>(); |
565 | } |
566 | if (grouped_rop->writePredicate()) { |
567 | indexed_rop = indexed_rop->withWritePredicate(grouped_rop->writePredicate()) |
568 | ->as<GroupedReductionOp>(); |
569 | } |
570 | |
571 | pushBack(indexed_rop); |
572 | GpuLower::current()->propagateExprInfo(grouped_rop, back()); |
573 | } |
574 | |
575 | void IndexLowering::handleGridReduction( |
576 | const GroupedReductionOp* grouped_rop, |
577 | const std::vector<Val*>& outputs, |
578 | const std::vector<Val*>& inputs) { |
579 | const auto out_tv = ir_utils::getTvOutput(grouped_rop); |
580 | const auto out_domain = out_tv->domain(); |
581 | |
582 | TORCH_INTERNAL_ASSERT(out_domain->hasGridReduction()); |
583 | |
584 | // If we do a grid reduction we can't have a reduction axis that is not bound |
585 | // to a grid or block dim. |
586 | TORCH_INTERNAL_ASSERT( |
587 | std::none_of( |
588 | out_domain->domain().begin(), |
589 | out_domain->domain().end(), |
590 | [](IterDomain* id) { |
591 | return !id->isThread() && id->isReduction() && |
592 | !id->extent()->isOneInt(); |
593 | }), |
594 | "Found a reduction stage that has both a non-parallelized " , |
595 | "reduction and a grid reduction. This is not supported, " , |
596 | "please use rfactor to do the serialized reduction first, " , |
597 | "then the grid reduction." ); |
598 | |
599 | const bool is_persistent = grouped_rop->isAllreduce(); |
600 | auto work_buf_size_info = |
601 | getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent); |
602 | |
603 | std::vector<kir::Allocate*> work_buffers; |
604 | std::transform( |
605 | outputs.begin(), |
606 | outputs.end(), |
607 | std::back_inserter(work_buffers), |
608 | [&](Val* output) { |
609 | return allocateUniqueBuffer( |
610 | work_buf_size_info.size_of_privatized_buffer, |
611 | output->dtype(), |
612 | false, |
613 | output->as<kir::TensorIndex>()->view(), |
614 | work_buffer_map_); |
615 | }); |
616 | |
617 | auto sync_buffer_size = |
618 | getGridSyncBufferSize(out_domain, for_loops_, is_persistent); |
619 | auto sync_buffer = allocateUniqueBuffer( |
620 | sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_); |
621 | |
622 | const auto entrance_ind = !is_persistent |
623 | ? getEntranceLinIndGridReduce(for_loops_) |
624 | : GpuLower::current()->kernel()->zeroVal(); |
625 | const auto n_entrances = !is_persistent |
626 | ? getEntranceCountGridReduce(for_loops_) |
627 | : GpuLower::current()->kernel()->oneVal(); |
628 | |
629 | // The thread predicate for GridReduction needs to be set |
630 | // separately from the main predicate. Do not combine them like |
631 | // other expressions. |
632 | const auto& thread_pred = |
633 | GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); |
634 | |
635 | auto grid_reduction = IrBuilder::create<kir::GroupedGridReduction>( |
636 | grouped_rop->getReductionOpTypes(), |
637 | grouped_rop->initVals(), |
638 | outputs, |
639 | inputs, |
640 | work_buffers, |
641 | sync_buffer, |
642 | entrance_ind, |
643 | n_entrances, |
644 | work_buf_size_info.buffer_stride, |
645 | grouped_rop->isAllreduce()); |
646 | |
647 | grid_reduction = grid_reduction->withThreadPredicate(thread_pred); |
648 | |
649 | if (grouped_rop->predicate()) { |
650 | grid_reduction = grid_reduction->withPredicate(grouped_rop->predicate()) |
651 | ->as<kir::GroupedGridReduction>(); |
652 | } |
653 | if (grouped_rop->writePredicate()) { |
654 | grid_reduction = |
655 | grid_reduction->withWritePredicate(grouped_rop->writePredicate()) |
656 | ->as<kir::GroupedGridReduction>(); |
657 | } |
658 | |
659 | pushBack(grid_reduction); |
660 | GpuLower::current()->propagateExprInfo(grouped_rop, back()); |
661 | |
662 | if (grouped_rop->isAllreduce()) { |
663 | allocateUniqueFusedReduction(grid_reduction, out_tv); |
664 | } |
665 | } |
666 | |
667 | void IndexLowering::handle(const WelfordOp* wop) { |
668 | TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(wop)); |
669 | |
670 | const auto out_tv = wop->outAvg()->as<TensorView>(); |
671 | const auto out_domain = out_tv->domain(); |
672 | |
673 | const bool has_block_reduce = out_domain->hasBlockReduction(); |
674 | const bool has_grid_reduce = out_domain->hasGridReduction(); |
675 | |
676 | if (has_grid_reduce) { |
677 | TORCH_INTERNAL_ASSERT( |
678 | std::none_of( |
679 | out_domain->domain().begin(), |
680 | out_domain->domain().end(), |
681 | [](IterDomain* id) { |
682 | return !id->isThread() && id->isReduction(); |
683 | }), |
684 | "Found a reduction stage that has both a non-parallelized " , |
685 | "reduction and a grid reduction. This is not supported, " , |
686 | "please use rfactor to do the serialized reduction first, " , |
687 | "then the grid reduction." ); |
688 | } |
689 | |
690 | // lower IO tensors |
691 | const auto in_var = |
692 | wop->inVar() ? lowerSrcIndex(wop->inVar(), wop->outAvg()) : nullptr; |
693 | const auto in_avg = lowerSrcIndex(wop->inAvg(), wop->outAvg()); |
694 | auto in_N = wop->inN(); |
695 | |
696 | // in Rfactor-ed case, the input N is actually a TV |
697 | if (!in_N->isScalar()) { |
698 | in_N = lowerSrcIndex(in_N, wop->outN()); |
699 | } |
700 | |
701 | auto out_avg = lowerDstIndex(wop->outAvg()); |
702 | auto out_var = lowerDstIndex(wop->outVar()); |
703 | auto out_N = lowerDstIndex(wop->outN()); |
704 | |
705 | WelfordOp* indexed_wop = IrBuilder::create<WelfordOp>( |
706 | out_avg, |
707 | out_var, |
708 | out_N, |
709 | in_avg, |
710 | in_var, |
711 | in_N, |
712 | wop->initAvg(), |
713 | wop->initVar(), |
714 | wop->initN(), |
715 | wop->isAllreduce()); |
716 | |
717 | if (wop->predicate()) { |
718 | indexed_wop = indexed_wop->withPredicate(wop->predicate())->as<WelfordOp>(); |
719 | } |
720 | if (wop->writePredicate()) { |
721 | indexed_wop = |
722 | indexed_wop->withWritePredicate(wop->writePredicate())->as<WelfordOp>(); |
723 | } |
724 | |
725 | // Serial welford |
726 | if (!has_block_reduce && !has_grid_reduce) { |
727 | pushBack(indexed_wop); |
728 | GpuLower::current()->propagateExprInfo(wop, back()); |
729 | return; |
730 | } |
731 | |
732 | // Block-only welford |
733 | if (!has_grid_reduce) { |
734 | pushBack(indexed_wop); |
735 | GpuLower::current()->propagateExprInfo(wop, back()); |
736 | return; |
737 | } |
738 | |
739 | handleGridWelford(indexed_wop); |
740 | } |
741 | |
742 | void IndexLowering::handleGridWelford(WelfordOp* indexed_wop) { |
743 | const auto out_tv = indexed_wop->out()->as<kir::TensorIndex>()->view(); |
744 | const auto out_domain = out_tv->domain(); |
745 | |
746 | // TODO: See the comment on the same variable in handleGridReduction |
747 | const bool is_persistent = indexed_wop->isAllreduce(); |
748 | const auto buffer_size_info = |
749 | getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent); |
750 | |
751 | const auto work_buffer_size = buffer_size_info.size_of_privatized_buffer; |
752 | auto out_avg_buffer = allocateUniqueBuffer( |
753 | work_buffer_size, |
754 | indexed_wop->outAvg()->dtype(), |
755 | false, |
756 | indexed_wop->outAvg()->as<kir::TensorIndex>()->view(), |
757 | work_buffer_map_); |
758 | auto out_var_buffer = allocateUniqueBuffer( |
759 | work_buffer_size, |
760 | indexed_wop->outVar()->dtype(), |
761 | false, |
762 | indexed_wop->outVar()->as<kir::TensorIndex>()->view(), |
763 | work_buffer_map_); |
764 | auto out_N_buffer = allocateUniqueBuffer( |
765 | work_buffer_size, |
766 | indexed_wop->outN()->dtype(), |
767 | false, |
768 | indexed_wop->outN()->as<kir::TensorIndex>()->view(), |
769 | work_buffer_map_); |
770 | |
771 | auto sync_buffer_size = |
772 | getGridSyncBufferSize(out_domain, for_loops_, is_persistent); |
773 | auto sync_buffer = allocateUniqueBuffer( |
774 | sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_); |
775 | |
776 | const auto entrance_ind = !is_persistent |
777 | ? getEntranceLinIndGridReduce(for_loops_) |
778 | : GpuLower::current()->kernel()->zeroVal(); |
779 | const auto n_entrances = !is_persistent |
780 | ? getEntranceCountGridReduce(for_loops_) |
781 | : GpuLower::current()->kernel()->oneVal(); |
782 | |
783 | // The thread predicate for GridReduction needs to be set |
784 | // separately from the main predicate. Do not combine them like |
785 | // other expressions. |
786 | const auto& thread_pred = |
787 | GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); |
788 | |
789 | auto grid_welford = IrBuilder::create<kir::GridWelford>( |
790 | indexed_wop, |
791 | out_var_buffer, |
792 | out_avg_buffer, |
793 | out_N_buffer, |
794 | sync_buffer, |
795 | entrance_ind, |
796 | n_entrances); |
797 | |
798 | grid_welford = grid_welford->withThreadPredicate(thread_pred); |
799 | |
800 | const bool block_reduce_separated = |
801 | out_domain->hasBlockReduction() && !indexed_wop->isAllreduce(); |
802 | |
803 | if (indexed_wop->predicate()) { |
804 | if (block_reduce_separated) { |
805 | grid_welford = grid_welford |
806 | ->withPredicate(IrBuilder::create<kir::Predicate>( |
807 | GpuLower::current()->kernel()->trueVal())) |
808 | ->as<kir::GridWelford>(); |
809 | } else { |
810 | grid_welford = grid_welford->withPredicate(indexed_wop->predicate()) |
811 | ->as<kir::GridWelford>(); |
812 | } |
813 | } |
814 | |
815 | if (indexed_wop->writePredicate()) { |
816 | grid_welford = |
817 | grid_welford->withWritePredicate(indexed_wop->writePredicate()) |
818 | ->as<kir::GridWelford>(); |
819 | } |
820 | |
821 | if (block_reduce_separated) { |
822 | pushBack(indexed_wop); |
823 | GpuLower::current()->propagateExprInfo(indexed_wop, back()); |
824 | } |
825 | |
826 | pushBack(grid_welford); |
827 | GpuLower::current()->propagateExprInfo(indexed_wop, back()); |
828 | |
829 | if (indexed_wop->isAllreduce()) { |
830 | // When using the fused reduction, allocate the reduction object at |
831 | // the outer-most scope |
832 | allocateUniqueFusedReduction(grid_welford, out_tv); |
833 | } |
834 | } |
835 | |
836 | void IndexLowering::handle(const GroupedWelfordOp* grouped_wop) { |
837 | TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(grouped_wop)); |
838 | |
839 | const auto out_tv = ir_utils::getTvOutput(grouped_wop); |
840 | const auto out_domain = out_tv->domain(); |
841 | |
842 | const bool has_grid_reduce = out_domain->hasGridReduction(); |
843 | |
844 | std::vector<WelfordTriplet> indexed_outputs(grouped_wop->numExprs()); |
845 | std::vector<WelfordTriplet> indexed_inputs(grouped_wop->numExprs()); |
846 | |
847 | for (const auto i : c10::irange(grouped_wop->numExprs())) { |
848 | const auto& output = grouped_wop->outputVals().at(i); |
849 | const auto& input = grouped_wop->inputVals().at(i); |
850 | WelfordTriplet indexed_output; |
851 | WelfordTriplet indexed_input; |
852 | for (const auto j : c10::irange(3)) { |
853 | indexed_output.get(j) = lowerDstIndex(output.get(j)); |
854 | indexed_input.get(j) = lowerSrcIndex(input.get(j), output.get(j)); |
855 | } |
856 | indexed_outputs[i] = indexed_output; |
857 | indexed_inputs[i] = indexed_input; |
858 | } |
859 | |
860 | if (has_grid_reduce) { |
861 | handleGroupedGridWelford( |
862 | grouped_wop, indexed_outputs, indexed_inputs, grouped_wop->initVals()); |
863 | } else { |
864 | TORCH_INTERNAL_ASSERT( |
865 | false, |
866 | "Only grid welford is supported. Validation should have caught non-grid welford grouping." ); |
867 | } |
868 | } |
869 | |
870 | std::vector<kir::Allocate*> IndexLowering::allocateWelfordWorkBuffer( |
871 | const std::vector<WelfordTriplet>& triplets, |
872 | WelfordTriplet::ValName name, |
873 | Val* buffer_size) { |
874 | std::vector<kir::Allocate*> work_buffers; |
875 | |
876 | std::transform( |
877 | triplets.begin(), |
878 | triplets.end(), |
879 | std::back_inserter(work_buffers), |
880 | [&](const WelfordTriplet& output) { |
881 | return allocateUniqueBuffer( |
882 | buffer_size, |
883 | output.get(name)->dtype(), |
884 | false, |
885 | output.get(name)->as<TensorView>(), |
886 | work_buffer_map_); |
887 | }); |
888 | |
889 | return work_buffers; |
890 | } |
891 | |
892 | void IndexLowering::handleGroupedGridWelford( |
893 | const GroupedWelfordOp* op, |
894 | const std::vector<WelfordTriplet>& output_vals, |
895 | const std::vector<WelfordTriplet>& input_vals, |
896 | const std::vector<WelfordTriplet>& init_vals) { |
897 | const auto out_tv = ir_utils::getTvOutput(op); |
898 | const auto out_domain = out_tv->domain(); |
899 | |
900 | TORCH_INTERNAL_ASSERT(out_domain->hasGridReduction()); |
901 | |
902 | // If we do a grid reduction we can't have a reduction axis that is not bound |
903 | // to a grid or block dim. |
904 | TORCH_INTERNAL_ASSERT( |
905 | std::none_of( |
906 | out_domain->domain().begin(), |
907 | out_domain->domain().end(), |
908 | [](IterDomain* id) { |
909 | return !id->isThread() && id->isReduction() && |
910 | !id->extent()->isOneInt(); |
911 | }), |
912 | "Found a reduction stage that has both a non-parallelized " , |
913 | "reduction and a grid reduction. This is not supported, " , |
914 | "please use rfactor to do the serialized reduction first, " , |
915 | "then the grid reduction." ); |
916 | |
917 | const bool is_persistent = op->isAllreduce(); |
918 | auto work_buf_size_info = |
919 | getGridCommWorkBufferSize(out_domain, for_loops_, is_persistent); |
920 | |
921 | const auto work_buffers_avg = allocateWelfordWorkBuffer( |
922 | op->outputVals(), |
923 | WelfordTriplet::ValName::Avg, |
924 | work_buf_size_info.size_of_privatized_buffer); |
925 | const auto work_buffers_var = allocateWelfordWorkBuffer( |
926 | op->outputVals(), |
927 | WelfordTriplet::ValName::Var, |
928 | work_buf_size_info.size_of_privatized_buffer); |
929 | const auto work_buffers_N = allocateWelfordWorkBuffer( |
930 | op->outputVals(), |
931 | WelfordTriplet::ValName::N, |
932 | work_buf_size_info.size_of_privatized_buffer); |
933 | |
934 | auto sync_buffer_size = |
935 | getGridSyncBufferSize(out_domain, for_loops_, is_persistent); |
936 | auto sync_buffer = allocateUniqueBuffer( |
937 | sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_); |
938 | |
939 | const auto entrance_ind = !is_persistent |
940 | ? getEntranceLinIndGridReduce(for_loops_) |
941 | : GpuLower::current()->kernel()->zeroVal(); |
942 | const auto n_entrances = !is_persistent |
943 | ? getEntranceCountGridReduce(for_loops_) |
944 | : GpuLower::current()->kernel()->oneVal(); |
945 | |
946 | // The thread predicate needs to be set separately from the main |
947 | // predicate. Do not combine them like other expressions. |
948 | const auto& thread_pred = |
949 | GpuLower::current()->threadPredMap().getPredicatedParallelTypes(out_tv); |
950 | |
951 | auto indexed_op = IrBuilder::create<kir::GroupedGridWelford>( |
952 | output_vals, |
953 | input_vals, |
954 | init_vals, |
955 | std::array<std::vector<kir::Allocate*>, 3>{ |
956 | work_buffers_avg, work_buffers_var, work_buffers_N}, |
957 | sync_buffer, |
958 | entrance_ind, |
959 | n_entrances, |
960 | work_buf_size_info.buffer_stride, |
961 | op->isAllreduce()); |
962 | |
963 | indexed_op = indexed_op->withThreadPredicate(thread_pred); |
964 | |
965 | if (op->predicate()) { |
966 | indexed_op = indexed_op->withPredicate(op->predicate()) |
967 | ->as<kir::GroupedGridWelford>(); |
968 | } |
969 | if (op->writePredicate()) { |
970 | indexed_op = indexed_op->withWritePredicate(op->writePredicate()) |
971 | ->as<kir::GroupedGridWelford>(); |
972 | } |
973 | |
974 | pushBack(indexed_op); |
975 | GpuLower::current()->propagateExprInfo(op, back()); |
976 | |
977 | if (op->isAllreduce()) { |
978 | allocateUniqueFusedReduction(indexed_op, out_tv); |
979 | } |
980 | } |
981 | |
982 | void IndexLowering::handle(const LoadStoreOp* ldst) { |
983 | const auto in = lowerSrcIndex(ldst->in(), ldst->out()); |
984 | const auto out = lowerDstIndex(ldst->out()); |
985 | auto new_ldst = IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in) |
986 | ->withPredicate(ldst->predicate()); |
987 | pushBack(new_ldst); |
988 | GpuLower::current()->propagateExprInfo(ldst, back()); |
989 | } |
990 | |
991 | void IndexLowering::handle(const MmaOp* mma) { |
992 | const auto a = lowerSrcIndex(mma->inA(), mma->out()); |
993 | const auto b = lowerSrcIndex(mma->inB(), mma->out()); |
994 | const auto out = lowerDstIndex(mma->out()); |
995 | auto mma_indexed = |
996 | IrBuilder::create<MmaOp>(out, a, b, mma->init(), mma->options()); |
997 | pushBack(mma_indexed); |
998 | GpuLower::current()->propagateExprInfo(mma, back()); |
999 | } |
1000 | |
1001 | void IndexLowering::handle(const BroadcastOp* bop) { |
1002 | TORCH_INTERNAL_ASSERT(ir_utils::isTvOp(bop)); |
1003 | |
1004 | const auto out_tv = bop->out()->as<TensorView>(); |
1005 | |
1006 | const auto out = lowerDstIndex(bop->out()); |
1007 | const auto in = lowerSrcIndex(bop->in(), bop->out()); |
1008 | auto indexed_expr = |
1009 | IrBuilder::create<BroadcastOp>(out, in, bop->getBroadcastDimFlags()); |
1010 | |
1011 | const ParallelTypeBitmap parallel_bitmap = |
1012 | GpuLower::current()->threadPredMap().getParallelBroadcastDomains(out_tv); |
1013 | |
1014 | const bool block_x = parallel_bitmap.get(ParallelType::BIDx); |
1015 | const bool block_y = parallel_bitmap.get(ParallelType::BIDy); |
1016 | const bool block_z = parallel_bitmap.get(ParallelType::BIDz); |
1017 | |
1018 | if (bop->predicate()) { |
1019 | indexed_expr = |
1020 | indexed_expr->withPredicate(bop->predicate())->as<BroadcastOp>(); |
1021 | } |
1022 | |
1023 | const bool grid_broadcast_needed = block_x || block_y || block_z; |
1024 | if (!grid_broadcast_needed) { |
1025 | pushBack(indexed_expr); |
1026 | GpuLower::current()->propagateExprInfo(bop, back()); |
1027 | return; |
1028 | } |
1029 | |
1030 | // Grid broadcast |
1031 | const auto out_domain = out_tv->domain(); |
1032 | const auto work_buffer_size = |
1033 | getGridCommWorkBufferSize(out_domain, for_loops_, true) |
1034 | .size_of_privatized_buffer; |
1035 | |
1036 | auto work_buffer = allocateUniqueBuffer( |
1037 | work_buffer_size, out->dtype(), false, out_tv, work_buffer_map_); |
1038 | |
1039 | auto sync_buffer_size = getGridSyncBufferSize(out_domain, for_loops_, true); |
1040 | auto sync_buffer = allocateUniqueBuffer( |
1041 | sync_buffer_size, DataType::Int, true, out_tv, sync_buffer_map_); |
1042 | |
1043 | auto grid_broadcast = IrBuilder::create<kir::GridBroadcast>( |
1044 | indexed_expr, work_buffer, sync_buffer); |
1045 | |
1046 | if (bop->predicate()) { |
1047 | grid_broadcast = grid_broadcast->withPredicate(bop->predicate()) |
1048 | ->as<kir::GridBroadcast>(); |
1049 | } |
1050 | |
1051 | pushBack(grid_broadcast); |
1052 | GpuLower::current()->propagateExprInfo(bop, back()); |
1053 | } |
1054 | |
1055 | void IndexLowering::handle(const kir::Allocate* allocate) { |
1056 | // TODO(kir): remove the need for const_cast |
1057 | pushBack(const_cast<kir::Allocate*>(allocate)); // NOLINT |
1058 | } |
1059 | |
1060 | void IndexLowering::handle(const kir::BlockSync* sync) { |
1061 | // TODO(kir): remove the need for const_cast |
1062 | pushBack(const_cast<kir::BlockSync*>(sync)); // NOLINT |
1063 | } |
1064 | |
1065 | void IndexLowering::handle(const kir::GridSync* sync) { |
1066 | // TODO(kir): remove the need for const_cast |
1067 | pushBack(const_cast<kir::GridSync*>(sync)); // NOLINT |
1068 | } |
1069 | |
1070 | void IndexLowering::handle(const kir::CpAsyncWait* wait) { |
1071 | // TODO(kir): remove the need for const_cast |
1072 | pushBack(const_cast<kir::CpAsyncWait*>(wait)); // NOLINT |
1073 | } |
1074 | |
1075 | void IndexLowering::handle(const kir::CpAsyncCommit* commit) { |
1076 | // TODO(kir): remove the need for const_cast |
1077 | pushBack(const_cast<kir::CpAsyncCommit*>(commit)); // NOLINT |
1078 | } |
1079 | |
1080 | void IndexLowering::generate(const std::vector<Expr*>& exprs) { |
1081 | for (auto expr : exprs) { |
1082 | OptOutConstDispatch::handle(expr); |
1083 | } |
1084 | } |
1085 | |
1086 | kir::Allocate* IndexLowering::allocateUniqueBuffer( |
1087 | Val* buffer_size, |
1088 | DataType dtype, |
1089 | bool zero_init, |
1090 | TensorView* out_tv, |
1091 | std::unordered_map<TensorView*, kir::Allocate*>& alloc_map) { |
1092 | // Return an existing allocation if exists |
1093 | auto it = alloc_map.find(out_tv); |
1094 | if (it != alloc_map.end()) { |
1095 | return it->second; |
1096 | } |
1097 | |
1098 | // No existing allocation found. Create a new one |
1099 | auto new_buffer = |
1100 | lower_utils::allocGlobalBufferForGridComm(buffer_size, dtype, zero_init); |
1101 | |
1102 | // Keep track of the allocation |
1103 | alloc_map.emplace(out_tv, new_buffer); |
1104 | |
1105 | // A buffer may be used in both the unswitched paths, so it must be |
1106 | // placed outside of the current scope. Simplying placing it at the |
1107 | // top-level scope should work. |
1108 | insertAtTopLevel(new_buffer); |
1109 | |
1110 | return new_buffer; |
1111 | } |
1112 | |
1113 | void IndexLowering::allocateUniqueFusedReduction( |
1114 | Expr* expr, |
1115 | TensorView* out_tv) { |
1116 | auto it = fused_reduction_map_.find(out_tv); |
1117 | if (it != fused_reduction_map_.end()) { |
1118 | return; |
1119 | } |
1120 | |
1121 | kir::AllocateFusedReduction* fused_reduction_alloc_reduction = nullptr; |
1122 | switch (expr->getExprType().value()) { |
1123 | case ExprType::GridReduction: |
1124 | fused_reduction_alloc_reduction = |
1125 | IrBuilder::create<kir::AllocateFusedReduction>( |
1126 | expr->as<kir::GridReduction>()); |
1127 | break; |
1128 | case ExprType::GridWelford: |
1129 | fused_reduction_alloc_reduction = |
1130 | IrBuilder::create<kir::AllocateFusedReduction>( |
1131 | expr->as<kir::GridWelford>()); |
1132 | break; |
1133 | case ExprType::GroupedGridReduction: |
1134 | fused_reduction_alloc_reduction = |
1135 | IrBuilder::create<kir::AllocateFusedReduction>( |
1136 | expr->as<kir::GroupedGridReduction>()); |
1137 | break; |
1138 | case ExprType::GroupedGridWelford: |
1139 | fused_reduction_alloc_reduction = |
1140 | IrBuilder::create<kir::AllocateFusedReduction>( |
1141 | expr->as<kir::GroupedGridWelford>()); |
1142 | break; |
1143 | default: |
1144 | TORCH_INTERNAL_ASSERT(false, "Invalid expr: " , expr->toString()); |
1145 | } |
1146 | |
1147 | fused_reduction_map_.emplace(out_tv, fused_reduction_alloc_reduction); |
1148 | |
1149 | // When using the fused reduction, allocate the reduction object at |
1150 | // the outer-most scope |
1151 | insertAtTopLevel(fused_reduction_alloc_reduction); |
1152 | } |
1153 | |
1154 | } // namespace cuda |
1155 | } // namespace fuser |
1156 | } // namespace jit |
1157 | } // namespace torch |
1158 | |