1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <ir_all_nodes.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace fuser { |
10 | namespace cuda { |
11 | |
12 | class ContigIDs; |
13 | |
14 | void validateIr(Fusion* fusion); |
15 | |
16 | //! Validate vectorization and collect information on vectorization |
17 | //! used in code generation as well as runtime validation. |
18 | void validateAndCollectVectorizeInfo(Fusion* fusion); |
19 | |
20 | //! Find the contig root domains that a vectorized leaf domain |
21 | //! of a consumer TV depends on. Required for runtime validation. |
22 | void fillConsumerVectorizedContigRootDomains( |
23 | const TensorView* consumer_tv, |
24 | const ContigIDs& contig_finder); |
25 | |
26 | //! Find the contig root domains that a vectorized leaf domain |
27 | //! of a producer TV depends on. Required for runtime validation. |
28 | //! Producer must be transformed as consumer. |
29 | void fillProducerVectorizedContigRootDomains( |
30 | const TensorView* producer_tv, |
31 | const TensorView* consumer_tv, |
32 | const std::unordered_map<IterDomain*, IterDomain*>& c2p_map, |
33 | const ContigIDs& contig_finder); |
34 | |
35 | //! Validates partial split expressions. Partial split only uses an |
36 | //! inner subdomain specified by start and stop offsets, ignoring the |
37 | //! values outside the range. It's designed to be used with non-padded |
38 | //! shift, which introduces non-zero start and stop smaller than the |
39 | //! extent. This function makes sure all tensors have all values |
40 | //! calculated that are necessary for output values. |
41 | void validatePartialSplit(Fusion* fusion); |
42 | |
43 | //! Validate data format and GPU arch compatibility of scheduled |
44 | //! mma operators on the fusion. |
45 | void validateMma(Fusion* fusion); |
46 | |
47 | //! Validates swizzle ops to ensure consistent indexing: |
48 | //! - Currently only allow swizzle ops on the right of CA axis, |
49 | //! - (Except ZShape) All swizzle ops have to be on const sized ids |
50 | //! - Xor and Transpose swizzle have to have equal dimensions on the |
51 | //! participating ids. |
52 | void validateSwizzle(Fusion* fusion); |
53 | |
54 | //! Validate use of ParallelType::Group. It is currently only allowed |
55 | //! in ReductionOp and not in WelfordOp. Group has similar constraints |
56 | //! as Vectorize, e.g., it can only be used with IterDomains with |
57 | //! static extents. Differences are, e.g, it has no constraints on |
58 | //! alignments and predicates. Each individual reduction has its own |
59 | //! predicate, so it is possile for only part of grouped reductions to |
60 | //! be executed. |
61 | //! |
62 | //! Also, grouping is only enabled for persistent grid reductions, in |
63 | //! other words, grid allreduces. Note that no grid reduction without |
64 | //! broadcast is persistent anymore. |
65 | //! |
66 | //! Validated ReductionOp with ParallelType::Group is converted to |
67 | //! GroupedReductionOp. |
68 | void validateAndConvertIterDomainGrouping(Fusion* fusion); |
69 | |
70 | //! Validate the number of grouped reductions is within the limit |
71 | void validateGroupedReductions(Fusion* fusion); |
72 | |
73 | } // namespace cuda |
74 | } // namespace fuser |
75 | } // namespace jit |
76 | } // namespace torch |
77 | |