1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ocl_eltwise_header = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#ifndef GPU_OCL_OCL_ELTWISE_H )==""\n"
21R"==(#define GPU_OCL_OCL_ELTWISE_H )==""\n"
22R"==(#if WITH_ELTWISE )==""\n"
23R"==(#if DT_F16 == 1 )==""\n"
24R"==(#pragma OPENCL EXTENSION cl_khr_fp16 : enable )==""\n"
25R"==(#endif )==""\n"
26R"==(#ifndef DATA_MAX )==""\n"
27R"==(#if DT_F16 == 1 )==""\n"
28R"==(#define DATA_MAX HALF_MAX )==""\n"
29R"==(#elif DT_S8 == 1 )==""\n"
30R"==(#define DATA_MAX CHAR_MAX )==""\n"
31R"==(#elif DT_U8 == 1 )==""\n"
32R"==(#define DATA_MAX UCHAR_MAX )==""\n"
33R"==(#else )==""\n"
34R"==(#define DATA_MAX FLT_MAX )==""\n"
35R"==(#endif )==""\n"
36R"==(#endif )==""\n"
37R"==(float relu_fwd(float s, float alpha) { )==""\n"
38R"==(return s > 0 ? s : ((alpha == 0) ? 0 : s * alpha); )==""\n"
39R"==(} )==""\n"
40R"==(float relu_bwd(float dd, float s, float alpha) { )==""\n"
41R"==(return s > 0 ? dd : dd * alpha; )==""\n"
42R"==(} )==""\n"
43R"==(float relu_bwd_use_dst(float dd, float d, float alpha) { )==""\n"
44R"==(return d > 0 ? dd : dd * alpha; )==""\n"
45R"==(} )==""\n"
46R"==(float linear_fwd(float s, float alpha, float beta) { )==""\n"
47R"==(return alpha * s + beta; )==""\n"
48R"==(} )==""\n"
49R"==(float linear_bwd(float dd, float alpha) { )==""\n"
50R"==(return dd * alpha; )==""\n"
51R"==(} )==""\n"
52R"==(float soft_relu_fwd(float s, float alpha) { )==""\n"
53R"==(s = alpha * s; )==""\n"
54R"==(float v = (s < log((float)DATA_MAX) ? log1p(exp(s)) : s); )==""\n"
55R"==(return v / alpha; )==""\n"
56R"==(} )==""\n"
57R"==(float soft_relu_bwd(float dd, float s, float alpha) { )==""\n"
58R"==(s = alpha * s; )==""\n"
59R"==(return dd / (1 + exp(-s)); )==""\n"
60R"==(} )==""\n"
61R"==(float logistic_fwd(float s) { )==""\n"
62R"==(return 1.0f / (1.0f + exp(-s)); )==""\n"
63R"==(} )==""\n"
64R"==(float logistic_bwd(float dd, float s) { )==""\n"
65R"==(float v = logistic_fwd(s); )==""\n"
66R"==(return dd * v * (1 - v); )==""\n"
67R"==(} )==""\n"
68R"==(float logistic_bwd_use_dst(float dd, float d) { )==""\n"
69R"==(return dd * d * (1 - d); )==""\n"
70R"==(} )==""\n"
71R"==(float square_fwd(float s) { )==""\n"
72R"==(return s * s; )==""\n"
73R"==(} )==""\n"
74R"==(float square_bwd(float dd, float s) { )==""\n"
75R"==(return dd * 2 * s; )==""\n"
76R"==(} )==""\n"
77R"==(float sqrt_fwd(float s) { )==""\n"
78R"==(return sqrt(s); )==""\n"
79R"==(} )==""\n"
80R"==(float sqrt_bwd(float dd, float s) { )==""\n"
81R"==(return dd / (2 * sqrt(s)); )==""\n"
82R"==(} )==""\n"
83R"==(float sqrt_bwd_use_dst(float dd, float d) { )==""\n"
84R"==(return dd / (2 * d); )==""\n"
85R"==(} )==""\n"
86R"==(float abs_fwd(float s) { )==""\n"
87R"==(return s > 0 ? s : -s; )==""\n"
88R"==(} )==""\n"
89R"==(float abs_bwd(float dd, float s) { )==""\n"
90R"==(return s > 0 ? dd : s < 0 ? -dd : 0; )==""\n"
91R"==(} )==""\n"
92R"==(float tanh_fwd(float s) { )==""\n"
93R"==(return tanh(s); )==""\n"
94R"==(} )==""\n"
95R"==(float tanh_bwd(float dd, float s) { )==""\n"
96R"==(float e = tanh_fwd(s); )==""\n"
97R"==(return dd * (1 - e) * (1 + e); )==""\n"
98R"==(} )==""\n"
99R"==(float tanh_bwd_use_dst(float dd, float d) { )==""\n"
100R"==(return dd * (1 - d) * (1 + d); )==""\n"
101R"==(} )==""\n"
102R"==(float mish_fwd(float s) { )==""\n"
103R"==(return s * tanh_fwd(soft_relu_fwd(s, 1.f)); )==""\n"
104R"==(} )==""\n"
105R"==(float mish_bwd(float dd, float s) { )==""\n"
106R"==(const float tanh = tanh_fwd(soft_relu_fwd(s, 1.f)); )==""\n"
107R"==(const float srelu_bwd = soft_relu_bwd(1.f, s, 1.f); )==""\n"
108R"==(const float derivative = tanh + s * srelu_bwd * (1 - pow(tanh, 2.0f)); )==""\n"
109R"==(return dd * derivative; )==""\n"
110R"==(} )==""\n"
111R"==(float elu_fwd(float s, float alpha) { )==""\n"
112R"==(return s > 0 ? s : alpha * expm1(s); )==""\n"
113R"==(} )==""\n"
114R"==(float elu_bwd(float dd, float s, float alpha) { )==""\n"
115R"==(return dd * (s > 0 ? 1 : alpha * exp(s)); )==""\n"
116R"==(} )==""\n"
117R"==(float elu_bwd_use_dst(float dd, float d, float alpha) { )==""\n"
118R"==(return dd * (d > 0 ? 1 : d + alpha); )==""\n"
119R"==(} )==""\n"
120R"==(float exp_fwd(float s) { )==""\n"
121R"==(return exp(s); )==""\n"
122R"==(} )==""\n"
123R"==(float exp_bwd(float dd, float s) { )==""\n"
124R"==(return dd * exp_fwd(s); )==""\n"
125R"==(} )==""\n"
126R"==(float exp_bwd_use_dst(float dd, float d) { )==""\n"
127R"==(return dd * d; )==""\n"
128R"==(} )==""\n"
129R"==(float gelu_tanh_fwd(float s) { )==""\n"
130R"==(const float sqrt_2_over_pi = 0.79788458347320556640625f; )==""\n"
131R"==(const float fitting_const = 0.044715f; )==""\n"
132R"==(const float g = sqrt_2_over_pi * s * (1.f + fitting_const * s * s); )==""\n"
133R"==(return (0.5f * s * (1.f + tanh_fwd(g))); )==""\n"
134R"==(} )==""\n"
135R"==(float gelu_tanh_bwd(float dd, float s) { )==""\n"
136R"==(const float sqrt_2_over_pi = 0.79788458347320556640625f; )==""\n"
137R"==(const float fitting_const = 0.044715f; )==""\n"
138R"==(const float g = sqrt_2_over_pi * s * (1.f + fitting_const * s * s); )==""\n"
139R"==(const float dg = sqrt_2_over_pi * (1.f + 3.f * fitting_const * s * s); )==""\n"
140R"==(const float v = tanh_fwd(g); )==""\n"
141R"==(return dd * 0.5f * (1.f + v) * (1.f + s * (1.f - v) * dg); )==""\n"
142R"==(} )==""\n"
143R"==(float swish_fwd(float s, float alpha) { )==""\n"
144R"==(float w = -alpha * s; )==""\n"
145R"==(return s / (1.0f + exp(w)); )==""\n"
146R"==(} )==""\n"
147R"==(float swish_bwd(float dd, float s, float alpha) { )==""\n"
148R"==(float v = logistic_fwd(alpha * s); )==""\n"
149R"==(return dd * (v + s * alpha * v * (1.0f - v)); )==""\n"
150R"==(} )==""\n"
151R"==(float log_fwd(float s) { )==""\n"
152R"==(return log(s); )==""\n"
153R"==(} )==""\n"
154R"==(float log_bwd(float dd, float s) { )==""\n"
155R"==(return dd / s; )==""\n"
156R"==(} )==""\n"
157R"==(float clip_fwd(float s, float alpha, float beta) { )==""\n"
158R"==(s = s > alpha ? s : alpha; )==""\n"
159R"==(return s > beta ? beta : s; )==""\n"
160R"==(} )==""\n"
161R"==(float clip_bwd(float dd, float s, float alpha, float beta) { )==""\n"
162R"==(return dd * (alpha < s && s <= beta ? 1 : 0); )==""\n"
163R"==(} )==""\n"
164R"==(float clip_v2_fwd(float s, float alpha, float beta) { )==""\n"
165R"==(s = s > alpha ? s : alpha; )==""\n"
166R"==(return s < beta ? s : beta; )==""\n"
167R"==(} )==""\n"
168R"==(float clip_v2_bwd(float dd, float s, float alpha, float beta) { )==""\n"
169R"==(return dd * (alpha < s && s < beta ? 1 : 0); )==""\n"
170R"==(} )==""\n"
171R"==(float clip_v2_bwd_use_dst(float dd, float d, float alpha, float beta) { )==""\n"
172R"==(return dd * (alpha < d && d < beta ? 1 : 0); )==""\n"
173R"==(} )==""\n"
174R"==(float pow_fwd(float s, float alpha, float beta) { )==""\n"
175R"==(return alpha * pow(s, beta); )==""\n"
176R"==(} )==""\n"
177R"==(float pow_bwd(float dd, float s, float alpha, float beta) { )==""\n"
178R"==(if (beta == 0) return 0; )==""\n"
179R"==(float v = pow_fwd(s, alpha * beta, beta - 1); )==""\n"
180R"==(return dd * v; )==""\n"
181R"==(} )==""\n"
182R"==(float gelu_erf_fwd(float s) { )==""\n"
183R"==(const float sqrt_2_over_2 = 0.707106769084930419921875f; )==""\n"
184R"==(float v = s * sqrt_2_over_2; )==""\n"
185R"==(return 0.5f * s * (1.f + erf(v)); )==""\n"
186R"==(} )==""\n"
187R"==(float gelu_erf_bwd(float dd, float s) { )==""\n"
188R"==(const float two_over_sqrt_pi = 1.12837922573089599609375f; )==""\n"
189R"==(const float sqrt_2_over_2 = 0.707106769084930419921875f; )==""\n"
190R"==(float v = s * sqrt_2_over_2; )==""\n"
191R"==(return dd * 0.5f * (1.f + erf(v) + v * two_over_sqrt_pi * exp(-v * v)); )==""\n"
192R"==(} )==""\n"
193R"==(float round_fwd(float s) { )==""\n"
194R"==(return (float)rint((float)s); )==""\n"
195R"==(} )==""\n"
196R"==(float hardsigmoid_fwd(float s, float alpha, float beta) { )==""\n"
197R"==(float v = alpha * s + beta; )==""\n"
198R"==(return v <= 0.f ? 0.f : v >= 1.f ? 1.f : v; )==""\n"
199R"==(} )==""\n"
200R"==(float hardsigmoid_bwd(float dd, float s, float alpha, float beta) { )==""\n"
201R"==(float v = alpha * s + beta; )==""\n"
202R"==(return v <= 0.f ? 0.f : v >= 1.f ? 0.f : dd * alpha; )==""\n"
203R"==(} )==""\n"
204R"==(float hardswish_fwd(float s, float alpha, float beta) { )==""\n"
205R"==(return s * hardsigmoid_fwd(s, alpha, beta); )==""\n"
206R"==(} )==""\n"
207R"==(float hardswish_bwd(float dd, float s, float alpha, float beta) { )==""\n"
208R"==(float v = alpha * s + beta; )==""\n"
209R"==(float w = 2.f * alpha * s + beta; )==""\n"
210R"==(return (v <= 0.f ? 0.f : v >= 1.f ? dd : dd * w); )==""\n"
211R"==(} )==""\n"
212R"==(float fwd_eltwise_common( )==""\n"
213R"==(int eltwise_alg, float x, float alpha_, float beta_, float scale_) { )==""\n"
214R"==(switch (eltwise_alg) { )==""\n"
215R"==(case RELU: return scale_ * relu_fwd(x, alpha_); break; )==""\n"
216R"==(case LINEAR: return scale_ * linear_fwd(x, alpha_, beta_); break; )==""\n"
217R"==(case SOFT_RELU: return scale_ * soft_relu_fwd(x, alpha_); break; )==""\n"
218R"==(case MISH: return scale_ * mish_fwd(x); break; )==""\n"
219R"==(case LOGISTIC: return scale_ * logistic_fwd(x); break; )==""\n"
220R"==(case TANH: return scale_ * tanh_fwd(x); break; )==""\n"
221R"==(case ELU: return scale_ * elu_fwd(x, alpha_); break; )==""\n"
222R"==(case SQUARE: return scale_ * square_fwd(x); break; )==""\n"
223R"==(case SQRT: return scale_ * sqrt_fwd(x); break; )==""\n"
224R"==(case ABS: return scale_ * abs_fwd(x); break; )==""\n"
225R"==(case EXP: return scale_ * exp_fwd(x); break; )==""\n"
226R"==(case GELU_TANH: return scale_ * gelu_tanh_fwd(x); break; )==""\n"
227R"==(case SWISH: return scale_ * swish_fwd(x, alpha_); break; )==""\n"
228R"==(case LOG: return scale_ * log_fwd(x); break; )==""\n"
229R"==(case CLIP: return scale_ * clip_fwd(x, alpha_, beta_); break; )==""\n"
230R"==(case CLIP_V2: return scale_ * clip_v2_fwd(x, alpha_, beta_); break; )==""\n"
231R"==(case POW: return scale_ * pow_fwd(x, alpha_, beta_); break; )==""\n"
232R"==(case GELU_ERF: return scale_ * gelu_erf_fwd(x); break; )==""\n"
233R"==(case ROUND: return scale_ * round_fwd(x); break; )==""\n"
234R"==(case HARDSWISH: return scale_ * hardswish_fwd(x, alpha_, beta_); break; )==""\n"
235R"==(case HARDSIGMOID: )==""\n"
236R"==(return scale_ * hardsigmoid_fwd(x, alpha_, beta_); )==""\n"
237R"==(break; )==""\n"
238R"==(case RELU_DST: return scale_ * relu_fwd(x, alpha_); break; )==""\n"
239R"==(case LOGISTIC_DST: return scale_ * logistic_fwd(x); break; )==""\n"
240R"==(case TANH_DST: return scale_ * tanh_fwd(x); break; )==""\n"
241R"==(case ELU_DST: return scale_ * elu_fwd(x, alpha_); break; )==""\n"
242R"==(case SQRT_DST: return scale_ * sqrt_fwd(x); break; )==""\n"
243R"==(case EXP_DST: return scale_ * exp_fwd(x); break; )==""\n"
244R"==(case CLIP_V2_DST: return scale_ * clip_v2_fwd(x, alpha_, beta_); break; )==""\n"
245R"==(default: return x; break; )==""\n"
246R"==(} )==""\n"
247R"==(} )==""\n"
248R"==(float fwd_eltwise(float x, float alpha_, float beta_, float scale_) { )==""\n"
249R"==(#ifdef ELTWISE_ALG )==""\n"
250R"==(return fwd_eltwise_common(ELTWISE_ALG, x, alpha_, beta_, scale_); )==""\n"
251R"==(#else )==""\n"
252R"==(return x; )==""\n"
253R"==(#endif )==""\n"
254R"==(} )==""\n"
255R"==(float bwd_eltwise(float x, float y, float alpha_, float beta_) { )==""\n"
256R"==(#ifdef ELTWISE_ALG )==""\n"
257R"==(switch (ELTWISE_ALG) { )==""\n"
258R"==(case RELU: return relu_bwd(x, y, alpha_); break; )==""\n"
259R"==(case LINEAR: return linear_bwd(x, alpha_); break; )==""\n"
260R"==(case SOFT_RELU: return soft_relu_bwd(x, y, alpha_); break; )==""\n"
261R"==(case MISH: return mish_bwd(x, y); break; )==""\n"
262R"==(case LOGISTIC: return logistic_bwd(x, y); break; )==""\n"
263R"==(case TANH: return tanh_bwd(x, y); break; )==""\n"
264R"==(case ELU: return elu_bwd(x, y, alpha_); break; )==""\n"
265R"==(case SQUARE: return square_bwd(x, y); break; )==""\n"
266R"==(case SQRT: return sqrt_bwd(x, y); break; )==""\n"
267R"==(case ABS: return abs_bwd(x, y); break; )==""\n"
268R"==(case EXP: return exp_bwd(x, y); break; )==""\n"
269R"==(case GELU_TANH: return gelu_tanh_bwd(x, y); break; )==""\n"
270R"==(case SWISH: return swish_bwd(x, y, alpha_); break; )==""\n"
271R"==(case LOG: return log_bwd(x, y); break; )==""\n"
272R"==(case CLIP: return clip_bwd(x, y, alpha_, beta_); break; )==""\n"
273R"==(case CLIP_V2: return clip_v2_bwd(x, y, alpha_, beta_); break; )==""\n"
274R"==(case POW: return pow_bwd(x, y, alpha_, beta_); break; )==""\n"
275R"==(case GELU_ERF: return gelu_erf_bwd(x, y); break; )==""\n"
276R"==(case HARDSWISH: return hardswish_bwd(x, y, alpha_, beta_); break; )==""\n"
277R"==(case HARDSIGMOID: return hardsigmoid_bwd(x, y, alpha_, beta_); break; )==""\n"
278R"==(case RELU_DST: return relu_bwd_use_dst(x, y, alpha_); break; )==""\n"
279R"==(case LOGISTIC_DST: return logistic_bwd_use_dst(x, y); break; )==""\n"
280R"==(case TANH_DST: return tanh_bwd_use_dst(x, y); break; )==""\n"
281R"==(case ELU_DST: return elu_bwd_use_dst(x, y, alpha_); break; )==""\n"
282R"==(case SQRT_DST: return sqrt_bwd_use_dst(x, y); break; )==""\n"
283R"==(case EXP_DST: return exp_bwd_use_dst(x, y); break; )==""\n"
284R"==(case CLIP_V2_DST: )==""\n"
285R"==(return clip_v2_bwd_use_dst(x, y, alpha_, beta_); )==""\n"
286R"==(break; )==""\n"
287R"==(default: return x; break; )==""\n"
288R"==(} )==""\n"
289R"==(#else )==""\n"
290R"==(return x; )==""\n"
291R"==(#endif )==""\n"
292R"==(} )==""\n"
293R"==(#endif )==""\n"
294R"==(#endif )==""\n"
295R"==()==";
296}
297}
298}
299}