1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_eltwise_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_eltwise.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
22R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
23R"==(#define DATA_OFF(x0, x1, x2, x3, x4, x5) OFF_MD(DATA, x0, x1, x2, x3, x4, x5) )==""\n"
24R"==(#define DIFF_DATA_OFF(x0, x1, x2, x3, x4, x5) \ )==""\n"
25R"==(OFF_MD(DIFF_DATA, x0, x1, x2, x3, x4, x5) )==""\n"
26R"==(#define KERNEL_ATTR __attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
27R"==(#if IS_FWD )==""\n"
28R"==(KERNEL_ATTR )==""\n"
29R"==(__kernel void ref_eltwise_fwd(__global DATA_T *src, __global DATA_T *dst, )==""\n"
30R"==(float alpha, float beta POST_OP_ARGS) { )==""\n"
31R"==(#if USE_GWS_GET )==""\n"
32R"==(int d0 = GWS_GET_D0(); )==""\n"
33R"==(int d1 = GWS_GET_D1(); )==""\n"
34R"==(int d2 = GWS_GET_D2(); )==""\n"
35R"==(int d3 = GWS_GET_D3(); )==""\n"
36R"==(int d4 = GWS_GET_D4(); )==""\n"
37R"==(int d5 = GWS_GET_D5(); )==""\n"
38R"==(const size_t data_off = DATA_OFF(d0, d1, d2, d3, d4, d5); )==""\n"
39R"==(if (d0 >= DATA_D0 || d1 >= DATA_D1 || d2 >= DATA_D2 || d3 >= DATA_D3 )==""\n"
40R"==(|| d4 >= DATA_D4 || d5 >= DATA_D5) { )==""\n"
41R"==(dst[data_off] = CONVERT_DATA_T(0.f); )==""\n"
42R"==(return; )==""\n"
43R"==(} )==""\n"
44R"==(#else )==""\n"
45R"==(const size_t data_off = get_global_id(0) )==""\n"
46R"==(#if GWS1 > 1 )==""\n"
47R"==(+ get_global_id(1) * GWS0 )==""\n"
48R"==(#endif )==""\n"
49R"==(#if GWS2 > 1 )==""\n"
50R"==(+ get_global_id(2) * GWS0 * GWS1 )==""\n"
51R"==(#endif )==""\n"
52R"==(; )==""\n"
53R"==(const int d0 = 0; )==""\n"
54R"==(const int d1 = 0; )==""\n"
55R"==(const int d2 = 0; )==""\n"
56R"==(const int d3 = 0; )==""\n"
57R"==(const int d4 = 0; )==""\n"
58R"==(const int d5 = 0; )==""\n"
59R"==(#endif )==""\n"
60R"==(#if DT_F16 == 1 )==""\n"
61R"==(float tmp_s = CONVERT_FLOAT_T(src[data_off]); )==""\n"
62R"==(#else )==""\n"
63R"==(float tmp_s = DATA_TO_REF(src[data_off]); )==""\n"
64R"==(#endif )==""\n"
65R"==(tmp_s = fwd_eltwise(tmp_s, alpha, beta, 1.0f); )==""\n"
66R"==(float dst_data; )==""\n"
67R"==(#if WITH_SUM )==""\n"
68R"==(dst_data = convert_float(DATA_TO_REF(dst[data_off])); )==""\n"
69R"==(#endif )==""\n"
70R"==(APPLY_POST_OPS_SERIAL(tmp_s, float, dst_data, float, d0, 1, d1, 1, d2, 1, )==""\n"
71R"==(d3, 1, d4, 1, d5, 1); )==""\n"
72R"==(dst[data_off] = CONVERT_DATA_T(tmp_s); )==""\n"
73R"==(} )==""\n"
74R"==(#else )==""\n"
75R"==(#if DT_F32 == 1 || DT_BF16 == 1 )==""\n"
76R"==(KERNEL_ATTR )==""\n"
77R"==(__kernel void ref_eltwise_bwd(__global DATA_T *src, __global DATA_T *diff_src, )==""\n"
78R"==(__global DATA_T *diff_dst, float alpha, float beta) { )==""\n"
79R"==(int d0 = GWS_GET_D0(); )==""\n"
80R"==(int d1 = GWS_GET_D1(); )==""\n"
81R"==(int d2 = GWS_GET_D2(); )==""\n"
82R"==(int d3 = GWS_GET_D3(); )==""\n"
83R"==(int d4 = GWS_GET_D4(); )==""\n"
84R"==(int d5 = GWS_GET_D5(); )==""\n"
85R"==(const size_t data_off = DATA_OFF(d0, d1, d2, d3, d4, d5); )==""\n"
86R"==(const size_t diff_data_off = DIFF_DATA_OFF(d0, d1, d2, d3, d4, d5); )==""\n"
87R"==(if (d0 >= DATA_D0 || d1 >= DATA_D1 || d2 >= DATA_D2 || d3 >= DATA_D3 )==""\n"
88R"==(|| d4 >= DATA_D4 || d5 >= DATA_D5) { )==""\n"
89R"==(diff_src[diff_data_off] = CONVERT_DATA_T(0.f); )==""\n"
90R"==(return; )==""\n"
91R"==(} )==""\n"
92R"==(POST_OP_DATA_T tmp_dd = DATA_TO_REF(diff_dst[diff_data_off]); )==""\n"
93R"==(POST_OP_DATA_T tmp_s = DATA_TO_REF(src[data_off]); )==""\n"
94R"==(diff_src[diff_data_off] )==""\n"
95R"==(= CONVERT_DATA_T(bwd_eltwise(tmp_dd, tmp_s, alpha, beta)); )==""\n"
96R"==(} )==""\n"
97R"==(#endif )==""\n"
98R"==(#endif )==""\n"
99R"==()==";
100}
101}
102}
103}