1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *xe_lp_x8s8x_compensation_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2021 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_zero_points.h" )==""\n"
22R"==(#if WEI_4O8I8O4I )==""\n"
23R"==(#define OCB ((OC + 31) / 32) )==""\n"
24R"==(#define ICB ((IC + 31) / 32) )==""\n"
25R"==(#define KDHW (KD * KH * KW) )==""\n"
26R"==(#define WEI_BLOCK (32 * 32) )==""\n"
27R"==(__attribute__((intel_reqd_sub_group_size(8))) )==""\n"
28R"==(__attribute__((reqd_work_group_size(8, 1, 1))) __kernel void )==""\n"
29R"==(xe_lp_x8s8x_compensation(const __global int *src_zpoints, )==""\n"
30R"==(const __global char *wei, __global int *dst) { )==""\n"
31R"==(const int oc_block_idx = get_global_id(1); )==""\n"
32R"==(const int g = get_global_id(2); )==""\n"
33R"==(wei += g * OCB * ICB * KDHW * WEI_BLOCK; )==""\n"
34R"==(wei += oc_block_idx * ICB * KDHW * WEI_BLOCK; )==""\n"
35R"==(dst += g * OCB * 32; )==""\n"
36R"==(dst += oc_block_idx * 32; )==""\n"
37R"==(#if WITH_SRC_ZPOINTS_PER_IC )==""\n"
38R"==(src_zpoints += g * IC; )==""\n"
39R"==(#else )==""\n"
40R"==(const int z = read_src_zero_point(src_zpoints); )==""\n"
41R"==(#endif )==""\n"
42R"==(int4 acc = 0; )==""\n"
43R"==(for (uint icb = 0; icb < ICB; ++icb) { )==""\n"
44R"==(#if WITH_SRC_ZPOINTS_PER_IC )==""\n"
45R"==(const int4 z = read_src_zero_points_32c(src_zpoints, icb * IC_BLOCK); )==""\n"
46R"==(#endif )==""\n"
47R"==(for (uint k = 0; k < KDHW; ++k) { )==""\n"
48R"==(const int8 w0 = as_int8(intel_sub_group_block_read8( )==""\n"
49R"==((__global uint *)(wei + 0 * IC_BLOCK))); )==""\n"
50R"==(const int8 w1 = as_int8(intel_sub_group_block_read8( )==""\n"
51R"==((__global uint *)(wei + 8 * IC_BLOCK))); )==""\n"
52R"==(const int8 w2 = as_int8(intel_sub_group_block_read8( )==""\n"
53R"==((__global uint *)(wei + 16 * IC_BLOCK))); )==""\n"
54R"==(const int8 w3 = as_int8(intel_sub_group_block_read8( )==""\n"
55R"==((__global uint *)(wei + 24 * IC_BLOCK))); )==""\n"
56R"==(#if WITH_SRC_ZPOINTS_PER_IC )==""\n"
57R"==(acc.s0 += calc_src_compensation_x32(z, w0); )==""\n"
58R"==(acc.s1 += calc_src_compensation_x32(z, w1); )==""\n"
59R"==(acc.s2 += calc_src_compensation_x32(z, w2); )==""\n"
60R"==(acc.s3 += calc_src_compensation_x32(z, w3); )==""\n"
61R"==(#else )==""\n"
62R"==(unroll_for(uint i = 0; i < 8; ++i) { )==""\n"
63R"==(acc.s0 = idot4(0x01010101, w0[i], acc.s0); )==""\n"
64R"==(acc.s1 = idot4(0x01010101, w1[i], acc.s1); )==""\n"
65R"==(acc.s2 = idot4(0x01010101, w2[i], acc.s2); )==""\n"
66R"==(acc.s3 = idot4(0x01010101, w3[i], acc.s3); )==""\n"
67R"==(} )==""\n"
68R"==(#endif )==""\n"
69R"==(wei += WEI_BLOCK; )==""\n"
70R"==(} )==""\n"
71R"==(} )==""\n"
72R"==(#if !WITH_SRC_ZPOINTS_PER_IC )==""\n"
73R"==(acc = z * acc; )==""\n"
74R"==(#endif )==""\n"
75R"==(intel_sub_group_block_write4((__global uint *)(dst), as_uint4(acc)); )==""\n"
76R"==(} )==""\n"
77R"==(#endif )==""\n"
78R"==(#if WEI_32G )==""\n"
79R"==(#define KDHW (KD * KH * KW) )==""\n"
80R"==(#define WEI_BLOCK 32 )==""\n"
81R"==(__attribute__((intel_reqd_sub_group_size(16))) )==""\n"
82R"==(__attribute__((reqd_work_group_size(16, 1, 1))) __kernel void )==""\n"
83R"==(xe_lp_x8s8x_compensation(const __global int *src_zpoints, )==""\n"
84R"==(const __global char *wei, __global int *dst) { )==""\n"
85R"==(const int g_block_idx = get_global_id(1); )==""\n"
86R"==(wei += g_block_idx * KDHW * WEI_BLOCK; )==""\n"
87R"==(dst += g_block_idx * WEI_BLOCK; )==""\n"
88R"==(#if WITH_SRC_ZPOINTS_PER_IC )==""\n"
89R"==(const int2 z )==""\n"
90R"==(= read_src_zero_points_32g(src_zpoints, g_block_idx * WEI_BLOCK); )==""\n"
91R"==(#else )==""\n"
92R"==(const int z = read_src_zero_point(src_zpoints); )==""\n"
93R"==(#endif )==""\n"
94R"==(int2 acc = 0; )==""\n"
95R"==(#if WITH_SRC_ZPOINTS_PER_IC )==""\n"
96R"==(#endif )==""\n"
97R"==(for (uint k = 0; k < KDHW; ++k) { )==""\n"
98R"==(const int2 w0 = convert_int2(as_char2( )==""\n"
99R"==(intel_sub_group_block_read_uc2((const __global uchar *)(wei)))); )==""\n"
100R"==(acc += w0; )==""\n"
101R"==(wei += WEI_BLOCK; )==""\n"
102R"==(} )==""\n"
103R"==(acc = z * acc; )==""\n"
104R"==(intel_sub_group_block_write2((__global uint *)(dst), as_uint2(acc)); )==""\n"
105R"==(} )==""\n"
106R"==(#endif )==""\n"
107R"==()==";
108}
109}
110}
111}