1
2#include <executor.h>
3
4#include <codegen.h>
5#include <executor_kernel_arg.h>
6#include <executor_utils.h>
7#include <instrumentation.h>
8#include <ir_all_nodes.h>
9#include <ir_utils.h>
10#include <iter_visitor.h>
11#include <kernel_ir.h>
12#include <lower_bank_conflict.h>
13#include <utils.h>
14
15#include <ATen/core/LegacyTypeDispatch.h>
16#include <ATen/cuda/CUDAContext.h>
17#include <ATen/cuda/llvm_jit_strings.h>
18#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
19#include <ATen/native/cuda/jit_utils.h>
20#include <c10/core/DeviceGuard.h>
21#include <c10/cuda/CUDAFunctions.h>
22#include <c10/cuda/CUDAStream.h>
23#include <c10/util/irange.h>
24
25#include <cmath>
26#include <fstream>
27
28namespace torch {
29namespace jit {
30namespace fuser {
31namespace cuda {
32
33int FusionExecutor::fusion_id_counter_ = 0; // NOLINT
34
35bool fill_allocation_with_nan_ = false;
36
37bool shouldFillAllocationWithNan() {
38 return fill_allocation_with_nan_;
39}
40
41void setFillAllocationWithNan(bool value) {
42 fill_allocation_with_nan_ = value;
43}
44
45namespace {
46
47static const char* defineIndexMode(KernelIndexMode index_mode) {
48 switch (index_mode) {
49 case KernelIndexMode::INT32:
50 return "typedef int nvfuser_index_t;\n";
51 case KernelIndexMode::INT64:
52 return "typedef int64_t nvfuser_index_t;\n";
53 default:
54 break;
55 }
56
57 TORCH_INTERNAL_ASSERT(false, "unknow indexing mode");
58 return "";
59}
60
61static const char* defineIntegerTypes() {
62 return R"(
63typedef signed char int8_t;
64typedef unsigned char uint8_t;
65typedef short int int16_t;
66typedef unsigned short int uint16_t;
67typedef int int32_t;
68typedef unsigned int uint32_t;
69typedef long long int int64_t;
70typedef unsigned long long int uint64_t;
71)";
72}
73
74static const std::string& defineComplexTypes() {
75 static std::string result = std::string(R"ESCAPE(
76#define POS_INFINITY __int_as_float(0x7f800000)
77#define INFINITY POS_INFINITY
78#define NEG_INFINITY __int_as_float(0xff800000)
79#define NAN __int_as_float(0x7fffffff)
80)ESCAPE") +
81 at::cuda::get_traits_string() + at::cuda::get_complex_body_string() +
82 at::cuda::get_cmath_string() + at::cuda::get_complex_math_string();
83 return result;
84}
85
86} // namespace
87
88std::string FusionExecutor::getStructuredCode(const std::string& kernel) {
89 // generating cuda code;
90 std::string code = "";
91#ifdef USE_ROCM
92#if ROCM_VERSION < 40200
93 code += std::string("#include <hip/hip_runtime.h>\n") +
94 std::string("#include <hip/hip_bf16.h>\n") +
95 std::string("#include <hip/hip_fp16.h>\n");
96#endif
97 code += std::string("#pragma clang force_cuda_host_device begin\n");
98#endif
99 code += std::string("namespace ") + FusionExecutor::kernelNamespace() +
100 " {\n" + defineIntegerTypes() + defineIndexMode(options_.index_mode) +
101 defineComplexTypes() + executor_utils::kernelPreamble() + kernel + "}\n";
102#ifdef USE_ROCM
103 code += std::string("#pragma clang force_cuda_host_device end\n");
104#endif
105
106 if (isDebugDumpEnabled(DebugDumpOption::CudaKernel)) {
107 std::cout << "\n======= Codegen output for kernel: " << kernelName()
108 << " =======\n\n"
109 << kernel << "\n======================================\n\n";
110 } else if (isDebugDumpEnabled(DebugDumpOption::CudaFull)) {
111 std::cout << "\n======= Codegen output for kernel: " << kernelName()
112 << " =======\n\n"
113 << code << "\n======================================\n\n";
114 }
115 if (isDebugDumpEnabled(DebugDumpOption::CudaToFile) ||
116 isDebugDumpEnabled(DebugDumpOption::DebugInfo)) {
117 std::stringstream file_name;
118 file_name << "__tmp_kernel" << fusion_id_ << ".cu";
119 std::cout << "PRINTING: " << file_name.str() << std::endl;
120 std::ofstream out(file_name.str());
121 out << code << std::endl;
122 out.close();
123 }
124
125 return code;
126}
127
128// TODO: come up with a more user friendly interface
129void FusionExecutor::debugCompileFusionFromStr(
130 Fusion* fusion,
131 const std::string& code,
132 const std::string& name,
133 int id,
134 CompileOptions options) {
135 options_ = options;
136
137 if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) {
138 fusion->print();
139 } else if (isDebugDumpEnabled(DebugDumpOption::FusionIrMath)) {
140 fusion->printMath();
141 }
142
143 if (isDebugDumpEnabled(DebugDumpOption::CudaFull)) {
144 std::cout << "\n==== codegen output for kernel: " << kernelName()
145 << " ====" << std::endl
146 << code << std::endl
147 << "======================================\n"
148 << std::endl;
149 }
150
151 lowered_ = std::make_unique<GpuLower>(fusion);
152 const auto kernel = lowered_->kernel();
153 fusion_ = lowered_->kernel();
154
155 fusion_id_ = id;
156 setUsedTVs();
157
158 if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) {
159 kernel->print();
160 }
161
162 const auto& kernel_summary = kernel->summary();
163
164 if (!kernel_summary.static_smem_allocations.empty()) {
165 kir::ExpressionEvaluator static_evaluator;
166 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
167 const auto static_smem_size = computeSharedMemory(
168 static_evaluator, kernel_summary.static_smem_allocations);
169 TORCH_INTERNAL_ASSERT(
170 static_smem_size < max_static_smem_,
171 "The static shared memory allocation is larger than available memory.");
172 }
173
174 std::tie(compiled_kernel_, last_compiler_log_) =
175 executor_utils::nvrtcCompile(code, name, fusion_id_);
176 TORCH_INTERNAL_ASSERT(
177 fusion_id_ > 0, "assign a fusion_id_ <= 0 is not accepted.");
178}
179
180void FusionExecutor::compileFusion(
181 Fusion* fusion,
182 const KernelArgumentHolder& args,
183 const LaunchParams& launch_constraints) {
184 FUSER_PERF_SCOPE("compileFusion");
185
186 TORCH_INTERNAL_ASSERT(
187 !fusion->outputs().empty(), "No output found for this kernel, aborting.");
188
189 for (auto out : fusion->outputs()) {
190 TORCH_INTERNAL_ASSERT(
191 out->getValType() == ValType::TensorView,
192 "Output types from fusions that are not tensors are not supported at this point.");
193
194 const auto maybe_rfactor_domain =
195 out->as<TensorView>()->getMaybeRFactorDomain();
196 // walking through outputs to see if output shapes are dependent on
197 // non-tensor inputs. For which case, we should have disabled output
198 // allocation, since the caching id only looks at tensor shapes.
199 // See issue https://github.com/csarofeen/pytorch/issues/2002
200 std::vector<Val*> output_extents;
201 for (const auto id : maybe_rfactor_domain) {
202 Val* extent = nullptr;
203 if (id->isReduction() || id->isStride()) {
204 continue;
205 } else if (id->isBroadcast() && id->hasExpandedExtent()) {
206 extent = id->expandedExtent();
207 } else {
208 extent = id->extent();
209 }
210 output_extents.emplace_back(extent);
211 }
212 auto dependencies = InputsOf::outputs(fusion, output_extents);
213 if (std::any_of(dependencies.begin(), dependencies.end(), [](Val* val) {
214 return val->isFusionInput();
215 })) {
216 // TODO: parameter cache is too big a hammer here. We should consider
217 // separate the caching logic of output sizes & launch params. Since
218 // output size dependency should only invalidate the output sizes
219 disable_parameter_cache_ = true;
220 break;
221 }
222 }
223
224 if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) {
225 fusion->print();
226 } else if (isDebugDumpEnabled(DebugDumpOption::FusionIrMath)) {
227 fusion->printMath();
228 }
229
230 // TODO: refactor the options_ passed through
231 options_.device = c10::Device(c10::DeviceType::CUDA, args.getDeviceIndex());
232 options_.index_mode = args.getIndexMode();
233 c10::DeviceGuard dg(options_.device);
234
235 TORCH_INTERNAL_ASSERT(
236 options_.device.is_cuda(), "Provided device to CUDA fuser is the CPU.");
237 auto properties = at::cuda::getDeviceProperties(options_.device.index());
238 configured_device_smem_ = properties->sharedMemPerBlock;
239#ifndef USE_ROCM
240 device_smem_limit_ = properties->sharedMemPerBlockOptin;
241#else
242 // don't know if rocm supports opt-in shared memory reconfiguration
243 device_smem_limit_ = properties->sharedMemPerBlock;
244#endif
245 warp_size_ = properties->warpSize;
246
247 lowered_ = std::make_unique<GpuLower>(
248 fusion,
249 options_.index_mode == KernelIndexMode::INT64 ? DataType::Int
250 : DataType::Int32);
251 const auto kernel = lowered_->kernel();
252 fusion_ = lowered_->kernel()->as<Fusion>();
253
254 fusion_id_ = ++fusion_id_counter_;
255 setUsedTVs();
256
257 if (isDebugDumpEnabled(DebugDumpOption::KernelIr)) {
258 kernel->print();
259 }
260
261 if (isDebugDumpEnabled(DebugDumpOption::BankConflictInfo)) {
262 auto bank_conflict_info = getBankConflictInfo(kernel);
263 if (bank_conflict_info.empty()) {
264 std::cout << "===== No bank confliction =====" << std::endl;
265 } else {
266 std::cout << "======= Bank confliction =======" << std::endl;
267 for (auto info : bank_conflict_info) {
268 std::cout << "Expr: " << info.first->toString() << std::endl;
269 auto conflict = info.second;
270 if (conflict.first > 1) {
271 std::cout << "input conflict: " << conflict.first << " way, ";
272 }
273 if (conflict.second > 1) {
274 std::cout << "output conflict: " << conflict.second << " way";
275 }
276 std::cout << std::endl;
277 }
278 std::cout << "================================" << std::endl;
279 }
280 }
281
282 kernel_code_ = codegen::generateCudaKernel(kernel, kernelName());
283 const auto structured_code = getStructuredCode(kernel_code_);
284
285 const auto& kernel_summary = kernel->summary();
286
287 // We currently shouldn't allocate any more shared mem
288 // tensors statically but could keep this path if
289 // needed in later development.
290 if (!kernel_summary.static_smem_allocations.empty()) {
291 kir::ExpressionEvaluator static_evaluator;
292 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
293 const auto static_smem_size = computeSharedMemory(
294 static_evaluator, kernel_summary.static_smem_allocations);
295 TORCH_INTERNAL_ASSERT(
296 static_smem_size < max_static_smem_,
297 "The static shared memory allocation is larger than available memory.");
298 }
299
300 if (kernel_summary.has_dynamic_local_memory_allocations) {
301 std::stringstream ss;
302 ss << "Allocations must be based on constant integers for local memory. However, found: ";
303 for (auto alloc : kernel_summary.dynamic_lmem_allocations) {
304 ss << alloc->buffer()->toString() << ", ";
305 }
306 ss << " have dynamic allocations but are placed in local memory.";
307 TORCH_INTERNAL_ASSERT(false, ss.str());
308 }
309
310 // TODO: pass block_size here;
311 c10::optional<int> block_size = c10::nullopt;
312 if (!args.empty()) {
313 auto expr_eval = executor_utils::bindKernelInputs(args, kernel);
314 auto launch_params =
315 computeLaunchParams(launch_constraints, expr_eval, warp_size_);
316 block_size = launch_params.nThreads();
317 TORCH_INTERNAL_ASSERT(
318 block_size > 0, "launch param inferred block size < 0");
319 }
320
321 // TODO: high water mark should be computed via occupancy API after
322 // compilation.
323
324 // Basically setting high water martk as 1 when we don't provide args for
325 // compilation, it will just generate a kernel that gets ditched at the first
326 // run - not great. We should have better heuristics.
327 block_size_high_water_mark = std::max<int64_t>(
328 (block_size.has_value() ? block_size.value() : 1),
329 block_size_high_water_mark);
330 std::tie(compiled_kernel_, last_compiler_log_) = executor_utils::nvrtcCompile(
331 structured_code,
332 (kernelNamespace() + "::" + kernelName()).c_str(),
333 fusion_id_,
334 block_size);
335 TORCH_INTERNAL_ASSERT(
336 fusion_id_ > 0, "failed to assign a fusion_id_ after compilation.");
337
338#ifndef USE_ROCM
339 // The driver API call requires an int argument.
340 int max_dynamic_smem = 0;
341 AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuFuncGetAttribute(
342 &max_dynamic_smem,
343 CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
344 compiled_kernel_.function));
345 maybe_available_dynamic_smem_ = max_dynamic_smem;
346#endif
347}
348
349namespace {
350
351void fillTensorWithNan(at::Tensor& t) {
352 switch (t.scalar_type()) {
353 case at::ScalarType::Byte:
354 t.fill_(0xFF);
355 break;
356 case at::ScalarType::Char:
357 t.fill_(0x7F);
358 break;
359 case at::ScalarType::Short:
360 t.fill_(0x7FFF);
361 break;
362 case at::ScalarType::Int:
363 t.fill_(0x7FFFFFFF);
364 break;
365 case at::ScalarType::Long:
366 t.fill_(0x7FFFFFFFFFFFFFFFL);
367 break;
368 case at::ScalarType::Bool:
369 t.fill_(true);
370 break;
371 case at::ScalarType::Half:
372 case at::ScalarType::Float:
373 case at::ScalarType::Double:
374 case at::ScalarType::BFloat16:
375 t.fill_(std::nan(""));
376 break;
377 case at::ScalarType::ComplexHalf:
378 case at::ScalarType::ComplexFloat:
379 case at::ScalarType::ComplexDouble:
380 t.fill_(c10::complex<double>(std::nan(""), std::nan("")));
381 break;
382 default:
383 TORCH_INTERNAL_ASSERT(false, "Unknown dtype");
384 }
385}
386
387at::Tensor inferAndAlloc(
388 const TensorView* tv,
389 const std::vector<Val*>& sizes,
390 kir::ExpressionEvaluator& expr_eval,
391 // Map from dim -> expanded size of TV if any expanded broadcast dimensions
392 // exist
393 std::unordered_map<int, Val*> expanded_map,
394 const CompileOptions& options,
395 bool zero_init = false) {
396 FUSER_PERF_SCOPE("inferAndAlloc");
397
398 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
399 // Going to infer all the sizes of the TensorView
400 std::vector<int64_t> inferred_sizes;
401 // Expanded sizes is at maximum the same size of inferred_sizes, as you could
402 // have a fully broadcasted tensor that's being expanded
403 std::vector<int64_t> expanded_sizes;
404 bool expanded_dim = false;
405 for (const auto size : sizes) {
406 const auto inferred_val = expr_eval.evaluate(size);
407 TORCH_INTERNAL_ASSERT(
408 inferred_val.has_value(),
409 "Could not launch kernel as program could not infer ",
410 size->toString(),
411 "(",
412 size->name(),
413 ") for the buffer ",
414 tv->toString());
415 inferred_sizes.push_back(inferred_val->as<int64_t>());
416 if (expanded_map.count(expanded_sizes.size())) {
417 auto expanded_size = expanded_map.at(expanded_sizes.size());
418 const auto inferred_expanded_size = expr_eval.evaluate(expanded_size);
419 TORCH_INTERNAL_ASSERT(
420 inferred_expanded_size.has_value(),
421 "Could not launch kernel as program could not infer the expanded extent ",
422 expanded_size->toString(),
423 "(",
424 expanded_size->name(),
425 ") for the buffer ",
426 tv->toString());
427 if (inferred_val.value() != 1) {
428 TORCH_INTERNAL_ASSERT(
429 inferred_val.value() == inferred_expanded_size.value(),
430 "Attempted an expand on a non-broadcasted dimension,",
431 " but the expand doesn't match the dimensions size.");
432 } else {
433 expanded_dim = true;
434 }
435 expanded_sizes.push_back(inferred_expanded_size->as<int64_t>());
436 } else {
437 expanded_sizes.push_back(inferred_val->as<int64_t>());
438 }
439 }
440
441 const auto at_type = data_type_to_aten(tv->dtype());
442 const auto tensor_options =
443 at::TensorOptions().dtype(at_type).device(options.device);
444 c10::IntArrayRef isizes(inferred_sizes);
445
446 if (zero_init) {
447 auto zeros = at::zeros(isizes, tensor_options);
448 if (expanded_dim) {
449 return zeros.expand(expanded_sizes);
450 }
451 return zeros;
452 } else {
453 // Non Variable type guard for empty_cuda call
454 at::AutoDispatchBelowADInplaceOrView non_variable_type_mode;
455 auto empty = at::empty(isizes, tensor_options);
456 if (shouldFillAllocationWithNan()) {
457 fillTensorWithNan(empty);
458 }
459 if (expanded_dim) {
460 return empty.expand(expanded_sizes);
461 }
462 return empty;
463 }
464}
465
466at::Tensor inferAndAllocOutput(
467 const TensorView* tv,
468 kir::ExpressionEvaluator& expr_eval,
469 const CompileOptions& options,
470 bool zero_init = false) {
471 const auto domain = tv->domain();
472 const auto maybe_rfactor_domain = domain->hasRFactor()
473 ? domain->getRFactorDomain()
474 : domain->getRootDomain();
475
476 std::vector<Val*> sizes;
477 std::unordered_map<int, Val*> expand_map;
478
479 for (const auto id : maybe_rfactor_domain) {
480 if (id->isReduction() || id->isStride()) {
481 continue;
482 }
483 sizes.push_back(id->extent());
484 if (id->isBroadcast() && id->hasExpandedExtent()) {
485 expand_map[sizes.size() - 1] = id->expandedExtent();
486 }
487 }
488 return inferAndAlloc(tv, sizes, expr_eval, expand_map, options, zero_init);
489}
490
491} // namespace
492
493uint64_t FusionExecutor::computeSharedMemory(
494 kir::ExpressionEvaluator& expr_eval,
495 const std::vector<const kir::Allocate*>& buffers,
496 bool align_padding,
497 uint64_t total) {
498 FUSER_PERF_SCOPE("computeSharedMemory");
499 for (auto smem_alloc : buffers) {
500 // If this buffer aliases another buffer,
501 // then do not allocate memory for this buffer.
502 if (smem_alloc->alias() == nullptr) {
503 const auto inferred_val = expr_eval.evaluate(smem_alloc->size());
504 if (inferred_val.has_value()) {
505 const uint64_t data_size = dataTypeSize(smem_alloc->buffer()->dtype());
506 // Add padding to align dynamic shared memory
507 if (align_padding) {
508#ifndef USE_ROCM
509 const int align_size = 16; // always align to 16B/128b.
510#else
511 const int align_size = 8; // see codegen.cpp for HIP
512#endif
513 total = ceilDiv(total, align_size) * align_size;
514 }
515 total += inferred_val->as<int64_t>() * data_size;
516 } else {
517 TORCH_INTERNAL_ASSERT(
518 false,
519 "Failed to evaluate the size ",
520 smem_alloc->size(),
521 " of shared memory buffer - T",
522 smem_alloc->buffer()->name());
523 }
524 }
525 }
526 return total;
527}
528
529LaunchParams FusionExecutor::computeLaunchParams(
530 const LaunchParams& launch_constraints,
531 kir::ExpressionEvaluator& expr_eval,
532 const int warp_size) {
533 FUSER_PERF_SCOPE("FusionExecutor::ComputeLaunchParams");
534 TORCH_INTERNAL_ASSERT(warp_size > 0, "WARP_SIZE should be larger than 0");
535
536 LaunchParams launch_params;
537
538 auto data_cache = compileTimeDataCache();
539
540 auto lower = lowered_.get();
541 auto& used_tvs = getUsedTVs();
542 auto parallel_binding_ids_entry =
543 executor_utils::caching::ExecutorCompileTimeEntry<
544 executor_utils::caching::ParallelBindingIterDomains>(
545 data_cache, [&used_tvs, &lower]() {
546 return std::make_unique<std::vector<IterDomain*>>(
547 executor_utils::getParallelBindingsIterDomains(
548 lower, used_tvs));
549 });
550 auto& parallel_binding_ids = parallel_binding_ids_entry.get();
551
552 auto parallel_iter_extent_entry =
553 executor_utils::caching::ExecutorCompileTimeEntry<
554 executor_utils::caching::ParallelIterExtentMap>(
555 data_cache, [&parallel_binding_ids]() {
556 return executor_utils::getParallelIterExtents(parallel_binding_ids);
557 });
558 auto& parallel_iter_extents = parallel_iter_extent_entry.get();
559
560 auto simplified_parallel_iter_extent_entry =
561 executor_utils::caching::ExecutorCompileTimeEntry<
562 executor_utils::caching::SimplifiedParallelIterExtentMap>(
563 data_cache, [&parallel_binding_ids, &lower]() {
564 return executor_utils::getSimplifiedParallelIterExtents(
565 lower, parallel_binding_ids);
566 });
567 auto& simplified_parallel_iter_extents =
568 simplified_parallel_iter_extent_entry.get();
569
570 auto warp_padded_parallel_entry =
571 executor_utils::caching::ExecutorCompileTimeEntry<
572 executor_utils::caching::WarpPaddedParallelExtents>(
573 data_cache, [&parallel_binding_ids, &lower]() {
574 return executor_utils::getWarpPaddedExtentsInfo(
575 lower->kernel(), parallel_binding_ids);
576 });
577 auto& warp_padded_extent_set =
578 warp_padded_parallel_entry.get().warp_padded_extent_set;
579 auto& warp_padded_constant =
580 warp_padded_parallel_entry.get().warp_padded_constant;
581
582 // TODO: Need to redesign this part a bit to
583 // find the right place to trigger evaluate
584 if (expr_eval.precomputedValues()) {
585 expr_eval.precomputedValues()->bindParallelExtents(
586 parallel_iter_extents, launch_constraints);
587 expr_eval.precomputedValues()->evaluate();
588 }
589
590 // If any dimension was set in launch constraints we need to run through
591 // IterDomains that have been parallelized, and bind those values. Or make
592 // sure if they could be inferred the inference matches what was set.
593 for (auto& entry : parallel_iter_extents) {
594 auto p_type = entry.first;
595 if (launch_constraints.hasDim(p_type)) {
596 auto parallel_extents = entry.second;
597 for (auto extent : parallel_extents) {
598 auto inferred_val = expr_eval.evaluate(extent);
599 if (inferred_val.has_value()) {
600 // This value could have been inferred, make sure it was set right.
601 bool valid =
602 inferred_val.value() == launch_constraints.getDim(p_type) ||
603 launch_constraints.getRawVal(p_type) == -1;
604 if (!useFallback() && !valid) {
605 TORCH_WARN_ONCE(
606 "Cannot validate parallelization scheme, "
607 "this may be due to mixed broadcast axes that are parallelized.");
608 }
609 } else if (!expr_eval.precomputedValues()) {
610 expr_eval.bind(extent, launch_constraints.getDim(p_type));
611 }
612 if (!launch_params.hasDim(p_type)) {
613 // Bind the launch constraint into our evaluation context
614 launch_params.bind(launch_constraints.getDim(p_type), p_type);
615 // Makes sure the p-types bound to evaluators are the
616 // final values that will become the actual launch
617 // param size to ensure accurate smem buffer size
618 // computation.
619 expr_eval.bind(p_type, launch_constraints.getDim(p_type));
620 }
621 }
622 }
623 }
624
625 // Run through the rest of the parallel IterDomains and infer their size
626 for (auto& entry : simplified_parallel_iter_extents) {
627 FUSER_PERF_SCOPE("FusionExecutor::ParallelBindingResolution");
628 auto p_type = entry.first;
629 auto parallel_extents = entry.second;
630 // Select the maxmimum value out of all the parallel extents
631 int64_t maximum_value = std::numeric_limits<int64_t>::min();
632 for (auto extent : parallel_extents) {
633 auto val = expr_eval.evaluate(extent);
634 TORCH_INTERNAL_ASSERT(
635 val.has_value(),
636 "Tried to evaluate the extent, ",
637 extent->toInlineString(),
638 " for the ptype: ",
639 p_type,
640 " to set launch bounds but could not.");
641
642 // apply padding to the extent if needed
643 if (warp_padded_extent_set.count(extent)) {
644 // Check if the extent has const value
645 auto padded_constant_it = warp_padded_constant.find(extent);
646
647 if (padded_constant_it != warp_padded_constant.end()) {
648 // If already specified padded to constant, need to check
649 // runtime value not over the constant bound
650 TORCH_INTERNAL_ASSERT(*val <= padded_constant_it->second);
651 *val = padded_constant_it->second;
652 } else {
653 // If no specified constant, pad to the smallest multiple of warp
654 // above the value.
655 auto padded_number_of_warps = (*val + warp_size - 1) / warp_size;
656 *val = warp_size * padded_number_of_warps;
657 }
658 TORCH_INTERNAL_ASSERT(
659 *val <= 1024, "padded dimension larger than max block size");
660 }
661 maximum_value = std::max(maximum_value, val->as<int64_t>());
662 }
663 // Protect for size-0 tensors, they still have a value so would prefer to
664 // bind nothing than 0
665 if (maximum_value > 0) {
666 expr_eval.bind(p_type, maximum_value);
667 launch_params.bind(maximum_value, p_type);
668 }
669 }
670
671 // Re-run the integer machine with all
672 // the thread sizes now determined.
673 if (expr_eval.precomputedValues()) {
674 expr_eval.precomputedValues()->evaluate();
675 }
676
677 const auto kernel = lowered_->kernel();
678 const auto& kernel_summary = kernel->summary();
679
680 // Calculate Dynamic Shared Memory Size
681 // Add workspace for reduction and broadcast
682 uint64_t reduction_broadcast_workspace = 0;
683 const bool has_workspace = kernel_summary.has_block_reductions ||
684 kernel_summary.has_grid_reductions ||
685 kernel_summary.has_block_broadcasts || kernel_summary.has_grid_broadcasts;
686 if (has_workspace &&
687 kernel_summary.largest_smem_data_type != DataType::Null) {
688 // Not using nThreads here since it does not handle uninitialized value
689
690 // TODO: here is an optimization opportunity since welford uses int64_t for
691 // N while the data type is not neccessarily double. But it may need more
692 // work on the alignment
693 const int welford_factor =
694 kernel_summary.has_block_welford || kernel_summary.has_grid_welford ? 3
695 : 1;
696 reduction_broadcast_workspace =
697 dataTypeSize(kernel_summary.largest_smem_data_type) * welford_factor *
698 launch_params.bdimx() * launch_params.bdimy() * launch_params.bdimz();
699 }
700
701 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
702 const uint64_t dynamic_smem_size = computeSharedMemory(
703 expr_eval,
704 kernel_summary.dynamic_smem_allocations,
705 true,
706 reduction_broadcast_workspace);
707
708 // Check that requested smem size can be dynamically allocated.
709 // This check is only done once a kernel has been compiled, since
710 // maybe_available_dynamic_smem_ needs to be evaluated on
711 // a compiled kernel.
712 if (maybe_available_dynamic_smem_.has_value()) {
713 // Dynamic shared memory space that we can allocate without
714 // carving more space from L1.
715 const uint64_t available_dynamic_smem_without_reconfiguration =
716 maybe_available_dynamic_smem_.value();
717 // Maximum additional shared memory size we could request
718 // if we do re-configuration.
719 const uint64_t additional_dynamic_smem_available_through_reconfiguration =
720 device_smem_limit_ - configured_device_smem_;
721
722 TORCH_INTERNAL_ASSERT(
723 (dynamic_smem_size) <
724 (available_dynamic_smem_without_reconfiguration +
725 additional_dynamic_smem_available_through_reconfiguration),
726 "The total shared memory allocation is larger than available memory.",
727 " Dynamic size: ",
728 dynamic_smem_size,
729 ". Available size: ",
730 maybe_available_dynamic_smem_.value(),
731 ". Configured smem size: ",
732 configured_device_smem_,
733 ". Device limit size: ",
734 device_smem_limit_);
735 }
736
737 launch_params.setSmem(dynamic_smem_size);
738
739 return launch_params;
740}
741
742FusionExecutor::GlobalBuffers FusionExecutor::allocGlobalVals(
743 kir::ExpressionEvaluator& expr_eval) {
744 FUSER_PERF_SCOPE("FusionExecutor::AllocGlobalVals");
745 GlobalBuffers global_buffers;
746 const auto kernel = lowered_->kernel();
747 const auto& kernel_summary = kernel->summary();
748 for (auto alloc : kernel_summary.global_allocations) {
749 TORCH_INTERNAL_ASSERT(
750 alloc->buffer()->isA<TensorView>(),
751 "Cannot allocate global buffers that are not tensors.");
752 auto tv = alloc->buffer()->as<TensorView>();
753 if (tv->isFusionOutput()) {
754 continue;
755 }
756 if (alloc->zeroInit()) {
757 global_buffers.buffers.push_back(
758 inferAndAlloc(tv, alloc->shape(), expr_eval, {}, options_, true));
759 global_buffers.zero_init.push_back(true);
760 } else {
761 global_buffers.buffers.push_back(
762 inferAndAlloc(tv, alloc->shape(), expr_eval, {}, options_, false));
763 global_buffers.zero_init.push_back(false);
764 }
765 // Remember the tensor buffer used for storing kernel profile
766 if (isOptionEnabled(EnableOption::KernelProfile) &&
767 tv == kernel->profile().getBuffer()) {
768 global_buffers.profile_buffer = global_buffers.buffers.back();
769 }
770 }
771
772 return global_buffers;
773}
774
775std::vector<at::Tensor> FusionExecutor::allocOutputs(
776 const KernelArgumentHolder& args,
777 kir::ExpressionEvaluator& expr_eval,
778 const std::unordered_set<int>& alias_indices) {
779 FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs");
780 const auto kernel = lowered_->kernel();
781 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
782 std::vector<at::Tensor> outputs;
783 TORCH_INTERNAL_ASSERT(
784 args.size() == kernel->inputs().size(),
785 "kernel arguments length does not match runtime arguments.");
786 for (const auto out_i : c10::irange(kernel->outputs().size())) {
787 if (kernel->outputs()[out_i]->isFusionInput()) {
788 // pushing empty tensor for trivial forwarding. Since we handle this in
789 // integration, see step 1 - note [trivial forwarding]
790 c10::Device device(c10::DeviceType::CUDA, args.getDeviceIndex());
791 const auto tensor_options =
792 at::TensorOptions().dtype(at::kFloat).device(device);
793 outputs.emplace_back(at::empty({0}, tensor_options));
794 } else {
795 TORCH_INTERNAL_ASSERT(
796 kernel->outputs()[out_i]->isA<TensorView>(),
797 "Cannot allocate outputs that are not tensors.");
798 auto output = kernel->outputs()[out_i]->as<TensorView>();
799 if (alias_indices.count(out_i) != 0) {
800 // aliasing to inputs, no need to allocate real output, just push empty
801 // tensor here.
802 outputs.emplace_back();
803 } else {
804 outputs.push_back(
805 inferAndAllocOutput(output, expr_eval, options_, false));
806 }
807 }
808 }
809 return outputs;
810}
811
812void FusionExecutor::setUsedTVs() {
813 auto used_vals = fusion_->usedMathVals();
814 auto used_tvs = ir_utils::filterByType<TensorView>(used_vals);
815 used_tvs_.clear();
816 used_tvs_.insert(used_tvs_.begin(), used_tvs.begin(), used_tvs.end());
817}
818
819KernelArgumentHolder FusionExecutor::evaluateOutputSizes(
820 const KernelArgumentHolder& args,
821 kir::ExpressionEvaluator& expr_eval,
822 const std::unordered_set<int>& alias_indices) {
823 FUSER_PERF_SCOPE("FusionExecutor::AllocOutputs");
824 const auto kernel = lowered_->kernel();
825
826 KernelArgumentHolder ret(args.getIndexMode());
827 ret.setDeviceIndex(args.getDeviceIndex());
828
829 CompileOptions meta_options = options_;
830 meta_options.device = c10::Device(DeviceType::Meta, 0);
831
832 for (const auto out_i : c10::irange(kernel->outputs().size())) {
833 // If the output is just trivially the input, just "copy" it over, see note
834 // [trivial forwarding]
835 if (kernel->outputs()[out_i]->isFusionInput()) {
836 for (auto inp_i : c10::irange(kernel->inputs().size())) {
837 if (kernel->inputs()[inp_i] == kernel->outputs()[out_i]) {
838 TORCH_INTERNAL_ASSERT(
839 inp_i < args.size(),
840 "Issue with an input showing up as output, couldn't find input.");
841
842 auto tensor_arg_abstract =
843 dynamic_cast<const TensorArgAbstract*>(args[inp_i]);
844 TORCH_INTERNAL_ASSERT(
845 tensor_arg_abstract,
846 "Cannot register a scalar as an output in a fusion.");
847 ret.push(tensor_arg_abstract);
848 break;
849 }
850 }
851 } else {
852 TORCH_INTERNAL_ASSERT(
853 kernel->outputs()[out_i]->isA<TensorView>(),
854 "Cannot allocate outputs that are not tensors.");
855 auto output = kernel->outputs()[out_i]->as<TensorView>();
856 if (alias_indices.count(out_i) != 0) {
857 // aliasing to inputs, no need to allocate real output
858 // but we still need to push an entry here.
859 ret.push(int64_t(0));
860 } else {
861 // TODO: we are using meta here, which is bad since it doesn't account
862 // for devices. Switch to fake tensor instead
863 ret.push(inferAndAllocOutput(output, expr_eval, meta_options, false));
864 }
865 }
866 }
867 return ret;
868}
869
870KernelArgumentHolder FusionExecutor::inferOutputSizes(
871 const KernelArgumentHolder& args,
872 const LaunchParams& launch_constraints) {
873 FUSER_PERF_SCOPE("FusionExecutor::RunFusion");
874
875 ExecutorEntry* executor_entry = nullptr;
876 c10::optional<size_t> opt_code = args.getCacheId();
877 if (opt_code.has_value()) {
878 executor_entry = &executor_entry_lookup_[*opt_code];
879 }
880
881 at::cuda::jit::initializeCudaContext();
882 TORCH_INTERNAL_ASSERT(lowered_);
883
884 TORCH_INTERNAL_ASSERT(
885 !executor_entry || !executor_entry->init,
886 "compile kernel shouldn't hit a pre-existing cache");
887 FUSER_PERF_SCOPE("ExecutorRunFusion::ValidateAndInitialize");
888 // TODO: validate kernel inputs currently won't be happy, since our fusion
889 // args are mapped with `meta` tensor instead of `cuda` tensor, check if this
890 // would be resolved with FakeTensor
891 // executor_utils::validateKernelInputs(fusion_, args, options_.device);
892
893 if (!evaluator_precomputed_values_) {
894 evaluator_precomputed_values_ =
895 std::make_unique<KernelPrecomputedValues>(lowered_->kernel());
896 }
897
898 kir::ExpressionEvaluator expr_eval;
899 evaluator_precomputed_values_->bindKernelInputs(lowered_->kernel(), args);
900 expr_eval.precomputedValues() = evaluator_precomputed_values_.get();
901
902 // I think this binds something to expr_eval, so even though we are not using
903 // launch_params_, we still need this in order to infer output shapes.
904 launch_params_ =
905 computeLaunchParams(launch_constraints, expr_eval, warp_size_);
906
907 executor_utils::validateVectorizedTensors(
908 lowered_.get()->kernel(), args, {}, compileTimeDataCache(), expr_eval);
909
910 auto alias_indices_entry = executor_utils::caching::ExecutorCompileTimeEntry<
911 executor_utils::caching::InputAliasIndices>(
912 compileTimeDataCache(), [&]() {
913 return std::make_unique<std::vector<std::pair<int, int>>>(
914 fusion_->getInputAliasIndices());
915 });
916
917 auto& alias_indices = alias_indices_entry.get();
918
919 // NOLINTNEXTLINE(bugprone-branch-clone)
920 auto output_alias_indices_entry =
921 executor_utils::caching::ExecutorCompileTimeEntry<
922 executor_utils::caching::OutputAliasIndices>(
923 compileTimeDataCache(), [&]() {
924 return std::make_unique<std::unordered_set<int>>(
925 fusion_->getOutputAliasIndices());
926 });
927
928 auto& output_alias_indices = output_alias_indices_entry.get();
929
930 auto ret = evaluateOutputSizes(args, expr_eval, output_alias_indices);
931
932 for (const auto& entry : alias_indices) {
933 auto aliased_output_index = entry.first;
934 auto aliased_input_index = entry.second;
935 TORCH_INTERNAL_ASSERT(
936 args[aliased_input_index]->isType(ArgType::Tensor),
937 "alias io only supports tensor");
938 ret.swap(aliased_output_index, args[aliased_input_index]);
939 }
940
941 return ret;
942}
943
944std::vector<at::Tensor> FusionExecutor::runFusion(
945 KernelArgumentHolder& args,
946 const LaunchParams& launch_constraints,
947 const std::vector<at::Tensor>& outputs) {
948 FUSER_PERF_SCOPE("FusionExecutor::RunFusion");
949 TORCH_INTERNAL_ASSERT(compiled());
950 TORCH_INTERNAL_ASSERT(
951 fusion_id_ > 0, "Cannot run fusion, it was not compiled.");
952 TORCH_INTERNAL_ASSERT(
953 !args.getCacheId().has_value() || outputs.empty(),
954 "short cut input cache is not compatible with pre-allocated output");
955
956 size_t num_inputs = args.size();
957
958 if (isDebugDumpEnabled(DebugDumpOption::FusionArgs)) {
959 std::cout << "Arguments for fusion" << fusion_id_ << ":" << std::endl
960 << "Inputs:" << std::endl;
961 for (auto i : c10::irange(args.size())) {
962 args[i]->print();
963 }
964 std::cout << "Outputs:" << std::endl;
965 for (const auto& output : outputs) {
966 std::cout << " " << output.scalar_type() << " " << output.sizes()
967 << " (strides = " << output.strides() << ")" << std::endl;
968 }
969 std::cout << launch_constraints.toString();
970 }
971
972 ExecutorEntry* executor_entry = nullptr;
973 if (args.getCacheId().has_value()) {
974 executor_entry = &executor_entry_lookup_[*args.getCacheId()];
975 }
976
977 c10::DeviceGuard dg(options_.device);
978 auto stream = at::cuda::getCurrentCUDAStream();
979 at::cuda::jit::initializeCudaContext();
980 TORCH_INTERNAL_ASSERT(lowered_);
981 launch_params_ = LaunchParams();
982 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
983 std::vector<at::Tensor> allocated_outputs;
984 GlobalBuffers global_buffers;
985 uint64_t rand_offset = 0;
986
987 if (executor_entry && executor_entry->init && !disable_parameter_cache_) {
988 {
989 // context manager to disable auto grad for `empty_cuda` calls later
990 at::AutoDispatchBelowADInplaceOrView non_variable_type_mode;
991 // take the short-cut for launch if we see a recorded input set again
992 launch_params_ = executor_entry->launch_params;
993 // only allocate outputs when not given
994 if (outputs.empty()) {
995 FUSER_PERF_SCOPE("ExecutorRunFusion::OutputAlloc");
996 for (const auto i : c10::irange(executor_entry->output_sizes.size())) {
997 allocated_outputs.push_back(at::native::empty_strided_cuda(
998 executor_entry->output_sizes[i],
999 executor_entry->output_strides[i],
1000 executor_entry->output_types[i],
1001 c10::nullopt,
1002 options_.device,
1003 c10::nullopt));
1004 if (shouldFillAllocationWithNan()) {
1005 fillTensorWithNan(allocated_outputs.back());
1006 }
1007 }
1008 // Note: aliased output is not returned as output. But we still need it
1009 // for kernel execution, so would need to push them to args
1010 for (const auto& entry : executor_entry->io_alias_indices) {
1011 auto aliased_output_index = entry.first;
1012 auto aliased_input_index = entry.second;
1013 auto tensor_arg_abstract =
1014 dynamic_cast<const TensorArgAbstract*>(args[aliased_input_index]);
1015 TORCH_INTERNAL_ASSERT(
1016 tensor_arg_abstract, "alias io only supports tensor");
1017 allocated_outputs[aliased_output_index] =
1018 tensor_arg_abstract->getTensor();
1019 }
1020 args.push(allocated_outputs);
1021 } else {
1022 TORCH_INTERNAL_ASSERT(
1023 outputs.size() == fusion_->outputs().size(),
1024 __func__,
1025 " provided number of outputs does match fusion output");
1026 allocated_outputs = outputs;
1027 args.push(outputs);
1028 }
1029
1030 {
1031 FUSER_PERF_SCOPE("ExecutorRunFusion::IntermediateBufferAlloc");
1032 for (const auto i : c10::irange(executor_entry->buffer_sizes.size())) {
1033 if (executor_entry->buffer_zero_init[i]) {
1034 global_buffers.buffers.push_back(at::zeros(
1035 executor_entry->buffer_sizes[i],
1036 at::TensorOptions()
1037 .dtype(executor_entry->buffer_types[i])
1038 .device(options_.device)));
1039 global_buffers.zero_init.push_back(true);
1040 } else {
1041 global_buffers.buffers.push_back(at::native::empty_cuda(
1042 executor_entry->buffer_sizes[i],
1043 executor_entry->buffer_types[i],
1044 c10::nullopt,
1045 options_.device,
1046 c10::nullopt));
1047 if (shouldFillAllocationWithNan()) {
1048 fillTensorWithNan(global_buffers.buffers.back());
1049 }
1050 global_buffers.zero_init.push_back(false);
1051 }
1052 }
1053 }
1054 }
1055 rand_offset = executor_entry->rand_offset;
1056 } else {
1057 FUSER_PERF_SCOPE("ExecutorRunFusion::ValidateAndInitialize");
1058 // code path to take when either:
1059 // 1. no opt_code is provided or
1060 // 2. `executor_entry` is not initialized
1061 executor_utils::validateKernelInputs(fusion_, args, options_.device);
1062
1063 if (!evaluator_precomputed_values_) {
1064 evaluator_precomputed_values_ =
1065 std::make_unique<KernelPrecomputedValues>(lowered_->kernel());
1066 }
1067
1068 kir::ExpressionEvaluator expr_eval;
1069 evaluator_precomputed_values_->bindKernelInputs(lowered_->kernel(), args);
1070 expr_eval.precomputedValues() = evaluator_precomputed_values_.get();
1071
1072 launch_params_ =
1073 computeLaunchParams(launch_constraints, expr_eval, warp_size_);
1074
1075 // Recompile the kernel if the number of threads in the block has increased
1076 if (launch_params_.nThreads() > block_size_high_water_mark) {
1077 const auto kernel = lowered_->kernel();
1078 kernel_code_ = codegen::generateCudaKernel(kernel, kernelName());
1079 const auto structured_code = getStructuredCode(kernel_code_);
1080 block_size_high_water_mark = launch_params_.nThreads();
1081
1082 std::tie(compiled_kernel_, last_compiler_log_) =
1083 executor_utils::nvrtcCompile(
1084 structured_code,
1085 (kernelNamespace() + "::" + kernelName()).c_str(),
1086 fusion_id_,
1087 block_size_high_water_mark);
1088 }
1089
1090 if (kernel()->summary().has_cooperative_grid_reduction) {
1091#ifndef USE_ROCM
1092 int num_blocks_per_SM = -1;
1093 at::globalContext().getNVRTC().cuOccupancyMaxActiveBlocksPerMultiprocessor(
1094 &num_blocks_per_SM,
1095 compiled_kernel_.function,
1096 (int)(launch_params_.bdimx() * launch_params_.bdimy() * launch_params_.bdimz()),
1097 (size_t)launch_params_.smem());
1098
1099 TORCH_INTERNAL_ASSERT(
1100 (int64_t)(
1101 num_blocks_per_SM *
1102 at::cuda::getDeviceProperties(options_.device.index())
1103 ->multiProcessorCount) >= launch_params_.gdimx() *
1104 launch_params_.gdimy() * launch_params_.gdimz(),
1105 "Wanted to launch a cooperative kernel, however the number of blocks is greater than ",
1106 "what can be resident on the GPU at once. Need: ",
1107 launch_params_.gdimx() * launch_params_.gdimy() *
1108 launch_params_.gdimz(),
1109 " (",
1110 launch_params_.gdimx(),
1111 " * ",
1112 launch_params_.gdimy(),
1113 " * ",
1114 launch_params_.gdimz(),
1115 ") but limited to ",
1116 num_blocks_per_SM,
1117 " * ",
1118 at::cuda::getDeviceProperties(options_.device.index())
1119 ->multiProcessorCount);
1120#else
1121 TORCH_INTERNAL_ASSERT(
1122 false, "Cross grid communication not supported with HIP.");
1123#endif
1124 }
1125
1126 executor_utils::validateVectorizedTensors(
1127 lowered_.get()->kernel(),
1128 args,
1129 outputs,
1130 compileTimeDataCache(),
1131 expr_eval);
1132
1133 auto alias_indices_entry =
1134 executor_utils::caching::ExecutorCompileTimeEntry<
1135 executor_utils::caching::InputAliasIndices>(
1136 compileTimeDataCache(), [&]() {
1137 return std::make_unique<std::vector<std::pair<int, int>>>(
1138 fusion_->getInputAliasIndices());
1139 });
1140
1141 auto& alias_indices = alias_indices_entry.get();
1142
1143 // NOLINTNEXTLINE(bugprone-branch-clone)
1144 if (outputs.empty()) {
1145 auto output_alias_indices_entry =
1146 executor_utils::caching::ExecutorCompileTimeEntry<
1147 executor_utils::caching::OutputAliasIndices>(
1148 compileTimeDataCache(), [&]() {
1149 return std::make_unique<std::unordered_set<int>>(
1150 fusion_->getOutputAliasIndices());
1151 });
1152
1153 auto& output_alias_indices = output_alias_indices_entry.get();
1154
1155 allocated_outputs = allocOutputs(args, expr_eval, output_alias_indices);
1156
1157 for (const auto& entry : alias_indices) {
1158 auto aliased_output_index = entry.first;
1159 auto aliased_input_index = entry.second;
1160 auto tensor_arg_abstract =
1161 dynamic_cast<const TensorArgAbstract*>(args[aliased_input_index]);
1162 TORCH_INTERNAL_ASSERT(
1163 tensor_arg_abstract, "alias io only supports tensor");
1164 allocated_outputs[aliased_output_index] =
1165 tensor_arg_abstract->getTensor();
1166 }
1167 args.push(allocated_outputs);
1168 } else {
1169 allocated_outputs = outputs;
1170 args.push(outputs);
1171 executor_utils::validateKernelOutputs(
1172 fusion_, allocated_outputs, options_.device);
1173 }
1174
1175 global_buffers = allocGlobalVals(expr_eval);
1176
1177 if (kernel()->summary().max_rng_offsets >= 0) {
1178 // NOTE: this is how we map offset to PW kernels in order to have
1179 // identical random number generator to match native PyTorch results.
1180 // But it doesn't really work as it takes assumption how threads are
1181 // binded but is not generally how we handle that in scheduler.
1182 // Refer to `Philox` in generated kernel to understand how the mapping
1183 // works.
1184 rand_offset = (kernel()->summary().max_rng_offsets + 1) * 4;
1185 }
1186
1187 // This is the entry when we have provided `opt_code` but the entry has not
1188 // been initialized yet.
1189 if (executor_entry) {
1190 FUSER_PERF_SCOPE("ExecutorRunFusion::FillCacheEntry");
1191 // record the the short-cut executor entry for the given input set;
1192 executor_entry->launch_params = launch_params_;
1193 executor_entry->io_alias_indices = alias_indices;
1194 for (const auto& output : allocated_outputs) {
1195 executor_entry->output_sizes.push_back(output.sizes().vec());
1196 executor_entry->output_strides.push_back(output.strides().vec());
1197 executor_entry->output_types.push_back(output.scalar_type());
1198 }
1199
1200 for (const auto& i : c10::irange(global_buffers.buffers.size())) {
1201 executor_entry->buffer_sizes.push_back(
1202 global_buffers.buffers[i].sizes().vec());
1203 executor_entry->buffer_types.push_back(
1204 global_buffers.buffers[i].scalar_type());
1205 executor_entry->buffer_zero_init.push_back(global_buffers.zero_init[i]);
1206 }
1207 executor_entry->rand_offset = rand_offset;
1208 executor_entry->init = true;
1209 }
1210 }
1211
1212 // push back global buffers
1213 args.push(global_buffers.buffers);
1214
1215 // push back RNG state if needed
1216 if (lowered_->kernel()->summary().max_rng_offsets >= 0) {
1217 args.appendPhiloxRNGSeed(rand_offset);
1218 }
1219
1220 if (isDebugDumpEnabled(DebugDumpOption::LaunchParam)) {
1221 launch_params_.print();
1222 }
1223
1224 if (isDebugDumpEnabled(DebugDumpOption::KernelArgs)) {
1225 std::cout << "Arguments for kernel" << fusion_id_ << ":" << std::endl
1226 << "Inputs:" << std::endl;
1227 for (auto i : c10::irange(args.size())) {
1228 args[i]->print();
1229 }
1230 std::cout << "Outputs:" << std::endl;
1231 // note: add aliased outputs here.
1232 for (const auto& output : allocated_outputs) {
1233 std::cout << " " << output.scalar_type() << " " << output.sizes()
1234 << " (strides = " << output.strides()
1235 << ", address = " << output.data_ptr() << ")" << std::endl;
1236 }
1237 std::cout << "Reduction and semaphore buffers:" << std::endl;
1238 TORCH_INTERNAL_ASSERT(
1239 global_buffers.buffers.size() == global_buffers.zero_init.size(),
1240 "global_buffer buffer & zero_init container should have identical sizes");
1241 for (const auto i : c10::irange(global_buffers.buffers.size())) {
1242 const auto& buffer = global_buffers.buffers[i];
1243 const auto& zero_init = global_buffers.zero_init[i];
1244 std::cout << " " << buffer.scalar_type() << " " << buffer.sizes()
1245 << " is_zero_initialized: " << zero_init << std::endl;
1246 }
1247 }
1248
1249 cudaEvent_t start_event = {};
1250 cudaEvent_t finish_event = {};
1251
1252 if (measure_kernel_time_ ||
1253 isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth) ||
1254 isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) {
1255 C10_CUDA_CHECK(cudaEventCreate(&start_event));
1256 C10_CUDA_CHECK(cudaEventCreate(&finish_event));
1257 C10_CUDA_CHECK(cudaEventRecord(start_event));
1258 }
1259
1260 if (execute_kernel_) {
1261 if (maybe_available_dynamic_smem_.has_value() &&
1262 size_t(launch_params_.smem()) > maybe_available_dynamic_smem_.value()) {
1263#ifndef USE_ROCM
1264 // Increase limit of dynamic shared memory if needed.
1265 AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuFuncSetAttribute(
1266 compiled_kernel_.function,
1267 CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
1268 launch_params_.smem()));
1269#else
1270 TORCH_INTERNAL_ASSERT(
1271 false, "cuFuncSetAttribute not supported with HIP.");
1272#endif
1273 }
1274 if (!kernel()->summary().has_cooperative_grid_reduction) {
1275 FUSER_PERF_SCOPE("ExecutorRunFusion::cuLaunchKernel");
1276 AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLaunchKernel(
1277 compiled_kernel_.function,
1278 launch_params_.gdimx(),
1279 launch_params_.gdimy(),
1280 launch_params_.gdimz(),
1281 launch_params_.bdimx(),
1282 launch_params_.bdimy(),
1283 launch_params_.bdimz(),
1284 launch_params_.smem(),
1285 stream,
1286 args.getBuffer(),
1287 nullptr));
1288 } else {
1289#ifndef USE_ROCM
1290 FUSER_PERF_SCOPE("ExecutorRunFusion::cuLaunchCooperativeKernel");
1291 AT_CUDA_DRIVER_CHECK(
1292 at::globalContext().getNVRTC().cuLaunchCooperativeKernel(
1293 compiled_kernel_.function,
1294 launch_params_.gdimx(),
1295 launch_params_.gdimy(),
1296 launch_params_.gdimz(),
1297 launch_params_.bdimx(),
1298 launch_params_.bdimy(),
1299 launch_params_.bdimz(),
1300 launch_params_.smem(),
1301 stream,
1302 args.getBuffer()));
1303#else
1304 TORCH_INTERNAL_ASSERT(
1305 false, "Cross grid communication not supported with HIP.");
1306#endif
1307 }
1308 }
1309
1310 if (measure_kernel_time_ ||
1311 isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth) ||
1312 isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) {
1313 C10_CUDA_CHECK(cudaEventRecord(finish_event));
1314 C10_CUDA_CHECK(cudaEventSynchronize(start_event));
1315 C10_CUDA_CHECK(cudaEventSynchronize(finish_event));
1316 C10_CUDA_CHECK(
1317 cudaEventElapsedTime(&kernel_time_ms_, start_event, finish_event));
1318 C10_CUDA_CHECK(cudaEventDestroy(start_event));
1319 C10_CUDA_CHECK(cudaEventDestroy(finish_event));
1320
1321 bytes_processed_ = 0;
1322 // Figure how many bytes are inputs, outputs, and temporary buffers
1323 for (auto i : c10::irange(num_inputs)) {
1324 if (auto tensor_arg_abstract =
1325 dynamic_cast<const TensorArgAbstract*>(args[i])) {
1326 bytes_processed_ += tensor_arg_abstract->numel() *
1327 dataTypeSize(tensor_arg_abstract->getDataType());
1328 }
1329 }
1330 for (const auto& output : allocated_outputs) {
1331 bytes_processed_ += output.numel() *
1332 dataTypeSize(aten_to_data_type(output.scalar_type()));
1333 }
1334
1335 if (isDebugDumpEnabled(DebugDumpOption::EffectiveBandwidth)) {
1336 double gb_per_s =
1337 ((double)bytes_processed_ / ((double)kernel_time_ms_ / 1000)) /
1338 (double)1.0e9;
1339 std::cout << "kernel" << fusion_id_ << " run in " << kernel_time_ms_
1340 << " ms, achieved: " << gb_per_s << " GB/s" << std::endl;
1341 }
1342 }
1343
1344 if (isOptionEnabled(EnableOption::KernelProfile)) {
1345 std::cout << kernel()->profile().toString(global_buffers.profile_buffer);
1346 }
1347
1348 return allocated_outputs;
1349}
1350
1351void FusionExecutor::compileRtc(
1352 const std::string& code,
1353 const std::string& name,
1354 bool structured,
1355 CompileOptions options) {
1356 FUSER_PERF_SCOPE("ExecutorRunFusion::compileRtc");
1357 std::string scode;
1358 if (!structured) {
1359 scode = getStructuredCode(code);
1360 } else {
1361 scode = code;
1362 }
1363 fusion_id_ = 1;
1364 options_ = options;
1365
1366 std::tie(compiled_kernel_, last_compiler_log_) =
1367 executor_utils::nvrtcCompile(scode, name, fusion_id_);
1368}
1369
1370void FusionExecutor::runRtc(
1371 const LaunchParams& launch_params,
1372 const std::vector<at::Tensor>& args) {
1373 FUSER_PERF_SCOPE("runFusion");
1374
1375 c10::DeviceGuard dg(options_.device);
1376 auto stream = at::cuda::getCurrentCUDAStream();
1377
1378 KernelArgumentHolder kernel_arguments(options_.index_mode);
1379 kernel_arguments.push(args);
1380 AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuLaunchKernel(
1381 compiled_kernel_.function,
1382 launch_params.gdimx(),
1383 launch_params.gdimy(),
1384 launch_params.gdimz(),
1385 launch_params.bdimx(),
1386 launch_params.bdimy(),
1387 launch_params.bdimz(),
1388 launch_params.smem(),
1389 stream,
1390 kernel_arguments.getBuffer(),
1391 nullptr));
1392}
1393
1394} // namespace cuda
1395} // namespace fuser
1396} // namespace jit
1397} // namespace torch
1398