1#include <predicate_compute.h>
2
3#include <arith.h>
4#include <expr_evaluator.h>
5#include <fusion.h>
6#include <index_compute.h>
7#include <instrumentation.h>
8#include <ir_utils.h>
9#include <lower2device.h>
10#include <transform_iter.h>
11
12#include <c10/util/irange.h>
13
14namespace torch {
15namespace jit {
16namespace fuser {
17namespace cuda {
18
19namespace {
20
21bool isTensorIndexOp(Expr* expr) {
22 const auto& outputs = expr->outputs();
23 return outputs.size() >= 1 && outputs[0]->isA<kir::TensorIndex>();
24}
25
26bool isOutputLocal(const Expr* expr) {
27 return std::all_of(
28 expr->outputs().begin(), expr->outputs().end(), [](const Val* output) {
29 return !output->isA<TensorView>() ||
30 output->as<TensorView>()->getMemoryType() == MemoryType::Local;
31 });
32}
33
34} // namespace
35
36bool ParallelizedDomainPredicate::PredicateInfo::addDomain(IterDomain* id) {
37 auto concrete_id = GpuLower::current()->caMap()->getConcreteMappedID(
38 id, IdMappingMode::EXACT);
39 if (std::find(ids_.begin(), ids_.end(), concrete_id) == ids_.end()) {
40 ids_.push_back(concrete_id);
41 return true;
42 } else {
43 return false;
44 }
45}
46
47Bool* ParallelizedDomainPredicate::PredicateInfo::getPredicate() const {
48 Bool* pred = nullptr;
49
50 auto index = SimplifyingIrBuilder::create<NamedScalar>(
51 stringifyThread(pt_), DataType::Int);
52
53 for (const auto& pred_id : ids()) {
54 // Just sanity check that pred_id is concrete
55 TORCH_INTERNAL_ASSERT(
56 pred_id ==
57 GpuLower::current()->caMap()->getConcreteMappedID(
58 pred_id, IdMappingMode::EXACT));
59 auto new_pred = SimplifyingIrBuilder::ltExpr(index, pred_id->extent());
60 pred = SimplifyingIrBuilder::andExpr(pred, new_pred)->as<Bool>();
61 }
62
63 return pred;
64}
65
66namespace {
67
68std::unordered_set<Val*> getNonUnswitchedRootDomains(
69 const std::vector<kir::ForLoop*>& loops,
70 size_t unswitched_loop_index) {
71 std::vector<Val*> non_unswited_leaf_domains;
72 std::transform(
73 loops.begin(),
74 loops.begin() + unswitched_loop_index,
75 std::back_inserter(non_unswited_leaf_domains),
76 [&](kir::ForLoop* loop) { return loop->iter_domain(); });
77
78 auto non_unswitched_inputs =
79 IterVisitor::getInputsTo(non_unswited_leaf_domains);
80
81 auto non_unswitched_root_doms =
82 ir_utils::filterByType<IterDomain>(non_unswitched_inputs);
83
84 std::unordered_set<Val*> non_unswitched_concrete_root_domains;
85
86 std::transform(
87 non_unswitched_root_doms.begin(),
88 non_unswitched_root_doms.end(),
89 std::inserter(
90 non_unswitched_concrete_root_domains,
91 non_unswitched_concrete_root_domains.end()),
92 [&](auto root_dom) {
93 return GpuLower::current()->caMap()->getConcreteMappedID(
94 root_dom, IdMappingMode::EXACT);
95 });
96
97 return non_unswitched_concrete_root_domains;
98}
99
100bool isFullyUnswitched(
101 IterDomain* loop_id,
102 const std::unordered_set<Val*>& non_unswitched_root_domains) {
103 auto root_vals = IterVisitor::getInputsTo({loop_id});
104
105 auto root_domains = ir_utils::filterByType<IterDomain>(root_vals);
106
107 return std::none_of(
108 root_domains.begin(), root_domains.end(), [&](auto root_dom) {
109 auto concrete_root_dom =
110 GpuLower::current()->caMap()->getConcreteMappedID(
111 root_dom, IdMappingMode::EXACT);
112 return non_unswitched_root_domains.count(concrete_root_dom) > 0;
113 });
114}
115
116} // namespace
117
118std::unordered_map<
119 ParallelType,
120 ParallelizedDomainPredicate::PredicateInfo,
121 TypeHash>
122ParallelizedDomainPredicate::getPredicateMap(
123 const Expr* expr,
124 const std::vector<kir::ForLoop*>& loops,
125 kir::ForLoop* unswitched_loop) {
126 const auto gpu_lower = GpuLower::current();
127 auto output_tvs = ir_utils::getTvs(expr->outputs());
128
129 if (output_tvs.empty()) {
130 return {};
131 }
132
133 // Initialize a map with empty predicate info
134 std::unordered_map<ParallelType, PredicateInfo, TypeHash> map;
135 for (auto pt : kParallelTypeThreads) {
136 map.insert({pt, PredicateInfo(pt)});
137 }
138
139 // For each loop, check if it's parallelized by an non-exact
140 // threading dimension. If yes and it's used in the given expr, the
141 // domain needs to be protected by a predicate on the thread/block
142 // index.
143
144 bool within_unswitch = false;
145 std::unordered_set<Val*> non_unswitched_root_domains;
146
147 for (const auto i : c10::irange(loops.size())) {
148 auto loop = loops[i];
149
150 // Parallel dimensions need not be predicated if fully unswitched.
151 if (loop == unswitched_loop) {
152 within_unswitch = true;
153 non_unswitched_root_domains = getNonUnswitchedRootDomains(loops, i);
154 }
155
156 auto loop_id = loop->iter_domain();
157 auto loop_ptype = loop_id->getParallelType();
158
159 // Not necessary to add a predicate if the paralle type is exact
160 if (!isParallelTypeThread(loop_ptype) ||
161 gpu_lower->parallelDimensionMap().isExact(loop_ptype)) {
162 continue;
163 }
164
165 // Parallel dimensions need not be predicated if fully unswitched.
166 if (within_unswitch &&
167 isFullyUnswitched(loop_id, non_unswitched_root_domains)) {
168 continue;
169 }
170
171 for (auto tv : output_tvs) {
172 // Check if the loop domain is used by the output tensor
173 auto it = std::find_if(
174 tv->domain()->domain().begin(),
175 tv->domain()->domain().end(),
176 [&](auto tv_id) {
177 return gpu_lower->caMap()->areMapped(
178 loop_id, tv_id, IdMappingMode::EXACT);
179 });
180 if (it == tv->domain()->domain().end()) {
181 continue;
182 }
183
184 IterDomain* tv_id = *it;
185
186 // If the corresponding domain is a broadcast, it's not really used.
187 if (tv_id->isBroadcast()) {
188 continue;
189 }
190
191 // If it's a root domain, it should be covered by the root
192 // predicates, so no extra predicate is required.
193 if (std::find(
194 tv->domain()->getRootDomain().begin(),
195 tv->domain()->getRootDomain().end(),
196 tv_id) != tv->domain()->getRootDomain().end()) {
197 continue;
198 }
199
200 // tv_id needs to be predicated. Adds it to the PredicateInfo map.
201 auto& info = map.at(loop_ptype);
202 info.addDomain(tv_id);
203 }
204 }
205
206 return map;
207}
208
209Bool* ParallelizedDomainPredicate::getPredicate(
210 const Expr* expr,
211 const std::vector<kir::ForLoop*>& loops) {
212 auto pred_map = getPredicateMap(expr, loops);
213
214 Val* pred = GpuLower::current()->kernel()->trueVal();
215
216 for (auto pt : kParallelTypeThreads) {
217 auto pred_info_it = pred_map.find(pt);
218 if (pred_info_it != pred_map.end()) {
219 const auto& pred_info = pred_info_it->second;
220 auto tid_pred = pred_info.getPredicate();
221 pred = SimplifyingIrBuilder::andExpr(pred, tid_pred);
222 }
223 }
224
225 TORCH_INTERNAL_ASSERT(pred != nullptr);
226 return pred->as<Bool>();
227}
228
229UnswitchPredicateKey::UnswitchPredicateKey()
230 : predicated_concrete_id_(nullptr) {
231 for (auto pt : kParallelTypeThreads) {
232 parallel_concrete_ids_.insert({pt, nullptr});
233 }
234}
235
236// For a predicated concrete domain, id, find which thread parallel
237// types are used. For each used parallel type, find the concrete
238// domain that the paralllel type is associated with. The parallelized
239// concrete domains are used to uniquely collect all necessary
240// unswitch predicates.
241UnswitchPredicateKey::UnswitchPredicateKey(
242 IterDomain* predicated_consumer_id,
243 TensorView* consumer_tv,
244 IterDomain* predicated_concrete_id)
245 : predicated_concrete_id_(predicated_concrete_id) {
246 // Initialize the parallelized domain map
247 for (auto pt : kParallelTypeThreads) {
248 parallel_concrete_ids_.insert({pt, nullptr});
249 }
250
251 std::vector<Val*> all_parallelized_consumer_leaf_ids;
252 std::copy_if(
253 consumer_tv->domain()->domain().begin(),
254 consumer_tv->domain()->domain().end(),
255 std::back_inserter(all_parallelized_consumer_leaf_ids),
256 [](IterDomain* x) { return isParallelTypeThread(x->getParallelType()); });
257
258 // If the consumer domais are not parallelized at all, no need to
259 // differentiate keys based on how the predicated id is parallelized
260 if (all_parallelized_consumer_leaf_ids.empty()) {
261 return;
262 }
263
264 // All domains that are parallelized descendants of predicated_consumer_id
265 auto all_parallelized_consumer_ids = DependencyCheck::getAllValsBetween(
266 {predicated_consumer_id}, all_parallelized_consumer_leaf_ids);
267 // Just pick leaf domains
268 std::vector<IterDomain*> parallelized_consumer_leaf_ids;
269 std::copy_if(
270 consumer_tv->domain()->domain().begin(),
271 consumer_tv->domain()->domain().end(),
272 std::back_inserter(parallelized_consumer_leaf_ids),
273 [&](IterDomain* x) {
274 return std::find(
275 all_parallelized_consumer_ids.begin(),
276 all_parallelized_consumer_ids.end(),
277 x) != all_parallelized_consumer_ids.end();
278 });
279
280 if (parallelized_consumer_leaf_ids.empty()) {
281 // None of the parallelized leaf domains are derived from
282 // predicated_consumer_id
283 return;
284 }
285
286 // Find the corresponding concrete id for each parallel type
287 for (auto consumer_leaf : parallelized_consumer_leaf_ids) {
288 auto pt = consumer_leaf->getParallelType();
289 auto concrete_leaf = GpuLower::current()->caMap()->getConcreteMappedID(
290 consumer_leaf, IdMappingMode::EXACT);
291 parallel_concrete_ids_.at(pt) = concrete_leaf;
292 }
293}
294
295std::string UnswitchPredicateKey::toString() const {
296 std::stringstream ss;
297 ss << "Predicated domain: ";
298 if (predicatedId() != nullptr) {
299 ss << predicatedId();
300 } else {
301 ss << "null";
302 }
303 for (auto pt : kParallelTypeThreads) {
304 auto pid = parallelId(pt);
305 ss << ", " << pt << ": ";
306 if (pid) {
307 ss << pid;
308 } else {
309 ss << "null";
310 }
311 }
312 return ss.str();
313}
314
315std::size_t UnswitchPredicateKeyHash::operator()(
316 const UnswitchPredicateKey& key) const {
317 auto h = std::hash<const IterDomain*>{}(key.predicatedId());
318 for (auto pt : kParallelTypeThreads) {
319 h = h ^ std::hash<const IterDomain*>{}(key.parallelId(pt));
320 }
321 return h;
322};
323
324Bool* PredicateCompute::getInlinePredicate(
325 const Expr* expr,
326 const std::vector<kir::ForLoop*>& loops,
327 Bool* thread_pred,
328 PredicateType pred_type) {
329 FUSER_PERF_SCOPE("GpuLower::Lower::getInlinePredicate");
330
331 const auto gpu_lower = GpuLower::current();
332
333 // If outputs are registers, no need to predicate for threads
334 if (isOutputLocal(expr)) {
335 thread_pred = gpu_lower->kernel()->trueVal();
336 }
337
338 if (loops.empty()) {
339 TORCH_INTERNAL_ASSERT(thread_pred != nullptr);
340 return thread_pred;
341 }
342
343 auto out_tv = ir_utils::getTvOutput(expr);
344 TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output");
345
346 // Predicates for non-exact parallel dimensions must be used even
347 // when PredicateElimination::canOmitPredicate is true.
348 auto parallel_dom_pred =
349 ParallelizedDomainPredicate::getPredicate(expr, loops);
350 TORCH_INTERNAL_ASSERT(parallel_dom_pred != nullptr);
351
352 if (gpu_lower->predicateElimination().canOmitPredicate(expr)) {
353 return SimplifyingIrBuilder::andExpr(thread_pred, parallel_dom_pred)
354 ->as<Bool>();
355 }
356
357 auto pred_info_vec = Index::getReferenceRootPredicates(
358 out_tv, loops, nullptr, pred_type == PredicateType::Padding);
359
360 std::vector<Bool*> preds;
361
362 // When pred_type is ReductionWrite, filter out predicates for
363 // reduction axes. For blockReduce, this is necessary when reduction
364 // axes start at non-zero offsets and parallelized with TID since
365 // blockReduce returns a valid output only at offset-zero
366 // threads. Similarly, for gridReduce, the last block to store the
367 // output may be predicated out with the read predicate, so the
368 // write predicate needs to ignore the reduction axes.
369 bool non_zero_start_found = false;
370 for (const auto& pred_info : pred_info_vec) {
371 if (pred_type == PredicateType::ReductionWrite) {
372 const auto& consumer_ids = pred_info.rootIds();
373 bool pred_for_reduction_axis = false;
374 for (auto consumer_id : consumer_ids) {
375 if (consumer_id->isReduction()) {
376 if (!consumer_id->start()->isZeroInt()) {
377 non_zero_start_found = true;
378 }
379 pred_for_reduction_axis = true;
380 break;
381 }
382 }
383 // Don't add the predicate if it corresponds to a reduction axis
384 if (pred_for_reduction_axis) {
385 continue;
386 }
387 }
388 preds.push_back(pred_info.startPredicate());
389 preds.push_back(pred_info.stopPredicate());
390 }
391
392 // When generating a predicate for blockReduce writes and not for
393 // gridReduce, if all reduction axes start with zero, we can just
394 // use the same predicate for reads. nullptr is returned then.
395 if (pred_type == PredicateType::ReductionWrite && !non_zero_start_found &&
396 !out_tv->domain()->hasGridReduction()) {
397 return nullptr;
398 }
399
400 preds.push_back(parallel_dom_pred);
401
402 if (thread_pred != nullptr) {
403 preds.push_back(thread_pred);
404 }
405
406 if (preds.empty()) {
407 return GpuLower::current()->kernel()->trueVal();
408 }
409
410 Val* cond = preds[0];
411 for (const auto i : c10::irange(1, preds.size())) {
412 cond = SimplifyingIrBuilder::andExpr(cond, preds[i]);
413 }
414
415 return cond->as<Bool>();
416}
417
418Bool* UnswitchPredicate::get(
419 const std::vector<kir::ForLoop*>& outer_loops,
420 kir::ForLoop* unrolled_loop) {
421 FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::get");
422
423 UnswitchPredicate up(outer_loops, unrolled_loop);
424
425 Val* unswitch_pred = GpuLower::current()->kernel()->trueVal();
426 for (auto pred : up.predicates_) {
427 unswitch_pred = SimplifyingIrBuilder::andExpr(unswitch_pred, pred);
428 }
429
430 return unswitch_pred->as<Bool>();
431}
432
433void UnswitchPredicate::predicateOn(Expr* tv_expr) {
434 FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::predicateOn");
435
436 if (for_loops_.empty()) {
437 return;
438 }
439
440 const auto gpu_lower = GpuLower::current();
441
442 // FIXME:
443 // Needed to keep the predicate of cp.async initialization to get the
444 // inverted predicate,
445 // see [Predicate Inversion for CpAsync]. In a follow up both this part and
446 // the [Predicate Inversion for CpAsync] should be cleaned up together.
447 if (gpu_lower->predicateElimination().canOmitPredicate(tv_expr) &&
448 !ir_utils::isCpAsyncInit(tv_expr)) {
449 addParallelizedDomainPredicates(tv_expr);
450 return;
451 }
452
453 auto out_tv = ir_utils::getTvOutput(tv_expr);
454 TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Missing TensorView output");
455
456 auto ref_pred_info = Index::getReferenceRootPredicates(
457 out_tv, for_loops_, unrolled_loop_, false);
458
459 // If RootPredicateInfo has a static predicate that is more
460 // restrictive than the current one, replace the current with the
461 // new one. If it has a dynamic predicate, add it to the dynamic
462 // predicate list. Since the final static predicate can't be
463 // determined until all expressions are analyzed, predicates are
464 // temporarily placed in the predicated_keys map and the final
465 // predicates are generated in the finalize function.
466
467 for (const auto& pred_info : ref_pred_info) {
468 TORCH_INTERNAL_ASSERT(pred_info.startPredicate() != nullptr);
469 TORCH_INTERNAL_ASSERT(pred_info.stopPredicate() != nullptr);
470
471 const auto& root_ids = pred_info.rootIds();
472
473 bool add_pred = false;
474
475 // Used to find a matching existing MergedPredicates
476 UnswitchPredicateKey first_key;
477 bool first_key_set = false;
478
479 for (auto root_id : root_ids) {
480 auto concrete_root_id = gpu_lower->caMap()->getConcreteMappedID(
481 root_id, IdMappingMode::EXACT);
482
483 if (root_id->isBroadcast()) {
484 continue;
485 }
486
487 UnswitchPredicateKey key(root_id, out_tv, concrete_root_id);
488 auto inserted = predicated_keys_.insert(key).second;
489 add_pred = add_pred || inserted;
490
491 if (!first_key_set) {
492 first_key = key;
493 first_key_set = true;
494 }
495 }
496
497 if (!first_key_set) {
498 // No predicate generated
499 continue;
500 }
501
502 // The start and stop offsets may need to be merged to avoid
503 // redundant predicates. When these offsets are zero, nothing is
504 // done. When non-zero, find the corresponding MergedPredicates
505 // and merge both the start and stop offsets. Note that the
506 // offsets are non-zero, the predicates must be generated at a
507 // root domain, so root_ids.size() must be one. That unique root
508 // domain is used as a key to find the corresponding
509 // MergedPredicate.
510
511 // Initialize with an invalid iterator to signal no corresponding
512 // MergedPredicates is found yet.
513 auto merged_pred_it = pending_predicates_.end();
514
515 if (add_pred) {
516 // This is a new predicate for the root domain. Initialize a new
517 // MergedPredicates and add it to the pending list.
518 UnswitchPredicate::MergedPredicates merged_pred;
519
520 // To look up this MergedPredicates for other predicates
521 // generated for the same predicate key
522 if (root_ids.size() == 1) {
523 merged_pred.predicate_key = first_key;
524 }
525
526 pending_predicates_.push_back(merged_pred);
527
528 merged_pred_it =
529 pending_predicates_.begin() + pending_predicates_.size() - 1;
530 } else if (root_ids.size() == 1) {
531 // If not new, try to find a corresponding MergedPredicates.
532 merged_pred_it = std::find_if(
533 pending_predicates_.begin(),
534 pending_predicates_.end(),
535 [&first_key](const auto& merged_predicates) {
536 return merged_predicates.predicate_key == first_key;
537 });
538 // Note: It is possible that no matching merged predicate info
539 // is found. Since add_pred is false here, the root domain is
540 // already predicated. It must mean that the root domain
541 // is included in a contiguous merged domain, which means there
542 // must be no halo-extended domain involved.
543 }
544
545 // If a corresponding MergedPredicates is found, merge both the
546 // start and stop offsets.
547 if (merged_pred_it != pending_predicates_.end()) {
548 mergeUnswitchPredicateOffsets(
549 pred_info.startPredicate(),
550 pred_info.startOffset(),
551 merged_pred_it->start,
552 true);
553
554 mergeUnswitchPredicateOffsets(
555 pred_info.stopPredicate(),
556 pred_info.stopOffset(),
557 merged_pred_it->stop,
558 false);
559 }
560 }
561
562 addParallelizedDomainPredicates(tv_expr);
563}
564
565void UnswitchPredicate::addParallelizedDomainPredicates(Expr* tv_expr) {
566 auto pred_map = ParallelizedDomainPredicate::getPredicateMap(
567 tv_expr, for_loops_, unrolled_loop_);
568 for (auto pt : kParallelTypeThreads) {
569 auto pred_info_it = pred_map.find(pt);
570 if (pred_info_it == pred_map.end()) {
571 continue;
572 }
573 const auto& new_info = pred_info_it->second;
574 auto& predicated =
575 parallelized_dom_predicates_
576 .insert({pt, ParallelizedDomainPredicate::PredicateInfo{pt}})
577 .first->second;
578 for (auto id : new_info.ids()) {
579 if (predicated.addDomain(id)) {
580 predicates_.push_back(new_info.getPredicate());
581 }
582 }
583 }
584}
585
586void UnswitchPredicate::openLoop(kir::ForLoop* fl) {
587 FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::openLoop");
588
589 for_loops_.push_back(fl);
590
591 for (auto expr : fl->body().exprs()) {
592 if (ir_utils::isTvOp(expr) || isTensorIndexOp(expr)) {
593 predicateOn(expr);
594 } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
595 openIte(ite);
596 } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
597 openLoop(for_loop);
598 }
599 }
600
601 for_loops_.pop_back();
602}
603
604void UnswitchPredicate::openIte(kir::IfThenElse* ite) {
605 FUSER_PERF_SCOPE("GpuLower::Lower::UnswitchPredicate::openIte");
606
607 // only expand the ite thenBody
608 for (auto expr : ite->thenBody().exprs()) {
609 if (ir_utils::isTvOp(expr) || isTensorIndexOp(expr)) {
610 predicateOn(expr);
611 } else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
612 openIte(ite);
613 } else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
614 openLoop(for_loop);
615 }
616 }
617}
618
619void UnswitchPredicate::finalize() {
620 for (const auto& merged_pred : pending_predicates_) {
621 const auto& start_info = merged_pred.start;
622 if (start_info.static_pred) {
623 predicates_.push_back(start_info.static_pred);
624 }
625 for (auto dynamic_pred : start_info.dynamic_preds) {
626 predicates_.push_back(dynamic_pred);
627 }
628 const auto& stop_info = merged_pred.stop;
629 if (stop_info.static_pred) {
630 predicates_.push_back(stop_info.static_pred);
631 }
632 for (auto dynamic_pred : stop_info.dynamic_preds) {
633 predicates_.push_back(dynamic_pred);
634 }
635 }
636}
637
638void UnswitchPredicate::mergeUnswitchPredicateOffsets(
639 Bool* predicate,
640 Val* offset,
641 MergedPredicates::Info& merged_predicate_info,
642 bool is_start) {
643 auto is_more_restrictive = [&is_start](int64_t new_val, int64_t current_val) {
644 if (is_start) {
645 return new_val < current_val;
646 } else {
647 return new_val > current_val;
648 }
649 };
650
651 auto offset_int = dynamic_cast<Int*>(offset);
652 // If it's a static predicate, replace the current one if it's
653 // more restrictive. If it's dynamic, just adds it to the dynamic
654 // predicate list.
655 if (offset_int && offset_int->isConst()) {
656 auto offset_const = offset_int->value().value();
657 auto& static_pred = merged_predicate_info.static_pred;
658 auto& static_offset = merged_predicate_info.static_offset;
659 if (static_pred == nullptr ||
660 is_more_restrictive(offset_const, static_offset)) {
661 static_pred = predicate;
662 static_offset = offset_const;
663 }
664 } else {
665 merged_predicate_info.dynamic_preds.push_back(predicate);
666 }
667}
668
669UnswitchPredicate::UnswitchPredicate(
670 std::vector<kir::ForLoop*> outer_loops,
671 kir::ForLoop* unrolled_loop)
672 : for_loops_(std::move(outer_loops)), unrolled_loop_(unrolled_loop) {
673 openLoop(unrolled_loop);
674 finalize();
675}
676
677} // namespace cuda
678} // namespace fuser
679} // namespace jit
680} // namespace torch
681