1 | #pragma once |
2 | #include <c10/macros/Export.h> |
3 | |
4 | #include <ir_all_nodes.h> |
5 | #include <kernel_ir.h> |
6 | |
7 | #include <vector> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | namespace fuser { |
12 | namespace cuda { |
13 | |
14 | class TORCH_CUDA_CU_API PredicateElimination : public IterVisitor { |
15 | public: |
16 | void build(Fusion* fusion); |
17 | |
18 | //! True if expr does not need a predicate |
19 | //! |
20 | //! \param expr Tensor expression |
21 | bool canOmitPredicate(const Expr* expr) const; |
22 | |
23 | //! Value to initialize out-of-bound regions |
24 | Val* getInitValue(TensorView* tv) const; |
25 | |
26 | //! Dump to string for debugging |
27 | std::string toString() const; |
28 | |
29 | // A utility to set removal info of `to` the same as `from`. |
30 | // See issue #1641 |
31 | // We build predicate info before lowering but more expressions |
32 | // are created during lowering that this class also need to |
33 | // keep track of to make sure correct predicate removal is |
34 | // applied. |
35 | // This utility is a quick patch for the missing information |
36 | // since it might be better just to recompute predicate info |
37 | // if all expressions were mutated, but that'd take much more |
38 | // global info to reliably track. |
39 | void propagateRemovalInfo(const Expr* from, const Expr* to); |
40 | |
41 | private: |
42 | using IterVisitor::handle; |
43 | |
44 | void handle(Expr* expr) final; |
45 | |
46 | //! Set a value to initialize out-of-bound regions |
47 | bool setDefaultInitValue(TensorView* tv); |
48 | //! Set a value to initialize out-of-bound regions of reduction tensors |
49 | bool setReductionInitValue(TensorView* tv, Val* reduction_init); |
50 | |
51 | //! Check if expr needs to be predicated |
52 | bool needsPredicate(Expr* expr) const; |
53 | |
54 | private: |
55 | //! Expressions that are found to be safe without predicates |
56 | std::unordered_set<const Expr*> non_predicated_exprs_; |
57 | //! Tensors and their initialization values |
58 | std::unordered_map<TensorView*, Val*> init_value_map_; |
59 | }; |
60 | |
61 | } // namespace cuda |
62 | } // namespace fuser |
63 | } // namespace jit |
64 | } // namespace torch |
65 | |