1 | #include <ir_builder.h> |
2 | #include <kernel.h> |
3 | #include <kernel_expr_evaluator.h> |
4 | #include <kernel_ir.h> |
5 | #include <lower2device.h> |
6 | #include <lower_utils.h> |
7 | #include <type.h> |
8 | |
9 | #include <iostream> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | namespace fuser { |
14 | namespace cuda { |
15 | namespace kir { |
16 | |
17 | Predicate::Predicate( |
18 | IrBuilderPasskey passkey, |
19 | PredicateType ptype, |
20 | const Expr* expr, |
21 | Bool* thread_pred) |
22 | : Val(passkey, ValType::Predicate, DataType::Bool), |
23 | ptype_(ptype), |
24 | expr_(expr), |
25 | thread_pred_(thread_pred) { |
26 | TORCH_INTERNAL_ASSERT( |
27 | passkey.ir_container_->isA<kir::Kernel>(), |
28 | "IR type only valid for Kernel container." ); |
29 | TORCH_INTERNAL_ASSERT( |
30 | ptype != PredicateType::Unswitch && ptype != PredicateType::Manual); |
31 | } |
32 | |
33 | Predicate::Predicate(IrBuilderPasskey passkey, ForLoop* unrolled_loop) |
34 | : Val(passkey, ValType::Predicate, DataType::Bool), |
35 | ptype_(PredicateType::Unswitch), |
36 | unrolled_loop_(unrolled_loop) { |
37 | TORCH_INTERNAL_ASSERT( |
38 | passkey.ir_container_->isA<kir::Kernel>(), |
39 | "IR type only valid for Kernel container." ); |
40 | TORCH_INTERNAL_ASSERT(unrolled_loop != nullptr); |
41 | } |
42 | |
43 | Predicate::Predicate(IrBuilderPasskey passkey, Bool* value) |
44 | : Val(passkey, ValType::Predicate, DataType::Bool), |
45 | ptype_(PredicateType::Manual), |
46 | value_(value) { |
47 | TORCH_INTERNAL_ASSERT( |
48 | passkey.ir_container_->isA<kir::Kernel>(), |
49 | "IR type only valid for Kernel container." ); |
50 | TORCH_INTERNAL_ASSERT(value != nullptr); |
51 | } |
52 | |
53 | TensorIndex::TensorIndex( |
54 | IrBuilderPasskey passkey, |
55 | const TensorView* view, |
56 | std::vector<Val*> indices) |
57 | : Val(passkey, ValType::TensorIndex, view->getDataType().value()), |
58 | view_(view), |
59 | indices_(indices) { |
60 | TORCH_INTERNAL_ASSERT( |
61 | passkey.ir_container_->isA<kir::Kernel>(), |
62 | "IR type only valid for Kernel container." ); |
63 | TORCH_INTERNAL_ASSERT( |
64 | std::all_of( |
65 | indices.begin(), |
66 | indices.end(), |
67 | [](Val* v) { return v->dtype() == DataType::Int; }), |
68 | "Cannot index with a value other than an int." ); |
69 | indices_.erase( |
70 | std::remove_if( |
71 | indices_.begin(), |
72 | indices_.end(), |
73 | [](Val* index) { return index->isZeroInt(); }), |
74 | indices_.end()); |
75 | // If indices becomes empty, just put one ZeroInt |
76 | if (indices_.empty()) { |
77 | indices_.push_back(FusionGuard::getCurFusion()->zeroVal()); |
78 | } |
79 | } |
80 | |
81 | Val* TensorIndex::index(int i) const { |
82 | TORCH_INTERNAL_ASSERT( |
83 | nDims() > 0, "Tried to get an index of a 0-dim TensorIndex" ); |
84 | if (i < 0) |
85 | i += nDims(); |
86 | TORCH_INTERNAL_ASSERT(i >= 0 && i < int(nDims())); |
87 | return indices_[i]; |
88 | } |
89 | |
90 | BlockSync::BlockSync(IrBuilderPasskey passkey, bool war_sync) |
91 | : Expr(passkey, ExprType::BlockSync), war_sync_(war_sync) { |
92 | TORCH_INTERNAL_ASSERT( |
93 | passkey.ir_container_->isA<kir::Kernel>(), |
94 | "IR type only valid for Kernel container." ); |
95 | } |
96 | |
97 | Expr* BlockSync::shallowCopy() const { |
98 | auto result = IrBuilder::create<BlockSync>(war_sync_); |
99 | result->copyPredicatesFrom(this); |
100 | return result; |
101 | } |
102 | |
103 | GridSync::GridSync( |
104 | IrBuilderPasskey passkey, |
105 | ParallelTypeBitmap sync_dims, |
106 | Val* sync_buffer) |
107 | : Expr(passkey, ExprType::GridSync), |
108 | sync_dims_(sync_dims), |
109 | sync_buffer_(sync_buffer) {} |
110 | |
111 | Expr* GridSync::shallowCopy() const { |
112 | auto result = IrBuilder::create<GridSync>(sync_dims_, sync_buffer_); |
113 | result->copyPredicatesFrom(this); |
114 | return result; |
115 | } |
116 | |
117 | CpAsyncWait::CpAsyncWait(IrBuilderPasskey passkey, unsigned int keep_stages) |
118 | : Expr(passkey, ExprType::CpAsyncWait), keep_stages_(keep_stages) { |
119 | TORCH_INTERNAL_ASSERT( |
120 | passkey.ir_container_->isA<kir::Kernel>(), |
121 | "IR type only valid for Kernel container." ); |
122 | } |
123 | |
124 | Expr* CpAsyncWait::shallowCopy() const { |
125 | auto result = IrBuilder::create<CpAsyncWait>(keep_stages_); |
126 | result->copyPredicatesFrom(this); |
127 | return result; |
128 | } |
129 | |
130 | CpAsyncCommit::CpAsyncCommit(IrBuilderPasskey passkey) |
131 | : Expr(passkey, ExprType::CpAsyncCommit) { |
132 | TORCH_INTERNAL_ASSERT( |
133 | passkey.ir_container_->isA<kir::Kernel>(), |
134 | "IR type only valid for Kernel container." ); |
135 | } |
136 | |
137 | Expr* CpAsyncCommit::shallowCopy() const { |
138 | auto result = IrBuilder::create<CpAsyncCommit>(); |
139 | result->copyPredicatesFrom(this); |
140 | return result; |
141 | } |
142 | |
143 | InitMagicZero::InitMagicZero(IrBuilderPasskey passkey) |
144 | : Expr(passkey, ExprType::InitMagicZero) { |
145 | TORCH_INTERNAL_ASSERT( |
146 | passkey.ir_container_->isA<kir::Kernel>(), |
147 | "IR type only valid for Kernel container." ); |
148 | } |
149 | |
150 | Expr* InitMagicZero::shallowCopy() const { |
151 | auto result = IrBuilder::create<InitMagicZero>(); |
152 | result->copyPredicatesFrom(this); |
153 | return result; |
154 | } |
155 | |
156 | UpdateMagicZero::UpdateMagicZero(IrBuilderPasskey passkey) |
157 | : Expr(passkey, ExprType::UpdateMagicZero) { |
158 | TORCH_INTERNAL_ASSERT( |
159 | passkey.ir_container_->isA<kir::Kernel>(), |
160 | "IR type only valid for Kernel container." ); |
161 | } |
162 | |
163 | Expr* UpdateMagicZero::shallowCopy() const { |
164 | auto result = IrBuilder::create<UpdateMagicZero>(); |
165 | result->copyPredicatesFrom(this); |
166 | return result; |
167 | } |
168 | |
169 | namespace { |
170 | |
171 | bool isIntegralScalar(const Val* val) { |
172 | return val->isScalar() && val->getDataType().has_value() && |
173 | isIntegralType(val->getDataType().value()); |
174 | } |
175 | |
176 | } // namespace |
177 | |
178 | IntPair::IntPair(IrBuilderPasskey passkey) |
179 | : Val(passkey, ValType::IntPair, DataType::Index) {} |
180 | |
181 | PairSelect::PairSelect( |
182 | IrBuilderPasskey passkey, |
183 | Val* out, |
184 | IntPair* in, |
185 | PairSelect::Selection selection) |
186 | : Expr(passkey, ExprType::PairSelect), |
187 | out_{out}, |
188 | in_{in}, |
189 | selection_(selection) { |
190 | addOutput(out); |
191 | addInput(in); |
192 | TORCH_INTERNAL_ASSERT(isIntegralScalar(out), "Integer only for this op" ); |
193 | } |
194 | |
195 | Expr* PairSelect::shallowCopy() const { |
196 | auto result = IrBuilder::create<PairSelect>(out_, in_, selection_); |
197 | result->copyPredicatesFrom(this); |
198 | return result; |
199 | } |
200 | |
201 | Swizzle2DInt::Swizzle2DInt( |
202 | IrBuilderPasskey passkey, |
203 | IntPair* out, |
204 | Val* in_x, |
205 | Val* in_y, |
206 | Val* extent_x, |
207 | Val* extent_y, |
208 | Swizzle2DType swizzle_type) |
209 | : Expr(passkey, ExprType::Swizzle2DInt), |
210 | out_{out}, |
211 | in_x_{in_x}, |
212 | in_y_{in_y}, |
213 | extent_x_(extent_x), |
214 | extent_y_(extent_y), |
215 | swizzle_type_(swizzle_type) { |
216 | TORCH_INTERNAL_ASSERT(isIntegralScalar(in_x), "Integer only for this op" ); |
217 | TORCH_INTERNAL_ASSERT(isIntegralScalar(in_y), "Integer only for this op" ); |
218 | |
219 | addOutput(out); |
220 | addInput(in_x); |
221 | addInput(in_y); |
222 | addInput(extent_x); |
223 | addInput(extent_y); |
224 | } |
225 | |
226 | Expr* Swizzle2DInt::shallowCopy() const { |
227 | auto result = IrBuilder::create<Swizzle2DInt>( |
228 | out_, in_x_, in_y_, extent_x_, extent_y_, swizzle_type_); |
229 | result->copyPredicatesFrom(this); |
230 | return result; |
231 | } |
232 | |
233 | void Scope::insert(std::vector<Expr*>::const_iterator pos, Expr* expr) { |
234 | exprs_.insert(pos, expr); |
235 | } |
236 | |
237 | void Scope::insert_before(Expr* ref, Expr* expr) { |
238 | const auto it = std::find(exprs_.begin(), exprs_.end(), ref); |
239 | TORCH_INTERNAL_ASSERT( |
240 | it != exprs_.end(), |
241 | "Tried to insert " , |
242 | expr, |
243 | " before the reference: " , |
244 | ref, |
245 | " however the reference was not found in this scope." ); |
246 | insert(it, expr); |
247 | } |
248 | |
249 | void Scope::insert_after(Expr* ref, Expr* expr) { |
250 | const auto it = std::find(exprs_.begin(), exprs_.end(), ref); |
251 | TORCH_INTERNAL_ASSERT( |
252 | it != exprs_.end(), |
253 | "Tried to insert " , |
254 | expr, |
255 | " after the reference: " , |
256 | ref, |
257 | " however the reference was not found in this scope." ); |
258 | insert(it + 1, expr); |
259 | } |
260 | |
261 | void Scope::insert(size_t pos, Expr* expr) { |
262 | const auto it = exprs_.begin() + pos; |
263 | insert(it, expr); |
264 | } |
265 | |
266 | void Scope::erase(std::vector<Expr*>::const_iterator pos) { |
267 | // Remove the scope of the expr if this is the scope |
268 | C10_UNUSED auto expr = *pos; |
269 | exprs_.erase(pos); |
270 | } |
271 | |
272 | void Scope::erase(Expr* ref) { |
273 | const auto it = std::find(exprs_.begin(), exprs_.end(), ref); |
274 | if (it != exprs_.end()) { |
275 | erase(it); |
276 | } |
277 | } |
278 | |
279 | void Scope::erase(size_t pos) { |
280 | TORCH_INTERNAL_ASSERT(pos < size()); |
281 | erase(exprs_.begin() + pos); |
282 | } |
283 | |
284 | bool Scope::contains(Expr* expr) const { |
285 | const auto it = std::find(exprs_.begin(), exprs_.end(), expr); |
286 | return it != exprs_.end(); |
287 | } |
288 | |
289 | void Scope::clear() { |
290 | exprs_.clear(); |
291 | } |
292 | |
293 | ForLoop::ForLoop( |
294 | IrBuilderPasskey passkey, |
295 | IterDomain* iter_domain, |
296 | Val* index, |
297 | Val* start, |
298 | Val* stop, |
299 | Val* step, |
300 | bool vectorize, |
301 | Val* vectorize_shift, |
302 | bool unroll_required, |
303 | DoubleBufferLoopStage double_buffer_loop_stage) |
304 | : Expr(passkey, ExprType::ForLoop), |
305 | iter_domain_{iter_domain}, |
306 | index_(index), |
307 | start_(start), |
308 | stop_(stop), |
309 | step_(step), |
310 | vectorize_(vectorize), |
311 | vectorize_shift_(vectorize_shift), |
312 | unroll_required_(unroll_required), |
313 | body_(this), |
314 | double_buffer_loop_stage_(double_buffer_loop_stage) { |
315 | TORCH_INTERNAL_ASSERT( |
316 | passkey.ir_container_->isA<kir::Kernel>(), |
317 | "IR type only valid for Kernel container." ); |
318 | TORCH_INTERNAL_ASSERT(index->dtype() == DataType::Int); |
319 | addInput(index); |
320 | addInput(iter_domain); |
321 | if (start_ == nullptr && iter_domain->isThread()) { |
322 | start_ = NamedScalar::getParallelIndex(iter_domain->getParallelType()); |
323 | } |
324 | if (step_ == nullptr) { |
325 | if (iter_domain->isThread()) { |
326 | step_ = NamedScalar::getParallelDim(iter_domain->getParallelType()); |
327 | } else { |
328 | step_ = FusionGuard::getCurFusion()->oneVal(); |
329 | } |
330 | } |
331 | } |
332 | |
333 | ForLoop::ForLoop(IrBuilderPasskey passkey, IterDomain* iter_domain) |
334 | : ForLoop( |
335 | passkey, |
336 | iter_domain, |
337 | GpuLower::current()->caMap()->getIndexVariable(iter_domain), |
338 | nullptr, |
339 | nullptr, |
340 | nullptr, |
341 | !iter_domain->isBroadcast() && |
342 | isParallelTypeVectorize(iter_domain->getParallelType()), |
343 | nullptr, |
344 | false, |
345 | DoubleBufferLoopStage::NotApplicable) { |
346 | TORCH_INTERNAL_ASSERT( |
347 | passkey.ir_container_->isA<kir::Kernel>(), |
348 | "IR type only valid for Kernel container." ); |
349 | } |
350 | |
351 | ForLoop::ForLoop(IrBuilderPasskey passkey, const ForLoop* other) |
352 | : ForLoop( |
353 | passkey, |
354 | other->iter_domain(), |
355 | other->index(), |
356 | other->start(), |
357 | other->stop(), |
358 | other->step(), |
359 | other->vectorize(), |
360 | other->vectorize_shift(), |
361 | other->isUnrollRequired(), |
362 | other->doubleBufferLoopStage()) { |
363 | TORCH_INTERNAL_ASSERT( |
364 | passkey.ir_container_->isA<kir::Kernel>(), |
365 | "IR type only valid for Kernel container." ); |
366 | } |
367 | |
368 | Expr* ForLoop::shallowCopy() const { |
369 | auto result = IrBuilder::create<ForLoop>( |
370 | iter_domain_, |
371 | index_, |
372 | start_, |
373 | stop_, |
374 | step_, |
375 | vectorize_, |
376 | vectorize_shift_, |
377 | unroll_required_, |
378 | double_buffer_loop_stage_); |
379 | result->body_ = body_; |
380 | result->copyPredicatesFrom(this); |
381 | return result; |
382 | } |
383 | |
384 | bool ForLoop::isUnrollable() const { |
385 | // Start and stop must be constant, must not be a broadcast |
386 | // dimension, cannot be bound to a parallel dimension, must not be |
387 | // vectorized. |
388 | return start()->isConstScalar() && stop()->isConstScalar() && |
389 | !iter_domain()->isThread() && !iter_domain()->isBroadcast() && |
390 | !vectorize(); |
391 | } |
392 | |
393 | bool ForLoop::isUnrolled() const { |
394 | if (isUnrollRequired() && !isUnrollable()) { |
395 | TORCH_WARN( |
396 | "Unroll required but not possible. Register allocation disabled. Loop index: " , |
397 | index_->toString()); |
398 | return false; |
399 | } |
400 | |
401 | // Size-one loop will not be materialized as a loop, so return false |
402 | if (start()->isZeroInt() && stop()->isOneInt()) { |
403 | return false; |
404 | } |
405 | |
406 | // Unroll if required. |
407 | if (isUnrollRequired()) { |
408 | return true; |
409 | } |
410 | |
411 | // Don't unroll if not possible |
412 | if (!isUnrollable()) { |
413 | return false; |
414 | } |
415 | |
416 | // Unrolling is technically possible but avoided |
417 | if (iter_domain()->getParallelType() == ParallelType::Unswitch) { |
418 | // Use ParallelType::Unroll if unrolling is desired. Note that |
419 | // unswitched size-one loops are not unrolled as they are not |
420 | // materialized as actual for-loops. |
421 | return false; |
422 | } |
423 | |
424 | return true; |
425 | } |
426 | |
427 | Val* ForLoop::start() const { |
428 | if (start_ != nullptr) { |
429 | return start_; |
430 | } else { |
431 | // clang-tidy complains without this |
432 | TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr); |
433 | return iter_domain_->start(); |
434 | } |
435 | } |
436 | |
437 | Val* ForLoop::stop() const { |
438 | if (stop_ != nullptr) { |
439 | return stop_; |
440 | } else { |
441 | // clang-tidy complains without this |
442 | TORCH_INTERNAL_ASSERT(iter_domain_ != nullptr); |
443 | return iter_domain_->extent(); |
444 | } |
445 | } |
446 | |
447 | Val* ForLoop::step() const { |
448 | TORCH_INTERNAL_ASSERT(step_ != nullptr); |
449 | return step_; |
450 | } |
451 | |
452 | bool ForLoop::isTrivial() const { |
453 | // These loops are not materialized |
454 | if (vectorize() || iter_domain()->isBroadcast() || |
455 | iter_domain()->isStride() || iter_domain()->isMma()) { |
456 | return true; |
457 | } |
458 | |
459 | // By default, a parallelized loop would look like: |
460 | // |
461 | // for (int x = threadIdx.x; x < stop; x += blockDim.x) { |
462 | // do_some_comp(x); |
463 | // } |
464 | // |
465 | // When stop is guaranteed to be smaller or equal to the number of |
466 | // threads, the for-loop is not necessary. In the above case, we |
467 | // would just generate the loop body without the for clause but |
468 | // references to the loop index replaced by the loop start value. |
469 | // |
470 | // When the loop end is the same as the IterDomain extent, the |
471 | // assumption can be safely made. This is more conservative than |
472 | // necessary since the loop stop value just needs to be <= the |
473 | // IterDomain extent. However, at this point, this conservative |
474 | // analysis seems sufficient. |
475 | if (stop() == iter_domain()->extent() && iter_domain()->isThread()) { |
476 | return true; |
477 | } |
478 | |
479 | // Extent-1 loop: for (int i = 0; i < 1; ++i) { |
480 | if (start()->isZeroInt() && stop()->isOneInt() && step()->isOneInt()) { |
481 | return true; |
482 | } |
483 | |
484 | // Another extent-1 loop: for (int i = N - 1; i < N; ++i) { |
485 | if (start()->definition() != nullptr && |
486 | start()->definition()->isA<BinaryOp>() && |
487 | start()->definition()->as<BinaryOp>()->getBinaryOpType() == |
488 | BinaryOpType::Sub && |
489 | start()->definition()->as<BinaryOp>()->lhs() == stop() && |
490 | start()->definition()->as<BinaryOp>()->rhs()->isOneInt()) { |
491 | return true; |
492 | } |
493 | |
494 | return false; |
495 | } |
496 | |
497 | IfThenElse::IfThenElse(IrBuilderPasskey passkey, Predicate* cond) |
498 | : Expr(passkey, ExprType::IfThenElse), then_body_(this), else_body_(this) { |
499 | setPredicate(cond); |
500 | addInput(cond); |
501 | } |
502 | |
503 | Expr* IfThenElse::shallowCopy() const { |
504 | auto result = IrBuilder::create<IfThenElse>(predicate()); |
505 | result->then_body_ = then_body_; |
506 | result->else_body_ = else_body_; |
507 | result->setWritePredicate(writePredicate()); |
508 | return result; |
509 | } |
510 | |
511 | Allocate::Allocate( |
512 | IrBuilderPasskey passkey, |
513 | Val* buffer, |
514 | MemoryType memory_type, |
515 | std::vector<Val*> shape, |
516 | bool zero_init) |
517 | : Expr(passkey, ExprType::Allocate), |
518 | buffer_(buffer), |
519 | memory_type_(memory_type), |
520 | shape_(std::move(shape)), |
521 | zero_init_(zero_init) { |
522 | TORCH_INTERNAL_ASSERT( |
523 | passkey.ir_container_->isA<kir::Kernel>(), |
524 | "IR type only valid for Kernel container." ); |
525 | if (!shape_.empty()) { |
526 | TORCH_INTERNAL_ASSERT( |
527 | (shape_.size() == 1 && shape_[0]->isOneInt()) || |
528 | buffer_->isA<TensorView>()); |
529 | } else { |
530 | TORCH_INTERNAL_ASSERT(buffer_->isA<TensorView>()); |
531 | TORCH_INTERNAL_ASSERT( |
532 | buffer_->as<TensorView>()->getMemoryType() == memory_type_); |
533 | const auto domain = buffer_->as<TensorView>()->domain(); |
534 | for (auto axis : domain->noReductions()) { |
535 | shape_.push_back(axis->extent()); |
536 | } |
537 | } |
538 | |
539 | for (auto s : shape_) { |
540 | if (size_ == nullptr) { |
541 | size_ = s; |
542 | } else { |
543 | size_ = IrBuilder::mulExpr(size_, s); |
544 | } |
545 | } |
546 | |
547 | if (size_ == nullptr) { |
548 | size_ = FusionGuard::getCurFusion()->oneVal(); |
549 | } |
550 | |
551 | addInput(size_); |
552 | } |
553 | |
554 | Allocate::Allocate( |
555 | IrBuilderPasskey passkey, |
556 | Val* buffer, |
557 | MemoryType memory_type, |
558 | Val* size, |
559 | bool zero_init) |
560 | : Allocate( |
561 | passkey, |
562 | buffer, |
563 | memory_type, |
564 | size == nullptr ? std::vector<Val*>{} : std::vector<Val*>{size}, |
565 | zero_init) { |
566 | TORCH_INTERNAL_ASSERT( |
567 | passkey.ir_container_->isA<kir::Kernel>(), |
568 | "IR type only valid for Kernel container." ); |
569 | } |
570 | |
571 | Expr* Allocate::shallowCopy() const { |
572 | auto result = |
573 | IrBuilder::create<Allocate>(buffer_, memory_type_, shape_, zero_init_); |
574 | result->copyPredicatesFrom(this); |
575 | return result; |
576 | } |
577 | |
578 | GridReduction::GridReduction( |
579 | IrBuilderPasskey passkey, |
580 | BinaryOpType reduction_op_type, |
581 | Val* init, |
582 | Val* out, |
583 | Val* in, |
584 | Allocate* reduction_buffer, |
585 | Allocate* sync_buffer, |
586 | Val* entrance_index, |
587 | Val* entrances, |
588 | bool is_allreduce) |
589 | : ReductionOp( |
590 | passkey, |
591 | reduction_op_type, |
592 | init, |
593 | out, |
594 | in, |
595 | is_allreduce, |
596 | ExprType::GridReduction), |
597 | reduction_buffer_(reduction_buffer), |
598 | sync_buffer_(sync_buffer), |
599 | entrance_index_(entrance_index), |
600 | entrances_(entrances) { |
601 | TORCH_INTERNAL_ASSERT( |
602 | passkey.ir_container_->isA<kir::Kernel>(), |
603 | "IR type only valid for Kernel container." ); |
604 | } |
605 | |
606 | Expr* GridReduction::shallowCopy() const { |
607 | auto result = IrBuilder::create<GridReduction>( |
608 | getReductionOpType(), |
609 | init(), |
610 | out(), |
611 | in(), |
612 | reduction_buffer_, |
613 | sync_buffer_, |
614 | entrance_index_, |
615 | entrances_, |
616 | isAllreduce()); |
617 | result->copyPredicatesFrom(this); |
618 | result->thread_predicate_ = thread_predicate_; |
619 | return result; |
620 | } |
621 | |
622 | GroupedGridReduction::GroupedGridReduction( |
623 | IrBuilderPasskey passkey, |
624 | std::vector<BinaryOpType> reduction_op_types, |
625 | std::vector<Val*> init_vals, |
626 | std::vector<Val*> outputs, |
627 | std::vector<Val*> inputs, |
628 | std::vector<Allocate*> reduction_buffers, |
629 | Allocate* sync_buffer, |
630 | Val* entrance_index, |
631 | Val* entrances, |
632 | Val* buffer_stride, |
633 | bool is_allreduce) |
634 | : GroupedReductionOp( |
635 | passkey, |
636 | std::move(reduction_op_types), |
637 | std::move(init_vals), |
638 | std::move(outputs), |
639 | std::move(inputs), |
640 | is_allreduce, |
641 | ExprType::GroupedGridReduction), |
642 | reduction_buffers_(std::move(reduction_buffers)), |
643 | sync_buffer_(sync_buffer), |
644 | entrance_index_(entrance_index), |
645 | entrances_(entrances), |
646 | buffer_stride_(buffer_stride) { |
647 | TORCH_INTERNAL_ASSERT( |
648 | passkey.ir_container_->isA<kir::Kernel>(), |
649 | "IR type only valid for Kernel container." ); |
650 | } |
651 | |
652 | Expr* GroupedGridReduction::shallowCopy() const { |
653 | auto result = IrBuilder::create<GroupedGridReduction>( |
654 | getReductionOpTypes(), |
655 | initVals(), |
656 | outputs(), |
657 | inputs(), |
658 | reduction_buffers_, |
659 | sync_buffer_, |
660 | entrance_index_, |
661 | entrances_, |
662 | buffer_stride_, |
663 | isAllreduce()); |
664 | result->copyPredicatesFrom(this); |
665 | result->thread_predicate_ = thread_predicate_; |
666 | return result; |
667 | } |
668 | |
669 | GridBroadcast::GridBroadcast( |
670 | IrBuilderPasskey passkey, |
671 | BroadcastOp* broadcast_op, |
672 | Allocate* broadcast_buffer, |
673 | Allocate* sync_buffer) |
674 | : Expr(passkey, ExprType::GridBroadcast), |
675 | broadcast_op_(broadcast_op), |
676 | broadcast_buffer_(broadcast_buffer), |
677 | sync_buffer_(sync_buffer) { |
678 | TORCH_INTERNAL_ASSERT( |
679 | passkey.ir_container_->isA<kir::Kernel>(), |
680 | "IR type only valid for Kernel container." ); |
681 | } |
682 | |
683 | Expr* GridBroadcast::shallowCopy() const { |
684 | auto result = IrBuilder::create<GridBroadcast>( |
685 | broadcast_op_, broadcast_buffer_, sync_buffer_); |
686 | result->copyPredicatesFrom(this); |
687 | return result; |
688 | } |
689 | |
690 | GridWelford::GridWelford( |
691 | IrBuilderPasskey passkey, |
692 | WelfordOp* welford_op, |
693 | Allocate* var_buffer, |
694 | Allocate* avg_buffer, |
695 | Allocate* n_buffer, |
696 | Allocate* sync_buffer, |
697 | Val* entrance_index, |
698 | Val* entrances) |
699 | : Expr(passkey, ExprType::GridWelford), |
700 | welford_op_(welford_op), |
701 | var_buffer_(var_buffer), |
702 | avg_buffer_(avg_buffer), |
703 | n_buffer_(n_buffer), |
704 | sync_buffer_(sync_buffer), |
705 | entrance_index_(entrance_index), |
706 | entrances_(entrances) { |
707 | TORCH_INTERNAL_ASSERT( |
708 | passkey.ir_container_->isA<kir::Kernel>(), |
709 | "IR type only valid for Kernel container." ); |
710 | } |
711 | |
712 | Expr* GridWelford::shallowCopy() const { |
713 | auto result = IrBuilder::create<GridWelford>( |
714 | welford_op_, |
715 | var_buffer_, |
716 | avg_buffer_, |
717 | n_buffer_, |
718 | sync_buffer_, |
719 | entrance_index_, |
720 | entrances_); |
721 | result->copyPredicatesFrom(this); |
722 | result->thread_predicate_ = thread_predicate_; |
723 | return result; |
724 | } |
725 | |
726 | GroupedGridWelford::GroupedGridWelford( |
727 | IrBuilderPasskey passkey, |
728 | std::vector<WelfordTriplet> output_vals, |
729 | std::vector<WelfordTriplet> input_vals, |
730 | std::vector<WelfordTriplet> init_vals, |
731 | std::array<std::vector<Allocate*>, 3> reduction_buffers, |
732 | Allocate* sync_buffer, |
733 | Val* entrance_index, |
734 | Val* entrances, |
735 | Val* buffer_stride, |
736 | bool is_allreduce) |
737 | : GroupedWelfordOp( |
738 | passkey, |
739 | std::move(output_vals), |
740 | std::move(input_vals), |
741 | std::move(init_vals), |
742 | is_allreduce, |
743 | ExprType::GroupedGridWelford), |
744 | reduction_buffers_(std::move(reduction_buffers)), |
745 | sync_buffer_(sync_buffer), |
746 | entrance_index_(entrance_index), |
747 | entrances_(entrances), |
748 | buffer_stride_(buffer_stride) { |
749 | TORCH_INTERNAL_ASSERT( |
750 | passkey.ir_container_->isA<kir::Kernel>(), |
751 | "IR type only valid for Kernel container." ); |
752 | } |
753 | |
754 | Expr* GroupedGridWelford::shallowCopy() const { |
755 | auto result = IrBuilder::create<GroupedGridWelford>( |
756 | outputVals(), |
757 | inputVals(), |
758 | initVals(), |
759 | reduction_buffers_, |
760 | sync_buffer_, |
761 | entrance_index_, |
762 | entrances_, |
763 | buffer_stride_, |
764 | isAllreduce()); |
765 | result->copyPredicatesFrom(this); |
766 | result->thread_predicate_ = thread_predicate_; |
767 | return result; |
768 | } |
769 | |
770 | AllocateFusedReduction::AllocateFusedReduction( |
771 | IrBuilderPasskey passkey, |
772 | GridReduction* grid_reduction) |
773 | : Expr(passkey, ExprType::AllocateFusedReduction), |
774 | grid_expr_(grid_reduction) { |
775 | TORCH_INTERNAL_ASSERT( |
776 | passkey.ir_container_->isA<kir::Kernel>(), |
777 | "IR type only valid for Kernel container." ); |
778 | } |
779 | |
780 | AllocateFusedReduction::AllocateFusedReduction( |
781 | IrBuilderPasskey passkey, |
782 | GridWelford* grid_welford) |
783 | : Expr(passkey, ExprType::AllocateFusedReduction), |
784 | grid_expr_(grid_welford) { |
785 | TORCH_INTERNAL_ASSERT( |
786 | passkey.ir_container_->isA<kir::Kernel>(), |
787 | "IR type only valid for Kernel container." ); |
788 | } |
789 | |
790 | AllocateFusedReduction::AllocateFusedReduction( |
791 | IrBuilderPasskey passkey, |
792 | GroupedGridReduction* grouped_grid_reduction) |
793 | : Expr(passkey, ExprType::AllocateFusedReduction), |
794 | grid_expr_(grouped_grid_reduction) { |
795 | TORCH_INTERNAL_ASSERT( |
796 | passkey.ir_container_->isA<kir::Kernel>(), |
797 | "IR type only valid for Kernel container." ); |
798 | } |
799 | |
800 | AllocateFusedReduction::AllocateFusedReduction( |
801 | IrBuilderPasskey passkey, |
802 | GroupedGridWelford* grouped_grid_welford) |
803 | : Expr(passkey, ExprType::AllocateFusedReduction), |
804 | grid_expr_(grouped_grid_welford) { |
805 | TORCH_INTERNAL_ASSERT( |
806 | passkey.ir_container_->isA<kir::Kernel>(), |
807 | "IR type only valid for Kernel container." ); |
808 | } |
809 | |
810 | Expr* AllocateFusedReduction::shallowCopy() const { |
811 | if (grid_expr_->isA<GridReduction>()) { |
812 | auto result = IrBuilder::create<AllocateFusedReduction>( |
813 | grid_expr_->as<GridReduction>()); |
814 | result->setPredicate(predicate()); |
815 | result->setWritePredicate(writePredicate()); |
816 | return result; |
817 | } else if (grid_expr_->isA<GridWelford>()) { |
818 | auto result = IrBuilder::create<AllocateFusedReduction>( |
819 | grid_expr_->as<GridWelford>()); |
820 | result->setPredicate(predicate()); |
821 | result->setWritePredicate(writePredicate()); |
822 | return result; |
823 | } else if (grid_expr_->isA<GroupedGridReduction>()) { |
824 | auto result = IrBuilder::create<AllocateFusedReduction>( |
825 | grid_expr_->as<GroupedGridReduction>()); |
826 | result->setPredicate(predicate()); |
827 | result->setWritePredicate(writePredicate()); |
828 | return result; |
829 | } else if (grid_expr_->isA<GroupedGridWelford>()) { |
830 | auto result = IrBuilder::create<AllocateFusedReduction>( |
831 | grid_expr_->as<GroupedGridWelford>()); |
832 | result->setPredicate(predicate()); |
833 | result->setWritePredicate(writePredicate()); |
834 | return result; |
835 | } |
836 | TORCH_INTERNAL_ASSERT( |
837 | false, "Unknown reduction type in AllocateFusedReduction::shallowCopy" ); |
838 | } |
839 | |
840 | TensorIndex* AllocateFusedReduction::out() const { |
841 | TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr); |
842 | if (grid_expr_->isA<GridReduction>() || |
843 | grid_expr_->isA<GroupedGridReduction>()) { |
844 | return grid_expr_->outputs().at(0)->as<kir::TensorIndex>(); |
845 | } else if (auto grid_welford = dynamic_cast<GridWelford*>(grid_expr_)) { |
846 | return grid_welford->welford_op()->out()->as<kir::TensorIndex>(); |
847 | } else if ( |
848 | auto grouped_grid_welford = |
849 | dynamic_cast<GroupedGridWelford*>(grid_expr_)) { |
850 | return grouped_grid_welford->out(0)->as<kir::TensorIndex>(); |
851 | } else { |
852 | TORCH_INTERNAL_ASSERT( |
853 | false, "Invalid grid expression: " , grid_expr_->toString()); |
854 | } |
855 | } |
856 | |
857 | const ParallelTypeBitmap& AllocateFusedReduction::threadPredicate() const { |
858 | TORCH_INTERNAL_ASSERT(grid_expr_ != nullptr); |
859 | if (auto grid_reduction = dynamic_cast<GridReduction*>(grid_expr_)) { |
860 | return grid_reduction->threadPredicate(); |
861 | } else if (auto grid_welford = dynamic_cast<GridWelford*>(grid_expr_)) { |
862 | return grid_welford->threadPredicate(); |
863 | } else if ( |
864 | auto grouped_grid_reduction = |
865 | dynamic_cast<GroupedGridReduction*>(grid_expr_)) { |
866 | return grouped_grid_reduction->threadPredicate(); |
867 | } else if ( |
868 | auto grouped_grid_welford = |
869 | dynamic_cast<GroupedGridWelford*>(grid_expr_)) { |
870 | return grouped_grid_welford->threadPredicate(); |
871 | } else { |
872 | TORCH_INTERNAL_ASSERT( |
873 | false, "Invalid grid expression: " , grid_expr_->toString()); |
874 | } |
875 | } |
876 | |
877 | } // namespace kir |
878 | } // namespace cuda |
879 | } // namespace fuser |
880 | } // namespace jit |
881 | } // namespace torch |
882 | |