1// Generated from "/code/pytorch/third_party/nvfuser/runtime/tuple.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* tuple_cu = R"(
7// std::tuple-like type
8template <typename... Types>
9struct Tuple;
10
11#define TUPLE_INCREMENT_PTR(idx) \
12 do { \
13 static_assert( \
14 IsPointerType<T##idx>::value, "Invalid for non-pointer types"); \
15 val##idx += offset; \
16 } while (0)
17
18template <typename T0>
19struct Tuple<T0> {
20 T0 val0;
21
22 Tuple() = default;
23
24 __device__ Tuple(T0 _val0) : val0(_val0) {}
25
26 // Only valid when instantiated for pointer types
27 __device__ void operator+=(nvfuser_index_t offset) {
28 TUPLE_INCREMENT_PTR(0);
29 }
30};
31
32template <typename T0, typename T1>
33struct Tuple<T0, T1> {
34 T0 val0;
35 T1 val1;
36
37 Tuple() = default;
38
39 __device__ Tuple(T0 _val0, T1 _val1) : val0(_val0), val1(_val1) {}
40
41 // Only valid when instantiated for pointer types
42 __device__ void operator+=(nvfuser_index_t offset) {
43 TUPLE_INCREMENT_PTR(0);
44 TUPLE_INCREMENT_PTR(1);
45 }
46};
47
48template <typename T0, typename T1, typename T2>
49struct Tuple<T0, T1, T2> {
50 T0 val0;
51 T1 val1;
52 T2 val2;
53
54 Tuple() = default;
55
56 __device__ Tuple(T0 _val0, T1 _val1, T2 _val2)
57 : val0(_val0), val1(_val1), val2(_val2) {}
58
59 // Only valid when instantiated for pointer types
60 __device__ void operator+=(nvfuser_index_t offset) {
61 TUPLE_INCREMENT_PTR(0);
62 TUPLE_INCREMENT_PTR(1);
63 TUPLE_INCREMENT_PTR(2);
64 }
65};
66
67template <typename T0, typename T1, typename T2, typename T3>
68struct Tuple<T0, T1, T2, T3> {
69 T0 val0;
70 T1 val1;
71 T2 val2;
72 T3 val3;
73
74 Tuple() = default;
75
76 __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3)
77 : val0(_val0), val1(_val1), val2(_val2), val3(_val3) {}
78
79 // Only valid when instantiated for pointer types
80 __device__ void operator+=(nvfuser_index_t offset) {
81 TUPLE_INCREMENT_PTR(0);
82 TUPLE_INCREMENT_PTR(1);
83 TUPLE_INCREMENT_PTR(2);
84 TUPLE_INCREMENT_PTR(3);
85 }
86};
87
88template <typename T0, typename T1, typename T2, typename T3, typename T4>
89struct Tuple<T0, T1, T2, T3, T4> {
90 T0 val0;
91 T1 val1;
92 T2 val2;
93 T3 val3;
94 T4 val4;
95
96 Tuple() = default;
97
98 __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4)
99 : val0(_val0), val1(_val1), val2(_val2), val3(_val3), val4(_val4) {}
100
101 // Only valid when instantiated for pointer types
102 __device__ void operator+=(nvfuser_index_t offset) {
103 TUPLE_INCREMENT_PTR(0);
104 TUPLE_INCREMENT_PTR(1);
105 TUPLE_INCREMENT_PTR(2);
106 TUPLE_INCREMENT_PTR(3);
107 TUPLE_INCREMENT_PTR(4);
108 }
109};
110
111template <
112 typename T0,
113 typename T1,
114 typename T2,
115 typename T3,
116 typename T4,
117 typename T5>
118struct Tuple<T0, T1, T2, T3, T4, T5> {
119 T0 val0;
120 T1 val1;
121 T2 val2;
122 T3 val3;
123 T4 val4;
124 T5 val5;
125
126 Tuple() = default;
127
128 __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4, T5 _val5)
129 : val0(_val0),
130 val1(_val1),
131 val2(_val2),
132 val3(_val3),
133 val4(_val4),
134 val5(_val5) {}
135
136 // Only valid when instantiated for pointer types
137 __device__ void operator+=(nvfuser_index_t offset) {
138 TUPLE_INCREMENT_PTR(0);
139 TUPLE_INCREMENT_PTR(1);
140 TUPLE_INCREMENT_PTR(2);
141 TUPLE_INCREMENT_PTR(3);
142 TUPLE_INCREMENT_PTR(4);
143 TUPLE_INCREMENT_PTR(5);
144 }
145};
146
147template <
148 typename T0,
149 typename T1,
150 typename T2,
151 typename T3,
152 typename T4,
153 typename T5,
154 typename T6>
155struct Tuple<T0, T1, T2, T3, T4, T5, T6> {
156 T0 val0;
157 T1 val1;
158 T2 val2;
159 T3 val3;
160 T4 val4;
161 T5 val5;
162 T6 val6;
163
164 Tuple() = default;
165
166 __device__ Tuple(
167 T0 _val0,
168 T1 _val1,
169 T2 _val2,
170 T3 _val3,
171 T4 _val4,
172 T5 _val5,
173 T6 _val6)
174 : val0(_val0),
175 val1(_val1),
176 val2(_val2),
177 val3(_val3),
178 val4(_val4),
179 val5(_val5),
180 val6(_val6) {}
181
182 // Only valid when instantiated for pointer types
183 __device__ void operator+=(nvfuser_index_t offset) {
184 TUPLE_INCREMENT_PTR(0);
185 TUPLE_INCREMENT_PTR(1);
186 TUPLE_INCREMENT_PTR(2);
187 TUPLE_INCREMENT_PTR(3);
188 TUPLE_INCREMENT_PTR(4);
189 TUPLE_INCREMENT_PTR(5);
190 TUPLE_INCREMENT_PTR(6);
191 }
192};
193
194template <
195 typename T0,
196 typename T1,
197 typename T2,
198 typename T3,
199 typename T4,
200 typename T5,
201 typename T6,
202 typename T7>
203struct Tuple<T0, T1, T2, T3, T4, T5, T6, T7> {
204 T0 val0;
205 T1 val1;
206 T2 val2;
207 T3 val3;
208 T4 val4;
209 T5 val5;
210 T6 val6;
211 T7 val7;
212
213 Tuple() = default;
214
215 __device__ Tuple(
216 T0 _val0,
217 T1 _val1,
218 T2 _val2,
219 T3 _val3,
220 T4 _val4,
221 T5 _val5,
222 T6 _val6,
223 T7 _val7)
224 : val0(_val0),
225 val1(_val1),
226 val2(_val2),
227 val3(_val3),
228 val4(_val4),
229 val5(_val5),
230 val6(_val6),
231 val7(_val7) {}
232
233 // Only valid when instantiated for pointer types
234 __device__ void operator+=(nvfuser_index_t offset) {
235 TUPLE_INCREMENT_PTR(0);
236 TUPLE_INCREMENT_PTR(1);
237 TUPLE_INCREMENT_PTR(2);
238 TUPLE_INCREMENT_PTR(3);
239 TUPLE_INCREMENT_PTR(4);
240 TUPLE_INCREMENT_PTR(5);
241 TUPLE_INCREMENT_PTR(6);
242 TUPLE_INCREMENT_PTR(7);
243 }
244};
245
246#undef TUPLE_INCREMENT_PTR
247
248// Accessor for Tuple
249template <int idx>
250struct get;
251
252#define DEFINE_TUPLE_GET(idx) \
253 template <> \
254 struct get<idx> { \
255 template <typename Tuple> \
256 __device__ auto& operator()(Tuple& vals) { \
257 return vals.val##idx; \
258 } \
259 template <typename Tuple> \
260 __device__ const auto& operator()(const Tuple& vals) { \
261 return vals.val##idx; \
262 } \
263 };
264
265DEFINE_TUPLE_GET(0);
266DEFINE_TUPLE_GET(1);
267DEFINE_TUPLE_GET(2);
268DEFINE_TUPLE_GET(3);
269DEFINE_TUPLE_GET(4);
270DEFINE_TUPLE_GET(5);
271DEFINE_TUPLE_GET(6);
272DEFINE_TUPLE_GET(7);
273#undef DEFINE_TUPLE_GET
274
275template <typename DstType, typename SrcType>
276__inline__ __device__ static void copyTuple(
277 DstType& dst,
278 nvfuser_index_t dst_offset,
279 const SrcType& src,
280 nvfuser_index_t src_offset = 0);
281
282template <typename DstType, typename SrcType>
283__inline__ __device__ static void copyTuple(
284 DstType& dst,
285 const SrcType& src,
286 nvfuser_index_t src_offset = 0);
287
288template <typename DstType>
289__inline__ __device__ static void setTuple(
290 DstType& dst,
291 typename DstType::template ValType<0> src);
292
293template <typename... Types>
294class LocalTuple {
295 public:
296 static constexpr int num_vals = sizeof...(Types);
297 using ValTypes = TypeList<Types...>;
298
299 template <int idx>
300 using ValType = typename TypeSelector<idx, Types...>::type;
301
302 LocalTuple() = default;
303
304 __device__ explicit LocalTuple(Types... args) : vals_(args...) {}
305
306 __device__ LocalTuple(const LocalTuple& other) : vals_(other.vals_) {}
307
308 template <template <typename...> typename TupleType>
309 __device__ LocalTuple(const TupleType<Types...>& other) {
310 copyTuple(*this, other);
311 }
312
313 __device__ LocalTuple& operator=(const LocalTuple<Types...>& other) {
314 copyTuple(*this, other);
315 return *this;
316 }
317
318 template <template <typename...> typename TupleType>
319 __device__ LocalTuple& operator=(const TupleType<Types...>& other) {
320 copyTuple(*this, other);
321 return *this;
322 }
323
324 template <int val_idx>
325 __device__ auto& val(nvfuser_index_t ptr_offset = 0) {
326 static_assert(val_idx < num_vals, "Out-of-range value index");
327 return get<val_idx>()(vals_);
328 }
329
330 template <int val_idx>
331 __device__ const auto& val(nvfuser_index_t ptr_offset = 0) const {
332 static_assert(val_idx < num_vals, "Out-of-range value index");
333 return get<val_idx>()(vals_);
334 }
335
336 private:
337 Tuple<Types...> vals_;
338};
339
340template <bool is_volatile, typename... Types>
341class PtrTupleBase {
342 public:
343 static constexpr int num_vals = sizeof...(Types);
344 using ValTypes = TypeList<Types...>;
345 template <int idx>
346 using ValType = typename TypeSelector<idx, Types...>::type;
347 template <int val_idx>
348 using TypeIMaybeVolatile = typename MaybeVolatile<
349 typename TypeSelector<val_idx, Types...>::type,
350 is_volatile>::type;
351
352 __device__ PtrTupleBase(Types*... args) : vals_(args...) {}
353
354 __device__ PtrTupleBase(const PtrTupleBase& other) : vals_(other.vals_) {}
355
356 // Note: this is a deep copy
357 __device__ PtrTupleBase& operator=(
358 const PtrTupleBase<is_volatile, Types...>& other) {
359 copyTuple(*this, other);
360 return *this;
361 }
362
363 template <template <typename...> typename TupleType>
364 __device__ PtrTupleBase& operator=(const TupleType<Types...>& other) {
365 copyTuple(*this, other);
366 return *this;
367 }
368
369 template <int val_idx>
370 __device__ TypeIMaybeVolatile<val_idx>& val(nvfuser_index_t ptr_offset = 0) {
371 static_assert(val_idx < num_vals, "Out-of-range value index");
372 return ((TypeIMaybeVolatile<val_idx>*)get<val_idx>()(vals_))[ptr_offset];
373 }
374
375 template <int val_idx>
376 __device__ const TypeIMaybeVolatile<val_idx>& val(
377 nvfuser_index_t ptr_offset = 0) const {
378 static_assert(val_idx < num_vals, "Out-of-range value index");
379 return ((TypeIMaybeVolatile<val_idx>*)get<val_idx>()(vals_))[ptr_offset];
380 }
381
382 __device__ void operator+=(nvfuser_index_t ptr_offset) {
383 vals_ += ptr_offset;
384 }
385
386 private:
387 Tuple<Types*...> vals_;
388};
389
390template <typename... Types>
391class RefTuple {
392 public:
393 static constexpr int num_vals = sizeof...(Types);
394 using ValTypes = TypeList<Types...>;
395 template <int idx>
396 using ValType = typename TypeSelector<idx, Types...>::type;
397
398 __device__ RefTuple(Types&... args) : vals_(args...) {}
399
400 __device__ RefTuple(const RefTuple& other) : vals_(other.vals_) {}
401
402 template <template <typename...> typename TupleType>
403 __device__ RefTuple(const TupleType<Types...>& other) {
404 copyTuple(*this, other);
405 }
406
407 __device__ RefTuple& operator=(const RefTuple<Types...>& other) {
408 copyTuple(*this, other);
409 return *this;
410 }
411
412 template <template <typename...> typename TupleType>
413 __device__ RefTuple& operator=(const TupleType<Types...>& other) {
414 copyTuple(*this, other);
415 return *this;
416 }
417
418 template <int val_idx>
419 __device__ auto& val(nvfuser_index_t ptr_offset = 0) {
420 static_assert(val_idx < num_vals, "Out-of-range value index");
421 return get<val_idx>()(vals_);
422 }
423
424 template <int val_idx>
425 __device__ const auto& val(nvfuser_index_t ptr_offset = 0) const {
426 static_assert(val_idx < num_vals, "Out-of-range value index");
427 return get<val_idx>()(vals_);
428 }
429
430 private:
431 Tuple<Types&...> vals_;
432};
433
434template <typename DstType, typename SrcType, int num_vals>
435struct TupleCopy {
436 __inline__ __device__ static void copy(
437 DstType& dst,
438 nvfuser_index_t dst_offset,
439 const SrcType& src,
440 nvfuser_index_t src_offset) {
441 static_assert(
442 IsSameType<typename DstType::ValTypes, typename SrcType::ValTypes>::
443 value,
444 "Invalid value types");
445 TupleCopy<DstType, SrcType, num_vals - 1>::copy(
446 dst, dst_offset, src, src_offset);
447 dst.val<num_vals - 1>(dst_offset) = src.val<num_vals - 1>(src_offset);
448 }
449};
450
451template <typename DstType, typename SrcType>
452struct TupleCopy<DstType, SrcType, 0> {
453 __inline__ __device__ static void copy(
454 DstType& dst,
455 nvfuser_index_t dst_offset,
456 const SrcType& src,
457 nvfuser_index_t src_offset) {}
458};
459
460template <typename DstType, typename SrcType>
461__inline__ __device__ static void copyTuple(
462 DstType& dst,
463 nvfuser_index_t dst_offset,
464 const SrcType& src,
465 nvfuser_index_t src_offset) {
466 static_assert(
467 IsSameType<typename DstType::ValTypes, typename SrcType::ValTypes>::value,
468 "Invalid value types");
469 TupleCopy<DstType, SrcType, DstType::num_vals>::copy(
470 dst, dst_offset, src, src_offset);
471};
472
473template <typename DstType, typename SrcType>
474__inline__ __device__ static void copyTuple(
475 DstType& dst,
476 const SrcType& src,
477 nvfuser_index_t src_offset) {
478 copyTuple<DstType, SrcType>(dst, 0, src, src_offset);
479};
480
481template <typename DstType, int num_vals>
482struct TupleSet {
483 __inline__ __device__ static void set(
484 DstType& dst,
485 nvfuser_index_t dst_offset,
486 typename DstType::template ValType<0> src) {
487 static_assert(
488 IsSameType<
489 typename DstType::template ValType<num_vals - 1>,
490 typename DstType::template ValType<0>>::value,
491 "Invalid value types");
492 TupleSet<DstType, num_vals - 1>::set(dst, dst_offset, src);
493 dst.val<num_vals - 1>(dst_offset) = src;
494 }
495};
496
497template <typename DstType>
498struct TupleSet<DstType, 0> {
499 __inline__ __device__ static void set(
500 DstType& dst,
501 nvfuser_index_t dst_offset,
502 typename DstType::template ValType<0> src) {}
503};
504
505template <typename DstType>
506__inline__ __device__ static void setTuple(
507 DstType& dst,
508 nvfuser_index_t dst_offset,
509 typename DstType::template ValType<0> src) {
510 TupleSet<DstType, DstType::num_vals>::set(dst, dst_offset, src);
511};
512
513template <typename DstType>
514__inline__ __device__ static void setTuple(
515 DstType& dst,
516 typename DstType::template ValType<0> src) {
517 setTuple(dst, 0, src);
518};
519
520template <typename DstType, typename SrcType, typename PredType, int num_vals>
521struct PredicatedTupleCopy {
522 __inline__ __device__ static void copy(
523 DstType& dst,
524 nvfuser_index_t dst_offset,
525 const SrcType& src,
526 nvfuser_index_t src_offset,
527 const PredType& pred) {
528 static_assert(
529 IsSameType<typename PredType::template ValType<num_vals - 1>, bool>::
530 value,
531 "Invalid predicate type");
532 PredicatedTupleCopy<DstType, SrcType, PredType, num_vals - 1>::copy(
533 dst, dst_offset, src, src_offset, pred);
534 if (pred.val<num_vals - 1>(0)) {
535 dst.val<num_vals - 1>(dst_offset) = src.val<num_vals - 1>(src_offset);
536 }
537 }
538};
539
540template <typename DstType, typename SrcType, typename PredType>
541struct PredicatedTupleCopy<DstType, SrcType, PredType, 0> {
542 __inline__ __device__ static void copy(
543 DstType& dst,
544 nvfuser_index_t dst_offset,
545 const SrcType& src,
546 nvfuser_index_t src_offset,
547 const PredType& pred) {}
548};
549
550template <typename DstType, typename SrcType, typename PredType>
551__inline__ __device__ static void copyTupleIf(
552 DstType& dst,
553 nvfuser_index_t dst_offset,
554 const SrcType& src,
555 nvfuser_index_t src_offset,
556 const PredType& pred) {
557 static_assert(
558 IsSameType<typename DstType::ValTypes, typename SrcType::ValTypes>::value,
559 "Invalid value types");
560 static_assert(
561 PredType::num_vals == DstType::num_vals, "Invalid predicate type");
562 PredicatedTupleCopy<DstType, SrcType, PredType, DstType::num_vals>::copy(
563 dst, dst_offset, src, src_offset, pred);
564};
565
566template <typename DstType, typename SrcType, typename PredType>
567__inline__ __device__ static void copyTupleIf(
568 DstType& dst,
569 const SrcType& src,
570 nvfuser_index_t src_offset,
571 const PredType& pred) {
572 copyTupleIf(dst, 0, src, src_offset, pred);
573};
574
575template <typename DstType, typename SrcType, typename PredType>
576__inline__ __device__ static void copyTupleIf(
577 DstType& dst,
578 const SrcType& src,
579 const PredType& pred) {
580 copyTupleIf(dst, 0, src, 0, pred);
581};
582
583)"
584R"(
585// Can a generic const and non-const RefTupe be defined?
586template <typename... Types>
587class ConstRefTuple {
588 public:
589 static constexpr int num_vals = sizeof...(Types);
590 using ValTypes = TypeList<Types...>;
591
592 __device__ ConstRefTuple(const Types&... args) : vals_(args...) {}
593
594 __device__ ConstRefTuple(const ConstRefTuple& other) : vals_(other.vals_) {}
595
596 template <template <typename...> typename TupleType>
597 __device__ ConstRefTuple(const TupleType<Types...>& other) {
598 copyTuple(*this, other);
599 }
600
601 template <int val_idx>
602 __device__ const auto& val(nvfuser_index_t ptr_offset = 0) const {
603 static_assert(val_idx < num_vals, "Out-of-range value index");
604 return get<val_idx>()(vals_);
605 }
606
607 private:
608 Tuple<const Types&...> vals_;
609};
610
611template <typename... Types>
612using PtrTuple = PtrTupleBase<false, Types...>;
613
614template <typename... Types>
615using VolatilePtrTuple = PtrTupleBase<true, Types...>;
616
617// Define a LocalTuple of NumVals values of type Type
618template <int NumVals, typename Type>
619struct MakeLocalTuple;
620
621template <typename Type>
622struct MakeLocalTuple<1, Type> {
623 using type = LocalTuple<Type>;
624};
625
626template <typename Type>
627struct MakeLocalTuple<2, Type> {
628 using type = LocalTuple<Type, Type>;
629};
630
631template <typename Type>
632struct MakeLocalTuple<3, Type> {
633 using type = LocalTuple<Type, Type, Type>;
634};
635
636template <typename Type>
637struct MakeLocalTuple<4, Type> {
638 using type = LocalTuple<Type, Type, Type, Type>;
639};
640
641template <typename Type>
642struct MakeLocalTuple<5, Type> {
643 using type = LocalTuple<Type, Type, Type, Type, Type>;
644};
645
646template <typename Type>
647struct MakeLocalTuple<6, Type> {
648 using type = LocalTuple<Type, Type, Type, Type, Type, Type>;
649};
650
651template <typename Type>
652struct MakeLocalTuple<7, Type> {
653 using type = LocalTuple<Type, Type, Type, Type, Type, Type, Type>;
654};
655
656template <typename Type>
657struct MakeLocalTuple<8, Type> {
658 using type = LocalTuple<Type, Type, Type, Type, Type, Type, Type, Type>;
659};
660
661template <int NumVals, typename Type>
662struct MakeRefTuple;
663
664template <typename Type>
665struct MakeRefTuple<1, Type> {
666 using type = RefTuple<Type>;
667};
668
669template <typename Type>
670struct MakeRefTuple<2, Type> {
671 using type = RefTuple<Type, Type>;
672};
673
674template <typename Type>
675struct MakeRefTuple<3, Type> {
676 using type = RefTuple<Type, Type, Type>;
677};
678
679template <typename Type>
680struct MakeRefTuple<4, Type> {
681 using type = RefTuple<Type, Type, Type, Type>;
682};
683
684template <typename Type>
685struct MakeRefTuple<5, Type> {
686 using type = RefTuple<Type, Type, Type, Type, Type>;
687};
688
689template <typename Type>
690struct MakeRefTuple<6, Type> {
691 using type = RefTuple<Type, Type, Type, Type, Type, Type>;
692};
693
694template <typename Type>
695struct MakeRefTuple<7, Type> {
696 using type = RefTuple<Type, Type, Type, Type, Type, Type, Type>;
697};
698
699template <typename Type>
700struct MakeRefTuple<8, Type> {
701 using type = RefTuple<Type, Type, Type, Type, Type, Type, Type, Type>;
702};
703
704template <int NumVals, typename Type>
705struct MakeConstRefTuple;
706
707template <typename Type>
708struct MakeConstRefTuple<1, Type> {
709 using type = ConstRefTuple<Type>;
710};
711
712template <typename Type>
713struct MakeConstRefTuple<2, Type> {
714 using type = ConstRefTuple<Type, Type>;
715};
716
717template <typename Type>
718struct MakeConstRefTuple<3, Type> {
719 using type = ConstRefTuple<Type, Type, Type>;
720};
721
722template <typename Type>
723struct MakeConstRefTuple<4, Type> {
724 using type = ConstRefTuple<Type, Type, Type, Type>;
725};
726
727template <typename Type>
728struct MakeConstRefTuple<5, Type> {
729 using type = ConstRefTuple<Type, Type, Type, Type, Type>;
730};
731
732template <typename Type>
733struct MakeConstRefTuple<6, Type> {
734 using type = ConstRefTuple<Type, Type, Type, Type, Type, Type>;
735};
736
737template <typename Type>
738struct MakeConstRefTuple<7, Type> {
739 using type = ConstRefTuple<Type, Type, Type, Type, Type, Type, Type>;
740};
741
742template <typename Type>
743struct MakeConstRefTuple<8, Type> {
744 using type = ConstRefTuple<Type, Type, Type, Type, Type, Type, Type, Type>;
745};
746
747template <int NumVals, typename Type>
748struct MakeVolatilePtrTuple;
749
750template <typename Type>
751struct MakeVolatilePtrTuple<1, Type> {
752 using type = VolatilePtrTuple<Type>;
753};
754
755template <typename Type>
756struct MakeVolatilePtrTuple<2, Type> {
757 using type = VolatilePtrTuple<Type, Type>;
758};
759
760template <typename Type>
761struct MakeVolatilePtrTuple<3, Type> {
762 using type = VolatilePtrTuple<Type, Type, Type>;
763};
764
765template <typename Type>
766struct MakeVolatilePtrTuple<4, Type> {
767 using type = VolatilePtrTuple<Type, Type, Type, Type>;
768};
769
770template <typename Type>
771struct MakeVolatilePtrTuple<5, Type> {
772 using type = VolatilePtrTuple<Type, Type, Type, Type, Type>;
773};
774
775template <typename Type>
776struct MakeVolatilePtrTuple<6, Type> {
777 using type = VolatilePtrTuple<Type, Type, Type, Type, Type, Type>;
778};
779
780template <typename Type>
781struct MakeVolatilePtrTuple<7, Type> {
782 using type = VolatilePtrTuple<Type, Type, Type, Type, Type, Type, Type>;
783};
784
785template <typename Type>
786struct MakeVolatilePtrTuple<8, Type> {
787 using type = VolatilePtrTuple<Type, Type, Type, Type, Type, Type, Type, Type>;
788};
789
790// Utility definitions. Currently only used with LocalTuple
791
792template <int idx, typename BinaryFunc, typename... DataTypes>
793struct TupleBinaryOp {
794 static __inline__ __device__ void apply(
795 BinaryFunc func,
796 const LocalTuple<DataTypes...>& lhs,
797 const LocalTuple<DataTypes...>& rhs,
798 LocalTuple<DataTypes...>& result) {
799 TupleBinaryOp<idx - 1, BinaryFunc, DataTypes...>::apply(
800 func, lhs, rhs, result);
801 result.val<idx - 1>(0) = func(lhs.val<idx - 1>(0), rhs.val<idx - 1>(0));
802 }
803};
804
805template <typename BinaryFunc, typename... DataTypes>
806struct TupleBinaryOp<0, BinaryFunc, DataTypes...> {
807 static __inline__ __device__ void apply(
808 BinaryFunc func,
809 const LocalTuple<DataTypes...>& lhs,
810 const LocalTuple<DataTypes...>& rhs,
811 LocalTuple<DataTypes...>& result) {}
812};
813
814template <typename BinaryFunc, typename... DataTypes>
815__inline__ __device__ LocalTuple<DataTypes...> apply(
816 BinaryFunc func,
817 const LocalTuple<DataTypes...>& lhs,
818 const LocalTuple<DataTypes...>& rhs) {
819 LocalTuple<DataTypes...> result = lhs;
820 TupleBinaryOp<sizeof...(DataTypes), BinaryFunc, DataTypes...>::apply(
821 func, result, rhs, result);
822 return result;
823}
824
825template <typename... BoolTypes>
826__inline__ __device__ LocalTuple<BoolTypes...> operator&&(
827 const LocalTuple<BoolTypes...>& lhs,
828 const LocalTuple<BoolTypes...>& rhs) {
829 return apply([](bool x, bool y) { return x && y; }, lhs, rhs);
830}
831
832template <typename... BoolTypes>
833__inline__ __device__ LocalTuple<BoolTypes...> operator&&(
834 bool lhs,
835 const LocalTuple<BoolTypes...>& rhs) {
836 LocalTuple<BoolTypes...> lhs_tuple;
837 setTuple(lhs_tuple, lhs);
838 return lhs_tuple && rhs;
839}
840
841template <typename... BoolTypes>
842__inline__ __device__ LocalTuple<BoolTypes...> operator&&(
843 const LocalTuple<BoolTypes...>& lhs,
844 bool rhs) {
845 LocalTuple<BoolTypes...> rhs_tuple;
846 setTuple(rhs_tuple, rhs);
847 return lhs && rhs_tuple;
848}
849)";
850
851} // namespace nvfuser_resources
852