1#include <ATen/cuda/CUDAContext.h>
2#include <ATen/cuda/CUDAGeneratorImpl.h>
3#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
4#include <ATen/native/cuda/jit_utils.h>
5
6#include <c10/util/irange.h>
7
8#include <contiguity.h>
9#include <executor_utils.h>
10#include <instrumentation.h>
11#include <ir_all_nodes.h>
12#include <ir_iostream.h>
13#include <ir_utils.h>
14#include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h>
15#include <torch/csrc/jit/resource_guard.h>
16
17#include <nvfuser_resources/PhiloxCudaStateRaw.h>
18#include <nvfuser_resources/array.h>
19#include <nvfuser_resources/bf16_support.h>
20#include <nvfuser_resources/block_reduction.h>
21#include <nvfuser_resources/block_sync_atomic.h>
22#include <nvfuser_resources/block_sync_default.h>
23#include <nvfuser_resources/broadcast.h>
24#include <nvfuser_resources/fp16_support.h>
25#include <nvfuser_resources/fused_reduction.h>
26#include <nvfuser_resources/fused_welford_helper.h>
27#include <nvfuser_resources/fused_welford_impl.h>
28#include <nvfuser_resources/grid_broadcast.h>
29#include <nvfuser_resources/grid_reduction.h>
30#include <nvfuser_resources/grid_sync.h>
31#include <nvfuser_resources/helpers.h>
32#include <nvfuser_resources/index_utils.h>
33#include <nvfuser_resources/memory.h>
34#include <nvfuser_resources/random_numbers.h>
35#include <nvfuser_resources/swizzle.h>
36#include <nvfuser_resources/tensor.h>
37#include <nvfuser_resources/tensorcore.h>
38#include <nvfuser_resources/tuple.h>
39#include <nvfuser_resources/type_traits.h>
40#include <nvfuser_resources/warp.h>
41#include <nvfuser_resources/welford.h>
42
43#ifdef USE_ROCM
44#include <nvfuser_resources/array_rocm.h>
45#include <nvfuser_resources/bf16_support_rocm.h>
46#include <nvfuser_resources/block_sync_default_rocm.h>
47#include <nvfuser_resources/warp_rocm.h>
48#endif
49
50#ifndef USE_ROCM
51#include <cuda_occupancy.h>
52#endif
53
54#include <fstream>
55
56namespace torch {
57namespace jit {
58namespace fuser {
59namespace cuda {
60namespace executor_utils {
61
62std::string kernelPreamble() {
63 std::stringstream ss;
64
65#ifndef USE_ROCM
66 ss << nvfuser_resources::fp16_support_cu;
67 ss << nvfuser_resources::bf16_support_cu;
68#else
69 ss << R"(
70#ifndef __noinline__
71#define __noinline__ __attribute__((noinline))
72#endif
73#ifndef __forceinline__
74#define __forceinline__ inline __attribute__((always_inline))
75#endif
76#ifndef assert
77#define assert(expr) ((void)0)
78#endif
79#ifndef __align__
80#define __align__(x) __attribute__((aligned(x)))
81#endif
82 )";
83 // fp16 support is automatic, bf16 is not
84 ss << nvfuser_resources::bf16_support_rocm_cu;
85#endif
86
87 // Base classes and helpers
88 ss << nvfuser_resources::tensor_cu;
89 ss << nvfuser_resources::type_traits_cu;
90#ifndef USE_ROCM
91 ss << nvfuser_resources::array_cu;
92#else
93 ss << nvfuser_resources::array_rocm_cu;
94#endif
95 ss << nvfuser_resources::random_numbers_cu;
96 ss << nvfuser_resources::helpers_cu;
97 ss << nvfuser_resources::index_utils_cu;
98 ss << nvfuser_resources::tuple_cu;
99
100 // Synchronization classes
101 if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC")) {
102 ss << nvfuser_resources::block_sync_atomic_cu;
103 } else {
104#ifndef USE_ROCM
105 ss << nvfuser_resources::block_sync_default_cu;
106#else
107 ss << nvfuser_resources::block_sync_default_rocm_cu;
108#endif
109 }
110 ss << nvfuser_resources::grid_sync_cu;
111
112 // Communication classes
113 ss << nvfuser_resources::block_reduction_cu;
114 ss << nvfuser_resources::grid_reduction_cu;
115 ss << nvfuser_resources::grid_broadcast_cu;
116 ss << nvfuser_resources::broadcast_cu;
117 ss << nvfuser_resources::welford_cu;
118#ifndef USE_ROCM
119 ss << nvfuser_resources::warp_cu;
120 ss << nvfuser_resources::tensorcore_cu;
121 ss << nvfuser_resources::memory_cu;
122#else
123 ss << nvfuser_resources::warp_rocm_cu;
124#endif
125 ss << nvfuser_resources::fused_welford_helper_cu;
126 ss << nvfuser_resources::fused_reduction_cu;
127 ss << nvfuser_resources::fused_welford_impl_cu;
128 ss << nvfuser_resources::swizzle_cu;
129
130 // Random utilities
131 ss << nvfuser_resources::PhiloxCudaStateRaw_cu;
132
133 return ss.str();
134}
135
136namespace {
137
138// return false if arg's type, number of dimensions, and device, doesn't match
139// param and provided c10:device
140bool validateKernelArgTensor(
141 const at::Tensor& arg,
142 const Val* param,
143 const c10::Device& device,
144 std::stringstream& msg) {
145 // Arg is a tensor. Param must be a tensor too.
146 if (*param->getValType() != ValType::TensorView) {
147 msg << "Argument is a tensor, but the parameter is not.\n";
148 return false;
149 }
150
151 if (is_cpu_scalar(arg) && !param->as<TensorView>()->isCpuScalar()) {
152 msg << "Argument is CPU Scalar Tensor, but parameter is not.\n";
153 return false;
154 }
155
156 if (!is_cpu_scalar(arg) && !arg.is_cuda()) {
157 msg << "Argument is a CPU tensor which is not supported in fusions.\n";
158 return false;
159 }
160
161 // Check the rank of the tensors.
162 size_t arg_dim = arg.dim();
163 // Note: This requires current Fusion to be active.
164 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
165 size_t param_dim = TensorDomain::noReductions(
166 param->as<TensorView>()->getMaybeRFactorDomain())
167 .size();
168 // see [Note - broadcast support in integration]
169 // Because of broadcasting support handled in integration, we relax the rank
170 // check as necessary.
171 if (arg_dim > param_dim) {
172 msg << "Argument tensor's rank is " << arg_dim << ", but the parameter is "
173 << param_dim << "\n";
174 return false;
175 }
176
177 if (!is_cpu_scalar(arg) && arg.device() != device) {
178 msg << "Argument is on device that is not compiled for."
179 << "\n";
180 return false;
181 }
182 // Check element type
183 at::ScalarType arg_data_type = arg.scalar_type();
184 DataType param_data_type = *param->getDataType();
185 bool match = false;
186 // TODO: remove this switch with `aten_to_data_type`
187 switch (arg_data_type) {
188 case at::ScalarType::Double:
189 match = param_data_type == DataType::Double;
190 break;
191 case at::ScalarType::Half:
192 match = param_data_type == DataType::Half;
193 break;
194 case at::ScalarType::BFloat16:
195 match = param_data_type == DataType::BFloat16;
196 break;
197 case at::ScalarType::Float:
198 match = param_data_type == DataType::Float;
199 break;
200 case at::ScalarType::Long:
201 match = param_data_type == DataType::Int;
202 break;
203 case at::ScalarType::Int:
204 match = param_data_type == DataType::Int32;
205 break;
206 case at::ScalarType::Bool:
207 match = param_data_type == DataType::Bool;
208 break;
209 case at::ScalarType::ComplexFloat:
210 match = param_data_type == DataType::ComplexFloat;
211 break;
212 case at::ScalarType::ComplexDouble:
213 match = param_data_type == DataType::ComplexDouble;
214 break;
215 default:
216 msg << "Argument element type, " << arg_data_type << ", is not supported."
217 << "\n";
218 return false;
219 }
220 if (!match)
221 msg << "Argument element type is " << arg_data_type
222 << ", but the parameter is " << param_data_type << "\n";
223 return match;
224}
225
226// Return false if arg_type doesn't match the type in param
227bool validateKernelArgScalar(
228 const ArgAbstract* arg,
229 const Val* param,
230 std::stringstream& msg) {
231 TORCH_INTERNAL_ASSERT(
232 param->getDataType().has_value(), "kernel param should have data type");
233 DataType param_type = *param->getDataType();
234 bool match = false;
235 switch (arg->type()) {
236 case ArgType::Long:
237 match = param_type == DataType::Int || param_type == DataType::Int32;
238 break;
239 case ArgType::Double:
240 match = param_type == DataType::Double || param_type == DataType::Float ||
241 param_type == DataType::Half || param_type == DataType::BFloat16;
242 break;
243 case ArgType::Bool:
244 match = param_type == DataType::Bool;
245 break;
246 case ArgType::ComplexDouble:
247 match = param_type == DataType::ComplexDouble ||
248 param_type == DataType::ComplexFloat;
249 break;
250 default:
251 // TODO: We need to verify that param is actually a scalar
252 msg << "Argument is not a scalar, but the parameter is."
253 << "\n";
254 return false;
255 }
256 if (!match) {
257 msg << "Argument type is " << argTypeToString(arg->type())
258 << ", but the parameter is " << param_type << "\n";
259 }
260 return match;
261}
262
263// Return false if arg and param don't match up and if arg's device (if a
264// tensor) doesn't match provided device
265bool validateKernelArg(
266 const ArgAbstract* arg,
267 const Val* param,
268 const c10::Device& device,
269 std::stringstream& msg) {
270 if (auto tensor_arg_abstract = dynamic_cast<const TensorArgAbstract*>(arg)) {
271 // TODO: don't use get tensor here. We would want to remove tensor reference
272 // for async compilation
273 return validateKernelArgTensor(
274 tensor_arg_abstract->getTensor(), param, device, msg);
275 } else if (arg->isType(ArgType::CpuScalarTensor)) {
276 // TODO: merge this one with above
277 // TODO: we need to check cpu scalar dtyp matches param
278 bool match = param->as<TensorView>()->isCpuScalar();
279 if (!match) {
280 msg << "Argument is scalar type, but kernel parameter is not\n";
281 }
282 return match;
283 } else {
284 return validateKernelArgScalar(arg, param, msg);
285 }
286}
287
288// Return true if all the tensors have the same stride, assumes all tensors are
289// contiguous
290bool checkSameStride(const std::vector<c10::IValue>& tensors) {
291 if (tensors.size() < 2) {
292 return true;
293 }
294 for (const auto idx : c10::irange(tensors.size() - 1)) {
295 auto current = tensors[idx];
296 auto next = tensors[idx + 1];
297 if (!current.isTensor() || !next.isTensor()) {
298 return false;
299 }
300
301 const auto& current_tensor = current.toTensor();
302 const auto& next_tensor = next.toTensor();
303 if (current_tensor.ndimension() != next_tensor.ndimension()) {
304 return false;
305 }
306
307 for (const auto i : c10::irange(current_tensor.ndimension())) {
308 if (current_tensor.stride(i) != next_tensor.stride(i)) {
309 return false;
310 }
311 }
312 }
313 return true;
314}
315
316// Return true if all the tensors are contiguous and have the same striding
317bool checkSameContiguity(const std::vector<c10::IValue>& tensors) {
318 if (tensors.size() < 2) {
319 return true;
320 }
321
322 auto reference = tensors.front();
323 if (!reference.isTensor()) {
324 return false;
325 }
326
327 // Determine if the reference tensor is contiguous
328 const auto& reference_tensor = reference.toTensor();
329 int64_t expected_stride = 1;
330 for (const auto i : c10::irange(1, reference_tensor.ndimension() + 1)) {
331 int64_t ind = reference_tensor.ndimension() - i;
332 if (reference_tensor.size(ind) == 1) {
333 continue;
334 }
335 if (reference_tensor.stride(ind) != expected_stride) {
336 return false;
337 }
338 expected_stride *= reference_tensor.size(ind);
339 }
340
341 // Check if all the tensors have the same contiguity
342 return checkSameStride(tensors);
343}
344
345bool checkValidMisalignedTensors(
346 const std::unordered_set<TensorView*>& inp_tv,
347 const std::unordered_set<TensorView*>& out_tv,
348 const std::vector<c10::IValue>& inp_tensors,
349 const std::vector<c10::IValue>& out_tensors) {
350 if (out_tv.empty()) {
351 // Only check input tensors
352 return checkSameStride(inp_tensors);
353 } else if (!out_tv.empty() && out_tensors.empty()) {
354 // out_tensors is empty unless outputs are given to runFusion.
355 // Assume out tensors are contiguous
356 return checkSameContiguity(inp_tensors);
357 } else {
358 // Only check input and output tensors
359 std::vector<c10::IValue> tensors;
360 tensors.insert(tensors.end(), inp_tensors.begin(), inp_tensors.end());
361 tensors.insert(tensors.end(), out_tensors.begin(), out_tensors.end());
362 return checkSameStride(tensors);
363 }
364}
365
366} // namespace
367
368void validateKernelInputs(
369 Fusion* fusion,
370 const KernelArgumentHolder& args,
371 const c10::Device& device) {
372 FUSER_PERF_SCOPE("executor_utils::ValidateKernelInputs");
373
374 // This is necessary as we were traversing the fusion graph later in the check
375 FusionGuard fg(fusion);
376 // Check inputs
377 TORCH_INTERNAL_ASSERT(
378 args.size() == fusion->inputs().size(), "Wrong number of kernel inputs.");
379
380 std::stringstream msg;
381 bool mismatch = false;
382 for (const auto i : c10::irange(args.size())) {
383 const ArgAbstract* arg = args[i];
384 const Val* param = fusion->inputs()[i];
385 mismatch = !validateKernelArg(arg, param, device, msg) || mismatch;
386 }
387 TORCH_INTERNAL_ASSERT(
388 !mismatch, "Found one or more invalid arguments: ", msg.str());
389}
390
391void validateKernelOutputs(
392 Fusion* fusion,
393 const std::vector<at::Tensor>& outputs,
394 const c10::Device& device) {
395 FUSER_PERF_SCOPE("executor_utils::ValidateKernelOutputs");
396
397 TORCH_INTERNAL_ASSERT(
398 fusion->outputs().size() != 0,
399 "Kernel should have at least one output tensor.");
400
401 TORCH_INTERNAL_ASSERT(
402 outputs.size() == fusion->outputs().size(),
403 "Wrong number of kernel outputs.");
404
405 std::stringstream msg;
406 bool mismatch = false;
407 for (const auto i : c10::irange(outputs.size())) {
408 const at::Tensor& arg = outputs[i];
409 const Val* param = fusion->outputs()[i];
410 mismatch = !validateKernelArgTensor(arg, param, device, msg) || mismatch;
411 }
412 TORCH_INTERNAL_ASSERT(
413 !mismatch, "Found one or more invalid arguments: ", msg.str());
414}
415
416namespace {
417
418// Finds a fusion input or output tensor to validate its stides
419// for vectorization.
420// Returns a pair consisting of a flag indicating it's a fusion input
421// and an integer position within in the input or output tensor list.
422std::vector<std::pair<bool, int>> getVectorizedFusionInputOutput(
423 TensorView* producer_tv,
424 TensorView* consumer_tv,
425 Fusion* fusion) {
426 std::vector<std::pair<bool, int>> vectorized_input_output;
427
428 // When the producer is a fusion input, validate only the producer
429 // and assume the consumer is contiguous. Similarly, when the
430 // consumer is a fusion output, validate the consumer and assume the
431 // producer is contiguous.
432
433 if (producer_tv->isFusionInput()) {
434 auto producer_it = std::find(
435 fusion->inputs().begin(), fusion->inputs().end(), producer_tv);
436 TORCH_INTERNAL_ASSERT(
437 producer_it != fusion->inputs().end(),
438 "Could not find ",
439 producer_tv,
440 " in fusion inputs.");
441 auto pos = std::distance(fusion->inputs().begin(), producer_it);
442 vectorized_input_output.push_back(
443 std::make_pair<bool, int>(true, static_cast<int>(pos)));
444 } else {
445 // If not fusion input, assume it's fully contiguous, so nothing
446 // to check with respect to strides.
447 TORCH_INTERNAL_ASSERT(
448 std::all_of(
449 producer_tv->domain()->contiguity().begin(),
450 producer_tv->domain()->contiguity().end(),
451 [](bool contig) { return contig; }),
452 "Unsupported pattern of vectorization: ",
453 consumer_tv->definition()->toString());
454 }
455
456 if (consumer_tv->isFusionOutput()) {
457 auto consumer_it = std::find(
458 fusion->outputs().begin(), fusion->outputs().end(), consumer_tv);
459 TORCH_INTERNAL_ASSERT(
460 consumer_it != fusion->outputs().end(),
461 "Could not find ",
462 consumer_tv,
463 " in fusion outputs.");
464 auto pos = std::distance(fusion->outputs().begin(), consumer_it);
465 vectorized_input_output.push_back(
466 std::make_pair<bool, int>(false, static_cast<int>(pos)));
467 } else {
468 // If not fusion input, assume it's fully contiguous, so nothing
469 // to check with respect to strides.
470 TORCH_INTERNAL_ASSERT(
471 std::all_of(
472 consumer_tv->domain()->contiguity().begin(),
473 consumer_tv->domain()->contiguity().end(),
474 [](bool contig) { return contig; }),
475 "Unsupported pattern of vectorization: ",
476 consumer_tv->definition()->toString());
477 }
478
479 return vectorized_input_output;
480}
481
482//! Returns the information of vectorized input/output tensors
483//! in the given fusion.
484std::unique_ptr<caching::VectorizedTensorInfo> getVectorizedTensorValidationInfo(
485 kir::Kernel* kernel) {
486 auto vectorized_tensor_info_ptr =
487 std::make_unique<caching::VectorizedTensorInfo>();
488
489 for (const auto& vector_info : kernel->summary().vectorized_set_info) {
490 auto consumer_tv = vector_info.consumer_tv;
491 auto producer_tv = vector_info.producer_tv;
492
493 auto vector_dim = vector_info.vectorized_leaf_id;
494 const auto is_aligned =
495 vector_dim->getParallelType() == ParallelType::Vectorize;
496
497 // Find fusion inputs and outputs that are used with misaligned
498 // vectorization.
499 if (!is_aligned) {
500 TORCH_INTERNAL_ASSERT(
501 producer_tv->isFusionInput() || consumer_tv->isFusionOutput(),
502 "MisalignedVectorize is assumed to be used with either input or output tensor");
503 if (consumer_tv->getMemoryType() == MemoryType::Global &&
504 producer_tv->getMemoryType() == MemoryType::Local) {
505 vectorized_tensor_info_ptr->global_out_misaligned_tv.insert(
506 consumer_tv);
507 } else if (
508 producer_tv->getMemoryType() == MemoryType::Global &&
509 consumer_tv->getMemoryType() == MemoryType::Local) {
510 vectorized_tensor_info_ptr->global_inp_misaligned_tv.insert(
511 producer_tv);
512 } else {
513 TORCH_INTERNAL_ASSERT(
514 false,
515 "Unsupported memory configuration for misaligned vectorization.");
516 }
517 }
518
519 // Collect information on corresponding fusion input and output
520 // tensors to verify strides.
521 auto inp_or_out_info =
522 getVectorizedFusionInputOutput(producer_tv, consumer_tv, kernel);
523
524 // If both producer and consumer are contig and intermediate,
525 // nothing to validate with respect to strides.
526 if (inp_or_out_info.empty()) {
527 continue;
528 }
529
530 // Misaligned vectorize only allows from input to local or local
531 // to output
532 if (!is_aligned) {
533 TORCH_INTERNAL_ASSERT(inp_or_out_info.size() == 1);
534 }
535
536 for (const auto& inp_or_out : inp_or_out_info) {
537 const bool is_input = inp_or_out.first;
538 const int pos = inp_or_out.second;
539
540 if (is_aligned) {
541 auto& pos_list = is_input
542 ? vectorized_tensor_info_ptr->aligned_vectorized_inp_tensor_pos
543 : vectorized_tensor_info_ptr->aligned_vectorized_out_tensor_pos;
544 pos_list.push_back(pos);
545 } else {
546 auto& map = is_input
547 ? vectorized_tensor_info_ptr->inp_misaligned_tensors_pos
548 : vectorized_tensor_info_ptr->out_misaligned_tensors_pos;
549 map.emplace_back(pos);
550 }
551 }
552 }
553
554 return vectorized_tensor_info_ptr;
555}
556
557// Make sure the root domain(s) comprising the vectorized leaf domain
558// have the (merged) extent that is divisible by the vectorization
559// word size.
560void validateAlignedVectorizeExtents(
561 const VectorizedSetInfo& info,
562 kir::ExpressionEvaluator& expr_eval) {
563 TORCH_INTERNAL_ASSERT(
564 !info.contig_root_ids.empty(),
565 "No root ID found for vectorization with ",
566 info.consumer_tv->toString(),
567 " and ",
568 info.producer_tv->toString());
569
570 int64_t vectorized_merged_domain_extent = 1;
571 for (auto id : info.contig_root_ids) {
572 auto extent_val = expr_eval.evaluate(id->extent());
573 TORCH_INTERNAL_ASSERT(
574 extent_val.has_value(),
575 "Error vectorizing, ",
576 info.consumer_tv->toString(),
577 " as the extent of a vectorized root domain, ",
578 id->toString(),
579 ", is unknown.");
580 vectorized_merged_domain_extent *= extent_val->as<int64_t>();
581 }
582
583 TORCH_INTERNAL_ASSERT(
584 vectorized_merged_domain_extent % info.word_size == 0,
585 "Error vectorizing, ",
586 info.consumer_tv->toString(),
587 " as the extent of the indexed domain, ",
588 vectorized_merged_domain_extent,
589 ", is not divisible by vector word size ",
590 info.word_size);
591}
592
593void validateAlignedVectorizedFusionInputOutput(
594 const at::Tensor& aten_tensor,
595 int word_size,
596 TensorView* tv) {
597 TORCH_INTERNAL_ASSERT(
598 reinterpret_cast<size_t>(aten_tensor.data_ptr()) %
599 (word_size * aten_tensor.dtype().itemsize()) ==
600 0,
601 "Vectorization of ",
602 tv->toString(),
603 " not possible as the memory address is not aligned. ",
604 "Address: ",
605 aten_tensor.data_ptr(),
606 ", vector word size: ",
607 word_size,
608 ", data type: ",
609 aten_tensor.dtype());
610
611 // Traverse strides from the right-most domains. The rightmost
612 // domain must have stride 1.
613 int64_t cur_contig_stride = 1;
614 bool still_rightmost = true;
615 for (auto i = aten_tensor.ndimension() - 1; i >= 0; --i) {
616 const auto stride = aten_tensor.strides().at(i);
617 const auto size = aten_tensor.sizes().at(i);
618 // If this domain is contiguous or size == 1, then not necessary to check
619 // the stride. Otherwise, stride must be 1 if it's rightmost or
620 // divisible by word_size
621 TORCH_INTERNAL_ASSERT(
622 stride == cur_contig_stride || size == 1 ||
623 (still_rightmost && stride == 1) ||
624 (!still_rightmost && stride % word_size == 0),
625 "Vectorization of ",
626 tv->toString(),
627 " with word size ",
628 word_size,
629 " not possible due to invalid stride.",
630 " Domain: ",
631 tv->axis(i)->toString(),
632 ", stride: ",
633 stride)
634 // If the domain is size-1, the next domain is still considered
635 // rightmost.
636 still_rightmost = still_rightmost && size == 1;
637 // We do not update cur_contig_stride for size==1 dimensions,
638 // since we have specialized vectorization stride check for them
639 if (size != 1) {
640 cur_contig_stride = stride * size;
641 }
642 }
643}
644
645void validateAlignedVectorizedTensors(
646 kir::Kernel* kernel,
647 const KernelArgumentHolder& args,
648 const std::vector<at::Tensor>& outputs,
649 caching::ExecutorCompileTimeInfoCache* data_cache,
650 kir::ExpressionEvaluator& expr_eval) {
651 auto tensor_vectorization_validation_entry =
652 executor_utils::caching::ExecutorCompileTimeEntry<
653 executor_utils::caching::VectorizedTensorValidation>(
654 data_cache, [kernel]() {
655 return executor_utils::getVectorizedTensorValidationInfo(kernel);
656 });
657
658 // Verify extents of aligned vectorized tensors
659 for (const auto& vec_info : kernel->summary().vectorized_set_info) {
660 if (vec_info.vectorized_leaf_id->getParallelType() ==
661 ParallelType::Vectorize) {
662 validateAlignedVectorizeExtents(vec_info, expr_eval);
663 }
664 }
665
666 // Validate input and output tensors with aligend
667 // vectorization.
668 for (auto pos : tensor_vectorization_validation_entry.get()
669 .aligned_vectorized_inp_tensor_pos) {
670 auto tv = kernel->inputs().at(pos)->as<TensorView>();
671 auto word_size = kernel->summary().vectorized_accesses.at(tv);
672 auto tensor_arg_abstract =
673 dynamic_cast<const TensorArgAbstract*>(args[pos]);
674 TORCH_INTERNAL_ASSERT(tensor_arg_abstract, "alias io only supports tensor");
675 validateAlignedVectorizedFusionInputOutput(
676 tensor_arg_abstract->getTensor(), word_size, tv);
677 }
678 if (!outputs.empty()) {
679 for (auto pos : tensor_vectorization_validation_entry.get()
680 .aligned_vectorized_out_tensor_pos) {
681 auto tv = kernel->outputs().at(pos)->as<TensorView>();
682 auto word_size = kernel->summary().vectorized_accesses.at(tv);
683 validateAlignedVectorizedFusionInputOutput(outputs[pos], word_size, tv);
684 }
685 }
686}
687
688// Misaligned vectorization check. Currently misaligned vectorization is limited
689// to global-register and register-global load/store patterns. However, this
690// could be improved to include shared memory.
691void validateMisalignedVectorizedTensors(
692 kir::Kernel* kernel,
693 const KernelArgumentHolder& args,
694 const std::vector<at::Tensor>& outputs,
695 caching::ExecutorCompileTimeInfoCache* data_cache,
696 kir::ExpressionEvaluator& expr_eval) {
697 auto tensor_vectorization_validation_entry =
698 executor_utils::caching::ExecutorCompileTimeEntry<
699 executor_utils::caching::VectorizedTensorValidation>(
700 data_cache, [kernel]() {
701 return executor_utils::getVectorizedTensorValidationInfo(kernel);
702 });
703
704 std::vector<c10::IValue> inp_misaligned_tensors;
705 std::vector<c10::IValue> out_misaligned_tensors;
706
707 const auto& inp_misaligned_tensors_pos =
708 tensor_vectorization_validation_entry.get().inp_misaligned_tensors_pos;
709 inp_misaligned_tensors.reserve(inp_misaligned_tensors_pos.size());
710 std::transform(
711 inp_misaligned_tensors_pos.begin(),
712 inp_misaligned_tensors_pos.end(),
713 std::back_inserter(inp_misaligned_tensors),
714 [&args](int idx) {
715 auto tensor_arg_abstract =
716 dynamic_cast<const TensorArgAbstract*>(args[idx]);
717 TORCH_INTERNAL_ASSERT(
718 tensor_arg_abstract, "alias io only supports tensor");
719 return tensor_arg_abstract->getTensor();
720 });
721
722 const auto& out_misaligned_tensors_pos =
723 tensor_vectorization_validation_entry.get().out_misaligned_tensors_pos;
724 if (outputs.size() > 0) {
725 out_misaligned_tensors.reserve(out_misaligned_tensors_pos.size());
726 std::transform(
727 out_misaligned_tensors_pos.begin(),
728 out_misaligned_tensors_pos.end(),
729 std::back_inserter(out_misaligned_tensors),
730 [&outputs](int idx) { return outputs[idx]; });
731 }
732 // If input stride is non-contiguous + no outputs, return false
733 TORCH_INTERNAL_ASSERT(
734 checkValidMisalignedTensors(
735 tensor_vectorization_validation_entry.get().global_inp_misaligned_tv,
736 tensor_vectorization_validation_entry.get().global_out_misaligned_tv,
737 inp_misaligned_tensors,
738 out_misaligned_tensors),
739 "All global tensors must have the same stride for misaligned vectorization.");
740}
741
742// Check if there's any split that is non-divisible and vectorized. If
743// found, Vectorize is illegal.
744void validateVectorizedSplits(
745 kir::Kernel* kernel,
746 kir::ExpressionEvaluator& expr_eval) {
747 for (const auto& extent_factor : kernel->summary().splits_to_validate) {
748 auto input_extent = expr_eval.evaluate(extent_factor.first);
749 auto split_factor = expr_eval.evaluate(extent_factor.second);
750 TORCH_INTERNAL_ASSERT(
751 input_extent.has_value(),
752 "Could not check if a split with vectorization is divisible because the extent, ",
753 extent_factor.first->toString(),
754 ", is not possible to evaluate.");
755 TORCH_INTERNAL_ASSERT(
756 input_extent.has_value(),
757 "Could not check if a split with vectorization is divisible because the split factor, ",
758 extent_factor.second->toString(),
759 ", is not possible to evaluate.");
760 TORCH_INTERNAL_ASSERT(
761 input_extent.value() % split_factor.value() == 0,
762 "Non-divisible split with vectorization is detected. ",
763 "Extent: ",
764 input_extent.value(),
765 ". Factor: ",
766 split_factor.value());
767 }
768}
769
770} // namespace
771
772void validateVectorizedTensors(
773 kir::Kernel* kernel,
774 const KernelArgumentHolder& args,
775 const std::vector<at::Tensor>& outputs,
776 caching::ExecutorCompileTimeInfoCache* data_cache,
777 kir::ExpressionEvaluator& expr_eval) {
778 FUSER_PERF_SCOPE("FusionExecutor::validateVectorizedTensors");
779
780 validateAlignedVectorizedTensors(
781 kernel, args, outputs, data_cache, expr_eval);
782
783 validateMisalignedVectorizedTensors(
784 kernel, args, outputs, data_cache, expr_eval);
785
786 validateVectorizedSplits(kernel, expr_eval);
787}
788
789namespace {
790
791template <typename EXPR_EVALUATOR>
792void bindInputForExprEvaluation(
793 Val* val,
794 const ArgAbstract* arg,
795 bool check_consistency,
796 EXPR_EVALUATOR& expr_eval) {
797 if (val->getValType() == ValType::TensorView) {
798 TensorView* cg_tensor = val->as<TensorView>();
799 auto root_domain =
800 TensorDomain::noReductions(cg_tensor->getMaybeRFactorDomain());
801
802 if (root_domain.size() == 0) {
803 TORCH_INTERNAL_ASSERT(
804 arg->isType(ArgType::CpuScalarTensor) ||
805 (arg->isType(ArgType::Tensor) &&
806 dynamic_cast<const TensorArgAbstract*>(arg)->getRank() == 0),
807 "Something went wrong configuring launch. Inputs is not rank 0 tensor");
808 } else {
809 TORCH_INTERNAL_ASSERT(
810 arg->isType(ArgType::Tensor),
811 "Something went wrong configuring launch. Inputs do not match.");
812
813 auto tensor_arg_abstract = dynamic_cast<const TensorArgAbstract*>(arg);
814 TORCH_INTERNAL_ASSERT(
815 tensor_arg_abstract &&
816 tensor_arg_abstract->getRank() == (int64_t)root_domain.size(),
817 "Something went wrong configuring launch. Inputs rank does not match.");
818
819 for (const auto dim : c10::irange(root_domain.size())) {
820 const auto tensor_arg_size = tensor_arg_abstract->getSize(dim);
821 const auto tensor_arg_stride = tensor_arg_abstract->getStride(dim);
822 const auto extent = root_domain[dim]->extent();
823 if (root_domain[dim]->hasExpandedExtent()) {
824 TORCH_INTERNAL_ASSERT(
825 tensor_arg_stride == 0,
826 "Expecting an expanded dimension on dimension ",
827 dim,
828 " but found stride ",
829 tensor_arg_stride);
830 // Could support dynamic size on expanded dimension, so may not have
831 // an inferable expanded extent here. This check might be better to do
832 // once all values are bound.
833 auto maybe_expanded_size =
834 expr_eval.evaluate(root_domain[dim]->expandedExtent());
835 if (maybe_expanded_size.has_value()) {
836 TORCH_CHECK(
837 *maybe_expanded_size == tensor_arg_size,
838 "Expecting expanded extent of ",
839 *maybe_expanded_size,
840 " but received value of ",
841 tensor_arg_size);
842 }
843 }
844
845 const auto value =
846 root_domain[dim]->hasExpandedExtent() ? 1 : tensor_arg_size;
847 bool should_bind = true;
848 if (check_consistency) {
849 const auto prev_value = expr_eval.evaluate(extent);
850 if (prev_value.has_value()) {
851 TORCH_CHECK(
852 *prev_value == value,
853 "Attempting to bind ",
854 extent->toString(),
855 " to ",
856 value,
857 " but it's already set to ",
858 *prev_value);
859 should_bind = false;
860 }
861 }
862 if (should_bind && !extent->isConstScalar()) {
863 expr_eval.bind(extent, value);
864 }
865 }
866 }
867 } else if (val->getValType().value() == ValType::Scalar) {
868 if (val->getDataType().value() == DataType::Int) {
869 TORCH_INTERNAL_ASSERT(
870 arg->isType(ArgType::Long),
871 "fusion expected Scalar Int inputs, but found ",
872 argTypeToString(arg->type()));
873 expr_eval.bind(val, *static_cast<const int64_t*>(arg->arg()));
874 } else if (val->getDataType().value() == DataType::Double) {
875 TORCH_INTERNAL_ASSERT(
876 arg->isType(ArgType::Double),
877 "fusion expected Scalar Double inputs, but found ",
878 argTypeToString(arg->type()));
879 expr_eval.bind(val, *static_cast<const double*>(arg->arg()));
880 }
881 }
882}
883
884} // namespace
885
886kir::ExpressionEvaluator bindKernelInputs(
887 const KernelArgumentHolder& args,
888 kir::Kernel* kernel,
889 bool check_consistency) {
890 FUSER_PERF_SCOPE("executor_utils::BindKernelInputs");
891
892 TORCH_INTERNAL_ASSERT(
893 kernel->inputs().size() == args.size(),
894 "Something went wrong configuring launch. Inputs no longer match.");
895
896 kir::ExpressionEvaluator expr_eval;
897 const auto& inputs = kernel->inputs();
898
899 for (const auto i : c10::irange(inputs.size())) {
900 bindInputForExprEvaluation(
901 inputs[i], args[i], check_consistency, expr_eval);
902 }
903 return expr_eval;
904}
905
906ExpressionEvaluator bindFusionInputs(
907 const KernelArgumentHolder& args,
908 Fusion* fusion) {
909 FUSER_PERF_SCOPE("executor_utils::BindFusionInputs");
910
911 auto inputs = fusion->inputs();
912 TORCH_INTERNAL_ASSERT(
913 inputs.size() == args.size(),
914 "Something went wrong configuring launch. Inputs do not match.\n",
915 "inputs: ",
916 ir_utils::toString(inputs),
917 " args size: ",
918 args.size());
919
920 ExpressionEvaluator expr_eval(fusion);
921
922 // This should probably move to EvaluationContext as we may want to bind
923 // input values frequently. Bind fusion input values to runtime values.
924 for (const auto i : c10::irange(inputs.size())) {
925 bindInputForExprEvaluation(inputs[i], args[i], true, expr_eval);
926 }
927 return expr_eval;
928}
929
930namespace {
931
932// Dump PTX or CUBIN to a file
933#if CUDA_VERSION >= 11010
934void dumpCompiledCodeToFile(
935 const nvrtcProgram& program,
936 int fusion_id,
937 bool dump_cubin) {
938 const auto getSize = dump_cubin
939 ? at::globalContext().getNVRTC().nvrtcGetCUBINSize
940 : at::globalContext().getNVRTC().nvrtcGetPTXSize;
941 const auto getCode = dump_cubin ? at::globalContext().getNVRTC().nvrtcGetCUBIN
942 : at::globalContext().getNVRTC().nvrtcGetPTX;
943 size_t size = 0;
944 AT_CUDA_NVRTC_CHECK(getSize(program, &size));
945 std::vector<char> code(size);
946 AT_CUDA_NVRTC_CHECK(getCode(program, code.data()));
947 std::stringstream file_name;
948 file_name << "__tmp_kernel" << fusion_id << "."
949 << (dump_cubin ? "cubin" : "ptx");
950 std::cout << "PRINTING: " << file_name.str() << std::endl;
951 std::ofstream out(file_name.str());
952 TORCH_INTERNAL_ASSERT(out.is_open());
953 out.write(code.data(), size);
954 out.close();
955}
956#endif
957
958} // namespace
959
960std::pair<NvrtcFunction, std::string> nvrtcCompile(
961 const std::string& code,
962 const std::string& func_name,
963 int id,
964 c10::optional<int> opt_block_size) {
965 FUSER_PERF_SCOPE("executor_utils::NVRTC");
966 if (isOptionDisabled(DisableOption::ArchCheck)) {
967 TORCH_WARN(
968 "NVFuser Compile: arch check disabled, should not compile any kernel");
969 }
970
971 at::cuda::jit::initializeCudaContext();
972
973 std::stringstream ptxas_log;
974
975 const auto prop = at::cuda::getCurrentDeviceProperties();
976
977 int major = 0, minor = 0;
978 bool compile_to_sass = false;
979 codegenOutputQuery(prop, major, minor, compile_to_sass);
980
981 nvrtcProgram program; // NOLINT(cppcoreguidelines-init-variables)
982
983 {
984 std::stringstream ss;
985 ss << "__tmp_kernel" << id << ".cu";
986 std::string name = ss.str();
987 FUSER_PERF_SCOPE("executor_utils::NvrtcCreateProgram");
988 AT_CUDA_NVRTC_CHECK(at::globalContext().getNVRTC().nvrtcCreateProgram(
989 &program, code.c_str(), name.c_str(), 0, nullptr, nullptr));
990 }
991
992 ResourceGuard holdProgram([&] {
993 FUSER_PERF_SCOPE("executor_utils::NvrtcDestroyProgram");
994 AT_CUDA_NVRTC_CHECK(
995 at::globalContext().getNVRTC().nvrtcDestroyProgram(&program));
996 });
997
998#ifdef USE_ROCM
999 std::vector<const char*> args = {"--std=c++17"};
1000#if ROCM_VERSION >= 40200
1001 args.push_back("-hip-pch");
1002#endif
1003#else
1004#if CUDA_VERSION < 11010
1005 // compile to sass is not allowed prior to CUDA 11.1
1006 compile_to_sass = false;
1007#endif
1008
1009 if (isOptionDisabled(DisableOption::CompileToSass)) {
1010 // Allows manually disabling compilation to sass
1011 // so the intermediate ptx could be checked.
1012 compile_to_sass = false;
1013 }
1014 // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_)
1015 // which gives better backwards compatibility to work on older driver,
1016 // (since older driver doesn't necessrily recognize PTX emitted by new
1017 // toolkit);
1018 // Meanwhile, for forward compatibility (future device with
1019 // `unsupported_arch==True`), since SASS are not necessarily compatible,
1020 // we fallback to PTX instead.
1021 const std::string compute = std::string("--gpu-architecture=") +
1022 (compile_to_sass ? "sm_" : "compute_") + std::to_string(major) +
1023 std::to_string(minor);
1024 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1025 std::vector<const char*> args = {
1026 "--std=c++17", compute.c_str(), "-default-device"};
1027#endif
1028
1029 const bool disable_fma = isOptionDisabled(DisableOption::Fma);
1030#ifdef USE_ROCM
1031 if (disable_fma) {
1032 TORCH_WARN_ONCE(
1033 "PYTORCH_CUDA_FUSER_DISABLE_FMA is not supported on ROCm, ignoring");
1034 }
1035#else
1036 if (disable_fma) {
1037 args.push_back("--fmad=false");
1038 } else {
1039 args.push_back("--fmad=true");
1040 }
1041#endif
1042 // Add line info to generated kernels
1043 if (isDebugDumpEnabled(DebugDumpOption::DebugInfo)) {
1044 args.push_back("-lineinfo");
1045 }
1046#ifdef NDEBUG
1047 // Avoid excessive register usage from assertion
1048 args.push_back("-DNDEBUG");
1049#endif
1050
1051 if (isOptionEnabled(EnableOption::KernelProfile)) {
1052 args.push_back("-DPYTORCH_NVFUSER_PROFILE_KERNEL");
1053 }
1054
1055 const char* ptxas_opt_level = getenv("PYTORCH_NVFUSER_JIT_OPT_LEVEL");
1056 std::string jit_opt_level = "-O";
1057
1058 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1059 std::vector<CUjit_option> options;
1060 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1061 std::vector<void*> option_vals;
1062 std::vector<char> info_log;
1063 unsigned int log_size = 8196;
1064
1065 if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog) ||
1066 isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) {
1067 // show register usage in compilation log
1068 if (compile_to_sass) {
1069 args.push_back("--ptxas-options");
1070 args.push_back("--verbose");
1071 } else {
1072 options.push_back(CU_JIT_LOG_VERBOSE);
1073 option_vals.push_back((void*)1);
1074 info_log.reserve(log_size);
1075
1076 options.push_back(CU_JIT_INFO_LOG_BUFFER);
1077 option_vals.push_back((void*)info_log.data());
1078
1079 options.push_back(CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES);
1080 option_vals.push_back((void*)(long)log_size);
1081 }
1082 }
1083
1084 if (ptxas_opt_level) {
1085 int val = atoi(ptxas_opt_level);
1086 if (val <= 4 && val >= 0) {
1087 if (val < 4) {
1088 TORCH_WARN(
1089 "ptxas optimization level manually set as ",
1090 val,
1091 ", which could negatively affect performance. Try removing env variable PYTORCH_NVFUSER_JIT_OPT_LEVEL for optimal performance.");
1092 }
1093 if (compile_to_sass) {
1094 jit_opt_level += std::to_string(val);
1095 args.push_back("--ptxas-options");
1096 args.push_back(jit_opt_level.c_str());
1097 } else {
1098 options.push_back(CU_JIT_OPTIMIZATION_LEVEL);
1099 option_vals.push_back((void*)(intptr_t)val);
1100 }
1101 } else {
1102 TORCH_WARN_ONCE(
1103 "acceptable range for PYTORCH_NVFUSER_JIT_OPT_LEVEL is between 0 and 4, but received ",
1104 val,
1105 ", ignoring the option");
1106 }
1107 }
1108
1109#ifndef USE_ROCM
1110 // keeping the string outside the loop for lifetime
1111 std::string max_register_usage = "--maxrregcount=";
1112 uint32_t max_register = 0;
1113 if (opt_block_size.has_value() && opt_block_size.value() > 0) {
1114 int num_partition = 0;
1115 int reg_allocation_granularity = 0;
1116 cudaOccDeviceProp occ_prop(*prop);
1117 cudaOccSubPartitionsPerMultiprocessor(&num_partition, &occ_prop);
1118 cudaOccRegAllocationGranularity(&reg_allocation_granularity, &occ_prop);
1119 int warp_size = prop->warpSize;
1120 int num_warps = ceilDiv(opt_block_size.value(), warp_size);
1121
1122 // warps could be distributed unevenly across partition
1123 int max_warps_per_sm_partition = ceilDiv(num_warps, num_partition);
1124 // registers are evenly distributed across partitions, partition with most
1125 // wraps determins the maximum register available per warp
1126 int max_reg_per_warp =
1127 prop->regsPerBlock / num_partition / max_warps_per_sm_partition;
1128 // clamp down to register allocation granularity at warp level
1129 int effective_max_reg_per_warp = max_reg_per_warp /
1130 reg_allocation_granularity * reg_allocation_granularity;
1131 // The maximum possible count allowed by ptxas is 255
1132 max_register = static_cast<uint32_t>(
1133 std::min(effective_max_reg_per_warp / warp_size, 255));
1134 if (compile_to_sass) {
1135 max_register_usage += std::to_string(max_register);
1136 args.push_back("--ptxas-options");
1137 args.push_back(max_register_usage.c_str());
1138 } else {
1139 options.push_back(CU_JIT_MAX_REGISTERS);
1140 option_vals.push_back((void*)(intptr_t)max_register);
1141 }
1142
1143 ptxas_log << "\nCompile options: ";
1144 for (auto arg : args) {
1145 ptxas_log << arg << " ";
1146 }
1147 ptxas_log << " ; block size=" << opt_block_size.value() << "\n";
1148 }
1149#endif
1150
1151 at::globalContext().getNVRTC().nvrtcAddNameExpression(
1152 program, func_name.c_str());
1153
1154 {
1155 FUSER_PERF_SCOPE("executor_utils::Nvrtc::CompileProgram");
1156
1157 const auto result = at::globalContext().getNVRTC().nvrtcCompileProgram(
1158 program, args.size(), args.data());
1159
1160 size_t logsize = 0;
1161 at::globalContext().getNVRTC().nvrtcGetProgramLogSize(program, &logsize);
1162
1163 std::vector<char> log(logsize);
1164 at::globalContext().getNVRTC().nvrtcGetProgramLog(program, log.data());
1165
1166 if (result != NVRTC_SUCCESS) {
1167 TORCH_INTERNAL_ASSERT(
1168 false, code.c_str(), "\nCUDA NVRTC compile error: ", log.data());
1169 }
1170
1171 ptxas_log << log.data() << std::endl;
1172 if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
1173 std::cout << log.data() << std::endl;
1174 }
1175 AT_CUDA_NVRTC_CHECK(result);
1176 }
1177
1178 const char* lowered_kernel_name = nullptr;
1179 at::globalContext().getNVRTC().nvrtcGetLoweredName(
1180 program, func_name.c_str(), &lowered_kernel_name);
1181
1182 size_t ptx_size = 0;
1183 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1184 std::vector<char> ptx;
1185
1186 {
1187 FUSER_PERF_SCOPE("executor_utils::Nvrtc::GetPTX");
1188#if CUDA_VERSION >= 11010
1189 // compile_to_sass determines whether we are generating SASS or PTX, hence
1190 // the different API.
1191 const auto getSize = compile_to_sass
1192 ? at::globalContext().getNVRTC().nvrtcGetCUBINSize
1193 : at::globalContext().getNVRTC().nvrtcGetPTXSize;
1194 const auto getFunc = compile_to_sass
1195 ? at::globalContext().getNVRTC().nvrtcGetCUBIN
1196 : at::globalContext().getNVRTC().nvrtcGetPTX;
1197#else
1198 const auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize;
1199 const auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX;
1200#endif
1201 AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size));
1202 ptx.resize(ptx_size);
1203 AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data()));
1204 }
1205
1206 NvrtcFunction compiled_kernel_;
1207
1208#ifndef USE_ROCM
1209
1210#if CUDA_VERSION >= 11010
1211 if (isDebugDumpEnabled(DebugDumpOption::Ptx)) {
1212 dumpCompiledCodeToFile(program, id, false);
1213 }
1214
1215 if (isDebugDumpEnabled(DebugDumpOption::Cubin)) {
1216 TORCH_INTERNAL_ASSERT(
1217 compile_to_sass,
1218 "CUBIN not available as the kernel was compiled only to PTX");
1219 dumpCompiledCodeToFile(program, id, true);
1220 }
1221#endif
1222
1223 {
1224 FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX");
1225
1226 // load ptx or cubin directly
1227 AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx(
1228 &(compiled_kernel_.module),
1229 ptx.data(),
1230 options.size(),
1231 options.data(),
1232 option_vals.data()));
1233
1234 if (!compile_to_sass &&
1235 isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) {
1236 std::cout << info_log.data() << std::endl;
1237 }
1238 }
1239#else
1240 // load ptx directly
1241 AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData(
1242 &(compiled_kernel_.module), ptx.data()));
1243
1244#endif
1245 AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleGetFunction(
1246 &(compiled_kernel_.function),
1247 compiled_kernel_.module,
1248 lowered_kernel_name));
1249
1250 TORCH_CHECK(
1251 !isOptionDisabled(DisableOption::ArchCheck),
1252 "NVFuser Compile: arch check disabled, should not return any compiled kernel");
1253
1254 return {compiled_kernel_, ptxas_log.str()};
1255}
1256
1257namespace caching {
1258
1259//! CompileTimeInfo is the actual subclass of CompileTimeInfoBase that will
1260//! be stored in the data cache. It owns a data_ state internally of the
1261//! dataType defined within the entry class, which are listed in header file.
1262template <typename EntryClass>
1263class CompileTimeInfo : public CompileTimeInfoBase {
1264 public:
1265 CompileTimeInfo(std::unique_ptr<typename EntryClass::DataType> data)
1266 : CompileTimeInfoBase(EntryClass::EntryType), data_(std::move(data)) {}
1267
1268 typename EntryClass::DataType* get() {
1269 return data_.get();
1270 }
1271
1272 private:
1273 std::unique_ptr<typename EntryClass::DataType> data_;
1274};
1275
1276void ExecutorCompileTimeInfoCache::insert(EntryOwningPtr new_entry) {
1277 // Just overwrite when insertion duplicates, equality not checked.
1278 entry_type_map_[new_entry->type()] = new_entry.get();
1279 entries_.emplace_back(std::move(new_entry));
1280}
1281
1282template <typename EntryClass>
1283ExecutorCompileTimeEntry<EntryClass>::ExecutorCompileTimeEntry(
1284 ExecutorCompileTimeInfoCache* data_cache,
1285 MakerFnType fn) {
1286 using InfoType = CompileTimeInfo<EntryClass>;
1287
1288 if (!data_cache || !data_cache->has(EntryClass::EntryType)) {
1289 owned_data_ = fn();
1290 data_ptr_ = owned_data_.get();
1291
1292 if (data_cache) {
1293 std::unique_ptr<CompileTimeInfoBase> new_entry =
1294 std::make_unique<InfoType>(std::move(owned_data_));
1295 data_cache->insert(std::move(new_entry));
1296 }
1297 } else {
1298 data_ptr_ =
1299 data_cache->at(EntryClass::EntryType)->template as<InfoType>()->get();
1300 }
1301}
1302
1303// Template instantiation
1304template class ExecutorCompileTimeEntry<ParallelBindingIterDomains>;
1305template class ExecutorCompileTimeEntry<ParallelIterExtentMap>;
1306template class ExecutorCompileTimeEntry<SimplifiedParallelIterExtentMap>;
1307template class ExecutorCompileTimeEntry<WarpPaddedParallelExtents>;
1308template class ExecutorCompileTimeEntry<VectorizedTensorValidation>;
1309template class ExecutorCompileTimeEntry<InputAliasIndices>;
1310template class ExecutorCompileTimeEntry<OutputAliasIndices>;
1311
1312} // namespace caching
1313
1314std::vector<IterDomain*> getParallelBindingsIterDomains(
1315 GpuLower* lower,
1316 const std::vector<TensorView*>& used_tvs) {
1317 std::vector<IterDomain*> parallel_ids;
1318 for (auto tv : used_tvs) {
1319 for (auto id : tv->domain()->domain()) {
1320 if (id->isThread()) {
1321 if (id->isBroadcast()) {
1322 // Want to keep the broadcast dimensions if they are not resolved
1323 // TODO: piping down the parallel dimension map here would
1324 // be helpful
1325 if (lower->caMap()->getConcreteMappedID(id, IdMappingMode::LOOP) ==
1326 id) {
1327 parallel_ids.push_back(id);
1328 }
1329 } else {
1330 // Non broadcast ids are directly added to the binding
1331 // ids.
1332 parallel_ids.push_back(id);
1333 }
1334 }
1335 }
1336 }
1337 return parallel_ids;
1338}
1339
1340namespace {
1341
1342void insertParallelExtent(
1343 IterDomain* binding_id,
1344 const std::unique_ptr<ParallelExtentMap>& parallel_iter_extents_ptr) {
1345 auto extent = binding_id->extent();
1346 const auto it =
1347 parallel_iter_extents_ptr->find(binding_id->getParallelType());
1348 if (it != parallel_iter_extents_ptr->end()) {
1349 it->second.push_back(extent);
1350 } else {
1351 parallel_iter_extents_ptr->operator[](binding_id->getParallelType()) = {
1352 extent};
1353 }
1354}
1355
1356} // namespace
1357
1358std::unique_ptr<ParallelExtentMap> getParallelIterExtents(
1359 std::vector<IterDomain*>& parallel_binding_ids) {
1360 auto parallel_iter_extents_ptr = std::make_unique<ParallelExtentMap>();
1361 for (auto id : parallel_binding_ids) {
1362 insertParallelExtent(id, parallel_iter_extents_ptr);
1363 }
1364
1365 return parallel_iter_extents_ptr;
1366}
1367
1368std::unique_ptr<ParallelExtentMap> getSimplifiedParallelIterExtents(
1369 GpuLower* lower,
1370 std::vector<IterDomain*>& parallel_binding_ids) {
1371 auto parallel_iter_extents_ptr = std::make_unique<ParallelExtentMap>();
1372 const auto& ca_map = lower->caMap();
1373 std::vector<IterDomain*> mapped;
1374 bool is_tidx_warp_padded = lower->getWarpPaddedParallelInfo().is_tidx_padded;
1375
1376 for (auto id : parallel_binding_ids) {
1377 if (std::any_of(
1378 mapped.begin(), mapped.end(), [id, &ca_map](IterDomain* mapped_id) {
1379 return ca_map->areMapped(mapped_id, id, IdMappingMode::LOOP);
1380 })) {
1381 if (id->getParallelType() != ParallelType::TIDx || !is_tidx_warp_padded) {
1382 continue;
1383 }
1384 }
1385
1386 insertParallelExtent(
1387 ca_map->getConcreteMappedID(id, IdMappingMode::LOOP),
1388 parallel_iter_extents_ptr);
1389 mapped.push_back(id);
1390 }
1391
1392 return parallel_iter_extents_ptr;
1393}
1394
1395std::unique_ptr<caching::WarpPaddedExtentsInfo> getWarpPaddedExtentsInfo(
1396 kir::Kernel* kernel,
1397 std::vector<IterDomain*>& parallel_binding_ids) {
1398 auto warp_padded_extent_info_ptr =
1399 std::make_unique<caching::WarpPaddedExtentsInfo>();
1400 auto& warp_padded_extent_set =
1401 warp_padded_extent_info_ptr->warp_padded_extent_set;
1402 auto& warp_padded_constant =
1403 warp_padded_extent_info_ptr->warp_padded_constant;
1404 bool has_warp_reduction =
1405 kernel->getWarpPaddedParallelInfo().has_warp_reduction;
1406
1407 for (auto id : parallel_binding_ids) {
1408 // Apply warp padding only when there're warp reductions in
1409 // the kernel.
1410 if (has_warp_reduction) {
1411 if (id->hasPaddingToMultipleOfWarp() ||
1412 kernel->isParallelTypePadded(id->getParallelType())) {
1413 auto extent = id->extent();
1414 warp_padded_extent_set.insert(extent);
1415 auto padded_value = id->getMaybeSizeAfterPadding();
1416 if (padded_value.has_value()) {
1417 warp_padded_constant[extent] = padded_value.value();
1418 }
1419 }
1420 }
1421 }
1422 return warp_padded_extent_info_ptr;
1423}
1424
1425} // namespace executor_utils
1426} // namespace cuda
1427} // namespace fuser
1428} // namespace jit
1429} // namespace torch
1430