1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_lrn_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2020 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
21R"==(#if IS_FWD == 1 )==""\n"
22R"==(KERNEL_ATTR )==""\n"
23R"==(__kernel void ref_lrn_fwd(__global const DATA_T *src, )==""\n"
24R"==(#if IS_TRAINING == 1 )==""\n"
25R"==(__global DEF_ACC_DATA_T *ws, )==""\n"
26R"==(#endif )==""\n"
27R"==(__global DATA_T *dst) { )==""\n"
28R"==(const uint mb = GWS_GET_MB(); )==""\n"
29R"==(const uint ic = GWS_GET_IC(); )==""\n"
30R"==(const uint id = GWS_GET_ID(); )==""\n"
31R"==(const uint ih = GWS_GET_IH(); )==""\n"
32R"==(const uint iw = GWS_GET_IW(); )==""\n"
33R"==(const uint src_index = SRC_OFF(mb, ic, id, ih, iw); )==""\n"
34R"==(const uint dst_index = DST_OFF(mb, ic, id, ih, iw); )==""\n"
35R"==(DEF_ACC_DATA_T sum = 0.0f; )==""\n"
36R"==(#if ACROSS_CHANNEL )==""\n"
37R"==(for (int j = 0; j < LOCAL_SIZE; j++) { )==""\n"
38R"==(const int z_idx = (j + ic - PADDING); )==""\n"
39R"==(bool zero = (z_idx < 0 || z_idx >= IC); )==""\n"
40R"==(DEF_ACC_DATA_T val = zero )==""\n"
41R"==(? 0.0f )==""\n"
42R"==(: TO_DEF_ACC_DATA_T(src[SRC_OFF(mb, z_idx, id, ih, iw)]); )==""\n"
43R"==(sum += val * val; )==""\n"
44R"==(} )==""\n"
45R"==(#else )==""\n"
46R"==(const int d = (int)id - PADDING; )==""\n"
47R"==(const int h = (int)ih - PADDING; )==""\n"
48R"==(const int w = (int)iw - PADDING; )==""\n"
49R"==(const int d_start = max(d, 0); )==""\n"
50R"==(const int h_start = max(h, 0); )==""\n"
51R"==(const int w_start = max(w, 0); )==""\n"
52R"==(const int d_end = min(d + LOCAL_SIZE, ID); )==""\n"
53R"==(const int h_end = min(h + LOCAL_SIZE, IH); )==""\n"
54R"==(const int w_end = min(w + LOCAL_SIZE, IW); )==""\n"
55R"==(for (int k = d_start; k < d_end; ++k) { )==""\n"
56R"==(for (int j = h_start; j < h_end; ++j) { )==""\n"
57R"==(for (int i = w_start; i < w_end; ++i) { )==""\n"
58R"==(DEF_ACC_DATA_T val )==""\n"
59R"==(= TO_DEF_ACC_DATA_T(src[SRC_OFF(mb, ic, k, j, i)]); )==""\n"
60R"==(sum += val * val; )==""\n"
61R"==(} )==""\n"
62R"==(} )==""\n"
63R"==(} )==""\n"
64R"==(#endif )==""\n"
65R"==(const DEF_ACC_DATA_T num_elements_div = NUM_ELEMENTS_DIV; )==""\n"
66R"==(const DEF_ACC_DATA_T base = (DEF_ACC_DATA_T)LRN_K )==""\n"
67R"==(+ (DEF_ACC_DATA_T)LRN_ALPHA * sum * num_elements_div; )==""\n"
68R"==(const DEF_ACC_DATA_T normalization_factor )==""\n"
69R"==(= native_powr(base, (DEF_ACC_DATA_T)(-LRN_BETA)); )==""\n"
70R"==(const DEF_ACC_DATA_T val = TO_DEF_ACC_DATA_T(src[src_index]); )==""\n"
71R"==(const DEF_ACC_DATA_T normres = val * normalization_factor; )==""\n"
72R"==(#if IS_TRAINING == 1 )==""\n"
73R"==(ws[dst_index] = base; )==""\n"
74R"==(#endif )==""\n"
75R"==(dst[dst_index] = TO_DATA_T(normres); )==""\n"
76R"==(} )==""\n"
77R"==(#endif )==""\n"
78R"==(#if IS_BWD == 1 )==""\n"
79R"==(KERNEL_ATTR )==""\n"
80R"==(__kernel void ref_lrn_bwd(__global const DATA_T *src, )==""\n"
81R"==(__global const DATA_T *diff_dst, __global DEF_ACC_DATA_T *ws, )==""\n"
82R"==(__global DATA_T *diff_src) { )==""\n"
83R"==(const uint mb = GWS_GET_MB(); )==""\n"
84R"==(const uint ic = GWS_GET_IC(); )==""\n"
85R"==(const uint id = GWS_GET_ID(); )==""\n"
86R"==(const uint ih = GWS_GET_IH(); )==""\n"
87R"==(const uint iw = GWS_GET_IW(); )==""\n"
88R"==(const uint src_index = SRC_OFF(mb, ic, id, ih, iw); )==""\n"
89R"==(const uint dst_index = DST_OFF(mb, ic, id, ih, iw); )==""\n"
90R"==(const DEF_ACC_DATA_T num_elements_div = NUM_ELEMENTS_DIV; )==""\n"
91R"==(DEF_ACC_DATA_T B = 0; )==""\n"
92R"==(#if ACROSS_CHANNEL )==""\n"
93R"==(for (int j = 0; j < LOCAL_SIZE; j++) { )==""\n"
94R"==(const int z_idx = (j + ic - PADDING); )==""\n"
95R"==(bool zero = (z_idx < 0 || z_idx >= IC); )==""\n"
96R"==(if (!zero) { )==""\n"
97R"==(DEF_ACC_DATA_T val )==""\n"
98R"==(= TO_DEF_ACC_DATA_T(src[SRC_OFF(mb, z_idx, id, ih, iw)]); )==""\n"
99R"==(DEF_ACC_DATA_T omega = ws[SRC_OFF(mb, z_idx, id, ih, iw)]; )==""\n"
100R"==(DEF_ACC_DATA_T tmp = (DEF_ACC_DATA_T)1.0f )==""\n"
101R"==(/ native_powr(omega, (DEF_ACC_DATA_T)LRN_BETA + 1); )==""\n"
102R"==(B += tmp * val )==""\n"
103R"==(* TO_DEF_ACC_DATA_T( )==""\n"
104R"==(diff_dst[DST_OFF(mb, z_idx, id, ih, iw)]); )==""\n"
105R"==(} )==""\n"
106R"==(} )==""\n"
107R"==(#else )==""\n"
108R"==(const int d = (int)id - PADDING; )==""\n"
109R"==(const int h = (int)ih - PADDING; )==""\n"
110R"==(const int w = (int)iw - PADDING; )==""\n"
111R"==(const int d_start = max(d, 0); )==""\n"
112R"==(const int h_start = max(h, 0); )==""\n"
113R"==(const int w_start = max(w, 0); )==""\n"
114R"==(const int d_end = min(d + LOCAL_SIZE, ID); )==""\n"
115R"==(const int h_end = min(h + LOCAL_SIZE, IH); )==""\n"
116R"==(const int w_end = min(w + LOCAL_SIZE, IW); )==""\n"
117R"==(for (int k = d_start; k < d_end; ++k) { )==""\n"
118R"==(for (int j = h_start; j < h_end; ++j) { )==""\n"
119R"==(for (int i = w_start; i < w_end; ++i) { )==""\n"
120R"==(int data_off = SRC_OFF(mb, ic, k, j, i); )==""\n"
121R"==(DEF_ACC_DATA_T val = TO_DEF_ACC_DATA_T(src[data_off]); )==""\n"
122R"==(DEF_ACC_DATA_T omega = ws[data_off]; )==""\n"
123R"==(DEF_ACC_DATA_T tmp = (DEF_ACC_DATA_T)1.0f )==""\n"
124R"==(/ native_powr(omega, (DEF_ACC_DATA_T)(LRN_BETA + 1)); )==""\n"
125R"==(B += tmp * val * TO_DEF_ACC_DATA_T(diff_dst[data_off]); )==""\n"
126R"==(} )==""\n"
127R"==(} )==""\n"
128R"==(} )==""\n"
129R"==(#endif )==""\n"
130R"==(const DEF_ACC_DATA_T A )==""\n"
131R"==(= native_powr(ws[src_index], (DEF_ACC_DATA_T)-LRN_BETA) )==""\n"
132R"==(* TO_DEF_ACC_DATA_T(diff_dst[dst_index]); )==""\n"
133R"==(diff_src[src_index] = TO_DATA_T(A )==""\n"
134R"==(- TO_DEF_ACC_DATA_T(src[src_index]) * 2 * (DEF_ACC_DATA_T)LRN_ALPHA )==""\n"
135R"==(* (DEF_ACC_DATA_T)LRN_BETA * num_elements_div * B); )==""\n"
136R"==(} )==""\n"
137R"==(#endif )==""\n"
138R"==()==";
139}
140}
141}
142}