1 | #include <arith.h> |
2 | #include <fusion.h> |
3 | #include <ir_builder.h> |
4 | #include <ir_iostream.h> |
5 | #include <ir_utils.h> |
6 | #include <lower_utils.h> |
7 | |
8 | #include <set> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | namespace fuser { |
13 | namespace cuda { |
14 | namespace ir_utils { |
15 | |
16 | std::vector<int64_t> normalizeNew2Old( |
17 | const std::vector<int64_t>& new2old_in, |
18 | size_t ndims) { |
19 | TORCH_CHECK( |
20 | new2old_in.size() == ndims, |
21 | "There must be a transpose mapping for each dimension in domain" ); |
22 | |
23 | // Canonicalize dimensions by wrapping each dim for the given ndims |
24 | std::vector<int64_t> new2old; |
25 | std::transform( |
26 | new2old_in.begin(), |
27 | new2old_in.end(), |
28 | std::inserter(new2old, new2old.begin()), |
29 | [ndims](int64_t entry) { return entry < 0 ? entry + ndims : entry; }); |
30 | |
31 | // Check if any adjusted values are < 0, or >= nDims, which are invalid |
32 | TORCH_CHECK( |
33 | std::none_of( |
34 | new2old.begin(), |
35 | new2old.end(), |
36 | [ndims](int64_t entry) { |
37 | return entry < 0 || (unsigned int)entry >= ndims; |
38 | }), |
39 | "New2Old axes are not within the number of dimensions of the provided domain.\t" , |
40 | new2old); |
41 | |
42 | // Going to use sets, to see if any duplicate values are in the map. |
43 | std::set<int64_t> old_pos_set; |
44 | std::transform( |
45 | new2old.begin(), |
46 | new2old.end(), |
47 | std::inserter(old_pos_set, old_pos_set.begin()), |
48 | [](int64_t entry) { return entry; }); |
49 | |
50 | // Error out if duplicate values are found. |
51 | TORCH_CHECK( |
52 | new2old.size() == ndims && old_pos_set.size() == new2old.size(), |
53 | "Duplicate entries in transformation map." ); |
54 | |
55 | // END VALIDATION CHECKS |
56 | return new2old; |
57 | } |
58 | |
59 | std::vector<int> normalizeOld2New( |
60 | const std::unordered_map<int, int>& old2new_in, |
61 | size_t ndims) { |
62 | // adjust based on negative values (any negative values gets nDims added to |
63 | // it) |
64 | std::unordered_map<int, int> old2new; |
65 | std::transform( |
66 | old2new_in.begin(), |
67 | old2new_in.end(), |
68 | std::inserter(old2new, old2new.begin()), |
69 | [ndims](std::unordered_map<int, int>::value_type entry) { |
70 | return std::unordered_map<int, int>::value_type({ |
71 | entry.first < 0 ? entry.first + ndims : entry.first, |
72 | entry.second < 0 ? entry.second + ndims : entry.second, |
73 | }); |
74 | }); |
75 | |
76 | // Check if any adjusted values are < 0, or >= nDims, which are invalid |
77 | |
78 | TORCH_CHECK( |
79 | std::none_of( |
80 | old2new.begin(), |
81 | old2new.end(), |
82 | [ndims](std::unordered_map<int, int>::value_type entry) { |
83 | return entry.first < 0 || (unsigned int)entry.first >= ndims || |
84 | entry.second < 0 || (unsigned int)entry.second >= ndims; |
85 | }), |
86 | "Reorder axes are not within the number of dimensions of the provided domain." ); |
87 | |
88 | // Going to use sets, to see if any duplicate values are in the map. |
89 | |
90 | std::set<int> old_pos_set; |
91 | std::transform( |
92 | old2new.begin(), |
93 | old2new.end(), |
94 | std::inserter(old_pos_set, old_pos_set.begin()), |
95 | [](std::unordered_map<int, int>::value_type entry) { |
96 | return entry.first; |
97 | }); |
98 | |
99 | std::set<int> new_pos_set; |
100 | std::transform( |
101 | old2new.begin(), |
102 | old2new.end(), |
103 | std::inserter(new_pos_set, new_pos_set.begin()), |
104 | [](std::unordered_map<int, int>::value_type entry) { |
105 | return entry.second; |
106 | }); |
107 | |
108 | // Error out if duplicate values are found. |
109 | TORCH_CHECK( |
110 | old_pos_set.size() == old2new.size() && |
111 | new_pos_set.size() == old2new.size(), |
112 | "Duplicate entries in transformation map sent to TensorView reorder." ); |
113 | |
114 | // END VALIDATION CHECKS |
115 | |
116 | std::vector<int> new2old(ndims, -1); |
117 | |
118 | // Go through each old and new position, make sure they're within [0, ndims) |
119 | for (std::pair<int, int> elem : old2new) { |
120 | int old_pos = elem.first; |
121 | int new_pos = elem.second; |
122 | new2old[new_pos] = old_pos; |
123 | } |
124 | |
125 | // old_positions that already have a new position |
126 | std::set<int> old_positions(new2old.begin(), new2old.end()); |
127 | old_positions.erase(-1); |
128 | |
129 | // All available new positions |
130 | std::set<int> all_positions; |
131 | for (decltype(ndims) i{0}; i < ndims; i++) |
132 | all_positions.insert(i); |
133 | |
134 | // Check what positions haven't been specified. |
135 | std::set<int> positions_left; |
136 | std::set_difference( |
137 | all_positions.begin(), |
138 | all_positions.end(), |
139 | old_positions.begin(), |
140 | old_positions.end(), |
141 | std::inserter(positions_left, positions_left.end())); |
142 | |
143 | // Fill in positions that weren't specified, in relative order, |
144 | // in empty spots in the set of new positions. |
145 | // new2old[new_position] = old_position |
146 | auto it = positions_left.begin(); // old positions left |
147 | std::transform( |
148 | new2old.begin(), new2old.end(), new2old.begin(), [&it](int i) -> int { |
149 | return i == -1 ? *it++ : i; |
150 | }); |
151 | |
152 | return new2old; |
153 | } |
154 | |
155 | namespace ValReplacement { |
156 | // Create New Expr given producer - [an input for the expression] |
157 | // Creates a new Expr substituting current with producer |
158 | struct SubstituteInExpr : public OptInDispatch { |
159 | public: |
160 | static Expr* subsitute(Expr* expr, Val* reference, Val* substitute) { |
161 | TORCH_INTERNAL_ASSERT( |
162 | expr != nullptr && reference != nullptr && substitute != nullptr, |
163 | "Nullptr arg found." ); |
164 | SubstituteInExpr sie(reference, substitute); |
165 | sie.handle(expr); |
166 | TORCH_INTERNAL_ASSERT( |
167 | sie.expr_ != nullptr, |
168 | "Substitution failed of " , |
169 | reference, |
170 | " with " , |
171 | substitute); |
172 | return sie.expr_; |
173 | } |
174 | |
175 | private: |
176 | explicit SubstituteInExpr(Val* reference, Val* substitute) |
177 | : reference_(reference), substitute_(substitute) {} |
178 | |
179 | void handle(Expr* expr) final { |
180 | OptInDispatch::handle(expr); |
181 | } |
182 | |
183 | void handle(FullOp* full_expr) final { |
184 | auto out = reference_->sameAs(full_expr->output(0)) ? substitute_ |
185 | : full_expr->output(0); |
186 | expr_ = IrBuilder::create<FullOp>( |
187 | full_expr->container(), |
188 | out, |
189 | full_expr->getFillValue(), |
190 | full_expr->dtype()); |
191 | } |
192 | |
193 | void handle(ARangeOp* arange_expr) final { |
194 | auto start = reference_->sameAs(arange_expr->start()) |
195 | ? substitute_ |
196 | : arange_expr->start(); |
197 | auto end = reference_->sameAs(arange_expr->end()) ? substitute_ |
198 | : arange_expr->end(); |
199 | auto step = reference_->sameAs(arange_expr->step()) ? substitute_ |
200 | : arange_expr->step(); |
201 | auto out = reference_->sameAs(arange_expr->output(0)) |
202 | ? substitute_ |
203 | : arange_expr->output(0); |
204 | expr_ = IrBuilder::create<ARangeOp>( |
205 | arange_expr->container(), |
206 | out, |
207 | start, |
208 | end, |
209 | step, |
210 | arange_expr->dtype(), |
211 | arange_expr->getLinearLogicalIndex()); |
212 | } |
213 | |
214 | void handle(EyeOp* eye_expr) final { |
215 | auto out = reference_->sameAs(eye_expr->output(0)) ? substitute_ |
216 | : eye_expr->output(0); |
217 | expr_ = IrBuilder::create<EyeOp>( |
218 | eye_expr->container(), |
219 | out, |
220 | eye_expr->dtype(), |
221 | eye_expr->getIndex1(), |
222 | eye_expr->getIndex2()); |
223 | } |
224 | |
225 | void handle(UnaryOp* unary_expr) final { |
226 | auto in = |
227 | reference_->sameAs(unary_expr->in()) ? substitute_ : unary_expr->in(); |
228 | auto out = |
229 | reference_->sameAs(unary_expr->out()) ? substitute_ : unary_expr->out(); |
230 | expr_ = IrBuilder::create<UnaryOp>( |
231 | unary_expr->container(), unary_expr->getUnaryOpType(), out, in); |
232 | } |
233 | |
234 | void handle(BinaryOp* binary_expr) final { |
235 | auto lhs = reference_->sameAs(binary_expr->lhs()) ? substitute_ |
236 | : binary_expr->lhs(); |
237 | auto rhs = reference_->sameAs(binary_expr->rhs()) ? substitute_ |
238 | : binary_expr->rhs(); |
239 | auto out = reference_->sameAs(binary_expr->out()) ? substitute_ |
240 | : binary_expr->out(); |
241 | |
242 | expr_ = IrBuilder::create<BinaryOp>( |
243 | binary_expr->container(), |
244 | binary_expr->getBinaryOpType(), |
245 | out, |
246 | lhs, |
247 | rhs); |
248 | } |
249 | |
250 | void handle(TernaryOp* ternary_expr) final { |
251 | auto in1 = reference_->sameAs(ternary_expr->in1()) ? substitute_ |
252 | : ternary_expr->in1(); |
253 | auto in2 = reference_->sameAs(ternary_expr->in2()) ? substitute_ |
254 | : ternary_expr->in2(); |
255 | auto in3 = reference_->sameAs(ternary_expr->in3()) ? substitute_ |
256 | : ternary_expr->in3(); |
257 | auto out = reference_->sameAs(ternary_expr->out()) ? substitute_ |
258 | : ternary_expr->out(); |
259 | expr_ = IrBuilder::create<TernaryOp>( |
260 | ternary_expr->container(), |
261 | ternary_expr->getTernaryOpType(), |
262 | out, |
263 | in1, |
264 | in2, |
265 | in3); |
266 | } |
267 | |
268 | void handle(RNGOp* rng_expr) final { |
269 | std::vector<Val*> subsituted_params; |
270 | for (auto v : rng_expr->getParameters()) { |
271 | subsituted_params.emplace_back(reference_->sameAs(v) ? substitute_ : v); |
272 | } |
273 | auto out = reference_->sameAs(rng_expr->output(0)) ? substitute_ |
274 | : rng_expr->output(0); |
275 | expr_ = IrBuilder::create<RNGOp>( |
276 | rng_expr->container(), |
277 | rng_expr->getRNGOpType(), |
278 | out, |
279 | rng_expr->dtype(), |
280 | subsituted_params, |
281 | rng_expr->getRNGOffset(), |
282 | rng_expr->getPhiloxIndex()); |
283 | } |
284 | |
285 | void handle(ReductionOp* reduction_expr) final { |
286 | auto init = reference_->sameAs(reduction_expr->init()) |
287 | ? substitute_ |
288 | : reduction_expr->init(); |
289 | auto out = reference_->sameAs(reduction_expr->out()) |
290 | ? substitute_ |
291 | : reduction_expr->out(); |
292 | auto in = reference_->sameAs(reduction_expr->in()) ? substitute_ |
293 | : reduction_expr->in(); |
294 | |
295 | expr_ = IrBuilder::create<ReductionOp>( |
296 | reduction_expr->container(), |
297 | reduction_expr->getReductionOpType(), |
298 | init, |
299 | out, |
300 | in); |
301 | } |
302 | |
303 | void handle(GroupedReductionOp* grouped_reduction_expr) final { |
304 | std::vector<Val*> outputs; |
305 | std::transform( |
306 | grouped_reduction_expr->outputs().begin(), |
307 | grouped_reduction_expr->outputs().end(), |
308 | std::back_inserter(outputs), |
309 | [&](Val* val) { return reference_->sameAs(val) ? substitute_ : val; }); |
310 | |
311 | std::vector<Val*> inputs; |
312 | std::transform( |
313 | grouped_reduction_expr->inputs().begin(), |
314 | grouped_reduction_expr->inputs().end(), |
315 | std::back_inserter(inputs), |
316 | [&](Val* val) { return reference_->sameAs(val) ? substitute_ : val; }); |
317 | |
318 | std::vector<Val*> init_vals; |
319 | std::transform( |
320 | grouped_reduction_expr->initVals().begin(), |
321 | grouped_reduction_expr->initVals().end(), |
322 | std::back_inserter(init_vals), |
323 | [&](Val* val) { return reference_->sameAs(val) ? substitute_ : val; }); |
324 | |
325 | expr_ = IrBuilder::create<GroupedReductionOp>( |
326 | grouped_reduction_expr->container(), |
327 | grouped_reduction_expr->getReductionOpTypes(), |
328 | init_vals, |
329 | outputs, |
330 | inputs); |
331 | } |
332 | |
333 | void handle(BroadcastOp* broadcast_expr) final { |
334 | auto out = reference_->sameAs(broadcast_expr->out()) |
335 | ? substitute_ |
336 | : broadcast_expr->out(); |
337 | auto in = reference_->sameAs(broadcast_expr->in()) ? substitute_ |
338 | : broadcast_expr->in(); |
339 | |
340 | expr_ = IrBuilder::create<BroadcastOp>( |
341 | broadcast_expr->container(), |
342 | out, |
343 | in, |
344 | broadcast_expr->getBroadcastDimFlags()); |
345 | } |
346 | |
347 | void handle(TransposeOp* transpose_expr) final { |
348 | TORCH_INTERNAL_ASSERT( |
349 | substitute_->isA<TensorView>(), |
350 | "All args to transpose must be tensor view, but received a non-TensorView for replacement: " , |
351 | substitute_); |
352 | auto out = reference_->sameAs(transpose_expr->out()) |
353 | ? substitute_->as<TensorView>() |
354 | : transpose_expr->out(); |
355 | auto in = reference_->sameAs(transpose_expr->in()) |
356 | ? substitute_->as<TensorView>() |
357 | : transpose_expr->in(); |
358 | expr_ = IrBuilder::create<TransposeOp>( |
359 | transpose_expr->container(), out, in, transpose_expr->new2old()); |
360 | } |
361 | |
362 | void handle(ExpandOp* expand_expr) final { |
363 | auto out = reference_->sameAs(expand_expr->out()) |
364 | ? substitute_->as<TensorView>() |
365 | : expand_expr->out(); |
366 | auto in = reference_->sameAs(expand_expr->in()) |
367 | ? substitute_->as<TensorView>() |
368 | : expand_expr->in(); |
369 | |
370 | auto expanded_extents = expand_expr->expanded_extents(); |
371 | if (substitute_->isA<Int>()) { |
372 | for (auto i : c10::irange(expanded_extents.size())) { |
373 | if (!expanded_extents[i]->sameAs(substitute_)) { |
374 | expanded_extents[i] = substitute_; |
375 | } |
376 | } |
377 | } |
378 | expr_ = IrBuilder::create<ExpandOp>( |
379 | expand_expr->container(), out, in, expanded_extents); |
380 | } |
381 | |
382 | void handle(ShiftOp* shift_expr) final { |
383 | auto out = |
384 | reference_->sameAs(shift_expr->out()) ? substitute_ : shift_expr->out(); |
385 | auto in = |
386 | reference_->sameAs(shift_expr->in()) ? substitute_ : shift_expr->in(); |
387 | |
388 | expr_ = IrBuilder::create<ShiftOp>( |
389 | shift_expr->container(), |
390 | out, |
391 | in, |
392 | shift_expr->offsets(), |
393 | shift_expr->padWidth()); |
394 | } |
395 | |
396 | void handle(GatherOp* gather_expr) final { |
397 | auto out = reference_->sameAs(gather_expr->out()) ? substitute_ |
398 | : gather_expr->out(); |
399 | auto in = |
400 | reference_->sameAs(gather_expr->in()) ? substitute_ : gather_expr->in(); |
401 | |
402 | expr_ = IrBuilder::create<GatherOp>( |
403 | gather_expr->container(), |
404 | out, |
405 | in, |
406 | gather_expr->windowShape(), |
407 | gather_expr->padWidth()); |
408 | } |
409 | |
410 | void handle(ViewAsScalar* expr) final { |
411 | TORCH_INTERNAL_ASSERT( |
412 | substitute_->isA<TensorView>(), |
413 | "All args to view must be TensorView, but received a non-TensorView for replacement: " , |
414 | substitute_); |
415 | auto in = reference_->sameAs(expr->in()) ? substitute_->as<TensorView>() |
416 | : expr->in(); |
417 | auto out = reference_->sameAs(expr->out()) ? substitute_->as<TensorView>() |
418 | : expr->out(); |
419 | expr_ = IrBuilder::create<ViewAsScalar>( |
420 | expr->container(), out, in, expr->vector_id(), expr->index()); |
421 | } |
422 | |
423 | void handle(ViewOp* view_expr) final { |
424 | TORCH_INTERNAL_ASSERT( |
425 | substitute_->isA<TensorView>(), |
426 | "All args to view must be TensorView, but received a non-TensorView for replacement: " , |
427 | substitute_); |
428 | auto in = reference_->sameAs(view_expr->in()) |
429 | ? substitute_->as<TensorView>() |
430 | : view_expr->in(); |
431 | auto out = reference_->sameAs(view_expr->out()) |
432 | ? substitute_->as<TensorView>() |
433 | : view_expr->out(); |
434 | expr_ = IrBuilder::create<ViewOp>(view_expr->container(), out, in); |
435 | } |
436 | |
437 | void handle(WelfordOp* welford_expr) final { |
438 | auto out_avg = reference_->sameAs(welford_expr->outAvg()) |
439 | ? substitute_->as<TensorView>() |
440 | : welford_expr->outAvg(); |
441 | auto out_var = reference_->sameAs(welford_expr->outVar()) |
442 | ? substitute_->as<TensorView>() |
443 | : welford_expr->outVar(); |
444 | auto out_N = reference_->sameAs(welford_expr->outN()) |
445 | ? substitute_->as<TensorView>() |
446 | : welford_expr->outN(); |
447 | auto in_avg = reference_->sameAs(welford_expr->inAvg()) |
448 | ? substitute_->as<TensorView>() |
449 | : welford_expr->inAvg(); |
450 | auto in_var = |
451 | welford_expr->inVar() && reference_->sameAs(welford_expr->inVar()) |
452 | ? substitute_->as<TensorView>() |
453 | : welford_expr->inVar(); |
454 | auto in_N = reference_->sameAs(welford_expr->inN()) ? substitute_ |
455 | : welford_expr->inN(); |
456 | auto init_avg = |
457 | welford_expr->initAvg() && reference_->sameAs(welford_expr->initAvg()) |
458 | ? substitute_->as<TensorView>() |
459 | : welford_expr->initAvg(); |
460 | auto init_var = |
461 | welford_expr->initVar() && reference_->sameAs(welford_expr->initVar()) |
462 | ? substitute_->as<TensorView>() |
463 | : welford_expr->initVar(); |
464 | auto init_N = |
465 | welford_expr->initN() && reference_->sameAs(welford_expr->initN()) |
466 | ? substitute_ |
467 | : welford_expr->initN(); |
468 | expr_ = IrBuilder::create<WelfordOp>( |
469 | welford_expr->container(), |
470 | out_avg, |
471 | out_var, |
472 | out_N, |
473 | in_avg, |
474 | in_var, |
475 | in_N, |
476 | init_avg, |
477 | init_var, |
478 | init_N, |
479 | welford_expr->isAllreduce()); |
480 | } |
481 | |
482 | void handle(LoadStoreOp* ldst_expr) final { |
483 | TORCH_INTERNAL_ASSERT( |
484 | substitute_->isA<TensorView>(), |
485 | "All args to view must be TensorView, but received a non-TensorView for replacement: " , |
486 | substitute_); |
487 | auto in = reference_->sameAs(ldst_expr->in()) |
488 | ? substitute_->as<TensorView>() |
489 | : ldst_expr->in(); |
490 | auto out = reference_->sameAs(ldst_expr->out()) |
491 | ? substitute_->as<TensorView>() |
492 | : ldst_expr->out(); |
493 | expr_ = IrBuilder::create<LoadStoreOp>( |
494 | ldst_expr->container(), ldst_expr->opType(), out, in); |
495 | } |
496 | |
497 | void handle(MmaOp* mma_expr) final { |
498 | TORCH_INTERNAL_ASSERT( |
499 | substitute_->isA<TensorView>(), |
500 | "All args to MmaOp must be TensorView, but received a non-TensorView for replacement: " , |
501 | substitute_); |
502 | auto in_a = reference_->sameAs(mma_expr->inA()) |
503 | ? substitute_->as<TensorView>() |
504 | : mma_expr->inA(); |
505 | auto in_b = reference_->sameAs(mma_expr->inB()) |
506 | ? substitute_->as<TensorView>() |
507 | : mma_expr->inB(); |
508 | auto out = reference_->sameAs(mma_expr->out()) |
509 | ? substitute_->as<TensorView>() |
510 | : mma_expr->out(); |
511 | auto init = reference_->sameAs(mma_expr->init()) |
512 | ? substitute_->as<TensorView>() |
513 | : mma_expr->init(); |
514 | expr_ = IrBuilder::create<MmaOp>( |
515 | mma_expr->container(), out, in_a, in_b, init, mma_expr->options()); |
516 | } |
517 | |
518 | private: |
519 | Val* reference_ = nullptr; |
520 | Val* substitute_ = nullptr; |
521 | Expr* expr_ = nullptr; |
522 | }; |
523 | |
524 | } // namespace ValReplacement |
525 | |
526 | Expr* replaceValInExpr(Expr* expr, Val* reference, Val* substitute) { |
527 | FusionGuard fg(expr->fusion()); |
528 | return ValReplacement::SubstituteInExpr::subsitute( |
529 | expr, reference, substitute); |
530 | } |
531 | |
532 | TensorView* rfactorHelper( |
533 | TensorView* reduction_tv, |
534 | const std::vector<int>& axes) { |
535 | TORCH_INTERNAL_ASSERT(reduction_tv->definition() != nullptr); |
536 | const bool has_multiple_tvs = reduction_tv->definition()->inputs().size() > 1; |
537 | if (!has_multiple_tvs) { |
538 | return reduction_tv->rFactor(axes); |
539 | } |
540 | |
541 | std::vector<TensorView*> out_tvs; |
542 | std::transform( |
543 | reduction_tv->definition()->outputs().begin(), |
544 | reduction_tv->definition()->outputs().end(), |
545 | std::back_inserter(out_tvs), |
546 | [](Val* val) { return val->as<TensorView>(); }); |
547 | |
548 | auto rf_tvs = reduction_tv->rFactor(axes, out_tvs); |
549 | |
550 | return rf_tvs.at(std::distance( |
551 | out_tvs.begin(), |
552 | std::find(out_tvs.begin(), out_tvs.end(), reduction_tv))); |
553 | } |
554 | |
555 | namespace { |
556 | |
557 | template <typename T> |
558 | std::vector<T*> uniqueEntries(const std::vector<T*>& tv_deuqe) { |
559 | std::vector<T*> unique_entries; |
560 | std::unordered_set<T*> inserted; |
561 | for (auto tv_entry : tv_deuqe) { |
562 | if (inserted.emplace(tv_entry).second) { |
563 | unique_entries.emplace_back(tv_entry); |
564 | } |
565 | } |
566 | return unique_entries; |
567 | } |
568 | |
569 | } // namespace |
570 | |
571 | // Return immediate producers of val |
572 | std::vector<Val*> producerValsOf(Val* val) { |
573 | if (val->definition() == nullptr) { |
574 | return {}; |
575 | } |
576 | auto producer_vals = val->definition()->inputs(); |
577 | return uniqueEntries<Val>({producer_vals.begin(), producer_vals.end()}); |
578 | } |
579 | |
580 | // Return immediate consumers of val |
581 | std::vector<Val*> consumerValsOf(Val* val) { |
582 | std::vector<Val*> consumer_vals; |
583 | for (auto use_expr : val->uses()) { |
584 | auto outputs = use_expr->outputs(); |
585 | consumer_vals.insert(consumer_vals.end(), outputs.begin(), outputs.end()); |
586 | } |
587 | return uniqueEntries<Val>(consumer_vals); |
588 | } |
589 | |
590 | // Return immediate siblings of val |
591 | std::vector<Val*> siblingValsOf(Val* val) { |
592 | std::vector<Val*> sibling_vals; |
593 | auto def = val->definition(); |
594 | if (def != nullptr) { |
595 | auto outs = def->outputs(); |
596 | for (auto sibling_val : outs) { |
597 | if (sibling_val == val) { |
598 | continue; |
599 | } |
600 | sibling_vals.emplace_back(sibling_val); |
601 | } |
602 | } |
603 | return sibling_vals; |
604 | } |
605 | |
606 | // Return immediate producers of val |
607 | std::vector<Val*> producerValsOf(const std::vector<Val*>& vals) { |
608 | std::vector<Val*> all_producer_vals; |
609 | for (auto val : vals) { |
610 | auto producer_vals = producerValsOf(val); |
611 | all_producer_vals.insert( |
612 | all_producer_vals.end(), producer_vals.begin(), producer_vals.end()); |
613 | } |
614 | |
615 | return uniqueEntries<Val>(all_producer_vals); |
616 | } |
617 | |
618 | // Return immediate consumers of val |
619 | std::vector<Val*> consumerValsOf(const std::vector<Val*>& vals) { |
620 | std::vector<Val*> all_consumer_vals; |
621 | for (auto val : vals) { |
622 | auto consumer_vals = consumerValsOf(val); |
623 | all_consumer_vals.insert( |
624 | all_consumer_vals.end(), consumer_vals.begin(), consumer_vals.end()); |
625 | } |
626 | |
627 | return uniqueEntries<Val>(all_consumer_vals); |
628 | } |
629 | |
630 | std::vector<TensorView*> producerTvsOf(TensorView* tv) { |
631 | auto producer_vals = producerValsOf(tv); |
632 | auto producer_tvs = ir_utils::filterByType<TensorView>(producer_vals); |
633 | return {producer_tvs.begin(), producer_tvs.end()}; |
634 | } |
635 | |
636 | std::vector<TensorView*> consumerTvsOf(TensorView* tv) { |
637 | auto consumer_vals = consumerValsOf(tv); |
638 | auto consumer_tvs = ir_utils::filterByType<TensorView>(consumer_vals); |
639 | return {consumer_tvs.begin(), consumer_tvs.end()}; |
640 | } |
641 | |
642 | std::vector<TensorView*> siblingTvsOf(TensorView* tv) { |
643 | auto sibling_vals = siblingValsOf(tv); |
644 | auto sibling_tvs = ir_utils::filterByType<TensorView>(sibling_vals); |
645 | return {sibling_tvs.begin(), sibling_tvs.end()}; |
646 | } |
647 | |
648 | std::vector<TensorView*> producerTvsOf(const std::vector<TensorView*>& tvs) { |
649 | std::vector<TensorView*> all_producer_tvs; |
650 | for (auto tv : tvs) { |
651 | auto producer_tvs = producerTvsOf(tv); |
652 | all_producer_tvs.insert( |
653 | all_producer_tvs.end(), producer_tvs.begin(), producer_tvs.end()); |
654 | } |
655 | |
656 | return uniqueEntries<TensorView>(all_producer_tvs); |
657 | } |
658 | |
659 | std::vector<TensorView*> consumerTvsOf(const std::vector<TensorView*>& tvs) { |
660 | std::vector<TensorView*> all_consumer_tvs; |
661 | for (auto tv : tvs) { |
662 | auto consumer_tvs = consumerTvsOf(tv); |
663 | all_consumer_tvs.insert( |
664 | all_consumer_tvs.end(), consumer_tvs.begin(), consumer_tvs.end()); |
665 | } |
666 | |
667 | return uniqueEntries<TensorView>(all_consumer_tvs); |
668 | } |
669 | |
670 | std::vector<TensorView*> inputTvsOf(TensorView* tv) { |
671 | return inputTvsOf(std::vector<TensorView*>{tv}); |
672 | } |
673 | |
674 | std::vector<TensorView*> outputTvsOf(TensorView* tv) { |
675 | return outputTvsOf(std::vector<TensorView*>{tv}); |
676 | } |
677 | |
678 | std::vector<TensorView*> inputTvsOf(std::vector<TensorView*> tvs) { |
679 | auto inp_vals = IterVisitor::getInputsTo({tvs.begin(), tvs.end()}); |
680 | auto filtered = ir_utils::filterByType<TensorView>(inp_vals); |
681 | std::vector<TensorView*> inp_tvs(filtered.begin(), filtered.end()); |
682 | return uniqueEntries<TensorView>(inp_tvs); |
683 | } |
684 | |
685 | std::vector<TensorView*> outputTvsOf(std::vector<TensorView*> tvs) { |
686 | auto out_vals = DependencyCheck::getAllOutputsOf({tvs.begin(), tvs.end()}); |
687 | auto filtered = ir_utils::filterByType<TensorView>(out_vals); |
688 | std::vector<TensorView*> out_tvs(filtered.begin(), filtered.end()); |
689 | return uniqueEntries<TensorView>(out_tvs); |
690 | } |
691 | |
692 | std::vector<TensorView*> allTvs(Fusion* fusion) { |
693 | auto used_vals = fusion->usedMathVals(); |
694 | auto used_tvs = ir_utils::filterByType<TensorView>(used_vals); |
695 | |
696 | // This shouldn't be necessary but FusionSegmentIoAlias_CUDA due to aliasing |
697 | // is having an input disconnected from outputs, and these iter domains are |
698 | // being checked in compute at maps in scheduling logic. This shouldn't hurt |
699 | // AFAICT. |
700 | auto tv_inputs = ir_utils::filterByType<TensorView>(fusion->inputs()); |
701 | |
702 | std::vector<TensorView*> all_tvs({used_tvs.begin(), used_tvs.end()}); |
703 | // Sometimes inputs are not connected to outputs, however, we still include |
704 | // them when returning allTvs because they are registered as an input. |
705 | all_tvs.insert(all_tvs.end(), tv_inputs.begin(), tv_inputs.end()); |
706 | |
707 | // all_tvs has duplicates, to deduplicate it and return |
708 | return uniqueEntries<TensorView>(all_tvs); |
709 | } |
710 | |
711 | std::vector<TensorView*> allTvsExcept( |
712 | Fusion* fusion, |
713 | const std::unordered_set<TensorView*>& except) { |
714 | auto all_tvs = allTvs(fusion); |
715 | std::vector<TensorView*> result; |
716 | for (auto tv : all_tvs) { |
717 | if (except.count(tv) == 0) { |
718 | result.emplace_back(tv); |
719 | } |
720 | } |
721 | return result; |
722 | } |
723 | |
724 | std::vector<Expr*> getReductionOps(Fusion* fusion, bool ignore_trivial) { |
725 | std::vector<Expr*> red_ops; |
726 | |
727 | auto isReduction = [&ignore_trivial](Val* out_val) { |
728 | if (out_val == nullptr || !out_val->isA<TensorView>()) { |
729 | return false; |
730 | } |
731 | auto out_tv = out_val->as<TensorView>(); |
732 | return std::any_of( |
733 | out_tv->getRootDomain().begin(), |
734 | out_tv->getRootDomain().end(), |
735 | [&ignore_trivial](IterDomain* id) { |
736 | return id->isReduction() && |
737 | !(ignore_trivial && id->isTrivialReduction()); |
738 | }); |
739 | }; |
740 | |
741 | for (auto expr : fusion->exprs()) { |
742 | bool is_reduction = false; |
743 | if (expr->isA<ReductionOp>()) { |
744 | is_reduction = isReduction(expr->as<ReductionOp>()->out()); |
745 | } else if (expr->isA<GroupedReductionOp>()) { |
746 | is_reduction = std::any_of( |
747 | expr->as<GroupedReductionOp>()->outputs().begin(), |
748 | expr->as<GroupedReductionOp>()->outputs().end(), |
749 | isReduction); |
750 | } else if (expr->isA<WelfordOp>()) { |
751 | is_reduction = isReduction(expr->as<WelfordOp>()->outAvg()); |
752 | } |
753 | if (is_reduction) { |
754 | red_ops.push_back(expr); |
755 | } |
756 | } |
757 | |
758 | return red_ops; |
759 | } |
760 | |
761 | namespace { |
762 | |
763 | class ValReplacementMutator : private OptOutMutator { |
764 | public: |
765 | ValReplacementMutator( |
766 | Fusion* fusion, |
767 | const std::unordered_map<Val*, Val*>& replacement_map) |
768 | : replacement_map_(replacement_map) { |
769 | FusionGuard fg(fusion); |
770 | |
771 | // Welford makes this a little annoying since it holds a count which is |
772 | // typically not used by anything else. If we don't grab that count, then it |
773 | // would be a tensorview that doesn't get updated extents. Therefore, first |
774 | // grab all leaves towards outputs and grab stmts from there. |
775 | auto stmts = StmtSort::getStmts(fusion, allLeafOuts(fusion), true); |
776 | |
777 | // Some fusions, such as standalone rand_like, can have disconnected DAG, so |
778 | // we need some mechanism to make sure our replacement set is as complete as |
779 | // possible |
780 | // TODO: I think we need a more general mechanism to support disconnected |
781 | // DAG |
782 | std::vector<Val*> more; |
783 | for (auto v : fusion->inputs()) { |
784 | if (std::find(stmts.begin(), stmts.end(), v) == stmts.end()) { |
785 | more.emplace_back(v); |
786 | } |
787 | } |
788 | auto more_stmts = StmtSort::getStmts(fusion, more, true); |
789 | more_stmts.insert(more_stmts.end(), stmts.begin(), stmts.end()); |
790 | |
791 | for (auto stmt : more_stmts) { |
792 | mutate(stmt); |
793 | } |
794 | } |
795 | |
796 | private: |
797 | using OptOutMutator::mutate; |
798 | |
799 | void mutate(Val* val) final { |
800 | if (replacement_map_.find(val) == replacement_map_.end()) { |
801 | return OptOutMutator::mutate(val); |
802 | } |
803 | auto replaced_val = replacement_map_.at(val); |
804 | registerMutation(val, replaced_val); |
805 | } |
806 | |
807 | std::vector<Val*> allLeafOuts(Fusion* fusion) { |
808 | auto exprs = StmtSort::getExprs(fusion, true); |
809 | std::unordered_set<Val*> inputs; |
810 | std::unordered_set<Val*> outputs; |
811 | std::vector<Val*> ordered_outputs; |
812 | for (auto expr : exprs) { |
813 | inputs.insert(expr->inputs().begin(), expr->inputs().end()); |
814 | outputs.insert(expr->outputs().begin(), expr->outputs().end()); |
815 | ordered_outputs.insert( |
816 | ordered_outputs.end(), |
817 | expr->outputs().begin(), |
818 | expr->outputs().end()); |
819 | } |
820 | for (auto input : inputs) { |
821 | outputs.erase(input); |
822 | } |
823 | |
824 | std::vector<Val*> ordered_leaf_outs; |
825 | for (auto out : ordered_outputs) { |
826 | if (outputs.find(out) != outputs.end()) { |
827 | ordered_leaf_outs.push_back(out); |
828 | } |
829 | } |
830 | return ordered_leaf_outs; |
831 | } |
832 | |
833 | const std::unordered_map<Val*, Val*>& replacement_map_; |
834 | }; |
835 | |
836 | } // namespace |
837 | |
838 | void replaceValue( |
839 | Fusion* fusion, |
840 | const std::unordered_map<Val*, Val*>& replacement_map) { |
841 | ValReplacementMutator(fusion, replacement_map); |
842 | } |
843 | |
844 | Val* getReductionInitValOf(TensorView* tv) { |
845 | auto def = tv->definition(); |
846 | if (def == nullptr) { |
847 | return nullptr; |
848 | } |
849 | |
850 | Val* init = nullptr; |
851 | if (auto rop = dynamic_cast<ReductionOp*>(def)) { |
852 | init = rop->init(); |
853 | } else if (auto grop = dynamic_cast<GroupedReductionOp*>(def)) { |
854 | int output_idx = grop->getExprIndexOfOutput(tv); |
855 | init = grop->initVal(output_idx); |
856 | } else if (auto wop = dynamic_cast<WelfordOp*>(def)) { |
857 | return wop->getInitValOfOutput(tv); |
858 | } else if (auto gwop = dynamic_cast<GroupedWelfordOp*>(def)) { |
859 | init = gwop->getInitValOfOutput(tv); |
860 | } else if (auto mma = dynamic_cast<MmaOp*>(def)) { |
861 | init = mma->init(); |
862 | } |
863 | |
864 | return init; |
865 | } |
866 | |
867 | // TODO: Should mma be in here? Should we return true if it's a trivial |
868 | // reduction? |
869 | bool isReductionOp(const Expr* expr) { |
870 | // Note that GridReduction inherits ReductionOp |
871 | return expr->isA<ReductionOp>() || expr->isA<GroupedReductionOp>() || |
872 | expr->isA<WelfordOp>() || expr->isA<GroupedWelfordOp>() || |
873 | expr->isA<kir::GridWelford>() || expr->isA<kir::GroupedGridWelford>(); |
874 | } |
875 | |
876 | bool isReductionTvOp(const Expr* expr) { |
877 | return ir_utils::isTvOp(expr) && isReductionOp(expr); |
878 | } |
879 | |
880 | std::vector<ViewOp*> getViewOps(Fusion* fusion) { |
881 | auto all_exprs = fusion->exprs(); |
882 | |
883 | auto all_view_ops = ir_utils::filterByType<ViewOp>(all_exprs); |
884 | |
885 | std::vector<ViewOp*> view_ops; |
886 | |
887 | std::copy_if( |
888 | all_view_ops.begin(), |
889 | all_view_ops.end(), |
890 | std::back_inserter(view_ops), |
891 | [](ViewOp* view) { |
892 | return std::any_of( |
893 | view->outputs().begin(), view->outputs().end(), [](Val* v) { |
894 | if (!v->isA<TensorView>()) { |
895 | return false; |
896 | } |
897 | return v->as<TensorView>()->hasRFactor(); |
898 | }); |
899 | }); |
900 | |
901 | return view_ops; |
902 | } |
903 | |
904 | namespace { |
905 | |
906 | struct ReplaceValInIndexVal : public OptInDispatch { |
907 | public: |
908 | //! Apply replacements to index as specified in |
909 | //! replacement_map. index is assumed to consist only from Int and |
910 | //! NamedScalar |
911 | static Val* replace( |
912 | Val* index, |
913 | const std::unordered_map<Val*, Val*>& replacement_map) { |
914 | ReplaceValInIndexVal replace_index_val(replacement_map); |
915 | replace_index_val.handle(index); |
916 | // Return the original index if not replaced |
917 | if (replace_index_val.is_replaced_) { |
918 | return replace_index_val.last_visited_val_; |
919 | } else { |
920 | return index; |
921 | } |
922 | } |
923 | |
924 | private: |
925 | ReplaceValInIndexVal(const std::unordered_map<Val*, Val*>& replacement_map) |
926 | : replacement_map_(replacement_map) {} |
927 | |
928 | using OptOutDispatch::handle; |
929 | |
930 | void handle(Val* val) override { |
931 | TORCH_INTERNAL_ASSERT( |
932 | val->isA<Int>() || val->isA<NamedScalar>() || val->isA<kir::IntPair>(), |
933 | "Invalid Val type: " , |
934 | val->toString()); |
935 | |
936 | // if val appears in the replacement map, stop traversing and set |
937 | // the current val with the replacement |
938 | auto it = replacement_map_.find(val); |
939 | if (it != replacement_map_.end()) { |
940 | last_visited_val_ = it->second; |
941 | is_replaced_ = true; |
942 | return; |
943 | } |
944 | |
945 | // Recursively traverse its defining expr |
946 | auto def = val->definition(); |
947 | if (def != nullptr) { |
948 | switch (def->etype()) { |
949 | case ExprType::UnaryOp: |
950 | case ExprType::BinaryOp: |
951 | case ExprType::Swizzle2DInt: |
952 | case ExprType::PairSelect: |
953 | handle(val->definition()); |
954 | break; |
955 | default: |
956 | TORCH_INTERNAL_ASSERT( |
957 | false, "Unexpected definition: " , def->toString()) |
958 | } |
959 | // last_visited_val_ is set in the expr handlers |
960 | } else { |
961 | last_visited_val_ = val; |
962 | } |
963 | } |
964 | |
965 | // Clone expression after recurisvely replacing inputs |
966 | void handle(UnaryOp* uop) override { |
967 | handle(uop->in()); |
968 | auto inp = last_visited_val_; |
969 | TORCH_INTERNAL_ASSERT(uop->out()->isA<Int>()); |
970 | auto out = IrBuilder::create<Int>(c10::nullopt); |
971 | IrBuilder::create<UnaryOp>(uop->getUnaryOpType(), out, inp); |
972 | last_visited_val_ = out; |
973 | } |
974 | |
975 | // Clone expression after recurisvely replacing inputs |
976 | void handle(BinaryOp* bop) override { |
977 | handle(bop->lhs()); |
978 | auto lhs = last_visited_val_; |
979 | handle(bop->rhs()); |
980 | auto rhs = last_visited_val_; |
981 | TORCH_INTERNAL_ASSERT(bop->out()->isA<Int>()); |
982 | auto out = IrBuilder::create<Int>(c10::nullopt); |
983 | IrBuilder::create<BinaryOp>(bop->getBinaryOpType(), out, lhs, rhs); |
984 | last_visited_val_ = out; |
985 | } |
986 | |
987 | // Clone expression after recurisvely replacing inputs |
988 | void handle(kir::Swizzle2DInt* swizzle_2d) override { |
989 | handle(swizzle_2d->inX()); |
990 | auto in_x = last_visited_val_; |
991 | handle(swizzle_2d->inY()); |
992 | auto in_y = last_visited_val_; |
993 | auto out = IrBuilder::create<kir::IntPair>(); |
994 | |
995 | // Extents are assumed constant in swizzle so no need to |
996 | // duplicate their graphs. |
997 | IrBuilder::create<kir::Swizzle2DInt>( |
998 | out, |
999 | in_x, |
1000 | in_y, |
1001 | swizzle_2d->extentX(), |
1002 | swizzle_2d->extentY(), |
1003 | swizzle_2d->swizzleType()); |
1004 | last_visited_val_ = out; |
1005 | } |
1006 | |
1007 | void handle(kir::PairSelect* pair_select) override { |
1008 | handle(pair_select->in()->asVal()); |
1009 | auto in = last_visited_val_; |
1010 | TORCH_INTERNAL_ASSERT(pair_select->out()->isA<Int>()); |
1011 | auto out = IrBuilder::create<Int>(c10::nullopt); |
1012 | IrBuilder::create<kir::PairSelect>( |
1013 | out, in->as<kir::IntPair>(), pair_select->selection()); |
1014 | last_visited_val_ = out; |
1015 | } |
1016 | |
1017 | private: |
1018 | const std::unordered_map<Val*, Val*>& replacement_map_; |
1019 | Val* last_visited_val_ = nullptr; |
1020 | bool is_replaced_ = false; |
1021 | }; |
1022 | |
1023 | } // namespace |
1024 | |
1025 | Val* replaceValInIndexVal( |
1026 | Val* index, |
1027 | const std::unordered_map<Val*, Val*>& replacement_map) { |
1028 | return ReplaceValInIndexVal::replace(index, replacement_map); |
1029 | } |
1030 | |
1031 | } // namespace ir_utils |
1032 | } // namespace cuda |
1033 | } // namespace fuser |
1034 | } // namespace jit |
1035 | } // namespace torch |
1036 | |