1// Generated from "/code/pytorch/third_party/nvfuser/runtime/fused_welford_helper.cu"
2// 2023-02-12 08:01:26
3
4namespace nvfuser_resources {
5
6constexpr const char* fused_welford_helper_cu = R"(
7namespace fused_reduction {
8
9// Tuple of Welford avg, var and N parameters.
10//
11// Template parameters:
12// - DataTypeT: Type of avg and var
13// - IndexTypeT: Type of N
14// - MakeTuple: Template template parameter to define Tuple types
15// (e.g., MakeLocalTuple>
16template <
17 int NumVals,
18 typename DataTypeT,
19 typename IndexTypeT,
20 template <int, typename>
21 typename MakeTuple>
22struct WelfordTripletTuple {
23 static constexpr int num_vals = NumVals;
24 using DataType = DataTypeT;
25 using IndexType = IndexTypeT;
26 using DataTuple = typename MakeTuple<NumVals, DataType>::type;
27 using IndexTuple = typename MakeTuple<NumVals, IndexType>::type;
28
29 DataTuple avg;
30 DataTuple var;
31 IndexTuple N;
32
33 WelfordTripletTuple(
34 const DataTuple& avg,
35 const DataTuple& var,
36 const IndexTuple& N)
37 : avg(avg), var(var), N(N) {}
38};
39
40template <int NumVals, typename DataType, typename IndexType>
41using LocalWelfordTripletTuple =
42 WelfordTripletTuple<NumVals, DataType, IndexType, MakeLocalTuple>;
43
44template <int NumVals, typename DataType, typename IndexType>
45using RefWelfordTripletTuple =
46 WelfordTripletTuple<NumVals, DataType, IndexType, MakeRefTuple>;
47
48template <int NumVals, typename DataType, typename IndexType>
49using ConstRefWelfordTripletTuple =
50 WelfordTripletTuple<NumVals, DataType, IndexType, MakeConstRefTuple>;
51
52template <int NumVals, typename DataTypeT, typename IndexTypeT>
53using VolatilePtrWelfordTripletTuple =
54 WelfordTripletTuple<NumVals, DataTypeT, IndexTypeT, MakeVolatilePtrTuple>;
55
56// Advance pointer offsets of WelfordTripleTuple. Only valid when the
57// values are pointer values.
58template <typename WelfordTripletTupleType>
59__inline__ __device__ static void operator+=(
60 WelfordTripletTupleType& triplet,
61 nvfuser_index_t offset) {
62 triplet.avg += offset;
63 triplet.var += offset;
64 triplet.N += offset;
65}
66
67// Copy each of the triplet tuples
68template <typename DstType, typename SrcType>
69__inline__ __device__ static void copyWelfordTripletTuple(
70 DstType& dst,
71 nvfuser_index_t dst_offset,
72 const SrcType& src,
73 nvfuser_index_t src_offset = 0) {
74 copyTuple(dst.avg, dst_offset, src.avg, src_offset);
75 copyTuple(dst.var, dst_offset, src.var, src_offset);
76 copyTuple(dst.N, dst_offset, src.N, src_offset);
77}
78
79// Copy each of the triplet tuples
80template <typename DstType, typename SrcType>
81__inline__ __device__ static void copyWelfordTripletTuple(
82 DstType& dst,
83 const SrcType& src,
84 nvfuser_index_t src_offset = 0) {
85 copyWelfordTripletTuple(dst, 0, src, src_offset);
86}
87
88// Copy each of the triplet tuples
89template <typename DstType, typename SrcType, typename PredType>
90__inline__ __device__ static void copyWelfordTripletTupleIf(
91 DstType& dst,
92 const SrcType& src,
93 const PredType& pred) {
94 copyTupleIf(dst.avg, src.avg, pred);
95 copyTupleIf(dst.var, src.var, pred);
96 copyTupleIf(dst.N, src.N, pred);
97}
98
99} // namespace fused_reduction
100)";
101
102} // namespace nvfuser_resources
103