1#include <lower_validation.h>
2
3#include <contiguity.h>
4#include <expr_evaluator.h>
5#include <instrumentation.h>
6#include <ir_iostream.h>
7#include <ir_utils.h>
8#include <iter_visitor.h>
9#include <lower2device.h>
10#include <lower_utils.h>
11#include <transform_iter.h>
12#include <transform_replay.h>
13#include <type.h>
14
15#include <ATen/cuda/CUDAContext.h>
16#include <limits>
17
18namespace torch {
19namespace jit {
20namespace fuser {
21namespace cuda {
22
23namespace {
24
25//! Validate multiple output tensors of the same expression, i.e.,
26//! siblings, have valid domains and parallel types. Since siblings
27//! are placed in the same loop nest, they must be parallelized the
28//! same way. Will infer and modify serial parallel types if other
29//! output/s are parallelized, so that user wouldn't have to specify
30//! the same parallelization 3 times. Will throw if conflicts are
31//! detected, i.e. TIDx vs BIDx etc.
32class ValidateSiblings : public IterVisitor {
33 public:
34 static void validate(Fusion* fusion) {
35 ValidateSiblings validator;
36 validator.traverse(fusion);
37 }
38
39 private:
40 using IterVisitor::handle;
41
42 void handle(Expr* expr) final {
43 if (!ir_utils::isTvOp(expr) || expr->outputs().size() < 2) {
44 IterVisitor::handle(expr);
45 return;
46 }
47
48 auto ref_output = expr->outputs().at(0)->as<TensorView>();
49 auto ref_ndims = ref_output->nDims();
50 const auto& ref_root = ref_output->getRootDomain();
51 std::unordered_map<IterDomain*, IterDomain*> id_map;
52
53 for (const auto sibling :
54 ir_utils::filterByType<TensorView>(expr->outputs())) {
55 if (ref_output == sibling) {
56 continue;
57 }
58
59 TORCH_INTERNAL_ASSERT(
60 sibling->nDims() == ref_ndims,
61 "Mismatched dimensionality detected. Expr: ",
62 expr->toString(),
63 "Ref output: ",
64 ref_output->toString(),
65 ". Sibling: ",
66 sibling->toString());
67
68 for (const auto i : c10::irange(ref_ndims)) {
69 validateParallelTypes(ref_output->axis(i), sibling->axis(i));
70 }
71
72 for (const auto i : c10::irange(ref_root.size())) {
73 id_map[ref_root[i]] = sibling->getRootDomain().at(i);
74 }
75
76 BestEffortReplay replay(
77 sibling->domain()->domain(), ref_output->domain()->domain(), id_map);
78 for (const auto i : c10::irange(ref_ndims)) {
79 auto it = replay.getReplay().find(ref_output->axis(i));
80 TORCH_INTERNAL_ASSERT(
81 it != replay.getReplay().end(),
82 "Matching sibling ID not found. Expr: ",
83 expr->toString(),
84 "Ref ID: ",
85 ref_output->axis(i)->toString());
86 auto sibling_id = it->second;
87 TORCH_INTERNAL_ASSERT(
88 sibling->axis(i) == sibling_id,
89 "Invalid matching sibling ID detected. Expr: ",
90 expr->toString(),
91 "Sibling ID: ",
92 sibling_id->toString());
93 }
94 }
95 }
96
97 // Parallelize id1 and id0 consistently if one is serial and the other isn't
98 void validateParallelTypes(IterDomain* id0, IterDomain* id1) {
99 const auto ptype0 = id0->getParallelType();
100 const auto ptype1 = id1->getParallelType();
101
102 if (ptype0 == ParallelType::Vectorize ||
103 ptype1 == ParallelType::Vectorize) {
104 auto other_type = ptype0 == ParallelType::Vectorize ? ptype1 : ptype0;
105 TORCH_INTERNAL_ASSERT(
106 other_type == ParallelType::Vectorize ||
107 (!isParallelTypeThreadDim(other_type) &&
108 !isParallelTypeBlockDim(other_type)),
109 "Vectorize type was parallelized inconsistently in. ",
110 "Detected during promoting parallel types.");
111 return;
112 }
113
114 if (ptype0 != ptype1) {
115 TORCH_CHECK(
116 ptype0 == ParallelType::Serial || ptype1 == ParallelType::Serial,
117 "Error promoting parallel types");
118 if (ptype0 == ParallelType::Serial) {
119 id0->parallelize(ptype1);
120 }
121 if (ptype1 == ParallelType::Serial) {
122 id1->parallelize(ptype0);
123 }
124 }
125 }
126};
127
128// Make sure all IterDomains are only used for a unique
129// TensorView. Several mappings from IterDomains are
130// created during lowering, which relies on the unique usage of
131// IterDomains.
132void validateIterDomainUsage(Fusion* fusion) {
133 FUSER_PERF_SCOPE("GpuLower::Lower::validateIterDomainUse");
134 FusionGuard fg(fusion);
135
136 auto used_vals = fusion->usedMathVals();
137 std::unordered_map<IterDomain*, TensorView*> domain_use_map;
138
139 for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
140 std::unordered_set<Val*> root_domains;
141 std::copy(
142 tv->getRootDomain().begin(),
143 tv->getRootDomain().end(),
144 std::inserter(root_domains, root_domains.begin()));
145
146 std::vector<Val*> leaf_domains;
147 std::copy(
148 tv->domain()->domain().begin(),
149 tv->domain()->domain().end(),
150 std::back_inserter(leaf_domains));
151
152 auto all_domain_vals =
153 DependencyCheck::getAllValsBetween(root_domains, leaf_domains);
154
155 for (auto id : ir_utils::filterByType<IterDomain>(all_domain_vals)) {
156 auto it = domain_use_map.find(id);
157 TORCH_INTERNAL_ASSERT(
158 it == domain_use_map.end(),
159 "Multiple use of ",
160 id,
161 " detected.",
162 " Used in both TV",
163 tv->name(),
164 " and TV",
165 it->second->name());
166 domain_use_map.insert({id, tv});
167 }
168 }
169}
170
171} // namespace
172
173void validateIr(Fusion* fusion) {
174 FUSER_PERF_SCOPE("GpuLower::Lower::validateIr");
175
176 FusionGuard fg(fusion);
177
178 fusion->validateInputs();
179
180 // Validate Parallelization
181 ValidateSiblings::validate(fusion);
182
183 validateIterDomainUsage(fusion);
184}
185
186namespace {
187
188// Check contiguity for all root domains associated with Misaligned Vectorize
189// ParallelType
190void checkContiguity(
191 const std::unordered_set<IterDomain*>& domains,
192 TensorView* tv) {
193 TORCH_INTERNAL_ASSERT(tv->getMemoryType() == MemoryType::Global);
194
195 for (const auto idx : c10::irange(tv->getRootDomain().size())) {
196 auto root = tv->getRootDomain()[idx];
197 if (domains.find(root) != domains.end()) {
198 TORCH_INTERNAL_ASSERT(
199 !root->isBroadcast(),
200 "Misaligned vectorization prohibits merging broadcast domains.",
201 "Issue found in, ",
202 tv);
203 TORCH_INTERNAL_ASSERT(
204 tv->domain()->contiguity()[idx],
205 "Cannot merge non-contiguous root domains with misaligned vectorization.",
206 "Issue found in, ",
207 tv);
208 }
209 }
210}
211
212// Check all root iter domains in consumer that are present in domain, making
213// sure they're contiguous. Map these domains to producer and make sure they are
214// also contiguous in producer. Producer-consumer relationship is assumed to be
215// through a set operation.
216void checkContiguity(
217 const std::unordered_set<IterDomain*>& domains,
218 TensorView* consumer,
219 TensorView* producer) {
220 // This seems not quite right, shouldn't we be able to reverse this?
221 TORCH_INTERNAL_ASSERT(consumer->getMemoryType() == MemoryType::Local);
222 TORCH_INTERNAL_ASSERT(producer->getMemoryType() == MemoryType::Global);
223
224 auto root_c2p =
225 PairwiseRootDomainMap(producer, consumer)
226 .mapConsumerToProducer(consumer->domain(), producer->domain());
227
228 std::unordered_map<IterDomain*, bool> producer_domain_contiguity;
229 for (const auto idx : c10::irange(producer->getMaybeRFactorDomain().size())) {
230 auto root = producer->getMaybeRFactorDomain()[idx];
231 auto contiguity = producer->domain()->contiguity()[idx];
232 producer_domain_contiguity.insert({root, contiguity});
233 }
234
235 for (auto consumer_root : consumer->getMaybeRFactorDomain()) {
236 if (domains.find(consumer_root) != domains.end()) {
237 auto producer_root = root_c2p[consumer_root];
238 TORCH_INTERNAL_ASSERT(
239 producer_domain_contiguity.find(producer_root) !=
240 producer_domain_contiguity.end());
241
242 TORCH_INTERNAL_ASSERT(
243 !consumer_root->isBroadcast() || !producer_root->isBroadcast(),
244 "Misaligned vectorization prohibits merging broadcast domains.",
245 "Issue found in, ",
246 consumer);
247
248 TORCH_INTERNAL_ASSERT(root_c2p.find(consumer_root) != root_c2p.end());
249
250 TORCH_INTERNAL_ASSERT(
251 producer_domain_contiguity[producer_root],
252 "Cannot merge non-contiguous root domains with misaligned vectorization.",
253 "Issue found in, ",
254 consumer);
255 }
256 }
257}
258
259class VectorizeValidator : public OptInDispatch {
260 private:
261 // Initially, vectorized_id is the IterDomain with Vectorize ParallelType
262 // After processing all merge and split operations,
263 // vectorized_id is the corresponding root domain
264 VectorizeValidator(IterDomain* vectorized_id)
265 : vectorized_id_(vectorized_id) {}
266
267 using OptInDispatch::handle;
268
269 void handle(Split* s) final {
270 if (s->outer() == vectorized_id_) {
271 is_valid = false;
272 } else if (s->inner() == vectorized_id_) {
273 vectorized_id_ = s->in();
274 }
275 domains_.insert(s->outer());
276 domains_.insert(s->inner());
277 }
278
279 void handle(Merge* m) final {
280 if (m->out() == vectorized_id_) {
281 if (m->inner()->isBroadcast() && !m->outer()->isBroadcast()) {
282 vectorized_id_ = m->outer();
283 } else {
284 vectorized_id_ = m->inner();
285 }
286 }
287 domains_.insert(m->outer());
288 domains_.insert(m->inner());
289 }
290
291 // For the producer tensor, it's indexed first by transformed like
292 // the consumer. So, to find its contig merged domain, use the
293 // consumer TensorDomain with the producer contiguity info.
294 static std::vector<bool> mapProducerContiguity(
295 TensorView* producer_tv,
296 TensorView* consumer_tv) {
297 const auto c2p = PairwiseRootDomainMap(producer_tv, consumer_tv)
298 .mapConsumerToProducer(
299 consumer_tv->domain(), producer_tv->domain());
300
301 std::vector<bool> producer_contiguity;
302
303 for (auto consumer_root_id : consumer_tv->getRootDomain()) {
304 auto producer_root_id = c2p.at(consumer_root_id);
305 auto producer_root_it = std::find(
306 producer_tv->getMaybeRFactorDomain().begin(),
307 producer_tv->getMaybeRFactorDomain().end(),
308 producer_root_id);
309 TORCH_INTERNAL_ASSERT(
310 producer_root_it != producer_tv->getMaybeRFactorDomain().end());
311 auto producer_root_id_offset = std::distance(
312 producer_tv->getMaybeRFactorDomain().begin(), producer_root_it);
313 producer_contiguity.push_back(
314 producer_tv->domain()->contiguity().at(producer_root_id_offset));
315 }
316
317 return producer_contiguity;
318 }
319
320 private:
321 std::unordered_set<IterDomain*> domains_;
322 IterDomain* vectorized_id_ = nullptr;
323 bool is_valid = true;
324
325 public:
326 static void validate(TensorView* tv) {
327 // Make sure there's only one vectorized ID
328 IterDomain* v_id = nullptr;
329 bool misaligned_vectorize = false;
330 for (auto id : tv->domain()->domain()) {
331 if (id->getParallelType() == ParallelType::Vectorize ||
332 id->getParallelType() == ParallelType::MisalignedVectorize) {
333 TORCH_INTERNAL_ASSERT(
334 v_id == nullptr,
335 "Found two vectorized domains in ",
336 tv,
337 " only one is allowed.");
338 v_id = id;
339 misaligned_vectorize =
340 id->getParallelType() == ParallelType::MisalignedVectorize;
341 }
342 }
343
344 // If no vectorized ids found simply return. If vectorized access is
345 // broadcast, it won't generate an actual vector instruction, so can safely
346 // be ignore
347 if (v_id == nullptr || v_id->isBroadcast()) {
348 return;
349 }
350
351 auto fusion = FusionGuard::getCurFusion();
352
353 TORCH_CHECK(
354 v_id->extent()->isConstScalar(),
355 "Vectorizing a domain requires a constant size.");
356
357 ExpressionEvaluator const_expr_eval(fusion);
358
359 auto vector_size_optional = const_expr_eval.evaluate(v_id->extent());
360
361 TORCH_CHECK(
362 vector_size_optional.has_value(),
363 "Could not evaluate constant value bound to vectorized dim.");
364
365 auto vector_size = ((int64_t)dataTypeSize(tv->getDataType().value())) *
366 vector_size_optional.value();
367
368 // Allow half2, float2, float4 and same sized vtypes.
369 std::array<int64_t, 4> allowed_vector_sizes = {2, 4, 8, 16}; // NOLINT
370
371 TORCH_CHECK(
372 std::find(
373 allowed_vector_sizes.begin(),
374 allowed_vector_sizes.end(),
375 vector_size) != allowed_vector_sizes.end(),
376 "Tried to vectorize a dim resulting in a word size of ",
377 vector_size,
378 " however, vector sizes only upto and including 16 bytes are supported.");
379
380 auto replay_exprs = DependencyCheck::getAllExprsBetween(
381 {tv->getMaybeRFactorDomain().begin(),
382 tv->getMaybeRFactorDomain().end()},
383 {v_id});
384
385 VectorizeValidator validator(v_id);
386
387 for (auto expr_it = replay_exprs.rbegin(); expr_it != replay_exprs.rend();
388 ++expr_it) {
389 auto expr = *expr_it;
390 validator.handle(expr);
391 }
392
393 TORCH_CHECK(
394 validator.is_valid,
395 "Invalid vectorized pattern found, vectorization iter domains must be descendants of inner-most dimension.",
396 "Issue found in, ",
397 tv,
398 "\n");
399
400 if (misaligned_vectorize) {
401 if (tv->getMemoryType() == MemoryType::Global) {
402 checkContiguity(validator.domains_, tv);
403 } else if (
404 tv->definition()->getExprType() == ExprType::UnaryOp &&
405 tv->definition()->as<UnaryOp>()->getUnaryOpType() ==
406 UnaryOpType::Set) {
407 auto input = tv->definition()->input(0);
408 TORCH_INTERNAL_ASSERT(input->isA<TensorView>());
409 auto input_tv = input->as<TensorView>();
410 checkContiguity(validator.domains_, tv, input_tv);
411 }
412 }
413
414 TORCH_INTERNAL_ASSERT(validator.vectorized_id_ != nullptr);
415
416 // Contiguity is based on rfactor domain.
417 IterDomain* last_root_dim = nullptr;
418 int last_root_dim_pos = -1;
419 for (size_t i = tv->getMaybeRFactorDomain().size(); i > 0; i--) {
420 auto r_id = tv->getMaybeRFactorDomain()[i - 1];
421 if (r_id->isReduction() || r_id->isBroadcast()) {
422 continue;
423 }
424 last_root_dim = r_id;
425 last_root_dim_pos = (int)i - 1;
426 break;
427 }
428
429 if (last_root_dim == nullptr) {
430 // Should never get here, but that would mean there are no concrete dims,
431 // so we should be fine.
432 return;
433 }
434
435 TORCH_CHECK(
436 last_root_dim == validator.vectorized_id_ &&
437 tv->domain()->contiguity()[last_root_dim_pos],
438 "Vectorized dim has to be from a contiguous inner most position: ",
439 tv,
440 "\n");
441
442 // Save info required to lowering and runtime validation
443 auto consumer_word_size_it =
444 GpuLower::current()->vectorizedAccesses().find(tv);
445 if (consumer_word_size_it !=
446 GpuLower::current()->vectorizedAccesses().end()) {
447 consumer_word_size_it->second = std::max(
448 (int)vector_size_optional.value(), consumer_word_size_it->second);
449 } else {
450 GpuLower::current()->vectorizedAccesses().emplace(
451 tv, (int)vector_size_optional.value());
452 }
453 auto producer_tv = tv->definition()->inputs().at(0)->as<TensorView>();
454 auto producer_word_size_it =
455 GpuLower::current()->vectorizedAccesses().find(producer_tv);
456 if (producer_word_size_it !=
457 GpuLower::current()->vectorizedAccesses().end()) {
458 producer_word_size_it->second = std::max(
459 (int)vector_size_optional.value(), producer_word_size_it->second);
460 } else {
461 GpuLower::current()->vectorizedAccesses().emplace(
462 producer_tv, (int)vector_size_optional.value());
463 }
464
465 VectorizedSetInfo vectorized_set_info;
466 vectorized_set_info.consumer_tv = tv;
467 vectorized_set_info.producer_tv = producer_tv;
468 // Note that VectorizedSetInfo is about each instance of
469 // vectorized set operations, so the word size is the size of this
470 // specific vectorized set.
471 vectorized_set_info.word_size = (int)vector_size_optional.value();
472 vectorized_set_info.vectorized_leaf_id = v_id;
473 vectorized_set_info.vectorized_root_id = validator.vectorized_id_;
474 // For aligned vectorize, the extent of a vectorized domain must
475 // be divisible by the vector word size. The domain is usually
476 // just one of the root domains, but can be a merged domain of
477 // contiguous domains. Those domains are saved in
478 // VectorizedSetInfo.contig_root_ids at the time of indexing.
479 GpuLower::current()->vectorizedSetInfo().emplace_back(vectorized_set_info);
480 }
481};
482
483} // namespace
484
485// Uses ContigIDs to find root contig domains that a vectorized domain
486// depends on. As ContigIDs depends on HaloInfo, this must be done
487// after HaloInfo is created.
488void validateAndCollectVectorizeInfo(Fusion* fusion) {
489 FUSER_PERF_SCOPE("GpuLower::Lower::validateVectorize");
490 FusionGuard fg(fusion);
491
492 auto used_vals = fusion->usedMathVals();
493
494 std::unordered_set<TensorView*> used_tvs;
495
496 for (auto val : used_vals) {
497 if (ir_utils::isTV(val)) {
498 used_tvs.emplace(val->as<TensorView>());
499 }
500 }
501
502 for (auto tv : used_tvs) {
503 bool has_vectorize_dim = false;
504 bool has_misaligned_vectorize_dim = false;
505
506 for (const auto i : c10::irange(tv->nDims())) {
507 IterDomain* id = tv->axis(i);
508 IterDomain* concrete_id =
509 GpuLower::current()->caMap()->getConcreteMappedID(
510 id, IdMappingMode::LOOP);
511
512 auto ptype = concrete_id->getParallelType();
513
514 if (ptype == ParallelType::Vectorize) {
515 // If we want to do this check up front we would have to do 2 things:
516 // (1) Check that the tensor view with vectorize being set on it is
517 // getting set outside the local compute at position
518 // (2) Check any producers of the tensor view with vectorize being set
519 // on it to make sure their compute at position isn't to the right of
520 // the vectorize dim.
521 TORCH_INTERNAL_ASSERT(
522 i >= tv->getComputeAtPosition(),
523 "IterDomains to the left of the compute at point cannot be vectorized: ",
524 tv,
525 "\n");
526 has_vectorize_dim = true;
527 }
528
529 if (concrete_id->getParallelType() == ParallelType::MisalignedVectorize) {
530 TORCH_INTERNAL_ASSERT(
531 !tv->hasComputeAt() ||
532 tv->getComputeAtPosition() == tv->nDims() - 1,
533 "Only allow misaligned vectorization in the -2 computeAt position.");
534 TORCH_INTERNAL_ASSERT(
535 tv->getMemoryType() == MemoryType::Local ||
536 tv->getMemoryType() == MemoryType::Global,
537 "Only allow misaligned vectorization between global and local memory.");
538 has_misaligned_vectorize_dim = true;
539 }
540 }
541 if (has_vectorize_dim) {
542 TORCH_INTERNAL_ASSERT(
543 tv->definition() == nullptr ||
544 (tv->definition()->isA<UnaryOp>() &&
545 tv->definition()->as<UnaryOp>()->getUnaryOpType() ==
546 UnaryOpType::Set) ||
547 tv->definition()->isA<LoadStoreOp>(),
548 "Vectorized accesses cannot be inline with computation, they are only supported with a Set operation.",
549 "TensorView: ",
550 tv);
551 }
552 // Validate the vectorized domain maps to the innermost domain of
553 // tv. Note that we don't need to validate its producer tv as
554 // both Vectorize and MisalignedVectorize can only be used with
555 // UnaryOp::Set.
556 if (has_vectorize_dim || has_misaligned_vectorize_dim) {
557 VectorizeValidator::validate(tv);
558 }
559 }
560}
561
562namespace {
563
564void fillVectorizedContigRootDomains(
565 const TensorView* tv,
566 const ContigIDs& contig_finder,
567 IterDomain* vectorized_root_id,
568 VectorizedSetInfo& info) {
569 const auto& root_dom = tv->getMaybeRFactorDomain();
570
571 // Find the root domains that are dependency of the merged contig
572 // domain.
573
574 auto consumer_indexed_it =
575 contig_finder.rootToIndexedID().find(vectorized_root_id);
576 TORCH_INTERNAL_ASSERT(
577 consumer_indexed_it != contig_finder.rootToIndexedID().end(),
578 "Contiguity information not found for root domain: ",
579 vectorized_root_id->toString());
580 auto consumer_indexed_id = consumer_indexed_it->second;
581
582 // Actual indexed root domains for this root domain. If
583 // contig merge is done, multiple root domains are included.
584 std::unordered_set<IterDomain*> indexed_root_ids;
585
586 if (consumer_indexed_id == vectorized_root_id) {
587 // Indexed domain is equal to the root domain, meaning no contig
588 // merge is involved.
589 indexed_root_ids.insert(vectorized_root_id);
590 } else {
591 auto consumer_within_contig_it =
592 contig_finder.withinContigIDs().find(consumer_indexed_id);
593 TORCH_INTERNAL_ASSERT(
594 consumer_within_contig_it != contig_finder.withinContigIDs().end());
595 const auto& within_ids = consumer_within_contig_it->second;
596 std::copy_if(
597 root_dom.begin(),
598 root_dom.end(),
599 std::inserter(indexed_root_ids, indexed_root_ids.end()),
600 [&](IterDomain* root_id) {
601 return within_ids.find(root_id) != within_ids.end();
602 });
603 }
604
605 // Store the contig merged root domains. If it is already set, pick
606 // the smaller one as it is used for validating divisibility of the
607 // merged extent.
608 if (info.contig_root_ids.empty() ||
609 indexed_root_ids.size() < info.contig_root_ids.size()) {
610 info.contig_root_ids = indexed_root_ids;
611 }
612}
613
614} // namespace
615
616void fillConsumerVectorizedContigRootDomains(
617 const TensorView* consumer_tv,
618 const ContigIDs& contig_finder) {
619 auto& info_vector = GpuLower::current()->vectorizedSetInfo();
620 auto it = std::find_if(
621 info_vector.begin(), info_vector.end(), [&consumer_tv](auto& info) {
622 return info.consumer_tv == consumer_tv;
623 });
624 if (it == info_vector.end()) {
625 return;
626 }
627
628 VectorizedSetInfo& info = *it;
629
630 // info.vectorized_root_id is validated at this point to be the
631 // last concrete root domain in consumer.
632 auto consumer_root_id = info.vectorized_root_id;
633
634 fillVectorizedContigRootDomains(
635 consumer_tv, contig_finder, consumer_root_id, info);
636}
637
638void fillProducerVectorizedContigRootDomains(
639 const TensorView* producer_tv,
640 const TensorView* consumer_tv,
641 const std::unordered_map<IterDomain*, IterDomain*>& c2p_map,
642 const ContigIDs& contig_finder) {
643 auto& info_vector = GpuLower::current()->vectorizedSetInfo();
644 auto it = std::find_if(
645 info_vector.begin(),
646 info_vector.end(),
647 [&producer_tv, &consumer_tv](auto& info) {
648 return info.consumer_tv == consumer_tv &&
649 info.producer_tv == producer_tv;
650 });
651 if (it == info_vector.end()) {
652 return;
653 }
654
655 VectorizedSetInfo& info = *it;
656
657 // info.vectorized_root_id is validated at this point to be the
658 // last concrete root domain in consumer.
659 auto consumer_root_id = info.vectorized_root_id;
660
661 auto root_id = c2p_map.at(consumer_root_id);
662
663 fillVectorizedContigRootDomains(producer_tv, contig_finder, root_id, info);
664}
665
666namespace {
667
668// Backward propagation of partial ranges from outputs to
669// inputs. Necessary to determine required ranges to compute.
670//
671// Example:
672// tv0: [0:N]
673// tv1: shift(tv0, {1}) -> [1:N]
674// tv2: shift(tv0, {-1}) -> [0:N-1]
675// tv3: tv1 + tv2 -> [1:N-1]
676//
677// In this case, the valid range of tv3 starts at 1 and ends at
678// N-1. This means that not all of the values of tv1 and tv2 are
679// actually necessary. Specifically, tv1[0] and tv2[N-1] aren't used
680// for tv3. This function calculates the required minimum range of
681// each tensor that needs to be computed.
682std::unordered_map<IterDomain*, std::pair<int64_t, int64_t>> getLiveRangeOffsets(
683 Fusion* fusion) {
684 auto exprs = StmtSort::getExprs(fusion);
685
686 std::unordered_map<IterDomain*, std::pair<int64_t, int64_t>> map;
687
688 ExpressionEvaluator ee(fusion);
689
690 for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) {
691 auto expr = *it;
692 for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
693 for (auto consumer_root : consumer->getRootDomain()) {
694 auto consumer_start_offset = ee.evaluate(consumer_root->start());
695 auto consumer_stop_offset = ee.evaluate(consumer_root->stopOffset());
696 TORCH_INTERNAL_ASSERT(
697 consumer_start_offset.has_value(),
698 "Can't evaluate start value of ",
699 consumer_root->start());
700 TORCH_INTERNAL_ASSERT(
701 consumer_stop_offset.has_value(),
702 "Can't evaluate stop value of ",
703 consumer_root->stopOffset());
704 auto it = map.find(consumer_root);
705 if (it == map.end() || consumer->isFusionOutput()) {
706 // No range set for this root domain, which means this
707 // consumer_tensor is an output tensor or the consumer_root
708 // domain is a reduction domain. In either case, the
709 // required range is simply defined by the start and stop
710 // offsets of the root domain.
711 // Also, when consumer is an output, even if it's not
712 // terminating, the range to compute must not be affected by
713 // how it's used by its consumers because an output tensor
714 // is visible to outside of the fusion.
715 map.insert(
716 {consumer_root,
717 {consumer_start_offset->as<int64_t>(),
718 consumer_stop_offset->as<int64_t>()}});
719 } else {
720 // When the range of this root domain is already set, it
721 // must be set by its consumers. Make sure the required
722 // range by the consumers is covered by the defined range of
723 // this root domain.
724 auto& consumer_range = it->second;
725 TORCH_INTERNAL_ASSERT(
726 consumer_start_offset->as<int64_t>() <= consumer_range.first);
727 TORCH_INTERNAL_ASSERT(
728 consumer_stop_offset->as<int64_t>() <= consumer_range.second);
729 }
730 }
731
732 // Propagate the range information from consumers to the
733 // produces. Note that the effect on the range by shift and
734 // gather is not considered here but taken care by halo regions.
735 for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
736 auto c2p =
737 PairwiseRootDomainMap(producer, consumer)
738 .mapConsumerToProducer(consumer->domain(), producer->domain());
739 for (auto consumer_root : consumer->getRootDomain()) {
740 auto producer_it = c2p.find(consumer_root);
741 if (producer_it == c2p.end()) {
742 continue;
743 }
744 auto producer_root = producer_it->second;
745 auto& consumer_range = map.at(consumer_root);
746 const std::pair<int64_t, int64_t> init_range{
747 std::numeric_limits<int64_t>::max(),
748 std::numeric_limits<int64_t>::max()};
749 auto& producer_range =
750 map.insert({producer_root, init_range}).first->second;
751 producer_range.first =
752 std::min(producer_range.first, consumer_range.first);
753 producer_range.second =
754 std::min(producer_range.second, consumer_range.second);
755 }
756 }
757 }
758 }
759
760 return map;
761}
762
763// Make sure that a partial split with split_offset does not violate
764// the required range defined by domain_offset. Suppose checking the
765// start side of a root domain. Only positions at split_offset or
766// larger are going to be computed, and all positions starting at
767// domain_offset must be computed, thus split_offset must be smaller
768// or equal to domain_offset. The same condition must hold for the end
769// side of the domain.
770//
771// In order to validate this condition, the split offset is assumed to
772// be a statically known constant value. This is not a hard
773// requirement, but otherwise a runtime check would be needed.
774void validateSplit(
775 Val* split_offset,
776 int64_t domain_offset,
777 const std::string& err_msg_prefix) {
778 ExpressionEvaluator ee(split_offset->fusion());
779
780 TORCH_INTERNAL_ASSERT(split_offset->isA<Int>());
781 auto split_offset_value = ee.evaluate(split_offset);
782 TORCH_INTERNAL_ASSERT(
783 split_offset_value.has_value(),
784 err_msg_prefix,
785 ": Unknown offset of split: ",
786 split_offset);
787
788 TORCH_INTERNAL_ASSERT(
789 split_offset_value.value() <= domain_offset,
790 err_msg_prefix,
791 ": Split offset is larger than the domain offset.",
792 " Split offset: ",
793 split_offset_value.value(),
794 ". Domain offset: ",
795 domain_offset);
796}
797
798} // namespace
799
800void validatePartialSplit(Fusion* fusion) {
801 FUSER_PERF_SCOPE("GpuLower::Lower::validatePartialSplit");
802 FusionGuard fg(fusion);
803
804 // If a root domain is partially split, only the sub range defined
805 // by the start and stop offsets of the partial split is
806 // computed. That sub range must cover the required range of the
807 // domain. So, the first thing to do is to determine the required
808 // minimum range of each root domain. Then, check if any partial
809 // split could result in a smaller range than the required range.
810
811 // Compute the required range of each root domain
812 auto range_info = getLiveRangeOffsets(fusion);
813
814 for (auto tv : ir_utils::allTvs(fusion)) {
815 auto exprs = StmtSort::getExprs(
816 tv->fusion(),
817 {tv->domain()->domain().begin(), tv->domain()->domain().end()});
818 for (auto split : ir_utils::filterByType<Split>(exprs)) {
819 // When the start and stop offsets are not zero, make sure the
820 // range defined by the split includes the required range to
821 // compute. If both of the split offsets are zero, this
822 // condition is obviously true. Also, this validation only needs
823 // to be done with root domains. Since the start and stop
824 // offsets of non-root domains must be just zero, they are
825 // skipped at this point.
826 if (split->startOffset()->isZeroInt() &&
827 split->stopOffset()->isZeroInt()) {
828 continue;
829 }
830 auto root_domain = split->in();
831 std::stringstream err_msg_prefix;
832 err_msg_prefix << "Error with " << root_domain << " in T" << tv->name();
833 TORCH_INTERNAL_ASSERT(range_info.find(root_domain) != range_info.end());
834 const auto& valid_range = range_info.at(root_domain);
835 // Check the start offset. If it's zero, no validation regarding
836 // the required range can occur.
837 if (!split->startOffset()->isZeroInt()) {
838 validateSplit(
839 split->startOffset(), valid_range.first, err_msg_prefix.str());
840 }
841 // Same for the stop offset.
842 if (!split->stopOffset()->isZeroInt()) {
843 validateSplit(
844 split->stopOffset(), valid_range.second, err_msg_prefix.str());
845 }
846 }
847 }
848}
849
850namespace {
851
852//! Utility to make sure targeted gpu capability is
853//! higher than provided major.minor.
854void validateMinimumArch(int major, int minor) {
855 // Skip checking arch if disabled.
856 if (isOptionDisabled(DisableOption::ArchCheck)) {
857 return;
858 }
859
860 auto prop = at::cuda::getCurrentDeviceProperties();
861 TORCH_INTERNAL_ASSERT(prop->major >= major);
862 if (prop->major == major) {
863 TORCH_INTERNAL_ASSERT(prop->minor >= minor);
864 }
865}
866
867//! Validates that the operand and result tensors
868//! of mma ops are swizzled and also validates
869//! specialization of tidx as lane id.
870void validateMmaTensors(MmaOp* mma) {
871 bool tidx_validated = false;
872 std::vector<TensorView*> to_validate = {
873 mma->inA()->as<TensorView>(),
874 mma->inB()->as<TensorView>(),
875 mma->out()->as<TensorView>()};
876
877 for (auto tv : to_validate) {
878 for (auto id : tv->domain()->domain()) {
879 auto ptype = id->getParallelType();
880 if (ptype == ParallelType::TIDx) {
881 TORCH_INTERNAL_ASSERT(
882 id->isMmaSwizzled(),
883 "TIDx for mma input/output must be set by WarpMmaSwizzler",
884 id,
885 tv);
886 if (!tidx_validated) {
887 // Check that TIDx is exact lane_id
888 const auto& paralel_dim_map =
889 GpuLower::current()->parallelDimensionMap();
890 TORCH_INTERNAL_ASSERT(
891 paralel_dim_map.isExact(ptype) &&
892 paralel_dim_map.get(ptype)->isConstInt() &&
893 paralel_dim_map.get(ptype)->evaluateInt() ==
894 at::cuda::warp_size(),
895 "TIDx is reserved for lane id in mma kernels, and it needs to be exactly a warp");
896 tidx_validated = true;
897 }
898 }
899 }
900 }
901
902 // Note: this check will be relaxed in a follow up.
903 auto validate_operand = [](const TensorView* tv) {
904 TORCH_INTERNAL_ASSERT(
905 tv->getMemoryType() == MemoryType::Local,
906 "Only supporting register input for mma ops, up to sm80 all mma ops have to take register inputs.");
907
908 TORCH_INTERNAL_ASSERT(
909 std::all_of(
910 tv->domain()->domain().begin() + tv->getComputeAtPosition(),
911 tv->domain()->domain().end(),
912 [](IterDomain* id) {
913 return id->isMmaSwizzled() ||
914 // MMA instructions can only take inputs from registers,
915 // so we always assume mma op inputs are located on
916 // registers.
917 // Currently requiring that serial ids on the right of the
918 // CA axis are constant sized to ensure early detection of
919 // invalid mma schedules.
920 ((id->isBroadcast() || id->extent()->isConstInt()) &&
921 id->getParallelType() == ParallelType::Serial);
922 }),
923 "All id's on the right of CA pos needs to be mma-swizzled by WarpMmaSwizzler\n",
924 tv);
925 };
926
927 validate_operand(mma->inA()->as<TensorView>());
928 validate_operand(mma->inB()->as<TensorView>());
929
930 // Additionally validate that mma is not directly taking a double buffered
931 // register input as the double buffer indexing is currently not compatible
932 // with fragment iteration. Would need to require a cache stage in this case.
933 TORCH_INTERNAL_ASSERT(
934 !mma->inA()->as<TensorView>()->isDoubleBuffered(),
935 "MMA op cannot directly take double buffered register input, put a set stage before.");
936 TORCH_INTERNAL_ASSERT(
937 !mma->inB()->as<TensorView>()->isDoubleBuffered(),
938 "MMA op cannot directly take double buffered register input, put a set stage before.");
939}
940
941//! Note and TODO:
942//! Currently relying on ldmatrix to
943//! obtain the correct data layout for turing/ampere
944//! mma's.
945//! This restriction will eventually not
946//! be necessary once the scatter swizzle is ready.
947void validateTuringMmaInput(TensorView* tv) {
948 // Pattern matching here to make sure LDMatrix is the right format.
949 // Format is done through swizzling in the scheduling and
950 // we check that swizzling to make sure it's correctly setup for LDMatrix.
951 // We could in theory support patterns LDMatrix doesn't support,
952 // but that would also mean the MMA isn't supported and
953 // so we would have to lower to something completely different.
954
955 // MemCpy async is a more generic utility that we can use.
956 // Currently only allowed input paths are:
957 // ldmatrix -> mma or
958 // ldmatrix -> broadcast -> mma
959 // We actually wouldn't want too much flexibility here since
960 // this path is very perf critical. But the check itself
961 // can be made cleaner once we have the correct swizzle
962 // labeling.
963 // The most generic support would involve build out to
964 // support any pointwise ops that does not change the
965 // datalayout.
966 auto tv_def = tv->definition();
967 TORCH_INTERNAL_ASSERT(tv_def);
968 if (tv_def->isA<BroadcastOp>()) {
969 tv_def = tv_def->input(0)->definition();
970 }
971 TORCH_INTERNAL_ASSERT(tv_def);
972 TORCH_INTERNAL_ASSERT(ir_utils::isLdMatrixOp(tv_def));
973}
974
975// Output of ldmatrix is swizzled with the mma format, so it
976// currently should not be fused with any pointwise ops. This
977// check is to protect against these cases.
978// This would also not be needed once scatter swizzle ready, should
979// just become a swizzle format check if we wanted to fuse ldmatrix
980// with any op other than mma.
981void validateLdMatrixOutput(TensorView* tv) {
982 const auto& out_uses = tv->fusion()->unordered_uses(tv);
983 if (out_uses.empty()) {
984 return;
985 }
986 // TODO: restricting to single use pipelines for now which
987 // is true to matmul mainloop. This Could be relaxed to
988 // support more complex mma usage.
989 TORCH_INTERNAL_ASSERT(out_uses.size() == 1);
990 auto out_use = *(out_uses.begin());
991
992 if (out_use->isA<BroadcastOp>()) {
993 validateLdMatrixOutput(out_use->output(0)->as<TensorView>());
994 return;
995 }
996
997 TORCH_INTERNAL_ASSERT(
998 out_use->isA<MmaOp>(),
999 "validateLdMatrixOutput: currently only supports single mma use for ldmatrix",
1000 out_use);
1001}
1002
1003// Checks that the memory ops are supported on the targeted GPU
1004void validateArchMemoryOp(LoadStoreOp* ldst) {
1005 switch (ldst->opType()) {
1006 case LoadStoreOpType::LdMatrix:
1007 case LoadStoreOpType::LdMatrixTranspose:
1008 validateMinimumArch(7, 5);
1009 validateLdMatrixOutput(ldst->out()->as<TensorView>());
1010 return;
1011 case LoadStoreOpType::CpAsync:
1012 validateMinimumArch(8, 0);
1013 return;
1014 default:
1015 return;
1016 }
1017}
1018
1019} // namespace
1020
1021//! Validate data format and GPU arch compatibility of scheduled
1022//! mma operators on the fusion.
1023void validateMma(Fusion* fusion) {
1024 auto exprs = StmtSort::getExprs(fusion);
1025
1026 for (auto expr : exprs) {
1027 if (auto mma = dynamic_cast<MmaOp*>(expr)) {
1028 validateMmaTensors(mma);
1029
1030 switch (mma->options().macro) {
1031 case MmaOptions::MacroType::Volta_16_16_4:
1032 validateMinimumArch(7, 0);
1033 break;
1034 case MmaOptions::MacroType::Turing_16_8_16:
1035 case MmaOptions::MacroType::Turing_16_16_16:
1036 validateMinimumArch(7, 5);
1037
1038 // Check that operands come from ldmatrix, can be
1039 // relaxed once swizzles can be labeled on iterdomains.
1040 validateTuringMmaInput(mma->inA()->as<TensorView>());
1041 validateTuringMmaInput(mma->inB()->as<TensorView>());
1042 break;
1043 case MmaOptions::MacroType::Ampere_16_8_16:
1044 case MmaOptions::MacroType::Ampere_16_16_16:
1045 validateMinimumArch(8, 0);
1046
1047 // Check that operands come from ldmatrix, can be
1048 // relaxed once swizzles can be labeled on iterdomains.
1049 validateTuringMmaInput(mma->inA()->as<TensorView>());
1050 validateTuringMmaInput(mma->inB()->as<TensorView>());
1051 break;
1052 default:
1053 TORCH_INTERNAL_ASSERT(false, "validate mma: unsupported macro");
1054 break;
1055 }
1056 }
1057 if (auto ldst = dynamic_cast<LoadStoreOp*>(expr)) {
1058 validateArchMemoryOp(ldst);
1059 }
1060 }
1061}
1062
1063namespace {
1064
1065// Utility function to validate a loop swizzle:
1066// 1. Throws an error if any output of the swizzle is not in leaf_domain set.
1067// 2. Warns if any output of the swizzle is not the concrete id of the loop
1068// map.
1069// The second case would make the codegen ignore this swizzle, as if it was not
1070// there at all.
1071void validateLoopSwizzle(
1072 Expr* swizzle_expr,
1073 std::unordered_set<IterDomain*>& leaf_domains) {
1074 for (auto out_id :
1075 ir_utils::filterByType<IterDomain>(swizzle_expr->outputs())) {
1076 TORCH_INTERNAL_ASSERT(
1077 leaf_domains.count(out_id),
1078 "Loop swizzle can only be direct producer of leaf domains.");
1079 if (GpuLower::current()->caMap()->getConcreteMappedID(
1080 out_id, IdMappingMode::LOOP) != out_id) {
1081 TORCH_WARN_ONCE("Ignored loop swizzle :", swizzle_expr->toString());
1082 }
1083 }
1084}
1085
1086} // namespace
1087
1088void validateSwizzle(Fusion* fusion) {
1089 auto used_vals = fusion->usedMathVals();
1090 for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
1091 if (tv->hasSwizzleOp()) {
1092 std::unordered_set<IterDomain*> tv_leaf_domain_set(
1093 tv->domain()->domain().begin(), tv->domain()->domain().end());
1094
1095 // Make sure no swizzle op is inlined:
1096 auto inlined_swizzles = ir_utils::getAllSwizzlesBetween(
1097 tv->getMaybeRFactorDomain(),
1098 {tv->domain()->domain().begin(),
1099 tv->domain()->domain().begin() + tv->getComputeAtPosition()});
1100
1101 auto not_inlined_swizzles = ir_utils::getAllSwizzlesBetween(
1102 tv->getMaybeRFactorDomain(),
1103 {tv->domain()->domain().begin() + tv->getComputeAtPosition(),
1104 tv->domain()->domain().end()});
1105
1106 // Check inlined swizzles: only loop swizzles can be inlined currently
1107 // as inlining data swizzles would require addtional support of unswizzle
1108 // operator, which currently doesn't have important use cases.
1109 for (auto swizzle_expr : inlined_swizzles) {
1110 TORCH_INTERNAL_ASSERT(
1111 swizzle_expr->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Loop,
1112 "Only support inlining loop swizzles");
1113 validateLoopSwizzle(swizzle_expr, tv_leaf_domain_set);
1114 }
1115
1116 std::unordered_set<Expr*> inlined_swizzle_set(
1117 inlined_swizzles.begin(), inlined_swizzles.end());
1118
1119 // Check not inlined swizzles:
1120 // Apply the loop swizzle check when it applies, and
1121 // also make sure that the no swizzle is also in inlined_swizzle set.
1122 // The latter would mean that one output of the swizzle is inlined while
1123 // the other is not. Such case will not be supported.
1124 for (auto swizzle_expr : not_inlined_swizzles) {
1125 TORCH_INTERNAL_ASSERT(
1126 !inlined_swizzle_set.count(swizzle_expr),
1127 "Cannot partially inline across swizzle domains.",
1128 swizzle_expr->toString());
1129 if (swizzle_expr->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Loop) {
1130 validateLoopSwizzle(swizzle_expr, tv_leaf_domain_set);
1131 }
1132 }
1133 }
1134 }
1135}
1136
1137void validateAndConvertIterDomainGrouping(Fusion* fusion) {
1138 for (auto tv : ir_utils::allTvs(fusion)) {
1139 bool is_grouped = false;
1140 for (const auto id_idx : c10::irange(tv->nDims())) {
1141 const auto id = tv->axis(id_idx);
1142 auto ptype = GpuLower::current()
1143 ->caMap()
1144 ->getConcreteMappedID(id, IdMappingMode::LOOP)
1145 ->getParallelType();
1146 if (ptype != ParallelType::Group) {
1147 // Not a grouped ID
1148 continue;
1149 }
1150
1151 // Remember if a grouped ID is found
1152 is_grouped = true;
1153
1154 // Grouping only makes sense for the normal iteration type
1155 TORCH_CHECK(
1156 id->getIterType() == IterType::Iteration,
1157 "Invalid use of ParallelType::Group.",
1158 " Grouping of ",
1159 id->getIterType(),
1160 " is not allowed. ",
1161 tv->toString());
1162
1163 // Extent must be static
1164 TORCH_CHECK(
1165 id->extent()->getInt().has_value(),
1166 "Invalid use of ParallelType::Group.",
1167 " IterDomain must have a static extent: ",
1168 id->toString());
1169
1170 // The CA position must be left of any grouped ID
1171 TORCH_CHECK(
1172 tv->getComputeAtPosition() <= id_idx,
1173 "Invalid use of ParallelType::Group.",
1174 " ComputeAt position must be left of grouped IDs: ",
1175 tv->toString());
1176
1177 // Similarly, the produce-at position must be left of any grouped ID
1178 TORCH_CHECK(
1179 tv->getMaxProducerPosition() <= id_idx,
1180 "Invalid use of ParallelType::Group.",
1181 " ProduceAt position must be left of grouped IDs: ",
1182 tv->toString());
1183
1184 // Halo is not allowed
1185 TORCH_CHECK(
1186 GpuLower::current()->haloInfo()->getExtent(id) == nullptr,
1187 "Invalid use of ParallelType::Group.",
1188 " Grouping of halo-extended IterDomain, ",
1189 id->toString(),
1190 ", is not supported. ",
1191 tv->toString());
1192 }
1193
1194 if (!is_grouped) {
1195 continue;
1196 }
1197
1198 // Must be defined by ReductionOp
1199 auto def = tv->definition();
1200 TORCH_CHECK(
1201 def != nullptr,
1202 "Invalid use of ParallelType::Group.",
1203 " Definition of tv with ParallelType::Group not found. ",
1204 tv->toString());
1205
1206 TORCH_CHECK(
1207 tv->definition()->isA<ReductionOp>() ||
1208 tv->definition()->isA<GroupedReductionOp>() ||
1209 tv->definition()->isA<WelfordOp>() ||
1210 tv->definition()->isA<GroupedWelfordOp>(),
1211 "Invalid use of ParallelType::Group. Only ReductionOp, GroupedReductionOp, WelfordOp and GroupedWelfordOp are allowed. ",
1212 tv->definition()->toString());
1213
1214 // Convert ReductionOp to GroupedReductionOp
1215 if (tv->definition()->isA<ReductionOp>()) {
1216 auto rop = def->as<ReductionOp>();
1217 auto is_allreduce = rop->isAllreduce();
1218
1219 TORCH_CHECK(
1220 is_allreduce,
1221 "Invalid use of ParallelType::Group.",
1222 " Only enabled for allreduce reductions: ",
1223 rop->toString());
1224
1225 TORCH_CHECK(
1226 tv->domain()->hasGridReduction(),
1227 "Invalid use of ParallelType::Group.",
1228 " Only enabled for grid reductions: ",
1229 rop->toString());
1230
1231 std::vector<BinaryOpType> op_types({rop->getReductionOpType()});
1232 std::vector<Val*> init_vals({rop->init()});
1233 std::vector<Val*> outputs({rop->out()});
1234 std::vector<Val*> inputs({rop->in()});
1235
1236 fusion->removeExpr(rop);
1237 IrBuilder::create<GroupedReductionOp>(
1238 static_cast<IrContainer*>(fusion),
1239 op_types,
1240 init_vals,
1241 outputs,
1242 inputs,
1243 is_allreduce);
1244 } else if (tv->definition()->isA<WelfordOp>()) {
1245 // Convert WelfordOp to GroupedWelfordOp
1246 auto wop = def->as<WelfordOp>();
1247 auto is_allreduce = wop->isAllreduce();
1248
1249 TORCH_CHECK(
1250 is_allreduce,
1251 "Invalid use of ParallelType::Group.",
1252 " Only enabled for allreduce reductions: ",
1253 wop->toString());
1254
1255 TORCH_CHECK(
1256 tv->domain()->hasGridReduction(),
1257 "Invalid use of ParallelType::Group.",
1258 " Only enabled for grid reductions: ",
1259 wop->toString());
1260
1261 std::vector<WelfordTriplet> output_vals(
1262 {{wop->outAvg(), wop->outVar(), wop->outN()}});
1263 std::vector<WelfordTriplet> input_vals(
1264 {{wop->inAvg(), wop->inVar(), wop->inN()}});
1265 std::vector<WelfordTriplet> init_vals(
1266 {{wop->initAvg(), wop->initVar(), wop->initN()}});
1267 fusion->removeExpr(wop);
1268 IrBuilder::create<GroupedWelfordOp>(
1269 static_cast<IrContainer*>(fusion),
1270 output_vals,
1271 input_vals,
1272 init_vals,
1273 is_allreduce);
1274 }
1275 }
1276}
1277
1278void validateGroupedReductions(Fusion* fusion) {
1279 for (auto expr : StmtSort::getExprs(fusion)) {
1280 if (auto grouped_reduction_op = dynamic_cast<GroupedReductionOp*>(expr)) {
1281 const auto num_exprs = grouped_reduction_op->numExprs();
1282 int num_grouped_iterations = 1;
1283 auto out_tv = ir_utils::getTvOutput(grouped_reduction_op);
1284 for (auto axis : out_tv->domain()->domain()) {
1285 if (axis->getParallelType() == ParallelType::Group) {
1286 num_grouped_iterations *= axis->extent()->getInt().value();
1287 }
1288 }
1289 TORCH_CHECK(
1290 num_exprs * num_grouped_iterations <= kMaxNumGroupedReductions,
1291 "Too many grouped reductions: ",
1292 grouped_reduction_op->toString(),
1293 ". Up to ",
1294 kMaxNumGroupedReductions,
1295 " reductions are allowed.");
1296 }
1297 }
1298}
1299
1300} // namespace cuda
1301} // namespace fuser
1302} // namespace jit
1303} // namespace torch
1304