1#include <type.h>
2
3#include <ATen/cuda/CUDAContext.h>
4
5#include <stdexcept>
6#include <unordered_map>
7
8namespace torch {
9namespace jit {
10namespace fuser {
11namespace cuda {
12
13DataType indexModeToDtype(KernelIndexMode index_mode) {
14 switch (index_mode) {
15 case KernelIndexMode::INT32:
16 return DataType::Int32;
17 case KernelIndexMode::INT64:
18 return DataType::Int;
19 default:
20 TORCH_CHECK(false, "Invalid kernel index mode type.");
21 }
22}
23
24bool isFloatingPointType(DataType dtype) {
25 switch (dtype) {
26 case DataType::Double:
27 case DataType::Float:
28 case DataType::Half:
29 case DataType::BFloat16:
30 return true;
31 case DataType::Bool:
32 case DataType::Index:
33 case DataType::Int:
34 case DataType::Int32:
35 case DataType::ComplexFloat:
36 case DataType::ComplexDouble:
37 return false;
38 case DataType::Null:
39 TORCH_CHECK(
40 false, "Null type is not a valid argument to isFloatingPointType");
41 default:
42 TORCH_CHECK(false, "Type not supported in isFloatingPointType");
43 }
44}
45
46bool isBooleanType(DataType dtype) {
47 switch (dtype) {
48 case DataType::Bool:
49 return true;
50 case DataType::Double:
51 case DataType::Float:
52 case DataType::Half:
53 case DataType::BFloat16:
54 case DataType::ComplexFloat:
55 case DataType::ComplexDouble:
56 case DataType::Index:
57 case DataType::Int:
58 case DataType::Int32:
59 return false;
60 case DataType::Null:
61 TORCH_CHECK(false, "Null type is not a valid argument to isBooleanType");
62 default:
63 TORCH_CHECK(false, "Type not supported in isBooleanType");
64 }
65}
66
67bool isIntegralType(DataType dtype) {
68 switch (dtype) {
69 case DataType::Bool:
70 case DataType::Double:
71 case DataType::Float:
72 case DataType::Half:
73 case DataType::BFloat16:
74 case DataType::ComplexFloat:
75 case DataType::ComplexDouble:
76 return false;
77 case DataType::Index:
78 case DataType::Int:
79 case DataType::Int32:
80 return true;
81 case DataType::Null:
82 TORCH_CHECK(false, "Null type is not a valid argument to isIntegralType");
83 default:
84 TORCH_CHECK(false, "Type not supported in isIntegralType");
85 }
86}
87
88bool isComplexType(DataType dtype) {
89 switch (dtype) {
90 case DataType::ComplexFloat:
91 case DataType::ComplexDouble:
92 return true;
93 case DataType::Bool:
94 case DataType::Double:
95 case DataType::Float:
96 case DataType::Half:
97 case DataType::BFloat16:
98 case DataType::Int:
99 case DataType::Index:
100 case DataType::Int32:
101 return false;
102 case DataType::Null:
103 TORCH_CHECK(false, "Null type is not a valid argument to isComplexType");
104 default:
105 TORCH_CHECK(false, "Type not supported in isComplexType");
106 }
107}
108
109bool isVectorType(DataType dtype) {
110 switch (dtype) {
111 case DataType::Float_2:
112 case DataType::Double_2:
113 return true;
114 default:
115 return false;
116 }
117}
118
119DataType getVectorType(DataType dtype, size_t vec_size) {
120 switch (dtype) {
121 case DataType::Float:
122 TORCH_INTERNAL_ASSERT(vec_size == 2, "Not supported vectorized type");
123 return DataType::Float_2;
124 case DataType::Double:
125 TORCH_INTERNAL_ASSERT(vec_size == 2, "Not supported vectorized type");
126 return DataType::Double_2;
127 default:
128 TORCH_INTERNAL_ASSERT(
129 false, "Not supported vectorized type:", dtype, " and ", vec_size);
130 }
131}
132
133int getVectorSizeFromType(DataType dtype) {
134 switch (dtype) {
135 case DataType::Float_2:
136 case DataType::Double_2:
137 return 2;
138 default:
139 TORCH_INTERNAL_ASSERT(false, "Not a vector type:", dtype);
140 }
141}
142
143DataType getTypeFromVectorType(DataType dtype) {
144 switch (dtype) {
145 case DataType::Float_2:
146 return DataType::Float;
147 case DataType::Double_2:
148 return DataType::Double;
149 default:
150 TORCH_INTERNAL_ASSERT(false, "Not a vector type:", dtype);
151 }
152}
153
154DataType getTypeFromComplexType(DataType dtype) {
155 switch (dtype) {
156 case DataType::ComplexFloat:
157 return DataType::Float;
158 case DataType::ComplexDouble:
159 return DataType::Double;
160 default:
161 TORCH_INTERNAL_ASSERT(false, "Not a vector type:", dtype);
162 }
163}
164
165bool isSupportedTypeByDevice(DataType dtype) {
166 auto prop = at::cuda::getCurrentDeviceProperties();
167 auto major_ver = prop->major;
168 switch (dtype) {
169 case DataType::BFloat16:
170 return major_ver >= 8;
171 default:
172 return true;
173 }
174}
175
176bool isIntegerOp(const BinaryOpType bopt) {
177 return bopt >= BinaryOpType::Mod && bopt <= BinaryOpType::Rshift;
178}
179
180bool isLogicalOp(const BinaryOpType bopt) {
181 return bopt >= BinaryOpType::Eq && bopt <= BinaryOpType::NE;
182}
183
184bool alsoBooleanOperator(const BinaryOpType bopt) {
185 return bopt >= BinaryOpType::And && bopt <= BinaryOpType::Xor;
186}
187
188bool alsoBooleanOperator(const UnaryOpType uopt) {
189 return uopt >= UnaryOpType::Not && uopt <= UnaryOpType::Not;
190}
191
192// Return highest on list (smallest enum val)
193DataType promote_type(const DataType& t1, const DataType& t2) {
194 TORCH_CHECK(
195 DataType::Null != t1 && DataType::Null != t2,
196 "Expected promotable DataTypes but got: ",
197 t1,
198 " and ",
199 t2);
200 return aten_to_data_type(
201 c10::promoteTypes(data_type_to_aten(t1), data_type_to_aten(t2)));
202}
203
204// Return highest on list (smallest enum val)
205ValType promote_type(const ValType& t1, const ValType& t2) {
206 if (t1 == ValType::TensorView || t2 == ValType::TensorView) {
207 return ValType::TensorView;
208 }
209 if (t1 == ValType::Scalar &&
210 (t2 == ValType::Scalar || t2 == ValType::NamedScalar)) {
211 return ValType::Scalar;
212 }
213 if (t2 == ValType::Scalar &&
214 (t1 == ValType::Scalar || t1 == ValType::NamedScalar)) {
215 return ValType::Scalar;
216 }
217 if (t1 == ValType::NamedScalar && t2 == ValType::NamedScalar) {
218 return ValType::Scalar;
219 }
220 TORCH_CHECK(false, "Expected promotable ValTypes but got: ", t1, " and ", t2);
221}
222
223static const char* data_type2string(DataType t) {
224 switch (t) {
225 case DataType::Bool:
226 return "bool";
227 case DataType::Double:
228 return "double";
229 case DataType::Float:
230 return "float";
231 case DataType::Half:
232 return "__half";
233 case DataType::BFloat16:
234 return "__bfloat";
235 case DataType::Int:
236 return "int64_t";
237 case DataType::Index:
238 return "nvfuser_index_t";
239 case DataType::Int32:
240 return "int";
241 case DataType::ComplexFloat:
242 return "std::complex<float>";
243 case DataType::ComplexDouble:
244 return "std::complex<double>";
245 case DataType::Double_2:
246 return "Array<double, 2, 1>";
247 case DataType::Float_2:
248 return "Array<float, 2, 1>";
249 case DataType::Null:
250 return "null_type";
251 default:
252 break;
253 }
254 TORCH_INTERNAL_ASSERT(false, "No string found for data type.");
255 return nullptr;
256}
257
258static const char* val_type2string(ValType t) {
259 switch (t) {
260 case ValType::TensorView:
261 return "TensorView";
262 case ValType::TensorDomain:
263 return "TensorDomain";
264 case ValType::IterDomain:
265 return "IterDomain";
266 case ValType::Scalar:
267 return "Scalar";
268 case ValType::NamedScalar:
269 return "NamedScalar";
270 case ValType::Predicate:
271 return "Predicate";
272 case ValType::TensorIndex:
273 return "TensorIndex";
274 case ValType::IntPair:
275 return "IntPair";
276 default:
277 TORCH_INTERNAL_ASSERT(false, "No string found for val type.");
278 }
279}
280
281static const char* predicate_type2string(PredicateType t) {
282 switch (t) {
283 case PredicateType::Manual:
284 return "Manual";
285 case PredicateType::Inline:
286 return "Inline";
287 case PredicateType::Unswitch:
288 return "Unswitch";
289 case PredicateType::Vectorize:
290 return "Vectorize";
291 case PredicateType::Misaligned:
292 return "Misaligned";
293 case PredicateType::Shift:
294 return "Shift";
295 case PredicateType::Padding:
296 return "Padding";
297 case PredicateType::ReductionWrite:
298 return "ReductionWrite";
299 default:
300 TORCH_INTERNAL_ASSERT(false, "No string found for predicate type.");
301 }
302}
303
304static const char* expr_type2string(ExprType t) {
305 switch (t) {
306 case ExprType::FullOp:
307 return "FullOp";
308 case ExprType::ARangeOp:
309 return "ARangeOp";
310 case ExprType::EyeOp:
311 return "EyeOp";
312 case ExprType::UnaryOp:
313 return "UnaryOp";
314 case ExprType::BinaryOp:
315 return "BinaryOp";
316 case ExprType::TernaryOp:
317 return "TernaryOp";
318 case ExprType::RNGOp:
319 return "RNGOp";
320 case ExprType::ReductionOp:
321 return "ReductionOp";
322 case ExprType::GroupedReductionOp:
323 return "GroupedReductionOp";
324 case ExprType::BroadcastOp:
325 return "BroadcastOp";
326 case ExprType::WelfordOp:
327 return "WelfordOp";
328 case ExprType::GroupedWelfordOp:
329 return "GroupedWelfordOp";
330 case ExprType::LoadStoreOp:
331 return "LoadStoreOp";
332 case ExprType::MmaOp:
333 return "MmaOp";
334 case ExprType::TransposeOp:
335 return "TransposeOp";
336 case ExprType::ExpandOp:
337 return "ExpandOp";
338 case ExprType::ShiftOp:
339 return "ShiftOp";
340 case ExprType::GatherOp:
341 return "GatherOp";
342 case ExprType::ViewAsScalar:
343 return "ViewAsScalar";
344 case ExprType::ViewOp:
345 return "ViewOp";
346 case ExprType::Split:
347 return "Split";
348 case ExprType::Merge:
349 return "Merge";
350 case ExprType::Allocate:
351 return "Allocate";
352 case ExprType::BlockSync:
353 return "BlockSync";
354 case ExprType::GridSync:
355 return "GridSync";
356 case ExprType::CpAsyncWait:
357 return "CpAsyncWait";
358 case ExprType::CpAsyncCommit:
359 return "CpAsyncCommit";
360 case ExprType::InitMagicZero:
361 return "InitMagicZero";
362 case ExprType::UpdateMagicZero:
363 return "UpdateMagicZero";
364 case ExprType::ForLoop:
365 return "ForLoop";
366 case ExprType::IfThenElse:
367 return "IfThenElse";
368 case ExprType::GridReduction:
369 return "GridReduction";
370 case ExprType::GroupedGridReduction:
371 return "GroupedGridReduction";
372 case ExprType::GridBroadcast:
373 return "GridBroadcast";
374 case ExprType::GridWelford:
375 return "GridWelford";
376 case ExprType::GroupedGridWelford:
377 return "GroupedGridWelford";
378 case ExprType::Swizzle2D:
379 return "Swizzle2D";
380 case ExprType::Swizzle2DInt:
381 return "Swizzle2DInt";
382 case ExprType::PairSelect:
383 return "PairSelect";
384 default:
385 TORCH_INTERNAL_ASSERT(false, "No string found for expr type.");
386 }
387}
388
389bool needFloatSuffix(UnaryOpType t) {
390 switch (t) {
391 case UnaryOpType::Abs:
392 case UnaryOpType::Cast:
393 case UnaryOpType::Frac:
394 case UnaryOpType::Gelu:
395 case UnaryOpType::Silu:
396 case UnaryOpType::BitCast:
397 case UnaryOpType::Neg:
398 case UnaryOpType::Relu:
399 case UnaryOpType::Reciprocal:
400 case UnaryOpType::Set:
401 case UnaryOpType::Sigmoid:
402 case UnaryOpType::IsFinite:
403 case UnaryOpType::IsInf:
404 case UnaryOpType::IsNan:
405 case UnaryOpType::IsNegInf:
406 case UnaryOpType::IsPosInf:
407 case UnaryOpType::IsReal:
408 case UnaryOpType::Print:
409 return false;
410 default:
411 return true;
412 }
413}
414
415bool needFloatSuffix(RNGOpType t) {
416 return true;
417}
418
419static const char* unary_op_type2string(UnaryOpType t) {
420 switch (t) {
421 case UnaryOpType::Abs:
422 return "abs";
423 case UnaryOpType::Acos:
424 return "acos";
425 case UnaryOpType::Asin:
426 return "asin";
427 case UnaryOpType::Atan:
428 return "atan";
429 case UnaryOpType::Atanh:
430 return "atanh";
431 case UnaryOpType::Cast:
432 return "cast";
433 case UnaryOpType::Ceil:
434 return "ceil";
435 case UnaryOpType::Cos:
436 return "cos";
437 case UnaryOpType::Cosh:
438 return "cosh";
439 case UnaryOpType::Exp:
440 return "exp";
441 case UnaryOpType::Expm1:
442 return "expm1";
443 case UnaryOpType::Erf:
444 return "erf";
445 case UnaryOpType::Erfc:
446 return "erfc";
447 case UnaryOpType::Floor:
448 return "floor";
449 case UnaryOpType::Frac:
450 return "frac";
451 case UnaryOpType::Silu:
452 return "silu";
453 case UnaryOpType::Lgamma:
454 return "lgamma";
455 case UnaryOpType::Log:
456 return "log";
457 case UnaryOpType::Log10:
458 return "log10";
459 case UnaryOpType::Log1p:
460 return "log1p";
461 case UnaryOpType::Log2:
462 return "log2";
463 case UnaryOpType::BitCast:
464 return "erase_type";
465 case UnaryOpType::Neg:
466 return "neg";
467 case UnaryOpType::Not:
468 return "not";
469 case UnaryOpType::Print:
470 return "print";
471 case UnaryOpType::Reciprocal:
472 return "reciprocal";
473 case UnaryOpType::Relu:
474 return "relu";
475 case UnaryOpType::Rsqrt:
476 return "rsqrt";
477 case UnaryOpType::Round:
478 return "nearbyint";
479 case UnaryOpType::Set:
480 return "set";
481 case UnaryOpType::Sigmoid:
482 return "sigmoid";
483 case UnaryOpType::Sin:
484 return "sin";
485 case UnaryOpType::Sinh:
486 return "sinh";
487 case UnaryOpType::Sqrt:
488 return "sqrt";
489 case UnaryOpType::Tan:
490 return "tan";
491 case UnaryOpType::Tanh:
492 return "tanh";
493 case UnaryOpType::Trunc:
494 return "trunc";
495 case UnaryOpType::IsFinite:
496 return "isfinite";
497 case UnaryOpType::IsInf:
498 return "isinf";
499 case UnaryOpType::IsNan:
500 return "isnan";
501 case UnaryOpType::IsNegInf:
502 return "isneginf";
503 case UnaryOpType::IsPosInf:
504 return "isposinf";
505 case UnaryOpType::IsReal:
506 return "isreal";
507 case UnaryOpType::Real:
508 return "std::real";
509 case UnaryOpType::Imag:
510 return "std::imag";
511 default:
512 TORCH_INTERNAL_ASSERT(false, "No string found for unary op type.");
513 }
514}
515
516std::string stringifyBooleanOp(const UnaryOpType uopt) {
517 TORCH_INTERNAL_ASSERT(
518 uopt == UnaryOpType::Not, uopt, " is not a boolean operator.");
519 return "!";
520}
521
522static const char* unary_op_type_inline_op2string(UnaryOpType t) {
523 switch (t) {
524 case UnaryOpType::Neg:
525 return "-";
526 case UnaryOpType::Not:
527 return "~";
528 case UnaryOpType::Set:
529 return "";
530 case UnaryOpType::Address:
531 return "(int64_t) &";
532 default:
533 break;
534 }
535 return nullptr;
536}
537
538bool needFloatSuffix(BinaryOpType t) {
539 switch (t) {
540 case BinaryOpType::Atan2:
541 case BinaryOpType::Div:
542 case BinaryOpType::Fmod:
543 return true;
544 default:
545 return false;
546 }
547}
548
549static const char* binary_op_type2string(BinaryOpType t) {
550 switch (t) {
551 case BinaryOpType::Add:
552 return "add";
553 case BinaryOpType::Atan2:
554 return "atan2";
555 case BinaryOpType::Div:
556 return "div";
557 case BinaryOpType::Fmod:
558 return "fmod";
559 case BinaryOpType::Max:
560 return "fmax";
561 case BinaryOpType::Min:
562 return "fmin";
563 case BinaryOpType::Mul:
564 return "mul";
565 case BinaryOpType::Pow:
566 return "pow";
567 case BinaryOpType::Remainder:
568 return "remainder";
569 case BinaryOpType::Sub:
570 return "sub";
571
572 // Integer Ops
573 case BinaryOpType::Mod:
574 return "mod";
575 case BinaryOpType::CeilDiv:
576 return "ceilDiv";
577 case BinaryOpType::Lshift:
578 return "lshift";
579 case BinaryOpType::Rshift:
580 return "rshift";
581
582 // Logical Ops
583 case BinaryOpType::And:
584 return "and";
585 case BinaryOpType::Eq:
586 return "equal";
587 case BinaryOpType::GE:
588 return "greaterThanOrEqual";
589 case BinaryOpType::GT:
590 return "greaterThan";
591 case BinaryOpType::LE:
592 return "lessThanOrEqual";
593 case BinaryOpType::LT:
594 return "lessThan";
595 case BinaryOpType::NE:
596 return "notEqual";
597 default:
598 TORCH_INTERNAL_ASSERT(false, "No string found for binary op type.");
599 }
600}
601
602static const char* binary_op_integer_op2string(BinaryOpType t) {
603 switch (t) {
604 case BinaryOpType::Max:
605 return "max";
606 case BinaryOpType::Min:
607 return "min";
608 case BinaryOpType::Fmod:
609 return "fmod";
610 default:
611 break;
612 }
613 return nullptr;
614}
615
616static const char* binary_op_bool_op2string(BinaryOpType t) {
617 switch (t) {
618 case BinaryOpType::Max:
619 return "max";
620 case BinaryOpType::Min:
621 return "min";
622 default:
623 break;
624 }
625 return nullptr;
626}
627
628static const char* binary_op_type_inline_op2string(BinaryOpType t) {
629 switch (t) {
630 case BinaryOpType::Add:
631 return "+";
632 case BinaryOpType::Div:
633 return "/";
634 case BinaryOpType::Mul:
635 return "*";
636 case BinaryOpType::Sub:
637 return "-";
638
639 // Integer ops
640 case BinaryOpType::Mod:
641 return "%";
642 case BinaryOpType::Lshift:
643 return "<<";
644 case BinaryOpType::Rshift:
645 return ">>";
646 // Logical Ops
647 case BinaryOpType::Eq:
648 return "==";
649 case BinaryOpType::GE:
650 return ">=";
651 case BinaryOpType::GT:
652 return ">";
653 case BinaryOpType::LE:
654 return "<=";
655 case BinaryOpType::LT:
656 return "<";
657 case BinaryOpType::NE:
658 return "!=";
659 // Assume bitwise, otherwise use stringifyBooleanOp
660 case BinaryOpType::And:
661 return "&";
662 case BinaryOpType::Or:
663 return "|";
664 case BinaryOpType::Xor:
665 return "^";
666 default:
667 break;
668 }
669 return nullptr;
670}
671
672static const char* rng_op_type_inline_op2string(RNGOpType t) {
673 switch (t) {
674 case RNGOpType::Uniform:
675 return "rng_uniform";
676 case RNGOpType::UniformRange:
677 return "rng_uniform_range";
678 default:
679 break;
680 }
681 return nullptr;
682}
683
684std::string stringifyBooleanOp(const BinaryOpType bopt) {
685 switch (bopt) {
686 case BinaryOpType::And:
687 return "&&";
688 case BinaryOpType::Or:
689 return "||";
690 case BinaryOpType::Xor:
691 return "!=";
692 default:
693 TORCH_INTERNAL_ASSERT(false, bopt, " is not a boolean operator.")
694 }
695}
696
697static const char* ternary_op_type2string(TernaryOpType t) {
698 switch (t) {
699 case TernaryOpType::Clamp:
700 return "clamp";
701 case TernaryOpType::Lerp:
702 return "lerp";
703 case TernaryOpType::Threshold:
704 return "threshold";
705 case TernaryOpType::Where:
706 return "where";
707 default:
708 TORCH_INTERNAL_ASSERT(false, "Unexpected TernaryOpType");
709 }
710}
711
712static const char* rng_op_type2string(RNGOpType t) {
713 switch (t) {
714 case RNGOpType::Uniform:
715 return "rng_uniform";
716 case RNGOpType::UniformRange:
717 return "rng_uniform_range";
718 default:
719 TORCH_INTERNAL_ASSERT(false, "Unexpected RNGOpType");
720 }
721}
722
723static const char* parallel_type2string(ParallelType t) {
724 switch (t) {
725 case ParallelType::BIDz:
726 return "blockIdx.z";
727 case ParallelType::BIDy:
728 return "blockIdx.y";
729 case ParallelType::BIDx:
730 return "blockIdx.x";
731 case ParallelType::TIDz:
732 return "threadIdx.z";
733 case ParallelType::TIDy:
734 return "threadIdx.y";
735 case ParallelType::TIDx:
736 return "threadIdx.x";
737 case ParallelType::Vectorize:
738 return "V";
739 case ParallelType::MisalignedVectorize:
740 return "MV";
741 case ParallelType::Unroll:
742 return "UR";
743 case ParallelType::Unswitch:
744 return "US";
745 case ParallelType::Mma:
746 return "MMA";
747 case ParallelType::Group:
748 return "G";
749 case ParallelType::Serial:
750 return "S";
751 default:
752 TORCH_INTERNAL_ASSERT(false, "Unexpected ParallelType");
753 }
754}
755
756std::unordered_set<ParallelType> allParallelTypesExcept(
757 const std::unordered_set<ParallelType>& except) {
758 std::unordered_set<ParallelType> result = {
759 ParallelType::BIDz,
760 ParallelType::BIDy,
761 ParallelType::BIDx,
762 ParallelType::TIDz,
763 ParallelType::TIDy,
764 ParallelType::TIDx,
765 ParallelType::Vectorize,
766 ParallelType::MisalignedVectorize,
767 ParallelType::Unroll,
768 ParallelType::Unswitch,
769 ParallelType::Mma,
770 ParallelType::Group,
771 ParallelType::Serial};
772 for (auto t : except) {
773 result.erase(t);
774 }
775 return result;
776}
777
778static const char* memory_type2string(MemoryType t) {
779 switch (t) {
780 case MemoryType::Local:
781 return "register";
782 case MemoryType::Shared:
783 return "shared";
784 case MemoryType::Global:
785 return "global";
786 default:
787 TORCH_INTERNAL_ASSERT(false, "Unexpected MemoryType");
788 }
789}
790
791static const char* id_map_mode_type2string(IdMappingMode t) {
792 switch (t) {
793 case IdMappingMode::PERMISSIVE:
794 return "permissive";
795 case IdMappingMode::EXACT:
796 return "exact";
797 case IdMappingMode::LOOP:
798 return "loop";
799 default:
800 // Don't try to print t as it would recursively call this function
801 TORCH_INTERNAL_ASSERT(false, "Unexpected IdMappingMode Type.");
802 }
803}
804
805static const char* iter_type2string(IterType t) {
806 switch (t) {
807 case IterType::Iteration:
808 return "i";
809 case IterType::Reduction:
810 return "r";
811 case IterType::Broadcast:
812 return "b";
813 case IterType::Gather:
814 return "g";
815 case IterType::Stride:
816 return "s";
817 case IterType::VectorComponent:
818 return "v";
819 default:
820 // Don't try to print t as it would recursively call this function
821 TORCH_INTERNAL_ASSERT(false, "Unexpected IterType");
822 }
823}
824
825static const char* thread_size2string(ParallelType t) {
826 switch (t) {
827 case ParallelType::BIDz:
828 return "gridDim.z";
829 case ParallelType::BIDy:
830 return "gridDim.y";
831 case ParallelType::BIDx:
832 return "gridDim.x";
833 case ParallelType::TIDz:
834 return "blockDim.z";
835 case ParallelType::TIDy:
836 return "blockDim.y";
837 case ParallelType::TIDx:
838 return "blockDim.x";
839 default:
840 TORCH_INTERNAL_ASSERT(false, "Unexpected parallel type");
841 }
842}
843
844static const char* load_store_type2string(LoadStoreOpType t) {
845 switch (t) {
846 case LoadStoreOpType::LdMatrix:
847 return "LdMatrix";
848 case LoadStoreOpType::LdMatrixTranspose:
849 return "LdMatrixTranspose";
850 case LoadStoreOpType::CpAsync:
851 return "CpAsync";
852 default:
853 TORCH_INTERNAL_ASSERT(false, "Unexpected parallel type");
854 }
855}
856
857const unsigned int _WORD_SHIFT = 16;
858constexpr unsigned int supported_switch_pair(DataType t1, DataType t2) {
859 return ((unsigned int)t1 << _WORD_SHIFT) + (unsigned int)t2;
860}
861
862static const char* supported_casts2string(
863 const std::pair<DataType, DataType>& t) {
864 switch (supported_switch_pair(std::get<0>(t), std::get<1>(t))) {
865 case supported_switch_pair(DataType::Index, DataType::Float):
866 case supported_switch_pair(DataType::Int, DataType::Float):
867 case supported_switch_pair(DataType::Int32, DataType::Float):
868 case supported_switch_pair(DataType::Double, DataType::Float):
869 case supported_switch_pair(DataType::Bool, DataType::Float):
870 return "(float)";
871 case supported_switch_pair(DataType::ComplexFloat, DataType::Float):
872 case supported_switch_pair(DataType::ComplexDouble, DataType::Float):
873 return "(float)std::real";
874 case supported_switch_pair(DataType::Index, DataType::Int):
875 case supported_switch_pair(DataType::Int32, DataType::Int):
876 case supported_switch_pair(DataType::Float, DataType::Int):
877 case supported_switch_pair(DataType::Double, DataType::Int):
878 case supported_switch_pair(DataType::Bool, DataType::Int):
879 return "(int64_t)";
880 case supported_switch_pair(DataType::ComplexFloat, DataType::Int):
881 case supported_switch_pair(DataType::ComplexDouble, DataType::Int):
882 return "(int64_t)std::real";
883 case supported_switch_pair(DataType::Index, DataType::Int32):
884 case supported_switch_pair(DataType::Int, DataType::Int32):
885 case supported_switch_pair(DataType::Float, DataType::Int32):
886 case supported_switch_pair(DataType::Double, DataType::Int32):
887 case supported_switch_pair(DataType::Bool, DataType::Int32):
888 return "(int32_t)";
889 case supported_switch_pair(DataType::ComplexFloat, DataType::Int32):
890 case supported_switch_pair(DataType::ComplexDouble, DataType::Int32):
891 return "(int32_t)std::real";
892 case supported_switch_pair(DataType::Int, DataType::Index):
893 case supported_switch_pair(DataType::Int32, DataType::Index):
894 case supported_switch_pair(DataType::Float, DataType::Index):
895 case supported_switch_pair(DataType::Double, DataType::Index):
896 return "(nvfuser_index_t)";
897 case supported_switch_pair(DataType::ComplexFloat, DataType::Index):
898 case supported_switch_pair(DataType::ComplexDouble, DataType::Index):
899 return "(nvfuser_index_t)std::real";
900 case supported_switch_pair(DataType::Index, DataType::Double):
901 case supported_switch_pair(DataType::Int, DataType::Double):
902 case supported_switch_pair(DataType::Int32, DataType::Double):
903 case supported_switch_pair(DataType::Float, DataType::Double):
904 case supported_switch_pair(DataType::Bool, DataType::Double):
905 return "(double)";
906 case supported_switch_pair(DataType::ComplexFloat, DataType::Double):
907 case supported_switch_pair(DataType::ComplexDouble, DataType::Double):
908 return "(double)std::real";
909 case supported_switch_pair(DataType::Float, DataType::Bool):
910 case supported_switch_pair(DataType::Double, DataType::Bool):
911 case supported_switch_pair(DataType::Int32, DataType::Bool):
912 case supported_switch_pair(DataType::Int, DataType::Bool):
913 return "(bool)";
914 case supported_switch_pair(DataType::ComplexFloat, DataType::Bool):
915 case supported_switch_pair(DataType::ComplexDouble, DataType::Bool):
916 return "(bool)std::real";
917 case supported_switch_pair(DataType::Index, DataType::ComplexDouble):
918 case supported_switch_pair(DataType::Int, DataType::ComplexDouble):
919 case supported_switch_pair(DataType::Int32, DataType::ComplexDouble):
920 case supported_switch_pair(DataType::Double, DataType::ComplexDouble):
921 case supported_switch_pair(DataType::Float, DataType::ComplexDouble):
922 case supported_switch_pair(DataType::Bool, DataType::ComplexDouble):
923 case supported_switch_pair(DataType::ComplexFloat, DataType::ComplexDouble):
924 return "(std::complex<double>)";
925 case supported_switch_pair(DataType::Index, DataType::ComplexFloat):
926 case supported_switch_pair(DataType::Int, DataType::ComplexFloat):
927 case supported_switch_pair(DataType::Int32, DataType::ComplexFloat):
928 case supported_switch_pair(DataType::Double, DataType::ComplexFloat):
929 case supported_switch_pair(DataType::Float, DataType::ComplexFloat):
930 case supported_switch_pair(DataType::Bool, DataType::ComplexFloat):
931 case supported_switch_pair(DataType::ComplexDouble, DataType::ComplexFloat):
932 return "(std::complex<float>)";
933 case supported_switch_pair(DataType::Float, DataType::Half):
934 return "__float2half";
935 case supported_switch_pair(DataType::Double, DataType::Half):
936 return "__double2half";
937 case supported_switch_pair(DataType::Float, DataType::BFloat16):
938 return "__float2bfloat";
939 case supported_switch_pair(DataType::Half, DataType::Float):
940 return "__half2float";
941 case supported_switch_pair(DataType::Half, DataType::Double):
942 return "__half2double";
943 case supported_switch_pair(DataType::BFloat16, DataType::Float):
944 return "__bfloat2float";
945 default:
946 return nullptr;
947 }
948}
949
950DataType aten_to_data_type(const at::ScalarType& scalar_type) {
951 switch (scalar_type) {
952 case at::ScalarType::Bool:
953 return DataType::Bool;
954 case at::ScalarType::Double:
955 return DataType::Double;
956 case at::ScalarType::Float:
957 return DataType::Float;
958 case at::ScalarType::Half:
959 return DataType::Half;
960 case at::ScalarType::BFloat16:
961 return DataType::BFloat16;
962 case at::ScalarType::Long:
963 return DataType::Int;
964 case at::ScalarType::Int:
965 return DataType::Int32;
966 case at::ScalarType::ComplexFloat:
967 return DataType::ComplexFloat;
968 case at::ScalarType::ComplexDouble:
969 return DataType::ComplexDouble;
970 default:
971 return DataType::Null;
972 }
973}
974
975at::ScalarType data_type_to_aten(const DataType& data_type) {
976 switch (data_type) {
977 case DataType::Bool:
978 return at::ScalarType::Bool;
979 case DataType::Double:
980 return at::ScalarType::Double;
981 case DataType::Float:
982 return at::ScalarType::Float;
983 case DataType::Half:
984 return at::ScalarType::Half;
985 case DataType::BFloat16:
986 return at::ScalarType::BFloat16;
987 case DataType::Int:
988 return at::ScalarType::Long;
989 case DataType::Index:
990 TORCH_INTERNAL_ASSERT(
991 false,
992 "Index is determined at compile time,",
993 " to convert from an aten type you need to have the compiled information. ",
994 "This information is passed to GpuLower at compile time, and then copied to kerned.",
995 "There's also this information in FusionExecutorCache and the Registry system.");
996 case DataType::Int32:
997 return at::ScalarType::Int;
998 case DataType::ComplexFloat:
999 return at::ScalarType::ComplexFloat;
1000 case DataType::ComplexDouble:
1001 return at::ScalarType::ComplexDouble;
1002 default:
1003 TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type.");
1004 }
1005}
1006
1007std::ostream& operator<<(std::ostream& out, const ValType vtype) {
1008 return out << val_type2string(vtype);
1009}
1010
1011std::ostream& operator<<(std::ostream& out, const PredicateType ptype) {
1012 return out << predicate_type2string(ptype);
1013}
1014
1015std::ostream& operator<<(std::ostream& out, const DataType dtype) {
1016 return out << data_type2string(dtype);
1017}
1018
1019std::ostream& operator<<(std::ostream& out, const ExprType etype) {
1020 return out << expr_type2string(etype);
1021}
1022
1023std::ostream& operator<<(std::ostream& out, const UnaryOpType uotype) {
1024 return out << unary_op_type2string(uotype);
1025}
1026
1027std::ostream& operator<<(std::ostream& out, const BinaryOpType botype) {
1028 return out << binary_op_type2string(botype);
1029}
1030
1031std::ostream& operator<<(std::ostream& out, const TernaryOpType totype) {
1032 return out << ternary_op_type2string(totype);
1033}
1034
1035std::ostream& operator<<(std::ostream& out, const RNGOpType rngtype) {
1036 return out << rng_op_type2string(rngtype);
1037}
1038
1039std::ostream& operator<<(std::ostream& out, const ParallelType ptype) {
1040 return out << stringifyThread(ptype);
1041}
1042
1043std::ostream& operator<<(std::ostream& out, const MemoryType mtype) {
1044 return out << memory_type2string(mtype);
1045}
1046
1047std::ostream& operator<<(std::ostream& out, const IdMappingMode immtype) {
1048 return out << id_map_mode_type2string(immtype);
1049}
1050
1051std::ostream& operator<<(
1052 std::ostream& out,
1053 const LoadStoreOpType load_store_type) {
1054 return out << load_store_type2string(load_store_type);
1055}
1056
1057std::ostream& operator<<(std::ostream& out, const IterType bt) {
1058 return out << iter_type2string(bt);
1059}
1060
1061std::ostream& operator<<(std::ostream& os, const Swizzle2DType& swizzle) {
1062 switch (swizzle) {
1063 case Swizzle2DType::NoSwizzle:
1064 os << "NoSwizzle";
1065 break;
1066 case Swizzle2DType::ZShape:
1067 os << "ZShape";
1068 break;
1069 case Swizzle2DType::Transpose:
1070 os << "Transpose";
1071 break;
1072 case Swizzle2DType::XOR:
1073 os << "Xor";
1074 break;
1075 case Swizzle2DType::Scatter:
1076 os << "Scatter";
1077 break;
1078 default:
1079 TORCH_INTERNAL_ASSERT(false, "undefined 2D swizzle");
1080 break;
1081 }
1082 return os;
1083}
1084
1085std::ostream& operator<<(std::ostream& os, const SwizzleMode& swizzle) {
1086 switch (swizzle) {
1087 case SwizzleMode::NoSwizzle:
1088 os << "NoSwizzle";
1089 break;
1090 case SwizzleMode::Loop:
1091 os << "Loop";
1092 break;
1093 case SwizzleMode::Data:
1094 os << "Data";
1095 break;
1096 default:
1097 TORCH_INTERNAL_ASSERT(false, "undefined 2D swizzle");
1098 break;
1099 }
1100 return os;
1101}
1102
1103c10::optional<std::string> inline_op_str(const UnaryOpType uotype) {
1104 const char* str = unary_op_type_inline_op2string(uotype);
1105 return str != nullptr ? c10::optional<std::string>(std::string(str))
1106 : c10::nullopt;
1107}
1108
1109c10::optional<std::string> inline_op_str(const BinaryOpType botype) {
1110 const char* str = binary_op_type_inline_op2string(botype);
1111 return str != nullptr ? c10::optional<std::string>(std::string(str))
1112 : c10::nullopt;
1113}
1114
1115c10::optional<std::string> inline_op_str(const RNGOpType rngtype) {
1116 const char* str = rng_op_type_inline_op2string(rngtype);
1117 return str != nullptr ? c10::optional<std::string>(std::string(str))
1118 : c10::nullopt;
1119}
1120
1121c10::optional<std::string> integer_op_str(const BinaryOpType botype) {
1122 const char* str = binary_op_integer_op2string(botype);
1123 return str != nullptr ? c10::optional<std::string>(std::string(str))
1124 : c10::nullopt;
1125}
1126
1127c10::optional<std::string> bool_op_str(const BinaryOpType botype) {
1128 const char* str = binary_op_bool_op2string(botype);
1129 return str != nullptr ? c10::optional<std::string>(std::string(str))
1130 : c10::nullopt;
1131}
1132
1133std::string stringifyThreadSize(const ParallelType ptype) {
1134 return thread_size2string(ptype);
1135}
1136
1137std::string stringifyThread(const ParallelType ptype) {
1138 return parallel_type2string(ptype);
1139}
1140
1141std::string typePrefix(const DataType data_type) {
1142 switch (data_type) {
1143 case DataType::Bool:
1144 return "b";
1145 case DataType::Double:
1146 return "d";
1147 case DataType::Float:
1148 case DataType::Half:
1149 case DataType::BFloat16:
1150 return "f";
1151 case DataType::Index:
1152 case DataType::Int:
1153 case DataType::Int32:
1154 return "i";
1155 case DataType::ComplexFloat:
1156 case DataType::ComplexDouble:
1157 return "c";
1158 default:
1159 TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type.");
1160 }
1161}
1162
1163bool isParallelTypeThreadDim(ParallelType ptype) {
1164 return ptype == ParallelType::TIDx || ptype == ParallelType::TIDy ||
1165 ptype == ParallelType::TIDz;
1166}
1167
1168bool isParallelTypeBlockDim(ParallelType ptype) {
1169 return ptype == ParallelType::BIDx || ptype == ParallelType::BIDy ||
1170 ptype == ParallelType::BIDz;
1171}
1172
1173bool isParallelTypeThread(ParallelType ptype) {
1174 return isParallelTypeBlockDim(ptype) || isParallelTypeThreadDim(ptype);
1175}
1176
1177bool isParallelTypeVectorize(ParallelType ptype) {
1178 return ptype == ParallelType::Vectorize ||
1179 ptype == ParallelType::MisalignedVectorize;
1180}
1181
1182c10::optional<std::string> cast_func_str(
1183 const std::pair<DataType, DataType>& cast) {
1184 const char* str = supported_casts2string(cast);
1185 return str != nullptr ? c10::optional<std::string>(std::string(str))
1186 : c10::nullopt;
1187}
1188
1189size_t dataTypeSize(DataType type) {
1190 switch (type) {
1191 case DataType::Bool:
1192 return sizeof(bool);
1193 case DataType::ComplexDouble:
1194 return sizeof(std::complex<double>);
1195 case DataType::ComplexFloat:
1196 return sizeof(std::complex<float>);
1197 case DataType::Double:
1198 return sizeof(double);
1199 case DataType::Float:
1200 return sizeof(float);
1201 case DataType::Half:
1202 return sizeof(at::Half);
1203 case DataType::BFloat16:
1204 return sizeof(at::BFloat16);
1205 case DataType::Index:
1206 TORCH_INTERNAL_ASSERT(
1207 false, "The actual type of Index is only known at compile time.");
1208 case DataType::Int:
1209 return sizeof(uint64_t);
1210 case DataType::Int32:
1211 return sizeof(uint32_t);
1212 case DataType::Double_2:
1213 return sizeof(double) * 2;
1214 case DataType::Float_2:
1215 return sizeof(float) * 2;
1216 default:
1217 TORCH_INTERNAL_ASSERT(false, "Size undefined for data type, ", type);
1218 }
1219}
1220
1221size_t dataTypeSize(DataType type, DataType index_type) {
1222 if (type == DataType::Index) {
1223 TORCH_INTERNAL_ASSERT(
1224 index_type == DataType::Int32 || index_type == DataType::Int,
1225 "Invalid index type of ",
1226 index_type);
1227 return dataTypeSize(index_type);
1228 }
1229 return dataTypeSize(type);
1230}
1231
1232std::ostream& operator<<(
1233 std::ostream& os,
1234 const DoubleBufferLoopStage loop_stage) {
1235 switch (loop_stage) {
1236 case DoubleBufferLoopStage::NotApplicable:
1237 break;
1238 case DoubleBufferLoopStage::Prolog:
1239 os << "{DoubleBufferProlog}";
1240 break;
1241 case DoubleBufferLoopStage::Main:
1242 os << "{DoubleBufferMainLoop}";
1243 break;
1244 case DoubleBufferLoopStage::Epilog:
1245 os << "{DoubleBufferEpilog}";
1246 break;
1247 default:
1248 TORCH_INTERNAL_ASSERT(false, "unknown double buffer stage");
1249 }
1250 return os;
1251}
1252
1253} // namespace cuda
1254} // namespace fuser
1255} // namespace jit
1256} // namespace torch
1257