1#include <lower_utils.h>
2
3#include <ATen/cuda/CUDAContext.h>
4#include <c10/util/irange.h>
5#include <arith.h>
6#include <ir_iostream.h>
7#include <ir_utils.h>
8#include <iter_visitor.h>
9#include <kernel_ir_dispatch.h>
10#include <lower2device.h>
11#include <lower_thread_predicate.h>
12#include <root_domain_map.h>
13
14#include <algorithm>
15
16// TODO: refactor this file (one per namespace)
17
18namespace torch {
19namespace jit {
20namespace fuser {
21namespace cuda {
22
23namespace scope_utils {
24
25//! Create an **empty** Forloop and copy the metadata.
26kir::ForLoop* cloneForLoop(kir::ForLoop* for_loop) {
27 return IrBuilder::create<kir::ForLoop>(for_loop);
28}
29
30//! Create an **empty** IfThenElse and copy the metadata.
31kir::IfThenElse* cloneIfThenElse(kir::IfThenElse* ite) {
32 return IrBuilder::create<kir::IfThenElse>(ite->predicate());
33}
34
35} // namespace scope_utils
36
37namespace ir_utils {
38
39TVDomainGuard::TVDomainGuard(TensorView* tv, TensorDomain* td)
40 : tv_(tv), prev_domain_(tv_->domain()) {
41 tv_->setDomain(td);
42}
43
44TVDomainGuard::TVDomainGuard(TVDomainGuard&& guard)
45 : tv_(nullptr), prev_domain_(guard.prev_domain_) {
46 std::swap(tv_, guard.tv_);
47}
48
49TVDomainGuard::~TVDomainGuard() {
50 if (tv_ != nullptr) {
51 tv_->setDomain(prev_domain_);
52 }
53}
54
55ir_utils::TVDomainGuard overrideContiguityGuard(
56 TensorView* tv,
57 bool contiguity) {
58 // Use domain guard to ignore the contiguity of
59 // consumer tv.
60 TensorDomain* domain_with_specified_contiguity = nullptr;
61 std::vector<bool> contiguity_vector(
62 tv->getMaybeRFactorDomain().size(), contiguity);
63 if (tv->hasRFactor()) {
64 domain_with_specified_contiguity = IrBuilder::create<TensorDomain>(
65 tv->getRootDomain(),
66 tv->getRFactorDomain(),
67 tv->domain()->domain(),
68 contiguity_vector);
69 } else {
70 domain_with_specified_contiguity = IrBuilder::create<TensorDomain>(
71 tv->getRootDomain(), tv->domain()->domain(), contiguity_vector);
72 }
73
74 return ir_utils::TVDomainGuard(tv, domain_with_specified_contiguity);
75}
76
77std::vector<IterDomain*> iterDomainInputsOf(
78 const std::vector<IterDomain*>& input_ids,
79 const std::vector<IterDomain*>& all_inputs) {
80 auto inputs = IterVisitor::getInputsTo(
81 {input_ids.begin(), input_ids.end()},
82 {all_inputs.begin(), all_inputs.end()});
83 std::vector<IterDomain*> id_inputs(
84 ir_utils::filterByType<IterDomain>(inputs).begin(),
85 ir_utils::filterByType<IterDomain>(inputs).end());
86 return id_inputs;
87}
88
89std::vector<IterDomain*> iterDomainInputsOfOrderedAs(
90 const std::vector<IterDomain*>& of,
91 const std::vector<IterDomain*>& order) {
92 auto inputs_vec = iterDomainInputsOf(of, order);
93
94 std::unordered_set<IterDomain*> inputs_set(
95 inputs_vec.begin(), inputs_vec.end());
96
97 std::vector<IterDomain*> ordered_inputs;
98 std::copy_if(
99 order.begin(),
100 order.end(),
101 std::back_inserter(ordered_inputs),
102 [&inputs_set](const auto& id) {
103 return inputs_set.find(id) != inputs_set.end();
104 });
105
106 return ordered_inputs;
107}
108
109bool isTV(const Val* val) {
110 return val->getValType().value() == ValType::TensorView ||
111 val->getValType().value() == ValType::TensorIndex;
112}
113
114// Check if we're a TensorView op that we can generate code for.
115bool isTvOp(const Expr* expr) {
116 if (std::any_of(
117 expr->outputs().begin(),
118 expr->outputs().end(),
119 [](Val* v) { return isTV(v); }) &&
120 (expr->getExprType().value() == ExprType::UnaryOp ||
121 expr->getExprType().value() == ExprType::BinaryOp ||
122 expr->getExprType().value() == ExprType::TernaryOp ||
123 expr->getExprType().value() == ExprType::RNGOp ||
124 expr->getExprType().value() == ExprType::FullOp ||
125 expr->getExprType().value() == ExprType::ARangeOp ||
126 expr->getExprType().value() == ExprType::EyeOp ||
127 expr->getExprType().value() == ExprType::ReductionOp ||
128 expr->getExprType().value() == ExprType::GroupedReductionOp ||
129 expr->getExprType().value() == ExprType::WelfordOp ||
130 expr->getExprType().value() == ExprType::GroupedWelfordOp ||
131 expr->getExprType().value() == ExprType::LoadStoreOp ||
132 expr->getExprType().value() == ExprType::MmaOp ||
133 expr->getExprType().value() == ExprType::BroadcastOp ||
134 expr->getExprType().value() == ExprType::TransposeOp ||
135 expr->getExprType().value() == ExprType::ExpandOp ||
136 expr->getExprType().value() == ExprType::ShiftOp ||
137 expr->getExprType().value() == ExprType::GatherOp ||
138 expr->getExprType().value() == ExprType::ViewAsScalar ||
139 expr->getExprType().value() == ExprType::ViewOp ||
140 expr->getExprType().value() == ExprType::GridReduction ||
141 expr->getExprType().value() == ExprType::GroupedGridReduction ||
142 expr->getExprType().value() == ExprType::GridBroadcast ||
143 expr->getExprType().value() == ExprType::GridWelford ||
144 expr->getExprType().value() == ExprType::GroupedGridWelford)) {
145 return true;
146 }
147 return false;
148}
149
150bool isLdMatrixOp(const Expr* expr) {
151 if (auto ldst = dynamic_cast<const LoadStoreOp*>(expr)) {
152 return ldst->opType() == LoadStoreOpType::LdMatrix ||
153 ldst->opType() == LoadStoreOpType::LdMatrixTranspose;
154 }
155 return false;
156}
157
158bool isCpAsyncOp(const Expr* expr) {
159 if (auto ldst = dynamic_cast<const LoadStoreOp*>(expr)) {
160 return ldst->opType() == LoadStoreOpType::CpAsync;
161 }
162 return false;
163}
164
165bool isTensorScalarFillOp(const Expr* expr) {
166 // Check that the input is a single scalar.
167 if (expr->inputs().size() == 1 && expr->input(0)->isScalar()) {
168 // All load store op with a single scalar input
169 // should be a scalar filling op. Semantically
170 // it literally means `Store`'ing a scalar
171 // into a tensor.
172 if (expr->isA<LoadStoreOp>()) {
173 return true;
174 }
175 // Unary copy op is also a scalar filling op.
176 if (auto uop = dynamic_cast<const UnaryOp*>(expr)) {
177 return uop->getUnaryOpType() == UnaryOpType::Set;
178 }
179 }
180 // Ideally any scalar expression that outputs
181 // to a tensor should be considered in this function
182 // but since we currently only limit scope to
183 // initialization patterns so other scalar expr's
184 // are low priority and are excluded here to avoid confusion.
185 return false;
186}
187
188TensorView* getTv(Val* val) {
189 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
190 return const_cast<TensorView*>(getTv(const_cast<const Val*>(val)));
191}
192
193const TensorView* getTv(const Val* val) {
194 if (val->isA<TensorView>()) {
195 return val->as<TensorView>();
196 } else if (val->isA<kir::TensorIndex>()) {
197 return val->as<kir::TensorIndex>()->view();
198 }
199 return nullptr;
200}
201
202std::vector<TensorView*> getTvs(const std::vector<Val*>& vals) {
203 std::vector<TensorView*> tvs;
204 for (auto val : vals) {
205 auto tv = ir_utils::getTv(val);
206 if (tv) {
207 tvs.emplace_back(tv);
208 }
209 }
210 return tvs;
211}
212
213TensorView* getTvOutput(const Expr* expr) {
214 for (auto out : expr->outputs()) {
215 if (auto tv = getTv(out)) {
216 return tv;
217 }
218 }
219 return nullptr;
220}
221
222TensorView* getTvInput(const Expr* expr) {
223 for (auto inp : expr->inputs()) {
224 if (auto tv = getTv(inp)) {
225 return tv;
226 }
227 }
228 return nullptr;
229}
230
231bool isScalarOp(const Expr* expr) {
232 for (auto out : expr->outputs())
233 if (!out->isScalar())
234 return false;
235 return true;
236}
237
238c10::optional<IterDomain*> getMaybeWarpReductionDim(
239 const Val* output,
240 const Val* input) {
241 auto tv_out = getTv(output);
242 if (tv_out == nullptr) {
243 return c10::nullopt;
244 }
245
246 auto tv_in = getTv(input);
247 // only support reducing to registers for now.
248 if (tv_in->getMemoryType() != MemoryType::Local ||
249 tv_out->getMemoryType() != MemoryType::Local) {
250 return c10::nullopt;
251 }
252
253 IterDomain* reduction_on_xdim = nullptr;
254 for (auto id : tv_out->domain()->domain()) {
255 // Currently warp reduction only allows
256 // serial and block.x parallel reductions
257 if (id->isReduction() && id->isParallelized()) {
258 if (id->getParallelType() == ParallelType::TIDx) {
259 reduction_on_xdim = id;
260 } else if (id->isThread()) {
261 return c10::nullopt;
262 }
263 }
264 }
265 if (!reduction_on_xdim) {
266 return c10::nullopt;
267 }
268
269 if (!reduction_on_xdim->start()->isZeroInt()) {
270 return c10::nullopt;
271 }
272
273 if (reduction_on_xdim->hasPaddingToMultipleOfWarp()) {
274 return c10::optional<IterDomain*>(reduction_on_xdim);
275 }
276
277 if (reduction_on_xdim->extent()->isConstInt()) {
278 auto extent_value = reduction_on_xdim->extent()->evaluateInt();
279 if (extent_value % at::cuda::warp_size() == 0) {
280 return c10::optional<IterDomain*>(reduction_on_xdim);
281 }
282 }
283
284 return c10::nullopt;
285}
286
287bool derivedFromRootCAAxes(const TensorView* tv, IterDomain* axis) {
288 std::vector<IterDomain*> ca_axes(
289 tv->domain()->domain().begin(),
290 tv->domain()->domain().begin() + tv->getComputeAtPosition());
291
292 auto ca_root_vals = IterVisitor::getInputsTo(
293 std::vector<Val*>(ca_axes.begin(), ca_axes.end()));
294
295 auto root_vals = IterVisitor::getInputsTo({axis});
296
297 return std::any_of(
298 root_vals.begin(), root_vals.end(), [&ca_root_vals](auto root) {
299 return std::find(ca_root_vals.begin(), ca_root_vals.end(), root) !=
300 ca_root_vals.end();
301 });
302}
303
304std::unordered_map<ParallelType, IterDomain*, TypeHash> getParallelDomains(
305 const Val* val) {
306 const TensorView* tv = nullptr;
307 if (val->isA<TensorView>()) {
308 tv = val->as<TensorView>();
309 } else if (val->isA<kir::TensorIndex>()) {
310 tv = val->as<kir::TensorIndex>()->view();
311 } else {
312 TORCH_INTERNAL_ASSERT(
313 false, "Provided val is not TensorIndex or TensorView.");
314 }
315
316 std::unordered_map<ParallelType, IterDomain*, TypeHash> parallel_domains;
317 for (auto d : tv->domain()->domain()) {
318 if (d->isThread()) {
319 parallel_domains.insert(std::make_pair(d->getParallelType(), d));
320 }
321 }
322 return parallel_domains;
323}
324
325bool isCpAsyncInit(const Expr* expr) {
326 return isTensorScalarFillOp(expr) &&
327 // FIXME:
328 // We'd need to add a flag to all the init
329 // exprs so we could robustly detect initialization
330 // in all cases.
331 isCpAsyncOp(getTvOutput(expr)->definition());
332}
333
334c10::optional<Expr*> getMaybePredicatedSingleton(Expr* expr) {
335 if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
336 if (ite->elseBody().empty()) {
337 if (ite->thenBody().size() == 1) {
338 return ite->thenBody().exprs()[0];
339 }
340 }
341 }
342 return c10::nullopt;
343}
344
345//! Short-cut for checking if the expression loads from global memory.
346bool isGlobalLoad(const Expr* expr) {
347 if (expr->isA<LoadStoreOp>() ||
348 (expr->isA<UnaryOp>() &&
349 expr->as<UnaryOp>()->getUnaryOpType() == UnaryOpType::Set)) {
350 if (auto in_tv = getTv(expr->input(0))) {
351 return in_tv->getMemoryType() == MemoryType::Global;
352 }
353 }
354 return false;
355}
356
357//! Short-cut for checking if the given expression initializes buffers
358//! for global memory load.
359bool isGlobalLoadInit(const Expr* expr) {
360 if (auto uop = dynamic_cast<const UnaryOp*>(expr)) {
361 if (uop->in()->isScalar()) {
362 // FIXME:
363 // We'd need to add a flag to all the init
364 // exprs so we could robustly detect initialization
365 // in all cases.
366 if (isGlobalLoad(getTvOutput(uop)->definition())) {
367 return true;
368 }
369 }
370 }
371 return false;
372}
373
374namespace {
375
376class ExprFlattener : private kir::IrVisitor {
377 private:
378 using kir::IrVisitor::handle;
379
380 void handle(Expr* expr) final {
381 if (expr->isA<kir::ForLoop>() || expr->isA<kir::IfThenElse>()) {
382 kir::IrVisitor::handle(expr);
383 } else {
384 flat_exprs_.push_back(expr);
385 }
386 }
387
388 private:
389 std::vector<Expr*> flat_exprs_;
390
391 public:
392 //! Flattens scopes extracting out a single ordered list of exprs.
393 static std::vector<Expr*> flatten(const std::vector<Expr*>& loop_nests) {
394 ExprFlattener flattener;
395 for (auto expr : loop_nests) {
396 flattener.handle(expr);
397 }
398 return flattener.flat_exprs_;
399 }
400};
401
402} // namespace
403
404std::vector<Expr*> flattenScopedExprs(const std::vector<Expr*>& loop_nests) {
405 return ExprFlattener::flatten(loop_nests);
406}
407
408namespace {
409
410class ReplaceExprInput : private kir::ExprMutator {
411 public:
412 static std::vector<Expr*> replace(
413 const std::vector<Expr*>& exprs,
414 const std::unordered_map<Val*, Val*>& replacement_map) {
415 ReplaceExprInput replacer(replacement_map);
416 replacer.traverseAndInsert(exprs);
417 return replacer.exprs_;
418 }
419
420 private:
421 ReplaceExprInput(const std::unordered_map<Val*, Val*>& replacement_map)
422 : replacement_map_(replacement_map) {}
423
424 using kir::ExprMutator::handle;
425
426 c10::optional<std::unordered_map<Val*, Val*>> getMaybeInputReplacementMap(
427 Expr* expr) {
428 bool need_replacement = false;
429
430 std::unordered_map<Val*, Val*> replaced_val;
431 for (auto in : expr->inputs()) {
432 auto replace_it = replacement_map_.find(in);
433 if (replace_it != replacement_map_.end()) {
434 need_replacement = true;
435 replaced_val[in] = replace_it->second;
436 } else {
437 replaced_val[in] = in;
438 }
439 }
440 if (need_replacement) {
441 return c10::optional<std::unordered_map<Val*, Val*>>(replaced_val);
442 } else {
443 return c10::nullopt;
444 }
445 }
446
447 // Copy predicates and register expression replacement
448 void registerReplaceWithPredicate(Expr* old_expr, Expr* new_expr) {
449 new_expr = new_expr->withPredicate(old_expr->predicate())
450 ->withWritePredicate(old_expr->writePredicate());
451 registerReplace(old_expr, new_expr);
452 }
453
454 void handle(UnaryOp* node) final {
455 auto replaced_inputs = getMaybeInputReplacementMap(node);
456 if (replaced_inputs.has_value()) {
457 auto replacement = IrBuilder::create<UnaryOp>(
458 node->getUnaryOpType(), node->out(), replaced_inputs->at(node->in()));
459 registerReplaceWithPredicate(node, replacement);
460 }
461 }
462
463 void handle(BinaryOp* node) final {
464 auto replaced_inputs = getMaybeInputReplacementMap(node);
465 if (replaced_inputs.has_value()) {
466 auto replacement = IrBuilder::create<BinaryOp>(
467 node->getBinaryOpType(),
468 node->out(),
469 replaced_inputs->at(node->lhs()),
470 replaced_inputs->at(node->rhs()));
471 registerReplaceWithPredicate(node, replacement);
472 }
473 }
474
475 void handle(TernaryOp* node) final {
476 auto replaced_inputs = getMaybeInputReplacementMap(node);
477 if (replaced_inputs.has_value()) {
478 auto replacement = IrBuilder::create<TernaryOp>(
479 node->getTernaryOpType(),
480 node->out(),
481 replaced_inputs->at(node->in1()),
482 replaced_inputs->at(node->in2()),
483 replaced_inputs->at(node->in3()));
484 registerReplaceWithPredicate(node, replacement);
485 }
486 }
487
488 void handle(RNGOp* node) final {
489 // RNGOp has no input
490 return;
491 }
492
493 void handle(ReductionOp* node) final {
494 auto replaced_inputs = getMaybeInputReplacementMap(node);
495 if (replaced_inputs.has_value()) {
496 auto replacement = IrBuilder::create<ReductionOp>(
497 node->getReductionOpType(),
498 node->init(),
499 node->out(),
500 replaced_inputs->at(node->in()),
501 node->isAllreduce());
502 registerReplaceWithPredicate(node, replacement);
503 }
504 }
505
506 void handle(GroupedReductionOp* node) final {
507 auto replaced_inputs = getMaybeInputReplacementMap(node);
508 if (replaced_inputs.has_value()) {
509 const auto& map = replaced_inputs.value();
510 auto inputs = node->inputs();
511 for (auto& input : inputs) {
512 auto it = map.find(input);
513 if (it != map.end()) {
514 input = it->second;
515 }
516 }
517 auto replacement = IrBuilder::create<GroupedReductionOp>(
518 node->getReductionOpTypes(),
519 node->initVals(),
520 node->outputs(),
521 inputs,
522 node->isAllreduce());
523 registerReplaceWithPredicate(node, replacement);
524 }
525 }
526 void handle(BroadcastOp* node) final {
527 auto replaced_inputs = getMaybeInputReplacementMap(node);
528 if (replaced_inputs.has_value()) {
529 auto replacement = IrBuilder::create<BroadcastOp>(
530 node->out(),
531 replaced_inputs->at(node->in()),
532 node->getBroadcastDimFlags());
533 registerReplaceWithPredicate(node, replacement);
534 }
535 }
536
537 void handle(WelfordOp* node) final {
538 auto replaced_inputs = getMaybeInputReplacementMap(node);
539 if (replaced_inputs.has_value()) {
540 auto replacement = IrBuilder::create<WelfordOp>(
541 node->outAvg(),
542 node->outVar(),
543 node->outN(),
544 node->initAvg(),
545 node->initVar(),
546 node->initN(),
547 replaced_inputs->at(node->inAvg()),
548 replaced_inputs->at(node->inVar()),
549 replaced_inputs->at(node->inN()));
550 registerReplaceWithPredicate(node, replacement);
551 }
552 }
553
554 void handle(MmaOp* node) final {
555 auto replaced_inputs = getMaybeInputReplacementMap(node);
556 if (replaced_inputs.has_value()) {
557 auto replacement = IrBuilder::create<MmaOp>(
558 node->out(),
559 replaced_inputs->at(node->inA()),
560 replaced_inputs->at(node->inB()),
561 node->init(),
562 node->options());
563 registerReplaceWithPredicate(node, replacement);
564 }
565 }
566
567 void handle(LoadStoreOp* node) final {
568 auto replaced_inputs = getMaybeInputReplacementMap(node);
569 if (replaced_inputs.has_value()) {
570 auto replacement = IrBuilder::create<LoadStoreOp>(
571 node->opType(), node->out(), node->in());
572 registerReplaceWithPredicate(node, replacement);
573 }
574 }
575
576 private:
577 const std::unordered_map<Val*, Val*>& replacement_map_;
578};
579
580} // namespace
581
582std::vector<Expr*> replaceInputsInExpr(
583 const std::vector<Expr*>& exprs,
584 const std::unordered_map<Val*, Val*>& replacement_map) {
585 return ReplaceExprInput::replace(exprs, replacement_map);
586}
587
588std::vector<Expr*> getAllSwizzlesBetween(
589 std::vector<IterDomain*> from,
590 std::vector<IterDomain*> to) {
591 auto all_expr = DependencyCheck::getAllExprsBetween(
592 {from.begin(), from.end()}, {to.begin(), to.end()});
593
594 std::vector<Expr*> all_swizzles;
595
596 std::copy_if(
597 all_expr.begin(),
598 all_expr.end(),
599 std::back_inserter(all_swizzles),
600 [](Expr* expr) {
601 return expr->getExprType().has_value() &&
602 (expr->etype() == ExprType::Swizzle2D);
603 });
604
605 return all_swizzles;
606}
607
608} // namespace ir_utils
609
610namespace lower_utils {
611
612bool hasBlockSync(const Expr* expr, const ThreadPredicateMap& pred_map) {
613 if (expr->isA<kir::BlockSync>()) {
614 return true;
615 }
616
617 if (!ir_utils::isTvOp(expr)) {
618 return false;
619 }
620
621 if (!(ir_utils::isReductionOp(expr) || expr->isA<BroadcastOp>() ||
622 expr->isA<kir::GridBroadcast>())) {
623 return false;
624 }
625
626 // GroupedReductionOp can have multiple output TVs, but they must be
627 // parallelized in the same way, so just checking one of them is enough.
628 auto tv = ir_utils::getTvOutput(expr);
629
630 if (tv->hasBlockReduction() || tv->hasGridReduction()) {
631 return true;
632 } else if (expr->isA<BroadcastOp>()) {
633 const ParallelTypeBitmap pt_map =
634 GpuLower::current()->threadPredMap().getParallelBroadcastDomains(tv);
635 return pt_map.any();
636 }
637
638 return false;
639}
640
641kir::Allocate* allocGlobalBufferForGridComm(
642 Val* buffer_size,
643 DataType dtype,
644 bool zero_init) {
645 const std::vector<IterDomain*> new_buffer_ids = {
646 IrBuilder::create<IterDomain>(IterDomainBuilder(
647 GpuLower::current()->kernel()->zeroVal(), buffer_size))};
648 const auto buffer_domain = IrBuilder::create<TensorDomain>(new_buffer_ids);
649 const auto buffer_tv =
650 IrBuilder::create<TensorView>(buffer_domain, dtype, MemoryType::Global);
651 return IrBuilder::create<kir::Allocate>(
652 buffer_tv, buffer_tv->getMemoryType(), nullptr, zero_init);
653}
654
655BasicAllocInfo getAllocInformation(
656 const TensorView* tv,
657 const std::vector<kir::ForLoop*>& for_loops,
658 const std::unordered_map<IterDomain*, IterDomain*>& id_map,
659 bool use_id_map) {
660 BasicAllocInfo info;
661 auto gpu_lower = GpuLower::current();
662
663 bool outer_alloc_found = false;
664
665 for (auto fl : for_loops) {
666 if (info.alloc_pos == tv->getComputeAtPosition()) {
667 break;
668 }
669
670 if (tv->axis(info.alloc_pos)->isReduction()) {
671 const auto outputs = FusionGuard::getCurFusion()->getTerminatingOutputs();
672 TORCH_INTERNAL_ASSERT(
673 std::find(outputs.begin(), outputs.end(), tv) != outputs.end(),
674 "Invalid computeAt of T",
675 tv->name(),
676 ". A reducation axis is detected outside computeAt point even though it is not an output tensor.");
677 break;
678 }
679
680 auto fl_id = fl->iter_domain();
681
682 if (fl_id->getParallelType() == ParallelType::Unroll) {
683 break;
684 }
685
686 // Shared memory must be allocated outside of unswitched
687 // domains. See issue #1133.
688 if (fl_id->getParallelType() == ParallelType::Unswitch &&
689 tv->getMemoryType() == MemoryType::Shared) {
690 outer_alloc_found = true;
691 }
692
693 // Assume global memory is allocated at outer most scope.
694 if (tv->getMemoryType() == MemoryType::Global) {
695 outer_alloc_found = true;
696 }
697
698 // Allocation of a double buffered tensor is placed outside its
699 // double buffer axis.
700 if ((tv->isDoubleBuffered() || tv->isCircularBuffered()) &&
701 tv->axis(info.alloc_pos) ==
702 gpu_lower->doubleBufferInfo().getDoubleBufferAxis(tv)) {
703 outer_alloc_found = true;
704 }
705
706 auto local_id = tv->axis(info.alloc_pos);
707
708 if (use_id_map) {
709 auto id_it = id_map.find(local_id);
710 if (id_it != id_map.end()) {
711 local_id = id_it->second;
712 }
713 }
714
715 if (GpuLower::current()->caMap()->areMapped(
716 local_id, fl_id, IdMappingMode::PERMISSIVE)) {
717 info.alloc_pos++;
718 }
719
720 info.init_for_loop = fl;
721
722 if (!outer_alloc_found) {
723 info.alloc_for_loop = fl;
724 }
725 }
726
727 return info;
728}
729
730//! Implementing this in here to avoid including too many headers
731//! in type.cpp. Conceptually this should be a generic definition
732//! rather than a util.
733bool supportInlinePredicate(Expr* expr) {
734 if (ir_utils::isCpAsyncOp(expr)) {
735 return true;
736 }
737 // TODO: build out support.
738 return false;
739}
740
741} // namespace lower_utils
742
743} // namespace cuda
744} // namespace fuser
745} // namespace jit
746} // namespace torch
747