1 | // Generated from "/code/pytorch/third_party/nvfuser/runtime/fused_welford_helper.cu" |
2 | // 2023-02-12 08:01:26 |
3 | |
4 | namespace nvfuser_resources { |
5 | |
6 | constexpr const char* fused_welford_helper_cu = R"( |
7 | namespace fused_reduction { |
8 | |
9 | // Tuple of Welford avg, var and N parameters. |
10 | // |
11 | // Template parameters: |
12 | // - DataTypeT: Type of avg and var |
13 | // - IndexTypeT: Type of N |
14 | // - MakeTuple: Template template parameter to define Tuple types |
15 | // (e.g., MakeLocalTuple> |
16 | template < |
17 | int NumVals, |
18 | typename DataTypeT, |
19 | typename IndexTypeT, |
20 | template <int, typename> |
21 | typename MakeTuple> |
22 | struct WelfordTripletTuple { |
23 | static constexpr int num_vals = NumVals; |
24 | using DataType = DataTypeT; |
25 | using IndexType = IndexTypeT; |
26 | using DataTuple = typename MakeTuple<NumVals, DataType>::type; |
27 | using IndexTuple = typename MakeTuple<NumVals, IndexType>::type; |
28 | |
29 | DataTuple avg; |
30 | DataTuple var; |
31 | IndexTuple N; |
32 | |
33 | WelfordTripletTuple( |
34 | const DataTuple& avg, |
35 | const DataTuple& var, |
36 | const IndexTuple& N) |
37 | : avg(avg), var(var), N(N) {} |
38 | }; |
39 | |
40 | template <int NumVals, typename DataType, typename IndexType> |
41 | using LocalWelfordTripletTuple = |
42 | WelfordTripletTuple<NumVals, DataType, IndexType, MakeLocalTuple>; |
43 | |
44 | template <int NumVals, typename DataType, typename IndexType> |
45 | using RefWelfordTripletTuple = |
46 | WelfordTripletTuple<NumVals, DataType, IndexType, MakeRefTuple>; |
47 | |
48 | template <int NumVals, typename DataType, typename IndexType> |
49 | using ConstRefWelfordTripletTuple = |
50 | WelfordTripletTuple<NumVals, DataType, IndexType, MakeConstRefTuple>; |
51 | |
52 | template <int NumVals, typename DataTypeT, typename IndexTypeT> |
53 | using VolatilePtrWelfordTripletTuple = |
54 | WelfordTripletTuple<NumVals, DataTypeT, IndexTypeT, MakeVolatilePtrTuple>; |
55 | |
56 | // Advance pointer offsets of WelfordTripleTuple. Only valid when the |
57 | // values are pointer values. |
58 | template <typename WelfordTripletTupleType> |
59 | __inline__ __device__ static void operator+=( |
60 | WelfordTripletTupleType& triplet, |
61 | nvfuser_index_t offset) { |
62 | triplet.avg += offset; |
63 | triplet.var += offset; |
64 | triplet.N += offset; |
65 | } |
66 | |
67 | // Copy each of the triplet tuples |
68 | template <typename DstType, typename SrcType> |
69 | __inline__ __device__ static void copyWelfordTripletTuple( |
70 | DstType& dst, |
71 | nvfuser_index_t dst_offset, |
72 | const SrcType& src, |
73 | nvfuser_index_t src_offset = 0) { |
74 | copyTuple(dst.avg, dst_offset, src.avg, src_offset); |
75 | copyTuple(dst.var, dst_offset, src.var, src_offset); |
76 | copyTuple(dst.N, dst_offset, src.N, src_offset); |
77 | } |
78 | |
79 | // Copy each of the triplet tuples |
80 | template <typename DstType, typename SrcType> |
81 | __inline__ __device__ static void copyWelfordTripletTuple( |
82 | DstType& dst, |
83 | const SrcType& src, |
84 | nvfuser_index_t src_offset = 0) { |
85 | copyWelfordTripletTuple(dst, 0, src, src_offset); |
86 | } |
87 | |
88 | // Copy each of the triplet tuples |
89 | template <typename DstType, typename SrcType, typename PredType> |
90 | __inline__ __device__ static void copyWelfordTripletTupleIf( |
91 | DstType& dst, |
92 | const SrcType& src, |
93 | const PredType& pred) { |
94 | copyTupleIf(dst.avg, src.avg, pred); |
95 | copyTupleIf(dst.var, src.var, pred); |
96 | copyTupleIf(dst.N, src.N, pred); |
97 | } |
98 | |
99 | } // namespace fused_reduction |
100 | )" ; |
101 | |
102 | } // namespace nvfuser_resources |
103 | |