1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/tuple.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* tuple_cu = R"( |
7 | // std::tuple-like type |
8 | template <typename... Types> |
9 | struct Tuple; |
10 | |
11 | #define TUPLE_INCREMENT_PTR(idx) \ |
12 | do { \ |
13 | static_assert( \ |
14 | IsPointerType<T##idx>::value, "Invalid for non-pointer types"); \ |
15 | val##idx += offset; \ |
16 | } while (0) |
17 | |
18 | template <typename T0> |
19 | struct Tuple<T0> { |
20 | T0 val0; |
21 | |
22 | Tuple() = default; |
23 | |
24 | __device__ Tuple(T0 _val0) : val0(_val0) {} |
25 | |
26 | // Only valid when instantiated for pointer types |
27 | __device__ void operator+=(nvfuser_index_t offset) { |
28 | TUPLE_INCREMENT_PTR(0); |
29 | } |
30 | }; |
31 | |
32 | template <typename T0, typename T1> |
33 | struct Tuple<T0, T1> { |
34 | T0 val0; |
35 | T1 val1; |
36 | |
37 | Tuple() = default; |
38 | |
39 | __device__ Tuple(T0 _val0, T1 _val1) : val0(_val0), val1(_val1) {} |
40 | |
41 | // Only valid when instantiated for pointer types |
42 | __device__ void operator+=(nvfuser_index_t offset) { |
43 | TUPLE_INCREMENT_PTR(0); |
44 | TUPLE_INCREMENT_PTR(1); |
45 | } |
46 | }; |
47 | |
48 | template <typename T0, typename T1, typename T2> |
49 | struct Tuple<T0, T1, T2> { |
50 | T0 val0; |
51 | T1 val1; |
52 | T2 val2; |
53 | |
54 | Tuple() = default; |
55 | |
56 | __device__ Tuple(T0 _val0, T1 _val1, T2 _val2) |
57 | : val0(_val0), val1(_val1), val2(_val2) {} |
58 | |
59 | // Only valid when instantiated for pointer types |
60 | __device__ void operator+=(nvfuser_index_t offset) { |
61 | TUPLE_INCREMENT_PTR(0); |
62 | TUPLE_INCREMENT_PTR(1); |
63 | TUPLE_INCREMENT_PTR(2); |
64 | } |
65 | }; |
66 | |
67 | template <typename T0, typename T1, typename T2, typename T3> |
68 | struct Tuple<T0, T1, T2, T3> { |
69 | T0 val0; |
70 | T1 val1; |
71 | T2 val2; |
72 | T3 val3; |
73 | |
74 | Tuple() = default; |
75 | |
76 | __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3) |
77 | : val0(_val0), val1(_val1), val2(_val2), val3(_val3) {} |
78 | |
79 | // Only valid when instantiated for pointer types |
80 | __device__ void operator+=(nvfuser_index_t offset) { |
81 | TUPLE_INCREMENT_PTR(0); |
82 | TUPLE_INCREMENT_PTR(1); |
83 | TUPLE_INCREMENT_PTR(2); |
84 | TUPLE_INCREMENT_PTR(3); |
85 | } |
86 | }; |
87 | |
88 | template <typename T0, typename T1, typename T2, typename T3, typename T4> |
89 | struct Tuple<T0, T1, T2, T3, T4> { |
90 | T0 val0; |
91 | T1 val1; |
92 | T2 val2; |
93 | T3 val3; |
94 | T4 val4; |
95 | |
96 | Tuple() = default; |
97 | |
98 | __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4) |
99 | : val0(_val0), val1(_val1), val2(_val2), val3(_val3), val4(_val4) {} |
100 | |
101 | // Only valid when instantiated for pointer types |
102 | __device__ void operator+=(nvfuser_index_t offset) { |
103 | TUPLE_INCREMENT_PTR(0); |
104 | TUPLE_INCREMENT_PTR(1); |
105 | TUPLE_INCREMENT_PTR(2); |
106 | TUPLE_INCREMENT_PTR(3); |
107 | TUPLE_INCREMENT_PTR(4); |
108 | } |
109 | }; |
110 | |
111 | template < |
112 | typename T0, |
113 | typename T1, |
114 | typename T2, |
115 | typename T3, |
116 | typename T4, |
117 | typename T5> |
118 | struct Tuple<T0, T1, T2, T3, T4, T5> { |
119 | T0 val0; |
120 | T1 val1; |
121 | T2 val2; |
122 | T3 val3; |
123 | T4 val4; |
124 | T5 val5; |
125 | |
126 | Tuple() = default; |
127 | |
128 | __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4, T5 _val5) |
129 | : val0(_val0), |
130 | val1(_val1), |
131 | val2(_val2), |
132 | val3(_val3), |
133 | val4(_val4), |
134 | val5(_val5) {} |
135 | |
136 | // Only valid when instantiated for pointer types |
137 | __device__ void operator+=(nvfuser_index_t offset) { |
138 | TUPLE_INCREMENT_PTR(0); |
139 | TUPLE_INCREMENT_PTR(1); |
140 | TUPLE_INCREMENT_PTR(2); |
141 | TUPLE_INCREMENT_PTR(3); |
142 | TUPLE_INCREMENT_PTR(4); |
143 | TUPLE_INCREMENT_PTR(5); |
144 | } |
145 | }; |
146 | |
147 | template < |
148 | typename T0, |
149 | typename T1, |
150 | typename T2, |
151 | typename T3, |
152 | typename T4, |
153 | typename T5, |
154 | typename T6> |
155 | struct Tuple<T0, T1, T2, T3, T4, T5, T6> { |
156 | T0 val0; |
157 | T1 val1; |
158 | T2 val2; |
159 | T3 val3; |
160 | T4 val4; |
161 | T5 val5; |
162 | T6 val6; |
163 | |
164 | Tuple() = default; |
165 | |
166 | __device__ Tuple( |
167 | T0 _val0, |
168 | T1 _val1, |
169 | T2 _val2, |
170 | T3 _val3, |
171 | T4 _val4, |
172 | T5 _val5, |
173 | T6 _val6) |
174 | : val0(_val0), |
175 | val1(_val1), |
176 | val2(_val2), |
177 | val3(_val3), |
178 | val4(_val4), |
179 | val5(_val5), |
180 | val6(_val6) {} |
181 | |
182 | // Only valid when instantiated for pointer types |
183 | __device__ void operator+=(nvfuser_index_t offset) { |
184 | TUPLE_INCREMENT_PTR(0); |
185 | TUPLE_INCREMENT_PTR(1); |
186 | TUPLE_INCREMENT_PTR(2); |
187 | TUPLE_INCREMENT_PTR(3); |
188 | TUPLE_INCREMENT_PTR(4); |
189 | TUPLE_INCREMENT_PTR(5); |
190 | TUPLE_INCREMENT_PTR(6); |
191 | } |
192 | }; |
193 | |
194 | template < |
195 | typename T0, |
196 | typename T1, |
197 | typename T2, |
198 | typename T3, |
199 | typename T4, |
200 | typename T5, |
201 | typename T6, |
202 | typename T7> |
203 | struct Tuple<T0, T1, T2, T3, T4, T5, T6, T7> { |
204 | T0 val0; |
205 | T1 val1; |
206 | T2 val2; |
207 | T3 val3; |
208 | T4 val4; |
209 | T5 val5; |
210 | T6 val6; |
211 | T7 val7; |
212 | |
213 | Tuple() = default; |
214 | |
215 | __device__ Tuple( |
216 | T0 _val0, |
217 | T1 _val1, |
218 | T2 _val2, |
219 | T3 _val3, |
220 | T4 _val4, |
221 | T5 _val5, |
222 | T6 _val6, |
223 | T7 _val7) |
224 | : val0(_val0), |
225 | val1(_val1), |
226 | val2(_val2), |
227 | val3(_val3), |
228 | val4(_val4), |
229 | val5(_val5), |
230 | val6(_val6), |
231 | val7(_val7) {} |
232 | |
233 | // Only valid when instantiated for pointer types |
234 | __device__ void operator+=(nvfuser_index_t offset) { |
235 | TUPLE_INCREMENT_PTR(0); |
236 | TUPLE_INCREMENT_PTR(1); |
237 | TUPLE_INCREMENT_PTR(2); |
238 | TUPLE_INCREMENT_PTR(3); |
239 | TUPLE_INCREMENT_PTR(4); |
240 | TUPLE_INCREMENT_PTR(5); |
241 | TUPLE_INCREMENT_PTR(6); |
242 | TUPLE_INCREMENT_PTR(7); |
243 | } |
244 | }; |
245 | |
246 | #undef TUPLE_INCREMENT_PTR |
247 | |
248 | // Accessor for Tuple |
249 | template <int idx> |
250 | struct get; |
251 | |
252 | #define DEFINE_TUPLE_GET(idx) \ |
253 | template <> \ |
254 | struct get<idx> { \ |
255 | template <typename Tuple> \ |
256 | __device__ auto& operator()(Tuple& vals) { \ |
257 | return vals.val##idx; \ |
258 | } \ |
259 | template <typename Tuple> \ |
260 | __device__ const auto& operator()(const Tuple& vals) { \ |
261 | return vals.val##idx; \ |
262 | } \ |
263 | }; |
264 | |
265 | DEFINE_TUPLE_GET(0); |
266 | DEFINE_TUPLE_GET(1); |
267 | DEFINE_TUPLE_GET(2); |
268 | DEFINE_TUPLE_GET(3); |
269 | DEFINE_TUPLE_GET(4); |
270 | DEFINE_TUPLE_GET(5); |
271 | DEFINE_TUPLE_GET(6); |
272 | DEFINE_TUPLE_GET(7); |
273 | #undef DEFINE_TUPLE_GET |
274 | |
275 | template <typename DstType, typename SrcType> |
276 | __inline__ __device__ static void copyTuple( |
277 | DstType& dst, |
278 | nvfuser_index_t dst_offset, |
279 | const SrcType& src, |
280 | nvfuser_index_t src_offset = 0); |
281 | |
282 | template <typename DstType, typename SrcType> |
283 | __inline__ __device__ static void copyTuple( |
284 | DstType& dst, |
285 | const SrcType& src, |
286 | nvfuser_index_t src_offset = 0); |
287 | |
288 | template <typename DstType> |
289 | __inline__ __device__ static void setTuple( |
290 | DstType& dst, |
291 | typename DstType::template ValType<0> src); |
292 | |
293 | template <typename... Types> |
294 | class LocalTuple { |
295 | public: |
296 | static constexpr int num_vals = sizeof...(Types); |
297 | using ValTypes = TypeList<Types...>; |
298 | |
299 | template <int idx> |
300 | using ValType = typename TypeSelector<idx, Types...>::type; |
301 | |
302 | LocalTuple() = default; |
303 | |
304 | __device__ explicit LocalTuple(Types... args) : vals_(args...) {} |
305 | |
306 | __device__ LocalTuple(const LocalTuple& other) : vals_(other.vals_) {} |
307 | |
308 | template <template <typename...> typename TupleType> |
309 | __device__ LocalTuple(const TupleType<Types...>& other) { |
310 | copyTuple(*this, other); |
311 | } |
312 | |
313 | __device__ LocalTuple& operator=(const LocalTuple<Types...>& other) { |
314 | copyTuple(*this, other); |
315 | return *this; |
316 | } |
317 | |
318 | template <template <typename...> typename TupleType> |
319 | __device__ LocalTuple& operator=(const TupleType<Types...>& other) { |
320 | copyTuple(*this, other); |
321 | return *this; |
322 | } |
323 | |
324 | template <int val_idx> |
325 | __device__ auto& val(nvfuser_index_t ptr_offset = 0) { |
326 | static_assert(val_idx < num_vals, "Out-of-range value index"); |
327 | return get<val_idx>()(vals_); |
328 | } |
329 | |
330 | template <int val_idx> |
331 | __device__ const auto& val(nvfuser_index_t ptr_offset = 0) const { |
332 | static_assert(val_idx < num_vals, "Out-of-range value index"); |
333 | return get<val_idx>()(vals_); |
334 | } |
335 | |
336 | private: |
337 | Tuple<Types...> vals_; |
338 | }; |
339 | |
340 | template <bool is_volatile, typename... Types> |
341 | class PtrTupleBase { |
342 | public: |
343 | static constexpr int num_vals = sizeof...(Types); |
344 | using ValTypes = TypeList<Types...>; |
345 | template <int idx> |
346 | using ValType = typename TypeSelector<idx, Types...>::type; |
347 | template <int val_idx> |
348 | using TypeIMaybeVolatile = typename MaybeVolatile< |
349 | typename TypeSelector<val_idx, Types...>::type, |
350 | is_volatile>::type; |
351 | |
352 | __device__ PtrTupleBase(Types*... args) : vals_(args...) {} |
353 | |
354 | __device__ PtrTupleBase(const PtrTupleBase& other) : vals_(other.vals_) {} |
355 | |
356 | // Note: this is a deep copy |
357 | __device__ PtrTupleBase& operator=( |
358 | const PtrTupleBase<is_volatile, Types...>& other) { |
359 | copyTuple(*this, other); |
360 | return *this; |
361 | } |
362 | |
363 | template <template <typename...> typename TupleType> |
364 | __device__ PtrTupleBase& operator=(const TupleType<Types...>& other) { |
365 | copyTuple(*this, other); |
366 | return *this; |
367 | } |
368 | |
369 | template <int val_idx> |
370 | __device__ TypeIMaybeVolatile<val_idx>& val(nvfuser_index_t ptr_offset = 0) { |
371 | static_assert(val_idx < num_vals, "Out-of-range value index"); |
372 | return ((TypeIMaybeVolatile<val_idx>*)get<val_idx>()(vals_))[ptr_offset]; |
373 | } |
374 | |
375 | template <int val_idx> |
376 | __device__ const TypeIMaybeVolatile<val_idx>& val( |
377 | nvfuser_index_t ptr_offset = 0) const { |
378 | static_assert(val_idx < num_vals, "Out-of-range value index"); |
379 | return ((TypeIMaybeVolatile<val_idx>*)get<val_idx>()(vals_))[ptr_offset]; |
380 | } |
381 | |
382 | __device__ void operator+=(nvfuser_index_t ptr_offset) { |
383 | vals_ += ptr_offset; |
384 | } |
385 | |
386 | private: |
387 | Tuple<Types*...> vals_; |
388 | }; |
389 | |
390 | template <typename... Types> |
391 | class RefTuple { |
392 | public: |
393 | static constexpr int num_vals = sizeof...(Types); |
394 | using ValTypes = TypeList<Types...>; |
395 | template <int idx> |
396 | using ValType = typename TypeSelector<idx, Types...>::type; |
397 | |
398 | __device__ RefTuple(Types&... args) : vals_(args...) {} |
399 | |
400 | __device__ RefTuple(const RefTuple& other) : vals_(other.vals_) {} |
401 | |
402 | template <template <typename...> typename TupleType> |
403 | __device__ RefTuple(const TupleType<Types...>& other) { |
404 | copyTuple(*this, other); |
405 | } |
406 | |
407 | __device__ RefTuple& operator=(const RefTuple<Types...>& other) { |
408 | copyTuple(*this, other); |
409 | return *this; |
410 | } |
411 | |
412 | template <template <typename...> typename TupleType> |
413 | __device__ RefTuple& operator=(const TupleType<Types...>& other) { |
414 | copyTuple(*this, other); |
415 | return *this; |
416 | } |
417 | |
418 | template <int val_idx> |
419 | __device__ auto& val(nvfuser_index_t ptr_offset = 0) { |
420 | static_assert(val_idx < num_vals, "Out-of-range value index"); |
421 | return get<val_idx>()(vals_); |
422 | } |
423 | |
424 | template <int val_idx> |
425 | __device__ const auto& val(nvfuser_index_t ptr_offset = 0) const { |
426 | static_assert(val_idx < num_vals, "Out-of-range value index"); |
427 | return get<val_idx>()(vals_); |
428 | } |
429 | |
430 | private: |
431 | Tuple<Types&...> vals_; |
432 | }; |
433 | |
434 | template <typename DstType, typename SrcType, int num_vals> |
435 | struct TupleCopy { |
436 | __inline__ __device__ static void copy( |
437 | DstType& dst, |
438 | nvfuser_index_t dst_offset, |
439 | const SrcType& src, |
440 | nvfuser_index_t src_offset) { |
441 | static_assert( |
442 | IsSameType<typename DstType::ValTypes, typename SrcType::ValTypes>:: |
443 | value, |
444 | "Invalid value types"); |
445 | TupleCopy<DstType, SrcType, num_vals - 1>::copy( |
446 | dst, dst_offset, src, src_offset); |
447 | dst.val<num_vals - 1>(dst_offset) = src.val<num_vals - 1>(src_offset); |
448 | } |
449 | }; |
450 | |
451 | template <typename DstType, typename SrcType> |
452 | struct TupleCopy<DstType, SrcType, 0> { |
453 | __inline__ __device__ static void copy( |
454 | DstType& dst, |
455 | nvfuser_index_t dst_offset, |
456 | const SrcType& src, |
457 | nvfuser_index_t src_offset) {} |
458 | }; |
459 | |
460 | template <typename DstType, typename SrcType> |
461 | __inline__ __device__ static void copyTuple( |
462 | DstType& dst, |
463 | nvfuser_index_t dst_offset, |
464 | const SrcType& src, |
465 | nvfuser_index_t src_offset) { |
466 | static_assert( |
467 | IsSameType<typename DstType::ValTypes, typename SrcType::ValTypes>::value, |
468 | "Invalid value types"); |
469 | TupleCopy<DstType, SrcType, DstType::num_vals>::copy( |
470 | dst, dst_offset, src, src_offset); |
471 | }; |
472 | |
473 | template <typename DstType, typename SrcType> |
474 | __inline__ __device__ static void copyTuple( |
475 | DstType& dst, |
476 | const SrcType& src, |
477 | nvfuser_index_t src_offset) { |
478 | copyTuple<DstType, SrcType>(dst, 0, src, src_offset); |
479 | }; |
480 | |
481 | template <typename DstType, int num_vals> |
482 | struct TupleSet { |
483 | __inline__ __device__ static void set( |
484 | DstType& dst, |
485 | nvfuser_index_t dst_offset, |
486 | typename DstType::template ValType<0> src) { |
487 | static_assert( |
488 | IsSameType< |
489 | typename DstType::template ValType<num_vals - 1>, |
490 | typename DstType::template ValType<0>>::value, |
491 | "Invalid value types"); |
492 | TupleSet<DstType, num_vals - 1>::set(dst, dst_offset, src); |
493 | dst.val<num_vals - 1>(dst_offset) = src; |
494 | } |
495 | }; |
496 | |
497 | template <typename DstType> |
498 | struct TupleSet<DstType, 0> { |
499 | __inline__ __device__ static void set( |
500 | DstType& dst, |
501 | nvfuser_index_t dst_offset, |
502 | typename DstType::template ValType<0> src) {} |
503 | }; |
504 | |
505 | template <typename DstType> |
506 | __inline__ __device__ static void setTuple( |
507 | DstType& dst, |
508 | nvfuser_index_t dst_offset, |
509 | typename DstType::template ValType<0> src) { |
510 | TupleSet<DstType, DstType::num_vals>::set(dst, dst_offset, src); |
511 | }; |
512 | |
513 | template <typename DstType> |
514 | __inline__ __device__ static void setTuple( |
515 | DstType& dst, |
516 | typename DstType::template ValType<0> src) { |
517 | setTuple(dst, 0, src); |
518 | }; |
519 | |
520 | template <typename DstType, typename SrcType, typename PredType, int num_vals> |
521 | struct PredicatedTupleCopy { |
522 | __inline__ __device__ static void copy( |
523 | DstType& dst, |
524 | nvfuser_index_t dst_offset, |
525 | const SrcType& src, |
526 | nvfuser_index_t src_offset, |
527 | const PredType& pred) { |
528 | static_assert( |
529 | IsSameType<typename PredType::template ValType<num_vals - 1>, bool>:: |
530 | value, |
531 | "Invalid predicate type"); |
532 | PredicatedTupleCopy<DstType, SrcType, PredType, num_vals - 1>::copy( |
533 | dst, dst_offset, src, src_offset, pred); |
534 | if (pred.val<num_vals - 1>(0)) { |
535 | dst.val<num_vals - 1>(dst_offset) = src.val<num_vals - 1>(src_offset); |
536 | } |
537 | } |
538 | }; |
539 | |
540 | template <typename DstType, typename SrcType, typename PredType> |
541 | struct PredicatedTupleCopy<DstType, SrcType, PredType, 0> { |
542 | __inline__ __device__ static void copy( |
543 | DstType& dst, |
544 | nvfuser_index_t dst_offset, |
545 | const SrcType& src, |
546 | nvfuser_index_t src_offset, |
547 | const PredType& pred) {} |
548 | }; |
549 | |
550 | template <typename DstType, typename SrcType, typename PredType> |
551 | __inline__ __device__ static void copyTupleIf( |
552 | DstType& dst, |
553 | nvfuser_index_t dst_offset, |
554 | const SrcType& src, |
555 | nvfuser_index_t src_offset, |
556 | const PredType& pred) { |
557 | static_assert( |
558 | IsSameType<typename DstType::ValTypes, typename SrcType::ValTypes>::value, |
559 | "Invalid value types"); |
560 | static_assert( |
561 | PredType::num_vals == DstType::num_vals, "Invalid predicate type"); |
562 | PredicatedTupleCopy<DstType, SrcType, PredType, DstType::num_vals>::copy( |
563 | dst, dst_offset, src, src_offset, pred); |
564 | }; |
565 | |
566 | template <typename DstType, typename SrcType, typename PredType> |
567 | __inline__ __device__ static void copyTupleIf( |
568 | DstType& dst, |
569 | const SrcType& src, |
570 | nvfuser_index_t src_offset, |
571 | const PredType& pred) { |
572 | copyTupleIf(dst, 0, src, src_offset, pred); |
573 | }; |
574 | |
575 | template <typename DstType, typename SrcType, typename PredType> |
576 | __inline__ __device__ static void copyTupleIf( |
577 | DstType& dst, |
578 | const SrcType& src, |
579 | const PredType& pred) { |
580 | copyTupleIf(dst, 0, src, 0, pred); |
581 | }; |
582 | |
583 | )" |
584 | R"( |
585 | // Can a generic const and non-const RefTupe be defined? |
586 | template <typename... Types> |
587 | class ConstRefTuple { |
588 | public: |
589 | static constexpr int num_vals = sizeof...(Types); |
590 | using ValTypes = TypeList<Types...>; |
591 | |
592 | __device__ ConstRefTuple(const Types&... args) : vals_(args...) {} |
593 | |
594 | __device__ ConstRefTuple(const ConstRefTuple& other) : vals_(other.vals_) {} |
595 | |
596 | template <template <typename...> typename TupleType> |
597 | __device__ ConstRefTuple(const TupleType<Types...>& other) { |
598 | copyTuple(*this, other); |
599 | } |
600 | |
601 | template <int val_idx> |
602 | __device__ const auto& val(nvfuser_index_t ptr_offset = 0) const { |
603 | static_assert(val_idx < num_vals, "Out-of-range value index"); |
604 | return get<val_idx>()(vals_); |
605 | } |
606 | |
607 | private: |
608 | Tuple<const Types&...> vals_; |
609 | }; |
610 | |
611 | template <typename... Types> |
612 | using PtrTuple = PtrTupleBase<false, Types...>; |
613 | |
614 | template <typename... Types> |
615 | using VolatilePtrTuple = PtrTupleBase<true, Types...>; |
616 | |
617 | // Define a LocalTuple of NumVals values of type Type |
618 | template <int NumVals, typename Type> |
619 | struct MakeLocalTuple; |
620 | |
621 | template <typename Type> |
622 | struct MakeLocalTuple<1, Type> { |
623 | using type = LocalTuple<Type>; |
624 | }; |
625 | |
626 | template <typename Type> |
627 | struct MakeLocalTuple<2, Type> { |
628 | using type = LocalTuple<Type, Type>; |
629 | }; |
630 | |
631 | template <typename Type> |
632 | struct MakeLocalTuple<3, Type> { |
633 | using type = LocalTuple<Type, Type, Type>; |
634 | }; |
635 | |
636 | template <typename Type> |
637 | struct MakeLocalTuple<4, Type> { |
638 | using type = LocalTuple<Type, Type, Type, Type>; |
639 | }; |
640 | |
641 | template <typename Type> |
642 | struct MakeLocalTuple<5, Type> { |
643 | using type = LocalTuple<Type, Type, Type, Type, Type>; |
644 | }; |
645 | |
646 | template <typename Type> |
647 | struct MakeLocalTuple<6, Type> { |
648 | using type = LocalTuple<Type, Type, Type, Type, Type, Type>; |
649 | }; |
650 | |
651 | template <typename Type> |
652 | struct MakeLocalTuple<7, Type> { |
653 | using type = LocalTuple<Type, Type, Type, Type, Type, Type, Type>; |
654 | }; |
655 | |
656 | template <typename Type> |
657 | struct MakeLocalTuple<8, Type> { |
658 | using type = LocalTuple<Type, Type, Type, Type, Type, Type, Type, Type>; |
659 | }; |
660 | |
661 | template <int NumVals, typename Type> |
662 | struct MakeRefTuple; |
663 | |
664 | template <typename Type> |
665 | struct MakeRefTuple<1, Type> { |
666 | using type = RefTuple<Type>; |
667 | }; |
668 | |
669 | template <typename Type> |
670 | struct MakeRefTuple<2, Type> { |
671 | using type = RefTuple<Type, Type>; |
672 | }; |
673 | |
674 | template <typename Type> |
675 | struct MakeRefTuple<3, Type> { |
676 | using type = RefTuple<Type, Type, Type>; |
677 | }; |
678 | |
679 | template <typename Type> |
680 | struct MakeRefTuple<4, Type> { |
681 | using type = RefTuple<Type, Type, Type, Type>; |
682 | }; |
683 | |
684 | template <typename Type> |
685 | struct MakeRefTuple<5, Type> { |
686 | using type = RefTuple<Type, Type, Type, Type, Type>; |
687 | }; |
688 | |
689 | template <typename Type> |
690 | struct MakeRefTuple<6, Type> { |
691 | using type = RefTuple<Type, Type, Type, Type, Type, Type>; |
692 | }; |
693 | |
694 | template <typename Type> |
695 | struct MakeRefTuple<7, Type> { |
696 | using type = RefTuple<Type, Type, Type, Type, Type, Type, Type>; |
697 | }; |
698 | |
699 | template <typename Type> |
700 | struct MakeRefTuple<8, Type> { |
701 | using type = RefTuple<Type, Type, Type, Type, Type, Type, Type, Type>; |
702 | }; |
703 | |
704 | template <int NumVals, typename Type> |
705 | struct MakeConstRefTuple; |
706 | |
707 | template <typename Type> |
708 | struct MakeConstRefTuple<1, Type> { |
709 | using type = ConstRefTuple<Type>; |
710 | }; |
711 | |
712 | template <typename Type> |
713 | struct MakeConstRefTuple<2, Type> { |
714 | using type = ConstRefTuple<Type, Type>; |
715 | }; |
716 | |
717 | template <typename Type> |
718 | struct MakeConstRefTuple<3, Type> { |
719 | using type = ConstRefTuple<Type, Type, Type>; |
720 | }; |
721 | |
722 | template <typename Type> |
723 | struct MakeConstRefTuple<4, Type> { |
724 | using type = ConstRefTuple<Type, Type, Type, Type>; |
725 | }; |
726 | |
727 | template <typename Type> |
728 | struct MakeConstRefTuple<5, Type> { |
729 | using type = ConstRefTuple<Type, Type, Type, Type, Type>; |
730 | }; |
731 | |
732 | template <typename Type> |
733 | struct MakeConstRefTuple<6, Type> { |
734 | using type = ConstRefTuple<Type, Type, Type, Type, Type, Type>; |
735 | }; |
736 | |
737 | template <typename Type> |
738 | struct MakeConstRefTuple<7, Type> { |
739 | using type = ConstRefTuple<Type, Type, Type, Type, Type, Type, Type>; |
740 | }; |
741 | |
742 | template <typename Type> |
743 | struct MakeConstRefTuple<8, Type> { |
744 | using type = ConstRefTuple<Type, Type, Type, Type, Type, Type, Type, Type>; |
745 | }; |
746 | |
747 | template <int NumVals, typename Type> |
748 | struct MakeVolatilePtrTuple; |
749 | |
750 | template <typename Type> |
751 | struct MakeVolatilePtrTuple<1, Type> { |
752 | using type = VolatilePtrTuple<Type>; |
753 | }; |
754 | |
755 | template <typename Type> |
756 | struct MakeVolatilePtrTuple<2, Type> { |
757 | using type = VolatilePtrTuple<Type, Type>; |
758 | }; |
759 | |
760 | template <typename Type> |
761 | struct MakeVolatilePtrTuple<3, Type> { |
762 | using type = VolatilePtrTuple<Type, Type, Type>; |
763 | }; |
764 | |
765 | template <typename Type> |
766 | struct MakeVolatilePtrTuple<4, Type> { |
767 | using type = VolatilePtrTuple<Type, Type, Type, Type>; |
768 | }; |
769 | |
770 | template <typename Type> |
771 | struct MakeVolatilePtrTuple<5, Type> { |
772 | using type = VolatilePtrTuple<Type, Type, Type, Type, Type>; |
773 | }; |
774 | |
775 | template <typename Type> |
776 | struct MakeVolatilePtrTuple<6, Type> { |
777 | using type = VolatilePtrTuple<Type, Type, Type, Type, Type, Type>; |
778 | }; |
779 | |
780 | template <typename Type> |
781 | struct MakeVolatilePtrTuple<7, Type> { |
782 | using type = VolatilePtrTuple<Type, Type, Type, Type, Type, Type, Type>; |
783 | }; |
784 | |
785 | template <typename Type> |
786 | struct MakeVolatilePtrTuple<8, Type> { |
787 | using type = VolatilePtrTuple<Type, Type, Type, Type, Type, Type, Type, Type>; |
788 | }; |
789 | |
790 | // Utility definitions. Currently only used with LocalTuple |
791 | |
792 | template <int idx, typename BinaryFunc, typename... DataTypes> |
793 | struct TupleBinaryOp { |
794 | static __inline__ __device__ void apply( |
795 | BinaryFunc func, |
796 | const LocalTuple<DataTypes...>& lhs, |
797 | const LocalTuple<DataTypes...>& rhs, |
798 | LocalTuple<DataTypes...>& result) { |
799 | TupleBinaryOp<idx - 1, BinaryFunc, DataTypes...>::apply( |
800 | func, lhs, rhs, result); |
801 | result.val<idx - 1>(0) = func(lhs.val<idx - 1>(0), rhs.val<idx - 1>(0)); |
802 | } |
803 | }; |
804 | |
805 | template <typename BinaryFunc, typename... DataTypes> |
806 | struct TupleBinaryOp<0, BinaryFunc, DataTypes...> { |
807 | static __inline__ __device__ void apply( |
808 | BinaryFunc func, |
809 | const LocalTuple<DataTypes...>& lhs, |
810 | const LocalTuple<DataTypes...>& rhs, |
811 | LocalTuple<DataTypes...>& result) {} |
812 | }; |
813 | |
814 | template <typename BinaryFunc, typename... DataTypes> |
815 | __inline__ __device__ LocalTuple<DataTypes...> apply( |
816 | BinaryFunc func, |
817 | const LocalTuple<DataTypes...>& lhs, |
818 | const LocalTuple<DataTypes...>& rhs) { |
819 | LocalTuple<DataTypes...> result = lhs; |
820 | TupleBinaryOp<sizeof...(DataTypes), BinaryFunc, DataTypes...>::apply( |
821 | func, result, rhs, result); |
822 | return result; |
823 | } |
824 | |
825 | template <typename... BoolTypes> |
826 | __inline__ __device__ LocalTuple<BoolTypes...> operator&&( |
827 | const LocalTuple<BoolTypes...>& lhs, |
828 | const LocalTuple<BoolTypes...>& rhs) { |
829 | return apply([](bool x, bool y) { return x && y; }, lhs, rhs); |
830 | } |
831 | |
832 | template <typename... BoolTypes> |
833 | __inline__ __device__ LocalTuple<BoolTypes...> operator&&( |
834 | bool lhs, |
835 | const LocalTuple<BoolTypes...>& rhs) { |
836 | LocalTuple<BoolTypes...> lhs_tuple; |
837 | setTuple(lhs_tuple, lhs); |
838 | return lhs_tuple && rhs; |
839 | } |
840 | |
841 | template <typename... BoolTypes> |
842 | __inline__ __device__ LocalTuple<BoolTypes...> operator&&( |
843 | const LocalTuple<BoolTypes...>& lhs, |
844 | bool rhs) { |
845 | LocalTuple<BoolTypes...> rhs_tuple; |
846 | setTuple(rhs_tuple, rhs); |
847 | return lhs && rhs_tuple; |
848 | } |
849 | )" ; |
850 | |
851 | } // namespace nvfuser_resources |
852 | |