1#pragma once
2
3#include <c10/macros/Export.h>
4
5#include <fusion.h>
6#include <ir_base_nodes.h>
7#include <ir_builder.h>
8#include <lower_sync_information.h>
9#include <lower_warp_reduce.h>
10#include <parallel_dimension_map.h>
11#include <utils.h>
12#include <vectorization_info.h>
13
14#include <memory>
15#include <unordered_map>
16#include <utility>
17#include <vector>
18
19namespace torch {
20namespace jit {
21namespace fuser {
22namespace cuda {
23namespace kir {
24
25//! Summary of interesting facts about the kernel
26// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
27struct KernelSummary {
28 //! Count of WAR (write-after-read) hazard barriers
29 int war_hazard_syncs_count = 0;
30
31 //! List of global buffers
32 std::vector<const kir::Allocate*> global_allocations;
33
34 //! List of dynamic shared memory buffers
35 std::vector<const kir::Allocate*> dynamic_smem_allocations;
36
37 //! List of static shared memory buffers
38 std::vector<const kir::Allocate*> static_smem_allocations;
39
40 //! Indicate the need to generate random numbers
41 int max_rng_offsets = -1;
42
43 //! Do we have any block reductions?
44 bool has_block_reductions = false;
45
46 //! Number of static grid reductions
47 bool has_grid_reductions = false;
48
49 //! Do we have any grid reduction in a loop, or grid reductions dependent on
50 //! grid reductions
51 bool has_cooperative_grid_reduction = false;
52
53 //! Do we have any block broadcasts?
54 bool has_block_broadcasts = false;
55
56 //! Do we have any grid broadcasts?
57 bool has_grid_broadcasts = false;
58
59 //! Do we have any welford op?
60 bool has_welford = false;
61
62 //! Do we have any welford op?
63 bool has_block_welford = false;
64
65 //! Do we have any welford op?
66 bool has_grid_welford = false;
67
68 //! Largest shared memory buffer base type
69 DataType largest_smem_data_type = DataType::Null;
70
71 //! Do we have allocations of dynamic local memory?
72 bool has_dynamic_local_memory_allocations = false;
73
74 //! List of dynamic local memory buffers.
75 //! Only used for debugging.
76 std::vector<const kir::Allocate*> dynamic_lmem_allocations;
77
78 //! ceilDiv extents that must be divisible
79 std::vector<std::pair<const Val*, const Val*>> splits_to_validate;
80
81 //! Effective ParallelTypes of broadcast ops
82 std::unordered_map<const BroadcastOp*, ParallelTypeBitmap>
83 broadcast_parallel_types;
84
85 //! Track which tensor views are inputs or outputs of a vectorized operation
86 //! and their maximum vectorized access size
87 std::unordered_map<TensorView*, int> vectorized_accesses;
88
89 // Sync map is needed to figure out if global memory buffers need to be marked
90 // as volatile because they're used for communication.
91 SyncMap sync_map;
92
93 // Parallel dimension map needed to set the correct properties of grid buffers
94 // (is a dim inactive)
95 ParallelDimensionMap parallel_dimension_map_;
96
97 //! Track information on vectorized set operations for runtime validation
98 std::vector<VectorizedSetInfo> vectorized_set_info;
99};
100
101class TORCH_CUDA_CU_API KernelPerformanceProfile {
102 public:
103 //! Register an expression to profile
104 void registerExpr(const Expr* expr);
105
106 //! Query if an expression is profiled
107 bool isProfiled(const Expr* expr) const;
108
109 //! Get the number of profiled expressions
110 int getNumberOfProfileEntries() const {
111 return num_profile_entries_;
112 }
113
114 //! Set the backing buffer of profile.
115 void setBuffer(TensorView* buffer) {
116 buffer_ = buffer;
117 }
118
119 //! Get the backing buffer
120 TensorView* getBuffer() const {
121 return buffer_;
122 }
123
124 //! Get the indices of the profile of an expression in the backing buffer
125 std::array<int, 2> getIndicesInProfileBuffer(const Expr* expr) const;
126
127 std::string toString(const at::Tensor& buffer) const;
128
129 private:
130 //! Get the new profile index
131 int getNewIndex();
132
133 //! Get the profile index
134 c10::optional<int> getIndex(const Expr* expr) const;
135
136 private:
137 int num_profile_entries_ = 0;
138
139 //! Backing buffer of Nx2 integer tensor, where N is the number of profiled
140 //! regions. Each region has two integer values, one representing
141 //! the cycles spent, and another the count.
142 TensorView* buffer_ = nullptr;
143
144 //! Map profiled expressions to profile entry offsets
145 std::unordered_map<const Expr*, int> expr_entry_map_;
146
147 // TODO: Allow profiling of ForLoops
148 //! Map profiled ForLoop to profile entry offsets
149 // std::unordered_map<const kir::ForLoop*, int> loop_entry_map_;
150};
151
152class KernelInternalProxy;
153
154//! Container for a lowered Kernel IR
155//!
156// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
157class TORCH_CUDA_CU_API Kernel final : public Fusion {
158 friend KernelInternalProxy;
159
160 public:
161 // Kernel starts by grabbing all the nodes from the provided fusion.
162 // Kernel is not SSA, if a definition is not set, we should update it, but
163 // not remove previous definition if it is set. This is primarily because when
164 // we do something like generate an initialization statement for a reduction
165 // TV, we may want to continue to do fusion like analysis on the original
166 // expression.
167 // TODO: Assert index type is int or int32
168 Kernel(Fusion* fusion, DataType index_type = DataType::Int)
169 : Fusion(*fusion), index_type_(index_type) {}
170
171 Kernel() = delete;
172
173 // No move or copy semantics
174 Kernel(const Kernel&) = delete;
175 Kernel& operator=(const Kernel&) = delete;
176
177 //! Finalize a kernel definition
178 //!
179 //! At this point we have a complete kernel definition and we can
180 //! run analysis passes to build a KernelSummary.
181 void finalize(std::vector<Expr*> top_level_exprs);
182
183 const std::vector<Expr*>& topLevelExprs() const {
184 return top_level_exprs_;
185 }
186
187 const KernelSummary& summary() const {
188 return summary_;
189 }
190
191 DataType indexType() const {
192 return index_type_;
193 }
194
195 //! Checks if parallel type is padded
196 bool isParallelTypePadded(ParallelType ptype) const {
197 return ptype == ParallelType::TIDx &&
198 warp_padded_parallel_info_.is_tidx_padded;
199 }
200
201 const WarpPaddedParallelInfo& getWarpPaddedParallelInfo() const {
202 return warp_padded_parallel_info_;
203 }
204
205 const KernelPerformanceProfile& profile() const {
206 return profile_;
207 }
208
209 //! Debug dump of the Kernel IR
210 void print() const;
211
212 protected:
213 //! Register the Val with this fusion
214 void registerVal(Val* val) override;
215
216 //! Register expr with this fusion.
217 //! When we register an expression, we want to update the dependency tracking
218 //! of Vals. We add expr to our general expr_set_,
219 void registerExpr(Expr* expr) override;
220
221 private:
222 // Analyze the kernel IR and caches the summary of interesting data
223 void analyze();
224
225 // Top level statements
226 std::vector<Expr*> top_level_exprs_;
227
228 // Summary of interesting kernel data
229 KernelSummary summary_;
230
231 // Is this kernel being compiled with int32 or int64 indexing. This
232 // information is required to resolve DataType::Index
233 DataType index_type_ = DataType::Int;
234
235 WarpPaddedParallelInfo warp_padded_parallel_info_;
236
237 KernelPerformanceProfile profile_;
238};
239
240//! A special debugging proxy for Kernel.
241//!
242//! Should not be used for other than testing and debugging.
243class TORCH_CUDA_CU_API KernelInternalProxy {
244 public:
245 KernelInternalProxy(Kernel* kernel) : kernel_(kernel) {}
246
247 std::vector<Expr*>& topLevelExprs();
248
249 private:
250 Kernel* kernel_ = nullptr;
251};
252
253} // namespace kir
254} // namespace cuda
255} // namespace fuser
256} // namespace jit
257} // namespace torch
258