1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *gen9_global_pooling_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2021-2022 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
21 | R"==(#define ALG_AVG (ALG_AVG_NP || ALG_AVG_P) )==" "\n" |
22 | R"==(#if IS_FWD )==" "\n" |
23 | R"==(KERNEL_ATTR )==" "\n" |
24 | R"==(__kernel void gen9_global_pooling_fwd( )==" "\n" |
25 | R"==(__global DATA_T *src, __global int *ws, __global DST_DATA_T *dst) { )==" "\n" |
26 | R"==(const int mb = get_global_id(0) / C; )==" "\n" |
27 | R"==(const int oc = get_global_id(0) % C; )==" "\n" |
28 | R"==(const uint dst_off = DST_OFF(mb, oc, 0, 0, 0); )==" "\n" |
29 | R"==(#if ALG_MAX )==" "\n" |
30 | R"==(#if DT_BF16 )==" "\n" |
31 | R"==(DEF_ACC_DATA_T dst_val = DATA_TO_REF(src[0]); )==" "\n" |
32 | R"==(#else )==" "\n" |
33 | R"==(float dst_val = src[0]; )==" "\n" |
34 | R"==(#endif )==" "\n" |
35 | R"==(#if IS_TRAINING )==" "\n" |
36 | R"==(int max_idx = -1; )==" "\n" |
37 | R"==(#endif )==" "\n" |
38 | R"==(#else )==" "\n" |
39 | R"==(#if DT_BF16 )==" "\n" |
40 | R"==(DEF_ACC_DATA_T dst_val = DATA_TO_REF(0.f); )==" "\n" |
41 | R"==(#else )==" "\n" |
42 | R"==(float dst_val = 0.f; )==" "\n" |
43 | R"==(#endif )==" "\n" |
44 | R"==(#endif )==" "\n" |
45 | R"==(for (int id = 0; id < ID; id++) { )==" "\n" |
46 | R"==(for (int ih = 0; ih < IH; ih++) { )==" "\n" |
47 | R"==(for (int iw = 0; iw < IW; iw++) { )==" "\n" |
48 | R"==(uint src_off = SRC_OFF(mb, oc, id, ih, iw); )==" "\n" |
49 | R"==(#if DT_BF16 )==" "\n" |
50 | R"==(DEF_ACC_DATA_T val = DATA_TO_REF(src[src_off]); )==" "\n" |
51 | R"==(#else )==" "\n" |
52 | R"==(float val = DATA_TO_REF(src[src_off]); )==" "\n" |
53 | R"==(#endif )==" "\n" |
54 | R"==(#if ALG_MAX )==" "\n" |
55 | R"==(if (val > dst_val) { )==" "\n" |
56 | R"==(dst_val = val; )==" "\n" |
57 | R"==(#if IS_TRAINING )==" "\n" |
58 | R"==(max_idx = id * IH * IW + ih * IW + iw; )==" "\n" |
59 | R"==(#endif )==" "\n" |
60 | R"==(} )==" "\n" |
61 | R"==(#else )==" "\n" |
62 | R"==(dst_val += val; )==" "\n" |
63 | R"==(#endif )==" "\n" |
64 | R"==(} )==" "\n" |
65 | R"==(} )==" "\n" |
66 | R"==(} )==" "\n" |
67 | R"==(#if ALG_MAX )==" "\n" |
68 | R"==(dst[dst_off] = TO_DST(dst_val); )==" "\n" |
69 | R"==(#if IS_TRAINING )==" "\n" |
70 | R"==(ws[dst_off] = max_idx; )==" "\n" |
71 | R"==(#endif )==" "\n" |
72 | R"==(#else )==" "\n" |
73 | R"==(dst[dst_off] = TO_DST(dst_val / ID / IH / IW); )==" "\n" |
74 | R"==(#endif )==" "\n" |
75 | R"==(} )==" "\n" |
76 | R"==(#endif )==" "\n" |
77 | R"==(#if IS_BWD )==" "\n" |
78 | R"==(KERNEL_ATTR )==" "\n" |
79 | R"==(__kernel void gen9_global_pooling_bwd(__global DATA_T *diff_src, )==" "\n" |
80 | R"==(__global int *ws, __global DATA_T *diff_dst) { )==" "\n" |
81 | R"==(const int mb = GWS_GET_MB(); )==" "\n" |
82 | R"==(const int c = GWS_GET_C(); )==" "\n" |
83 | R"==(const int spatial = GWS_GET_SPATIAL(); )==" "\n" |
84 | R"==(const bool is_in_padded_area = NEED_ZERO_PADDING && (mb >= MB || c >= C); )==" "\n" |
85 | R"==(const int dst_off = DST_OFF(mb, c, 0, 0, 0); )==" "\n" |
86 | R"==(#if ALG_AVG )==" "\n" |
87 | R"==(const DATA_T dst_val = diff_dst[dst_off]; )==" "\n" |
88 | R"==(#endif )==" "\n" |
89 | R"==(int ws_val = ws[dst_off]; )==" "\n" |
90 | R"==(for (int sp_idx = spatial; )==" "\n" |
91 | R"==(sp_idx < min(spatial + SPATIAL_CHUNK, SPATIAL_DIM); sp_idx++) { )==" "\n" |
92 | R"==(const int iw = sp_idx % IW; )==" "\n" |
93 | R"==(const int ih = ((sp_idx - iw) % (IH * IW)) / IW; )==" "\n" |
94 | R"==(const int id = (sp_idx - iw - ih * IW) / (IH * IW); )==" "\n" |
95 | R"==(DATA_T val_to_write; )==" "\n" |
96 | R"==(if (is_in_padded_area) )==" "\n" |
97 | R"==(val_to_write = DATA_ZERO; )==" "\n" |
98 | R"==(else { )==" "\n" |
99 | R"==(#if ALG_MAX )==" "\n" |
100 | R"==(const int current_input_idx = id * IH * IW + ih * IW + iw; )==" "\n" |
101 | R"==(if (current_input_idx == ws_val) { )==" "\n" |
102 | R"==(val_to_write = diff_dst[dst_off]; )==" "\n" |
103 | R"==(} else { )==" "\n" |
104 | R"==(val_to_write = DATA_ZERO; )==" "\n" |
105 | R"==(} )==" "\n" |
106 | R"==(#else )==" "\n" |
107 | R"==(float dst_val_f = DATA_TO_REF(dst_val) / SPATIAL_DIM; )==" "\n" |
108 | R"==(val_to_write = CONVERT_DATA_T(dst_val_f); )==" "\n" |
109 | R"==(#endif )==" "\n" |
110 | R"==(} )==" "\n" |
111 | R"==(const int src_off = SRC_OFF(mb, c, id, ih, iw); )==" "\n" |
112 | R"==(diff_src[src_off] = val_to_write; )==" "\n" |
113 | R"==(} )==" "\n" |
114 | R"==(} )==" "\n" |
115 | R"==(#endif )==" "\n" |
116 | R"==()==" ; |
117 | } |
118 | } |
119 | } |
120 | } |