1 | #include <type.h> |
2 | |
3 | #include <ATen/cuda/CUDAContext.h> |
4 | |
5 | #include <stdexcept> |
6 | #include <unordered_map> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | DataType indexModeToDtype(KernelIndexMode index_mode) { |
14 | switch (index_mode) { |
15 | case KernelIndexMode::INT32: |
16 | return DataType::Int32; |
17 | case KernelIndexMode::INT64: |
18 | return DataType::Int; |
19 | default: |
20 | TORCH_CHECK(false, "Invalid kernel index mode type." ); |
21 | } |
22 | } |
23 | |
24 | bool isFloatingPointType(DataType dtype) { |
25 | switch (dtype) { |
26 | case DataType::Double: |
27 | case DataType::Float: |
28 | case DataType::Half: |
29 | case DataType::BFloat16: |
30 | return true; |
31 | case DataType::Bool: |
32 | case DataType::Index: |
33 | case DataType::Int: |
34 | case DataType::Int32: |
35 | case DataType::ComplexFloat: |
36 | case DataType::ComplexDouble: |
37 | return false; |
38 | case DataType::Null: |
39 | TORCH_CHECK( |
40 | false, "Null type is not a valid argument to isFloatingPointType" ); |
41 | default: |
42 | TORCH_CHECK(false, "Type not supported in isFloatingPointType" ); |
43 | } |
44 | } |
45 | |
46 | bool isBooleanType(DataType dtype) { |
47 | switch (dtype) { |
48 | case DataType::Bool: |
49 | return true; |
50 | case DataType::Double: |
51 | case DataType::Float: |
52 | case DataType::Half: |
53 | case DataType::BFloat16: |
54 | case DataType::ComplexFloat: |
55 | case DataType::ComplexDouble: |
56 | case DataType::Index: |
57 | case DataType::Int: |
58 | case DataType::Int32: |
59 | return false; |
60 | case DataType::Null: |
61 | TORCH_CHECK(false, "Null type is not a valid argument to isBooleanType" ); |
62 | default: |
63 | TORCH_CHECK(false, "Type not supported in isBooleanType" ); |
64 | } |
65 | } |
66 | |
67 | bool isIntegralType(DataType dtype) { |
68 | switch (dtype) { |
69 | case DataType::Bool: |
70 | case DataType::Double: |
71 | case DataType::Float: |
72 | case DataType::Half: |
73 | case DataType::BFloat16: |
74 | case DataType::ComplexFloat: |
75 | case DataType::ComplexDouble: |
76 | return false; |
77 | case DataType::Index: |
78 | case DataType::Int: |
79 | case DataType::Int32: |
80 | return true; |
81 | case DataType::Null: |
82 | TORCH_CHECK(false, "Null type is not a valid argument to isIntegralType" ); |
83 | default: |
84 | TORCH_CHECK(false, "Type not supported in isIntegralType" ); |
85 | } |
86 | } |
87 | |
88 | bool isComplexType(DataType dtype) { |
89 | switch (dtype) { |
90 | case DataType::ComplexFloat: |
91 | case DataType::ComplexDouble: |
92 | return true; |
93 | case DataType::Bool: |
94 | case DataType::Double: |
95 | case DataType::Float: |
96 | case DataType::Half: |
97 | case DataType::BFloat16: |
98 | case DataType::Int: |
99 | case DataType::Index: |
100 | case DataType::Int32: |
101 | return false; |
102 | case DataType::Null: |
103 | TORCH_CHECK(false, "Null type is not a valid argument to isComplexType" ); |
104 | default: |
105 | TORCH_CHECK(false, "Type not supported in isComplexType" ); |
106 | } |
107 | } |
108 | |
109 | bool isVectorType(DataType dtype) { |
110 | switch (dtype) { |
111 | case DataType::Float_2: |
112 | case DataType::Double_2: |
113 | return true; |
114 | default: |
115 | return false; |
116 | } |
117 | } |
118 | |
119 | DataType getVectorType(DataType dtype, size_t vec_size) { |
120 | switch (dtype) { |
121 | case DataType::Float: |
122 | TORCH_INTERNAL_ASSERT(vec_size == 2, "Not supported vectorized type" ); |
123 | return DataType::Float_2; |
124 | case DataType::Double: |
125 | TORCH_INTERNAL_ASSERT(vec_size == 2, "Not supported vectorized type" ); |
126 | return DataType::Double_2; |
127 | default: |
128 | TORCH_INTERNAL_ASSERT( |
129 | false, "Not supported vectorized type:" , dtype, " and " , vec_size); |
130 | } |
131 | } |
132 | |
133 | int getVectorSizeFromType(DataType dtype) { |
134 | switch (dtype) { |
135 | case DataType::Float_2: |
136 | case DataType::Double_2: |
137 | return 2; |
138 | default: |
139 | TORCH_INTERNAL_ASSERT(false, "Not a vector type:" , dtype); |
140 | } |
141 | } |
142 | |
143 | DataType getTypeFromVectorType(DataType dtype) { |
144 | switch (dtype) { |
145 | case DataType::Float_2: |
146 | return DataType::Float; |
147 | case DataType::Double_2: |
148 | return DataType::Double; |
149 | default: |
150 | TORCH_INTERNAL_ASSERT(false, "Not a vector type:" , dtype); |
151 | } |
152 | } |
153 | |
154 | DataType getTypeFromComplexType(DataType dtype) { |
155 | switch (dtype) { |
156 | case DataType::ComplexFloat: |
157 | return DataType::Float; |
158 | case DataType::ComplexDouble: |
159 | return DataType::Double; |
160 | default: |
161 | TORCH_INTERNAL_ASSERT(false, "Not a vector type:" , dtype); |
162 | } |
163 | } |
164 | |
165 | bool isSupportedTypeByDevice(DataType dtype) { |
166 | auto prop = at::cuda::getCurrentDeviceProperties(); |
167 | auto major_ver = prop->major; |
168 | switch (dtype) { |
169 | case DataType::BFloat16: |
170 | return major_ver >= 8; |
171 | default: |
172 | return true; |
173 | } |
174 | } |
175 | |
176 | bool isIntegerOp(const BinaryOpType bopt) { |
177 | return bopt >= BinaryOpType::Mod && bopt <= BinaryOpType::Rshift; |
178 | } |
179 | |
180 | bool isLogicalOp(const BinaryOpType bopt) { |
181 | return bopt >= BinaryOpType::Eq && bopt <= BinaryOpType::NE; |
182 | } |
183 | |
184 | bool alsoBooleanOperator(const BinaryOpType bopt) { |
185 | return bopt >= BinaryOpType::And && bopt <= BinaryOpType::Xor; |
186 | } |
187 | |
188 | bool alsoBooleanOperator(const UnaryOpType uopt) { |
189 | return uopt >= UnaryOpType::Not && uopt <= UnaryOpType::Not; |
190 | } |
191 | |
192 | // Return highest on list (smallest enum val) |
193 | DataType promote_type(const DataType& t1, const DataType& t2) { |
194 | TORCH_CHECK( |
195 | DataType::Null != t1 && DataType::Null != t2, |
196 | "Expected promotable DataTypes but got: " , |
197 | t1, |
198 | " and " , |
199 | t2); |
200 | return aten_to_data_type( |
201 | c10::promoteTypes(data_type_to_aten(t1), data_type_to_aten(t2))); |
202 | } |
203 | |
204 | // Return highest on list (smallest enum val) |
205 | ValType promote_type(const ValType& t1, const ValType& t2) { |
206 | if (t1 == ValType::TensorView || t2 == ValType::TensorView) { |
207 | return ValType::TensorView; |
208 | } |
209 | if (t1 == ValType::Scalar && |
210 | (t2 == ValType::Scalar || t2 == ValType::NamedScalar)) { |
211 | return ValType::Scalar; |
212 | } |
213 | if (t2 == ValType::Scalar && |
214 | (t1 == ValType::Scalar || t1 == ValType::NamedScalar)) { |
215 | return ValType::Scalar; |
216 | } |
217 | if (t1 == ValType::NamedScalar && t2 == ValType::NamedScalar) { |
218 | return ValType::Scalar; |
219 | } |
220 | TORCH_CHECK(false, "Expected promotable ValTypes but got: " , t1, " and " , t2); |
221 | } |
222 | |
223 | static const char* data_type2string(DataType t) { |
224 | switch (t) { |
225 | case DataType::Bool: |
226 | return "bool" ; |
227 | case DataType::Double: |
228 | return "double" ; |
229 | case DataType::Float: |
230 | return "float" ; |
231 | case DataType::Half: |
232 | return "__half" ; |
233 | case DataType::BFloat16: |
234 | return "__bfloat" ; |
235 | case DataType::Int: |
236 | return "int64_t" ; |
237 | case DataType::Index: |
238 | return "nvfuser_index_t" ; |
239 | case DataType::Int32: |
240 | return "int" ; |
241 | case DataType::ComplexFloat: |
242 | return "std::complex<float>" ; |
243 | case DataType::ComplexDouble: |
244 | return "std::complex<double>" ; |
245 | case DataType::Double_2: |
246 | return "Array<double, 2, 1>" ; |
247 | case DataType::Float_2: |
248 | return "Array<float, 2, 1>" ; |
249 | case DataType::Null: |
250 | return "null_type" ; |
251 | default: |
252 | break; |
253 | } |
254 | TORCH_INTERNAL_ASSERT(false, "No string found for data type." ); |
255 | return nullptr; |
256 | } |
257 | |
258 | static const char* val_type2string(ValType t) { |
259 | switch (t) { |
260 | case ValType::TensorView: |
261 | return "TensorView" ; |
262 | case ValType::TensorDomain: |
263 | return "TensorDomain" ; |
264 | case ValType::IterDomain: |
265 | return "IterDomain" ; |
266 | case ValType::Scalar: |
267 | return "Scalar" ; |
268 | case ValType::NamedScalar: |
269 | return "NamedScalar" ; |
270 | case ValType::Predicate: |
271 | return "Predicate" ; |
272 | case ValType::TensorIndex: |
273 | return "TensorIndex" ; |
274 | case ValType::IntPair: |
275 | return "IntPair" ; |
276 | default: |
277 | TORCH_INTERNAL_ASSERT(false, "No string found for val type." ); |
278 | } |
279 | } |
280 | |
281 | static const char* predicate_type2string(PredicateType t) { |
282 | switch (t) { |
283 | case PredicateType::Manual: |
284 | return "Manual" ; |
285 | case PredicateType::Inline: |
286 | return "Inline" ; |
287 | case PredicateType::Unswitch: |
288 | return "Unswitch" ; |
289 | case PredicateType::Vectorize: |
290 | return "Vectorize" ; |
291 | case PredicateType::Misaligned: |
292 | return "Misaligned" ; |
293 | case PredicateType::Shift: |
294 | return "Shift" ; |
295 | case PredicateType::Padding: |
296 | return "Padding" ; |
297 | case PredicateType::ReductionWrite: |
298 | return "ReductionWrite" ; |
299 | default: |
300 | TORCH_INTERNAL_ASSERT(false, "No string found for predicate type." ); |
301 | } |
302 | } |
303 | |
304 | static const char* expr_type2string(ExprType t) { |
305 | switch (t) { |
306 | case ExprType::FullOp: |
307 | return "FullOp" ; |
308 | case ExprType::ARangeOp: |
309 | return "ARangeOp" ; |
310 | case ExprType::EyeOp: |
311 | return "EyeOp" ; |
312 | case ExprType::UnaryOp: |
313 | return "UnaryOp" ; |
314 | case ExprType::BinaryOp: |
315 | return "BinaryOp" ; |
316 | case ExprType::TernaryOp: |
317 | return "TernaryOp" ; |
318 | case ExprType::RNGOp: |
319 | return "RNGOp" ; |
320 | case ExprType::ReductionOp: |
321 | return "ReductionOp" ; |
322 | case ExprType::GroupedReductionOp: |
323 | return "GroupedReductionOp" ; |
324 | case ExprType::BroadcastOp: |
325 | return "BroadcastOp" ; |
326 | case ExprType::WelfordOp: |
327 | return "WelfordOp" ; |
328 | case ExprType::GroupedWelfordOp: |
329 | return "GroupedWelfordOp" ; |
330 | case ExprType::LoadStoreOp: |
331 | return "LoadStoreOp" ; |
332 | case ExprType::MmaOp: |
333 | return "MmaOp" ; |
334 | case ExprType::TransposeOp: |
335 | return "TransposeOp" ; |
336 | case ExprType::ExpandOp: |
337 | return "ExpandOp" ; |
338 | case ExprType::ShiftOp: |
339 | return "ShiftOp" ; |
340 | case ExprType::GatherOp: |
341 | return "GatherOp" ; |
342 | case ExprType::ViewAsScalar: |
343 | return "ViewAsScalar" ; |
344 | case ExprType::ViewOp: |
345 | return "ViewOp" ; |
346 | case ExprType::Split: |
347 | return "Split" ; |
348 | case ExprType::Merge: |
349 | return "Merge" ; |
350 | case ExprType::Allocate: |
351 | return "Allocate" ; |
352 | case ExprType::BlockSync: |
353 | return "BlockSync" ; |
354 | case ExprType::GridSync: |
355 | return "GridSync" ; |
356 | case ExprType::CpAsyncWait: |
357 | return "CpAsyncWait" ; |
358 | case ExprType::CpAsyncCommit: |
359 | return "CpAsyncCommit" ; |
360 | case ExprType::InitMagicZero: |
361 | return "InitMagicZero" ; |
362 | case ExprType::UpdateMagicZero: |
363 | return "UpdateMagicZero" ; |
364 | case ExprType::ForLoop: |
365 | return "ForLoop" ; |
366 | case ExprType::IfThenElse: |
367 | return "IfThenElse" ; |
368 | case ExprType::GridReduction: |
369 | return "GridReduction" ; |
370 | case ExprType::GroupedGridReduction: |
371 | return "GroupedGridReduction" ; |
372 | case ExprType::GridBroadcast: |
373 | return "GridBroadcast" ; |
374 | case ExprType::GridWelford: |
375 | return "GridWelford" ; |
376 | case ExprType::GroupedGridWelford: |
377 | return "GroupedGridWelford" ; |
378 | case ExprType::Swizzle2D: |
379 | return "Swizzle2D" ; |
380 | case ExprType::Swizzle2DInt: |
381 | return "Swizzle2DInt" ; |
382 | case ExprType::PairSelect: |
383 | return "PairSelect" ; |
384 | default: |
385 | TORCH_INTERNAL_ASSERT(false, "No string found for expr type." ); |
386 | } |
387 | } |
388 | |
389 | bool needFloatSuffix(UnaryOpType t) { |
390 | switch (t) { |
391 | case UnaryOpType::Abs: |
392 | case UnaryOpType::Cast: |
393 | case UnaryOpType::Frac: |
394 | case UnaryOpType::Gelu: |
395 | case UnaryOpType::Silu: |
396 | case UnaryOpType::BitCast: |
397 | case UnaryOpType::Neg: |
398 | case UnaryOpType::Relu: |
399 | case UnaryOpType::Reciprocal: |
400 | case UnaryOpType::Set: |
401 | case UnaryOpType::Sigmoid: |
402 | case UnaryOpType::IsFinite: |
403 | case UnaryOpType::IsInf: |
404 | case UnaryOpType::IsNan: |
405 | case UnaryOpType::IsNegInf: |
406 | case UnaryOpType::IsPosInf: |
407 | case UnaryOpType::IsReal: |
408 | case UnaryOpType::Print: |
409 | return false; |
410 | default: |
411 | return true; |
412 | } |
413 | } |
414 | |
415 | bool needFloatSuffix(RNGOpType t) { |
416 | return true; |
417 | } |
418 | |
419 | static const char* unary_op_type2string(UnaryOpType t) { |
420 | switch (t) { |
421 | case UnaryOpType::Abs: |
422 | return "abs" ; |
423 | case UnaryOpType::Acos: |
424 | return "acos" ; |
425 | case UnaryOpType::Asin: |
426 | return "asin" ; |
427 | case UnaryOpType::Atan: |
428 | return "atan" ; |
429 | case UnaryOpType::Atanh: |
430 | return "atanh" ; |
431 | case UnaryOpType::Cast: |
432 | return "cast" ; |
433 | case UnaryOpType::Ceil: |
434 | return "ceil" ; |
435 | case UnaryOpType::Cos: |
436 | return "cos" ; |
437 | case UnaryOpType::Cosh: |
438 | return "cosh" ; |
439 | case UnaryOpType::Exp: |
440 | return "exp" ; |
441 | case UnaryOpType::Expm1: |
442 | return "expm1" ; |
443 | case UnaryOpType::Erf: |
444 | return "erf" ; |
445 | case UnaryOpType::Erfc: |
446 | return "erfc" ; |
447 | case UnaryOpType::Floor: |
448 | return "floor" ; |
449 | case UnaryOpType::Frac: |
450 | return "frac" ; |
451 | case UnaryOpType::Silu: |
452 | return "silu" ; |
453 | case UnaryOpType::Lgamma: |
454 | return "lgamma" ; |
455 | case UnaryOpType::Log: |
456 | return "log" ; |
457 | case UnaryOpType::Log10: |
458 | return "log10" ; |
459 | case UnaryOpType::Log1p: |
460 | return "log1p" ; |
461 | case UnaryOpType::Log2: |
462 | return "log2" ; |
463 | case UnaryOpType::BitCast: |
464 | return "erase_type" ; |
465 | case UnaryOpType::Neg: |
466 | return "neg" ; |
467 | case UnaryOpType::Not: |
468 | return "not" ; |
469 | case UnaryOpType::Print: |
470 | return "print" ; |
471 | case UnaryOpType::Reciprocal: |
472 | return "reciprocal" ; |
473 | case UnaryOpType::Relu: |
474 | return "relu" ; |
475 | case UnaryOpType::Rsqrt: |
476 | return "rsqrt" ; |
477 | case UnaryOpType::Round: |
478 | return "nearbyint" ; |
479 | case UnaryOpType::Set: |
480 | return "set" ; |
481 | case UnaryOpType::Sigmoid: |
482 | return "sigmoid" ; |
483 | case UnaryOpType::Sin: |
484 | return "sin" ; |
485 | case UnaryOpType::Sinh: |
486 | return "sinh" ; |
487 | case UnaryOpType::Sqrt: |
488 | return "sqrt" ; |
489 | case UnaryOpType::Tan: |
490 | return "tan" ; |
491 | case UnaryOpType::Tanh: |
492 | return "tanh" ; |
493 | case UnaryOpType::Trunc: |
494 | return "trunc" ; |
495 | case UnaryOpType::IsFinite: |
496 | return "isfinite" ; |
497 | case UnaryOpType::IsInf: |
498 | return "isinf" ; |
499 | case UnaryOpType::IsNan: |
500 | return "isnan" ; |
501 | case UnaryOpType::IsNegInf: |
502 | return "isneginf" ; |
503 | case UnaryOpType::IsPosInf: |
504 | return "isposinf" ; |
505 | case UnaryOpType::IsReal: |
506 | return "isreal" ; |
507 | case UnaryOpType::Real: |
508 | return "std::real" ; |
509 | case UnaryOpType::Imag: |
510 | return "std::imag" ; |
511 | default: |
512 | TORCH_INTERNAL_ASSERT(false, "No string found for unary op type." ); |
513 | } |
514 | } |
515 | |
516 | std::string stringifyBooleanOp(const UnaryOpType uopt) { |
517 | TORCH_INTERNAL_ASSERT( |
518 | uopt == UnaryOpType::Not, uopt, " is not a boolean operator." ); |
519 | return "!" ; |
520 | } |
521 | |
522 | static const char* unary_op_type_inline_op2string(UnaryOpType t) { |
523 | switch (t) { |
524 | case UnaryOpType::Neg: |
525 | return "-" ; |
526 | case UnaryOpType::Not: |
527 | return "~" ; |
528 | case UnaryOpType::Set: |
529 | return "" ; |
530 | case UnaryOpType::Address: |
531 | return "(int64_t) &" ; |
532 | default: |
533 | break; |
534 | } |
535 | return nullptr; |
536 | } |
537 | |
538 | bool needFloatSuffix(BinaryOpType t) { |
539 | switch (t) { |
540 | case BinaryOpType::Atan2: |
541 | case BinaryOpType::Div: |
542 | case BinaryOpType::Fmod: |
543 | return true; |
544 | default: |
545 | return false; |
546 | } |
547 | } |
548 | |
549 | static const char* binary_op_type2string(BinaryOpType t) { |
550 | switch (t) { |
551 | case BinaryOpType::Add: |
552 | return "add" ; |
553 | case BinaryOpType::Atan2: |
554 | return "atan2" ; |
555 | case BinaryOpType::Div: |
556 | return "div" ; |
557 | case BinaryOpType::Fmod: |
558 | return "fmod" ; |
559 | case BinaryOpType::Max: |
560 | return "fmax" ; |
561 | case BinaryOpType::Min: |
562 | return "fmin" ; |
563 | case BinaryOpType::Mul: |
564 | return "mul" ; |
565 | case BinaryOpType::Pow: |
566 | return "pow" ; |
567 | case BinaryOpType::Remainder: |
568 | return "remainder" ; |
569 | case BinaryOpType::Sub: |
570 | return "sub" ; |
571 | |
572 | // Integer Ops |
573 | case BinaryOpType::Mod: |
574 | return "mod" ; |
575 | case BinaryOpType::CeilDiv: |
576 | return "ceilDiv" ; |
577 | case BinaryOpType::Lshift: |
578 | return "lshift" ; |
579 | case BinaryOpType::Rshift: |
580 | return "rshift" ; |
581 | |
582 | // Logical Ops |
583 | case BinaryOpType::And: |
584 | return "and" ; |
585 | case BinaryOpType::Eq: |
586 | return "equal" ; |
587 | case BinaryOpType::GE: |
588 | return "greaterThanOrEqual" ; |
589 | case BinaryOpType::GT: |
590 | return "greaterThan" ; |
591 | case BinaryOpType::LE: |
592 | return "lessThanOrEqual" ; |
593 | case BinaryOpType::LT: |
594 | return "lessThan" ; |
595 | case BinaryOpType::NE: |
596 | return "notEqual" ; |
597 | default: |
598 | TORCH_INTERNAL_ASSERT(false, "No string found for binary op type." ); |
599 | } |
600 | } |
601 | |
602 | static const char* binary_op_integer_op2string(BinaryOpType t) { |
603 | switch (t) { |
604 | case BinaryOpType::Max: |
605 | return "max" ; |
606 | case BinaryOpType::Min: |
607 | return "min" ; |
608 | case BinaryOpType::Fmod: |
609 | return "fmod" ; |
610 | default: |
611 | break; |
612 | } |
613 | return nullptr; |
614 | } |
615 | |
616 | static const char* binary_op_bool_op2string(BinaryOpType t) { |
617 | switch (t) { |
618 | case BinaryOpType::Max: |
619 | return "max" ; |
620 | case BinaryOpType::Min: |
621 | return "min" ; |
622 | default: |
623 | break; |
624 | } |
625 | return nullptr; |
626 | } |
627 | |
628 | static const char* binary_op_type_inline_op2string(BinaryOpType t) { |
629 | switch (t) { |
630 | case BinaryOpType::Add: |
631 | return "+" ; |
632 | case BinaryOpType::Div: |
633 | return "/" ; |
634 | case BinaryOpType::Mul: |
635 | return "*" ; |
636 | case BinaryOpType::Sub: |
637 | return "-" ; |
638 | |
639 | // Integer ops |
640 | case BinaryOpType::Mod: |
641 | return "%" ; |
642 | case BinaryOpType::Lshift: |
643 | return "<<" ; |
644 | case BinaryOpType::Rshift: |
645 | return ">>" ; |
646 | // Logical Ops |
647 | case BinaryOpType::Eq: |
648 | return "==" ; |
649 | case BinaryOpType::GE: |
650 | return ">=" ; |
651 | case BinaryOpType::GT: |
652 | return ">" ; |
653 | case BinaryOpType::LE: |
654 | return "<=" ; |
655 | case BinaryOpType::LT: |
656 | return "<" ; |
657 | case BinaryOpType::NE: |
658 | return "!=" ; |
659 | // Assume bitwise, otherwise use stringifyBooleanOp |
660 | case BinaryOpType::And: |
661 | return "&" ; |
662 | case BinaryOpType::Or: |
663 | return "|" ; |
664 | case BinaryOpType::Xor: |
665 | return "^" ; |
666 | default: |
667 | break; |
668 | } |
669 | return nullptr; |
670 | } |
671 | |
672 | static const char* rng_op_type_inline_op2string(RNGOpType t) { |
673 | switch (t) { |
674 | case RNGOpType::Uniform: |
675 | return "rng_uniform" ; |
676 | case RNGOpType::UniformRange: |
677 | return "rng_uniform_range" ; |
678 | default: |
679 | break; |
680 | } |
681 | return nullptr; |
682 | } |
683 | |
684 | std::string stringifyBooleanOp(const BinaryOpType bopt) { |
685 | switch (bopt) { |
686 | case BinaryOpType::And: |
687 | return "&&" ; |
688 | case BinaryOpType::Or: |
689 | return "||" ; |
690 | case BinaryOpType::Xor: |
691 | return "!=" ; |
692 | default: |
693 | TORCH_INTERNAL_ASSERT(false, bopt, " is not a boolean operator." ) |
694 | } |
695 | } |
696 | |
697 | static const char* ternary_op_type2string(TernaryOpType t) { |
698 | switch (t) { |
699 | case TernaryOpType::Clamp: |
700 | return "clamp" ; |
701 | case TernaryOpType::Lerp: |
702 | return "lerp" ; |
703 | case TernaryOpType::Threshold: |
704 | return "threshold" ; |
705 | case TernaryOpType::Where: |
706 | return "where" ; |
707 | default: |
708 | TORCH_INTERNAL_ASSERT(false, "Unexpected TernaryOpType" ); |
709 | } |
710 | } |
711 | |
712 | static const char* rng_op_type2string(RNGOpType t) { |
713 | switch (t) { |
714 | case RNGOpType::Uniform: |
715 | return "rng_uniform" ; |
716 | case RNGOpType::UniformRange: |
717 | return "rng_uniform_range" ; |
718 | default: |
719 | TORCH_INTERNAL_ASSERT(false, "Unexpected RNGOpType" ); |
720 | } |
721 | } |
722 | |
723 | static const char* parallel_type2string(ParallelType t) { |
724 | switch (t) { |
725 | case ParallelType::BIDz: |
726 | return "blockIdx.z" ; |
727 | case ParallelType::BIDy: |
728 | return "blockIdx.y" ; |
729 | case ParallelType::BIDx: |
730 | return "blockIdx.x" ; |
731 | case ParallelType::TIDz: |
732 | return "threadIdx.z" ; |
733 | case ParallelType::TIDy: |
734 | return "threadIdx.y" ; |
735 | case ParallelType::TIDx: |
736 | return "threadIdx.x" ; |
737 | case ParallelType::Vectorize: |
738 | return "V" ; |
739 | case ParallelType::MisalignedVectorize: |
740 | return "MV" ; |
741 | case ParallelType::Unroll: |
742 | return "UR" ; |
743 | case ParallelType::Unswitch: |
744 | return "US" ; |
745 | case ParallelType::Mma: |
746 | return "MMA" ; |
747 | case ParallelType::Group: |
748 | return "G" ; |
749 | case ParallelType::Serial: |
750 | return "S" ; |
751 | default: |
752 | TORCH_INTERNAL_ASSERT(false, "Unexpected ParallelType" ); |
753 | } |
754 | } |
755 | |
756 | std::unordered_set<ParallelType> allParallelTypesExcept( |
757 | const std::unordered_set<ParallelType>& except) { |
758 | std::unordered_set<ParallelType> result = { |
759 | ParallelType::BIDz, |
760 | ParallelType::BIDy, |
761 | ParallelType::BIDx, |
762 | ParallelType::TIDz, |
763 | ParallelType::TIDy, |
764 | ParallelType::TIDx, |
765 | ParallelType::Vectorize, |
766 | ParallelType::MisalignedVectorize, |
767 | ParallelType::Unroll, |
768 | ParallelType::Unswitch, |
769 | ParallelType::Mma, |
770 | ParallelType::Group, |
771 | ParallelType::Serial}; |
772 | for (auto t : except) { |
773 | result.erase(t); |
774 | } |
775 | return result; |
776 | } |
777 | |
778 | static const char* memory_type2string(MemoryType t) { |
779 | switch (t) { |
780 | case MemoryType::Local: |
781 | return "register" ; |
782 | case MemoryType::Shared: |
783 | return "shared" ; |
784 | case MemoryType::Global: |
785 | return "global" ; |
786 | default: |
787 | TORCH_INTERNAL_ASSERT(false, "Unexpected MemoryType" ); |
788 | } |
789 | } |
790 | |
791 | static const char* id_map_mode_type2string(IdMappingMode t) { |
792 | switch (t) { |
793 | case IdMappingMode::PERMISSIVE: |
794 | return "permissive" ; |
795 | case IdMappingMode::EXACT: |
796 | return "exact" ; |
797 | case IdMappingMode::LOOP: |
798 | return "loop" ; |
799 | default: |
800 | // Don't try to print t as it would recursively call this function |
801 | TORCH_INTERNAL_ASSERT(false, "Unexpected IdMappingMode Type." ); |
802 | } |
803 | } |
804 | |
805 | static const char* iter_type2string(IterType t) { |
806 | switch (t) { |
807 | case IterType::Iteration: |
808 | return "i" ; |
809 | case IterType::Reduction: |
810 | return "r" ; |
811 | case IterType::Broadcast: |
812 | return "b" ; |
813 | case IterType::Gather: |
814 | return "g" ; |
815 | case IterType::Stride: |
816 | return "s" ; |
817 | case IterType::VectorComponent: |
818 | return "v" ; |
819 | default: |
820 | // Don't try to print t as it would recursively call this function |
821 | TORCH_INTERNAL_ASSERT(false, "Unexpected IterType" ); |
822 | } |
823 | } |
824 | |
825 | static const char* thread_size2string(ParallelType t) { |
826 | switch (t) { |
827 | case ParallelType::BIDz: |
828 | return "gridDim.z" ; |
829 | case ParallelType::BIDy: |
830 | return "gridDim.y" ; |
831 | case ParallelType::BIDx: |
832 | return "gridDim.x" ; |
833 | case ParallelType::TIDz: |
834 | return "blockDim.z" ; |
835 | case ParallelType::TIDy: |
836 | return "blockDim.y" ; |
837 | case ParallelType::TIDx: |
838 | return "blockDim.x" ; |
839 | default: |
840 | TORCH_INTERNAL_ASSERT(false, "Unexpected parallel type" ); |
841 | } |
842 | } |
843 | |
844 | static const char* load_store_type2string(LoadStoreOpType t) { |
845 | switch (t) { |
846 | case LoadStoreOpType::LdMatrix: |
847 | return "LdMatrix" ; |
848 | case LoadStoreOpType::LdMatrixTranspose: |
849 | return "LdMatrixTranspose" ; |
850 | case LoadStoreOpType::CpAsync: |
851 | return "CpAsync" ; |
852 | default: |
853 | TORCH_INTERNAL_ASSERT(false, "Unexpected parallel type" ); |
854 | } |
855 | } |
856 | |
857 | const unsigned int _WORD_SHIFT = 16; |
858 | constexpr unsigned int supported_switch_pair(DataType t1, DataType t2) { |
859 | return ((unsigned int)t1 << _WORD_SHIFT) + (unsigned int)t2; |
860 | } |
861 | |
862 | static const char* supported_casts2string( |
863 | const std::pair<DataType, DataType>& t) { |
864 | switch (supported_switch_pair(std::get<0>(t), std::get<1>(t))) { |
865 | case supported_switch_pair(DataType::Index, DataType::Float): |
866 | case supported_switch_pair(DataType::Int, DataType::Float): |
867 | case supported_switch_pair(DataType::Int32, DataType::Float): |
868 | case supported_switch_pair(DataType::Double, DataType::Float): |
869 | case supported_switch_pair(DataType::Bool, DataType::Float): |
870 | return "(float)" ; |
871 | case supported_switch_pair(DataType::ComplexFloat, DataType::Float): |
872 | case supported_switch_pair(DataType::ComplexDouble, DataType::Float): |
873 | return "(float)std::real" ; |
874 | case supported_switch_pair(DataType::Index, DataType::Int): |
875 | case supported_switch_pair(DataType::Int32, DataType::Int): |
876 | case supported_switch_pair(DataType::Float, DataType::Int): |
877 | case supported_switch_pair(DataType::Double, DataType::Int): |
878 | case supported_switch_pair(DataType::Bool, DataType::Int): |
879 | return "(int64_t)" ; |
880 | case supported_switch_pair(DataType::ComplexFloat, DataType::Int): |
881 | case supported_switch_pair(DataType::ComplexDouble, DataType::Int): |
882 | return "(int64_t)std::real" ; |
883 | case supported_switch_pair(DataType::Index, DataType::Int32): |
884 | case supported_switch_pair(DataType::Int, DataType::Int32): |
885 | case supported_switch_pair(DataType::Float, DataType::Int32): |
886 | case supported_switch_pair(DataType::Double, DataType::Int32): |
887 | case supported_switch_pair(DataType::Bool, DataType::Int32): |
888 | return "(int32_t)" ; |
889 | case supported_switch_pair(DataType::ComplexFloat, DataType::Int32): |
890 | case supported_switch_pair(DataType::ComplexDouble, DataType::Int32): |
891 | return "(int32_t)std::real" ; |
892 | case supported_switch_pair(DataType::Int, DataType::Index): |
893 | case supported_switch_pair(DataType::Int32, DataType::Index): |
894 | case supported_switch_pair(DataType::Float, DataType::Index): |
895 | case supported_switch_pair(DataType::Double, DataType::Index): |
896 | return "(nvfuser_index_t)" ; |
897 | case supported_switch_pair(DataType::ComplexFloat, DataType::Index): |
898 | case supported_switch_pair(DataType::ComplexDouble, DataType::Index): |
899 | return "(nvfuser_index_t)std::real" ; |
900 | case supported_switch_pair(DataType::Index, DataType::Double): |
901 | case supported_switch_pair(DataType::Int, DataType::Double): |
902 | case supported_switch_pair(DataType::Int32, DataType::Double): |
903 | case supported_switch_pair(DataType::Float, DataType::Double): |
904 | case supported_switch_pair(DataType::Bool, DataType::Double): |
905 | return "(double)" ; |
906 | case supported_switch_pair(DataType::ComplexFloat, DataType::Double): |
907 | case supported_switch_pair(DataType::ComplexDouble, DataType::Double): |
908 | return "(double)std::real" ; |
909 | case supported_switch_pair(DataType::Float, DataType::Bool): |
910 | case supported_switch_pair(DataType::Double, DataType::Bool): |
911 | case supported_switch_pair(DataType::Int32, DataType::Bool): |
912 | case supported_switch_pair(DataType::Int, DataType::Bool): |
913 | return "(bool)" ; |
914 | case supported_switch_pair(DataType::ComplexFloat, DataType::Bool): |
915 | case supported_switch_pair(DataType::ComplexDouble, DataType::Bool): |
916 | return "(bool)std::real" ; |
917 | case supported_switch_pair(DataType::Index, DataType::ComplexDouble): |
918 | case supported_switch_pair(DataType::Int, DataType::ComplexDouble): |
919 | case supported_switch_pair(DataType::Int32, DataType::ComplexDouble): |
920 | case supported_switch_pair(DataType::Double, DataType::ComplexDouble): |
921 | case supported_switch_pair(DataType::Float, DataType::ComplexDouble): |
922 | case supported_switch_pair(DataType::Bool, DataType::ComplexDouble): |
923 | case supported_switch_pair(DataType::ComplexFloat, DataType::ComplexDouble): |
924 | return "(std::complex<double>)" ; |
925 | case supported_switch_pair(DataType::Index, DataType::ComplexFloat): |
926 | case supported_switch_pair(DataType::Int, DataType::ComplexFloat): |
927 | case supported_switch_pair(DataType::Int32, DataType::ComplexFloat): |
928 | case supported_switch_pair(DataType::Double, DataType::ComplexFloat): |
929 | case supported_switch_pair(DataType::Float, DataType::ComplexFloat): |
930 | case supported_switch_pair(DataType::Bool, DataType::ComplexFloat): |
931 | case supported_switch_pair(DataType::ComplexDouble, DataType::ComplexFloat): |
932 | return "(std::complex<float>)" ; |
933 | case supported_switch_pair(DataType::Float, DataType::Half): |
934 | return "__float2half" ; |
935 | case supported_switch_pair(DataType::Double, DataType::Half): |
936 | return "__double2half" ; |
937 | case supported_switch_pair(DataType::Float, DataType::BFloat16): |
938 | return "__float2bfloat" ; |
939 | case supported_switch_pair(DataType::Half, DataType::Float): |
940 | return "__half2float" ; |
941 | case supported_switch_pair(DataType::Half, DataType::Double): |
942 | return "__half2double" ; |
943 | case supported_switch_pair(DataType::BFloat16, DataType::Float): |
944 | return "__bfloat2float" ; |
945 | default: |
946 | return nullptr; |
947 | } |
948 | } |
949 | |
950 | DataType aten_to_data_type(const at::ScalarType& scalar_type) { |
951 | switch (scalar_type) { |
952 | case at::ScalarType::Bool: |
953 | return DataType::Bool; |
954 | case at::ScalarType::Double: |
955 | return DataType::Double; |
956 | case at::ScalarType::Float: |
957 | return DataType::Float; |
958 | case at::ScalarType::Half: |
959 | return DataType::Half; |
960 | case at::ScalarType::BFloat16: |
961 | return DataType::BFloat16; |
962 | case at::ScalarType::Long: |
963 | return DataType::Int; |
964 | case at::ScalarType::Int: |
965 | return DataType::Int32; |
966 | case at::ScalarType::ComplexFloat: |
967 | return DataType::ComplexFloat; |
968 | case at::ScalarType::ComplexDouble: |
969 | return DataType::ComplexDouble; |
970 | default: |
971 | return DataType::Null; |
972 | } |
973 | } |
974 | |
975 | at::ScalarType data_type_to_aten(const DataType& data_type) { |
976 | switch (data_type) { |
977 | case DataType::Bool: |
978 | return at::ScalarType::Bool; |
979 | case DataType::Double: |
980 | return at::ScalarType::Double; |
981 | case DataType::Float: |
982 | return at::ScalarType::Float; |
983 | case DataType::Half: |
984 | return at::ScalarType::Half; |
985 | case DataType::BFloat16: |
986 | return at::ScalarType::BFloat16; |
987 | case DataType::Int: |
988 | return at::ScalarType::Long; |
989 | case DataType::Index: |
990 | TORCH_INTERNAL_ASSERT( |
991 | false, |
992 | "Index is determined at compile time," , |
993 | " to convert from an aten type you need to have the compiled information. " , |
994 | "This information is passed to GpuLower at compile time, and then copied to kerned." , |
995 | "There's also this information in FusionExecutorCache and the Registry system." ); |
996 | case DataType::Int32: |
997 | return at::ScalarType::Int; |
998 | case DataType::ComplexFloat: |
999 | return at::ScalarType::ComplexFloat; |
1000 | case DataType::ComplexDouble: |
1001 | return at::ScalarType::ComplexDouble; |
1002 | default: |
1003 | TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type." ); |
1004 | } |
1005 | } |
1006 | |
1007 | std::ostream& operator<<(std::ostream& out, const ValType vtype) { |
1008 | return out << val_type2string(vtype); |
1009 | } |
1010 | |
1011 | std::ostream& operator<<(std::ostream& out, const PredicateType ptype) { |
1012 | return out << predicate_type2string(ptype); |
1013 | } |
1014 | |
1015 | std::ostream& operator<<(std::ostream& out, const DataType dtype) { |
1016 | return out << data_type2string(dtype); |
1017 | } |
1018 | |
1019 | std::ostream& operator<<(std::ostream& out, const ExprType etype) { |
1020 | return out << expr_type2string(etype); |
1021 | } |
1022 | |
1023 | std::ostream& operator<<(std::ostream& out, const UnaryOpType uotype) { |
1024 | return out << unary_op_type2string(uotype); |
1025 | } |
1026 | |
1027 | std::ostream& operator<<(std::ostream& out, const BinaryOpType botype) { |
1028 | return out << binary_op_type2string(botype); |
1029 | } |
1030 | |
1031 | std::ostream& operator<<(std::ostream& out, const TernaryOpType totype) { |
1032 | return out << ternary_op_type2string(totype); |
1033 | } |
1034 | |
1035 | std::ostream& operator<<(std::ostream& out, const RNGOpType rngtype) { |
1036 | return out << rng_op_type2string(rngtype); |
1037 | } |
1038 | |
1039 | std::ostream& operator<<(std::ostream& out, const ParallelType ptype) { |
1040 | return out << stringifyThread(ptype); |
1041 | } |
1042 | |
1043 | std::ostream& operator<<(std::ostream& out, const MemoryType mtype) { |
1044 | return out << memory_type2string(mtype); |
1045 | } |
1046 | |
1047 | std::ostream& operator<<(std::ostream& out, const IdMappingMode immtype) { |
1048 | return out << id_map_mode_type2string(immtype); |
1049 | } |
1050 | |
1051 | std::ostream& operator<<( |
1052 | std::ostream& out, |
1053 | const LoadStoreOpType load_store_type) { |
1054 | return out << load_store_type2string(load_store_type); |
1055 | } |
1056 | |
1057 | std::ostream& operator<<(std::ostream& out, const IterType bt) { |
1058 | return out << iter_type2string(bt); |
1059 | } |
1060 | |
1061 | std::ostream& operator<<(std::ostream& os, const Swizzle2DType& swizzle) { |
1062 | switch (swizzle) { |
1063 | case Swizzle2DType::NoSwizzle: |
1064 | os << "NoSwizzle" ; |
1065 | break; |
1066 | case Swizzle2DType::ZShape: |
1067 | os << "ZShape" ; |
1068 | break; |
1069 | case Swizzle2DType::Transpose: |
1070 | os << "Transpose" ; |
1071 | break; |
1072 | case Swizzle2DType::XOR: |
1073 | os << "Xor" ; |
1074 | break; |
1075 | case Swizzle2DType::Scatter: |
1076 | os << "Scatter" ; |
1077 | break; |
1078 | default: |
1079 | TORCH_INTERNAL_ASSERT(false, "undefined 2D swizzle" ); |
1080 | break; |
1081 | } |
1082 | return os; |
1083 | } |
1084 | |
1085 | std::ostream& operator<<(std::ostream& os, const SwizzleMode& swizzle) { |
1086 | switch (swizzle) { |
1087 | case SwizzleMode::NoSwizzle: |
1088 | os << "NoSwizzle" ; |
1089 | break; |
1090 | case SwizzleMode::Loop: |
1091 | os << "Loop" ; |
1092 | break; |
1093 | case SwizzleMode::Data: |
1094 | os << "Data" ; |
1095 | break; |
1096 | default: |
1097 | TORCH_INTERNAL_ASSERT(false, "undefined 2D swizzle" ); |
1098 | break; |
1099 | } |
1100 | return os; |
1101 | } |
1102 | |
1103 | c10::optional<std::string> inline_op_str(const UnaryOpType uotype) { |
1104 | const char* str = unary_op_type_inline_op2string(uotype); |
1105 | return str != nullptr ? c10::optional<std::string>(std::string(str)) |
1106 | : c10::nullopt; |
1107 | } |
1108 | |
1109 | c10::optional<std::string> inline_op_str(const BinaryOpType botype) { |
1110 | const char* str = binary_op_type_inline_op2string(botype); |
1111 | return str != nullptr ? c10::optional<std::string>(std::string(str)) |
1112 | : c10::nullopt; |
1113 | } |
1114 | |
1115 | c10::optional<std::string> inline_op_str(const RNGOpType rngtype) { |
1116 | const char* str = rng_op_type_inline_op2string(rngtype); |
1117 | return str != nullptr ? c10::optional<std::string>(std::string(str)) |
1118 | : c10::nullopt; |
1119 | } |
1120 | |
1121 | c10::optional<std::string> integer_op_str(const BinaryOpType botype) { |
1122 | const char* str = binary_op_integer_op2string(botype); |
1123 | return str != nullptr ? c10::optional<std::string>(std::string(str)) |
1124 | : c10::nullopt; |
1125 | } |
1126 | |
1127 | c10::optional<std::string> bool_op_str(const BinaryOpType botype) { |
1128 | const char* str = binary_op_bool_op2string(botype); |
1129 | return str != nullptr ? c10::optional<std::string>(std::string(str)) |
1130 | : c10::nullopt; |
1131 | } |
1132 | |
1133 | std::string stringifyThreadSize(const ParallelType ptype) { |
1134 | return thread_size2string(ptype); |
1135 | } |
1136 | |
1137 | std::string stringifyThread(const ParallelType ptype) { |
1138 | return parallel_type2string(ptype); |
1139 | } |
1140 | |
1141 | std::string typePrefix(const DataType data_type) { |
1142 | switch (data_type) { |
1143 | case DataType::Bool: |
1144 | return "b" ; |
1145 | case DataType::Double: |
1146 | return "d" ; |
1147 | case DataType::Float: |
1148 | case DataType::Half: |
1149 | case DataType::BFloat16: |
1150 | return "f" ; |
1151 | case DataType::Index: |
1152 | case DataType::Int: |
1153 | case DataType::Int32: |
1154 | return "i" ; |
1155 | case DataType::ComplexFloat: |
1156 | case DataType::ComplexDouble: |
1157 | return "c" ; |
1158 | default: |
1159 | TORCH_INTERNAL_ASSERT(false, "No data type found for scalar type." ); |
1160 | } |
1161 | } |
1162 | |
1163 | bool isParallelTypeThreadDim(ParallelType ptype) { |
1164 | return ptype == ParallelType::TIDx || ptype == ParallelType::TIDy || |
1165 | ptype == ParallelType::TIDz; |
1166 | } |
1167 | |
1168 | bool isParallelTypeBlockDim(ParallelType ptype) { |
1169 | return ptype == ParallelType::BIDx || ptype == ParallelType::BIDy || |
1170 | ptype == ParallelType::BIDz; |
1171 | } |
1172 | |
1173 | bool isParallelTypeThread(ParallelType ptype) { |
1174 | return isParallelTypeBlockDim(ptype) || isParallelTypeThreadDim(ptype); |
1175 | } |
1176 | |
1177 | bool isParallelTypeVectorize(ParallelType ptype) { |
1178 | return ptype == ParallelType::Vectorize || |
1179 | ptype == ParallelType::MisalignedVectorize; |
1180 | } |
1181 | |
1182 | c10::optional<std::string> cast_func_str( |
1183 | const std::pair<DataType, DataType>& cast) { |
1184 | const char* str = supported_casts2string(cast); |
1185 | return str != nullptr ? c10::optional<std::string>(std::string(str)) |
1186 | : c10::nullopt; |
1187 | } |
1188 | |
1189 | size_t dataTypeSize(DataType type) { |
1190 | switch (type) { |
1191 | case DataType::Bool: |
1192 | return sizeof(bool); |
1193 | case DataType::ComplexDouble: |
1194 | return sizeof(std::complex<double>); |
1195 | case DataType::ComplexFloat: |
1196 | return sizeof(std::complex<float>); |
1197 | case DataType::Double: |
1198 | return sizeof(double); |
1199 | case DataType::Float: |
1200 | return sizeof(float); |
1201 | case DataType::Half: |
1202 | return sizeof(at::Half); |
1203 | case DataType::BFloat16: |
1204 | return sizeof(at::BFloat16); |
1205 | case DataType::Index: |
1206 | TORCH_INTERNAL_ASSERT( |
1207 | false, "The actual type of Index is only known at compile time." ); |
1208 | case DataType::Int: |
1209 | return sizeof(uint64_t); |
1210 | case DataType::Int32: |
1211 | return sizeof(uint32_t); |
1212 | case DataType::Double_2: |
1213 | return sizeof(double) * 2; |
1214 | case DataType::Float_2: |
1215 | return sizeof(float) * 2; |
1216 | default: |
1217 | TORCH_INTERNAL_ASSERT(false, "Size undefined for data type, " , type); |
1218 | } |
1219 | } |
1220 | |
1221 | size_t dataTypeSize(DataType type, DataType index_type) { |
1222 | if (type == DataType::Index) { |
1223 | TORCH_INTERNAL_ASSERT( |
1224 | index_type == DataType::Int32 || index_type == DataType::Int, |
1225 | "Invalid index type of " , |
1226 | index_type); |
1227 | return dataTypeSize(index_type); |
1228 | } |
1229 | return dataTypeSize(type); |
1230 | } |
1231 | |
1232 | std::ostream& operator<<( |
1233 | std::ostream& os, |
1234 | const DoubleBufferLoopStage loop_stage) { |
1235 | switch (loop_stage) { |
1236 | case DoubleBufferLoopStage::NotApplicable: |
1237 | break; |
1238 | case DoubleBufferLoopStage::Prolog: |
1239 | os << "{DoubleBufferProlog}" ; |
1240 | break; |
1241 | case DoubleBufferLoopStage::Main: |
1242 | os << "{DoubleBufferMainLoop}" ; |
1243 | break; |
1244 | case DoubleBufferLoopStage::Epilog: |
1245 | os << "{DoubleBufferEpilog}" ; |
1246 | break; |
1247 | default: |
1248 | TORCH_INTERNAL_ASSERT(false, "unknown double buffer stage" ); |
1249 | } |
1250 | return os; |
1251 | } |
1252 | |
1253 | } // namespace cuda |
1254 | } // namespace fuser |
1255 | } // namespace jit |
1256 | } // namespace torch |
1257 | |