1#pragma once
2
3#include <fusion.h>
4#include <ir_all_nodes.h>
5#include <ir_container.h>
6
7namespace torch {
8namespace jit {
9namespace fuser {
10namespace cuda {
11
12namespace kir {
13class Kernel;
14}
15
16class IrCloner;
17
18// Passkey for builder to register properties with statements, and to call
19// functions in IrContainer
20class TORCH_CUDA_CU_API IrBuilderPasskey {
21 friend class IrBuilder;
22
23 public:
24 // TODO: Collapse ir_container and Kernel once Kernel inherits from
25 // IrContainer
26 IrContainer* const ir_container_ = nullptr;
27
28 private:
29 explicit IrBuilderPasskey(IrContainer* ir_container);
30};
31
32//! IR builder interface
33class TORCH_CUDA_CU_API IrBuilder {
34 public:
35 //! Allocate a new IR node, forwarding the arguments to the appropriate
36 //! constructor and registering with the container
37 template <class T, class... Args>
38 static T* create(Args&&... args) {
39 auto container = FusionGuard::getCurFusion();
40 // return create<T>(container, std::forward<Args>(args)...);
41 TORCH_INTERNAL_ASSERT(
42 container != nullptr, "Need an active container to build IR.");
43 T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...);
44
45 container->registerStmt(IrBuilderPasskey(container), node);
46
47 return node;
48 }
49
50 //! Allocate a new IR node, forwarding the arguments to the appropriate
51 //! constructor and registering with the container
52 template <class T, class... Args>
53 static T* create(IrContainer* container, Args&&... args) {
54 TORCH_INTERNAL_ASSERT(
55 container != nullptr, "Need an active container to build IR.");
56 T* node = new T(IrBuilderPasskey(container), std::forward<Args>(args)...);
57
58 container->registerStmt(IrBuilderPasskey(container), node);
59
60 return node;
61 }
62
63 //! Clone an IR node, forwarding the arguments to the IrCloner constructor.
64 //! Register clones with IrCloner's target container.
65 template <class T>
66 static T* clone(const T* src, IrCloner* ir_cloner);
67
68 // Unary operations
69 static Val* negExpr(Val* val);
70 static Val* notExpr(Val* val);
71 static Val* setExpr(Val* val);
72 static Val* setExprNamedScalar(const std::string& name, Val* val);
73 static Val* addressExprNamedScalar(const std::string& name, Val* val);
74
75 // Binary operations
76 static Val* andExpr(Val* lhs, Val* rhs);
77 static Val* eqExpr(Val* lhs, Val* rhs);
78 static Val* gtExpr(Val* lhs, Val* rhs);
79 static Val* ltExpr(Val* lhs, Val* rhs);
80 static Val* leExpr(Val* lhs, Val* rhs);
81 static Val* geExpr(Val* lhs, Val* rhs);
82 static Val* addExpr(Val* lhs, Val* rhs);
83 static Val* subExpr(Val* lhs, Val* rhs);
84 static Val* mulExpr(Val* lhs, Val* rhs);
85 static Val* divExpr(Val* lhs, Val* rhs);
86 static Val* ceilDivExpr(Val* lhs, Val* rhs);
87 static Val* modExpr(Val* lhs, Val* rhs);
88 static Val* maxExpr(Val* lhs, Val* rhs);
89 static Val* minExpr(Val* lhs, Val* rhs);
90
91 // Ternary operations
92 static Val* whereExpr(Val* pred, Val* lhs, Val* rhs);
93
94 // Swizzle operations
95 static Val* swizzle2DIntExpr(
96 Val* x,
97 Val* y,
98 Val* extent_x,
99 Val* extent_y,
100 Swizzle2DType swizzle_type);
101 static Val* pairSelectExpr(Val* in, kir::PairSelect::Selection sel);
102
103 private:
104 static Val* newResult(DataType dtype);
105 static Val* newArithmeticExpr(BinaryOpType op_type, Val* lhs, Val* rhs);
106 static Val* newLogicExpr(BinaryOpType op_type, Val* lhs, Val* rhs);
107};
108
109//! A wrapper builder with static expression simplification
110//!
111//! Example:
112//! - addExpr(new Int(1), new Int(2)) -> Int(3)
113//! - addExpr(new Int(0), new NamedScalar("foo")) -> NamedScalar("foo")
114//!
115//! Designed to be used to simplify predicate and index expressions in
116//! generated code. Also, the shift validation may fail without
117//! this simplification.
118class TORCH_CUDA_CU_API SimplifyingIrBuilder : public IrBuilder {
119 public:
120 static Val* negExpr(Val* val);
121 static Val* notExpr(Val* val);
122
123 static Val* addExpr(Int* lhs, Int::ScalarType rhs);
124 static Val* addExpr(Val* lhs, Int::ScalarType rhs);
125 static Val* addExpr(Int* lhs, Int* rhs);
126 static Val* addExpr(Val* lhs, Val* rhs);
127 static Val* subExpr(Val* lhs, Val* rhs);
128 static Val* mulExpr(Int* lhs, Int::ScalarType rhs);
129 static Val* mulExpr(Val* lhs, Int::ScalarType rhs);
130 static Val* mulExpr(Int* lhs, Int* rhs);
131 static Val* mulExpr(Val* lhs, Val* rhs);
132 static Val* andExpr(Val* lhs, Val* rhs);
133 static Val* maxExpr(Val* lhs, Val* rhs);
134 static Val* minExpr(Val* lhs, Val* rhs);
135};
136
137} // namespace cuda
138} // namespace fuser
139} // namespace jit
140} // namespace torch
141