1#include <dispatch.h>
2#include <instrumentation.h>
3#include <ir_iostream.h>
4#include <ir_utils.h>
5#include <iter_visitor.h>
6#include <lower2device.h>
7#include <lower_trivial_reductions.h>
8#include <lower_utils.h>
9#include <root_domain_map.h>
10
11#include <unordered_set>
12
13namespace torch {
14namespace jit {
15namespace fuser {
16namespace cuda {
17
18namespace {
19
20bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id);
21
22// Checks the producer of tv to see if the
23bool traverseToRFactorTensor(TensorView* tv, IterDomain* root_id) {
24 TORCH_INTERNAL_ASSERT(
25 root_id->definition() == nullptr, "Not root IterDomain: ", root_id);
26
27 auto def = tv->definition();
28
29 if (def == nullptr) {
30 // This is an input tensor, so no rfactor tensor to traverse.
31 return false;
32 }
33
34 // Check the reduction expression that produces tv
35 if (!ir_utils::isReductionOp(def) || def->isA<MmaOp>()) {
36 return false;
37 }
38
39 TORCH_INTERNAL_ASSERT(
40 def->inputs().size() == def->outputs().size(),
41 "This logic block assumes number of inputs is the same as number of outputs of reduction ops.");
42
43 // Reduction expr may have multiple inputs, just grab any TV
44 // input. Note that in theory it is possible that a
45 // GroupedReductionOp has rfactor inputs as well as non-rfactor
46 // inputs, so grabbing the one that actually corresponds to tv can
47 // be important. In reality, though, such a GroupedReductionOp
48 // should not happen as we do not group reductions of rfactor and
49 // non-rfactor tensor.
50 auto producer_tv = ir_utils::getTvInput(def);
51
52 TORCH_INTERNAL_ASSERT(producer_tv != nullptr);
53
54 if (!producer_tv->hasRFactor()) {
55 return false;
56 }
57
58 auto c2p = PairwiseRootDomainMap(producer_tv, tv)
59 .mapConsumerToProducer(tv->domain(), producer_tv->domain());
60
61 auto producer_id_it = c2p.find(root_id);
62 if (producer_id_it == c2p.end()) {
63 // No matching producer is found. Stop traversing.
64 return false;
65 }
66
67 auto producer_root_id = producer_id_it->second;
68
69 return analyzeIfDerivedFromTrivialReduction(producer_tv, producer_root_id);
70}
71
72bool analyzeIfDerivedFromTrivialReduction(TensorView* tv, IterDomain* id) {
73 auto id_inputs = InputsOf::output(id->fusion(), id);
74 for (auto root_id : ir_utils::filterByType<IterDomain>(id_inputs)) {
75 if (root_id->isReduction() && root_id->extent()->isOneInt()) {
76 continue;
77 }
78 // If not possible to prove the root ID is trivial, see if the ID
79 // is derived from a rfactor tensor. This may mean that the iteration domain
80 // was merged or split in another expression through rfactor. Trace back
81 // through rfactor expressions to find original roots and determine there if
82 // trivial.
83 if (!traverseToRFactorTensor(tv, root_id)) {
84 return false;
85 }
86 }
87 return true;
88}
89
90} // namespace
91
92void TrivialReductionInfo::build(Fusion* fusion) {
93 auto used_vals = fusion->usedMathVals();
94
95 for (auto tv : ir_utils::filterByType<TensorView>(used_vals)) {
96 for (auto id : tv->domain()->domain()) {
97 if (analyzeIfDerivedFromTrivialReduction(tv, id)) {
98 // If id is a trivial reduction, all of its ancestor vals are
99 // also trivial reductions.
100 for (auto dep_id : DependencyCheck::getAllValsBetween(
101 std::unordered_set<Val*>(
102 tv->getRootDomain().begin(), tv->getRootDomain().end()),
103 {id})) {
104 domains_.insert(dep_id->as<IterDomain>());
105 domains_derived_from_root_.insert(dep_id->as<IterDomain>());
106 }
107 } else if (id->isReduction() && id->extent()->isOneInt()) {
108 // This happens when a leaf domain is trivial but its root
109 // axes are not. For example, consider a non-trivial domain
110 // split by one. The inner output axis is a trivial domain,
111 // whereas the outer output axis is not. Since the root axis
112 // is not trivial, a for-loop needs to be generated.
113 domains_.insert(id);
114 }
115 }
116 }
117}
118
119bool TrivialReductionInfo::isDerived(IterDomain* id) const {
120 return domains_.find(id) != domains_.end();
121}
122
123} // namespace cuda
124} // namespace fuser
125} // namespace jit
126} // namespace torch
127