1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *ref_reduction_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2020-2022 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/ocl_post_ops.h" )==" "\n" |
21 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
22 | R"==(#if defined(IS_MAX) )==" "\n" |
23 | R"==(#define INIT_ACC TO_DEF_ACC_DATA_T(DATA_MIN) )==" "\n" |
24 | R"==(#elif defined(IS_MIN) )==" "\n" |
25 | R"==(#define INIT_ACC TO_DEF_ACC_DATA_T(DATA_MAX) )==" "\n" |
26 | R"==(#elif defined(IS_MUL) )==" "\n" |
27 | R"==(#define INIT_ACC TO_DEF_ACC_DATA_T(DATA_ONE) )==" "\n" |
28 | R"==(#else )==" "\n" |
29 | R"==(#define INIT_ACC TO_DEF_ACC_DATA_T(DATA_ZERO) )==" "\n" |
30 | R"==(#endif )==" "\n" |
31 | R"==(#if defined(IS_MAX) )==" "\n" |
32 | R"==(#if defined(SRC_DT_S8) || defined(SRC_DT_U8) )==" "\n" |
33 | R"==(#define ACCUMULATE(x, y) max(x, y) )==" "\n" |
34 | R"==(#else )==" "\n" |
35 | R"==(#define ACCUMULATE(x, y) fmax(x, y) )==" "\n" |
36 | R"==(#endif )==" "\n" |
37 | R"==(#elif defined(IS_MIN) )==" "\n" |
38 | R"==(#if defined(SRC_DT_S8) || defined(SRC_DT_U8) )==" "\n" |
39 | R"==(#define ACCUMULATE(x, y) min(x, y) )==" "\n" |
40 | R"==(#else )==" "\n" |
41 | R"==(#define ACCUMULATE(x, y) fmin(x, y) )==" "\n" |
42 | R"==(#endif )==" "\n" |
43 | R"==(#elif defined(IS_MEAN) || defined(IS_SUM) )==" "\n" |
44 | R"==(#define ACCUMULATE(x, y) (x + y) )==" "\n" |
45 | R"==(#elif defined(IS_MUL) )==" "\n" |
46 | R"==(#define ACCUMULATE(x, y) (x * y) )==" "\n" |
47 | R"==(#else )==" "\n" |
48 | R"==(#define ACCUMULATE(x, y) (x + pow(fabs(y), POWER)) )==" "\n" |
49 | R"==(#endif )==" "\n" |
50 | R"==(#if defined(IS_MEAN) )==" "\n" |
51 | R"==(#define FINALIZE(x) (x / DIV) )==" "\n" |
52 | R"==(#elif defined(IS_LP_MAX) )==" "\n" |
53 | R"==(#define FINALIZE(x) rootn(fmax(x, EPS), POWER) )==" "\n" |
54 | R"==(#elif defined(IS_LP_SUM) )==" "\n" |
55 | R"==(#define FINALIZE(x) rootn(x + EPS, POWER) )==" "\n" |
56 | R"==(#elif defined(IS_P_MAX) )==" "\n" |
57 | R"==(#define FINALIZE(x) fmax(x, EPS) )==" "\n" |
58 | R"==(#elif defined(IS_P_SUM) )==" "\n" |
59 | R"==(#define FINALIZE(x) (x + EPS) )==" "\n" |
60 | R"==(#else )==" "\n" |
61 | R"==(#define FINALIZE(x) (x) )==" "\n" |
62 | R"==(#endif )==" "\n" |
63 | R"==(#if NDIMS == 6 )==" "\n" |
64 | R"==(#define _SRC_OFF(x0, x1, x2, x3, x4, x5) OFF_MD(SRC, x0, x1, x2, x3, x4, x5) )==" "\n" |
65 | R"==(#define _DST_OFF(x0, x1, x2, x3, x4, x5) OFF_MD(DST, x0, x1, x2, x3, x4, x5) )==" "\n" |
66 | R"==(#elif NDIMS == 1 )==" "\n" |
67 | R"==(#define _SRC_OFF(x0, x1, x2, x3, x4, x5) (x0) )==" "\n" |
68 | R"==(#define _DST_OFF(x0, x1, x2, x3, x4, x5) (x0) )==" "\n" |
69 | R"==(#else )==" "\n" |
70 | R"==(#define _SRC_OFF(x0, x1, ignore, x3, x4, x5) SRC_OFF(x0, x1, x3, x4, x5) )==" "\n" |
71 | R"==(#define _DST_OFF(x0, x1, ignore, x3, x4, x5) DST_OFF(x0, x1, x3, x4, x5) )==" "\n" |
72 | R"==(#endif )==" "\n" |
73 | R"==(#if NDIMS == 6 )==" "\n" |
74 | R"==(#define ITERATE_OVER_REDUCTION_D2 \ )==" "\n" |
75 | R"==(for_(int d2_off = 0; d2_off < REDUCTION_D2; d2_off++) )==" "\n" |
76 | R"==(#define D2_OFF d2_off )==" "\n" |
77 | R"==(#else )==" "\n" |
78 | R"==(#define ITERATE_OVER_REDUCTION_D2 )==" "\n" |
79 | R"==(#define D2_OFF 0 )==" "\n" |
80 | R"==(#endif )==" "\n" |
81 | R"==(#define _DST_OFF_MODULO_DIM(x0, x1, x2, x3, x4, x5) \ )==" "\n" |
82 | R"==(({ \ )==" "\n" |
83 | R"==(int ret_val; \ )==" "\n" |
84 | R"==(if (NDIMS == 1) \ )==" "\n" |
85 | R"==(ret_val = _DST_OFF(x0 % DST_D0, 0, 0, 0, 0, 0); \ )==" "\n" |
86 | R"==(else if (NDIMS == 2) \ )==" "\n" |
87 | R"==(ret_val = _DST_OFF(x0 % DST_D0, x1 % DST_D1, 0, 0, 0, 0); \ )==" "\n" |
88 | R"==(else if (NDIMS == 3) \ )==" "\n" |
89 | R"==(ret_val = _DST_OFF( \ )==" "\n" |
90 | R"==(x0 % DST_D0, x1 % DST_D1, 0, 0, 0, x5 % DST_D2); \ )==" "\n" |
91 | R"==(else if (NDIMS == 4) \ )==" "\n" |
92 | R"==(ret_val = _DST_OFF( \ )==" "\n" |
93 | R"==(x0 % DST_D0, x1 % DST_D1, 0, 0, x4 % DST_D2, x5 % DST_D3); \ )==" "\n" |
94 | R"==(else if (NDIMS == 5) \ )==" "\n" |
95 | R"==(ret_val = _DST_OFF(x0 % DST_D0, x1 % DST_D1, 0, x3 % DST_D2, \ )==" "\n" |
96 | R"==(x4 % DST_D3, x5 % DST_D4); \ )==" "\n" |
97 | R"==(else \ )==" "\n" |
98 | R"==(ret_val = _DST_OFF(x0 % DST_D0, x1 % DST_D1, x2 % DST_D2, \ )==" "\n" |
99 | R"==(x3 % DST_D3, x4 % DST_D4, x5 % DST_D5); \ )==" "\n" |
100 | R"==(ret_val; \ )==" "\n" |
101 | R"==(}) )==" "\n" |
102 | R"==(__kernel void ref_reduce( )==" "\n" |
103 | R"==(__global SRC_DATA_T *src, __global DST_DATA_T *dst POST_OP_ARGS) { )==" "\n" |
104 | R"==(int d0 = GWS_GET_D0(); )==" "\n" |
105 | R"==(int d1 = GWS_GET_D1(); )==" "\n" |
106 | R"==(int d2 = GWS_GET_D2(); )==" "\n" |
107 | R"==(int d3 = GWS_GET_D3(); )==" "\n" |
108 | R"==(int d4 = GWS_GET_D4(); )==" "\n" |
109 | R"==(int d5 = GWS_GET_D5(); )==" "\n" |
110 | R"==(DEF_ACC_DATA_T acc = INIT_ACC; )==" "\n" |
111 | R"==(for_(int d0_off = 0; d0_off < REDUCTION_D0; d0_off++) )==" "\n" |
112 | R"==(for_(int d1_off = 0; d1_off < REDUCTION_D1; d1_off++) )==" "\n" |
113 | R"==(ITERATE_OVER_REDUCTION_D2 )==" "\n" |
114 | R"==(for_(int d3_off = 0; d3_off < REDUCTION_D3; d3_off++) )==" "\n" |
115 | R"==(for_(int d4_off = 0; d4_off < REDUCTION_D4; d4_off++) )==" "\n" |
116 | R"==(for_(int d5_off = 0; d5_off < REDUCTION_D5; d5_off++) )==" "\n" |
117 | R"==({ )==" "\n" |
118 | R"==(const int src_off = _SRC_OFF(d0 + d0_off, d1 + d1_off, d2 + D2_OFF, )==" "\n" |
119 | R"==(d3 + d3_off, d4 + d4_off, d5 + d5_off); )==" "\n" |
120 | R"==(acc = ACCUMULATE(acc, TO_DEF_ACC_DATA_T(src[src_off])); )==" "\n" |
121 | R"==(} )==" "\n" |
122 | R"==(float res = convert_float(acc); )==" "\n" |
123 | R"==(res = FINALIZE(res); )==" "\n" |
124 | R"==(const int dst_off = _DST_OFF_MODULO_DIM(d0, d1, d2, d3, d4, d5); )==" "\n" |
125 | R"==(const int dst_off_pd = _DST_OFF(d0, d1, d2, d3, d4, d5); )==" "\n" |
126 | R"==(float dst_val; )==" "\n" |
127 | R"==(#if WITH_SUM )==" "\n" |
128 | R"==(dst_val = DST_TO_REF(dst[dst_off]); )==" "\n" |
129 | R"==(#endif )==" "\n" |
130 | R"==(#if NDIMS == 4 )==" "\n" |
131 | R"==(#if REDUCTION_D1 != 1 )==" "\n" |
132 | R"==(d1 = 0; )==" "\n" |
133 | R"==(d2 = d4; )==" "\n" |
134 | R"==(d3 = d5; )==" "\n" |
135 | R"==(#elif REDUCTION_D4 != 1 )==" "\n" |
136 | R"==(d2 = 0; )==" "\n" |
137 | R"==(d3 = d5; )==" "\n" |
138 | R"==(#elif REDUCTION_D5 != 1 )==" "\n" |
139 | R"==(d2 = d4; )==" "\n" |
140 | R"==(d3 = 0; )==" "\n" |
141 | R"==(#endif )==" "\n" |
142 | R"==(APPLY_POST_OPS_SERIAL( )==" "\n" |
143 | R"==(res, float, dst_val, float, d0, 1, d1, 1, d2, 1, d3, 1, 0, 1, 0, 1); )==" "\n" |
144 | R"==(#elif NDIMS == 5 )==" "\n" |
145 | R"==(#if REDUCTION_D1 != 1 )==" "\n" |
146 | R"==(d1 = 0; )==" "\n" |
147 | R"==(#elif REDUCTION_D5 != 1 )==" "\n" |
148 | R"==(d5 = 0; )==" "\n" |
149 | R"==(#endif )==" "\n" |
150 | R"==(APPLY_POST_OPS_SERIAL(res, float, dst_val, float, d0, 1, d1, 1, d3, 1, d4, )==" "\n" |
151 | R"==(1, d5, 1, 0, 1); )==" "\n" |
152 | R"==(#else )==" "\n" |
153 | R"==(APPLY_POST_OPS_SERIAL(res, float, dst_val, float, d0, 1, d1, 1, d2, 1, d3, )==" "\n" |
154 | R"==(1, d4, 1, d5, 1); )==" "\n" |
155 | R"==(#endif )==" "\n" |
156 | R"==(if (dst_off_pd != dst_off) res = 0.f; )==" "\n" |
157 | R"==(dst[dst_off_pd] = TO_DST(res); )==" "\n" |
158 | R"==(} )==" "\n" |
159 | R"==()==" ; |
160 | } |
161 | } |
162 | } |
163 | } |