1 | #include <ATen/cuda/CUDAContext.h> |
2 | #include <ATen/cuda/CUDAGeneratorImpl.h> |
3 | #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h> |
4 | #include <ATen/native/cuda/jit_utils.h> |
5 | |
6 | #include <c10/util/irange.h> |
7 | |
8 | #include <contiguity.h> |
9 | #include <executor_utils.h> |
10 | #include <instrumentation.h> |
11 | #include <ir_all_nodes.h> |
12 | #include <ir_iostream.h> |
13 | #include <ir_utils.h> |
14 | #include <torch/csrc/jit/codegen/fuser/cuda/fused_kernel.h> |
15 | #include <torch/csrc/jit/resource_guard.h> |
16 | |
17 | #include <nvfuser_resources/PhiloxCudaStateRaw.h> |
18 | #include <nvfuser_resources/array.h> |
19 | #include <nvfuser_resources/bf16_support.h> |
20 | #include <nvfuser_resources/block_reduction.h> |
21 | #include <nvfuser_resources/block_sync_atomic.h> |
22 | #include <nvfuser_resources/block_sync_default.h> |
23 | #include <nvfuser_resources/broadcast.h> |
24 | #include <nvfuser_resources/fp16_support.h> |
25 | #include <nvfuser_resources/fused_reduction.h> |
26 | #include <nvfuser_resources/fused_welford_helper.h> |
27 | #include <nvfuser_resources/fused_welford_impl.h> |
28 | #include <nvfuser_resources/grid_broadcast.h> |
29 | #include <nvfuser_resources/grid_reduction.h> |
30 | #include <nvfuser_resources/grid_sync.h> |
31 | #include <nvfuser_resources/helpers.h> |
32 | #include <nvfuser_resources/index_utils.h> |
33 | #include <nvfuser_resources/memory.h> |
34 | #include <nvfuser_resources/random_numbers.h> |
35 | #include <nvfuser_resources/swizzle.h> |
36 | #include <nvfuser_resources/tensor.h> |
37 | #include <nvfuser_resources/tensorcore.h> |
38 | #include <nvfuser_resources/tuple.h> |
39 | #include <nvfuser_resources/type_traits.h> |
40 | #include <nvfuser_resources/warp.h> |
41 | #include <nvfuser_resources/welford.h> |
42 | |
43 | #ifdef USE_ROCM |
44 | #include <nvfuser_resources/array_rocm.h> |
45 | #include <nvfuser_resources/bf16_support_rocm.h> |
46 | #include <nvfuser_resources/block_sync_default_rocm.h> |
47 | #include <nvfuser_resources/warp_rocm.h> |
48 | #endif |
49 | |
50 | #ifndef USE_ROCM |
51 | #include <cuda_occupancy.h> |
52 | #endif |
53 | |
54 | #include <fstream> |
55 | |
56 | namespace torch { |
57 | namespace jit { |
58 | namespace fuser { |
59 | namespace cuda { |
60 | namespace executor_utils { |
61 | |
62 | std::string kernelPreamble() { |
63 | std::stringstream ss; |
64 | |
65 | #ifndef USE_ROCM |
66 | ss << nvfuser_resources::fp16_support_cu; |
67 | ss << nvfuser_resources::bf16_support_cu; |
68 | #else |
69 | ss << R"( |
70 | #ifndef __noinline__ |
71 | #define __noinline__ __attribute__((noinline)) |
72 | #endif |
73 | #ifndef __forceinline__ |
74 | #define __forceinline__ inline __attribute__((always_inline)) |
75 | #endif |
76 | #ifndef assert |
77 | #define assert(expr) ((void)0) |
78 | #endif |
79 | #ifndef __align__ |
80 | #define __align__(x) __attribute__((aligned(x))) |
81 | #endif |
82 | )" ; |
83 | // fp16 support is automatic, bf16 is not |
84 | ss << nvfuser_resources::bf16_support_rocm_cu; |
85 | #endif |
86 | |
87 | // Base classes and helpers |
88 | ss << nvfuser_resources::tensor_cu; |
89 | ss << nvfuser_resources::type_traits_cu; |
90 | #ifndef USE_ROCM |
91 | ss << nvfuser_resources::array_cu; |
92 | #else |
93 | ss << nvfuser_resources::array_rocm_cu; |
94 | #endif |
95 | ss << nvfuser_resources::random_numbers_cu; |
96 | ss << nvfuser_resources::helpers_cu; |
97 | ss << nvfuser_resources::index_utils_cu; |
98 | ss << nvfuser_resources::tuple_cu; |
99 | |
100 | // Synchronization classes |
101 | if (std::getenv("PYTORCH_NVFUSER_USE_BLOCK_SYNC_ATOMIC" )) { |
102 | ss << nvfuser_resources::block_sync_atomic_cu; |
103 | } else { |
104 | #ifndef USE_ROCM |
105 | ss << nvfuser_resources::block_sync_default_cu; |
106 | #else |
107 | ss << nvfuser_resources::block_sync_default_rocm_cu; |
108 | #endif |
109 | } |
110 | ss << nvfuser_resources::grid_sync_cu; |
111 | |
112 | // Communication classes |
113 | ss << nvfuser_resources::block_reduction_cu; |
114 | ss << nvfuser_resources::grid_reduction_cu; |
115 | ss << nvfuser_resources::grid_broadcast_cu; |
116 | ss << nvfuser_resources::broadcast_cu; |
117 | ss << nvfuser_resources::welford_cu; |
118 | #ifndef USE_ROCM |
119 | ss << nvfuser_resources::warp_cu; |
120 | ss << nvfuser_resources::tensorcore_cu; |
121 | ss << nvfuser_resources::memory_cu; |
122 | #else |
123 | ss << nvfuser_resources::warp_rocm_cu; |
124 | #endif |
125 | ss << nvfuser_resources::fused_welford_helper_cu; |
126 | ss << nvfuser_resources::fused_reduction_cu; |
127 | ss << nvfuser_resources::fused_welford_impl_cu; |
128 | ss << nvfuser_resources::swizzle_cu; |
129 | |
130 | // Random utilities |
131 | ss << nvfuser_resources::PhiloxCudaStateRaw_cu; |
132 | |
133 | return ss.str(); |
134 | } |
135 | |
136 | namespace { |
137 | |
138 | // return false if arg's type, number of dimensions, and device, doesn't match |
139 | // param and provided c10:device |
140 | bool validateKernelArgTensor( |
141 | const at::Tensor& arg, |
142 | const Val* param, |
143 | const c10::Device& device, |
144 | std::stringstream& msg) { |
145 | // Arg is a tensor. Param must be a tensor too. |
146 | if (*param->getValType() != ValType::TensorView) { |
147 | msg << "Argument is a tensor, but the parameter is not.\n" ; |
148 | return false; |
149 | } |
150 | |
151 | if (is_cpu_scalar(arg) && !param->as<TensorView>()->isCpuScalar()) { |
152 | msg << "Argument is CPU Scalar Tensor, but parameter is not.\n" ; |
153 | return false; |
154 | } |
155 | |
156 | if (!is_cpu_scalar(arg) && !arg.is_cuda()) { |
157 | msg << "Argument is a CPU tensor which is not supported in fusions.\n" ; |
158 | return false; |
159 | } |
160 | |
161 | // Check the rank of the tensors. |
162 | size_t arg_dim = arg.dim(); |
163 | // Note: This requires current Fusion to be active. |
164 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
165 | size_t param_dim = TensorDomain::noReductions( |
166 | param->as<TensorView>()->getMaybeRFactorDomain()) |
167 | .size(); |
168 | // see [Note - broadcast support in integration] |
169 | // Because of broadcasting support handled in integration, we relax the rank |
170 | // check as necessary. |
171 | if (arg_dim > param_dim) { |
172 | msg << "Argument tensor's rank is " << arg_dim << ", but the parameter is " |
173 | << param_dim << "\n" ; |
174 | return false; |
175 | } |
176 | |
177 | if (!is_cpu_scalar(arg) && arg.device() != device) { |
178 | msg << "Argument is on device that is not compiled for." |
179 | << "\n" ; |
180 | return false; |
181 | } |
182 | // Check element type |
183 | at::ScalarType arg_data_type = arg.scalar_type(); |
184 | DataType param_data_type = *param->getDataType(); |
185 | bool match = false; |
186 | // TODO: remove this switch with `aten_to_data_type` |
187 | switch (arg_data_type) { |
188 | case at::ScalarType::Double: |
189 | match = param_data_type == DataType::Double; |
190 | break; |
191 | case at::ScalarType::Half: |
192 | match = param_data_type == DataType::Half; |
193 | break; |
194 | case at::ScalarType::BFloat16: |
195 | match = param_data_type == DataType::BFloat16; |
196 | break; |
197 | case at::ScalarType::Float: |
198 | match = param_data_type == DataType::Float; |
199 | break; |
200 | case at::ScalarType::Long: |
201 | match = param_data_type == DataType::Int; |
202 | break; |
203 | case at::ScalarType::Int: |
204 | match = param_data_type == DataType::Int32; |
205 | break; |
206 | case at::ScalarType::Bool: |
207 | match = param_data_type == DataType::Bool; |
208 | break; |
209 | case at::ScalarType::ComplexFloat: |
210 | match = param_data_type == DataType::ComplexFloat; |
211 | break; |
212 | case at::ScalarType::ComplexDouble: |
213 | match = param_data_type == DataType::ComplexDouble; |
214 | break; |
215 | default: |
216 | msg << "Argument element type, " << arg_data_type << ", is not supported." |
217 | << "\n" ; |
218 | return false; |
219 | } |
220 | if (!match) |
221 | msg << "Argument element type is " << arg_data_type |
222 | << ", but the parameter is " << param_data_type << "\n" ; |
223 | return match; |
224 | } |
225 | |
226 | // Return false if arg_type doesn't match the type in param |
227 | bool validateKernelArgScalar( |
228 | const ArgAbstract* arg, |
229 | const Val* param, |
230 | std::stringstream& msg) { |
231 | TORCH_INTERNAL_ASSERT( |
232 | param->getDataType().has_value(), "kernel param should have data type" ); |
233 | DataType param_type = *param->getDataType(); |
234 | bool match = false; |
235 | switch (arg->type()) { |
236 | case ArgType::Long: |
237 | match = param_type == DataType::Int || param_type == DataType::Int32; |
238 | break; |
239 | case ArgType::Double: |
240 | match = param_type == DataType::Double || param_type == DataType::Float || |
241 | param_type == DataType::Half || param_type == DataType::BFloat16; |
242 | break; |
243 | case ArgType::Bool: |
244 | match = param_type == DataType::Bool; |
245 | break; |
246 | case ArgType::ComplexDouble: |
247 | match = param_type == DataType::ComplexDouble || |
248 | param_type == DataType::ComplexFloat; |
249 | break; |
250 | default: |
251 | // TODO: We need to verify that param is actually a scalar |
252 | msg << "Argument is not a scalar, but the parameter is." |
253 | << "\n" ; |
254 | return false; |
255 | } |
256 | if (!match) { |
257 | msg << "Argument type is " << argTypeToString(arg->type()) |
258 | << ", but the parameter is " << param_type << "\n" ; |
259 | } |
260 | return match; |
261 | } |
262 | |
263 | // Return false if arg and param don't match up and if arg's device (if a |
264 | // tensor) doesn't match provided device |
265 | bool validateKernelArg( |
266 | const ArgAbstract* arg, |
267 | const Val* param, |
268 | const c10::Device& device, |
269 | std::stringstream& msg) { |
270 | if (auto tensor_arg_abstract = dynamic_cast<const TensorArgAbstract*>(arg)) { |
271 | // TODO: don't use get tensor here. We would want to remove tensor reference |
272 | // for async compilation |
273 | return validateKernelArgTensor( |
274 | tensor_arg_abstract->getTensor(), param, device, msg); |
275 | } else if (arg->isType(ArgType::CpuScalarTensor)) { |
276 | // TODO: merge this one with above |
277 | // TODO: we need to check cpu scalar dtyp matches param |
278 | bool match = param->as<TensorView>()->isCpuScalar(); |
279 | if (!match) { |
280 | msg << "Argument is scalar type, but kernel parameter is not\n" ; |
281 | } |
282 | return match; |
283 | } else { |
284 | return validateKernelArgScalar(arg, param, msg); |
285 | } |
286 | } |
287 | |
288 | // Return true if all the tensors have the same stride, assumes all tensors are |
289 | // contiguous |
290 | bool checkSameStride(const std::vector<c10::IValue>& tensors) { |
291 | if (tensors.size() < 2) { |
292 | return true; |
293 | } |
294 | for (const auto idx : c10::irange(tensors.size() - 1)) { |
295 | auto current = tensors[idx]; |
296 | auto next = tensors[idx + 1]; |
297 | if (!current.isTensor() || !next.isTensor()) { |
298 | return false; |
299 | } |
300 | |
301 | const auto& current_tensor = current.toTensor(); |
302 | const auto& next_tensor = next.toTensor(); |
303 | if (current_tensor.ndimension() != next_tensor.ndimension()) { |
304 | return false; |
305 | } |
306 | |
307 | for (const auto i : c10::irange(current_tensor.ndimension())) { |
308 | if (current_tensor.stride(i) != next_tensor.stride(i)) { |
309 | return false; |
310 | } |
311 | } |
312 | } |
313 | return true; |
314 | } |
315 | |
316 | // Return true if all the tensors are contiguous and have the same striding |
317 | bool checkSameContiguity(const std::vector<c10::IValue>& tensors) { |
318 | if (tensors.size() < 2) { |
319 | return true; |
320 | } |
321 | |
322 | auto reference = tensors.front(); |
323 | if (!reference.isTensor()) { |
324 | return false; |
325 | } |
326 | |
327 | // Determine if the reference tensor is contiguous |
328 | const auto& reference_tensor = reference.toTensor(); |
329 | int64_t expected_stride = 1; |
330 | for (const auto i : c10::irange(1, reference_tensor.ndimension() + 1)) { |
331 | int64_t ind = reference_tensor.ndimension() - i; |
332 | if (reference_tensor.size(ind) == 1) { |
333 | continue; |
334 | } |
335 | if (reference_tensor.stride(ind) != expected_stride) { |
336 | return false; |
337 | } |
338 | expected_stride *= reference_tensor.size(ind); |
339 | } |
340 | |
341 | // Check if all the tensors have the same contiguity |
342 | return checkSameStride(tensors); |
343 | } |
344 | |
345 | bool checkValidMisalignedTensors( |
346 | const std::unordered_set<TensorView*>& inp_tv, |
347 | const std::unordered_set<TensorView*>& out_tv, |
348 | const std::vector<c10::IValue>& inp_tensors, |
349 | const std::vector<c10::IValue>& out_tensors) { |
350 | if (out_tv.empty()) { |
351 | // Only check input tensors |
352 | return checkSameStride(inp_tensors); |
353 | } else if (!out_tv.empty() && out_tensors.empty()) { |
354 | // out_tensors is empty unless outputs are given to runFusion. |
355 | // Assume out tensors are contiguous |
356 | return checkSameContiguity(inp_tensors); |
357 | } else { |
358 | // Only check input and output tensors |
359 | std::vector<c10::IValue> tensors; |
360 | tensors.insert(tensors.end(), inp_tensors.begin(), inp_tensors.end()); |
361 | tensors.insert(tensors.end(), out_tensors.begin(), out_tensors.end()); |
362 | return checkSameStride(tensors); |
363 | } |
364 | } |
365 | |
366 | } // namespace |
367 | |
368 | void validateKernelInputs( |
369 | Fusion* fusion, |
370 | const KernelArgumentHolder& args, |
371 | const c10::Device& device) { |
372 | FUSER_PERF_SCOPE("executor_utils::ValidateKernelInputs" ); |
373 | |
374 | // This is necessary as we were traversing the fusion graph later in the check |
375 | FusionGuard fg(fusion); |
376 | // Check inputs |
377 | TORCH_INTERNAL_ASSERT( |
378 | args.size() == fusion->inputs().size(), "Wrong number of kernel inputs." ); |
379 | |
380 | std::stringstream msg; |
381 | bool mismatch = false; |
382 | for (const auto i : c10::irange(args.size())) { |
383 | const ArgAbstract* arg = args[i]; |
384 | const Val* param = fusion->inputs()[i]; |
385 | mismatch = !validateKernelArg(arg, param, device, msg) || mismatch; |
386 | } |
387 | TORCH_INTERNAL_ASSERT( |
388 | !mismatch, "Found one or more invalid arguments: " , msg.str()); |
389 | } |
390 | |
391 | void validateKernelOutputs( |
392 | Fusion* fusion, |
393 | const std::vector<at::Tensor>& outputs, |
394 | const c10::Device& device) { |
395 | FUSER_PERF_SCOPE("executor_utils::ValidateKernelOutputs" ); |
396 | |
397 | TORCH_INTERNAL_ASSERT( |
398 | fusion->outputs().size() != 0, |
399 | "Kernel should have at least one output tensor." ); |
400 | |
401 | TORCH_INTERNAL_ASSERT( |
402 | outputs.size() == fusion->outputs().size(), |
403 | "Wrong number of kernel outputs." ); |
404 | |
405 | std::stringstream msg; |
406 | bool mismatch = false; |
407 | for (const auto i : c10::irange(outputs.size())) { |
408 | const at::Tensor& arg = outputs[i]; |
409 | const Val* param = fusion->outputs()[i]; |
410 | mismatch = !validateKernelArgTensor(arg, param, device, msg) || mismatch; |
411 | } |
412 | TORCH_INTERNAL_ASSERT( |
413 | !mismatch, "Found one or more invalid arguments: " , msg.str()); |
414 | } |
415 | |
416 | namespace { |
417 | |
418 | // Finds a fusion input or output tensor to validate its stides |
419 | // for vectorization. |
420 | // Returns a pair consisting of a flag indicating it's a fusion input |
421 | // and an integer position within in the input or output tensor list. |
422 | std::vector<std::pair<bool, int>> getVectorizedFusionInputOutput( |
423 | TensorView* producer_tv, |
424 | TensorView* consumer_tv, |
425 | Fusion* fusion) { |
426 | std::vector<std::pair<bool, int>> vectorized_input_output; |
427 | |
428 | // When the producer is a fusion input, validate only the producer |
429 | // and assume the consumer is contiguous. Similarly, when the |
430 | // consumer is a fusion output, validate the consumer and assume the |
431 | // producer is contiguous. |
432 | |
433 | if (producer_tv->isFusionInput()) { |
434 | auto producer_it = std::find( |
435 | fusion->inputs().begin(), fusion->inputs().end(), producer_tv); |
436 | TORCH_INTERNAL_ASSERT( |
437 | producer_it != fusion->inputs().end(), |
438 | "Could not find " , |
439 | producer_tv, |
440 | " in fusion inputs." ); |
441 | auto pos = std::distance(fusion->inputs().begin(), producer_it); |
442 | vectorized_input_output.push_back( |
443 | std::make_pair<bool, int>(true, static_cast<int>(pos))); |
444 | } else { |
445 | // If not fusion input, assume it's fully contiguous, so nothing |
446 | // to check with respect to strides. |
447 | TORCH_INTERNAL_ASSERT( |
448 | std::all_of( |
449 | producer_tv->domain()->contiguity().begin(), |
450 | producer_tv->domain()->contiguity().end(), |
451 | [](bool contig) { return contig; }), |
452 | "Unsupported pattern of vectorization: " , |
453 | consumer_tv->definition()->toString()); |
454 | } |
455 | |
456 | if (consumer_tv->isFusionOutput()) { |
457 | auto consumer_it = std::find( |
458 | fusion->outputs().begin(), fusion->outputs().end(), consumer_tv); |
459 | TORCH_INTERNAL_ASSERT( |
460 | consumer_it != fusion->outputs().end(), |
461 | "Could not find " , |
462 | consumer_tv, |
463 | " in fusion outputs." ); |
464 | auto pos = std::distance(fusion->outputs().begin(), consumer_it); |
465 | vectorized_input_output.push_back( |
466 | std::make_pair<bool, int>(false, static_cast<int>(pos))); |
467 | } else { |
468 | // If not fusion input, assume it's fully contiguous, so nothing |
469 | // to check with respect to strides. |
470 | TORCH_INTERNAL_ASSERT( |
471 | std::all_of( |
472 | consumer_tv->domain()->contiguity().begin(), |
473 | consumer_tv->domain()->contiguity().end(), |
474 | [](bool contig) { return contig; }), |
475 | "Unsupported pattern of vectorization: " , |
476 | consumer_tv->definition()->toString()); |
477 | } |
478 | |
479 | return vectorized_input_output; |
480 | } |
481 | |
482 | //! Returns the information of vectorized input/output tensors |
483 | //! in the given fusion. |
484 | std::unique_ptr<caching::VectorizedTensorInfo> getVectorizedTensorValidationInfo( |
485 | kir::Kernel* kernel) { |
486 | auto vectorized_tensor_info_ptr = |
487 | std::make_unique<caching::VectorizedTensorInfo>(); |
488 | |
489 | for (const auto& vector_info : kernel->summary().vectorized_set_info) { |
490 | auto consumer_tv = vector_info.consumer_tv; |
491 | auto producer_tv = vector_info.producer_tv; |
492 | |
493 | auto vector_dim = vector_info.vectorized_leaf_id; |
494 | const auto is_aligned = |
495 | vector_dim->getParallelType() == ParallelType::Vectorize; |
496 | |
497 | // Find fusion inputs and outputs that are used with misaligned |
498 | // vectorization. |
499 | if (!is_aligned) { |
500 | TORCH_INTERNAL_ASSERT( |
501 | producer_tv->isFusionInput() || consumer_tv->isFusionOutput(), |
502 | "MisalignedVectorize is assumed to be used with either input or output tensor" ); |
503 | if (consumer_tv->getMemoryType() == MemoryType::Global && |
504 | producer_tv->getMemoryType() == MemoryType::Local) { |
505 | vectorized_tensor_info_ptr->global_out_misaligned_tv.insert( |
506 | consumer_tv); |
507 | } else if ( |
508 | producer_tv->getMemoryType() == MemoryType::Global && |
509 | consumer_tv->getMemoryType() == MemoryType::Local) { |
510 | vectorized_tensor_info_ptr->global_inp_misaligned_tv.insert( |
511 | producer_tv); |
512 | } else { |
513 | TORCH_INTERNAL_ASSERT( |
514 | false, |
515 | "Unsupported memory configuration for misaligned vectorization." ); |
516 | } |
517 | } |
518 | |
519 | // Collect information on corresponding fusion input and output |
520 | // tensors to verify strides. |
521 | auto inp_or_out_info = |
522 | getVectorizedFusionInputOutput(producer_tv, consumer_tv, kernel); |
523 | |
524 | // If both producer and consumer are contig and intermediate, |
525 | // nothing to validate with respect to strides. |
526 | if (inp_or_out_info.empty()) { |
527 | continue; |
528 | } |
529 | |
530 | // Misaligned vectorize only allows from input to local or local |
531 | // to output |
532 | if (!is_aligned) { |
533 | TORCH_INTERNAL_ASSERT(inp_or_out_info.size() == 1); |
534 | } |
535 | |
536 | for (const auto& inp_or_out : inp_or_out_info) { |
537 | const bool is_input = inp_or_out.first; |
538 | const int pos = inp_or_out.second; |
539 | |
540 | if (is_aligned) { |
541 | auto& pos_list = is_input |
542 | ? vectorized_tensor_info_ptr->aligned_vectorized_inp_tensor_pos |
543 | : vectorized_tensor_info_ptr->aligned_vectorized_out_tensor_pos; |
544 | pos_list.push_back(pos); |
545 | } else { |
546 | auto& map = is_input |
547 | ? vectorized_tensor_info_ptr->inp_misaligned_tensors_pos |
548 | : vectorized_tensor_info_ptr->out_misaligned_tensors_pos; |
549 | map.emplace_back(pos); |
550 | } |
551 | } |
552 | } |
553 | |
554 | return vectorized_tensor_info_ptr; |
555 | } |
556 | |
557 | // Make sure the root domain(s) comprising the vectorized leaf domain |
558 | // have the (merged) extent that is divisible by the vectorization |
559 | // word size. |
560 | void validateAlignedVectorizeExtents( |
561 | const VectorizedSetInfo& info, |
562 | kir::ExpressionEvaluator& expr_eval) { |
563 | TORCH_INTERNAL_ASSERT( |
564 | !info.contig_root_ids.empty(), |
565 | "No root ID found for vectorization with " , |
566 | info.consumer_tv->toString(), |
567 | " and " , |
568 | info.producer_tv->toString()); |
569 | |
570 | int64_t vectorized_merged_domain_extent = 1; |
571 | for (auto id : info.contig_root_ids) { |
572 | auto extent_val = expr_eval.evaluate(id->extent()); |
573 | TORCH_INTERNAL_ASSERT( |
574 | extent_val.has_value(), |
575 | "Error vectorizing, " , |
576 | info.consumer_tv->toString(), |
577 | " as the extent of a vectorized root domain, " , |
578 | id->toString(), |
579 | ", is unknown." ); |
580 | vectorized_merged_domain_extent *= extent_val->as<int64_t>(); |
581 | } |
582 | |
583 | TORCH_INTERNAL_ASSERT( |
584 | vectorized_merged_domain_extent % info.word_size == 0, |
585 | "Error vectorizing, " , |
586 | info.consumer_tv->toString(), |
587 | " as the extent of the indexed domain, " , |
588 | vectorized_merged_domain_extent, |
589 | ", is not divisible by vector word size " , |
590 | info.word_size); |
591 | } |
592 | |
593 | void validateAlignedVectorizedFusionInputOutput( |
594 | const at::Tensor& aten_tensor, |
595 | int word_size, |
596 | TensorView* tv) { |
597 | TORCH_INTERNAL_ASSERT( |
598 | reinterpret_cast<size_t>(aten_tensor.data_ptr()) % |
599 | (word_size * aten_tensor.dtype().itemsize()) == |
600 | 0, |
601 | "Vectorization of " , |
602 | tv->toString(), |
603 | " not possible as the memory address is not aligned. " , |
604 | "Address: " , |
605 | aten_tensor.data_ptr(), |
606 | ", vector word size: " , |
607 | word_size, |
608 | ", data type: " , |
609 | aten_tensor.dtype()); |
610 | |
611 | // Traverse strides from the right-most domains. The rightmost |
612 | // domain must have stride 1. |
613 | int64_t cur_contig_stride = 1; |
614 | bool still_rightmost = true; |
615 | for (auto i = aten_tensor.ndimension() - 1; i >= 0; --i) { |
616 | const auto stride = aten_tensor.strides().at(i); |
617 | const auto size = aten_tensor.sizes().at(i); |
618 | // If this domain is contiguous or size == 1, then not necessary to check |
619 | // the stride. Otherwise, stride must be 1 if it's rightmost or |
620 | // divisible by word_size |
621 | TORCH_INTERNAL_ASSERT( |
622 | stride == cur_contig_stride || size == 1 || |
623 | (still_rightmost && stride == 1) || |
624 | (!still_rightmost && stride % word_size == 0), |
625 | "Vectorization of " , |
626 | tv->toString(), |
627 | " with word size " , |
628 | word_size, |
629 | " not possible due to invalid stride." , |
630 | " Domain: " , |
631 | tv->axis(i)->toString(), |
632 | ", stride: " , |
633 | stride) |
634 | // If the domain is size-1, the next domain is still considered |
635 | // rightmost. |
636 | still_rightmost = still_rightmost && size == 1; |
637 | // We do not update cur_contig_stride for size==1 dimensions, |
638 | // since we have specialized vectorization stride check for them |
639 | if (size != 1) { |
640 | cur_contig_stride = stride * size; |
641 | } |
642 | } |
643 | } |
644 | |
645 | void validateAlignedVectorizedTensors( |
646 | kir::Kernel* kernel, |
647 | const KernelArgumentHolder& args, |
648 | const std::vector<at::Tensor>& outputs, |
649 | caching::ExecutorCompileTimeInfoCache* data_cache, |
650 | kir::ExpressionEvaluator& expr_eval) { |
651 | auto tensor_vectorization_validation_entry = |
652 | executor_utils::caching::ExecutorCompileTimeEntry< |
653 | executor_utils::caching::VectorizedTensorValidation>( |
654 | data_cache, [kernel]() { |
655 | return executor_utils::getVectorizedTensorValidationInfo(kernel); |
656 | }); |
657 | |
658 | // Verify extents of aligned vectorized tensors |
659 | for (const auto& vec_info : kernel->summary().vectorized_set_info) { |
660 | if (vec_info.vectorized_leaf_id->getParallelType() == |
661 | ParallelType::Vectorize) { |
662 | validateAlignedVectorizeExtents(vec_info, expr_eval); |
663 | } |
664 | } |
665 | |
666 | // Validate input and output tensors with aligend |
667 | // vectorization. |
668 | for (auto pos : tensor_vectorization_validation_entry.get() |
669 | .aligned_vectorized_inp_tensor_pos) { |
670 | auto tv = kernel->inputs().at(pos)->as<TensorView>(); |
671 | auto word_size = kernel->summary().vectorized_accesses.at(tv); |
672 | auto tensor_arg_abstract = |
673 | dynamic_cast<const TensorArgAbstract*>(args[pos]); |
674 | TORCH_INTERNAL_ASSERT(tensor_arg_abstract, "alias io only supports tensor" ); |
675 | validateAlignedVectorizedFusionInputOutput( |
676 | tensor_arg_abstract->getTensor(), word_size, tv); |
677 | } |
678 | if (!outputs.empty()) { |
679 | for (auto pos : tensor_vectorization_validation_entry.get() |
680 | .aligned_vectorized_out_tensor_pos) { |
681 | auto tv = kernel->outputs().at(pos)->as<TensorView>(); |
682 | auto word_size = kernel->summary().vectorized_accesses.at(tv); |
683 | validateAlignedVectorizedFusionInputOutput(outputs[pos], word_size, tv); |
684 | } |
685 | } |
686 | } |
687 | |
688 | // Misaligned vectorization check. Currently misaligned vectorization is limited |
689 | // to global-register and register-global load/store patterns. However, this |
690 | // could be improved to include shared memory. |
691 | void validateMisalignedVectorizedTensors( |
692 | kir::Kernel* kernel, |
693 | const KernelArgumentHolder& args, |
694 | const std::vector<at::Tensor>& outputs, |
695 | caching::ExecutorCompileTimeInfoCache* data_cache, |
696 | kir::ExpressionEvaluator& expr_eval) { |
697 | auto tensor_vectorization_validation_entry = |
698 | executor_utils::caching::ExecutorCompileTimeEntry< |
699 | executor_utils::caching::VectorizedTensorValidation>( |
700 | data_cache, [kernel]() { |
701 | return executor_utils::getVectorizedTensorValidationInfo(kernel); |
702 | }); |
703 | |
704 | std::vector<c10::IValue> inp_misaligned_tensors; |
705 | std::vector<c10::IValue> out_misaligned_tensors; |
706 | |
707 | const auto& inp_misaligned_tensors_pos = |
708 | tensor_vectorization_validation_entry.get().inp_misaligned_tensors_pos; |
709 | inp_misaligned_tensors.reserve(inp_misaligned_tensors_pos.size()); |
710 | std::transform( |
711 | inp_misaligned_tensors_pos.begin(), |
712 | inp_misaligned_tensors_pos.end(), |
713 | std::back_inserter(inp_misaligned_tensors), |
714 | [&args](int idx) { |
715 | auto tensor_arg_abstract = |
716 | dynamic_cast<const TensorArgAbstract*>(args[idx]); |
717 | TORCH_INTERNAL_ASSERT( |
718 | tensor_arg_abstract, "alias io only supports tensor" ); |
719 | return tensor_arg_abstract->getTensor(); |
720 | }); |
721 | |
722 | const auto& out_misaligned_tensors_pos = |
723 | tensor_vectorization_validation_entry.get().out_misaligned_tensors_pos; |
724 | if (outputs.size() > 0) { |
725 | out_misaligned_tensors.reserve(out_misaligned_tensors_pos.size()); |
726 | std::transform( |
727 | out_misaligned_tensors_pos.begin(), |
728 | out_misaligned_tensors_pos.end(), |
729 | std::back_inserter(out_misaligned_tensors), |
730 | [&outputs](int idx) { return outputs[idx]; }); |
731 | } |
732 | // If input stride is non-contiguous + no outputs, return false |
733 | TORCH_INTERNAL_ASSERT( |
734 | checkValidMisalignedTensors( |
735 | tensor_vectorization_validation_entry.get().global_inp_misaligned_tv, |
736 | tensor_vectorization_validation_entry.get().global_out_misaligned_tv, |
737 | inp_misaligned_tensors, |
738 | out_misaligned_tensors), |
739 | "All global tensors must have the same stride for misaligned vectorization." ); |
740 | } |
741 | |
742 | // Check if there's any split that is non-divisible and vectorized. If |
743 | // found, Vectorize is illegal. |
744 | void validateVectorizedSplits( |
745 | kir::Kernel* kernel, |
746 | kir::ExpressionEvaluator& expr_eval) { |
747 | for (const auto& extent_factor : kernel->summary().splits_to_validate) { |
748 | auto input_extent = expr_eval.evaluate(extent_factor.first); |
749 | auto split_factor = expr_eval.evaluate(extent_factor.second); |
750 | TORCH_INTERNAL_ASSERT( |
751 | input_extent.has_value(), |
752 | "Could not check if a split with vectorization is divisible because the extent, " , |
753 | extent_factor.first->toString(), |
754 | ", is not possible to evaluate." ); |
755 | TORCH_INTERNAL_ASSERT( |
756 | input_extent.has_value(), |
757 | "Could not check if a split with vectorization is divisible because the split factor, " , |
758 | extent_factor.second->toString(), |
759 | ", is not possible to evaluate." ); |
760 | TORCH_INTERNAL_ASSERT( |
761 | input_extent.value() % split_factor.value() == 0, |
762 | "Non-divisible split with vectorization is detected. " , |
763 | "Extent: " , |
764 | input_extent.value(), |
765 | ". Factor: " , |
766 | split_factor.value()); |
767 | } |
768 | } |
769 | |
770 | } // namespace |
771 | |
772 | void validateVectorizedTensors( |
773 | kir::Kernel* kernel, |
774 | const KernelArgumentHolder& args, |
775 | const std::vector<at::Tensor>& outputs, |
776 | caching::ExecutorCompileTimeInfoCache* data_cache, |
777 | kir::ExpressionEvaluator& expr_eval) { |
778 | FUSER_PERF_SCOPE("FusionExecutor::validateVectorizedTensors" ); |
779 | |
780 | validateAlignedVectorizedTensors( |
781 | kernel, args, outputs, data_cache, expr_eval); |
782 | |
783 | validateMisalignedVectorizedTensors( |
784 | kernel, args, outputs, data_cache, expr_eval); |
785 | |
786 | validateVectorizedSplits(kernel, expr_eval); |
787 | } |
788 | |
789 | namespace { |
790 | |
791 | template <typename EXPR_EVALUATOR> |
792 | void bindInputForExprEvaluation( |
793 | Val* val, |
794 | const ArgAbstract* arg, |
795 | bool check_consistency, |
796 | EXPR_EVALUATOR& expr_eval) { |
797 | if (val->getValType() == ValType::TensorView) { |
798 | TensorView* cg_tensor = val->as<TensorView>(); |
799 | auto root_domain = |
800 | TensorDomain::noReductions(cg_tensor->getMaybeRFactorDomain()); |
801 | |
802 | if (root_domain.size() == 0) { |
803 | TORCH_INTERNAL_ASSERT( |
804 | arg->isType(ArgType::CpuScalarTensor) || |
805 | (arg->isType(ArgType::Tensor) && |
806 | dynamic_cast<const TensorArgAbstract*>(arg)->getRank() == 0), |
807 | "Something went wrong configuring launch. Inputs is not rank 0 tensor" ); |
808 | } else { |
809 | TORCH_INTERNAL_ASSERT( |
810 | arg->isType(ArgType::Tensor), |
811 | "Something went wrong configuring launch. Inputs do not match." ); |
812 | |
813 | auto tensor_arg_abstract = dynamic_cast<const TensorArgAbstract*>(arg); |
814 | TORCH_INTERNAL_ASSERT( |
815 | tensor_arg_abstract && |
816 | tensor_arg_abstract->getRank() == (int64_t)root_domain.size(), |
817 | "Something went wrong configuring launch. Inputs rank does not match." ); |
818 | |
819 | for (const auto dim : c10::irange(root_domain.size())) { |
820 | const auto tensor_arg_size = tensor_arg_abstract->getSize(dim); |
821 | const auto tensor_arg_stride = tensor_arg_abstract->getStride(dim); |
822 | const auto extent = root_domain[dim]->extent(); |
823 | if (root_domain[dim]->hasExpandedExtent()) { |
824 | TORCH_INTERNAL_ASSERT( |
825 | tensor_arg_stride == 0, |
826 | "Expecting an expanded dimension on dimension " , |
827 | dim, |
828 | " but found stride " , |
829 | tensor_arg_stride); |
830 | // Could support dynamic size on expanded dimension, so may not have |
831 | // an inferable expanded extent here. This check might be better to do |
832 | // once all values are bound. |
833 | auto maybe_expanded_size = |
834 | expr_eval.evaluate(root_domain[dim]->expandedExtent()); |
835 | if (maybe_expanded_size.has_value()) { |
836 | TORCH_CHECK( |
837 | *maybe_expanded_size == tensor_arg_size, |
838 | "Expecting expanded extent of " , |
839 | *maybe_expanded_size, |
840 | " but received value of " , |
841 | tensor_arg_size); |
842 | } |
843 | } |
844 | |
845 | const auto value = |
846 | root_domain[dim]->hasExpandedExtent() ? 1 : tensor_arg_size; |
847 | bool should_bind = true; |
848 | if (check_consistency) { |
849 | const auto prev_value = expr_eval.evaluate(extent); |
850 | if (prev_value.has_value()) { |
851 | TORCH_CHECK( |
852 | *prev_value == value, |
853 | "Attempting to bind " , |
854 | extent->toString(), |
855 | " to " , |
856 | value, |
857 | " but it's already set to " , |
858 | *prev_value); |
859 | should_bind = false; |
860 | } |
861 | } |
862 | if (should_bind && !extent->isConstScalar()) { |
863 | expr_eval.bind(extent, value); |
864 | } |
865 | } |
866 | } |
867 | } else if (val->getValType().value() == ValType::Scalar) { |
868 | if (val->getDataType().value() == DataType::Int) { |
869 | TORCH_INTERNAL_ASSERT( |
870 | arg->isType(ArgType::Long), |
871 | "fusion expected Scalar Int inputs, but found " , |
872 | argTypeToString(arg->type())); |
873 | expr_eval.bind(val, *static_cast<const int64_t*>(arg->arg())); |
874 | } else if (val->getDataType().value() == DataType::Double) { |
875 | TORCH_INTERNAL_ASSERT( |
876 | arg->isType(ArgType::Double), |
877 | "fusion expected Scalar Double inputs, but found " , |
878 | argTypeToString(arg->type())); |
879 | expr_eval.bind(val, *static_cast<const double*>(arg->arg())); |
880 | } |
881 | } |
882 | } |
883 | |
884 | } // namespace |
885 | |
886 | kir::ExpressionEvaluator bindKernelInputs( |
887 | const KernelArgumentHolder& args, |
888 | kir::Kernel* kernel, |
889 | bool check_consistency) { |
890 | FUSER_PERF_SCOPE("executor_utils::BindKernelInputs" ); |
891 | |
892 | TORCH_INTERNAL_ASSERT( |
893 | kernel->inputs().size() == args.size(), |
894 | "Something went wrong configuring launch. Inputs no longer match." ); |
895 | |
896 | kir::ExpressionEvaluator expr_eval; |
897 | const auto& inputs = kernel->inputs(); |
898 | |
899 | for (const auto i : c10::irange(inputs.size())) { |
900 | bindInputForExprEvaluation( |
901 | inputs[i], args[i], check_consistency, expr_eval); |
902 | } |
903 | return expr_eval; |
904 | } |
905 | |
906 | ExpressionEvaluator bindFusionInputs( |
907 | const KernelArgumentHolder& args, |
908 | Fusion* fusion) { |
909 | FUSER_PERF_SCOPE("executor_utils::BindFusionInputs" ); |
910 | |
911 | auto inputs = fusion->inputs(); |
912 | TORCH_INTERNAL_ASSERT( |
913 | inputs.size() == args.size(), |
914 | "Something went wrong configuring launch. Inputs do not match.\n" , |
915 | "inputs: " , |
916 | ir_utils::toString(inputs), |
917 | " args size: " , |
918 | args.size()); |
919 | |
920 | ExpressionEvaluator expr_eval(fusion); |
921 | |
922 | // This should probably move to EvaluationContext as we may want to bind |
923 | // input values frequently. Bind fusion input values to runtime values. |
924 | for (const auto i : c10::irange(inputs.size())) { |
925 | bindInputForExprEvaluation(inputs[i], args[i], true, expr_eval); |
926 | } |
927 | return expr_eval; |
928 | } |
929 | |
930 | namespace { |
931 | |
932 | // Dump PTX or CUBIN to a file |
933 | #if CUDA_VERSION >= 11010 |
934 | void dumpCompiledCodeToFile( |
935 | const nvrtcProgram& program, |
936 | int fusion_id, |
937 | bool dump_cubin) { |
938 | const auto getSize = dump_cubin |
939 | ? at::globalContext().getNVRTC().nvrtcGetCUBINSize |
940 | : at::globalContext().getNVRTC().nvrtcGetPTXSize; |
941 | const auto getCode = dump_cubin ? at::globalContext().getNVRTC().nvrtcGetCUBIN |
942 | : at::globalContext().getNVRTC().nvrtcGetPTX; |
943 | size_t size = 0; |
944 | AT_CUDA_NVRTC_CHECK(getSize(program, &size)); |
945 | std::vector<char> code(size); |
946 | AT_CUDA_NVRTC_CHECK(getCode(program, code.data())); |
947 | std::stringstream file_name; |
948 | file_name << "__tmp_kernel" << fusion_id << "." |
949 | << (dump_cubin ? "cubin" : "ptx" ); |
950 | std::cout << "PRINTING: " << file_name.str() << std::endl; |
951 | std::ofstream out(file_name.str()); |
952 | TORCH_INTERNAL_ASSERT(out.is_open()); |
953 | out.write(code.data(), size); |
954 | out.close(); |
955 | } |
956 | #endif |
957 | |
958 | } // namespace |
959 | |
960 | std::pair<NvrtcFunction, std::string> nvrtcCompile( |
961 | const std::string& code, |
962 | const std::string& func_name, |
963 | int id, |
964 | c10::optional<int> opt_block_size) { |
965 | FUSER_PERF_SCOPE("executor_utils::NVRTC" ); |
966 | if (isOptionDisabled(DisableOption::ArchCheck)) { |
967 | TORCH_WARN( |
968 | "NVFuser Compile: arch check disabled, should not compile any kernel" ); |
969 | } |
970 | |
971 | at::cuda::jit::initializeCudaContext(); |
972 | |
973 | std::stringstream ptxas_log; |
974 | |
975 | const auto prop = at::cuda::getCurrentDeviceProperties(); |
976 | |
977 | int major = 0, minor = 0; |
978 | bool compile_to_sass = false; |
979 | codegenOutputQuery(prop, major, minor, compile_to_sass); |
980 | |
981 | nvrtcProgram program; // NOLINT(cppcoreguidelines-init-variables) |
982 | |
983 | { |
984 | std::stringstream ss; |
985 | ss << "__tmp_kernel" << id << ".cu" ; |
986 | std::string name = ss.str(); |
987 | FUSER_PERF_SCOPE("executor_utils::NvrtcCreateProgram" ); |
988 | AT_CUDA_NVRTC_CHECK(at::globalContext().getNVRTC().nvrtcCreateProgram( |
989 | &program, code.c_str(), name.c_str(), 0, nullptr, nullptr)); |
990 | } |
991 | |
992 | ResourceGuard holdProgram([&] { |
993 | FUSER_PERF_SCOPE("executor_utils::NvrtcDestroyProgram" ); |
994 | AT_CUDA_NVRTC_CHECK( |
995 | at::globalContext().getNVRTC().nvrtcDestroyProgram(&program)); |
996 | }); |
997 | |
998 | #ifdef USE_ROCM |
999 | std::vector<const char*> args = {"--std=c++17" }; |
1000 | #if ROCM_VERSION >= 40200 |
1001 | args.push_back("-hip-pch" ); |
1002 | #endif |
1003 | #else |
1004 | #if CUDA_VERSION < 11010 |
1005 | // compile to sass is not allowed prior to CUDA 11.1 |
1006 | compile_to_sass = false; |
1007 | #endif |
1008 | |
1009 | if (isOptionDisabled(DisableOption::CompileToSass)) { |
1010 | // Allows manually disabling compilation to sass |
1011 | // so the intermediate ptx could be checked. |
1012 | compile_to_sass = false; |
1013 | } |
1014 | // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_) |
1015 | // which gives better backwards compatibility to work on older driver, |
1016 | // (since older driver doesn't necessrily recognize PTX emitted by new |
1017 | // toolkit); |
1018 | // Meanwhile, for forward compatibility (future device with |
1019 | // `unsupported_arch==True`), since SASS are not necessarily compatible, |
1020 | // we fallback to PTX instead. |
1021 | const std::string compute = std::string("--gpu-architecture=" ) + |
1022 | (compile_to_sass ? "sm_" : "compute_" ) + std::to_string(major) + |
1023 | std::to_string(minor); |
1024 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1025 | std::vector<const char*> args = { |
1026 | "--std=c++17" , compute.c_str(), "-default-device" }; |
1027 | #endif |
1028 | |
1029 | const bool disable_fma = isOptionDisabled(DisableOption::Fma); |
1030 | #ifdef USE_ROCM |
1031 | if (disable_fma) { |
1032 | TORCH_WARN_ONCE( |
1033 | "PYTORCH_CUDA_FUSER_DISABLE_FMA is not supported on ROCm, ignoring" ); |
1034 | } |
1035 | #else |
1036 | if (disable_fma) { |
1037 | args.push_back("--fmad=false" ); |
1038 | } else { |
1039 | args.push_back("--fmad=true" ); |
1040 | } |
1041 | #endif |
1042 | // Add line info to generated kernels |
1043 | if (isDebugDumpEnabled(DebugDumpOption::DebugInfo)) { |
1044 | args.push_back("-lineinfo" ); |
1045 | } |
1046 | #ifdef NDEBUG |
1047 | // Avoid excessive register usage from assertion |
1048 | args.push_back("-DNDEBUG" ); |
1049 | #endif |
1050 | |
1051 | if (isOptionEnabled(EnableOption::KernelProfile)) { |
1052 | args.push_back("-DPYTORCH_NVFUSER_PROFILE_KERNEL" ); |
1053 | } |
1054 | |
1055 | const char* ptxas_opt_level = getenv("PYTORCH_NVFUSER_JIT_OPT_LEVEL" ); |
1056 | std::string jit_opt_level = "-O" ; |
1057 | |
1058 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1059 | std::vector<CUjit_option> options; |
1060 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1061 | std::vector<void*> option_vals; |
1062 | std::vector<char> info_log; |
1063 | unsigned int log_size = 8196; |
1064 | |
1065 | if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog) || |
1066 | isDebugDumpEnabled(DebugDumpOption::PerfDebugVerbose)) { |
1067 | // show register usage in compilation log |
1068 | if (compile_to_sass) { |
1069 | args.push_back("--ptxas-options" ); |
1070 | args.push_back("--verbose" ); |
1071 | } else { |
1072 | options.push_back(CU_JIT_LOG_VERBOSE); |
1073 | option_vals.push_back((void*)1); |
1074 | info_log.reserve(log_size); |
1075 | |
1076 | options.push_back(CU_JIT_INFO_LOG_BUFFER); |
1077 | option_vals.push_back((void*)info_log.data()); |
1078 | |
1079 | options.push_back(CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES); |
1080 | option_vals.push_back((void*)(long)log_size); |
1081 | } |
1082 | } |
1083 | |
1084 | if (ptxas_opt_level) { |
1085 | int val = atoi(ptxas_opt_level); |
1086 | if (val <= 4 && val >= 0) { |
1087 | if (val < 4) { |
1088 | TORCH_WARN( |
1089 | "ptxas optimization level manually set as " , |
1090 | val, |
1091 | ", which could negatively affect performance. Try removing env variable PYTORCH_NVFUSER_JIT_OPT_LEVEL for optimal performance." ); |
1092 | } |
1093 | if (compile_to_sass) { |
1094 | jit_opt_level += std::to_string(val); |
1095 | args.push_back("--ptxas-options" ); |
1096 | args.push_back(jit_opt_level.c_str()); |
1097 | } else { |
1098 | options.push_back(CU_JIT_OPTIMIZATION_LEVEL); |
1099 | option_vals.push_back((void*)(intptr_t)val); |
1100 | } |
1101 | } else { |
1102 | TORCH_WARN_ONCE( |
1103 | "acceptable range for PYTORCH_NVFUSER_JIT_OPT_LEVEL is between 0 and 4, but received " , |
1104 | val, |
1105 | ", ignoring the option" ); |
1106 | } |
1107 | } |
1108 | |
1109 | #ifndef USE_ROCM |
1110 | // keeping the string outside the loop for lifetime |
1111 | std::string max_register_usage = "--maxrregcount=" ; |
1112 | uint32_t max_register = 0; |
1113 | if (opt_block_size.has_value() && opt_block_size.value() > 0) { |
1114 | int num_partition = 0; |
1115 | int reg_allocation_granularity = 0; |
1116 | cudaOccDeviceProp occ_prop(*prop); |
1117 | cudaOccSubPartitionsPerMultiprocessor(&num_partition, &occ_prop); |
1118 | cudaOccRegAllocationGranularity(®_allocation_granularity, &occ_prop); |
1119 | int warp_size = prop->warpSize; |
1120 | int num_warps = ceilDiv(opt_block_size.value(), warp_size); |
1121 | |
1122 | // warps could be distributed unevenly across partition |
1123 | int max_warps_per_sm_partition = ceilDiv(num_warps, num_partition); |
1124 | // registers are evenly distributed across partitions, partition with most |
1125 | // wraps determins the maximum register available per warp |
1126 | int max_reg_per_warp = |
1127 | prop->regsPerBlock / num_partition / max_warps_per_sm_partition; |
1128 | // clamp down to register allocation granularity at warp level |
1129 | int effective_max_reg_per_warp = max_reg_per_warp / |
1130 | reg_allocation_granularity * reg_allocation_granularity; |
1131 | // The maximum possible count allowed by ptxas is 255 |
1132 | max_register = static_cast<uint32_t>( |
1133 | std::min(effective_max_reg_per_warp / warp_size, 255)); |
1134 | if (compile_to_sass) { |
1135 | max_register_usage += std::to_string(max_register); |
1136 | args.push_back("--ptxas-options" ); |
1137 | args.push_back(max_register_usage.c_str()); |
1138 | } else { |
1139 | options.push_back(CU_JIT_MAX_REGISTERS); |
1140 | option_vals.push_back((void*)(intptr_t)max_register); |
1141 | } |
1142 | |
1143 | ptxas_log << "\nCompile options: " ; |
1144 | for (auto arg : args) { |
1145 | ptxas_log << arg << " " ; |
1146 | } |
1147 | ptxas_log << " ; block size=" << opt_block_size.value() << "\n" ; |
1148 | } |
1149 | #endif |
1150 | |
1151 | at::globalContext().getNVRTC().nvrtcAddNameExpression( |
1152 | program, func_name.c_str()); |
1153 | |
1154 | { |
1155 | FUSER_PERF_SCOPE("executor_utils::Nvrtc::CompileProgram" ); |
1156 | |
1157 | const auto result = at::globalContext().getNVRTC().nvrtcCompileProgram( |
1158 | program, args.size(), args.data()); |
1159 | |
1160 | size_t logsize = 0; |
1161 | at::globalContext().getNVRTC().nvrtcGetProgramLogSize(program, &logsize); |
1162 | |
1163 | std::vector<char> log(logsize); |
1164 | at::globalContext().getNVRTC().nvrtcGetProgramLog(program, log.data()); |
1165 | |
1166 | if (result != NVRTC_SUCCESS) { |
1167 | TORCH_INTERNAL_ASSERT( |
1168 | false, code.c_str(), "\nCUDA NVRTC compile error: " , log.data()); |
1169 | } |
1170 | |
1171 | ptxas_log << log.data() << std::endl; |
1172 | if (isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) { |
1173 | std::cout << log.data() << std::endl; |
1174 | } |
1175 | AT_CUDA_NVRTC_CHECK(result); |
1176 | } |
1177 | |
1178 | const char* lowered_kernel_name = nullptr; |
1179 | at::globalContext().getNVRTC().nvrtcGetLoweredName( |
1180 | program, func_name.c_str(), &lowered_kernel_name); |
1181 | |
1182 | size_t ptx_size = 0; |
1183 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
1184 | std::vector<char> ptx; |
1185 | |
1186 | { |
1187 | FUSER_PERF_SCOPE("executor_utils::Nvrtc::GetPTX" ); |
1188 | #if CUDA_VERSION >= 11010 |
1189 | // compile_to_sass determines whether we are generating SASS or PTX, hence |
1190 | // the different API. |
1191 | const auto getSize = compile_to_sass |
1192 | ? at::globalContext().getNVRTC().nvrtcGetCUBINSize |
1193 | : at::globalContext().getNVRTC().nvrtcGetPTXSize; |
1194 | const auto getFunc = compile_to_sass |
1195 | ? at::globalContext().getNVRTC().nvrtcGetCUBIN |
1196 | : at::globalContext().getNVRTC().nvrtcGetPTX; |
1197 | #else |
1198 | const auto getSize = at::globalContext().getNVRTC().nvrtcGetPTXSize; |
1199 | const auto getFunc = at::globalContext().getNVRTC().nvrtcGetPTX; |
1200 | #endif |
1201 | AT_CUDA_NVRTC_CHECK(getSize(program, &ptx_size)); |
1202 | ptx.resize(ptx_size); |
1203 | AT_CUDA_NVRTC_CHECK(getFunc(program, ptx.data())); |
1204 | } |
1205 | |
1206 | NvrtcFunction compiled_kernel_; |
1207 | |
1208 | #ifndef USE_ROCM |
1209 | |
1210 | #if CUDA_VERSION >= 11010 |
1211 | if (isDebugDumpEnabled(DebugDumpOption::Ptx)) { |
1212 | dumpCompiledCodeToFile(program, id, false); |
1213 | } |
1214 | |
1215 | if (isDebugDumpEnabled(DebugDumpOption::Cubin)) { |
1216 | TORCH_INTERNAL_ASSERT( |
1217 | compile_to_sass, |
1218 | "CUBIN not available as the kernel was compiled only to PTX" ); |
1219 | dumpCompiledCodeToFile(program, id, true); |
1220 | } |
1221 | #endif |
1222 | |
1223 | { |
1224 | FUSER_PERF_SCOPE("executor_utils::Nvrtc::LoadPTX" ); |
1225 | |
1226 | // load ptx or cubin directly |
1227 | AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadDataEx( |
1228 | &(compiled_kernel_.module), |
1229 | ptx.data(), |
1230 | options.size(), |
1231 | options.data(), |
1232 | option_vals.data())); |
1233 | |
1234 | if (!compile_to_sass && |
1235 | isDebugDumpEnabled(DebugDumpOption::PrintPtxasLog)) { |
1236 | std::cout << info_log.data() << std::endl; |
1237 | } |
1238 | } |
1239 | #else |
1240 | // load ptx directly |
1241 | AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleLoadData( |
1242 | &(compiled_kernel_.module), ptx.data())); |
1243 | |
1244 | #endif |
1245 | AT_CUDA_DRIVER_CHECK(at::globalContext().getNVRTC().cuModuleGetFunction( |
1246 | &(compiled_kernel_.function), |
1247 | compiled_kernel_.module, |
1248 | lowered_kernel_name)); |
1249 | |
1250 | TORCH_CHECK( |
1251 | !isOptionDisabled(DisableOption::ArchCheck), |
1252 | "NVFuser Compile: arch check disabled, should not return any compiled kernel" ); |
1253 | |
1254 | return {compiled_kernel_, ptxas_log.str()}; |
1255 | } |
1256 | |
1257 | namespace caching { |
1258 | |
1259 | //! CompileTimeInfo is the actual subclass of CompileTimeInfoBase that will |
1260 | //! be stored in the data cache. It owns a data_ state internally of the |
1261 | //! dataType defined within the entry class, which are listed in header file. |
1262 | template <typename EntryClass> |
1263 | class CompileTimeInfo : public CompileTimeInfoBase { |
1264 | public: |
1265 | CompileTimeInfo(std::unique_ptr<typename EntryClass::DataType> data) |
1266 | : CompileTimeInfoBase(EntryClass::EntryType), data_(std::move(data)) {} |
1267 | |
1268 | typename EntryClass::DataType* get() { |
1269 | return data_.get(); |
1270 | } |
1271 | |
1272 | private: |
1273 | std::unique_ptr<typename EntryClass::DataType> data_; |
1274 | }; |
1275 | |
1276 | void ExecutorCompileTimeInfoCache::insert(EntryOwningPtr new_entry) { |
1277 | // Just overwrite when insertion duplicates, equality not checked. |
1278 | entry_type_map_[new_entry->type()] = new_entry.get(); |
1279 | entries_.emplace_back(std::move(new_entry)); |
1280 | } |
1281 | |
1282 | template <typename EntryClass> |
1283 | ExecutorCompileTimeEntry<EntryClass>::ExecutorCompileTimeEntry( |
1284 | ExecutorCompileTimeInfoCache* data_cache, |
1285 | MakerFnType fn) { |
1286 | using InfoType = CompileTimeInfo<EntryClass>; |
1287 | |
1288 | if (!data_cache || !data_cache->has(EntryClass::EntryType)) { |
1289 | owned_data_ = fn(); |
1290 | data_ptr_ = owned_data_.get(); |
1291 | |
1292 | if (data_cache) { |
1293 | std::unique_ptr<CompileTimeInfoBase> new_entry = |
1294 | std::make_unique<InfoType>(std::move(owned_data_)); |
1295 | data_cache->insert(std::move(new_entry)); |
1296 | } |
1297 | } else { |
1298 | data_ptr_ = |
1299 | data_cache->at(EntryClass::EntryType)->template as<InfoType>()->get(); |
1300 | } |
1301 | } |
1302 | |
1303 | // Template instantiation |
1304 | template class ExecutorCompileTimeEntry<ParallelBindingIterDomains>; |
1305 | template class ExecutorCompileTimeEntry<ParallelIterExtentMap>; |
1306 | template class ExecutorCompileTimeEntry<SimplifiedParallelIterExtentMap>; |
1307 | template class ExecutorCompileTimeEntry<WarpPaddedParallelExtents>; |
1308 | template class ExecutorCompileTimeEntry<VectorizedTensorValidation>; |
1309 | template class ExecutorCompileTimeEntry<InputAliasIndices>; |
1310 | template class ExecutorCompileTimeEntry<OutputAliasIndices>; |
1311 | |
1312 | } // namespace caching |
1313 | |
1314 | std::vector<IterDomain*> getParallelBindingsIterDomains( |
1315 | GpuLower* lower, |
1316 | const std::vector<TensorView*>& used_tvs) { |
1317 | std::vector<IterDomain*> parallel_ids; |
1318 | for (auto tv : used_tvs) { |
1319 | for (auto id : tv->domain()->domain()) { |
1320 | if (id->isThread()) { |
1321 | if (id->isBroadcast()) { |
1322 | // Want to keep the broadcast dimensions if they are not resolved |
1323 | // TODO: piping down the parallel dimension map here would |
1324 | // be helpful |
1325 | if (lower->caMap()->getConcreteMappedID(id, IdMappingMode::LOOP) == |
1326 | id) { |
1327 | parallel_ids.push_back(id); |
1328 | } |
1329 | } else { |
1330 | // Non broadcast ids are directly added to the binding |
1331 | // ids. |
1332 | parallel_ids.push_back(id); |
1333 | } |
1334 | } |
1335 | } |
1336 | } |
1337 | return parallel_ids; |
1338 | } |
1339 | |
1340 | namespace { |
1341 | |
1342 | void insertParallelExtent( |
1343 | IterDomain* binding_id, |
1344 | const std::unique_ptr<ParallelExtentMap>& parallel_iter_extents_ptr) { |
1345 | auto extent = binding_id->extent(); |
1346 | const auto it = |
1347 | parallel_iter_extents_ptr->find(binding_id->getParallelType()); |
1348 | if (it != parallel_iter_extents_ptr->end()) { |
1349 | it->second.push_back(extent); |
1350 | } else { |
1351 | parallel_iter_extents_ptr->operator[](binding_id->getParallelType()) = { |
1352 | extent}; |
1353 | } |
1354 | } |
1355 | |
1356 | } // namespace |
1357 | |
1358 | std::unique_ptr<ParallelExtentMap> getParallelIterExtents( |
1359 | std::vector<IterDomain*>& parallel_binding_ids) { |
1360 | auto parallel_iter_extents_ptr = std::make_unique<ParallelExtentMap>(); |
1361 | for (auto id : parallel_binding_ids) { |
1362 | insertParallelExtent(id, parallel_iter_extents_ptr); |
1363 | } |
1364 | |
1365 | return parallel_iter_extents_ptr; |
1366 | } |
1367 | |
1368 | std::unique_ptr<ParallelExtentMap> getSimplifiedParallelIterExtents( |
1369 | GpuLower* lower, |
1370 | std::vector<IterDomain*>& parallel_binding_ids) { |
1371 | auto parallel_iter_extents_ptr = std::make_unique<ParallelExtentMap>(); |
1372 | const auto& ca_map = lower->caMap(); |
1373 | std::vector<IterDomain*> mapped; |
1374 | bool is_tidx_warp_padded = lower->getWarpPaddedParallelInfo().is_tidx_padded; |
1375 | |
1376 | for (auto id : parallel_binding_ids) { |
1377 | if (std::any_of( |
1378 | mapped.begin(), mapped.end(), [id, &ca_map](IterDomain* mapped_id) { |
1379 | return ca_map->areMapped(mapped_id, id, IdMappingMode::LOOP); |
1380 | })) { |
1381 | if (id->getParallelType() != ParallelType::TIDx || !is_tidx_warp_padded) { |
1382 | continue; |
1383 | } |
1384 | } |
1385 | |
1386 | insertParallelExtent( |
1387 | ca_map->getConcreteMappedID(id, IdMappingMode::LOOP), |
1388 | parallel_iter_extents_ptr); |
1389 | mapped.push_back(id); |
1390 | } |
1391 | |
1392 | return parallel_iter_extents_ptr; |
1393 | } |
1394 | |
1395 | std::unique_ptr<caching::WarpPaddedExtentsInfo> getWarpPaddedExtentsInfo( |
1396 | kir::Kernel* kernel, |
1397 | std::vector<IterDomain*>& parallel_binding_ids) { |
1398 | auto warp_padded_extent_info_ptr = |
1399 | std::make_unique<caching::WarpPaddedExtentsInfo>(); |
1400 | auto& warp_padded_extent_set = |
1401 | warp_padded_extent_info_ptr->warp_padded_extent_set; |
1402 | auto& warp_padded_constant = |
1403 | warp_padded_extent_info_ptr->warp_padded_constant; |
1404 | bool has_warp_reduction = |
1405 | kernel->getWarpPaddedParallelInfo().has_warp_reduction; |
1406 | |
1407 | for (auto id : parallel_binding_ids) { |
1408 | // Apply warp padding only when there're warp reductions in |
1409 | // the kernel. |
1410 | if (has_warp_reduction) { |
1411 | if (id->hasPaddingToMultipleOfWarp() || |
1412 | kernel->isParallelTypePadded(id->getParallelType())) { |
1413 | auto extent = id->extent(); |
1414 | warp_padded_extent_set.insert(extent); |
1415 | auto padded_value = id->getMaybeSizeAfterPadding(); |
1416 | if (padded_value.has_value()) { |
1417 | warp_padded_constant[extent] = padded_value.value(); |
1418 | } |
1419 | } |
1420 | } |
1421 | } |
1422 | return warp_padded_extent_info_ptr; |
1423 | } |
1424 | |
1425 | } // namespace executor_utils |
1426 | } // namespace cuda |
1427 | } // namespace fuser |
1428 | } // namespace jit |
1429 | } // namespace torch |
1430 | |