1 | #include <lower_validation.h> |
2 | |
3 | #include <contiguity.h> |
4 | #include <expr_evaluator.h> |
5 | #include <instrumentation.h> |
6 | #include <ir_iostream.h> |
7 | #include <ir_utils.h> |
8 | #include <iter_visitor.h> |
9 | #include <lower2device.h> |
10 | #include <lower_utils.h> |
11 | #include <transform_iter.h> |
12 | #include <transform_replay.h> |
13 | #include <type.h> |
14 | |
15 | #include <ATen/cuda/CUDAContext.h> |
16 | #include <limits> |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | namespace fuser { |
21 | namespace cuda { |
22 | |
23 | namespace { |
24 | |
25 | //! Validate multiple output tensors of the same expression, i.e., |
26 | //! siblings, have valid domains and parallel types. Since siblings |
27 | //! are placed in the same loop nest, they must be parallelized the |
28 | //! same way. Will infer and modify serial parallel types if other |
29 | //! output/s are parallelized, so that user wouldn't have to specify |
30 | //! the same parallelization 3 times. Will throw if conflicts are |
31 | //! detected, i.e. TIDx vs BIDx etc. |
32 | class ValidateSiblings : public IterVisitor { |
33 | public: |
34 | static void validate(Fusion* fusion) { |
35 | ValidateSiblings validator; |
36 | validator.traverse(fusion); |
37 | } |
38 | |
39 | private: |
40 | using IterVisitor::handle; |
41 | |
42 | void handle(Expr* expr) final { |
43 | if (!ir_utils::isTvOp(expr) || expr->outputs().size() < 2) { |
44 | IterVisitor::handle(expr); |
45 | return; |
46 | } |
47 | |
48 | auto ref_output = expr->outputs().at(0)->as<TensorView>(); |
49 | auto ref_ndims = ref_output->nDims(); |
50 | const auto& ref_root = ref_output->getRootDomain(); |
51 | std::unordered_map<IterDomain*, IterDomain*> id_map; |
52 | |
53 | for (const auto sibling : |
54 | ir_utils::filterByType<TensorView>(expr->outputs())) { |
55 | if (ref_output == sibling) { |
56 | continue; |
57 | } |
58 | |
59 | TORCH_INTERNAL_ASSERT( |
60 | sibling->nDims() == ref_ndims, |
61 | "Mismatched dimensionality detected. Expr: " , |
62 | expr->toString(), |
63 | "Ref output: " , |
64 | ref_output->toString(), |
65 | ". Sibling: " , |
66 | sibling->toString()); |
67 | |
68 | for (const auto i : c10::irange(ref_ndims)) { |
69 | validateParallelTypes(ref_output->axis(i), sibling->axis(i)); |
70 | } |
71 | |
72 | for (const auto i : c10::irange(ref_root.size())) { |
73 | id_map[ref_root[i]] = sibling->getRootDomain().at(i); |
74 | } |
75 | |
76 | BestEffortReplay replay( |
77 | sibling->domain()->domain(), ref_output->domain()->domain(), id_map); |
78 | for (const auto i : c10::irange(ref_ndims)) { |
79 | auto it = replay.getReplay().find(ref_output->axis(i)); |
80 | TORCH_INTERNAL_ASSERT( |
81 | it != replay.getReplay().end(), |
82 | "Matching sibling ID not found. Expr: " , |
83 | expr->toString(), |
84 | "Ref ID: " , |
85 | ref_output->axis(i)->toString()); |
86 | auto sibling_id = it->second; |
87 | TORCH_INTERNAL_ASSERT( |
88 | sibling->axis(i) == sibling_id, |
89 | "Invalid matching sibling ID detected. Expr: " , |
90 | expr->toString(), |
91 | "Sibling ID: " , |
92 | sibling_id->toString()); |
93 | } |
94 | } |
95 | } |
96 | |
97 | // Parallelize id1 and id0 consistently if one is serial and the other isn't |
98 | void validateParallelTypes(IterDomain* id0, IterDomain* id1) { |
99 | const auto ptype0 = id0->getParallelType(); |
100 | const auto ptype1 = id1->getParallelType(); |
101 | |
102 | if (ptype0 == ParallelType::Vectorize || |
103 | ptype1 == ParallelType::Vectorize) { |
104 | auto other_type = ptype0 == ParallelType::Vectorize ? ptype1 : ptype0; |
105 | TORCH_INTERNAL_ASSERT( |
106 | other_type == ParallelType::Vectorize || |
107 | (!isParallelTypeThreadDim(other_type) && |
108 | !isParallelTypeBlockDim(other_type)), |
109 | "Vectorize type was parallelized inconsistently in. " , |
110 | "Detected during promoting parallel types." ); |
111 | return; |
112 | } |
113 | |
114 | if (ptype0 != ptype1) { |
115 | TORCH_CHECK( |
116 | ptype0 == ParallelType::Serial || ptype1 == ParallelType::Serial, |
117 | "Error promoting parallel types" ); |
118 | if (ptype0 == ParallelType::Serial) { |
119 | id0->parallelize(ptype1); |
120 | } |
121 | if (ptype1 == ParallelType::Serial) { |
122 | id1->parallelize(ptype0); |
123 | } |
124 | } |
125 | } |
126 | }; |
127 | |
128 | // Make sure all IterDomains are only used for a unique |
129 | // TensorView. Several mappings from IterDomains are |
130 | // created during lowering, which relies on the unique usage of |
131 | // IterDomains. |
132 | void validateIterDomainUsage(Fusion* fusion) { |
133 | FUSER_PERF_SCOPE("GpuLower::Lower::validateIterDomainUse" ); |
134 | FusionGuard fg(fusion); |
135 | |
136 | auto used_vals = fusion->usedMathVals(); |
137 | std::unordered_map<IterDomain*, TensorView*> domain_use_map; |
138 | |
139 | for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) { |
140 | std::unordered_set<Val*> root_domains; |
141 | std::copy( |
142 | tv->getRootDomain().begin(), |
143 | tv->getRootDomain().end(), |
144 | std::inserter(root_domains, root_domains.begin())); |
145 | |
146 | std::vector<Val*> leaf_domains; |
147 | std::copy( |
148 | tv->domain()->domain().begin(), |
149 | tv->domain()->domain().end(), |
150 | std::back_inserter(leaf_domains)); |
151 | |
152 | auto all_domain_vals = |
153 | DependencyCheck::getAllValsBetween(root_domains, leaf_domains); |
154 | |
155 | for (auto id : ir_utils::filterByType<IterDomain>(all_domain_vals)) { |
156 | auto it = domain_use_map.find(id); |
157 | TORCH_INTERNAL_ASSERT( |
158 | it == domain_use_map.end(), |
159 | "Multiple use of " , |
160 | id, |
161 | " detected." , |
162 | " Used in both TV" , |
163 | tv->name(), |
164 | " and TV" , |
165 | it->second->name()); |
166 | domain_use_map.insert({id, tv}); |
167 | } |
168 | } |
169 | } |
170 | |
171 | } // namespace |
172 | |
173 | void validateIr(Fusion* fusion) { |
174 | FUSER_PERF_SCOPE("GpuLower::Lower::validateIr" ); |
175 | |
176 | FusionGuard fg(fusion); |
177 | |
178 | fusion->validateInputs(); |
179 | |
180 | // Validate Parallelization |
181 | ValidateSiblings::validate(fusion); |
182 | |
183 | validateIterDomainUsage(fusion); |
184 | } |
185 | |
186 | namespace { |
187 | |
188 | // Check contiguity for all root domains associated with Misaligned Vectorize |
189 | // ParallelType |
190 | void checkContiguity( |
191 | const std::unordered_set<IterDomain*>& domains, |
192 | TensorView* tv) { |
193 | TORCH_INTERNAL_ASSERT(tv->getMemoryType() == MemoryType::Global); |
194 | |
195 | for (const auto idx : c10::irange(tv->getRootDomain().size())) { |
196 | auto root = tv->getRootDomain()[idx]; |
197 | if (domains.find(root) != domains.end()) { |
198 | TORCH_INTERNAL_ASSERT( |
199 | !root->isBroadcast(), |
200 | "Misaligned vectorization prohibits merging broadcast domains." , |
201 | "Issue found in, " , |
202 | tv); |
203 | TORCH_INTERNAL_ASSERT( |
204 | tv->domain()->contiguity()[idx], |
205 | "Cannot merge non-contiguous root domains with misaligned vectorization." , |
206 | "Issue found in, " , |
207 | tv); |
208 | } |
209 | } |
210 | } |
211 | |
212 | // Check all root iter domains in consumer that are present in domain, making |
213 | // sure they're contiguous. Map these domains to producer and make sure they are |
214 | // also contiguous in producer. Producer-consumer relationship is assumed to be |
215 | // through a set operation. |
216 | void checkContiguity( |
217 | const std::unordered_set<IterDomain*>& domains, |
218 | TensorView* consumer, |
219 | TensorView* producer) { |
220 | // This seems not quite right, shouldn't we be able to reverse this? |
221 | TORCH_INTERNAL_ASSERT(consumer->getMemoryType() == MemoryType::Local); |
222 | TORCH_INTERNAL_ASSERT(producer->getMemoryType() == MemoryType::Global); |
223 | |
224 | auto root_c2p = |
225 | PairwiseRootDomainMap(producer, consumer) |
226 | .mapConsumerToProducer(consumer->domain(), producer->domain()); |
227 | |
228 | std::unordered_map<IterDomain*, bool> producer_domain_contiguity; |
229 | for (const auto idx : c10::irange(producer->getMaybeRFactorDomain().size())) { |
230 | auto root = producer->getMaybeRFactorDomain()[idx]; |
231 | auto contiguity = producer->domain()->contiguity()[idx]; |
232 | producer_domain_contiguity.insert({root, contiguity}); |
233 | } |
234 | |
235 | for (auto consumer_root : consumer->getMaybeRFactorDomain()) { |
236 | if (domains.find(consumer_root) != domains.end()) { |
237 | auto producer_root = root_c2p[consumer_root]; |
238 | TORCH_INTERNAL_ASSERT( |
239 | producer_domain_contiguity.find(producer_root) != |
240 | producer_domain_contiguity.end()); |
241 | |
242 | TORCH_INTERNAL_ASSERT( |
243 | !consumer_root->isBroadcast() || !producer_root->isBroadcast(), |
244 | "Misaligned vectorization prohibits merging broadcast domains." , |
245 | "Issue found in, " , |
246 | consumer); |
247 | |
248 | TORCH_INTERNAL_ASSERT(root_c2p.find(consumer_root) != root_c2p.end()); |
249 | |
250 | TORCH_INTERNAL_ASSERT( |
251 | producer_domain_contiguity[producer_root], |
252 | "Cannot merge non-contiguous root domains with misaligned vectorization." , |
253 | "Issue found in, " , |
254 | consumer); |
255 | } |
256 | } |
257 | } |
258 | |
259 | class VectorizeValidator : public OptInDispatch { |
260 | private: |
261 | // Initially, vectorized_id is the IterDomain with Vectorize ParallelType |
262 | // After processing all merge and split operations, |
263 | // vectorized_id is the corresponding root domain |
264 | VectorizeValidator(IterDomain* vectorized_id) |
265 | : vectorized_id_(vectorized_id) {} |
266 | |
267 | using OptInDispatch::handle; |
268 | |
269 | void handle(Split* s) final { |
270 | if (s->outer() == vectorized_id_) { |
271 | is_valid = false; |
272 | } else if (s->inner() == vectorized_id_) { |
273 | vectorized_id_ = s->in(); |
274 | } |
275 | domains_.insert(s->outer()); |
276 | domains_.insert(s->inner()); |
277 | } |
278 | |
279 | void handle(Merge* m) final { |
280 | if (m->out() == vectorized_id_) { |
281 | if (m->inner()->isBroadcast() && !m->outer()->isBroadcast()) { |
282 | vectorized_id_ = m->outer(); |
283 | } else { |
284 | vectorized_id_ = m->inner(); |
285 | } |
286 | } |
287 | domains_.insert(m->outer()); |
288 | domains_.insert(m->inner()); |
289 | } |
290 | |
291 | // For the producer tensor, it's indexed first by transformed like |
292 | // the consumer. So, to find its contig merged domain, use the |
293 | // consumer TensorDomain with the producer contiguity info. |
294 | static std::vector<bool> mapProducerContiguity( |
295 | TensorView* producer_tv, |
296 | TensorView* consumer_tv) { |
297 | const auto c2p = PairwiseRootDomainMap(producer_tv, consumer_tv) |
298 | .mapConsumerToProducer( |
299 | consumer_tv->domain(), producer_tv->domain()); |
300 | |
301 | std::vector<bool> producer_contiguity; |
302 | |
303 | for (auto consumer_root_id : consumer_tv->getRootDomain()) { |
304 | auto producer_root_id = c2p.at(consumer_root_id); |
305 | auto producer_root_it = std::find( |
306 | producer_tv->getMaybeRFactorDomain().begin(), |
307 | producer_tv->getMaybeRFactorDomain().end(), |
308 | producer_root_id); |
309 | TORCH_INTERNAL_ASSERT( |
310 | producer_root_it != producer_tv->getMaybeRFactorDomain().end()); |
311 | auto producer_root_id_offset = std::distance( |
312 | producer_tv->getMaybeRFactorDomain().begin(), producer_root_it); |
313 | producer_contiguity.push_back( |
314 | producer_tv->domain()->contiguity().at(producer_root_id_offset)); |
315 | } |
316 | |
317 | return producer_contiguity; |
318 | } |
319 | |
320 | private: |
321 | std::unordered_set<IterDomain*> domains_; |
322 | IterDomain* vectorized_id_ = nullptr; |
323 | bool is_valid = true; |
324 | |
325 | public: |
326 | static void validate(TensorView* tv) { |
327 | // Make sure there's only one vectorized ID |
328 | IterDomain* v_id = nullptr; |
329 | bool misaligned_vectorize = false; |
330 | for (auto id : tv->domain()->domain()) { |
331 | if (id->getParallelType() == ParallelType::Vectorize || |
332 | id->getParallelType() == ParallelType::MisalignedVectorize) { |
333 | TORCH_INTERNAL_ASSERT( |
334 | v_id == nullptr, |
335 | "Found two vectorized domains in " , |
336 | tv, |
337 | " only one is allowed." ); |
338 | v_id = id; |
339 | misaligned_vectorize = |
340 | id->getParallelType() == ParallelType::MisalignedVectorize; |
341 | } |
342 | } |
343 | |
344 | // If no vectorized ids found simply return. If vectorized access is |
345 | // broadcast, it won't generate an actual vector instruction, so can safely |
346 | // be ignore |
347 | if (v_id == nullptr || v_id->isBroadcast()) { |
348 | return; |
349 | } |
350 | |
351 | auto fusion = FusionGuard::getCurFusion(); |
352 | |
353 | TORCH_CHECK( |
354 | v_id->extent()->isConstScalar(), |
355 | "Vectorizing a domain requires a constant size." ); |
356 | |
357 | ExpressionEvaluator const_expr_eval(fusion); |
358 | |
359 | auto vector_size_optional = const_expr_eval.evaluate(v_id->extent()); |
360 | |
361 | TORCH_CHECK( |
362 | vector_size_optional.has_value(), |
363 | "Could not evaluate constant value bound to vectorized dim." ); |
364 | |
365 | auto vector_size = ((int64_t)dataTypeSize(tv->getDataType().value())) * |
366 | vector_size_optional.value(); |
367 | |
368 | // Allow half2, float2, float4 and same sized vtypes. |
369 | std::array<int64_t, 4> allowed_vector_sizes = {2, 4, 8, 16}; // NOLINT |
370 | |
371 | TORCH_CHECK( |
372 | std::find( |
373 | allowed_vector_sizes.begin(), |
374 | allowed_vector_sizes.end(), |
375 | vector_size) != allowed_vector_sizes.end(), |
376 | "Tried to vectorize a dim resulting in a word size of " , |
377 | vector_size, |
378 | " however, vector sizes only upto and including 16 bytes are supported." ); |
379 | |
380 | auto replay_exprs = DependencyCheck::getAllExprsBetween( |
381 | {tv->getMaybeRFactorDomain().begin(), |
382 | tv->getMaybeRFactorDomain().end()}, |
383 | {v_id}); |
384 | |
385 | VectorizeValidator validator(v_id); |
386 | |
387 | for (auto expr_it = replay_exprs.rbegin(); expr_it != replay_exprs.rend(); |
388 | ++expr_it) { |
389 | auto expr = *expr_it; |
390 | validator.handle(expr); |
391 | } |
392 | |
393 | TORCH_CHECK( |
394 | validator.is_valid, |
395 | "Invalid vectorized pattern found, vectorization iter domains must be descendants of inner-most dimension." , |
396 | "Issue found in, " , |
397 | tv, |
398 | "\n" ); |
399 | |
400 | if (misaligned_vectorize) { |
401 | if (tv->getMemoryType() == MemoryType::Global) { |
402 | checkContiguity(validator.domains_, tv); |
403 | } else if ( |
404 | tv->definition()->getExprType() == ExprType::UnaryOp && |
405 | tv->definition()->as<UnaryOp>()->getUnaryOpType() == |
406 | UnaryOpType::Set) { |
407 | auto input = tv->definition()->input(0); |
408 | TORCH_INTERNAL_ASSERT(input->isA<TensorView>()); |
409 | auto input_tv = input->as<TensorView>(); |
410 | checkContiguity(validator.domains_, tv, input_tv); |
411 | } |
412 | } |
413 | |
414 | TORCH_INTERNAL_ASSERT(validator.vectorized_id_ != nullptr); |
415 | |
416 | // Contiguity is based on rfactor domain. |
417 | IterDomain* last_root_dim = nullptr; |
418 | int last_root_dim_pos = -1; |
419 | for (size_t i = tv->getMaybeRFactorDomain().size(); i > 0; i--) { |
420 | auto r_id = tv->getMaybeRFactorDomain()[i - 1]; |
421 | if (r_id->isReduction() || r_id->isBroadcast()) { |
422 | continue; |
423 | } |
424 | last_root_dim = r_id; |
425 | last_root_dim_pos = (int)i - 1; |
426 | break; |
427 | } |
428 | |
429 | if (last_root_dim == nullptr) { |
430 | // Should never get here, but that would mean there are no concrete dims, |
431 | // so we should be fine. |
432 | return; |
433 | } |
434 | |
435 | TORCH_CHECK( |
436 | last_root_dim == validator.vectorized_id_ && |
437 | tv->domain()->contiguity()[last_root_dim_pos], |
438 | "Vectorized dim has to be from a contiguous inner most position: " , |
439 | tv, |
440 | "\n" ); |
441 | |
442 | // Save info required to lowering and runtime validation |
443 | auto consumer_word_size_it = |
444 | GpuLower::current()->vectorizedAccesses().find(tv); |
445 | if (consumer_word_size_it != |
446 | GpuLower::current()->vectorizedAccesses().end()) { |
447 | consumer_word_size_it->second = std::max( |
448 | (int)vector_size_optional.value(), consumer_word_size_it->second); |
449 | } else { |
450 | GpuLower::current()->vectorizedAccesses().emplace( |
451 | tv, (int)vector_size_optional.value()); |
452 | } |
453 | auto producer_tv = tv->definition()->inputs().at(0)->as<TensorView>(); |
454 | auto producer_word_size_it = |
455 | GpuLower::current()->vectorizedAccesses().find(producer_tv); |
456 | if (producer_word_size_it != |
457 | GpuLower::current()->vectorizedAccesses().end()) { |
458 | producer_word_size_it->second = std::max( |
459 | (int)vector_size_optional.value(), producer_word_size_it->second); |
460 | } else { |
461 | GpuLower::current()->vectorizedAccesses().emplace( |
462 | producer_tv, (int)vector_size_optional.value()); |
463 | } |
464 | |
465 | VectorizedSetInfo vectorized_set_info; |
466 | vectorized_set_info.consumer_tv = tv; |
467 | vectorized_set_info.producer_tv = producer_tv; |
468 | // Note that VectorizedSetInfo is about each instance of |
469 | // vectorized set operations, so the word size is the size of this |
470 | // specific vectorized set. |
471 | vectorized_set_info.word_size = (int)vector_size_optional.value(); |
472 | vectorized_set_info.vectorized_leaf_id = v_id; |
473 | vectorized_set_info.vectorized_root_id = validator.vectorized_id_; |
474 | // For aligned vectorize, the extent of a vectorized domain must |
475 | // be divisible by the vector word size. The domain is usually |
476 | // just one of the root domains, but can be a merged domain of |
477 | // contiguous domains. Those domains are saved in |
478 | // VectorizedSetInfo.contig_root_ids at the time of indexing. |
479 | GpuLower::current()->vectorizedSetInfo().emplace_back(vectorized_set_info); |
480 | } |
481 | }; |
482 | |
483 | } // namespace |
484 | |
485 | // Uses ContigIDs to find root contig domains that a vectorized domain |
486 | // depends on. As ContigIDs depends on HaloInfo, this must be done |
487 | // after HaloInfo is created. |
488 | void validateAndCollectVectorizeInfo(Fusion* fusion) { |
489 | FUSER_PERF_SCOPE("GpuLower::Lower::validateVectorize" ); |
490 | FusionGuard fg(fusion); |
491 | |
492 | auto used_vals = fusion->usedMathVals(); |
493 | |
494 | std::unordered_set<TensorView*> used_tvs; |
495 | |
496 | for (auto val : used_vals) { |
497 | if (ir_utils::isTV(val)) { |
498 | used_tvs.emplace(val->as<TensorView>()); |
499 | } |
500 | } |
501 | |
502 | for (auto tv : used_tvs) { |
503 | bool has_vectorize_dim = false; |
504 | bool has_misaligned_vectorize_dim = false; |
505 | |
506 | for (const auto i : c10::irange(tv->nDims())) { |
507 | IterDomain* id = tv->axis(i); |
508 | IterDomain* concrete_id = |
509 | GpuLower::current()->caMap()->getConcreteMappedID( |
510 | id, IdMappingMode::LOOP); |
511 | |
512 | auto ptype = concrete_id->getParallelType(); |
513 | |
514 | if (ptype == ParallelType::Vectorize) { |
515 | // If we want to do this check up front we would have to do 2 things: |
516 | // (1) Check that the tensor view with vectorize being set on it is |
517 | // getting set outside the local compute at position |
518 | // (2) Check any producers of the tensor view with vectorize being set |
519 | // on it to make sure their compute at position isn't to the right of |
520 | // the vectorize dim. |
521 | TORCH_INTERNAL_ASSERT( |
522 | i >= tv->getComputeAtPosition(), |
523 | "IterDomains to the left of the compute at point cannot be vectorized: " , |
524 | tv, |
525 | "\n" ); |
526 | has_vectorize_dim = true; |
527 | } |
528 | |
529 | if (concrete_id->getParallelType() == ParallelType::MisalignedVectorize) { |
530 | TORCH_INTERNAL_ASSERT( |
531 | !tv->hasComputeAt() || |
532 | tv->getComputeAtPosition() == tv->nDims() - 1, |
533 | "Only allow misaligned vectorization in the -2 computeAt position." ); |
534 | TORCH_INTERNAL_ASSERT( |
535 | tv->getMemoryType() == MemoryType::Local || |
536 | tv->getMemoryType() == MemoryType::Global, |
537 | "Only allow misaligned vectorization between global and local memory." ); |
538 | has_misaligned_vectorize_dim = true; |
539 | } |
540 | } |
541 | if (has_vectorize_dim) { |
542 | TORCH_INTERNAL_ASSERT( |
543 | tv->definition() == nullptr || |
544 | (tv->definition()->isA<UnaryOp>() && |
545 | tv->definition()->as<UnaryOp>()->getUnaryOpType() == |
546 | UnaryOpType::Set) || |
547 | tv->definition()->isA<LoadStoreOp>(), |
548 | "Vectorized accesses cannot be inline with computation, they are only supported with a Set operation." , |
549 | "TensorView: " , |
550 | tv); |
551 | } |
552 | // Validate the vectorized domain maps to the innermost domain of |
553 | // tv. Note that we don't need to validate its producer tv as |
554 | // both Vectorize and MisalignedVectorize can only be used with |
555 | // UnaryOp::Set. |
556 | if (has_vectorize_dim || has_misaligned_vectorize_dim) { |
557 | VectorizeValidator::validate(tv); |
558 | } |
559 | } |
560 | } |
561 | |
562 | namespace { |
563 | |
564 | void fillVectorizedContigRootDomains( |
565 | const TensorView* tv, |
566 | const ContigIDs& contig_finder, |
567 | IterDomain* vectorized_root_id, |
568 | VectorizedSetInfo& info) { |
569 | const auto& root_dom = tv->getMaybeRFactorDomain(); |
570 | |
571 | // Find the root domains that are dependency of the merged contig |
572 | // domain. |
573 | |
574 | auto consumer_indexed_it = |
575 | contig_finder.rootToIndexedID().find(vectorized_root_id); |
576 | TORCH_INTERNAL_ASSERT( |
577 | consumer_indexed_it != contig_finder.rootToIndexedID().end(), |
578 | "Contiguity information not found for root domain: " , |
579 | vectorized_root_id->toString()); |
580 | auto consumer_indexed_id = consumer_indexed_it->second; |
581 | |
582 | // Actual indexed root domains for this root domain. If |
583 | // contig merge is done, multiple root domains are included. |
584 | std::unordered_set<IterDomain*> indexed_root_ids; |
585 | |
586 | if (consumer_indexed_id == vectorized_root_id) { |
587 | // Indexed domain is equal to the root domain, meaning no contig |
588 | // merge is involved. |
589 | indexed_root_ids.insert(vectorized_root_id); |
590 | } else { |
591 | auto consumer_within_contig_it = |
592 | contig_finder.withinContigIDs().find(consumer_indexed_id); |
593 | TORCH_INTERNAL_ASSERT( |
594 | consumer_within_contig_it != contig_finder.withinContigIDs().end()); |
595 | const auto& within_ids = consumer_within_contig_it->second; |
596 | std::copy_if( |
597 | root_dom.begin(), |
598 | root_dom.end(), |
599 | std::inserter(indexed_root_ids, indexed_root_ids.end()), |
600 | [&](IterDomain* root_id) { |
601 | return within_ids.find(root_id) != within_ids.end(); |
602 | }); |
603 | } |
604 | |
605 | // Store the contig merged root domains. If it is already set, pick |
606 | // the smaller one as it is used for validating divisibility of the |
607 | // merged extent. |
608 | if (info.contig_root_ids.empty() || |
609 | indexed_root_ids.size() < info.contig_root_ids.size()) { |
610 | info.contig_root_ids = indexed_root_ids; |
611 | } |
612 | } |
613 | |
614 | } // namespace |
615 | |
616 | void fillConsumerVectorizedContigRootDomains( |
617 | const TensorView* consumer_tv, |
618 | const ContigIDs& contig_finder) { |
619 | auto& info_vector = GpuLower::current()->vectorizedSetInfo(); |
620 | auto it = std::find_if( |
621 | info_vector.begin(), info_vector.end(), [&consumer_tv](auto& info) { |
622 | return info.consumer_tv == consumer_tv; |
623 | }); |
624 | if (it == info_vector.end()) { |
625 | return; |
626 | } |
627 | |
628 | VectorizedSetInfo& info = *it; |
629 | |
630 | // info.vectorized_root_id is validated at this point to be the |
631 | // last concrete root domain in consumer. |
632 | auto consumer_root_id = info.vectorized_root_id; |
633 | |
634 | fillVectorizedContigRootDomains( |
635 | consumer_tv, contig_finder, consumer_root_id, info); |
636 | } |
637 | |
638 | void fillProducerVectorizedContigRootDomains( |
639 | const TensorView* producer_tv, |
640 | const TensorView* consumer_tv, |
641 | const std::unordered_map<IterDomain*, IterDomain*>& c2p_map, |
642 | const ContigIDs& contig_finder) { |
643 | auto& info_vector = GpuLower::current()->vectorizedSetInfo(); |
644 | auto it = std::find_if( |
645 | info_vector.begin(), |
646 | info_vector.end(), |
647 | [&producer_tv, &consumer_tv](auto& info) { |
648 | return info.consumer_tv == consumer_tv && |
649 | info.producer_tv == producer_tv; |
650 | }); |
651 | if (it == info_vector.end()) { |
652 | return; |
653 | } |
654 | |
655 | VectorizedSetInfo& info = *it; |
656 | |
657 | // info.vectorized_root_id is validated at this point to be the |
658 | // last concrete root domain in consumer. |
659 | auto consumer_root_id = info.vectorized_root_id; |
660 | |
661 | auto root_id = c2p_map.at(consumer_root_id); |
662 | |
663 | fillVectorizedContigRootDomains(producer_tv, contig_finder, root_id, info); |
664 | } |
665 | |
666 | namespace { |
667 | |
668 | // Backward propagation of partial ranges from outputs to |
669 | // inputs. Necessary to determine required ranges to compute. |
670 | // |
671 | // Example: |
672 | // tv0: [0:N] |
673 | // tv1: shift(tv0, {1}) -> [1:N] |
674 | // tv2: shift(tv0, {-1}) -> [0:N-1] |
675 | // tv3: tv1 + tv2 -> [1:N-1] |
676 | // |
677 | // In this case, the valid range of tv3 starts at 1 and ends at |
678 | // N-1. This means that not all of the values of tv1 and tv2 are |
679 | // actually necessary. Specifically, tv1[0] and tv2[N-1] aren't used |
680 | // for tv3. This function calculates the required minimum range of |
681 | // each tensor that needs to be computed. |
682 | std::unordered_map<IterDomain*, std::pair<int64_t, int64_t>> getLiveRangeOffsets( |
683 | Fusion* fusion) { |
684 | auto exprs = StmtSort::getExprs(fusion); |
685 | |
686 | std::unordered_map<IterDomain*, std::pair<int64_t, int64_t>> map; |
687 | |
688 | ExpressionEvaluator ee(fusion); |
689 | |
690 | for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) { |
691 | auto expr = *it; |
692 | for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) { |
693 | for (auto consumer_root : consumer->getRootDomain()) { |
694 | auto consumer_start_offset = ee.evaluate(consumer_root->start()); |
695 | auto consumer_stop_offset = ee.evaluate(consumer_root->stopOffset()); |
696 | TORCH_INTERNAL_ASSERT( |
697 | consumer_start_offset.has_value(), |
698 | "Can't evaluate start value of " , |
699 | consumer_root->start()); |
700 | TORCH_INTERNAL_ASSERT( |
701 | consumer_stop_offset.has_value(), |
702 | "Can't evaluate stop value of " , |
703 | consumer_root->stopOffset()); |
704 | auto it = map.find(consumer_root); |
705 | if (it == map.end() || consumer->isFusionOutput()) { |
706 | // No range set for this root domain, which means this |
707 | // consumer_tensor is an output tensor or the consumer_root |
708 | // domain is a reduction domain. In either case, the |
709 | // required range is simply defined by the start and stop |
710 | // offsets of the root domain. |
711 | // Also, when consumer is an output, even if it's not |
712 | // terminating, the range to compute must not be affected by |
713 | // how it's used by its consumers because an output tensor |
714 | // is visible to outside of the fusion. |
715 | map.insert( |
716 | {consumer_root, |
717 | {consumer_start_offset->as<int64_t>(), |
718 | consumer_stop_offset->as<int64_t>()}}); |
719 | } else { |
720 | // When the range of this root domain is already set, it |
721 | // must be set by its consumers. Make sure the required |
722 | // range by the consumers is covered by the defined range of |
723 | // this root domain. |
724 | auto& consumer_range = it->second; |
725 | TORCH_INTERNAL_ASSERT( |
726 | consumer_start_offset->as<int64_t>() <= consumer_range.first); |
727 | TORCH_INTERNAL_ASSERT( |
728 | consumer_stop_offset->as<int64_t>() <= consumer_range.second); |
729 | } |
730 | } |
731 | |
732 | // Propagate the range information from consumers to the |
733 | // produces. Note that the effect on the range by shift and |
734 | // gather is not considered here but taken care by halo regions. |
735 | for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) { |
736 | auto c2p = |
737 | PairwiseRootDomainMap(producer, consumer) |
738 | .mapConsumerToProducer(consumer->domain(), producer->domain()); |
739 | for (auto consumer_root : consumer->getRootDomain()) { |
740 | auto producer_it = c2p.find(consumer_root); |
741 | if (producer_it == c2p.end()) { |
742 | continue; |
743 | } |
744 | auto producer_root = producer_it->second; |
745 | auto& consumer_range = map.at(consumer_root); |
746 | const std::pair<int64_t, int64_t> init_range{ |
747 | std::numeric_limits<int64_t>::max(), |
748 | std::numeric_limits<int64_t>::max()}; |
749 | auto& producer_range = |
750 | map.insert({producer_root, init_range}).first->second; |
751 | producer_range.first = |
752 | std::min(producer_range.first, consumer_range.first); |
753 | producer_range.second = |
754 | std::min(producer_range.second, consumer_range.second); |
755 | } |
756 | } |
757 | } |
758 | } |
759 | |
760 | return map; |
761 | } |
762 | |
763 | // Make sure that a partial split with split_offset does not violate |
764 | // the required range defined by domain_offset. Suppose checking the |
765 | // start side of a root domain. Only positions at split_offset or |
766 | // larger are going to be computed, and all positions starting at |
767 | // domain_offset must be computed, thus split_offset must be smaller |
768 | // or equal to domain_offset. The same condition must hold for the end |
769 | // side of the domain. |
770 | // |
771 | // In order to validate this condition, the split offset is assumed to |
772 | // be a statically known constant value. This is not a hard |
773 | // requirement, but otherwise a runtime check would be needed. |
774 | void validateSplit( |
775 | Val* split_offset, |
776 | int64_t domain_offset, |
777 | const std::string& err_msg_prefix) { |
778 | ExpressionEvaluator ee(split_offset->fusion()); |
779 | |
780 | TORCH_INTERNAL_ASSERT(split_offset->isA<Int>()); |
781 | auto split_offset_value = ee.evaluate(split_offset); |
782 | TORCH_INTERNAL_ASSERT( |
783 | split_offset_value.has_value(), |
784 | err_msg_prefix, |
785 | ": Unknown offset of split: " , |
786 | split_offset); |
787 | |
788 | TORCH_INTERNAL_ASSERT( |
789 | split_offset_value.value() <= domain_offset, |
790 | err_msg_prefix, |
791 | ": Split offset is larger than the domain offset." , |
792 | " Split offset: " , |
793 | split_offset_value.value(), |
794 | ". Domain offset: " , |
795 | domain_offset); |
796 | } |
797 | |
798 | } // namespace |
799 | |
800 | void validatePartialSplit(Fusion* fusion) { |
801 | FUSER_PERF_SCOPE("GpuLower::Lower::validatePartialSplit" ); |
802 | FusionGuard fg(fusion); |
803 | |
804 | // If a root domain is partially split, only the sub range defined |
805 | // by the start and stop offsets of the partial split is |
806 | // computed. That sub range must cover the required range of the |
807 | // domain. So, the first thing to do is to determine the required |
808 | // minimum range of each root domain. Then, check if any partial |
809 | // split could result in a smaller range than the required range. |
810 | |
811 | // Compute the required range of each root domain |
812 | auto range_info = getLiveRangeOffsets(fusion); |
813 | |
814 | for (auto tv : ir_utils::allTvs(fusion)) { |
815 | auto exprs = StmtSort::getExprs( |
816 | tv->fusion(), |
817 | {tv->domain()->domain().begin(), tv->domain()->domain().end()}); |
818 | for (auto split : ir_utils::filterByType<Split>(exprs)) { |
819 | // When the start and stop offsets are not zero, make sure the |
820 | // range defined by the split includes the required range to |
821 | // compute. If both of the split offsets are zero, this |
822 | // condition is obviously true. Also, this validation only needs |
823 | // to be done with root domains. Since the start and stop |
824 | // offsets of non-root domains must be just zero, they are |
825 | // skipped at this point. |
826 | if (split->startOffset()->isZeroInt() && |
827 | split->stopOffset()->isZeroInt()) { |
828 | continue; |
829 | } |
830 | auto root_domain = split->in(); |
831 | std::stringstream err_msg_prefix; |
832 | err_msg_prefix << "Error with " << root_domain << " in T" << tv->name(); |
833 | TORCH_INTERNAL_ASSERT(range_info.find(root_domain) != range_info.end()); |
834 | const auto& valid_range = range_info.at(root_domain); |
835 | // Check the start offset. If it's zero, no validation regarding |
836 | // the required range can occur. |
837 | if (!split->startOffset()->isZeroInt()) { |
838 | validateSplit( |
839 | split->startOffset(), valid_range.first, err_msg_prefix.str()); |
840 | } |
841 | // Same for the stop offset. |
842 | if (!split->stopOffset()->isZeroInt()) { |
843 | validateSplit( |
844 | split->stopOffset(), valid_range.second, err_msg_prefix.str()); |
845 | } |
846 | } |
847 | } |
848 | } |
849 | |
850 | namespace { |
851 | |
852 | //! Utility to make sure targeted gpu capability is |
853 | //! higher than provided major.minor. |
854 | void validateMinimumArch(int major, int minor) { |
855 | // Skip checking arch if disabled. |
856 | if (isOptionDisabled(DisableOption::ArchCheck)) { |
857 | return; |
858 | } |
859 | |
860 | auto prop = at::cuda::getCurrentDeviceProperties(); |
861 | TORCH_INTERNAL_ASSERT(prop->major >= major); |
862 | if (prop->major == major) { |
863 | TORCH_INTERNAL_ASSERT(prop->minor >= minor); |
864 | } |
865 | } |
866 | |
867 | //! Validates that the operand and result tensors |
868 | //! of mma ops are swizzled and also validates |
869 | //! specialization of tidx as lane id. |
870 | void validateMmaTensors(MmaOp* mma) { |
871 | bool tidx_validated = false; |
872 | std::vector<TensorView*> to_validate = { |
873 | mma->inA()->as<TensorView>(), |
874 | mma->inB()->as<TensorView>(), |
875 | mma->out()->as<TensorView>()}; |
876 | |
877 | for (auto tv : to_validate) { |
878 | for (auto id : tv->domain()->domain()) { |
879 | auto ptype = id->getParallelType(); |
880 | if (ptype == ParallelType::TIDx) { |
881 | TORCH_INTERNAL_ASSERT( |
882 | id->isMmaSwizzled(), |
883 | "TIDx for mma input/output must be set by WarpMmaSwizzler" , |
884 | id, |
885 | tv); |
886 | if (!tidx_validated) { |
887 | // Check that TIDx is exact lane_id |
888 | const auto& paralel_dim_map = |
889 | GpuLower::current()->parallelDimensionMap(); |
890 | TORCH_INTERNAL_ASSERT( |
891 | paralel_dim_map.isExact(ptype) && |
892 | paralel_dim_map.get(ptype)->isConstInt() && |
893 | paralel_dim_map.get(ptype)->evaluateInt() == |
894 | at::cuda::warp_size(), |
895 | "TIDx is reserved for lane id in mma kernels, and it needs to be exactly a warp" ); |
896 | tidx_validated = true; |
897 | } |
898 | } |
899 | } |
900 | } |
901 | |
902 | // Note: this check will be relaxed in a follow up. |
903 | auto validate_operand = [](const TensorView* tv) { |
904 | TORCH_INTERNAL_ASSERT( |
905 | tv->getMemoryType() == MemoryType::Local, |
906 | "Only supporting register input for mma ops, up to sm80 all mma ops have to take register inputs." ); |
907 | |
908 | TORCH_INTERNAL_ASSERT( |
909 | std::all_of( |
910 | tv->domain()->domain().begin() + tv->getComputeAtPosition(), |
911 | tv->domain()->domain().end(), |
912 | [](IterDomain* id) { |
913 | return id->isMmaSwizzled() || |
914 | // MMA instructions can only take inputs from registers, |
915 | // so we always assume mma op inputs are located on |
916 | // registers. |
917 | // Currently requiring that serial ids on the right of the |
918 | // CA axis are constant sized to ensure early detection of |
919 | // invalid mma schedules. |
920 | ((id->isBroadcast() || id->extent()->isConstInt()) && |
921 | id->getParallelType() == ParallelType::Serial); |
922 | }), |
923 | "All id's on the right of CA pos needs to be mma-swizzled by WarpMmaSwizzler\n" , |
924 | tv); |
925 | }; |
926 | |
927 | validate_operand(mma->inA()->as<TensorView>()); |
928 | validate_operand(mma->inB()->as<TensorView>()); |
929 | |
930 | // Additionally validate that mma is not directly taking a double buffered |
931 | // register input as the double buffer indexing is currently not compatible |
932 | // with fragment iteration. Would need to require a cache stage in this case. |
933 | TORCH_INTERNAL_ASSERT( |
934 | !mma->inA()->as<TensorView>()->isDoubleBuffered(), |
935 | "MMA op cannot directly take double buffered register input, put a set stage before." ); |
936 | TORCH_INTERNAL_ASSERT( |
937 | !mma->inB()->as<TensorView>()->isDoubleBuffered(), |
938 | "MMA op cannot directly take double buffered register input, put a set stage before." ); |
939 | } |
940 | |
941 | //! Note and TODO: |
942 | //! Currently relying on ldmatrix to |
943 | //! obtain the correct data layout for turing/ampere |
944 | //! mma's. |
945 | //! This restriction will eventually not |
946 | //! be necessary once the scatter swizzle is ready. |
947 | void validateTuringMmaInput(TensorView* tv) { |
948 | // Pattern matching here to make sure LDMatrix is the right format. |
949 | // Format is done through swizzling in the scheduling and |
950 | // we check that swizzling to make sure it's correctly setup for LDMatrix. |
951 | // We could in theory support patterns LDMatrix doesn't support, |
952 | // but that would also mean the MMA isn't supported and |
953 | // so we would have to lower to something completely different. |
954 | |
955 | // MemCpy async is a more generic utility that we can use. |
956 | // Currently only allowed input paths are: |
957 | // ldmatrix -> mma or |
958 | // ldmatrix -> broadcast -> mma |
959 | // We actually wouldn't want too much flexibility here since |
960 | // this path is very perf critical. But the check itself |
961 | // can be made cleaner once we have the correct swizzle |
962 | // labeling. |
963 | // The most generic support would involve build out to |
964 | // support any pointwise ops that does not change the |
965 | // datalayout. |
966 | auto tv_def = tv->definition(); |
967 | TORCH_INTERNAL_ASSERT(tv_def); |
968 | if (tv_def->isA<BroadcastOp>()) { |
969 | tv_def = tv_def->input(0)->definition(); |
970 | } |
971 | TORCH_INTERNAL_ASSERT(tv_def); |
972 | TORCH_INTERNAL_ASSERT(ir_utils::isLdMatrixOp(tv_def)); |
973 | } |
974 | |
975 | // Output of ldmatrix is swizzled with the mma format, so it |
976 | // currently should not be fused with any pointwise ops. This |
977 | // check is to protect against these cases. |
978 | // This would also not be needed once scatter swizzle ready, should |
979 | // just become a swizzle format check if we wanted to fuse ldmatrix |
980 | // with any op other than mma. |
981 | void validateLdMatrixOutput(TensorView* tv) { |
982 | const auto& out_uses = tv->fusion()->unordered_uses(tv); |
983 | if (out_uses.empty()) { |
984 | return; |
985 | } |
986 | // TODO: restricting to single use pipelines for now which |
987 | // is true to matmul mainloop. This Could be relaxed to |
988 | // support more complex mma usage. |
989 | TORCH_INTERNAL_ASSERT(out_uses.size() == 1); |
990 | auto out_use = *(out_uses.begin()); |
991 | |
992 | if (out_use->isA<BroadcastOp>()) { |
993 | validateLdMatrixOutput(out_use->output(0)->as<TensorView>()); |
994 | return; |
995 | } |
996 | |
997 | TORCH_INTERNAL_ASSERT( |
998 | out_use->isA<MmaOp>(), |
999 | "validateLdMatrixOutput: currently only supports single mma use for ldmatrix" , |
1000 | out_use); |
1001 | } |
1002 | |
1003 | // Checks that the memory ops are supported on the targeted GPU |
1004 | void validateArchMemoryOp(LoadStoreOp* ldst) { |
1005 | switch (ldst->opType()) { |
1006 | case LoadStoreOpType::LdMatrix: |
1007 | case LoadStoreOpType::LdMatrixTranspose: |
1008 | validateMinimumArch(7, 5); |
1009 | validateLdMatrixOutput(ldst->out()->as<TensorView>()); |
1010 | return; |
1011 | case LoadStoreOpType::CpAsync: |
1012 | validateMinimumArch(8, 0); |
1013 | return; |
1014 | default: |
1015 | return; |
1016 | } |
1017 | } |
1018 | |
1019 | } // namespace |
1020 | |
1021 | //! Validate data format and GPU arch compatibility of scheduled |
1022 | //! mma operators on the fusion. |
1023 | void validateMma(Fusion* fusion) { |
1024 | auto exprs = StmtSort::getExprs(fusion); |
1025 | |
1026 | for (auto expr : exprs) { |
1027 | if (auto mma = dynamic_cast<MmaOp*>(expr)) { |
1028 | validateMmaTensors(mma); |
1029 | |
1030 | switch (mma->options().macro) { |
1031 | case MmaOptions::MacroType::Volta_16_16_4: |
1032 | validateMinimumArch(7, 0); |
1033 | break; |
1034 | case MmaOptions::MacroType::Turing_16_8_16: |
1035 | case MmaOptions::MacroType::Turing_16_16_16: |
1036 | validateMinimumArch(7, 5); |
1037 | |
1038 | // Check that operands come from ldmatrix, can be |
1039 | // relaxed once swizzles can be labeled on iterdomains. |
1040 | validateTuringMmaInput(mma->inA()->as<TensorView>()); |
1041 | validateTuringMmaInput(mma->inB()->as<TensorView>()); |
1042 | break; |
1043 | case MmaOptions::MacroType::Ampere_16_8_16: |
1044 | case MmaOptions::MacroType::Ampere_16_16_16: |
1045 | validateMinimumArch(8, 0); |
1046 | |
1047 | // Check that operands come from ldmatrix, can be |
1048 | // relaxed once swizzles can be labeled on iterdomains. |
1049 | validateTuringMmaInput(mma->inA()->as<TensorView>()); |
1050 | validateTuringMmaInput(mma->inB()->as<TensorView>()); |
1051 | break; |
1052 | default: |
1053 | TORCH_INTERNAL_ASSERT(false, "validate mma: unsupported macro" ); |
1054 | break; |
1055 | } |
1056 | } |
1057 | if (auto ldst = dynamic_cast<LoadStoreOp*>(expr)) { |
1058 | validateArchMemoryOp(ldst); |
1059 | } |
1060 | } |
1061 | } |
1062 | |
1063 | namespace { |
1064 | |
1065 | // Utility function to validate a loop swizzle: |
1066 | // 1. Throws an error if any output of the swizzle is not in leaf_domain set. |
1067 | // 2. Warns if any output of the swizzle is not the concrete id of the loop |
1068 | // map. |
1069 | // The second case would make the codegen ignore this swizzle, as if it was not |
1070 | // there at all. |
1071 | void validateLoopSwizzle( |
1072 | Expr* swizzle_expr, |
1073 | std::unordered_set<IterDomain*>& leaf_domains) { |
1074 | for (auto out_id : |
1075 | ir_utils::filterByType<IterDomain>(swizzle_expr->outputs())) { |
1076 | TORCH_INTERNAL_ASSERT( |
1077 | leaf_domains.count(out_id), |
1078 | "Loop swizzle can only be direct producer of leaf domains." ); |
1079 | if (GpuLower::current()->caMap()->getConcreteMappedID( |
1080 | out_id, IdMappingMode::LOOP) != out_id) { |
1081 | TORCH_WARN_ONCE("Ignored loop swizzle :" , swizzle_expr->toString()); |
1082 | } |
1083 | } |
1084 | } |
1085 | |
1086 | } // namespace |
1087 | |
1088 | void validateSwizzle(Fusion* fusion) { |
1089 | auto used_vals = fusion->usedMathVals(); |
1090 | for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) { |
1091 | if (tv->hasSwizzleOp()) { |
1092 | std::unordered_set<IterDomain*> tv_leaf_domain_set( |
1093 | tv->domain()->domain().begin(), tv->domain()->domain().end()); |
1094 | |
1095 | // Make sure no swizzle op is inlined: |
1096 | auto inlined_swizzles = ir_utils::getAllSwizzlesBetween( |
1097 | tv->getMaybeRFactorDomain(), |
1098 | {tv->domain()->domain().begin(), |
1099 | tv->domain()->domain().begin() + tv->getComputeAtPosition()}); |
1100 | |
1101 | auto not_inlined_swizzles = ir_utils::getAllSwizzlesBetween( |
1102 | tv->getMaybeRFactorDomain(), |
1103 | {tv->domain()->domain().begin() + tv->getComputeAtPosition(), |
1104 | tv->domain()->domain().end()}); |
1105 | |
1106 | // Check inlined swizzles: only loop swizzles can be inlined currently |
1107 | // as inlining data swizzles would require addtional support of unswizzle |
1108 | // operator, which currently doesn't have important use cases. |
1109 | for (auto swizzle_expr : inlined_swizzles) { |
1110 | TORCH_INTERNAL_ASSERT( |
1111 | swizzle_expr->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Loop, |
1112 | "Only support inlining loop swizzles" ); |
1113 | validateLoopSwizzle(swizzle_expr, tv_leaf_domain_set); |
1114 | } |
1115 | |
1116 | std::unordered_set<Expr*> inlined_swizzle_set( |
1117 | inlined_swizzles.begin(), inlined_swizzles.end()); |
1118 | |
1119 | // Check not inlined swizzles: |
1120 | // Apply the loop swizzle check when it applies, and |
1121 | // also make sure that the no swizzle is also in inlined_swizzle set. |
1122 | // The latter would mean that one output of the swizzle is inlined while |
1123 | // the other is not. Such case will not be supported. |
1124 | for (auto swizzle_expr : not_inlined_swizzles) { |
1125 | TORCH_INTERNAL_ASSERT( |
1126 | !inlined_swizzle_set.count(swizzle_expr), |
1127 | "Cannot partially inline across swizzle domains." , |
1128 | swizzle_expr->toString()); |
1129 | if (swizzle_expr->as<Swizzle2D>()->swizzleMode() == SwizzleMode::Loop) { |
1130 | validateLoopSwizzle(swizzle_expr, tv_leaf_domain_set); |
1131 | } |
1132 | } |
1133 | } |
1134 | } |
1135 | } |
1136 | |
1137 | void validateAndConvertIterDomainGrouping(Fusion* fusion) { |
1138 | for (auto tv : ir_utils::allTvs(fusion)) { |
1139 | bool is_grouped = false; |
1140 | for (const auto id_idx : c10::irange(tv->nDims())) { |
1141 | const auto id = tv->axis(id_idx); |
1142 | auto ptype = GpuLower::current() |
1143 | ->caMap() |
1144 | ->getConcreteMappedID(id, IdMappingMode::LOOP) |
1145 | ->getParallelType(); |
1146 | if (ptype != ParallelType::Group) { |
1147 | // Not a grouped ID |
1148 | continue; |
1149 | } |
1150 | |
1151 | // Remember if a grouped ID is found |
1152 | is_grouped = true; |
1153 | |
1154 | // Grouping only makes sense for the normal iteration type |
1155 | TORCH_CHECK( |
1156 | id->getIterType() == IterType::Iteration, |
1157 | "Invalid use of ParallelType::Group." , |
1158 | " Grouping of " , |
1159 | id->getIterType(), |
1160 | " is not allowed. " , |
1161 | tv->toString()); |
1162 | |
1163 | // Extent must be static |
1164 | TORCH_CHECK( |
1165 | id->extent()->getInt().has_value(), |
1166 | "Invalid use of ParallelType::Group." , |
1167 | " IterDomain must have a static extent: " , |
1168 | id->toString()); |
1169 | |
1170 | // The CA position must be left of any grouped ID |
1171 | TORCH_CHECK( |
1172 | tv->getComputeAtPosition() <= id_idx, |
1173 | "Invalid use of ParallelType::Group." , |
1174 | " ComputeAt position must be left of grouped IDs: " , |
1175 | tv->toString()); |
1176 | |
1177 | // Similarly, the produce-at position must be left of any grouped ID |
1178 | TORCH_CHECK( |
1179 | tv->getMaxProducerPosition() <= id_idx, |
1180 | "Invalid use of ParallelType::Group." , |
1181 | " ProduceAt position must be left of grouped IDs: " , |
1182 | tv->toString()); |
1183 | |
1184 | // Halo is not allowed |
1185 | TORCH_CHECK( |
1186 | GpuLower::current()->haloInfo()->getExtent(id) == nullptr, |
1187 | "Invalid use of ParallelType::Group." , |
1188 | " Grouping of halo-extended IterDomain, " , |
1189 | id->toString(), |
1190 | ", is not supported. " , |
1191 | tv->toString()); |
1192 | } |
1193 | |
1194 | if (!is_grouped) { |
1195 | continue; |
1196 | } |
1197 | |
1198 | // Must be defined by ReductionOp |
1199 | auto def = tv->definition(); |
1200 | TORCH_CHECK( |
1201 | def != nullptr, |
1202 | "Invalid use of ParallelType::Group." , |
1203 | " Definition of tv with ParallelType::Group not found. " , |
1204 | tv->toString()); |
1205 | |
1206 | TORCH_CHECK( |
1207 | tv->definition()->isA<ReductionOp>() || |
1208 | tv->definition()->isA<GroupedReductionOp>() || |
1209 | tv->definition()->isA<WelfordOp>() || |
1210 | tv->definition()->isA<GroupedWelfordOp>(), |
1211 | "Invalid use of ParallelType::Group. Only ReductionOp, GroupedReductionOp, WelfordOp and GroupedWelfordOp are allowed. " , |
1212 | tv->definition()->toString()); |
1213 | |
1214 | // Convert ReductionOp to GroupedReductionOp |
1215 | if (tv->definition()->isA<ReductionOp>()) { |
1216 | auto rop = def->as<ReductionOp>(); |
1217 | auto is_allreduce = rop->isAllreduce(); |
1218 | |
1219 | TORCH_CHECK( |
1220 | is_allreduce, |
1221 | "Invalid use of ParallelType::Group." , |
1222 | " Only enabled for allreduce reductions: " , |
1223 | rop->toString()); |
1224 | |
1225 | TORCH_CHECK( |
1226 | tv->domain()->hasGridReduction(), |
1227 | "Invalid use of ParallelType::Group." , |
1228 | " Only enabled for grid reductions: " , |
1229 | rop->toString()); |
1230 | |
1231 | std::vector<BinaryOpType> op_types({rop->getReductionOpType()}); |
1232 | std::vector<Val*> init_vals({rop->init()}); |
1233 | std::vector<Val*> outputs({rop->out()}); |
1234 | std::vector<Val*> inputs({rop->in()}); |
1235 | |
1236 | fusion->removeExpr(rop); |
1237 | IrBuilder::create<GroupedReductionOp>( |
1238 | static_cast<IrContainer*>(fusion), |
1239 | op_types, |
1240 | init_vals, |
1241 | outputs, |
1242 | inputs, |
1243 | is_allreduce); |
1244 | } else if (tv->definition()->isA<WelfordOp>()) { |
1245 | // Convert WelfordOp to GroupedWelfordOp |
1246 | auto wop = def->as<WelfordOp>(); |
1247 | auto is_allreduce = wop->isAllreduce(); |
1248 | |
1249 | TORCH_CHECK( |
1250 | is_allreduce, |
1251 | "Invalid use of ParallelType::Group." , |
1252 | " Only enabled for allreduce reductions: " , |
1253 | wop->toString()); |
1254 | |
1255 | TORCH_CHECK( |
1256 | tv->domain()->hasGridReduction(), |
1257 | "Invalid use of ParallelType::Group." , |
1258 | " Only enabled for grid reductions: " , |
1259 | wop->toString()); |
1260 | |
1261 | std::vector<WelfordTriplet> output_vals( |
1262 | {{wop->outAvg(), wop->outVar(), wop->outN()}}); |
1263 | std::vector<WelfordTriplet> input_vals( |
1264 | {{wop->inAvg(), wop->inVar(), wop->inN()}}); |
1265 | std::vector<WelfordTriplet> init_vals( |
1266 | {{wop->initAvg(), wop->initVar(), wop->initN()}}); |
1267 | fusion->removeExpr(wop); |
1268 | IrBuilder::create<GroupedWelfordOp>( |
1269 | static_cast<IrContainer*>(fusion), |
1270 | output_vals, |
1271 | input_vals, |
1272 | init_vals, |
1273 | is_allreduce); |
1274 | } |
1275 | } |
1276 | } |
1277 | |
1278 | void validateGroupedReductions(Fusion* fusion) { |
1279 | for (auto expr : StmtSort::getExprs(fusion)) { |
1280 | if (auto grouped_reduction_op = dynamic_cast<GroupedReductionOp*>(expr)) { |
1281 | const auto num_exprs = grouped_reduction_op->numExprs(); |
1282 | int num_grouped_iterations = 1; |
1283 | auto out_tv = ir_utils::getTvOutput(grouped_reduction_op); |
1284 | for (auto axis : out_tv->domain()->domain()) { |
1285 | if (axis->getParallelType() == ParallelType::Group) { |
1286 | num_grouped_iterations *= axis->extent()->getInt().value(); |
1287 | } |
1288 | } |
1289 | TORCH_CHECK( |
1290 | num_exprs * num_grouped_iterations <= kMaxNumGroupedReductions, |
1291 | "Too many grouped reductions: " , |
1292 | grouped_reduction_op->toString(), |
1293 | ". Up to " , |
1294 | kMaxNumGroupedReductions, |
1295 | " reductions are allowed." ); |
1296 | } |
1297 | } |
1298 | } |
1299 | |
1300 | } // namespace cuda |
1301 | } // namespace fuser |
1302 | } // namespace jit |
1303 | } // namespace torch |
1304 | |