1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *xe_lp_conv_bwd_data_mb_block_x8s8x8_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_math_utils.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#define BLOCK_READ_DST(data, idx) \ )==""\n"
23R"==(data = AS_INT8_T( \ )==""\n"
24R"==(intel_sub_group_block_read8((__global uint *)&current_dst[idx])); )==""\n"
25R"==(#define BLOCK_READ_WHT(data, idx) \ )==""\n"
26R"==(data = as_int8(intel_sub_group_block_read8((__global uint *)&wei[idx])); )==""\n"
27R"==(#define BLOCK_READ_BIA(data, idx) \ )==""\n"
28R"==(data = as_float4(intel_sub_group_block_read4((__global uint *)&bias[idx])); )==""\n"
29R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
30R"==(__attribute__((reqd_work_group_size(LWS_0, LWS_1, LWS_2))) __kernel void )==""\n"
31R"==(conv_bwd_data_mb_block_x8s8x8(const __global uchar *src, )==""\n"
32R"==(const __global char *wei, const __global float *bias, )==""\n"
33R"==(__global DATA_T *dst) { )==""\n"
34R"==(const int mb_blocks = 2; )==""\n"
35R"==(const int group_ic = get_group_id(0) * IC_GROUP; )==""\n"
36R"==(const int group_mb = get_group_id(2) * MB_GROUP / mb_blocks; )==""\n"
37R"==(const int group_sp = get_group_id(1) * SP_GROUP; )==""\n"
38R"==(const int sub_group_id = get_sub_group_id(); )==""\n"
39R"==(const int mb = get_group_id(2) % mb_blocks; )==""\n"
40R"==(const int ic = (sub_group_id % IC_GROUP); )==""\n"
41R"==(const int sp = (sub_group_id / IC_GROUP); )==""\n"
42R"==(const int g = (group_ic + ic) / IC_NCHUNK; )==""\n"
43R"==(const int group_oc = OC_NCHUNK * g; )==""\n"
44R"==(const int gid = group_sp / (IW_PADDED * IH); )==""\n"
45R"==(const int gihw = group_sp % (IW_PADDED * IH); )==""\n"
46R"==(const int gih = gihw / IW_PADDED; )==""\n"
47R"==(const int giw = gihw % IW_PADDED; )==""\n"
48R"==(const int local_ih = sp / IW_PADDED; )==""\n"
49R"==(const int local_iw = sp % IW_PADDED; )==""\n"
50R"==(const int id = gid; )==""\n"
51R"==(const int iw = giw + local_iw; )==""\n"
52R"==(const int ih = gih + local_ih; )==""\n"
53R"==(if (iw >= IW) return; )==""\n"
54R"==(src += IC_BLOCK * ID * IH * IW * MB_BLOCK * (group_ic + ic); )==""\n"
55R"==(src += IC_BLOCK * ID * IH * IW * IC_NCHUNK * G * MB_BLOCK * group_mb; )==""\n"
56R"==(src += IC_BLOCK * MB_BLOCK / 2 * mb; )==""\n"
57R"==(src += IC_BLOCK * MB_BLOCK * (IW * IH * id + IW * ih + iw); )==""\n"
58R"==(dst += OC_BLOCK * OD * OH * OW * MB_BLOCK * group_oc; )==""\n"
59R"==(dst += OC_BLOCK * OD * OH * OW * OC_NCHUNK * G * MB_BLOCK * group_mb; )==""\n"
60R"==(dst += OC_BLOCK * MB_BLOCK / 2 * mb; )==""\n"
61R"==(wei += OC_BLOCK * KD * KH * KW * IC_BLOCK * (group_ic + ic) * OC_NCHUNK; )==""\n"
62R"==(int8 C00 = 0, C01 = 0, C02 = 0, C03 = 0; )==""\n"
63R"==(int8 C10 = 0, C11 = 0, C12 = 0, C13 = 0; )==""\n"
64R"==(__attribute__((opencl_unroll_hint)) for (int oc_chunk = 0; )==""\n"
65R"==(oc_chunk < OC_NCHUNK; oc_chunk++) { )==""\n"
66R"==(if (MB % MB_BLOCK != 0 )==""\n"
67R"==(&& (group_mb * MB_BLOCK + mb * MB_BLOCK / mb_blocks) >= MB) { )==""\n"
68R"==(break; )==""\n"
69R"==(} )==""\n"
70R"==(INT8_T D0, D1; )==""\n"
71R"==(int8 W0, W1, W2, W3; )==""\n"
72R"==(for (int kd = 0; kd < KD; kd++) { )==""\n"
73R"==(if ((id + PD - kd * (1 + DD)) % SD != 0) { )==""\n"
74R"==(wei += IC_BLOCK * OC_BLOCK * KH * KW; )==""\n"
75R"==(continue; )==""\n"
76R"==(} )==""\n"
77R"==(const int od = (id + PD - kd * (1 + DD)) / SD; )==""\n"
78R"==(if (od < 0 || od >= OD) { )==""\n"
79R"==(wei += IC_BLOCK * OC_BLOCK * KH * KW; )==""\n"
80R"==(continue; )==""\n"
81R"==(} )==""\n"
82R"==(for (int kh = 0; kh < KH; kh++) { )==""\n"
83R"==(if ((ih + PH - kh * (1 + DH)) % SH != 0) { )==""\n"
84R"==(wei += IC_BLOCK * OC_BLOCK * KW; )==""\n"
85R"==(continue; )==""\n"
86R"==(} )==""\n"
87R"==(const int oh = (ih + PH - kh * (1 + DH)) / SH; )==""\n"
88R"==(if (oh < 0 || oh >= OH) { )==""\n"
89R"==(wei += IC_BLOCK * OC_BLOCK * KW; )==""\n"
90R"==(continue; )==""\n"
91R"==(} )==""\n"
92R"==(__attribute__((opencl_unroll_hint)) for (int kw = 0; kw < KW; )==""\n"
93R"==(kw++) { )==""\n"
94R"==(if ((iw + PW - kw * (1 + DW)) % SW == 0) { )==""\n"
95R"==(const int ow = (iw + PW - kw * (1 + DW)) / SW; )==""\n"
96R"==(if (ow >= 0 && ow < OW) { )==""\n"
97R"==(__global DATA_T *current_dst = dst )==""\n"
98R"==(+ OC_BLOCK * MB_BLOCK )==""\n"
99R"==(* (OW * OH * od + OW * oh + ow); )==""\n"
100R"==(BLOCK_READ_DST(D0, 0); )==""\n"
101R"==(#if MB > 8 )==""\n"
102R"==(BLOCK_READ_DST(D1, 8 * IC_BLOCK); )==""\n"
103R"==(#endif )==""\n"
104R"==(BLOCK_READ_WHT(W0, 0); )==""\n"
105R"==(BLOCK_READ_WHT(W1, 8 * IC_BLOCK); )==""\n"
106R"==(BLOCK_READ_WHT(W2, 16 * IC_BLOCK); )==""\n"
107R"==(BLOCK_READ_WHT(W3, 24 * IC_BLOCK); )==""\n"
108R"==(C00 = mmad8x8(D0, W0, C00); )==""\n"
109R"==(C01 = mmad8x8(D0, W1, C01); )==""\n"
110R"==(C02 = mmad8x8(D0, W2, C02); )==""\n"
111R"==(C03 = mmad8x8(D0, W3, C03); )==""\n"
112R"==(#if MB > 8 )==""\n"
113R"==(C10 = mmad8x8(D1, W0, C10); )==""\n"
114R"==(C11 = mmad8x8(D1, W1, C11); )==""\n"
115R"==(C12 = mmad8x8(D1, W2, C12); )==""\n"
116R"==(C13 = mmad8x8(D1, W3, C13); )==""\n"
117R"==(#endif )==""\n"
118R"==(} )==""\n"
119R"==(} )==""\n"
120R"==(wei += IC_BLOCK * OC_BLOCK; )==""\n"
121R"==(} )==""\n"
122R"==(} )==""\n"
123R"==(} )==""\n"
124R"==(dst += OC_BLOCK * MB_BLOCK * OD * OH * OW; )==""\n"
125R"==(} )==""\n"
126R"==(#if WITH_BIAS )==""\n"
127R"==(#define BIAS_SUM_RELU(RES, TMP, ACC, BIA, DST) \ )==""\n"
128R"==(TMP = (float)ACC + BIA; \ )==""\n"
129R"==(RES = TO_SRC(TMP); )==""\n"
130R"==(#else )==""\n"
131R"==(#define BIAS_SUM_RELU(RES, TMP, ACC, BIA, DST) RES = TO_SRC((float)ACC); )==""\n"
132R"==(#endif )==""\n"
133R"==(#define PACK(idx) \ )==""\n"
134R"==(BIAS_SUM_RELU(D00[0], T00, C00[idx], b0, S00[0]); \ )==""\n"
135R"==(BIAS_SUM_RELU(D00[1], T01, C01[idx], b1, S00[1]); \ )==""\n"
136R"==(BIAS_SUM_RELU(D00[2], T02, C02[idx], b2, S00[2]); \ )==""\n"
137R"==(BIAS_SUM_RELU(D00[3], T03, C03[idx], b3, S00[3]); \ )==""\n"
138R"==(T0[idx] = as_uint(D00); \ )==""\n"
139R"==(BIAS_SUM_RELU(D01[0], T10, C10[idx], b0, S01[0]); \ )==""\n"
140R"==(BIAS_SUM_RELU(D01[1], T11, C11[idx], b1, S01[1]); \ )==""\n"
141R"==(BIAS_SUM_RELU(D01[2], T12, C12[idx], b2, S01[2]); \ )==""\n"
142R"==(BIAS_SUM_RELU(D01[3], T13, C13[idx], b3, S01[3]); \ )==""\n"
143R"==(T1[idx] = as_uint(D01); )==""\n"
144R"==(#if WITH_BIAS )==""\n"
145R"==(float4 bia; )==""\n"
146R"==(BLOCK_READ_BIA(bia, (group_ic + ic) * IC_BLOCK); )==""\n"
147R"==(float b0 = bia[0]; )==""\n"
148R"==(float b1 = bia[1]; )==""\n"
149R"==(float b2 = bia[2]; )==""\n"
150R"==(float b3 = bia[3]; )==""\n"
151R"==(#endif )==""\n"
152R"==(uchar4 D00, D01; )==""\n"
153R"==(uint8 T0, T1; )==""\n"
154R"==(float T00, T01, T02, T03; )==""\n"
155R"==(float T10, T11, T12, T13; )==""\n"
156R"==(PACK(0); )==""\n"
157R"==(PACK(1); )==""\n"
158R"==(PACK(2); )==""\n"
159R"==(PACK(3); )==""\n"
160R"==(PACK(4); )==""\n"
161R"==(PACK(5); )==""\n"
162R"==(PACK(6); )==""\n"
163R"==(PACK(7); )==""\n"
164R"==(intel_sub_group_block_write_uc16( )==""\n"
165R"==((__global uchar *)&src[0 * IC_BLOCK], as_uchar16(T0.s0123)); )==""\n"
166R"==(intel_sub_group_block_write_uc16( )==""\n"
167R"==((__global uchar *)&src[4 * IC_BLOCK], as_uchar16(T0.s4567)); )==""\n"
168R"==(intel_sub_group_block_write_uc16( )==""\n"
169R"==((__global uchar *)&src[8 * IC_BLOCK], as_uchar16(T1.s0123)); )==""\n"
170R"==(intel_sub_group_block_write_uc16( )==""\n"
171R"==((__global uchar *)&src[12 * IC_BLOCK], as_uchar16(T1.s4567)); )==""\n"
172R"==(} )==""\n"
173R"==()==";
174}
175}
176}
177}