1#include <ir_utils.h>
2#include <iter_visitor.h>
3#include <kernel_ir_dispatch.h>
4#include <lower2device.h>
5
6#include <lower_fused_reduction.h>
7
8#include <algorithm>
9
10namespace torch {
11namespace jit {
12namespace fuser {
13namespace cuda {
14
15namespace {
16
17//! An instance of reduction patterns to fuse
18class FusedReductionBroadcastInfo : public PolymorphicBase {
19 public:
20 FusedReductionBroadcastInfo(ReductionOp* reduction, bool with_broadcast)
21 : reductions_({reduction}), with_broadcast_({with_broadcast}) {}
22
23 FusedReductionBroadcastInfo(WelfordOp* welford, bool with_broadcast)
24 : reductions_({welford}), with_broadcast_({with_broadcast}) {}
25
26 FusedReductionBroadcastInfo(
27 GroupedReductionOp* grouped_rop,
28 bool with_broadcast)
29 : reductions_({grouped_rop}), with_broadcast_({with_broadcast}) {}
30
31 const std::vector<Expr*>& reductions() const {
32 return reductions_;
33 }
34
35 const std::vector<bool>& withBroadcast() const {
36 return with_broadcast_;
37 }
38
39 private:
40 // Holds ReductionOp, WelfordOp or GroupedReductionOp.
41 std::vector<Expr*> reductions_;
42 // True each reduction also broadcasts
43 std::vector<bool> with_broadcast_;
44};
45
46//! Inspect a fusion to detect eligible sequences of expressions to
47//! use the fused reduction kernel
48class FusionInspector : private IterVisitor {
49 public:
50 static std::vector<FusedReductionBroadcastInfo> run(Fusion* fusion) {
51 FusionInspector inspector(fusion);
52 return inspector.fusion_list_;
53 }
54
55 private:
56 FusionInspector(Fusion* fusion) {
57 traverse(fusion);
58 }
59
60 using IterVisitor::handle;
61
62 void handle(ReductionOp* rop) final {
63 /// If it's a grid reduction, keep track of tensors that depend on
64 /// this reduction.
65 // Only consider when out is on register as that is assumed in the
66 // fused reduction kernel.
67 auto out = ir_utils::getTvOutput(rop);
68 if (out->getMemoryType() == MemoryType::Local &&
69 out->domain()->hasGridReduction()) {
70 reduction_dep_[out].insert(rop);
71 }
72 }
73
74 void handle(WelfordOp* wop) final {
75 /// If it's a grid reduction, keep track of tensors that depend on
76 /// this reduction.
77 // Only consider when out is on register as that is assumed in the
78 // fused reduction kernel.
79 auto out = ir_utils::getTvOutput(wop);
80 if (out->getMemoryType() == MemoryType::Local &&
81 out->domain()->hasGridReduction()) {
82 reduction_dep_[out].insert(wop);
83 }
84 }
85
86 void handle(GroupedReductionOp* grouped_rop) final {
87 auto out = ir_utils::getTvOutput(grouped_rop);
88 if (out->getMemoryType() == MemoryType::Local &&
89 out->domain()->hasGridReduction()) {
90 reduction_dep_[out].insert(grouped_rop);
91 }
92 }
93
94 void handle(Expr* expr) final {
95 IterVisitor::handle(expr);
96 for (auto in_tv : ir_utils::filterByType<TensorView>(expr->inputs())) {
97 for (auto reduction_op : reduction_dep_[in_tv]) {
98 if (fused_exprs_.find(reduction_op) != fused_exprs_.end()) {
99 continue;
100 }
101 for (auto out_tv :
102 ir_utils::filterByType<TensorView>(expr->outputs())) {
103 reduction_dep_[out_tv].insert(reduction_op);
104 }
105 }
106 }
107 }
108
109 // In the case of welford, use the fused broadcast reduction when at
110 // least one of the outputs is broadcast.
111 void handle(BroadcastOp* bop) final {
112 // Detect a pattern where a reduction is followed by a broadcast
113 auto bop_out = bop->out()->as<TensorView>();
114 auto bop_in = bop->in()->as<TensorView>();
115
116 for (Expr* preceding_expr : reduction_dep_[bop_in]) {
117 auto parallel_reduction_axes =
118 getReductionParallelTypeStates(preceding_expr);
119
120 // If not matching, propagate the reduction further down to
121 // subsequent expressions
122 if (!isBroadcastFuseable(bop_out, parallel_reduction_axes)) {
123 continue;
124 }
125
126 if (fused_exprs_.find(preceding_expr) != fused_exprs_.end()) {
127 // Already added to the fusion list. This can happen with
128 // welford as there can be multiple broadcast consumer
129 // expressions.
130 continue;
131 }
132
133 if (preceding_expr->isA<ReductionOp>()) {
134 fusion_list_.emplace_back(preceding_expr->as<ReductionOp>(), true);
135 } else if (preceding_expr->isA<GroupedReductionOp>()) {
136 fusion_list_.emplace_back(
137 preceding_expr->as<GroupedReductionOp>(), true);
138 } else if (preceding_expr->isA<WelfordOp>()) {
139 fusion_list_.emplace_back(preceding_expr->as<WelfordOp>(), true);
140 } else {
141 TORCH_INTERNAL_ASSERT(
142 false, "Invalid preceding expr: ", preceding_expr->toString());
143 }
144
145 fused_exprs_.insert(preceding_expr);
146 }
147 }
148
149 ParallelTypeBitmap getReductionParallelTypeStates(Expr* expr) {
150 ParallelTypeBitmap parallel_reduction_axes;
151
152 for (auto id : ir_utils::getTvOutput(expr)->domain()->domain()) {
153 auto pt = id->getParallelType();
154 if (id->isReduction() && isParallelTypeThread(pt)) {
155 parallel_reduction_axes.set(pt);
156 }
157 }
158
159 return parallel_reduction_axes;
160 }
161
162 // Requires reduction parallel dimensions to exactly match parallel broadcast
163 // dimensions
164 bool isBroadcastFuseable(
165 TensorView* broadcast_out,
166 const ParallelTypeBitmap& parallel_reduction_axes) {
167 const auto broadcast_parallel_types =
168 GpuLower::current()->threadPredMap().getParallelBroadcastDomains(
169 broadcast_out);
170
171 // If no parallel broadcast, nothing to fuse
172 if (broadcast_parallel_types.none()) {
173 return false;
174 }
175
176 // Make sure the broadcast parallel types are the types reduced by
177 // the preceding reduction op
178 for (auto id : broadcast_out->domain()->domain()) {
179 auto pt = id->getParallelType();
180 if (!isParallelTypeThread(pt)) {
181 continue;
182 }
183 // Parallel broadcast must be included in reduction_states
184 if (id->isBroadcast() && broadcast_parallel_types.get(pt)) {
185 if (!parallel_reduction_axes.get(pt)) {
186 return false;
187 }
188 }
189 }
190
191 return true;
192 }
193
194 private:
195 //! List of expression sequences to fuse
196 std::vector<FusedReductionBroadcastInfo> fusion_list_;
197 //! Keep track of fused reduction/welford exprs to avoid duplication
198 std::unordered_set<Expr*> fused_exprs_;
199 //! Keep track of ReductionOp/WelfordOp expressions that are
200 //! (indirectly) input to a tensor
201 std::unordered_map<TensorView*, std::unordered_set<Expr*>> reduction_dep_;
202};
203
204//! Transform a fusion to use the fused reduction kernel.
205class FusionTransformer {
206 public:
207 static void run(
208 Fusion* fusion,
209 const std::vector<FusedReductionBroadcastInfo>& fusion_list) {
210 FusionTransformer transformer(fusion, fusion_list);
211 }
212
213 private:
214 FusionTransformer(
215 Fusion* fusion,
216 const std::vector<FusedReductionBroadcastInfo>& fusion_list)
217 : fusion_(fusion), fusion_list_(fusion_list) {
218 transform();
219 }
220
221 void transform() {
222 for (const auto& info : fusion_list_) {
223 transform(info);
224 }
225 // If the thread predicate map is modified, rebuild the
226 // map. build() only updates mappings that need to be updated.
227 if (thread_pred_map_modified_) {
228 GpuLower::current()->threadPredMap().build(fusion_);
229 }
230 }
231
232 void transform(const FusedReductionBroadcastInfo& info) {
233 TORCH_INTERNAL_ASSERT(
234 info.reductions().size() == 1, "Horizontal fusion not supported yet");
235
236 for (const auto i : c10::irange(info.reductions().size())) {
237 const auto expr = info.reductions().at(i);
238 const auto with_broadcast = info.withBroadcast().at(i);
239 Expr* fused_expr = nullptr;
240
241 if (auto reduction = dynamic_cast<ReductionOp*>(expr)) {
242 TORCH_INTERNAL_ASSERT(!reduction->isAllreduce());
243
244 auto red_op_type = reduction->getReductionOpType();
245 auto init = reduction->init();
246 auto out = reduction->out();
247 auto in = reduction->in();
248
249 fusion_->removeExpr(reduction);
250
251 fused_expr =
252 IrBuilder::create<ReductionOp>(red_op_type, init, out, in, true);
253 } else if (auto welford = dynamic_cast<WelfordOp*>(expr)) {
254 TORCH_INTERNAL_ASSERT(!welford->isAllreduce());
255
256 auto out_avg = welford->outAvg();
257 auto out_var = welford->outVar();
258 auto out_n = welford->outN();
259 auto init_avg = welford->initAvg();
260 auto init_var = welford->initVar();
261 auto init_n = welford->initN();
262 auto in_avg = welford->inAvg();
263 auto in_var = welford->inVar();
264 auto in_n = welford->inN();
265
266 fusion_->removeExpr(welford);
267
268 fused_expr = IrBuilder::create<WelfordOp>(
269 WelfordTriplet{out_avg, out_var, out_n},
270 WelfordTriplet{in_avg, in_var, in_n},
271 WelfordTriplet{init_avg, init_var, init_n},
272 true);
273 } else if (auto grouped_rop = dynamic_cast<GroupedReductionOp*>(expr)) {
274 TORCH_INTERNAL_ASSERT(!grouped_rop->isAllreduce());
275
276 auto op_types = grouped_rop->getReductionOpTypes();
277 auto init_vals = grouped_rop->initVals();
278 auto outputs = grouped_rop->outputs();
279 auto inputs = grouped_rop->inputs();
280
281 fusion_->removeExpr(grouped_rop);
282
283 fused_expr = IrBuilder::create<GroupedReductionOp>(
284 op_types, init_vals, outputs, inputs, true);
285 } else {
286 TORCH_INTERNAL_ASSERT(false, "Invalid expr: ", expr->toString());
287 }
288
289 TORCH_INTERNAL_ASSERT(fused_expr != nullptr);
290
291 // Do not just remove the broadcast but just reset the thread
292 // predicate of the broadcast op. Since fusion is applied only
293 // when all parallel broadcast domains are to be parallel
294 // reduction, all parallel types can be reset.
295 if (with_broadcast) {
296 // It may be just fine to remove the broadcast expr, but
297 // technically speaking that would violate the root domain mapping
298 // as broadcast domains would appear in the consumer of the
299 // broadcast output tensor without a broadcast expression.
300 for (auto reduction_out :
301 ir_utils::filterByType<TensorView>(fused_expr->outputs())) {
302 for (auto id : reduction_out->domain()->domain()) {
303 if (id->isReduction()) {
304 GpuLower::current()->fusedReductionInfo().markAsAllreduce(id);
305 GpuLower::current()->threadPredMap().markAsUpdated(reduction_out);
306 thread_pred_map_modified_ = true;
307 }
308 }
309 }
310 }
311 }
312 }
313
314 private:
315 Fusion* fusion_ = nullptr;
316 const std::vector<FusedReductionBroadcastInfo>& fusion_list_;
317 bool thread_pred_map_modified_ = false;
318};
319
320} // namespace
321
322void fuseReductionsAndBroadcasts(Fusion* fusion) {
323 auto fusion_list = FusionInspector::run(fusion);
324 FusionTransformer::run(fusion, fusion_list);
325}
326
327void FusedReductionInfo::markAsAllreduce(IterDomain* id) {
328 allreduce_ids_.insert(id);
329}
330
331bool FusedReductionInfo::isAllreduce(IterDomain* id) const {
332 return allreduce_ids_.find(id) != allreduce_ids_.end();
333}
334
335} // namespace cuda
336} // namespace fuser
337} // namespace jit
338} // namespace torch
339