1
2#include <lower_divisible_split.h>
3
4#include <disjoint_set.h>
5#include <ir_utils.h>
6
7#include <unordered_set>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14std::unordered_set<Split*> getAllDivisibleSplits(Fusion* fusion) {
15 ComputeAtMap ca_map(fusion);
16 return getAllDivisibleSplits(fusion, &ca_map);
17}
18
19std::unordered_set<Split*> getAllDivisibleSplits(
20 Fusion* fusion,
21 const ComputeAtMap* ca_map) {
22 std::unordered_set<Split*> all_divisible_splits;
23
24 auto all_tvs = ir_utils::allTvs(fusion);
25 // Find all tensor views with a view like rfactor. Splits used in view
26 // transformations must be divisible by definition.
27 for (auto tv : all_tvs) {
28 auto rfactor_dom = tv->getMaybeRFactorDomain();
29 // Not view if there's no rfactor axis
30 if (!tv->domain()->hasViewLikeRFactor()) {
31 continue;
32 }
33
34 // Take the view transformations and add all the splits. Those splits are
35 // the only divisible splits.
36 auto view_exprs =
37 StmtSort::getExprs(fusion, {rfactor_dom.begin(), rfactor_dom.end()});
38 auto split_exprs = ir_utils::filterByType<Split>(view_exprs);
39 all_divisible_splits.insert(split_exprs.begin(), split_exprs.end());
40 }
41
42 // Vectorized dimensions are enforced to be a result of divisible splits.
43 // Gather vectorized splits.
44 for (auto tv : all_tvs) {
45 auto vec_id_it = std::find_if(
46 tv->domain()->domain().begin(),
47 tv->domain()->domain().end(),
48 [](IterDomain* id) {
49 return isParallelTypeVectorize(id->getParallelType());
50 });
51
52 if (vec_id_it == tv->domain()->domain().end()) {
53 continue;
54 }
55
56 // We could have a case technically like:
57 // [8, 2] where we do:
58 // split(0, 2)
59 // merge(1)
60 // so it ends up as [4, 4]
61 // split(0, 2) must be divisible, but for now we're not going to capture
62 // cases like this. Just look for direct split's producing a vectorize
63 // dimension.
64 auto vec_id = *vec_id_it;
65 if (vec_id->definition() != nullptr && vec_id->definition()->isA<Split>()) {
66 all_divisible_splits.emplace(vec_id->definition()->as<Split>());
67 }
68 }
69
70 // If there's no view like splits, there's nothing to find
71 if (all_divisible_splits.empty()) {
72 return all_divisible_splits;
73 }
74
75 // Track the concrete id in the exact map of the outer output of the split
76 // expressions. This is how we'll check if there are matching splits. This
77 // also gets rid of any splits that already match (for processing).
78 std::unordered_map<IterDomain*, Expr*> outer_concrete_id_to_expr;
79
80 for (auto split : all_divisible_splits) {
81 outer_concrete_id_to_expr[ca_map->getConcreteMappedID(
82 split->outer(), IdMappingMode::EXACT)] = split;
83 }
84
85 std::unordered_set<Expr*> visited(
86 all_divisible_splits.begin(), all_divisible_splits.end());
87
88 // Find splits that match what we already have:
89 for (auto entry : outer_concrete_id_to_expr) {
90 auto concrete_id = entry.first;
91 auto original_view_split = entry.second;
92
93 const auto& exact_mapped_ids =
94 ca_map->idGraph().exactNodes().getDisjointSetOf(concrete_id).vector();
95 for (auto other_id : exact_mapped_ids) {
96 if (other_id->definition() == nullptr) {
97 continue;
98 }
99
100 if (!visited.emplace(other_id->definition()).second) {
101 // Already visited
102 continue;
103 }
104
105 if (IterDomainGraph::exprsMap(
106 original_view_split,
107 other_id->definition(),
108 false,
109 ca_map->idGraph().exactNodes())) {
110 all_divisible_splits.emplace(other_id->definition()->as<Split>());
111 }
112 }
113 }
114
115 return all_divisible_splits;
116}
117
118} // namespace cuda
119} // namespace fuser
120} // namespace jit
121} // namespace torch
122