1#include <ir_builder.h>
2#include <ir_utils.h>
3#include <root_domain_map.h>
4#include <transform_iter.h>
5
6#include <grouped_reduction.h>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13namespace {
14
15// Return if ref and other are transformed in the same way.
16bool hasMatchingTransformations(TensorView* ref, TensorView* other) {
17 std::unordered_map<IterDomain*, IterDomain*> ref_2_other;
18 for (const auto i : c10::irange(ref->getRootDomain().size())) {
19 ref_2_other.emplace(
20 ref->getRootDomain().at(i), other->getRootDomain().at(i));
21 }
22
23 auto replay =
24 BestEffortReplay(
25 other->domain()->domain(), ref->domain()->domain(), ref_2_other)
26 .getReplay();
27
28 for (const auto i : c10::irange(ref->nDims())) {
29 auto ref_id = ref->axis(i);
30 auto other_id = other->axis(i);
31 auto it = replay.find(ref_id);
32 if (it == replay.end() || it->second != other_id) {
33 return false;
34 }
35 }
36
37 return true;
38}
39
40// Validate grouping of reductions and return a new max producer position
41void validateReductionGrouping(
42 const std::vector<Val*>& inputs,
43 const std::vector<Val*>& outputs) {
44 TORCH_INTERNAL_ASSERT(inputs.size() == outputs.size());
45 TORCH_INTERNAL_ASSERT(!inputs.empty());
46
47 auto fusion = dynamic_cast<Fusion*>(outputs[0]->container());
48 TORCH_INTERNAL_ASSERT(
49 fusion != nullptr, "Grouping of reductions must be done within a Fusion");
50
51 ExactRootDomainMap exact_map(fusion);
52
53 // Pick the first output TV as a reference and compare it with the
54 // rest. Do not allow grouping if any mismatch is detected.
55 auto ref_tv = outputs[0]->as<TensorView>();
56 const auto ref_domain = ref_tv->getRootDomain();
57 const auto num_root_dims = ref_domain.size();
58 const auto num_dims = ref_tv->nDims();
59 const auto ref_ca_pos = ref_tv->getComputeAtPosition();
60 for (const auto i : c10::irange(inputs.size())) {
61 auto output_tv = outputs.at(i)->as<TensorView>();
62 const auto& output_domain = output_tv->getRootDomain();
63 if (ref_tv == output_tv) {
64 continue;
65 }
66 TORCH_INTERNAL_ASSERT(
67 output_domain.size() == num_root_dims,
68 "Invalid grouped reduction due to mismatched number of root dimensions. "
69 "Expected: ",
70 num_root_dims,
71 ". Detected: ",
72 output_domain.size(),
73 ". Invalid output tensor: ",
74 output_tv->toString());
75 TORCH_INTERNAL_ASSERT(
76 output_tv->nDims() == num_dims,
77 "Invalid grouped reduction due to mismatched number of dimensions. "
78 "Expected: ",
79 num_dims,
80 ". Detected: ",
81 output_tv->nDims(),
82 ". Invalid output tensor: ",
83 output_tv->toString());
84 for (const auto i : c10::irange(num_root_dims)) {
85 auto ref_id = ref_domain.at(i);
86 auto output_id = output_domain.at(i);
87 // If an IterDomain is broadcast, require the other
88 // corresponding IterDomains are also broadcast. This may not be
89 // necessary but not completely certain.
90 TORCH_INTERNAL_ASSERT(
91 ref_id->isBroadcast() == output_id->isBroadcast(),
92 "Invalid grouped reduction due to mismatched broadcast root domains. ",
93 "Reference domain: ",
94 ref_id->toString(),
95 ". Mismatched domain: ",
96 output_id->toString(),
97 ". Invalid tensor: ",
98 output_tv->toString());
99 if (ref_id->isBroadcast()) {
100 continue;
101 }
102 TORCH_INTERNAL_ASSERT(
103 ref_id->isReduction() == output_id->isReduction(),
104 "Invalid grouped reduction due to mismatched reduction root domains. ",
105 "Reference domain: ",
106 ref_id->toString(),
107 ". Mismatched domain: ",
108 output_id->toString(),
109 ". Invalid tensor: ",
110 output_tv->toString());
111 TORCH_INTERNAL_ASSERT(
112 exact_map.areMapped(ref_id, output_id) || ref_id->sameAs(output_id),
113 "Invalid grouped reduction due to mismatched root domains. ",
114 "Reference domain: ",
115 ref_id->toString(),
116 ". Mismatched domain: ",
117 output_id->toString(),
118 ". Invalid tensor: ",
119 output_tv->toString());
120 }
121
122 TORCH_INTERNAL_ASSERT(
123 hasMatchingTransformations(ref_tv, output_tv),
124 "Invalid grouped reduction due to mismatched transformations. ",
125 "Reference tensor: ",
126 ref_tv->toString(),
127 ". Mismatched tensor: ",
128 output_tv->toString());
129
130 // Must have the same computeAt position
131 TORCH_INTERNAL_ASSERT(
132 output_tv->getComputeAtPosition() == ref_ca_pos,
133 "Invalid grouped reduction due to mismatched computeAt position. ",
134 "Reference tensor: ",
135 ref_tv->toString(),
136 ". Mismatched tensor: ",
137 output_tv->toString());
138 }
139
140 // Must not have any data dependency from outputs to inputs
141 const auto all_dep_vals = DependencyCheck::getAllValsBetween(
142 {outputs.begin(), outputs.end()}, inputs);
143 if (!all_dep_vals.empty()) {
144 std::stringstream ss;
145 ss << "Invalid dependency:";
146 for (auto val : all_dep_vals) {
147 ss << " " << val->toString();
148 }
149 TORCH_INTERNAL_ASSERT(all_dep_vals.empty(), ss.str());
150 }
151}
152
153} // namespace
154
155void groupReductions(const std::vector<TensorView*>& reduction_outputs) {
156 TORCH_CHECK(!reduction_outputs.empty(), "No tensor is given");
157
158 auto container = reduction_outputs[0]->container();
159
160 const auto num_reductions = reduction_outputs.size();
161
162 std::vector<BinaryOpType> op_types(num_reductions);
163 std::vector<Val*> init_vals(num_reductions);
164 std::vector<Val*> outputs(num_reductions);
165 std::vector<Val*> inputs(num_reductions);
166
167 for (const auto i : c10::irange(num_reductions)) {
168 auto reduction_out = reduction_outputs.at(i);
169 TORCH_CHECK(
170 reduction_out->definition() != nullptr,
171 "Invalid tensor to group: ",
172 reduction_out->toString(),
173 ". Definition not found");
174 auto rop = dynamic_cast<ReductionOp*>(reduction_out->definition());
175 TORCH_CHECK(
176 rop != nullptr,
177 "Invalid tensor to group: ",
178 reduction_out->toString(),
179 ". Not an output of a ReductionOp: ",
180 reduction_out->definition()->toString());
181 // Fused reduction is only enabled during the lowering, so at this
182 // point it should be false.
183 TORCH_INTERNAL_ASSERT(
184 !rop->isAllreduce(), "Invalid ReductionOp: ", rop->toString());
185 op_types.at(i) = rop->getReductionOpType();
186 init_vals.at(i) = rop->init();
187 outputs.at(i) = rop->out();
188 inputs.at(i) = rop->in();
189 }
190
191 validateReductionGrouping(inputs, outputs);
192
193 IrBuilder::create<GroupedReductionOp>(
194 container, op_types, init_vals, outputs, inputs);
195
196 for (auto output : ir_utils::filterByType<TensorView>(outputs)) {
197 output->updateMaxProducerPosition();
198 }
199}
200
201} // namespace cuda
202} // namespace fuser
203} // namespace jit
204} // namespace torch
205