1 | #include <ir_builder.h> |
2 | #include <ir_utils.h> |
3 | #include <root_domain_map.h> |
4 | #include <transform_iter.h> |
5 | |
6 | #include <grouped_reduction.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | namespace { |
14 | |
15 | // Return if ref and other are transformed in the same way. |
16 | bool hasMatchingTransformations(TensorView* ref, TensorView* other) { |
17 | std::unordered_map<IterDomain*, IterDomain*> ref_2_other; |
18 | for (const auto i : c10::irange(ref->getRootDomain().size())) { |
19 | ref_2_other.emplace( |
20 | ref->getRootDomain().at(i), other->getRootDomain().at(i)); |
21 | } |
22 | |
23 | auto replay = |
24 | BestEffortReplay( |
25 | other->domain()->domain(), ref->domain()->domain(), ref_2_other) |
26 | .getReplay(); |
27 | |
28 | for (const auto i : c10::irange(ref->nDims())) { |
29 | auto ref_id = ref->axis(i); |
30 | auto other_id = other->axis(i); |
31 | auto it = replay.find(ref_id); |
32 | if (it == replay.end() || it->second != other_id) { |
33 | return false; |
34 | } |
35 | } |
36 | |
37 | return true; |
38 | } |
39 | |
40 | // Validate grouping of reductions and return a new max producer position |
41 | void validateReductionGrouping( |
42 | const std::vector<Val*>& inputs, |
43 | const std::vector<Val*>& outputs) { |
44 | TORCH_INTERNAL_ASSERT(inputs.size() == outputs.size()); |
45 | TORCH_INTERNAL_ASSERT(!inputs.empty()); |
46 | |
47 | auto fusion = dynamic_cast<Fusion*>(outputs[0]->container()); |
48 | TORCH_INTERNAL_ASSERT( |
49 | fusion != nullptr, "Grouping of reductions must be done within a Fusion" ); |
50 | |
51 | ExactRootDomainMap exact_map(fusion); |
52 | |
53 | // Pick the first output TV as a reference and compare it with the |
54 | // rest. Do not allow grouping if any mismatch is detected. |
55 | auto ref_tv = outputs[0]->as<TensorView>(); |
56 | const auto ref_domain = ref_tv->getRootDomain(); |
57 | const auto num_root_dims = ref_domain.size(); |
58 | const auto num_dims = ref_tv->nDims(); |
59 | const auto ref_ca_pos = ref_tv->getComputeAtPosition(); |
60 | for (const auto i : c10::irange(inputs.size())) { |
61 | auto output_tv = outputs.at(i)->as<TensorView>(); |
62 | const auto& output_domain = output_tv->getRootDomain(); |
63 | if (ref_tv == output_tv) { |
64 | continue; |
65 | } |
66 | TORCH_INTERNAL_ASSERT( |
67 | output_domain.size() == num_root_dims, |
68 | "Invalid grouped reduction due to mismatched number of root dimensions. " |
69 | "Expected: " , |
70 | num_root_dims, |
71 | ". Detected: " , |
72 | output_domain.size(), |
73 | ". Invalid output tensor: " , |
74 | output_tv->toString()); |
75 | TORCH_INTERNAL_ASSERT( |
76 | output_tv->nDims() == num_dims, |
77 | "Invalid grouped reduction due to mismatched number of dimensions. " |
78 | "Expected: " , |
79 | num_dims, |
80 | ". Detected: " , |
81 | output_tv->nDims(), |
82 | ". Invalid output tensor: " , |
83 | output_tv->toString()); |
84 | for (const auto i : c10::irange(num_root_dims)) { |
85 | auto ref_id = ref_domain.at(i); |
86 | auto output_id = output_domain.at(i); |
87 | // If an IterDomain is broadcast, require the other |
88 | // corresponding IterDomains are also broadcast. This may not be |
89 | // necessary but not completely certain. |
90 | TORCH_INTERNAL_ASSERT( |
91 | ref_id->isBroadcast() == output_id->isBroadcast(), |
92 | "Invalid grouped reduction due to mismatched broadcast root domains. " , |
93 | "Reference domain: " , |
94 | ref_id->toString(), |
95 | ". Mismatched domain: " , |
96 | output_id->toString(), |
97 | ". Invalid tensor: " , |
98 | output_tv->toString()); |
99 | if (ref_id->isBroadcast()) { |
100 | continue; |
101 | } |
102 | TORCH_INTERNAL_ASSERT( |
103 | ref_id->isReduction() == output_id->isReduction(), |
104 | "Invalid grouped reduction due to mismatched reduction root domains. " , |
105 | "Reference domain: " , |
106 | ref_id->toString(), |
107 | ". Mismatched domain: " , |
108 | output_id->toString(), |
109 | ". Invalid tensor: " , |
110 | output_tv->toString()); |
111 | TORCH_INTERNAL_ASSERT( |
112 | exact_map.areMapped(ref_id, output_id) || ref_id->sameAs(output_id), |
113 | "Invalid grouped reduction due to mismatched root domains. " , |
114 | "Reference domain: " , |
115 | ref_id->toString(), |
116 | ". Mismatched domain: " , |
117 | output_id->toString(), |
118 | ". Invalid tensor: " , |
119 | output_tv->toString()); |
120 | } |
121 | |
122 | TORCH_INTERNAL_ASSERT( |
123 | hasMatchingTransformations(ref_tv, output_tv), |
124 | "Invalid grouped reduction due to mismatched transformations. " , |
125 | "Reference tensor: " , |
126 | ref_tv->toString(), |
127 | ". Mismatched tensor: " , |
128 | output_tv->toString()); |
129 | |
130 | // Must have the same computeAt position |
131 | TORCH_INTERNAL_ASSERT( |
132 | output_tv->getComputeAtPosition() == ref_ca_pos, |
133 | "Invalid grouped reduction due to mismatched computeAt position. " , |
134 | "Reference tensor: " , |
135 | ref_tv->toString(), |
136 | ". Mismatched tensor: " , |
137 | output_tv->toString()); |
138 | } |
139 | |
140 | // Must not have any data dependency from outputs to inputs |
141 | const auto all_dep_vals = DependencyCheck::getAllValsBetween( |
142 | {outputs.begin(), outputs.end()}, inputs); |
143 | if (!all_dep_vals.empty()) { |
144 | std::stringstream ss; |
145 | ss << "Invalid dependency:" ; |
146 | for (auto val : all_dep_vals) { |
147 | ss << " " << val->toString(); |
148 | } |
149 | TORCH_INTERNAL_ASSERT(all_dep_vals.empty(), ss.str()); |
150 | } |
151 | } |
152 | |
153 | } // namespace |
154 | |
155 | void groupReductions(const std::vector<TensorView*>& reduction_outputs) { |
156 | TORCH_CHECK(!reduction_outputs.empty(), "No tensor is given" ); |
157 | |
158 | auto container = reduction_outputs[0]->container(); |
159 | |
160 | const auto num_reductions = reduction_outputs.size(); |
161 | |
162 | std::vector<BinaryOpType> op_types(num_reductions); |
163 | std::vector<Val*> init_vals(num_reductions); |
164 | std::vector<Val*> outputs(num_reductions); |
165 | std::vector<Val*> inputs(num_reductions); |
166 | |
167 | for (const auto i : c10::irange(num_reductions)) { |
168 | auto reduction_out = reduction_outputs.at(i); |
169 | TORCH_CHECK( |
170 | reduction_out->definition() != nullptr, |
171 | "Invalid tensor to group: " , |
172 | reduction_out->toString(), |
173 | ". Definition not found" ); |
174 | auto rop = dynamic_cast<ReductionOp*>(reduction_out->definition()); |
175 | TORCH_CHECK( |
176 | rop != nullptr, |
177 | "Invalid tensor to group: " , |
178 | reduction_out->toString(), |
179 | ". Not an output of a ReductionOp: " , |
180 | reduction_out->definition()->toString()); |
181 | // Fused reduction is only enabled during the lowering, so at this |
182 | // point it should be false. |
183 | TORCH_INTERNAL_ASSERT( |
184 | !rop->isAllreduce(), "Invalid ReductionOp: " , rop->toString()); |
185 | op_types.at(i) = rop->getReductionOpType(); |
186 | init_vals.at(i) = rop->init(); |
187 | outputs.at(i) = rop->out(); |
188 | inputs.at(i) = rop->in(); |
189 | } |
190 | |
191 | validateReductionGrouping(inputs, outputs); |
192 | |
193 | IrBuilder::create<GroupedReductionOp>( |
194 | container, op_types, init_vals, outputs, inputs); |
195 | |
196 | for (auto output : ir_utils::filterByType<TensorView>(outputs)) { |
197 | output->updateMaxProducerPosition(); |
198 | } |
199 | } |
200 | |
201 | } // namespace cuda |
202 | } // namespace fuser |
203 | } // namespace jit |
204 | } // namespace torch |
205 | |