1 | #pragma once |
2 | |
3 | #include <c10/macros/Export.h> |
4 | |
5 | #include <fusion.h> |
6 | #include <ir_base_nodes.h> |
7 | #include <ir_builder.h> |
8 | #include <lower_sync_information.h> |
9 | #include <lower_warp_reduce.h> |
10 | #include <parallel_dimension_map.h> |
11 | #include <utils.h> |
12 | #include <vectorization_info.h> |
13 | |
14 | #include <memory> |
15 | #include <unordered_map> |
16 | #include <utility> |
17 | #include <vector> |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | namespace fuser { |
22 | namespace cuda { |
23 | namespace kir { |
24 | |
25 | //! Summary of interesting facts about the kernel |
26 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
27 | struct KernelSummary { |
28 | //! Count of WAR (write-after-read) hazard barriers |
29 | int war_hazard_syncs_count = 0; |
30 | |
31 | //! List of global buffers |
32 | std::vector<const kir::Allocate*> global_allocations; |
33 | |
34 | //! List of dynamic shared memory buffers |
35 | std::vector<const kir::Allocate*> dynamic_smem_allocations; |
36 | |
37 | //! List of static shared memory buffers |
38 | std::vector<const kir::Allocate*> static_smem_allocations; |
39 | |
40 | //! Indicate the need to generate random numbers |
41 | int max_rng_offsets = -1; |
42 | |
43 | //! Do we have any block reductions? |
44 | bool has_block_reductions = false; |
45 | |
46 | //! Number of static grid reductions |
47 | bool has_grid_reductions = false; |
48 | |
49 | //! Do we have any grid reduction in a loop, or grid reductions dependent on |
50 | //! grid reductions |
51 | bool has_cooperative_grid_reduction = false; |
52 | |
53 | //! Do we have any block broadcasts? |
54 | bool has_block_broadcasts = false; |
55 | |
56 | //! Do we have any grid broadcasts? |
57 | bool has_grid_broadcasts = false; |
58 | |
59 | //! Do we have any welford op? |
60 | bool has_welford = false; |
61 | |
62 | //! Do we have any welford op? |
63 | bool has_block_welford = false; |
64 | |
65 | //! Do we have any welford op? |
66 | bool has_grid_welford = false; |
67 | |
68 | //! Largest shared memory buffer base type |
69 | DataType largest_smem_data_type = DataType::Null; |
70 | |
71 | //! Do we have allocations of dynamic local memory? |
72 | bool has_dynamic_local_memory_allocations = false; |
73 | |
74 | //! List of dynamic local memory buffers. |
75 | //! Only used for debugging. |
76 | std::vector<const kir::Allocate*> dynamic_lmem_allocations; |
77 | |
78 | //! ceilDiv extents that must be divisible |
79 | std::vector<std::pair<const Val*, const Val*>> splits_to_validate; |
80 | |
81 | //! Effective ParallelTypes of broadcast ops |
82 | std::unordered_map<const BroadcastOp*, ParallelTypeBitmap> |
83 | broadcast_parallel_types; |
84 | |
85 | //! Track which tensor views are inputs or outputs of a vectorized operation |
86 | //! and their maximum vectorized access size |
87 | std::unordered_map<TensorView*, int> vectorized_accesses; |
88 | |
89 | // Sync map is needed to figure out if global memory buffers need to be marked |
90 | // as volatile because they're used for communication. |
91 | SyncMap sync_map; |
92 | |
93 | // Parallel dimension map needed to set the correct properties of grid buffers |
94 | // (is a dim inactive) |
95 | ParallelDimensionMap parallel_dimension_map_; |
96 | |
97 | //! Track information on vectorized set operations for runtime validation |
98 | std::vector<VectorizedSetInfo> vectorized_set_info; |
99 | }; |
100 | |
101 | class TORCH_CUDA_CU_API KernelPerformanceProfile { |
102 | public: |
103 | //! Register an expression to profile |
104 | void registerExpr(const Expr* expr); |
105 | |
106 | //! Query if an expression is profiled |
107 | bool isProfiled(const Expr* expr) const; |
108 | |
109 | //! Get the number of profiled expressions |
110 | int getNumberOfProfileEntries() const { |
111 | return num_profile_entries_; |
112 | } |
113 | |
114 | //! Set the backing buffer of profile. |
115 | void setBuffer(TensorView* buffer) { |
116 | buffer_ = buffer; |
117 | } |
118 | |
119 | //! Get the backing buffer |
120 | TensorView* getBuffer() const { |
121 | return buffer_; |
122 | } |
123 | |
124 | //! Get the indices of the profile of an expression in the backing buffer |
125 | std::array<int, 2> getIndicesInProfileBuffer(const Expr* expr) const; |
126 | |
127 | std::string toString(const at::Tensor& buffer) const; |
128 | |
129 | private: |
130 | //! Get the new profile index |
131 | int getNewIndex(); |
132 | |
133 | //! Get the profile index |
134 | c10::optional<int> getIndex(const Expr* expr) const; |
135 | |
136 | private: |
137 | int num_profile_entries_ = 0; |
138 | |
139 | //! Backing buffer of Nx2 integer tensor, where N is the number of profiled |
140 | //! regions. Each region has two integer values, one representing |
141 | //! the cycles spent, and another the count. |
142 | TensorView* buffer_ = nullptr; |
143 | |
144 | //! Map profiled expressions to profile entry offsets |
145 | std::unordered_map<const Expr*, int> expr_entry_map_; |
146 | |
147 | // TODO: Allow profiling of ForLoops |
148 | //! Map profiled ForLoop to profile entry offsets |
149 | // std::unordered_map<const kir::ForLoop*, int> loop_entry_map_; |
150 | }; |
151 | |
152 | class KernelInternalProxy; |
153 | |
154 | //! Container for a lowered Kernel IR |
155 | //! |
156 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
157 | class TORCH_CUDA_CU_API Kernel final : public Fusion { |
158 | friend KernelInternalProxy; |
159 | |
160 | public: |
161 | // Kernel starts by grabbing all the nodes from the provided fusion. |
162 | // Kernel is not SSA, if a definition is not set, we should update it, but |
163 | // not remove previous definition if it is set. This is primarily because when |
164 | // we do something like generate an initialization statement for a reduction |
165 | // TV, we may want to continue to do fusion like analysis on the original |
166 | // expression. |
167 | // TODO: Assert index type is int or int32 |
168 | Kernel(Fusion* fusion, DataType index_type = DataType::Int) |
169 | : Fusion(*fusion), index_type_(index_type) {} |
170 | |
171 | Kernel() = delete; |
172 | |
173 | // No move or copy semantics |
174 | Kernel(const Kernel&) = delete; |
175 | Kernel& operator=(const Kernel&) = delete; |
176 | |
177 | //! Finalize a kernel definition |
178 | //! |
179 | //! At this point we have a complete kernel definition and we can |
180 | //! run analysis passes to build a KernelSummary. |
181 | void finalize(std::vector<Expr*> top_level_exprs); |
182 | |
183 | const std::vector<Expr*>& topLevelExprs() const { |
184 | return top_level_exprs_; |
185 | } |
186 | |
187 | const KernelSummary& summary() const { |
188 | return summary_; |
189 | } |
190 | |
191 | DataType indexType() const { |
192 | return index_type_; |
193 | } |
194 | |
195 | //! Checks if parallel type is padded |
196 | bool isParallelTypePadded(ParallelType ptype) const { |
197 | return ptype == ParallelType::TIDx && |
198 | warp_padded_parallel_info_.is_tidx_padded; |
199 | } |
200 | |
201 | const WarpPaddedParallelInfo& getWarpPaddedParallelInfo() const { |
202 | return warp_padded_parallel_info_; |
203 | } |
204 | |
205 | const KernelPerformanceProfile& profile() const { |
206 | return profile_; |
207 | } |
208 | |
209 | //! Debug dump of the Kernel IR |
210 | void print() const; |
211 | |
212 | protected: |
213 | //! Register the Val with this fusion |
214 | void registerVal(Val* val) override; |
215 | |
216 | //! Register expr with this fusion. |
217 | //! When we register an expression, we want to update the dependency tracking |
218 | //! of Vals. We add expr to our general expr_set_, |
219 | void registerExpr(Expr* expr) override; |
220 | |
221 | private: |
222 | // Analyze the kernel IR and caches the summary of interesting data |
223 | void analyze(); |
224 | |
225 | // Top level statements |
226 | std::vector<Expr*> top_level_exprs_; |
227 | |
228 | // Summary of interesting kernel data |
229 | KernelSummary summary_; |
230 | |
231 | // Is this kernel being compiled with int32 or int64 indexing. This |
232 | // information is required to resolve DataType::Index |
233 | DataType index_type_ = DataType::Int; |
234 | |
235 | WarpPaddedParallelInfo warp_padded_parallel_info_; |
236 | |
237 | KernelPerformanceProfile profile_; |
238 | }; |
239 | |
240 | //! A special debugging proxy for Kernel. |
241 | //! |
242 | //! Should not be used for other than testing and debugging. |
243 | class TORCH_CUDA_CU_API KernelInternalProxy { |
244 | public: |
245 | KernelInternalProxy(Kernel* kernel) : kernel_(kernel) {} |
246 | |
247 | std::vector<Expr*>& topLevelExprs(); |
248 | |
249 | private: |
250 | Kernel* kernel_ = nullptr; |
251 | }; |
252 | |
253 | } // namespace kir |
254 | } // namespace cuda |
255 | } // namespace fuser |
256 | } // namespace jit |
257 | } // namespace torch |
258 | |