1 | #include <dispatch.h> |
2 | #include <instrumentation.h> |
3 | #include <ir_iostream.h> |
4 | #include <ir_utils.h> |
5 | #include <iter_visitor.h> |
6 | #include <lower2device.h> |
7 | #include <lower_trivial_reductions.h> |
8 | #include <lower_utils.h> |
9 | #include <root_domain_map.h> |
10 | |
11 | #include <unordered_set> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | namespace fuser { |
16 | namespace cuda { |
17 | |
18 | namespace { |
19 | |
20 | bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id); |
21 | |
22 | // Checks the producer of tv to see if the |
23 | bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) { |
24 | TORCH_INTERNAL_ASSERT( |
25 | root_id->definition() == nullptr, "Not root IterDomain: " , root_id); |
26 | |
27 | auto def = tv->definition(); |
28 | |
29 | if (def == nullptr) { |
30 | // This is an input tensor, so no rfactor tensor to traverse. |
31 | return false; |
32 | } |
33 | |
34 | // Check the reduction expression that produces tv |
35 | if (!ir_utils::isReductionOp(def) || def->isA<MmaOp>()) { |
36 | return false; |
37 | } |
38 | |
39 | TORCH_INTERNAL_ASSERT( |
40 | def->inputs().size() == def->outputs().size(), |
41 | "This logic block assumes number of inputs is the same as number of outputs of reduction ops." ); |
42 | |
43 | // Reduction expr may have multiple inputs, just grab any TV |
44 | // input. Note that in theory it is possible that a |
45 | // GroupedReductionOp has rfactor inputs as well as non-rfactor |
46 | // inputs, so grabbing the one that actually corresponds to tv can |
47 | // be important. In reality, though, such a GroupedReductionOp |
48 | // should not happen as we do not group reductions of rfactor and |
49 | // non-rfactor tensor. |
50 | auto producer_tv = ir_utils::getTvInput(def); |
51 | |
52 | TORCH_INTERNAL_ASSERT(producer_tv != nullptr); |
53 | |
54 | if (!producer_tv->hasRFactor()) { |
55 | return false; |
56 | } |
57 | |
58 | auto c2p = PairwiseRootDomainMap(producer_tv, tv) |
59 | .mapConsumerToProducer(tv->domain(), producer_tv->domain()); |
60 | |
61 | auto producer_id_it = c2p.find(root_id); |
62 | if (producer_id_it == c2p.end()) { |
63 | // No matching producer is found. Stop traversing. |
64 | return false; |
65 | } |
66 | |
67 | auto producer_root_id = producer_id_it->second; |
68 | |
69 | return analyzeIfDerivedFromTrivialReduction(producer_tv, producer_root_id); |
70 | } |
71 | |
72 | bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) { |
73 | auto id_inputs = InputsOf::output(id->fusion(), id); |
74 | for (auto root_id : ir_utils::filterByType<IterDomain>(id_inputs)) { |
75 | if (root_id->isReduction() && root_id->extent()->isOneInt()) { |
76 | continue; |
77 | } |
78 | // If not possible to prove the root ID is trivial, see if the ID |
79 | // is derived from a rfactor tensor. This may mean that the iteration domain |
80 | // was merged or split in another expression through rfactor. Trace back |
81 | // through rfactor expressions to find original roots and determine there if |
82 | // trivial. |
83 | if (!traverseToRFactorTensor(tv, root_id)) { |
84 | return false; |
85 | } |
86 | } |
87 | return true; |
88 | } |
89 | |
90 | } // namespace |
91 | |
92 | void TrivialReductionInfo::build(Fusion* fusion) { |
93 | auto used_vals = fusion->usedMathVals(); |
94 | |
95 | for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) { |
96 | for (auto id : tv->domain()->domain()) { |
97 | if (analyzeIfDerivedFromTrivialReduction(tv, id)) { |
98 | // If id is a trivial reduction, all of its ancestor vals are |
99 | // also trivial reductions. |
100 | for (auto dep_id : DependencyCheck::getAllValsBetween( |
101 | std::unordered_set<Val*>( |
102 | tv->getRootDomain().begin(), tv->getRootDomain().end()), |
103 | {id})) { |
104 | domains_.insert(dep_id->as<IterDomain>()); |
105 | domains_derived_from_root_.insert(dep_id->as<IterDomain>()); |
106 | } |
107 | } else if (id->isReduction() && id->extent()->isOneInt()) { |
108 | // This happens when a leaf domain is trivial but its root |
109 | // axes are not. For example, consider a non-trivial domain |
110 | // split by one. The inner output axis is a trivial domain, |
111 | // whereas the outer output axis is not. Since the root axis |
112 | // is not trivial, a for-loop needs to be generated. |
113 | domains_.insert(id); |
114 | } |
115 | } |
116 | } |
117 | } |
118 | |
119 | bool TrivialReductionInfo::isDerived(IterDomain* id) const { |
120 | return domains_.find(id) != domains_.end(); |
121 | } |
122 | |
123 | } // namespace cuda |
124 | } // namespace fuser |
125 | } // namespace jit |
126 | } // namespace torch |
127 | |