1 | |
2 | #include <lower_divisible_split.h> |
3 | |
4 | #include <disjoint_set.h> |
5 | #include <ir_utils.h> |
6 | |
7 | #include <unordered_set> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | std::unordered_set<Split*> getAllDivisibleSplits(Fusion* fusion) { |
15 | ComputeAtMap ca_map(fusion); |
16 | return getAllDivisibleSplits(fusion, &ca_map); |
17 | } |
18 | |
19 | std::unordered_set<Split*> getAllDivisibleSplits( |
20 | Fusion* fusion, |
21 | const ComputeAtMap* ca_map) { |
22 | std::unordered_set<Split*> all_divisible_splits; |
23 | |
24 | auto all_tvs = ir_utils::allTvs(fusion); |
25 | // Find all tensor views with a view like rfactor. Splits used in view |
26 | // transformations must be divisible by definition. |
27 | for (auto tv : all_tvs) { |
28 | auto rfactor_dom = tv->getMaybeRFactorDomain(); |
29 | // Not view if there's no rfactor axis |
30 | if (!tv->domain()->hasViewLikeRFactor()) { |
31 | continue; |
32 | } |
33 | |
34 | // Take the view transformations and add all the splits. Those splits are |
35 | // the only divisible splits. |
36 | auto view_exprs = |
37 | StmtSort::getExprs(fusion, {rfactor_dom.begin(), rfactor_dom.end()}); |
38 | auto split_exprs = ir_utils::filterByType<Split>(view_exprs); |
39 | all_divisible_splits.insert(split_exprs.begin(), split_exprs.end()); |
40 | } |
41 | |
42 | // Vectorized dimensions are enforced to be a result of divisible splits. |
43 | // Gather vectorized splits. |
44 | for (auto tv : all_tvs) { |
45 | auto vec_id_it = std::find_if( |
46 | tv->domain()->domain().begin(), |
47 | tv->domain()->domain().end(), |
48 | [](IterDomain* id) { |
49 | return isParallelTypeVectorize(id->getParallelType()); |
50 | }); |
51 | |
52 | if (vec_id_it == tv->domain()->domain().end()) { |
53 | continue; |
54 | } |
55 | |
56 | // We could have a case technically like: |
57 | // [8, 2] where we do: |
58 | // split(0, 2) |
59 | // merge(1) |
60 | // so it ends up as [4, 4] |
61 | // split(0, 2) must be divisible, but for now we're not going to capture |
62 | // cases like this. Just look for direct split's producing a vectorize |
63 | // dimension. |
64 | auto vec_id = *vec_id_it; |
65 | if (vec_id->definition() != nullptr && vec_id->definition()->isA<Split>()) { |
66 | all_divisible_splits.emplace(vec_id->definition()->as<Split>()); |
67 | } |
68 | } |
69 | |
70 | // If there's no view like splits, there's nothing to find |
71 | if (all_divisible_splits.empty()) { |
72 | return all_divisible_splits; |
73 | } |
74 | |
75 | // Track the concrete id in the exact map of the outer output of the split |
76 | // expressions. This is how we'll check if there are matching splits. This |
77 | // also gets rid of any splits that already match (for processing). |
78 | std::unordered_map<IterDomain*, Expr*> outer_concrete_id_to_expr; |
79 | |
80 | for (auto split : all_divisible_splits) { |
81 | outer_concrete_id_to_expr[ca_map->getConcreteMappedID( |
82 | split->outer(), IdMappingMode::EXACT)] = split; |
83 | } |
84 | |
85 | std::unordered_set<Expr*> visited( |
86 | all_divisible_splits.begin(), all_divisible_splits.end()); |
87 | |
88 | // Find splits that match what we already have: |
89 | for (auto entry : outer_concrete_id_to_expr) { |
90 | auto concrete_id = entry.first; |
91 | auto original_view_split = entry.second; |
92 | |
93 | const auto& exact_mapped_ids = |
94 | ca_map->idGraph().exactNodes().getDisjointSetOf(concrete_id).vector(); |
95 | for (auto other_id : exact_mapped_ids) { |
96 | if (other_id->definition() == nullptr) { |
97 | continue; |
98 | } |
99 | |
100 | if (!visited.emplace(other_id->definition()).second) { |
101 | // Already visited |
102 | continue; |
103 | } |
104 | |
105 | if (IterDomainGraph::exprsMap( |
106 | original_view_split, |
107 | other_id->definition(), |
108 | false, |
109 | ca_map->idGraph().exactNodes())) { |
110 | all_divisible_splits.emplace(other_id->definition()->as<Split>()); |
111 | } |
112 | } |
113 | } |
114 | |
115 | return all_divisible_splits; |
116 | } |
117 | |
118 | } // namespace cuda |
119 | } // namespace fuser |
120 | } // namespace jit |
121 | } // namespace torch |
122 | |