1#pragma once
2#include <c10/macros/Export.h>
3
4#include <ir_all_nodes.h>
5#include <kernel_ir.h>
6
7#include <vector>
8
9namespace torch {
10namespace jit {
11namespace fuser {
12namespace cuda {
13
14class TORCH_CUDA_CU_API PredicateElimination : public IterVisitor {
15 public:
16 void build(Fusion* fusion);
17
18 //! True if expr does not need a predicate
19 //!
20 //! \param expr Tensor expression
21 bool canOmitPredicate(const Expr* expr) const;
22
23 //! Value to initialize out-of-bound regions
24 Val* getInitValue(TensorView* tv) const;
25
26 //! Dump to string for debugging
27 std::string toString() const;
28
29 // A utility to set removal info of `to` the same as `from`.
30 // See issue #1641
31 // We build predicate info before lowering but more expressions
32 // are created during lowering that this class also need to
33 // keep track of to make sure correct predicate removal is
34 // applied.
35 // This utility is a quick patch for the missing information
36 // since it might be better just to recompute predicate info
37 // if all expressions were mutated, but that'd take much more
38 // global info to reliably track.
39 void propagateRemovalInfo(const Expr* from, const Expr* to);
40
41 private:
42 using IterVisitor::handle;
43
44 void handle(Expr* expr) final;
45
46 //! Set a value to initialize out-of-bound regions
47 bool setDefaultInitValue(TensorView* tv);
48 //! Set a value to initialize out-of-bound regions of reduction tensors
49 bool setReductionInitValue(TensorView* tv, Val* reduction_init);
50
51 //! Check if expr needs to be predicated
52 bool needsPredicate(Expr* expr) const;
53
54 private:
55 //! Expressions that are found to be safe without predicates
56 std::unordered_set<const Expr*> non_predicated_exprs_;
57 //! Tensors and their initialization values
58 std::unordered_map<TensorView*, Val*> init_value_map_;
59};
60
61} // namespace cuda
62} // namespace fuser
63} // namespace jit
64} // namespace torch
65