1 | #pragma once |
2 | |
3 | #include <fusion.h> |
4 | #include <ir_all_nodes.h> |
5 | #include <ir_container.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | namespace fuser { |
10 | namespace cuda { |
11 | |
12 | namespace kir { |
13 | class Kernel; |
14 | } |
15 | |
16 | class IrCloner; |
17 | |
18 | // Passkey for builder to register properties with statements, and to call |
19 | // functions in IrContainer |
20 | class TORCH_CUDA_CU_API IrBuilderPasskey { |
21 | friend class IrBuilder; |
22 | |
23 | public: |
24 | // TODO: Collapse ir_container and Kernel once Kernel inherits from |
25 | // IrContainer |
26 | IrContainer* const ir_container_ = nullptr; |
27 | |
28 | private: |
29 | explicit IrBuilderPasskey(IrContainer* ir_container); |
30 | }; |
31 | |
32 | //! IR builder interface |
33 | class TORCH_CUDA_CU_API IrBuilder { |
34 | public: |
35 | //! Allocate a new IR node, forwarding the arguments to the appropriate |
36 | //! constructor and registering with the container |
37 | template <class T, class... Args> |
38 | static T* create(Args&&... args) { |
39 | auto container = FusionGuard::getCurFusion(); |
40 | // return create<T>(container, std::forward<Args>(args)...); |
41 | TORCH_INTERNAL_ASSERT( |
42 | container != nullptr, "Need an active container to build IR." ); |
43 | T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...); |
44 | |
45 | container->registerStmt(IrBuilderPasskey(container), node); |
46 | |
47 | return node; |
48 | } |
49 | |
50 | //! Allocate a new IR node, forwarding the arguments to the appropriate |
51 | //! constructor and registering with the container |
52 | template <class T, class... Args> |
53 | static T* create(IrContainer* container, Args&&... args) { |
54 | TORCH_INTERNAL_ASSERT( |
55 | container != nullptr, "Need an active container to build IR." ); |
56 | T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...); |
57 | |
58 | container->registerStmt(IrBuilderPasskey(container), node); |
59 | |
60 | return node; |
61 | } |
62 | |
63 | //! Clone an IR node, forwarding the arguments to the IrCloner constructor. |
64 | //! Register clones with IrCloner's target container. |
65 | template <class T> |
66 | static T* clone(const T* src, IrCloner* ir_cloner); |
67 | |
68 | // Unary operations |
69 | static Val* negExpr(Val* val); |
70 | static Val* notExpr(Val* val); |
71 | static Val* setExpr(Val* val); |
72 | static Val* setExprNamedScalar(const std::string& name, Val* val); |
73 | static Val* addressExprNamedScalar(const std::string& name, Val* val); |
74 | |
75 | // Binary operations |
76 | static Val* andExpr(Val* lhs, Val* rhs); |
77 | static Val* eqExpr(Val* lhs, Val* rhs); |
78 | static Val* gtExpr(Val* lhs, Val* rhs); |
79 | static Val* ltExpr(Val* lhs, Val* rhs); |
80 | static Val* leExpr(Val* lhs, Val* rhs); |
81 | static Val* geExpr(Val* lhs, Val* rhs); |
82 | static Val* addExpr(Val* lhs, Val* rhs); |
83 | static Val* subExpr(Val* lhs, Val* rhs); |
84 | static Val* mulExpr(Val* lhs, Val* rhs); |
85 | static Val* divExpr(Val* lhs, Val* rhs); |
86 | static Val* ceilDivExpr(Val* lhs, Val* rhs); |
87 | static Val* modExpr(Val* lhs, Val* rhs); |
88 | static Val* maxExpr(Val* lhs, Val* rhs); |
89 | static Val* minExpr(Val* lhs, Val* rhs); |
90 | |
91 | // Ternary operations |
92 | static Val* whereExpr(Val* pred, Val* lhs, Val* rhs); |
93 | |
94 | // Swizzle operations |
95 | static Val* swizzle2DIntExpr( |
96 | Val* x, |
97 | Val* y, |
98 | Val* extent_x, |
99 | Val* extent_y, |
100 | Swizzle2DType swizzle_type); |
101 | static Val* pairSelectExpr(Val* in, kir::PairSelect::Selection sel); |
102 | |
103 | private: |
104 | static Val* newResult(DataType dtype); |
105 | static Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs); |
106 | static Val* newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs); |
107 | }; |
108 | |
109 | //! A wrapper builder with static expression simplification |
110 | //! |
111 | //! Example: |
112 | //! - addExpr(new Int(1), new Int(2)) -> Int(3) |
113 | //! - addExpr(new Int(0), new NamedScalar("foo")) -> NamedScalar("foo") |
114 | //! |
115 | //! Designed to be used to simplify predicate and index expressions in |
116 | //! generated code. Also, the shift validation may fail without |
117 | //! this simplification. |
118 | class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder { |
119 | public: |
120 | static Val* negExpr(Val* val); |
121 | static Val* notExpr(Val* val); |
122 | |
123 | static Val* addExpr(Int* lhs, Int::ScalarType rhs); |
124 | static Val* addExpr(Val* lhs, Int::ScalarType rhs); |
125 | static Val* addExpr(Int* lhs, Int* rhs); |
126 | static Val* addExpr(Val* lhs, Val* rhs); |
127 | static Val* subExpr(Val* lhs, Val* rhs); |
128 | static Val* mulExpr(Int* lhs, Int::ScalarType rhs); |
129 | static Val* mulExpr(Val* lhs, Int::ScalarType rhs); |
130 | static Val* mulExpr(Int* lhs, Int* rhs); |
131 | static Val* mulExpr(Val* lhs, Val* rhs); |
132 | static Val* andExpr(Val* lhs, Val* rhs); |
133 | static Val* maxExpr(Val* lhs, Val* rhs); |
134 | static Val* minExpr(Val* lhs, Val* rhs); |
135 | }; |
136 | |
137 | } // namespace cuda |
138 | } // namespace fuser |
139 | } // namespace jit |
140 | } // namespace torch |
141 | |