1#pragma once
2#include <executor_launch_params.h>
3#include <executor_utils.h>
4#include <fusion.h>
5#include <ir_all_nodes.h>
6#include <ir_cloner.h>
7#include <ir_printer.h>
8#include <kernel_expr_evaluator.h>
9#include <lower2device.h>
10#include <utils.h>
11
12#include <c10/core/DeviceType.h>
13
14namespace torch {
15namespace jit {
16namespace fuser {
17namespace cuda {
18
19TORCH_CUDA_CU_API bool shouldFillAllocationWithNan();
20TORCH_CUDA_CU_API void setFillAllocationWithNan(bool value);
21
22// TODO: Should this actually be in launch params?
23struct TORCH_CUDA_CU_API CompileOptions {
24 c10::Device device = c10::Device(c10::DeviceType::CUDA, 0);
25 KernelIndexMode index_mode = KernelIndexMode::INT64;
26};
27
28class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
29 public:
30 // Unsafe compilation that's useful for debugging kernels, iterating over
31 // slight modifications of a generated kernel
32 void debugCompileFusionFromStr(
33 Fusion* fusion,
34 const std::string& code,
35 const std::string& name,
36 int id,
37 CompileOptions options = CompileOptions());
38
39 //! infers output sizes via returning non-allocated KernelArgumentHolder.
40 //! this function is useful for async compilation for segmented fusion
41 KernelArgumentHolder inferOutputSizes(
42 const KernelArgumentHolder& args,
43 const LaunchParams& launch_constraints);
44
45 void compileFusion(
46 Fusion* fusion,
47 const KernelArgumentHolder& args,
48 const LaunchParams& launch_constraints = LaunchParams());
49
50 // TODO: merge it with the overload above.
51 //! This API is merely here so we don't have to go back and update all cpp
52 //! tests.
53 void compileFusion(
54 Fusion* fusion,
55 const at::ArrayRef<IValue>& inputs = {},
56 const LaunchParams& launch_constraints = LaunchParams()) {
57 KernelArgumentHolder args =
58 KernelArgumentHolder::createKernelArgumentHolder(inputs);
59 compileFusion(fusion, args, launch_constraints);
60 }
61
62 std::vector<at::Tensor> runFusion(
63 KernelArgumentHolder& args,
64 const LaunchParams& launch_constraints = LaunchParams(),
65 const std::vector<at::Tensor>& outputs = {});
66
67 std::vector<at::Tensor> runFusion(
68 const at::ArrayRef<IValue>& inputs,
69 const std::vector<at::Tensor>& outputs,
70 const LaunchParams& launch_constraints = LaunchParams(),
71 const c10::optional<size_t>& opt_code = c10::nullopt) {
72 KernelArgumentHolder args =
73 KernelArgumentHolder::createKernelArgumentHolder(inputs);
74 if (opt_code.has_value()) {
75 args.setCacheId(*opt_code);
76 }
77 return runFusion(args, launch_constraints, outputs);
78 }
79
80 std::vector<at::Tensor> runFusion(
81 const at::ArrayRef<IValue>& inputs,
82 const LaunchParams& launch_constraints = LaunchParams(),
83 const c10::optional<size_t>& opt_code = c10::nullopt) {
84 return runFusion(inputs, {}, launch_constraints, opt_code);
85 }
86
87 // function to query whether a `FusionExecutor` has a compiled kernel to
88 // execute
89 bool compiled() const {
90 return fusion_id_ != -1 && lowered_;
91 };
92
93 void evictCache(size_t cache_id) {
94 executor_entry_lookup_.erase(cache_id);
95 }
96
97 // struct used to hold necessary information to launch compiled kernel on a
98 // given input set.
99 //
100 // TODO: strides would also be important when we handle permutations in
101 // codegen.
102 //
103 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
104 struct ExecutorEntry {
105 bool init = false;
106 LaunchParams launch_params;
107 std::vector<std::pair<int, int>> io_alias_indices;
108 std::vector<std::vector<int64_t>> output_sizes;
109 std::vector<std::vector<int64_t>> output_strides;
110 std::vector<at::ScalarType> output_types;
111 std::vector<std::vector<int64_t>> buffer_sizes;
112 std::vector<at::ScalarType> buffer_types;
113 std::vector<bool> buffer_zero_init;
114 uint64_t rand_offset;
115 };
116
117 using ExecutorCompileTimeInfoCache =
118 executor_utils::caching::ExecutorCompileTimeInfoCache;
119
120 kir::Kernel* kernel() const {
121 TORCH_INTERNAL_ASSERT(lowered_);
122 return lowered_->kernel();
123 }
124
125 //! Internal knob used for debugging/profiling only
126 void setExecuteKernelFlag(bool execute_kernel) {
127 execute_kernel_ = execute_kernel;
128 }
129
130 //! Internal knob used for debugging/profiling only
131 void setMeasureKernelTimeFlag(bool measure_kernel_time) {
132 measure_kernel_time_ = measure_kernel_time;
133 }
134
135 //! Returns the last kernel execution time, in milliseconds
136 //!
137 //! \note The kernel time is only tracked if enabled by calling
138 //! setMeasureKernelTimeFlag(true)
139 //!
140 float kernelTimeMs() const {
141 return measure_kernel_time_ ? kernel_time_ms_ : 0;
142 }
143
144 //! Returns the number of bytes processed last kernel execution
145 int64_t bytesProcessed() const {
146 return bytes_processed_;
147 }
148
149 //! Returns the launch parameters from the last kernel execution
150 LaunchParams lastLaunchParams() const {
151 return launch_params_;
152 }
153
154 //! Returns the string of the compiled kernel
155 std::string kernelString() const {
156 return kernel_code_;
157 }
158
159 //! Returns the latest compile log
160 std::string compilerLog() const {
161 return last_compiler_log_;
162 }
163
164 std::string kernelName() const {
165 std::stringstream ss;
166 ss << "kernel" << fusion_id_;
167 return ss.str();
168 }
169
170 //! Internal tests only. Compiles CUDA code with NVRTC directly from
171 //! string. This util provides a path to test runtime code, i.e. the resource
172 //! strings.
173 void compileRtc(
174 const std::string& code,
175 const std::string& name,
176 bool structured = false,
177 CompileOptions options = CompileOptions());
178
179 //! Internal tests only. Runs the compiled CUDA kernel from compileRtc.
180 void runRtc(
181 const LaunchParams& launch_params,
182 const std::vector<at::Tensor>& args);
183
184 //! Internal knob used for debugging/profiling only
185 void disableLaunchParamCache() {
186 disable_parameter_cache_ = true;
187 }
188
189 private:
190 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
191 struct GlobalBuffers {
192 std::vector<at::Tensor> buffers;
193 std::vector<bool> zero_init;
194 at::Tensor profile_buffer;
195 };
196
197 static std::string kernelNamespace() {
198 return "CudaCodeGen";
199 }
200
201 // Add preamble and wrap in namespace
202 std::string getStructuredCode(const std::string& kernel);
203
204 LaunchParams computeLaunchParams(
205 const LaunchParams& launch_constraints,
206 kir::ExpressionEvaluator& expr_eval,
207 const int warp_size);
208
209 uint64_t computeSharedMemory(
210 kir::ExpressionEvaluator& expr_eval,
211 const std::vector<const kir::Allocate*>& buffers,
212 bool align_padding = false,
213 uint64_t total = 0);
214
215 // return a pair of vector of tensors, where tensors in the first vector are
216 // not initialized, while the second vector contains zero-initiliazed tensors
217 GlobalBuffers allocGlobalVals(kir::ExpressionEvaluator& expr_eval);
218
219 // alias_index: index of outputs that are aliases to inputs, hence we should
220 // skip allocating real storage for those, but still maintain its spot to
221 // maintain the indexing from output aliases to inputs
222 std::vector<at::Tensor> allocOutputs(
223 const KernelArgumentHolder& args,
224 kir::ExpressionEvaluator& expr_eval,
225 const std::unordered_set<int>& alias_indices = {});
226
227 void setUsedTVs();
228
229 const std::vector<TensorView*>& getUsedTVs() const {
230 return used_tvs_;
231 };
232
233 ExecutorCompileTimeInfoCache* compileTimeDataCache() {
234 return &compile_time_info_cache_;
235 }
236
237 //! returns KernelArgumentHolder representing the output sizes from kernel
238 //! execution. Note: 1. this API would ignoring aliased outputs and instead
239 //! pushing scalar int 0 as a place holder; 2. this API doesn't actually
240 //! allocate output in memory, but rather is used just to infer output sizes.
241 KernelArgumentHolder evaluateOutputSizes(
242 const KernelArgumentHolder& args,
243 kir::ExpressionEvaluator& expr_eval,
244 const std::unordered_set<int>& alias_indices = {});
245
246 private:
247 CompileOptions options_;
248
249 //! Current configured total shared mem size from cudaDeviceProp
250 size_t configured_device_smem_ = std::numeric_limits<size_t>().max();
251
252 //! Available shared memory space for dynamic allocation for the current
253 //! compiled kernel at the current shared memory/L1 configuration
254 c10::optional<size_t> maybe_available_dynamic_smem_ = c10::nullopt;
255
256 //! Absolute limit of all available shared mem space from cudaDeviceProp
257 size_t device_smem_limit_ = std::numeric_limits<size_t>().max();
258
259 // Assuming sm70 or above:
260 // limit of statically allocated smem is 48 KB:
261 // See:
262 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-7-x
263 // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#shared-memory-8-x
264 const uint64_t max_static_smem_ = 48 << 10;
265 int warp_size_ = 0;
266 executor_utils::NvrtcFunction compiled_kernel_;
267
268 // TensorViews actually used in the kernel.
269 std::vector<TensorView*> used_tvs_;
270
271 // Counter to be used for kernel name.
272 int fusion_id_ = -1;
273 static int fusion_id_counter_;
274
275 std::unique_ptr<GpuLower> lowered_;
276 // Copy of lowered_->kernel()
277 Fusion* fusion_ = nullptr;
278
279 // Track the block size this kernel was compiled with. If the block size
280 // increases, recompile to adjust maxregister count.
281 int64_t block_size_high_water_mark = 1;
282
283 // lookup table to take short cut to retrieve recorded information in order to
284 // launch kernels without re-inference parameters.
285 std::unordered_map<size_t, ExecutorEntry> executor_entry_lookup_;
286
287 // Compile time information caching. This is used for shape inference
288 // support. The cache stores graph information that are available
289 // without shape information so that each shape inference call will
290 // not need to re-compute them.
291 ExecutorCompileTimeInfoCache compile_time_info_cache_;
292
293 // Cached expr eval
294 std::unique_ptr<KernelPrecomputedValues> evaluator_precomputed_values_ =
295 nullptr;
296
297 // Profiling support: knob to control wheter we actually execute the
298 // kernel on the GPU or not
299 bool execute_kernel_ = true;
300
301 // Profiling support: knob to enable measuring kernel execution time
302 bool measure_kernel_time_ = false;
303
304 // Profiling support: the last kernel execution time, if measure_kernel_time_
305 // is true
306 float kernel_time_ms_ = 0;
307
308 // Profiling support: the last kernel Bytes processed
309 int64_t bytes_processed_ = 0;
310
311 // Profiling support: the last launch param used
312 LaunchParams launch_params_;
313
314 // Profiling support: disable caching of launch params and output allocation
315 // output allocation is also disable when output sizes are dependent on
316 // runtime scalar inputs, such as for the case of tensor factory. see
317 // https://github.com/csarofeen/pytorch/issues/2002
318 bool disable_parameter_cache_ = false;
319
320 // Profiling support: kept copy of the cuda kernel
321 std::string kernel_code_;
322
323 // Profiling support: nvrtc log for debugging
324 std::string last_compiler_log_;
325};
326
327} // namespace cuda
328} // namespace fuser
329} // namespace jit
330} // namespace torch
331