1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_global_pooling_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2021-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
21R"==(#define ALG_AVG (ALG_AVG_NP || ALG_AVG_P) )==""\n"
22R"==(#if IS_FWD )==""\n"
23R"==(KERNEL_ATTR )==""\n"
24R"==(__kernel void gen9_global_pooling_fwd( )==""\n"
25R"==(__global DATA_T *src, __global int *ws, __global DST_DATA_T *dst) { )==""\n"
26R"==(const int mb = get_global_id(0) / C; )==""\n"
27R"==(const int oc = get_global_id(0) % C; )==""\n"
28R"==(const uint dst_off = DST_OFF(mb, oc, 0, 0, 0); )==""\n"
29R"==(#if ALG_MAX )==""\n"
30R"==(#if DT_BF16 )==""\n"
31R"==(DEF_ACC_DATA_T dst_val = DATA_TO_REF(src[0]); )==""\n"
32R"==(#else )==""\n"
33R"==(float dst_val = src[0]; )==""\n"
34R"==(#endif )==""\n"
35R"==(#if IS_TRAINING )==""\n"
36R"==(int max_idx = -1; )==""\n"
37R"==(#endif )==""\n"
38R"==(#else )==""\n"
39R"==(#if DT_BF16 )==""\n"
40R"==(DEF_ACC_DATA_T dst_val = DATA_TO_REF(0.f); )==""\n"
41R"==(#else )==""\n"
42R"==(float dst_val = 0.f; )==""\n"
43R"==(#endif )==""\n"
44R"==(#endif )==""\n"
45R"==(for (int id = 0; id < ID; id++) { )==""\n"
46R"==(for (int ih = 0; ih < IH; ih++) { )==""\n"
47R"==(for (int iw = 0; iw < IW; iw++) { )==""\n"
48R"==(uint src_off = SRC_OFF(mb, oc, id, ih, iw); )==""\n"
49R"==(#if DT_BF16 )==""\n"
50R"==(DEF_ACC_DATA_T val = DATA_TO_REF(src[src_off]); )==""\n"
51R"==(#else )==""\n"
52R"==(float val = DATA_TO_REF(src[src_off]); )==""\n"
53R"==(#endif )==""\n"
54R"==(#if ALG_MAX )==""\n"
55R"==(if (val > dst_val) { )==""\n"
56R"==(dst_val = val; )==""\n"
57R"==(#if IS_TRAINING )==""\n"
58R"==(max_idx = id * IH * IW + ih * IW + iw; )==""\n"
59R"==(#endif )==""\n"
60R"==(} )==""\n"
61R"==(#else )==""\n"
62R"==(dst_val += val; )==""\n"
63R"==(#endif )==""\n"
64R"==(} )==""\n"
65R"==(} )==""\n"
66R"==(} )==""\n"
67R"==(#if ALG_MAX )==""\n"
68R"==(dst[dst_off] = TO_DST(dst_val); )==""\n"
69R"==(#if IS_TRAINING )==""\n"
70R"==(ws[dst_off] = max_idx; )==""\n"
71R"==(#endif )==""\n"
72R"==(#else )==""\n"
73R"==(dst[dst_off] = TO_DST(dst_val / ID / IH / IW); )==""\n"
74R"==(#endif )==""\n"
75R"==(} )==""\n"
76R"==(#endif )==""\n"
77R"==(#if IS_BWD )==""\n"
78R"==(KERNEL_ATTR )==""\n"
79R"==(__kernel void gen9_global_pooling_bwd(__global DATA_T *diff_src, )==""\n"
80R"==(__global int *ws, __global DATA_T *diff_dst) { )==""\n"
81R"==(const int mb = GWS_GET_MB(); )==""\n"
82R"==(const int c = GWS_GET_C(); )==""\n"
83R"==(const int spatial = GWS_GET_SPATIAL(); )==""\n"
84R"==(const bool is_in_padded_area = NEED_ZERO_PADDING && (mb >= MB || c >= C); )==""\n"
85R"==(const int dst_off = DST_OFF(mb, c, 0, 0, 0); )==""\n"
86R"==(#if ALG_AVG )==""\n"
87R"==(const DATA_T dst_val = diff_dst[dst_off]; )==""\n"
88R"==(#endif )==""\n"
89R"==(int ws_val = ws[dst_off]; )==""\n"
90R"==(for (int sp_idx = spatial; )==""\n"
91R"==(sp_idx < min(spatial + SPATIAL_CHUNK, SPATIAL_DIM); sp_idx++) { )==""\n"
92R"==(const int iw = sp_idx % IW; )==""\n"
93R"==(const int ih = ((sp_idx - iw) % (IH * IW)) / IW; )==""\n"
94R"==(const int id = (sp_idx - iw - ih * IW) / (IH * IW); )==""\n"
95R"==(DATA_T val_to_write; )==""\n"
96R"==(if (is_in_padded_area) )==""\n"
97R"==(val_to_write = DATA_ZERO; )==""\n"
98R"==(else { )==""\n"
99R"==(#if ALG_MAX )==""\n"
100R"==(const int current_input_idx = id * IH * IW + ih * IW + iw; )==""\n"
101R"==(if (current_input_idx == ws_val) { )==""\n"
102R"==(val_to_write = diff_dst[dst_off]; )==""\n"
103R"==(} else { )==""\n"
104R"==(val_to_write = DATA_ZERO; )==""\n"
105R"==(} )==""\n"
106R"==(#else )==""\n"
107R"==(float dst_val_f = DATA_TO_REF(dst_val) / SPATIAL_DIM; )==""\n"
108R"==(val_to_write = CONVERT_DATA_T(dst_val_f); )==""\n"
109R"==(#endif )==""\n"
110R"==(} )==""\n"
111R"==(const int src_off = SRC_OFF(mb, c, id, ih, iw); )==""\n"
112R"==(diff_src[src_off] = val_to_write; )==""\n"
113R"==(} )==""\n"
114R"==(} )==""\n"
115R"==(#endif )==""\n"
116R"==()==";
117}
118}
119}
120}