1 | #include <lower_predicate_elimination.h> |
2 | |
3 | #include <arith.h> |
4 | #include <expr_evaluator.h> |
5 | #include <instrumentation.h> |
6 | #include <ir_iostream.h> |
7 | #include <ir_utils.h> |
8 | #include <lower2device.h> |
9 | #include <lower_shift.h> |
10 | #include <lower_utils.h> |
11 | #include <predicate_compute.h> |
12 | #include <transform_iter.h> |
13 | #include <transform_replay.h> |
14 | |
15 | namespace torch { |
16 | namespace jit { |
17 | namespace fuser { |
18 | namespace cuda { |
19 | |
20 | namespace { |
21 | |
22 | // Warp primitives are currently limited to un-predicated usage, |
23 | // predicating these ops will require extra steps to ensure that |
24 | // the whole warp will get the same value. |
25 | void assertOnWarpOps(const Expr* expr) { |
26 | TORCH_INTERNAL_ASSERT( |
27 | !ir_utils::isLdMatrixOp(expr), |
28 | "Predicate elimination: cannot eliminate pred for ldmatrix, use exact parallel dims" , |
29 | expr->toString()); |
30 | TORCH_INTERNAL_ASSERT( |
31 | !expr->isA<MmaOp>(), |
32 | "Mma op: cannot eliminate predicate for mma op, tiling not valid. " , |
33 | expr->toString()); |
34 | } |
35 | |
36 | } // namespace |
37 | |
38 | namespace { |
39 | |
40 | // Utility to check if the scheduled domain of the given |
41 | // TensorView represent an exact shared mem access, meaning |
42 | // that all the thread parallel dimensions on the leaf nodes |
43 | // are exact so that the shared mem read/write would not |
44 | // run out of bound because of thread over-subscription. |
45 | bool isExactParallelSharedMemAccess(TensorView* tv) { |
46 | auto& parallel_dimension_map = GpuLower::current()->parallelDimensionMap(); |
47 | for (auto id : tv->domain()->domain()) { |
48 | if (id->isThreadDim()) { |
49 | auto ptype = id->getParallelType(); |
50 | // Need to predicate to avoid out of bound access |
51 | // because of over-subscribed block size. |
52 | if (!parallel_dimension_map.isExact(ptype)) { |
53 | return false; |
54 | } |
55 | } |
56 | } |
57 | return true; |
58 | } |
59 | |
60 | class PredicateAnalyzer : public OptOutDispatch { |
61 | public: |
62 | //! Checks if a predicate is needed to avoid out-of-bound accesses. |
63 | //! |
64 | //! Due to the way we allocate local-memory tensors, there should |
65 | //! never be out-of-bound accesses with consumer tensors when allocated on |
66 | //! local memory. However, accessing producer tensors still may |
67 | //! result in out-of-bound as they are replayed as consumers. |
68 | static bool needsPredicate(TensorView* producer, TensorView* consumer) { |
69 | // Both tensors must be on local or shared memory. Global tensors must be |
70 | // predicated as allocation is done based on root domains. Smem |
71 | // and local tensors are allocated based on leaf domains. |
72 | // However, smem tensors are parallelized, which is highly likely, the size |
73 | // of the parallelized axis is the actual size of the axis, not |
74 | // the number of threads. This is currently actively checked to avoid |
75 | // out of bound shared mem access by out of bound threads. |
76 | if (producer->getMemoryType() == MemoryType::Global || |
77 | consumer->getMemoryType() == MemoryType::Global) { |
78 | return true; |
79 | } |
80 | |
81 | auto pairwise_map = PairwiseRootDomainMap(producer, consumer); |
82 | auto c2p = |
83 | BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) |
84 | .getReplay(); |
85 | |
86 | PredicateAnalyzer analyzer(c2p); |
87 | |
88 | for (auto id : consumer->domain()->domain()) { |
89 | if (analyzer.needsPredicate(id)) { |
90 | return true; |
91 | } |
92 | } |
93 | |
94 | return false; |
95 | } |
96 | |
97 | private: |
98 | PredicateAnalyzer(const std::unordered_map<IterDomain*, IterDomain*>& c2p_map) |
99 | : c2p_map_(c2p_map) {} |
100 | |
101 | // Returns true if no out-of-bound accesses could occur with a |
102 | // producer |
103 | bool needsPredicate(IterDomain* consumer_id) { |
104 | needs_predicate_ = false; |
105 | handle(consumer_id); |
106 | return needs_predicate_; |
107 | } |
108 | |
109 | void handle(IterDomain* consumer_id) override { |
110 | // The traversal should have ended if needs_predicate_ was true |
111 | TORCH_INTERNAL_ASSERT(!needs_predicate_); |
112 | |
113 | // If consumer_id is not going to be materialized as a loop (e.g., |
114 | // broadcast), no need to predicate |
115 | if (consumer_id->isBroadcast() || |
116 | GpuLower::current()->trivialReductionInfo().isDerived(consumer_id)) { |
117 | return; |
118 | } |
119 | |
120 | // If the producer has a matching domain, it should not cause |
121 | // out-of-bound accesses |
122 | if (c2p_map_.find(consumer_id) != c2p_map_.end()) { |
123 | return; |
124 | } |
125 | |
126 | // If no definition exists, stop traversing |
127 | if (consumer_id->definition() == nullptr) { |
128 | return; |
129 | } |
130 | |
131 | OptOutDispatch::handle(consumer_id->definition()); |
132 | } |
133 | |
134 | // If it splits the input axis evenly, proceeds to check the input |
135 | // axis. Otherwise, we can't skip predication as it might cause |
136 | // out-bound accesses with the producer tensor |
137 | void handle(Split* split) override { |
138 | auto factor = split->factor()->getInt(); |
139 | if (!factor.has_value()) { |
140 | needs_predicate_ = true; |
141 | return; |
142 | } |
143 | |
144 | ExpressionEvaluator ee(split->fusion()); |
145 | const auto in_extent = ee.evaluate(split->in()->extent()); |
146 | |
147 | if (!in_extent.has_value() || ((in_extent.value() % factor.value()) != 0)) { |
148 | needs_predicate_ = true; |
149 | return; |
150 | } |
151 | |
152 | handle(split->in()); |
153 | } |
154 | |
155 | void handle(Merge* merge) override { |
156 | handle(merge->inner()); |
157 | if (needs_predicate_) { |
158 | return; |
159 | } |
160 | handle(merge->outer()); |
161 | } |
162 | |
163 | private: |
164 | //! BestEffort map from consumer IDs to producer IDs |
165 | const std::unordered_map<IterDomain*, IterDomain*>& c2p_map_; |
166 | bool needs_predicate_ = false; |
167 | }; |
168 | |
169 | class PredicateChcker : public IterVisitor { |
170 | public: |
171 | static bool needsPredicate( |
172 | Expr* expr, |
173 | const std::unordered_set<const Expr*>& non_predicated_exprs) { |
174 | if (!ir_utils::isTvOp(expr)) { |
175 | return false; |
176 | } |
177 | |
178 | PredicateChcker checker(non_predicated_exprs); |
179 | checker.handle(expr); |
180 | return checker.needs_predicate_; |
181 | } |
182 | |
183 | private: |
184 | PredicateChcker(const std::unordered_set<const Expr*>& non_predicated_exprs) |
185 | : non_predicated_exprs_(non_predicated_exprs) {} |
186 | |
187 | using IterVisitor::handle; |
188 | |
189 | void handle(Expr* expr) final { |
190 | needs_predicate_ = predicateIntDiv(expr) || |
191 | predicateMisalignedVectorize(expr) || predicateShift(expr) || |
192 | predicateSharedMemAccess(expr) || predicateProducerConsumerPair(expr) || |
193 | predicateNonDivisibleRootDomains(expr) || |
194 | predicateNonDivisibleSplit(expr) || predicateExpandReduce(expr); |
195 | |
196 | // A cp.async op would need a predicate for either the global |
197 | // input or its shared mem output, or both. |
198 | // Due to the WAR discussed in [Predicate Inversion for CpAsync], |
199 | // we currently cannot support use cases where both the gmem read |
200 | // and the smem write need to be predicated. |
201 | // Adding a check here would make the exclusion of such case as precise as |
202 | // possible and avoid duplication of predicateSharedMemAccess |
203 | // logic. But this part along with [Predicate Inversion for CpAsync] |
204 | // should be cleaned up all together when we extend predicate/masking |
205 | // logic to cover this usage. |
206 | TORCH_INTERNAL_ASSERT( |
207 | !(ir_utils::isCpAsyncOp(expr) && predicateSharedMemAccess(expr)), |
208 | "predicate removal: unsupported use case of cp.async" ); |
209 | |
210 | if (needs_predicate_) { |
211 | return; |
212 | } |
213 | |
214 | // Check ExprType-specific conditions |
215 | IterVisitor::handle(expr); |
216 | } |
217 | |
218 | // All "predicateXYZ" functions return true if an expr needs to be |
219 | // predicated. |
220 | |
221 | // Always predicate integer division and related ops as we don't |
222 | // know what values are in the out-of-bound region and they may |
223 | // cause exceptions |
224 | bool predicateIntDiv(Expr* expr) const { |
225 | auto dt = expr->outputs()[0]->getDataType().value(); |
226 | return ( |
227 | (dt == DataType::Int || dt == DataType::Int32) && |
228 | expr->isA<BinaryOp>() && |
229 | (expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Div || |
230 | expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Mod || |
231 | expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::Remainder || |
232 | expr->as<BinaryOp>()->getBinaryOpType() == BinaryOpType::CeilDiv)); |
233 | } |
234 | |
235 | // If we're reducing an expanded domain, we need to be careful to predicate it |
236 | // or we could end up reducing a broadcasted value too many times. |
237 | bool predicateExpandReduce(Expr* expr) const { |
238 | if (!ir_utils::isReductionOp(expr)) { |
239 | return false; |
240 | } |
241 | auto tv_inputs = ir_utils::getTvs(expr->inputs()); |
242 | TORCH_INTERNAL_ASSERT( |
243 | tv_inputs.size() > 0, |
244 | "Should never have a reduction op without a tensor view input." ); |
245 | bool found_expand = false; |
246 | for (auto tv_input : tv_inputs) { |
247 | found_expand |= std::any_of( |
248 | tv_input->getMaybeRFactorDomain().begin(), |
249 | tv_input->getMaybeRFactorDomain().end(), |
250 | [](IterDomain* id) { return id->hasExpandedExtent(); }); |
251 | } |
252 | |
253 | if (!found_expand) { |
254 | return false; |
255 | } |
256 | |
257 | auto tv_outputs = ir_utils::getTvs(expr->outputs()); |
258 | if (expr->isA<WelfordOp>() && tv_inputs.size() != tv_outputs.size()) { |
259 | tv_outputs = std::vector<TensorView*>(tv_inputs.size(), tv_outputs[0]); |
260 | } |
261 | |
262 | TORCH_INTERNAL_ASSERT( |
263 | tv_outputs.size() == tv_inputs.size(), |
264 | "Was expecting matching number of inputs and outputs for expression: " , |
265 | expr->toString()); |
266 | |
267 | for (auto i : c10::irange(tv_inputs.size())) { |
268 | const auto root_p2c = |
269 | PairwiseRootDomainMap(tv_inputs[i], tv_outputs[i]) |
270 | .mapProducerToConsumer( |
271 | tv_inputs[i]->domain(), tv_outputs[i]->domain()); |
272 | for (auto entry : root_p2c) { |
273 | auto p_id = entry.first; |
274 | auto c_id = entry.second; |
275 | if (p_id->hasExpandedExtent() && c_id->isReduction()) { |
276 | return true; |
277 | } |
278 | } |
279 | } |
280 | return false; |
281 | } |
282 | |
283 | // Skip if MisalignedVectorize is involved for now. This could be |
284 | // relaxed. |
285 | bool predicateMisalignedVectorize(Expr* expr) const { |
286 | std::vector<const std::vector<Val*>*> inputs_and_outputs = { |
287 | &(expr->inputs()), &(expr->outputs())}; |
288 | for (const auto& inputs_or_outputs : inputs_and_outputs) { |
289 | for (auto tv : ir_utils::filterByType<TensorView>(*inputs_or_outputs)) { |
290 | if (std::any_of( |
291 | tv->domain()->domain().begin(), |
292 | tv->domain()->domain().end(), |
293 | [](IterDomain* axis) { |
294 | return axis->getParallelType() == |
295 | ParallelType::MisalignedVectorize; |
296 | })) { |
297 | return true; |
298 | } |
299 | } |
300 | } |
301 | return false; |
302 | } |
303 | |
304 | // Shift is not supported yet. |
305 | bool predicateShift(Expr* expr) const { |
306 | auto halo_info = GpuLower::current()->haloInfo(); |
307 | auto input_tvs = ir_utils::filterByType<TensorView>(expr->inputs()); |
308 | return halo_info->needsShiftPredicate(expr) || |
309 | std::any_of(input_tvs.begin(), input_tvs.end(), [&](auto input_tv) { |
310 | return input_tv->definition() != nullptr && |
311 | halo_info->needsShiftPredicate(input_tv->definition()); |
312 | }); |
313 | } |
314 | |
315 | // Predicates the expression if any producer-consumer pair of the |
316 | // expression needs to be predicated |
317 | bool predicateProducerConsumerPair(Expr* expr) const { |
318 | for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) { |
319 | for (auto input : ir_utils::filterByType<TensorView>(expr->inputs())) { |
320 | if (PredicateAnalyzer::needsPredicate(input, output)) { |
321 | return true; |
322 | } |
323 | } |
324 | } |
325 | return false; |
326 | } |
327 | |
328 | bool predicateSharedMemAccess(Expr* expr) const { |
329 | // This is initial step to gradually remove predicates around |
330 | // sharedmem access in suitable situations. |
331 | // Using an additional variable to track the predicate-on reasons |
332 | // when the predicate around shared mem cannot be removed. |
333 | for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) { |
334 | for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) { |
335 | if (producer->getMemoryType() == MemoryType::Shared || |
336 | consumer->getMemoryType() == MemoryType::Shared) { |
337 | if (needSharedMemPredicate(producer, consumer)) { |
338 | return true; |
339 | } |
340 | } |
341 | } |
342 | } |
343 | |
344 | return false; |
345 | } |
346 | |
347 | // Check for conditions where the predicate cannot be removed |
348 | // when either producer or consumer is in shared memory. |
349 | bool needSharedMemPredicate(TensorView* producer, TensorView* consumer) |
350 | const { |
351 | // Indexing is based on consumer leaf ids so check the consumer. |
352 | |
353 | // If consumer schedule contains in-exact thread parallel |
354 | // dimensions, need to predicate against out of bound |
355 | // shared memory access by out of bound threads. |
356 | if (!isExactParallelSharedMemAccess(consumer)) { |
357 | return true; |
358 | } |
359 | |
360 | // TODO: This is directed WAR on FusionPersistentNormLocalShared. |
361 | // This use case along with other previous issues motivate a |
362 | // joint optimization of predicate removal and buffer reuse. |
363 | // In this particular case: |
364 | // __shared__ T0 [10], T1[10] |
365 | // for i in ... |
366 | // if(pred) |
367 | // T1[i] = T0[i] + ... // exp0 |
368 | // T2 = 0; // init for exp1 |
369 | // if(pred) |
370 | // T2 = T1 ... // exp1 |
371 | // If we remove pred around expr1, as the way the pred removal |
372 | // pass is set up, the init for expr will be pushed up to |
373 | // initialize T1 instead. |
374 | // However if we initialize T1, the code will look like: |
375 | // for i in ... |
376 | // T1[i] = 0; |
377 | // for i in ... |
378 | // if(pred) |
379 | // T1[i] = T0[i] + ... |
380 | // Note that we'd be able to reuse buffer of T0 for T1 but |
381 | // if we initialze T1 we cannot do that and thus the |
382 | // kernel would not fit in smaller devices. |
383 | if (producer->getMemoryType() == MemoryType::Shared) { |
384 | if (auto producer_def = producer->definition()) { |
385 | if (std::any_of( |
386 | producer_def->inputs().begin(), |
387 | producer_def->inputs().end(), |
388 | [](Val* val) { |
389 | if (auto tv = ir_utils::getTv(val)) { |
390 | return tv->getMemoryType() == MemoryType::Shared; |
391 | } |
392 | return false; |
393 | })) { |
394 | // Disable shared memory producers that is a consumer |
395 | // of another shared memory tensor. The initialization would |
396 | // break potential opportunity to re-use shared mem buffer. |
397 | return true; |
398 | } |
399 | } |
400 | } |
401 | |
402 | for (auto id : consumer->domain()->domain()) { |
403 | // TODO: (Enable in a follow up) |
404 | // smem predicate removal with init would break unroll and unswitch, |
405 | // eg. as in issue 1133, so disabling this removal pattern for now. |
406 | if (id->getParallelType() == ParallelType::Unroll || |
407 | id->getParallelType() == ParallelType::Unswitch) { |
408 | return true; |
409 | } |
410 | |
411 | // TODO: (Enable in a follow up) |
412 | // This cannot yet be removed since smem initialization needs to be |
413 | // handled specially, e.g. as in smem_reduce test. Will be able to |
414 | // lift this one once the generic pred removal pass with fusion |
415 | // traversal is ready. |
416 | auto consumer_def = consumer->definition(); |
417 | if (ir_utils::isReductionOp(consumer_def)) { |
418 | if (producer->getMemoryType() == MemoryType::Shared) { |
419 | return true; |
420 | } |
421 | } |
422 | } |
423 | |
424 | return false; |
425 | } |
426 | |
427 | // Utility to find the leaf iterdomains of the given |
428 | // tensor view that will be treated as "zero loops" |
429 | // in the indexing pass. |
430 | // For details on zero loops, see indexMapFromTV in |
431 | // lower index pass. |
432 | std::vector<Val*> getZeroLeafIds(const TensorView* tv) const { |
433 | TORCH_INTERNAL_ASSERT( |
434 | tv->getMemoryType() == MemoryType::Local || |
435 | tv->getMemoryType() == MemoryType::Shared, |
436 | "Local or shared memory tensor is assumed: " , |
437 | tv->toString()); |
438 | bool is_shared_mem = tv->getMemoryType() == MemoryType::Shared; |
439 | std::vector<Val*> zero_leaf_ids; |
440 | for (const auto i : c10::irange(tv->nDims())) { |
441 | auto leaf_id = tv->axis(i); |
442 | if (is_shared_mem && leaf_id->isThreadDim()) { |
443 | // Thread parallel axes on shared mem are never |
444 | // zero loops as each thread owns its share |
445 | // of the shared mem space. |
446 | continue; |
447 | } |
448 | if ( |
449 | // Non-thread parallel dimension on the left |
450 | // of CA axes are zero loops. |
451 | i < tv->getComputeAtPosition() || |
452 | // Parallel axes on local mem is zero loop. |
453 | // Grid axes on shared mem is zero loop. |
454 | leaf_id->isThread() || |
455 | // Mma axes, similar to vectorization, are |
456 | // implicit in hardware intrinsics, and thus |
457 | // will be treated as a zero loop. |
458 | leaf_id->isMma()) { |
459 | zero_leaf_ids.push_back(leaf_id); |
460 | } |
461 | } |
462 | |
463 | return zero_leaf_ids; |
464 | } |
465 | |
466 | // An index can exceed the logical extent of the indexed domain if |
467 | // it's split. It can cause a reduction op to reduce the same value |
468 | // multiple times. Even a pointwise op can be a problem if the |
469 | // consumer is an alias of the producer. This check excludes such |
470 | // expressions from predicate elimination. |
471 | // |
472 | // This is not an issue if the index includes a zero domain (as defined in |
473 | // index_compute.cpp), the extent is calculated by multiplying the |
474 | // split output domains, so it never cross the domain boundary. |
475 | // So, if a root domain is split and none of its descendants is a |
476 | // zero domain, the expr needs to be predicated. See |
477 | // FusionPredicateElimination6 for a concrete example. |
478 | // |
479 | // It would be also possible to avoid register aliasing instead of |
480 | // giving up predicate elimination. Since this condition should be |
481 | // rather uncommon, either would be fine as long as correctness is |
482 | // provided. |
483 | bool predicateNonDivisibleRootDomains(Expr* expr) const { |
484 | for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) { |
485 | const auto all_exprs = DependencyCheck::getAllExprsBetween( |
486 | {output->getMaybeRFactorDomain().begin(), |
487 | output->getMaybeRFactorDomain().end()}, |
488 | {output->domain()->domain().begin(), |
489 | output->domain()->domain().end()}); |
490 | std::unordered_set<Val*> split_root; |
491 | std::copy_if( |
492 | output->getMaybeRFactorDomain().begin(), |
493 | output->getMaybeRFactorDomain().end(), |
494 | std::inserter(split_root, split_root.end()), |
495 | [&](auto rf_root) { |
496 | if (rf_root->isBroadcast() || |
497 | GpuLower::current()->trivialReductionInfo().isDerived( |
498 | rf_root)) { |
499 | return false; |
500 | } |
501 | for (Expr* use : rf_root->uses()) { |
502 | if (std::find(all_exprs.begin(), all_exprs.end(), use) == |
503 | all_exprs.end()) { |
504 | continue; |
505 | } |
506 | return use->isA<Split>(); |
507 | } |
508 | return false; |
509 | }); |
510 | // If no root domain is split, no need to predicate |
511 | if (split_root.empty()) { |
512 | continue; |
513 | } |
514 | const auto zero_leaf_ids = getZeroLeafIds(output); |
515 | if (zero_leaf_ids.empty()) { |
516 | return true; |
517 | } |
518 | const auto vals = |
519 | DependencyCheck::getAllValsBetween(split_root, zero_leaf_ids); |
520 | if (std::any_of( |
521 | split_root.begin(), |
522 | split_root.end(), |
523 | [&vals](auto split_root_id) { |
524 | return std::find(vals.begin(), vals.end(), split_root_id) == |
525 | vals.end(); |
526 | })) { |
527 | return true; |
528 | } |
529 | } |
530 | return false; |
531 | } |
532 | |
533 | // Always predicate if non-divisible split is found. It may be |
534 | // possible to make it less conservative. |
535 | // See FusionPredicateElimination7 for a concrete example. |
536 | bool predicateNonDivisibleSplit(Expr* expr) const { |
537 | const auto& non_divisible_split_info = |
538 | GpuLower::current()->nonDivisibleSplitInfo(); |
539 | for (auto output : ir_utils::filterByType<TensorView>(expr->outputs())) { |
540 | if (non_divisible_split_info.splitsToPredicate().find(output) != |
541 | non_divisible_split_info.splitsToPredicate().end()) { |
542 | return true; |
543 | } |
544 | } |
545 | return false; |
546 | } |
547 | |
548 | // If this is a reduction, and if we omit the predicate for the |
549 | // input, the input may have a garbabe value, which must not be used |
550 | // for this reduction. However, it is still legal to omit its |
551 | // predicate when: 1) the predicate of the input is not omitted and |
552 | // 2) the input can be initialized to the init value of this |
553 | // reduction. When the input is the output of another reduciton, the |
554 | // input is initialized to the init value of the reduction, so the |
555 | // two reductions must use the same init value. |
556 | // See FusionPredicateElimination3 and FusionPredicateElimination4 |
557 | // for concrete examples. |
558 | void handle(ReductionOp* rop) final { |
559 | auto input = rop->inputs()[0]->as<TensorView>(); |
560 | auto input_def = input->definition(); |
561 | // When input_def is null, input must be an input to the fusion, |
562 | // so that must be allocated on global memory. Since we don't omit |
563 | // predication for expressions involving global memory, this |
564 | // should never occur. |
565 | TORCH_INTERNAL_ASSERT( |
566 | input_def != nullptr, "Inconsistent input found: " , input->toString()); |
567 | |
568 | // The input needs to be initialized to the init value to omit |
569 | // the predicate, so if the input has its own init value, i.e., |
570 | // produced by another reduction, they must use the same init |
571 | // value. |
572 | Val* input_init = ir_utils::getReductionInitValOf(input); |
573 | if (input_init != nullptr && !rop->init()->sameAs(input_init)) { |
574 | needs_predicate_ = true; |
575 | return; |
576 | } |
577 | |
578 | // If input is not predicated, out-of-bound value may be |
579 | // overwritten by a garbage value. However, it doesn't matter if |
580 | // the input is also produced by another reduction. If the preceding |
581 | // reduction omits the predicate, it means its input must be |
582 | // initialized to its init value, so no predicate should be |
583 | // needed in both of the two reduction ops if they use the same |
584 | // init value, which is guaranteed by the above check, and the |
585 | // same reduction op. |
586 | if (auto input_def_rop = dynamic_cast<ReductionOp*>(input_def)) { |
587 | if (rop->getReductionOpType() != input_def_rop->getReductionOpType() && |
588 | non_predicated_exprs_.find(input_def) != |
589 | non_predicated_exprs_.end()) { |
590 | needs_predicate_ = true; |
591 | return; |
592 | } |
593 | } else if ( |
594 | non_predicated_exprs_.find(input_def) != non_predicated_exprs_.end()) { |
595 | needs_predicate_ = true; |
596 | return; |
597 | } |
598 | } |
599 | |
600 | // Welford. See FusionPredicateElimination5. |
601 | void handle(WelfordOp* wop) final { |
602 | for (const auto i : c10::irange(3)) { |
603 | auto init = wop->getInitVals()[i]; |
604 | |
605 | // Welford input can be a scalar. Predicate is required unless |
606 | // the scalar value is equal to the init value. |
607 | auto input = wop->inputs().at(i); |
608 | if (input->isScalar()) { |
609 | if (!input->sameAs(init)) { |
610 | needs_predicate_ = true; |
611 | return; |
612 | } |
613 | continue; |
614 | } |
615 | |
616 | auto input_tv = dynamic_cast<TensorView*>(input); |
617 | TORCH_INTERNAL_ASSERT(input_tv != nullptr); |
618 | |
619 | auto input_def = input->definition(); |
620 | |
621 | // When input_def is null, input must be an input to the fusion, |
622 | // so that must be allocated on global memory. Since we don't omit |
623 | // predication for expressions involving global memory, this |
624 | // should never occur. |
625 | TORCH_INTERNAL_ASSERT( |
626 | input_def != nullptr, |
627 | "Inconsistent input found: " , |
628 | input->toString()); |
629 | |
630 | // The input needs to be initialized to the init value to omit |
631 | // the predicate, so if the input has its own init value, i.e., |
632 | // produced by another reduction, they must use the same init |
633 | // value. |
634 | Val* input_init = ir_utils::getReductionInitValOf(input_tv); |
635 | if (input_init != nullptr && !init->sameAs(input_init)) { |
636 | needs_predicate_ = true; |
637 | return; |
638 | } |
639 | |
640 | // If input is not predicated, out-of-bound value may be |
641 | // overwritten by a garbage value. However, it doesn't matter if |
642 | // the input is also produced by another welford. |
643 | if (!input_def->isA<WelfordOp>() && !input_def->isA<GroupedWelfordOp>() && |
644 | non_predicated_exprs_.find(input_def) != |
645 | non_predicated_exprs_.end()) { |
646 | needs_predicate_ = true; |
647 | return; |
648 | } |
649 | } |
650 | } |
651 | |
652 | void handle(GroupedReductionOp* grouped_rop) final { |
653 | for (const auto i : c10::irange(grouped_rop->numExprs())) { |
654 | auto input = grouped_rop->input(i)->as<TensorView>(); |
655 | auto input_def = input->definition(); |
656 | // When input_def is null, input must be an input to the fusion, |
657 | // so that must be allocated on global memory. Since we don't omit |
658 | // predication for expressions involving global memory, this |
659 | // should never occur. |
660 | TORCH_INTERNAL_ASSERT( |
661 | input_def != nullptr, |
662 | "Inconsistent input found: " , |
663 | input->toString()); |
664 | |
665 | // The input needs to be initialized to the init value to omit |
666 | // the predicate, so if the input has its own init value, i.e., |
667 | // produced by another reduction, they must use the same init |
668 | // value. |
669 | Val* input_init = ir_utils::getReductionInitValOf(input); |
670 | if (input_init != nullptr && |
671 | !grouped_rop->initVal(i)->sameAs(input_init)) { |
672 | needs_predicate_ = true; |
673 | return; |
674 | } |
675 | |
676 | // If input is not predicated, out-of-bound value may be |
677 | // overwritten by a garbage value. However, it doesn't matter if |
678 | // the input is also produced by another reduction. If the preceding |
679 | // reduction omits the predicate, it means its input must be |
680 | // initialized to its init value, so no predicate should be |
681 | // needed in both of the two reduction ops if they use the same |
682 | // init value, which is guaranteed by the above check, and the |
683 | // same reduction op. |
684 | if (auto input_def_rop = dynamic_cast<ReductionOp*>(input_def)) { |
685 | if (grouped_rop->getReductionOpType(i) != |
686 | input_def_rop->getReductionOpType() && |
687 | non_predicated_exprs_.find(input_def) != |
688 | non_predicated_exprs_.end()) { |
689 | needs_predicate_ = true; |
690 | return; |
691 | } |
692 | } else if ( |
693 | auto input_def_grouped_rop = |
694 | dynamic_cast<GroupedReductionOp*>(input_def)) { |
695 | auto input_index_as_output = |
696 | input_def_grouped_rop->getExprIndexOfOutput(input); |
697 | if (grouped_rop->getReductionOpType(i) != |
698 | input_def_grouped_rop->getReductionOpType( |
699 | input_index_as_output) && |
700 | non_predicated_exprs_.find(input_def) != |
701 | non_predicated_exprs_.end()) { |
702 | needs_predicate_ = true; |
703 | return; |
704 | } |
705 | } else if ( |
706 | non_predicated_exprs_.find(input_def) != |
707 | non_predicated_exprs_.end()) { |
708 | needs_predicate_ = true; |
709 | return; |
710 | } |
711 | } |
712 | } |
713 | |
714 | void handle(GroupedWelfordOp* grouped_wop) final { |
715 | for (const auto expr_idx : c10::irange(grouped_wop->numExprs())) { |
716 | for (const auto val_idx : c10::irange(3)) { |
717 | auto init = grouped_wop->initVals().at(expr_idx).get(val_idx); |
718 | |
719 | // Welford input can be a scalar. Predicate is required unless |
720 | // the scalar value is equal to the init value. |
721 | auto input = grouped_wop->inputVals().at(expr_idx).get(val_idx); |
722 | if (input->isScalar()) { |
723 | if (!input->sameAs(init)) { |
724 | needs_predicate_ = true; |
725 | return; |
726 | } |
727 | continue; |
728 | } |
729 | |
730 | auto input_tv = dynamic_cast<TensorView*>(input); |
731 | TORCH_INTERNAL_ASSERT(input_tv != nullptr); |
732 | |
733 | auto input_def = input->definition(); |
734 | |
735 | // When input_def is null, input must be an input to the fusion, |
736 | // so that must be allocated on global memory. Since we don't omit |
737 | // predication for expressions involving global memory, this |
738 | // should never occur. |
739 | TORCH_INTERNAL_ASSERT( |
740 | input_def != nullptr, |
741 | "Inconsistent input found: " , |
742 | input->toString()); |
743 | |
744 | // The input needs to be initialized to the init value to omit |
745 | // the predicate, so if the input has its own init value, i.e., |
746 | // produced by another reduction, they must use the same init |
747 | // value. |
748 | Val* input_init = ir_utils::getReductionInitValOf(input_tv); |
749 | if (input_init != nullptr && !init->sameAs(input_init)) { |
750 | needs_predicate_ = true; |
751 | return; |
752 | } |
753 | |
754 | // If input is not predicated, out-of-bound value may be |
755 | // overwritten by a garbage value. However, it doesn't matter if |
756 | // the input is also produced by another reduction op as it |
757 | // must be initialized and its initialized value is already |
758 | // found to be equal to the initil value of this op. |
759 | if (!input_def->isA<WelfordOp>() && |
760 | !input_def->isA<GroupedWelfordOp>() && |
761 | non_predicated_exprs_.find(input_def) != |
762 | non_predicated_exprs_.end()) { |
763 | needs_predicate_ = true; |
764 | return; |
765 | } |
766 | } |
767 | } |
768 | } |
769 | |
770 | // Similar to the above reduction constraint but for MMA |
771 | void handle(MmaOp* mma) final { |
772 | for (auto input : ir_utils::filterByType<TensorView>(mma->inputs())) { |
773 | auto input_def = input->definition(); |
774 | TORCH_INTERNAL_ASSERT( |
775 | input_def != nullptr, |
776 | "Inconsistent input found: " , |
777 | input->toString()); |
778 | |
779 | Val* input_init = ir_utils::getReductionInitValOf(input); |
780 | if (input_init != nullptr && !mma->init()->sameAs(input_init)) { |
781 | needs_predicate_ = true; |
782 | return; |
783 | } |
784 | |
785 | if (non_predicated_exprs_.find(input_def) != |
786 | non_predicated_exprs_.end()) { |
787 | // If producer of mma is non_predicated and initialized |
788 | // with the same value. The mma should not need a |
789 | // predicate. In fact this is the only way we can |
790 | // use mma at the moment since we could not predicate |
791 | // mma ops without guaranteeing warp uniform results. |
792 | auto input_init = |
793 | GpuLower::current()->predicateElimination().getInitValue(input); |
794 | |
795 | // TODO: |
796 | // clean up this to support more generic prolog fusion. |
797 | // Will need additional analysis passes on initialization |
798 | // propagation and further predicate placement on top. |
799 | // More TODO: |
800 | // Even when producer is initialized, it is still generally |
801 | // not safe to remove predicate around reduction ops if the |
802 | // producer is not predicated. |
803 | // On the other side, we do have patterns like ldmatrix->mma where |
804 | // both producer and consumer cannot be safely predicated without |
805 | // guaranteeing warp uniform results. |
806 | // This is currently a WAR and relies on validation pass to exclude |
807 | // complex prolog patterns in mma based matmul kernels. Will |
808 | // definitely need to revisit and build out predicate and |
809 | // initialization analysis pass to better handle this case. |
810 | if (input_init != nullptr && !input_init->sameAs(mma->init())) { |
811 | // This is a WAR at the moment. We would need to propagate |
812 | // initialization information from PredicateElimination |
813 | // pass to most accurately detect if the input is |
814 | // initialized correctly. |
815 | // This could also be fixed when we have the traversal |
816 | // based predicate elimination and initialization pass |
817 | // ready. Would be easy to clean up this part at that point. |
818 | needs_predicate_ = true; |
819 | return; |
820 | } |
821 | } |
822 | } |
823 | } |
824 | |
825 | private: |
826 | const std::unordered_set<const Expr*>& non_predicated_exprs_; |
827 | bool needs_predicate_ = false; |
828 | }; |
829 | |
830 | } // namespace |
831 | |
832 | bool PredicateElimination::needsPredicate(Expr* expr) const { |
833 | return PredicateChcker::needsPredicate(expr, non_predicated_exprs_); |
834 | } |
835 | |
836 | void PredicateElimination::handle(Expr* expr) { |
837 | if (!ir_utils::isTvOp(expr)) { |
838 | return; |
839 | } |
840 | |
841 | if (needsPredicate(expr)) { |
842 | assertOnWarpOps(expr); |
843 | return; |
844 | } |
845 | |
846 | non_predicated_exprs_.insert(expr); |
847 | |
848 | // Ensure all inputs have some values set at the out-of-bound |
849 | // regions |
850 | for (const auto i : c10::irange(expr->inputs().size())) { |
851 | auto input = dynamic_cast<TensorView*>(expr->inputs()[i]); |
852 | if (input == nullptr) { |
853 | continue; |
854 | } |
855 | auto input_def = input->definition(); |
856 | // When input_def is null, input must be an input to the fusion, |
857 | // so that must be allocated on global memory. Since we don't omit |
858 | // predication for expressions involving global memory, this |
859 | // should never occur. |
860 | TORCH_INTERNAL_ASSERT( |
861 | input_def != nullptr, "Inconsistent input found: " , input->toString()); |
862 | |
863 | // If input is an output of reduction, it should be fully |
864 | // initialied as it's allocated on local memory. |
865 | if (ir_utils::isReductionOp(input_def)) { |
866 | continue; |
867 | } |
868 | |
869 | if (expr->isA<ReductionOp>()) { |
870 | setReductionInitValue(input, expr->as<ReductionOp>()->init()); |
871 | continue; |
872 | } else if (expr->isA<GroupedReductionOp>()) { |
873 | setReductionInitValue(input, expr->as<GroupedReductionOp>()->initVal(i)); |
874 | continue; |
875 | } else if (auto wop = dynamic_cast<WelfordOp*>(expr)) { |
876 | Val* init = wop->getInitVals().at(i); |
877 | setReductionInitValue(input, init); |
878 | continue; |
879 | } else if (expr->isA<MmaOp>()) { |
880 | setReductionInitValue(input, expr->as<MmaOp>()->init()); |
881 | continue; |
882 | } else if ( |
883 | non_predicated_exprs_.find(input_def) != non_predicated_exprs_.end()) { |
884 | // If an input does not need a predicate either, then it should |
885 | // have some value, so no need to set a default value |
886 | continue; |
887 | } else { |
888 | // Make sure input is initialized |
889 | setDefaultInitValue(input); |
890 | } |
891 | } |
892 | } |
893 | |
894 | bool PredicateElimination::setDefaultInitValue(TensorView* tv) { |
895 | auto it = init_value_map_.find(tv); |
896 | // If there's already a mapping for tv, it should be mapped to a |
897 | // zero val or a reduction init. Either case, no need to modify |
898 | // the existing mapping. |
899 | if (it == init_value_map_.end()) { |
900 | init_value_map_.insert({tv, nullptr}); |
901 | } |
902 | return true; |
903 | } |
904 | |
905 | bool PredicateElimination::setReductionInitValue( |
906 | TensorView* tv, |
907 | Val* reduction_init) { |
908 | TORCH_INTERNAL_ASSERT(tv != nullptr); |
909 | |
910 | auto it = init_value_map_.find(tv); |
911 | if (it == init_value_map_.end()) { |
912 | init_value_map_.insert({tv, reduction_init}); |
913 | return true; |
914 | } |
915 | |
916 | auto existing_val = it->second; |
917 | if (existing_val == nullptr) { |
918 | // If the existing mapping returns nullptr, it means that a |
919 | // default init was set before. Overwrite with the reduction |
920 | // init val. |
921 | init_value_map_[tv] = reduction_init; |
922 | return true; |
923 | } else if (existing_val->sameAs(reduction_init)) { |
924 | return true; |
925 | } else { |
926 | TORCH_INTERNAL_ASSERT( |
927 | false, |
928 | "Inconsistent setting of initialization value for t" , |
929 | tv->name(), |
930 | ". Prev: " , |
931 | existing_val->toString(), |
932 | ", New: " , |
933 | reduction_init->toString()); |
934 | return false; |
935 | } |
936 | } |
937 | |
938 | bool PredicateElimination::canOmitPredicate(const Expr* expr) const { |
939 | // Predicate elimination can be disabled with |
940 | // PYTORCH_NVFUSER_DISABLE=predicate_elimination |
941 | if (isOptionDisabled(DisableOption::PredicateElimination)) { |
942 | assertOnWarpOps(expr); |
943 | return false; |
944 | } |
945 | |
946 | TORCH_INTERNAL_ASSERT(expr != nullptr); |
947 | const auto out_tv = ir_utils::getTvOutput(expr); |
948 | TORCH_INTERNAL_ASSERT(out_tv != nullptr, "Not a tensor expression" ); |
949 | |
950 | if (ir_utils::isTensorScalarFillOp(expr)) { |
951 | if (out_tv->getMemoryType() == MemoryType::Local) { |
952 | // Filling a local tensor with scalar shouldn't |
953 | // need any predicate currently. |
954 | return true; |
955 | } else if (out_tv->getMemoryType() == MemoryType::Shared) { |
956 | // A shared memory initialization should be same except |
957 | // that we'd need a predicate to guard against out of |
958 | // bound access by out of inexact threads. |
959 | return isExactParallelSharedMemAccess(out_tv); |
960 | } |
961 | } |
962 | |
963 | if (non_predicated_exprs_.find(expr) != non_predicated_exprs_.end()) { |
964 | return true; |
965 | } |
966 | |
967 | assertOnWarpOps(expr); |
968 | return false; |
969 | } |
970 | |
971 | void PredicateElimination::propagateRemovalInfo( |
972 | const Expr* from, |
973 | const Expr* to) { |
974 | if (non_predicated_exprs_.count(from)) { |
975 | non_predicated_exprs_.insert(to); |
976 | } |
977 | } |
978 | |
979 | Val* PredicateElimination::getInitValue(TensorView* tv) const { |
980 | auto it = init_value_map_.find(tv); |
981 | if (it == init_value_map_.end()) { |
982 | return nullptr; |
983 | } |
984 | auto init_val = it->second; |
985 | if (init_val == nullptr) { |
986 | // No reduction restriction. Just use zero |
987 | return GpuLower::current()->kernel()->zeroVal(); |
988 | } else { |
989 | return init_val; |
990 | } |
991 | } |
992 | |
993 | void PredicateElimination::build(Fusion* fusion) { |
994 | traverseTo(fusion, fusion->outputs()); |
995 | } |
996 | |
997 | std::string PredicateElimination::toString() const { |
998 | std::stringstream ss; |
999 | ss << "Tensors that do not need predication:" ; |
1000 | for (auto expr : non_predicated_exprs_) { |
1001 | for (auto out : expr->outputs()) { |
1002 | TORCH_INTERNAL_ASSERT(out->isA<TensorView>()); |
1003 | ss << " T" << out->name(); |
1004 | } |
1005 | } |
1006 | ss << "\n" ; |
1007 | ss << "Init values:" ; |
1008 | for (auto kv : init_value_map_) { |
1009 | ss << " T" << kv.first->name() << "->" ; |
1010 | if (kv.second == nullptr) { |
1011 | ss << "<default(0)>" ; |
1012 | } else { |
1013 | ss << kv.second; |
1014 | } |
1015 | } |
1016 | ss << "\n" ; |
1017 | return ss.str(); |
1018 | } |
1019 | |
1020 | } // namespace cuda |
1021 | } // namespace fuser |
1022 | } // namespace jit |
1023 | } // namespace torch |
1024 | |