1 | #include <ir_utils.h> |
2 | #include <iter_visitor.h> |
3 | #include <kernel_ir_dispatch.h> |
4 | #include <lower2device.h> |
5 | |
6 | #include <lower_fused_reduction.h> |
7 | |
8 | #include <algorithm> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | |
15 | namespace { |
16 | |
17 | //! An instance of reduction patterns to fuse |
18 | class FusedReductionBroadcastInfo : public PolymorphicBase { |
19 | public: |
20 | FusedReductionBroadcastInfo(ReductionOp* reduction, bool with_broadcast) |
21 | : reductions_({reduction}), with_broadcast_({with_broadcast}) {} |
22 | |
23 | FusedReductionBroadcastInfo(WelfordOp* welford, bool with_broadcast) |
24 | : reductions_({welford}), with_broadcast_({with_broadcast}) {} |
25 | |
26 | FusedReductionBroadcastInfo( |
27 | GroupedReductionOp* grouped_rop, |
28 | bool with_broadcast) |
29 | : reductions_({grouped_rop}), with_broadcast_({with_broadcast}) {} |
30 | |
31 | const std::vector<Expr*>& reductions() const { |
32 | return reductions_; |
33 | } |
34 | |
35 | const std::vector<bool>& withBroadcast() const { |
36 | return with_broadcast_; |
37 | } |
38 | |
39 | private: |
40 | // Holds ReductionOp, WelfordOp or GroupedReductionOp. |
41 | std::vector<Expr*> reductions_; |
42 | // True each reduction also broadcasts |
43 | std::vector<bool> with_broadcast_; |
44 | }; |
45 | |
46 | //! Inspect a fusion to detect eligible sequences of expressions to |
47 | //! use the fused reduction kernel |
48 | class FusionInspector : private IterVisitor { |
49 | public: |
50 | static std::vector<FusedReductionBroadcastInfo> run(Fusion* fusion) { |
51 | FusionInspector inspector(fusion); |
52 | return inspector.fusion_list_; |
53 | } |
54 | |
55 | private: |
56 | FusionInspector(Fusion* fusion) { |
57 | traverse(fusion); |
58 | } |
59 | |
60 | using IterVisitor::handle; |
61 | |
62 | void handle(ReductionOp* rop) final { |
63 | /// If it's a grid reduction, keep track of tensors that depend on |
64 | /// this reduction. |
65 | // Only consider when out is on register as that is assumed in the |
66 | // fused reduction kernel. |
67 | auto out = ir_utils::getTvOutput(rop); |
68 | if (out->getMemoryType() == MemoryType::Local && |
69 | out->domain()->hasGridReduction()) { |
70 | reduction_dep_[out].insert(rop); |
71 | } |
72 | } |
73 | |
74 | void handle(WelfordOp* wop) final { |
75 | /// If it's a grid reduction, keep track of tensors that depend on |
76 | /// this reduction. |
77 | // Only consider when out is on register as that is assumed in the |
78 | // fused reduction kernel. |
79 | auto out = ir_utils::getTvOutput(wop); |
80 | if (out->getMemoryType() == MemoryType::Local && |
81 | out->domain()->hasGridReduction()) { |
82 | reduction_dep_[out].insert(wop); |
83 | } |
84 | } |
85 | |
86 | void handle(GroupedReductionOp* grouped_rop) final { |
87 | auto out = ir_utils::getTvOutput(grouped_rop); |
88 | if (out->getMemoryType() == MemoryType::Local && |
89 | out->domain()->hasGridReduction()) { |
90 | reduction_dep_[out].insert(grouped_rop); |
91 | } |
92 | } |
93 | |
94 | void handle(Expr* expr) final { |
95 | IterVisitor::handle(expr); |
96 | for (auto in_tv : ir_utils::filterByType<TensorView>(expr->inputs())) { |
97 | for (auto reduction_op : reduction_dep_[in_tv]) { |
98 | if (fused_exprs_.find(reduction_op) != fused_exprs_.end()) { |
99 | continue; |
100 | } |
101 | for (auto out_tv : |
102 | ir_utils::filterByType<TensorView>(expr->outputs())) { |
103 | reduction_dep_[out_tv].insert(reduction_op); |
104 | } |
105 | } |
106 | } |
107 | } |
108 | |
109 | // In the case of welford, use the fused broadcast reduction when at |
110 | // least one of the outputs is broadcast. |
111 | void handle(BroadcastOp* bop) final { |
112 | // Detect a pattern where a reduction is followed by a broadcast |
113 | auto bop_out = bop->out()->as<TensorView>(); |
114 | auto bop_in = bop->in()->as<TensorView>(); |
115 | |
116 | for (Expr* preceding_expr : reduction_dep_[bop_in]) { |
117 | auto parallel_reduction_axes = |
118 | getReductionParallelTypeStates(preceding_expr); |
119 | |
120 | // If not matching, propagate the reduction further down to |
121 | // subsequent expressions |
122 | if (!isBroadcastFuseable(bop_out, parallel_reduction_axes)) { |
123 | continue; |
124 | } |
125 | |
126 | if (fused_exprs_.find(preceding_expr) != fused_exprs_.end()) { |
127 | // Already added to the fusion list. This can happen with |
128 | // welford as there can be multiple broadcast consumer |
129 | // expressions. |
130 | continue; |
131 | } |
132 | |
133 | if (preceding_expr->isA<ReductionOp>()) { |
134 | fusion_list_.emplace_back(preceding_expr->as<ReductionOp>(), true); |
135 | } else if (preceding_expr->isA<GroupedReductionOp>()) { |
136 | fusion_list_.emplace_back( |
137 | preceding_expr->as<GroupedReductionOp>(), true); |
138 | } else if (preceding_expr->isA<WelfordOp>()) { |
139 | fusion_list_.emplace_back(preceding_expr->as<WelfordOp>(), true); |
140 | } else { |
141 | TORCH_INTERNAL_ASSERT( |
142 | false, "Invalid preceding expr: " , preceding_expr->toString()); |
143 | } |
144 | |
145 | fused_exprs_.insert(preceding_expr); |
146 | } |
147 | } |
148 | |
149 | ParallelTypeBitmap getReductionParallelTypeStates(Expr* expr) { |
150 | ParallelTypeBitmap parallel_reduction_axes; |
151 | |
152 | for (auto id : ir_utils::getTvOutput(expr)->domain()->domain()) { |
153 | auto pt = id->getParallelType(); |
154 | if (id->isReduction() && isParallelTypeThread(pt)) { |
155 | parallel_reduction_axes.set(pt); |
156 | } |
157 | } |
158 | |
159 | return parallel_reduction_axes; |
160 | } |
161 | |
162 | // Requires reduction parallel dimensions to exactly match parallel broadcast |
163 | // dimensions |
164 | bool isBroadcastFuseable( |
165 | TensorView* broadcast_out, |
166 | const ParallelTypeBitmap& parallel_reduction_axes) { |
167 | const auto broadcast_parallel_types = |
168 | GpuLower::current()->threadPredMap().getParallelBroadcastDomains( |
169 | broadcast_out); |
170 | |
171 | // If no parallel broadcast, nothing to fuse |
172 | if (broadcast_parallel_types.none()) { |
173 | return false; |
174 | } |
175 | |
176 | // Make sure the broadcast parallel types are the types reduced by |
177 | // the preceding reduction op |
178 | for (auto id : broadcast_out->domain()->domain()) { |
179 | auto pt = id->getParallelType(); |
180 | if (!isParallelTypeThread(pt)) { |
181 | continue; |
182 | } |
183 | // Parallel broadcast must be included in reduction_states |
184 | if (id->isBroadcast() && broadcast_parallel_types.get(pt)) { |
185 | if (!parallel_reduction_axes.get(pt)) { |
186 | return false; |
187 | } |
188 | } |
189 | } |
190 | |
191 | return true; |
192 | } |
193 | |
194 | private: |
195 | //! List of expression sequences to fuse |
196 | std::vector<FusedReductionBroadcastInfo> fusion_list_; |
197 | //! Keep track of fused reduction/welford exprs to avoid duplication |
198 | std::unordered_set<Expr*> fused_exprs_; |
199 | //! Keep track of ReductionOp/WelfordOp expressions that are |
200 | //! (indirectly) input to a tensor |
201 | std::unordered_map<TensorView*, std::unordered_set<Expr*>> reduction_dep_; |
202 | }; |
203 | |
204 | //! Transform a fusion to use the fused reduction kernel. |
205 | class FusionTransformer { |
206 | public: |
207 | static void run( |
208 | Fusion* fusion, |
209 | const std::vector<FusedReductionBroadcastInfo>& fusion_list) { |
210 | FusionTransformer transformer(fusion, fusion_list); |
211 | } |
212 | |
213 | private: |
214 | FusionTransformer( |
215 | Fusion* fusion, |
216 | const std::vector<FusedReductionBroadcastInfo>& fusion_list) |
217 | : fusion_(fusion), fusion_list_(fusion_list) { |
218 | transform(); |
219 | } |
220 | |
221 | void transform() { |
222 | for (const auto& info : fusion_list_) { |
223 | transform(info); |
224 | } |
225 | // If the thread predicate map is modified, rebuild the |
226 | // map. build() only updates mappings that need to be updated. |
227 | if (thread_pred_map_modified_) { |
228 | GpuLower::current()->threadPredMap().build(fusion_); |
229 | } |
230 | } |
231 | |
232 | void transform(const FusedReductionBroadcastInfo& info) { |
233 | TORCH_INTERNAL_ASSERT( |
234 | info.reductions().size() == 1, "Horizontal fusion not supported yet" ); |
235 | |
236 | for (const auto i : c10::irange(info.reductions().size())) { |
237 | const auto expr = info.reductions().at(i); |
238 | const auto with_broadcast = info.withBroadcast().at(i); |
239 | Expr* fused_expr = nullptr; |
240 | |
241 | if (auto reduction = dynamic_cast<ReductionOp*>(expr)) { |
242 | TORCH_INTERNAL_ASSERT(!reduction->isAllreduce()); |
243 | |
244 | auto red_op_type = reduction->getReductionOpType(); |
245 | auto init = reduction->init(); |
246 | auto out = reduction->out(); |
247 | auto in = reduction->in(); |
248 | |
249 | fusion_->removeExpr(reduction); |
250 | |
251 | fused_expr = |
252 | IrBuilder::create<ReductionOp>(red_op_type, init, out, in, true); |
253 | } else if (auto welford = dynamic_cast<WelfordOp*>(expr)) { |
254 | TORCH_INTERNAL_ASSERT(!welford->isAllreduce()); |
255 | |
256 | auto out_avg = welford->outAvg(); |
257 | auto out_var = welford->outVar(); |
258 | auto out_n = welford->outN(); |
259 | auto init_avg = welford->initAvg(); |
260 | auto init_var = welford->initVar(); |
261 | auto init_n = welford->initN(); |
262 | auto in_avg = welford->inAvg(); |
263 | auto in_var = welford->inVar(); |
264 | auto in_n = welford->inN(); |
265 | |
266 | fusion_->removeExpr(welford); |
267 | |
268 | fused_expr = IrBuilder::create<WelfordOp>( |
269 | WelfordTriplet{out_avg, out_var, out_n}, |
270 | WelfordTriplet{in_avg, in_var, in_n}, |
271 | WelfordTriplet{init_avg, init_var, init_n}, |
272 | true); |
273 | } else if (auto grouped_rop = dynamic_cast<GroupedReductionOp*>(expr)) { |
274 | TORCH_INTERNAL_ASSERT(!grouped_rop->isAllreduce()); |
275 | |
276 | auto op_types = grouped_rop->getReductionOpTypes(); |
277 | auto init_vals = grouped_rop->initVals(); |
278 | auto outputs = grouped_rop->outputs(); |
279 | auto inputs = grouped_rop->inputs(); |
280 | |
281 | fusion_->removeExpr(grouped_rop); |
282 | |
283 | fused_expr = IrBuilder::create<GroupedReductionOp>( |
284 | op_types, init_vals, outputs, inputs, true); |
285 | } else { |
286 | TORCH_INTERNAL_ASSERT(false, "Invalid expr: " , expr->toString()); |
287 | } |
288 | |
289 | TORCH_INTERNAL_ASSERT(fused_expr != nullptr); |
290 | |
291 | // Do not just remove the broadcast but just reset the thread |
292 | // predicate of the broadcast op. Since fusion is applied only |
293 | // when all parallel broadcast domains are to be parallel |
294 | // reduction, all parallel types can be reset. |
295 | if (with_broadcast) { |
296 | // It may be just fine to remove the broadcast expr, but |
297 | // technically speaking that would violate the root domain mapping |
298 | // as broadcast domains would appear in the consumer of the |
299 | // broadcast output tensor without a broadcast expression. |
300 | for (auto reduction_out : |
301 | ir_utils::filterByType<TensorView>(fused_expr->outputs())) { |
302 | for (auto id : reduction_out->domain()->domain()) { |
303 | if (id->isReduction()) { |
304 | GpuLower::current()->fusedReductionInfo().markAsAllreduce(id); |
305 | GpuLower::current()->threadPredMap().markAsUpdated(reduction_out); |
306 | thread_pred_map_modified_ = true; |
307 | } |
308 | } |
309 | } |
310 | } |
311 | } |
312 | } |
313 | |
314 | private: |
315 | Fusion* fusion_ = nullptr; |
316 | const std::vector<FusedReductionBroadcastInfo>& fusion_list_; |
317 | bool thread_pred_map_modified_ = false; |
318 | }; |
319 | |
320 | } // namespace |
321 | |
322 | void fuseReductionsAndBroadcasts(Fusion* fusion) { |
323 | auto fusion_list = FusionInspector::run(fusion); |
324 | FusionTransformer::run(fusion, fusion_list); |
325 | } |
326 | |
327 | void FusedReductionInfo::markAsAllreduce(IterDomain* id) { |
328 | allreduce_ids_.insert(id); |
329 | } |
330 | |
331 | bool FusedReductionInfo::isAllreduce(IterDomain* id) const { |
332 | return allreduce_ids_.find(id) != allreduce_ids_.end(); |
333 | } |
334 | |
335 | } // namespace cuda |
336 | } // namespace fuser |
337 | } // namespace jit |
338 | } // namespace torch |
339 | |