1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <ir_all_nodes.h> |
6 | #include <iter_visitor.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | //! If an IterDomain is split and its inner output domain is |
14 | //! eventually split too, the second split must be divisible or the |
15 | //! inner domain must be predicated. This class finds Split |
16 | //! expressions that need to be divisible or predicated. |
17 | //! |
18 | //! Second splits are not limited to just direct output domains of |
19 | //! first splits but also indirect descendent domains as well. |
20 | //! |
21 | //! Predicating non-divisible split domains does not work if split |
22 | //! output domains are vectorized where ParallelType::Vectorize is |
23 | //! applied to an inner domain of splits. If it's non-divisible, |
24 | //! predicating the input domain of the non-divisible split results in |
25 | //! a vectoried operation is predicated out entirely since we do not |
26 | //! generate a fall-back non-vectorized else path. Runtime check is |
27 | //! done for those domains. |
28 | class TORCH_CUDA_CU_API NonDivisibleSplitInfo : public IterVisitor { |
29 | public: |
30 | void build(Fusion* fusion); |
31 | |
32 | const auto& splitsToPredicate() const { |
33 | return splits_to_predicate_; |
34 | } |
35 | |
36 | const auto& splitsToValidate() const { |
37 | return splits_to_validate_; |
38 | } |
39 | |
40 | private: |
41 | using IterVisitor::handle; |
42 | |
43 | void handle(Split* split) override; |
44 | |
45 | void handle(Merge* merge) override; |
46 | |
47 | //! True if reachable from inner domains of splits |
48 | bool isReachableFromInnerDomains(IterDomain* id) const; |
49 | |
50 | //! Forward propagate the reachability information |
51 | void propagateReachability(Split* split, bool is_protected); |
52 | |
53 | //! Forward propagate the reachability information |
54 | void propagateReachability(Merge* merge); |
55 | |
56 | void clearReachability(); |
57 | |
58 | //! Returns the extent of a split output domain if it's not proven to |
59 | //! be divisible. |
60 | Val* getMaybeNonDivisibleExtent(Split* split) const; |
61 | |
62 | //! Remove redundant predicates as divisibility may be validated at |
63 | //! run time |
64 | void removeRedundancy(); |
65 | |
66 | private: |
67 | //! Split expressions whose input domain must be predicated |
68 | std::unordered_map<TensorView*, std::vector<Split*>> splits_to_predicate_; |
69 | //! Split expressions whose divisibility must be validated at run time |
70 | std::unordered_set<Split*> splits_to_validate_; |
71 | |
72 | //! Temporarily used for analyzing each tensor |
73 | TensorView* current_tv_ = nullptr; |
74 | std::unordered_set<IterDomain*> inner_domains_; |
75 | }; |
76 | |
77 | } // namespace cuda |
78 | } // namespace fuser |
79 | } // namespace jit |
80 | } // namespace torch |
81 | |