1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_prelu_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_eltwise.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#if IS_FWD )==""\n"
23R"==(__kernel void ref_prelu_fwd(const __global SRC_DATA_T *src, )==""\n"
24R"==(const __global WEI_DATA_T *weights, __global DST_DATA_T *dst) { )==""\n"
25R"==(const int d0 = GWS_GET_D0(); )==""\n"
26R"==(const int d1 = GWS_GET_D1(); )==""\n"
27R"==(const int d2 = GWS_GET_D2(); )==""\n"
28R"==(const int d3 = GWS_GET_D3(); )==""\n"
29R"==(const int d4 = GWS_GET_D4(); )==""\n"
30R"==(const int d5 = GWS_GET_D5(); )==""\n"
31R"==(const unsigned data_off = OFF_MD(SRC, d0, d1, d2, d3, d4, d5); )==""\n"
32R"==(const unsigned wei_off = OFF_MD(WEI, d0 % WEI_D0, d1 % WEI_D1, d2 % WEI_D2, )==""\n"
33R"==(d3 % WEI_D3, d4 % WEI_D4, d5 % WEI_D5); )==""\n"
34R"==(const float src_data = SRC_TO_REF(src[data_off]); )==""\n"
35R"==(const float wei_data = WEI_TO_REF(weights[wei_off]); )==""\n"
36R"==(const float res_data = relu_fwd(src_data, wei_data); )==""\n"
37R"==(dst[data_off] = TO_DST(res_data); )==""\n"
38R"==(} )==""\n"
39R"==(#else )==""\n"
40R"==(__kernel void ref_prelu_bwd(const __global SRC_DATA_T *src, )==""\n"
41R"==(const __global WEI_DATA_T *weights, const __global DST_DATA_T *diff_dst, )==""\n"
42R"==(__global DIFF_SRC_DATA_T *diff_src, )==""\n"
43R"==(__global DIFF_WEI_DATA_T *diff_weights) { )==""\n"
44R"==(const int d0 = GWS_GET_D0(); )==""\n"
45R"==(const int d1 = GWS_GET_D1(); )==""\n"
46R"==(const int d2 = GWS_GET_D2(); )==""\n"
47R"==(const int d3 = GWS_GET_D3(); )==""\n"
48R"==(const int d4 = GWS_GET_D4(); )==""\n"
49R"==(const int d5 = GWS_GET_D5(); )==""\n"
50R"==(const unsigned data_off = OFF_MD(SRC, d0 % SRC_D0, d1 % SRC_D1, d2 % SRC_D2, )==""\n"
51R"==(d3 % SRC_D3, d4 % SRC_D4, d5 % SRC_D5); )==""\n"
52R"==(const unsigned data_off_pd = OFF_MD(SRC, d0 % SRC_PD0, d1 % SRC_PD1, )==""\n"
53R"==(d2 % SRC_PD2, d3 % SRC_PD3, d4 % SRC_PD4, d5 % SRC_PD5); )==""\n"
54R"==(const unsigned wei_off = OFF_MD(WEI, d0 % WEI_D0, d1 % WEI_D1, d2 % WEI_D2, )==""\n"
55R"==(d3 % WEI_D3, d4 % WEI_D4, d5 % WEI_D5); )==""\n"
56R"==(const float src_data = SRC_TO_REF(src[data_off]); )==""\n"
57R"==(const float diff_dst_data = DST_TO_REF(diff_dst[data_off]); )==""\n"
58R"==(const float wei_data = WEI_TO_REF(weights[wei_off]); )==""\n"
59R"==(float diff_src_data )==""\n"
60R"==(= src_data > 0 ? diff_dst_data : diff_dst_data * wei_data; )==""\n"
61R"==(if (data_off != data_off_pd) diff_src_data = 0.f; )==""\n"
62R"==(#define COORDINATES_ARE_IN_RANGE(mem_prefix) \ )==""\n"
63R"==(({ \ )==""\n"
64R"==(bool is_in_range = d0 < CONCAT2(mem_prefix, _PD0) \ )==""\n"
65R"==(&& d1 < CONCAT2(mem_prefix, _PD1) \ )==""\n"
66R"==(&& d2 < CONCAT2(mem_prefix, _PD2) \ )==""\n"
67R"==(&& d3 < CONCAT2(mem_prefix, _PD3) \ )==""\n"
68R"==(&& d4 < CONCAT2(mem_prefix, _PD4) \ )==""\n"
69R"==(&& d5 < CONCAT2(mem_prefix, _PD5); \ )==""\n"
70R"==(is_in_range; \ )==""\n"
71R"==(}) )==""\n"
72R"==(if (COORDINATES_ARE_IN_RANGE(SRC)) )==""\n"
73R"==(diff_src[data_off_pd] = TO_SRC(diff_src_data); )==""\n"
74R"==(if (!COORDINATES_ARE_IN_RANGE(DIFF_WEI)) return; )==""\n"
75R"==(const unsigned diff_wei_off = OFF_MD(DIFF_WEI, d0 % DIFF_WEI_D0, )==""\n"
76R"==(d1 % DIFF_WEI_D1, d2 % DIFF_WEI_D2, d3 % DIFF_WEI_D3, )==""\n"
77R"==(d4 % DIFF_WEI_D4, d5 % DIFF_WEI_D5); )==""\n"
78R"==(const unsigned diff_wei_off_pd = OFF_MD(DIFF_WEI, d0 % DIFF_WEI_PD0, )==""\n"
79R"==(d1 % DIFF_WEI_PD1, d2 % DIFF_WEI_PD2, d3 % DIFF_WEI_PD3, )==""\n"
80R"==(d4 % DIFF_WEI_PD4, d5 % DIFF_WEI_PD5); )==""\n"
81R"==(float diff_wei_data = src_data > 0 ? 0 : diff_dst_data * src_data; )==""\n"
82R"==(if (diff_wei_off != diff_wei_off_pd) diff_wei_data = 0.f; )==""\n"
83R"==(#if DIFF_WEI_DT_F32 )==""\n"
84R"==(diff_weights[diff_wei_off_pd] = diff_wei_data; )==""\n"
85R"==(#else )==""\n"
86R"==(diff_weights[diff_wei_off_pd] = TO_DIFF_WEI(diff_wei_data); )==""\n"
87R"==(#endif )==""\n"
88R"==(} )==""\n"
89R"==(#endif )==""\n"
90R"==()==";
91}
92}
93}
94}