1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <instrumentation.h> |
6 | #include <kernel_ir.h> |
7 | #include <kernel_ir_dispatch.h> |
8 | #include <root_domain_map.h> |
9 | |
10 | #include <vector> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | namespace fuser { |
15 | namespace cuda { |
16 | |
17 | // TODO: Replace with mutator as IndexLowering is replacing expr's with |
18 | // versions that are doing indexing |
19 | class TORCH_CUDA_CU_API IndexLowering : private OptOutConstDispatch { |
20 | public: |
21 | static std::vector<Expr*> getIndexedExprs(std::vector<Expr*> incoming_exprs) { |
22 | FUSER_PERF_SCOPE("GpuLower::Lower::IndexLowering::getIndexedExprs" ); |
23 | IndexLowering il; |
24 | il.generate(incoming_exprs); |
25 | return il.lowered_exprs_; |
26 | } |
27 | |
28 | private: |
29 | IndexLowering() = default; |
30 | |
31 | void pushBack(Expr*); |
32 | |
33 | // Return the most recently inserted |
34 | // expression in the current active |
35 | // scope or global scope. |
36 | Expr* back() const; |
37 | |
38 | // Insert an expression before the current top-level expression. |
39 | void insertAtTopLevel(Expr* expr); |
40 | |
41 | void handle(const FullOp*) final; |
42 | void handle(const ARangeOp*) final; |
43 | void handle(const EyeOp*) final; |
44 | void handle(const ViewAsScalar*) final; |
45 | void handle(const UnaryOp*) final; |
46 | |
47 | void handle(const BinaryOp*) final; |
48 | void handle(const TernaryOp*) final; |
49 | void handle(const RNGOp*) final; |
50 | void handle(const ReductionOp*) final; |
51 | void handle(const GroupedReductionOp*) final; |
52 | void handle(const WelfordOp*) final; |
53 | void handle(const GroupedWelfordOp*) final; |
54 | void handle(const LoadStoreOp*) final; |
55 | void handle(const MmaOp*) final; |
56 | void handle(const BroadcastOp*) final; |
57 | |
58 | void handle(const kir::ForLoop*) final; |
59 | void handle(const kir::IfThenElse*) final; |
60 | void handle(const kir::Allocate*) final; |
61 | void handle(const kir::BlockSync*) final; |
62 | void handle(const kir::GridSync*) final; |
63 | void handle(const kir::CpAsyncWait*) final; |
64 | void handle(const kir::CpAsyncCommit*) final; |
65 | |
66 | void generate(const std::vector<Expr*>& exprs); |
67 | |
68 | Val* lowerSrcIndex(Val* val, Val* dst) const; |
69 | |
70 | Val* lowerDstIndex(Val* dst) const; |
71 | |
72 | void handleBlockReduction(const ReductionOp* rop, Val* out, Val* in); |
73 | void handleGridReduction(const ReductionOp* rop, Val* out, Val* in); |
74 | |
75 | void handleBlockReduction( |
76 | const GroupedReductionOp* rop, |
77 | const std::vector<Val*>& outputs, |
78 | const std::vector<Val*>& inputs); |
79 | void handleGridReduction( |
80 | const GroupedReductionOp* rop, |
81 | const std::vector<Val*>& outputs, |
82 | const std::vector<Val*>& inputs); |
83 | |
84 | void handleGridWelford(WelfordOp* new_wop); |
85 | |
86 | void handleGroupedBlockWelford( |
87 | const GroupedWelfordOp* wop, |
88 | const std::vector<WelfordTriplet>& output_vals, |
89 | const std::vector<WelfordTriplet>& input_vals, |
90 | const std::vector<WelfordTriplet>& init_vals); |
91 | void handleGroupedGridWelford( |
92 | const GroupedWelfordOp* wop, |
93 | const std::vector<WelfordTriplet>& output_vals, |
94 | const std::vector<WelfordTriplet>& input_vals, |
95 | const std::vector<WelfordTriplet>& init_vals); |
96 | |
97 | // Allocate a unique buffer for grid reductions and broadcast. A |
98 | // buffer is uniquely allocated for each output tensor of an |
99 | // expression. |
100 | kir::Allocate* allocateUniqueBuffer( |
101 | Val* buffer_size, |
102 | DataType dtype, |
103 | bool zero_init, |
104 | TensorView* out_tv, |
105 | std::unordered_map<TensorView*, kir::Allocate*>& alloc_map); |
106 | |
107 | std::vector<kir::Allocate*> allocateWelfordWorkBuffer( |
108 | const std::vector<WelfordTriplet>& triplets, |
109 | WelfordTriplet::ValName name, |
110 | Val* buffer_size); |
111 | |
112 | // Allocate a fused reduction object uniquely for a given |
113 | // TensorView. Parameter expr is the expression corresponding to the |
114 | // fused reduction. |
115 | void allocateUniqueFusedReduction(Expr* expr, TensorView* out_tv); |
116 | |
117 | private: |
118 | std::vector<Expr*> lowered_exprs_; |
119 | |
120 | // This is a slight work around as scope has a couple definitions, we have the |
121 | // Scope that's in ForLoop/IfThenElse which is really just a wrapper around |
122 | // std::vector<Expr*> and then we have the actual ForLoop/IfThenElse. We want |
123 | // to be able to carry both around because when we push back to a scope it |
124 | // could be either the body or else body of the IfThenElse. However, we want |
125 | // to understand the nesting of IfThenElse/ForLoop nodes. |
126 | kir::Scope* active_scope_ = nullptr; |
127 | |
128 | // Track for loops to send to indexing. Similar to what's done in |
129 | // kir::IrVisitor |
130 | std::vector<kir::ForLoop*> for_loops_; |
131 | |
132 | // Maps to keep track of allocated buffers and objects that must be |
133 | // allocated only once |
134 | std::unordered_map<TensorView*, kir::Allocate*> sync_buffer_map_; |
135 | std::unordered_map<TensorView*, kir::Allocate*> work_buffer_map_; |
136 | std::unordered_map<TensorView*, kir::AllocateFusedReduction*> |
137 | fused_reduction_map_; |
138 | }; |
139 | |
140 | } // namespace cuda |
141 | } // namespace fuser |
142 | } // namespace jit |
143 | } // namespace torch |
144 | |